Add CoreML export code

coreml
Peter Lin 3 years ago
commit b485090534

@ -0,0 +1,29 @@
# Export CoreML
## Overview
This branch contains our code to export to CoreML models. The `/model` folder is the same as the `master` branch. The main exporting logics are in `export_coreml.py`.
At the time of this writing, CoreML's `ResizeBilinear` and `Upsample` ops don't not support dynamic scale parameters, so the `downsample_ratio` hyperparameter must be hardcoded.
Our export script is written to have input size fixed. The output coreml models require iOS14+, MacOS11+. If you have other requirements, feel free to modify the export script. Contributions are welcomed.
## Export Yourself
The following procedures were used to generate our CoreML models.
1. Install dependencies
```sh
pip install -r requirements.txt
```
2. Use the export script. You can change the `resolution` and `downsample-ratio` to fit your need. You can change quantization to one of `[8, 16, 32]`, denoting `int8`, `fp16`, and `fp32`.
```sh
python export_coreml.py \
--model-variant mobilenetv3 \
--checkpoint rvm_mobilenetv3.pth \
--resolution 1920 1080 \
--downsample-ratio 0.25 \
--quantize-nbits 16 \
--output model.mlmodel
```

@ -0,0 +1,192 @@
"""
python export_coreml.py \
--model-variant mobilenetv3 \
--checkpoint rvm_mobilenetv3.pth \
--resolution 1920 1080 \
--downsample-ratio 0.25 \
--quantize-nbits 16 \
--output model.mlmodel
"""
import argparse
import coremltools as ct
import torch
from coremltools.models.neural_network.quantization_utils import quantize_weights
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.frontend.torch.torch_op_registry import register_torch_op
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
from coremltools.proto import FeatureTypes_pb2 as ft
from model import MattingNetwork
class Exporter:
def __init__(self):
self.parse_args()
self.init_model()
self.register_custom_ops()
self.export()
def parse_args(self):
parser = argparse.ArgumentParser()
parser.add_argument('--model-variant', type=str, required=True, choices=['mobilenetv3', 'resnet50'])
parser.add_argument('--model-refiner', type=str, default='deep_guided_filter', choices=['deep_guided_filter', 'fast_guided_filter'])
parser.add_argument('--checkpoint', type=str, required=False)
parser.add_argument('--resolution', type=int, required=True, nargs=2)
parser.add_argument('--downsample-ratio', type=float, required=True)
parser.add_argument('--quantize-nbits', type=int, required=True, choices=[8, 16, 32])
parser.add_argument('--output', type=str, required=True)
self.args = parser.parse_args()
def init_model(self):
downsample_ratio = self.args.downsample_ratio
class Wrapper(MattingNetwork):
def forward(self, src, r1=None, r2=None, r3=None, r4=None):
# Hardcode downsample_ratio into the network instead of taking it as input. This is needed for torchscript tracing.
# Also, we are multiply result by 255 to convert them to CoreML image format
fgr, pha, r1, r2, r3, r4 = super().forward(src, r1, r2, r3, r4, downsample_ratio)
return fgr.mul(255), pha.mul(255), r1, r2, r3, r4
self.model = Wrapper(self.args.model_variant, self.args.model_refiner).eval()
if self.args.checkpoint is not None:
self.model.load_state_dict(torch.load(self.args.checkpoint, map_location='cpu'), strict=False)
def register_custom_ops(self):
@register_torch_op(override=True)
def hardswish_(context, node):
inputs = _get_inputs(context, node, expected=1)
x = inputs[0]
y = mb.sigmoid_hard(x=inputs[0], alpha=1.0/6, beta=0.5)
z = mb.mul(x=x, y=y, name=node.name)
context.add(z)
@register_torch_op(override=True)
def hardsigmoid_(context, node):
inputs = _get_inputs(context, node, expected=1)
res = mb.sigmoid_hard(x=inputs[0], alpha=1.0/6, beta=0.5, name=node.name)
context.add(res)
@register_torch_op(override=True)
def type_as(context, node):
inputs = _get_inputs(context, node)
context.add(mb.cast(x=inputs[0], dtype='fp32'), node.name)
@register_torch_op(override=True)
def upsample_bilinear2d(context, node):
# Change to use `resize_bilinear` instead to support iOS 13.
inputs = _get_inputs(context, node)
x = inputs[0]
output_size = inputs[1]
align_corners = bool(inputs[2].val)
scale_factors = inputs[3]
if scale_factors is not None and scale_factors.val is not None \
and scale_factors.rank == 1 and scale_factors.shape[0] == 2:
scale_factors = scale_factors.val
resize = mb.resize_bilinear(
x=x,
target_size_height=int(x.shape[-2] * scale_factors[0]),
target_size_width=int(x.shape[-1] * scale_factors[1]),
sampling_mode='ALIGN_CORNERS',
name=node.name,
)
context.add(resize)
else:
resize = mb.resize_bilinear(
x=x,
target_size_height=output_size.val[0],
target_size_width=output_size.val[1],
sampling_mode='ALIGN_CORNERS',
name=node.name,
)
context.add(resize)
def export(self):
src = torch.zeros([1, 3, *self.args.resolution[::-1]])
_, _, r1, r2, r3, r4 = self.model(src)
model_traced = torch.jit.trace(self.model, (src, r1, r2, r3, r4))
model_coreml = ct.convert(
model_traced,
inputs=[
ct.ImageType(name='src', shape=(ct.RangeDim(), *src.shape[1:]), channel_first=True, scale=1/255),
ct.TensorType(name='r1i', shape=(ct.RangeDim(), *r1.shape[1:])),
ct.TensorType(name='r2i', shape=(ct.RangeDim(), *r2.shape[1:])),
ct.TensorType(name='r3i', shape=(ct.RangeDim(), *r3.shape[1:])),
ct.TensorType(name='r4i', shape=(ct.RangeDim(), *r4.shape[1:])),
],
)
if self.args.quantize_nbits in [8, 16]:
out = quantize_weights(model_coreml, nbits=self.args.quantize_nbits)
if isinstance(out, ct.models.model.MLModel):
# When the export is done on OSX, return is an mlmodel.
spec = out.get_spec()
else:
# When the export is done on Linux, the return is a spec.
spec = out
else:
spec = model_coreml.get_spec()
# Some internal outputs are also named 'fgr' and 'pha'.
# We change them to avoid conflicts.
for layer in spec.neuralNetwork.layers:
for i in range(len(layer.input)):
if layer.input[i] == 'fgr':
layer.input[i] = 'fgr_internal'
if layer.input[i] == 'pha':
layer.input[i] = 'pha_internal'
for i in range(len(layer.output)):
if layer.output[i] == 'fgr':
layer.output[i] = 'fgr_internal'
if layer.output[i] == 'pha':
layer.output[i] = 'pha_internal'
# Update output names
ct.utils.rename_feature(spec, spec.description.output[0].name, 'fgr')
ct.utils.rename_feature(spec, spec.description.output[1].name, 'pha')
ct.utils.rename_feature(spec, spec.description.output[2].name, 'r1o')
ct.utils.rename_feature(spec, spec.description.output[3].name, 'r2o')
ct.utils.rename_feature(spec, spec.description.output[4].name, 'r3o')
ct.utils.rename_feature(spec, spec.description.output[5].name, 'r4o')
# Update model description
spec.description.metadata.author = 'Shanchuan Lin'
spec.description.metadata.shortDescription = 'A robust human video matting model with recurrent architecture. The model has recurrent states that must be passed to subsequent frames. Please refer to paper "Robust High-Resolution Video Matting with Temporal Guidance" for more details.'
spec.description.metadata.license = 'Apache License 2.0'
spec.description.metadata.versionString = '1.0.0'
spec.description.input[0].shortDescription = 'Source frame'
spec.description.input[1].shortDescription = 'Recurrent state 1. Initial state is an all zero tensor. Subsequent state is received from r1o.'
spec.description.input[2].shortDescription = 'Recurrent state 2. Initial state is an all zero tensor. Subsequent state is received from r2o.'
spec.description.input[3].shortDescription = 'Recurrent state 3. Initial state is an all zero tensor. Subsequent state is received from r3o.'
spec.description.input[4].shortDescription = 'Recurrent state 4. Initial state is an all zero tensor. Subsequent state is received from r4o.'
spec.description.output[0].shortDescription = 'Foreground prediction'
spec.description.output[1].shortDescription = 'Alpha prediction'
spec.description.output[2].shortDescription = 'Recurrent state 1. Needs to be passed as r1i input in the next time step.'
spec.description.output[3].shortDescription = 'Recurrent state 2. Needs to be passed as r2i input in the next time step.'
spec.description.output[4].shortDescription = 'Recurrent state 3. Needs to be passed as r3i input in the next time step.'
spec.description.output[5].shortDescription = 'Recurrent state 4. Needs to be passed as r4i input in the next time step.'
# Update output types
spec.description.output[0].type.imageType.colorSpace = ft.ImageFeatureType.RGB
spec.description.output[0].type.imageType.width = src.size(3)
spec.description.output[0].type.imageType.height = src.size(2)
spec.description.output[1].type.imageType.colorSpace = ft.ImageFeatureType.GRAYSCALE
spec.description.output[1].type.imageType.width = src.size(3)
spec.description.output[1].type.imageType.height = src.size(2)
# Set recurrent states as optional inputs
spec.description.input[1].type.isOptional = True
spec.description.input[2].type.isOptional = True
spec.description.input[3].type.isOptional = True
spec.description.input[4].type.isOptional = True
# Save output
ct.utils.save_spec(spec, self.args.output)
if __name__ == '__main__':
Exporter()

