mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1375 lines
51 KiB
Python
1375 lines
51 KiB
Python
# coding=utf-8
|
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Pretrain utilities."""
|
|
import os
|
|
from datetime import datetime
|
|
import math
|
|
import sys
|
|
import time
|
|
import json
|
|
|
|
# The earliest we can measure the start time.
|
|
_TRAIN_START_TIME = time.time()
|
|
|
|
import torch
|
|
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
|
|
|
|
from codegeex.megatron import get_args
|
|
from codegeex.megatron import get_timers
|
|
from codegeex.megatron import get_tensorboard_writer
|
|
from codegeex.megatron import get_current_global_batch_size
|
|
from codegeex.megatron import get_num_microbatches
|
|
from codegeex.megatron import is_last_rank
|
|
from codegeex.megatron import update_num_microbatches
|
|
from codegeex.megatron import mpu
|
|
from codegeex.megatron import print_rank_0
|
|
from codegeex.megatron import print_rank_last
|
|
from codegeex.megatron.checkpointing import load_checkpoint
|
|
from codegeex.megatron.checkpointing import save_checkpoint
|
|
from codegeex.megatron.model import Float16Module
|
|
from codegeex.megatron.optimizer import get_megatron_optimizer
|
|
from codegeex.megatron.initialize import initialize_megatron
|
|
from codegeex.megatron.initialize import write_args_to_tensorboard
|
|
from codegeex.megatron.initialize import initialize_wandb_experiment
|
|
from codegeex.megatron.learning_rates import AnnealingLR
|
|
from codegeex.megatron.model import DistributedDataParallel as LocalDDP
|
|
from codegeex.megatron.utils import check_adlr_autoresume_termination
|
|
from codegeex.megatron.utils import unwrap_model
|
|
from codegeex.megatron.data.data_samplers import build_pretraining_data_loader
|
|
from codegeex.megatron.utils import calc_params_l2_norm
|
|
from codegeex.megatron.schedules import forward_backward_no_pipelining
|
|
from codegeex.megatron.schedules import forward_backward_pipelining_without_interleaving
|
|
from codegeex.megatron.schedules import forward_backward_pipelining_with_interleaving
|
|
from codegeex.megatron.utils import report_memory, flops_calculator
|
|
|
|
import deepspeed
|
|
|
|
try:
|
|
import wandb
|
|
except ImportError:
|
|
wandb = None
|
|
|
|
from filelock import FileLock
|
|
import pathlib
|
|
|
|
try:
|
|
import bmcook
|
|
from bmcook import Config
|
|
except ImportError:
|
|
print("bmcook not imported.")
|
|
bmcook = None
|
|
|
|
|
|
def print_datetime(string):
|
|
"""Note that this call will sync across all ranks."""
|
|
torch.distributed.barrier()
|
|
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
print_rank_0("[" + string + "] datetime: {} ".format(time_str))
|
|
|
|
|
|
def compress_setup(args, model, optimizer):
|
|
teacher = get_model(args)
|
|
cook_config = ConfigParser(args.cook_config)
|
|
CPMAntTrainer.set_compression(cook_config, model, optimizer, teacher=teacher, remove_ckptblock=False, target_linear=Linear)
|
|
|
|
def pretrain(
|
|
train_valid_test_dataset_provider,
|
|
model_provider,
|
|
forward_step_func,
|
|
extra_args_provider=None,
|
|
args_defaults={},
|
|
):
|
|
"""Main training program.
|
|
|
|
This function will run the followings in the order provided:
|
|
1) initialize Megatron.
|
|
2) setup model, optimizer and lr schedule using the model_provider.
|
|
3) call train_val_test_data_provider to get train/val/test datasets.
|
|
4) train the modle using the forward_step_func.
|
|
|
|
Arguments:
|
|
train_valid_test_dataset_provider: a function that takes the size of
|
|
train/valid/test dataset and returns `train, valid, test` datasets.
|
|
model_provider: a function that returns a vanilla version of the
|
|
model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
|
|
forward_step_func: a function that takes a `data iterator` and `model`,
|
|
and returns a `loss` scalar with a dictionary with key:values being
|
|
the info we would like to monitor during training, for example
|
|
`lm-loss: value`. We also require that this function add
|
|
`batch generator` to the timers class.
|
|
extra_args_provider: a function that takes a parser and adds arguments
|
|
to it. It is used for programs to add their own arguments.
|
|
args_defaults: a dictionary from argument-name to argument-value. It
|
|
to set already parse arguments.
|
|
"""
|
|
|
|
# Initalize and get arguments, timers, and Tensorboard writer.
|
|
initialize_megatron(
|
|
extra_args_provider=extra_args_provider, args_defaults=args_defaults
|
|
)
|
|
|
|
# Adjust the startup time so it reflects the largest value.
|
|
# This will be closer to what scheduler will see (outside of
|
|
# image ... launches.
|
|
global _TRAIN_START_TIME
|
|
start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
|
|
torch.distributed.all_reduce(start_time_tensor, op=torch.distributed.ReduceOp.MIN)
|
|
_TRAIN_START_TIME = start_time_tensor.item()
|
|
print_rank_0(
|
|
"time to initialize megatron (seconds): {:.3f}".format(
|
|
time.time() - _TRAIN_START_TIME
|
|
)
|
|
)
|
|
print_datetime("after megatron is initialized")
|
|
|
|
args = get_args()
|
|
timers = get_timers()
|
|
|
|
if args.local_rank == 0 and args.save is not None:
|
|
print(f"Creating output dir ...")
|
|
os.makedirs(args.save, exist_ok=True)
|
|
|
|
if args.deepspeed:
|
|
args.deepspeed_configuration = json.load(
|
|
open(args.deepspeed_config, "r", encoding="utf-8")
|
|
)
|
|
|
|
# Model, optimizer, and learning rate.
|
|
timers("model-and-optimizer-setup").start()
|
|
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
|
|
timers("model-and-optimizer-setup").stop()
|
|
print_datetime("after model, optimizer, and learning rate " "scheduler are built")
|
|
|
|
# Data stuff.
|
|
timers("train/valid/test-data-iterators-setup").start()
|
|
if args.virtual_pipeline_model_parallel_size is not None:
|
|
all_data_iterators = [
|
|
build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
|
|
for _ in range(len(model))
|
|
]
|
|
train_data_iterator = [
|
|
data_iterators[0] for data_iterators in all_data_iterators
|
|
]
|
|
valid_data_iterator = [
|
|
data_iterators[1] for data_iterators in all_data_iterators
|
|
]
|
|
test_data_iterator = [
|
|
data_iterators[2] for data_iterators in all_data_iterators
|
|
]
|
|
else:
|
|
(
|
|
train_data_iterator,
|
|
valid_data_iterator,
|
|
test_data_iterator,
|
|
) = build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
|
|
timers("train/valid/test-data-iterators-setup").stop()
|
|
print_datetime("after dataloaders are built")
|
|
|
|
# Print setup timing.
|
|
print_rank_0("done with setup ...")
|
|
timers.log(["model-and-optimizer-setup", "train/valid/test-data-iterators-setup"])
|
|
print_rank_0("training ...")
|
|
|
|
iteration = 0
|
|
if args.do_train and args.train_iters > 0:
|
|
iteration = train(
|
|
forward_step_func,
|
|
model,
|
|
optimizer,
|
|
lr_scheduler,
|
|
train_data_iterator,
|
|
valid_data_iterator,
|
|
)
|
|
print_datetime("after training is done")
|
|
|
|
if args.do_valid:
|
|
prefix = "the end of training for val data"
|
|
if args.co_evaluation:
|
|
for key, value in valid_data_iterator.items():
|
|
evaluate_and_print_results(
|
|
prefix, forward_step_func, value, model, iteration, False, tag=key
|
|
)
|
|
else:
|
|
evaluate_and_print_results(
|
|
prefix, forward_step_func, valid_data_iterator, model, iteration, False
|
|
)
|
|
|
|
if args.save and iteration != 0:
|
|
save_checkpoint(iteration, model, optimizer, lr_scheduler)
|
|
|
|
if args.do_test:
|
|
# Run on test data.
|
|
prefix = "the end of training for test data"
|
|
if args.co_evaluation:
|
|
for key, value in test_data_iterator.items():
|
|
evaluate_and_print_results(
|
|
prefix, forward_step_func, value, model, 0, True, tag=key
|
|
)
|
|
else:
|
|
evaluate_and_print_results(
|
|
prefix, forward_step_func, test_data_iterator, model, 0, True
|
|
)
|
|
|
|
if args.wandb_logging and is_last_rank():
|
|
wandb.finish()
|
|
|
|
|
|
def update_train_iters(args):
|
|
|
|
# For iteration-based training, we don't need to do anything
|
|
if args.train_iters:
|
|
return
|
|
|
|
# Constant batch size with sample-based training.
|
|
if args.rampup_batch_size is None:
|
|
args.train_iters = args.train_samples // args.global_batch_size
|
|
|
|
else:
|
|
# Sample based training with rampup batch size.
|
|
iterations = 0
|
|
consumed_samples = 0
|
|
# Rampup phase.
|
|
while consumed_samples <= int(args.rampup_batch_size[2]):
|
|
update_num_microbatches(consumed_samples, consistency_check=False)
|
|
consumed_samples += get_current_global_batch_size()
|
|
iterations += 1
|
|
# Reset
|
|
update_num_microbatches(0, consistency_check=False)
|
|
# Constant phase
|
|
# Note that we throw away any partial last batch.
|
|
iterations += (args.train_samples - consumed_samples) // args.global_batch_size
|
|
args.train_iters = iterations
|
|
|
|
print_rank_0("setting training iterations to {}".format(args.train_iters))
|
|
|
|
|
|
def get_model(model_provider_func):
|
|
"""Build the model."""
|
|
args = get_args()
|
|
|
|
# Build model.
|
|
if (
|
|
mpu.get_pipeline_model_parallel_world_size() > 1
|
|
and args.virtual_pipeline_model_parallel_size is not None
|
|
):
|
|
model = []
|
|
for i in range(args.virtual_pipeline_model_parallel_size):
|
|
mpu.set_virtual_pipeline_model_parallel_rank(i)
|
|
# Set pre_process and post_process only after virtual rank is set.
|
|
pre_process = mpu.is_pipeline_first_stage()
|
|
post_process = mpu.is_pipeline_last_stage()
|
|
this_model = model_provider_func(
|
|
pre_process=pre_process, post_process=post_process
|
|
)
|
|
model.append(this_model)
|
|
else:
|
|
pre_process = mpu.is_pipeline_first_stage()
|
|
post_process = mpu.is_pipeline_last_stage()
|
|
model = model_provider_func(pre_process=pre_process, post_process=post_process)
|
|
|
|
if not isinstance(model, list):
|
|
model = [model]
|
|
|
|
# Set tensor model parallel attributes if not set.
|
|
# Only parameters that are already tensor model parallel have these
|
|
# attributes set for them. We should make sure the default attributes
|
|
# are set for all params so the optimizer can use them.
|
|
for model_module in model:
|
|
for param in model_module.parameters():
|
|
mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
|
|
|
|
# Print number of parameters.
|
|
if mpu.get_data_parallel_rank() == 0:
|
|
print(
|
|
" > number of parameters on (tensor, pipeline) "
|
|
"model parallel rank ({}, {}): {}".format(
|
|
mpu.get_tensor_model_parallel_rank(),
|
|
mpu.get_pipeline_model_parallel_rank(),
|
|
sum(
|
|
[
|
|
sum(
|
|
[
|
|
p.ds_numel if hasattr(p, "ds_id") else p.nelement()
|
|
for p in model_module.parameters()
|
|
]
|
|
)
|
|
for model_module in model
|
|
]
|
|
),
|
|
),
|
|
flush=True,
|
|
)
|
|
|
|
if args.deepspeed:
|
|
return model
|
|
|
|
# GPU allocation.
|
|
print(f" > moving model to GPU ...", flush=True)
|
|
for model_module in model:
|
|
model_module.cuda(torch.cuda.current_device())
|
|
print(f" > moving to GPU done", flush=True)
|
|
|
|
# Fp16 conversion.
|
|
if args.fp16 or args.bf16:
|
|
print(f" > converting model to fp16 ...", flush=True)
|
|
model = [Float16Module(model_module, args) for model_module in model]
|
|
print(f" > converting to fp16 done", flush=True)
|
|
|
|
if args.DDP_impl == "torch":
|
|
i = torch.cuda.current_device()
|
|
model = [
|
|
torchDDP(
|
|
model_module,
|
|
device_ids=[i],
|
|
output_device=i,
|
|
process_group=mpu.get_data_parallel_group(),
|
|
)
|
|
for model_module in model
|
|
]
|
|
return model
|
|
|
|
if args.DDP_impl == "local":
|
|
print(f" > creating DDP model ...", flush=True)
|
|
model = [
|
|
LocalDDP(
|
|
model_module,
|
|
args.accumulate_allreduce_grads_in_fp32,
|
|
args.use_contiguous_buffers_in_ddp,
|
|
)
|
|
for model_module in model
|
|
]
|
|
print(f" > creating DDP model done", flush=True)
|
|
return model
|
|
|
|
raise NotImplementedError(
|
|
"Unknown DDP implementation specified: {}. " "Exiting.".format(args.DDP_impl)
|
|
)
|
|
|
|
|
|
def get_learning_rate_scheduler(optimizer):
|
|
"""Build the learning rate scheduler."""
|
|
args = get_args()
|
|
|
|
# Iteration-based training.
|
|
if args.train_iters:
|
|
if args.lr_decay_iters is None:
|
|
args.lr_decay_iters = args.train_iters
|
|
decay_steps = args.lr_decay_iters * args.global_batch_size
|
|
if args.lr_warmup_fraction is not None:
|
|
warmup_steps = args.lr_warmup_fraction * decay_steps
|
|
else:
|
|
warmup_steps = args.lr_warmup_iters * args.global_batch_size
|
|
# Sample-based training.
|
|
elif args.train_samples:
|
|
# We need to set training iters for later use. Technically
|
|
# we need to adjust the training samples too (due to last
|
|
# batch being incomplete) but we leave it as is for now.
|
|
update_train_iters(args)
|
|
if args.lr_decay_samples is None:
|
|
args.lr_decay_samples = args.train_samples
|
|
decay_steps = args.lr_decay_samples
|
|
if args.lr_warmup_fraction is not None:
|
|
warmup_steps = args.lr_warmup_fraction * decay_steps
|
|
else:
|
|
warmup_steps = args.lr_warmup_samples
|
|
else:
|
|
raise Exception("either train-iters or train-samples should be provided.")
|
|
|
|
lr_scheduler = AnnealingLR(
|
|
optimizer,
|
|
max_lr=args.lr,
|
|
min_lr=args.min_lr,
|
|
warmup_steps=warmup_steps,
|
|
decay_steps=decay_steps,
|
|
decay_style=args.lr_decay_style,
|
|
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
|
|
override_lr_scheduler=args.override_lr_scheduler,
|
|
)
|
|
|
|
return lr_scheduler
|
|
|
|
|
|
def setup_model_and_optimizer(model_provider_func):
|
|
"""Setup model and optimizer."""
|
|
args = get_args()
|
|
|
|
model = get_model(model_provider_func)
|
|
unwrapped_model = unwrap_model(model, (torchDDP, LocalDDP, Float16Module))
|
|
optimizer = get_megatron_optimizer(unwrapped_model)
|
|
lr_scheduler = get_learning_rate_scheduler(optimizer)
|
|
|
|
if args.deepspeed:
|
|
print_rank_0("DeepSpeed is enabled.")
|
|
pp = mpu.get_pipeline_model_parallel_world_size()
|
|
print_rank_0(pp)
|
|
|
|
model, optimizer, _, lr_scheduler = deepspeed.initialize(
|
|
model=model[0],
|
|
optimizer=optimizer,
|
|
args=args,
|
|
lr_scheduler=lr_scheduler,
|
|
mpu=mpu if args.no_pipeline_parallel else None,
|
|
)
|
|
print_rank_0("FinishInitialization.")
|
|
if isinstance(model, deepspeed.PipelineEngine):
|
|
# hack to get batch_fn from pretrain_gpt.py
|
|
print_rank_0("InstancePipelineEngine.")
|
|
model.set_batch_fn(model.module._megatron_batch_fn)
|
|
|
|
assert (
|
|
model.grid.get_pipe_parallel_rank()
|
|
== mpu.get_pipeline_model_parallel_rank()
|
|
)
|
|
assert (
|
|
model.grid.get_slice_parallel_rank()
|
|
== mpu.get_tensor_model_parallel_rank()
|
|
)
|
|
assert model.grid.get_data_parallel_rank() == mpu.get_data_parallel_rank()
|
|
model = [model]
|
|
print_rank_0("Finishparallel.")
|
|
|
|
if args.load is not None:
|
|
timers = get_timers()
|
|
# Extra barrier is added to make sure all ranks report the
|
|
# max time.
|
|
torch.distributed.barrier()
|
|
timers("load-checkpoint").start()
|
|
if args.low_memory_load:
|
|
load_start = time.perf_counter()
|
|
with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"), timeout=-1):
|
|
this_rank_load_start = time.perf_counter()
|
|
print(f"Rank {args.rank} is loading checkpoint ...")
|
|
args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
|
|
this_rank_load_time = time.perf_counter() - this_rank_load_start
|
|
load_time = time.perf_counter() - load_start
|
|
print(f"Rank {args.rank} loaded checkpoint, this rank time: {this_rank_load_time}, total time: {load_time}")
|
|
else:
|
|
args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
|
|
print(f"Rank {args.rank} loaded checkpoint and waiting for other ranks")
|
|
torch.distributed.barrier()
|
|
timers("load-checkpoint").stop()
|
|
timers.log(["load-checkpoint"])
|
|
else:
|
|
args.iteration = 0
|
|
|
|
# We only support local DDP with multiple micro-batches.
|
|
if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
|
|
assert args.DDP_impl == "local"
|
|
|
|
# get model without FP16 and/or TorchDDP wrappers
|
|
if (
|
|
args.iteration == 0
|
|
and len(unwrapped_model) == 1
|
|
and hasattr(unwrapped_model[0], "init_state_dict_from_bert")
|
|
):
|
|
print_rank_0("Initializing ICT from pretrained BERT model")
|
|
unwrapped_model[0].init_state_dict_from_bert()
|
|
if args.fp16:
|
|
optimizer.reload_model_params()
|
|
|
|
return model, optimizer, lr_scheduler
|
|
|
|
|
|
def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler):
|
|
"""Single training step."""
|
|
args = get_args()
|
|
timers = get_timers()
|
|
|
|
if args.deepspeed and args.ds_pipeline_enabled:
|
|
skipped_iter = 0
|
|
num_zeros_in_grad = 0
|
|
assert isinstance(model[0], deepspeed.PipelineEngine)
|
|
loss = model[0].train_batch(data_iter=data_iterator)
|
|
grad_norm = model[0].get_global_grad_norm()
|
|
return {"lm loss": loss}, skipped_iter, grad_norm, num_zeros_in_grad
|
|
|
|
# Set grad to zero.
|
|
if not args.deepspeed:
|
|
if args.DDP_impl == "local" and args.use_contiguous_buffers_in_ddp:
|
|
for partition in model:
|
|
partition.zero_grad_buffer()
|
|
else:
|
|
optimizer.zero_grad()
|
|
|
|
if mpu.get_pipeline_model_parallel_world_size() > 1:
|
|
if args.virtual_pipeline_model_parallel_size is not None:
|
|
# print_rank_0("===> fb_func = w/ interleaving")
|
|
forward_backward_func = forward_backward_pipelining_with_interleaving
|
|
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, (
|
|
"number of microbatches is not divisible by pipeline-parallel "
|
|
"size when using interleaved schedule"
|
|
)
|
|
else:
|
|
# print_rank_0("===> fb_func = w/o interleaving")
|
|
forward_backward_func = forward_backward_pipelining_without_interleaving
|
|
else:
|
|
# print_rank_0("===> fb_func = no_pp")
|
|
forward_backward_func = forward_backward_no_pipelining
|
|
# print_rank_0("===> running fb_func")
|
|
losses_reduced = forward_backward_func(
|
|
forward_step_func, data_iterator, model, optimizer, timers, forward_only=False
|
|
)
|
|
|
|
# All-reduce if needed.
|
|
if not args.deepspeed and args.DDP_impl == "local":
|
|
timers("backward-params-all-reduce").start()
|
|
for model_module in model:
|
|
model_module.allreduce_gradients()
|
|
timers("backward-params-all-reduce").stop()
|
|
|
|
# All-reduce word_embeddings' grad across first and last stages to ensure
|
|
# that word_embeddings parameters stay in sync.
|
|
# This should only run for models that support pipelined model parallelism
|
|
# (BERT and GPT-2).
|
|
if not args.deepspeed:
|
|
timers("backward-embedding-all-reduce").start()
|
|
if (
|
|
mpu.is_pipeline_first_stage(ignore_virtual=True)
|
|
or mpu.is_pipeline_last_stage(ignore_virtual=True)
|
|
) and mpu.get_pipeline_model_parallel_world_size() > 1:
|
|
if mpu.is_pipeline_first_stage(ignore_virtual=True):
|
|
unwrapped_model = model[0]
|
|
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
|
|
unwrapped_model = model[-1]
|
|
unwrapped_model = unwrap_model(
|
|
unwrapped_model, (torchDDP, LocalDDP, Float16Module)
|
|
)
|
|
|
|
if unwrapped_model.share_word_embeddings:
|
|
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
|
|
if args.DDP_impl == "local":
|
|
grad = word_embeddings_weight.main_grad
|
|
else:
|
|
grad = word_embeddings_weight.grad
|
|
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
|
|
timers("backward-embedding-all-reduce").stop()
|
|
|
|
# Update parameters.
|
|
timers("optimizer").start()
|
|
# print_rank_0("===> start of update params")
|
|
if args.deepspeed:
|
|
increment = (
|
|
get_num_microbatches() * args.micro_batch_size * args.data_parallel_size
|
|
)
|
|
model[0].step(lr_kwargs={"increment": increment})
|
|
update_successful = model[0].was_step_applied()
|
|
else:
|
|
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
|
|
# print_rank_0("===> end of update params")
|
|
timers("optimizer").stop()
|
|
|
|
# Update learning rate.
|
|
if args.deepspeed:
|
|
skipped_iter = 0
|
|
grad_norm = None
|
|
num_zeros_in_grad = None
|
|
|
|
loss_reduced = {}
|
|
for key in losses_reduced[0]:
|
|
losses_reduced_for_key = [x[key] for x in losses_reduced]
|
|
loss_reduced[key] = sum(losses_reduced_for_key) / len(
|
|
losses_reduced_for_key
|
|
)
|
|
return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
|
|
else:
|
|
if update_successful:
|
|
increment = (
|
|
get_num_microbatches() * args.micro_batch_size * args.data_parallel_size
|
|
)
|
|
lr_scheduler.step(increment=increment)
|
|
skipped_iter = 0
|
|
else:
|
|
skipped_iter = 1
|
|
|
|
if mpu.is_pipeline_last_stage(ignore_virtual=True):
|
|
# Average loss across microbatches.
|
|
loss_reduced = {}
|
|
for key in losses_reduced[0]:
|
|
losses_reduced_for_key = [x[key] for x in losses_reduced]
|
|
loss_reduced[key] = sum(losses_reduced_for_key) / len(
|
|
losses_reduced_for_key
|
|
)
|
|
return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
|
|
return {}, skipped_iter, grad_norm, num_zeros_in_grad
|
|
|
|
|
|
def training_log(
|
|
loss_dict,
|
|
total_loss_dict,
|
|
learning_rate,
|
|
iteration,
|
|
loss_scale,
|
|
report_memory_flag,
|
|
skipped_iter,
|
|
grad_norm,
|
|
params_norm,
|
|
num_zeros_in_grad,
|
|
model=None,
|
|
):
|
|
"""Log training information such as losses, timing, ...."""
|
|
args = get_args()
|
|
timers = get_timers()
|
|
writer = get_tensorboard_writer()
|
|
|
|
# Advanced, skipped, and Nan iterations.
|
|
advanced_iters_key = "advanced iterations"
|
|
skipped_iters_key = "skipped iterations"
|
|
nan_iters_key = "nan iterations"
|
|
# Advanced iterations.
|
|
if not skipped_iter:
|
|
total_loss_dict[advanced_iters_key] = (
|
|
total_loss_dict.get(advanced_iters_key, 0) + 1
|
|
)
|
|
else:
|
|
if advanced_iters_key not in total_loss_dict:
|
|
total_loss_dict[advanced_iters_key] = 0
|
|
# Skipped iterations.
|
|
total_loss_dict[skipped_iters_key] = (
|
|
total_loss_dict.get(skipped_iters_key, 0) + skipped_iter
|
|
)
|
|
# Update losses and set nan iterations
|
|
got_nan = False
|
|
for key in loss_dict:
|
|
if not skipped_iter:
|
|
total_loss_dict[key] = (
|
|
total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
|
|
)
|
|
else:
|
|
value = loss_dict[key].float().sum().item()
|
|
is_nan = value == float("inf") or value == -float("inf") or value != value
|
|
got_nan = got_nan or is_nan
|
|
total_loss_dict[nan_iters_key] = total_loss_dict.get(nan_iters_key, 0) + int(
|
|
got_nan
|
|
)
|
|
|
|
# Logging.
|
|
timers_to_log = []
|
|
|
|
def add_to_logging(name):
|
|
if name in timers.timers:
|
|
timers_to_log.append(name)
|
|
|
|
add_to_logging("forward-compute")
|
|
add_to_logging("forward-recv")
|
|
add_to_logging("forward-send")
|
|
add_to_logging("forward-backward-send-forward-backward-recv")
|
|
add_to_logging("backward-compute")
|
|
add_to_logging("backward-recv")
|
|
add_to_logging("backward-send")
|
|
add_to_logging("backward-send-forward-recv")
|
|
add_to_logging("backward-send-backward-recv")
|
|
add_to_logging("backward-params-all-reduce")
|
|
add_to_logging("backward-embedding-all-reduce")
|
|
add_to_logging("optimizer-copy-to-main-grad")
|
|
add_to_logging("optimizer-unscale-and-check-inf")
|
|
add_to_logging("optimizer-clip-main-grad")
|
|
add_to_logging("optimizer-copy-main-to-model-params")
|
|
add_to_logging("optimizer")
|
|
add_to_logging("batch-generator")
|
|
|
|
# Calculate batch size.
|
|
batch_size = (
|
|
args.micro_batch_size * args.data_parallel_size * get_num_microbatches()
|
|
)
|
|
|
|
total_iterations = (
|
|
total_loss_dict[advanced_iters_key] + total_loss_dict[skipped_iters_key]
|
|
)
|
|
|
|
# wandb logging.
|
|
if (
|
|
args.wandb_logging
|
|
and (iteration % args.wandb_log_interval == 0)
|
|
and is_last_rank()
|
|
):
|
|
wandb.log(
|
|
{
|
|
"train/tokens": args.consumed_train_tokens,
|
|
"train/lr": learning_rate,
|
|
},
|
|
step=iteration,
|
|
)
|
|
|
|
for k, v in loss_dict.items():
|
|
wandb.log({f"train/{k}": v}, step=iteration)
|
|
|
|
for k in timers_to_log:
|
|
value = timers.timers[k].elapsed(reset=False)
|
|
wandb.log({f"timer/{k}": value}, step=iteration)
|
|
|
|
# Tensorboard values.
|
|
if writer and (iteration % args.tensorboard_log_interval == 0) and is_last_rank():
|
|
writer.add_scalar(
|
|
"steps-vs-samples/y=steps,x=samples", iteration, args.consumed_train_samples
|
|
)
|
|
writer.add_scalar(
|
|
"steps-vs-samples/y=samples,x=steps", args.consumed_train_samples, iteration
|
|
)
|
|
writer.add_scalar(
|
|
"steps-vs-tokens/y=steps,x=tokens", iteration, args.consumed_train_tokens
|
|
)
|
|
writer.add_scalar(
|
|
"steps-vs-tokens/y=tokens,x=steps", args.consumed_train_tokens, iteration
|
|
)
|
|
if args.log_learning_rate_to_tensorboard:
|
|
writer.add_scalar("learning-rate/learning-rate", learning_rate, iteration)
|
|
writer.add_scalar(
|
|
"learning-rate/learning-rate vs samples",
|
|
learning_rate,
|
|
args.consumed_train_samples,
|
|
)
|
|
writer.add_scalar(
|
|
"learning-rate/learning-rate vs tokens",
|
|
learning_rate,
|
|
args.consumed_train_tokens,
|
|
)
|
|
if args.log_batch_size_to_tensorboard:
|
|
writer.add_scalar("batch-size/batch-size", batch_size, iteration)
|
|
writer.add_scalar(
|
|
"batch-size/batch-size vs samples",
|
|
batch_size,
|
|
args.consumed_train_samples,
|
|
)
|
|
for key in loss_dict:
|
|
writer.add_scalar(f"lm-loss-training/{key}", loss_dict[key], iteration)
|
|
# writer.add_scalar(
|
|
# f"lm-loss-training/{key}" + " vs samples",
|
|
# loss_dict[key],
|
|
# args.consumed_train_samples,
|
|
# )
|
|
# writer.add_scalar(
|
|
# f"lm-loss-training/{key}" + " vs tokens",
|
|
# loss_dict[key],
|
|
# args.consumed_train_tokens,
|
|
# )
|
|
if args.log_loss_scale_to_tensorboard:
|
|
writer.add_scalar("loss-scale/loss-scale", loss_scale, iteration)
|
|
writer.add_scalar(
|
|
"loss-scale/loss-scale vs samples",
|
|
loss_scale,
|
|
args.consumed_train_samples,
|
|
)
|
|
writer.add_scalar(
|
|
"loss-scale/loss-scale vs tokens",
|
|
loss_scale,
|
|
args.consumed_train_tokens,
|
|
)
|
|
if grad_norm is not None:
|
|
writer.add_scalar("grad-norm/grad-norm", grad_norm, iteration)
|
|
writer.add_scalar(
|
|
"grad-norm/grad-norm vs samples", grad_norm, args.consumed_train_samples
|
|
)
|
|
writer.add_scalar(
|
|
"grad-norm/grad-norm vs tokens", grad_norm, args.consumed_train_tokens
|
|
)
|
|
if num_zeros_in_grad is not None:
|
|
writer.add_scalar("num-zeros/num-zeros", num_zeros_in_grad, iteration)
|
|
writer.add_scalar(
|
|
"num-zeros/num-zeros vs samples",
|
|
num_zeros_in_grad,
|
|
args.consumed_train_samples,
|
|
)
|
|
writer.add_scalar(
|
|
"num-zeros/num-zeros vs tokens",
|
|
num_zeros_in_grad,
|
|
args.consumed_train_tokens,
|
|
)
|
|
if params_norm is not None:
|
|
writer.add_scalar("params-norm/params-norm", params_norm, iteration)
|
|
writer.add_scalar(
|
|
"params-norm/params-norm vs samples",
|
|
params_norm,
|
|
args.consumed_train_samples,
|
|
)
|
|
writer.add_scalar(
|
|
"params-norm/params-norm vs tokens",
|
|
params_norm,
|
|
args.consumed_train_tokens,
|
|
)
|
|
if args.log_timers_to_tensorboard:
|
|
timers.write(timers_to_log, writer, iteration, normalizer=total_iterations)
|
|
|
|
if iteration % args.log_interval == 0:
|
|
elapsed_time = timers("interval-time").elapsed()
|
|
elapsed_time_per_iteration = elapsed_time / total_iterations
|
|
|
|
# log iteration time to wandb
|
|
if args.wandb_logging and is_last_rank():
|
|
wandb.log(
|
|
{
|
|
"train/iteration-time": elapsed_time_per_iteration,
|
|
},
|
|
step=iteration,
|
|
)
|
|
|
|
# only the last rank process has a non-None _GLOBAL_TENSORBOARD_WRITER
|
|
if writer and is_last_rank():
|
|
if args.log_timers_to_tensorboard:
|
|
writer.add_scalar(
|
|
"iteration-time/iteration-time",
|
|
elapsed_time_per_iteration,
|
|
iteration,
|
|
)
|
|
writer.add_scalar(
|
|
"iteration-time/iteration-time vs samples",
|
|
elapsed_time_per_iteration,
|
|
args.consumed_train_samples,
|
|
)
|
|
writer.add_scalar(
|
|
"iteration-time/iteration-time vs tokens",
|
|
elapsed_time_per_iteration,
|
|
args.consumed_train_tokens,
|
|
)
|
|
log_string = "==> iteration {:8d}/{:8d} |".format(iteration, args.train_iters)
|
|
log_string += " consumed samples: {:12d} |".format(args.consumed_train_samples)
|
|
log_string += " consumed tokens: {:12d} |".format(args.consumed_train_tokens)
|
|
log_string += " elapsed time per iteration (ms): {:.1f} |".format(
|
|
elapsed_time_per_iteration * 1000.0
|
|
)
|
|
log_string += " learning rate: {:.3E} |".format(learning_rate)
|
|
log_string += " global batch size: {:5d} |".format(batch_size)
|
|
for key in total_loss_dict:
|
|
if key not in [advanced_iters_key, skipped_iters_key, nan_iters_key]:
|
|
avg = total_loss_dict[key].item() / float(
|
|
max(1, total_loss_dict[advanced_iters_key])
|
|
)
|
|
if avg > 0.0:
|
|
log_string += " {}: {:.6E} |".format(key, avg)
|
|
total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
|
|
log_string += " loss scale: {:.1f} |".format(loss_scale)
|
|
if grad_norm is not None:
|
|
log_string += " grad norm: {:.3f} |".format(grad_norm)
|
|
if num_zeros_in_grad is not None:
|
|
log_string += " num zeros: {:.1f} |".format(num_zeros_in_grad)
|
|
if params_norm is not None:
|
|
log_string += " params norm: {:.3f} |".format(params_norm)
|
|
log_string += " number of skipped iterations: {:3d} |".format(
|
|
total_loss_dict[skipped_iters_key]
|
|
)
|
|
log_string += " number of nan iterations: {:3d} |".format(
|
|
total_loss_dict[nan_iters_key]
|
|
)
|
|
total_loss_dict[advanced_iters_key] = 0
|
|
total_loss_dict[skipped_iters_key] = 0
|
|
total_loss_dict[nan_iters_key] = 0
|
|
print_rank_last(log_string)
|
|
if report_memory_flag and learning_rate > 0.0:
|
|
# Report memory after optimizer state has been initialized.
|
|
report_memory("(after {} iterations)".format(iteration))
|
|
report_memory_flag = False
|
|
timers.log(timers_to_log, normalizer=args.log_interval)
|
|
flops_calculator(model, args, elapsed_time)
|
|
|
|
return report_memory_flag
|
|
|
|
|
|
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
|
|
timers = get_timers()
|
|
# Extra barrier is added to make sure
|
|
# all ranks report the max time.
|
|
torch.distributed.barrier()
|
|
timers("save-checkpoint").start()
|
|
save_checkpoint(iteration, model, optimizer, lr_scheduler)
|
|
torch.distributed.barrier()
|
|
timers("save-checkpoint").stop()
|
|
timers.log(["save-checkpoint"])
|
|
|
|
|
|
def train(
|
|
forward_step_func,
|
|
model,
|
|
optimizer,
|
|
lr_scheduler,
|
|
train_data_iterator,
|
|
valid_data_iterator,
|
|
):
|
|
"""Train the model function."""
|
|
args = get_args()
|
|
timers = get_timers()
|
|
|
|
# Write args to tensorboard
|
|
write_args_to_tensorboard()
|
|
|
|
if args.wandb_logging:
|
|
torch.distributed.barrier()
|
|
print_datetime("before the initialization of wandb")
|
|
timers("wandb-init").start()
|
|
if is_last_rank():
|
|
initialize_wandb_experiment()
|
|
torch.distributed.barrier()
|
|
timers("wandb-init").stop()
|
|
timers.log(["wandb-init"])
|
|
|
|
# Turn on training mode which enables dropout.
|
|
for model_module in model:
|
|
model_module.train()
|
|
|
|
# Tracking loss.
|
|
total_loss_dict = {}
|
|
|
|
# Iterations.
|
|
iteration = args.iteration
|
|
|
|
timers("interval-time").start()
|
|
print_datetime("before the start of training step")
|
|
report_memory_flag = True
|
|
|
|
while iteration < args.train_iters and (
|
|
args.train_tokens is None or args.consumed_train_tokens < args.train_tokens
|
|
):
|
|
# print_rank_0(f'=> iteration {iteration}')
|
|
update_num_microbatches(args.consumed_train_samples)
|
|
if args.deepspeed:
|
|
# inform deepspeed of any batch size changes
|
|
global_batch_size = (
|
|
mpu.get_data_parallel_world_size()
|
|
* args.micro_batch_size
|
|
* get_num_microbatches()
|
|
)
|
|
model[0].set_train_batch_size(global_batch_size)
|
|
|
|
# print_rank_0(f"==> running train step for iteration {iteration}")
|
|
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
|
|
forward_step_func, train_data_iterator, model, optimizer, lr_scheduler
|
|
)
|
|
iteration += 1
|
|
args.iteration = iteration
|
|
new_samples = (
|
|
mpu.get_data_parallel_world_size()
|
|
* args.micro_batch_size
|
|
* get_num_microbatches()
|
|
)
|
|
args.consumed_train_samples += new_samples
|
|
args.consumed_train_tokens += new_samples * args.seq_length
|
|
|
|
# Logging.
|
|
if args.deepspeed:
|
|
loss_scale = model[0].optimizer.cur_scale
|
|
else:
|
|
loss_scale = optimizer.get_loss_scale().item()
|
|
params_norm = None
|
|
if args.log_params_norm:
|
|
params_norm = calc_params_l2_norm(model)
|
|
report_memory_flag = training_log(
|
|
loss_dict,
|
|
total_loss_dict,
|
|
optimizer.param_groups[0]["lr"],
|
|
iteration,
|
|
loss_scale,
|
|
report_memory_flag,
|
|
skipped_iter,
|
|
grad_norm,
|
|
params_norm,
|
|
num_zeros_in_grad,
|
|
model,
|
|
)
|
|
|
|
# Autoresume
|
|
if args.adlr_autoresume and (iteration % args.adlr_autoresume_interval == 0):
|
|
check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler)
|
|
|
|
# Evaluation
|
|
if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
|
|
prefix = "iteration {}".format(iteration)
|
|
if args.co_evaluation:
|
|
for key, value in valid_data_iterator.items():
|
|
evaluate_and_print_results(
|
|
prefix, forward_step_func, value, model, iteration, False, tag=key
|
|
)
|
|
else:
|
|
evaluate_and_print_results(
|
|
prefix, forward_step_func, valid_data_iterator, model, iteration, False
|
|
)
|
|
|
|
# Checkpointing
|
|
saved_checkpoint = False
|
|
if args.save and args.save_interval and (iteration % args.save_interval == 0): # debugging
|
|
save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler)
|
|
saved_checkpoint = True
|
|
|
|
# Exiting based on duration
|
|
if args.exit_duration_in_mins:
|
|
train_time = (time.time() - _TRAIN_START_TIME) / 60.0
|
|
done_cuda = torch.cuda.IntTensor([train_time > args.exit_duration_in_mins])
|
|
torch.distributed.all_reduce(done_cuda, op=torch.distributed.ReduceOp.MAX)
|
|
done = done_cuda.item()
|
|
if done:
|
|
if not saved_checkpoint:
|
|
save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler)
|
|
print_datetime("exiting program after {} minutes".format(train_time))
|
|
sys.exit()
|
|
|
|
# Exiting based on iterations
|
|
if args.exit_interval and iteration % args.exit_interval == 0:
|
|
if not saved_checkpoint:
|
|
save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler)
|
|
torch.distributed.barrier()
|
|
print_datetime("exiting program at iteration {}".format(iteration))
|
|
sys.exit()
|
|
|
|
return iteration
|
|
|
|
|
|
def evaluate(forward_step_func, data_iterator, model, verbose=False):
|
|
"""Evaluation."""
|
|
args = get_args()
|
|
|
|
# Turn on evaluation mode which disables dropout.
|
|
for model_module in model:
|
|
model_module.eval()
|
|
|
|
total_loss_dict = {}
|
|
|
|
with torch.no_grad():
|
|
iteration = 0
|
|
while iteration < args.eval_iters:
|
|
iteration += 1
|
|
if verbose and iteration % args.log_interval == 0:
|
|
print_rank_0("Evaluating iter {}/{}".format(iteration, args.eval_iters))
|
|
|
|
if mpu.get_pipeline_model_parallel_world_size() > 1:
|
|
if args.virtual_pipeline_model_parallel_size is not None:
|
|
forward_backward_func = (
|
|
forward_backward_pipelining_with_interleaving
|
|
)
|
|
else:
|
|
forward_backward_func = (
|
|
forward_backward_pipelining_without_interleaving
|
|
)
|
|
else:
|
|
forward_backward_func = forward_backward_no_pipelining
|
|
|
|
if args.deepspeed and not args.no_pipeline_parallel:
|
|
# DeepSpeed uses eval_batch() and already aggregates losses.
|
|
assert isinstance(model, list) and len(model) == 1
|
|
loss = model[0].eval_batch(data_iterator)
|
|
loss_dicts = [{"lm loss": loss}] * get_num_microbatches()
|
|
else:
|
|
loss_dicts = forward_backward_func(
|
|
forward_step_func,
|
|
data_iterator,
|
|
model,
|
|
optimizer=None,
|
|
timers=None,
|
|
forward_only=True,
|
|
)
|
|
|
|
if mpu.is_pipeline_last_stage(ignore_virtual=True):
|
|
# Reduce across processes.
|
|
for loss_dict in loss_dicts:
|
|
for key in loss_dict:
|
|
total_loss_dict[key] = (
|
|
total_loss_dict.get(key, torch.cuda.FloatTensor([0.0]))
|
|
+ loss_dict[key]
|
|
)
|
|
|
|
args.consumed_valid_samples += (
|
|
mpu.get_data_parallel_world_size()
|
|
* args.micro_batch_size
|
|
* get_num_microbatches()
|
|
)
|
|
# Move model back to the train mode.
|
|
for model_module in model:
|
|
model_module.train()
|
|
|
|
for key in total_loss_dict:
|
|
total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
|
|
|
|
return total_loss_dict
|
|
|
|
|
|
def evaluate_and_print_results(
|
|
prefix, forward_step_func, data_iterator, model, iteration, verbose=False, tag=None
|
|
):
|
|
"""Helper function to evaluate and dump results on screen."""
|
|
args = get_args()
|
|
writer = get_tensorboard_writer()
|
|
|
|
total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
|
|
if tag is None:
|
|
string = " validation loss at {} | ".format(prefix)
|
|
else:
|
|
string = " validation loss for {} at {} | ".format(tag, prefix)
|
|
for key in total_loss_dict:
|
|
string += "{} value: {:.6E} | ".format(key, total_loss_dict[key].item())
|
|
ppl = math.exp(min(20, total_loss_dict[key].item()))
|
|
string += "{} PPL: {:.6E} | ".format(key, ppl)
|
|
|
|
if tag is not None:
|
|
display_key = tag + "-" + key
|
|
else:
|
|
display_key = key
|
|
|
|
if args.wandb_logging and is_last_rank():
|
|
wandb.log(
|
|
{
|
|
f"eval/{display_key}": total_loss_dict[key].item(),
|
|
},
|
|
step=iteration,
|
|
)
|
|
|
|
if writer and is_last_rank():
|
|
writer.add_scalar(
|
|
f"lm-loss-validation/{display_key} validation",
|
|
total_loss_dict[key].item(),
|
|
iteration,
|
|
)
|
|
# writer.add_scalar(
|
|
# f"lm-loss-validation/{display_key} validation vs samples",
|
|
# total_loss_dict[key].item(),
|
|
# args.consumed_train_samples,
|
|
# )
|
|
# writer.add_scalar(
|
|
# f"lm-loss-validation/{display_key} validation vs tokens",
|
|
# total_loss_dict[key].item(),
|
|
# args.consumed_train_tokens,
|
|
# )
|
|
if args.log_validation_ppl_to_tensorboard:
|
|
writer.add_scalar(
|
|
f"lm-loss-validation/{display_key} validation ppl", ppl, iteration
|
|
)
|
|
writer.add_scalar(
|
|
f"lm-loss-validation/{display_key} validation ppl vs samples",
|
|
ppl,
|
|
args.consumed_train_samples,
|
|
)
|
|
writer.add_scalar(
|
|
f"lm-loss-validation/{display_key} validation ppl vs tokens",
|
|
ppl,
|
|
args.consumed_train_tokens,
|
|
)
|
|
|
|
length = len(string) + 1
|
|
print_rank_last("-" * length)
|
|
print_rank_last(string)
|
|
print_rank_last("-" * length)
|
|
|
|
|
|
def evaluate_and_print_results_gold(
|
|
prefix, forward_step_func, data_iterator, model, iteration, verbose=False, tag=None
|
|
):
|
|
"""Helper function to evaluate and dump results on screen."""
|
|
args = get_args()
|
|
writer = get_tensorboard_writer()
|
|
|
|
total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
|
|
if tag is None:
|
|
string = " validation loss (gold) at {} | ".format(prefix)
|
|
else:
|
|
string = " validation loss (gold) for {} at {} | ".format(tag, prefix)
|
|
for key in total_loss_dict:
|
|
string += "{} value: {:.6E} | ".format(key, total_loss_dict[key].item())
|
|
ppl = math.exp(min(20, total_loss_dict[key].item()))
|
|
string += "{} PPL: {:.6E} | ".format(key, ppl)
|
|
|
|
if tag is not None:
|
|
display_key = tag + "-" + key
|
|
else:
|
|
display_key = key
|
|
|
|
if args.wandb_logging and is_last_rank():
|
|
wandb.log(
|
|
{
|
|
f"eval/{display_key}": total_loss_dict[key].item(),
|
|
},
|
|
step=iteration,
|
|
)
|
|
|
|
if writer and is_last_rank():
|
|
writer.add_scalar(
|
|
f"lm-loss-validation-gold/{display_key} validation",
|
|
total_loss_dict[key].item(),
|
|
iteration,
|
|
)
|
|
# writer.add_scalar(
|
|
# f"lm-loss-validation/{display_key} validation vs samples",
|
|
# total_loss_dict[key].item(),
|
|
# args.consumed_train_samples,
|
|
# )
|
|
# writer.add_scalar(
|
|
# f"lm-loss-validation/{display_key} validation vs tokens",
|
|
# total_loss_dict[key].item(),
|
|
# args.consumed_train_tokens,
|
|
# )
|
|
if args.log_validation_ppl_to_tensorboard:
|
|
writer.add_scalar(
|
|
f"lm-loss-validation/{display_key} validation ppl", ppl, iteration
|
|
)
|
|
writer.add_scalar(
|
|
f"lm-loss-validation/{display_key} validation ppl vs samples",
|
|
ppl,
|
|
args.consumed_train_samples,
|
|
)
|
|
writer.add_scalar(
|
|
f"lm-loss-validation/{display_key} validation ppl vs tokens",
|
|
ppl,
|
|
args.consumed_train_tokens,
|
|
)
|
|
|
|
length = len(string) + 1
|
|
print_rank_last("-" * length)
|
|
print_rank_last(string)
|
|
print_rank_last("-" * length)
|
|
|
|
|
|
def cyclic_iter(iter):
|
|
while True:
|
|
for x in iter:
|
|
yield x
|
|
|
|
|
|
def build_train_valid_test_data_iterators(build_train_valid_test_datasets_provider):
|
|
args = get_args()
|
|
|
|
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
|
|
|
|
print_rank_0("> building train, validation, and test datasets ...")
|
|
|
|
# Backward compatibility, assume fixed batch size.
|
|
if args.iteration > 0 and args.consumed_train_samples == 0:
|
|
assert (
|
|
args.train_samples is None
|
|
), "only backward compatibility support for iteration-based training"
|
|
args.consumed_train_samples = args.iteration * args.global_batch_size
|
|
if args.iteration > 0 and args.consumed_valid_samples == 0:
|
|
assert (
|
|
args.train_samples is None
|
|
), "only backward compatibility support for iteration-based training"
|
|
args.consumed_valid_samples = (
|
|
(args.iteration // args.eval_interval)
|
|
* args.eval_iters
|
|
* args.global_batch_size
|
|
)
|
|
|
|
# Data loader only on rank 0 of each model parallel group.
|
|
if mpu.get_tensor_model_parallel_rank() == 0:
|
|
|
|
# Number of train/valid/test samples.
|
|
if args.train_samples:
|
|
train_samples = args.train_samples
|
|
else:
|
|
train_samples = args.train_iters * args.global_batch_size
|
|
eval_iters = (args.train_iters // args.eval_interval + 1) * args.eval_iters
|
|
test_iters = args.eval_iters
|
|
train_val_test_num_samples = [
|
|
train_samples,
|
|
eval_iters * args.global_batch_size,
|
|
test_iters * args.global_batch_size,
|
|
]
|
|
print_rank_0(" > datasets target sizes (minimum size):")
|
|
print_rank_0(" train: {}".format(train_val_test_num_samples[0]))
|
|
print_rank_0(" validation: {}".format(train_val_test_num_samples[1]))
|
|
print_rank_0(" test: {}".format(train_val_test_num_samples[2]))
|
|
|
|
# Build the datasets.
|
|
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
|
|
train_val_test_num_samples
|
|
)
|
|
|
|
# Build dataloders.
|
|
train_dataloader = build_pretraining_data_loader(
|
|
train_ds, args.consumed_train_samples
|
|
)
|
|
if args.co_evaluation:
|
|
valid_dataloader = {}
|
|
for key, value in valid_ds.items():
|
|
valid_dataloader[key] = build_pretraining_data_loader(
|
|
value, args.consumed_valid_samples
|
|
)
|
|
else:
|
|
valid_dataloader = build_pretraining_data_loader(
|
|
valid_ds, args.consumed_valid_samples
|
|
)
|
|
if args.co_evaluation:
|
|
if test_ds is not None:
|
|
test_dataloader = {}
|
|
for key, value in test_ds.items():
|
|
test_dataloader[key] = build_pretraining_data_loader(value, 0)
|
|
else:
|
|
test_dataloader = None
|
|
else:
|
|
test_dataloader = build_pretraining_data_loader(test_ds, 0)
|
|
|
|
# Flags to know if we need to do training/validation/testing.
|
|
do_train = train_dataloader is not None and args.train_iters > 0
|
|
do_valid = valid_dataloader is not None and args.eval_iters > 0
|
|
do_test = test_dataloader is not None and args.eval_iters > 0
|
|
# Need to broadcast num_tokens and num_type_tokens.
|
|
flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)])
|
|
else:
|
|
flags = torch.cuda.LongTensor([0, 0, 0])
|
|
|
|
# Broadcast num tokens.
|
|
torch.distributed.broadcast(
|
|
flags,
|
|
mpu.get_tensor_model_parallel_src_rank(),
|
|
group=mpu.get_tensor_model_parallel_group(),
|
|
)
|
|
args.do_train = flags[0].item()
|
|
args.do_valid = flags[1].item()
|
|
args.do_test = flags[2].item()
|
|
|
|
# Build iterators.
|
|
dl_type = args.dataloader_type
|
|
assert dl_type in ["single", "cyclic"]
|
|
|
|
if train_dataloader is not None:
|
|
train_data_iterator = (
|
|
iter(train_dataloader)
|
|
if dl_type == "single"
|
|
else iter(cyclic_iter(train_dataloader))
|
|
)
|
|
else:
|
|
train_data_iterator = None
|
|
|
|
if valid_dataloader is not None:
|
|
if args.co_evaluation:
|
|
valid_data_iterator = {}
|
|
for key, value in valid_dataloader.items():
|
|
valid_data_iterator[key] = (
|
|
iter(value)
|
|
if dl_type == "single"
|
|
else iter(cyclic_iter(value))
|
|
)
|
|
else:
|
|
valid_data_iterator = (
|
|
iter(valid_dataloader)
|
|
if dl_type == "single"
|
|
else iter(cyclic_iter(valid_dataloader))
|
|
)
|
|
else:
|
|
valid_data_iterator = None
|
|
|
|
if test_dataloader is not None:
|
|
if args.co_evaluation:
|
|
test_data_iterator = {}
|
|
for key, value in test_dataloader.items():
|
|
test_data_iterator[key] = (
|
|
iter(value)
|
|
if dl_type == "single"
|
|
else iter(cyclic_iter(value))
|
|
)
|
|
else:
|
|
test_data_iterator = (
|
|
iter(test_dataloader)
|
|
if dl_type == "single"
|
|
else iter(cyclic_iter(test_dataloader))
|
|
)
|
|
else:
|
|
test_data_iterator = None
|
|
|
|
return train_data_iterator, valid_data_iterator, test_data_iterator
|