mirror of https://github.com/THUDM/CodeGeeX.git
Refactor
parent
31c3bf351d
commit
a866d8db20
@ -0,0 +1,145 @@
|
|||||||
|
"""Get model parallel partitions."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
|
||||||
|
os.path.pardir)))
|
||||||
|
|
||||||
|
from codegeex.megatron import get_args
|
||||||
|
from codegeex.megatron.model import CodeGeeXModel
|
||||||
|
from codegeex.megatron.initialize import initialize_megatron
|
||||||
|
from codegeex.megatron.checkpointing import ensure_directory_exists
|
||||||
|
|
||||||
|
|
||||||
|
def get_change_ckpt_args(parser):
|
||||||
|
"""Provide extra arguments required for merging."""
|
||||||
|
group = parser.add_argument_group(title='Mindspore to megatron')
|
||||||
|
group.add_argument(
|
||||||
|
'--load-ckpt-path',
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help='path to load ".pt" checkpoint.',
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
'--save-ckpt-path',
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help='dir to save converted checkpoints.',
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
'--target-tensor-model-parallel-size',
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help='target tensor model parallel size',
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_element_from_dict_by_path(d, path):
|
||||||
|
"""
|
||||||
|
Get element from dictionary by path. If element is not present, recursively add empty dictionaries.
|
||||||
|
Args:
|
||||||
|
d (dict): the dictionary to get the element from
|
||||||
|
path (list): the path to the element which is delimited by "."
|
||||||
|
"""
|
||||||
|
path = path.split(".")
|
||||||
|
for k in path:
|
||||||
|
if k not in d:
|
||||||
|
d[k] = {}
|
||||||
|
d = d[k]
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
|
os.environ["MASTER_PORT"] = str(random.randint(10000, 20000))
|
||||||
|
|
||||||
|
initialize_megatron(
|
||||||
|
extra_args_provider=get_change_ckpt_args,
|
||||||
|
args_defaults={
|
||||||
|
"tokenizer_type": "GPT2BPETokenizer",
|
||||||
|
"no_load_rng" : True,
|
||||||
|
"no_load_optim" : True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
args = get_args()
|
||||||
|
print(f"Load ckpt from {args.load_ckpt_path}...")
|
||||||
|
state_dict = torch.load(args.load_ckpt_path, map_location="cpu")
|
||||||
|
|
||||||
|
print(f"Spliting ckpt into {args.target_tensor_model_parallel_size} parts...")
|
||||||
|
output_state_dict = []
|
||||||
|
for i in range(args.target_tensor_model_parallel_size):
|
||||||
|
output_state_dict.append({})
|
||||||
|
|
||||||
|
print("Converting Embedding layers...")
|
||||||
|
word_embeddings = state_dict['module']['language_model']['embedding']['word_embeddings']['weight']
|
||||||
|
position_embeddings = state_dict['module']['language_model']['embedding']['position_embeddings']['weight']
|
||||||
|
out_word_embeddings = torch.chunk(word_embeddings, args.target_tensor_model_parallel_size, dim=0)
|
||||||
|
|
||||||
|
for i in range(args.target_tensor_model_parallel_size):
|
||||||
|
pos_emb_dict = get_element_from_dict_by_path(
|
||||||
|
output_state_dict[i], "module.language_model.embedding.position_embeddings"
|
||||||
|
)
|
||||||
|
pos_emb_dict["weight"] = position_embeddings
|
||||||
|
|
||||||
|
word_emb_dict = get_element_from_dict_by_path(
|
||||||
|
output_state_dict[i], "module.language_model.embedding.word_embeddings"
|
||||||
|
)
|
||||||
|
word_emb_dict["weight"] = out_word_embeddings[i]
|
||||||
|
|
||||||
|
print("Converting QueryEmbedding layers...")
|
||||||
|
query_embeddings = state_dict['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight']
|
||||||
|
out_query_embeddings = torch.chunk(query_embeddings, args.target_tensor_model_parallel_size, dim=0)
|
||||||
|
|
||||||
|
for i in range(args.target_tensor_model_parallel_size):
|
||||||
|
query_emb_dict = get_element_from_dict_by_path(
|
||||||
|
output_state_dict[i], "module.language_model.topQueryEmbedding.top_query_embeddings"
|
||||||
|
)
|
||||||
|
query_emb_dict["weight"] = out_query_embeddings[i]
|
||||||
|
|
||||||
|
print("Converting Transformer layers...")
|
||||||
|
for layer_name in state_dict['module']['language_model']['transformer'].keys():
|
||||||
|
params = state_dict['module']['language_model']['transformer'][layer_name]
|
||||||
|
if "layernorm" in layer_name:
|
||||||
|
pass
|
||||||
|
elif "attention" in layer_name and "weight" in layer_name:
|
||||||
|
if "dense" in layer_name:
|
||||||
|
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=1)
|
||||||
|
else:
|
||||||
|
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=0)
|
||||||
|
elif "weight" in layer_name and "dense" in layer_name:
|
||||||
|
if "h_to_4h" in layer_name:
|
||||||
|
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=0)
|
||||||
|
else:
|
||||||
|
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=1)
|
||||||
|
elif "bias" in layer_name:
|
||||||
|
if "dense" not in layer_name or "mlp" in layer_name:
|
||||||
|
if "4h_to_h" in layer_name:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=0)
|
||||||
|
|
||||||
|
for i in range(args.target_tensor_model_parallel_size):
|
||||||
|
params_dict = get_element_from_dict_by_path(output_state_dict[i], "module.language_model.transformer")
|
||||||
|
if type(params) is tuple:
|
||||||
|
params_dict[layer_name] = params[i]
|
||||||
|
else:
|
||||||
|
params_dict[layer_name] = params
|
||||||
|
|
||||||
|
os.makedirs(args.save_ckpt_path, exist_ok=True)
|
||||||
|
for rank in range(args.target_tensor_model_parallel_size):
|
||||||
|
save_ckpt_path = os.path.join(args.save_ckpt_path, f"mp_rank_{rank:02d}_model_states.pt")
|
||||||
|
torch.save(output_state_dict[rank], save_ckpt_path)
|
||||||
|
print(f"Converted checkpoint saved in {save_ckpt_path}.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -1,201 +0,0 @@
|
|||||||
|
|
||||||
import os
|
|
||||||
import copy
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
import random
|
|
||||||
import argparse
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from codegeex.torch.inference import get_token_stream
|
|
||||||
from codegeex.torch import CodeGeeXModel
|
|
||||||
from codegeex.tokenizer import CodeGeeXTokenizer
|
|
||||||
from codegeex.quantization import quantize
|
|
||||||
|
|
||||||
|
|
||||||
def model_provider(args):
|
|
||||||
"""Build the model."""
|
|
||||||
|
|
||||||
model = CodeGeeXModel(
|
|
||||||
args.hidden_size,
|
|
||||||
args.num_layers,
|
|
||||||
args.num_attention_heads,
|
|
||||||
args.padded_vocab_size,
|
|
||||||
args.max_position_embeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def add_code_generation_args(parser):
|
|
||||||
group = parser.add_argument_group(title="code generation")
|
|
||||||
group.add_argument(
|
|
||||||
"--num-layers",
|
|
||||||
type=int,
|
|
||||||
default=39,
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--hidden-size",
|
|
||||||
type=int,
|
|
||||||
default=5120,
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--num-attention-heads",
|
|
||||||
type=int,
|
|
||||||
default=40,
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--padded-vocab-size",
|
|
||||||
type=int,
|
|
||||||
default=52224,
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--max-position-embeddings",
|
|
||||||
type=int,
|
|
||||||
default=2048,
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--temperature",
|
|
||||||
type=float,
|
|
||||||
default=1.0,
|
|
||||||
help="Sampling temperature.",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--greedy",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Use greedy sampling.",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--top-p",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
help="Top p sampling.",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--top-k",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Top k sampling.",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--out-seq-length",
|
|
||||||
type=int,
|
|
||||||
default=2048,
|
|
||||||
help="Size of the output generated text.",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--prompt-file",
|
|
||||||
type=str,
|
|
||||||
default="./test_prompt.txt",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--tokenizer-path",
|
|
||||||
type=str,
|
|
||||||
default="./tokenizer",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--load",
|
|
||||||
type=str,
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--state-dict-path",
|
|
||||||
type=str,
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--micro-batch-size",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser = add_code_generation_args(parser)
|
|
||||||
args, _ = parser.parse_known_args()
|
|
||||||
|
|
||||||
print("Building CodeGeeX model ...")
|
|
||||||
model = model_provider(args)
|
|
||||||
|
|
||||||
print("Loading tokenizer ...")
|
|
||||||
tokenizer = CodeGeeXTokenizer(
|
|
||||||
tokenizer_path=args.tokenizer_path,
|
|
||||||
mode="codegeex-13b")
|
|
||||||
|
|
||||||
print("Loading state dict ...")
|
|
||||||
state_dict = torch.load(args.load, map_location="cpu")
|
|
||||||
state_dict = state_dict["module"]
|
|
||||||
|
|
||||||
print("Building CodeGeeX model ...")
|
|
||||||
model = model_provider(args)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
model.eval()
|
|
||||||
model.half()
|
|
||||||
model = quantize(model, weight_bit_width=8)
|
|
||||||
model.cuda()
|
|
||||||
|
|
||||||
with open(args.prompt_file, "r") as f:
|
|
||||||
prompt = f.readlines()
|
|
||||||
prompt = "".join(prompt)
|
|
||||||
|
|
||||||
times = {}
|
|
||||||
out_seq_lengths = [args.out_seq_length]
|
|
||||||
micro_batch_size = args.micro_batch_size
|
|
||||||
seq_length = args.max_position_embeddings
|
|
||||||
for out_seq_length in out_seq_lengths:
|
|
||||||
print(f"Generating with out_seq_len {out_seq_length}...")
|
|
||||||
|
|
||||||
times[out_seq_length] = []
|
|
||||||
for prompt in [prompt]:
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
tokens = tokenizer.encode_code(prompt)
|
|
||||||
print(tokens)
|
|
||||||
print("Current prompt:")
|
|
||||||
print(prompt)
|
|
||||||
n_token_prompt = len(tokens)
|
|
||||||
print("N_token_prompt:", n_token_prompt)
|
|
||||||
token_stream = get_token_stream(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
seq_length,
|
|
||||||
out_seq_length,
|
|
||||||
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
|
|
||||||
micro_batch_size=micro_batch_size,
|
|
||||||
topk=args.top_k,
|
|
||||||
topp=args.top_p,
|
|
||||||
temperature=args.temperature,
|
|
||||||
greedy=args.greedy,
|
|
||||||
)
|
|
||||||
is_finished = [False for _ in range(micro_batch_size)]
|
|
||||||
for i, generated in enumerate(token_stream):
|
|
||||||
generated_tokens = generated[0]
|
|
||||||
for j in range(micro_batch_size):
|
|
||||||
if is_finished[j]:
|
|
||||||
continue
|
|
||||||
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len(
|
|
||||||
generated_tokens[j]) >= out_seq_length:
|
|
||||||
is_finished[j] = True
|
|
||||||
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
|
|
||||||
generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:])
|
|
||||||
generated_code = "".join(generated_code)
|
|
||||||
t1 = time.perf_counter()
|
|
||||||
print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt)
|
|
||||||
print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
|
|
||||||
times[out_seq_length].append(t1 - t0)
|
|
||||||
print("================================= Generated code:")
|
|
||||||
print(generated_code)
|
|
||||||
|
|
||||||
if all(is_finished):
|
|
||||||
break
|
|
||||||
|
|
||||||
print(times)
|
|
||||||
for out_seq_length in times.keys():
|
|
||||||
print(out_seq_length, np.mean(times[out_seq_length]))
|
|
||||||
|
|
||||||
print("Generation finished.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
Loading…
Reference in New Issue