@ -0,0 +1 @@
from .model import MattingNetwork

@ -0,0 +1,217 @@
import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
from typing import Tuple, Optional
class RecurrentDecoder(nn.Module):
def __init__(self, feature_channels, decoder_channels):
super().__init__()
self.avgpool = AvgPool()
self.decode4 = BottleneckBlock(feature_channels[3])
self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 3, decoder_channels[0])
self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 3, decoder_channels[1])
self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 3, decoder_channels[2])
self.decode0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3])
def forward(self,
s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, f4: Tensor,
r1: Optional[Tensor], r2: Optional[Tensor],
r3: Optional[Tensor], r4: Optional[Tensor]):
s1, s2, s3 = self.avgpool(s0)
x4, r4 = self.decode4(f4, r4)
x3, r3 = self.decode3(x4, f3, s3, r3)
x2, r2 = self.decode2(x3, f2, s2, r2)
x1, r1 = self.decode1(x2, f1, s1, r1)
x0 = self.decode0(x1, s0)
return x0, r1, r2, r3, r4
class AvgPool(nn.Module):
def __init__(self):
super().__init__()
self.avgpool = nn.AvgPool2d(2, 2, count_include_pad=False, ceil_mode=True)
def forward_single_frame(self, s0):
s1 = self.avgpool(s0)
s2 = self.avgpool(s1)
s3 = self.avgpool(s2)
return s1, s2, s3
def forward_time_series(self, s0):
B, T = s0.shape[:2]
s0 = s0.flatten(0, 1)
s1, s2, s3 = self.forward_single_frame(s0)
s1 = s1.unflatten(0, (B, T))
s2 = s2.unflatten(0, (B, T))
s3 = s3.unflatten(0, (B, T))
return s1, s2, s3
def forward(self, s0):
if s0.ndim == 5:
return self.forward_time_series(s0)
else:
return self.forward_single_frame(s0)
class BottleneckBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.gru = ConvGRU(channels // 2)
def forward(self, x, r: Optional[Tensor]):
a, b = x.split(self.channels // 2, dim=-3)
b, r = self.gru(b, r)
x = torch.cat([a, b], dim=-3)
return x, r
class UpsamplingBlock(nn.Module):
def __init__(self, in_channels, skip_channels, src_channels, out_channels):
super().__init__()
self.out_channels = out_channels
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.conv = nn.Sequential(
nn.Conv2d(in_channels + skip_channels + src_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
)
self.gru = ConvGRU(out_channels // 2)
def forward_single_frame(self, x, f, s, r: Optional[Tensor]):
x = self.upsample(x)
# Optimized for CoreML export
if x.size(2) != s.size(2):
x = x[:, :, :s.size(2), :]
if x.size(3) != s.size(3):
x = x[:, :, :, :s.size(3)]
x = torch.cat([x, f, s], dim=1)
x = self.conv(x)
a, b = x.split(self.out_channels // 2, dim=1)
b, r = self.gru(b, r)
x = torch.cat([a, b], dim=1)
return x, r
def forward_time_series(self, x, f, s, r: Optional[Tensor]):
B, T, _, H, W = s.shape
x = x.flatten(0, 1)
f = f.flatten(0, 1)
s = s.flatten(0, 1)
x = self.upsample(x)
x = x[:, :, :H, :W]
x = torch.cat([x, f, s], dim=1)
x = self.conv(x)
x = x.unflatten(0, (B, T))
a, b = x.split(self.out_channels // 2, dim=2)
b, r = self.gru(b, r)
x = torch.cat([a, b], dim=2)
return x, r
def forward(self, x, f, s, r: Optional[Tensor]):
if x.ndim == 5:
return self.forward_time_series(x, f, s, r)
else:
return self.forward_single_frame(x, f, s, r)
class OutputBlock(nn.Module):
def __init__(self, in_channels, src_channels, out_channels):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.conv = nn.Sequential(
nn.Conv2d(in_channels + src_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
)
def forward_single_frame(self, x, s):
x = self.upsample(x)
if x.size(2) != s.size(2):
x = x[:, :, :s.size(2), :]
if x.size(3) != s.size(3):
x = x[:, :, :, :s.size(3)]
x = torch.cat([x, s], dim=1)
x = self.conv(x)
return x
def forward_time_series(self, x, s):
B, T, _, H, W = s.shape
x = x.flatten(0, 1)
s = s.flatten(0, 1)
x = self.upsample(x)
x = x[:, :, :H, :W]
x = torch.cat([x, s], dim=1)
x = self.conv(x)
x = x.unflatten(0, (B, T))
return x
def forward(self, x, s):
if x.ndim == 5:
return self.forward_time_series(x, s)
else:
return self.forward_single_frame(x, s)
class ConvGRU(nn.Module):
def __init__(self,
channels: int,
kernel_size: int = 3,
padding: int = 1):
super().__init__()
self.channels = channels
self.ih = nn.Sequential(
nn.Conv2d(channels * 2, channels * 2, kernel_size, padding=padding),
nn.Sigmoid()
)
self.hh = nn.Sequential(
nn.Conv2d(channels * 2, channels, kernel_size, padding=padding),
nn.Tanh()
)
def forward_single_frame(self, x, h):
r, z = self.ih(torch.cat([x, h], dim=1)).split(self.channels, dim=1)
c = self.hh(torch.cat([x, r * h], dim=1))
h = (1 - z) * h + z * c
return h, h
def forward_time_series(self, x, h):
o = []
for xt in x.unbind(dim=1):
ot, h = self.forward_single_frame(xt, h)
o.append(ot)
o = torch.stack(o, dim=1)
return o, h
def forward(self, x, h: Optional[Tensor]):
if h is None:
h = torch.zeros((x.size(0), x.size(-3), x.size(-2), x.size(-1)),
device=x.device, dtype=x.dtype)
if x.ndim == 5:
return self.forward_time_series(x, h)
else:
return self.forward_single_frame(x, h)
class Projection(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 1)
def forward_single_frame(self, x):
return self.conv(x)
def forward_time_series(self, x):
B, T = x.shape[:2]
return self.conv(x.flatten(0, 1)).unflatten(0, (B, T))
def forward(self, x):
if x.ndim == 5:
return self.forward_time_series(x)
else:
return self.forward_single_frame(x)

@ -0,0 +1,61 @@
import torch
from torch import nn
from torch.nn import functional as F
"""
Adopted from <https://github.com/wuhuikai/DeepGuidedFilter/>
"""
class DeepGuidedFilterRefiner(nn.Module):
def __init__(self, hid_channels=16):
super().__init__()
self.box_filter = nn.Conv2d(4, 4, kernel_size=3, padding=1, bias=False, groups=4)
self.box_filter.weight.data[...] = 1 / 9
self.conv = nn.Sequential(
nn.Conv2d(4 * 2 + hid_channels, hid_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(hid_channels),
nn.ReLU(True),
nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(hid_channels),
nn.ReLU(True),
nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True)
)
def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha, base_hid):
fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1)
base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1)
base_y = torch.cat([base_fgr, base_pha], dim=1)
mean_x = self.box_filter(base_x)
mean_y = self.box_filter(base_y)
cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y
var_x = self.box_filter(base_x * base_x) - mean_x * mean_x
A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1))
b = mean_y - A * mean_x
H, W = fine_src.shape[2:]
A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False)
b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False)
out = A * fine_x + b
fgr, pha = out.split([3, 1], dim=1)
return fgr, pha
def forward_time_series(self, fine_src, base_src, base_fgr, base_pha, base_hid):
B, T = fine_src.shape[:2]
fgr, pha = self.forward_single_frame(
fine_src.flatten(0, 1),
base_src.flatten(0, 1),
base_fgr.flatten(0, 1),
base_pha.flatten(0, 1),
base_hid.flatten(0, 1))
fgr = fgr.unflatten(0, (B, T))
pha = pha.unflatten(0, (B, T))
return fgr, pha
def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
if fine_src.ndim == 5:
return self.forward_time_series(fine_src, base_src, base_fgr, base_pha, base_hid)
else:
return self.forward_single_frame(fine_src, base_src, base_fgr, base_pha, base_hid)

