You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
CodeGeeX/tests/test_inference.py

190 lines
4.7 KiB
Python

import time
import torch
import argparse
import numpy as np
import codegeex
from codegeex.torch import CodeGeeXModel
from codegeex.tokenizer import CodeGeeXTokenizer
from codegeex.quantization import quantize
def model_provider(args):
"""Build the model."""
model = CodeGeeXModel(
args.hidden_size,
args.num_layers,
args.num_attention_heads,
args.padded_vocab_size,
args.max_position_embeddings
)
return model
def add_code_generation_args(parser):
group = parser.add_argument_group(title="code generation")
group.add_argument(
"--num-layers",
type=int,
default=39,
)
group.add_argument(
"--hidden-size",
type=int,
default=5120,
)
group.add_argument(
"--num-attention-heads",
type=int,
default=40,
)
group.add_argument(
"--padded-vocab-size",
type=int,
default=52224,
)
group.add_argument(
"--max-position-embeddings",
type=int,
default=2048,
)
group.add_argument(
"--temperature",
type=float,
default=1.0,
help="Sampling temperature.",
)
group.add_argument(
"--greedy",
action="store_true",
default=False,
help="Use greedy sampling.",
)
group.add_argument(
"--top-p",
type=float,
default=0.0,
help="Top p sampling.",
)
group.add_argument(
"--top-k",
type=int,
default=0,
help="Top k sampling.",
)
group.add_argument(
"--out-seq-length",
type=int,
default=2048,
help="Size of the output generated text.",
)
group.add_argument(
"--prompt-file",
type=str,
default="./test_prompt.txt",
)
group.add_argument(
"--tokenizer-path",
type=str,
default="./tokenizer",
)
group.add_argument(
"--load",
type=str,
)
group.add_argument(
"--state-dict-path",
type=str,
)
group.add_argument(
"--micro-batch-size",
type=int,
default=1,
)
group.add_argument(
"--quantize",
action="store_true",
)
group.add_argument(
"--interative",
action="store_true",
)
return parser
def main():
parser = argparse.ArgumentParser()
parser = add_code_generation_args(parser)
args, _ = parser.parse_known_args()
print("Loading tokenizer ...")
tokenizer = CodeGeeXTokenizer(
tokenizer_path=args.tokenizer_path,
mode="codegeex-13b")
print("Loading state dict ...")
state_dict = torch.load(args.load, map_location="cpu")
state_dict = state_dict["module"]
print("Building CodeGeeX model ...")
model = model_provider(args)
model.load_state_dict(state_dict)
model.eval()
model.half()
if args.quantize:
model = quantize(model, weight_bit_width=8, backend="torch")
model.cuda()
torch.cuda.synchronize()
with open(args.prompt_file, "r") as f:
prompt = f.readlines()
prompt = "".join(prompt)
out_seq_lengths = [args.out_seq_length]
for out_seq_length in out_seq_lengths:
print(f"Generating with out_seq_len {out_seq_length}...")
while True:
print("\nPlease Input Query (Ctrl-D to save multiple lines, 'stop' to exit) >>> ")
prompts = []
while True:
try:
line = input()
except EOFError:
break
prompts.append(line)
prompt = "\n".join(prompts)
prompt = prompt.strip()
if not prompt:
print('Query should not be empty!')
continue
if prompt == "stop":
return
try:
t0 = time.perf_counter()
generated_code = codegeex.generate(
model,
tokenizer,
prompt,
out_seq_length=out_seq_length,
seq_length=args.max_position_embeddings,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
micro_batch_size=args.micro_batch_size,
backend="megatron",
verbose=True,
)
t1 = time.perf_counter()
print("Total generation time:", t1 - t0)
except (ValueError, FileNotFoundError) as e:
print(e)
continue
print("Generation finished.")
if __name__ == "__main__":
main()