You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

971 lines
36 KiB
Python

# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer."""
import math
import torch
from torch.nn import LayerNorm
from codegeex.megatron import get_args
from codegeex.megatron import mpu
from codegeex.megatron.model.module import MegatronModule
from codegeex.megatron.model.utils import fast_gelu
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
p: number of model parallel partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
l: number of layers
Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
"""
class ParallelMLP(MegatronModule):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
"""
def __init__(self, init_method, output_layer_init_method):
super(ParallelMLP, self).__init__()
args = get_args()
# Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear(
args.hidden_size,
4 * args.hidden_size,
gather_output=False,
init_method=init_method,
# skip_bias_add=True,
)
self.activation_func = fast_gelu
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
4 * args.hidden_size,
args.hidden_size,
input_is_parallel=False,
init_method=output_layer_init_method,
# skip_bias_add=True,
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias
class ParallelSelfAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(self, init_method,
output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
self.layer_number = max(1, layer_number)
# Per attention head and per partition values.
world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(args.hidden_size,
world_size)
self.hidden_size_per_attention_head = mpu.divide(
args.hidden_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size)
if hasattr(args, 'attention_upweight'):
self.attention_upweight = args.attention_upweight
else:
self.attention_upweight = None
# Strided linear layer.
self.query = mpu.ColumnParallelLinear(
args.hidden_size,
args.hidden_size,
gather_output=False,
init_method=init_method)
self.key = mpu.ColumnParallelLinear(
args.hidden_size,
args.hidden_size,
gather_output=False,
init_method=init_method)
self.value = mpu.ColumnParallelLinear(
args.hidden_size,
args.hidden_size,
gather_output=False,
init_method=init_method)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.softmax = torch.nn.Softmax(dim=-1)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
# Output.
self.dense = mpu.RowParallelLinear(
args.hidden_size,
args.hidden_size,
input_is_parallel=False,
init_method=output_layer_init_method,
skip_bias_add=True)
def forward(
self,
hidden_states,
attention_mask,
layer_past=None,
get_key_value=False,
prompt_length=None,
context_length=None,
):
# hidden_states: [sq, b, h]
# =====================
# Query, Key, and Value
# =====================
query_layer, _ = self.query(hidden_states)
key_layer, _ = self.key(hidden_states)
value_layer, _ = self.value(hidden_states)
new_query_layer_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_query_layer_shape)
new_query_layer_shape = key_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
key_layer = key_layer.view(*new_query_layer_shape)
new_query_layer_shape = value_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
value_layer = value_layer.view(*new_query_layer_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer),
key_layer), dim=0)
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer), dim=0)
if get_key_value:
present = (key_layer, value_layer)
# ===================================
# Raw attention scores. [b, np, sq, sk]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1)
key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.matmul(query_layer.transpose(0, 1),
key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
if self.attention_upweight is not None and layer_past is None:
log_attention_weights = torch.zeros(attention_scores.size(3), attention_scores.size(3),
device=torch.cuda.current_device(),
dtype=torch.half if self.fp16 else torch.float32)
if prompt_length is None:
log_attention_weights = self.attention_upweight
else:
log_attention_weights[:prompt_length, :prompt_length] = self.attention_upweight
attention_scores += log_attention_weights
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
if get_key_value:
with torch.no_grad():
if layer_past is not None:
attention_mask = attention_mask[
...,
attention_scores.size(3) - 1,
:attention_scores.size(3)].unsqueeze(2)
else:
attention_mask = attention_mask[
...,
:attention_scores.size(3),
:attention_scores.size(3)]
# ===========================
# Attention probs and dropout
# ===========================
if context_length is not None:
attention_mask = torch.clone(attention_mask)
attention_mask[:, :, context_length:, :] = True
# attention scores and attention mask [b, np, sq, sk]
# attention_scores = attention_mask_func(attention_scores, attention_mask)
attention_scores = attention_scores - attention_mask * 10000.0
if self.attention_softmax_in_fp32:
attention_probs = self.softmax(attention_scores.float()).half()
else:
attention_probs = self.softmax(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sq, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# change view [sq, b * np, hn]
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# # [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# # [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.dense(context_layer)
if get_key_value:
output = [output, present]
return output, bias
class ParallelTopQuerySelfAttention(MegatronModule):
"""Parallel top query self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(self, init_method,
output_layer_init_method, layer_number):
super(ParallelTopQuerySelfAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
self.layer_number = max(1, layer_number)
if hasattr(args, 'attention_upweight_top'):
self.attention_upweight = args.attention_upweight_top
else:
self.attention_upweight = None
# Per attention head and per partition values.
world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(args.hidden_size,
world_size)
self.hidden_size_per_attention_head = mpu.divide(
args.hidden_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size)
self.query = mpu.ColumnParallelLinear(
args.hidden_size,
args.hidden_size,
gather_output=False,
init_method=init_method)
self.key = mpu.ColumnParallelLinear(
args.hidden_size,
args.hidden_size,
gather_output=False,
init_method=init_method)
self.value = mpu.ColumnParallelLinear(
args.hidden_size,
args.hidden_size,
gather_output=False,
init_method=init_method)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.softmax = torch.nn.Softmax(dim=-1)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
# Output.
self.dense = mpu.RowParallelLinear(
args.hidden_size,
args.hidden_size,
input_is_parallel=False,
init_method=output_layer_init_method,
skip_bias_add=True)
def forward(
self,
hidden_states,
query_hidden_state,
attention_mask,
layer_past=None,
get_key_value=False,
prompt_length=None,
context_length=None,
):
# hidden_states: [sq, b, h]
query_layer, _ = self.query(query_hidden_state)
key_layer, _ = self.key(hidden_states)
value_layer, _ = self.value(hidden_states)
new_query_layer_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_query_layer_shape)
new_query_layer_shape = key_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
key_layer = key_layer.view(*new_query_layer_shape)
new_query_layer_shape = value_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
value_layer = value_layer.view(*new_query_layer_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer),
key_layer), dim=0)
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer), dim=0)
if get_key_value:
present = (key_layer, value_layer)
# ===================================
# Raw attention scores. [b, np, sq, sk]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [s, b, np, hn] -> [s, b * np, hn]
query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1)
key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.matmul(query_layer.transpose(0, 1),
key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor
# change view to [b, np, s, s]
attention_scores = matmul_result.view(*output_size)
if self.attention_upweight is not None and layer_past is None:
log_attention_weights = torch.zeros(attention_scores.size(3), attention_scores.size(3),
device=torch.cuda.current_device(),
dtype=torch.half if self.fp16 else torch.float32)
if prompt_length is None:
log_attention_weights = self.attention_upweight
else:
log_attention_weights[:prompt_length, :prompt_length] = self.attention_upweight
attention_scores += log_attention_weights
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
if get_key_value:
with torch.no_grad():
if layer_past is not None:
attention_mask = attention_mask[
...,
attention_scores.size(3) - 1,
:attention_scores.size(3)].unsqueeze(2)
else:
attention_mask = attention_mask[
...,
:attention_scores.size(3),
:attention_scores.size(3)]
# ===========================
# Attention probs and dropout
# ===========================
if context_length is not None:
attention_mask = torch.clone(attention_mask)
attention_mask[:, :, context_length:, :] = True
# attention scores and attention mask [b, np, sq, sk]
# attention_scores = attention_mask_func(attention_scores, attention_mask)
attention_scores = attention_scores - attention_mask * 10000.0
if self.attention_softmax_in_fp32:
attention_probs = self.softmax(attention_scores.float()).half()
else:
attention_probs = self.softmax(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sq, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# change view [sq, b * np, hn]
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.dense(context_layer)
if get_key_value:
output = [output, present]
return output, bias
def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
@torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob):
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob):
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, False)
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, init_method,
output_layer_init_method, layer_number):
args = get_args()
super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
# Self attention.
self.attention = ParallelSelfAttention(init_method,
output_layer_init_method,
layer_number)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
# Layernorm on the input data.
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
if hasattr(args, 'attention_upweight'):
self.attention_upweight = args.attention_upweight
else:
self.attention_upweight = None
if hasattr(args, 'ln_fp16'):
self.ln_fp16 = args.ln_fp16
else:
self.ln_fp16 = False
# MLP
self.mlp = ParallelMLP(init_method,
output_layer_init_method)
def forward(
self,
hidden_states,
attention_mask,
layer_past=None,
get_key_value=False,
prompt_length=None,
context_length=None,
):
# hidden_states: [b, s, h]
if self.ln_fp16:
layernorm_output = self.input_layernorm(hidden_states)
else:
layernorm_output = self.input_layernorm(hidden_states.float()).half()
# Self attention.
attention_output, attention_bias = \
self.attention(layernorm_output,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
prompt_length=prompt_length,
context_length=context_length)
if get_key_value:
attention_output, presents = attention_output
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the self attention.
if self.ln_fp16:
layernorm_output = self.post_attention_layernorm(layernorm_input)
else:
layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half()
mlp_output, _ = self.mlp(layernorm_output)
# MLP.
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
output = mlp_output + residual
if get_key_value:
output = [output, presents]
return output
class ParallelTopQueryLayer(MegatronModule):
"""A single top query layer.
Top query layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, init_method,
output_layer_init_method, layer_number):
args = get_args()
super(ParallelTopQueryLayer, self).__init__()
self.layer_number = layer_number
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
# Self attention.
self.attention = ParallelTopQuerySelfAttention(init_method,
output_layer_init_method,
layer_number)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
# Layernorm on the input data.
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
if hasattr(args, 'ln_fp16'):
self.ln_fp16 = args.ln_fp16
else:
self.ln_fp16 = False
# MLP
self.mlp = ParallelMLP(init_method,
output_layer_init_method)
def forward(
self,
hidden_states,
query_hidden_state,
attention_mask,
layer_past=None,
get_key_value=False,
prompt_length=None,
context_length=None,
):
# hidden_states: [b, s, h]
assert query_hidden_state != None
# Layer norm at the begining of the transformer layer.
if self.ln_fp16:
layernorm_output = self.input_layernorm(hidden_states)
else:
layernorm_output = self.input_layernorm(hidden_states.float()).half()
# Self attention.
attention_output, attention_bias = \
self.attention(layernorm_output,
query_hidden_state,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
prompt_length=prompt_length,
context_length=context_length)
if get_key_value:
attention_output, presents = attention_output
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the self attention.
if self.ln_fp16:
layernorm_output = self.post_attention_layernorm(layernorm_input)
else:
layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half()
# MLP.
mlp_output, _ = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
output = mlp_output + residual
if get_key_value:
output = [output, presents]
return output
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(self, init_method, output_layer_init_method):
super(ParallelTransformer, self).__init__()
args = get_args()
# Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = args.checkpoint_num_layers
# Number of layers:
self.num_layers = args.num_layers
self.num_unique_layers = None
#################
assert self.num_unique_layers is None
#################
if self.num_unique_layers is None:
self.num_unique_layers = self.num_layers
assert self.num_layers % self.num_unique_layers == 0, \
'number of layers should be divisible by number of unique layers'
self.param_sharing_style = 'grouped'
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer(
init_method,
output_layer_init_method, layer_number)
self.layers = torch.nn.ModuleList(
[build_layer(i + 1) for i in range(self.num_unique_layers)])
self.topQueryLayer = ParallelTopQueryLayer(
init_method,
output_layer_init_method, self.num_unique_layers)
# Final layer norm before output.
if hasattr(args, 'ln_fp16'):
self.ln_fp16 = args.ln_fp16
else:
self.ln_fp16 = False
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
def _get_layer_index(self, layer_number):
if self.param_sharing_style == 'grouped':
return layer_number % self.num_unique_layers
if self.param_sharing_style == 'spaced':
return layer_number // (self.num_layers // self.num_unique_layers)
assert False, 'should not be here'
def _get_layer(self, layer_number):
return self.layers[self._get_layer_index(layer_number)]
def _checkpointed_forward(self, hidden_states, attention_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, inputs[1])
return x_
return custom_forward
# Make sure memory is freed.
mpu.reset_checkpointed_activations_memory_buffer()
l = 0
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask)
l += self.checkpoint_num_layers
return hidden_states
def forward(
self,
hidden_states,
query_hidden_state,
attention_mask,
layer_past=None,
get_key_value=False,
prompt_length=None,
context_length=None,
):
# Checks
if layer_past is not None:
assert get_key_value, \
'for not None values in layer_past, ' \
'expected get_key_value to be set'
if get_key_value:
assert not self.checkpoint_activations, \
'get_key_value does not work with ' \
'activation checkpointing'
# data format change to avoid explicit tranposes : [b s h] --> [s b h]
hidden_states = hidden_states.transpose(0, 1).contiguous()
query_hidden_state = query_hidden_state.transpose(0, 1).contiguous()
if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask)
else:
if get_key_value:
presents = []
for index in range(self.num_layers):
layer = self._get_layer(index)
past = None
if layer_past is not None:
past = layer_past[index]
hidden_states = layer(hidden_states,
attention_mask,
layer_past=past,
get_key_value=get_key_value,
prompt_length=prompt_length,
context_length=context_length)
if get_key_value:
hidden_states, present = hidden_states
presents.append(present)
if self.ln_fp16:
hidden_states_ = self.final_layernorm(hidden_states)
else:
hidden_states_ = self.final_layernorm(hidden_states.float()).half()
#################################
# top query layer
#################################
past = None
if layer_past is not None:
past = layer_past[self.num_layers]
hidden_states = self.topQueryLayer(hidden_states_,
query_hidden_state,
attention_mask,
layer_past=past,
get_key_value=get_key_value,
prompt_length=prompt_length,
context_length=context_length)
if get_key_value:
hidden_states, present = hidden_states
presents.append(present)
# reverting data format change [s b h] --> [b s h]
output = hidden_states.transpose(0, 1).contiguous()
if get_key_value:
output = [output, presents]
return output