You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
CodeGeeX/tests/test_inference_megatron.py

223 lines
6.4 KiB
Python

import copy
import time
import torch
import numpy as np
from codegeex.megatron import get_tokenizer, get_args, print_rank_0
from codegeex.megatron.initialize import initialize_megatron
from codegeex.megatron.model import CodeGeeXModel
from codegeex.megatron.code_generation_utils import get_token_stream
from codegeex.quantization import quantize
from codegeex.megatron.training import get_model
from codegeex.megatron.checkpointing import load_checkpoint
torch.set_printoptions(precision=8)
def set_random_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0("Building CodeGeeX model ...")
model = CodeGeeXModel(num_tokentypes=0,
parallel_output=False)
return model
def add_code_generation_args(parser):
"""Code generation arguments."""
group = parser.add_argument_group(title="code generation")
group.add_argument(
"--temperature",
type=float,
default=1.0,
help="Sampling temperature.",
)
group.add_argument(
"--greedy",
action="store_true",
default=False,
help="Use greedy sampling.",
)
group.add_argument(
"--top-p",
type=float,
default=0.0,
help="Top p sampling.",
)
group.add_argument(
"--top-k",
type=int,
default=0,
help="Top k sampling.",
)
group.add_argument(
"--out-seq-length",
type=int,
default=2048,
help="Size of the output generated text.",
)
group.add_argument(
"--recompute",
action="store_true",
help="During generation recompute all attention "
"instead of using previously computed keys/values.",
)
group.add_argument(
"--ws-encoding-start-id",
type=int,
default=10,
help="Start id for whitespace encoding",
)
group.add_argument(
"--ws-encoding-length",
type=int,
default=10,
help="Length of whitespace encoding",
)
group.add_argument(
"--n-generation",
type=int,
default=10,
)
group.add_argument(
"--eos-id",
type=int,
default=50256,
)
group.add_argument(
"--prompt-file",
type=str,
default="./test_prompt.txt",
)
group.add_argument(
"--perf-file",
type=str,
default="./perf_out.txt",
)
group.add_argument(
"--perf-trace",
type=str,
default="./perf_out.txt",
)
group.add_argument(
"--use-torch-profile",
action="store_true",
)
group.add_argument(
"--ln-fp32",
action="store_true",
)
group.add_argument(
'--bad-ids',
nargs="*",
type=int,
default=None,
help='Identify the type of programming language to generate',
)
group.add_argument(
"--quantize",
action="store_true",
)
return parser
def main():
initialize_megatron(
extra_args_provider=add_code_generation_args,
args_defaults={
'no_load_rng': True,
'no_load_optim': True,
}
)
args = get_args()
set_random_seed(args.seed)
print_rank_0("Loading tokenizer ...")
tokenizer = get_tokenizer()
print_rank_0("Loading state dict ...")
model = get_model(model_provider)
if args.load is not None:
_ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
model.eval()
if args.fp16 and args.ln_fp16:
model.half()
if args.quantize:
model = quantize(model, weight_bit_width=8, backend="megatron")
with open(args.prompt_file, "r") as f:
prompt = f.readlines()
prompt = "".join(prompt)
times = {}
out_seq_lengths = [args.out_seq_length]
micro_batch_size = args.micro_batch_size
for out_seq_length in out_seq_lengths:
print_rank_0(f"Generating with out_seq_len {out_seq_length}...")
times[out_seq_length] = []
for prompt in [prompt] * args.n_generation:
t0 = time.perf_counter()
tokens = tokenizer.tokenize(prompt)
print_rank_0(tokens)
print_rank_0("Current prompt:")
print_rank_0(prompt)
n_token_prompt = len(tokens)
print_rank_0(f"N_token_prompt:{n_token_prompt}")
token_stream = get_token_stream(
model,
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
micro_batch_size=micro_batch_size,
topk=args.top_k,
topp=args.top_p,
temperature=args.temperature,
)
is_finished = [False for _ in range(micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]
for j in range(micro_batch_size):
if is_finished[j]:
continue
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eod or len(
generated_tokens[j]) >= out_seq_length:
is_finished[j] = True
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
t1 = time.perf_counter()
print_rank_0(f"Total generation time: {t1 - t0}, # Tokens: {len(generated_tokens_) - n_token_prompt}")
print_rank_0(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
times[out_seq_length].append(t1 - t0)
print_rank_0("================================= Generated code:")
print_rank_0(generated_code)
t0 = time.perf_counter()
if all(is_finished):
break
print_rank_0(times)
for out_seq_length in times.keys():
print_rank_0(f"{out_seq_length}, {np.mean(times[out_seq_length])}")
print_rank_0("Generation finished.")
if __name__ == "__main__":
main()