You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
CodeGeeX/codegeex/mindspore/save_1p_ckpt_from_8p_ckpt.py

252 lines
11 KiB
Python

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
PanguAlpha train script
"""
import datetime
import argparse
import numpy as np
import json
import glob
import os
import math
import time
from pathlib2 import Path
from collections import defaultdict
import moxing as mox
from tensorboardX import SummaryWriter
from mindspore import context
import mindspore.communication.management as D
from mindspore.context import ParallelMode
import mindspore.nn as nn
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.parallel import set_algo_parameters
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.parallel.nn import TransformerOpParallelConfig, CrossEntropyLoss
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
import mindspore
from mindspore.train.serialization import load_checkpoint, build_searched_strategy, save_checkpoint, \
merge_sliced_parameter, _convert_to_list, _convert_to_layout
from mindspore.common.parameter import Parameter
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap, ParallelLayouts
from mindspore import Tensor
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
from src.adam import AdamWeightDecayOp
from src.dataset import create_dataset
from src.pangu_alpha import PanGUAlphaWithLoss, PanguAlphaModel
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell
from src.pangu_alpha_config import set_parse, PanguAlphaConfig
from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay
from src.utils import download_data
from src.callbacks import LossCallBack, SaveCheckpointCallback
from mindspore.profiler import Profiler
project_root = os.path.abspath(
os.path.dirname(os.path.realpath(__file__)) + os.path.sep + "..")
print('project_root:', project_root)
def set_weight_decay(params):
"""
Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest
"""
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
decay_params = list(filter(decay_filter, params))
other_params = list(filter(lambda x: not decay_filter(x), params))
group_params = [
{"params": decay_params, "weight_decay": 1e-1},
{"params": other_params, "weight_decay": 0.0},
{"order_params": params},
]
return group_params
def add_checkpoint_callback_policy(args_param, callback, rank_id):
r"""
Add checkpoint policy to callback.
"""
if args_param.save_checkpoint:
# checkpoint store epoch_num and step_num info
ckpt_append_info = [{"epoch_num": args_param.has_trained_epoches, "step_num": args_param.has_trained_steps}]
ckpt_config = CheckpointConfig(
save_checkpoint_steps=args_param.save_checkpoint_steps,
keep_checkpoint_max=args_param.keep_checkpoint_max,
integrated_save=False,
append_info=ckpt_append_info,
)
# save checkpoint into rank directory
ckpoint_cb = ModelCheckpoint(prefix=args_param.ckpt_name_prefix + str(rank_id),
directory=os.path.join(args_param.save_checkpoint_path, f"rank_{rank_id}"),
config=ckpt_config)
callback.append(ckpoint_cb)
saveckpt_cb = SaveCheckpointCallback(cache_dir=args_param.save_checkpoint_path,
bucket=args_param.save_checkpoint_obs_path,
local_rank=rank_id,
has_trained_epoch=args_param.has_trained_epoches,
has_trained_step=args_param.has_trained_steps,
syn_times=args_param.save_checkpoint_steps)
callback.append(saveckpt_cb)
def set_parallel_context(args_opt):
r"""Set parallel context"""
D.init()
device_num = D.get_group_size()
rank = D.get_rank()
print("rank_id is {}, device_num is {}".format(rank, device_num))
context.reset_auto_parallel_context()
context.set_auto_parallel_context(
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False,
full_batch=bool(args_opt.full_batch), strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path,
enable_parallel_optimizer=bool(args_opt.optimizer_shard), strategy_ckpt_save_file='strategy.ckpt',
optimizer_weight_shard_size=16, optimizer_weight_shard_aggregated_save=True)
set_algo_parameters(elementwise_op_strategy_follow=True)
_set_multi_subgraphs()
return rank, device_num
def download_ckpt(args_opt, file_num, rank_num, rank_id):
ckpt_list = []
for rank in range(0, file_num):
ckpt_name = f"code-13B{rank}-{args_opt.load_ckpt_epoch}.ckpt"
local_file = os.path.join(args_opt.save_checkpoint_path, f"origin_rank_{rank}", ckpt_name)
ckpt_list.append(local_file)
if rank % rank_num != rank_id:
continue
time.sleep(rank * 0.05)
if not os.path.exists(os.path.join(args_opt.save_checkpoint_path, f"origin_rank_{rank}")):
os.mkdir(os.path.join(args_opt.save_checkpoint_path, f"origin_rank_{rank}"))
if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name)):
print(f"Checkpoint from rank {rank} doesn't exist!")
mox.file.copy(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), local_file)
print("===download ckpt ok: ", local_file, flush=True)
print(ckpt_list)
return ckpt_list
def get_needed_model_parallel_list(train_strategy_file, self_rank):
train_strategy_origin = build_searched_strategy(train_strategy_file)
strategy_keys = list(train_strategy_origin.keys())
train_strategy = _convert_to_list(train_strategy_origin)
rank_list = _infer_rank_list(train_strategy, None)
needed_ckpt_ranks = []
for param_name in strategy_keys:
param_needs_rank_list = rank_list[param_name][0]
if len(param_needs_rank_list) > len(needed_ckpt_ranks): # 实际上应该求并集。
needed_ckpt_ranks = param_needs_rank_list
return needed_ckpt_ranks
def transform_model_parallel(restore_local_ckpt_file_list, train_strategy_file, save_path, using_fp16=False):
# check whether the ckpt_file has been download
for local_file in restore_local_ckpt_file_list:
if not os.path.exists(local_file):
raise ValueError("ckpt not download: ", restore_local_ckpt_file_list)
time.sleep(0.1)
param_total_dict = defaultdict(dict)
for file_index, local_file in enumerate(restore_local_ckpt_file_list):
param_dict = load_checkpoint(local_file)
for param_name, param in param_dict.items():
if "adam" in param_name:
continue
print(f"===loading {file_index}: ", param_name, flush=True)
param_total_dict[param_name][file_index] = param
print("===load param done.", flush=True)
train_strategy_origin = build_searched_strategy(train_strategy_file)
train_strategy = _convert_to_list(train_strategy_origin)
rank_list = _infer_rank_list(train_strategy, None)
strategy_keys = list(train_strategy_origin.keys())
merged_param_list = []
for param_name in param_total_dict.keys():
if "adam" in param_name:
continue
if param_name not in strategy_keys:
each_param = {"name": param_name}
each_param["data"] = param_total_dict[param_name][0]
print("====", param_name, param_total_dict[param_name][0].data.asnumpy().shape, flush=True)
merged_param_list.append(each_param)
continue
param_unique_strategy = _remove_repeated_slices(train_strategy[param_name])
_param_unique_strategy = _convert_to_layout(param_name, param_unique_strategy)
sliced_params = []
if using_fp16 and "embedding" not in param_name and "layernorm" not in param_name:
for i in rank_list[param_name][0]:
slice_param = param_total_dict[param_name][i]
layerwise_parallel = slice_param.layerwise_parallel
requires_grad = slice_param.requires_grad
sliced_data = sliced_params.data.asnumpy()
sliced_data = sliced_data.astype(np.float16)
paramete_fp16 = Parameter(Tensor(sliced_data), param_name, requires_grad, layerwise_parallel)
sliced_params.append(paramete_fp16)
else:
sliced_params = [param_total_dict[param_name][i] for i in rank_list[param_name][0]]
merged_param = merge_sliced_parameter(sliced_params, _param_unique_strategy)
each_param = {"name": param_name}
each_param["data"] = merged_param
print("====", param_name, merged_param.data.asnumpy().shape, flush=True)
merged_param_list.append(each_param)
save_file = os.path.join(save_path, "predict_1p.ckpt")
save_checkpoint(merged_param_list, save_file)
return save_file
def run_transform_model_parallel_ckpt(args_opt):
# Set execution mode
context.set_context(
mode=context.GRAPH_MODE, device_target=args_opt.device_target
)
# Set parallel context
rank = 0
device_num = 1
if args_opt.distribute == "true":
rank, device_num = set_parallel_context(args_opt)
print("=====rank is: ", rank, flush=True)
ckpt_file_list = download_ckpt(args_opt, 8, device_num, rank)
if rank != 0:
return
needed_ckpt_ranks = get_needed_model_parallel_list(args_opt.strategy_load_ckpt_path, rank)
restore_local_ckpt_file_list = [ckpt_file_list[i] for i in needed_ckpt_ranks]
print("====restore_local_ckpt_file_list====", restore_local_ckpt_file_list, flush=True)
save_path = os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}")
if not os.path.exists(save_path):
os.mkdir(save_path)
save_file = transform_model_parallel(restore_local_ckpt_file_list, args_opt.strategy_load_ckpt_path, save_path)
obs_save_path = args_opt.save_checkpoint_obs_path
time.sleep(rank * 0.1)
if not mox.file.exists(obs_save_path):
mox.file.make_dirs(obs_save_path)
rank_obs_save_path = os.path.join(obs_save_path, f"rank_{rank}")
if not mox.file.exists(rank_obs_save_path):
mox.file.make_dirs(rank_obs_save_path)
rank_obs_save_file = os.path.join(rank_obs_save_path, f"code-13B{rank}-{args_opt.load_ckpt_epoch}.ckpt")
if not os.path.exists(save_file):
raise ValueError(save_file + " not exists")
mox.file.copy(save_file, rank_obs_save_file)
print("=====save ok, save_path", save_path)
if __name__ == "__main__":
opt = get_args()
set_parse(opt)
run_transform_model_parallel_ckpt(opt)