mirror of https://github.com/THUDM/CodeGeeX.git
Refactor
parent
31c3bf351d
commit
a866d8db20
@ -0,0 +1,145 @@
|
||||
"""Get model parallel partitions."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
|
||||
os.path.pardir)))
|
||||
|
||||
from codegeex.megatron import get_args
|
||||
from codegeex.megatron.model import CodeGeeXModel
|
||||
from codegeex.megatron.initialize import initialize_megatron
|
||||
from codegeex.megatron.checkpointing import ensure_directory_exists
|
||||
|
||||
|
||||
def get_change_ckpt_args(parser):
|
||||
"""Provide extra arguments required for merging."""
|
||||
group = parser.add_argument_group(title='Mindspore to megatron')
|
||||
group.add_argument(
|
||||
'--load-ckpt-path',
|
||||
type=str,
|
||||
required=True,
|
||||
help='path to load ".pt" checkpoint.',
|
||||
)
|
||||
group.add_argument(
|
||||
'--save-ckpt-path',
|
||||
type=str,
|
||||
required=True,
|
||||
help='dir to save converted checkpoints.',
|
||||
)
|
||||
group.add_argument(
|
||||
'--target-tensor-model-parallel-size',
|
||||
type=int,
|
||||
default=2,
|
||||
help='target tensor model parallel size',
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_element_from_dict_by_path(d, path):
|
||||
"""
|
||||
Get element from dictionary by path. If element is not present, recursively add empty dictionaries.
|
||||
Args:
|
||||
d (dict): the dictionary to get the element from
|
||||
path (list): the path to the element which is delimited by "."
|
||||
"""
|
||||
path = path.split(".")
|
||||
for k in path:
|
||||
if k not in d:
|
||||
d[k] = {}
|
||||
d = d[k]
|
||||
return d
|
||||
|
||||
|
||||
def main():
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(random.randint(10000, 20000))
|
||||
|
||||
initialize_megatron(
|
||||
extra_args_provider=get_change_ckpt_args,
|
||||
args_defaults={
|
||||
"tokenizer_type": "GPT2BPETokenizer",
|
||||
"no_load_rng" : True,
|
||||
"no_load_optim" : True,
|
||||
},
|
||||
)
|
||||
|
||||
args = get_args()
|
||||
print(f"Load ckpt from {args.load_ckpt_path}...")
|
||||
state_dict = torch.load(args.load_ckpt_path, map_location="cpu")
|
||||
|
||||
print(f"Spliting ckpt into {args.target_tensor_model_parallel_size} parts...")
|
||||
output_state_dict = []
|
||||
for i in range(args.target_tensor_model_parallel_size):
|
||||
output_state_dict.append({})
|
||||
|
||||
print("Converting Embedding layers...")
|
||||
word_embeddings = state_dict['module']['language_model']['embedding']['word_embeddings']['weight']
|
||||
position_embeddings = state_dict['module']['language_model']['embedding']['position_embeddings']['weight']
|
||||
out_word_embeddings = torch.chunk(word_embeddings, args.target_tensor_model_parallel_size, dim=0)
|
||||
|
||||
for i in range(args.target_tensor_model_parallel_size):
|
||||
pos_emb_dict = get_element_from_dict_by_path(
|
||||
output_state_dict[i], "module.language_model.embedding.position_embeddings"
|
||||
)
|
||||
pos_emb_dict["weight"] = position_embeddings
|
||||
|
||||
word_emb_dict = get_element_from_dict_by_path(
|
||||
output_state_dict[i], "module.language_model.embedding.word_embeddings"
|
||||
)
|
||||
word_emb_dict["weight"] = out_word_embeddings[i]
|
||||
|
||||
print("Converting QueryEmbedding layers...")
|
||||
query_embeddings = state_dict['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight']
|
||||
out_query_embeddings = torch.chunk(query_embeddings, args.target_tensor_model_parallel_size, dim=0)
|
||||
|
||||
for i in range(args.target_tensor_model_parallel_size):
|
||||
query_emb_dict = get_element_from_dict_by_path(
|
||||
output_state_dict[i], "module.language_model.topQueryEmbedding.top_query_embeddings"
|
||||
)
|
||||
query_emb_dict["weight"] = out_query_embeddings[i]
|
||||
|
||||
print("Converting Transformer layers...")
|
||||
for layer_name in state_dict['module']['language_model']['transformer'].keys():
|
||||
params = state_dict['module']['language_model']['transformer'][layer_name]
|
||||
if "layernorm" in layer_name:
|
||||
pass
|
||||
elif "attention" in layer_name and "weight" in layer_name:
|
||||
if "dense" in layer_name:
|
||||
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=1)
|
||||
else:
|
||||
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=0)
|
||||
elif "weight" in layer_name and "dense" in layer_name:
|
||||
if "h_to_4h" in layer_name:
|
||||
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=0)
|
||||
else:
|
||||
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=1)
|
||||
elif "bias" in layer_name:
|
||||
if "dense" not in layer_name or "mlp" in layer_name:
|
||||
if "4h_to_h" in layer_name:
|
||||
pass
|
||||
else:
|
||||
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=0)
|
||||
|
||||
for i in range(args.target_tensor_model_parallel_size):
|
||||
params_dict = get_element_from_dict_by_path(output_state_dict[i], "module.language_model.transformer")
|
||||
if type(params) is tuple:
|
||||
params_dict[layer_name] = params[i]
|
||||
else:
|
||||
params_dict[layer_name] = params
|
||||
|
||||
os.makedirs(args.save_ckpt_path, exist_ok=True)
|
||||
for rank in range(args.target_tensor_model_parallel_size):
|
||||
save_ckpt_path = os.path.join(args.save_ckpt_path, f"mp_rank_{rank:02d}_model_states.pt")
|
||||
torch.save(output_state_dict[rank], save_ckpt_path)
|
||||
print(f"Converted checkpoint saved in {save_ckpt_path}.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,201 +0,0 @@
|
||||
|
||||
import os
|
||||
import copy
|
||||
import time
|
||||
import torch
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from codegeex.torch.inference import get_token_stream
|
||||
from codegeex.torch import CodeGeeXModel
|
||||
from codegeex.tokenizer import CodeGeeXTokenizer
|
||||
from codegeex.quantization import quantize
|
||||
|
||||
|
||||
def model_provider(args):
|
||||
"""Build the model."""
|
||||
|
||||
model = CodeGeeXModel(
|
||||
args.hidden_size,
|
||||
args.num_layers,
|
||||
args.num_attention_heads,
|
||||
args.padded_vocab_size,
|
||||
args.max_position_embeddings
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def add_code_generation_args(parser):
|
||||
group = parser.add_argument_group(title="code generation")
|
||||
group.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
default=39,
|
||||
)
|
||||
group.add_argument(
|
||||
"--hidden-size",
|
||||
type=int,
|
||||
default=5120,
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-attention-heads",
|
||||
type=int,
|
||||
default=40,
|
||||
)
|
||||
group.add_argument(
|
||||
"--padded-vocab-size",
|
||||
type=int,
|
||||
default=52224,
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-position-embeddings",
|
||||
type=int,
|
||||
default=2048,
|
||||
)
|
||||
group.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Sampling temperature.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--greedy",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use greedy sampling.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--top-p",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Top p sampling.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Top k sampling.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--out-seq-length",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Size of the output generated text.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--prompt-file",
|
||||
type=str,
|
||||
default="./test_prompt.txt",
|
||||
)
|
||||
group.add_argument(
|
||||
"--tokenizer-path",
|
||||
type=str,
|
||||
default="./tokenizer",
|
||||
)
|
||||
group.add_argument(
|
||||
"--load",
|
||||
type=str,
|
||||
)
|
||||
group.add_argument(
|
||||
"--state-dict-path",
|
||||
type=str,
|
||||
)
|
||||
group.add_argument(
|
||||
"--micro-batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = add_code_generation_args(parser)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
print("Building CodeGeeX model ...")
|
||||
model = model_provider(args)
|
||||
|
||||
print("Loading tokenizer ...")
|
||||
tokenizer = CodeGeeXTokenizer(
|
||||
tokenizer_path=args.tokenizer_path,
|
||||
mode="codegeex-13b")
|
||||
|
||||
print("Loading state dict ...")
|
||||
state_dict = torch.load(args.load, map_location="cpu")
|
||||
state_dict = state_dict["module"]
|
||||
|
||||
print("Building CodeGeeX model ...")
|
||||
model = model_provider(args)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
model.half()
|
||||
model = quantize(model, weight_bit_width=8)
|
||||
model.cuda()
|
||||
|
||||
with open(args.prompt_file, "r") as f:
|
||||
prompt = f.readlines()
|
||||
prompt = "".join(prompt)
|
||||
|
||||
times = {}
|
||||
out_seq_lengths = [args.out_seq_length]
|
||||
micro_batch_size = args.micro_batch_size
|
||||
seq_length = args.max_position_embeddings
|
||||
for out_seq_length in out_seq_lengths:
|
||||
print(f"Generating with out_seq_len {out_seq_length}...")
|
||||
|
||||
times[out_seq_length] = []
|
||||
for prompt in [prompt]:
|
||||
t0 = time.perf_counter()
|
||||
tokens = tokenizer.encode_code(prompt)
|
||||
print(tokens)
|
||||
print("Current prompt:")
|
||||
print(prompt)
|
||||
n_token_prompt = len(tokens)
|
||||
print("N_token_prompt:", n_token_prompt)
|
||||
token_stream = get_token_stream(
|
||||
model,
|
||||
tokenizer,
|
||||
seq_length,
|
||||
out_seq_length,
|
||||
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
|
||||
micro_batch_size=micro_batch_size,
|
||||
topk=args.top_k,
|
||||
topp=args.top_p,
|
||||
temperature=args.temperature,
|
||||
greedy=args.greedy,
|
||||
)
|
||||
is_finished = [False for _ in range(micro_batch_size)]
|
||||
for i, generated in enumerate(token_stream):
|
||||
generated_tokens = generated[0]
|
||||
for j in range(micro_batch_size):
|
||||
if is_finished[j]:
|
||||
continue
|
||||
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len(
|
||||
generated_tokens[j]) >= out_seq_length:
|
||||
is_finished[j] = True
|
||||
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
|
||||
generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:])
|
||||
generated_code = "".join(generated_code)
|
||||
t1 = time.perf_counter()
|
||||
print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt)
|
||||
print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
|
||||
times[out_seq_length].append(t1 - t0)
|
||||
print("================================= Generated code:")
|
||||
print(generated_code)
|
||||
|
||||
if all(is_finished):
|
||||
break
|
||||
|
||||
print(times)
|
||||
for out_seq_length in times.keys():
|
||||
print(out_seq_length, np.mean(times[out_seq_length]))
|
||||
|
||||
print("Generation finished.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue