You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
210 lines
6.9 KiB
Python
210 lines
6.9 KiB
Python
import torch
|
|
from torch import Tensor
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
from typing import Tuple, Optional
|
|
|
|
class RecurrentDecoder(nn.Module):
|
|
def __init__(self, feature_channels, decoder_channels):
|
|
super().__init__()
|
|
self.avgpool = AvgPool()
|
|
self.decode4 = BottleneckBlock(feature_channels[3])
|
|
self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 3, decoder_channels[0])
|
|
self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 3, decoder_channels[1])
|
|
self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 3, decoder_channels[2])
|
|
self.decode0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3])
|
|
|
|
def forward(self,
|
|
s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, f4: Tensor,
|
|
r1: Optional[Tensor], r2: Optional[Tensor],
|
|
r3: Optional[Tensor], r4: Optional[Tensor]):
|
|
s1, s2, s3 = self.avgpool(s0)
|
|
x4, r4 = self.decode4(f4, r4)
|
|
x3, r3 = self.decode3(x4, f3, s3, r3)
|
|
x2, r2 = self.decode2(x3, f2, s2, r2)
|
|
x1, r1 = self.decode1(x2, f1, s1, r1)
|
|
x0 = self.decode0(x1, s0)
|
|
return x0, r1, r2, r3, r4
|
|
|
|
|
|
class AvgPool(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.avgpool = nn.AvgPool2d(2, 2, count_include_pad=False, ceil_mode=True)
|
|
|
|
def forward_single_frame(self, s0):
|
|
s1 = self.avgpool(s0)
|
|
s2 = self.avgpool(s1)
|
|
s3 = self.avgpool(s2)
|
|
return s1, s2, s3
|
|
|
|
def forward_time_series(self, s0):
|
|
B, T = s0.shape[:2]
|
|
s0 = s0.flatten(0, 1)
|
|
s1, s2, s3 = self.forward_single_frame(s0)
|
|
s1 = s1.unflatten(0, (B, T))
|
|
s2 = s2.unflatten(0, (B, T))
|
|
s3 = s3.unflatten(0, (B, T))
|
|
return s1, s2, s3
|
|
|
|
def forward(self, s0):
|
|
if s0.ndim == 5:
|
|
return self.forward_time_series(s0)
|
|
else:
|
|
return self.forward_single_frame(s0)
|
|
|
|
|
|
class BottleneckBlock(nn.Module):
|
|
def __init__(self, channels):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.gru = ConvGRU(channels // 2)
|
|
|
|
def forward(self, x, r: Optional[Tensor]):
|
|
a, b = x.split(self.channels // 2, dim=-3)
|
|
b, r = self.gru(b, r)
|
|
x = torch.cat([a, b], dim=-3)
|
|
return x, r
|
|
|
|
|
|
class UpsamplingBlock(nn.Module):
|
|
def __init__(self, in_channels, skip_channels, src_channels, out_channels):
|
|
super().__init__()
|
|
self.out_channels = out_channels
|
|
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(in_channels + skip_channels + src_channels, out_channels, 3, 1, 1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(True),
|
|
)
|
|
self.gru = ConvGRU(out_channels // 2)
|
|
|
|
def forward_single_frame(self, x, f, s, r: Optional[Tensor]):
|
|
x = self.upsample(x)
|
|
x = x[:, :, :s.size(2), :s.size(3)]
|
|
x = torch.cat([x, f, s], dim=1)
|
|
x = self.conv(x)
|
|
a, b = x.split(self.out_channels // 2, dim=1)
|
|
b, r = self.gru(b, r)
|
|
x = torch.cat([a, b], dim=1)
|
|
return x, r
|
|
|
|
def forward_time_series(self, x, f, s, r: Optional[Tensor]):
|
|
B, T, _, H, W = s.shape
|
|
x = x.flatten(0, 1)
|
|
f = f.flatten(0, 1)
|
|
s = s.flatten(0, 1)
|
|
x = self.upsample(x)
|
|
x = x[:, :, :H, :W]
|
|
x = torch.cat([x, f, s], dim=1)
|
|
x = self.conv(x)
|
|
x = x.unflatten(0, (B, T))
|
|
a, b = x.split(self.out_channels // 2, dim=2)
|
|
b, r = self.gru(b, r)
|
|
x = torch.cat([a, b], dim=2)
|
|
return x, r
|
|
|
|
def forward(self, x, f, s, r: Optional[Tensor]):
|
|
if x.ndim == 5:
|
|
return self.forward_time_series(x, f, s, r)
|
|
else:
|
|
return self.forward_single_frame(x, f, s, r)
|
|
|
|
|
|
class OutputBlock(nn.Module):
|
|
def __init__(self, in_channels, src_channels, out_channels):
|
|
super().__init__()
|
|
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(in_channels + src_channels, out_channels, 3, 1, 1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(True),
|
|
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(True),
|
|
)
|
|
|
|
def forward_single_frame(self, x, s):
|
|
x = self.upsample(x)
|
|
x = x[:, :, :s.size(2), :s.size(3)]
|
|
x = torch.cat([x, s], dim=1)
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
def forward_time_series(self, x, s):
|
|
B, T, _, H, W = s.shape
|
|
x = x.flatten(0, 1)
|
|
s = s.flatten(0, 1)
|
|
x = self.upsample(x)
|
|
x = x[:, :, :H, :W]
|
|
x = torch.cat([x, s], dim=1)
|
|
x = self.conv(x)
|
|
x = x.unflatten(0, (B, T))
|
|
return x
|
|
|
|
def forward(self, x, s):
|
|
if x.ndim == 5:
|
|
return self.forward_time_series(x, s)
|
|
else:
|
|
return self.forward_single_frame(x, s)
|
|
|
|
|
|
class ConvGRU(nn.Module):
|
|
def __init__(self,
|
|
channels: int,
|
|
kernel_size: int = 3,
|
|
padding: int = 1):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.ih = nn.Sequential(
|
|
nn.Conv2d(channels * 2, channels * 2, kernel_size, padding=padding),
|
|
nn.Sigmoid()
|
|
)
|
|
self.hh = nn.Sequential(
|
|
nn.Conv2d(channels * 2, channels, kernel_size, padding=padding),
|
|
nn.Tanh()
|
|
)
|
|
|
|
def forward_single_frame(self, x, h):
|
|
r, z = self.ih(torch.cat([x, h], dim=1)).split(self.channels, dim=1)
|
|
c = self.hh(torch.cat([x, r * h], dim=1))
|
|
h = (1 - z) * h + z * c
|
|
return h, h
|
|
|
|
def forward_time_series(self, x, h):
|
|
o = []
|
|
for xt in x.unbind(dim=1):
|
|
ot, h = self.forward_single_frame(xt, h)
|
|
o.append(ot)
|
|
o = torch.stack(o, dim=1)
|
|
return o, h
|
|
|
|
def forward(self, x, h: Optional[Tensor]):
|
|
if h is None:
|
|
h = torch.zeros((x.size(0), x.size(-3), x.size(-2), x.size(-1)),
|
|
device=x.device, dtype=x.dtype)
|
|
|
|
if x.ndim == 5:
|
|
return self.forward_time_series(x, h)
|
|
else:
|
|
return self.forward_single_frame(x, h)
|
|
|
|
|
|
class Projection(nn.Module):
|
|
def __init__(self, in_channels, out_channels):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(in_channels, out_channels, 1)
|
|
|
|
def forward_single_frame(self, x):
|
|
return self.conv(x)
|
|
|
|
def forward_time_series(self, x):
|
|
B, T = x.shape[:2]
|
|
return self.conv(x.flatten(0, 1)).unflatten(0, (B, T))
|
|
|
|
def forward(self, x):
|
|
if x.ndim == 5:
|
|
return self.forward_time_series(x)
|
|
else:
|
|
return self.forward_single_frame(x)
|
|
|