refactor test inference

pull/69/head
Stanislas0 2 years ago
parent 7a7a59c16c
commit 843a946f41

@ -1,13 +1,9 @@
import os
import copy
import time import time
import torch import torch
import random
import argparse import argparse
import numpy as np import numpy as np
from codegeex.torch.inference import get_token_stream import codegeex
from codegeex.torch import CodeGeeXModel from codegeex.torch import CodeGeeXModel
from codegeex.tokenizer import CodeGeeXTokenizer from codegeex.tokenizer import CodeGeeXTokenizer
from codegeex.quantization import quantize from codegeex.quantization import quantize
@ -111,6 +107,10 @@ def add_code_generation_args(parser):
"--quantize", "--quantize",
action="store_true", action="store_true",
) )
group.add_argument(
"--interative",
action="store_true",
)
return parser return parser
@ -137,63 +137,51 @@ def main():
if args.quantize: if args.quantize:
model = quantize(model, weight_bit_width=8, backend="torch") model = quantize(model, weight_bit_width=8, backend="torch")
model.cuda() model.cuda()
torch.cuda.synchronize()
with open(args.prompt_file, "r") as f: with open(args.prompt_file, "r") as f:
prompt = f.readlines() prompt = f.readlines()
prompt = "".join(prompt) prompt = "".join(prompt)
times = {}
out_seq_lengths = [args.out_seq_length] out_seq_lengths = [args.out_seq_length]
micro_batch_size = args.micro_batch_size
seq_length = args.max_position_embeddings
for out_seq_length in out_seq_lengths: for out_seq_length in out_seq_lengths:
print(f"Generating with out_seq_len {out_seq_length}...") print(f"Generating with out_seq_len {out_seq_length}...")
while True:
times[out_seq_length] = [] print("\nPlease Input Query (Ctrl-D to save multiple lines, 'stop' to exit) >>> ")
for prompt in [prompt]: prompts = []
t0 = time.perf_counter() while True:
tokens = tokenizer.encode_code(prompt) try:
print(tokens) line = input()
print("Current prompt:") except EOFError:
print(prompt) break
n_token_prompt = len(tokens) prompts.append(line)
print("N_token_prompt:", n_token_prompt) prompt = "\n".join(prompts)
token_stream = get_token_stream( prompt = prompt.strip()
model, if not prompt:
tokenizer, print('Query should not be empty!')
seq_length, continue
out_seq_length, if prompt == "stop":
[copy.deepcopy(tokens) for _ in range(micro_batch_size)], return
micro_batch_size=micro_batch_size, try:
topk=args.top_k, t0 = time.perf_counter()
topp=args.top_p, generated_code = codegeex.generate(
temperature=args.temperature, model,
greedy=args.greedy, tokenizer,
) prompt,
is_finished = [False for _ in range(micro_batch_size)] out_seq_length=out_seq_length,
for i, generated in enumerate(token_stream): seq_length=args.max_position_embeddings,
generated_tokens = generated[0] top_k=args.top_k,
for j in range(micro_batch_size): top_p=args.top_p,
if is_finished[j]: temperature=args.temperature,
continue micro_batch_size=args.micro_batch_size,
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len(generated_tokens[j]) >= out_seq_length: backend="megatron",
is_finished[j] = True verbose=True,
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist() )
generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:]) t1 = time.perf_counter()
generated_code = "".join(generated_code) print("Total generation time:", t1 - t0)
t1 = time.perf_counter() except (ValueError, FileNotFoundError) as e:
print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt) print(e)
print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token") continue
times[out_seq_length].append(t1 - t0)
print("================================= Generated code:")
print(generated_code)
if all(is_finished):
break
print(times)
for out_seq_length in times.keys():
print(out_seq_length, np.mean(times[out_seq_length]))
print("Generation finished.") print("Generation finished.")

Loading…
Cancel
Save