mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
190 lines
4.7 KiB
Python
190 lines
4.7 KiB
Python
import time
|
|
import torch
|
|
import argparse
|
|
import numpy as np
|
|
|
|
import codegeex
|
|
from codegeex.torch import CodeGeeXModel
|
|
from codegeex.tokenizer import CodeGeeXTokenizer
|
|
from codegeex.quantization import quantize
|
|
|
|
|
|
def model_provider(args):
|
|
"""Build the model."""
|
|
|
|
model = CodeGeeXModel(
|
|
args.hidden_size,
|
|
args.num_layers,
|
|
args.num_attention_heads,
|
|
args.padded_vocab_size,
|
|
args.max_position_embeddings
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
def add_code_generation_args(parser):
|
|
group = parser.add_argument_group(title="code generation")
|
|
group.add_argument(
|
|
"--num-layers",
|
|
type=int,
|
|
default=39,
|
|
)
|
|
group.add_argument(
|
|
"--hidden-size",
|
|
type=int,
|
|
default=5120,
|
|
)
|
|
group.add_argument(
|
|
"--num-attention-heads",
|
|
type=int,
|
|
default=40,
|
|
)
|
|
group.add_argument(
|
|
"--padded-vocab-size",
|
|
type=int,
|
|
default=52224,
|
|
)
|
|
group.add_argument(
|
|
"--max-position-embeddings",
|
|
type=int,
|
|
default=2048,
|
|
)
|
|
group.add_argument(
|
|
"--temperature",
|
|
type=float,
|
|
default=1.0,
|
|
help="Sampling temperature.",
|
|
)
|
|
group.add_argument(
|
|
"--greedy",
|
|
action="store_true",
|
|
default=False,
|
|
help="Use greedy sampling.",
|
|
)
|
|
group.add_argument(
|
|
"--top-p",
|
|
type=float,
|
|
default=0.0,
|
|
help="Top p sampling.",
|
|
)
|
|
group.add_argument(
|
|
"--top-k",
|
|
type=int,
|
|
default=0,
|
|
help="Top k sampling.",
|
|
)
|
|
group.add_argument(
|
|
"--out-seq-length",
|
|
type=int,
|
|
default=2048,
|
|
help="Size of the output generated text.",
|
|
)
|
|
group.add_argument(
|
|
"--prompt-file",
|
|
type=str,
|
|
default="./test_prompt.txt",
|
|
)
|
|
group.add_argument(
|
|
"--tokenizer-path",
|
|
type=str,
|
|
default="./tokenizer",
|
|
)
|
|
group.add_argument(
|
|
"--load",
|
|
type=str,
|
|
)
|
|
group.add_argument(
|
|
"--state-dict-path",
|
|
type=str,
|
|
)
|
|
group.add_argument(
|
|
"--micro-batch-size",
|
|
type=int,
|
|
default=1,
|
|
)
|
|
group.add_argument(
|
|
"--quantize",
|
|
action="store_true",
|
|
)
|
|
group.add_argument(
|
|
"--interative",
|
|
action="store_true",
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser = add_code_generation_args(parser)
|
|
args, _ = parser.parse_known_args()
|
|
|
|
print("Loading tokenizer ...")
|
|
tokenizer = CodeGeeXTokenizer(
|
|
tokenizer_path=args.tokenizer_path,
|
|
mode="codegeex-13b")
|
|
|
|
print("Loading state dict ...")
|
|
state_dict = torch.load(args.load, map_location="cpu")
|
|
state_dict = state_dict["module"]
|
|
|
|
print("Building CodeGeeX model ...")
|
|
model = model_provider(args)
|
|
model.load_state_dict(state_dict)
|
|
model.eval()
|
|
model.half()
|
|
if args.quantize:
|
|
model = quantize(model, weight_bit_width=8, backend="torch")
|
|
model.cuda()
|
|
torch.cuda.synchronize()
|
|
|
|
with open(args.prompt_file, "r") as f:
|
|
prompt = f.readlines()
|
|
prompt = "".join(prompt)
|
|
|
|
out_seq_lengths = [args.out_seq_length]
|
|
for out_seq_length in out_seq_lengths:
|
|
print(f"Generating with out_seq_len {out_seq_length}...")
|
|
while True:
|
|
print("\nPlease Input Query (Ctrl-D to save multiple lines, 'stop' to exit) >>> ")
|
|
prompts = []
|
|
while True:
|
|
try:
|
|
line = input()
|
|
except EOFError:
|
|
break
|
|
prompts.append(line)
|
|
prompt = "\n".join(prompts)
|
|
prompt = prompt.strip()
|
|
if not prompt:
|
|
print('Query should not be empty!')
|
|
continue
|
|
if prompt == "stop":
|
|
return
|
|
try:
|
|
t0 = time.perf_counter()
|
|
generated_code = codegeex.generate(
|
|
model,
|
|
tokenizer,
|
|
prompt,
|
|
out_seq_length=out_seq_length,
|
|
seq_length=args.max_position_embeddings,
|
|
top_k=args.top_k,
|
|
top_p=args.top_p,
|
|
temperature=args.temperature,
|
|
micro_batch_size=args.micro_batch_size,
|
|
backend="megatron",
|
|
verbose=True,
|
|
)
|
|
t1 = time.perf_counter()
|
|
print("Total generation time:", t1 - t0)
|
|
except (ValueError, FileNotFoundError) as e:
|
|
print(e)
|
|
continue
|
|
|
|
print("Generation finished.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |