# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Input/output checkpointing."""

import os
import random
import sys
import numpy as np
from glob import glob

import torch

from codegeex.megatron import get_args, mpu, print_rank_0, update_num_microbatches, utils

_CHECKPOINT_VERSION = None


def set_checkpoint_version(value):
    global _CHECKPOINT_VERSION
    if _CHECKPOINT_VERSION is not None:
        assert _CHECKPOINT_VERSION == value, "checkpoint versions do not match"
    _CHECKPOINT_VERSION = value


def get_checkpoint_version():
    global _CHECKPOINT_VERSION
    return _CHECKPOINT_VERSION


def check_checkpoint_args(checkpoint_args):
    """Ensure fixed arguments for a model are the same for the input
    arguments and the one retrieved from checkpoint."""
    args = get_args()

    def _compare(arg_name, old_arg_name=None):
        if old_arg_name is not None:
            checkpoint_value = getattr(checkpoint_args, old_arg_name)
        else:
            checkpoint_value = getattr(checkpoint_args, arg_name)
        args_value = getattr(args, arg_name)
        error_message = (
            "{} value from checkpoint ({}) is not equal to the "
            "input argument value ({}).".format(arg_name, checkpoint_value, args_value)
        )
        assert checkpoint_value == args_value, error_message

    _compare("num_layers")
    _compare("hidden_size")
    _compare("num_attention_heads")
    _compare("max_position_embeddings")
    if args.vocab_file:
        _compare("make_vocab_size_divisible_by")
        _compare("padded_vocab_size")
        _compare("tokenizer_type")
    if get_checkpoint_version() < 3.0:
        _compare("tensor_model_parallel_size", old_arg_name="model_parallel_size")
    if get_checkpoint_version() >= 3.0:
        _compare("tensor_model_parallel_size")
        _compare("pipeline_model_parallel_size")


def ensure_directory_exists(filename):
    """Build filename's path if it does not already exists."""
    dirname = os.path.dirname(filename)
    if not os.path.exists(dirname):
        os.makedirs(dirname)


def get_checkpoint_name(checkpoints_path, iteration, release=False):
    """A unified checkpoint name."""
    if release:
        directory = ""
    else:
        directory = f"global_step{iteration}"
    # Use both the tensor and pipeline MP rank.
    if mpu.get_pipeline_model_parallel_world_size() == 1:
        return os.path.join(
            checkpoints_path,
            directory,
            "mp_rank_{:02d}_model_states.pt".format(mpu.get_tensor_model_parallel_rank()),
        )
    return os.path.join(
        checkpoints_path,
        directory,
        "mp_rank_{:02d}_{:03d}_model_states.pt".format(
            mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank()
        ),
    )


def get_checkpoint_tracker_filename(checkpoints_path):
    """Tracker file rescords the latest chckpoint during
    training to restart from."""
    return os.path.join(checkpoints_path, "latest_checkpointed_iteration.txt")


