pull/69/head
Stanislas0 2 years ago
parent 5c9dab3701
commit 079b9ebd94

@ -390,8 +390,6 @@ class QueryEmbedding(MegatronModule):
= state_dict[key]
pos_len = state_dict_['weight'].shape[0]
max_seq_len = self.max_sequence_length // get_tensor_model_parallel_world_size()
print_rank_0(f"pos_len: {pos_len}")
print_rank_0(f"max_seq_len: {max_seq_len}")
if pos_len < max_seq_len:
print_rank_0(f"Top query embedding padded {pos_len} -> {max_seq_len}.")
top_query_embeddings_padded = torch.nn.Embedding(

Loading…
Cancel
Save