@ -0,0 +1,76 @@
import torch
from torch import nn
from torch.nn import functional as F
"""
Adopted from <https://github.com/wuhuikai/DeepGuidedFilter/>
"""
class FastGuidedFilterRefiner(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.guilded_filter = FastGuidedFilter(1)
def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha):
fine_src_gray = fine_src.mean(1, keepdim=True)
base_src_gray = base_src.mean(1, keepdim=True)
fgr, pha = self.guilded_filter(
torch.cat([base_src, base_src_gray], dim=1),
torch.cat([base_fgr, base_pha], dim=1),
torch.cat([fine_src, fine_src_gray], dim=1)).split([3, 1], dim=1)
return fgr, pha
def forward_time_series(self, fine_src, base_src, base_fgr, base_pha):
B, T = fine_src.shape[:2]
fgr, pha = self.forward_single_frame(
fine_src.flatten(0, 1),
base_src.flatten(0, 1),
base_fgr.flatten(0, 1),
base_pha.flatten(0, 1))
fgr = fgr.unflatten(0, (B, T))
pha = pha.unflatten(0, (B, T))
return fgr, pha
def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
if fine_src.ndim == 5:
return self.forward_time_series(fine_src, base_src, base_fgr, base_pha)
else:
return self.forward_single_frame(fine_src, base_src, base_fgr, base_pha)
class FastGuidedFilter(nn.Module):
def __init__(self, r: int, eps: float = 1e-5):
super().__init__()
self.r = r
self.eps = eps
self.boxfilter = BoxFilter(r)
def forward(self, lr_x, lr_y, hr_x):
mean_x = self.boxfilter(lr_x)
mean_y = self.boxfilter(lr_y)
cov_xy = self.boxfilter(lr_x * lr_y) - mean_x * mean_y
var_x = self.boxfilter(lr_x * lr_x) - mean_x * mean_x
A = cov_xy / (var_x + self.eps)
b = mean_y - A * mean_x
A = F.interpolate(A, hr_x.shape[2:], mode='bilinear', align_corners=False)
b = F.interpolate(b, hr_x.shape[2:], mode='bilinear', align_corners=False)
return A * hr_x + b
class BoxFilter(nn.Module):
def __init__(self, r):
super(BoxFilter, self).__init__()
self.r = r
def forward(self, x):
# Note: The original implementation at <https://github.com/wuhuikai/DeepGuidedFilter/>
# uses faster box blur. However, it may not be friendly for ONNX export.
# We are switching to use simple convolution for box blur.
kernel_size = 2 * self.r + 1
kernel_x = torch.full((x.data.shape[1], 1, 1, kernel_size), 1 / kernel_size, device=x.device, dtype=x.dtype)
kernel_y = torch.full((x.data.shape[1], 1, kernel_size, 1), 1 / kernel_size, device=x.device, dtype=x.dtype)
x = F.conv2d(x, kernel_x, padding=(0, self.r), groups=x.data.shape[1])
x = F.conv2d(x, kernel_y, padding=(self.r, 0), groups=x.data.shape[1])
return x