def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    """Save a model checkpoint."""
    args = get_args()

    # Only rank zero of the data parallel writes to the disk.
    if not args.deepspeed:
        model = utils.unwrap_model(model)

    print_rank_0(
        "saving checkpoint at iteration {:7d} to {}".format(iteration, args.save)
    )

    if (
        not torch.distributed.is_initialized()
        or mpu.get_data_parallel_rank() == 0
        or args.deepspeed
    ):

        # Arguments, iteration, and model.
        state_dict = {}
        state_dict["args"] = args
        state_dict["checkpoint_version"] = 3.0
        state_dict["iteration"] = iteration
        state_dict["tokens"] = args.consumed_train_tokens

        # DeepSpeed saves the model/optimizer/scheduler
        if not args.deepspeed:
            if len(model) == 1:
                state_dict["model"] = model[0].state_dict_for_save_checkpoint()
            else:
                for i in range(len(model)):
                    mpu.set_virtual_pipeline_model_parallel_rank(i)
                    state_dict["model%d" % i] = model[
                        i
                    ].state_dict_for_save_checkpoint()

            # Optimizer stuff.
            if not args.no_save_optim:
                if optimizer is not None:
                    state_dict["optimizer"] = optimizer.state_dict()
                if lr_scheduler is not None:
                    state_dict["lr_scheduler"] = lr_scheduler.state_dict()

        # RNG states.
        if not args.no_save_rng:
            state_dict["random_rng_state"] = random.getstate()
            state_dict["np_rng_state"] = np.random.get_state()
            state_dict["torch_rng_state"] = torch.get_rng_state()
            state_dict["cuda_rng_state"] = torch.cuda.get_rng_state()
            state_dict["rng_tracker_states"] = mpu.get_cuda_rng_tracker().get_states()

        # Save.
        checkpoint_name = get_checkpoint_name(args.save, iteration)
        if not args.deepspeed:
            ensure_directory_exists(checkpoint_name)
            torch.save(state_dict, checkpoint_name)

    if args.deepspeed:
        # megatron model uses state_dict_for_save_checkpointing instead of the standard state_dict
        # state_dict is used by deepspeed for module saving so it needs to point to the right function
        if args.no_pipeline_parallel:
            original_state_dict = model[0].module.state_dict
            model[0].module.state_dict = model[0].module.state_dict_for_save_checkpoint

        # Saving is a collective communication
        checkpoint_name = get_checkpoint_name(args.save, iteration)
        # Trim off the filename and mp_rank_* directory.
        for _ in range(2):
            checkpoint_name = os.path.dirname(checkpoint_name)
        model[0].save_checkpoint(checkpoint_name, client_state=state_dict)

        if args.no_pipeline_parallel:
            model[0].module.state_dict = original_state_dict

    # Wait so everyone is done (necessary)
    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    print_rank_0(
        "  successfully saved checkpoint at iteration {:7d} to {}".format(
            iteration, args.save
        )
    )

    # And update the latest iteration
    if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        with open(tracker_filename, "w") as f:
            f.write(str(iteration))

    # Wait so everyone is done (not necessary)
    if torch.distributed.is_initialized():
        torch.distributed.barrier()


def _transpose_first_dim(t, num_splits, num_splits_first, model):
    input_shape = t.size()
    # We use a self_attention module but the values extracted aren't
    # specific to self attention so should work for cross attention as well
    while hasattr(model, "module"):
        model = model.module
    attention_module = model.language_model.encoder.layers[0].self_attention
    hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
    num_attention_heads_per_partition = (
        attention_module.num_attention_heads_per_partition
    )
    if num_splits_first:
        """[num_splits * np * hn, h]
        -->(view) [num_splits, np, hn, h]
        -->(tranpose) [np, num_splits, hn, h]
        -->(view) [np * num_splits * hn, h]"""

        intermediate_shape = (
            num_splits,
            num_attention_heads_per_partition,
            hidden_size_per_attention_head,
        ) + input_shape[1:]

        t = t.view(*intermediate_shape)
        t = t.transpose(0, 1).contiguous()
    else:
        """[np * hn * num_splits, h]
        -->(view) [np, hn, num_splits, h]
        -->(tranpose) [np, num_splits, hn, h]
        -->(view) [np * num_splits * hn, h]"""

        intermediate_shape = (
            num_attention_heads_per_partition,
            hidden_size_per_attention_head,
            num_splits,
        ) + input_shape[1:]

        t = t.view(*intermediate_shape)
        t = t.transpose(1, 2).contiguous()
    t = t.view(*input_shape)

    return t


def fix_query_key_value_ordering(model, checkpoint_version):
    """Fix up query/key/value matrix ordering if checkpoint
    version is smaller than 2.0
    """
    if checkpoint_version < 2.0:
        if isinstance(model, list):
            assert len(model) == 1
            model = model[0]
        for name, param in model.named_parameters():
            if name.endswith((".query_key_value.weight", ".query_key_value.bias")):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 3, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 3, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
            if name.endswith((".key_value.weight", ".key_value.bias")):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 2, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 2, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
        print_rank_0(
            " succesfully fixed query-key-values ordering for"
            " checkpoint version {}".format(checkpoint_version)
        )


