mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
200 lines
5.3 KiB
Python
200 lines
5.3 KiB
Python
2 years ago
|
import os
|
||
|
import copy
|
||
|
import time
|
||
|
import torch
|
||
|
import random
|
||
|
import numpy as np
|
||
|
|
||
|
from codegeex.megatron import get_tokenizer, get_args
|
||
|
from codegeex.megatron.initialize import initialize_megatron
|
||
|
from codegeex.megatron.model import CodeGeeXModel
|
||
|
from codegeex.megatron.code_generation_utils import get_token_stream
|
||
|
|
||
|
torch.set_printoptions(precision=8)
|
||
|
|
||
|
|
||
|
def set_random_seed(seed):
|
||
|
np.random.seed(seed)
|
||
|
torch.manual_seed(seed)
|
||
|
torch.cuda.manual_seed_all(seed)
|
||
|
|
||
|
torch.backends.cudnn.deterministic = True
|
||
|
torch.backends.cudnn.benchmark = False
|
||
|
|
||
|
|
||
|
def model_provider():
|
||
|
"""Build the model."""
|
||
|
|
||
|
model = CodeGeeXModel(num_tokentypes=0,
|
||
|
parallel_output=False)
|
||
|
|
||
|
return model
|
||
|
|
||
|
|
||
|
def add_code_generation_args(parser):
|
||
|
"""Code generation arguments."""
|
||
|
group = parser.add_argument_group(title="code generation")
|
||
|
|
||
|
group.add_argument(
|
||
|
"--temperature",
|
||
|
type=float,
|
||
|
default=1.0,
|
||
|
help="Sampling temperature.",
|
||
|
)
|
||
|
group.add_argument(
|
||
|
"--greedy",
|
||
|
action="store_true",
|
||
|
default=False,
|
||
|
help="Use greedy sampling.",
|
||
|
)
|
||
|
group.add_argument(
|
||
|
"--top-p",
|
||
|
type=float,
|
||
|
default=0.0,
|
||
|
help="Top p sampling.",
|
||
|
)
|
||
|
group.add_argument(
|
||
|
"--top-k",
|
||
|
type=int,
|
||
|
default=0,
|
||
|
help="Top k sampling.",
|
||
|
)
|
||
|
group.add_argument(
|
||
|
"--out-seq-length",
|
||
|
type=int,
|
||
|
default=2048,
|
||
|
help="Size of the output generated text.",
|
||
|
)
|
||
|
group.add_argument(
|
||
|
"--recompute",
|
||
|
action="store_true",
|
||
|
help="During generation recompute all attention "
|
||
|
"instead of using previously computed keys/values.",
|
||
|
)
|
||
|
group.add_argument(
|
||
|
"--ws-encoding-start-id",
|
||
|
type=int,
|
||
|
default=10,
|
||
|
help="Start id for whitespace encoding",
|
||
|
)
|
||
|
group.add_argument(
|
||
|
"--ws-encoding-length",
|
||
|
type=int,
|
||
|
default=80,
|
||
|
help="Length of whitespace encoding",
|
||
|
)
|
||
|
group.add_argument(
|
||
|
"--n-generation",
|
||
|
type=int,
|
||
|
default=10,
|
||
|
)
|
||
|
group.add_argument(
|
||
|
"--eos-id",
|
||
|
type=int,
|
||
|
default=50256,
|
||
|
)
|
||
|
group.add_argument(
|
||
|
"--prompt-file",
|
||
|
type=str,
|
||
|
default="./test_prompt.txt",
|
||
|
)
|
||
|
group.add_argument(
|
||
|
"--perf-file",
|
||
|
type=str,
|
||
|
default="./perf_out.txt",
|
||
|
)
|
||
|
group.add_argument(
|
||
|
"--perf-trace",
|
||
|
type=str,
|
||
|
default="./perf_out.txt",
|
||
|
)
|
||
|
group.add_argument(
|
||
|
"--use-torch-profile",
|
||
|
action="store_true",
|
||
|
)
|
||
|
group.add_argument(
|
||
|
"--ln-fp32",
|
||
|
action="store_true",
|
||
|
)
|
||
|
group.add_argument(
|
||
|
'--bad-ids',
|
||
|
nargs="*",
|
||
|
type=int,
|
||
|
default=None,
|
||
|
help='Identify the type of programming language to generate',
|
||
|
)
|
||
|
|
||
|
return parser
|
||
|
|
||
|
|
||
|
def main():
|
||
|
os.environ["MASTER_ADDR"] = "localhost"
|
||
|
os.environ["MASTER_PORT"] = str(random.randint(10000, 20000))
|
||
|
|
||
|
initialize_megatron(
|
||
|
extra_args_provider=add_code_generation_args,
|
||
|
)
|
||
|
|
||
|
args = get_args()
|
||
|
set_random_seed(args.seed)
|
||
|
|
||
|
print("Loading tokenizer ...")
|
||
|
tokenizer = get_tokenizer()
|
||
|
|
||
|
print("Loading state dict ...")
|
||
|
state_dict = torch.load(args.load, map_location="cpu")
|
||
|
state_dict = state_dict["module"]
|
||
|
|
||
|
print("Building CodeGeeX model ...")
|
||
|
model = model_provider()
|
||
|
model.load_state_dict(state_dict)
|
||
|
model.eval()
|
||
|
if args.fp16 and args.ln_fp16:
|
||
|
model.half()
|
||
|
model.cuda()
|
||
|
|
||
|
with open(args.prompt_file, "r") as f:
|
||
|
prompt = f.readlines()
|
||
|
prompt = "".join(prompt)
|
||
|
|
||
|
print("Generating ...")
|
||
|
t0 = time.perf_counter()
|
||
|
for prompt in [prompt]:
|
||
|
tokens = tokenizer.tokenize(prompt)
|
||
|
print(tokens)
|
||
|
print("Current prompt:")
|
||
|
print(prompt)
|
||
|
n_token_prompt = len(tokens)
|
||
|
print("N_token_prompt:", n_token_prompt)
|
||
|
token_stream = get_token_stream(
|
||
|
model,
|
||
|
[copy.deepcopy(tokens) for _ in range(args.micro_batch_size)],
|
||
|
micro_batch_size=args.micro_batch_size,
|
||
|
bad_ids=args.bad_ids,
|
||
|
)
|
||
|
is_finished = [False for _ in range(args.micro_batch_size)]
|
||
|
for i, generated in enumerate(token_stream):
|
||
|
generated_tokens = generated[0]
|
||
|
for j in range(args.micro_batch_size):
|
||
|
if is_finished[j]:
|
||
|
continue
|
||
|
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eod or len(
|
||
|
generated_tokens[j]) >= args.out_seq_length:
|
||
|
is_finished[j] = True
|
||
|
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
|
||
|
generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
|
||
|
t1 = time.perf_counter()
|
||
|
print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt)
|
||
|
print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
|
||
|
print("================================= Generated code:")
|
||
|
print(generated_code)
|
||
|
t0 = time.perf_counter()
|
||
|
if all(is_finished):
|
||
|
break
|
||
|
|
||
|
print("Generation finished.")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|