You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

145 lines
4.4 KiB
Python

# model blending technique
import os
import cv2 as cv
import torch
from model import Generator
import math
import argparse
def extract_conv_names(model):
model = list(model.keys())
conv_name = []
resolutions = [4*2**x for x in range(9)]
level_names = [["Conv0_up", "Const"], ["Conv1", "ToRGB"]]
def blend_models(model_1, model_2, resolution, level, blend_width=None):
resolutions = [4 * 2 ** i for i in range(7)]
mid = resolutions.index(resolution)
device = "cuda"
size = 256
latent = 512
n_mlp = 8
channel_multiplier =2
G_1 = Generator(
size, latent, n_mlp, channel_multiplier=channel_multiplier
).to(device)
ckpt_ffhq = torch.load(model_1, map_location=lambda storage, loc: storage)
G_1.load_state_dict(ckpt_ffhq["g"], strict=False)
G_2 = Generator(
size, latent, n_mlp, channel_multiplier=channel_multiplier
).to(device)
ckpt_toon = torch.load(model_2)
G_2.load_state_dict(ckpt_toon["g_ema"])
# G_1 = stylegan2.models.load(model_1)
# G_2 = stylegan2.models.load(model_2)
model_1_state_dict = G_1.state_dict()
model_2_state_dict = G_2.state_dict()
assert(model_1_state_dict.keys() == model_2_state_dict.keys())
G_out = G_1.clone()
layers = []
ys = []
for k, v in model_1_state_dict.items():
if k.startswith('convs.'):
pos = int(k[len('convs.')])
x = pos - mid
if blend_width:
exponent = -x / blend_width
y = 1 / (1 + math.exp(exponent))
else:
y = 1 if x > 0 else 0
layers.append(k)
ys.append(y)
elif k.startswith('to_rgbs.'):
pos = int(k[len('to_rgbs.')])
x = pos - mid
if blend_width:
exponent = -x / blend_width
y = 1 / (1 + math.exp(exponent))
else:
y = 1 if x > 0 else 0
layers.append(k)
ys.append(y)
out_state = G_out.state_dict()
for y, layer in zip(ys, layers):
out_state[layer] = y * model_2_state_dict[layer] + \
(1 - y) * model_1_state_dict[layer]
print('blend layer %s'%str(y))
G_out.load_state_dict(out_state)
return G_out
def blend_models_2(model_1, model_2, resolution, level, blend_width=None):
# resolution = f"{resolution}x{resolution}"
resolutions = [4 * 2 ** i for i in range(7)]
mid = [resolutions.index(r) for r in resolution]
G_1 = stylegan2.models.load(model_1)
G_2 = stylegan2.models.load(model_2)
model_1_state_dict = G_1.state_dict()
model_2_state_dict = G_2.state_dict()
assert(model_1_state_dict.keys() == model_2_state_dict.keys())
G_out = G_1.clone()
layers = []
ys = []
for k, v in model_1_state_dict.items():
if k.startswith('G_synthesis.conv_blocks.'):
pos = int(k[len('G_synthesis.conv_blocks.')])
y = 0 if pos in mid else 1
layers.append(k)
ys.append(y)
elif k.startswith('G_synthesis.to_data_layers.'):
pos = int(k[len('G_synthesis.to_data_layers.')])
y = 0 if pos in mid else 1
layers.append(k)
ys.append(y)
# print(ys, layers)
out_state = G_out.state_dict()
for y, layer in zip(ys, layers):
out_state[layer] = y * model_2_state_dict[layer] + \
(1 - y) * model_1_state_dict[layer]
G_out.load_state_dict(out_state)
return G_out
def main(name):
resolution = 4
model_name = '001000.pt'
G_out = blend_models("pretrained_models/stylegan2-ffhq-config-f-256-550000.pt",
"face_generation/experiment_stylegan/"+name+"/models/"+model_name,
resolution,
None)
# G_out.save('G_blend.pth')
outdir = os.path.join('face_generation/experiment_stylegan',name,'models_blend')
if not os.path.exists(outdir):
os.makedirs(outdir)
outpath = os.path.join(outdir, 'G_blend_'+str(model_name[:-3])+'_'+ str(resolution)+'.pt')
torch.save(
{
"g_ema": G_out.state_dict(),
},
# 'G_blend_570000_16.pth',
outpath
)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="style blender")
parser.add_argument('--name', type=str, default='')
args = parser.parse_args()
print('model name:%s'%args.name)
main(args.name)