pull/69/head
Stanislas0 2 years ago
parent 965abd81b4
commit 76c550d876

@ -84,7 +84,7 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False):
if release:
directory = ""
else:
directory = "iter_{:07d}".format(iteration)
directory = f"global_step{iteration}"
# Use both the tensor and pipeline MP rank.
if mpu.get_pipeline_model_parallel_world_size() == 1:
return os.path.join(
@ -174,7 +174,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Saving is a collective communication
checkpoint_name = get_checkpoint_name(args.save, iteration)
# Trim off the filename and mp_rank_* directory.
for _ in range(3):
for _ in range(2):
checkpoint_name = os.path.dirname(checkpoint_name)
model[0].save_checkpoint(checkpoint_name, client_state=state_dict)

@ -19,10 +19,11 @@ import torch
import torch.nn.functional as F
from codegeex.megatron import get_args
from codegeex.megatron import mpu
from codegeex.megatron import mpu, print_rank_0
from codegeex.megatron.model.module import MegatronModule
from codegeex.megatron.model.transformer import ParallelTransformer
from codegeex.megatron.model.utils import init_method_normal, scaled_init_method_normal
from codegeex.megatron.mpu.initialize import get_tensor_model_parallel_world_size
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
@ -92,22 +93,33 @@ class Embedding(MegatronModule):
num_tokentypes=0,
):
super(Embedding, self).__init__()
args = get_args()
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes
self.max_sequence_length = max_sequence_length
# Word embeddings (parallel).
self.word_embeddings = mpu.VocabParallelEmbedding(
vocab_size, self.hidden_size, init_method=self.init_method)
self._word_embeddings_key = 'word_embeddings'
if args.compress:
self._word_embeddings_key = 'word_embedding'
else:
self._word_embeddings_key = 'word_embeddings'
self.vocab_size = vocab_size
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(
max_sequence_length, self.hidden_size)
self.position_embeddings = self.position_embeddings.half()
self._position_embeddings_key = 'position_embeddings'
if args.compress:
self._position_embeddings_key = 'position_embedding'
else:
self._position_embeddings_key = 'position_embeddings'
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
@ -190,7 +202,8 @@ class Embedding(MegatronModule):
if 'word_embeddings' in key:
state_dict_[key.split('word_embeddings.')[1]] \
= state_dict[key]
state_dict_["weight"] = state_dict_["weight"][:self.vocab_size]
vocab_len = state_dict_['weight'].shape[0]
state_dict_["weight"] = state_dict_["weight"][:self.vocab_size // get_tensor_model_parallel_world_size()]
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
# Position embedding.
@ -203,6 +216,17 @@ class Embedding(MegatronModule):
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] \
= state_dict[key]
pos_len = state_dict_['weight'].shape[0]
max_seq_len = self.max_sequence_length
if pos_len < max_seq_len:
print_rank_0(f"Position embedding padded {pos_len} -> {max_seq_len}.")
position_embeddings_padded = torch.nn.Embedding(
max_seq_len - pos_len, self.hidden_size).half()
self.init_method(position_embeddings_padded.weight)
state_dict_['weight'] = torch.cat([state_dict_['weight'], position_embeddings_padded.weight], dim=0)
# self.position_embeddings = self.position_embeddings.half()
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.
@ -284,12 +308,14 @@ class QueryEmbedding(MegatronModule):
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes
self.max_sequence_length = max_sequence_length
# Top query position embedding (serial).
self.top_query_embeddings = mpu.VocabParallelEmbedding(
max_sequence_length, self.hidden_size, init_method=self.init_method)
self.top_query_embeddings = self.top_query_embeddings.half()
self._top_query_embeddings_key = 'top_query_embeddings'
# Initialize the top query position embeddings.
self.init_method(self.top_query_embeddings.weight)
@ -368,6 +394,16 @@ class QueryEmbedding(MegatronModule):
if 'top_query_embeddings' in key:
state_dict_[key.split('top_query_embeddings.')[1]] \
= state_dict[key]
pos_len = state_dict_['weight'].shape[0]
max_seq_len = self.max_sequence_length // get_tensor_model_parallel_world_size()
print_rank_0(f"pos_len: {pos_len}")
print_rank_0(f"max_seq_len: {max_seq_len}")
if pos_len < max_seq_len:
print_rank_0(f"Top query embedding padded {pos_len} -> {max_seq_len}.")
top_query_embeddings_padded = torch.nn.Embedding(
max_seq_len - pos_len, self.hidden_size).half()
self.init_method(top_query_embeddings_padded.weight)
state_dict_['weight'] = torch.cat([state_dict_['weight'], top_query_embeddings_padded.weight], dim=0)
self.top_query_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.

