You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

190 lines
8.4 KiB
Python

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
PanguAlpha train script
"""
import datetime
import numpy as np
import glob
import os
import math
import time
from collections import defaultdict
import moxing as mox
from tensorboardX import SummaryWriter
from mindspore import context
import mindspore.communication.management as D
from mindspore.context import ParallelMode
import mindspore.nn as nn
from mindspore.train.callback import TimeMonitor
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
import mindspore.common.dtype as mstype
from mindspore.parallel import set_algo_parameters
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_distributed_checkpoint, load_checkpoint, load_param_into_net
import mindspore
from mindspore.train.serialization import load_checkpoint, build_searched_strategy, save_checkpoint, \
merge_sliced_parameter
from mindspore.common.parameter import Parameter
from mindspore import Tensor
from src.adam import AdamWeightDecayOp
from src.dataset import create_dataset
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell
from src.pangu_alpha_config import set_parse, PanguAlphaConfig
from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay
from src.utils import download_data
from mindspore.profiler import Profiler
project_root = os.path.abspath(
os.path.dirname(os.path.realpath(__file__)) + os.path.sep + "..")
print('project_root:', project_root)
def set_parallel_context(args_opt):
r"""Set parallel context"""
D.init()
device_num = D.get_group_size()
rank = D.get_rank()
print("rank_id is {}, device_num is {}".format(rank, device_num))
context.reset_auto_parallel_context()
context.set_auto_parallel_context(
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False,
full_batch=bool(args_opt.full_batch), strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path,
enable_parallel_optimizer=bool(args_opt.optimizer_shard), strategy_ckpt_save_file='strategy.ckpt',
optimizer_weight_shard_size=16, optimizer_weight_shard_aggregated_save=True)
set_algo_parameters(elementwise_op_strategy_follow=True)
_set_multi_subgraphs()
return rank, device_num
def download_ckpt(args_opt, file_num, rank_num, rank_id):
ckpt_list = []
for rank in range(0, file_num):
ckpt_name = f"code-13B{rank}_22-{args_opt.load_ckpt_epoch}_2.ckpt"
local_file = os.path.join(args_opt.save_checkpoint_path, f"origin_rank_{rank}", ckpt_name)
ckpt_list.append(local_file)
if rank % rank_num != rank_id:
continue
time.sleep(rank * 0.05)
os.mkdir(os.path.join(args_opt.save_checkpoint_path, f"origin_rank_{rank}"))
if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name)):
print(f"Checkpoint from rank {rank} doesn't exist!")
mox.file.copy(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), local_file)
print("===download ckpt ok: ", local_file, flush=True)
return ckpt_list
def get_needed_opt_shard_list(train_strategy_file, self_rank):
train_strategy_origin = build_searched_strategy(train_strategy_file)
strategy_keys = list(train_strategy_origin.keys())
needed_ckpt_ranks = []
for param_name in strategy_keys:
opt_weight_shard_size = train_strategy_origin[param_name].opt_weight_shard_size
opt_weight_shard_step = train_strategy_origin[param_name].opt_weight_shard_step
if opt_weight_shard_size <= 0:
continue
group_index = self_rank % opt_weight_shard_step
current_needed_ckpt_ranks = [group_index + i * opt_weight_shard_step for i in range(0, opt_weight_shard_size)]
if len(current_needed_ckpt_ranks) > len(needed_ckpt_ranks):
needed_ckpt_ranks = current_needed_ckpt_ranks
return needed_ckpt_ranks
def transform_opt_shard(restore_local_ckpt_file_list, train_strategy_file, save_path):
# check whether the ckpt_file has been download
for local_file in restore_local_ckpt_file_list:
if not os.path.exists(local_file):
raise ValueError("ckpt not download: ", restore_local_ckpt_file_list)
time.sleep(0.1)
param_total_dict = defaultdict(dict)
for file_index, local_file in enumerate(restore_local_ckpt_file_list):
param_dict = load_checkpoint(local_file)
for param_name, param in param_dict.items():
param_total_dict[param_name][file_index] = param
train_strategy_origin = build_searched_strategy(train_strategy_file)
strategy_keys = list(train_strategy_origin.keys())
merged_param_list = []
for param_name in param_total_dict.keys():
if param_name not in strategy_keys:
each_param = {"name": param_name}
each_param["data"] = param_total_dict[param_name][0]
print("====", param_name, param_total_dict[param_name][0].data.asnumpy().shape, flush=True)
merged_param_list.append(each_param)
continue
opt_weight_shard_size = train_strategy_origin[param_name].opt_weight_shard_size
opt_weight_shard_step = train_strategy_origin[param_name].opt_weight_shard_step
if opt_weight_shard_step == 0:
print("====not opt shard:", param_name)
each_param = {"name": param_name}
each_param["data"] = param_total_dict[param_name][0]
print("====", param_name, param_total_dict[param_name][0].data.asnumpy().shape, flush=True)
merged_param_list.append(each_param)
continue
print("====do opt shard:", param_name)
sliced_params = [param_total_dict[param_name][i] for i in range(len(param_total_dict[param_name]))]
merged_param = merge_sliced_parameter(sliced_params, None)
each_param = {"name": param_name}
each_param["data"] = merged_param
print("====", param_name, merged_param.data.asnumpy().shape, flush=True)
merged_param_list.append(each_param)
save_file = os.path.join(save_path, "predict.ckpt")
save_checkpoint(merged_param_list, save_file)
return save_file
def run_transform_opt_shard_ckpt(args_opt):
# Set execution mode
context.set_context(
mode=context.GRAPH_MODE, device_target=args_opt.device_target
)
# Set parallel context
rank = 0
device_num = 1
if args_opt.distribute == "true":
rank, device_num = set_parallel_context(args_opt)
print("=====rank is: ", rank, flush=True)
ckpt_file_list = download_ckpt(args_opt, 128, device_num, rank)
needed_ckpt_ranks = get_needed_opt_shard_list(args_opt.strategy_load_ckpt_path, rank)
restore_local_ckpt_file_list = [ckpt_file_list[i] for i in needed_ckpt_ranks]
print("====restore_local_ckpt_file_list====", restore_local_ckpt_file_list, flush=True)
save_path = os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}")
os.mkdir(save_path)
save_file = transform_opt_shard(restore_local_ckpt_file_list, args_opt.strategy_load_ckpt_path, save_path)
obs_save_path = args_opt.save_checkpoint_obs_path
time.sleep(rank * 0.1)
if not mox.file.exists(obs_save_path):
mox.file.make_dirs(obs_save_path)
rank_obs_save_path = os.path.join(obs_save_path, f"rank_{rank}")
if not mox.file.exists(rank_obs_save_path):
mox.file.make_dirs(rank_obs_save_path)
rank_obs_save_file = os.path.join(rank_obs_save_path, f"code-13B{rank}-{args_opt.load_ckpt_epoch}.ckpt")
if not os.path.exists(save_file):
raise ValueError(save_file + " not exists")
mox.file.copy(save_file, rank_obs_save_file)
print("=====save ok, save_path", save_path)
if __name__ == "__main__":
opt = get_args()
set_parse(opt)
run_transform_opt_shard_ckpt(opt)