@ -0,0 +1,29 @@
from torch import nn
class LRASPP(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.aspp1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True)
)
self.aspp2 = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.Sigmoid()
)
def forward_single_frame(self, x):
return self.aspp1(x) * self.aspp2(x)
def forward_time_series(self, x):
B, T = x.shape[:2]
x = self.forward_single_frame(x.flatten(0, 1)).unflatten(0, (B, T))
return x
def forward(self, x):
if x.ndim == 5:
return self.forward_time_series(x)
else:
return self.forward_single_frame(x)

@ -0,0 +1,72 @@
from torch import nn
from torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfig
from torchvision.models.utils import load_state_dict_from_url
from torchvision.transforms.functional import normalize
class MobileNetV3LargeEncoder(MobileNetV3):
def __init__(self, pretrained: bool = False):
super().__init__(
inverted_residual_setting=[
InvertedResidualConfig( 16, 3, 16, 16, False, "RE", 1, 1, 1),
InvertedResidualConfig( 16, 3, 64, 24, False, "RE", 2, 1, 1), # C1
InvertedResidualConfig( 24, 3, 72, 24, False, "RE", 1, 1, 1),
InvertedResidualConfig( 24, 5, 72, 40, True, "RE", 2, 1, 1), # C2
InvertedResidualConfig( 40, 5, 120, 40, True, "RE", 1, 1, 1),
InvertedResidualConfig( 40, 5, 120, 40, True, "RE", 1, 1, 1),
InvertedResidualConfig( 40, 3, 240, 80, False, "HS", 2, 1, 1), # C3
InvertedResidualConfig( 80, 3, 200, 80, False, "HS", 1, 1, 1),
InvertedResidualConfig( 80, 3, 184, 80, False, "HS", 1, 1, 1),
InvertedResidualConfig( 80, 3, 184, 80, False, "HS", 1, 1, 1),
InvertedResidualConfig( 80, 3, 480, 112, True, "HS", 1, 1, 1),
InvertedResidualConfig(112, 3, 672, 112, True, "HS", 1, 1, 1),
InvertedResidualConfig(112, 5, 672, 160, True, "HS", 2, 2, 1), # C4
InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1, 2, 1),
InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1, 2, 1),
],
last_channel=1280
)
if pretrained:
self.load_state_dict(load_state_dict_from_url(
'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth'))
del self.avgpool
del self.classifier
def forward_single_frame(self, x):
x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
x = self.features[0](x)
x = self.features[1](x)
f1 = x
x = self.features[2](x)
x = self.features[3](x)
f2 = x
x = self.features[4](x)
x = self.features[5](x)
x = self.features[6](x)
f3 = x
x = self.features[7](x)
x = self.features[8](x)
x = self.features[9](x)
x = self.features[10](x)
x = self.features[11](x)
x = self.features[12](x)
x = self.features[13](x)
x = self.features[14](x)
x = self.features[15](x)
x = self.features[16](x)
f4 = x
return [f1, f2, f3, f4]
def forward_time_series(self, x):
B, T = x.shape[:2]
features = self.forward_single_frame(x.flatten(0, 1))
features = [f.unflatten(0, (B, T)) for f in features]
return features
def forward(self, x):
if x.ndim == 5:
return self.forward_time_series(x)
else:
return self.forward_single_frame(x)

