import torch from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F from torch import optim from PIL import Image import os # gram matrix and loss class GramMatrix(nn.Module): def forward(self, input): b, c, h, w = input.size() F = input.view(b, c, h * w) G = torch.bmm(F, F.transpose(1, 2)) G.div_(h * w) return G class GramMSELoss(nn.Module): def forward(self, input, target): out = nn.MSELoss()(GramMatrix()(input), target) return (out) # vgg definition that conveniently let's you grab the outputs from any layer class VGG(nn.Module): def __init__(self, pool='max'): super(VGG, self).__init__() # vgg modules self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1) self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) if pool == 'max': self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) elif pool == 'avg': self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2) self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2) self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2) def forward(self, x): out = {} out['r11'] = F.relu(self.conv1_1(x)) out['r12'] = F.relu(self.conv1_2(out['r11'])) out['p1'] = self.pool1(out['r12']) out['r21'] = F.relu(self.conv2_1(out['p1'])) out['r22'] = F.relu(self.conv2_2(out['r21'])) out['p2'] = self.pool2(out['r22']) out['r31'] = F.relu(self.conv3_1(out['p2'])) out['r32'] = F.relu(self.conv3_2(out['r31'])) out['r33'] = F.relu(self.conv3_3(out['r32'])) out['r34'] = F.relu(self.conv3_4(out['r33'])) out['p3'] = self.pool3(out['r34']) out['r41'] = F.relu(self.conv4_1(out['p3'])) out['r42'] = F.relu(self.conv4_2(out['r41'])) out['r43'] = F.relu(self.conv4_3(out['r42'])) conv_4_4 = self.conv4_4(out['r43']) out['r44'] = F.relu(conv_4_4) out['p4'] = self.pool4(out['r44']) return conv_4_4