You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
46 lines
1.4 KiB
Python
46 lines
1.4 KiB
Python
from torch import nn
|
|
from torchvision.models.resnet import ResNet, Bottleneck
|
|
from torchvision.models.utils import load_state_dict_from_url
|
|
|
|
class ResNet50Encoder(ResNet):
|
|
def __init__(self, pretrained: bool = False):
|
|
super().__init__(
|
|
block=Bottleneck,
|
|
layers=[3, 4, 6, 3],
|
|
replace_stride_with_dilation=[False, False, True],
|
|
norm_layer=None)
|
|
|
|
if pretrained:
|
|
self.load_state_dict(load_state_dict_from_url(
|
|
'https://download.pytorch.org/models/resnet50-0676ba61.pth'))
|
|
|
|
del self.avgpool
|
|
del self.fc
|
|
|
|
def forward_single_frame(self, x):
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
f1 = x # 1/2
|
|
x = self.maxpool(x)
|
|
x = self.layer1(x)
|
|
f2 = x # 1/4
|
|
x = self.layer2(x)
|
|
f3 = x # 1/8
|
|
x = self.layer3(x)
|
|
x = self.layer4(x)
|
|
f4 = x # 1/16
|
|
return [f1, f2, f3, f4]
|
|
|
|
def forward_time_series(self, x):
|
|
B, T = x.shape[:2]
|
|
features = self.forward_single_frame(x.flatten(0, 1))
|
|
features = [f.unflatten(0, (B, T)) for f in features]
|
|
return features
|
|
|
|
def forward(self, x):
|
|
if x.ndim == 5:
|
|
return self.forward_time_series(x)
|
|
else:
|
|
return self.forward_single_frame(x)
|