@ -0,0 +1,79 @@
import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
from typing import Optional, List
from .mobilenetv3 import MobileNetV3LargeEncoder
from .resnet import ResNet50Encoder
from .lraspp import LRASPP
from .decoder import RecurrentDecoder, Projection
from .fast_guided_filter import FastGuidedFilterRefiner
from .deep_guided_filter import DeepGuidedFilterRefiner
class MattingNetwork(nn.Module):
def __init__(self,
variant: str = 'mobilenetv3',
refiner: str = 'deep_guided_filter',
pretrained_backbone: bool = False):
super().__init__()
assert variant in ['mobilenetv3', 'resnet50']
assert refiner in ['fast_guided_filter', 'deep_guided_filter']
if variant == 'mobilenetv3':
self.backbone = MobileNetV3LargeEncoder(pretrained_backbone)
self.aspp = LRASPP(960, 128)
self.decoder = RecurrentDecoder([16, 24, 40, 128], [80, 40, 32, 16])
else:
self.backbone = ResNet50Encoder(pretrained_backbone)
self.aspp = LRASPP(2048, 256)
self.decoder = RecurrentDecoder([64, 256, 512, 256], [128, 64, 32, 16])
self.project_mat = Projection(16, 4)
self.project_seg = Projection(16, 1)
if refiner == 'deep_guided_filter':
self.refiner = DeepGuidedFilterRefiner()
else:
self.refiner = FastGuidedFilterRefiner()
def forward(self,
src: Tensor,
r1: Optional[Tensor] = None,
r2: Optional[Tensor] = None,
r3: Optional[Tensor] = None,
r4: Optional[Tensor] = None,
downsample_ratio: float = 1,
segmentation_pass: bool = False):
if downsample_ratio != 1:
src_sm = self._interpolate(src, scale_factor=downsample_ratio)
else:
src_sm = src
f1, f2, f3, f4 = self.backbone(src_sm)
f4 = self.aspp(f4)
hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4)
if not segmentation_pass:
fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3)
if downsample_ratio != 1:
fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid)
fgr = fgr_residual + src
fgr = fgr.clamp(0., 1.)
pha = pha.clamp(0., 1.)
return [fgr, pha, *rec]
else:
seg = self.project_seg(hid)
return [seg, *rec]
def _interpolate(self, x: Tensor, scale_factor: float):
if x.ndim == 5:
B, T = x.shape[:2]
x = F.interpolate(x.flatten(0, 1), scale_factor=scale_factor,
mode='bilinear', align_corners=False, recompute_scale_factor=False)
x = x.unflatten(0, (B, T))
else:
x = F.interpolate(x, scale_factor=scale_factor,
mode='bilinear', align_corners=False, recompute_scale_factor=False)
return x

