Fix merge parallel ckpt

pull/69/head
Stanislas0 2 years ago
parent 0a3ec0e53f
commit 5c9dab3701

@ -25,12 +25,6 @@ def get_change_ckpt_args(parser):
required=True,
help='path to save ".pt" checkpoint.',
)
group.add_argument(
'--save-name',
type=str,
default="mp_rank_00_model_states.pt",
help='name of saved ".pt" checkpoint.',
)
group.add_argument(
'--source-tensor-model-parallel-size',
type=int,
@ -73,8 +67,6 @@ def main():
exit(0)
print(f"Merging {len(state_dict_list)} partitions into a single ckpt...")
output_state_dict = {}
print("Merging Embedding layers...")
vocab_parallel_size = args.make_vocab_size_divisible_by // args.source_tensor_model_parallel_size
for i in range(args.source_tensor_model_parallel_size):
@ -85,43 +77,54 @@ def main():
print("Merging QueryEmbedding layers...")
query_parallel_size = args.max_position_embeddings // args.source_tensor_model_parallel_size
for i in range(args.source_tensor_model_parallel_size):
sd['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight'][i * query_parallel_size : (i + 1) * query_parallel_size, :] = state_dict_list[i]['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight']
sd['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight'][i * query_parallel_size : (i + 1) * query_parallel_size, :] = state_dict_list[i]['module']['language_model']['topQueryEmbedding']['top_query_embeddings'].pop('weight', None)
print("Merging Transformer layers...")
for layer_name in sd['module']['language_model']['transformer'].keys():
if "layernorm" in layer_name:
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'][layer_name]
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None)
elif "attention" in layer_name and "weight" in layer_name:
if "dense" in layer_name:
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[1] // args.source_tensor_model_parallel_size
for i in range(args.source_tensor_model_parallel_size):
sd['module']['language_model']['transformer'][layer_name][:, i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'][layer_name]
sd['module']['language_model']['transformer'][layer_name][:, i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
else:
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size
for i in range(args.source_tensor_model_parallel_size):
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size, :] = state_dict_list[i]['module']['language_model']['transformer'][layer_name]
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size, :] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
elif "weight" in layer_name and "dense" in layer_name:
if "h_to_4h" in layer_name:
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size
for i in range(args.source_tensor_model_parallel_size):
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size, :] = state_dict_list[i]['module']['language_model']['transformer'][layer_name]
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size, :] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
else:
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[1] // args.source_tensor_model_parallel_size
for i in range(args.source_tensor_model_parallel_size):
sd['module']['language_model']['transformer'][layer_name][:, i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'][layer_name]
sd['module']['language_model']['transformer'][layer_name][:, i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
elif "bias" in layer_name:
if "dense" not in layer_name or "mlp" in layer_name:
if "mlp" in layer_name:
if "4h_to_h" in layer_name:
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'][layer_name]
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None)
else:
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size
for i in range(args.source_tensor_model_parallel_size):
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
elif "attention" in layer_name:
if "dense" in layer_name:
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None)
else:
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size
for i in range(args.source_tensor_model_parallel_size):
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'][layer_name]
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
else:
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'][layer_name]
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None)
if args.save_ckpt_path.endswith(".pt"):
save_ckpt_path = args.save_ckpt_path
else:
os.makedirs(args.save_ckpt_path, exist_ok=True)
save_ckpt_path = os.path.join(args.save_ckpt_path, args.save_name)
save_ckpt_path = os.path.join(args.save_ckpt_path, "mp_rank_00_model_states.pt")
torch.save(sd, save_ckpt_path)
print(f"Converted checkpoint saved in {save_ckpt_path}.")

Loading…
Cancel
Save