diff --git a/codegeex/megatron/merge_ckpt_parallel.py b/codegeex/megatron/merge_ckpt_parallel.py index 2aea62d..f032bef 100644 --- a/codegeex/megatron/merge_ckpt_parallel.py +++ b/codegeex/megatron/merge_ckpt_parallel.py @@ -25,12 +25,6 @@ def get_change_ckpt_args(parser): required=True, help='path to save ".pt" checkpoint.', ) - group.add_argument( - '--save-name', - type=str, - default="mp_rank_00_model_states.pt", - help='name of saved ".pt" checkpoint.', - ) group.add_argument( '--source-tensor-model-parallel-size', type=int, @@ -73,8 +67,6 @@ def main(): exit(0) print(f"Merging {len(state_dict_list)} partitions into a single ckpt...") - output_state_dict = {} - print("Merging Embedding layers...") vocab_parallel_size = args.make_vocab_size_divisible_by // args.source_tensor_model_parallel_size for i in range(args.source_tensor_model_parallel_size): @@ -85,43 +77,54 @@ def main(): print("Merging QueryEmbedding layers...") query_parallel_size = args.max_position_embeddings // args.source_tensor_model_parallel_size for i in range(args.source_tensor_model_parallel_size): - sd['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight'][i * query_parallel_size : (i + 1) * query_parallel_size, :] = state_dict_list[i]['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight'] + sd['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight'][i * query_parallel_size : (i + 1) * query_parallel_size, :] = state_dict_list[i]['module']['language_model']['topQueryEmbedding']['top_query_embeddings'].pop('weight', None) print("Merging Transformer layers...") for layer_name in sd['module']['language_model']['transformer'].keys(): if "layernorm" in layer_name: - sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'][layer_name] + sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None) elif "attention" in layer_name and "weight" in layer_name: if "dense" in layer_name: hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[1] // args.source_tensor_model_parallel_size for i in range(args.source_tensor_model_parallel_size): - sd['module']['language_model']['transformer'][layer_name][:, i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'][layer_name] + sd['module']['language_model']['transformer'][layer_name][:, i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None) else: hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size for i in range(args.source_tensor_model_parallel_size): - sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size, :] = state_dict_list[i]['module']['language_model']['transformer'][layer_name] + sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size, :] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None) elif "weight" in layer_name and "dense" in layer_name: if "h_to_4h" in layer_name: hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size for i in range(args.source_tensor_model_parallel_size): - sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size, :] = state_dict_list[i]['module']['language_model']['transformer'][layer_name] + sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size, :] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None) else: hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[1] // args.source_tensor_model_parallel_size for i in range(args.source_tensor_model_parallel_size): - sd['module']['language_model']['transformer'][layer_name][:, i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'][layer_name] + sd['module']['language_model']['transformer'][layer_name][:, i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None) elif "bias" in layer_name: - if "dense" not in layer_name or "mlp" in layer_name: + if "mlp" in layer_name: if "4h_to_h" in layer_name: - sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'][layer_name] + sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None) else: hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size for i in range(args.source_tensor_model_parallel_size): - sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'][layer_name] + sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None) + elif "attention" in layer_name: + if "dense" in layer_name: + sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None) + else: + hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size + for i in range(args.source_tensor_model_parallel_size): + sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None) else: - sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'][layer_name] - - os.makedirs(args.save_ckpt_path, exist_ok=True) - save_ckpt_path = os.path.join(args.save_ckpt_path, args.save_name) + sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None) + + if args.save_ckpt_path.endswith(".pt"): + save_ckpt_path = args.save_ckpt_path + else: + os.makedirs(args.save_ckpt_path, exist_ok=True) + save_ckpt_path = os.path.join(args.save_ckpt_path, "mp_rank_00_model_states.pt") + torch.save(sd, save_ckpt_path) print(f"Converted checkpoint saved in {save_ckpt_path}.")