Fix issue with torchvision 0.11.0

pull/84/merge
Peter Lin 3 years ago
parent e58f7bbdf0
commit 918a97f347

@ -1,6 +1,6 @@
import torch
from torch import nn from torch import nn
from torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfig from torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfig
from torchvision.models.utils import load_state_dict_from_url
from torchvision.transforms.functional import normalize from torchvision.transforms.functional import normalize
class MobileNetV3LargeEncoder(MobileNetV3): class MobileNetV3LargeEncoder(MobileNetV3):
@ -27,7 +27,7 @@ class MobileNetV3LargeEncoder(MobileNetV3):
) )
if pretrained: if pretrained:
self.load_state_dict(load_state_dict_from_url( self.load_state_dict(torch.hub.load_state_dict_from_url(
'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth')) 'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth'))
del self.avgpool del self.avgpool

@ -1,6 +1,6 @@
import torch
from torch import nn from torch import nn
from torchvision.models.resnet import ResNet, Bottleneck from torchvision.models.resnet import ResNet, Bottleneck
from torchvision.models.utils import load_state_dict_from_url
class ResNet50Encoder(ResNet): class ResNet50Encoder(ResNet):
def __init__(self, pretrained: bool = False): def __init__(self, pretrained: bool = False):
@ -11,7 +11,7 @@ class ResNet50Encoder(ResNet):
norm_layer=None) norm_layer=None)
if pretrained: if pretrained:
self.load_state_dict(load_state_dict_from_url( self.load_state_dict(torch.hub.load_state_dict_from_url(
'https://download.pytorch.org/models/resnet50-0676ba61.pth')) 'https://download.pytorch.org/models/resnet50-0676ba61.pth'))
del self.avgpool del self.avgpool

Loading…
Cancel
Save