@ -0,0 +1,45 @@
from torch import nn
from torchvision.models.resnet import ResNet, Bottleneck
from torchvision.models.utils import load_state_dict_from_url
class ResNet50Encoder(ResNet):
def __init__(self, pretrained: bool = False):
super().__init__(
block=Bottleneck,
layers=[3, 4, 6, 3],
replace_stride_with_dilation=[False, False, True],
norm_layer=None)
if pretrained:
self.load_state_dict(load_state_dict_from_url(
'https://download.pytorch.org/models/resnet50-0676ba61.pth'))
del self.avgpool
del self.fc
def forward_single_frame(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
f1 = x # 1/2
x = self.maxpool(x)
x = self.layer1(x)
f2 = x # 1/4
x = self.layer2(x)
f3 = x # 1/8
x = self.layer3(x)
x = self.layer4(x)
f4 = x # 1/16
return [f1, f2, f3, f4]
def forward_time_series(self, x):
B, T = x.shape[:2]
features = self.forward_single_frame(x.flatten(0, 1))
features = [f.unflatten(0, (B, T)) for f in features]
return features
def forward(self, x):
if x.ndim == 5:
return self.forward_time_series(x)
else:
return self.forward_single_frame(x)

@ -0,0 +1,3 @@
torch==1.8.1
torchvision==0.9.1
coremltools==5.0b1
Loading…
Cancel
Save