refactor test inference

pull/69/head
Stanislas0 2 years ago
parent 7a7a59c16c
commit 843a946f41

@ -1,13 +1,9 @@
import os
import copy
import time
import torch
import random
import argparse
import numpy as np
from codegeex.torch.inference import get_token_stream
import codegeex
from codegeex.torch import CodeGeeXModel
from codegeex.tokenizer import CodeGeeXTokenizer
from codegeex.quantization import quantize
@ -111,6 +107,10 @@ def add_code_generation_args(parser):
"--quantize",
action="store_true",
)
group.add_argument(
"--interative",
action="store_true",
)
return parser
@ -137,63 +137,51 @@ def main():
if args.quantize:
model = quantize(model, weight_bit_width=8, backend="torch")
model.cuda()
torch.cuda.synchronize()
with open(args.prompt_file, "r") as f:
prompt = f.readlines()
prompt = "".join(prompt)
times = {}
out_seq_lengths = [args.out_seq_length]
micro_batch_size = args.micro_batch_size
seq_length = args.max_position_embeddings
for out_seq_length in out_seq_lengths:
print(f"Generating with out_seq_len {out_seq_length}...")
times[out_seq_length] = []
for prompt in [prompt]:
while True:
print("\nPlease Input Query (Ctrl-D to save multiple lines, 'stop' to exit) >>> ")
prompts = []
while True:
try:
line = input()
except EOFError:
break
prompts.append(line)
prompt = "\n".join(prompts)
prompt = prompt.strip()
if not prompt:
print('Query should not be empty!')
continue
if prompt == "stop":
return
try:
t0 = time.perf_counter()
tokens = tokenizer.encode_code(prompt)
print(tokens)
print("Current prompt:")
print(prompt)
n_token_prompt = len(tokens)
print("N_token_prompt:", n_token_prompt)
token_stream = get_token_stream(
generated_code = codegeex.generate(
model,
tokenizer,
seq_length,
out_seq_length,
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
micro_batch_size=micro_batch_size,
topk=args.top_k,
topp=args.top_p,
prompt,
out_seq_length=out_seq_length,
seq_length=args.max_position_embeddings,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
greedy=args.greedy,
micro_batch_size=args.micro_batch_size,
backend="megatron",
verbose=True,
)
is_finished = [False for _ in range(micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]
for j in range(micro_batch_size):
if is_finished[j]:
continue
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len(generated_tokens[j]) >= out_seq_length:
is_finished[j] = True
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:])
generated_code = "".join(generated_code)
t1 = time.perf_counter()
print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt)
print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
times[out_seq_length].append(t1 - t0)
print("================================= Generated code:")
print(generated_code)
if all(is_finished):
break
print(times)
for out_seq_length in times.keys():
print(out_seq_length, np.mean(times[out_seq_length]))
print("Total generation time:", t1 - t0)
except (ValueError, FileNotFoundError) as e:
print(e)
continue
print("Generation finished.")

Loading…
Cancel
Save