|
|
@ -390,8 +390,6 @@ class QueryEmbedding(MegatronModule):
|
|
|
|
= state_dict[key]
|
|
|
|
= state_dict[key]
|
|
|
|
pos_len = state_dict_['weight'].shape[0]
|
|
|
|
pos_len = state_dict_['weight'].shape[0]
|
|
|
|
max_seq_len = self.max_sequence_length // get_tensor_model_parallel_world_size()
|
|
|
|
max_seq_len = self.max_sequence_length // get_tensor_model_parallel_world_size()
|
|
|
|
print_rank_0(f"pos_len: {pos_len}")
|
|
|
|
|
|
|
|
print_rank_0(f"max_seq_len: {max_seq_len}")
|
|
|
|
|
|
|
|
if pos_len < max_seq_len:
|
|
|
|
if pos_len < max_seq_len:
|
|
|
|
print_rank_0(f"Top query embedding padded {pos_len} -> {max_seq_len}.")
|
|
|
|
print_rank_0(f"Top query embedding padded {pos_len} -> {max_seq_len}.")
|
|
|
|
top_query_embeddings_padded = torch.nn.Embedding(
|
|
|
|
top_query_embeddings_padded = torch.nn.Embedding(
|
|
|
|