You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1056 lines
40 KiB
Python

# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer."""
import math
import torch
from torch.nn import LayerNorm
from codegeex.megatron import get_args
from codegeex.megatron import mpu
from codegeex.megatron.model.module import MegatronModule
from codegeex.megatron.model.utils import fast_gelu
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
p: number of model parallel partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
l: number of layers
Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
"""
class ParallelMLP(MegatronModule):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
"""
def __init__(self, init_method, output_layer_init_method):
super(ParallelMLP, self).__init__()
args = get_args()
# Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear(
args.hidden_size,
4 * args.hidden_size,
gather_output=False,
init_method=init_method,
# skip_bias_add=True,
)
self.activation_func = fast_gelu
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
4 * args.hidden_size,
args.hidden_size,
input_is_parallel=True if args.tensor_model_parallel_size > 1 else False,
init_method=output_layer_init_method,
# skip_bias_add=True,
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias
class ParallelSelfAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(self, init_method,
output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
self.layer_number = max(1, layer_number)
# Per attention head and per partition values.
world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(args.hidden_size,
world_size)
self.hidden_size_per_attention_head = mpu.divide(
args.hidden_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size)
if hasattr(args, 'attention_upweight'):
self.attention_upweight = args.attention_upweight
else:
self.attention_upweight = None
# Strided linear layer.
self.query = mpu.ColumnParallelLinear(
args.hidden_size,
args.hidden_size,
gather_output=False,
init_method=init_method)
self.key = mpu.ColumnParallelLinear(
args.hidden_size,
args.hidden_size,
gather_output=False,
init_method=init_method)
self.value = mpu.ColumnParallelLinear(
args.hidden_size,
args.hidden_size,
gather_output=False,
init_method=init_method)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.softmax = torch.nn.Softmax(dim=-1)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
# Output.
self.dense = mpu.RowParallelLinear(
args.hidden_size,
args.hidden_size,
input_is_parallel=True if args.tensor_model_parallel_size > 1 else False,
init_method=output_layer_init_method,
skip_bias_add=True)
def forward(
self,
hidden_states,
attention_mask,
layer_past=None,
get_key_value=False,
prompt_length=None,
context_length=None,
):
# hidden_states: [sq, b, h]
# =====================
# Query, Key, and Value
# =====================
query_layer, _ = self.query(hidden_states)
key_layer, _ = self.key(hidden_states)
value_layer, _ = self.value(hidden_states)
new_query_layer_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_query_layer_shape)
new_query_layer_shape = key_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
key_layer = key_layer.view(*new_query_layer_shape)
new_query_layer_shape = value_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
value_layer = value_layer.view(*new_query_layer_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer),
key_layer), dim=0)
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer), dim=0)
if get_key_value:
present = (key_layer, value_layer)
# ===================================
# Raw attention scores. [b, np, sq, sk]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1)
key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.matmul(query_layer.transpose(0, 1),
key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
if self.attention_upweight is not None and layer_past is None:
log_attention_weights = torch.zeros(attention_scores.size(3), attention_scores.size(3),
device=torch.cuda.current_device(),
dtype=torch.half if self.fp16 else torch.float32)
if prompt_length is None:
log_attention_weights = self.attention_upweight
else:
log_attention_weights[:prompt_length, :prompt_length] = self.attention_upweight
attention_scores += log_attention_weights
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
if get_key_value:
with torch.no_grad():
if layer_past is not None:
attention_mask = attention_mask[
...,
attention_scores.size(3) - 1,
:attention_scores.size(3)].unsqueeze(2)
else:
attention_mask = attention_mask[
...,
:attention_scores.size(3),
:attention_scores.size(3)]
# ===========================
# Attention probs and dropout
# ===========================
if context_length is not None:
attention_mask = torch.clone(attention_mask)
attention_mask[:, :, context_length:, :] = True
# attention scores and attention mask [b, np, sq, sk]
# attention_scores = attention_mask_func(attention_scores, attention_mask)
attention_scores = attention_scores - attention_mask * 10000.0
if self.attention_softmax_in_fp32:
attention_probs = self.softmax(attention_scores.float()).half()
else:
attention_probs = self.softmax(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sq, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# change view [sq, b * np, hn]
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# # [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# # [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.dense(context_layer)
if get_key_value:
output = [output, present]
return output, bias
class ParallelTopQuerySelfAttention(MegatronModule):
"""Parallel top query self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(self, init_method,
output_layer_init_method, layer_number):
super(ParallelTopQuerySelfAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
self.layer_number = max(1, layer_number)
if hasattr(args, 'attention_upweight_top'):
self.attention_upweight = args.attention_upweight_top
else:
self.attention_upweight = None
# Per attention head and per partition values.
world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(args.hidden_size,
world_size)
self.hidden_size_per_attention_head = mpu.divide(
args.hidden_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size)
self.query = mpu.ColumnParallelLinear(
args.hidden_size,
args.hidden_size,
gather_output=False,
init_method=init_method)
self.key = mpu.ColumnParallelLinear(
args.hidden_size,
args.hidden_size,
gather_output=False,
init_method=init_method)
self.value = mpu.ColumnParallelLinear(
args.hidden_size,
args.hidden_size,
gather_output=False,
init_method=init_method)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.softmax = torch.nn.Softmax(dim=-1)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
# Output.
self.dense = mpu.RowParallelLinear(
args.hidden_size,
args.hidden_size,
input_is_parallel=True if args.tensor_model_parallel_size > 1 else False,
init_method=output_layer_init_method,
skip_bias_add=True)
def forward(
self,
hidden_states,
query_hidden_state,
attention_mask,
layer_past=None,
get_key_value=False,
prompt_length=None,
context_length=None,
):
# hidden_states: [sq, b, h]
query_layer, _ = self.query(query_hidden_state)
key_layer, _ = self.key(hidden_states)
value_layer, _ = self.value(hidden_states)
new_query_layer_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_query_layer_shape)
new_query_layer_shape = key_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
key_layer = key_layer.view(*new_query_layer_shape)
new_query_layer_shape = value_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
value_layer = value_layer.view(*new_query_layer_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer),
key_layer), dim=0)
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer), dim=0)
if get_key_value:
present = (key_layer, value_layer)
# ===================================
# Raw attention scores. [b, np, sq, sk]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [s, b, np, hn] -> [s, b * np, hn]
query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1)
key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.matmul(query_layer.transpose(0, 1),
key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor
# change view to [b, np, s, s]
attention_scores = matmul_result.view(*output_size)
if self.attention_upweight is not None and layer_past is None:
log_attention_weights = torch.zeros(attention_scores.size(3), attention_scores.size(3),
device=torch.cuda.current_device(),
dtype=torch.half if self.fp16 else torch.float32)
if prompt_length is None:
log_attention_weights = self.attention_upweight
else:
log_attention_weights[:prompt_length, :prompt_length] = self.attention_upweight
attention_scores += log_attention_weights
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
if get_key_value:
with torch.no_grad():
if layer_past is not None:
attention_mask = attention_mask[
...,
attention_scores.size(3) - 1,
:attention_scores.size(3)].unsqueeze(2)
else:
attention_mask = attention_mask[
...,
:attention_scores.size(3),
:attention_scores.size(3)]
# ===========================
# Attention probs and dropout
# ===========================
if context_length is not None:
attention_mask = torch.clone(attention_mask)
attention_mask[:, :, context_length:, :] = True
# attention scores and attention mask [b, np, sq, sk]
# attention_scores = attention_mask_func(attention_scores, attention_mask)
attention_scores = attention_scores - attention_mask * 10000.0
if self.attention_softmax_in_fp32:
attention_probs = self.softmax(attention_scores.float()).half()
else:
attention_probs = self.softmax(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sq, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# change view [sq, b * np, hn]
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.dense(context_layer)
if get_key_value:
output = [output, present]
return output, bias
def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
@torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob):
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob):
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, False)
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, init_method,
output_layer_init_method, layer_number):
args = get_args()
super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
# Self attention.
self.attention = ParallelSelfAttention(init_method,
output_layer_init_method,
layer_number)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
# Layernorm on the input data.
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
if hasattr(args, 'attention_upweight'):
self.attention_upweight = args.attention_upweight
else:
self.attention_upweight = None
if hasattr(args, 'ln_fp16'):
self.ln_fp16 = args.ln_fp16
else:
self.ln_fp16 = False
# MLP
self.mlp = ParallelMLP(init_method,
output_layer_init_method)
def forward(
self,
hidden_states,
attention_mask,
layer_past=None,
get_key_value=False,
prompt_length=None,
context_length=None,
):
# hidden_states: [b, s, h]
if self.ln_fp16:
layernorm_output = self.input_layernorm(hidden_states)
else:
layernorm_output = self.input_layernorm(hidden_states.float()).half()
# Self attention.
attention_output, attention_bias = \
self.attention(layernorm_output,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
prompt_length=prompt_length,
context_length=context_length)
if get_key_value:
attention_output, presents = attention_output
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the self attention.
if self.ln_fp16:
layernorm_output = self.post_attention_layernorm(layernorm_input)
else:
layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half()
mlp_output, _ = self.mlp(layernorm_output)
# MLP.
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
output = mlp_output + residual
if get_key_value:
output = [output, presents]
return output
class ParallelTransformerLayerPipe(ParallelTransformerLayer):
"""Extends ParallelTransformerLayer to forward attention_mask through the pipeline.
Forward has two usages that affect attention mask communication:
1) forward((input, attn_mask) , **kwargs) -> (output, mask)
When the attention mask is provided as the second positional
argument, typical pipeline behavior is used and both the output
*and* mask are returned in a tuple. This tuple is then forwarded
to the next stage in the pipeline.
This version is useful if masks are dynamic.
2) forward(input, **kwargs) -> output
When the mask is static over all samples, it is advantageous to
cache the mask and avoid communicating it.
If no mask is provided, the module will query `self._args.attn_mask`
for the mask and only return `super().forward(...)`
"""
def forward(self, inputs, **kwargs):
assert torch.is_tensor(inputs) or isinstance(inputs, tuple)
if torch.is_tensor(inputs) or len(inputs) == 1:
# No attention mask forwarded, search for args.attn_mask
if not hasattr(self, "_args"):
self._args = get_args()
hidden_states, attention_mask = inputs, self._args.attn_mask
return super().forward(hidden_states, attention_mask, **kwargs)
elif len(inputs) == 2:
# Attention mask is an activation.
hidden_states, attention_mask = inputs[0], inputs[1]
return super().forward(*inputs, **kwargs), attention_mask
else:
raise RuntimeError("Received more inputs than understood.")
class ParallelTopQueryLayer(MegatronModule):
"""A single top query layer.
Top query layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, init_method,
output_layer_init_method, layer_number):
args = get_args()
super(ParallelTopQueryLayer, self).__init__()
self.layer_number = layer_number
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
# Self attention.
self.attention = ParallelTopQuerySelfAttention(init_method,
output_layer_init_method,
layer_number)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
# Layernorm on the input data.
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
if hasattr(args, 'ln_fp16'):
self.ln_fp16 = args.ln_fp16
else:
self.ln_fp16 = False
# MLP
self.mlp = ParallelMLP(init_method,
output_layer_init_method)
def forward(
self,
hidden_states,
query_hidden_state,
attention_mask,
layer_past=None,
get_key_value=False,
prompt_length=None,
context_length=None,
):
# hidden_states: [b, s, h]
assert query_hidden_state != None
# Layer norm at the begining of the transformer layer.
if self.ln_fp16:
layernorm_output = self.input_layernorm(hidden_states)
else:
layernorm_output = self.input_layernorm(hidden_states.float()).half()
# Self attention.
attention_output, attention_bias = \
self.attention(layernorm_output,
query_hidden_state,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
prompt_length=prompt_length,
context_length=context_length)
if get_key_value:
attention_output, presents = attention_output
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the self attention.
if self.ln_fp16:
layernorm_output = self.post_attention_layernorm(layernorm_input)
else:
layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half()
# MLP.
mlp_output, _ = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
output = mlp_output + residual
if get_key_value:
output = [output, presents]
return output
class ParallelTopQueryLayerPipe(ParallelTopQueryLayer):
"""Extends ParallelTopQueryLayer to forward attention_mask through the pipeline.
Forward has two usages that affect attention mask communication:
1) forward((input, attn_mask) , **kwargs) -> (output, mask)
When the attention mask is provided as the second positional
argument, typical pipeline behavior is used and both the output
*and* mask are returned in a tuple. This tuple is then forwarded
to the next stage in the pipeline.
This version is useful if masks are dynamic.
2) forward(input, **kwargs) -> output
When the mask is static over all samples, it is advantageous to
cache the mask and avoid communicating it.
If no mask is provided, the module will query `self._args.attn_mask`
for the mask and only return `super().forward(...)`
"""
def forward(self, inputs, **kwargs):
assert torch.is_tensor(inputs) or isinstance(inputs, tuple)
if torch.is_tensor(inputs) or len(inputs) == 2:
# No attention mask forwarded, search for args.attn_mask
if not hasattr(self, "_args"):
self._args = get_args()
hidden_states, query_hidden_state = inputs
attention_mask = self._args.attn_mask
return super().forward(hidden_states, query_hidden_state, attention_mask, **kwargs)
elif len(inputs) == 3:
# Attention mask is an activation.
hidden_states, query_hidden_state, attention_mask = inputs[0], inputs[1]
return super().forward(*inputs, **kwargs), attention_mask
else:
raise RuntimeError("Received more inputs than understood.")
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(self, init_method, output_layer_init_method):
super(ParallelTransformer, self).__init__()
args = get_args()
# Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = args.checkpoint_num_layers
# Number of layers:
self.num_layers = args.num_layers
self.num_unique_layers = None
#################
assert self.num_unique_layers is None
#################
if self.num_unique_layers is None:
self.num_unique_layers = self.num_layers
assert self.num_layers % self.num_unique_layers == 0, \
'number of layers should be divisible by number of unique layers'
self.param_sharing_style = 'grouped'
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer(
init_method,
output_layer_init_method, layer_number)
self.layers = torch.nn.ModuleList(
[build_layer(i + 1) for i in range(self.num_unique_layers)])
self.topQueryLayer = ParallelTopQueryLayer(
init_method,
output_layer_init_method, self.num_unique_layers)
# Final layer norm before output.
if hasattr(args, 'ln_fp16'):
self.ln_fp16 = args.ln_fp16
else:
self.ln_fp16 = False
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
def _get_layer_index(self, layer_number):
if self.param_sharing_style == 'grouped':
return layer_number % self.num_unique_layers
if self.param_sharing_style == 'spaced':
return layer_number // (self.num_layers // self.num_unique_layers)
assert False, 'should not be here'
def _get_layer(self, layer_number):
return self.layers[self._get_layer_index(layer_number)]
def _checkpointed_forward(self, hidden_states, attention_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, inputs[1])
return x_
return custom_forward
# Make sure memory is freed.
mpu.reset_checkpointed_activations_memory_buffer()
l = 0
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask)
l += self.checkpoint_num_layers
return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(
self,
hidden_states,
query_hidden_state,
attention_mask,
layer_past=None,
get_key_value=False,
prompt_length=None,
context_length=None,
):
# Checks
if layer_past is not None:
assert get_key_value, \
'for not None values in layer_past, ' \
'expected get_key_value to be set'
if get_key_value:
assert not self.checkpoint_activations, \
'get_key_value does not work with ' \
'activation checkpointing'
# data format change to avoid explicit tranposes : [b s h] --> [s b h]
hidden_states = hidden_states.transpose(0, 1).contiguous()
query_hidden_state = query_hidden_state.transpose(0, 1).contiguous()
if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask)
else:
if get_key_value:
presents = []
for index in range(self.num_layers):
layer = self._get_layer(index)
past = None
if layer_past is not None:
past = layer_past[index]
hidden_states = layer(hidden_states,
attention_mask,
layer_past=past,
get_key_value=get_key_value,
prompt_length=prompt_length,
context_length=context_length)
if get_key_value:
hidden_states, present = hidden_states
presents.append(present)
if self.ln_fp16:
hidden_states_ = self.final_layernorm(hidden_states)
else:
hidden_states_ = self.final_layernorm(hidden_states.float()).half()
#################################
# top query layer
#################################
past = None
if layer_past is not None:
past = layer_past[self.num_layers]
hidden_states = self.topQueryLayer(hidden_states_,
query_hidden_state,
attention_mask,
layer_past=past,
get_key_value=get_key_value,
prompt_length=prompt_length,
context_length=context_length)
if get_key_value:
hidden_states, present = hidden_states
presents.append(present)
# reverting data format change [s b h] --> [b s h]
output = hidden_states.transpose(0, 1).contiguous()
if get_key_value:
output = [output, presents]
return output