You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

200 lines
7.3 KiB
Python

# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron Module"""
import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from codegeex.megatron import get_args
from codegeex.megatron import mpu
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
def param_is_not_shared(param):
return not hasattr(param, "shared") or not param.shared
class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support
for pipelining."""
def __init__(self, share_word_embeddings=True):
super(MegatronModule, self).__init__()
self.share_word_embeddings = share_word_embeddings
def state_dict_for_save_checkpoint(
self, destination=None, prefix="", keep_vars=False
):
"""Use this function to override the state dict for
saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars)
def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage(ignore_virtual=True):
return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage(ignore_virtual=True):
if not self.share_word_embeddings:
raise Exception(
"word_embeddings_weight() called for last "
"stage, but share_word_embeddings is false"
)
return self.word_embeddings.weight
raise Exception(
"word_embeddings_weight() should be " "called for first and last stage only"
)
def initialize_word_embeddings(self, init_method_normal):
args = get_args()
if not self.share_word_embeddings:
raise Exception(
"initialize_word_embeddings() was called but "
"share_word_embeddings is false"
)
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. If we aren't using pipeline
# parallelism there is nothing to do.
if args.pipeline_model_parallel_size == 1:
return
# Parameters are shared between the word embeddings layer, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if mpu.is_pipeline_last_stage():
assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = "word_embeddings_for_head"
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.word_embeddings = mpu.VocabParallelEmbedding(
args.padded_vocab_size,
args.hidden_size,
init_method=init_method_normal(args.init_method_std),
)
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(
self.word_embeddings_weight().data, group=mpu.get_embedding_group()
)
else:
print(
"WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn
def fp32_to_float16(val, float16_convertor):
"""Convert fp32 `val` to fp16/bf16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _FLOAT_TYPES):
val = float16_convertor(val)
return val
return conversion_helper(val, half_conversion)
def float16_to_fp32(val):
"""Convert fp16/bf16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class Float16Module(MegatronModule):
def __init__(self, module, args):
super(Float16Module, self).__init__()
if args.fp16:
self.add_module("module", module.half())
def float16_convertor(val):
return val.half()
elif args.bf16:
self.add_module("module", module.bfloat16())
def float16_convertor(val):
return val.bfloat16()
else:
raise Exception("should not be here")
self.float16_convertor = float16_convertor
def forward(self, *inputs, **kwargs):
if mpu.is_pipeline_first_stage():
inputs = fp32_to_float16(inputs, self.float16_convertor)
outputs = self.module(*inputs, **kwargs)
if mpu.is_pipeline_last_stage():
outputs = float16_to_fp32(outputs)
return outputs
def state_dict(self, destination=None, prefix="", keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars)
def state_dict_for_save_checkpoint(
self, destination=None, prefix="", keep_vars=False
):
return self.module.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)