def load_deepspeed_state(model):
    model = utils.unwrap_model(model)
    args = get_args()
    load_dir = args.load
    if os.path.isdir(load_dir):
        model_state_paths = glob(os.path.join(load_dir, "*model_states.pt"))
        assert len(model_state_paths) == 1, (
            "only support loading deepspeed checkpoint of model parallel size 1"
            ", but got {}".format(model_state_paths)
        )
        model_state_path = model_state_paths[0]
    else:
        model_state_path = load_dir
    state_dict = torch.load(model_state_path, map_location="cpu")
    state_dict = state_dict["module"]

    model[0].load_state_dict(state_dict, strict=True)


def load_checkpoint(
    model, 
    optimizer, 
    lr_scheduler, 
    load_arg="load", 
    strict=True,
):
    """Load a model checkpoint and return the iteration.
    strict (bool): whether to strictly enforce that the keys in
        :attr:`state_dict` of the checkpoint match the names of
        parameters and buffers in model.
    """
    args = get_args()
    load_dir = getattr(args, load_arg)

    if args.deepspeed:
        loaded_dir, state_dict = model[0].load_checkpoint(load_dir)
        if loaded_dir is None:
            print_rank_0(
                "WARNING: could not find the metadata file {} ".format(load_dir)
            )
            print_rank_0(
                "    will not load any checkpoints and will start from " "random"
            )
            return 0
        release = False
    else:
        model = utils.unwrap_model(model)

        if load_dir.endswith(".pt"):
            checkpoint_name = load_dir
            release = True
        else:
            # Read the tracker file and set the iteration.
            tracker_filename = get_checkpoint_tracker_filename(load_dir)

            # If no tracker file, return iretation zero.
            if not os.path.isfile(tracker_filename):
                print_rank_0(
                    "WARNING: could not find the metadata file {} ".format(tracker_filename)
                )
                iteration = 0
                release = True
                checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
                if not os.path.isfile(checkpoint_name):
                    print_rank_0(
                        "    will not load any checkpoints and will start from random"
                    )
                    return 0
            else:
                # Otherwise, read the tracker file and either set the iteration or
                # mark it as a release checkpoint.
                iteration = 0
                release = False
                with open(tracker_filename, "r") as f:
                    metastring = f.read().strip()
                    try:
                        iteration = int(metastring)
                    except ValueError:
                        release = metastring == "release"
                        if not release:
                            print_rank_0(
                                "ERROR: Invalid metadata file {}. Exiting".format(
                                    tracker_filename
                                )
                            )
                            sys.exit()

                assert iteration > 0 or release, "error parsing metadata file {}".format(
                    tracker_filename
                )

                # Checkpoint.
                checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
                print_rank_0(f" loading checkpoint from {args.load} at iteration {iteration}")
                

        # Load the checkpoint.
        try:
            state_dict = torch.load(checkpoint_name, map_location="cpu")
        except ModuleNotFoundError:
            from megatron.fp16_deprecated import loss_scaler

            # For backward compatibility.
            print_rank_0(" > deserializing using the old code structure ...")
            sys.modules["fp16.loss_scaler"] = sys.modules[
                "megatron.fp16_deprecated.loss_scaler"
            ]
            sys.modules["megatron.fp16.loss_scaler"] = sys.modules[
                "megatron.fp16_deprecated.loss_scaler"
            ]
            state_dict = torch.load(checkpoint_name, map_location="cpu")
            sys.modules.pop("fp16.loss_scaler", None)
            sys.modules.pop("megatron.fp16.loss_scaler", None)
        except BaseException as e:
            print_rank_0("could not load the checkpoint")
            print_rank_0(e)
            sys.exit()

    # set checkpoint version
    set_checkpoint_version(state_dict.get("checkpoint_version", 0))

    # Set iteration.
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = state_dict["iteration"]
            if "tokens" in state_dict:
                args.consumed_train_tokens = state_dict["tokens"]
        except KeyError:
            try:  # Backward compatible with older checkpoints
                iteration = state_dict["total_iters"]
            except KeyError:
                print_rank_0(
                    "A metadata file exists but unable to load "
                    "iteration from checkpoint {}, exiting".format(checkpoint_name)
                )
                sys.exit()

    # Check arguments.
    assert args.consumed_train_samples == 0
    assert args.consumed_valid_samples == 0
    if "args" in state_dict:
        checkpoint_args = state_dict["args"]
        check_checkpoint_args(checkpoint_args)
        args.consumed_train_samples = getattr(
            checkpoint_args, "consumed_train_samples", 0
        )
        update_num_microbatches(consumed_samples=args.consumed_train_samples)
        args.consumed_valid_samples = getattr(
            checkpoint_args, "consumed_valid_samples", 0
        )
    else:
        print_rank_0("could not find arguments in the checkpoint ...")

    # Model.
    if not args.deepspeed:
        if release:
            if len(model) == 1:
                model[0].load_state_dict(state_dict["module"], strict=strict)
            else:
                for i in range(len(model)):
                    mpu.set_virtual_pipeline_model_parallel_rank(i)
                    model[i].load_state_dict(state_dict["model%d" % i], strict=strict)
        else:
            if len(model) == 1:
                model[0].load_state_dict(state_dict["module"], strict=strict)
            else:
                for i in range(len(model)):
                    mpu.set_virtual_pipeline_model_parallel_rank(i)
                    model[i].load_state_dict(state_dict["model%d" % i], strict=strict)

    # Fix up query/key/value matrix ordering if needed
    checkpoint_version = get_checkpoint_version()
    print_rank_0(f" checkpoint version {checkpoint_version}")
    fix_query_key_value_ordering(model, checkpoint_version)

    # Optimizer.
    if not args.deepspeed:
        if not release and not args.finetune and not args.no_load_optim:
            try:
                if optimizer is not None:
                    optimizer.load_state_dict(state_dict["optimizer"])
                if lr_scheduler is not None:
                    lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
            except KeyError:
                print_rank_0(
                    "Unable to load optimizer from checkpoint {}. "
                    "Specify --no-load-optim or --finetune to prevent "
                    "attempting to load the optimizer state, "
                    "exiting ...".format(checkpoint_name)
                )
                sys.exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
            random.setstate(state_dict["random_rng_state"])
            np.random.set_state(state_dict["np_rng_state"])
            torch.set_rng_state(state_dict["torch_rng_state"])
            torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
            # Check for empty states array
            if not state_dict["rng_tracker_states"]:
                raise KeyError
            mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"])
        except KeyError:
            print_rank_0(
                "Unable to load rng state from checkpoint {}. "
                "Specify --no-load-rng or --finetune to prevent "
                "attempting to load the rng state, "
                "exiting ...".format(checkpoint_name)
            )
            sys.exit()

    # Some utilities want to load a checkpoint without distributed being initialized
    # if torch.distributed.is_initialized():
    #     torch.distributed.barrier()

    print_rank_0(
        f"  successfully loaded checkpoint from {args.load} "
        f"at iteration {iteration}"
    )

    return iteration


def load_biencoder_checkpoint(
    model, only_query_model=False, only_context_model=False, custom_load_path=None
):
    """
    selectively load retrieval models for indexing/retrieving
    from saved checkpoints
    """

    args = get_args()

    model = utils.unwrap_model(model)

    load_path = custom_load_path if custom_load_path is not None else args.load

    tracker_filename = get_checkpoint_tracker_filename(load_path)
    with open(tracker_filename, "r") as f:
        iteration = int(f.read().strip())

    checkpoint_name = get_checkpoint_name(load_path, iteration, False)
    if mpu.get_data_parallel_rank() == 0:
        print(
            "global rank {} is loading checkpoint {}".format(
                torch.distributed.get_rank(), checkpoint_name
            )
        )

    state_dict = torch.load(checkpoint_name, map_location="cpu")
    ret_state_dict = state_dict["model"]

    if only_query_model:
        ret_state_dict.pop("context_model")
    if only_context_model:
        ret_state_dict.pop("query_model")

    assert len(model) == 1
    model[0].load_state_dict(ret_state_dict)
    torch.distributed.barrier()

    if mpu.get_data_parallel_rank() == 0:
        print(" successfully loaded {}".format(checkpoint_name))

    return model