You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

481 lines
13 KiB
Python

import os
import zmq
import time
import torch
import random
import socket
import logging
import argparse
from typing import *
from codegeex.benchmark.utils import read_translation_dataset
from codegeex.megatron import get_args
from codegeex.megatron.inference import run_generation_distributed, model_provider
from codegeex.megatron.initialize import initialize_megatron
logging.getLogger("torch").setLevel(logging.WARNING)
def add_code_generate_args(parser):
"""Code generation arguments."""
group = parser.add_argument_group(title="code generation")
group.add_argument(
"--hostfile",
type=str,
default="./hostfile",
)
group.add_argument(
"--channel-ip",
type=str,
default=None,
help="IP for ZeroMQ channel",
)
group.add_argument(
"--channel-port",
type=int,
default=5555,
help="Port for ZeroMQ channel",
)
group.add_argument(
"--master-port",
type=int,
default=5666,
)
group.add_argument(
"--temperature",
type=float,
default=1.0,
help="Sampling temperature.",
)
group.add_argument(
"--greedy",
action="store_true",
default=False,
help="Use greedy sampling.",
)
group.add_argument(
"--top-p",
type=float,
default=0.0,
help="Top p sampling.",
)
group.add_argument(
"--top-k",
type=int,
default=0,
help="Top k sampling.",
)
group.add_argument(
"--out-seq-length",
type=int,
default=1024,
help="Size of the output generated text.",
)
group.add_argument(
"--input-path",
type=str,
default="./benchmark/humaneval/HumanEval.jsonl",
help="Get input path",
)
group.add_argument(
"--num-samples",
type=int,
default=0,
help="Number of samples to generate",
)
group.add_argument(
"--recompute",
action="store_true",
help="During generation recompute all attention "
"instead of using previously computed keys/values.",
)
group.add_argument(
"--load-deepspeed",
action="store_true",
help="Load DeepSpeed checkpoint",
)
group.add_argument(
"--ws-encoding-start-id",
type=int,
default=None,
help="Start id for whitespace encoding",
)
group.add_argument(
"--ws-encoding-length",
type=int,
default=None,
help="Length of whitespace encoding",
)
group.add_argument(
"--dataset",
type=str,
default="humaneval",
)
group.add_argument(
"--samples-per-problem",
type=int,
default=200,
help="Number of samples to generate for each problem",
)
group.add_argument(
"--output-prefix",
type=str,
default="./output/humaneval",
help="Prefix for output files",
)
group.add_argument(
"--gen-node-rank",
type=int,
default=None,
)
group.add_argument(
"--gen-node-world-size",
type=int,
default=None,
)
group.add_argument(
"--gen-world-size",
type=int,
default=1,
help="Number of machines to use for generation",
)
group.add_argument(
"--gen-rank",
type=int,
default=0,
help="Machine rank for human eval generation",
)
group.add_argument(
"--extra-prompt",
type=str,
default=None,
help="Extra prompt to use for human eval generation",
)
group.add_argument(
"--verbose-interval",
type=int,
default=100,
)
group.add_argument(
"--problem-split",
type=str,
default="test",
)
group.add_argument(
"--prompt-type",
type=str,
default="notag",
)
group.add_argument(
"--num-devices-per-node",
type=int,
default=None,
)
group.add_argument(
"--return-scores",
action="store_true",
)
group.add_argument(
"--free-guidance",
action="store_true",
)
group.add_argument(
"--guide-temp",
type=float,
default=1.5,
)
group.add_argument(
"--attention-upweight",
type=float,
default=None,
)
group.add_argument(
'--bad-ids',
nargs="*",
type=int,
default=None,
help='Identify the type of programming language to generate',
)
group.add_argument(
"--src-path",
type=str,
default="",
help="Get source path",
)
group.add_argument(
"--tgt-path",
type=str,
default="",
help="Get target path",
)
group.add_argument(
'--language-src-type',
type=str,
default=None,
help='Identify the type of programming language',
)
group.add_argument(
'--language-tgt-type',
type=str,
default=None,
help='Identify the type of programming language to translate',
)
return parser
def main(node_rank: int, local_rank: int, master_port: int, num_devices: int):
"""Main program."""
os.environ["WORLD_SIZE"] = str(num_devices)
os.environ["RANK"] = str(local_rank)
os.environ["MASTER_ADDR"] = "0.0.0.0"
os.environ["MASTER_PORT"] = f"{master_port}"
initialize_megatron(
extra_args_provider=add_code_generate_args,
args_defaults={
"tokenizer_type": "GPT2BPETokenizer",
"no_load_rng" : True,
"no_load_optim" : True,
},
)
# set_random_seed(node_rank * num_devices + local_rank)
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
world_size = args.gen_node_world_size * num_devices
args.gen_rank = num_devices * node_rank + local_rank
args.gen_world_size = world_size
print(f"Generating on rank {args.gen_rank} of {args.gen_world_size}")
# Set up model and load checkpoint.
state_dict = torch.load(args.load, map_location="cpu")
state_dict = state_dict["module"]
print("Building CodeGeeX model ...")
model = model_provider()
model.load_state_dict(state_dict)
model.eval()
if args.fp16 and args.ln_fp16:
model.half()
model.cuda()
# Generate samples.
run_generation_distributed(model)
print(f"(gen_rank={args.gen_rank}, rank={local_rank}) finished, waiting ...")
torch.distributed.barrier()
def server():
print(f"[ server ] starting ...", flush=True)
parser = argparse.ArgumentParser()
parser.add_argument(
"--channel-ip",
type=str,
default=None,
help="IP for ZeroMQ channel",
)
parser.add_argument(
"--channel-port",
type=int,
default=5555,
help="Port for ZeroMQ channel",
)
parser.add_argument(
"--master-port",
type=int,
default=6666,
help="Port for torch distributed",
)
parser.add_argument(
"--samples-per-problem",
type=int,
default=200,
help="Number of samples to generate for each problem",
)
parser.add_argument(
"--gen-node-world-size",
type=int,
default=1,
help="Number of machines to use for generation",
)
parser.add_argument(
"--src-path",
type=str,
default="",
help="Get source path",
)
parser.add_argument(
"--tgt-path",
type=str,
default="",
help="Get target path",
)
parser.add_argument(
"--problem-split",
type=str,
default="test",
)
parser.add_argument(
"--micro-batch-size",
type=int,
default=1,
)
parser.add_argument(
'--language-src-type',
type=str,
default=None,
help='Identify the type of programming language',
)
parser.add_argument(
'--language-tgt-type',
type=str,
default=None,
help='Identify the type of programming language to translate',
)
args = parser.parse_known_args()[0]
entries = read_translation_dataset(args.src_path,
args.tgt_path,
lang_src=args.language_src_type,
lang_tgt=args.language_tgt_type,
dataset_type="humaneval")
assert args.samples_per_problem % args.micro_batch_size == 0, "samples_per_problem should be divisible by micro_batch_size"
res = []
for entry in entries.values():
res.extend([entry] * (args.samples_per_problem // args.micro_batch_size))
random.shuffle(res)
all_entries = res
# setup zeromq channel
print(f"[ server ] starting up on port {args.channel_port}", flush=True)
context = zmq.Context()
print(f"[ server ] creating socket", flush=True)
socket = context.socket(zmq.REP)
print(f"[ server ] binding to port {args.channel_port}", flush=True)
socket.bind(f"tcp://*:{args.channel_port}")
print(
f"[ server ] loaded {len(entries)} entries, generated {len(all_entries)} samples",
flush=True,
)
remaining_entries = all_entries.copy()
running_workers = args.gen_node_world_size * torch.cuda.device_count()
num_finished = 0
print(f"[ server ] listening for requests ...", flush=True)
start_time = time.perf_counter()
while True:
# Wait for next request from client
msg = socket.recv_json()
rank = msg["rank"]
action = msg["action"]
if action == "pull":
if len(remaining_entries) == 0:
print(f"[ server ] Shutting down worker {rank}", flush=True)
socket.send_json({"task_id": None})
running_workers -= 1
if running_workers == 0:
print(f"[ server ] All workers finished", flush=True)
break
else:
entry = remaining_entries.pop()
time_elapsed = time.perf_counter() - start_time
print(f"[ server ] Sending entry {entry['task_id']} to worker {rank}", flush=True)
remaining = (
len(remaining_entries)
/ (len(all_entries) - len(remaining_entries))
* time_elapsed
)
time_per_sampple = 0.0 if num_finished == 0 else time_elapsed / num_finished / args.micro_batch_size
print(
f"[ server ] total {len(all_entries)}, assigned {len(all_entries) - len(remaining_entries)}, "
f"finished {num_finished}, "
f"elapsed {time_elapsed:.4f}",
f"speed {time_per_sampple:.4f}s/sample",
f"remaining {remaining:.4f}",
flush=True,
)
socket.send_json({"task_id": entry})
else:
if action == "success":
print(f"[ server ] {msg['task_id']} is finished", flush=True)
socket.send_json({"pong": 1})
else:
print(f"[ server ] {msg['task_id']} is not finished", flush=True)
remaining_entries.append(msg['task_id'])
socket.send_json({"pong": 1})
break
num_finished += 1
if __name__ == "__main__":
torch.multiprocessing.set_start_method("spawn")
parser = argparse.ArgumentParser()
parser.add_argument(
"--hostfile",
type=str,
default="./hostfile",
)
parser.add_argument(
"--master-port",
type=int,
default=5666,
)
parser.add_argument(
"--seed",
type=int,
default=42,
)
args = parser.parse_known_args()[0]
print("start method: " + torch.multiprocessing.get_start_method())
processes = []
num_devices = torch.cuda.device_count()
hosts = open(args.hostfile, "r").readlines()
hosts = [host.strip() for host in hosts]
master_port = args.master_port
node_rank = None
for i in range(len(hosts)):
if hosts[i] == socket.gethostbyname(socket.gethostname()):
node_rank = i
break
assert (
node_rank is not None
), f"Could not find hostname ({socket.gethostbyname(socket.gethostname())}) in hostlist"
# launch server
if socket.gethostbyname(socket.gethostname()) == hosts[0]:
server_process = torch.multiprocessing.Process(target=server)
print(f"Launching server ...")
server_process.start()
processes.append(server_process)
for i in range(num_devices):
local_rank = i
print(f"launching local rank {i}")
p = torch.multiprocessing.Process(
target=main,
args=(node_rank, local_rank, master_port, num_devices),
)
p.start()
processes.append(p)
for p in processes:
p.join()