import argparse
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
import torch
from torchvision import utils
from model import Generator
from tqdm import tqdm
import json
import glob

from PIL import Image

def make_image(tensor):
    return (
        tensor.detach()
        .clamp_(min=-1, max=1)
        .add(1)
        .div_(2)
        .mul(255)
        .type(torch.uint8)
        .permute(0, 2, 3, 1)
        .to("cpu")
        .numpy()
    )

def generate(args, g_ema, device, mean_latent, model_name, g_ema_ffhq):

    outdir = args.save_dir

    # print(outdir)
    # outdir = os.path.join(args.output, args.name, 'eval','toons_paired_0512')
    if not os.path.exists(outdir):
        os.makedirs(outdir)

    with torch.no_grad():
        g_ema.eval()
        for i in tqdm(range(args.pics)):
            sample_z = torch.randn(args.sample, args.latent, device=device)

            res, _ = g_ema(
                [sample_z], truncation=args.truncation, truncation_latent=mean_latent
            )
            if args.form == "pair":
                sample_face, _ = g_ema_ffhq(
                    [sample_z], truncation=args.truncation, truncation_latent=mean_latent
                )
                res = torch.cat([sample_face, res], 3)

            outpath = os.path.join(outdir, str(i).zfill(6)+'.png')
            utils.save_image(
                res,
                outpath,
                # f"sample/{str(i).zfill(6)}.png",
                nrow=1,
                normalize=True,
                range=(-1, 1),
            )
            # print('save %s'% outpath)




if __name__ == "__main__":
    device = "cuda"

    parser = argparse.ArgumentParser(description="Generate samples from the generator")
    parser.add_argument('--config', type=str, default='config/conf_server_test_blend_shell.json')
    parser.add_argument('--name', type=str, default='')
    parser.add_argument('--save_dir', type=str, default='')

    parser.add_argument('--form', type=str, default='single')
    parser.add_argument(
        "--size", type=int, default=256, help="output image size of the generator"
    )
    parser.add_argument(
        "--sample",
        type=int,
        default=1,
        help="number of samples to be generated for each image",
    )
    parser.add_argument(
        "--pics", type=int, default=20, help="number of images to be generated"
    )
    parser.add_argument("--truncation", type=float, default=1, help="truncation ratio")
    parser.add_argument(
        "--truncation_mean",
        type=int,
        default=4096,
        help="number of vectors to calculate mean for the truncation",
    )
    parser.add_argument(
        "--ckpt",
        type=str,
        default="stylegan2-ffhq-config-f.pt",
        help="path to the model checkpoint",
    )
    parser.add_argument(
        "--channel_multiplier",
        type=int,
        default=2,
        help="channel multiplier of the generator. config-f = 2, else = 1",
    )

    args = parser.parse_args()
    # from config updata paras
    opt = vars(args)
    with open(args.config) as f:
        config = json.load(f)['parameters']
        for key, value in config.items():
            opt[key] = value

    # args.ckpt = 'face_generation/experiment_stylegan/'+args.name+'/models_blend/G_blend_001000_4.pt'
    args.ckpt = 'face_generation/experiment_stylegan/'+args.name+'/models_blend/G_blend_'
    args.ckpt = glob.glob(args.ckpt+'*')[0]

    args.latent = 512
    args.n_mlp = 8

    g_ema = Generator(
        args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
    ).to(device)
    checkpoint = torch.load(args.ckpt)

    # g_ema.load_state_dict(checkpoint["g_ema"])
    g_ema.load_state_dict(checkpoint["g_ema"], strict=False)

    ## add G_ffhq
    g_ema_ffhq = Generator(
        args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
    ).to(device)
    checkpoint_ffhq = torch.load(args.ffhq_ckpt)
    g_ema_ffhq.load_state_dict(checkpoint_ffhq["g_ema"], strict=False)


    if args.truncation < 1:
        with torch.no_grad():
            mean_latent = g_ema.mean_latent(args.truncation_mean)
    else:
        mean_latent = None

    model_name = os.path.basename(args.ckpt)
    print('save generated samples to %s'% os.path.join(args.output, args.name, 'eval_blend', model_name))
    generate(args, g_ema, device, mean_latent, model_name, g_ema_ffhq)
    # generate_style_mix(args, g_ema, device, mean_latent, model_name, g_ema_ffhq)

    # latent_path = 'test2.pt'
    # generate_from_latent(args, g_ema, device, mean_latent, latent_path)