mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
252 lines
11 KiB
Python
252 lines
11 KiB
Python
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""
|
|
PanguAlpha train script
|
|
"""
|
|
|
|
import datetime
|
|
import argparse
|
|
import numpy as np
|
|
import json
|
|
import glob
|
|
import os
|
|
import math
|
|
import time
|
|
from pathlib2 import Path
|
|
from collections import defaultdict
|
|
import moxing as mox
|
|
|
|
from tensorboardX import SummaryWriter
|
|
|
|
from mindspore import context
|
|
import mindspore.communication.management as D
|
|
from mindspore.context import ParallelMode
|
|
import mindspore.nn as nn
|
|
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
|
from mindspore.parallel import set_algo_parameters
|
|
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
|
|
from mindspore.parallel.nn import TransformerOpParallelConfig, CrossEntropyLoss
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
|
|
|
import mindspore
|
|
from mindspore.train.serialization import load_checkpoint, build_searched_strategy, save_checkpoint, \
|
|
merge_sliced_parameter, _convert_to_list, _convert_to_layout
|
|
from mindspore.common.parameter import Parameter
|
|
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap, ParallelLayouts
|
|
from mindspore import Tensor
|
|
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
|
|
|
|
from src.adam import AdamWeightDecayOp
|
|
from src.dataset import create_dataset
|
|
from src.pangu_alpha import PanGUAlphaWithLoss, PanguAlphaModel
|
|
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell
|
|
from src.pangu_alpha_config import set_parse, PanguAlphaConfig
|
|
from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay
|
|
from src.utils import download_data
|
|
from src.callbacks import LossCallBack, SaveCheckpointCallback
|
|
from mindspore.profiler import Profiler
|
|
|
|
project_root = os.path.abspath(
|
|
os.path.dirname(os.path.realpath(__file__)) + os.path.sep + "..")
|
|
print('project_root:', project_root)
|
|
|
|
|
|
def set_weight_decay(params):
|
|
"""
|
|
Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest
|
|
"""
|
|
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
|
|
decay_params = list(filter(decay_filter, params))
|
|
other_params = list(filter(lambda x: not decay_filter(x), params))
|
|
group_params = [
|
|
{"params": decay_params, "weight_decay": 1e-1},
|
|
{"params": other_params, "weight_decay": 0.0},
|
|
{"order_params": params},
|
|
]
|
|
return group_params
|
|
|
|
|
|
def add_checkpoint_callback_policy(args_param, callback, rank_id):
|
|
r"""
|
|
Add checkpoint policy to callback.
|
|
"""
|
|
if args_param.save_checkpoint:
|
|
# checkpoint store epoch_num and step_num info
|
|
ckpt_append_info = [{"epoch_num": args_param.has_trained_epoches, "step_num": args_param.has_trained_steps}]
|
|
ckpt_config = CheckpointConfig(
|
|
save_checkpoint_steps=args_param.save_checkpoint_steps,
|
|
keep_checkpoint_max=args_param.keep_checkpoint_max,
|
|
integrated_save=False,
|
|
append_info=ckpt_append_info,
|
|
)
|
|
|
|
# save checkpoint into rank directory
|
|
ckpoint_cb = ModelCheckpoint(prefix=args_param.ckpt_name_prefix + str(rank_id),
|
|
directory=os.path.join(args_param.save_checkpoint_path, f"rank_{rank_id}"),
|
|
config=ckpt_config)
|
|
|
|
callback.append(ckpoint_cb)
|
|
|
|
saveckpt_cb = SaveCheckpointCallback(cache_dir=args_param.save_checkpoint_path,
|
|
bucket=args_param.save_checkpoint_obs_path,
|
|
local_rank=rank_id,
|
|
has_trained_epoch=args_param.has_trained_epoches,
|
|
has_trained_step=args_param.has_trained_steps,
|
|
syn_times=args_param.save_checkpoint_steps)
|
|
callback.append(saveckpt_cb)
|
|
|
|
|
|
def set_parallel_context(args_opt):
|
|
r"""Set parallel context"""
|
|
D.init()
|
|
device_num = D.get_group_size()
|
|
rank = D.get_rank()
|
|
print("rank_id is {}, device_num is {}".format(rank, device_num))
|
|
context.reset_auto_parallel_context()
|
|
context.set_auto_parallel_context(
|
|
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False,
|
|
full_batch=bool(args_opt.full_batch), strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path,
|
|
enable_parallel_optimizer=bool(args_opt.optimizer_shard), strategy_ckpt_save_file='strategy.ckpt',
|
|
optimizer_weight_shard_size=16, optimizer_weight_shard_aggregated_save=True)
|
|
set_algo_parameters(elementwise_op_strategy_follow=True)
|
|
_set_multi_subgraphs()
|
|
return rank, device_num
|
|
|
|
|
|
def download_ckpt(args_opt, file_num, rank_num, rank_id):
|
|
ckpt_list = []
|
|
for rank in range(0, file_num):
|
|
ckpt_name = f"code-13B{rank}-{args_opt.load_ckpt_epoch}.ckpt"
|
|
local_file = os.path.join(args_opt.save_checkpoint_path, f"origin_rank_{rank}", ckpt_name)
|
|
ckpt_list.append(local_file)
|
|
if rank % rank_num != rank_id:
|
|
continue
|
|
time.sleep(rank * 0.05)
|
|
if not os.path.exists(os.path.join(args_opt.save_checkpoint_path, f"origin_rank_{rank}")):
|
|
os.mkdir(os.path.join(args_opt.save_checkpoint_path, f"origin_rank_{rank}"))
|
|
if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name)):
|
|
print(f"Checkpoint from rank {rank} doesn't exist!")
|
|
mox.file.copy(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), local_file)
|
|
print("===download ckpt ok: ", local_file, flush=True)
|
|
print(ckpt_list)
|
|
return ckpt_list
|
|
|
|
|
|
def get_needed_model_parallel_list(train_strategy_file, self_rank):
|
|
train_strategy_origin = build_searched_strategy(train_strategy_file)
|
|
strategy_keys = list(train_strategy_origin.keys())
|
|
train_strategy = _convert_to_list(train_strategy_origin)
|
|
rank_list = _infer_rank_list(train_strategy, None)
|
|
needed_ckpt_ranks = []
|
|
for param_name in strategy_keys:
|
|
param_needs_rank_list = rank_list[param_name][0]
|
|
if len(param_needs_rank_list) > len(needed_ckpt_ranks): # 实际上应该求并集。
|
|
needed_ckpt_ranks = param_needs_rank_list
|
|
return needed_ckpt_ranks
|
|
|
|
|
|
def transform_model_parallel(restore_local_ckpt_file_list, train_strategy_file, save_path, using_fp16=False):
|
|
# check whether the ckpt_file has been download
|
|
for local_file in restore_local_ckpt_file_list:
|
|
if not os.path.exists(local_file):
|
|
raise ValueError("ckpt not download: ", restore_local_ckpt_file_list)
|
|
time.sleep(0.1)
|
|
param_total_dict = defaultdict(dict)
|
|
for file_index, local_file in enumerate(restore_local_ckpt_file_list):
|
|
param_dict = load_checkpoint(local_file)
|
|
for param_name, param in param_dict.items():
|
|
if "adam" in param_name:
|
|
continue
|
|
print(f"===loading {file_index}: ", param_name, flush=True)
|
|
param_total_dict[param_name][file_index] = param
|
|
print("===load param done.", flush=True)
|
|
train_strategy_origin = build_searched_strategy(train_strategy_file)
|
|
train_strategy = _convert_to_list(train_strategy_origin)
|
|
rank_list = _infer_rank_list(train_strategy, None)
|
|
strategy_keys = list(train_strategy_origin.keys())
|
|
merged_param_list = []
|
|
for param_name in param_total_dict.keys():
|
|
if "adam" in param_name:
|
|
continue
|
|
if param_name not in strategy_keys:
|
|
each_param = {"name": param_name}
|
|
each_param["data"] = param_total_dict[param_name][0]
|
|
print("====", param_name, param_total_dict[param_name][0].data.asnumpy().shape, flush=True)
|
|
merged_param_list.append(each_param)
|
|
continue
|
|
param_unique_strategy = _remove_repeated_slices(train_strategy[param_name])
|
|
_param_unique_strategy = _convert_to_layout(param_name, param_unique_strategy)
|
|
sliced_params = []
|
|
if using_fp16 and "embedding" not in param_name and "layernorm" not in param_name:
|
|
for i in rank_list[param_name][0]:
|
|
slice_param = param_total_dict[param_name][i]
|
|
layerwise_parallel = slice_param.layerwise_parallel
|
|
requires_grad = slice_param.requires_grad
|
|
sliced_data = sliced_params.data.asnumpy()
|
|
sliced_data = sliced_data.astype(np.float16)
|
|
paramete_fp16 = Parameter(Tensor(sliced_data), param_name, requires_grad, layerwise_parallel)
|
|
sliced_params.append(paramete_fp16)
|
|
else:
|
|
sliced_params = [param_total_dict[param_name][i] for i in rank_list[param_name][0]]
|
|
merged_param = merge_sliced_parameter(sliced_params, _param_unique_strategy)
|
|
each_param = {"name": param_name}
|
|
each_param["data"] = merged_param
|
|
print("====", param_name, merged_param.data.asnumpy().shape, flush=True)
|
|
merged_param_list.append(each_param)
|
|
save_file = os.path.join(save_path, "predict_1p.ckpt")
|
|
save_checkpoint(merged_param_list, save_file)
|
|
return save_file
|
|
|
|
|
|
def run_transform_model_parallel_ckpt(args_opt):
|
|
# Set execution mode
|
|
context.set_context(
|
|
mode=context.GRAPH_MODE, device_target=args_opt.device_target
|
|
)
|
|
# Set parallel context
|
|
rank = 0
|
|
device_num = 1
|
|
if args_opt.distribute == "true":
|
|
rank, device_num = set_parallel_context(args_opt)
|
|
print("=====rank is: ", rank, flush=True)
|
|
ckpt_file_list = download_ckpt(args_opt, 8, device_num, rank)
|
|
if rank != 0:
|
|
return
|
|
needed_ckpt_ranks = get_needed_model_parallel_list(args_opt.strategy_load_ckpt_path, rank)
|
|
restore_local_ckpt_file_list = [ckpt_file_list[i] for i in needed_ckpt_ranks]
|
|
print("====restore_local_ckpt_file_list====", restore_local_ckpt_file_list, flush=True)
|
|
save_path = os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}")
|
|
if not os.path.exists(save_path):
|
|
os.mkdir(save_path)
|
|
save_file = transform_model_parallel(restore_local_ckpt_file_list, args_opt.strategy_load_ckpt_path, save_path)
|
|
obs_save_path = args_opt.save_checkpoint_obs_path
|
|
time.sleep(rank * 0.1)
|
|
if not mox.file.exists(obs_save_path):
|
|
mox.file.make_dirs(obs_save_path)
|
|
rank_obs_save_path = os.path.join(obs_save_path, f"rank_{rank}")
|
|
if not mox.file.exists(rank_obs_save_path):
|
|
mox.file.make_dirs(rank_obs_save_path)
|
|
rank_obs_save_file = os.path.join(rank_obs_save_path, f"code-13B{rank}-{args_opt.load_ckpt_epoch}.ckpt")
|
|
if not os.path.exists(save_file):
|
|
raise ValueError(save_file + " not exists")
|
|
mox.file.copy(save_file, rank_obs_save_file)
|
|
print("=====save ok, save_path", save_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
opt = get_args()
|
|
set_parse(opt)
|
|
run_transform_model_parallel_ckpt(opt)
|