# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Input/output checkpointing.""" import os import random import sys import numpy as np from glob import glob import torch from codegeex.megatron import get_args, mpu, print_rank_0, update_num_microbatches, utils _CHECKPOINT_VERSION = None def set_checkpoint_version(value): global _CHECKPOINT_VERSION if _CHECKPOINT_VERSION is not None: assert _CHECKPOINT_VERSION == value, "checkpoint versions do not match" _CHECKPOINT_VERSION = value def get_checkpoint_version(): global _CHECKPOINT_VERSION return _CHECKPOINT_VERSION def check_checkpoint_args(checkpoint_args): """Ensure fixed arguments for a model are the same for the input arguments and the one retrieved from checkpoint.""" args = get_args() def _compare(arg_name, old_arg_name=None): if old_arg_name is not None: checkpoint_value = getattr(checkpoint_args, old_arg_name) else: checkpoint_value = getattr(checkpoint_args, arg_name) args_value = getattr(args, arg_name) error_message = ( "{} value from checkpoint ({}) is not equal to the " "input argument value ({}).".format(arg_name, checkpoint_value, args_value) ) assert checkpoint_value == args_value, error_message _compare("num_layers") _compare("hidden_size") _compare("num_attention_heads") _compare("max_position_embeddings") if args.vocab_file: _compare("make_vocab_size_divisible_by") _compare("padded_vocab_size") _compare("tokenizer_type") if get_checkpoint_version() < 3.0: _compare("tensor_model_parallel_size", old_arg_name="model_parallel_size") if get_checkpoint_version() >= 3.0: _compare("tensor_model_parallel_size") _compare("pipeline_model_parallel_size") def ensure_directory_exists(filename): """Build filename's path if it does not already exists.""" dirname = os.path.dirname(filename) if not os.path.exists(dirname): os.makedirs(dirname) def get_checkpoint_name(checkpoints_path, iteration, release=False): """A unified checkpoint name.""" if release: directory = "" else: directory = f"global_step{iteration}" # Use both the tensor and pipeline MP rank. if mpu.get_pipeline_model_parallel_world_size() == 1: return os.path.join( checkpoints_path, directory, "mp_rank_{:02d}_model_states.pt".format(mpu.get_tensor_model_parallel_rank()), ) return os.path.join( checkpoints_path, directory, "mp_rank_{:02d}_{:03d}_model_states.pt".format( mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank() ), ) def get_checkpoint_tracker_filename(checkpoints_path): """Tracker file rescords the latest chckpoint during training to restart from.""" return os.path.join(checkpoints_path, "latest_checkpointed_iteration.txt") def save_checkpoint(iteration, model, optimizer, lr_scheduler): """Save a model checkpoint.""" args = get_args() # Only rank zero of the data parallel writes to the disk. if not args.deepspeed: model = utils.unwrap_model(model) print_rank_0( "saving checkpoint at iteration {:7d} to {}".format(iteration, args.save) ) if ( not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0 or args.deepspeed ): # Arguments, iteration, and model. state_dict = {} state_dict["args"] = args state_dict["checkpoint_version"] = 3.0 state_dict["iteration"] = iteration state_dict["tokens"] = args.consumed_train_tokens # DeepSpeed saves the model/optimizer/scheduler if not args.deepspeed: if len(model) == 1: state_dict["model"] = model[0].state_dict_for_save_checkpoint() else: for i in range(len(model)): mpu.set_virtual_pipeline_model_parallel_rank(i) state_dict["model%d" % i] = model[ i ].state_dict_for_save_checkpoint() # Optimizer stuff. if not args.no_save_optim: if optimizer is not None: state_dict["optimizer"] = optimizer.state_dict() if lr_scheduler is not None: state_dict["lr_scheduler"] = lr_scheduler.state_dict() # RNG states. if not args.no_save_rng: state_dict["random_rng_state"] = random.getstate() state_dict["np_rng_state"] = np.random.get_state() state_dict["torch_rng_state"] = torch.get_rng_state() state_dict["cuda_rng_state"] = torch.cuda.get_rng_state() state_dict["rng_tracker_states"] = mpu.get_cuda_rng_tracker().get_states() # Save. checkpoint_name = get_checkpoint_name(args.save, iteration) if not args.deepspeed: ensure_directory_exists(checkpoint_name) torch.save(state_dict, checkpoint_name) if args.deepspeed: # megatron model uses state_dict_for_save_checkpointing instead of the standard state_dict # state_dict is used by deepspeed for module saving so it needs to point to the right function if args.no_pipeline_parallel: original_state_dict = model[0].module.state_dict model[0].module.state_dict = model[0].module.state_dict_for_save_checkpoint # Saving is a collective communication checkpoint_name = get_checkpoint_name(args.save, iteration) # Trim off the filename and mp_rank_* directory. for _ in range(2): checkpoint_name = os.path.dirname(checkpoint_name) model[0].save_checkpoint(checkpoint_name, client_state=state_dict) if args.no_pipeline_parallel: model[0].module.state_dict = original_state_dict # Wait so everyone is done (necessary) if torch.distributed.is_initialized(): torch.distributed.barrier() print_rank_0( " successfully saved checkpoint at iteration {:7d} to {}".format( iteration, args.save ) ) # And update the latest iteration if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: tracker_filename = get_checkpoint_tracker_filename(args.save) with open(tracker_filename, "w") as f: f.write(str(iteration)) # Wait so everyone is done (not necessary) if torch.distributed.is_initialized(): torch.distributed.barrier() def _transpose_first_dim(t, num_splits, num_splits_first, model): input_shape = t.size() # We use a self_attention module but the values extracted aren't # specific to self attention so should work for cross attention as well while hasattr(model, "module"): model = model.module attention_module = model.language_model.encoder.layers[0].self_attention hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head num_attention_heads_per_partition = ( attention_module.num_attention_heads_per_partition ) if num_splits_first: """[num_splits * np * hn, h] -->(view) [num_splits, np, hn, h] -->(tranpose) [np, num_splits, hn, h] -->(view) [np * num_splits * hn, h]""" intermediate_shape = ( num_splits, num_attention_heads_per_partition, hidden_size_per_attention_head, ) + input_shape[1:] t = t.view(*intermediate_shape) t = t.transpose(0, 1).contiguous() else: """[np * hn * num_splits, h] -->(view) [np, hn, num_splits, h] -->(tranpose) [np, num_splits, hn, h] -->(view) [np * num_splits * hn, h]""" intermediate_shape = ( num_attention_heads_per_partition, hidden_size_per_attention_head, num_splits, ) + input_shape[1:] t = t.view(*intermediate_shape) t = t.transpose(1, 2).contiguous() t = t.view(*input_shape) return t def fix_query_key_value_ordering(model, checkpoint_version): """Fix up query/key/value matrix ordering if checkpoint version is smaller than 2.0 """ if checkpoint_version < 2.0: if isinstance(model, list): assert len(model) == 1 model = model[0] for name, param in model.named_parameters(): if name.endswith((".query_key_value.weight", ".query_key_value.bias")): if checkpoint_version == 0: fixed_param = _transpose_first_dim(param.data, 3, True, model) elif checkpoint_version == 1.0: fixed_param = _transpose_first_dim(param.data, 3, False, model) else: print_rank_0(f"Invalid checkpoint version {checkpoint_version}.") sys.exit() param.data.copy_(fixed_param) if name.endswith((".key_value.weight", ".key_value.bias")): if checkpoint_version == 0: fixed_param = _transpose_first_dim(param.data, 2, True, model) elif checkpoint_version == 1.0: fixed_param = _transpose_first_dim(param.data, 2, False, model) else: print_rank_0(f"Invalid checkpoint version {checkpoint_version}.") sys.exit() param.data.copy_(fixed_param) print_rank_0( " successfully fixed query-key-values ordering for" " checkpoint version {}".format(checkpoint_version) ) def load_deepspeed_state(model): model = utils.unwrap_model(model) args = get_args() load_dir = args.load if os.path.isdir(load_dir): model_state_paths = glob(os.path.join(load_dir, "*model_states.pt")) assert len(model_state_paths) == 1, ( "only support loading deepspeed checkpoint of model parallel size 1" ", but got {}".format(model_state_paths) ) model_state_path = model_state_paths[0] else: model_state_path = load_dir state_dict = torch.load(model_state_path, map_location="cpu") state_dict = state_dict["module"] model[0].load_state_dict(state_dict, strict=True) def load_checkpoint( model, optimizer, lr_scheduler, load_arg="load", strict=True, ): """Load a model checkpoint and return the iteration. strict (bool): whether to strictly enforce that the keys in :attr:`state_dict` of the checkpoint match the names of parameters and buffers in model. """ args = get_args() load_dir = getattr(args, load_arg) if args.deepspeed: loaded_dir, state_dict = model[0].load_checkpoint(load_dir) if loaded_dir is None: print_rank_0( "WARNING: could not find the metadata file {} ".format(load_dir) ) print_rank_0( " will not load any checkpoints and will start from " "random" ) return 0 release = False else: model = utils.unwrap_model(model) if load_dir.endswith(".pt"): checkpoint_name = load_dir release = True else: # Read the tracker file and set the iteration. tracker_filename = get_checkpoint_tracker_filename(load_dir) # If no tracker file, return iretation zero. if not os.path.isfile(tracker_filename): print_rank_0( "WARNING: could not find the metadata file {} ".format(tracker_filename) ) iteration = 0 release = True checkpoint_name = get_checkpoint_name(load_dir, iteration, release) if not os.path.isfile(checkpoint_name): print_rank_0( " will not load any checkpoints and will start from random" ) return 0 else: # Otherwise, read the tracker file and either set the iteration or # mark it as a release checkpoint. iteration = 0 release = False with open(tracker_filename, "r") as f: metastring = f.read().strip() try: iteration = int(metastring) except ValueError: release = metastring == "release" if not release: print_rank_0( "ERROR: Invalid metadata file {}. Exiting".format( tracker_filename ) ) sys.exit() assert iteration > 0 or release, "error parsing metadata file {}".format( tracker_filename ) # Checkpoint. checkpoint_name = get_checkpoint_name(load_dir, iteration, release) print_rank_0(f" loading checkpoint from {args.load} at iteration {iteration}") # Load the checkpoint. try: state_dict = torch.load(checkpoint_name, map_location="cpu") except ModuleNotFoundError: from megatron.fp16_deprecated import loss_scaler # For backward compatibility. print_rank_0(" > deserializing using the old code structure ...") sys.modules["fp16.loss_scaler"] = sys.modules[ "megatron.fp16_deprecated.loss_scaler" ] sys.modules["megatron.fp16.loss_scaler"] = sys.modules[ "megatron.fp16_deprecated.loss_scaler" ] state_dict = torch.load(checkpoint_name, map_location="cpu") sys.modules.pop("fp16.loss_scaler", None) sys.modules.pop("megatron.fp16.loss_scaler", None) except BaseException as e: print_rank_0("could not load the checkpoint") print_rank_0(e) sys.exit() # set checkpoint version set_checkpoint_version(state_dict.get("checkpoint_version", 0)) # Set iteration. if args.finetune or release: iteration = 0 else: try: iteration = state_dict["iteration"] if "tokens" in state_dict: args.consumed_train_tokens = state_dict["tokens"] except KeyError: try: # Backward compatible with older checkpoints iteration = state_dict["total_iters"] except KeyError: print_rank_0( "A metadata file exists but unable to load " "iteration from checkpoint {}, exiting".format(checkpoint_name) ) sys.exit() # Check arguments. assert args.consumed_train_samples == 0 assert args.consumed_valid_samples == 0 if "args" in state_dict: checkpoint_args = state_dict["args"] check_checkpoint_args(checkpoint_args) args.consumed_train_samples = getattr( checkpoint_args, "consumed_train_samples", 0 ) update_num_microbatches(consumed_samples=args.consumed_train_samples) args.consumed_valid_samples = getattr( checkpoint_args, "consumed_valid_samples", 0 ) else: print_rank_0("could not find arguments in the checkpoint ...") # Model. if not args.deepspeed: if release: if len(model) == 1: model[0].load_state_dict(state_dict["module"], strict=strict) else: for i in range(len(model)): mpu.set_virtual_pipeline_model_parallel_rank(i) model[i].load_state_dict(state_dict["model%d" % i], strict=strict) else: if len(model) == 1: model[0].load_state_dict(state_dict["module"], strict=strict) else: for i in range(len(model)): mpu.set_virtual_pipeline_model_parallel_rank(i) model[i].load_state_dict(state_dict["model%d" % i], strict=strict) # Fix up query/key/value matrix ordering if needed checkpoint_version = get_checkpoint_version() print_rank_0(f" checkpoint version {checkpoint_version}") fix_query_key_value_ordering(model, checkpoint_version) # Optimizer. if not args.deepspeed: if not release and not args.finetune and not args.no_load_optim: try: if optimizer is not None: optimizer.load_state_dict(state_dict["optimizer"]) if lr_scheduler is not None: lr_scheduler.load_state_dict(state_dict["lr_scheduler"]) except KeyError: print_rank_0( "Unable to load optimizer from checkpoint {}. " "Specify --no-load-optim or --finetune to prevent " "attempting to load the optimizer state, " "exiting ...".format(checkpoint_name) ) sys.exit() # rng states. if not release and not args.finetune and not args.no_load_rng: try: random.setstate(state_dict["random_rng_state"]) np.random.set_state(state_dict["np_rng_state"]) torch.set_rng_state(state_dict["torch_rng_state"]) torch.cuda.set_rng_state(state_dict["cuda_rng_state"]) # Check for empty states array if not state_dict["rng_tracker_states"]: raise KeyError mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"]) except KeyError: print_rank_0( "Unable to load rng state from checkpoint {}. " "Specify --no-load-rng or --finetune to prevent " "attempting to load the rng state, " "exiting ...".format(checkpoint_name) ) sys.exit() # Some utilities want to load a checkpoint without distributed being initialized # if torch.distributed.is_initialized(): # torch.distributed.barrier() print_rank_0( f" successfully loaded checkpoint from {args.load} " f"at iteration {iteration}" ) return iteration def load_biencoder_checkpoint( model, only_query_model=False, only_context_model=False, custom_load_path=None ): """ selectively load retrieval models for indexing/retrieving from saved checkpoints """ args = get_args() model = utils.unwrap_model(model) load_path = custom_load_path if custom_load_path is not None else args.load tracker_filename = get_checkpoint_tracker_filename(load_path) with open(tracker_filename, "r") as f: iteration = int(f.read().strip()) checkpoint_name = get_checkpoint_name(load_path, iteration, False) if mpu.get_data_parallel_rank() == 0: print( "global rank {} is loading checkpoint {}".format( torch.distributed.get_rank(), checkpoint_name ) ) state_dict = torch.load(checkpoint_name, map_location="cpu") ret_state_dict = state_dict["model"] if only_query_model: ret_state_dict.pop("context_model") if only_context_model: ret_state_dict.pop("query_model") assert len(model) == 1 model[0].load_state_dict(ret_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(" successfully loaded {}".format(checkpoint_name)) return model