mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
659 lines
30 KiB
Python
659 lines
30 KiB
Python
2 years ago
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ============================================================================
|
||
|
"""
|
||
|
PanguAlpha train script
|
||
|
"""
|
||
|
|
||
|
import datetime
|
||
|
import glob
|
||
|
import json
|
||
|
import math
|
||
|
import os
|
||
|
import time
|
||
|
|
||
|
import mindspore.common.dtype as mstype
|
||
|
import mindspore.communication.management as D
|
||
|
import mindspore.nn as nn
|
||
|
import moxing as mox
|
||
|
from mindspore import context
|
||
|
from mindspore.context import ParallelMode
|
||
|
from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell, MicroBatchInterleaved
|
||
|
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||
|
from mindspore.parallel import set_algo_parameters
|
||
|
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
|
||
|
from mindspore.parallel.nn import TransformerOpParallelConfig, CrossEntropyLoss
|
||
|
from mindspore.profiler import Profiler
|
||
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||
|
from mindspore.train.callback import TimeMonitor
|
||
|
from mindspore.train.model import Model
|
||
|
from mindspore.train.serialization import load_distributed_checkpoint, load_checkpoint, load_param_into_net
|
||
|
from tensorboardX import SummaryWriter
|
||
|
|
||
|
from src.adam import AdamWeightDecayOp
|
||
|
from src.callbacks import EvalCallBack, LossCallBack, SaveCheckpointCallback
|
||
|
from src.dataset import create_dataset
|
||
|
from src.metrics import PPLMetric, ValidationLoss
|
||
|
from src.pangu_alpha import PanGUAlphaWithLoss, PanguAlphaModel
|
||
|
from src.pangu_alpha_config import set_parse, PanguAlphaConfig
|
||
|
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell
|
||
|
from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay
|
||
|
from src.utils import download_data
|
||
|
|
||
|
project_root = os.path.abspath(
|
||
|
os.path.dirname(os.path.realpath(__file__)) + os.path.sep + "..")
|
||
|
print('project_root:', project_root)
|
||
|
|
||
|
|
||
|
def set_weight_decay(params):
|
||
|
"""
|
||
|
Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest
|
||
|
"""
|
||
|
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
|
||
|
decay_params = list(filter(decay_filter, params))
|
||
|
other_params = list(filter(lambda x: not decay_filter(x), params))
|
||
|
group_params = [
|
||
|
{"params": decay_params, "weight_decay": 1e-1},
|
||
|
{"params": other_params, "weight_decay": 0.0},
|
||
|
{"order_params": params},
|
||
|
]
|
||
|
return group_params
|
||
|
|
||
|
|
||
|
def add_checkpoint_callback_policy(args_param, callback, rank_id):
|
||
|
r"""
|
||
|
Add checkpoint policy to callback.
|
||
|
"""
|
||
|
if args_param.save_checkpoint:
|
||
|
# checkpoint store epoch_num and step_num info
|
||
|
ckpt_append_info = [{"epoch_num": args_param.has_trained_epoches, "step_num": args_param.has_trained_steps}]
|
||
|
ckpt_config = CheckpointConfig(
|
||
|
save_checkpoint_steps=args_param.save_checkpoint_steps,
|
||
|
keep_checkpoint_max=args_param.keep_checkpoint_max,
|
||
|
integrated_save=False,
|
||
|
append_info=ckpt_append_info,
|
||
|
)
|
||
|
|
||
|
# save checkpoint into rank directory
|
||
|
ckpoint_cb = ModelCheckpoint(prefix=args_param.ckpt_name_prefix + str(rank_id),
|
||
|
directory=os.path.join(args_param.save_checkpoint_path, f"rank_{rank_id}"),
|
||
|
config=ckpt_config)
|
||
|
|
||
|
callback.append(ckpoint_cb)
|
||
|
|
||
|
saveckpt_cb = SaveCheckpointCallback(cache_dir=args_param.save_checkpoint_path,
|
||
|
bucket=args_param.save_checkpoint_obs_path,
|
||
|
local_rank=rank_id,
|
||
|
has_trained_epoch=args_param.has_trained_epoches,
|
||
|
has_trained_step=args_param.has_trained_steps,
|
||
|
syn_times=args_param.save_checkpoint_steps)
|
||
|
callback.append(saveckpt_cb)
|
||
|
|
||
|
|
||
|
def set_parallel_context(args_opt):
|
||
|
r"""Set parallel context"""
|
||
|
D.init()
|
||
|
device_num = D.get_group_size()
|
||
|
rank = D.get_rank()
|
||
|
print("rank_id is {}, device_num is {}".format(rank, device_num))
|
||
|
context.reset_auto_parallel_context()
|
||
|
context.set_auto_parallel_context(
|
||
|
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False,
|
||
|
full_batch=bool(args_opt.full_batch), strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path,
|
||
|
enable_parallel_optimizer=bool(args_opt.optimizer_shard), strategy_ckpt_save_file='strategy.ckpt',
|
||
|
optimizer_weight_shard_size=16)
|
||
|
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||
|
_set_multi_subgraphs()
|
||
|
return rank, device_num
|
||
|
|
||
|
|
||
|
def run_train(args_opt):
|
||
|
r"""The main training process."""
|
||
|
os.environ["HCCL_CONNECT_TIMEOUT"] = "2000"
|
||
|
# Set execution mode
|
||
|
context.set_context(
|
||
|
mode=context.GRAPH_MODE, device_target=args_opt.device_target
|
||
|
)
|
||
|
if args_opt.profiling:
|
||
|
profiler = Profiler(output_path="/cache/profiler_data")
|
||
|
context.set_context(variable_memory_max_size="30GB")
|
||
|
# Set parallel context
|
||
|
rank = 0
|
||
|
device_num = 1
|
||
|
if args_opt.distribute == "true":
|
||
|
rank, device_num = set_parallel_context(args_opt)
|
||
|
context.set_context(
|
||
|
save_graphs=False,
|
||
|
save_graphs_path="/cache/graphs_of_device_id_" + str(rank),
|
||
|
)
|
||
|
cache_url = '/cache/Data/'
|
||
|
eval_cache_url = '/cache/EvalData/'
|
||
|
if not args_opt.offline:
|
||
|
download_data(src_data_url=args_opt.data_url, tgt_data_path=cache_url, rank=rank)
|
||
|
download_data(src_data_url=args_opt.eval_data_url, tgt_data_path=eval_cache_url, rank=rank)
|
||
|
# Set model property
|
||
|
model_parallel_num = args_opt.op_level_model_parallel_num
|
||
|
data_parallel_num = int(device_num / model_parallel_num)
|
||
|
batch_size = args_opt.per_batch_size * data_parallel_num
|
||
|
parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, model_parallel=model_parallel_num,
|
||
|
pipeline_stage=args_opt.stage_num,
|
||
|
micro_batch_num=args_opt.micro_size,
|
||
|
optimizer_shard=bool(args_opt.optimizer_shard),
|
||
|
vocab_emb_dp=bool(args_opt.word_emb_dp), recompute=True,
|
||
|
gradient_aggregation_group=args_opt.gradient_aggregation_group)
|
||
|
|
||
|
micro_interleaved_size = args_opt.micro_interleaved_size
|
||
|
config = PanguAlphaConfig(
|
||
|
batch_size=batch_size // micro_interleaved_size,
|
||
|
num_heads=args_opt.num_heads,
|
||
|
hidden_size=args_opt.embedding_size,
|
||
|
seq_length=args_opt.seq_length,
|
||
|
vocab_size=args_opt.vocab_size,
|
||
|
num_layers=args_opt.num_layers,
|
||
|
ffn_hidden_size=args_opt.embedding_size * 4,
|
||
|
eod_token=args_opt.eod_id,
|
||
|
load_ckpt_path=args_opt.load_ckpt_path,
|
||
|
param_init_type=mstype.float32
|
||
|
if args_opt.param_init_type == "fp32"
|
||
|
else mstype.float16,
|
||
|
dropout_rate=args_opt.dropout_rate,
|
||
|
enable_offload=bool(args_opt.opt_offload),
|
||
|
use_moe=bool(args_opt.use_moe),
|
||
|
per_dp_dim_expert_num=args_opt.per_dp_dim_expert_num,
|
||
|
hidden_act="fast_gelu" if args_opt.device_target != "GPU" else "gelu",
|
||
|
parallel_config=parallel_config,
|
||
|
)
|
||
|
print("===config is: ", config, flush=True)
|
||
|
# Define network
|
||
|
pangu_alpha = PanguAlphaModel(config=config)
|
||
|
loss = CrossEntropyLoss(config.parallel_config.dp_mp_config)
|
||
|
if micro_interleaved_size > 1:
|
||
|
print("===using MicroBatchInterleaved", flush=True)
|
||
|
pangu_alpha_with_loss_net = MicroBatchInterleaved(PanGUAlphaWithLoss(config, pangu_alpha, loss),
|
||
|
micro_interleaved_size)
|
||
|
else:
|
||
|
pangu_alpha_with_loss_net = PanGUAlphaWithLoss(config, pangu_alpha, loss)
|
||
|
pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net)
|
||
|
print("=====args_opt is: ", args_opt, flush=True)
|
||
|
# Warm-up and cosine decay learning rate
|
||
|
lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr,
|
||
|
warmup_steps=args_opt.warmup_step, decay_steps=args_opt.decay_steps)
|
||
|
params = pangu_alpha_with_loss.trainable_params()
|
||
|
group_params = set_weight_decay(params)
|
||
|
if args_opt.optimizer == "lamb":
|
||
|
optimizer = nn.Lamb(group_params, learning_rate=lr)
|
||
|
elif args_opt.opt_offload:
|
||
|
optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95,
|
||
|
param_init_type=config.param_init_type)
|
||
|
else:
|
||
|
optimizer = FP32StateAdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
|
||
|
# Initial scaling sens
|
||
|
loss_scale_value = math.pow(2, 32)
|
||
|
epoch_num = args_opt.epoch_size
|
||
|
|
||
|
if args_opt.load_ckpt_epoch > 0:
|
||
|
time.sleep(rank * 0.05)
|
||
|
os.mkdir(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}"))
|
||
|
ckpt_name = f"code-13B{rank}_20-{args_opt.load_ckpt_epoch}_2.ckpt"
|
||
|
if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name)):
|
||
|
print(f"Checkpoint from rank {rank} doesn't exist!")
|
||
|
mox.file.copy(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name),
|
||
|
os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name))
|
||
|
param_dict = load_checkpoint(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name))
|
||
|
# TODO: remove after warming-up!
|
||
|
# param_dict.pop('global_step')
|
||
|
# TODO: add them back if not for the 1st run!
|
||
|
# if param_dict.get("epoch_num") and param_dict.get("step_num"):
|
||
|
# args_opt.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy())
|
||
|
# args_opt.has_trained_steps = int(param_dict["step_num"].data.asnumpy())
|
||
|
# args_opt.has_trained_steps = 9000
|
||
|
|
||
|
os.mkdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1/rank_{rank}')
|
||
|
while True:
|
||
|
num = len(os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1'))
|
||
|
if num == device_num:
|
||
|
break
|
||
|
if rank % 64 == 0:
|
||
|
print("Loaded ckpt in step 1: ", num)
|
||
|
time.sleep(1)
|
||
|
|
||
|
if args_opt.tb_dir is not None and rank == device_num - 1:
|
||
|
os.makedirs(args_opt.tb_dir, exist_ok=True)
|
||
|
summary_writer = SummaryWriter(args_opt.tb_dir)
|
||
|
os.system(f'chomd 777 -R {args_opt.tb_dir}')
|
||
|
else:
|
||
|
summary_writer = None
|
||
|
|
||
|
# Dataset loading mindrecord files
|
||
|
ds, ds_eval = create_dataset(config.batch_size * micro_interleaved_size, data_path=args_opt.code_data,
|
||
|
args_opt=args_opt, data_start_index=0,
|
||
|
eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch),
|
||
|
eod_id=args_opt.eod_id,
|
||
|
device_num=device_num, rank=rank, epoch=epoch_num,
|
||
|
train_and_eval=bool(args_opt.train_and_eval_mode), val_ratio=0.001)
|
||
|
actual_epoch_num = int(ds.get_dataset_size() / args_opt.sink_size)
|
||
|
callback = [
|
||
|
TimeMonitor(args_opt.sink_size),
|
||
|
]
|
||
|
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
|
||
|
pangu_alpha_with_grads = PanguAlphaTrainOneStepWithLossScaleCell(
|
||
|
pangu_alpha_with_loss, optimizer=optimizer, scale_update_cell=update_cell, enable_global_norm=True,
|
||
|
config=config)
|
||
|
if ds_eval:
|
||
|
ppl_metric = PPLMetric(config.seq_length)
|
||
|
validation_loss = ValidationLoss(eod_token=args_opt.eod_id)
|
||
|
model = Model(pangu_alpha_with_grads, eval_network=pangu_alpha_with_loss,
|
||
|
metrics={"ppl": ppl_metric, "validation_loss": validation_loss})
|
||
|
callback.append(
|
||
|
EvalCallBack(
|
||
|
model=model,
|
||
|
eval_dataset=ds_eval,
|
||
|
ppl_metric=ppl_metric,
|
||
|
validation_loss=validation_loss,
|
||
|
print_per_step=10,
|
||
|
has_trained_step=args_opt.has_trained_steps,
|
||
|
local_rank=rank,
|
||
|
rank_size=device_num,
|
||
|
tb_writer=summary_writer
|
||
|
)
|
||
|
)
|
||
|
else:
|
||
|
model = Model(pangu_alpha_with_grads)
|
||
|
if args_opt.load_ckpt_epoch > 0:
|
||
|
print("===build model and load ckpt")
|
||
|
time_stamp = datetime.datetime.now()
|
||
|
print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} before building", flush=True)
|
||
|
model.build(train_dataset=ds, sink_size=args_opt.sink_size, epoch=actual_epoch_num)
|
||
|
time_stamp = datetime.datetime.now()
|
||
|
print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} before loading ckpt", flush=True)
|
||
|
net_not_load = load_param_into_net(pangu_alpha_with_loss, param_dict)
|
||
|
opt_not_load = load_param_into_net(optimizer, param_dict)
|
||
|
|
||
|
os.mkdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/2/rank_{rank}')
|
||
|
while True:
|
||
|
num = len(os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/2'))
|
||
|
if num == device_num:
|
||
|
break
|
||
|
if rank % 64 == 0:
|
||
|
print("Loaded ckpt in step 2: ", num)
|
||
|
time.sleep(1)
|
||
|
callback.append(
|
||
|
LossCallBack(
|
||
|
name=args_opt.ckpt_name_prefix,
|
||
|
dataset_size=args_opt.sink_size,
|
||
|
local_rank=rank,
|
||
|
rank_size=device_num,
|
||
|
has_trained_epoch=args_opt.has_trained_epoches,
|
||
|
has_trained_step=args_opt.has_trained_steps,
|
||
|
micro_size=args_opt.micro_size * micro_interleaved_size,
|
||
|
tb_writer=summary_writer,
|
||
|
)
|
||
|
)
|
||
|
if not args_opt.profiling:
|
||
|
add_checkpoint_callback_policy(args_opt, callback, rank)
|
||
|
if args_opt.incremental_training:
|
||
|
strategy = model.infer_train_layout(train_dataset=ds, sink_size=args_opt.sink_size)
|
||
|
print("======start load_distributed checkpoint", flush=True)
|
||
|
# For 2.6B and 13B models, the number of ckpt files is 512.
|
||
|
ckpt_file_list = [os.path.join(args_opt.load_ckpt_path, f"filerted_{ckpt_rank}.ckpt") for ckpt_rank in
|
||
|
range(0, 512)]
|
||
|
print(f"Loading from path {ckpt_file_list[0]}", flush=True)
|
||
|
load_distributed_checkpoint(model.train_network, ckpt_file_list, strategy)
|
||
|
print("Dataset size: {}, actual_epoch_num: {}".format(ds.get_dataset_size(), actual_epoch_num), flush=True)
|
||
|
|
||
|
try:
|
||
|
model.train(10 if args_opt.profiling else actual_epoch_num, ds, callbacks=callback,
|
||
|
sink_size=args_opt.sink_size, dataset_sink_mode=True)
|
||
|
finally:
|
||
|
if args_opt.profiling:
|
||
|
jobid = os.environ["BATCH_JOB_ID"]
|
||
|
profiler.analyse()
|
||
|
rank_id = rank
|
||
|
if context.get_context("save_graphs"):
|
||
|
mox.file.make_dirs("s3://wudao-1/yyf/graphs_" + jobid)
|
||
|
mox.file.copy_parallel(src_url="/cache/graphs_of_device_id_" + str(rank_id),
|
||
|
dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id))
|
||
|
if rank_id % 8 == 0:
|
||
|
mox.file.make_dirs("s3://wudao-1/yyf/profiler_" + jobid)
|
||
|
mox.file.copy_parallel(src_url="/cache/profiler_data",
|
||
|
dst_url="s3://wudao-1/yyf/profiler_" + jobid + "/" + str(rank_id))
|
||
|
|
||
|
|
||
|
def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch):
|
||
|
r"""
|
||
|
Load checkpoint process.
|
||
|
"""
|
||
|
print("======start single checkpoint", flush=True)
|
||
|
ckpt_name = args_param.ckpt_name_prefix
|
||
|
ckpt_pattern = os.path.join(args_param.save_checkpoint_path, "rank_{}".format(D.get_rank()), f"{ckpt_name}*.ckpt")
|
||
|
ckpt_all_files = glob.glob(ckpt_pattern)
|
||
|
|
||
|
if not ckpt_all_files:
|
||
|
print(f"There is no ckpt file in {args_param.save_checkpoint_path}, "
|
||
|
f"current ckpt_files found is {ckpt_all_files} "
|
||
|
f"with pattern {ckpt_pattern}, so skip the loading.")
|
||
|
return
|
||
|
|
||
|
ckpt_exp_pattern = os.path.join(args_param.save_checkpoint_path, "rank_{}".format(D.get_rank()),
|
||
|
f"{ckpt_name}*_breakpoint.ckpt")
|
||
|
ckpt_exp_files = glob.glob(ckpt_exp_pattern)
|
||
|
ckpt_files = []
|
||
|
for file in ckpt_all_files:
|
||
|
if file not in ckpt_exp_files:
|
||
|
ckpt_files.append(file)
|
||
|
|
||
|
if not ckpt_files:
|
||
|
print(f"There is no ckpt file in {args_param.save_checkpoint_path}, "
|
||
|
f"current ckpt_files found is {ckpt_files} "
|
||
|
f"with pattern {ckpt_pattern}, so skip the loading.")
|
||
|
return
|
||
|
|
||
|
ckpt_files.sort(key=os.path.getmtime, reverse=True)
|
||
|
time_stamp = datetime.datetime.now()
|
||
|
print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} pre trained ckpt model {ckpt_files} loading",
|
||
|
flush=True)
|
||
|
# Load checkpoint files latest file
|
||
|
print(f'Start to load from {ckpt_files[0]}')
|
||
|
param_dict = load_checkpoint(ckpt_files[0])
|
||
|
if param_dict.get("epoch_num") and param_dict.get("step_num"):
|
||
|
args_param.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy())
|
||
|
args_param.has_trained_steps = int(param_dict["step_num"].data.asnumpy())
|
||
|
model.build(train_dataset=dataset, sink_size=sink_size, epoch=epoch)
|
||
|
load_param_into_net(network, param_dict)
|
||
|
|
||
|
|
||
|
def get_exception_checkpoints(args_param):
|
||
|
"""
|
||
|
Get exception checkpoint based on restore ranks
|
||
|
Args:
|
||
|
args_param: training model parameters
|
||
|
Returns: exception checkpoint list
|
||
|
"""
|
||
|
print("======start exception checkpoint", flush=True)
|
||
|
restore_ranks = os.getenv("RESTORE_RANKS")
|
||
|
if not restore_ranks:
|
||
|
return None
|
||
|
|
||
|
restore_rank_list = list(map(int, restore_ranks.split(",")))
|
||
|
ckpt_file_list = []
|
||
|
ckpt_name = args_param.ckpt_name_prefix
|
||
|
for ckpt_rank in restore_rank_list:
|
||
|
ckpt_pattern = os.path.join(args_param.save_checkpoint_path,
|
||
|
f"rank_{ckpt_rank}",
|
||
|
f"{ckpt_name}*_breakpoint.ckpt")
|
||
|
ckpt_files = glob.glob(ckpt_pattern)
|
||
|
if not ckpt_files:
|
||
|
print(
|
||
|
f"There is no ckpt file in {args_param.save_checkpoint_path}, "
|
||
|
f"current ckpt_files found is {ckpt_files} "
|
||
|
f"with pattern {ckpt_pattern}, so skip the loading.")
|
||
|
return None
|
||
|
ckpt_files.sort(key=os.path.getmtime, reverse=True)
|
||
|
ckpt_file_list.append(ckpt_files[0])
|
||
|
print(f"checkpoint file {ckpt_file_list}")
|
||
|
return ckpt_file_list
|
||
|
|
||
|
|
||
|
def check_exception_checkpoints(ckpt_file_list):
|
||
|
"""
|
||
|
Checkpoint exception checkpoints size.
|
||
|
Args:
|
||
|
ckpt_file_list: exception checkpoints
|
||
|
Returns: result of exception checkpoints size check.
|
||
|
"""
|
||
|
ckpt_size_list = []
|
||
|
for ckpt_file in ckpt_file_list:
|
||
|
ckpt_size_list.append(os.path.getsize(ckpt_file))
|
||
|
|
||
|
if len(set(ckpt_size_list)) > 1:
|
||
|
return False
|
||
|
|
||
|
return True
|
||
|
|
||
|
|
||
|
def restore_exception_checkpoint(args_param, sink_size, dataset, model, network, epoch):
|
||
|
"""
|
||
|
Restore exception checkpoint.
|
||
|
Args:
|
||
|
args_param: training job params
|
||
|
sink_size: training job sink size
|
||
|
dataset: dataset for training
|
||
|
model: model
|
||
|
network: pangu_alpha network
|
||
|
epoch: training epoch
|
||
|
Returns: load exception checkpoint success or not.
|
||
|
"""
|
||
|
if os.getenv("RESTORE_RANKS") == "-1":
|
||
|
return False
|
||
|
|
||
|
ckpt_file_list = get_exception_checkpoints(args_param)
|
||
|
|
||
|
restore_flag = False
|
||
|
if ckpt_file_list:
|
||
|
restore_flag = check_exception_checkpoints(ckpt_file_list)
|
||
|
|
||
|
if not restore_flag:
|
||
|
return False
|
||
|
|
||
|
ckpt_name = args_param.ckpt_name_prefix
|
||
|
restore_ranks_map = os.getenv("RESTORE_RANKS_MAP")
|
||
|
if not restore_ranks_map:
|
||
|
return False
|
||
|
|
||
|
try:
|
||
|
print("whether run into load process")
|
||
|
restore_ranks_map_json = json.loads(restore_ranks_map)
|
||
|
map_rank_id = D.get_rank()
|
||
|
for key in restore_ranks_map_json.keys():
|
||
|
key_list = list(key.split(","))
|
||
|
if str(D.get_rank()) in key_list:
|
||
|
map_rank_id = restore_ranks_map_json.get(key)
|
||
|
|
||
|
print(f"loading map rank id {map_rank_id}")
|
||
|
ckpt_pattern = os.path.join(args_param.save_checkpoint_path,
|
||
|
f"rank_{map_rank_id}",
|
||
|
f"{ckpt_name}*breakpoint.ckpt")
|
||
|
ckpt_files = glob.glob(ckpt_pattern)
|
||
|
ckpt_files.sort(key=os.path.getmtime, reverse=True)
|
||
|
print(f" checkpoint files {ckpt_files[0]}")
|
||
|
param_dict = load_checkpoint(ckpt_files[0])
|
||
|
print(f" checkpoint param dict epoch num {param_dict.get('epoch_num')}")
|
||
|
if param_dict.get("epoch_num") and param_dict.get("step_num"):
|
||
|
args_param.has_trained_epoches = int(
|
||
|
param_dict["epoch_num"].data.asnumpy())
|
||
|
args_param.has_trained_steps = int(
|
||
|
param_dict["step_num"].data.asnumpy())
|
||
|
|
||
|
# Load checkpoint files
|
||
|
model.build(train_dataset=dataset, sink_size=sink_size, epoch=epoch)
|
||
|
load_param_into_net(network, param_dict)
|
||
|
except TypeError:
|
||
|
return False
|
||
|
else:
|
||
|
return True
|
||
|
|
||
|
|
||
|
def set_pipeline_parallel_context(args_opt):
|
||
|
"""
|
||
|
Set prarllel context in pipeline training process
|
||
|
"""
|
||
|
D.init()
|
||
|
device_num = D.get_group_size()
|
||
|
rank_id = D.get_rank()
|
||
|
print("rank_id is {}, device_num is {}".format(rank_id, device_num))
|
||
|
context.reset_auto_parallel_context()
|
||
|
context.set_auto_parallel_context(
|
||
|
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False,
|
||
|
full_batch=bool(args_opt.full_batch), loss_repeated_mean=True,
|
||
|
device_num=device_num, enable_parallel_optimizer=bool(args_opt.optimizer_shard),
|
||
|
pipeline_stages=args_opt.stage_num)
|
||
|
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||
|
_set_multi_subgraphs()
|
||
|
return rank_id, device_num
|
||
|
|
||
|
|
||
|
def run_train_pipeline(args_opt):
|
||
|
r"""The main training process in pipeline."""
|
||
|
# Set hccl connect time
|
||
|
os.environ['HCCL_CONNECT_TIMEOUT'] = "6000"
|
||
|
|
||
|
context.set_context(save_graphs=False, mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||
|
if args_opt.profiling:
|
||
|
profiler = Profiler(output_path="./profiler_data")
|
||
|
context.set_context(variable_memory_max_size="30GB")
|
||
|
|
||
|
rank_id = 0
|
||
|
device_num = 1
|
||
|
if args_opt.distribute == "true":
|
||
|
rank_id, device_num = set_pipeline_parallel_context(args_opt)
|
||
|
|
||
|
# copy data from the cloud to the /cache/Data
|
||
|
cache_url = '/cache/Data/'
|
||
|
eval_cache_url = '/cache/EvalData/'
|
||
|
if not args_opt.offline:
|
||
|
download_data(src_data_url=args_opt.data_url, tgt_data_path=cache_url, rank=rank_id)
|
||
|
download_data(src_data_url=args_opt.eval_data_url, tgt_data_path=eval_cache_url, rank=rank_id)
|
||
|
model_parallel_num = args_opt.op_level_model_parallel_num
|
||
|
stage_device_num = int(device_num / args_opt.stage_num)
|
||
|
data_parallel_num = int(stage_device_num / model_parallel_num)
|
||
|
print("Topology:", model_parallel_num, data_parallel_num, stage_device_num)
|
||
|
if data_parallel_num <= 1 and args_opt.optimizer_shard == 1:
|
||
|
raise ValueError("The dp must large than 1 when applying optimizer shard.")
|
||
|
per_batch_size = args_opt.per_batch_size
|
||
|
batch_size = per_batch_size * data_parallel_num * args_opt.micro_size
|
||
|
|
||
|
parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num,
|
||
|
model_parallel=model_parallel_num,
|
||
|
pipeline_stage=args_opt.stage_num,
|
||
|
micro_batch_num=args_opt.micro_size,
|
||
|
optimizer_shard=bool(args_opt.optimizer_shard),
|
||
|
vocab_emb_dp=bool(args_opt.word_emb_dp),
|
||
|
recompute=True,
|
||
|
)
|
||
|
config = PanguAlphaConfig(
|
||
|
batch_size=batch_size // parallel_config.micro_batch_num,
|
||
|
num_heads=args_opt.num_heads,
|
||
|
hidden_size=args_opt.embedding_size,
|
||
|
seq_length=args_opt.seq_length,
|
||
|
vocab_size=args_opt.vocab_size,
|
||
|
num_layers=args_opt.num_layers,
|
||
|
ffn_hidden_size=args_opt.embedding_size * 4,
|
||
|
eod_token=args_opt.eod_id,
|
||
|
load_ckpt_path=args_opt.load_ckpt_path,
|
||
|
param_init_type=mstype.float32
|
||
|
if args_opt.param_init_type == "fp32"
|
||
|
else mstype.float16,
|
||
|
enable_offload=bool(args_opt.opt_offload),
|
||
|
parallel_config=parallel_config,
|
||
|
apply_scale_normalization=args_opt.apply_scale_normalization,
|
||
|
)
|
||
|
|
||
|
print("===config is: ", config, flush=True)
|
||
|
pangu_alpha = PanguAlphaModel(config=config)
|
||
|
loss = CrossEntropyLoss(config.parallel_config.dp_mp_config)
|
||
|
pangu_alpha_with_loss_net = PipelineCell(PanGUAlphaWithLoss(config, pangu_alpha, loss),
|
||
|
config.parallel_config.micro_batch_num)
|
||
|
pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net)
|
||
|
print("=====args_opt is: ", args_opt, flush=True)
|
||
|
lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr,
|
||
|
warmup_steps=args_opt.warmup_step, decay_steps=args_opt.decay_steps)
|
||
|
params = pangu_alpha.infer_param_pipeline_stage()
|
||
|
group_params = set_weight_decay(params)
|
||
|
if args_opt.optimizer == "lamb":
|
||
|
optimizer = nn.Lamb(group_params, learning_rate=lr)
|
||
|
elif args_opt.opt_offload:
|
||
|
optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95,
|
||
|
param_init_type=config.param_init_type)
|
||
|
else:
|
||
|
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
|
||
|
|
||
|
ds = create_dataset(config.batch_size * parallel_config.micro_batch_num, data_path=args_opt.code_data,
|
||
|
device_num=stage_device_num, args_opt=args_opt,
|
||
|
rank=rank_id % stage_device_num, eod_reset=True, data_start_index=0,
|
||
|
full_batch=context.get_auto_parallel_context("full_batch"),
|
||
|
column_name=args_opt.data_column_name)
|
||
|
epoch_num = args_opt.epoch_size
|
||
|
step_per_epoch = ds.get_dataset_size()
|
||
|
callback_size = args_opt.sink_size
|
||
|
actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
|
||
|
callback = [
|
||
|
TimeMonitor(callback_size),
|
||
|
LossCallBack(
|
||
|
args_opt.ckpt_name_prefix,
|
||
|
callback_size,
|
||
|
rank_id,
|
||
|
device_num,
|
||
|
micro_size=parallel_config.micro_batch_num,
|
||
|
tb_dir=args_opt.tb_dir,
|
||
|
),
|
||
|
]
|
||
|
loss_scale_value = math.pow(2, 32)
|
||
|
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
|
||
|
pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell(
|
||
|
pangu_alpha_with_loss, optimizer=optimizer, config=config, scale_update_cell=update_cell)
|
||
|
if args_opt.train_and_eval_mode:
|
||
|
ds_eval = create_dataset(config.batch_size * parallel_config.micro_batch_num, data_path=eval_cache_url,
|
||
|
args_opt=args_opt,
|
||
|
device_num=stage_device_num, rank=rank_id % stage_device_num, eod_reset=True,
|
||
|
data_start_index=0, full_batch=bool(args_opt.full_batch),
|
||
|
column_name=args_opt.data_column_name,
|
||
|
num_samples=args_opt.eval_steps * config.batch_size)
|
||
|
ppl_metric = PPLMetric(config.seq_length)
|
||
|
pangu_alpha_with_loss_eval_net = _VirtualDatasetCell(PanGUAlphaWithLoss(config, pangu_alpha, loss))
|
||
|
model = Model(pangu_alpha_with_grads, eval_network=pangu_alpha_with_loss_eval_net, metrics={"ppl": ppl_metric})
|
||
|
model.build(ds, ds_eval, sink_size=callback_size)
|
||
|
eval_callback = EvalCallBack(model, ds_eval, ppl_metric)
|
||
|
callback.append(eval_callback)
|
||
|
else:
|
||
|
model = Model(pangu_alpha_with_grads)
|
||
|
|
||
|
if args_opt.pre_trained:
|
||
|
flag = restore_exception_checkpoint(args_opt, callback_size, ds, model,
|
||
|
pangu_alpha_with_grads, epoch=actual_epoch_num)
|
||
|
if not flag:
|
||
|
restore_checkpoint(args_opt, callback_size, ds, model, pangu_alpha_with_grads, epoch=actual_epoch_num)
|
||
|
|
||
|
callback = [
|
||
|
TimeMonitor(callback_size),
|
||
|
LossCallBack(
|
||
|
args_opt.ckpt_name_prefix,
|
||
|
callback_size,
|
||
|
rank_id,
|
||
|
device_num,
|
||
|
args_opt.has_trained_epoches,
|
||
|
args_opt.has_trained_steps,
|
||
|
tb_dir=args_opt.tb_dir,
|
||
|
),
|
||
|
]
|
||
|
# add_checkpoint_callback_policy(args_opt, callback, rank_id)
|
||
|
|
||
|
print("------ train start -------")
|
||
|
model.train(10 if args_opt.profiling else actual_epoch_num, ds, callbacks=callback,
|
||
|
sink_size=callback_size, dataset_sink_mode=True)
|
||
|
if args_opt.profiling:
|
||
|
profiler.analyse()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
opt = get_args()
|
||
|
set_parse(opt)
|
||
|
if opt.per_batch_size == 0:
|
||
|
raise ValueError("The per_batch_size has not been configured.")
|
||
|
if opt.stage_num > 1:
|
||
|
if bool(opt.use_moe) or bool(opt.opt_offload):
|
||
|
raise ValueError("Currently, moe and host device mode is not supported in pipeline parallel.")
|
||
|
run_train_pipeline(opt)
|
||
|
else:
|
||
|
run_train(opt)
|