mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
220 lines
7.6 KiB
Python
220 lines
7.6 KiB
Python
# coding=utf-8
|
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""General utilities."""
|
|
|
|
import sys
|
|
|
|
import torch
|
|
from torch.nn.parallel import DistributedDataParallel as torchDDP
|
|
|
|
from apex.multi_tensor_apply import multi_tensor_applier
|
|
import amp_C
|
|
|
|
from codegeex.megatron import get_args
|
|
from codegeex.megatron import print_rank_0
|
|
from codegeex.megatron import get_adlr_autoresume
|
|
from codegeex.megatron import mpu
|
|
from codegeex.megatron.model.module import param_is_not_shared
|
|
from codegeex.megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
|
|
|
|
|
|
def unwrap_model(model, module_instances=(torchDDP)):
|
|
return_list = True
|
|
if not isinstance(model, list):
|
|
model = [model]
|
|
return_list = False
|
|
unwrapped_model = []
|
|
for model_module in model:
|
|
while isinstance(model_module, module_instances):
|
|
model_module = model_module.module
|
|
unwrapped_model.append(model_module)
|
|
if not return_list:
|
|
return unwrapped_model[0]
|
|
return unwrapped_model
|
|
|
|
|
|
def calc_params_l2_norm(model):
|
|
"""Calculate l2 norm of parameters"""
|
|
args = get_args()
|
|
if not isinstance(model, list):
|
|
model = [model]
|
|
# Remove duplicate params.
|
|
params_data = []
|
|
for model_ in model:
|
|
for param in model_.parameters():
|
|
is_not_shared = param_is_not_shared(param)
|
|
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
|
|
if is_not_shared and is_not_tp_duplicate:
|
|
if args.bf16:
|
|
params_data.append(param.data.float())
|
|
else:
|
|
params_data.append(param.data)
|
|
# Calculate norm
|
|
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
|
norm, _ = multi_tensor_applier(
|
|
amp_C.multi_tensor_l2norm,
|
|
dummy_overflow_buf,
|
|
[params_data],
|
|
False, # no per-parameter norm
|
|
)
|
|
norm_2 = norm * norm
|
|
# Sum across all model-parallel GPUs.
|
|
torch.distributed.all_reduce(
|
|
norm_2, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group()
|
|
)
|
|
return norm_2.item() ** 0.5
|
|
|
|
|
|
def average_losses_across_data_parallel_group(losses):
|
|
"""Reduce a tensor of losses across all GPUs."""
|
|
averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses])
|
|
torch.distributed.all_reduce(averaged_losses, group=mpu.get_data_parallel_group())
|
|
averaged_losses = averaged_losses / torch.distributed.get_world_size(
|
|
group=mpu.get_data_parallel_group()
|
|
)
|
|
|
|
return averaged_losses
|
|
|
|
|
|
def report_memory(name):
|
|
"""Simple GPU memory report."""
|
|
mega_bytes = 1024.0 * 1024.0
|
|
string = name + " memory (MB)"
|
|
string += " | allocated: {}".format(torch.cuda.memory_allocated() / mega_bytes)
|
|
string += " | max allocated: {}".format(
|
|
torch.cuda.max_memory_allocated() / mega_bytes
|
|
)
|
|
string += " | reserved: {}".format(torch.cuda.memory_reserved() / mega_bytes)
|
|
string += " | max reserved: {}".format(
|
|
torch.cuda.max_memory_reserved() / mega_bytes
|
|
)
|
|
if mpu.get_data_parallel_rank() == 0:
|
|
print("[Rank {}] {}".format(torch.distributed.get_rank(), string), flush=True)
|
|
|
|
|
|
def print_params_min_max_norm(optimizer, iteration):
|
|
"""Print min, max, and norm of all parameters."""
|
|
index = 0
|
|
rank = torch.distributed.get_rank()
|
|
string = "iteration, rank, index, tensor-model-parallel, min, max, norm\n"
|
|
optimizer_ = optimizer.optimizer
|
|
for param_group in optimizer_.param_groups:
|
|
for param in param_group["params"]:
|
|
index += 1
|
|
min_ = param.data.min()
|
|
max_ = param.data.max()
|
|
norm = torch.linalg.norm(param.data)
|
|
string += "{:7d}, {:4d}, {:4d}, {:2d}, ".format(
|
|
iteration, rank, index, int(param.tensor_model_parallel)
|
|
)
|
|
string += "{:.6E}, {:.6E}, {:.6E}\n".format(min_, max_, norm)
|
|
print(string, flush=True)
|
|
|
|
|
|
def check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler):
|
|
"""Check for autoresume signal and exit if it is received."""
|
|
from codegeex.megatron.checkpointing import save_checkpoint
|
|
|
|
args = get_args()
|
|
autoresume = get_adlr_autoresume()
|
|
# Add barrier to ensure consistnecy.
|
|
torch.distributed.barrier()
|
|
if autoresume.termination_requested():
|
|
if args.save:
|
|
save_checkpoint(iteration, model, optimizer, lr_scheduler)
|
|
print_rank_0(">>> autoresume termination request found!")
|
|
if torch.distributed.get_rank() == 0:
|
|
autoresume.request_resume()
|
|
print_rank_0(">>> training terminated. Returning")
|
|
sys.exit(0)
|
|
|
|
|
|
def get_ltor_masks_and_position_ids(
|
|
data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss
|
|
):
|
|
"""Build masks and position id for left to right model."""
|
|
|
|
# Extract batch size and sequence length.
|
|
micro_batch_size, seq_length = data.size()
|
|
|
|
# Attention mask (lower triangular).
|
|
if reset_attention_mask:
|
|
att_mask_batch = micro_batch_size
|
|
else:
|
|
att_mask_batch = 1
|
|
attention_mask = torch.tril(
|
|
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
|
|
).view(att_mask_batch, 1, seq_length, seq_length)
|
|
|
|
# Loss mask.
|
|
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
|
|
if eod_mask_loss:
|
|
loss_mask[data == eod_token] = 0.0
|
|
|
|
# Position ids.
|
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
|
|
position_ids = position_ids.unsqueeze(0).expand_as(data)
|
|
# We need to clone as the ids will be modifed based on batch index.
|
|
if reset_position_ids:
|
|
position_ids = position_ids.clone()
|
|
|
|
if reset_position_ids or reset_attention_mask:
|
|
# Loop through the batches:
|
|
for b in range(micro_batch_size):
|
|
|
|
# Find indecies where EOD token is.
|
|
eod_index = position_ids[b, data[b] == eod_token]
|
|
# Detach indecies from positions if going to modify positions.
|
|
if reset_position_ids:
|
|
eod_index = eod_index.clone()
|
|
|
|
# Loop through EOD indecies:
|
|
prev_index = 0
|
|
for j in range(eod_index.size()[0]):
|
|
i = eod_index[j]
|
|
# Mask attention loss.
|
|
if reset_attention_mask:
|
|
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
|
|
# Reset positions.
|
|
if reset_position_ids:
|
|
position_ids[b, (i + 1) :] -= i + 1 - prev_index
|
|
prev_index = i + 1
|
|
|
|
# Convert attention mask to binary:
|
|
attention_mask = attention_mask < 0.5
|
|
|
|
return attention_mask, loss_mask, position_ids
|
|
|
|
|
|
def get_parameters_in_billions(model):
|
|
gpus_per_model = torch.distributed.get_world_size(
|
|
group=mpu.get_model_parallel_group()
|
|
)
|
|
|
|
approx_parameters_in_billions = sum(
|
|
[
|
|
sum(
|
|
[
|
|
p.ds_numel if hasattr(p, "ds_id") else p.nelement()
|
|
for p in model_module.parameters()
|
|
]
|
|
)
|
|
for model_module in model
|
|
]
|
|
)
|
|
|
|
return approx_parameters_in_billions * gpus_per_model / (1e9)
|