mirror of https://github.com/menyifang/DCT-Net
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
83 lines
3.4 KiB
Python
83 lines
3.4 KiB
Python
import torch
|
|
from torch.autograd import Variable
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import optim
|
|
|
|
from PIL import Image
|
|
import os
|
|
|
|
|
|
# gram matrix and loss
|
|
class GramMatrix(nn.Module):
|
|
def forward(self, input):
|
|
b, c, h, w = input.size()
|
|
F = input.view(b, c, h * w)
|
|
G = torch.bmm(F, F.transpose(1, 2))
|
|
G.div_(h * w)
|
|
return G
|
|
|
|
|
|
class GramMSELoss(nn.Module):
|
|
def forward(self, input, target):
|
|
out = nn.MSELoss()(GramMatrix()(input), target)
|
|
return (out)
|
|
|
|
|
|
# vgg definition that conveniently let's you grab the outputs from any layer
|
|
class VGG(nn.Module):
|
|
def __init__(self, pool='max'):
|
|
super(VGG, self).__init__()
|
|
# vgg modules
|
|
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
|
|
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
|
|
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
|
|
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
|
|
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
|
|
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
|
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
|
self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
|
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
|
|
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
|
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
|
self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
|
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
|
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
|
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
|
self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
|
|
|
if pool == 'max':
|
|
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
|
|
elif pool == 'avg':
|
|
self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
|
|
self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
|
|
self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
|
|
self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
|
|
self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)
|
|
|
|
|
|
def forward(self, x):
|
|
out = {}
|
|
out['r11'] = F.relu(self.conv1_1(x))
|
|
out['r12'] = F.relu(self.conv1_2(out['r11']))
|
|
out['p1'] = self.pool1(out['r12'])
|
|
out['r21'] = F.relu(self.conv2_1(out['p1']))
|
|
out['r22'] = F.relu(self.conv2_2(out['r21']))
|
|
out['p2'] = self.pool2(out['r22'])
|
|
out['r31'] = F.relu(self.conv3_1(out['p2']))
|
|
out['r32'] = F.relu(self.conv3_2(out['r31']))
|
|
out['r33'] = F.relu(self.conv3_3(out['r32']))
|
|
out['r34'] = F.relu(self.conv3_4(out['r33']))
|
|
out['p3'] = self.pool3(out['r34'])
|
|
out['r41'] = F.relu(self.conv4_1(out['p3']))
|
|
out['r42'] = F.relu(self.conv4_2(out['r41']))
|
|
out['r43'] = F.relu(self.conv4_3(out['r42']))
|
|
conv_4_4 = self.conv4_4(out['r43'])
|
|
out['r44'] = F.relu(conv_4_4)
|
|
out['p4'] = self.pool4(out['r44'])
|
|
return conv_4_4 |