mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
529 lines
19 KiB
Python
529 lines
19 KiB
Python
# coding=utf-8
|
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Input/output checkpointing."""
|
|
|
|
import os
|
|
import random
|
|
import sys
|
|
import numpy as np
|
|
from glob import glob
|
|
|
|
import torch
|
|
|
|
from megatron import get_args, mpu, print_rank_0, update_num_microbatches, utils
|
|
|
|
_CHECKPOINT_VERSION = None
|
|
|
|
|
|
def set_checkpoint_version(value):
|
|
global _CHECKPOINT_VERSION
|
|
if _CHECKPOINT_VERSION is not None:
|
|
assert _CHECKPOINT_VERSION == value, "checkpoint versions do not match"
|
|
_CHECKPOINT_VERSION = value
|
|
|
|
|
|
def get_checkpoint_version():
|
|
global _CHECKPOINT_VERSION
|
|
return _CHECKPOINT_VERSION
|
|
|
|
|
|
def check_checkpoint_args(checkpoint_args):
|
|
"""Ensure fixed arguments for a model are the same for the input
|
|
arguments and the one retrieved from checkpoint."""
|
|
args = get_args()
|
|
|
|
def _compare(arg_name, old_arg_name=None):
|
|
if old_arg_name is not None:
|
|
checkpoint_value = getattr(checkpoint_args, old_arg_name)
|
|
else:
|
|
checkpoint_value = getattr(checkpoint_args, arg_name)
|
|
args_value = getattr(args, arg_name)
|
|
error_message = (
|
|
"{} value from checkpoint ({}) is not equal to the "
|
|
"input argument value ({}).".format(arg_name, checkpoint_value, args_value)
|
|
)
|
|
assert checkpoint_value == args_value, error_message
|
|
|
|
_compare("num_layers")
|
|
_compare("hidden_size")
|
|
_compare("num_attention_heads")
|
|
_compare("max_position_embeddings")
|
|
if args.vocab_file:
|
|
_compare("make_vocab_size_divisible_by")
|
|
_compare("padded_vocab_size")
|
|
_compare("tokenizer_type")
|
|
if get_checkpoint_version() < 3.0:
|
|
_compare("tensor_model_parallel_size", old_arg_name="model_parallel_size")
|
|
if get_checkpoint_version() >= 3.0:
|
|
_compare("tensor_model_parallel_size")
|
|
_compare("pipeline_model_parallel_size")
|
|
|
|
|
|
def ensure_directory_exists(filename):
|
|
"""Build filename's path if it does not already exists."""
|
|
dirname = os.path.dirname(filename)
|
|
if not os.path.exists(dirname):
|
|
os.makedirs(dirname)
|
|
|
|
|
|
def get_checkpoint_name(checkpoints_path, iteration, release=False):
|
|
"""A unified checkpoint name."""
|
|
if release:
|
|
directory = "release"
|
|
else:
|
|
directory = "iter_{:07d}".format(iteration)
|
|
# Use both the tensor and pipeline MP rank.
|
|
if mpu.get_pipeline_model_parallel_world_size() == 1:
|
|
return os.path.join(
|
|
checkpoints_path,
|
|
directory,
|
|
"mp_rank_{:02d}".format(mpu.get_tensor_model_parallel_rank()),
|
|
"model_optim_rng.pt",
|
|
)
|
|
return os.path.join(
|
|
checkpoints_path,
|
|
directory,
|
|
"mp_rank_{:02d}_{:03d}".format(
|
|
mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank()
|
|
),
|
|
"model_optim_rng.pt",
|
|
)
|
|
|
|
|
|
def get_checkpoint_tracker_filename(checkpoints_path):
|
|
"""Tracker file rescords the latest chckpoint during
|
|
training to restart from."""
|
|
return os.path.join(checkpoints_path, "latest_checkpointed_iteration.txt")
|
|
|
|
|
|
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
|
|
"""Save a model checkpoint."""
|
|
args = get_args()
|
|
|
|
# Only rank zero of the data parallel writes to the disk.
|
|
if not args.deepspeed:
|
|
model = utils.unwrap_model(model)
|
|
|
|
print_rank_0(
|
|
"saving checkpoint at iteration {:7d} to {}".format(iteration, args.save)
|
|
)
|
|
|
|
if (
|
|
not torch.distributed.is_initialized()
|
|
or mpu.get_data_parallel_rank() == 0
|
|
or args.deepspeed
|
|
):
|
|
|
|
# Arguments, iteration, and model.
|
|
state_dict = {}
|
|
state_dict["args"] = args
|
|
state_dict["checkpoint_version"] = 3.0
|
|
state_dict["iteration"] = iteration
|
|
state_dict["tokens"] = args.consumed_train_tokens
|
|
|
|
# DeepSpeed saves the model/optimizer/scheduler
|
|
if not args.deepspeed:
|
|
if len(model) == 1:
|
|
state_dict["model"] = model[0].state_dict_for_save_checkpoint()
|
|
else:
|
|
for i in range(len(model)):
|
|
mpu.set_virtual_pipeline_model_parallel_rank(i)
|
|
state_dict["model%d" % i] = model[
|
|
i
|
|
].state_dict_for_save_checkpoint()
|
|
|
|
# Optimizer stuff.
|
|
if not args.no_save_optim:
|
|
if optimizer is not None:
|
|
state_dict["optimizer"] = optimizer.state_dict()
|
|
if lr_scheduler is not None:
|
|
state_dict["lr_scheduler"] = lr_scheduler.state_dict()
|
|
|
|
# RNG states.
|
|
if not args.no_save_rng:
|
|
state_dict["random_rng_state"] = random.getstate()
|
|
state_dict["np_rng_state"] = np.random.get_state()
|
|
state_dict["torch_rng_state"] = torch.get_rng_state()
|
|
state_dict["cuda_rng_state"] = torch.cuda.get_rng_state()
|
|
state_dict["rng_tracker_states"] = mpu.get_cuda_rng_tracker().get_states()
|
|
|
|
# Save.
|
|
checkpoint_name = get_checkpoint_name(args.save, iteration)
|
|
if not args.deepspeed:
|
|
ensure_directory_exists(checkpoint_name)
|
|
torch.save(state_dict, checkpoint_name)
|
|
|
|
if args.deepspeed:
|
|
# megatron model uses state_dict_for_save_checkpointing instead of the standard state_dict
|
|
# state_dict is used by deepspeed for module saving so it needs to point to the right function
|
|
if args.no_pipeline_parallel:
|
|
original_state_dict = model[0].module.state_dict
|
|
model[0].module.state_dict = model[0].module.state_dict_for_save_checkpoint
|
|
|
|
# Saving is a collective communication
|
|
checkpoint_name = get_checkpoint_name(args.save, iteration)
|
|
# Trim off the filename and mp_rank_* directory.
|
|
for _ in range(3):
|
|
checkpoint_name = os.path.dirname(checkpoint_name)
|
|
model[0].save_checkpoint(checkpoint_name, client_state=state_dict)
|
|
|
|
if args.no_pipeline_parallel:
|
|
model[0].module.state_dict = original_state_dict
|
|
|
|
# Wait so everyone is done (necessary)
|
|
if torch.distributed.is_initialized():
|
|
torch.distributed.barrier()
|
|
|
|
print_rank_0(
|
|
" successfully saved checkpoint at iteration {:7d} to {}".format(
|
|
iteration, args.save
|
|
)
|
|
)
|
|
|
|
# And update the latest iteration
|
|
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
|
tracker_filename = get_checkpoint_tracker_filename(args.save)
|
|
with open(tracker_filename, "w") as f:
|
|
f.write(str(iteration))
|
|
|
|
# Wait so everyone is done (not necessary)
|
|
if torch.distributed.is_initialized():
|
|
torch.distributed.barrier()
|
|
|
|
|
|
def _transpose_first_dim(t, num_splits, num_splits_first, model):
|
|
input_shape = t.size()
|
|
# We use a self_attention module but the values extracted aren't
|
|
# specific to self attention so should work for cross attention as well
|
|
while hasattr(model, "module"):
|
|
model = model.module
|
|
attention_module = model.language_model.encoder.layers[0].self_attention
|
|
hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
|
|
num_attention_heads_per_partition = (
|
|
attention_module.num_attention_heads_per_partition
|
|
)
|
|
if num_splits_first:
|
|
"""[num_splits * np * hn, h]
|
|
-->(view) [num_splits, np, hn, h]
|
|
-->(tranpose) [np, num_splits, hn, h]
|
|
-->(view) [np * num_splits * hn, h]"""
|
|
|
|
intermediate_shape = (
|
|
num_splits,
|
|
num_attention_heads_per_partition,
|
|
hidden_size_per_attention_head,
|
|
) + input_shape[1:]
|
|
|
|
t = t.view(*intermediate_shape)
|
|
t = t.transpose(0, 1).contiguous()
|
|
else:
|
|
"""[np * hn * num_splits, h]
|
|
-->(view) [np, hn, num_splits, h]
|
|
-->(tranpose) [np, num_splits, hn, h]
|
|
-->(view) [np * num_splits * hn, h]"""
|
|
|
|
intermediate_shape = (
|
|
num_attention_heads_per_partition,
|
|
hidden_size_per_attention_head,
|
|
num_splits,
|
|
) + input_shape[1:]
|
|
|
|
t = t.view(*intermediate_shape)
|
|
t = t.transpose(1, 2).contiguous()
|
|
t = t.view(*input_shape)
|
|
|
|
return t
|
|
|
|
|
|
def fix_query_key_value_ordering(model, checkpoint_version):
|
|
"""Fix up query/key/value matrix ordering if checkpoint
|
|
version is smaller than 2.0
|
|
"""
|
|
if checkpoint_version < 2.0:
|
|
if isinstance(model, list):
|
|
assert len(model) == 1
|
|
model = model[0]
|
|
for name, param in model.named_parameters():
|
|
if name.endswith((".query_key_value.weight", ".query_key_value.bias")):
|
|
if checkpoint_version == 0:
|
|
fixed_param = _transpose_first_dim(param.data, 3, True, model)
|
|
elif checkpoint_version == 1.0:
|
|
fixed_param = _transpose_first_dim(param.data, 3, False, model)
|
|
else:
|
|
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
|
|
sys.exit()
|
|
param.data.copy_(fixed_param)
|
|
if name.endswith((".key_value.weight", ".key_value.bias")):
|
|
if checkpoint_version == 0:
|
|
fixed_param = _transpose_first_dim(param.data, 2, True, model)
|
|
elif checkpoint_version == 1.0:
|
|
fixed_param = _transpose_first_dim(param.data, 2, False, model)
|
|
else:
|
|
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
|
|
sys.exit()
|
|
param.data.copy_(fixed_param)
|
|
print_rank_0(
|
|
" succesfully fixed query-key-values ordering for"
|
|
" checkpoint version {}".format(checkpoint_version)
|
|
)
|
|
|
|
|
|
def load_deepspeed_state(model):
|
|
model = utils.unwrap_model(model)
|
|
args = get_args()
|
|
load_dir = args.load
|
|
if os.path.isdir(load_dir):
|
|
model_state_paths = glob(os.path.join(load_dir, "*model_states.pt"))
|
|
assert len(model_state_paths) == 1, (
|
|
"only support loading deepspeed checkpoint of model parallel size 1"
|
|
", but got {}".format(model_state_paths)
|
|
)
|
|
model_state_path = model_state_paths[0]
|
|
else:
|
|
model_state_path = load_dir
|
|
state_dict = torch.load(model_state_path, map_location="cpu")
|
|
state_dict = state_dict["module"]
|
|
|
|
model[0].load_state_dict(state_dict, strict=True)
|
|
|
|
|
|
def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load", strict=True):
|
|
"""Load a model checkpoint and return the iteration.
|
|
strict (bool): whether to strictly enforce that the keys in
|
|
:attr:`state_dict` of the checkpoint match the names of
|
|
parameters and buffers in model.
|
|
"""
|
|
args = get_args()
|
|
load_dir = getattr(args, load_arg)
|
|
|
|
if args.deepspeed:
|
|
loaded_dir, state_dict = model[0].load_checkpoint(load_dir)
|
|
if loaded_dir is None:
|
|
print_rank_0(
|
|
"WARNING: could not find the metadata file {} ".format(load_dir)
|
|
)
|
|
print_rank_0(
|
|
" will not load any checkpoints and will start from " "random"
|
|
)
|
|
return 0
|
|
release = False
|
|
else:
|
|
model = utils.unwrap_model(model)
|
|
|
|
# Read the tracker file and set the iteration.
|
|
tracker_filename = get_checkpoint_tracker_filename(load_dir)
|
|
|
|
# If no tracker file, return iretation zero.
|
|
if not os.path.isfile(tracker_filename):
|
|
print_rank_0(
|
|
"WARNING: could not find the metadata file {} ".format(tracker_filename)
|
|
)
|
|
print_rank_0(
|
|
" will not load any checkpoints and will start from " "random"
|
|
)
|
|
return 0
|
|
|
|
# Otherwise, read the tracker file and either set the iteration or
|
|
# mark it as a release checkpoint.
|
|
iteration = 0
|
|
release = False
|
|
with open(tracker_filename, "r") as f:
|
|
metastring = f.read().strip()
|
|
try:
|
|
iteration = int(metastring)
|
|
except ValueError:
|
|
release = metastring == "release"
|
|
if not release:
|
|
print_rank_0(
|
|
"ERROR: Invalid metadata file {}. Exiting".format(
|
|
tracker_filename
|
|
)
|
|
)
|
|
sys.exit()
|
|
|
|
assert iteration > 0 or release, "error parsing metadata file {}".format(
|
|
tracker_filename
|
|
)
|
|
|
|
# Checkpoint.
|
|
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
|
|
print_rank_0(f" loading checkpoint from {args.load} at iteration {iteration}")
|
|
|
|
# Load the checkpoint.
|
|
try:
|
|
state_dict = torch.load(checkpoint_name, map_location="cpu")
|
|
except ModuleNotFoundError:
|
|
from megatron.fp16_deprecated import loss_scaler
|
|
|
|
# For backward compatibility.
|
|
print_rank_0(" > deserializing using the old code structure ...")
|
|
sys.modules["fp16.loss_scaler"] = sys.modules[
|
|
"megatron.fp16_deprecated.loss_scaler"
|
|
]
|
|
sys.modules["megatron.fp16.loss_scaler"] = sys.modules[
|
|
"megatron.fp16_deprecated.loss_scaler"
|
|
]
|
|
state_dict = torch.load(checkpoint_name, map_location="cpu")
|
|
sys.modules.pop("fp16.loss_scaler", None)
|
|
sys.modules.pop("megatron.fp16.loss_scaler", None)
|
|
except BaseException as e:
|
|
print_rank_0("could not load the checkpoint")
|
|
print_rank_0(e)
|
|
sys.exit()
|
|
|
|
# set checkpoint version
|
|
set_checkpoint_version(state_dict.get("checkpoint_version", 0))
|
|
|
|
# Set iteration.
|
|
if args.finetune or release:
|
|
iteration = 0
|
|
else:
|
|
try:
|
|
iteration = state_dict["iteration"]
|
|
if "tokens" in state_dict:
|
|
args.consumed_train_tokens = state_dict["tokens"]
|
|
except KeyError:
|
|
try: # Backward compatible with older checkpoints
|
|
iteration = state_dict["total_iters"]
|
|
except KeyError:
|
|
print_rank_0(
|
|
"A metadata file exists but unable to load "
|
|
"iteration from checkpoint {}, exiting".format(checkpoint_name)
|
|
)
|
|
sys.exit()
|
|
|
|
# Check arguments.
|
|
assert args.consumed_train_samples == 0
|
|
assert args.consumed_valid_samples == 0
|
|
if "args" in state_dict:
|
|
checkpoint_args = state_dict["args"]
|
|
check_checkpoint_args(checkpoint_args)
|
|
args.consumed_train_samples = getattr(
|
|
checkpoint_args, "consumed_train_samples", 0
|
|
)
|
|
update_num_microbatches(consumed_samples=args.consumed_train_samples)
|
|
args.consumed_valid_samples = getattr(
|
|
checkpoint_args, "consumed_valid_samples", 0
|
|
)
|
|
else:
|
|
print_rank_0("could not find arguments in the checkpoint ...")
|
|
|
|
# Model.
|
|
if not args.deepspeed:
|
|
if len(model) == 1:
|
|
model[0].load_state_dict(state_dict["model"], strict=strict)
|
|
else:
|
|
for i in range(len(model)):
|
|
mpu.set_virtual_pipeline_model_parallel_rank(i)
|
|
model[i].load_state_dict(state_dict["model%d" % i], strict=strict)
|
|
|
|
# Fix up query/key/value matrix ordering if needed
|
|
checkpoint_version = get_checkpoint_version()
|
|
print_rank_0(f" checkpoint version {checkpoint_version}")
|
|
fix_query_key_value_ordering(model, checkpoint_version)
|
|
|
|
# Optimizer.
|
|
if not args.deepspeed:
|
|
if not release and not args.finetune and not args.no_load_optim:
|
|
try:
|
|
if optimizer is not None:
|
|
optimizer.load_state_dict(state_dict["optimizer"])
|
|
if lr_scheduler is not None:
|
|
lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
|
|
except KeyError:
|
|
print_rank_0(
|
|
"Unable to load optimizer from checkpoint {}. "
|
|
"Specify --no-load-optim or --finetune to prevent "
|
|
"attempting to load the optimizer state, "
|
|
"exiting ...".format(checkpoint_name)
|
|
)
|
|
sys.exit()
|
|
|
|
# rng states.
|
|
if not release and not args.finetune and not args.no_load_rng:
|
|
try:
|
|
random.setstate(state_dict["random_rng_state"])
|
|
np.random.set_state(state_dict["np_rng_state"])
|
|
torch.set_rng_state(state_dict["torch_rng_state"])
|
|
torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
|
|
# Check for empty states array
|
|
if not state_dict["rng_tracker_states"]:
|
|
raise KeyError
|
|
mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"])
|
|
except KeyError:
|
|
print_rank_0(
|
|
"Unable to load rng state from checkpoint {}. "
|
|
"Specify --no-load-rng or --finetune to prevent "
|
|
"attempting to load the rng state, "
|
|
"exiting ...".format(checkpoint_name)
|
|
)
|
|
sys.exit()
|
|
|
|
# Some utilities want to load a checkpoint without distributed being initialized
|
|
# if torch.distributed.is_initialized():
|
|
# torch.distributed.barrier()
|
|
|
|
print_rank_0(
|
|
f" successfully loaded checkpoint from {args.load} "
|
|
f"at iteration {iteration}"
|
|
)
|
|
|
|
return iteration
|
|
|
|
|
|
def load_biencoder_checkpoint(
|
|
model, only_query_model=False, only_context_model=False, custom_load_path=None
|
|
):
|
|
"""
|
|
selectively load retrieval models for indexing/retrieving
|
|
from saved checkpoints
|
|
"""
|
|
|
|
args = get_args()
|
|
|
|
model = utils.unwrap_model(model)
|
|
|
|
load_path = custom_load_path if custom_load_path is not None else args.load
|
|
|
|
tracker_filename = get_checkpoint_tracker_filename(load_path)
|
|
with open(tracker_filename, "r") as f:
|
|
iteration = int(f.read().strip())
|
|
|
|
checkpoint_name = get_checkpoint_name(load_path, iteration, False)
|
|
if mpu.get_data_parallel_rank() == 0:
|
|
print(
|
|
"global rank {} is loading checkpoint {}".format(
|
|
torch.distributed.get_rank(), checkpoint_name
|
|
)
|
|
)
|
|
|
|
state_dict = torch.load(checkpoint_name, map_location="cpu")
|
|
ret_state_dict = state_dict["model"]
|
|
|
|
if only_query_model:
|
|
ret_state_dict.pop("context_model")
|
|
if only_context_model:
|
|
ret_state_dict.pop("query_model")
|
|
|
|
assert len(model) == 1
|
|
model[0].load_state_dict(ret_state_dict)
|
|
torch.distributed.barrier()
|
|
|
|
if mpu.get_data_parallel_rank() == 0:
|
|
print(" successfully loaded {}".format(checkpoint_name))
|
|
|
|
return model
|