diff --git a/codegeex/megatron/convert_ckpt_parallel.py b/codegeex/megatron/convert_ckpt_parallel.py index 7b625c1..f1b0e50 100644 --- a/codegeex/megatron/convert_ckpt_parallel.py +++ b/codegeex/megatron/convert_ckpt_parallel.py @@ -72,7 +72,7 @@ def main(): word_emb_dict = get_element_from_dict_by_path( output_state_dict[i], "module.language_model.embedding.word_embeddings" ) - word_emb_dict["weight"] = out_word_embeddings[i] + word_emb_dict["weight"] = out_word_embeddings[i].clone() print("Converting QueryEmbedding layers...") query_embeddings = state_dict['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight'] @@ -82,7 +82,7 @@ def main(): query_emb_dict = get_element_from_dict_by_path( output_state_dict[i], "module.language_model.topQueryEmbedding.top_query_embeddings" ) - query_emb_dict["weight"] = out_query_embeddings[i] + query_emb_dict["weight"] = out_query_embeddings[i].clone() print("Converting Transformer layers...") for layer_name in state_dict['module']['language_model']['transformer'].keys(): @@ -109,10 +109,10 @@ def main(): for i in range(args.target_tensor_model_parallel_size): params_dict = get_element_from_dict_by_path(output_state_dict[i], "module.language_model.transformer") if type(params) is tuple: - params_dict[layer_name] = params[i] + params_dict[layer_name] = params[i].clone() else: params_dict[layer_name] = params - + os.makedirs(args.save_ckpt_path, exist_ok=True) for rank in range(args.target_tensor_model_parallel_size): save_ckpt_path = os.path.join(args.save_ckpt_path, f"mp_rank_{rank:02d}_model_states.pt") diff --git a/codegeex/megatron/merge_ckpt_parallel.py b/codegeex/megatron/merge_ckpt_parallel.py index f032bef..c85f40f 100644 --- a/codegeex/megatron/merge_ckpt_parallel.py +++ b/codegeex/megatron/merge_ckpt_parallel.py @@ -25,6 +25,11 @@ def get_change_ckpt_args(parser): required=True, help='path to save ".pt" checkpoint.', ) + group.add_argument( + '--save-name', + type=str, + help='name of checkpoint.', + ) group.add_argument( '--source-tensor-model-parallel-size', type=int, @@ -123,7 +128,10 @@ def main(): save_ckpt_path = args.save_ckpt_path else: os.makedirs(args.save_ckpt_path, exist_ok=True) - save_ckpt_path = os.path.join(args.save_ckpt_path, "mp_rank_00_model_states.pt") + if args.save_name: + save_ckpt_path = os.path.join(args.save_ckpt_path, args.save_name) + else: + save_ckpt_path = os.path.join(args.save_ckpt_path, "mp_rank_00_model_states.pt") torch.save(sd, save_ckpt_path) print(f"Converted checkpoint saved in {save_ckpt_path}.")