Add CoreML export code
commit
b485090534
@ -0,0 +1,29 @@
|
||||
# Export CoreML
|
||||
|
||||
## Overview
|
||||
|
||||
This branch contains our code to export to CoreML models. The `/model` folder is the same as the `master` branch. The main exporting logics are in `export_coreml.py`.
|
||||
|
||||
At the time of this writing, CoreML's `ResizeBilinear` and `Upsample` ops don't not support dynamic scale parameters, so the `downsample_ratio` hyperparameter must be hardcoded.
|
||||
|
||||
Our export script is written to have input size fixed. The output coreml models require iOS14+, MacOS11+. If you have other requirements, feel free to modify the export script. Contributions are welcomed.
|
||||
|
||||
## Export Yourself
|
||||
|
||||
The following procedures were used to generate our CoreML models.
|
||||
|
||||
1. Install dependencies
|
||||
```sh
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Use the export script. You can change the `resolution` and `downsample-ratio` to fit your need. You can change quantization to one of `[8, 16, 32]`, denoting `int8`, `fp16`, and `fp32`.
|
||||
```sh
|
||||
python export_coreml.py \
|
||||
--model-variant mobilenetv3 \
|
||||
--checkpoint rvm_mobilenetv3.pth \
|
||||
--resolution 1920 1080 \
|
||||
--downsample-ratio 0.25 \
|
||||
--quantize-nbits 16 \
|
||||
--output model.mlmodel
|
||||
```
|
@ -0,0 +1,192 @@
|
||||
"""
|
||||
python export_coreml.py \
|
||||
--model-variant mobilenetv3 \
|
||||
--checkpoint rvm_mobilenetv3.pth \
|
||||
--resolution 1920 1080 \
|
||||
--downsample-ratio 0.25 \
|
||||
--quantize-nbits 16 \
|
||||
--output model.mlmodel
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import coremltools as ct
|
||||
import torch
|
||||
|
||||
from coremltools.models.neural_network.quantization_utils import quantize_weights
|
||||
from coremltools.converters.mil.mil import Builder as mb
|
||||
from coremltools.converters.mil.frontend.torch.torch_op_registry import register_torch_op
|
||||
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
|
||||
from coremltools.proto import FeatureTypes_pb2 as ft
|
||||
|
||||
from model import MattingNetwork
|
||||
|
||||
class Exporter:
|
||||
def __init__(self):
|
||||
self.parse_args()
|
||||
self.init_model()
|
||||
self.register_custom_ops()
|
||||
self.export()
|
||||
|
||||
def parse_args(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model-variant', type=str, required=True, choices=['mobilenetv3', 'resnet50'])
|
||||
parser.add_argument('--model-refiner', type=str, default='deep_guided_filter', choices=['deep_guided_filter', 'fast_guided_filter'])
|
||||
parser.add_argument('--checkpoint', type=str, required=False)
|
||||
parser.add_argument('--resolution', type=int, required=True, nargs=2)
|
||||
parser.add_argument('--downsample-ratio', type=float, required=True)
|
||||
parser.add_argument('--quantize-nbits', type=int, required=True, choices=[8, 16, 32])
|
||||
parser.add_argument('--output', type=str, required=True)
|
||||
self.args = parser.parse_args()
|
||||
|
||||
def init_model(self):
|
||||
downsample_ratio = self.args.downsample_ratio
|
||||
class Wrapper(MattingNetwork):
|
||||
def forward(self, src, r1=None, r2=None, r3=None, r4=None):
|
||||
# Hardcode downsample_ratio into the network instead of taking it as input. This is needed for torchscript tracing.
|
||||
# Also, we are multiply result by 255 to convert them to CoreML image format
|
||||
fgr, pha, r1, r2, r3, r4 = super().forward(src, r1, r2, r3, r4, downsample_ratio)
|
||||
return fgr.mul(255), pha.mul(255), r1, r2, r3, r4
|
||||
|
||||
self.model = Wrapper(self.args.model_variant, self.args.model_refiner).eval()
|
||||
if self.args.checkpoint is not None:
|
||||
self.model.load_state_dict(torch.load(self.args.checkpoint, map_location='cpu'), strict=False)
|
||||
|
||||
def register_custom_ops(self):
|
||||
@register_torch_op(override=True)
|
||||
def hardswish_(context, node):
|
||||
inputs = _get_inputs(context, node, expected=1)
|
||||
x = inputs[0]
|
||||
y = mb.sigmoid_hard(x=inputs[0], alpha=1.0/6, beta=0.5)
|
||||
z = mb.mul(x=x, y=y, name=node.name)
|
||||
context.add(z)
|
||||
|
||||
@register_torch_op(override=True)
|
||||
def hardsigmoid_(context, node):
|
||||
inputs = _get_inputs(context, node, expected=1)
|
||||
res = mb.sigmoid_hard(x=inputs[0], alpha=1.0/6, beta=0.5, name=node.name)
|
||||
context.add(res)
|
||||
|
||||
@register_torch_op(override=True)
|
||||
def type_as(context, node):
|
||||
inputs = _get_inputs(context, node)
|
||||
context.add(mb.cast(x=inputs[0], dtype='fp32'), node.name)
|
||||
|
||||
@register_torch_op(override=True)
|
||||
def upsample_bilinear2d(context, node):
|
||||
# Change to use `resize_bilinear` instead to support iOS 13.
|
||||
inputs = _get_inputs(context, node)
|
||||
x = inputs[0]
|
||||
output_size = inputs[1]
|
||||
align_corners = bool(inputs[2].val)
|
||||
scale_factors = inputs[3]
|
||||
|
||||
if scale_factors is not None and scale_factors.val is not None \
|
||||
and scale_factors.rank == 1 and scale_factors.shape[0] == 2:
|
||||
scale_factors = scale_factors.val
|
||||
resize = mb.resize_bilinear(
|
||||
x=x,
|
||||
target_size_height=int(x.shape[-2] * scale_factors[0]),
|
||||
target_size_width=int(x.shape[-1] * scale_factors[1]),
|
||||
sampling_mode='ALIGN_CORNERS',
|
||||
name=node.name,
|
||||
)
|
||||
context.add(resize)
|
||||
else:
|
||||
resize = mb.resize_bilinear(
|
||||
x=x,
|
||||
target_size_height=output_size.val[0],
|
||||
target_size_width=output_size.val[1],
|
||||
sampling_mode='ALIGN_CORNERS',
|
||||
name=node.name,
|
||||
)
|
||||
context.add(resize)
|
||||
|
||||
def export(self):
|
||||
src = torch.zeros([1, 3, *self.args.resolution[::-1]])
|
||||
_, _, r1, r2, r3, r4 = self.model(src)
|
||||
|
||||
model_traced = torch.jit.trace(self.model, (src, r1, r2, r3, r4))
|
||||
model_coreml = ct.convert(
|
||||
model_traced,
|
||||
inputs=[
|
||||
ct.ImageType(name='src', shape=(ct.RangeDim(), *src.shape[1:]), channel_first=True, scale=1/255),
|
||||
ct.TensorType(name='r1i', shape=(ct.RangeDim(), *r1.shape[1:])),
|
||||
ct.TensorType(name='r2i', shape=(ct.RangeDim(), *r2.shape[1:])),
|
||||
ct.TensorType(name='r3i', shape=(ct.RangeDim(), *r3.shape[1:])),
|
||||
ct.TensorType(name='r4i', shape=(ct.RangeDim(), *r4.shape[1:])),
|
||||
],
|
||||
)
|
||||
|
||||
if self.args.quantize_nbits in [8, 16]:
|
||||
out = quantize_weights(model_coreml, nbits=self.args.quantize_nbits)
|
||||
if isinstance(out, ct.models.model.MLModel):
|
||||
# When the export is done on OSX, return is an mlmodel.
|
||||
spec = out.get_spec()
|
||||
else:
|
||||
# When the export is done on Linux, the return is a spec.
|
||||
spec = out
|
||||
else:
|
||||
spec = model_coreml.get_spec()
|
||||
|
||||
# Some internal outputs are also named 'fgr' and 'pha'.
|
||||
# We change them to avoid conflicts.
|
||||
for layer in spec.neuralNetwork.layers:
|
||||
for i in range(len(layer.input)):
|
||||
if layer.input[i] == 'fgr':
|
||||
layer.input[i] = 'fgr_internal'
|
||||
if layer.input[i] == 'pha':
|
||||
layer.input[i] = 'pha_internal'
|
||||
for i in range(len(layer.output)):
|
||||
if layer.output[i] == 'fgr':
|
||||
layer.output[i] = 'fgr_internal'
|
||||
if layer.output[i] == 'pha':
|
||||
layer.output[i] = 'pha_internal'
|
||||
|
||||
# Update output names
|
||||
ct.utils.rename_feature(spec, spec.description.output[0].name, 'fgr')
|
||||
ct.utils.rename_feature(spec, spec.description.output[1].name, 'pha')
|
||||
ct.utils.rename_feature(spec, spec.description.output[2].name, 'r1o')
|
||||
ct.utils.rename_feature(spec, spec.description.output[3].name, 'r2o')
|
||||
ct.utils.rename_feature(spec, spec.description.output[4].name, 'r3o')
|
||||
ct.utils.rename_feature(spec, spec.description.output[5].name, 'r4o')
|
||||
|
||||
# Update model description
|
||||
spec.description.metadata.author = 'Shanchuan Lin'
|
||||
spec.description.metadata.shortDescription = 'A robust human video matting model with recurrent architecture. The model has recurrent states that must be passed to subsequent frames. Please refer to paper "Robust High-Resolution Video Matting with Temporal Guidance" for more details.'
|
||||
spec.description.metadata.license = 'Apache License 2.0'
|
||||
spec.description.metadata.versionString = '1.0.0'
|
||||
spec.description.input[0].shortDescription = 'Source frame'
|
||||
spec.description.input[1].shortDescription = 'Recurrent state 1. Initial state is an all zero tensor. Subsequent state is received from r1o.'
|
||||
spec.description.input[2].shortDescription = 'Recurrent state 2. Initial state is an all zero tensor. Subsequent state is received from r2o.'
|
||||
spec.description.input[3].shortDescription = 'Recurrent state 3. Initial state is an all zero tensor. Subsequent state is received from r3o.'
|
||||
spec.description.input[4].shortDescription = 'Recurrent state 4. Initial state is an all zero tensor. Subsequent state is received from r4o.'
|
||||
spec.description.output[0].shortDescription = 'Foreground prediction'
|
||||
spec.description.output[1].shortDescription = 'Alpha prediction'
|
||||
spec.description.output[2].shortDescription = 'Recurrent state 1. Needs to be passed as r1i input in the next time step.'
|
||||
spec.description.output[3].shortDescription = 'Recurrent state 2. Needs to be passed as r2i input in the next time step.'
|
||||
spec.description.output[4].shortDescription = 'Recurrent state 3. Needs to be passed as r3i input in the next time step.'
|
||||
spec.description.output[5].shortDescription = 'Recurrent state 4. Needs to be passed as r4i input in the next time step.'
|
||||
|
||||
# Update output types
|
||||
spec.description.output[0].type.imageType.colorSpace = ft.ImageFeatureType.RGB
|
||||
spec.description.output[0].type.imageType.width = src.size(3)
|
||||
spec.description.output[0].type.imageType.height = src.size(2)
|
||||
spec.description.output[1].type.imageType.colorSpace = ft.ImageFeatureType.GRAYSCALE
|
||||
spec.description.output[1].type.imageType.width = src.size(3)
|
||||
spec.description.output[1].type.imageType.height = src.size(2)
|
||||
|
||||
# Set recurrent states as optional inputs
|
||||
spec.description.input[1].type.isOptional = True
|
||||
spec.description.input[2].type.isOptional = True
|
||||
spec.description.input[3].type.isOptional = True
|
||||
spec.description.input[4].type.isOptional = True
|
||||
|
||||
# Save output
|
||||
ct.utils.save_spec(spec, self.args.output)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
Exporter()
|
@ -0,0 +1 @@
|
||||
from .model import MattingNetwork
|
@ -0,0 +1,217 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from typing import Tuple, Optional
|
||||
|
||||
class RecurrentDecoder(nn.Module):
|
||||
def __init__(self, feature_channels, decoder_channels):
|
||||
super().__init__()
|
||||
self.avgpool = AvgPool()
|
||||
self.decode4 = BottleneckBlock(feature_channels[3])
|
||||
self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 3, decoder_channels[0])
|
||||
self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 3, decoder_channels[1])
|
||||
self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 3, decoder_channels[2])
|
||||
self.decode0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3])
|
||||
|
||||
def forward(self,
|
||||
s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, f4: Tensor,
|
||||
r1: Optional[Tensor], r2: Optional[Tensor],
|
||||
r3: Optional[Tensor], r4: Optional[Tensor]):
|
||||
s1, s2, s3 = self.avgpool(s0)
|
||||
x4, r4 = self.decode4(f4, r4)
|
||||
x3, r3 = self.decode3(x4, f3, s3, r3)
|
||||
x2, r2 = self.decode2(x3, f2, s2, r2)
|
||||
x1, r1 = self.decode1(x2, f1, s1, r1)
|
||||
x0 = self.decode0(x1, s0)
|
||||
return x0, r1, r2, r3, r4
|
||||
|
||||
|
||||
class AvgPool(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.avgpool = nn.AvgPool2d(2, 2, count_include_pad=False, ceil_mode=True)
|
||||
|
||||
def forward_single_frame(self, s0):
|
||||
s1 = self.avgpool(s0)
|
||||
s2 = self.avgpool(s1)
|
||||
s3 = self.avgpool(s2)
|
||||
return s1, s2, s3
|
||||
|
||||
def forward_time_series(self, s0):
|
||||
B, T = s0.shape[:2]
|
||||
s0 = s0.flatten(0, 1)
|
||||
s1, s2, s3 = self.forward_single_frame(s0)
|
||||
s1 = s1.unflatten(0, (B, T))
|
||||
s2 = s2.unflatten(0, (B, T))
|
||||
s3 = s3.unflatten(0, (B, T))
|
||||
return s1, s2, s3
|
||||
|
||||
def forward(self, s0):
|
||||
if s0.ndim == 5:
|
||||
return self.forward_time_series(s0)
|
||||
else:
|
||||
return self.forward_single_frame(s0)
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.gru = ConvGRU(channels // 2)
|
||||
|
||||
def forward(self, x, r: Optional[Tensor]):
|
||||
a, b = x.split(self.channels // 2, dim=-3)
|
||||
b, r = self.gru(b, r)
|
||||
x = torch.cat([a, b], dim=-3)
|
||||
return x, r
|
||||
|
||||
|
||||
class UpsamplingBlock(nn.Module):
|
||||
def __init__(self, in_channels, skip_channels, src_channels, out_channels):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels + skip_channels + src_channels, out_channels, 3, 1, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
self.gru = ConvGRU(out_channels // 2)
|
||||
|
||||
def forward_single_frame(self, x, f, s, r: Optional[Tensor]):
|
||||
x = self.upsample(x)
|
||||
# Optimized for CoreML export
|
||||
if x.size(2) != s.size(2):
|
||||
x = x[:, :, :s.size(2), :]
|
||||
if x.size(3) != s.size(3):
|
||||
x = x[:, :, :, :s.size(3)]
|
||||
x = torch.cat([x, f, s], dim=1)
|
||||
x = self.conv(x)
|
||||
a, b = x.split(self.out_channels // 2, dim=1)
|
||||
b, r = self.gru(b, r)
|
||||
x = torch.cat([a, b], dim=1)
|
||||
return x, r
|
||||
|
||||
def forward_time_series(self, x, f, s, r: Optional[Tensor]):
|
||||
B, T, _, H, W = s.shape
|
||||
x = x.flatten(0, 1)
|
||||
f = f.flatten(0, 1)
|
||||
s = s.flatten(0, 1)
|
||||
x = self.upsample(x)
|
||||
x = x[:, :, :H, :W]
|
||||
x = torch.cat([x, f, s], dim=1)
|
||||
x = self.conv(x)
|
||||
x = x.unflatten(0, (B, T))
|
||||
a, b = x.split(self.out_channels // 2, dim=2)
|
||||
b, r = self.gru(b, r)
|
||||
x = torch.cat([a, b], dim=2)
|
||||
return x, r
|
||||
|
||||
def forward(self, x, f, s, r: Optional[Tensor]):
|
||||
if x.ndim == 5:
|
||||
return self.forward_time_series(x, f, s, r)
|
||||
else:
|
||||
return self.forward_single_frame(x, f, s, r)
|
||||
|
||||
|
||||
class OutputBlock(nn.Module):
|
||||
def __init__(self, in_channels, src_channels, out_channels):
|
||||
super().__init__()
|
||||
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels + src_channels, out_channels, 3, 1, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
|
||||
def forward_single_frame(self, x, s):
|
||||
x = self.upsample(x)
|
||||
if x.size(2) != s.size(2):
|
||||
x = x[:, :, :s.size(2), :]
|
||||
if x.size(3) != s.size(3):
|
||||
x = x[:, :, :, :s.size(3)]
|
||||
x = torch.cat([x, s], dim=1)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
def forward_time_series(self, x, s):
|
||||
B, T, _, H, W = s.shape
|
||||
x = x.flatten(0, 1)
|
||||
s = s.flatten(0, 1)
|
||||
x = self.upsample(x)
|
||||
x = x[:, :, :H, :W]
|
||||
x = torch.cat([x, s], dim=1)
|
||||
x = self.conv(x)
|
||||
x = x.unflatten(0, (B, T))
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
if x.ndim == 5:
|
||||
return self.forward_time_series(x, s)
|
||||
else:
|
||||
return self.forward_single_frame(x, s)
|
||||
|
||||
|
||||
class ConvGRU(nn.Module):
|
||||
def __init__(self,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
padding: int = 1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.ih = nn.Sequential(
|
||||
nn.Conv2d(channels * 2, channels * 2, kernel_size, padding=padding),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
self.hh = nn.Sequential(
|
||||
nn.Conv2d(channels * 2, channels, kernel_size, padding=padding),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def forward_single_frame(self, x, h):
|
||||
r, z = self.ih(torch.cat([x, h], dim=1)).split(self.channels, dim=1)
|
||||
c = self.hh(torch.cat([x, r * h], dim=1))
|
||||
h = (1 - z) * h + z * c
|
||||
return h, h
|
||||
|
||||
def forward_time_series(self, x, h):
|
||||
o = []
|
||||
for xt in x.unbind(dim=1):
|
||||
ot, h = self.forward_single_frame(xt, h)
|
||||
o.append(ot)
|
||||
o = torch.stack(o, dim=1)
|
||||
return o, h
|
||||
|
||||
def forward(self, x, h: Optional[Tensor]):
|
||||
if h is None:
|
||||
h = torch.zeros((x.size(0), x.size(-3), x.size(-2), x.size(-1)),
|
||||
device=x.device, dtype=x.dtype)
|
||||
|
||||
if x.ndim == 5:
|
||||
return self.forward_time_series(x, h)
|
||||
else:
|
||||
return self.forward_single_frame(x, h)
|
||||
|
||||
|
||||
class Projection(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, 1)
|
||||
|
||||
def forward_single_frame(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
def forward_time_series(self, x):
|
||||
B, T = x.shape[:2]
|
||||
return self.conv(x.flatten(0, 1)).unflatten(0, (B, T))
|
||||
|
||||
def forward(self, x):
|
||||
if x.ndim == 5:
|
||||
return self.forward_time_series(x)
|
||||
else:
|
||||
return self.forward_single_frame(x)
|
||||
|
@ -0,0 +1,61 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
"""
|
||||
Adopted from <https://github.com/wuhuikai/DeepGuidedFilter/>
|
||||
"""
|
||||
|
||||
class DeepGuidedFilterRefiner(nn.Module):
|
||||
def __init__(self, hid_channels=16):
|
||||
super().__init__()
|
||||
self.box_filter = nn.Conv2d(4, 4, kernel_size=3, padding=1, bias=False, groups=4)
|
||||
self.box_filter.weight.data[...] = 1 / 9
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(4 * 2 + hid_channels, hid_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(hid_channels),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(hid_channels),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True)
|
||||
)
|
||||
|
||||
def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha, base_hid):
|
||||
fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1)
|
||||
base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1)
|
||||
base_y = torch.cat([base_fgr, base_pha], dim=1)
|
||||
|
||||
mean_x = self.box_filter(base_x)
|
||||
mean_y = self.box_filter(base_y)
|
||||
cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y
|
||||
var_x = self.box_filter(base_x * base_x) - mean_x * mean_x
|
||||
|
||||
A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1))
|
||||
b = mean_y - A * mean_x
|
||||
|
||||
H, W = fine_src.shape[2:]
|
||||
A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False)
|
||||
b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False)
|
||||
|
||||
out = A * fine_x + b
|
||||
fgr, pha = out.split([3, 1], dim=1)
|
||||
return fgr, pha
|
||||
|
||||
def forward_time_series(self, fine_src, base_src, base_fgr, base_pha, base_hid):
|
||||
B, T = fine_src.shape[:2]
|
||||
fgr, pha = self.forward_single_frame(
|
||||
fine_src.flatten(0, 1),
|
||||
base_src.flatten(0, 1),
|
||||
base_fgr.flatten(0, 1),
|
||||
base_pha.flatten(0, 1),
|
||||
base_hid.flatten(0, 1))
|
||||
fgr = fgr.unflatten(0, (B, T))
|
||||
pha = pha.unflatten(0, (B, T))
|
||||
return fgr, pha
|
||||
|
||||
def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
|
||||
if fine_src.ndim == 5:
|
||||
return self.forward_time_series(fine_src, base_src, base_fgr, base_pha, base_hid)
|
||||
else:
|
||||
return self.forward_single_frame(fine_src, base_src, base_fgr, base_pha, base_hid)
|
@ -0,0 +1,76 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
"""
|
||||
Adopted from <https://github.com/wuhuikai/DeepGuidedFilter/>
|
||||
"""
|
||||
|
||||
class FastGuidedFilterRefiner(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.guilded_filter = FastGuidedFilter(1)
|
||||
|
||||
def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha):
|
||||
fine_src_gray = fine_src.mean(1, keepdim=True)
|
||||
base_src_gray = base_src.mean(1, keepdim=True)
|
||||
|
||||
fgr, pha = self.guilded_filter(
|
||||
torch.cat([base_src, base_src_gray], dim=1),
|
||||
torch.cat([base_fgr, base_pha], dim=1),
|
||||
torch.cat([fine_src, fine_src_gray], dim=1)).split([3, 1], dim=1)
|
||||
|
||||
return fgr, pha
|
||||
|
||||
def forward_time_series(self, fine_src, base_src, base_fgr, base_pha):
|
||||
B, T = fine_src.shape[:2]
|
||||
fgr, pha = self.forward_single_frame(
|
||||
fine_src.flatten(0, 1),
|
||||
base_src.flatten(0, 1),
|
||||
base_fgr.flatten(0, 1),
|
||||
base_pha.flatten(0, 1))
|
||||
fgr = fgr.unflatten(0, (B, T))
|
||||
pha = pha.unflatten(0, (B, T))
|
||||
return fgr, pha
|
||||
|
||||
def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
|
||||
if fine_src.ndim == 5:
|
||||
return self.forward_time_series(fine_src, base_src, base_fgr, base_pha)
|
||||
else:
|
||||
return self.forward_single_frame(fine_src, base_src, base_fgr, base_pha)
|
||||
|
||||
|
||||
class FastGuidedFilter(nn.Module):
|
||||
def __init__(self, r: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.r = r
|
||||
self.eps = eps
|
||||
self.boxfilter = BoxFilter(r)
|
||||
|
||||
def forward(self, lr_x, lr_y, hr_x):
|
||||
mean_x = self.boxfilter(lr_x)
|
||||
mean_y = self.boxfilter(lr_y)
|
||||
cov_xy = self.boxfilter(lr_x * lr_y) - mean_x * mean_y
|
||||
var_x = self.boxfilter(lr_x * lr_x) - mean_x * mean_x
|
||||
A = cov_xy / (var_x + self.eps)
|
||||
b = mean_y - A * mean_x
|
||||
A = F.interpolate(A, hr_x.shape[2:], mode='bilinear', align_corners=False)
|
||||
b = F.interpolate(b, hr_x.shape[2:], mode='bilinear', align_corners=False)
|
||||
return A * hr_x + b
|
||||
|
||||
|
||||
class BoxFilter(nn.Module):
|
||||
def __init__(self, r):
|
||||
super(BoxFilter, self).__init__()
|
||||
self.r = r
|
||||
|
||||
def forward(self, x):
|
||||
# Note: The original implementation at <https://github.com/wuhuikai/DeepGuidedFilter/>
|
||||
# uses faster box blur. However, it may not be friendly for ONNX export.
|
||||
# We are switching to use simple convolution for box blur.
|
||||
kernel_size = 2 * self.r + 1
|
||||
kernel_x = torch.full((x.data.shape[1], 1, 1, kernel_size), 1 / kernel_size, device=x.device, dtype=x.dtype)
|
||||
kernel_y = torch.full((x.data.shape[1], 1, kernel_size, 1), 1 / kernel_size, device=x.device, dtype=x.dtype)
|
||||
x = F.conv2d(x, kernel_x, padding=(0, self.r), groups=x.data.shape[1])
|
||||
x = F.conv2d(x, kernel_y, padding=(self.r, 0), groups=x.data.shape[1])
|
||||
return x
|
@ -0,0 +1,29 @@
|
||||
from torch import nn
|
||||
|
||||
class LRASPP(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.aspp1 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(True)
|
||||
)
|
||||
self.aspp2 = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward_single_frame(self, x):
|
||||
return self.aspp1(x) * self.aspp2(x)
|
||||
|
||||
def forward_time_series(self, x):
|
||||
B, T = x.shape[:2]
|
||||
x = self.forward_single_frame(x.flatten(0, 1)).unflatten(0, (B, T))
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
if x.ndim == 5:
|
||||
return self.forward_time_series(x)
|
||||
else:
|
||||
return self.forward_single_frame(x)
|
@ -0,0 +1,72 @@
|
||||
from torch import nn
|
||||
from torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfig
|
||||
from torchvision.models.utils import load_state_dict_from_url
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
class MobileNetV3LargeEncoder(MobileNetV3):
|
||||
def __init__(self, pretrained: bool = False):
|
||||
super().__init__(
|
||||
inverted_residual_setting=[
|
||||
InvertedResidualConfig( 16, 3, 16, 16, False, "RE", 1, 1, 1),
|
||||
InvertedResidualConfig( 16, 3, 64, 24, False, "RE", 2, 1, 1), # C1
|
||||
InvertedResidualConfig( 24, 3, 72, 24, False, "RE", 1, 1, 1),
|
||||
InvertedResidualConfig( 24, 5, 72, 40, True, "RE", 2, 1, 1), # C2
|
||||
InvertedResidualConfig( 40, 5, 120, 40, True, "RE", 1, 1, 1),
|
||||
InvertedResidualConfig( 40, 5, 120, 40, True, "RE", 1, 1, 1),
|
||||
InvertedResidualConfig( 40, 3, 240, 80, False, "HS", 2, 1, 1), # C3
|
||||
InvertedResidualConfig( 80, 3, 200, 80, False, "HS", 1, 1, 1),
|
||||
InvertedResidualConfig( 80, 3, 184, 80, False, "HS", 1, 1, 1),
|
||||
InvertedResidualConfig( 80, 3, 184, 80, False, "HS", 1, 1, 1),
|
||||
InvertedResidualConfig( 80, 3, 480, 112, True, "HS", 1, 1, 1),
|
||||
InvertedResidualConfig(112, 3, 672, 112, True, "HS", 1, 1, 1),
|
||||
InvertedResidualConfig(112, 5, 672, 160, True, "HS", 2, 2, 1), # C4
|
||||
InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1, 2, 1),
|
||||
InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1, 2, 1),
|
||||
],
|
||||
last_channel=1280
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
self.load_state_dict(load_state_dict_from_url(
|
||||
'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth'))
|
||||
|
||||
del self.avgpool
|
||||
del self.classifier
|
||||
|
||||
def forward_single_frame(self, x):
|
||||
x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
|
||||
x = self.features[0](x)
|
||||
x = self.features[1](x)
|
||||
f1 = x
|
||||
x = self.features[2](x)
|
||||
x = self.features[3](x)
|
||||
f2 = x
|
||||
x = self.features[4](x)
|
||||
x = self.features[5](x)
|
||||
x = self.features[6](x)
|
||||
f3 = x
|
||||
x = self.features[7](x)
|
||||
x = self.features[8](x)
|
||||
x = self.features[9](x)
|
||||
x = self.features[10](x)
|
||||
x = self.features[11](x)
|
||||
x = self.features[12](x)
|
||||
x = self.features[13](x)
|
||||
x = self.features[14](x)
|
||||
x = self.features[15](x)
|
||||
x = self.features[16](x)
|
||||
f4 = x
|
||||
return [f1, f2, f3, f4]
|
||||
|
||||
def forward_time_series(self, x):
|
||||
B, T = x.shape[:2]
|
||||
features = self.forward_single_frame(x.flatten(0, 1))
|
||||
features = [f.unflatten(0, (B, T)) for f in features]
|
||||
return features
|
||||
|
||||
def forward(self, x):
|
||||
if x.ndim == 5:
|
||||
return self.forward_time_series(x)
|
||||
else:
|
||||
return self.forward_single_frame(x)
|
@ -0,0 +1,79 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from typing import Optional, List
|
||||
|
||||
from .mobilenetv3 import MobileNetV3LargeEncoder
|
||||
from .resnet import ResNet50Encoder
|
||||
from .lraspp import LRASPP
|
||||
from .decoder import RecurrentDecoder, Projection
|
||||
from .fast_guided_filter import FastGuidedFilterRefiner
|
||||
from .deep_guided_filter import DeepGuidedFilterRefiner
|
||||
|
||||
class MattingNetwork(nn.Module):
|
||||
def __init__(self,
|
||||
variant: str = 'mobilenetv3',
|
||||
refiner: str = 'deep_guided_filter',
|
||||
pretrained_backbone: bool = False):
|
||||
super().__init__()
|
||||
assert variant in ['mobilenetv3', 'resnet50']
|
||||
assert refiner in ['fast_guided_filter', 'deep_guided_filter']
|
||||
|
||||
if variant == 'mobilenetv3':
|
||||
self.backbone = MobileNetV3LargeEncoder(pretrained_backbone)
|
||||
self.aspp = LRASPP(960, 128)
|
||||
self.decoder = RecurrentDecoder([16, 24, 40, 128], [80, 40, 32, 16])
|
||||
else:
|
||||
self.backbone = ResNet50Encoder(pretrained_backbone)
|
||||
self.aspp = LRASPP(2048, 256)
|
||||
self.decoder = RecurrentDecoder([64, 256, 512, 256], [128, 64, 32, 16])
|
||||
|
||||
self.project_mat = Projection(16, 4)
|
||||
self.project_seg = Projection(16, 1)
|
||||
|
||||
if refiner == 'deep_guided_filter':
|
||||
self.refiner = DeepGuidedFilterRefiner()
|
||||
else:
|
||||
self.refiner = FastGuidedFilterRefiner()
|
||||
|
||||
def forward(self,
|
||||
src: Tensor,
|
||||
r1: Optional[Tensor] = None,
|
||||
r2: Optional[Tensor] = None,
|
||||
r3: Optional[Tensor] = None,
|
||||
r4: Optional[Tensor] = None,
|
||||
downsample_ratio: float = 1,
|
||||
segmentation_pass: bool = False):
|
||||
|
||||
if downsample_ratio != 1:
|
||||
src_sm = self._interpolate(src, scale_factor=downsample_ratio)
|
||||
else:
|
||||
src_sm = src
|
||||
|
||||
f1, f2, f3, f4 = self.backbone(src_sm)
|
||||
f4 = self.aspp(f4)
|
||||
hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4)
|
||||
|
||||
if not segmentation_pass:
|
||||
fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3)
|
||||
if downsample_ratio != 1:
|
||||
fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid)
|
||||
fgr = fgr_residual + src
|
||||
fgr = fgr.clamp(0., 1.)
|
||||
pha = pha.clamp(0., 1.)
|
||||
return [fgr, pha, *rec]
|
||||
else:
|
||||
seg = self.project_seg(hid)
|
||||
return [seg, *rec]
|
||||
|
||||
def _interpolate(self, x: Tensor, scale_factor: float):
|
||||
if x.ndim == 5:
|
||||
B, T = x.shape[:2]
|
||||
x = F.interpolate(x.flatten(0, 1), scale_factor=scale_factor,
|
||||
mode='bilinear', align_corners=False, recompute_scale_factor=False)
|
||||
x = x.unflatten(0, (B, T))
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=scale_factor,
|
||||
mode='bilinear', align_corners=False, recompute_scale_factor=False)
|
||||
return x
|
@ -0,0 +1,45 @@
|
||||
from torch import nn
|
||||
from torchvision.models.resnet import ResNet, Bottleneck
|
||||
from torchvision.models.utils import load_state_dict_from_url
|
||||
|
||||
class ResNet50Encoder(ResNet):
|
||||
def __init__(self, pretrained: bool = False):
|
||||
super().__init__(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
replace_stride_with_dilation=[False, False, True],
|
||||
norm_layer=None)
|
||||
|
||||
if pretrained:
|
||||
self.load_state_dict(load_state_dict_from_url(
|
||||
'https://download.pytorch.org/models/resnet50-0676ba61.pth'))
|
||||
|
||||
del self.avgpool
|
||||
del self.fc
|
||||
|
||||
def forward_single_frame(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
f1 = x # 1/2
|
||||
x = self.maxpool(x)
|
||||
x = self.layer1(x)
|
||||
f2 = x # 1/4
|
||||
x = self.layer2(x)
|
||||
f3 = x # 1/8
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
f4 = x # 1/16
|
||||
return [f1, f2, f3, f4]
|
||||
|
||||
def forward_time_series(self, x):
|
||||
B, T = x.shape[:2]
|
||||
features = self.forward_single_frame(x.flatten(0, 1))
|
||||
features = [f.unflatten(0, (B, T)) for f in features]
|
||||
return features
|
||||
|
||||
def forward(self, x):
|
||||
if x.ndim == 5:
|
||||
return self.forward_time_series(x)
|
||||
else:
|
||||
return self.forward_single_frame(x)
|
@ -0,0 +1,3 @@
|
||||
torch==1.8.1
|
||||
torchvision==0.9.1
|
||||
coremltools==5.0b1
|
Loading…
Reference in New Issue