Merge pull request #71 from THUDM/develop

Develop
pull/74/head
Qinkai 2 years ago committed by GitHub
commit 54052c3cc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -72,7 +72,7 @@ def main():
word_emb_dict = get_element_from_dict_by_path(
output_state_dict[i], "module.language_model.embedding.word_embeddings"
)
word_emb_dict["weight"] = out_word_embeddings[i]
word_emb_dict["weight"] = out_word_embeddings[i].clone()
print("Converting QueryEmbedding layers...")
query_embeddings = state_dict['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight']
@ -82,7 +82,7 @@ def main():
query_emb_dict = get_element_from_dict_by_path(
output_state_dict[i], "module.language_model.topQueryEmbedding.top_query_embeddings"
)
query_emb_dict["weight"] = out_query_embeddings[i]
query_emb_dict["weight"] = out_query_embeddings[i].clone()
print("Converting Transformer layers...")
for layer_name in state_dict['module']['language_model']['transformer'].keys():
@ -109,10 +109,10 @@ def main():
for i in range(args.target_tensor_model_parallel_size):
params_dict = get_element_from_dict_by_path(output_state_dict[i], "module.language_model.transformer")
if type(params) is tuple:
params_dict[layer_name] = params[i]
params_dict[layer_name] = params[i].clone()
else:
params_dict[layer_name] = params
os.makedirs(args.save_ckpt_path, exist_ok=True)
for rank in range(args.target_tensor_model_parallel_size):
save_ckpt_path = os.path.join(args.save_ckpt_path, f"mp_rank_{rank:02d}_model_states.pt")

@ -25,6 +25,11 @@ def get_change_ckpt_args(parser):
required=True,
help='path to save ".pt" checkpoint.',
)
group.add_argument(
'--save-name',
type=str,
help='name of checkpoint.',
)
group.add_argument(
'--source-tensor-model-parallel-size',
type=int,
@ -123,7 +128,10 @@ def main():
save_ckpt_path = args.save_ckpt_path
else:
os.makedirs(args.save_ckpt_path, exist_ok=True)
save_ckpt_path = os.path.join(args.save_ckpt_path, "mp_rank_00_model_states.pt")
if args.save_name:
save_ckpt_path = os.path.join(args.save_ckpt_path, args.save_name)
else:
save_ckpt_path = os.path.join(args.save_ckpt_path, "mp_rank_00_model_states.pt")
torch.save(sd, save_ckpt_path)
print(f"Converted checkpoint saved in {save_ckpt_path}.")

Loading…
Cancel
Save