mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
142 lines
7.6 KiB
Python
142 lines
7.6 KiB
Python
"""Merge model parallel partitions into a single checkpoint."""
|
|
|
|
import os
|
|
import torch
|
|
import random
|
|
|
|
from codegeex.megatron import get_args
|
|
from codegeex.megatron.model import CodeGeeXModel
|
|
from codegeex.megatron.initialize import initialize_megatron
|
|
from codegeex.megatron.checkpointing import ensure_directory_exists
|
|
|
|
|
|
def get_change_ckpt_args(parser):
|
|
"""Provide extra arguments required for merging."""
|
|
group = parser.add_argument_group(title='Mindspore to megatron')
|
|
group.add_argument(
|
|
'--load-ckpt-path',
|
|
type=str,
|
|
required=True,
|
|
help='dir to load model parallel partitions.',
|
|
)
|
|
group.add_argument(
|
|
'--save-ckpt-path',
|
|
type=str,
|
|
required=True,
|
|
help='path to save ".pt" checkpoint.',
|
|
)
|
|
group.add_argument(
|
|
'--save-name',
|
|
type=str,
|
|
help='name of checkpoint.',
|
|
)
|
|
group.add_argument(
|
|
'--source-tensor-model-parallel-size',
|
|
type=int,
|
|
default=2,
|
|
help='original tensor model parallel size',
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def main():
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
os.environ["MASTER_PORT"] = str(random.randint(10000, 20000))
|
|
|
|
initialize_megatron(
|
|
extra_args_provider=get_change_ckpt_args,
|
|
args_defaults={
|
|
"tokenizer_type": "GPT2BPETokenizer",
|
|
"no_load_rng" : True,
|
|
"no_load_optim" : True,
|
|
},
|
|
)
|
|
|
|
args = get_args()
|
|
model = CodeGeeXModel()
|
|
print(model.state_dict)
|
|
|
|
# Save the model.
|
|
sd = {}
|
|
sd['module'] = model.state_dict_for_save_checkpoint()
|
|
ensure_directory_exists(args.save_ckpt_path)
|
|
|
|
print(f"Load ckpt from {args.load_ckpt_path}...")
|
|
state_dict_list = []
|
|
for i in range(args.source_tensor_model_parallel_size):
|
|
try:
|
|
state_dict_list.append(torch.load(os.path.join(args.load_ckpt_path, f"mp_rank_{i:02d}_model_states.pt"), map_location="cpu"))
|
|
except Exception as e:
|
|
print(e)
|
|
exit(0)
|
|
|
|
print(f"Merging {len(state_dict_list)} partitions into a single ckpt...")
|
|
print("Merging Embedding layers...")
|
|
vocab_parallel_size = args.make_vocab_size_divisible_by // args.source_tensor_model_parallel_size
|
|
for i in range(args.source_tensor_model_parallel_size):
|
|
sd['module']['language_model']['embedding']['word_embeddings']['weight'][i * vocab_parallel_size : (i + 1) * vocab_parallel_size, :] = state_dict_list[i]['module']['language_model']['embedding']['word_embeddings']['weight']
|
|
|
|
sd['module']['language_model']['embedding']['position_embeddings']['weight'] = state_dict_list[0]['module']['language_model']['embedding']['position_embeddings']['weight']
|
|
|
|
print("Merging QueryEmbedding layers...")
|
|
query_parallel_size = args.max_position_embeddings // args.source_tensor_model_parallel_size
|
|
for i in range(args.source_tensor_model_parallel_size):
|
|
sd['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight'][i * query_parallel_size : (i + 1) * query_parallel_size, :] = state_dict_list[i]['module']['language_model']['topQueryEmbedding']['top_query_embeddings'].pop('weight', None)
|
|
|
|
print("Merging Transformer layers...")
|
|
for layer_name in sd['module']['language_model']['transformer'].keys():
|
|
if "layernorm" in layer_name:
|
|
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None)
|
|
elif "attention" in layer_name and "weight" in layer_name:
|
|
if "dense" in layer_name:
|
|
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[1] // args.source_tensor_model_parallel_size
|
|
for i in range(args.source_tensor_model_parallel_size):
|
|
sd['module']['language_model']['transformer'][layer_name][:, i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
|
|
else:
|
|
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size
|
|
for i in range(args.source_tensor_model_parallel_size):
|
|
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size, :] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
|
|
elif "weight" in layer_name and "dense" in layer_name:
|
|
if "h_to_4h" in layer_name:
|
|
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size
|
|
for i in range(args.source_tensor_model_parallel_size):
|
|
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size, :] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
|
|
else:
|
|
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[1] // args.source_tensor_model_parallel_size
|
|
for i in range(args.source_tensor_model_parallel_size):
|
|
sd['module']['language_model']['transformer'][layer_name][:, i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
|
|
elif "bias" in layer_name:
|
|
if "mlp" in layer_name:
|
|
if "4h_to_h" in layer_name:
|
|
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None)
|
|
else:
|
|
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size
|
|
for i in range(args.source_tensor_model_parallel_size):
|
|
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
|
|
elif "attention" in layer_name:
|
|
if "dense" in layer_name:
|
|
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None)
|
|
else:
|
|
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size
|
|
for i in range(args.source_tensor_model_parallel_size):
|
|
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
|
|
else:
|
|
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None)
|
|
|
|
if args.save_ckpt_path.endswith(".pt"):
|
|
save_ckpt_path = args.save_ckpt_path
|
|
else:
|
|
os.makedirs(args.save_ckpt_path, exist_ok=True)
|
|
if args.save_name:
|
|
save_ckpt_path = os.path.join(args.save_ckpt_path, args.save_name)
|
|
else:
|
|
save_ckpt_path = os.path.join(args.save_ckpt_path, "mp_rank_00_model_states.pt")
|
|
|
|
torch.save(sd, save_ckpt_path)
|
|
print(f"Converted checkpoint saved in {save_ckpt_path}.")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|