You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1529 lines
47 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron arguments."""
import argparse
import os
import torch
import deepspeed
def parse_args(extra_args_provider=None, defaults={}, ignore_unknown_args=False):
"""Parse all arguments."""
parser = argparse.ArgumentParser(
description="Megatron-LM Arguments", allow_abbrev=False
)
# Standard arguments.
parser = _add_network_size_args(parser)
parser = _add_regularization_args(parser)
parser = _add_training_args(parser)
parser = _add_initialization_args(parser)
parser = _add_learning_rate_args(parser)
parser = _add_checkpointing_args(parser)
parser = _add_mixed_precision_args(parser)
parser = _add_distributed_args(parser)
parser = _add_validation_args(parser)
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser)
parser = _add_logging_args(parser)
parser = _add_zero_args(parser)
parser = _add_memoryopt_args(parser)
parser = _add_activation_checkpoint_args(parser)
parser = _add_inference_args(parser)
# Custom arguments.
if extra_args_provider is not None:
parser = extra_args_provider(parser)
parser = deepspeed.add_config_arguments(parser)
# Parse.
if ignore_unknown_args:
args, _ = parser.parse_known_args()
else:
args = parser.parse_args()
# helper argument to set deepspeed pipeline parallel or not
args.ds_pipeline_enabled = not args.no_pipeline_parallel
# Distributed args.
args.rank = int(os.getenv("RANK", "0"))
args.world_size = int(os.getenv("WORLD_SIZE", "1"))
# Tensor model parallel size.
args.tensor_model_parallel_size = min(
args.tensor_model_parallel_size, args.world_size
)
assert (
args.world_size % args.tensor_model_parallel_size == 0
), "world size" " ({}) is not divisible by tensor model parallel size ({})".format(
args.world_size, args.tensor_model_parallel_size
)
# Pipeline model parallel size.
args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size),
)
# Checks.
if args.no_pipeline_parallel:
assert (
args.pipeline_model_parallel_size == 1
), "pipeline_model_parallel_size must be 1 if pipeline parallel is disabled"
model_parallel_size = (
args.pipeline_model_parallel_size * args.tensor_model_parallel_size
)
assert args.world_size % model_parallel_size == 0, (
"world size is not"
" divisible by tensor parallel size ({}) times pipeline parallel "
"size ({})".format(
args.world_size,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
)
)
args.data_parallel_size = args.world_size // model_parallel_size
if args.rank == 0:
print(
"using world size: {}, data-parallel-size: {}, "
"tensor-model-parallel size: {}, "
"pipeline-model-parallel size: {} ".format(
args.world_size,
args.data_parallel_size,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
),
flush=True,
)
# Deprecated arguments
assert args.batch_size is None, (
"--batch-size argument is no longer " "valid, use --micro-batch-size instead"
)
del args.batch_size
assert args.warmup is None, (
"--warmup argument is no longer valid, use " "--lr-warmup-fraction instead"
)
del args.warmup
assert args.model_parallel_size is None, (
"--model-parallel-size is no "
"longer valid, use --tensor-model-parallel-size instead"
)
del args.model_parallel_size
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if getattr(args, key) is not None:
if args.force_default:
print(
"WARNING: overriding arguments for {key}:{v2} \
with default {key}:{v}".format(
key=key, v=defaults[key], v2=getattr(args, key)
),
flush=True,
)
setattr(args, key, defaults[key])
else:
if args.rank == 0:
print(
"WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}".format(
key=key, v=defaults[key], v2=getattr(args, key)
),
flush=True,
)
else:
setattr(args, key, defaults[key])
# Batch size.
assert args.micro_batch_size is not None
assert args.micro_batch_size > 0
if args.global_batch_size is None:
args.global_batch_size = args.micro_batch_size * args.data_parallel_size
if args.rank == 0:
print(
"setting global batch size to {}".format(args.global_batch_size),
flush=True,
)
assert args.global_batch_size > 0
if args.num_layers_per_virtual_pipeline_stage is not None:
assert args.pipeline_model_parallel_size > 2, (
"pipeline-model-parallel size should be greater than 2 with "
"interleaved schedule"
)
assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, (
"number of layers is not divisible by number of layers per virtual "
"pipeline stage"
)
args.virtual_pipeline_model_parallel_size = (
args.num_layers // args.pipeline_model_parallel_size
) // args.num_layers_per_virtual_pipeline_stage
else:
args.virtual_pipeline_model_parallel_size = None
# Parameters dtype.
args.params_dtype = torch.float
if args.fp16:
assert not args.bf16
args.params_dtype = torch.half
if args.bf16:
assert not args.fp16
args.params_dtype = torch.bfloat16
# bfloat16 requires gradient accumulation and all-reduce to
# be done in fp32.
if not args.accumulate_allreduce_grads_in_fp32:
args.accumulate_allreduce_grads_in_fp32 = True
if args.rank == 0:
print(
"accumulate and all-reduce gradients in fp32 for "
"bfloat16 data type.",
flush=True,
)
if args.rank == 0:
print("using {} for parameters ...".format(args.params_dtype), flush=True)
# If we do accumulation and all-reduces in fp32, we need to have
# local DDP and we should set the use-contiguous-buffers-in-ddp.
if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == "local"
args.use_contiguous_buffers_in_ddp = True
if args.dataloader_type is None:
args.dataloader_type = "single"
# Consumed tokens.
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
args.consumed_train_tokens = 0
# Iteration-based training.
if args.train_iters:
# If we use iteration-based training, make sure the
# sample-based options are off.
assert args.train_samples is None, "expected iteration-based training"
assert (
args.lr_decay_samples is None
), "expected iteration-based learning rate decay"
assert (
args.lr_warmup_samples == 0
), "expected iteration-based learning rate warmup"
assert (
args.rampup_batch_size is None
), "expected no batch-size rampup for iteration-based training"
if args.lr_warmup_fraction is not None:
assert (
args.lr_warmup_iters == 0
), "can only specify one of lr-warmup-fraction and lr-warmup-iters"
# Sample-based training.
if args.train_samples:
# If we use sample-based training, make sure the
# iteration-based options are off.
assert args.train_iters is None, "expected sample-based training"
assert args.lr_decay_iters is None, "expected sample-based learning rate decay"
assert args.lr_warmup_iters == 0, "expected sample-based learnig rate warmup"
if args.lr_warmup_fraction is not None:
assert args.lr_warmup_samples == 0, (
"can only specify one of lr-warmup-fraction " "and lr-warmup-samples"
)
# Check required arguments.
required_args = [
"num_layers",
"hidden_size",
"num_attention_heads",
"max_position_embeddings",
]
for req_arg in required_args:
_check_arg_is_not_none(args, req_arg)
# args.learned_position_embeddings = args.learned_position_embeddings > 0
# Checks.
if args.ffn_hidden_size is None:
args.ffn_hidden_size = 4 * args.hidden_size
if args.kv_channels is None:
assert args.hidden_size % args.num_attention_heads == 0
args.kv_channels = args.hidden_size // args.num_attention_heads
if args.seq_length is not None:
assert args.encoder_seq_length is None
args.encoder_seq_length = args.seq_length
else:
assert args.encoder_seq_length is not None
args.seq_length = args.encoder_seq_length
if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length
if args.decoder_seq_length is not None:
assert args.max_position_embeddings >= args.decoder_seq_length
if args.lr is not None:
assert args.min_lr <= args.lr
if args.save is not None:
assert args.save_interval is not None
# Mixed precision checks.
if args.fp16_lm_cross_entropy:
assert args.fp16, "lm cross entropy in fp16 only support in fp16 mode."
if args.fp32_residual_connection:
assert (
args.fp16 or args.bf16
), "residual connection in fp32 only supported when using fp16 or bf16."
# Activation checkpointing.
if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, (
"for distribute-checkpointed-activations to work you "
"need to enable checkpoint-activations"
)
_print_args(args)
return args
def _print_args(args):
"""Print arguments."""
if args.rank == 0:
print("------------------------ arguments ------------------------", flush=True)
str_list = []
for arg in vars(args):
dots = "." * (48 - len(arg))
str_list.append(" {} {} {}".format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()):
print(arg, flush=True)
print("-------------------- end of arguments ---------------------", flush=True)
def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, "{} argument is None".format(arg)
def _add_network_size_args(parser):
group = parser.add_argument_group(title="network size")
group.add_argument(
"--num-layers",
type=int,
default=None,
help="Number of transformer layers.",
)
group.add_argument(
"--hidden-size",
type=int,
default=None,
help="Transformer hidden size.",
)
group.add_argument(
"--reward-growth",
type=str,
default="constant",
choices=["constant", "linear", "quadratic"],
help="Reward growth function.",
)
group.add_argument(
"--ffn-hidden-size",
type=int,
default=None,
help="Transformer Feed-Forward Network hidden size. "
"This is set to 4*hidden-size if not provided",
)
group.add_argument(
"--num-attention-heads",
type=int,
default=None,
help="Number of transformer attention heads.",
)
group.add_argument(
"--kv-channels",
type=int,
default=None,
help="Projection weights dimension in multi-head "
"attention. This is set to "
" args.hidden_size // args.num_attention_heads "
"if not provided.",
)
group.add_argument(
"--scale-embeddings",
action="store_true",
help="Scale embeddings by sqrt(d_model).",
)
group.add_argument(
"--max-position-embeddings",
type=int,
default=None,
help="Maximum number of position embeddings to use. "
"This is the size of position embedding.",
)
group.add_argument(
"--no-learned-position-embeddings",
action="store_true",
help="Do not learn position embeddings. ",
)
group.add_argument(
"--make-vocab-size-divisible-by",
type=int,
default=128,
help="Pad the vocab size to be divisible by this value."
"This is added for computational efficieny reasons.",
)
group.add_argument(
"--layernorm-epsilon", type=float, default=1e-5, help="Layer norm epsilon."
)
group.add_argument(
"--apply-residual-connection-post-layernorm",
action="store_true",
help="If set, use original BERT residula connection " "ordering.",
)
group.add_argument('--scaled-upper-triang-masked-softmax-fusion',
action='store_true',
help='Enable fusion of query_key_value_scaling '
'time (upper diagonal) masking, softmax.')
group.add_argument(
"--openai-gelu",
action="store_true",
help="Use OpenAIs GeLU implementation. This option"
"should not be used unless for backward compatibility"
"reasons.",
)
group.add_argument(
"--onnx-safe",
type=bool,
required=False,
help="Use workarounds for known problems with " "Torch ONNX exporter",
)
group.add_argument(
"--bert-no-binary-head",
action="store_false",
help="Disable BERT binary head.",
dest="bert_binary_head",
)
return parser
def _add_logging_args(parser):
group = parser.add_argument_group(title="logging")
group.add_argument(
"--log-params-norm",
action="store_true",
help="If set, calculate and log parameters norm.",
)
group.add_argument(
"--log-num-zeros-in-grad",
action="store_true",
help="If set, calculate and log the number of zeros in gradient.",
)
group.add_argument(
"--tensorboard-log-interval",
type=int,
default=1,
help="Report to tensorboard interval.",
)
group.add_argument(
"--tensorboard-queue-size",
type=int,
default=1000,
help="Size of the tensorboard queue for pending events "
"and summaries before one of the add calls forces a "
"flush to disk.",
)
group.add_argument(
"--log-timers-to-tensorboard",
action="store_true",
help="If set, write timers to tensorboard.",
)
group.add_argument(
"--log-batch-size-to-tensorboard",
action="store_true",
help="If set, write batch-size to tensorboard.",
)
group.add_argument(
"--no-log-learnig-rate-to-tensorboard",
action="store_false",
help="Disable learning rate logging to tensorboard.",
dest="log_learning_rate_to_tensorboard",
)
group.add_argument(
"--no-log-loss-scale-to-tensorboard",
action="store_false",
help="Disable loss-scale logging to tensorboard.",
dest="log_loss_scale_to_tensorboard",
)
group.add_argument(
"--log-validation-ppl-to-tensorboard",
action="store_true",
help="If set, write validation perplexity to " "tensorboard.",
)
group.add_argument(
"--wandb-logging",
action="store_true",
help="If set, log training progress to wandb.",
)
group.add_argument(
"--wandb-log-interval",
type=int,
default=1,
help="Log to wandb every N steps.",
)
return parser
def _add_regularization_args(parser):
group = parser.add_argument_group(title="regularization")
group.add_argument(
"--attention-dropout",
type=float,
default=0.1,
help="Post attention dropout probability.",
)
group.add_argument(
"--hidden-dropout",
type=float,
default=0.1,
help="Dropout probability for hidden state transformer.",
)
group.add_argument(
"--weight-decay",
type=float,
default=0.01,
help="Weight decay coefficient for L2 regularization.",
)
group.add_argument(
"--tempering",
type=float,
default=None,
help="Tempering coefficient for the model.",
)
group.add_argument(
"--gold",
action="store_true",
help="If set, use gold regularization.",
)
group.add_argument(
"--gold-beta",
type=float,
default=0.05,
help="Beta for GOLD tempering.",
)
group.add_argument(
"--play-tau",
type=float,
default=2.0
)
group.add_argument(
"--clip-grad",
type=float,
default=1.0,
help="Gradient clipping based on global L2 norm.",
)
group.add_argument(
"--adam-beta1",
type=float,
default=0.9,
help="First coefficient for computing running averages "
"of gradient and its square",
)
group.add_argument(
"--adam-beta2",
type=float,
default=0.999,
help="Second coefficient for computing running averages "
"of gradient and its square",
)
group.add_argument(
"--adam-eps",
type=float,
default=1e-08,
help="Term added to the denominator to improve" "numerical stability",
)
group.add_argument(
"--sgd-momentum", type=float, default=0.9, help="Momentum factor for sgd"
)
return parser
def _add_training_args(parser):
group = parser.add_argument_group(title="training")
group.add_argument(
"--micro-batch-size",
type=int,
default=None,
help="Batch size per model instance (local batch size). "
"Global batch size is local batch size times data "
"parallel size times number of micro batches.",
)
group.add_argument(
"--batch-size",
type=int,
default=None,
help="Old batch size parameter, do not use. " "Use --micro-batch-size instead",
)
group.add_argument(
"--global-batch-size",
type=int,
default=None,
help="Training batch size. If set, it should be a "
"multiple of micro-batch-size times data-parallel-size. "
"If this value is None, then "
"use micro-batch-size * data-parallel-size as the "
"global batch size. This choice will result in 1 for "
"number of micro-batches.",
)
group.add_argument(
"--rampup-batch-size",
nargs="*",
default=None,
help="Batch size ramp up with the following values:"
" --rampup-batch-size <start batch size> "
" <batch size incerement> "
" <ramp-up samples> "
"For example:"
" --rampup-batch-size 16 8 300000 \ "
" --global-batch-size 1024"
"will start with global batch size 16 and over "
" (1024 - 16) / 8 = 126 intervals will increase"
"the batch size linearly to 1024. In each interval"
"we will use approximately 300000 / 126 = 2380 samples.",
)
group.add_argument(
"--checkpoint-activations",
action="store_true",
help="Checkpoint activation to allow for training "
"with larger models, sequences, and batch sizes.",
)
group.add_argument(
"--distribute-checkpointed-activations",
action="store_true",
help="If set, distribute checkpointed activations "
"across model parallel group.",
)
group.add_argument(
"--checkpoint-num-layers",
type=int,
default=1,
help="chunk size (number of layers) for checkpointing.",
)
group.add_argument(
"--train-iters",
type=int,
default=None,
help="Total number of iterations to train over all "
"training runs. Note that either train-iters or "
"train-samples should be provided.",
)
group.add_argument(
"--train-samples",
type=int,
default=None,
help="Total number of samples to train over all "
"training runs. Note that either train-iters or "
"train-samples should be provided.",
)
group.add_argument(
"--train-tokens",
type=int,
default=None,
help="Total number of tokens to train over all " "training runs.",
)
group.add_argument(
"--log-interval", type=int, default=100, help="Report loss and timing interval."
)
group.add_argument(
"--exit-interval",
type=int,
default=None,
help="Exit the program after the iteration is divisible " "by this value.",
)
group.add_argument(
"--exit-duration-in-mins",
type=int,
default=None,
help="Exit the program after this many minutes.",
)
group.add_argument(
"--tensorboard-dir",
type=str,
default=None,
help="Write TensorBoard logs to this directory.",
)
group.add_argument(
"--no-masked-softmax-fusion",
action="store_false",
help="Disable fusion of query_key_value scaling, " "masking, and softmax.",
dest="masked_softmax_fusion",
)
group.add_argument(
"--no-bias-gelu-fusion",
action="store_false",
help="Disable bias and gelu fusion.",
dest="bias_gelu_fusion",
)
group.add_argument(
"--no-bias-dropout-fusion",
action="store_false",
help="Disable bias and dropout fusion.",
dest="bias_dropout_fusion",
)
group.add_argument(
"--optimizer",
type=str,
default="adam",
choices=["adam", "sgd"],
help="Optimizer function",
)
group.add_argument(
"--dataloader-type",
type=str,
default=None,
choices=["single", "cyclic"],
help="Single pass vs multiple pass data loader",
)
group.add_argument(
"--cpu-optimizer", action="store_true", help="Run optimizer on CPU"
)
group.add_argument(
"--cpu_torch_adam",
action="store_true",
help="Use Torch Adam as optimizer on CPU.",
)
group.add_argument(
"--no-pipeline-parallel",
action="store_true",
help="Disable pipeline parallelism",
)
group.add_argument(
"--ms-model",
action="store_true",
help="use model converted from Mindspore",
)
return parser
def _add_initialization_args(parser):
group = parser.add_argument_group(title="initialization")
group.add_argument(
"--seed",
type=int,
default=1234,
help="Random seed used for python, numpy, " "pytorch, and cuda.",
)
group.add_argument(
"--init-method-std",
type=float,
default=0.02,
help="Standard deviation of the zero mean normal "
"distribution used for weight initialization.",
)
group.add_argument(
"--init-method-xavier-uniform",
action="store_true",
help="Enable Xavier uniform parameter initialization",
)
return parser
def _add_inference_args(parser):
group = parser.add_argument_group(title="initialization")
group.add_argument(
'--beam-warmup',
action="store_true",
)
group.add_argument(
'--beam-warmup-length',
type=int,
default=0,
)
group.add_argument(
'--beam-search',
action="store_true",
)
group.add_argument(
'--beam-search-nucleus',
action="store_true",
)
group.add_argument(
'--num-beams',
type=int,
default=4,
)
return parser
def _add_learning_rate_args(parser):
group = parser.add_argument_group(title="learning rate")
group.add_argument(
"--lr",
type=float,
default=None,
help="Initial learning rate. Depending on decay style "
"and initial warmup, the learing rate at each "
"iteration would be different.",
)
group.add_argument(
"--lr-decay-style",
type=str,
default="linear",
choices=["constant", "linear", "cosine"],
help="Learning rate decay function.",
)
group.add_argument(
"--lr-decay-iters",
type=int,
default=None,
help="number of iterations to decay learning rate over,"
" If None defaults to `--train-iters`",
)
group.add_argument(
"--lr-decay-samples",
type=int,
default=None,
help="number of samples to decay learning rate over,"
" If None defaults to `--train-samples`",
)
group.add_argument(
"--lr-decay-tokens",
type=int,
default=None,
help="number of tokens to decay learning rate over,"
" If not None will override iter/sample-based decay",
)
group.add_argument(
"--lr-warmup-fraction",
type=float,
default=None,
help="fraction of lr-warmup-(iters/samples) to use " "for warmup (as a float)",
)
group.add_argument(
"--lr-warmup-iters",
type=int,
default=0,
help="number of iterations to linearly warmup " "learning rate over.",
)
group.add_argument(
"--lr-warmup-samples",
type=int,
default=0,
help="number of samples to linearly warmup " "learning rate over.",
)
group.add_argument(
"--warmup",
type=int,
default=None,
help="Old lr warmup argument, do not use. Use one of the"
"--lr-warmup-* arguments above",
)
group.add_argument(
"--min-lr",
type=float,
default=0.0,
help="Minumum value for learning rate. The scheduler"
"clip values below this threshold.",
)
group.add_argument(
"--override-lr-scheduler",
action="store_true",
help="Reset the values of the scheduler (learning rate,"
"warmup iterations, minimum learning rate, maximum "
"number of iterations, and decay style from input "
"arguments and ignore values from checkpoints. Note"
"that all the above values will be reset.",
)
group.add_argument(
"--use-checkpoint-lr-scheduler",
action="store_true",
help="Use checkpoint to set the values of the scheduler "
"(learning rate, warmup iterations, minimum learning "
"rate, maximum number of iterations, and decay style "
"from checkpoint and ignore input arguments.",
)
return parser
def _add_checkpointing_args(parser):
group = parser.add_argument_group(title="checkpointing")
group.add_argument(
"--save",
type=str,
default=None,
help="Output directory to save checkpoints to.",
)
group.add_argument(
"--save-interval",
type=int,
default=None,
help="Number of iterations between checkpoint saves.",
)
group.add_argument(
"--no-save-optim",
action="store_true",
default=None,
help="Do not save current optimizer.",
)
group.add_argument(
"--no-save-rng",
action="store_true",
default=None,
help="Do not save current rng state.",
)
group.add_argument(
"--load",
type=str,
default=None,
help="Directory containing a model checkpoint.",
)
group.add_argument(
"--low-memory-load",
action="store_true",
default=None,
help="Load model checkpoint in low memory mode."
"On each machine, workers load the checkpoint one at a time."
)
group.add_argument(
"--dist-timeout",
type=int,
default=30,
help="Timeout for Pytorch Distributed backend (in minutes).",
)
group.add_argument(
"--load-state",
type=str,
default=None,
help="Start training from a existing model state.",
)
group.add_argument(
"--no-load-optim",
action="store_true",
default=None,
help="Do not load optimizer when loading checkpoint.",
)
group.add_argument(
"--no-load-rng",
action="store_true",
default=None,
help="Do not load rng state when loading checkpoint.",
)
group.add_argument(
"--finetune",
action="store_true",
help="Load model for finetuning. Do not load optimizer "
"or rng state from checkpoint and set iteration to 0. "
"Assumed when loading a release checkpoint.",
)
return parser
def _add_mixed_precision_args(parser):
group = parser.add_argument_group(title="mixed precision")
group.add_argument("--fp16", action="store_true", help="Run model in fp16 mode.")
group.add_argument("--ln-fp16", action="store_true", help="Run layernorm in fp16 mode.")
group.add_argument(
"--bf16", action="store_true", help="Run model in bfloat16 mode."
)
group.add_argument(
"--loss-scale",
type=float,
default=None,
help="Static loss scaling, positive power of 2 "
"values can improve fp16 convergence. If None, dynamic"
"loss scaling is used.",
)
group.add_argument(
"--initial-loss-scale",
type=float,
default=2 ** 32,
help="Initial loss-scale for dynamic loss scaling.",
)
group.add_argument(
"--min-loss-scale",
type=float,
default=1.0,
help="Minimum loss scale for dynamic loss scale.",
)
group.add_argument(
"--loss-scale-window",
type=float,
default=1000,
help="Window over which to raise/lower dynamic scale.",
)
group.add_argument(
"--hysteresis", type=int, default=2, help="hysteresis for dynamic loss scaling"
)
group.add_argument(
"--fp32-residual-connection",
action="store_true",
help="Move residual connections to fp32.",
)
group.add_argument('--apply-query-key-layer-scaling', action='store_true',
help='Scale Q * K^T by 1 / layer-number. If this flag '
'is set, then it will automatically set '
'attention-softmax-in-fp32 to true')
group.add_argument(
"--attention-softmax-in-fp32",
action="store_true",
help="Run attention masking and softmax in fp32. "
"This flag is ignored unless "
"--no-query-key-layer-scaling is specified.",
)
group.add_argument(
"--accumulate-allreduce-grads-in-fp32",
action="store_true",
help="Gradient accumulation and all-reduce in fp32.",
)
group.add_argument(
"--fp16-lm-cross-entropy",
action="store_true",
help="Move the cross entropy unreduced loss calculation" "for lm head to fp16.",
)
return parser
def _add_distributed_args(parser):
group = parser.add_argument_group(title="distributed")
group.add_argument(
"--tensor-model-parallel-size",
type=int,
default=1,
help="Degree of tensor model parallelism.",
)
group.add_argument(
"--pipeline-model-parallel-size",
type=int,
default=1,
help="Degree of pipeline model parallelism.",
)
group.add_argument(
"--model-parallel-size",
type=int,
default=None,
help="Old model parallel argument, do not use. Use "
"--tensor-model-parallel-size instead.",
)
group.add_argument(
"--num-layers-per-virtual-pipeline-stage",
type=int,
default=None,
help="Number of layers per virtual pipeline stage",
)
group.add_argument(
"--distributed-backend",
default="nccl",
choices=["nccl", "gloo"],
help="Which backend to use for distributed training.",
)
group.add_argument(
"--DDP-impl",
default="local",
choices=["local", "torch"],
help="which DistributedDataParallel implementation " "to use.",
)
group.add_argument(
"--use-contiguous-buffers-in-ddp",
action="store_true",
help="If set, use contiguous buffer in DDP. Note that "
"this option only works woth local DDP.",
)
group.add_argument(
"--no-scatter-gather-tensors-in-pipeline",
action="store_false",
help="Use scatter/gather to optimize communication of tensors in pipeline",
dest="scatter_gather_tensors_in_pipeline",
)
group.add_argument(
"--local_rank",
type=int,
default=None,
help="local rank passed from distributed launcher.",
)
group.add_argument(
"--lazy-mpu-init",
type=bool,
required=False,
help="If set to True, initialize_megatron() "
"skips DDP initialization and returns function to "
"complete it instead.Also turns on "
"--use-cpu-initialization flag. This is for "
"external DDP manager.",
)
group.add_argument(
"--use-cpu-initialization",
action="store_true",
default=None,
help="If set, affine parallel weights " "initialization uses CPU",
)
group.add_argument(
"--force-device",
type=int,
default=None,
help="Force the model to run on a particular gpu",
)
group.add_argument(
"--force-default",
action="store_true",
help="Force setting default arguments for distributed training",
)
return parser
def _add_validation_args(parser):
group = parser.add_argument_group(title="validation")
group.add_argument(
"--eval-iters",
type=int,
default=100,
help="Number of iterations to run for evaluation" "validation/test for.",
)
group.add_argument(
"--eval-interval",
type=int,
default=1000,
help="Interval between running evaluation on " "validation set.",
)
group.add_argument(
"--co-evaluation",
action="store_true",
help="If set, run evaluation on each part of the validation set"
)
return parser
def _add_data_args(parser):
group = parser.add_argument_group(title="data and dataloader")
group.add_argument(
"--data-path",
nargs="*",
default=None,
help="Path to the training dataset. Accepted format:"
"1) a single data path, 2) multiple datasets in the"
"form: dataset1-weight dataset1-path dataset2-weight "
"dataset2-path ...",
)
group.add_argument(
"--valid-data-path",
nargs="*",
default=None,
help="Path to the validation dataset. Accepted format:"
"1) a single data path, 2) multiple datasets in the"
"form: dataset1-weight dataset1-path dataset2-weight "
"dataset2-path ...;"
"when co-evaluation is enabled, the form will be dataset1-tag dataset1-path ...",
)
group.add_argument("--index-cache-dir", type=str, default=None, help="Path to the index cache")
group.add_argument(
"--test-data-path",
nargs="*",
default=None,
help="Path to the test dataset. Accepted format:"
"1) a single data path, 2) multiple datasets in the"
"form: dataset1-tag dataset1-path dataset2-tag "
"dataset2-path ...",
)
group.add_argument(
"--split",
type=str,
default="969, 30, 1",
help="Comma-separated list of proportions for training,"
" validation, and test split. For example the split "
"`90,5,5` will use 90%% of data for training, 5%% for "
"validation and 5%% for test.",
)
group.add_argument(
"--vocab-file",
type=str,
default=None,
help="Path to the vocab file.",
)
group.add_argument(
"--merge-file",
type=str,
default=None,
help="Path to the BPE merge file.",
)
group.add_argument(
"--tokenizer-path",
type=str,
default=None,
help="Path to the tokenizer dir.",
)
group.add_argument(
"--vocab-extra-ids",
type=int,
default=0,
help="Number of additional vocabulary tokens. "
"They are used for span masking in the T5 model",
)
group.add_argument(
"--seq-length",
type=int,
default=None,
help="Maximum sequence length to process.",
)
group.add_argument(
"--encoder-seq-length",
type=int,
default=None,
help="Maximum encoder sequence length to process."
"This should be exclusive of --seq-length",
)
group.add_argument(
"--decoder-seq-length",
type=int,
default=None,
help="Maximum decoder sequence length to process.",
)
group.add_argument(
"--retriever-seq-length",
type=int,
default=256,
help="Maximum sequence length for the biencoder model " " for retriever",
)
group.add_argument(
"--sample-rate",
type=float,
default=1.0,
help="sample rate for training data. Supposed to be 0 " " < sample_rate < 1",
)
group.add_argument(
"--mask-prob",
type=float,
default=0.15,
help="Probability of replacing a token with mask.",
)
group.add_argument(
"--short-seq-prob",
type=float,
default=0.1,
help="Probability of producing a short sequence.",
)
group.add_argument("--mmap-warmup", action="store_true", help="Warm up mmap files.")
group.add_argument(
"--num-workers", type=int, default=2, help="Dataloader number of workers."
)
group.add_argument(
"--tokenizer-type",
type=str,
default=None,
choices=["BertWordPieceLowerCase", "BertWordPieceCase", "GPT2BPETokenizer"],
help="What type of tokenizer to use.",
)
group.add_argument(
"--data-impl",
type=str,
default="infer",
choices=["lazy", "cached", "mmap", "infer"],
help="Implementation of indexed datasets.",
)
group.add_argument(
"--reset-position-ids",
action="store_true",
help="Reset posistion ids after end-of-document token.",
)
group.add_argument(
"--reset-attention-mask",
action="store_true",
help="Reset self attention masks after " "end-of-document token.",
)
group.add_argument(
"--eod-mask-loss",
action="store_true",
help="Mask loss for the end of document tokens.",
)
return parser
def _add_autoresume_args(parser):
group = parser.add_argument_group(title="autoresume")
group.add_argument(
"--adlr-autoresume",
action="store_true",
help="Enable autoresume on adlr cluster.",
)
group.add_argument(
"--adlr-autoresume-interval",
type=int,
default=1000,
help="Intervals over which check for autoresume" "termination signal",
)
return parser
def _add_biencoder_args(parser):
group = parser.add_argument_group(title="biencoder")
# network size
group.add_argument(
"--ict-head-size",
type=int,
default=None,
help="Size of block embeddings to be used in ICT and "
"REALM (paper default: 128)",
)
group.add_argument(
"--biencoder-projection-dim",
type=int,
default=0,
help="Size of projection head used in biencoder (paper" " default: 128)",
)
group.add_argument(
"--biencoder-shared-query-context-model",
action="store_true",
help="Whether to share the parameters of the query "
"and context models or not",
)
# checkpointing
group.add_argument(
"--ict-load",
type=str,
default=None,
help="Directory containing an ICTBertModel checkpoint",
)
group.add_argument(
"--bert-load",
type=str,
default=None,
help="Directory containing an BertModel checkpoint "
"(needed to start ICT and REALM)",
)
# data
group.add_argument(
"--titles-data-path",
type=str,
default=None,
help="Path to titles dataset used for ICT",
)
group.add_argument(
"--query-in-block-prob",
type=float,
default=0.1,
help="Probability of keeping query in block for " "ICT dataset",
)
group.add_argument(
"--use-one-sent-docs",
action="store_true",
help="Whether to use one sentence documents in ICT",
)
group.add_argument(
"--evidence-data-path",
type=str,
default=None,
help="Path to Wikipedia Evidence frm DPR paper",
)
# training
group.add_argument(
"--retriever-report-topk-accuracies",
nargs="+",
type=int,
default=[],
help="Which top-k accuracies to report " "(e.g. '1 5 20')",
)
group.add_argument(
"--retriever-score-scaling",
action="store_true",
help="Whether to scale retriever scores by inverse "
"square root of hidden size",
)
# faiss index
group.add_argument(
"--block-data-path",
type=str,
default=None,
help="Where to save/load BlockData to/from",
)
group.add_argument(
"--embedding-path",
type=str,
default=None,
help="Where to save/load Open-Retrieval Embedding" " data to/from",
)
# indexer
group.add_argument(
"--indexer-batch-size",
type=int,
default=128,
help="How large of batches to use when doing indexing " "jobs",
)
group.add_argument(
"--indexer-log-interval",
type=int,
default=1000,
help="After how many batches should the indexer " "report progress",
)
return parser
def _add_vit_args(parser):
group = parser.add_argument_group(title="vit")
group.add_argument(
"--num-classes",
type=int,
default=1000,
help="num of classes in vision classificaiton task",
)
group.add_argument(
"--img-dim",
type=int,
default=224,
help="Image size for vision classification task",
)
group.add_argument(
"--num-channels",
type=int,
default=3,
help="Number of channels in input image data",
)
group.add_argument(
"--patch-dim", type=int, default=16, help="patch dimension used in vit"
)
return parser
def _add_zero_args(parser):
"""Text generate arguments."""
group = parser.add_argument_group("ZeRO configurations", "configurations")
group.add_argument("--zero-stage", type=int, default=1.0)
group.add_argument(
"--zero-reduce-scatter",
action="store_true",
help="Use reduce scatter if specified",
)
group.add_argument(
"--zero-contigious-gradients",
action="store_true",
help="Use contigious memory optimizaiton if specified",
)
group.add_argument("--zero-reduce-bucket-size", type=int, default=0.0)
group.add_argument("--zero-allgather-bucket-size", type=int, default=0.0)
group.add_argument(
"--remote-device",
type=str,
default="none",
choices=["none", "cpu", "nvme"],
help="Remote device for ZeRO-3 initialized parameters.",
)
group.add_argument(
"--use-pin-memory",
action="store_true",
help="Use pinned CPU memory for ZeRO-3 initialized model parameters.",
)
return parser
def _add_memoryopt_args(parser):
"""Memory optimization arguments."""
group = parser.add_argument_group("Memory optimizations", "configurations")
group.add_argument(
"--scattered-embeddings",
action="store_true",
help="Save memory by scattering embedding activations. "
"Introduces dropout differences across MP configurations.",
)
group.add_argument(
"--split-transformers",
action="store_true",
help="Save memory by splitting transformer layers into two parts, "
"allowing for more frequent activation checkpoint savings.",
)
group.add_argument(
"--memory-centric-tiled-linear",
action="store_true",
help="Save memory by tiling with deepspeed.zero.TiledLinear.",
)
group.add_argument(
"--tile-factor",
type=int,
default=1,
help="Make all linear layers the same size of [hidden/tile_factor, hidden/tile_factor]. "
"Must be enabled with --memory-centric-tiled-linear. "
"Example A: if tile_factor=1, the qkv layer [hidden, 3* hidden] would be converted into [1,3] tiles of size [hidden,hidden]. "
"Example B: if tile_factor=2, the intermediate layer [4*hidden, hidden] will be converted into [8, 2] tiles of size [hidden/2, hidden/2]. "
"Default is 1.",
)
return parser
def _add_activation_checkpoint_args(parser):
group = parser.add_argument_group(
"Activation Checkpointing", "Checkpointing Configurations"
)
group.add_argument(
"--deepspeed-activation-checkpointing",
action="store_true",
help="uses activation checkpointing from deepspeed",
)
group.add_argument(
"--partition-activations",
action="store_true",
help="partition Activations across GPUs before checkpointing.",
)
group.add_argument(
"--contigious-checkpointing",
action="store_true",
help="Contigious memory checkpointing for activatoins.",
)
group.add_argument(
"--checkpoint-in-cpu",
action="store_true",
help="Move the activation checkpoints to CPU.",
)
group.add_argument(
"--synchronize-each-layer",
action="store_true",
help="does a synchronize at the beginning and end of each checkpointed layer.",
)
group.add_argument(
"--profile-backward",
action="store_true",
help="Enables backward pass profiling for checkpointed layers.",
)
return parser