You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
CodeGeeX/codegeex/mindspore/generation_humaneval.py

259 lines
12 KiB
Python

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
PanGu predict run
"""
import json
import os
import time
import mindspore
import mindspore.common.dtype as mstype
import mindspore.communication.management as D
import moxing as mox
import numpy as np
from mindspore import context, Tensor
from mindspore import export
from mindspore.context import ParallelMode
from mindspore.parallel import set_algo_parameters
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.parallel.nn.transformer import TransformerOpParallelConfig
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.code_tokenizer import CodeTokenizer
from src.pangu_alpha import EvalNet, PanguAlphaModel
from src.pangu_alpha_config import set_parse, PanguAlphaConfig
from src.utils import get_args
def load_model(args_opt):
r"""
The main function for load model
"""
# Set execution mode
context.set_context(save_graphs=False,
mode=context.GRAPH_MODE,
device_target=args_opt.device_target)
context.set_context(variable_memory_max_size="30GB")
# Set parallel context
if args_opt.distribute == "true":
D.init()
device_num = D.get_group_size()
rank = D.get_rank()
print("rank_id is {}, device_num is {}".format(rank, device_num))
context.reset_auto_parallel_context()
context.set_auto_parallel_context(
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
gradients_mean=False,
full_batch=True,
loss_repeated_mean=True,
enable_parallel_optimizer=False,
pipeline_stages=args_opt.stage_num)
set_algo_parameters(elementwise_op_strategy_follow=True)
_set_multi_subgraphs()
else:
rank = 0
device_num = 1
context.reset_auto_parallel_context()
context.set_auto_parallel_context(
strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path)
context.set_context(
save_graphs=False,
save_graphs_path="/cache/graphs_of_device_id_" + str(rank),
)
use_past = (args_opt.use_past == "true")
print('local_rank:{}, start to run...'.format(rank), flush=True)
if args_opt.export:
use_past = True
# Set model property
model_parallel_num = args_opt.op_level_model_parallel_num
data_parallel_num = int(device_num / model_parallel_num)
parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num,
model_parallel=model_parallel_num,
pipeline_stage=args_opt.stage_num,
micro_batch_num=args_opt.micro_size,
optimizer_shard=False,
vocab_emb_dp=bool(args_opt.word_emb_dp),
recompute=True)
per_batch_size = args_opt.per_batch_size
batch_size = per_batch_size * data_parallel_num
config = PanguAlphaConfig(
batch_size=batch_size,
seq_length=args_opt.seq_length,
vocab_size=args_opt.vocab_size,
hidden_size=args_opt.embedding_size,
num_layers=args_opt.num_layers,
num_heads=args_opt.num_heads,
post_layernorm_residual=False,
dropout_rate=0.0,
ffn_hidden_size=args_opt.embedding_size * 4,
use_past=use_past,
eod_token=args_opt.eod_id,
eod_reset=False,
parallel_config=parallel_config,
load_ckpt_path=args_opt.load_ckpt_path,
param_init_type=mstype.float32
if args_opt.param_init_type == 'fp32'
else mstype.float16,
)
print("===config is: ", config, flush=True)
print("=====args_opt is: ", args_opt, flush=True)
ckpt_name = args_opt.load_ckpt_name
# Define network
pangu_alpha = PanguAlphaModel(config)
eval_net = EvalNet(pangu_alpha, pad_token=50256)
eval_net.set_train(False)
model_predict = Model(eval_net)
# Compile network and obtain tensor layout for loading ckpt
inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
current_index = Tensor(np.array([0 for _ in range(batch_size)]), mstype.int32)
if args_opt.distribute == "false":
predict_layout = None
elif config.use_past:
batch_valid_length = Tensor(np.array([0 for _ in range(batch_size)]), mstype.int32)
init_true = Tensor([True], mstype.bool_)
print("Input shape:", inputs_np.shape, flush=True)
inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32)
model_predict.predict_network.add_flags_recursive(is_first_iteration=True)
print("is_first_iteration=True", flush=True)
predict_layout = model_predict.infer_predict_layout(inputs_np, current_index, init_true, batch_valid_length)
model_predict.predict_network.add_flags_recursive(is_first_iteration=False)
print("is_first_iteration=False", flush=True)
init_false = Tensor([False], mstype.bool_)
_ = model_predict.infer_predict_layout(inputs_np_1, current_index, init_false, batch_valid_length)
else:
predict_layout = model_predict.infer_predict_layout(inputs_np, current_index)
if context.get_context("save_graphs"):
print("==============save_graph", flush=True)
jobid = os.environ["BATCH_JOB_ID"]
rank_id = rank
mox.file.make_dirs("s3://wudao-1/yyf/graphs_" + jobid)
mox.file.copy_parallel(src_url="/cache/graphs_of_device_id_" + str(rank_id),
dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id))
print("======start load_distributed checkpoint", flush=True)
if args_opt.load_ckpt_epoch > 0:
time.sleep(rank * 0.1)
os.mkdir(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}"))
ckpt_name = f"code-13B{rank}-{args_opt.load_ckpt_epoch}.ckpt"
if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name)):
print(f"Checkpoint from rank {rank} doesn't exist!")
mox.file.copy(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name),
os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name))
param_dict = load_checkpoint(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name))
if param_dict.get("epoch_num") and param_dict.get("step_num"):
args_opt.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy())
args_opt.has_trained_steps = int(param_dict["step_num"].data.asnumpy())
os.mkdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1/rank_{rank}')
while True:
num = len(os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1'))
if num == device_num:
break
if rank % 8 == 0:
print("Loaded ckpt in step 1: ", num)
time.sleep(1)
net_not_load = load_param_into_net(pangu_alpha, param_dict)
print("====== load_distributed checkpoint done, net_not_load: ", net_not_load, flush=True)
return model_predict, config, rank
def export_mindir(model_predict, config):
"""Export mindir model"""
inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
current_index = Tensor(np.array([0]), mstype.int32)
batch_valid_length = Tensor(np.array([0]), mstype.int32)
init_true = Tensor([True], mstype.bool_)
inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32)
model_predict.predict_network.add_flags_recursive(is_first_iteration=True)
export(model_predict.predict_network, inputs_np, current_index,
init_true, batch_valid_length, file_name='pangu_alpha_1024', file_format='MINDIR')
model_predict.predict_network.add_flags_recursive(is_first_iteration=False)
export(model_predict.predict_network, inputs_np_1, current_index,
init_true, batch_valid_length, file_name='pangu_alpha_1', file_format='MINDIR')
print("Export finished and now exit.")
def run_predict(model_predict, config, args_opt, rank):
"""run predict"""
from src.generate_humaneval import generate_increment
# Define tokenizer
tokenizer = CodeTokenizer(mode='6b')
# Tokenize input sentence to ids
humaneval_path = '/home/work/sfs/xx/human_eval_x/data/humaneval_cpp.jsonl' # TODO: set as current humaneval path
humaneval = open(humaneval_path, 'r').readlines()
humaneval = [json.loads(task) for task in humaneval if len(task) != 0]
samples = [task['prompt'] for task in humaneval]
generations = []
batch_size = config.batch_size
verbose = (rank % 8 == 0)
part = int(args_opt.part)
gen_times = 12 # TODO: set as generation times of current task
print(f"gen times: {gen_times}, part: {part}")
save_path = f'/home/work/sfs/xx/pangu_alpha_code/generation_humanevalx/cpp/temp_{args_opt.temperature}/samples_{args_opt.load_ckpt_epoch}_part_{part}.jsonl' # TODO: set as current save path
if rank == 0 and not os.path.exists(save_path):
os.makedirs(os.path.split(save_path)[0], exist_ok=True)
f = open(save_path, 'w')
f.close()
os.system(f'sudo chmod 777 {save_path}')
for i, sample in enumerate(samples):
tag = "// language: C++\n"
sample = tag + sample
if rank % 8 == 0:
print(f"=================== prompt {i} ====================")
print(sample, flush=True)
for j in range((gen_times + batch_size - 1) // batch_size):
tokenized_token = tokenizer.encode_code(sample)
input_ids = np.array(tokenized_token).reshape(1, -1).repeat(batch_size, axis=0)
# Call inference
mindspore.set_seed(j + 8 * part)
generate_func = generate_increment
t0 = time.perf_counter()
output_ids = generate_func(model_predict, input_ids, args_opt, tokenizer, verbose)
t1 = time.perf_counter()
if rank % 8 == 0:
print(f"=== Batch time: {t1 - t0}s")
for k, out in enumerate(output_ids):
print(f"=================== generation {j * batch_size + k} ====================")
print(out, flush=True)
generations.append(json.dumps({'task_id': humaneval[i]['task_id'], 'completion': out}))
if rank == 0:
f = open(save_path, 'a')
f.write(generations[-1] + '\n')
f.close()
def main():
"""Main process for predict or export model"""
print("===Enter main!")
opt = get_args(True)
set_parse(opt)
model_predict, config, rank = load_model(opt)
if opt.export:
export_mindir(model_predict, config)
else:
run_predict(model_predict, config, opt, rank)
if __name__ == "__main__":
main()