You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

634 lines
25 KiB
Python

# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer based language model."""
import torch
import torch.nn.functional as F
from codegeex.megatron import get_args
2 years ago
from codegeex.megatron import mpu, print_rank_0
from codegeex.megatron.model.module import MegatronModule
from codegeex.megatron.model.transformer import ParallelTransformer
from codegeex.megatron.model.utils import init_method_normal, scaled_init_method_normal
2 years ago
from codegeex.megatron.mpu.initialize import get_tensor_model_parallel_world_size
2 years ago
def get_shrink_embedding_gradient_alpha(iteration):
args = get_args()
alpha = args.shrink_embedding_gradient_alpha
if args.shrink_embedding_gradient_steps is None:
return alpha
else:
x1 = int(args.shrink_embedding_gradient_steps[0])
x2 = int(args.shrink_embedding_gradient_steps[1])
if iteration <= x1:
return alpha
elif iteration >= x1 + x2:
return 1.0
else:
return alpha + (1 - alpha) * (args.iteration - x1) / x2
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
"""LM logits using word embedding weights."""
# Parallel logits.
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
2 years ago
args = get_args()
if args.shrink_logit_embedding_gradient:
if hasattr(args, 'iteration'):
alpha = get_shrink_embedding_gradient_alpha(args.iteration + 1)
else:
alpha = args.shrink_embedding_gradient_alpha
word_embeddings_weight = word_embeddings_weight if alpha == 1.0 \
else (
word_embeddings_weight * alpha +
word_embeddings_weight.detach() * (1 - alpha)
)
if bias is None:
logits_parallel = F.linear(input_parallel, word_embeddings_weight.half())
else:
logits_parallel = F.linear(input_parallel, word_embeddings_weight.half(), bias)
# Gather if needed.
if parallel_output:
return logits_parallel
return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(
num_tokentypes,
add_pooler,
init_method=None,
scaled_init_method=None,
):
"""Build language model and return along with the key to save."""
args = get_args()
if init_method is None:
init_method = init_method_normal(args.init_method_std)
if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
# Language model.
language_model = TransformerLanguageModel(
init_method=init_method,
output_layer_init_method=scaled_init_method,
num_tokentypes=num_tokentypes,
add_pooler=add_pooler)
# key used for checkpoints.
language_model_key = 'language_model'
return language_model, language_model_key
class Embedding(MegatronModule):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(
self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
init_method,
num_tokentypes=0,
):
super(Embedding, self).__init__()
2 years ago
args = get_args()
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes
2 years ago
self.max_sequence_length = max_sequence_length
# Word embeddings (parallel).
self.word_embeddings = mpu.VocabParallelEmbedding(
vocab_size, self.hidden_size, init_method=self.init_method)
2 years ago
self._word_embeddings_key = 'word_embeddings'
2 years ago
self.vocab_size = vocab_size
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(
max_sequence_length, self.hidden_size)
self.position_embeddings = self.position_embeddings.half()
2 years ago
self._position_embeddings_key = 'position_embeddings'
2 years ago
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self._tokentype_embeddings_key = 'tokentype_embeddings'
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
This allows us to load the model normally and then add this embedding.
"""
if self.tokentype_embeddings is not None:
raise Exception('tokentype embeddings is already initialized')
if torch.distributed.get_rank() == 0:
print('adding embedding for {} tokentypes'.format(num_tokentypes),
flush=True)
self.num_tokentypes = num_tokentypes
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Dropout.
embeddings = self.embedding_dropout(embeddings)
return embeddings
def state_dict_for_save_checkpoint(
self, destination=None, prefix='', keep_vars=False,
):
"""For easy load."""
state_dict_ = {}
state_dict_[self._word_embeddings_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(
destination, prefix, keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] \
= self.tokentype_embeddings.state_dict(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Word embedding.
if self._word_embeddings_key in state_dict:
state_dict_ = state_dict[self._word_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'word_embeddings' in key:
state_dict_[key.split('word_embeddings.')[1]] \
= state_dict[key]
2 years ago
vocab_len = state_dict_['weight'].shape[0]
state_dict_["weight"] = state_dict_["weight"][:self.vocab_size // get_tensor_model_parallel_world_size()]
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
# Position embedding.
if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] \
= state_dict[key]
2 years ago
pos_len = state_dict_['weight'].shape[0]
max_seq_len = self.max_sequence_length
if pos_len < max_seq_len:
print_rank_0(f"Position embedding padded {pos_len} -> {max_seq_len}.")
position_embeddings_padded = torch.nn.Embedding(
max_seq_len - pos_len, self.hidden_size).half()
self.init_method(position_embeddings_padded.weight)
state_dict_['weight'] = torch.cat([state_dict_['weight'], position_embeddings_padded.weight], dim=0)
# self.position_embeddings = self.position_embeddings.half()
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.
if self.num_tokentypes > 0:
state_dict_ = {}
if self._tokentype_embeddings_key in state_dict:
state_dict_ = state_dict[self._tokentype_embeddings_key]
else:
# for backward compatibility.
for key in state_dict.keys():
if 'tokentype_embeddings' in key:
state_dict_[key.split('tokentype_embeddings.')[1]] \
= state_dict[key]
if len(state_dict_.keys()) > 0:
self.tokentype_embeddings.load_state_dict(state_dict_,
strict=strict)
else:
print('***WARNING*** expected tokentype embeddings in the '
'checkpoint but could not find it', flush=True)
class EmbeddingPipe(Embedding):
def forward(self, inputs, **kwargs):
if not hasattr(self, "_args"):
self._args = get_args()
input_ids = inputs[0]
position_ids = inputs[1]
if hasattr(self._args, "attn_mask"):
attention_mask = None
else:
attention_mask = inputs[2]
if len(inputs) == 4:
tokentype_ids = inputs[3]
else:
tokentype_ids = None
embeddings = super().forward(
input_ids, position_ids, tokentype_ids=tokentype_ids
)
# If cmd args has attn_mask, we don't forward it as an activation.
if hasattr(self._args, "attn_mask"):
return embeddings
else:
assert False
return embeddings, attention_mask
@property
def word_embeddings_weight(self):
"""Easy accessory for the DeepSpeed pipeline engine to tie embeddings across stages."""
return self.word_embeddings.weight
class QueryEmbedding(MegatronModule):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
init_method,
num_tokentypes=0):
super(QueryEmbedding, self).__init__()
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes
2 years ago
self.max_sequence_length = max_sequence_length
# Top query position embedding (serial).
self.top_query_embeddings = mpu.VocabParallelEmbedding(
max_sequence_length, self.hidden_size, init_method=self.init_method)
self.top_query_embeddings = self.top_query_embeddings.half()
self._top_query_embeddings_key = 'top_query_embeddings'
2 years ago
# Initialize the top query position embeddings.
self.init_method(self.top_query_embeddings.weight)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self._tokentype_embeddings_key = 'tokentype_embeddings'
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
This allows us to load the model normally and then add this embedding.
"""
if self.tokentype_embeddings is not None:
raise Exception('tokentype embeddings is already initialized')
if torch.distributed.get_rank() == 0:
print('adding embedding for {} tokentypes'.format(num_tokentypes),
flush=True)
self.num_tokentypes = num_tokentypes
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
def forward(self, position_ids, tokentype_ids=None):
# Embeddings.
embeddings = self.top_query_embeddings(position_ids)
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Dropout.
embeddings = self.embedding_dropout(embeddings)
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load."""
state_dict_ = {}
state_dict_[self._top_query_embeddings_key] \
= self.top_query_embeddings.state_dict(
destination, prefix, keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] \
= self.tokentype_embeddings.state_dict(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Position embedding.
if self._top_query_embeddings_key in state_dict:
state_dict_ = state_dict[self._top_query_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'top_query_embeddings' in key:
state_dict_[key.split('top_query_embeddings.')[1]] \
= state_dict[key]
2 years ago
pos_len = state_dict_['weight'].shape[0]
max_seq_len = self.max_sequence_length // get_tensor_model_parallel_world_size()
if pos_len < max_seq_len:
print_rank_0(f"Top query embedding padded {pos_len} -> {max_seq_len}.")
top_query_embeddings_padded = torch.nn.Embedding(
max_seq_len - pos_len, self.hidden_size).half()
self.init_method(top_query_embeddings_padded.weight)
state_dict_['weight'] = torch.cat([state_dict_['weight'], top_query_embeddings_padded.weight], dim=0)
self.top_query_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.
if self.num_tokentypes > 0:
state_dict_ = {}
if self._tokentype_embeddings_key in state_dict:
state_dict_ = state_dict[self._tokentype_embeddings_key]
else:
# for backward compatibility.
for key in state_dict.keys():
if 'tokentype_embeddings' in key:
state_dict_[key.split('tokentype_embeddings.')[1]] \
= state_dict[key]
if len(state_dict_.keys()) > 0:
self.tokentype_embeddings.load_state_dict(state_dict_,
strict=strict)
else:
print('***WARNING*** expected tokentype embeddings in the '
'checkpoint but could not find it', flush=True)
class QueryEmbeddingPipe(QueryEmbedding):
def forward(self, inputs, **kwargs):
if not hasattr(self, "_args"):
self._args = get_args()
position_ids = inputs[0]
if hasattr(self._args, "attn_mask"):
attention_mask = None
else:
attention_mask = inputs[1]
if len(inputs) == 3:
tokentype_ids = inputs[2]
else:
tokentype_ids = None
embeddings = super().forward(
position_ids, tokentype_ids=tokentype_ids,
)
# If cmd args has attn_mask, we don't forward it as an activation.
if hasattr(self._args, "attn_mask"):
return embeddings
else:
assert False
return embeddings, attention_mask
@property
def word_embeddings_weight(self):
"""Easy accessory for the DeepSpeed pipeline engine to tie embeddings across stages."""
return self.top_query_embeddings.weight
class TransformerLanguageModel(MegatronModule):
"""Transformer language model.
Arguments:
transformer_hparams: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
init_method,
output_layer_init_method,
num_tokentypes=0,
add_pooler=False):
super(TransformerLanguageModel, self).__init__()
args = get_args()
self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method = init_method
self.add_pooler = add_pooler
# Embeddings
self.embedding = Embedding(self.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes)
self._embedding_key = 'embedding'
# Query embeddings
self.topQueryEmbedding = QueryEmbedding(self.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes)
self._topQueryEmbedding_key = 'topQueryEmbedding'
# Transformer
self.transformer = ParallelTransformer(
self.init_method,
output_layer_init_method)
self._transformer_key = 'transformer'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.transformer.set_input_tensor(input_tensor)
def forward(
self,
input_ids,
position_ids,
attention_mask,
tokentype_ids=None,
layer_past=None,
get_key_value=False,
pooling_sequence_index=0,
prompt_length=None,
context_length=None,
):
# Embeddings.
embedding_output = self.embedding(input_ids, position_ids,
tokentype_ids=tokentype_ids)
query_position_ids = position_ids
queryEmbedding_out = self.topQueryEmbedding(query_position_ids,
tokentype_ids=tokentype_ids)
# Transformer.
transformer_output = self.transformer(embedding_output,
queryEmbedding_out,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
prompt_length=prompt_length,
context_length=context_length, )
return transformer_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load."""
state_dict_ = {}
state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._topQueryEmbedding_key] \
= self.topQueryEmbedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._transformer_key] \
= self.transformer.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.add_pooler:
state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Embedding.
if self._embedding_key in state_dict:
state_dict_ = state_dict[self._embedding_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if '_embeddings' in key:
state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict)
if self._topQueryEmbedding_key in state_dict:
state_dict_ = state_dict[self._topQueryEmbedding_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if '_embeddings' in key:
state_dict_[key] = state_dict[key]
self.topQueryEmbedding.load_state_dict(state_dict_, strict=strict)
# Transformer.
if self._transformer_key in state_dict:
state_dict_ = state_dict[self._transformer_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
self.transformer.load_state_dict(state_dict_, strict=strict)
# Pooler.
if self.add_pooler:
assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)