You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
61 lines
2.2 KiB
Python
61 lines
2.2 KiB
Python
"""
|
|
python export_tensorflow.py \
|
|
--model-variant mobilenetv3 \
|
|
--model-variant deep_guided_filter \
|
|
--pytorch-checkpoint rvm_mobilenetv3.pth \
|
|
--tensorflow-output rvm_mobilenetv3_tf
|
|
"""
|
|
|
|
import argparse
|
|
import torch
|
|
import tensorflow as tf
|
|
|
|
from model import MattingNetwork, load_torch_weights
|
|
|
|
|
|
# Add input output names and shapes
|
|
class MattingNetworkWrapper(MattingNetwork):
|
|
@tf.function(input_signature=[[
|
|
tf.TensorSpec(tf.TensorShape([None, None, None, 3]), tf.float32, 'src'),
|
|
tf.TensorSpec(tf.TensorShape(None), tf.float32, 'r1i'),
|
|
tf.TensorSpec(tf.TensorShape(None), tf.float32, 'r2i'),
|
|
tf.TensorSpec(tf.TensorShape(None), tf.float32, 'r3i'),
|
|
tf.TensorSpec(tf.TensorShape(None), tf.float32, 'r4i'),
|
|
tf.TensorSpec(tf.TensorShape(None), tf.float32, 'downsample_ratio')
|
|
]])
|
|
def call(self, inputs):
|
|
fgr, pha, r1o, r2o, r3o, r4o = super().call(inputs)
|
|
return {'fgr': fgr, 'pha': pha, 'r1o': r1o, 'r2o': r2o, 'r3o': r3o, 'r4o': r4o}
|
|
|
|
|
|
class Exporter:
|
|
def __init__(self):
|
|
self.parse_args()
|
|
self.init_model()
|
|
self.export()
|
|
|
|
def parse_args(self):
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--model-variant', type=str, required=True, choices=['mobilenetv3', 'resnet50'])
|
|
parser.add_argument('--model-refiner', type=str, default='deep_guided_filter', choices=['deep_guided_filter', 'fast_guided_filter'])
|
|
parser.add_argument('--pytorch-checkpoint', type=str, required=True)
|
|
parser.add_argument('--tensorflow-output', type=str, required=True)
|
|
self.args = parser.parse_args()
|
|
|
|
def init_model(self):
|
|
# Construct model
|
|
self.model = MattingNetworkWrapper(self.args.model_variant, self.args.model_refiner)
|
|
# Build model
|
|
src = tf.random.normal([1, 1080, 1920, 3])
|
|
rec = [ tf.constant(0.) ] * 4
|
|
downsample_ratio = tf.constant(0.25)
|
|
self.model([src, *rec, downsample_ratio])
|
|
# Load PyTorch checkpoint
|
|
load_torch_weights(self.model, torch.load(self.args.pytorch_checkpoint, map_location='cpu'))
|
|
|
|
def export(self):
|
|
self.model.save(self.args.tensorflow_output)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
Exporter() |