@ -0,0 +1,50 @@
import os
import sys
import torch
import random
import argparse
import numpy as np
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--load-path",
type=str,
default="/zhangpai24/workspace/ckpt_ms/ckpt_ms_213000_fp32_52224.pt")
parser.add_argument("--save-path",
type=str,
default="/zhangpai24/workspace/ckpt_ms/ckpt_ms_213000_qkv.pt")
args, _ = parser.parse_known_args()
state_dict_path = args.load_path
print("Loading state dict ...")
sd = torch.load(state_dict_path, map_location="cpu")
for i in range(40):
if i < 39:
query_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.query.weight', None)
query_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.query.bias', None)
key_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.key.weight', None)
key_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.key.bias', None)
value_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.value.weight', None)
value_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.value.bias', None)
qkv_weight = torch.cat([query_weight, key_weight, value_weight], dim=0)
qkv_bias = torch.cat([query_bias, key_bias, value_bias])
sd['module']['language_model']['transformer'][f'layers.{i}.attention.query_key_value.weight'] = qkv_weight
sd['module']['language_model']['transformer'][f'layers.{i}.attention.query_key_value.bias'] = qkv_bias
else:
tq_key_weight = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.key.weight', None)
tq_key_bias = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.key.bias', None)
tq_value_weight = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.value.weight', None)
tq_value_bias = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.value.bias', None)
tq_kv_weight = torch.cat([tq_key_weight, tq_value_weight], dim=0)
tq_kv_bias = torch.cat([tq_key_bias, tq_value_bias])
sd['module']['language_model']['transformer']['topQueryLayer.attention.key_value.weight'] = tq_kv_weight
sd['module']['language_model']['transformer']['topQueryLayer.attention.key_value.bias'] = tq_kv_bias
save_ckpt_path = args.save_path
torch.save(sd, save_ckpt_path)
if __name__ == '__main__':
main()

@ -176,8 +176,7 @@ def main():
for j in range(micro_batch_size):
if is_finished[j]:
continue
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len(
generated_tokens[j]) >= out_seq_length:
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len(generated_tokens[j]) >= out_seq_length:
is_finished[j] = True
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:])

@ -166,44 +166,57 @@ def main():
with open(args.prompt_file, "r") as f:
prompt = f.readlines()
prompt = "".join(prompt)
print_rank_0("Generating ...")
t0 = time.perf_counter()
for prompt in [prompt]:
tokens = tokenizer.tokenize(prompt)
print_rank_0(tokens)
print_rank_0("Current prompt:")
print_rank_0(prompt)
n_token_prompt = len(tokens)
print_rank_0(f"N_token_prompt: {n_token_prompt}")
token_stream = get_token_stream(
model,
[copy.deepcopy(tokens) for _ in range(args.micro_batch_size)],
micro_batch_size=args.micro_batch_size,
bad_ids=args.bad_ids,
)
is_finished = [False for _ in range(args.micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]
for j in range(args.micro_batch_size):
if is_finished[j]:
continue
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eod or len(
generated_tokens[j]) >= args.out_seq_length:
is_finished[j] = True
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
t1 = time.perf_counter()
print_rank_0(f"Total generation time: {t1 - t0}, # Tokens: {len(generated_tokens_) - n_token_prompt}")
print_rank_0(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
print_rank_0("================================= Generated code:")
print_rank_0(generated_code)
t0 = time.perf_counter()
if all(is_finished):
break
times = {}
out_seq_lengths = [args.out_seq_length]
micro_batch_size = args.micro_batch_size
for out_seq_length in out_seq_lengths:
print_rank_0(f"Generating with out_seq_len {out_seq_length}...")
times[out_seq_length] = []
for prompt in [prompt] * args.n_generation:
t0 = time.perf_counter()
tokens = tokenizer.tokenize(prompt)
print_rank_0(tokens)
print_rank_0("Current prompt:")
print_rank_0(prompt)
n_token_prompt = len(tokens)
print_rank_0(f"N_token_prompt:{n_token_prompt}")
token_stream = get_token_stream(
model,
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
micro_batch_size=micro_batch_size,
topk=args.top_k,
topp=args.top_p,
temperature=args.temperature,
)
is_finished = [False for _ in range(micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]
for j in range(micro_batch_size):
if is_finished[j]:
continue
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eod or len(
generated_tokens[j]) >= out_seq_length:
is_finished[j] = True
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
t1 = time.perf_counter()
print_rank_0(f"Total generation time: {t1 - t0}, # Tokens: {len(generated_tokens_) - n_token_prompt}")
print_rank_0(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
times[out_seq_length].append(t1 - t0)
print_rank_0("================================= Generated code:")
print_rank_0(generated_code)
t0 = time.perf_counter()
if all(is_finished):
break
print_rank_0(times)
for out_seq_length in times.keys():
print_rank_0(f"{out_seq_length}, {np.mean(times[out_seq_length])}")
print_rank_0("Generation finished.")
if __name__ == "__main__":
main()

Loading…
Cancel
Save