commit b4850905347f4fcc588b5f7ed7cbfd34ae206436 Author: Peter Lin Date: Fri Sep 17 00:00:37 2021 -0700 Add CoreML export code diff --git a/README.md b/README.md new file mode 100644 index 0000000..11448bd --- /dev/null +++ b/README.md @@ -0,0 +1,29 @@ +# Export CoreML + +## Overview + +This branch contains our code to export to CoreML models. The `/model` folder is the same as the `master` branch. The main exporting logics are in `export_coreml.py`. + +At the time of this writing, CoreML's `ResizeBilinear` and `Upsample` ops don't not support dynamic scale parameters, so the `downsample_ratio` hyperparameter must be hardcoded. + +Our export script is written to have input size fixed. The output coreml models require iOS14+, MacOS11+. If you have other requirements, feel free to modify the export script. Contributions are welcomed. + +## Export Yourself + +The following procedures were used to generate our CoreML models. + +1. Install dependencies +```sh +pip install -r requirements.txt +``` + +2. Use the export script. You can change the `resolution` and `downsample-ratio` to fit your need. You can change quantization to one of `[8, 16, 32]`, denoting `int8`, `fp16`, and `fp32`. +```sh +python export_coreml.py \ + --model-variant mobilenetv3 \ + --checkpoint rvm_mobilenetv3.pth \ + --resolution 1920 1080 \ + --downsample-ratio 0.25 \ + --quantize-nbits 16 \ + --output model.mlmodel +``` \ No newline at end of file diff --git a/export_coreml.py b/export_coreml.py new file mode 100644 index 0000000..953b214 --- /dev/null +++ b/export_coreml.py @@ -0,0 +1,192 @@ +""" +python export_coreml.py \ + --model-variant mobilenetv3 \ + --checkpoint rvm_mobilenetv3.pth \ + --resolution 1920 1080 \ + --downsample-ratio 0.25 \ + --quantize-nbits 16 \ + --output model.mlmodel +""" + + +import argparse +import coremltools as ct +import torch + +from coremltools.models.neural_network.quantization_utils import quantize_weights +from coremltools.converters.mil.mil import Builder as mb +from coremltools.converters.mil.frontend.torch.torch_op_registry import register_torch_op +from coremltools.converters.mil.frontend.torch.ops import _get_inputs +from coremltools.proto import FeatureTypes_pb2 as ft + +from model import MattingNetwork + +class Exporter: + def __init__(self): + self.parse_args() + self.init_model() + self.register_custom_ops() + self.export() + + def parse_args(self): + parser = argparse.ArgumentParser() + parser.add_argument('--model-variant', type=str, required=True, choices=['mobilenetv3', 'resnet50']) + parser.add_argument('--model-refiner', type=str, default='deep_guided_filter', choices=['deep_guided_filter', 'fast_guided_filter']) + parser.add_argument('--checkpoint', type=str, required=False) + parser.add_argument('--resolution', type=int, required=True, nargs=2) + parser.add_argument('--downsample-ratio', type=float, required=True) + parser.add_argument('--quantize-nbits', type=int, required=True, choices=[8, 16, 32]) + parser.add_argument('--output', type=str, required=True) + self.args = parser.parse_args() + + def init_model(self): + downsample_ratio = self.args.downsample_ratio + class Wrapper(MattingNetwork): + def forward(self, src, r1=None, r2=None, r3=None, r4=None): + # Hardcode downsample_ratio into the network instead of taking it as input. This is needed for torchscript tracing. + # Also, we are multiply result by 255 to convert them to CoreML image format + fgr, pha, r1, r2, r3, r4 = super().forward(src, r1, r2, r3, r4, downsample_ratio) + return fgr.mul(255), pha.mul(255), r1, r2, r3, r4 + + self.model = Wrapper(self.args.model_variant, self.args.model_refiner).eval() + if self.args.checkpoint is not None: + self.model.load_state_dict(torch.load(self.args.checkpoint, map_location='cpu'), strict=False) + + def register_custom_ops(self): + @register_torch_op(override=True) + def hardswish_(context, node): + inputs = _get_inputs(context, node, expected=1) + x = inputs[0] + y = mb.sigmoid_hard(x=inputs[0], alpha=1.0/6, beta=0.5) + z = mb.mul(x=x, y=y, name=node.name) + context.add(z) + + @register_torch_op(override=True) + def hardsigmoid_(context, node): + inputs = _get_inputs(context, node, expected=1) + res = mb.sigmoid_hard(x=inputs[0], alpha=1.0/6, beta=0.5, name=node.name) + context.add(res) + + @register_torch_op(override=True) + def type_as(context, node): + inputs = _get_inputs(context, node) + context.add(mb.cast(x=inputs[0], dtype='fp32'), node.name) + + @register_torch_op(override=True) + def upsample_bilinear2d(context, node): + # Change to use `resize_bilinear` instead to support iOS 13. + inputs = _get_inputs(context, node) + x = inputs[0] + output_size = inputs[1] + align_corners = bool(inputs[2].val) + scale_factors = inputs[3] + + if scale_factors is not None and scale_factors.val is not None \ + and scale_factors.rank == 1 and scale_factors.shape[0] == 2: + scale_factors = scale_factors.val + resize = mb.resize_bilinear( + x=x, + target_size_height=int(x.shape[-2] * scale_factors[0]), + target_size_width=int(x.shape[-1] * scale_factors[1]), + sampling_mode='ALIGN_CORNERS', + name=node.name, + ) + context.add(resize) + else: + resize = mb.resize_bilinear( + x=x, + target_size_height=output_size.val[0], + target_size_width=output_size.val[1], + sampling_mode='ALIGN_CORNERS', + name=node.name, + ) + context.add(resize) + + def export(self): + src = torch.zeros([1, 3, *self.args.resolution[::-1]]) + _, _, r1, r2, r3, r4 = self.model(src) + + model_traced = torch.jit.trace(self.model, (src, r1, r2, r3, r4)) + model_coreml = ct.convert( + model_traced, + inputs=[ + ct.ImageType(name='src', shape=(ct.RangeDim(), *src.shape[1:]), channel_first=True, scale=1/255), + ct.TensorType(name='r1i', shape=(ct.RangeDim(), *r1.shape[1:])), + ct.TensorType(name='r2i', shape=(ct.RangeDim(), *r2.shape[1:])), + ct.TensorType(name='r3i', shape=(ct.RangeDim(), *r3.shape[1:])), + ct.TensorType(name='r4i', shape=(ct.RangeDim(), *r4.shape[1:])), + ], + ) + + if self.args.quantize_nbits in [8, 16]: + out = quantize_weights(model_coreml, nbits=self.args.quantize_nbits) + if isinstance(out, ct.models.model.MLModel): + # When the export is done on OSX, return is an mlmodel. + spec = out.get_spec() + else: + # When the export is done on Linux, the return is a spec. + spec = out + else: + spec = model_coreml.get_spec() + + # Some internal outputs are also named 'fgr' and 'pha'. + # We change them to avoid conflicts. + for layer in spec.neuralNetwork.layers: + for i in range(len(layer.input)): + if layer.input[i] == 'fgr': + layer.input[i] = 'fgr_internal' + if layer.input[i] == 'pha': + layer.input[i] = 'pha_internal' + for i in range(len(layer.output)): + if layer.output[i] == 'fgr': + layer.output[i] = 'fgr_internal' + if layer.output[i] == 'pha': + layer.output[i] = 'pha_internal' + + # Update output names + ct.utils.rename_feature(spec, spec.description.output[0].name, 'fgr') + ct.utils.rename_feature(spec, spec.description.output[1].name, 'pha') + ct.utils.rename_feature(spec, spec.description.output[2].name, 'r1o') + ct.utils.rename_feature(spec, spec.description.output[3].name, 'r2o') + ct.utils.rename_feature(spec, spec.description.output[4].name, 'r3o') + ct.utils.rename_feature(spec, spec.description.output[5].name, 'r4o') + + # Update model description + spec.description.metadata.author = 'Shanchuan Lin' + spec.description.metadata.shortDescription = 'A robust human video matting model with recurrent architecture. The model has recurrent states that must be passed to subsequent frames. Please refer to paper "Robust High-Resolution Video Matting with Temporal Guidance" for more details.' + spec.description.metadata.license = 'Apache License 2.0' + spec.description.metadata.versionString = '1.0.0' + spec.description.input[0].shortDescription = 'Source frame' + spec.description.input[1].shortDescription = 'Recurrent state 1. Initial state is an all zero tensor. Subsequent state is received from r1o.' + spec.description.input[2].shortDescription = 'Recurrent state 2. Initial state is an all zero tensor. Subsequent state is received from r2o.' + spec.description.input[3].shortDescription = 'Recurrent state 3. Initial state is an all zero tensor. Subsequent state is received from r3o.' + spec.description.input[4].shortDescription = 'Recurrent state 4. Initial state is an all zero tensor. Subsequent state is received from r4o.' + spec.description.output[0].shortDescription = 'Foreground prediction' + spec.description.output[1].shortDescription = 'Alpha prediction' + spec.description.output[2].shortDescription = 'Recurrent state 1. Needs to be passed as r1i input in the next time step.' + spec.description.output[3].shortDescription = 'Recurrent state 2. Needs to be passed as r2i input in the next time step.' + spec.description.output[4].shortDescription = 'Recurrent state 3. Needs to be passed as r3i input in the next time step.' + spec.description.output[5].shortDescription = 'Recurrent state 4. Needs to be passed as r4i input in the next time step.' + + # Update output types + spec.description.output[0].type.imageType.colorSpace = ft.ImageFeatureType.RGB + spec.description.output[0].type.imageType.width = src.size(3) + spec.description.output[0].type.imageType.height = src.size(2) + spec.description.output[1].type.imageType.colorSpace = ft.ImageFeatureType.GRAYSCALE + spec.description.output[1].type.imageType.width = src.size(3) + spec.description.output[1].type.imageType.height = src.size(2) + + # Set recurrent states as optional inputs + spec.description.input[1].type.isOptional = True + spec.description.input[2].type.isOptional = True + spec.description.input[3].type.isOptional = True + spec.description.input[4].type.isOptional = True + + # Save output + ct.utils.save_spec(spec, self.args.output) + + + + +if __name__ == '__main__': + Exporter() \ No newline at end of file diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..ac047a1 --- /dev/null +++ b/model/__init__.py @@ -0,0 +1 @@ +from .model import MattingNetwork \ No newline at end of file diff --git a/model/decoder.py b/model/decoder.py new file mode 100644 index 0000000..b596acf --- /dev/null +++ b/model/decoder.py @@ -0,0 +1,217 @@ +import torch +from torch import Tensor +from torch import nn +from torch.nn import functional as F +from typing import Tuple, Optional + +class RecurrentDecoder(nn.Module): + def __init__(self, feature_channels, decoder_channels): + super().__init__() + self.avgpool = AvgPool() + self.decode4 = BottleneckBlock(feature_channels[3]) + self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 3, decoder_channels[0]) + self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 3, decoder_channels[1]) + self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 3, decoder_channels[2]) + self.decode0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3]) + + def forward(self, + s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, f4: Tensor, + r1: Optional[Tensor], r2: Optional[Tensor], + r3: Optional[Tensor], r4: Optional[Tensor]): + s1, s2, s3 = self.avgpool(s0) + x4, r4 = self.decode4(f4, r4) + x3, r3 = self.decode3(x4, f3, s3, r3) + x2, r2 = self.decode2(x3, f2, s2, r2) + x1, r1 = self.decode1(x2, f1, s1, r1) + x0 = self.decode0(x1, s0) + return x0, r1, r2, r3, r4 + + +class AvgPool(nn.Module): + def __init__(self): + super().__init__() + self.avgpool = nn.AvgPool2d(2, 2, count_include_pad=False, ceil_mode=True) + + def forward_single_frame(self, s0): + s1 = self.avgpool(s0) + s2 = self.avgpool(s1) + s3 = self.avgpool(s2) + return s1, s2, s3 + + def forward_time_series(self, s0): + B, T = s0.shape[:2] + s0 = s0.flatten(0, 1) + s1, s2, s3 = self.forward_single_frame(s0) + s1 = s1.unflatten(0, (B, T)) + s2 = s2.unflatten(0, (B, T)) + s3 = s3.unflatten(0, (B, T)) + return s1, s2, s3 + + def forward(self, s0): + if s0.ndim == 5: + return self.forward_time_series(s0) + else: + return self.forward_single_frame(s0) + + +class BottleneckBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.gru = ConvGRU(channels // 2) + + def forward(self, x, r: Optional[Tensor]): + a, b = x.split(self.channels // 2, dim=-3) + b, r = self.gru(b, r) + x = torch.cat([a, b], dim=-3) + return x, r + + +class UpsamplingBlock(nn.Module): + def __init__(self, in_channels, skip_channels, src_channels, out_channels): + super().__init__() + self.out_channels = out_channels + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + self.conv = nn.Sequential( + nn.Conv2d(in_channels + skip_channels + src_channels, out_channels, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(True), + ) + self.gru = ConvGRU(out_channels // 2) + + def forward_single_frame(self, x, f, s, r: Optional[Tensor]): + x = self.upsample(x) + # Optimized for CoreML export + if x.size(2) != s.size(2): + x = x[:, :, :s.size(2), :] + if x.size(3) != s.size(3): + x = x[:, :, :, :s.size(3)] + x = torch.cat([x, f, s], dim=1) + x = self.conv(x) + a, b = x.split(self.out_channels // 2, dim=1) + b, r = self.gru(b, r) + x = torch.cat([a, b], dim=1) + return x, r + + def forward_time_series(self, x, f, s, r: Optional[Tensor]): + B, T, _, H, W = s.shape + x = x.flatten(0, 1) + f = f.flatten(0, 1) + s = s.flatten(0, 1) + x = self.upsample(x) + x = x[:, :, :H, :W] + x = torch.cat([x, f, s], dim=1) + x = self.conv(x) + x = x.unflatten(0, (B, T)) + a, b = x.split(self.out_channels // 2, dim=2) + b, r = self.gru(b, r) + x = torch.cat([a, b], dim=2) + return x, r + + def forward(self, x, f, s, r: Optional[Tensor]): + if x.ndim == 5: + return self.forward_time_series(x, f, s, r) + else: + return self.forward_single_frame(x, f, s, r) + + +class OutputBlock(nn.Module): + def __init__(self, in_channels, src_channels, out_channels): + super().__init__() + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + self.conv = nn.Sequential( + nn.Conv2d(in_channels + src_channels, out_channels, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(True), + nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(True), + ) + + def forward_single_frame(self, x, s): + x = self.upsample(x) + if x.size(2) != s.size(2): + x = x[:, :, :s.size(2), :] + if x.size(3) != s.size(3): + x = x[:, :, :, :s.size(3)] + x = torch.cat([x, s], dim=1) + x = self.conv(x) + return x + + def forward_time_series(self, x, s): + B, T, _, H, W = s.shape + x = x.flatten(0, 1) + s = s.flatten(0, 1) + x = self.upsample(x) + x = x[:, :, :H, :W] + x = torch.cat([x, s], dim=1) + x = self.conv(x) + x = x.unflatten(0, (B, T)) + return x + + def forward(self, x, s): + if x.ndim == 5: + return self.forward_time_series(x, s) + else: + return self.forward_single_frame(x, s) + + +class ConvGRU(nn.Module): + def __init__(self, + channels: int, + kernel_size: int = 3, + padding: int = 1): + super().__init__() + self.channels = channels + self.ih = nn.Sequential( + nn.Conv2d(channels * 2, channels * 2, kernel_size, padding=padding), + nn.Sigmoid() + ) + self.hh = nn.Sequential( + nn.Conv2d(channels * 2, channels, kernel_size, padding=padding), + nn.Tanh() + ) + + def forward_single_frame(self, x, h): + r, z = self.ih(torch.cat([x, h], dim=1)).split(self.channels, dim=1) + c = self.hh(torch.cat([x, r * h], dim=1)) + h = (1 - z) * h + z * c + return h, h + + def forward_time_series(self, x, h): + o = [] + for xt in x.unbind(dim=1): + ot, h = self.forward_single_frame(xt, h) + o.append(ot) + o = torch.stack(o, dim=1) + return o, h + + def forward(self, x, h: Optional[Tensor]): + if h is None: + h = torch.zeros((x.size(0), x.size(-3), x.size(-2), x.size(-1)), + device=x.device, dtype=x.dtype) + + if x.ndim == 5: + return self.forward_time_series(x, h) + else: + return self.forward_single_frame(x, h) + + +class Projection(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, 1) + + def forward_single_frame(self, x): + return self.conv(x) + + def forward_time_series(self, x): + B, T = x.shape[:2] + return self.conv(x.flatten(0, 1)).unflatten(0, (B, T)) + + def forward(self, x): + if x.ndim == 5: + return self.forward_time_series(x) + else: + return self.forward_single_frame(x) + \ No newline at end of file diff --git a/model/deep_guided_filter.py b/model/deep_guided_filter.py new file mode 100644 index 0000000..a24b8c5 --- /dev/null +++ b/model/deep_guided_filter.py @@ -0,0 +1,61 @@ +import torch +from torch import nn +from torch.nn import functional as F + +""" +Adopted from +""" + +class DeepGuidedFilterRefiner(nn.Module): + def __init__(self, hid_channels=16): + super().__init__() + self.box_filter = nn.Conv2d(4, 4, kernel_size=3, padding=1, bias=False, groups=4) + self.box_filter.weight.data[...] = 1 / 9 + self.conv = nn.Sequential( + nn.Conv2d(4 * 2 + hid_channels, hid_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(hid_channels), + nn.ReLU(True), + nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(hid_channels), + nn.ReLU(True), + nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True) + ) + + def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha, base_hid): + fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1) + base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1) + base_y = torch.cat([base_fgr, base_pha], dim=1) + + mean_x = self.box_filter(base_x) + mean_y = self.box_filter(base_y) + cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y + var_x = self.box_filter(base_x * base_x) - mean_x * mean_x + + A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1)) + b = mean_y - A * mean_x + + H, W = fine_src.shape[2:] + A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False) + b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False) + + out = A * fine_x + b + fgr, pha = out.split([3, 1], dim=1) + return fgr, pha + + def forward_time_series(self, fine_src, base_src, base_fgr, base_pha, base_hid): + B, T = fine_src.shape[:2] + fgr, pha = self.forward_single_frame( + fine_src.flatten(0, 1), + base_src.flatten(0, 1), + base_fgr.flatten(0, 1), + base_pha.flatten(0, 1), + base_hid.flatten(0, 1)) + fgr = fgr.unflatten(0, (B, T)) + pha = pha.unflatten(0, (B, T)) + return fgr, pha + + def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid): + if fine_src.ndim == 5: + return self.forward_time_series(fine_src, base_src, base_fgr, base_pha, base_hid) + else: + return self.forward_single_frame(fine_src, base_src, base_fgr, base_pha, base_hid) diff --git a/model/fast_guided_filter.py b/model/fast_guided_filter.py new file mode 100644 index 0000000..df9b4b2 --- /dev/null +++ b/model/fast_guided_filter.py @@ -0,0 +1,76 @@ +import torch +from torch import nn +from torch.nn import functional as F + +""" +Adopted from +""" + +class FastGuidedFilterRefiner(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.guilded_filter = FastGuidedFilter(1) + + def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha): + fine_src_gray = fine_src.mean(1, keepdim=True) + base_src_gray = base_src.mean(1, keepdim=True) + + fgr, pha = self.guilded_filter( + torch.cat([base_src, base_src_gray], dim=1), + torch.cat([base_fgr, base_pha], dim=1), + torch.cat([fine_src, fine_src_gray], dim=1)).split([3, 1], dim=1) + + return fgr, pha + + def forward_time_series(self, fine_src, base_src, base_fgr, base_pha): + B, T = fine_src.shape[:2] + fgr, pha = self.forward_single_frame( + fine_src.flatten(0, 1), + base_src.flatten(0, 1), + base_fgr.flatten(0, 1), + base_pha.flatten(0, 1)) + fgr = fgr.unflatten(0, (B, T)) + pha = pha.unflatten(0, (B, T)) + return fgr, pha + + def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid): + if fine_src.ndim == 5: + return self.forward_time_series(fine_src, base_src, base_fgr, base_pha) + else: + return self.forward_single_frame(fine_src, base_src, base_fgr, base_pha) + + +class FastGuidedFilter(nn.Module): + def __init__(self, r: int, eps: float = 1e-5): + super().__init__() + self.r = r + self.eps = eps + self.boxfilter = BoxFilter(r) + + def forward(self, lr_x, lr_y, hr_x): + mean_x = self.boxfilter(lr_x) + mean_y = self.boxfilter(lr_y) + cov_xy = self.boxfilter(lr_x * lr_y) - mean_x * mean_y + var_x = self.boxfilter(lr_x * lr_x) - mean_x * mean_x + A = cov_xy / (var_x + self.eps) + b = mean_y - A * mean_x + A = F.interpolate(A, hr_x.shape[2:], mode='bilinear', align_corners=False) + b = F.interpolate(b, hr_x.shape[2:], mode='bilinear', align_corners=False) + return A * hr_x + b + + +class BoxFilter(nn.Module): + def __init__(self, r): + super(BoxFilter, self).__init__() + self.r = r + + def forward(self, x): + # Note: The original implementation at + # uses faster box blur. However, it may not be friendly for ONNX export. + # We are switching to use simple convolution for box blur. + kernel_size = 2 * self.r + 1 + kernel_x = torch.full((x.data.shape[1], 1, 1, kernel_size), 1 / kernel_size, device=x.device, dtype=x.dtype) + kernel_y = torch.full((x.data.shape[1], 1, kernel_size, 1), 1 / kernel_size, device=x.device, dtype=x.dtype) + x = F.conv2d(x, kernel_x, padding=(0, self.r), groups=x.data.shape[1]) + x = F.conv2d(x, kernel_y, padding=(self.r, 0), groups=x.data.shape[1]) + return x \ No newline at end of file diff --git a/model/lraspp.py b/model/lraspp.py new file mode 100644 index 0000000..5fc7079 --- /dev/null +++ b/model/lraspp.py @@ -0,0 +1,29 @@ +from torch import nn + +class LRASPP(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.aspp1 = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(True) + ) + self.aspp2 = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.Sigmoid() + ) + + def forward_single_frame(self, x): + return self.aspp1(x) * self.aspp2(x) + + def forward_time_series(self, x): + B, T = x.shape[:2] + x = self.forward_single_frame(x.flatten(0, 1)).unflatten(0, (B, T)) + return x + + def forward(self, x): + if x.ndim == 5: + return self.forward_time_series(x) + else: + return self.forward_single_frame(x) \ No newline at end of file diff --git a/model/mobilenetv3.py b/model/mobilenetv3.py new file mode 100644 index 0000000..5f4b082 --- /dev/null +++ b/model/mobilenetv3.py @@ -0,0 +1,72 @@ +from torch import nn +from torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfig +from torchvision.models.utils import load_state_dict_from_url +from torchvision.transforms.functional import normalize + +class MobileNetV3LargeEncoder(MobileNetV3): + def __init__(self, pretrained: bool = False): + super().__init__( + inverted_residual_setting=[ + InvertedResidualConfig( 16, 3, 16, 16, False, "RE", 1, 1, 1), + InvertedResidualConfig( 16, 3, 64, 24, False, "RE", 2, 1, 1), # C1 + InvertedResidualConfig( 24, 3, 72, 24, False, "RE", 1, 1, 1), + InvertedResidualConfig( 24, 5, 72, 40, True, "RE", 2, 1, 1), # C2 + InvertedResidualConfig( 40, 5, 120, 40, True, "RE", 1, 1, 1), + InvertedResidualConfig( 40, 5, 120, 40, True, "RE", 1, 1, 1), + InvertedResidualConfig( 40, 3, 240, 80, False, "HS", 2, 1, 1), # C3 + InvertedResidualConfig( 80, 3, 200, 80, False, "HS", 1, 1, 1), + InvertedResidualConfig( 80, 3, 184, 80, False, "HS", 1, 1, 1), + InvertedResidualConfig( 80, 3, 184, 80, False, "HS", 1, 1, 1), + InvertedResidualConfig( 80, 3, 480, 112, True, "HS", 1, 1, 1), + InvertedResidualConfig(112, 3, 672, 112, True, "HS", 1, 1, 1), + InvertedResidualConfig(112, 5, 672, 160, True, "HS", 2, 2, 1), # C4 + InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1, 2, 1), + InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1, 2, 1), + ], + last_channel=1280 + ) + + if pretrained: + self.load_state_dict(load_state_dict_from_url( + 'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth')) + + del self.avgpool + del self.classifier + + def forward_single_frame(self, x): + x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + + x = self.features[0](x) + x = self.features[1](x) + f1 = x + x = self.features[2](x) + x = self.features[3](x) + f2 = x + x = self.features[4](x) + x = self.features[5](x) + x = self.features[6](x) + f3 = x + x = self.features[7](x) + x = self.features[8](x) + x = self.features[9](x) + x = self.features[10](x) + x = self.features[11](x) + x = self.features[12](x) + x = self.features[13](x) + x = self.features[14](x) + x = self.features[15](x) + x = self.features[16](x) + f4 = x + return [f1, f2, f3, f4] + + def forward_time_series(self, x): + B, T = x.shape[:2] + features = self.forward_single_frame(x.flatten(0, 1)) + features = [f.unflatten(0, (B, T)) for f in features] + return features + + def forward(self, x): + if x.ndim == 5: + return self.forward_time_series(x) + else: + return self.forward_single_frame(x) diff --git a/model/model.py b/model/model.py new file mode 100644 index 0000000..71fc684 --- /dev/null +++ b/model/model.py @@ -0,0 +1,79 @@ +import torch +from torch import Tensor +from torch import nn +from torch.nn import functional as F +from typing import Optional, List + +from .mobilenetv3 import MobileNetV3LargeEncoder +from .resnet import ResNet50Encoder +from .lraspp import LRASPP +from .decoder import RecurrentDecoder, Projection +from .fast_guided_filter import FastGuidedFilterRefiner +from .deep_guided_filter import DeepGuidedFilterRefiner + +class MattingNetwork(nn.Module): + def __init__(self, + variant: str = 'mobilenetv3', + refiner: str = 'deep_guided_filter', + pretrained_backbone: bool = False): + super().__init__() + assert variant in ['mobilenetv3', 'resnet50'] + assert refiner in ['fast_guided_filter', 'deep_guided_filter'] + + if variant == 'mobilenetv3': + self.backbone = MobileNetV3LargeEncoder(pretrained_backbone) + self.aspp = LRASPP(960, 128) + self.decoder = RecurrentDecoder([16, 24, 40, 128], [80, 40, 32, 16]) + else: + self.backbone = ResNet50Encoder(pretrained_backbone) + self.aspp = LRASPP(2048, 256) + self.decoder = RecurrentDecoder([64, 256, 512, 256], [128, 64, 32, 16]) + + self.project_mat = Projection(16, 4) + self.project_seg = Projection(16, 1) + + if refiner == 'deep_guided_filter': + self.refiner = DeepGuidedFilterRefiner() + else: + self.refiner = FastGuidedFilterRefiner() + + def forward(self, + src: Tensor, + r1: Optional[Tensor] = None, + r2: Optional[Tensor] = None, + r3: Optional[Tensor] = None, + r4: Optional[Tensor] = None, + downsample_ratio: float = 1, + segmentation_pass: bool = False): + + if downsample_ratio != 1: + src_sm = self._interpolate(src, scale_factor=downsample_ratio) + else: + src_sm = src + + f1, f2, f3, f4 = self.backbone(src_sm) + f4 = self.aspp(f4) + hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4) + + if not segmentation_pass: + fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3) + if downsample_ratio != 1: + fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid) + fgr = fgr_residual + src + fgr = fgr.clamp(0., 1.) + pha = pha.clamp(0., 1.) + return [fgr, pha, *rec] + else: + seg = self.project_seg(hid) + return [seg, *rec] + + def _interpolate(self, x: Tensor, scale_factor: float): + if x.ndim == 5: + B, T = x.shape[:2] + x = F.interpolate(x.flatten(0, 1), scale_factor=scale_factor, + mode='bilinear', align_corners=False, recompute_scale_factor=False) + x = x.unflatten(0, (B, T)) + else: + x = F.interpolate(x, scale_factor=scale_factor, + mode='bilinear', align_corners=False, recompute_scale_factor=False) + return x diff --git a/model/resnet.py b/model/resnet.py new file mode 100644 index 0000000..7634fcb --- /dev/null +++ b/model/resnet.py @@ -0,0 +1,45 @@ +from torch import nn +from torchvision.models.resnet import ResNet, Bottleneck +from torchvision.models.utils import load_state_dict_from_url + +class ResNet50Encoder(ResNet): + def __init__(self, pretrained: bool = False): + super().__init__( + block=Bottleneck, + layers=[3, 4, 6, 3], + replace_stride_with_dilation=[False, False, True], + norm_layer=None) + + if pretrained: + self.load_state_dict(load_state_dict_from_url( + 'https://download.pytorch.org/models/resnet50-0676ba61.pth')) + + del self.avgpool + del self.fc + + def forward_single_frame(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + f1 = x # 1/2 + x = self.maxpool(x) + x = self.layer1(x) + f2 = x # 1/4 + x = self.layer2(x) + f3 = x # 1/8 + x = self.layer3(x) + x = self.layer4(x) + f4 = x # 1/16 + return [f1, f2, f3, f4] + + def forward_time_series(self, x): + B, T = x.shape[:2] + features = self.forward_single_frame(x.flatten(0, 1)) + features = [f.unflatten(0, (B, T)) for f in features] + return features + + def forward(self, x): + if x.ndim == 5: + return self.forward_time_series(x) + else: + return self.forward_single_frame(x) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..df3f2a5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +torch==1.8.1 +torchvision==0.9.1 +coremltools==5.0b1 \ No newline at end of file