mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
190 lines
8.4 KiB
Python
190 lines
8.4 KiB
Python
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""
|
|
PanguAlpha train script
|
|
"""
|
|
|
|
import datetime
|
|
import numpy as np
|
|
import glob
|
|
import os
|
|
import math
|
|
import time
|
|
from collections import defaultdict
|
|
import moxing as mox
|
|
|
|
from tensorboardX import SummaryWriter
|
|
|
|
from mindspore import context
|
|
import mindspore.communication.management as D
|
|
from mindspore.context import ParallelMode
|
|
import mindspore.nn as nn
|
|
from mindspore.train.callback import TimeMonitor
|
|
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
|
import mindspore.common.dtype as mstype
|
|
from mindspore.parallel import set_algo_parameters
|
|
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
|
from mindspore.train.serialization import load_distributed_checkpoint, load_checkpoint, load_param_into_net
|
|
|
|
import mindspore
|
|
from mindspore.train.serialization import load_checkpoint, build_searched_strategy, save_checkpoint, \
|
|
merge_sliced_parameter
|
|
from mindspore.common.parameter import Parameter
|
|
from mindspore import Tensor
|
|
|
|
from src.adam import AdamWeightDecayOp
|
|
from src.dataset import create_dataset
|
|
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell
|
|
from src.pangu_alpha_config import set_parse, PanguAlphaConfig
|
|
from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay
|
|
from src.utils import download_data
|
|
from mindspore.profiler import Profiler
|
|
|
|
project_root = os.path.abspath(
|
|
os.path.dirname(os.path.realpath(__file__)) + os.path.sep + "..")
|
|
print('project_root:', project_root)
|
|
|
|
|
|
def set_parallel_context(args_opt):
|
|
r"""Set parallel context"""
|
|
D.init()
|
|
device_num = D.get_group_size()
|
|
rank = D.get_rank()
|
|
print("rank_id is {}, device_num is {}".format(rank, device_num))
|
|
context.reset_auto_parallel_context()
|
|
context.set_auto_parallel_context(
|
|
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False,
|
|
full_batch=bool(args_opt.full_batch), strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path,
|
|
enable_parallel_optimizer=bool(args_opt.optimizer_shard), strategy_ckpt_save_file='strategy.ckpt',
|
|
optimizer_weight_shard_size=16, optimizer_weight_shard_aggregated_save=True)
|
|
set_algo_parameters(elementwise_op_strategy_follow=True)
|
|
_set_multi_subgraphs()
|
|
return rank, device_num
|
|
|
|
|
|
def download_ckpt(args_opt, file_num, rank_num, rank_id):
|
|
ckpt_list = []
|
|
for rank in range(0, file_num):
|
|
ckpt_name = f"code-13B{rank}_22-{args_opt.load_ckpt_epoch}_2.ckpt"
|
|
local_file = os.path.join(args_opt.save_checkpoint_path, f"origin_rank_{rank}", ckpt_name)
|
|
ckpt_list.append(local_file)
|
|
if rank % rank_num != rank_id:
|
|
continue
|
|
time.sleep(rank * 0.05)
|
|
os.mkdir(os.path.join(args_opt.save_checkpoint_path, f"origin_rank_{rank}"))
|
|
if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name)):
|
|
print(f"Checkpoint from rank {rank} doesn't exist!")
|
|
mox.file.copy(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), local_file)
|
|
print("===download ckpt ok: ", local_file, flush=True)
|
|
return ckpt_list
|
|
|
|
|
|
def get_needed_opt_shard_list(train_strategy_file, self_rank):
|
|
train_strategy_origin = build_searched_strategy(train_strategy_file)
|
|
strategy_keys = list(train_strategy_origin.keys())
|
|
needed_ckpt_ranks = []
|
|
for param_name in strategy_keys:
|
|
opt_weight_shard_size = train_strategy_origin[param_name].opt_weight_shard_size
|
|
opt_weight_shard_step = train_strategy_origin[param_name].opt_weight_shard_step
|
|
if opt_weight_shard_size <= 0:
|
|
continue
|
|
group_index = self_rank % opt_weight_shard_step
|
|
current_needed_ckpt_ranks = [group_index + i * opt_weight_shard_step for i in range(0, opt_weight_shard_size)]
|
|
if len(current_needed_ckpt_ranks) > len(needed_ckpt_ranks):
|
|
needed_ckpt_ranks = current_needed_ckpt_ranks
|
|
return needed_ckpt_ranks
|
|
|
|
|
|
def transform_opt_shard(restore_local_ckpt_file_list, train_strategy_file, save_path):
|
|
# check whether the ckpt_file has been download
|
|
for local_file in restore_local_ckpt_file_list:
|
|
if not os.path.exists(local_file):
|
|
raise ValueError("ckpt not download: ", restore_local_ckpt_file_list)
|
|
time.sleep(0.1)
|
|
param_total_dict = defaultdict(dict)
|
|
for file_index, local_file in enumerate(restore_local_ckpt_file_list):
|
|
param_dict = load_checkpoint(local_file)
|
|
for param_name, param in param_dict.items():
|
|
param_total_dict[param_name][file_index] = param
|
|
|
|
train_strategy_origin = build_searched_strategy(train_strategy_file)
|
|
strategy_keys = list(train_strategy_origin.keys())
|
|
merged_param_list = []
|
|
for param_name in param_total_dict.keys():
|
|
if param_name not in strategy_keys:
|
|
each_param = {"name": param_name}
|
|
each_param["data"] = param_total_dict[param_name][0]
|
|
print("====", param_name, param_total_dict[param_name][0].data.asnumpy().shape, flush=True)
|
|
merged_param_list.append(each_param)
|
|
continue
|
|
opt_weight_shard_size = train_strategy_origin[param_name].opt_weight_shard_size
|
|
opt_weight_shard_step = train_strategy_origin[param_name].opt_weight_shard_step
|
|
if opt_weight_shard_step == 0:
|
|
print("====not opt shard:", param_name)
|
|
each_param = {"name": param_name}
|
|
each_param["data"] = param_total_dict[param_name][0]
|
|
print("====", param_name, param_total_dict[param_name][0].data.asnumpy().shape, flush=True)
|
|
merged_param_list.append(each_param)
|
|
continue
|
|
print("====do opt shard:", param_name)
|
|
sliced_params = [param_total_dict[param_name][i] for i in range(len(param_total_dict[param_name]))]
|
|
merged_param = merge_sliced_parameter(sliced_params, None)
|
|
each_param = {"name": param_name}
|
|
each_param["data"] = merged_param
|
|
print("====", param_name, merged_param.data.asnumpy().shape, flush=True)
|
|
merged_param_list.append(each_param)
|
|
save_file = os.path.join(save_path, "predict.ckpt")
|
|
save_checkpoint(merged_param_list, save_file)
|
|
return save_file
|
|
|
|
|
|
def run_transform_opt_shard_ckpt(args_opt):
|
|
# Set execution mode
|
|
context.set_context(
|
|
mode=context.GRAPH_MODE, device_target=args_opt.device_target
|
|
)
|
|
# Set parallel context
|
|
rank = 0
|
|
device_num = 1
|
|
if args_opt.distribute == "true":
|
|
rank, device_num = set_parallel_context(args_opt)
|
|
print("=====rank is: ", rank, flush=True)
|
|
ckpt_file_list = download_ckpt(args_opt, 128, device_num, rank)
|
|
needed_ckpt_ranks = get_needed_opt_shard_list(args_opt.strategy_load_ckpt_path, rank)
|
|
restore_local_ckpt_file_list = [ckpt_file_list[i] for i in needed_ckpt_ranks]
|
|
print("====restore_local_ckpt_file_list====", restore_local_ckpt_file_list, flush=True)
|
|
save_path = os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}")
|
|
os.mkdir(save_path)
|
|
save_file = transform_opt_shard(restore_local_ckpt_file_list, args_opt.strategy_load_ckpt_path, save_path)
|
|
obs_save_path = args_opt.save_checkpoint_obs_path
|
|
time.sleep(rank * 0.1)
|
|
if not mox.file.exists(obs_save_path):
|
|
mox.file.make_dirs(obs_save_path)
|
|
rank_obs_save_path = os.path.join(obs_save_path, f"rank_{rank}")
|
|
if not mox.file.exists(rank_obs_save_path):
|
|
mox.file.make_dirs(rank_obs_save_path)
|
|
rank_obs_save_file = os.path.join(rank_obs_save_path, f"code-13B{rank}-{args_opt.load_ckpt_epoch}.ckpt")
|
|
if not os.path.exists(save_file):
|
|
raise ValueError(save_file + " not exists")
|
|
mox.file.copy(save_file, rank_obs_save_file)
|
|
print("=====save ok, save_path", save_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
opt = get_args()
|
|
set_parse(opt)
|
|
run_transform_opt_shard_ckpt(opt)
|