|
|
@ -72,7 +72,7 @@ def main():
|
|
|
|
word_emb_dict = get_element_from_dict_by_path(
|
|
|
|
word_emb_dict = get_element_from_dict_by_path(
|
|
|
|
output_state_dict[i], "module.language_model.embedding.word_embeddings"
|
|
|
|
output_state_dict[i], "module.language_model.embedding.word_embeddings"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
word_emb_dict["weight"] = out_word_embeddings[i]
|
|
|
|
word_emb_dict["weight"] = out_word_embeddings[i].clone()
|
|
|
|
|
|
|
|
|
|
|
|
print("Converting QueryEmbedding layers...")
|
|
|
|
print("Converting QueryEmbedding layers...")
|
|
|
|
query_embeddings = state_dict['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight']
|
|
|
|
query_embeddings = state_dict['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight']
|
|
|
@ -82,7 +82,7 @@ def main():
|
|
|
|
query_emb_dict = get_element_from_dict_by_path(
|
|
|
|
query_emb_dict = get_element_from_dict_by_path(
|
|
|
|
output_state_dict[i], "module.language_model.topQueryEmbedding.top_query_embeddings"
|
|
|
|
output_state_dict[i], "module.language_model.topQueryEmbedding.top_query_embeddings"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
query_emb_dict["weight"] = out_query_embeddings[i]
|
|
|
|
query_emb_dict["weight"] = out_query_embeddings[i].clone()
|
|
|
|
|
|
|
|
|
|
|
|
print("Converting Transformer layers...")
|
|
|
|
print("Converting Transformer layers...")
|
|
|
|
for layer_name in state_dict['module']['language_model']['transformer'].keys():
|
|
|
|
for layer_name in state_dict['module']['language_model']['transformer'].keys():
|
|
|
@ -109,10 +109,10 @@ def main():
|
|
|
|
for i in range(args.target_tensor_model_parallel_size):
|
|
|
|
for i in range(args.target_tensor_model_parallel_size):
|
|
|
|
params_dict = get_element_from_dict_by_path(output_state_dict[i], "module.language_model.transformer")
|
|
|
|
params_dict = get_element_from_dict_by_path(output_state_dict[i], "module.language_model.transformer")
|
|
|
|
if type(params) is tuple:
|
|
|
|
if type(params) is tuple:
|
|
|
|
params_dict[layer_name] = params[i]
|
|
|
|
params_dict[layer_name] = params[i].clone()
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
params_dict[layer_name] = params
|
|
|
|
params_dict[layer_name] = params
|
|
|
|
|
|
|
|
|
|
|
|
os.makedirs(args.save_ckpt_path, exist_ok=True)
|
|
|
|
os.makedirs(args.save_ckpt_path, exist_ok=True)
|
|
|
|
for rank in range(args.target_tensor_model_parallel_size):
|
|
|
|
for rank in range(args.target_tensor_model_parallel_size):
|
|
|
|
save_ckpt_path = os.path.join(args.save_ckpt_path, f"mp_rank_{rank:02d}_model_states.pt")
|
|
|
|
save_ckpt_path = os.path.join(args.save_ckpt_path, f"mp_rank_{rank:02d}_model_states.pt")
|
|
|
|