mirror of https://github.com/menyifang/DCT-Net
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
145 lines
4.4 KiB
Python
145 lines
4.4 KiB
Python
# model blending technique
|
|
import os
|
|
import cv2 as cv
|
|
import torch
|
|
from model import Generator
|
|
import math
|
|
import argparse
|
|
|
|
def extract_conv_names(model):
|
|
model = list(model.keys())
|
|
conv_name = []
|
|
resolutions = [4*2**x for x in range(9)]
|
|
level_names = [["Conv0_up", "Const"], ["Conv1", "ToRGB"]]
|
|
|
|
|
|
def blend_models(model_1, model_2, resolution, level, blend_width=None):
|
|
resolutions = [4 * 2 ** i for i in range(7)]
|
|
mid = resolutions.index(resolution)
|
|
|
|
device = "cuda"
|
|
|
|
size = 256
|
|
latent = 512
|
|
n_mlp = 8
|
|
channel_multiplier =2
|
|
G_1 = Generator(
|
|
size, latent, n_mlp, channel_multiplier=channel_multiplier
|
|
).to(device)
|
|
ckpt_ffhq = torch.load(model_1, map_location=lambda storage, loc: storage)
|
|
G_1.load_state_dict(ckpt_ffhq["g"], strict=False)
|
|
|
|
|
|
G_2 = Generator(
|
|
size, latent, n_mlp, channel_multiplier=channel_multiplier
|
|
).to(device)
|
|
ckpt_toon = torch.load(model_2)
|
|
G_2.load_state_dict(ckpt_toon["g_ema"])
|
|
|
|
|
|
|
|
# G_1 = stylegan2.models.load(model_1)
|
|
# G_2 = stylegan2.models.load(model_2)
|
|
model_1_state_dict = G_1.state_dict()
|
|
model_2_state_dict = G_2.state_dict()
|
|
assert(model_1_state_dict.keys() == model_2_state_dict.keys())
|
|
G_out = G_1.clone()
|
|
|
|
layers = []
|
|
ys = []
|
|
for k, v in model_1_state_dict.items():
|
|
if k.startswith('convs.'):
|
|
pos = int(k[len('convs.')])
|
|
x = pos - mid
|
|
if blend_width:
|
|
exponent = -x / blend_width
|
|
y = 1 / (1 + math.exp(exponent))
|
|
else:
|
|
y = 1 if x > 0 else 0
|
|
|
|
layers.append(k)
|
|
ys.append(y)
|
|
elif k.startswith('to_rgbs.'):
|
|
pos = int(k[len('to_rgbs.')])
|
|
x = pos - mid
|
|
if blend_width:
|
|
exponent = -x / blend_width
|
|
y = 1 / (1 + math.exp(exponent))
|
|
else:
|
|
y = 1 if x > 0 else 0
|
|
layers.append(k)
|
|
ys.append(y)
|
|
out_state = G_out.state_dict()
|
|
for y, layer in zip(ys, layers):
|
|
out_state[layer] = y * model_2_state_dict[layer] + \
|
|
(1 - y) * model_1_state_dict[layer]
|
|
print('blend layer %s'%str(y))
|
|
G_out.load_state_dict(out_state)
|
|
return G_out
|
|
|
|
|
|
def blend_models_2(model_1, model_2, resolution, level, blend_width=None):
|
|
# resolution = f"{resolution}x{resolution}"
|
|
resolutions = [4 * 2 ** i for i in range(7)]
|
|
mid = [resolutions.index(r) for r in resolution]
|
|
|
|
G_1 = stylegan2.models.load(model_1)
|
|
G_2 = stylegan2.models.load(model_2)
|
|
model_1_state_dict = G_1.state_dict()
|
|
model_2_state_dict = G_2.state_dict()
|
|
assert(model_1_state_dict.keys() == model_2_state_dict.keys())
|
|
G_out = G_1.clone()
|
|
|
|
layers = []
|
|
ys = []
|
|
for k, v in model_1_state_dict.items():
|
|
if k.startswith('G_synthesis.conv_blocks.'):
|
|
pos = int(k[len('G_synthesis.conv_blocks.')])
|
|
y = 0 if pos in mid else 1
|
|
layers.append(k)
|
|
ys.append(y)
|
|
elif k.startswith('G_synthesis.to_data_layers.'):
|
|
pos = int(k[len('G_synthesis.to_data_layers.')])
|
|
y = 0 if pos in mid else 1
|
|
layers.append(k)
|
|
ys.append(y)
|
|
# print(ys, layers)
|
|
out_state = G_out.state_dict()
|
|
for y, layer in zip(ys, layers):
|
|
out_state[layer] = y * model_2_state_dict[layer] + \
|
|
(1 - y) * model_1_state_dict[layer]
|
|
G_out.load_state_dict(out_state)
|
|
return G_out
|
|
|
|
|
|
def main(name):
|
|
|
|
resolution = 4
|
|
|
|
model_name = '001000.pt'
|
|
|
|
G_out = blend_models("pretrained_models/stylegan2-ffhq-config-f-256-550000.pt",
|
|
"face_generation/experiment_stylegan/"+name+"/models/"+model_name,
|
|
resolution,
|
|
None)
|
|
# G_out.save('G_blend.pth')
|
|
outdir = os.path.join('face_generation/experiment_stylegan',name,'models_blend')
|
|
if not os.path.exists(outdir):
|
|
os.makedirs(outdir)
|
|
|
|
outpath = os.path.join(outdir, 'G_blend_'+str(model_name[:-3])+'_'+ str(resolution)+'.pt')
|
|
torch.save(
|
|
{
|
|
"g_ema": G_out.state_dict(),
|
|
},
|
|
# 'G_blend_570000_16.pth',
|
|
outpath
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description="style blender")
|
|
parser.add_argument('--name', type=str, default='')
|
|
args = parser.parse_args()
|
|
print('model name:%s'%args.name)
|
|
main(args.name) |