diff --git a/tests/test_inference.py b/tests/test_inference.py index cfda7b9..c4267ab 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1,13 +1,9 @@ - -import os -import copy import time import torch -import random import argparse import numpy as np -from codegeex.torch.inference import get_token_stream +import codegeex from codegeex.torch import CodeGeeXModel from codegeex.tokenizer import CodeGeeXTokenizer from codegeex.quantization import quantize @@ -111,6 +107,10 @@ def add_code_generation_args(parser): "--quantize", action="store_true", ) + group.add_argument( + "--interative", + action="store_true", + ) return parser @@ -137,66 +137,54 @@ def main(): if args.quantize: model = quantize(model, weight_bit_width=8, backend="torch") model.cuda() + torch.cuda.synchronize() with open(args.prompt_file, "r") as f: prompt = f.readlines() prompt = "".join(prompt) - times = {} out_seq_lengths = [args.out_seq_length] - micro_batch_size = args.micro_batch_size - seq_length = args.max_position_embeddings for out_seq_length in out_seq_lengths: print(f"Generating with out_seq_len {out_seq_length}...") - - times[out_seq_length] = [] - for prompt in [prompt]: - t0 = time.perf_counter() - tokens = tokenizer.encode_code(prompt) - print(tokens) - print("Current prompt:") - print(prompt) - n_token_prompt = len(tokens) - print("N_token_prompt:", n_token_prompt) - token_stream = get_token_stream( - model, - tokenizer, - seq_length, - out_seq_length, - [copy.deepcopy(tokens) for _ in range(micro_batch_size)], - micro_batch_size=micro_batch_size, - topk=args.top_k, - topp=args.top_p, - temperature=args.temperature, - greedy=args.greedy, - ) - is_finished = [False for _ in range(micro_batch_size)] - for i, generated in enumerate(token_stream): - generated_tokens = generated[0] - for j in range(micro_batch_size): - if is_finished[j]: - continue - if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len(generated_tokens[j]) >= out_seq_length: - is_finished[j] = True - generated_tokens_ = generated_tokens[j].cpu().numpy().tolist() - generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:]) - generated_code = "".join(generated_code) - t1 = time.perf_counter() - print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt) - print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token") - times[out_seq_length].append(t1 - t0) - print("================================= Generated code:") - print(generated_code) - - if all(is_finished): - break - - print(times) - for out_seq_length in times.keys(): - print(out_seq_length, np.mean(times[out_seq_length])) - + while True: + print("\nPlease Input Query (Ctrl-D to save multiple lines, 'stop' to exit) >>> ") + prompts = [] + while True: + try: + line = input() + except EOFError: + break + prompts.append(line) + prompt = "\n".join(prompts) + prompt = prompt.strip() + if not prompt: + print('Query should not be empty!') + continue + if prompt == "stop": + return + try: + t0 = time.perf_counter() + generated_code = codegeex.generate( + model, + tokenizer, + prompt, + out_seq_length=out_seq_length, + seq_length=args.max_position_embeddings, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + micro_batch_size=args.micro_batch_size, + backend="megatron", + verbose=True, + ) + t1 = time.perf_counter() + print("Total generation time:", t1 - t0) + except (ValueError, FileNotFoundError) as e: + print(e) + continue + print("Generation finished.") if __name__ == "__main__": - main() + main() \ No newline at end of file