|
|
|
@ -1,6 +1,6 @@
|
|
|
|
|
import torch
|
|
|
|
|
from torch import nn
|
|
|
|
|
from torchvision.models.resnet import ResNet, Bottleneck
|
|
|
|
|
from torchvision.models.utils import load_state_dict_from_url
|
|
|
|
|
|
|
|
|
|
class ResNet50Encoder(ResNet):
|
|
|
|
|
def __init__(self, pretrained: bool = False):
|
|
|
|
@ -11,7 +11,7 @@ class ResNet50Encoder(ResNet):
|
|
|
|
|
norm_layer=None)
|
|
|
|
|
|
|
|
|
|
if pretrained:
|
|
|
|
|
self.load_state_dict(load_state_dict_from_url(
|
|
|
|
|
self.load_state_dict(torch.hub.load_state_dict_from_url(
|
|
|
|
|
'https://download.pytorch.org/models/resnet50-0676ba61.pth'))
|
|
|
|
|
|
|
|
|
|
del self.avgpool
|
|
|
|
|