mirror of https://github.com/menyifang/DCT-Net
update training code
parent
d3c83d8a62
commit
2acc41f9e5
File diff suppressed because one or more lines are too long
@ -0,0 +1,45 @@
|
||||
|
||||
data_root='data'
|
||||
align_dir='raw_style_data_faces'
|
||||
|
||||
echo "STEP: start to prepare data for stylegan ..."
|
||||
cd $data_root
|
||||
if [ ! -d stylegan ]; then
|
||||
mkdir stylegan
|
||||
fi
|
||||
cd stylegan
|
||||
stylegan_data_dir=$(pwd)
|
||||
if [ ! -d "$(date +"%Y%m%d")" ]; then
|
||||
mkdir "$(date +"%Y%m%d")"
|
||||
fi
|
||||
cd "$(date +"%Y%m%d")"
|
||||
cp $align_dir . -r
|
||||
if [ -d $(echo $align_dir) ]; then
|
||||
cp $(echo $align_dir) . -r
|
||||
fi
|
||||
src_dir_sg=$(pwd)
|
||||
|
||||
cd $data_root/../source
|
||||
outdir_sg="$(echo $stylegan_data_dir)/traindata_$(echo $stylename)_256_$(date +"%m%d")"
|
||||
echo $outdir_sg
|
||||
echo $src_dir_sg
|
||||
if [ ! -d "$outdir_sg" ]; then
|
||||
python prepare_data.py --size 256 --out $outdir_sg $src_dir_sg
|
||||
fi
|
||||
echo "prepare data for stylegan finished!"
|
||||
|
||||
### train model
|
||||
#cd $data_root
|
||||
#cd stylegan
|
||||
#stylegan_data_dir=$(pwd)
|
||||
#outdir_sg="$(echo $stylegan_data_dir)/traindata_$(echo $stylename)_256_$(date +"%m%d")"
|
||||
#echo "STEP:start to train the style learner ..."
|
||||
#echo $outdir_sg
|
||||
#exp_name="ffhq_$(echo $stylename)_s256_id01_$(date +"%m%d")"
|
||||
#cd /data/vdb/qingyao/cartoon/mycode/stylegan2-pytorch
|
||||
#model_path=face_generation/experiment_stylegan/$(echo $exp_name)/models/001000.pt
|
||||
#if [ ! -f "$model_path" ]; then
|
||||
# CUDA_VISIBLE_DEVICES=6 python train_condition.py --name $exp_name --path $outdir_sg --config config/conf_server_train_condition_shell.json
|
||||
#fi
|
||||
#### [training...]
|
||||
#echo "train the style learner finished!"
|
Binary file not shown.
@ -0,0 +1,78 @@
|
||||
import oss2
|
||||
import argparse
|
||||
import cv2
|
||||
import glob
|
||||
import os
|
||||
import tqdm
|
||||
import numpy as np
|
||||
# from .utils import get_rmbg_alpha, get_img_from_url,reasonable_resize,major_detection,crop_img
|
||||
import tqdm
|
||||
import urllib
|
||||
import random
|
||||
from multiprocessing import Pool
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="process remove bg result")
|
||||
parser.add_argument("--data_dir", type=str, default="", help="Path to images.")
|
||||
parser.add_argument("--save_dir", type=str, default="", help="Path to save images.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
args.save_dir = os.path.join(args.data_dir, 'total_flip')
|
||||
form = 'single'
|
||||
|
||||
|
||||
def flipImage(image):
|
||||
new_image = cv2.flip(image, 1)
|
||||
return new_image
|
||||
|
||||
def all_file(file_dir):
|
||||
L=[]
|
||||
for root, dirs, files in os.walk(file_dir):
|
||||
for file in files:
|
||||
extend = os.path.splitext(file)[1]
|
||||
if extend == '.png' or extend == '.jpg' or extend == '.jpeg' or extend == '.JPG':
|
||||
L.append(os.path.join(root, file))
|
||||
return L
|
||||
|
||||
|
||||
paths = all_file(args.data_dir)
|
||||
|
||||
|
||||
def process(path):
|
||||
|
||||
print(path)
|
||||
outpath = args.save_dir+path[len(args.data_dir):]
|
||||
if os.path.exists(outpath):
|
||||
return
|
||||
|
||||
sub_dir = os.path.dirname(outpath)
|
||||
# print(sub_dir)
|
||||
if not os.path.exists(sub_dir):
|
||||
os.makedirs(sub_dir,exist_ok=True)
|
||||
|
||||
img = cv2.imread(path, -1)
|
||||
h, w, c = img.shape
|
||||
if form == "pair":
|
||||
imga = img[:, :int(w / 2), :]
|
||||
imgb = img[:, int(w / 2):, :]
|
||||
imga = flipImage(imga)
|
||||
imgb = flipImage(imgb)
|
||||
res = cv2.hconcat([imga, imgb]) # 水平拼接
|
||||
|
||||
else:
|
||||
res = flipImage(img)
|
||||
|
||||
cv2.imwrite(outpath, res)
|
||||
print('save %s' % outpath)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# main(args)
|
||||
pool = Pool(100)
|
||||
rl = pool.map(process, paths)
|
||||
pool.close()
|
||||
pool.join()
|
@ -0,0 +1,101 @@
|
||||
import oss2
|
||||
import argparse
|
||||
import cv2
|
||||
import glob
|
||||
import os
|
||||
import tqdm
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import urllib
|
||||
import random
|
||||
from multiprocessing import Pool
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="process remove bg result")
|
||||
parser.add_argument("--data_dir", type=str, default="", help="Path to images.")
|
||||
parser.add_argument("--save_dir", type=str, default="", help="Path to save images.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
args.save_dir = os.path.join(args.data_dir, 'total_rotate')
|
||||
form = 'single'
|
||||
|
||||
if not os.path.exists(args.save_dir):
|
||||
os.makedirs(args.save_dir,exist_ok=True)
|
||||
|
||||
def all_file(file_dir):
|
||||
L=[]
|
||||
for root, dirs, files in os.walk(file_dir):
|
||||
for file in files:
|
||||
extend = os.path.splitext(file)[1]
|
||||
if extend == '.png' or extend == '.jpg' or extend == '.jpeg':
|
||||
L.append(os.path.join(root, file))
|
||||
return L
|
||||
|
||||
def rotateImage(image, angle):
|
||||
row,col,_ = image.shape
|
||||
center=tuple(np.array([row,col])/2)
|
||||
rot_mat = cv2.getRotationMatrix2D(center,angle,1.0)
|
||||
new_image = cv2.warpAffine(image, rot_mat, (col,row), borderMode=cv2.BORDER_REFLECT)
|
||||
return new_image
|
||||
|
||||
|
||||
paths = all_file(args.data_dir)
|
||||
|
||||
|
||||
def process(path):
|
||||
|
||||
if 'total_scale' in path:
|
||||
return
|
||||
|
||||
outpath = args.save_dir + path[len(args.data_dir):]
|
||||
sub_dir = os.path.dirname(outpath)
|
||||
if not os.path.exists(sub_dir):
|
||||
os.makedirs(sub_dir, exist_ok=True)
|
||||
|
||||
img0 = cv2.imread(path, -1)
|
||||
h, w, c = img0.shape
|
||||
img = img0[:, :, :3].copy()
|
||||
if c == 4:
|
||||
alpha = img0[:, :, 3]
|
||||
mask = alpha[:, :, np.newaxis].copy() / 255.
|
||||
img = (img * mask + (1 - mask) * 255)
|
||||
|
||||
imgb = None
|
||||
imgc = None
|
||||
if form is 'single':
|
||||
imga = img
|
||||
elif form is 'pair':
|
||||
imga = img[:, :int(w / 2), :]
|
||||
imgb = img[:, int(w / 2):, :]
|
||||
elif form is 'tuple':
|
||||
imga = img[:, :int(w / 3), :]
|
||||
imgb = img[:, int(w / 3): int(w * 2 / 3), :]
|
||||
imgc = img[:, int(w * 2 / 3):, :]
|
||||
|
||||
angles = [ random.randint(-10, 0), random.randint(0, 10)]
|
||||
|
||||
for angle in angles:
|
||||
|
||||
imga_r = rotateImage(imga, angle)
|
||||
if form is 'single':
|
||||
res = imga_r
|
||||
elif form is 'pair':
|
||||
imgb_r = rotateImage(imgb, angle)
|
||||
res = cv2.hconcat([imga_r, imgb_r]) # 水平拼接
|
||||
else:
|
||||
imgb_r = rotateImage(imgb, angle)
|
||||
imgc_r = rotateImage(imgc, angle)
|
||||
res = cv2.hconcat([imga_r, imgb_r, imgc_r]) # 水平拼接
|
||||
|
||||
cv2.imwrite(outpath[:-4]+'_'+str(angle)+'.png', res)
|
||||
print('save %s'% outpath)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# main(args)
|
||||
pool = Pool(100)
|
||||
rl = pool.map(process, paths)
|
||||
pool.close()
|
||||
pool.join()
|
@ -0,0 +1,117 @@
|
||||
import oss2
|
||||
import argparse
|
||||
import cv2
|
||||
import glob
|
||||
import os
|
||||
import tqdm
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import urllib
|
||||
import random
|
||||
from multiprocessing import Pool
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="process remove bg result")
|
||||
parser.add_argument("--data_dir", type=str, default="", help="Path to images.")
|
||||
parser.add_argument("--save_dir", type=str, default="", help="Path to save images.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
args.save_dir = os.path.join(args.data_dir, 'total_scale')
|
||||
form = 'single'
|
||||
|
||||
if not os.path.exists(args.save_dir):
|
||||
os.makedirs(args.save_dir,exist_ok=True)
|
||||
|
||||
def all_file(file_dir):
|
||||
L=[]
|
||||
for root, dirs, files in os.walk(file_dir):
|
||||
for file in files:
|
||||
extend = os.path.splitext(file)[1]
|
||||
if extend == '.png' or extend == '.jpg' or extend == '.jpeg':
|
||||
L.append(os.path.join(root, file))
|
||||
return L
|
||||
|
||||
def scaleImage(image, degree):
|
||||
|
||||
h, w, _ = image.shape
|
||||
canvas = np.ones((h, w, 3), dtype="uint8")*255
|
||||
nw, nh = (int(w*degree), int(h*degree))
|
||||
image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) # w, h
|
||||
|
||||
if degree<1:
|
||||
canvas[int((h-nh)/2):int((h-nh)/2)+nh, int((w-nw)/2):int((w-nw)/2)+nw,:] = image
|
||||
elif degree>1:
|
||||
canvas = image[int((nh-h)/2):int((nh-h)/2)+h, int((nw-w)/2):int((nw-w)/2)+w, :]
|
||||
else:
|
||||
canvas = image.copy()
|
||||
|
||||
return canvas
|
||||
|
||||
def scaleImage2(image, degree, angle=0):
|
||||
row,col,_ = image.shape
|
||||
center=tuple(np.array([row,col])/2)
|
||||
rot_mat = cv2.getRotationMatrix2D(center,angle,degree)
|
||||
new_image = cv2.warpAffine(image, rot_mat, (col,row), borderMode=cv2.BORDER_REFLECT)
|
||||
return new_image
|
||||
|
||||
|
||||
paths = all_file(args.data_dir)
|
||||
|
||||
|
||||
def process(path):
|
||||
|
||||
outpath = args.save_dir+path[len(args.data_dir):]
|
||||
sub_dir = os.path.dirname(outpath)
|
||||
if not os.path.exists(sub_dir):
|
||||
os.makedirs(sub_dir, exist_ok=True)
|
||||
|
||||
|
||||
img0 = cv2.imread(path, -1)
|
||||
h, w, c = img0.shape
|
||||
img = img0[:, :, :3].copy()
|
||||
if c==4:
|
||||
alpha = img0[:, :, 3]
|
||||
mask = alpha[:, :, np.newaxis].copy() / 255.
|
||||
img = (img * mask + (1 - mask) * 255)
|
||||
|
||||
imgb = None
|
||||
imgc = None
|
||||
if form is 'single':
|
||||
imga = img
|
||||
elif form is 'pair':
|
||||
imga = img[:, :int(w / 2), :]
|
||||
imgb = img[:, int(w / 2):, :]
|
||||
elif form is 'tuple':
|
||||
imga = img[:, :int(w / 3), :]
|
||||
imgb = img[:, int(w / 3): int(w * 2 / 3), :]
|
||||
imgc = img[:, int(w * 2 / 3):, :]
|
||||
|
||||
if random.random()>0.9:
|
||||
angles = [random.uniform(1, 1.1)]
|
||||
else:
|
||||
angles = [random.uniform(0.8, 1)]
|
||||
|
||||
for angle in angles:
|
||||
|
||||
imga_r = scaleImage(imga, angle)
|
||||
if form is 'single':
|
||||
res = imga_r
|
||||
elif form is 'pair':
|
||||
imgb_r = scaleImage(imgb, angle)
|
||||
res = cv2.hconcat([imga_r, imgb_r]) # 水平拼接
|
||||
else:
|
||||
imgb_r = scaleImage(imgb, angle)
|
||||
imgc_r = scaleImage(imgc, angle)
|
||||
res = cv2.hconcat([imga_r, imgb_r, imgc_r]) # 水平拼接
|
||||
|
||||
cv2.imwrite(outpath[:-4]+'_'+str(angle)+'.png', res)
|
||||
print('save %s'% outpath)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# main(args)
|
||||
pool = Pool(100)
|
||||
rl = pool.map(process, paths)
|
||||
pool.close()
|
||||
pool.join()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,12 @@
|
||||
{
|
||||
"parameters": {
|
||||
"output": "face_generation/experiment_stylegan",
|
||||
"ffhq_ckpt": "pretrained_models/stylegan2-ffhq-config-f-256-550000.pt",
|
||||
"size": 256,
|
||||
"sample": 1,
|
||||
"pics": 5000,
|
||||
"truncation": 0.7,
|
||||
"form": "single"
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,14 @@
|
||||
{
|
||||
"parameters": {
|
||||
"output": "face_generation/experiment_stylegan",
|
||||
"ckpt": "pretrained_models/stylegan2-ffhq-config-f-256-550000.pt",
|
||||
"ckpt_ffhq": "pretrained_models/stylegan2-ffhq-config-f-256-550000.pt",
|
||||
"size": 256,
|
||||
"batch": 8,
|
||||
"n_sample": 4,
|
||||
"iter": 1500,
|
||||
"sample_every": 100,
|
||||
"save_every": 100
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,119 @@
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
|
||||
|
||||
"""
|
||||
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
||||
"""
|
||||
|
||||
|
||||
class Flatten(Module):
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1)
|
||||
|
||||
|
||||
def l2_norm(input, axis=1):
|
||||
norm = torch.norm(input, 2, axis, True)
|
||||
output = torch.div(input, norm)
|
||||
return output
|
||||
|
||||
|
||||
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
||||
""" A named tuple describing a ResNet block. """
|
||||
|
||||
|
||||
def get_block(in_channel, depth, num_units, stride=2):
|
||||
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
||||
|
||||
|
||||
def get_blocks(num_layers):
|
||||
if num_layers == 50:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=4),
|
||||
get_block(in_channel=128, depth=256, num_units=14),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
elif num_layers == 100:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=13),
|
||||
get_block(in_channel=128, depth=256, num_units=30),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
elif num_layers == 152:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=8),
|
||||
get_block(in_channel=128, depth=256, num_units=36),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
else:
|
||||
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
|
||||
return blocks
|
||||
|
||||
|
||||
class SEModule(Module):
|
||||
def __init__(self, channels, reduction):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = AdaptiveAvgPool2d(1)
|
||||
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
|
||||
self.relu = ReLU(inplace=True)
|
||||
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
|
||||
self.sigmoid = Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
module_input = x
|
||||
x = self.avg_pool(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return module_input * x
|
||||
|
||||
|
||||
class bottleneck_IR(Module):
|
||||
def __init__(self, in_channel, depth, stride):
|
||||
super(bottleneck_IR, self).__init__()
|
||||
if in_channel == depth:
|
||||
self.shortcut_layer = MaxPool2d(1, stride)
|
||||
else:
|
||||
self.shortcut_layer = Sequential(
|
||||
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||
BatchNorm2d(depth)
|
||||
)
|
||||
self.res_layer = Sequential(
|
||||
BatchNorm2d(in_channel),
|
||||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
|
||||
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut_layer(x)
|
||||
res = self.res_layer(x)
|
||||
return res + shortcut
|
||||
|
||||
|
||||
class bottleneck_IR_SE(Module):
|
||||
def __init__(self, in_channel, depth, stride):
|
||||
super(bottleneck_IR_SE, self).__init__()
|
||||
if in_channel == depth:
|
||||
self.shortcut_layer = MaxPool2d(1, stride)
|
||||
else:
|
||||
self.shortcut_layer = Sequential(
|
||||
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||
BatchNorm2d(depth)
|
||||
)
|
||||
self.res_layer = Sequential(
|
||||
BatchNorm2d(in_channel),
|
||||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
||||
PReLU(depth),
|
||||
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
||||
BatchNorm2d(depth),
|
||||
SEModule(depth, 16)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut_layer(x)
|
||||
res = self.res_layer(x)
|
||||
return res + shortcut
|
@ -0,0 +1,40 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from .model_irse import Backbone
|
||||
|
||||
|
||||
class IDLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(IDLoss, self).__init__()
|
||||
print('Loading ResNet ArcFace')
|
||||
model_paths = '/data/vdb/qingyao/cartoon/mycode/pretrained_models/model_ir_se50.pth'
|
||||
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
||||
self.facenet.load_state_dict(torch.load(model_paths))
|
||||
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
||||
self.facenet.eval()
|
||||
|
||||
def extract_feats(self, x):
|
||||
x = x[:, :, 35:223, 32:220] # Crop interesting region
|
||||
x = self.face_pool(x)
|
||||
x_feats = self.facenet(x)
|
||||
return x_feats
|
||||
|
||||
def forward(self, y_hat, x):
|
||||
n_samples = x.shape[0]
|
||||
x_feats = self.extract_feats(x)
|
||||
y_hat_feats = self.extract_feats(y_hat)
|
||||
loss = 0
|
||||
sim_improvement = 0
|
||||
id_logs = []
|
||||
count = 0
|
||||
for i in range(n_samples):
|
||||
diff_input = y_hat_feats[i].dot(x_feats[i])
|
||||
id_logs.append({
|
||||
'diff_input': float(diff_input)
|
||||
})
|
||||
# loss += 1 - diff_target
|
||||
# modify
|
||||
loss += 1 - diff_input
|
||||
count += 1
|
||||
|
||||
return loss / count
|
@ -0,0 +1,35 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from criteria.lpips.networks import get_network, LinLayers
|
||||
from criteria.lpips.utils import get_state_dict
|
||||
|
||||
|
||||
class LPIPS(nn.Module):
|
||||
r"""Creates a criterion that measures
|
||||
Learned Perceptual Image Patch Similarity (LPIPS).
|
||||
Arguments:
|
||||
net_type (str): the network type to compare the features:
|
||||
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
|
||||
version (str): the version of LPIPS. Default: 0.1.
|
||||
"""
|
||||
def __init__(self, net_type: str = 'alex', version: str = '0.1'):
|
||||
|
||||
assert version in ['0.1'], 'v0.1 is only supported now'
|
||||
|
||||
super(LPIPS, self).__init__()
|
||||
|
||||
# pretrained network
|
||||
self.net = get_network(net_type).to("cuda")
|
||||
|
||||
# linear layers
|
||||
self.lin = LinLayers(self.net.n_channels_list).to("cuda")
|
||||
self.lin.load_state_dict(get_state_dict(net_type, version))
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||
feat_x, feat_y = self.net(x), self.net(y)
|
||||
|
||||
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
|
||||
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
|
||||
|
||||
return torch.sum(torch.cat(res, 0)) / x.shape[0]
|
@ -0,0 +1,96 @@
|
||||
from typing import Sequence
|
||||
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
|
||||
from criteria.lpips.utils import normalize_activation
|
||||
|
||||
|
||||
def get_network(net_type: str):
|
||||
if net_type == 'alex':
|
||||
return AlexNet()
|
||||
elif net_type == 'squeeze':
|
||||
return SqueezeNet()
|
||||
elif net_type == 'vgg':
|
||||
return VGG16()
|
||||
else:
|
||||
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
|
||||
|
||||
|
||||
class LinLayers(nn.ModuleList):
|
||||
def __init__(self, n_channels_list: Sequence[int]):
|
||||
super(LinLayers, self).__init__([
|
||||
nn.Sequential(
|
||||
nn.Identity(),
|
||||
nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
|
||||
) for nc in n_channels_list
|
||||
])
|
||||
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
class BaseNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(BaseNet, self).__init__()
|
||||
|
||||
# register buffer
|
||||
self.register_buffer(
|
||||
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
||||
self.register_buffer(
|
||||
'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
||||
|
||||
def set_requires_grad(self, state: bool):
|
||||
for param in chain(self.parameters(), self.buffers()):
|
||||
param.requires_grad = state
|
||||
|
||||
def z_score(self, x: torch.Tensor):
|
||||
return (x - self.mean) / self.std
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.z_score(x)
|
||||
|
||||
output = []
|
||||
for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
|
||||
x = layer(x)
|
||||
if i in self.target_layers:
|
||||
output.append(normalize_activation(x))
|
||||
if len(output) == len(self.target_layers):
|
||||
break
|
||||
return output
|
||||
|
||||
|
||||
class SqueezeNet(BaseNet):
|
||||
def __init__(self):
|
||||
super(SqueezeNet, self).__init__()
|
||||
|
||||
self.layers = models.squeezenet1_1(True).features
|
||||
self.target_layers = [2, 5, 8, 10, 11, 12, 13]
|
||||
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
|
||||
|
||||
self.set_requires_grad(False)
|
||||
|
||||
|
||||
class AlexNet(BaseNet):
|
||||
def __init__(self):
|
||||
super(AlexNet, self).__init__()
|
||||
|
||||
self.layers = models.alexnet(True).features
|
||||
self.target_layers = [2, 5, 8, 10, 12]
|
||||
self.n_channels_list = [64, 192, 384, 256, 256]
|
||||
|
||||
self.set_requires_grad(False)
|
||||
|
||||
|
||||
class VGG16(BaseNet):
|
||||
def __init__(self):
|
||||
super(VGG16, self).__init__()
|
||||
|
||||
self.layers = models.vgg16(True).features
|
||||
self.target_layers = [4, 9, 16, 23, 30]
|
||||
self.n_channels_list = [64, 128, 256, 512, 512]
|
||||
|
||||
self.set_requires_grad(False)
|
@ -0,0 +1,30 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def normalize_activation(x, eps=1e-10):
|
||||
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
|
||||
return x / (norm_factor + eps)
|
||||
|
||||
|
||||
def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
|
||||
# build url
|
||||
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
|
||||
+ f'master/lpips/weights/v{version}/{net_type}.pth'
|
||||
|
||||
# download
|
||||
old_state_dict = torch.hub.load_state_dict_from_url(
|
||||
url, progress=True,
|
||||
map_location=None if torch.cuda.is_available() else torch.device('cpu')
|
||||
)
|
||||
|
||||
# rename keys
|
||||
new_state_dict = OrderedDict()
|
||||
for key, val in old_state_dict.items():
|
||||
new_key = key
|
||||
new_key = new_key.replace('lin', '')
|
||||
new_key = new_key.replace('model.', '')
|
||||
new_state_dict[new_key] = val
|
||||
|
||||
return new_state_dict
|
@ -0,0 +1,69 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from configs.paths_config import model_paths
|
||||
|
||||
|
||||
class MocoLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MocoLoss, self).__init__()
|
||||
print("Loading MOCO model from path: {}".format(model_paths["moco"]))
|
||||
self.model = self.__load_model()
|
||||
self.model.cuda()
|
||||
self.model.eval()
|
||||
|
||||
@staticmethod
|
||||
def __load_model():
|
||||
import torchvision.models as models
|
||||
model = models.__dict__["resnet50"]()
|
||||
# freeze all layers but the last fc
|
||||
for name, param in model.named_parameters():
|
||||
if name not in ['fc.weight', 'fc.bias']:
|
||||
param.requires_grad = False
|
||||
checkpoint = torch.load(model_paths['moco'], map_location="cpu")
|
||||
state_dict = checkpoint['state_dict']
|
||||
# rename moco pre-trained keys
|
||||
for k in list(state_dict.keys()):
|
||||
# retain only encoder_q up to before the embedding layer
|
||||
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
|
||||
# remove prefix
|
||||
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
|
||||
# delete renamed or unused k
|
||||
del state_dict[k]
|
||||
msg = model.load_state_dict(state_dict, strict=False)
|
||||
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
|
||||
# remove output layer
|
||||
model = nn.Sequential(*list(model.children())[:-1]).cuda()
|
||||
return model
|
||||
|
||||
def extract_feats(self, x):
|
||||
x = F.interpolate(x, size=224)
|
||||
x_feats = self.model(x)
|
||||
x_feats = nn.functional.normalize(x_feats, dim=1)
|
||||
x_feats = x_feats.squeeze()
|
||||
return x_feats
|
||||
|
||||
def forward(self, y_hat, y, x):
|
||||
n_samples = x.shape[0]
|
||||
x_feats = self.extract_feats(x)
|
||||
y_feats = self.extract_feats(y)
|
||||
y_hat_feats = self.extract_feats(y_hat)
|
||||
y_feats = y_feats.detach()
|
||||
loss = 0
|
||||
sim_improvement = 0
|
||||
sim_logs = []
|
||||
count = 0
|
||||
for i in range(n_samples):
|
||||
diff_target = y_hat_feats[i].dot(y_feats[i])
|
||||
diff_input = y_hat_feats[i].dot(x_feats[i])
|
||||
diff_views = y_feats[i].dot(x_feats[i])
|
||||
sim_logs.append({'diff_target': float(diff_target),
|
||||
'diff_input': float(diff_input),
|
||||
'diff_views': float(diff_views)})
|
||||
loss += 1 - diff_target
|
||||
sim_diff = float(diff_target) - float(diff_views)
|
||||
sim_improvement += sim_diff
|
||||
count += 1
|
||||
|
||||
return loss / count, sim_improvement / count, sim_logs
|
@ -0,0 +1,84 @@
|
||||
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
|
||||
from .helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
|
||||
|
||||
"""
|
||||
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
||||
"""
|
||||
|
||||
|
||||
class Backbone(Module):
|
||||
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
|
||||
super(Backbone, self).__init__()
|
||||
assert input_size in [112, 224], "input_size should be 112 or 224"
|
||||
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
|
||||
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
|
||||
blocks = get_blocks(num_layers)
|
||||
if mode == 'ir':
|
||||
unit_module = bottleneck_IR
|
||||
elif mode == 'ir_se':
|
||||
unit_module = bottleneck_IR_SE
|
||||
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
||||
BatchNorm2d(64),
|
||||
PReLU(64))
|
||||
if input_size == 112:
|
||||
self.output_layer = Sequential(BatchNorm2d(512),
|
||||
Dropout(drop_ratio),
|
||||
Flatten(),
|
||||
Linear(512 * 7 * 7, 512),
|
||||
BatchNorm1d(512, affine=affine))
|
||||
else:
|
||||
self.output_layer = Sequential(BatchNorm2d(512),
|
||||
Dropout(drop_ratio),
|
||||
Flatten(),
|
||||
Linear(512 * 14 * 14, 512),
|
||||
BatchNorm1d(512, affine=affine))
|
||||
|
||||
modules = []
|
||||
for block in blocks:
|
||||
for bottleneck in block:
|
||||
modules.append(unit_module(bottleneck.in_channel,
|
||||
bottleneck.depth,
|
||||
bottleneck.stride))
|
||||
self.body = Sequential(*modules)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_layer(x)
|
||||
x = self.body(x)
|
||||
x = self.output_layer(x)
|
||||
return l2_norm(x)
|
||||
|
||||
|
||||
def IR_50(input_size):
|
||||
"""Constructs a ir-50 model."""
|
||||
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_101(input_size):
|
||||
"""Constructs a ir-101 model."""
|
||||
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_152(input_size):
|
||||
"""Constructs a ir-152 model."""
|
||||
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_50(input_size):
|
||||
"""Constructs a ir_se-50 model."""
|
||||
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_101(input_size):
|
||||
"""Constructs a ir_se-101 model."""
|
||||
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_152(input_size):
|
||||
"""Constructs a ir_se-152 model."""
|
||||
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
@ -0,0 +1,83 @@
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import optim
|
||||
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
|
||||
# gram matrix and loss
|
||||
class GramMatrix(nn.Module):
|
||||
def forward(self, input):
|
||||
b, c, h, w = input.size()
|
||||
F = input.view(b, c, h * w)
|
||||
G = torch.bmm(F, F.transpose(1, 2))
|
||||
G.div_(h * w)
|
||||
return G
|
||||
|
||||
|
||||
class GramMSELoss(nn.Module):
|
||||
def forward(self, input, target):
|
||||
out = nn.MSELoss()(GramMatrix()(input), target)
|
||||
return (out)
|
||||
|
||||
|
||||
# vgg definition that conveniently let's you grab the outputs from any layer
|
||||
class VGG(nn.Module):
|
||||
def __init__(self, pool='max'):
|
||||
super(VGG, self).__init__()
|
||||
# vgg modules
|
||||
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
|
||||
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
|
||||
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
|
||||
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
|
||||
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
|
||||
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
||||
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
||||
self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
||||
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
|
||||
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
||||
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
||||
self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
||||
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
||||
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
||||
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
||||
self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
||||
|
||||
if pool == 'max':
|
||||
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
elif pool == 'avg':
|
||||
self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
out = {}
|
||||
out['r11'] = F.relu(self.conv1_1(x))
|
||||
out['r12'] = F.relu(self.conv1_2(out['r11']))
|
||||
out['p1'] = self.pool1(out['r12'])
|
||||
out['r21'] = F.relu(self.conv2_1(out['p1']))
|
||||
out['r22'] = F.relu(self.conv2_2(out['r21']))
|
||||
out['p2'] = self.pool2(out['r22'])
|
||||
out['r31'] = F.relu(self.conv3_1(out['p2']))
|
||||
out['r32'] = F.relu(self.conv3_2(out['r31']))
|
||||
out['r33'] = F.relu(self.conv3_3(out['r32']))
|
||||
out['r34'] = F.relu(self.conv3_4(out['r33']))
|
||||
out['p3'] = self.pool3(out['r34'])
|
||||
out['r41'] = F.relu(self.conv4_1(out['p3']))
|
||||
out['r42'] = F.relu(self.conv4_2(out['r41']))
|
||||
out['r43'] = F.relu(self.conv4_3(out['r42']))
|
||||
conv_4_4 = self.conv4_4(out['r43'])
|
||||
out['r44'] = F.relu(conv_4_4)
|
||||
out['p4'] = self.pool4(out['r44'])
|
||||
return conv_4_4
|
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from .vgg import VGG
|
||||
import os
|
||||
from configs.paths_config import model_paths
|
||||
|
||||
class VggLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(VggLoss, self).__init__()
|
||||
print("Loading VGG19 model from path: {}".format(model_paths["vgg"]))
|
||||
|
||||
self.vgg_model = VGG()
|
||||
self.vgg_model.load_state_dict(torch.load(model_paths['vgg']))
|
||||
self.vgg_model.cuda()
|
||||
self.vgg_model.eval()
|
||||
|
||||
self.l1loss = torch.nn.L1Loss()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def forward(self, input_photo, output):
|
||||
vgg_photo = self.vgg_model(input_photo)
|
||||
vgg_output = self.vgg_model(output)
|
||||
n, c, h, w = vgg_photo.shape
|
||||
# h, w, c = vgg_photo.get_shape().as_list()[1:]
|
||||
loss = self.l1loss(vgg_photo, vgg_output)
|
||||
|
||||
return loss
|
@ -0,0 +1,14 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class WNormLoss(nn.Module):
|
||||
|
||||
def __init__(self, start_from_latent_avg=True):
|
||||
super(WNormLoss, self).__init__()
|
||||
self.start_from_latent_avg = start_from_latent_avg
|
||||
|
||||
def forward(self, latent, latent_avg=None):
|
||||
if self.start_from_latent_avg:
|
||||
latent = latent - latent_avg
|
||||
return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0]
|
@ -0,0 +1,40 @@
|
||||
from io import BytesIO
|
||||
|
||||
import lmdb
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class MultiResolutionDataset(Dataset):
|
||||
def __init__(self, path, transform, resolution=256):
|
||||
self.env = lmdb.open(
|
||||
path,
|
||||
max_readers=32,
|
||||
readonly=True,
|
||||
lock=False,
|
||||
readahead=False,
|
||||
meminit=False,
|
||||
)
|
||||
|
||||
if not self.env:
|
||||
raise IOError('Cannot open lmdb dataset', path)
|
||||
|
||||
with self.env.begin(write=False) as txn:
|
||||
self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
|
||||
|
||||
self.resolution = resolution
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, index):
|
||||
with self.env.begin(write=False) as txn:
|
||||
key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
|
||||
img_bytes = txn.get(key)
|
||||
|
||||
buffer = BytesIO(img_bytes)
|
||||
img = Image.open(buffer)
|
||||
img = self.transform(img)
|
||||
|
||||
return img
|
@ -0,0 +1,126 @@
|
||||
import math
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not dist.is_available():
|
||||
return 0
|
||||
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def synchronize():
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
if world_size == 1:
|
||||
return
|
||||
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not dist.is_available():
|
||||
return 1
|
||||
|
||||
if not dist.is_initialized():
|
||||
return 1
|
||||
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def reduce_sum(tensor):
|
||||
if not dist.is_available():
|
||||
return tensor
|
||||
|
||||
if not dist.is_initialized():
|
||||
return tensor
|
||||
|
||||
tensor = tensor.clone()
|
||||
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def gather_grad(params):
|
||||
world_size = get_world_size()
|
||||
|
||||
if world_size == 1:
|
||||
return
|
||||
|
||||
for param in params:
|
||||
if param.grad is not None:
|
||||
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
|
||||
param.grad.data.div_(world_size)
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
world_size = get_world_size()
|
||||
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to('cuda')
|
||||
|
||||
local_size = torch.IntTensor([tensor.numel()]).to('cuda')
|
||||
size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
|
||||
|
||||
if local_size != max_size:
|
||||
padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
|
||||
tensor = torch.cat((tensor, padding), 0)
|
||||
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def reduce_loss_dict(loss_dict):
|
||||
world_size = get_world_size()
|
||||
|
||||
if world_size < 2:
|
||||
return loss_dict
|
||||
|
||||
with torch.no_grad():
|
||||
keys = []
|
||||
losses = []
|
||||
|
||||
for k in sorted(loss_dict.keys()):
|
||||
keys.append(k)
|
||||
losses.append(loss_dict[k])
|
||||
|
||||
losses = torch.stack(losses, 0)
|
||||
dist.reduce(losses, dst=0)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
losses /= world_size
|
||||
|
||||
reduced_losses = {k: v for k, v in zip(keys, losses)}
|
||||
|
||||
return reduced_losses
|
@ -0,0 +1,147 @@
|
||||
import argparse
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
|
||||
import torch
|
||||
from torchvision import utils
|
||||
from model import Generator
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
import glob
|
||||
|
||||
from PIL import Image
|
||||
|
||||
def make_image(tensor):
|
||||
return (
|
||||
tensor.detach()
|
||||
.clamp_(min=-1, max=1)
|
||||
.add(1)
|
||||
.div_(2)
|
||||
.mul(255)
|
||||
.type(torch.uint8)
|
||||
.permute(0, 2, 3, 1)
|
||||
.to("cpu")
|
||||
.numpy()
|
||||
)
|
||||
|
||||
def generate(args, g_ema, device, mean_latent, model_name, g_ema_ffhq):
|
||||
|
||||
outdir = args.save_dir
|
||||
|
||||
# print(outdir)
|
||||
# outdir = os.path.join(args.output, args.name, 'eval','toons_paired_0512')
|
||||
if not os.path.exists(outdir):
|
||||
os.makedirs(outdir)
|
||||
|
||||
with torch.no_grad():
|
||||
g_ema.eval()
|
||||
for i in tqdm(range(args.pics)):
|
||||
sample_z = torch.randn(args.sample, args.latent, device=device)
|
||||
|
||||
res, _ = g_ema(
|
||||
[sample_z], truncation=args.truncation, truncation_latent=mean_latent
|
||||
)
|
||||
if args.form == "pair":
|
||||
sample_face, _ = g_ema_ffhq(
|
||||
[sample_z], truncation=args.truncation, truncation_latent=mean_latent
|
||||
)
|
||||
res = torch.cat([sample_face, res], 3)
|
||||
|
||||
outpath = os.path.join(outdir, str(i).zfill(6)+'.png')
|
||||
utils.save_image(
|
||||
res,
|
||||
outpath,
|
||||
# f"sample/{str(i).zfill(6)}.png",
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
range=(-1, 1),
|
||||
)
|
||||
# print('save %s'% outpath)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = "cuda"
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate samples from the generator")
|
||||
parser.add_argument('--config', type=str, default='config/conf_server_test_blend_shell.json')
|
||||
parser.add_argument('--name', type=str, default='')
|
||||
parser.add_argument('--save_dir', type=str, default='')
|
||||
|
||||
parser.add_argument('--form', type=str, default='single')
|
||||
parser.add_argument(
|
||||
"--size", type=int, default=256, help="output image size of the generator"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of samples to be generated for each image",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pics", type=int, default=20, help="number of images to be generated"
|
||||
)
|
||||
parser.add_argument("--truncation", type=float, default=1, help="truncation ratio")
|
||||
parser.add_argument(
|
||||
"--truncation_mean",
|
||||
type=int,
|
||||
default=4096,
|
||||
help="number of vectors to calculate mean for the truncation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="stylegan2-ffhq-config-f.pt",
|
||||
help="path to the model checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--channel_multiplier",
|
||||
type=int,
|
||||
default=2,
|
||||
help="channel multiplier of the generator. config-f = 2, else = 1",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
# from config updata paras
|
||||
opt = vars(args)
|
||||
with open(args.config) as f:
|
||||
config = json.load(f)['parameters']
|
||||
for key, value in config.items():
|
||||
opt[key] = value
|
||||
|
||||
# args.ckpt = 'face_generation/experiment_stylegan/'+args.name+'/models_blend/G_blend_001000_4.pt'
|
||||
args.ckpt = 'face_generation/experiment_stylegan/'+args.name+'/models_blend/G_blend_'
|
||||
args.ckpt = glob.glob(args.ckpt+'*')[0]
|
||||
|
||||
args.latent = 512
|
||||
args.n_mlp = 8
|
||||
|
||||
g_ema = Generator(
|
||||
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
|
||||
).to(device)
|
||||
checkpoint = torch.load(args.ckpt)
|
||||
|
||||
# g_ema.load_state_dict(checkpoint["g_ema"])
|
||||
g_ema.load_state_dict(checkpoint["g_ema"], strict=False)
|
||||
|
||||
## add G_ffhq
|
||||
g_ema_ffhq = Generator(
|
||||
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
|
||||
).to(device)
|
||||
checkpoint_ffhq = torch.load(args.ffhq_ckpt)
|
||||
g_ema_ffhq.load_state_dict(checkpoint_ffhq["g_ema"], strict=False)
|
||||
|
||||
|
||||
if args.truncation < 1:
|
||||
with torch.no_grad():
|
||||
mean_latent = g_ema.mean_latent(args.truncation_mean)
|
||||
else:
|
||||
mean_latent = None
|
||||
|
||||
model_name = os.path.basename(args.ckpt)
|
||||
print('save generated samples to %s'% os.path.join(args.output, args.name, 'eval_blend', model_name))
|
||||
generate(args, g_ema, device, mean_latent, model_name, g_ema_ffhq)
|
||||
# generate_style_mix(args, g_ema, device, mean_latent, model_name, g_ema_ffhq)
|
||||
|
||||
# latent_path = 'test2.pt'
|
||||
# generate_from_latent(args, g_ema, device, mean_latent, latent_path)
|
@ -0,0 +1,939 @@
|
||||
import math
|
||||
import random
|
||||
import functools
|
||||
import operator
|
||||
import copy
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
||||
|
||||
|
||||
def make_kernel(k):
|
||||
k = torch.tensor(k, dtype=torch.float32)
|
||||
|
||||
if k.ndim == 1:
|
||||
k = k[None, :] * k[:, None]
|
||||
|
||||
k /= k.sum()
|
||||
|
||||
return k
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, kernel, factor=2):
|
||||
super().__init__()
|
||||
|
||||
self.factor = factor
|
||||
kernel = make_kernel(kernel) * (factor ** 2)
|
||||
self.register_buffer("kernel", kernel)
|
||||
|
||||
p = kernel.shape[0] - factor
|
||||
|
||||
pad0 = (p + 1) // 2 + factor - 1
|
||||
pad1 = p // 2
|
||||
|
||||
self.pad = (pad0, pad1)
|
||||
|
||||
def forward(self, input):
|
||||
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, kernel, factor=2):
|
||||
super().__init__()
|
||||
|
||||
self.factor = factor
|
||||
kernel = make_kernel(kernel)
|
||||
self.register_buffer("kernel", kernel)
|
||||
|
||||
p = kernel.shape[0] - factor
|
||||
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
|
||||
self.pad = (pad0, pad1)
|
||||
|
||||
def forward(self, input):
|
||||
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Blur(nn.Module):
|
||||
def __init__(self, kernel, pad, upsample_factor=1):
|
||||
super().__init__()
|
||||
|
||||
kernel = make_kernel(kernel)
|
||||
|
||||
if upsample_factor > 1:
|
||||
kernel = kernel * (upsample_factor ** 2)
|
||||
|
||||
self.register_buffer("kernel", kernel)
|
||||
|
||||
self.pad = pad
|
||||
|
||||
def forward(self, input):
|
||||
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class EqualConv2d(nn.Module):
|
||||
def __init__(
|
||||
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
||||
)
|
||||
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
||||
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_channel))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
out = conv2d_gradfix.conv2d(
|
||||
input,
|
||||
self.weight * self.scale,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
|
||||
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
|
||||
)
|
||||
|
||||
|
||||
class EqualLinear(nn.Module):
|
||||
def __init__(
|
||||
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
||||
self.lr_mul = lr_mul
|
||||
|
||||
def forward(self, input):
|
||||
if self.activation:
|
||||
out = F.linear(input, self.weight * self.scale)
|
||||
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
||||
|
||||
else:
|
||||
out = F.linear(
|
||||
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
|
||||
)
|
||||
|
||||
|
||||
class ModulatedConv2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
style_dim,
|
||||
demodulate=True,
|
||||
upsample=False,
|
||||
downsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
fused=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.eps = 1e-8
|
||||
self.kernel_size = kernel_size
|
||||
self.in_channel = in_channel
|
||||
self.out_channel = out_channel
|
||||
self.upsample = upsample
|
||||
self.downsample = downsample
|
||||
|
||||
if upsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2 + factor - 1
|
||||
pad1 = p // 2 + 1
|
||||
|
||||
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
||||
|
||||
if downsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
|
||||
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
||||
|
||||
fan_in = in_channel * kernel_size ** 2
|
||||
self.scale = 1 / math.sqrt(fan_in)
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
||||
)
|
||||
|
||||
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
||||
|
||||
self.demodulate = demodulate
|
||||
self.fused = fused
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
|
||||
f"upsample={self.upsample}, downsample={self.downsample})"
|
||||
)
|
||||
|
||||
def forward(self, input, style):
|
||||
batch, in_channel, height, width = input.shape
|
||||
|
||||
if not self.fused:
|
||||
weight = self.scale * self.weight.squeeze(0)
|
||||
style = self.modulation(style)
|
||||
|
||||
if self.demodulate:
|
||||
w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
|
||||
dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
|
||||
|
||||
input = input * style.reshape(batch, in_channel, 1, 1)
|
||||
|
||||
if self.upsample:
|
||||
weight = weight.transpose(0, 1)
|
||||
out = conv2d_gradfix.conv_transpose2d(
|
||||
input, weight, padding=0, stride=2
|
||||
)
|
||||
out = self.blur(out)
|
||||
|
||||
elif self.downsample:
|
||||
input = self.blur(input)
|
||||
out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
|
||||
|
||||
else:
|
||||
out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
|
||||
|
||||
if self.demodulate:
|
||||
out = out * dcoefs.view(batch, -1, 1, 1)
|
||||
|
||||
return out
|
||||
|
||||
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
||||
weight = self.scale * self.weight * style
|
||||
|
||||
if self.demodulate:
|
||||
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
||||
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
||||
|
||||
weight = weight.view(
|
||||
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
||||
)
|
||||
|
||||
if self.upsample:
|
||||
input = input.view(1, batch * in_channel, height, width)
|
||||
weight = weight.view(
|
||||
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
||||
)
|
||||
weight = weight.transpose(1, 2).reshape(
|
||||
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
||||
)
|
||||
out = conv2d_gradfix.conv_transpose2d(
|
||||
input, weight, padding=0, stride=2, groups=batch
|
||||
)
|
||||
_, _, height, width = out.shape
|
||||
out = out.view(batch, self.out_channel, height, width)
|
||||
out = self.blur(out)
|
||||
|
||||
elif self.downsample:
|
||||
input = self.blur(input)
|
||||
_, _, height, width = input.shape
|
||||
input = input.view(1, batch * in_channel, height, width)
|
||||
out = conv2d_gradfix.conv2d(
|
||||
input, weight, padding=0, stride=2, groups=batch
|
||||
)
|
||||
_, _, height, width = out.shape
|
||||
out = out.view(batch, self.out_channel, height, width)
|
||||
|
||||
else:
|
||||
input = input.view(1, batch * in_channel, height, width)
|
||||
out = conv2d_gradfix.conv2d(
|
||||
input, weight, padding=self.padding, groups=batch
|
||||
)
|
||||
_, _, height, width = out.shape
|
||||
out = out.view(batch, self.out_channel, height, width)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class NoiseInjection(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, image, noise=None):
|
||||
if noise is None:
|
||||
batch, _, height, width = image.shape
|
||||
noise = image.new_empty(batch, 1, height, width).normal_()
|
||||
|
||||
return image + self.weight * noise
|
||||
|
||||
|
||||
class ConstantInput(nn.Module):
|
||||
def __init__(self, channel, size=4):
|
||||
super().__init__()
|
||||
|
||||
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
||||
|
||||
def forward(self, input):
|
||||
batch = input.shape[0]
|
||||
out = self.input.repeat(batch, 1, 1, 1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class StyledConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
style_dim,
|
||||
upsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
demodulate=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv = ModulatedConv2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
style_dim,
|
||||
upsample=upsample,
|
||||
blur_kernel=blur_kernel,
|
||||
demodulate=demodulate,
|
||||
)
|
||||
|
||||
self.noise = NoiseInjection()
|
||||
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
||||
# self.activate = ScaledLeakyReLU(0.2)
|
||||
self.activate = FusedLeakyReLU(out_channel)
|
||||
|
||||
def forward(self, input, style, noise=None):
|
||||
out = self.conv(input, style)
|
||||
out = self.noise(out, noise=noise)
|
||||
# out = out + self.bias
|
||||
out = self.activate(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ToRGB(nn.Module):
|
||||
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
||||
super().__init__()
|
||||
|
||||
if upsample:
|
||||
self.upsample = Upsample(blur_kernel)
|
||||
|
||||
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
||||
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
||||
|
||||
def forward(self, input, style, skip=None):
|
||||
out = self.conv(input, style)
|
||||
out = out + self.bias
|
||||
|
||||
if skip is not None:
|
||||
skip = self.upsample(skip)
|
||||
|
||||
out = out + skip
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
style_dim,
|
||||
n_mlp,
|
||||
channel_multiplier=2,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
lr_mlp=0.01,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.size = size
|
||||
|
||||
self.style_dim = style_dim
|
||||
|
||||
layers = [PixelNorm()]
|
||||
|
||||
for i in range(n_mlp):
|
||||
layers.append(
|
||||
EqualLinear(
|
||||
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
|
||||
)
|
||||
)
|
||||
|
||||
self.style = nn.Sequential(*layers)
|
||||
|
||||
self.channels = {
|
||||
4: 512,
|
||||
8: 512,
|
||||
16: 512,
|
||||
32: 512,
|
||||
64: 256 * channel_multiplier,
|
||||
128: 128 * channel_multiplier,
|
||||
256: 64 * channel_multiplier,
|
||||
512: 32 * channel_multiplier,
|
||||
1024: 16 * channel_multiplier,
|
||||
}
|
||||
|
||||
self.input = ConstantInput(self.channels[4])
|
||||
self.conv1 = StyledConv(
|
||||
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
||||
)
|
||||
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
||||
|
||||
self.log_size = int(math.log(size, 2))# 256, 8
|
||||
self.num_layers = (self.log_size - 2) * 2 + 1
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
self.upsamples = nn.ModuleList()
|
||||
self.to_rgbs = nn.ModuleList()
|
||||
self.noises = nn.Module()
|
||||
|
||||
in_channel = self.channels[4]
|
||||
|
||||
for layer_idx in range(self.num_layers):
|
||||
res = (layer_idx + 5) // 2
|
||||
shape = [1, 1, 2 ** res, 2 ** res]
|
||||
self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
|
||||
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channel = self.channels[2 ** i]
|
||||
|
||||
self.convs.append(
|
||||
StyledConv(
|
||||
in_channel,
|
||||
out_channel,
|
||||
3,
|
||||
style_dim,
|
||||
upsample=True,
|
||||
blur_kernel=blur_kernel,
|
||||
)
|
||||
)
|
||||
|
||||
self.convs.append(
|
||||
StyledConv(
|
||||
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
||||
)
|
||||
)
|
||||
|
||||
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
||||
|
||||
in_channel = out_channel
|
||||
|
||||
self.n_latent = self.log_size * 2 - 2
|
||||
|
||||
def make_noise(self):
|
||||
device = self.input.input.device
|
||||
|
||||
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
||||
|
||||
for i in range(3, self.log_size + 1):
|
||||
for _ in range(2):
|
||||
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
||||
|
||||
return noises
|
||||
|
||||
def mean_latent(self, n_latent):
|
||||
latent_in = torch.randn(
|
||||
n_latent, self.style_dim, device=self.input.input.device
|
||||
)
|
||||
latent = self.style(latent_in).mean(0, keepdim=True)
|
||||
|
||||
return latent
|
||||
|
||||
def get_latent(self, input):
|
||||
return self.style(input)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
styles,
|
||||
return_latents=False,
|
||||
inject_index=None,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
input_is_latent=False,
|
||||
noise=None,
|
||||
randomize_noise=True,
|
||||
):
|
||||
if not input_is_latent:
|
||||
styles = [self.style(s) for s in styles]
|
||||
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
noise = [None] * self.num_layers
|
||||
else:
|
||||
noise = [
|
||||
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
|
||||
]
|
||||
|
||||
if truncation < 1:
|
||||
style_t = []
|
||||
|
||||
for style in styles:
|
||||
style_t.append(
|
||||
truncation_latent + truncation * (style - truncation_latent)
|
||||
)
|
||||
|
||||
styles = style_t
|
||||
|
||||
if len(styles) < 2:
|
||||
inject_index = self.n_latent
|
||||
|
||||
if styles[0].ndim < 3:
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
|
||||
else:
|
||||
latent = styles[0]
|
||||
|
||||
else:
|
||||
if inject_index is None:
|
||||
inject_index = random.randint(1, self.n_latent - 1)
|
||||
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
||||
|
||||
latent = torch.cat([latent, latent2], 1)
|
||||
|
||||
out = self.input(latent)
|
||||
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
||||
|
||||
skip = self.to_rgb1(out, latent[:, 1])
|
||||
|
||||
i = 1
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
||||
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
||||
):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
||||
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent
|
||||
|
||||
else:
|
||||
return image, None
|
||||
|
||||
def clone(self):
|
||||
"""
|
||||
Create a copy of this model.
|
||||
Returns:
|
||||
model_copy (nn.Module)
|
||||
"""
|
||||
return copy.deepcopy(self)
|
||||
|
||||
class StyleFusion(nn.Module):
|
||||
def __init__(self, n_latent):
|
||||
super().__init__()
|
||||
self.n_latent = n_latent
|
||||
self.weight = nn.Parameter(torch.zeros(n_latent))
|
||||
|
||||
def forward(self, sty1, sty2):
|
||||
sty_t = sty1.clone()
|
||||
for i in range(self.n_latent):
|
||||
sty_t[:, i, :] = sty1[:, i, :]* self.weight[i] + sty2[:, i, :]* (1-self.weight[i])
|
||||
|
||||
return sty_t
|
||||
|
||||
|
||||
class Generator_resty(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
style_dim,
|
||||
n_mlp,
|
||||
channel_multiplier=2,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
lr_mlp=0.01,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.size = size
|
||||
self.style_dim = style_dim
|
||||
|
||||
self.log_size = int(math.log(size, 2))
|
||||
self.num_layers = (self.log_size - 2) * 2 + 1
|
||||
self.n_latent = self.log_size * 2 - 2
|
||||
# add
|
||||
self.fusion_style = StyleFusion(self.n_latent)
|
||||
|
||||
layers = [PixelNorm()]
|
||||
|
||||
# inject to w+, 14*512
|
||||
# for i in range(n_mlp):
|
||||
# if i==n_mlp-1:
|
||||
# layers.append(
|
||||
# EqualLinear(
|
||||
# style_dim, style_dim * self.n_latent, lr_mul=lr_mlp, activation="fused_lrelu"
|
||||
# )
|
||||
# )
|
||||
# else:
|
||||
# layers.append(
|
||||
# EqualLinear(
|
||||
# style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
|
||||
# )
|
||||
# )
|
||||
|
||||
for i in range(n_mlp):
|
||||
layers.append(
|
||||
EqualLinear(
|
||||
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
|
||||
)
|
||||
)
|
||||
layers.append(
|
||||
EqualLinear(
|
||||
style_dim, style_dim * self.n_latent, lr_mul=lr_mlp, activation="fused_lrelu"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
self.style = nn.Sequential(*layers)
|
||||
|
||||
self.channels = {
|
||||
4: 512,
|
||||
8: 512,
|
||||
16: 512,
|
||||
32: 512,
|
||||
64: 256 * channel_multiplier,
|
||||
128: 128 * channel_multiplier,
|
||||
256: 64 * channel_multiplier,
|
||||
512: 32 * channel_multiplier,
|
||||
1024: 16 * channel_multiplier,
|
||||
}
|
||||
|
||||
self.input = ConstantInput(self.channels[4])
|
||||
self.conv1 = StyledConv(
|
||||
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
||||
)
|
||||
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
||||
|
||||
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
self.upsamples = nn.ModuleList()
|
||||
self.to_rgbs = nn.ModuleList()
|
||||
self.noises = nn.Module()
|
||||
|
||||
in_channel = self.channels[4]
|
||||
|
||||
for layer_idx in range(self.num_layers):
|
||||
res = (layer_idx + 5) // 2
|
||||
shape = [1, 1, 2 ** res, 2 ** res]
|
||||
self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
|
||||
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channel = self.channels[2 ** i]
|
||||
|
||||
self.convs.append(
|
||||
StyledConv(
|
||||
in_channel,
|
||||
out_channel,
|
||||
3,
|
||||
style_dim,
|
||||
upsample=True,
|
||||
blur_kernel=blur_kernel,
|
||||
)
|
||||
)
|
||||
|
||||
self.convs.append(
|
||||
StyledConv(
|
||||
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
||||
)
|
||||
)
|
||||
|
||||
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
||||
|
||||
in_channel = out_channel
|
||||
|
||||
|
||||
|
||||
def make_noise(self):
|
||||
device = self.input.input.device
|
||||
|
||||
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
||||
|
||||
for i in range(3, self.log_size + 1):
|
||||
for _ in range(2):
|
||||
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
||||
|
||||
return noises
|
||||
|
||||
def mean_latent(self, n_latent):
|
||||
latent_in = torch.randn(
|
||||
n_latent, self.style_dim, device=self.input.input.device
|
||||
)
|
||||
latent = self.style(latent_in).mean(0, keepdim=True)
|
||||
|
||||
return latent
|
||||
|
||||
def get_latent(self, input):
|
||||
return self.style(input)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
styles,
|
||||
style_single,
|
||||
generator_source,
|
||||
return_latents=False,
|
||||
inject_index=None,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
input_is_latent=False,
|
||||
noise=None,
|
||||
randomize_noise=True,
|
||||
):
|
||||
|
||||
styles = [generator_source.style(s) for s in styles] # [2,bs,512] or [1,bs,512]
|
||||
if truncation < 1:
|
||||
style_t = []
|
||||
for style in styles:
|
||||
style_t.append(
|
||||
truncation_latent + truncation * (style - truncation_latent)
|
||||
)
|
||||
styles = style_t
|
||||
|
||||
if len(styles) < 2:
|
||||
inject_index = self.n_latent
|
||||
if styles[0].ndim < 3:
|
||||
source_styles = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
else:
|
||||
source_styles = styles[0]
|
||||
else:
|
||||
if inject_index is None:
|
||||
inject_index = random.randint(1, self.n_latent - 1)
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
||||
source_styles = torch.cat([latent, latent2], 1)
|
||||
|
||||
styles2 = self.style(style_single[0]) # bs*7068
|
||||
styles2 = styles2.view(-1, self.n_latent, 512) # bs*14*512
|
||||
|
||||
latent = self.fusion_style(styles2, source_styles)
|
||||
# test
|
||||
# latent = source_styles.clone()
|
||||
|
||||
|
||||
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
noise = [None] * self.num_layers
|
||||
else:
|
||||
noise = [
|
||||
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
|
||||
]
|
||||
|
||||
|
||||
|
||||
out = self.input(latent)
|
||||
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
||||
|
||||
skip = self.to_rgb1(out, latent[:, 1])
|
||||
|
||||
i = 1
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
||||
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
||||
):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
||||
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent
|
||||
|
||||
else:
|
||||
return image, None
|
||||
|
||||
def clone(self):
|
||||
"""
|
||||
Create a copy of this model.
|
||||
Returns:
|
||||
model_copy (nn.Module)
|
||||
"""
|
||||
return copy.deepcopy(self)
|
||||
|
||||
class ConvLayer(nn.Sequential):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
downsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
bias=True,
|
||||
activate=True,
|
||||
):
|
||||
layers = []
|
||||
|
||||
if downsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
|
||||
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
||||
|
||||
stride = 2
|
||||
self.padding = 0
|
||||
|
||||
else:
|
||||
stride = 1
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
layers.append(
|
||||
EqualConv2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
padding=self.padding,
|
||||
stride=stride,
|
||||
bias=bias and not activate,
|
||||
)
|
||||
)
|
||||
|
||||
if activate:
|
||||
layers.append(FusedLeakyReLU(out_channel, bias=bias))
|
||||
|
||||
super().__init__(*layers)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
||||
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
||||
|
||||
self.skip = ConvLayer(
|
||||
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.conv1(input)
|
||||
out = self.conv2(out)
|
||||
|
||||
skip = self.skip(input)
|
||||
out = (out + skip) / math.sqrt(2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
|
||||
super().__init__()
|
||||
|
||||
channels = {
|
||||
4: 512,
|
||||
8: 512,
|
||||
16: 512,
|
||||
32: 512,
|
||||
64: 256 * channel_multiplier,
|
||||
128: 128 * channel_multiplier,
|
||||
256: 64 * channel_multiplier,
|
||||
512: 32 * channel_multiplier,
|
||||
1024: 16 * channel_multiplier,
|
||||
}
|
||||
|
||||
convs = [ConvLayer(3, channels[size], 1)]
|
||||
|
||||
log_size = int(math.log(size, 2))
|
||||
|
||||
in_channel = channels[size]
|
||||
|
||||
for i in range(log_size, 2, -1):
|
||||
out_channel = channels[2 ** (i - 1)]
|
||||
|
||||
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
||||
|
||||
in_channel = out_channel
|
||||
|
||||
self.convs = nn.Sequential(*convs)
|
||||
|
||||
self.stddev_group = 4
|
||||
self.stddev_feat = 1
|
||||
|
||||
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
||||
self.final_linear = nn.Sequential(
|
||||
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
|
||||
EqualLinear(channels[4], 1),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.convs(input)
|
||||
|
||||
batch, channel, height, width = out.shape
|
||||
group = min(batch, self.stddev_group)
|
||||
stddev = out.view(
|
||||
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
||||
)
|
||||
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
||||
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
||||
stddev = stddev.repeat(group, 1, height, width)
|
||||
out = torch.cat([out, stddev], 1)
|
||||
|
||||
out = self.final_conv(out)
|
||||
|
||||
out = out.view(batch, -1)
|
||||
out = self.final_linear(out)
|
||||
|
||||
return out
|
||||
|
@ -0,0 +1,465 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import autograd
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
from distributed import reduce_sum
|
||||
from op import upfirdn2d
|
||||
|
||||
|
||||
class AdaptiveAugment:
|
||||
def __init__(self, ada_aug_target, ada_aug_len, update_every, device):
|
||||
self.ada_aug_target = ada_aug_target
|
||||
self.ada_aug_len = ada_aug_len
|
||||
self.update_every = update_every
|
||||
|
||||
self.ada_update = 0
|
||||
self.ada_aug_buf = torch.tensor([0.0, 0.0], device=device)
|
||||
self.r_t_stat = 0
|
||||
self.ada_aug_p = 0
|
||||
|
||||
@torch.no_grad()
|
||||
def tune(self, real_pred):
|
||||
self.ada_aug_buf += torch.tensor(
|
||||
(torch.sign(real_pred).sum().item(), real_pred.shape[0]),
|
||||
device=real_pred.device,
|
||||
)
|
||||
self.ada_update += 1
|
||||
|
||||
if self.ada_update % self.update_every == 0:
|
||||
self.ada_aug_buf = reduce_sum(self.ada_aug_buf)
|
||||
pred_signs, n_pred = self.ada_aug_buf.tolist()
|
||||
|
||||
self.r_t_stat = pred_signs / n_pred
|
||||
|
||||
if self.r_t_stat > self.ada_aug_target:
|
||||
sign = 1
|
||||
|
||||
else:
|
||||
sign = -1
|
||||
|
||||
self.ada_aug_p += sign * n_pred / self.ada_aug_len
|
||||
self.ada_aug_p = min(1, max(0, self.ada_aug_p))
|
||||
self.ada_aug_buf.mul_(0)
|
||||
self.ada_update = 0
|
||||
|
||||
return self.ada_aug_p
|
||||
|
||||
|
||||
SYM6 = (
|
||||
0.015404109327027373,
|
||||
0.0034907120842174702,
|
||||
-0.11799011114819057,
|
||||
-0.048311742585633,
|
||||
0.4910559419267466,
|
||||
0.787641141030194,
|
||||
0.3379294217276218,
|
||||
-0.07263752278646252,
|
||||
-0.021060292512300564,
|
||||
0.04472490177066578,
|
||||
0.0017677118642428036,
|
||||
-0.007800708325034148,
|
||||
)
|
||||
|
||||
|
||||
def translate_mat(t_x, t_y, device="cpu"):
|
||||
batch = t_x.shape[0]
|
||||
|
||||
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
|
||||
translate = torch.stack((t_x, t_y), 1)
|
||||
mat[:, :2, 2] = translate
|
||||
|
||||
return mat
|
||||
|
||||
|
||||
def rotate_mat(theta, device="cpu"):
|
||||
batch = theta.shape[0]
|
||||
|
||||
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
|
||||
sin_t = torch.sin(theta)
|
||||
cos_t = torch.cos(theta)
|
||||
rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)
|
||||
mat[:, :2, :2] = rot
|
||||
|
||||
return mat
|
||||
|
||||
|
||||
def scale_mat(s_x, s_y, device="cpu"):
|
||||
batch = s_x.shape[0]
|
||||
|
||||
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
|
||||
mat[:, 0, 0] = s_x
|
||||
mat[:, 1, 1] = s_y
|
||||
|
||||
return mat
|
||||
|
||||
|
||||
def translate3d_mat(t_x, t_y, t_z):
|
||||
batch = t_x.shape[0]
|
||||
|
||||
mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
||||
translate = torch.stack((t_x, t_y, t_z), 1)
|
||||
mat[:, :3, 3] = translate
|
||||
|
||||
return mat
|
||||
|
||||
|
||||
def rotate3d_mat(axis, theta):
|
||||
batch = theta.shape[0]
|
||||
|
||||
u_x, u_y, u_z = axis
|
||||
|
||||
eye = torch.eye(3).unsqueeze(0)
|
||||
cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0)
|
||||
outer = torch.tensor(axis)
|
||||
outer = (outer.unsqueeze(1) * outer).unsqueeze(0)
|
||||
|
||||
sin_t = torch.sin(theta).view(-1, 1, 1)
|
||||
cos_t = torch.cos(theta).view(-1, 1, 1)
|
||||
|
||||
rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer
|
||||
|
||||
eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
||||
eye_4[:, :3, :3] = rot
|
||||
|
||||
return eye_4
|
||||
|
||||
|
||||
def scale3d_mat(s_x, s_y, s_z):
|
||||
batch = s_x.shape[0]
|
||||
|
||||
mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
||||
mat[:, 0, 0] = s_x
|
||||
mat[:, 1, 1] = s_y
|
||||
mat[:, 2, 2] = s_z
|
||||
|
||||
return mat
|
||||
|
||||
|
||||
def luma_flip_mat(axis, i):
|
||||
batch = i.shape[0]
|
||||
|
||||
eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
||||
axis = torch.tensor(axis + (0,))
|
||||
flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1)
|
||||
|
||||
return eye - flip
|
||||
|
||||
|
||||
def saturation_mat(axis, i):
|
||||
batch = i.shape[0]
|
||||
|
||||
eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
||||
axis = torch.tensor(axis + (0,))
|
||||
axis = torch.ger(axis, axis)
|
||||
saturate = axis + (eye - axis) * i.view(-1, 1, 1)
|
||||
|
||||
return saturate
|
||||
|
||||
|
||||
def lognormal_sample(size, mean=0, std=1, device="cpu"):
|
||||
return torch.empty(size, device=device).log_normal_(mean=mean, std=std)
|
||||
|
||||
|
||||
def category_sample(size, categories, device="cpu"):
|
||||
category = torch.tensor(categories, device=device)
|
||||
sample = torch.randint(high=len(categories), size=(size,), device=device)
|
||||
|
||||
return category[sample]
|
||||
|
||||
|
||||
def uniform_sample(size, low, high, device="cpu"):
|
||||
return torch.empty(size, device=device).uniform_(low, high)
|
||||
|
||||
|
||||
def normal_sample(size, mean=0, std=1, device="cpu"):
|
||||
return torch.empty(size, device=device).normal_(mean, std)
|
||||
|
||||
|
||||
def bernoulli_sample(size, p, device="cpu"):
|
||||
return torch.empty(size, device=device).bernoulli_(p)
|
||||
|
||||
|
||||
def random_mat_apply(p, transform, prev, eye, device="cpu"):
|
||||
size = transform.shape[0]
|
||||
select = bernoulli_sample(size, p, device=device).view(size, 1, 1)
|
||||
select_transform = select * transform + (1 - select) * eye
|
||||
|
||||
return select_transform @ prev
|
||||
|
||||
|
||||
def sample_affine(p, size, height, width, device="cpu"):
|
||||
G = torch.eye(3, device=device).unsqueeze(0).repeat(size, 1, 1)
|
||||
eye = G
|
||||
|
||||
# flip
|
||||
param = category_sample(size, (0, 1))
|
||||
Gc = scale_mat(1 - 2.0 * param, torch.ones(size), device=device)
|
||||
G = random_mat_apply(p, Gc, G, eye, device=device)
|
||||
# print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n')
|
||||
|
||||
# 90 rotate
|
||||
param = category_sample(size, (0, 3))
|
||||
Gc = rotate_mat(-math.pi / 2 * param, device=device)
|
||||
G = random_mat_apply(p, Gc, G, eye, device=device)
|
||||
# print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
|
||||
|
||||
# integer translate
|
||||
param = uniform_sample(size, -0.125, 0.125)
|
||||
param_height = torch.round(param * height) / height
|
||||
param_width = torch.round(param * width) / width
|
||||
Gc = translate_mat(param_width, param_height, device=device)
|
||||
G = random_mat_apply(p, Gc, G, eye, device=device)
|
||||
# print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
|
||||
|
||||
# isotropic scale
|
||||
param = lognormal_sample(size, std=0.2 * math.log(2))
|
||||
Gc = scale_mat(param, param, device=device)
|
||||
G = random_mat_apply(p, Gc, G, eye, device=device)
|
||||
# print('isotropic scale', G, scale_mat(param, param), sep='\n')
|
||||
|
||||
p_rot = 1 - math.sqrt(1 - p)
|
||||
|
||||
# pre-rotate
|
||||
param = uniform_sample(size, -math.pi, math.pi)
|
||||
Gc = rotate_mat(-param, device=device)
|
||||
G = random_mat_apply(p_rot, Gc, G, eye, device=device)
|
||||
# print('pre-rotate', G, rotate_mat(-param), sep='\n')
|
||||
|
||||
# anisotropic scale
|
||||
param = lognormal_sample(size, std=0.2 * math.log(2))
|
||||
Gc = scale_mat(param, 1 / param, device=device)
|
||||
G = random_mat_apply(p, Gc, G, eye, device=device)
|
||||
# print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n')
|
||||
|
||||
# post-rotate
|
||||
param = uniform_sample(size, -math.pi, math.pi)
|
||||
Gc = rotate_mat(-param, device=device)
|
||||
G = random_mat_apply(p_rot, Gc, G, eye, device=device)
|
||||
# print('post-rotate', G, rotate_mat(-param), sep='\n')
|
||||
|
||||
# fractional translate
|
||||
param = normal_sample(size, std=0.125)
|
||||
Gc = translate_mat(param, param, device=device)
|
||||
G = random_mat_apply(p, Gc, G, eye, device=device)
|
||||
# print('fractional translate', G, translate_mat(param, param), sep='\n')
|
||||
|
||||
return G
|
||||
|
||||
|
||||
def sample_color(p, size):
|
||||
C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1)
|
||||
eye = C
|
||||
axis_val = 1 / math.sqrt(3)
|
||||
axis = (axis_val, axis_val, axis_val)
|
||||
|
||||
# brightness
|
||||
param = normal_sample(size, std=0.2)
|
||||
Cc = translate3d_mat(param, param, param)
|
||||
C = random_mat_apply(p, Cc, C, eye)
|
||||
|
||||
# contrast
|
||||
param = lognormal_sample(size, std=0.5 * math.log(2))
|
||||
Cc = scale3d_mat(param, param, param)
|
||||
C = random_mat_apply(p, Cc, C, eye)
|
||||
|
||||
# luma flip
|
||||
param = category_sample(size, (0, 1))
|
||||
Cc = luma_flip_mat(axis, param)
|
||||
C = random_mat_apply(p, Cc, C, eye)
|
||||
|
||||
# hue rotation
|
||||
param = uniform_sample(size, -math.pi, math.pi)
|
||||
Cc = rotate3d_mat(axis, param)
|
||||
C = random_mat_apply(p, Cc, C, eye)
|
||||
|
||||
# saturation
|
||||
param = lognormal_sample(size, std=1 * math.log(2))
|
||||
Cc = saturation_mat(axis, param)
|
||||
C = random_mat_apply(p, Cc, C, eye)
|
||||
|
||||
return C
|
||||
|
||||
|
||||
def make_grid(shape, x0, x1, y0, y1, device):
|
||||
n, c, h, w = shape
|
||||
grid = torch.empty(n, h, w, 3, device=device)
|
||||
grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device)
|
||||
grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1)
|
||||
grid[:, :, :, 2] = 1
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def affine_grid(grid, mat):
|
||||
n, h, w, _ = grid.shape
|
||||
return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2)
|
||||
|
||||
|
||||
def get_padding(G, height, width, kernel_size):
|
||||
device = G.device
|
||||
|
||||
cx = (width - 1) / 2
|
||||
cy = (height - 1) / 2
|
||||
cp = torch.tensor(
|
||||
[(-cx, -cy, 1), (cx, -cy, 1), (cx, cy, 1), (-cx, cy, 1)], device=device
|
||||
)
|
||||
cp = G @ cp.T
|
||||
|
||||
pad_k = kernel_size // 4
|
||||
|
||||
pad = cp[:, :2, :].permute(1, 0, 2).flatten(1)
|
||||
pad = torch.cat((-pad, pad)).max(1).values
|
||||
pad = pad + torch.tensor([pad_k * 2 - cx, pad_k * 2 - cy] * 2, device=device)
|
||||
pad = pad.max(torch.tensor([0, 0] * 2, device=device))
|
||||
pad = pad.min(torch.tensor([width - 1, height - 1] * 2, device=device))
|
||||
|
||||
pad_x1, pad_y1, pad_x2, pad_y2 = pad.ceil().to(torch.int32)
|
||||
|
||||
return pad_x1, pad_x2, pad_y1, pad_y2
|
||||
|
||||
|
||||
def try_sample_affine_and_pad(img, p, kernel_size, G=None):
|
||||
batch, _, height, width = img.shape
|
||||
|
||||
G_try = G
|
||||
|
||||
if G is None:
|
||||
G_try = torch.inverse(sample_affine(p, batch, height, width))
|
||||
|
||||
pad_x1, pad_x2, pad_y1, pad_y2 = get_padding(G_try, height, width, kernel_size)
|
||||
|
||||
img_pad = F.pad(img, (pad_x1, pad_x2, pad_y1, pad_y2), mode="reflect")
|
||||
|
||||
return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2)
|
||||
|
||||
|
||||
class GridSampleForward(autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, grid):
|
||||
out = F.grid_sample(
|
||||
input, grid, mode="bilinear", padding_mode="zeros", align_corners=False
|
||||
)
|
||||
ctx.save_for_backward(input, grid)
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, grid = ctx.saved_tensors
|
||||
grad_input, grad_grid = GridSampleBackward.apply(grad_output, input, grid)
|
||||
|
||||
return grad_input, grad_grid
|
||||
|
||||
|
||||
class GridSampleBackward(autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, input, grid):
|
||||
op = torch._C._jit_get_operation("aten::grid_sampler_2d_backward")
|
||||
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
|
||||
ctx.save_for_backward(grid)
|
||||
|
||||
return grad_input, grad_grid
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_grad_input, grad_grad_grid):
|
||||
grid, = ctx.saved_tensors
|
||||
grad_grad_output = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_grad_output = GridSampleForward.apply(grad_grad_input, grid)
|
||||
|
||||
return grad_grad_output, None, None
|
||||
|
||||
|
||||
grid_sample = GridSampleForward.apply
|
||||
|
||||
|
||||
def scale_mat_single(s_x, s_y):
|
||||
return torch.tensor(((s_x, 0, 0), (0, s_y, 0), (0, 0, 1)), dtype=torch.float32)
|
||||
|
||||
|
||||
def translate_mat_single(t_x, t_y):
|
||||
return torch.tensor(((1, 0, t_x), (0, 1, t_y), (0, 0, 1)), dtype=torch.float32)
|
||||
|
||||
|
||||
def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6):
|
||||
kernel = antialiasing_kernel
|
||||
len_k = len(kernel)
|
||||
|
||||
kernel = torch.as_tensor(kernel).to(img)
|
||||
# kernel = torch.ger(kernel, kernel).to(img)
|
||||
kernel_flip = torch.flip(kernel, (0,))
|
||||
|
||||
img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad(
|
||||
img, p, len_k, G
|
||||
)
|
||||
|
||||
G_inv = (
|
||||
translate_mat_single((pad_x1 - pad_x2).item() / 2, (pad_y1 - pad_y2).item() / 2)
|
||||
@ G
|
||||
)
|
||||
up_pad = (
|
||||
(len_k + 2 - 1) // 2,
|
||||
(len_k - 2) // 2,
|
||||
(len_k + 2 - 1) // 2,
|
||||
(len_k - 2) // 2,
|
||||
)
|
||||
img_2x = upfirdn2d(img_pad, kernel.unsqueeze(0), up=(2, 1), pad=(*up_pad[:2], 0, 0))
|
||||
img_2x = upfirdn2d(img_2x, kernel.unsqueeze(1), up=(1, 2), pad=(0, 0, *up_pad[2:]))
|
||||
G_inv = scale_mat_single(2, 2) @ G_inv @ scale_mat_single(1 / 2, 1 / 2)
|
||||
G_inv = translate_mat_single(-0.5, -0.5) @ G_inv @ translate_mat_single(0.5, 0.5)
|
||||
batch_size, channel, height, width = img.shape
|
||||
pad_k = len_k // 4
|
||||
shape = (batch_size, channel, (height + pad_k * 2) * 2, (width + pad_k * 2) * 2)
|
||||
G_inv = (
|
||||
scale_mat_single(2 / img_2x.shape[3], 2 / img_2x.shape[2])
|
||||
@ G_inv
|
||||
@ scale_mat_single(1 / (2 / shape[3]), 1 / (2 / shape[2]))
|
||||
)
|
||||
grid = F.affine_grid(G_inv[:, :2, :].to(img_2x), shape, align_corners=False)
|
||||
img_affine = grid_sample(img_2x, grid)
|
||||
d_p = -pad_k * 2
|
||||
down_pad = (
|
||||
d_p + (len_k - 2 + 1) // 2,
|
||||
d_p + (len_k - 2) // 2,
|
||||
d_p + (len_k - 2 + 1) // 2,
|
||||
d_p + (len_k - 2) // 2,
|
||||
)
|
||||
img_down = upfirdn2d(
|
||||
img_affine, kernel_flip.unsqueeze(0), down=(2, 1), pad=(*down_pad[:2], 0, 0)
|
||||
)
|
||||
img_down = upfirdn2d(
|
||||
img_down, kernel_flip.unsqueeze(1), down=(1, 2), pad=(0, 0, *down_pad[2:])
|
||||
)
|
||||
|
||||
return img_down, G
|
||||
|
||||
|
||||
def apply_color(img, mat):
|
||||
batch = img.shape[0]
|
||||
img = img.permute(0, 2, 3, 1)
|
||||
mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3)
|
||||
mat_add = mat[:, :3, 3].view(batch, 1, 1, 3)
|
||||
img = img @ mat_mul + mat_add
|
||||
img = img.permute(0, 3, 1, 2)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def random_apply_color(img, p, C=None):
|
||||
if C is None:
|
||||
C = sample_color(p, img.shape[0])
|
||||
|
||||
img = apply_color(img, C.to(img))
|
||||
|
||||
return img, C
|
||||
|
||||
|
||||
def augment(img, p, transform_matrix=(None, None)):
|
||||
img, G = random_apply_affine(img, p, transform_matrix[0])
|
||||
img, C = random_apply_color(img, p, transform_matrix[1])
|
||||
|
||||
return img, (G, C)
|
@ -0,0 +1,2 @@
|
||||
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
||||
from .upfirdn2d import upfirdn2d
|
@ -0,0 +1,227 @@
|
||||
import contextlib
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import autograd
|
||||
from torch.nn import functional as F
|
||||
|
||||
enabled = True
|
||||
weight_gradients_disabled = False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_weight_gradients():
|
||||
global weight_gradients_disabled
|
||||
|
||||
old = weight_gradients_disabled
|
||||
weight_gradients_disabled = True
|
||||
yield
|
||||
weight_gradients_disabled = old
|
||||
|
||||
|
||||
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
if could_use_op(input):
|
||||
return conv2d_gradfix(
|
||||
transpose=False,
|
||||
weight_shape=weight.shape,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=0,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
).apply(input, weight, bias)
|
||||
|
||||
return F.conv2d(
|
||||
input=input,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
|
||||
def conv_transpose2d(
|
||||
input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=1,
|
||||
padding=0,
|
||||
output_padding=0,
|
||||
groups=1,
|
||||
dilation=1,
|
||||
):
|
||||
if could_use_op(input):
|
||||
return conv2d_gradfix(
|
||||
transpose=True,
|
||||
weight_shape=weight.shape,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation,
|
||||
).apply(input, weight, bias)
|
||||
|
||||
return F.conv_transpose2d(
|
||||
input=input,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
|
||||
def could_use_op(input):
|
||||
if (not enabled) or (not torch.backends.cudnn.enabled):
|
||||
return False
|
||||
|
||||
if input.device.type != "cuda":
|
||||
return False
|
||||
|
||||
if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
|
||||
return True
|
||||
|
||||
warnings.warn(
|
||||
f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def ensure_tuple(xs, ndim):
|
||||
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
||||
|
||||
return xs
|
||||
|
||||
|
||||
conv2d_gradfix_cache = dict()
|
||||
|
||||
|
||||
def conv2d_gradfix(
|
||||
transpose, weight_shape, stride, padding, output_padding, dilation, groups
|
||||
):
|
||||
ndim = 2
|
||||
weight_shape = tuple(weight_shape)
|
||||
stride = ensure_tuple(stride, ndim)
|
||||
padding = ensure_tuple(padding, ndim)
|
||||
output_padding = ensure_tuple(output_padding, ndim)
|
||||
dilation = ensure_tuple(dilation, ndim)
|
||||
|
||||
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
||||
if key in conv2d_gradfix_cache:
|
||||
return conv2d_gradfix_cache[key]
|
||||
|
||||
common_kwargs = dict(
|
||||
stride=stride, padding=padding, dilation=dilation, groups=groups
|
||||
)
|
||||
|
||||
def calc_output_padding(input_shape, output_shape):
|
||||
if transpose:
|
||||
return [0, 0]
|
||||
|
||||
return [
|
||||
input_shape[i + 2]
|
||||
- (output_shape[i + 2] - 1) * stride[i]
|
||||
- (1 - 2 * padding[i])
|
||||
- dilation[i] * (weight_shape[i + 2] - 1)
|
||||
for i in range(ndim)
|
||||
]
|
||||
|
||||
class Conv2d(autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias):
|
||||
if not transpose:
|
||||
out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
||||
|
||||
else:
|
||||
out = F.conv_transpose2d(
|
||||
input=input,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
output_padding=output_padding,
|
||||
**common_kwargs,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(input, weight)
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
grad_input, grad_weight, grad_bias = None, None, None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
p = calc_output_padding(
|
||||
input_shape=input.shape, output_shape=grad_output.shape
|
||||
)
|
||||
grad_input = conv2d_gradfix(
|
||||
transpose=(not transpose),
|
||||
weight_shape=weight_shape,
|
||||
output_padding=p,
|
||||
**common_kwargs,
|
||||
).apply(grad_output, weight, None)
|
||||
|
||||
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
||||
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
||||
|
||||
if ctx.needs_input_grad[2]:
|
||||
grad_bias = grad_output.sum((0, 2, 3))
|
||||
|
||||
return grad_input, grad_weight, grad_bias
|
||||
|
||||
class Conv2dGradWeight(autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, input):
|
||||
op = torch._C._jit_get_operation(
|
||||
"aten::cudnn_convolution_backward_weight"
|
||||
if not transpose
|
||||
else "aten::cudnn_convolution_transpose_backward_weight"
|
||||
)
|
||||
flags = [
|
||||
torch.backends.cudnn.benchmark,
|
||||
torch.backends.cudnn.deterministic,
|
||||
torch.backends.cudnn.allow_tf32,
|
||||
]
|
||||
grad_weight = op(
|
||||
weight_shape,
|
||||
grad_output,
|
||||
input,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
*flags,
|
||||
)
|
||||
ctx.save_for_backward(grad_output, input)
|
||||
|
||||
return grad_weight
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_grad_weight):
|
||||
grad_output, input = ctx.saved_tensors
|
||||
grad_grad_output, grad_grad_input = None, None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
p = calc_output_padding(
|
||||
input_shape=input.shape, output_shape=grad_output.shape
|
||||
)
|
||||
grad_grad_input = conv2d_gradfix(
|
||||
transpose=(not transpose),
|
||||
weight_shape=weight_shape,
|
||||
output_padding=p,
|
||||
**common_kwargs,
|
||||
).apply(grad_output, grad_grad_weight, None)
|
||||
|
||||
return grad_grad_output, grad_grad_input
|
||||
|
||||
conv2d_gradfix_cache[key] = Conv2d
|
||||
|
||||
return Conv2d
|
@ -0,0 +1,119 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.autograd import Function
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
|
||||
module_path = os.path.dirname(__file__)
|
||||
fused = load(
|
||||
"fused",
|
||||
sources=[
|
||||
os.path.join(module_path, "fused_bias_act.cpp"),
|
||||
os.path.join(module_path, "fused_bias_act_kernel.cu"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class FusedLeakyReLUFunctionBackward(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, out, bias, negative_slope, scale):
|
||||
ctx.save_for_backward(out)
|
||||
ctx.negative_slope = negative_slope
|
||||
ctx.scale = scale
|
||||
|
||||
empty = grad_output.new_empty(0)
|
||||
|
||||
grad_input = fused.fused_bias_act(
|
||||
grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
|
||||
)
|
||||
|
||||
dim = [0]
|
||||
|
||||
if grad_input.ndim > 2:
|
||||
dim += list(range(2, grad_input.ndim))
|
||||
|
||||
if bias:
|
||||
grad_bias = grad_input.sum(dim).detach()
|
||||
|
||||
else:
|
||||
grad_bias = empty
|
||||
|
||||
return grad_input, grad_bias
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gradgrad_input, gradgrad_bias):
|
||||
out, = ctx.saved_tensors
|
||||
gradgrad_out = fused.fused_bias_act(
|
||||
gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
|
||||
)
|
||||
|
||||
return gradgrad_out, None, None, None, None
|
||||
|
||||
|
||||
class FusedLeakyReLUFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, bias, negative_slope, scale):
|
||||
empty = input.new_empty(0)
|
||||
|
||||
ctx.bias = bias is not None
|
||||
|
||||
if bias is None:
|
||||
bias = empty
|
||||
|
||||
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
|
||||
ctx.save_for_backward(out)
|
||||
ctx.negative_slope = negative_slope
|
||||
ctx.scale = scale
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
out, = ctx.saved_tensors
|
||||
|
||||
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
||||
grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
|
||||
)
|
||||
|
||||
if not ctx.bias:
|
||||
grad_bias = None
|
||||
|
||||
return grad_input, grad_bias, None, None
|
||||
|
||||
|
||||
class FusedLeakyReLU(nn.Module):
|
||||
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
|
||||
super().__init__()
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(channel))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.negative_slope = negative_slope
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, input):
|
||||
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
||||
|
||||
|
||||
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
|
||||
if input.device.type == "cpu":
|
||||
if bias is not None:
|
||||
rest_dim = [1] * (input.ndim - bias.ndim - 1)
|
||||
return (
|
||||
F.leaky_relu(
|
||||
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
|
||||
)
|
||||
* scale
|
||||
)
|
||||
|
||||
else:
|
||||
return F.leaky_relu(input, negative_slope=0.2) * scale
|
||||
|
||||
else:
|
||||
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
@ -0,0 +1,32 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
torch::Tensor fused_bias_act_op(const torch::Tensor &input,
|
||||
const torch::Tensor &bias,
|
||||
const torch::Tensor &refer, int act, int grad,
|
||||
float alpha, float scale);
|
||||
|
||||
#define CHECK_CUDA(x) \
|
||||
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
torch::Tensor fused_bias_act(const torch::Tensor &input,
|
||||
const torch::Tensor &bias,
|
||||
const torch::Tensor &refer, int act, int grad,
|
||||
float alpha, float scale) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(bias);
|
||||
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
||||
}
|
@ -0,0 +1,105 @@
|
||||
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
||||
//
|
||||
// This work is made available under the Nvidia Source Code License-NC.
|
||||
// To view a copy of this license, visit
|
||||
// https://nvlabs.github.io/stylegan2/license.html
|
||||
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
template <typename scalar_t>
|
||||
static __global__ void
|
||||
fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
|
||||
const scalar_t *p_ref, int act, int grad, scalar_t alpha,
|
||||
scalar_t scale, int loop_x, int size_x, int step_b,
|
||||
int size_b, int use_bias, int use_ref) {
|
||||
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
||||
|
||||
scalar_t zero = 0.0;
|
||||
|
||||
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
|
||||
loop_idx++, xi += blockDim.x) {
|
||||
scalar_t x = p_x[xi];
|
||||
|
||||
if (use_bias) {
|
||||
x += p_b[(xi / step_b) % size_b];
|
||||
}
|
||||
|
||||
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
||||
|
||||
scalar_t y;
|
||||
|
||||
switch (act * 10 + grad) {
|
||||
default:
|
||||
case 10:
|
||||
y = x;
|
||||
break;
|
||||
case 11:
|
||||
y = x;
|
||||
break;
|
||||
case 12:
|
||||
y = 0.0;
|
||||
break;
|
||||
|
||||
case 30:
|
||||
y = (x > 0.0) ? x : x * alpha;
|
||||
break;
|
||||
case 31:
|
||||
y = (ref > 0.0) ? x : x * alpha;
|
||||
break;
|
||||
case 32:
|
||||
y = 0.0;
|
||||
break;
|
||||
}
|
||||
|
||||
out[xi] = y * scale;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor fused_bias_act_op(const torch::Tensor &input,
|
||||
const torch::Tensor &bias,
|
||||
const torch::Tensor &refer, int act, int grad,
|
||||
float alpha, float scale) {
|
||||
int curDevice = -1;
|
||||
cudaGetDevice(&curDevice);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
auto x = input.contiguous();
|
||||
auto b = bias.contiguous();
|
||||
auto ref = refer.contiguous();
|
||||
|
||||
int use_bias = b.numel() ? 1 : 0;
|
||||
int use_ref = ref.numel() ? 1 : 0;
|
||||
|
||||
int size_x = x.numel();
|
||||
int size_b = b.numel();
|
||||
int step_b = 1;
|
||||
|
||||
for (int i = 1 + 1; i < x.dim(); i++) {
|
||||
step_b *= x.size(i);
|
||||
}
|
||||
|
||||
int loop_x = 4;
|
||||
int block_size = 4 * 32;
|
||||
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
||||
|
||||
auto y = torch::empty_like(x);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
x.scalar_type(), "fused_bias_act_kernel", [&] {
|
||||
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
||||
y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
||||
b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
|
||||
scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
|
||||
});
|
||||
|
||||
return y;
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
||||
const torch::Tensor &kernel, int up_x, int up_y,
|
||||
int down_x, int down_y, int pad_x0, int pad_x1,
|
||||
int pad_y0, int pad_y1);
|
||||
|
||||
#define CHECK_CUDA(x) \
|
||||
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
|
||||
int up_x, int up_y, int down_x, int down_y, int pad_x0,
|
||||
int pad_x1, int pad_y0, int pad_y1) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(kernel);
|
||||
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
|
||||
pad_y0, pad_y1);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
||||
}
|
@ -0,0 +1,209 @@
|
||||
from collections import abc
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch.autograd import Function
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
|
||||
module_path = os.path.dirname(__file__)
|
||||
upfirdn2d_op = load(
|
||||
"upfirdn2d",
|
||||
sources=[
|
||||
os.path.join(module_path, "upfirdn2d.cpp"),
|
||||
os.path.join(module_path, "upfirdn2d_kernel.cu"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class UpFirDn2dBackward(Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
|
||||
):
|
||||
|
||||
up_x, up_y = up
|
||||
down_x, down_y = down
|
||||
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
||||
|
||||
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
||||
|
||||
grad_input = upfirdn2d_op.upfirdn2d(
|
||||
grad_output,
|
||||
grad_kernel,
|
||||
down_x,
|
||||
down_y,
|
||||
up_x,
|
||||
up_y,
|
||||
g_pad_x0,
|
||||
g_pad_x1,
|
||||
g_pad_y0,
|
||||
g_pad_y1,
|
||||
)
|
||||
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
|
||||
|
||||
ctx.save_for_backward(kernel)
|
||||
|
||||
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
||||
|
||||
ctx.up_x = up_x
|
||||
ctx.up_y = up_y
|
||||
ctx.down_x = down_x
|
||||
ctx.down_y = down_y
|
||||
ctx.pad_x0 = pad_x0
|
||||
ctx.pad_x1 = pad_x1
|
||||
ctx.pad_y0 = pad_y0
|
||||
ctx.pad_y1 = pad_y1
|
||||
ctx.in_size = in_size
|
||||
ctx.out_size = out_size
|
||||
|
||||
return grad_input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gradgrad_input):
|
||||
kernel, = ctx.saved_tensors
|
||||
|
||||
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
|
||||
|
||||
gradgrad_out = upfirdn2d_op.upfirdn2d(
|
||||
gradgrad_input,
|
||||
kernel,
|
||||
ctx.up_x,
|
||||
ctx.up_y,
|
||||
ctx.down_x,
|
||||
ctx.down_y,
|
||||
ctx.pad_x0,
|
||||
ctx.pad_x1,
|
||||
ctx.pad_y0,
|
||||
ctx.pad_y1,
|
||||
)
|
||||
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
|
||||
gradgrad_out = gradgrad_out.view(
|
||||
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
|
||||
)
|
||||
|
||||
return gradgrad_out, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class UpFirDn2d(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, kernel, up, down, pad):
|
||||
up_x, up_y = up
|
||||
down_x, down_y = down
|
||||
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
||||
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
batch, channel, in_h, in_w = input.shape
|
||||
ctx.in_size = input.shape
|
||||
|
||||
input = input.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
||||
|
||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
||||
ctx.out_size = (out_h, out_w)
|
||||
|
||||
ctx.up = (up_x, up_y)
|
||||
ctx.down = (down_x, down_y)
|
||||
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
||||
|
||||
g_pad_x0 = kernel_w - pad_x0 - 1
|
||||
g_pad_y0 = kernel_h - pad_y0 - 1
|
||||
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
||||
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
||||
|
||||
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
||||
|
||||
out = upfirdn2d_op.upfirdn2d(
|
||||
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
||||
)
|
||||
# out = out.view(major, out_h, out_w, minor)
|
||||
out = out.view(-1, channel, out_h, out_w)
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
kernel, grad_kernel = ctx.saved_tensors
|
||||
|
||||
grad_input = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = UpFirDn2dBackward.apply(
|
||||
grad_output,
|
||||
kernel,
|
||||
grad_kernel,
|
||||
ctx.up,
|
||||
ctx.down,
|
||||
ctx.pad,
|
||||
ctx.g_pad,
|
||||
ctx.in_size,
|
||||
ctx.out_size,
|
||||
)
|
||||
|
||||
return grad_input, None, None, None, None
|
||||
|
||||
|
||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||
if not isinstance(up, abc.Iterable):
|
||||
up = (up, up)
|
||||
|
||||
if not isinstance(down, abc.Iterable):
|
||||
down = (down, down)
|
||||
|
||||
if len(pad) == 2:
|
||||
pad = (pad[0], pad[1], pad[0], pad[1])
|
||||
|
||||
if input.device.type == "cpu":
|
||||
out = upfirdn2d_native(input, kernel, *up, *down, *pad)
|
||||
|
||||
else:
|
||||
out = UpFirDn2d.apply(input, kernel, up, down, pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def upfirdn2d_native(
|
||||
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
||||
):
|
||||
_, channel, in_h, in_w = input.shape
|
||||
input = input.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
_, in_h, in_w, minor = input.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
||||
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
||||
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
||||
|
||||
out = F.pad(
|
||||
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
||||
)
|
||||
out = out[
|
||||
:,
|
||||
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
||||
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
||||
:,
|
||||
]
|
||||
|
||||
out = out.permute(0, 3, 1, 2)
|
||||
out = out.reshape(
|
||||
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
||||
)
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(
|
||||
-1,
|
||||
minor,
|
||||
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
||||
)
|
||||
out = out.permute(0, 2, 3, 1)
|
||||
out = out[:, ::down_y, ::down_x, :]
|
||||
|
||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
||||
|
||||
return out.view(-1, channel, out_h, out_w)
|
@ -0,0 +1,369 @@
|
||||
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
||||
//
|
||||
// This work is made available under the Nvidia Source Code License-NC.
|
||||
// To view a copy of this license, visit
|
||||
// https://nvlabs.github.io/stylegan2/license.html
|
||||
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
||||
int c = a / b;
|
||||
|
||||
if (c * b > a) {
|
||||
c--;
|
||||
}
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
struct UpFirDn2DKernelParams {
|
||||
int up_x;
|
||||
int up_y;
|
||||
int down_x;
|
||||
int down_y;
|
||||
int pad_x0;
|
||||
int pad_x1;
|
||||
int pad_y0;
|
||||
int pad_y1;
|
||||
|
||||
int major_dim;
|
||||
int in_h;
|
||||
int in_w;
|
||||
int minor_dim;
|
||||
int kernel_h;
|
||||
int kernel_w;
|
||||
int out_h;
|
||||
int out_w;
|
||||
int loop_major;
|
||||
int loop_x;
|
||||
};
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
|
||||
const scalar_t *kernel,
|
||||
const UpFirDn2DKernelParams p) {
|
||||
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int out_y = minor_idx / p.minor_dim;
|
||||
minor_idx -= out_y * p.minor_dim;
|
||||
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
|
||||
int major_idx_base = blockIdx.z * p.loop_major;
|
||||
|
||||
if (out_x_base >= p.out_w || out_y >= p.out_h ||
|
||||
major_idx_base >= p.major_dim) {
|
||||
return;
|
||||
}
|
||||
|
||||
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
|
||||
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
|
||||
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
|
||||
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
|
||||
|
||||
for (int loop_major = 0, major_idx = major_idx_base;
|
||||
loop_major < p.loop_major && major_idx < p.major_dim;
|
||||
loop_major++, major_idx++) {
|
||||
for (int loop_x = 0, out_x = out_x_base;
|
||||
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
|
||||
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
|
||||
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
|
||||
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
|
||||
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
|
||||
|
||||
const scalar_t *x_p =
|
||||
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
|
||||
minor_idx];
|
||||
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
|
||||
int x_px = p.minor_dim;
|
||||
int k_px = -p.up_x;
|
||||
int x_py = p.in_w * p.minor_dim;
|
||||
int k_py = -p.up_y * p.kernel_w;
|
||||
|
||||
scalar_t v = 0.0f;
|
||||
|
||||
for (int y = 0; y < h; y++) {
|
||||
for (int x = 0; x < w; x++) {
|
||||
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
|
||||
x_p += x_px;
|
||||
k_p += k_px;
|
||||
}
|
||||
|
||||
x_p += x_py - w * x_px;
|
||||
k_p += k_py - w * k_px;
|
||||
}
|
||||
|
||||
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
||||
minor_idx] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
|
||||
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
||||
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
|
||||
const scalar_t *kernel,
|
||||
const UpFirDn2DKernelParams p) {
|
||||
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
||||
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
||||
|
||||
__shared__ volatile float sk[kernel_h][kernel_w];
|
||||
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
||||
|
||||
int minor_idx = blockIdx.x;
|
||||
int tile_out_y = minor_idx / p.minor_dim;
|
||||
minor_idx -= tile_out_y * p.minor_dim;
|
||||
tile_out_y *= tile_out_h;
|
||||
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
||||
int major_idx_base = blockIdx.z * p.loop_major;
|
||||
|
||||
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
|
||||
major_idx_base >= p.major_dim) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
|
||||
tap_idx += blockDim.x) {
|
||||
int ky = tap_idx / kernel_w;
|
||||
int kx = tap_idx - ky * kernel_w;
|
||||
scalar_t v = 0.0;
|
||||
|
||||
if (kx < p.kernel_w & ky < p.kernel_h) {
|
||||
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
||||
}
|
||||
|
||||
sk[ky][kx] = v;
|
||||
}
|
||||
|
||||
for (int loop_major = 0, major_idx = major_idx_base;
|
||||
loop_major < p.loop_major & major_idx < p.major_dim;
|
||||
loop_major++, major_idx++) {
|
||||
for (int loop_x = 0, tile_out_x = tile_out_x_base;
|
||||
loop_x < p.loop_x & tile_out_x < p.out_w;
|
||||
loop_x++, tile_out_x += tile_out_w) {
|
||||
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
||||
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
||||
int tile_in_x = floor_div(tile_mid_x, up_x);
|
||||
int tile_in_y = floor_div(tile_mid_y, up_y);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
|
||||
in_idx += blockDim.x) {
|
||||
int rel_in_y = in_idx / tile_in_w;
|
||||
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
||||
int in_x = rel_in_x + tile_in_x;
|
||||
int in_y = rel_in_y + tile_in_y;
|
||||
|
||||
scalar_t v = 0.0;
|
||||
|
||||
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
||||
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
|
||||
p.minor_dim +
|
||||
minor_idx];
|
||||
}
|
||||
|
||||
sx[rel_in_y][rel_in_x] = v;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
|
||||
out_idx += blockDim.x) {
|
||||
int rel_out_y = out_idx / tile_out_w;
|
||||
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
||||
int out_x = rel_out_x + tile_out_x;
|
||||
int out_y = rel_out_y + tile_out_y;
|
||||
|
||||
int mid_x = tile_mid_x + rel_out_x * down_x;
|
||||
int mid_y = tile_mid_y + rel_out_y * down_y;
|
||||
int in_x = floor_div(mid_x, up_x);
|
||||
int in_y = floor_div(mid_y, up_y);
|
||||
int rel_in_x = in_x - tile_in_x;
|
||||
int rel_in_y = in_y - tile_in_y;
|
||||
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
||||
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
||||
|
||||
scalar_t v = 0.0;
|
||||
|
||||
#pragma unroll
|
||||
for (int y = 0; y < kernel_h / up_y; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < kernel_w / up_x; x++)
|
||||
v += sx[rel_in_y + y][rel_in_x + x] *
|
||||
sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
||||
|
||||
if (out_x < p.out_w & out_y < p.out_h) {
|
||||
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
||||
minor_idx] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
||||
const torch::Tensor &kernel, int up_x, int up_y,
|
||||
int down_x, int down_y, int pad_x0, int pad_x1,
|
||||
int pad_y0, int pad_y1) {
|
||||
int curDevice = -1;
|
||||
cudaGetDevice(&curDevice);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
UpFirDn2DKernelParams p;
|
||||
|
||||
auto x = input.contiguous();
|
||||
auto k = kernel.contiguous();
|
||||
|
||||
p.major_dim = x.size(0);
|
||||
p.in_h = x.size(1);
|
||||
p.in_w = x.size(2);
|
||||
p.minor_dim = x.size(3);
|
||||
p.kernel_h = k.size(0);
|
||||
p.kernel_w = k.size(1);
|
||||
p.up_x = up_x;
|
||||
p.up_y = up_y;
|
||||
p.down_x = down_x;
|
||||
p.down_y = down_y;
|
||||
p.pad_x0 = pad_x0;
|
||||
p.pad_x1 = pad_x1;
|
||||
p.pad_y0 = pad_y0;
|
||||
p.pad_y1 = pad_y1;
|
||||
|
||||
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
|
||||
p.down_y;
|
||||
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
|
||||
p.down_x;
|
||||
|
||||
auto out =
|
||||
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
||||
|
||||
int mode = -1;
|
||||
|
||||
int tile_out_h = -1;
|
||||
int tile_out_w = -1;
|
||||
|
||||
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
||||
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
||||
mode = 1;
|
||||
tile_out_h = 16;
|
||||
tile_out_w = 64;
|
||||
}
|
||||
|
||||
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
||||
p.kernel_h <= 3 && p.kernel_w <= 3) {
|
||||
mode = 2;
|
||||
tile_out_h = 16;
|
||||
tile_out_w = 64;
|
||||
}
|
||||
|
||||
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
||||
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
||||
mode = 3;
|
||||
tile_out_h = 16;
|
||||
tile_out_w = 64;
|
||||
}
|
||||
|
||||
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
||||
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
||||
mode = 4;
|
||||
tile_out_h = 16;
|
||||
tile_out_w = 64;
|
||||
}
|
||||
|
||||
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
||||
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
||||
mode = 5;
|
||||
tile_out_h = 8;
|
||||
tile_out_w = 32;
|
||||
}
|
||||
|
||||
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
||||
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
||||
mode = 6;
|
||||
tile_out_h = 8;
|
||||
tile_out_w = 32;
|
||||
}
|
||||
|
||||
dim3 block_size;
|
||||
dim3 grid_size;
|
||||
|
||||
if (tile_out_h > 0 && tile_out_w > 0) {
|
||||
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
||||
p.loop_x = 1;
|
||||
block_size = dim3(32 * 8, 1, 1);
|
||||
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
||||
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
||||
(p.major_dim - 1) / p.loop_major + 1);
|
||||
} else {
|
||||
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
||||
p.loop_x = 4;
|
||||
block_size = dim3(4, 32, 1);
|
||||
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
|
||||
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
|
||||
(p.major_dim - 1) / p.loop_major + 1);
|
||||
}
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
||||
switch (mode) {
|
||||
case 1:
|
||||
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 2:
|
||||
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 3:
|
||||
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 4:
|
||||
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 5:
|
||||
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 6:
|
||||
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
default:
|
||||
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
}
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
@ -0,0 +1,103 @@
|
||||
import argparse
|
||||
from io import BytesIO
|
||||
import multiprocessing
|
||||
from functools import partial
|
||||
|
||||
from PIL import Image
|
||||
import lmdb
|
||||
from tqdm import tqdm
|
||||
from torchvision import datasets
|
||||
from torchvision.transforms import functional as trans_fn
|
||||
import os
|
||||
|
||||
def resize_and_convert(img, size, resample, quality=100):
|
||||
img = trans_fn.resize(img, size, resample)
|
||||
img = trans_fn.center_crop(img, size)
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format="jpeg", quality=quality)
|
||||
val = buffer.getvalue()
|
||||
|
||||
return val
|
||||
|
||||
|
||||
def resize_multiple(
|
||||
img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100
|
||||
):
|
||||
imgs = []
|
||||
|
||||
for size in sizes:
|
||||
imgs.append(resize_and_convert(img, size, resample, quality))
|
||||
|
||||
return imgs
|
||||
|
||||
|
||||
def resize_worker(img_file, sizes, resample):
|
||||
i, file = img_file
|
||||
img = Image.open(file)
|
||||
img = img.convert("RGB")
|
||||
out = resize_multiple(img, sizes=sizes, resample=resample)
|
||||
|
||||
return i, out
|
||||
|
||||
|
||||
def prepare(
|
||||
env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
|
||||
):
|
||||
resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
|
||||
|
||||
files = sorted(dataset.imgs, key=lambda x: x[0])
|
||||
files = [(i, file) for i, (file, label) in enumerate(files)]
|
||||
total = 0
|
||||
|
||||
with multiprocessing.Pool(n_worker) as pool:
|
||||
for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
|
||||
for size, img in zip(sizes, imgs):
|
||||
key = f"{size}-{str(i).zfill(5)}".encode("utf-8")
|
||||
|
||||
with env.begin(write=True) as txn:
|
||||
txn.put(key, img)
|
||||
|
||||
total += 1
|
||||
|
||||
with env.begin(write=True) as txn:
|
||||
txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Preprocess images for model training")
|
||||
parser.add_argument("path", type=str, help="path to the image dataset")
|
||||
parser.add_argument("--out", type=str, help="filename of the result lmdb dataset")
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
type=str,
|
||||
default="128,256,512,1024",
|
||||
help="resolutions of images for the dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_worker",
|
||||
type=int,
|
||||
default=16,
|
||||
help="number of workers for preparing dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resample",
|
||||
type=str,
|
||||
default="lanczos",
|
||||
help="resampling methods for resizing images",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if not os.path.exists(str(args.out)):
|
||||
os.makedirs(str(args.out))
|
||||
|
||||
resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
|
||||
resample = resample_map[args.resample]
|
||||
|
||||
sizes = [int(s.strip()) for s in args.size.split(",")]
|
||||
|
||||
print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
|
||||
|
||||
imgset = datasets.ImageFolder(args.path)
|
||||
|
||||
with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
|
||||
prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)
|
@ -0,0 +1,145 @@
|
||||
# model blending technique
|
||||
import os
|
||||
import cv2 as cv
|
||||
import torch
|
||||
from model import Generator
|
||||
import math
|
||||
import argparse
|
||||
|
||||
def extract_conv_names(model):
|
||||
model = list(model.keys())
|
||||
conv_name = []
|
||||
resolutions = [4*2**x for x in range(9)]
|
||||
level_names = [["Conv0_up", "Const"], ["Conv1", "ToRGB"]]
|
||||
|
||||
|
||||
def blend_models(model_1, model_2, resolution, level, blend_width=None):
|
||||
resolutions = [4 * 2 ** i for i in range(7)]
|
||||
mid = resolutions.index(resolution)
|
||||
|
||||
device = "cuda"
|
||||
|
||||
size = 256
|
||||
latent = 512
|
||||
n_mlp = 8
|
||||
channel_multiplier =2
|
||||
G_1 = Generator(
|
||||
size, latent, n_mlp, channel_multiplier=channel_multiplier
|
||||
).to(device)
|
||||
ckpt_ffhq = torch.load(model_1, map_location=lambda storage, loc: storage)
|
||||
G_1.load_state_dict(ckpt_ffhq["g"], strict=False)
|
||||
|
||||
|
||||
G_2 = Generator(
|
||||
size, latent, n_mlp, channel_multiplier=channel_multiplier
|
||||
).to(device)
|
||||
ckpt_toon = torch.load(model_2)
|
||||
G_2.load_state_dict(ckpt_toon["g_ema"])
|
||||
|
||||
|
||||
|
||||
# G_1 = stylegan2.models.load(model_1)
|
||||
# G_2 = stylegan2.models.load(model_2)
|
||||
model_1_state_dict = G_1.state_dict()
|
||||
model_2_state_dict = G_2.state_dict()
|
||||
assert(model_1_state_dict.keys() == model_2_state_dict.keys())
|
||||
G_out = G_1.clone()
|
||||
|
||||
layers = []
|
||||
ys = []
|
||||
for k, v in model_1_state_dict.items():
|
||||
if k.startswith('convs.'):
|
||||
pos = int(k[len('convs.')])
|
||||
x = pos - mid
|
||||
if blend_width:
|
||||
exponent = -x / blend_width
|
||||
y = 1 / (1 + math.exp(exponent))
|
||||
else:
|
||||
y = 1 if x > 0 else 0
|
||||
|
||||
layers.append(k)
|
||||
ys.append(y)
|
||||
elif k.startswith('to_rgbs.'):
|
||||
pos = int(k[len('to_rgbs.')])
|
||||
x = pos - mid
|
||||
if blend_width:
|
||||
exponent = -x / blend_width
|
||||
y = 1 / (1 + math.exp(exponent))
|
||||
else:
|
||||
y = 1 if x > 0 else 0
|
||||
layers.append(k)
|
||||
ys.append(y)
|
||||
out_state = G_out.state_dict()
|
||||
for y, layer in zip(ys, layers):
|
||||
out_state[layer] = y * model_2_state_dict[layer] + \
|
||||
(1 - y) * model_1_state_dict[layer]
|
||||
print('blend layer %s'%str(y))
|
||||
G_out.load_state_dict(out_state)
|
||||
return G_out
|
||||
|
||||
|
||||
def blend_models_2(model_1, model_2, resolution, level, blend_width=None):
|
||||
# resolution = f"{resolution}x{resolution}"
|
||||
resolutions = [4 * 2 ** i for i in range(7)]
|
||||
mid = [resolutions.index(r) for r in resolution]
|
||||
|
||||
G_1 = stylegan2.models.load(model_1)
|
||||
G_2 = stylegan2.models.load(model_2)
|
||||
model_1_state_dict = G_1.state_dict()
|
||||
model_2_state_dict = G_2.state_dict()
|
||||
assert(model_1_state_dict.keys() == model_2_state_dict.keys())
|
||||
G_out = G_1.clone()
|
||||
|
||||
layers = []
|
||||
ys = []
|
||||
for k, v in model_1_state_dict.items():
|
||||
if k.startswith('G_synthesis.conv_blocks.'):
|
||||
pos = int(k[len('G_synthesis.conv_blocks.')])
|
||||
y = 0 if pos in mid else 1
|
||||
layers.append(k)
|
||||
ys.append(y)
|
||||
elif k.startswith('G_synthesis.to_data_layers.'):
|
||||
pos = int(k[len('G_synthesis.to_data_layers.')])
|
||||
y = 0 if pos in mid else 1
|
||||
layers.append(k)
|
||||
ys.append(y)
|
||||
# print(ys, layers)
|
||||
out_state = G_out.state_dict()
|
||||
for y, layer in zip(ys, layers):
|
||||
out_state[layer] = y * model_2_state_dict[layer] + \
|
||||
(1 - y) * model_1_state_dict[layer]
|
||||
G_out.load_state_dict(out_state)
|
||||
return G_out
|
||||
|
||||
|
||||
def main(name):
|
||||
|
||||
resolution = 4
|
||||
|
||||
model_name = '001000.pt'
|
||||
|
||||
G_out = blend_models("pretrained_models/stylegan2-ffhq-config-f-256-550000.pt",
|
||||
"face_generation/experiment_stylegan/"+name+"/models/"+model_name,
|
||||
resolution,
|
||||
None)
|
||||
# G_out.save('G_blend.pth')
|
||||
outdir = os.path.join('face_generation/experiment_stylegan',name,'models_blend')
|
||||
if not os.path.exists(outdir):
|
||||
os.makedirs(outdir)
|
||||
|
||||
outpath = os.path.join(outdir, 'G_blend_'+str(model_name[:-3])+'_'+ str(resolution)+'.pt')
|
||||
torch.save(
|
||||
{
|
||||
"g_ema": G_out.state_dict(),
|
||||
},
|
||||
# 'G_blend_570000_16.pth',
|
||||
outpath
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="style blender")
|
||||
parser.add_argument('--name', type=str, default='')
|
||||
args = parser.parse_args()
|
||||
print('model name:%s'%args.name)
|
||||
main(args.name)
|
@ -0,0 +1,584 @@
|
||||
import os
|
||||
import argparse
|
||||
import math
|
||||
import random
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn, autograd, optim
|
||||
from torch.nn import functional as F
|
||||
from torch.utils import data
|
||||
import torch.distributed as dist
|
||||
from torchvision import transforms, utils
|
||||
from tqdm import tqdm
|
||||
from criteria import id_loss
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
||||
except ImportError:
|
||||
wandb = None
|
||||
|
||||
|
||||
from dataset import MultiResolutionDataset
|
||||
from distributed import (
|
||||
get_rank,
|
||||
synchronize,
|
||||
reduce_loss_dict,
|
||||
reduce_sum,
|
||||
get_world_size,
|
||||
)
|
||||
from op import conv2d_gradfix
|
||||
from non_leaking import augment, AdaptiveAugment
|
||||
|
||||
|
||||
def data_sampler(dataset, shuffle, distributed):
|
||||
if distributed:
|
||||
return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||
|
||||
if shuffle:
|
||||
return data.RandomSampler(dataset)
|
||||
|
||||
else:
|
||||
return data.SequentialSampler(dataset)
|
||||
|
||||
|
||||
def requires_grad(model, flag=True):
|
||||
for p in model.parameters():
|
||||
p.requires_grad = flag
|
||||
|
||||
|
||||
def accumulate(model1, model2, decay=0.999):
|
||||
par1 = dict(model1.named_parameters())
|
||||
par2 = dict(model2.named_parameters())
|
||||
|
||||
for k in par1.keys():
|
||||
par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
|
||||
|
||||
|
||||
def sample_data(loader):
|
||||
while True:
|
||||
for batch in loader:
|
||||
yield batch
|
||||
|
||||
|
||||
def d_logistic_loss(real_pred, fake_pred):
|
||||
real_loss = F.softplus(-real_pred)
|
||||
fake_loss = F.softplus(fake_pred)
|
||||
|
||||
return real_loss.mean() + fake_loss.mean()
|
||||
|
||||
|
||||
def d_r1_loss(real_pred, real_img):
|
||||
with conv2d_gradfix.no_weight_gradients():
|
||||
grad_real, = autograd.grad(
|
||||
outputs=real_pred.sum(), inputs=real_img, create_graph=True
|
||||
)
|
||||
grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
|
||||
|
||||
return grad_penalty
|
||||
|
||||
|
||||
def g_nonsaturating_loss(fake_pred):
|
||||
loss = F.softplus(-fake_pred).mean()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
|
||||
noise = torch.randn_like(fake_img) / math.sqrt(
|
||||
fake_img.shape[2] * fake_img.shape[3]
|
||||
)
|
||||
grad, = autograd.grad(
|
||||
outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
|
||||
)
|
||||
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
|
||||
|
||||
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
|
||||
|
||||
path_penalty = (path_lengths - path_mean).pow(2).mean()
|
||||
|
||||
return path_penalty, path_mean.detach(), path_lengths
|
||||
|
||||
|
||||
def make_noise(batch, latent_dim, n_noise, device):
|
||||
if n_noise == 1:
|
||||
return torch.randn(batch, latent_dim, device=device)
|
||||
|
||||
noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
|
||||
|
||||
return noises
|
||||
|
||||
|
||||
def mixing_noise(batch, latent_dim, prob, device):
|
||||
if prob > 0 and random.random() < prob:
|
||||
return make_noise(batch, latent_dim, 2, device)
|
||||
|
||||
else:
|
||||
return [make_noise(batch, latent_dim, 1, device)]
|
||||
|
||||
|
||||
def set_grad_none(model, targets):
|
||||
for n, p in model.named_parameters():
|
||||
if n in targets:
|
||||
p.grad = None
|
||||
|
||||
|
||||
def train(args, loader, generator, generator_source, discriminator, g_optim, d_optim, g_ema, device):
|
||||
loader = sample_data(loader)
|
||||
|
||||
pbar = range(args.iter)
|
||||
|
||||
if get_rank() == 0:
|
||||
pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)
|
||||
|
||||
mean_path_length = 0
|
||||
|
||||
d_loss_val = 0
|
||||
r1_loss = torch.tensor(0.0, device=device)
|
||||
g_loss_val = 0
|
||||
path_loss = torch.tensor(0.0, device=device)
|
||||
path_lengths = torch.tensor(0.0, device=device)
|
||||
mean_path_length_avg = 0
|
||||
loss_dict = {}
|
||||
|
||||
### add id loss, content loss
|
||||
g_id_loss = id_loss.IDLoss().to(device).eval()
|
||||
|
||||
if args.distributed:
|
||||
g_module = generator.module
|
||||
d_module = discriminator.module
|
||||
|
||||
else:
|
||||
g_module = generator
|
||||
d_module = discriminator
|
||||
|
||||
accum = 0.5 ** (32 / (10 * 1000))
|
||||
ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
|
||||
r_t_stat = 0
|
||||
|
||||
if args.augment and args.augment_p == 0:
|
||||
ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8, device)
|
||||
|
||||
# sample_z = torch.randn(args.n_sample, args.latent, device=device)
|
||||
sample_z = torch.load('noise.pt').to(device)
|
||||
|
||||
for idx in pbar:
|
||||
i = idx + args.start_iter
|
||||
|
||||
if i > args.iter:
|
||||
print("Done!")
|
||||
|
||||
break
|
||||
|
||||
real_img = next(loader)
|
||||
real_img = real_img.to(device)
|
||||
|
||||
requires_grad(generator, False)
|
||||
requires_grad(discriminator, True)
|
||||
|
||||
noise = mixing_noise(args.batch, args.latent, args.mixing, device)
|
||||
fake_img, _ = generator(noise)
|
||||
|
||||
if args.augment:
|
||||
real_img_aug, _ = augment(real_img, ada_aug_p)
|
||||
fake_img, _ = augment(fake_img, ada_aug_p)
|
||||
|
||||
else:
|
||||
real_img_aug = real_img
|
||||
|
||||
fake_pred = discriminator(fake_img)
|
||||
real_pred = discriminator(real_img_aug)
|
||||
d_loss = d_logistic_loss(real_pred, fake_pred)
|
||||
|
||||
loss_dict["d"] = d_loss
|
||||
loss_dict["real_score"] = real_pred.mean()
|
||||
loss_dict["fake_score"] = fake_pred.mean()
|
||||
|
||||
discriminator.zero_grad()
|
||||
d_loss.backward()
|
||||
d_optim.step()
|
||||
|
||||
if args.augment and args.augment_p == 0:
|
||||
ada_aug_p = ada_augment.tune(real_pred)
|
||||
r_t_stat = ada_augment.r_t_stat
|
||||
|
||||
d_regularize = i % args.d_reg_every == 0
|
||||
|
||||
if d_regularize:
|
||||
real_img.requires_grad = True
|
||||
|
||||
if args.augment:
|
||||
real_img_aug, _ = augment(real_img, ada_aug_p)
|
||||
|
||||
else:
|
||||
real_img_aug = real_img
|
||||
|
||||
real_pred = discriminator(real_img_aug)
|
||||
r1_loss = d_r1_loss(real_pred, real_img)
|
||||
|
||||
discriminator.zero_grad()
|
||||
(args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()
|
||||
|
||||
d_optim.step()
|
||||
|
||||
loss_dict["r1"] = r1_loss
|
||||
|
||||
requires_grad(generator, True)
|
||||
requires_grad(discriminator, False)
|
||||
|
||||
noise = mixing_noise(args.batch, args.latent, args.mixing, device)
|
||||
fake_img, _ = generator(noise)
|
||||
|
||||
|
||||
if args.augment:
|
||||
fake_img, _ = augment(fake_img, ada_aug_p)
|
||||
|
||||
fake_pred = discriminator(fake_img)
|
||||
g_loss = g_nonsaturating_loss(fake_pred)
|
||||
|
||||
### id loss
|
||||
fake_img_source, _ = generator_source(noise)
|
||||
loss_id = g_id_loss(fake_img, fake_img_source)
|
||||
loss_id = float(loss_id)
|
||||
loss_dict['loss_id'] = loss_id
|
||||
# print(loss_id)
|
||||
|
||||
g_loss += loss_id*0.1
|
||||
|
||||
|
||||
loss_dict["g"] = g_loss
|
||||
|
||||
generator.zero_grad()
|
||||
g_loss.backward()
|
||||
g_optim.step()
|
||||
|
||||
g_regularize = i % args.g_reg_every == 0
|
||||
|
||||
if g_regularize:
|
||||
path_batch_size = max(1, args.batch // args.path_batch_shrink)
|
||||
noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
|
||||
fake_img, latents = generator(noise, return_latents=True)
|
||||
|
||||
path_loss, mean_path_length, path_lengths = g_path_regularize(
|
||||
fake_img, latents, mean_path_length
|
||||
)
|
||||
|
||||
generator.zero_grad()
|
||||
weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
|
||||
|
||||
if args.path_batch_shrink:
|
||||
weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
|
||||
|
||||
weighted_path_loss.backward()
|
||||
|
||||
g_optim.step()
|
||||
|
||||
mean_path_length_avg = (
|
||||
reduce_sum(mean_path_length).item() / get_world_size()
|
||||
)
|
||||
|
||||
loss_dict["path"] = path_loss
|
||||
loss_dict["path_length"] = path_lengths.mean()
|
||||
|
||||
accumulate(g_ema, g_module, accum)
|
||||
|
||||
loss_reduced = reduce_loss_dict(loss_dict)
|
||||
|
||||
d_loss_val = loss_reduced["d"].mean().item()
|
||||
g_loss_val = loss_reduced["g"].mean().item()
|
||||
r1_val = loss_reduced["r1"].mean().item()
|
||||
path_loss_val = loss_reduced["path"].mean().item()
|
||||
real_score_val = loss_reduced["real_score"].mean().item()
|
||||
fake_score_val = loss_reduced["fake_score"].mean().item()
|
||||
path_length_val = loss_reduced["path_length"].mean().item()
|
||||
|
||||
if get_rank() == 0:
|
||||
pbar.set_description(
|
||||
(
|
||||
f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
|
||||
f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
|
||||
f"augment: {ada_aug_p:.4f}"
|
||||
)
|
||||
)
|
||||
|
||||
if wandb and args.wandb:
|
||||
wandb.log(
|
||||
{
|
||||
"Generator": g_loss_val,
|
||||
"Discriminator": d_loss_val,
|
||||
"Augment": ada_aug_p,
|
||||
"Rt": r_t_stat,
|
||||
"R1": r1_val,
|
||||
"Path Length Regularization": path_loss_val,
|
||||
"Mean Path Length": mean_path_length,
|
||||
"Real Score": real_score_val,
|
||||
"Fake Score": fake_score_val,
|
||||
"Path Length": path_length_val,
|
||||
}
|
||||
)
|
||||
|
||||
if i % args.sample_every == 0:
|
||||
with torch.no_grad():
|
||||
g_ema.eval()
|
||||
sample, _ = g_ema([sample_z])
|
||||
outpath = os.path.join(args.output, args.name,'sample', str(i).zfill(6)+'.png')
|
||||
utils.save_image(
|
||||
sample,
|
||||
outpath,
|
||||
# nrow=int(args.n_sample ** 0.5),
|
||||
nrow=int(math.sqrt(sample.shape[0])),
|
||||
normalize=True,
|
||||
range=(-1, 1),
|
||||
)
|
||||
|
||||
if i % args.save_every == 0:
|
||||
outpath_ckpt = os.path.join(args.output, args.name, 'models', str(i).zfill(6) + '.pt')
|
||||
torch.save(
|
||||
{
|
||||
"g": g_module.state_dict(),
|
||||
"d": d_module.state_dict(),
|
||||
"g_ema": g_ema.state_dict(),
|
||||
"g_optim": g_optim.state_dict(),
|
||||
"d_optim": d_optim.state_dict(),
|
||||
"args": args,
|
||||
"ada_aug_p": ada_aug_p,
|
||||
},
|
||||
outpath_ckpt,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = "cuda"
|
||||
|
||||
parser = argparse.ArgumentParser(description="StyleGAN2 trainer")
|
||||
parser.add_argument('--config', type=str, default='config/conf_server_train_condition.json')
|
||||
parser.add_argument("--path", type=str, help="path to the lmdb dataset")
|
||||
parser.add_argument("--name", type=str, help="name of experiment")
|
||||
parser.add_argument('--arch', type=str, default='stylegan2', help='model architectures (stylegan2 | swagan)')
|
||||
parser.add_argument(
|
||||
"--iter", type=int, default=800000, help="total training iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch", type=int, default=1, help="batch sizes for each gpus"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_sample",
|
||||
type=int,
|
||||
default=64,
|
||||
help="number of the samples generated during training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--size", type=int, default=1024, help="image sizes for the model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--r1", type=float, default=10, help="weight of the r1 regularization"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--path_regularize",
|
||||
type=float,
|
||||
default=2,
|
||||
help="weight of the path length regularization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--path_batch_shrink",
|
||||
type=int,
|
||||
default=2,
|
||||
help="batch size reducing factor for the path length regularization (reduce memory consumption)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--d_reg_every",
|
||||
type=int,
|
||||
default=16,
|
||||
help="interval of the applying r1 regularization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--g_reg_every",
|
||||
type=int,
|
||||
default=4,
|
||||
help="interval of the applying path length regularization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixing", type=float, default=0.9, help="probability of latent code mixing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
help="path to the checkpoints to resume training",
|
||||
)
|
||||
parser.add_argument("--lr", type=float, default=0.002, help="learning rate")
|
||||
parser.add_argument(
|
||||
"--channel_multiplier",
|
||||
type=int,
|
||||
default=2,
|
||||
help="channel multiplier factor for the model. config-f = 2, else = 1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wandb", action="store_true", help="use weights and biases logging"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local_rank", type=int, default=0, help="local rank for distributed training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--augment", action="store_true", help="apply non leaking augmentation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--augment_p",
|
||||
type=float,
|
||||
default=0,
|
||||
help="probability of applying augmentation. 0 = use adaptive augmentation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ada_target",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="target augmentation probability for adaptive augmentation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ada_length",
|
||||
type=int,
|
||||
default=500 * 1000,
|
||||
help="target duraing to reach augmentation probability for adaptive augmentation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ada_every",
|
||||
type=int,
|
||||
default=256,
|
||||
help="probability update interval of the adaptive augmentation",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# from config updata paras
|
||||
opt = vars(args)
|
||||
with open(args.config) as f:
|
||||
config = json.load(f)['parameters']
|
||||
for key, value in config.items():
|
||||
opt[key] = value
|
||||
|
||||
output_sample = os.path.join(args.output, args.name,'sample')
|
||||
output_model = os.path.join(args.output, args.name,'models')
|
||||
|
||||
if not os.path.exists(output_sample):
|
||||
os.makedirs(output_sample)
|
||||
if not os.path.exists(output_model):
|
||||
os.makedirs(output_model)
|
||||
|
||||
|
||||
n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
|
||||
args.distributed = n_gpu > 1
|
||||
|
||||
if args.distributed:
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
||||
synchronize()
|
||||
|
||||
args.latent = 512
|
||||
args.n_mlp = 8
|
||||
|
||||
args.start_iter = 0
|
||||
|
||||
if args.arch == 'stylegan2':
|
||||
from model import Generator, Discriminator
|
||||
|
||||
elif args.arch == 'swagan':
|
||||
from swagan import Generator, Discriminator
|
||||
|
||||
generator = Generator(
|
||||
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
|
||||
).to(device)
|
||||
|
||||
### add source G
|
||||
generator_source = Generator(
|
||||
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
|
||||
).to(device)
|
||||
generator_source.eval()
|
||||
|
||||
discriminator = Discriminator(
|
||||
args.size, channel_multiplier=args.channel_multiplier
|
||||
).to(device)
|
||||
g_ema = Generator(
|
||||
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
|
||||
).to(device)
|
||||
g_ema.eval()
|
||||
accumulate(g_ema, generator, 0)
|
||||
|
||||
g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
|
||||
d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)
|
||||
|
||||
g_optim = optim.Adam(
|
||||
generator.parameters(),
|
||||
lr=args.lr * g_reg_ratio,
|
||||
betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
|
||||
)
|
||||
d_optim = optim.Adam(
|
||||
discriminator.parameters(),
|
||||
lr=args.lr * d_reg_ratio,
|
||||
betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
|
||||
)
|
||||
|
||||
if args.ckpt is not None:
|
||||
print("load model:", args.ckpt)
|
||||
|
||||
ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)
|
||||
|
||||
try:
|
||||
ckpt_name = os.path.basename(args.ckpt)
|
||||
args.start_iter = int(os.path.splitext(ckpt_name)[0])
|
||||
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
generator.load_state_dict(ckpt["g"], strict=False)
|
||||
discriminator.load_state_dict(ckpt["d"], strict=False)
|
||||
g_ema.load_state_dict(ckpt["g_ema"], strict=False)
|
||||
|
||||
if args.ckpt_ffhq is not None:
|
||||
ckpt_ffhq = torch.load(args.ckpt_ffhq, map_location=lambda storage, loc: storage)
|
||||
generator_source.load_state_dict(ckpt_ffhq["g"], strict=False)
|
||||
else:
|
||||
print('No source ckpt!')
|
||||
|
||||
|
||||
# generator.load_state_dict(ckpt["g"])
|
||||
# discriminator.load_state_dict(ckpt["d"])
|
||||
# g_ema.load_state_dict(ckpt["g_ema"])
|
||||
|
||||
# g_optim.load_state_dict(ckpt["g_optim"])
|
||||
# d_optim.load_state_dict(ckpt["d_optim"])
|
||||
|
||||
if args.distributed:
|
||||
generator = nn.parallel.DistributedDataParallel(
|
||||
generator,
|
||||
device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
broadcast_buffers=False,
|
||||
)
|
||||
|
||||
discriminator = nn.parallel.DistributedDataParallel(
|
||||
discriminator,
|
||||
device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
broadcast_buffers=False,
|
||||
)
|
||||
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
|
||||
]
|
||||
)
|
||||
|
||||
dataset = MultiResolutionDataset(args.path, transform, args.size)
|
||||
loader = data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch,
|
||||
sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed),
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
if get_rank() == 0 and wandb is not None and args.wandb:
|
||||
wandb.init(project="stylegan 2")
|
||||
|
||||
train(args, loader, generator, generator_source, discriminator, g_optim, d_optim, g_ema, device)
|
@ -0,0 +1,36 @@
|
||||
import os
|
||||
import cv2
|
||||
from modelscope.trainers.cv import CartoonTranslationTrainer
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
data_photo = os.path.join(args.data_dir, 'face_photo')
|
||||
data_cartoon = os.path.join(args.data_dir, 'face_cartoon')
|
||||
|
||||
style = args.style
|
||||
if style == "anime":
|
||||
style = ""
|
||||
else:
|
||||
style = '-' + style
|
||||
model_id = 'damo/cv_unet_person-image-cartoon' + style + '_compound-models'
|
||||
|
||||
max_steps = 300000
|
||||
trainer = CartoonTranslationTrainer(
|
||||
model=model_id,
|
||||
work_dir=args.work_dir,
|
||||
photo=data_photo,
|
||||
cartoon=data_cartoon,
|
||||
max_steps=max_steps)
|
||||
trainer.train()
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="process remove bg result")
|
||||
parser.add_argument("--data_dir", type=str, default='', help="Path to training images.")
|
||||
parser.add_argument("--work_dir", type=str, default='', help="Path to save results.")
|
||||
parser.add_argument("--style", type=str, default='anime', help="resume training from similar style.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
Loading…
Reference in New Issue