mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
245 lines
11 KiB
Python
245 lines
11 KiB
Python
2 years ago
|
import copy
|
||
|
import json
|
||
|
import random
|
||
|
import traceback
|
||
|
from typing import *
|
||
|
|
||
|
import numpy
|
||
|
import torch
|
||
|
import zmq
|
||
|
|
||
|
from codegeex.benchmark.utils import is_code_generation_finished, cleanup_code
|
||
|
from codegeex.megatron import get_args, get_tokenizer
|
||
|
from codegeex.megatron import mpu
|
||
|
from codegeex.megatron.code_generation_utils import get_token_stream
|
||
|
from codegeex.megatron.model import CodeGeeXModel
|
||
|
|
||
|
|
||
|
def model_provider():
|
||
|
"""Build the model."""
|
||
|
|
||
|
model = CodeGeeXModel(num_tokentypes=0,
|
||
|
parallel_output=False)
|
||
|
|
||
|
return model
|
||
|
|
||
|
|
||
|
def set_random_seed(seed):
|
||
|
"""Set random seed for reproducability."""
|
||
|
random.seed(seed)
|
||
|
numpy.random.seed(seed)
|
||
|
torch.manual_seed(seed)
|
||
|
mpu.model_parallel_cuda_manual_seed(seed)
|
||
|
|
||
|
|
||
|
def run_generation_distributed(model):
|
||
|
args = get_args()
|
||
|
if hasattr(args, "language_tgt_type"):
|
||
|
language_type = args.language_tgt_type
|
||
|
else:
|
||
|
language_type = args.language_type
|
||
|
print(f"Connecting to tcp://{args.channel_ip}:{args.channel_port}")
|
||
|
context = zmq.Context()
|
||
|
socket = context.socket(zmq.REQ)
|
||
|
socket.connect(f"tcp://{args.channel_ip}:{args.channel_port}")
|
||
|
output_file_path = args.output_prefix + f"_finished_rank{args.gen_rank}.jsonl"
|
||
|
unfinished_output_file_path = args.output_prefix + f"_unfinished_rank{args.gen_rank}.jsonl"
|
||
|
problems = {}
|
||
|
print("Building tokenizer...")
|
||
|
tokenizer = get_tokenizer()
|
||
|
|
||
|
with open(output_file_path, "w") as f:
|
||
|
with open(unfinished_output_file_path, "w") as unfinished_f:
|
||
|
while True:
|
||
|
socket.send_json({"rank": args.gen_rank, "action": "pull"})
|
||
|
resp = socket.recv_json()
|
||
|
try:
|
||
|
if "codecontest" in args.dataset.lower():
|
||
|
if resp["contest_name"] is None:
|
||
|
break
|
||
|
elif resp["task_id"] is None:
|
||
|
break
|
||
|
|
||
|
if "codecontest" in args.dataset.lower():
|
||
|
current_spec = problems[resp["contest_name"]]
|
||
|
prompt = current_spec.prompt
|
||
|
else:
|
||
|
current_spec = resp["task_id"]
|
||
|
prompt = current_spec["prompt"]
|
||
|
|
||
|
temperature = None if "temperature" not in resp else resp["temperature"]
|
||
|
topp = None if "topp" not in resp else resp["topp"]
|
||
|
|
||
|
f.flush()
|
||
|
unfinished_f.flush()
|
||
|
tokens = tokenizer.tokenize(prompt)
|
||
|
n_token_prompt = len(tokens)
|
||
|
if n_token_prompt >= args.seq_length:
|
||
|
continue
|
||
|
if "micro_batch_size" in resp:
|
||
|
micro_batch_size = resp["micro_batch_size"]
|
||
|
else:
|
||
|
micro_batch_size = args.micro_batch_size
|
||
|
if args.beam_search:
|
||
|
beams = get_token_stream(
|
||
|
model,
|
||
|
[
|
||
|
copy.deepcopy(tokens)
|
||
|
for _ in range(micro_batch_size)
|
||
|
],
|
||
|
return_scores=args.return_scores,
|
||
|
prompt_length=n_token_prompt,
|
||
|
micro_batch_size=micro_batch_size,
|
||
|
bad_ids=args.bad_ids,
|
||
|
temperature=temperature,
|
||
|
topp=topp,
|
||
|
beam_warmup=args.beam_warmup,
|
||
|
)
|
||
|
for beam in beams:
|
||
|
generated_tokens_ = beam.tokens
|
||
|
generated_tokens_ = (
|
||
|
generated_tokens_
|
||
|
if generated_tokens_[-1] != tokenizer.eod
|
||
|
else generated_tokens_[:-1]
|
||
|
)
|
||
|
generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
|
||
|
generated_code = cleanup_code(generated_code,
|
||
|
language_type=language_type,
|
||
|
dataset=args.dataset)
|
||
|
f.write(
|
||
|
json.dumps(
|
||
|
{
|
||
|
"task_id" : current_spec['task_id'],
|
||
|
"prompt" : prompt,
|
||
|
"generation": generated_code,
|
||
|
"scores" : beam.score,
|
||
|
"finish" : 2 if generated_tokens[i].cpu().numpy()[
|
||
|
-1] == tokenizer.eod else 1,
|
||
|
"output" : beam.tokens,
|
||
|
}
|
||
|
)
|
||
|
+ "\n"
|
||
|
)
|
||
|
socket.send_json(
|
||
|
{
|
||
|
"rank" : args.gen_rank,
|
||
|
"action" : "success",
|
||
|
"task_id": current_spec['task_id']
|
||
|
}
|
||
|
)
|
||
|
socket.recv()
|
||
|
continue
|
||
|
|
||
|
token_stream = get_token_stream(
|
||
|
model,
|
||
|
[
|
||
|
copy.deepcopy(tokens)
|
||
|
for _ in range(micro_batch_size)
|
||
|
],
|
||
|
return_scores=args.return_scores,
|
||
|
prompt_length=n_token_prompt,
|
||
|
micro_batch_size=micro_batch_size,
|
||
|
bad_ids=args.bad_ids,
|
||
|
temperature=temperature,
|
||
|
topp=topp,
|
||
|
beam_warmup=args.beam_warmup,
|
||
|
)
|
||
|
is_finished = [False for _ in range(micro_batch_size)]
|
||
|
for generated in token_stream:
|
||
|
generated_tokens = generated[0]
|
||
|
if args.return_scores:
|
||
|
scores = generated[1][1]
|
||
|
else:
|
||
|
scores = None
|
||
|
|
||
|
for i in range(micro_batch_size):
|
||
|
if is_finished[i]:
|
||
|
continue
|
||
|
|
||
|
generated_tokens_ = generated_tokens[i].cpu().numpy().tolist()
|
||
|
generated_tokens_ = (
|
||
|
generated_tokens_
|
||
|
if generated_tokens_[-1] != tokenizer.eod
|
||
|
else generated_tokens_[:-1]
|
||
|
)
|
||
|
generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
|
||
|
if generated_tokens[i].cpu().numpy()[-1] == tokenizer.eod or \
|
||
|
is_code_generation_finished(
|
||
|
generated_code,
|
||
|
language_type=language_type,
|
||
|
dataset=args.dataset,
|
||
|
):
|
||
|
is_finished[i] = True
|
||
|
generated_code = cleanup_code(generated_code,
|
||
|
language_type=language_type,
|
||
|
dataset=args.dataset)
|
||
|
f.write(
|
||
|
json.dumps(
|
||
|
{
|
||
|
"task_id" : current_spec['task_id'],
|
||
|
"prompt" : prompt,
|
||
|
"generation": generated_code,
|
||
|
"scores" : 0.0 if scores is None else scores[i].detach().cpu().item(),
|
||
|
"finish" : 2 if generated_tokens[i].cpu().numpy()[
|
||
|
-1] == tokenizer.eod else 1,
|
||
|
"output" : generated_tokens[i].cpu().numpy().tolist(),
|
||
|
}
|
||
|
)
|
||
|
+ "\n"
|
||
|
)
|
||
|
|
||
|
if len(generated_tokens[i]) >= args.out_seq_length:
|
||
|
break
|
||
|
|
||
|
if all(is_finished):
|
||
|
break
|
||
|
|
||
|
for i in range(micro_batch_size):
|
||
|
if not is_finished[i]:
|
||
|
generated_tokens_ = generated_tokens[i].cpu().numpy().tolist()
|
||
|
generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
|
||
|
unfinished_f.write(
|
||
|
json.dumps(
|
||
|
{
|
||
|
"task_id" : current_spec['task_id'],
|
||
|
"prompt" : prompt,
|
||
|
"generation": generated_code,
|
||
|
"scores" : 0.0 if scores is None else scores[i].detach().cpu().item(),
|
||
|
"finish" : 0,
|
||
|
"output" : generated_tokens_,
|
||
|
}
|
||
|
)
|
||
|
+ "\n"
|
||
|
)
|
||
|
|
||
|
socket.send_json(
|
||
|
{
|
||
|
"rank" : args.gen_rank,
|
||
|
"action" : "success",
|
||
|
"task_id": current_spec['task_id']
|
||
|
}
|
||
|
)
|
||
|
socket.recv()
|
||
|
|
||
|
except Exception as e:
|
||
|
print(f"*** (rank={args.gen_rank}) crashed.")
|
||
|
print(f" error: {repr(e)}")
|
||
|
traceback.print_exc()
|
||
|
if args.dataset.lower() == "codecontest":
|
||
|
socket.send_json({
|
||
|
"rank" : args.gen_rank,
|
||
|
"action" : "fail",
|
||
|
"contest_name" : current_spec.name,
|
||
|
"micro_batch_size": micro_batch_size
|
||
|
})
|
||
|
else:
|
||
|
socket.send_json(
|
||
|
{
|
||
|
"rank" : args.gen_rank,
|
||
|
"action" : "fail",
|
||
|
"task_id": current_spec['task_id']
|
||
|
}
|
||
|
)
|
||
|
socket.recv()
|
||
|
continue
|