mirror of https://github.com/menyifang/DCT-Net
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
31 lines
818 B
Python
31 lines
818 B
Python
2 years ago
|
import torch
|
||
|
from torch import nn
|
||
|
from .vgg import VGG
|
||
|
import os
|
||
|
from configs.paths_config import model_paths
|
||
|
|
||
|
class VggLoss(nn.Module):
|
||
|
def __init__(self):
|
||
|
super(VggLoss, self).__init__()
|
||
|
print("Loading VGG19 model from path: {}".format(model_paths["vgg"]))
|
||
|
|
||
|
self.vgg_model = VGG()
|
||
|
self.vgg_model.load_state_dict(torch.load(model_paths['vgg']))
|
||
|
self.vgg_model.cuda()
|
||
|
self.vgg_model.eval()
|
||
|
|
||
|
self.l1loss = torch.nn.L1Loss()
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
def forward(self, input_photo, output):
|
||
|
vgg_photo = self.vgg_model(input_photo)
|
||
|
vgg_output = self.vgg_model(output)
|
||
|
n, c, h, w = vgg_photo.shape
|
||
|
# h, w, c = vgg_photo.get_shape().as_list()[1:]
|
||
|
loss = self.l1loss(vgg_photo, vgg_output)
|
||
|
|
||
|
return loss
|