You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

585 lines
18 KiB
Python

import os
import argparse
import math
import random
import json
import numpy as np
import torch
from torch import nn, autograd, optim
from torch.nn import functional as F
from torch.utils import data
import torch.distributed as dist
from torchvision import transforms, utils
from tqdm import tqdm
from criteria import id_loss
try:
import wandb
except ImportError:
wandb = None
from dataset import MultiResolutionDataset
from distributed import (
get_rank,
synchronize,
reduce_loss_dict,
reduce_sum,
get_world_size,
)
from op import conv2d_gradfix
from non_leaking import augment, AdaptiveAugment
def data_sampler(dataset, shuffle, distributed):
if distributed:
return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
return data.RandomSampler(dataset)
else:
return data.SequentialSampler(dataset)
def requires_grad(model, flag=True):
for p in model.parameters():
p.requires_grad = flag
def accumulate(model1, model2, decay=0.999):
par1 = dict(model1.named_parameters())
par2 = dict(model2.named_parameters())
for k in par1.keys():
par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
def sample_data(loader):
while True:
for batch in loader:
yield batch
def d_logistic_loss(real_pred, fake_pred):
real_loss = F.softplus(-real_pred)
fake_loss = F.softplus(fake_pred)
return real_loss.mean() + fake_loss.mean()
def d_r1_loss(real_pred, real_img):
with conv2d_gradfix.no_weight_gradients():
grad_real, = autograd.grad(
outputs=real_pred.sum(), inputs=real_img, create_graph=True
)
grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
def g_nonsaturating_loss(fake_pred):
loss = F.softplus(-fake_pred).mean()
return loss
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
noise = torch.randn_like(fake_img) / math.sqrt(
fake_img.shape[2] * fake_img.shape[3]
)
grad, = autograd.grad(
outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
)
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
path_penalty = (path_lengths - path_mean).pow(2).mean()
return path_penalty, path_mean.detach(), path_lengths
def make_noise(batch, latent_dim, n_noise, device):
if n_noise == 1:
return torch.randn(batch, latent_dim, device=device)
noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
return noises
def mixing_noise(batch, latent_dim, prob, device):
if prob > 0 and random.random() < prob:
return make_noise(batch, latent_dim, 2, device)
else:
return [make_noise(batch, latent_dim, 1, device)]
def set_grad_none(model, targets):
for n, p in model.named_parameters():
if n in targets:
p.grad = None
def train(args, loader, generator, generator_source, discriminator, g_optim, d_optim, g_ema, device):
loader = sample_data(loader)
pbar = range(args.iter)
if get_rank() == 0:
pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)
mean_path_length = 0
d_loss_val = 0
r1_loss = torch.tensor(0.0, device=device)
g_loss_val = 0
path_loss = torch.tensor(0.0, device=device)
path_lengths = torch.tensor(0.0, device=device)
mean_path_length_avg = 0
loss_dict = {}
### add id loss, content loss
g_id_loss = id_loss.IDLoss().to(device).eval()
if args.distributed:
g_module = generator.module
d_module = discriminator.module
else:
g_module = generator
d_module = discriminator
accum = 0.5 ** (32 / (10 * 1000))
ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
r_t_stat = 0
if args.augment and args.augment_p == 0:
ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8, device)
# sample_z = torch.randn(args.n_sample, args.latent, device=device)
sample_z = torch.load('noise.pt').to(device)
for idx in pbar:
i = idx + args.start_iter
if i > args.iter:
print("Done!")
break
real_img = next(loader)
real_img = real_img.to(device)
requires_grad(generator, False)
requires_grad(discriminator, True)
noise = mixing_noise(args.batch, args.latent, args.mixing, device)
fake_img, _ = generator(noise)
if args.augment:
real_img_aug, _ = augment(real_img, ada_aug_p)
fake_img, _ = augment(fake_img, ada_aug_p)
else:
real_img_aug = real_img
fake_pred = discriminator(fake_img)
real_pred = discriminator(real_img_aug)
d_loss = d_logistic_loss(real_pred, fake_pred)
loss_dict["d"] = d_loss
loss_dict["real_score"] = real_pred.mean()
loss_dict["fake_score"] = fake_pred.mean()
discriminator.zero_grad()
d_loss.backward()
d_optim.step()
if args.augment and args.augment_p == 0:
ada_aug_p = ada_augment.tune(real_pred)
r_t_stat = ada_augment.r_t_stat
d_regularize = i % args.d_reg_every == 0
if d_regularize:
real_img.requires_grad = True
if args.augment:
real_img_aug, _ = augment(real_img, ada_aug_p)
else:
real_img_aug = real_img
real_pred = discriminator(real_img_aug)
r1_loss = d_r1_loss(real_pred, real_img)
discriminator.zero_grad()
(args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()
d_optim.step()
loss_dict["r1"] = r1_loss
requires_grad(generator, True)
requires_grad(discriminator, False)
noise = mixing_noise(args.batch, args.latent, args.mixing, device)
fake_img, _ = generator(noise)
if args.augment:
fake_img, _ = augment(fake_img, ada_aug_p)
fake_pred = discriminator(fake_img)
g_loss = g_nonsaturating_loss(fake_pred)
### id loss
fake_img_source, _ = generator_source(noise)
loss_id = g_id_loss(fake_img, fake_img_source)
loss_id = float(loss_id)
loss_dict['loss_id'] = loss_id
# print(loss_id)
g_loss += loss_id*0.1
loss_dict["g"] = g_loss
generator.zero_grad()
g_loss.backward()
g_optim.step()
g_regularize = i % args.g_reg_every == 0
if g_regularize:
path_batch_size = max(1, args.batch // args.path_batch_shrink)
noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
fake_img, latents = generator(noise, return_latents=True)
path_loss, mean_path_length, path_lengths = g_path_regularize(
fake_img, latents, mean_path_length
)
generator.zero_grad()
weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
if args.path_batch_shrink:
weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
weighted_path_loss.backward()
g_optim.step()
mean_path_length_avg = (
reduce_sum(mean_path_length).item() / get_world_size()
)
loss_dict["path"] = path_loss
loss_dict["path_length"] = path_lengths.mean()
accumulate(g_ema, g_module, accum)
loss_reduced = reduce_loss_dict(loss_dict)
d_loss_val = loss_reduced["d"].mean().item()
g_loss_val = loss_reduced["g"].mean().item()
r1_val = loss_reduced["r1"].mean().item()
path_loss_val = loss_reduced["path"].mean().item()
real_score_val = loss_reduced["real_score"].mean().item()
fake_score_val = loss_reduced["fake_score"].mean().item()
path_length_val = loss_reduced["path_length"].mean().item()
if get_rank() == 0:
pbar.set_description(
(
f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
f"augment: {ada_aug_p:.4f}"
)
)
if wandb and args.wandb:
wandb.log(
{
"Generator": g_loss_val,
"Discriminator": d_loss_val,
"Augment": ada_aug_p,
"Rt": r_t_stat,
"R1": r1_val,
"Path Length Regularization": path_loss_val,
"Mean Path Length": mean_path_length,
"Real Score": real_score_val,
"Fake Score": fake_score_val,
"Path Length": path_length_val,
}
)
if i % args.sample_every == 0:
with torch.no_grad():
g_ema.eval()
sample, _ = g_ema([sample_z])
outpath = os.path.join(args.output, args.name,'sample', str(i).zfill(6)+'.png')
utils.save_image(
sample,
outpath,
# nrow=int(args.n_sample ** 0.5),
nrow=int(math.sqrt(sample.shape[0])),
normalize=True,
range=(-1, 1),
)
if i % args.save_every == 0:
outpath_ckpt = os.path.join(args.output, args.name, 'models', str(i).zfill(6) + '.pt')
torch.save(
{
"g": g_module.state_dict(),
"d": d_module.state_dict(),
"g_ema": g_ema.state_dict(),
"g_optim": g_optim.state_dict(),
"d_optim": d_optim.state_dict(),
"args": args,
"ada_aug_p": ada_aug_p,
},
outpath_ckpt,
)
if __name__ == "__main__":
device = "cuda"
parser = argparse.ArgumentParser(description="StyleGAN2 trainer")
parser.add_argument('--config', type=str, default='config/conf_server_train_condition.json')
parser.add_argument("--path", type=str, help="path to the lmdb dataset")
parser.add_argument("--name", type=str, help="name of experiment")
parser.add_argument('--arch', type=str, default='stylegan2', help='model architectures (stylegan2 | swagan)')
parser.add_argument(
"--iter", type=int, default=800000, help="total training iterations"
)
parser.add_argument(
"--batch", type=int, default=1, help="batch sizes for each gpus"
)
parser.add_argument(
"--n_sample",
type=int,
default=64,
help="number of the samples generated during training",
)
parser.add_argument(
"--size", type=int, default=1024, help="image sizes for the model"
)
parser.add_argument(
"--r1", type=float, default=10, help="weight of the r1 regularization"
)
parser.add_argument(
"--path_regularize",
type=float,
default=2,
help="weight of the path length regularization",
)
parser.add_argument(
"--path_batch_shrink",
type=int,
default=2,
help="batch size reducing factor for the path length regularization (reduce memory consumption)",
)
parser.add_argument(
"--d_reg_every",
type=int,
default=16,
help="interval of the applying r1 regularization",
)
parser.add_argument(
"--g_reg_every",
type=int,
default=4,
help="interval of the applying path length regularization",
)
parser.add_argument(
"--mixing", type=float, default=0.9, help="probability of latent code mixing"
)
parser.add_argument(
"--ckpt",
type=str,
help="path to the checkpoints to resume training",
)
parser.add_argument("--lr", type=float, default=0.002, help="learning rate")
parser.add_argument(
"--channel_multiplier",
type=int,
default=2,
help="channel multiplier factor for the model. config-f = 2, else = 1",
)
parser.add_argument(
"--wandb", action="store_true", help="use weights and biases logging"
)
parser.add_argument(
"--local_rank", type=int, default=0, help="local rank for distributed training"
)
parser.add_argument(
"--augment", action="store_true", help="apply non leaking augmentation"
)
parser.add_argument(
"--augment_p",
type=float,
default=0,
help="probability of applying augmentation. 0 = use adaptive augmentation",
)
parser.add_argument(
"--ada_target",
type=float,
default=0.6,
help="target augmentation probability for adaptive augmentation",
)
parser.add_argument(
"--ada_length",
type=int,
default=500 * 1000,
help="target duraing to reach augmentation probability for adaptive augmentation",
)
parser.add_argument(
"--ada_every",
type=int,
default=256,
help="probability update interval of the adaptive augmentation",
)
args = parser.parse_args()
# from config updata paras
opt = vars(args)
with open(args.config) as f:
config = json.load(f)['parameters']
for key, value in config.items():
opt[key] = value
output_sample = os.path.join(args.output, args.name,'sample')
output_model = os.path.join(args.output, args.name,'models')
if not os.path.exists(output_sample):
os.makedirs(output_sample)
if not os.path.exists(output_model):
os.makedirs(output_model)
n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
args.distributed = n_gpu > 1
if args.distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
synchronize()
args.latent = 512
args.n_mlp = 8
args.start_iter = 0
if args.arch == 'stylegan2':
from model import Generator, Discriminator
elif args.arch == 'swagan':
from swagan import Generator, Discriminator
generator = Generator(
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
).to(device)
### add source G
generator_source = Generator(
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
).to(device)
generator_source.eval()
discriminator = Discriminator(
args.size, channel_multiplier=args.channel_multiplier
).to(device)
g_ema = Generator(
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
).to(device)
g_ema.eval()
accumulate(g_ema, generator, 0)
g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)
g_optim = optim.Adam(
generator.parameters(),
lr=args.lr * g_reg_ratio,
betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
)
d_optim = optim.Adam(
discriminator.parameters(),
lr=args.lr * d_reg_ratio,
betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
)
if args.ckpt is not None:
print("load model:", args.ckpt)
ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)
try:
ckpt_name = os.path.basename(args.ckpt)
args.start_iter = int(os.path.splitext(ckpt_name)[0])
except ValueError:
pass
generator.load_state_dict(ckpt["g"], strict=False)
discriminator.load_state_dict(ckpt["d"], strict=False)
g_ema.load_state_dict(ckpt["g_ema"], strict=False)
if args.ckpt_ffhq is not None:
ckpt_ffhq = torch.load(args.ckpt_ffhq, map_location=lambda storage, loc: storage)
generator_source.load_state_dict(ckpt_ffhq["g"], strict=False)
else:
print('No source ckpt!')
# generator.load_state_dict(ckpt["g"])
# discriminator.load_state_dict(ckpt["d"])
# g_ema.load_state_dict(ckpt["g_ema"])
# g_optim.load_state_dict(ckpt["g_optim"])
# d_optim.load_state_dict(ckpt["d_optim"])
if args.distributed:
generator = nn.parallel.DistributedDataParallel(
generator,
device_ids=[args.local_rank],
output_device=args.local_rank,
broadcast_buffers=False,
)
discriminator = nn.parallel.DistributedDataParallel(
discriminator,
device_ids=[args.local_rank],
output_device=args.local_rank,
broadcast_buffers=False,
)
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
]
)
dataset = MultiResolutionDataset(args.path, transform, args.size)
loader = data.DataLoader(
dataset,
batch_size=args.batch,
sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed),
drop_last=True,
)
if get_rank() == 0 and wandb is not None and args.wandb:
wandb.init(project="stylegan 2")
train(args, loader, generator, generator_source, discriminator, g_optim, d_optim, g_ema, device)