|
|
|
@ -1,13 +1,9 @@
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import copy
|
|
|
|
|
import time
|
|
|
|
|
import torch
|
|
|
|
|
import random
|
|
|
|
|
import argparse
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
from codegeex.torch.inference import get_token_stream
|
|
|
|
|
import codegeex
|
|
|
|
|
from codegeex.torch import CodeGeeXModel
|
|
|
|
|
from codegeex.tokenizer import CodeGeeXTokenizer
|
|
|
|
|
from codegeex.quantization import quantize
|
|
|
|
@ -111,6 +107,10 @@ def add_code_generation_args(parser):
|
|
|
|
|
"--quantize",
|
|
|
|
|
action="store_true",
|
|
|
|
|
)
|
|
|
|
|
group.add_argument(
|
|
|
|
|
"--interative",
|
|
|
|
|
action="store_true",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return parser
|
|
|
|
|
|
|
|
|
@ -137,63 +137,51 @@ def main():
|
|
|
|
|
if args.quantize:
|
|
|
|
|
model = quantize(model, weight_bit_width=8, backend="torch")
|
|
|
|
|
model.cuda()
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
|
|
with open(args.prompt_file, "r") as f:
|
|
|
|
|
prompt = f.readlines()
|
|
|
|
|
prompt = "".join(prompt)
|
|
|
|
|
|
|
|
|
|
times = {}
|
|
|
|
|
out_seq_lengths = [args.out_seq_length]
|
|
|
|
|
micro_batch_size = args.micro_batch_size
|
|
|
|
|
seq_length = args.max_position_embeddings
|
|
|
|
|
for out_seq_length in out_seq_lengths:
|
|
|
|
|
print(f"Generating with out_seq_len {out_seq_length}...")
|
|
|
|
|
|
|
|
|
|
times[out_seq_length] = []
|
|
|
|
|
for prompt in [prompt]:
|
|
|
|
|
while True:
|
|
|
|
|
print("\nPlease Input Query (Ctrl-D to save multiple lines, 'stop' to exit) >>> ")
|
|
|
|
|
prompts = []
|
|
|
|
|
while True:
|
|
|
|
|
try:
|
|
|
|
|
line = input()
|
|
|
|
|
except EOFError:
|
|
|
|
|
break
|
|
|
|
|
prompts.append(line)
|
|
|
|
|
prompt = "\n".join(prompts)
|
|
|
|
|
prompt = prompt.strip()
|
|
|
|
|
if not prompt:
|
|
|
|
|
print('Query should not be empty!')
|
|
|
|
|
continue
|
|
|
|
|
if prompt == "stop":
|
|
|
|
|
return
|
|
|
|
|
try:
|
|
|
|
|
t0 = time.perf_counter()
|
|
|
|
|
tokens = tokenizer.encode_code(prompt)
|
|
|
|
|
print(tokens)
|
|
|
|
|
print("Current prompt:")
|
|
|
|
|
print(prompt)
|
|
|
|
|
n_token_prompt = len(tokens)
|
|
|
|
|
print("N_token_prompt:", n_token_prompt)
|
|
|
|
|
token_stream = get_token_stream(
|
|
|
|
|
generated_code = codegeex.generate(
|
|
|
|
|
model,
|
|
|
|
|
tokenizer,
|
|
|
|
|
seq_length,
|
|
|
|
|
out_seq_length,
|
|
|
|
|
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
|
|
|
|
|
micro_batch_size=micro_batch_size,
|
|
|
|
|
topk=args.top_k,
|
|
|
|
|
topp=args.top_p,
|
|
|
|
|
prompt,
|
|
|
|
|
out_seq_length=out_seq_length,
|
|
|
|
|
seq_length=args.max_position_embeddings,
|
|
|
|
|
top_k=args.top_k,
|
|
|
|
|
top_p=args.top_p,
|
|
|
|
|
temperature=args.temperature,
|
|
|
|
|
greedy=args.greedy,
|
|
|
|
|
micro_batch_size=args.micro_batch_size,
|
|
|
|
|
backend="megatron",
|
|
|
|
|
verbose=True,
|
|
|
|
|
)
|
|
|
|
|
is_finished = [False for _ in range(micro_batch_size)]
|
|
|
|
|
for i, generated in enumerate(token_stream):
|
|
|
|
|
generated_tokens = generated[0]
|
|
|
|
|
for j in range(micro_batch_size):
|
|
|
|
|
if is_finished[j]:
|
|
|
|
|
continue
|
|
|
|
|
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len(generated_tokens[j]) >= out_seq_length:
|
|
|
|
|
is_finished[j] = True
|
|
|
|
|
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
|
|
|
|
|
generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:])
|
|
|
|
|
generated_code = "".join(generated_code)
|
|
|
|
|
t1 = time.perf_counter()
|
|
|
|
|
print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt)
|
|
|
|
|
print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
|
|
|
|
|
times[out_seq_length].append(t1 - t0)
|
|
|
|
|
print("================================= Generated code:")
|
|
|
|
|
print(generated_code)
|
|
|
|
|
|
|
|
|
|
if all(is_finished):
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
print(times)
|
|
|
|
|
for out_seq_length in times.keys():
|
|
|
|
|
print(out_seq_length, np.mean(times[out_seq_length]))
|
|
|
|
|
print("Total generation time:", t1 - t0)
|
|
|
|
|
except (ValueError, FileNotFoundError) as e:
|
|
|
|
|
print(e)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
print("Generation finished.")
|
|
|
|
|
|
|
|
|
|