mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
971 lines
36 KiB
Python
971 lines
36 KiB
Python
# coding=utf-8
|
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Transformer."""
|
|
|
|
import math
|
|
import torch
|
|
from torch.nn import LayerNorm
|
|
|
|
from codegeex.megatron import get_args
|
|
from codegeex.megatron import mpu
|
|
from codegeex.megatron.model.module import MegatronModule
|
|
from codegeex.megatron.model.utils import fast_gelu
|
|
|
|
# flags required to enable jit fusion kernels
|
|
torch._C._jit_set_profiling_mode(False)
|
|
torch._C._jit_set_profiling_executor(False)
|
|
torch._C._jit_override_can_fuse_on_cpu(True)
|
|
torch._C._jit_override_can_fuse_on_gpu(True)
|
|
|
|
""" We use the following notation throughout this file:
|
|
h: hidden size
|
|
n: number of attention heads
|
|
p: number of model parallel partitions
|
|
np: n/p
|
|
hp: h/p
|
|
hn: h/n
|
|
b: batch size
|
|
s: sequence length
|
|
l: number of layers
|
|
Transformer takes input of size [s, b, h] and returns a
|
|
tensor of the same size. We use the following arguments:
|
|
hyperparameters: transformer hyperparameters
|
|
attention_mask_func: a function that takes `unmaksed-attention-scores`
|
|
with size [b, np, s, s] and an `attention-mask` and will apply
|
|
the masking. The function should return a masked score of the
|
|
same size [b, np, s, s].
|
|
masked-attention-scores = attention_mask_func(
|
|
unmaksed-attention-scores, attention-mask)
|
|
"""
|
|
|
|
|
|
class ParallelMLP(MegatronModule):
|
|
"""MLP.
|
|
|
|
MLP will take the input with h hidden state, project it to 4*h
|
|
hidden dimension, perform nonlinear transformation, and project the
|
|
state back into h hidden dimension. At the end, dropout is also
|
|
applied.
|
|
"""
|
|
|
|
def __init__(self, init_method, output_layer_init_method):
|
|
super(ParallelMLP, self).__init__()
|
|
args = get_args()
|
|
|
|
# Project to 4h.
|
|
self.dense_h_to_4h = mpu.ColumnParallelLinear(
|
|
args.hidden_size,
|
|
4 * args.hidden_size,
|
|
gather_output=False,
|
|
init_method=init_method,
|
|
# skip_bias_add=True,
|
|
)
|
|
|
|
self.activation_func = fast_gelu
|
|
|
|
# Project back to h.
|
|
self.dense_4h_to_h = mpu.RowParallelLinear(
|
|
4 * args.hidden_size,
|
|
args.hidden_size,
|
|
input_is_parallel=False,
|
|
init_method=output_layer_init_method,
|
|
# skip_bias_add=True,
|
|
)
|
|
|
|
def forward(self, hidden_states):
|
|
# [s, b, 4hp]
|
|
intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
|
|
intermediate_parallel = self.activation_func(intermediate_parallel)
|
|
# [s, b, h]
|
|
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
|
|
|
|
return output, output_bias
|
|
|
|
|
|
class ParallelSelfAttention(MegatronModule):
|
|
"""Parallel self-attention layer abstract class.
|
|
|
|
Self-attention layer takes input with size [b, s, h]
|
|
and returns output of the same size.
|
|
"""
|
|
|
|
def __init__(self, init_method,
|
|
output_layer_init_method, layer_number):
|
|
super(ParallelSelfAttention, self).__init__()
|
|
args = get_args()
|
|
self.fp16 = args.fp16
|
|
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
|
|
self.layer_number = max(1, layer_number)
|
|
|
|
# Per attention head and per partition values.
|
|
world_size = mpu.get_model_parallel_world_size()
|
|
self.hidden_size_per_partition = mpu.divide(args.hidden_size,
|
|
world_size)
|
|
self.hidden_size_per_attention_head = mpu.divide(
|
|
args.hidden_size, args.num_attention_heads)
|
|
self.num_attention_heads_per_partition = mpu.divide(
|
|
args.num_attention_heads, world_size)
|
|
if hasattr(args, 'attention_upweight'):
|
|
self.attention_upweight = args.attention_upweight
|
|
else:
|
|
self.attention_upweight = None
|
|
# Strided linear layer.
|
|
self.query = mpu.ColumnParallelLinear(
|
|
args.hidden_size,
|
|
args.hidden_size,
|
|
gather_output=False,
|
|
init_method=init_method)
|
|
self.key = mpu.ColumnParallelLinear(
|
|
args.hidden_size,
|
|
args.hidden_size,
|
|
gather_output=False,
|
|
init_method=init_method)
|
|
self.value = mpu.ColumnParallelLinear(
|
|
args.hidden_size,
|
|
args.hidden_size,
|
|
gather_output=False,
|
|
init_method=init_method)
|
|
|
|
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
|
self.softmax = torch.nn.Softmax(dim=-1)
|
|
|
|
# Dropout. Note that for a single iteration, this layer will generate
|
|
# different outputs on different number of parallel partitions but
|
|
# on average it should not be partition dependent.
|
|
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
|
|
|
|
# Output.
|
|
self.dense = mpu.RowParallelLinear(
|
|
args.hidden_size,
|
|
args.hidden_size,
|
|
input_is_parallel=False,
|
|
init_method=output_layer_init_method,
|
|
skip_bias_add=True)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask,
|
|
layer_past=None,
|
|
get_key_value=False,
|
|
prompt_length=None,
|
|
context_length=None,
|
|
):
|
|
# hidden_states: [sq, b, h]
|
|
|
|
# =====================
|
|
# Query, Key, and Value
|
|
# =====================
|
|
|
|
query_layer, _ = self.query(hidden_states)
|
|
key_layer, _ = self.key(hidden_states)
|
|
value_layer, _ = self.value(hidden_states)
|
|
|
|
new_query_layer_shape = query_layer.size()[:-1] + \
|
|
(self.num_attention_heads_per_partition,
|
|
self.hidden_size_per_attention_head)
|
|
query_layer = query_layer.view(*new_query_layer_shape)
|
|
|
|
new_query_layer_shape = key_layer.size()[:-1] + \
|
|
(self.num_attention_heads_per_partition,
|
|
self.hidden_size_per_attention_head)
|
|
key_layer = key_layer.view(*new_query_layer_shape)
|
|
|
|
new_query_layer_shape = value_layer.size()[:-1] + \
|
|
(self.num_attention_heads_per_partition,
|
|
self.hidden_size_per_attention_head)
|
|
value_layer = value_layer.view(*new_query_layer_shape)
|
|
|
|
# ==================================
|
|
# Adjust key and value for inference
|
|
# ==================================
|
|
|
|
if layer_past is not None:
|
|
past_key, past_value = layer_past
|
|
key_layer = torch.cat((past_key.type_as(key_layer),
|
|
key_layer), dim=0)
|
|
value_layer = torch.cat((past_value.type_as(value_layer),
|
|
value_layer), dim=0)
|
|
if get_key_value:
|
|
present = (key_layer, value_layer)
|
|
|
|
# ===================================
|
|
# Raw attention scores. [b, np, sq, sk]
|
|
# ===================================
|
|
|
|
# [b, np, sq, sk]
|
|
output_size = (query_layer.size(1),
|
|
query_layer.size(2),
|
|
query_layer.size(0),
|
|
key_layer.size(0))
|
|
|
|
# [sq, b, np, hn] -> [sq, b * np, hn]
|
|
query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1)
|
|
key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1)
|
|
|
|
# Raw attention scores. [b * np, sq, sk]
|
|
matmul_result = torch.matmul(query_layer.transpose(0, 1),
|
|
key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor
|
|
|
|
# change view to [b, np, sq, sk]
|
|
attention_scores = matmul_result.view(*output_size)
|
|
|
|
if self.attention_upweight is not None and layer_past is None:
|
|
log_attention_weights = torch.zeros(attention_scores.size(3), attention_scores.size(3),
|
|
device=torch.cuda.current_device(),
|
|
dtype=torch.half if self.fp16 else torch.float32)
|
|
if prompt_length is None:
|
|
log_attention_weights = self.attention_upweight
|
|
else:
|
|
log_attention_weights[:prompt_length, :prompt_length] = self.attention_upweight
|
|
attention_scores += log_attention_weights
|
|
|
|
# ==================================================
|
|
# Update attention mask for inference. [b, np, sq, sk]
|
|
# ==================================================
|
|
|
|
if get_key_value:
|
|
with torch.no_grad():
|
|
if layer_past is not None:
|
|
attention_mask = attention_mask[
|
|
...,
|
|
attention_scores.size(3) - 1,
|
|
:attention_scores.size(3)].unsqueeze(2)
|
|
else:
|
|
attention_mask = attention_mask[
|
|
...,
|
|
:attention_scores.size(3),
|
|
:attention_scores.size(3)]
|
|
|
|
# ===========================
|
|
# Attention probs and dropout
|
|
# ===========================
|
|
|
|
if context_length is not None:
|
|
attention_mask = torch.clone(attention_mask)
|
|
attention_mask[:, :, context_length:, :] = True
|
|
|
|
# attention scores and attention mask [b, np, sq, sk]
|
|
# attention_scores = attention_mask_func(attention_scores, attention_mask)
|
|
attention_scores = attention_scores - attention_mask * 10000.0
|
|
if self.attention_softmax_in_fp32:
|
|
attention_probs = self.softmax(attention_scores.float()).half()
|
|
else:
|
|
attention_probs = self.softmax(attention_scores)
|
|
|
|
# This is actually dropping out entire tokens to attend to, which might
|
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
with mpu.get_cuda_rng_tracker().fork():
|
|
attention_probs = self.attention_dropout(attention_probs)
|
|
|
|
# =========================
|
|
# Context layer. [sq, b, hp]
|
|
# =========================
|
|
|
|
# value_layer -> context layer.
|
|
# [sq, b, np, hn] --> [b, np, sq, hn]
|
|
|
|
# context layer shape: [b, np, sq, hn]
|
|
output_size = (value_layer.size(1),
|
|
value_layer.size(2),
|
|
query_layer.size(0),
|
|
value_layer.size(3))
|
|
|
|
# change view [sq, b * np, hn]
|
|
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
|
|
|
|
# change view [b * np, sq, sk]
|
|
attention_probs = attention_probs.view(output_size[0] * output_size[1],
|
|
output_size[2], -1)
|
|
|
|
context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))
|
|
|
|
# change view [b, np, sq, hn]
|
|
context_layer = context_layer.view(*output_size)
|
|
|
|
# # [b, np, sq, hn] --> [sq, b, np, hn]
|
|
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
|
|
|
# # [sq, b, np, hn] --> [sq, b, hp]
|
|
new_context_layer_shape = context_layer.size()[:-2] + \
|
|
(self.hidden_size_per_partition,)
|
|
context_layer = context_layer.view(*new_context_layer_shape)
|
|
|
|
# =================
|
|
# Output. [sq, b, h]
|
|
# =================
|
|
|
|
output, bias = self.dense(context_layer)
|
|
|
|
if get_key_value:
|
|
output = [output, present]
|
|
|
|
return output, bias
|
|
|
|
|
|
class ParallelTopQuerySelfAttention(MegatronModule):
|
|
"""Parallel top query self-attention layer abstract class.
|
|
|
|
Self-attention layer takes input with size [b, s, h]
|
|
and returns output of the same size.
|
|
"""
|
|
|
|
def __init__(self, init_method,
|
|
output_layer_init_method, layer_number):
|
|
super(ParallelTopQuerySelfAttention, self).__init__()
|
|
args = get_args()
|
|
self.fp16 = args.fp16
|
|
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
|
|
self.layer_number = max(1, layer_number)
|
|
|
|
if hasattr(args, 'attention_upweight_top'):
|
|
self.attention_upweight = args.attention_upweight_top
|
|
else:
|
|
self.attention_upweight = None
|
|
# Per attention head and per partition values.
|
|
world_size = mpu.get_model_parallel_world_size()
|
|
self.hidden_size_per_partition = mpu.divide(args.hidden_size,
|
|
world_size)
|
|
self.hidden_size_per_attention_head = mpu.divide(
|
|
args.hidden_size, args.num_attention_heads)
|
|
self.num_attention_heads_per_partition = mpu.divide(
|
|
args.num_attention_heads, world_size)
|
|
|
|
self.query = mpu.ColumnParallelLinear(
|
|
args.hidden_size,
|
|
args.hidden_size,
|
|
gather_output=False,
|
|
init_method=init_method)
|
|
|
|
self.key = mpu.ColumnParallelLinear(
|
|
args.hidden_size,
|
|
args.hidden_size,
|
|
gather_output=False,
|
|
init_method=init_method)
|
|
|
|
self.value = mpu.ColumnParallelLinear(
|
|
args.hidden_size,
|
|
args.hidden_size,
|
|
gather_output=False,
|
|
init_method=init_method)
|
|
|
|
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
|
self.softmax = torch.nn.Softmax(dim=-1)
|
|
|
|
# Dropout. Note that for a single iteration, this layer will generate
|
|
# different outputs on different number of parallel partitions but
|
|
# on average it should not be partition dependent.
|
|
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
|
|
|
|
# Output.
|
|
self.dense = mpu.RowParallelLinear(
|
|
args.hidden_size,
|
|
args.hidden_size,
|
|
input_is_parallel=False,
|
|
init_method=output_layer_init_method,
|
|
skip_bias_add=True)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
query_hidden_state,
|
|
attention_mask,
|
|
layer_past=None,
|
|
get_key_value=False,
|
|
prompt_length=None,
|
|
context_length=None,
|
|
):
|
|
|
|
# hidden_states: [sq, b, h]
|
|
|
|
query_layer, _ = self.query(query_hidden_state)
|
|
key_layer, _ = self.key(hidden_states)
|
|
value_layer, _ = self.value(hidden_states)
|
|
|
|
new_query_layer_shape = query_layer.size()[:-1] + \
|
|
(self.num_attention_heads_per_partition,
|
|
self.hidden_size_per_attention_head)
|
|
query_layer = query_layer.view(*new_query_layer_shape)
|
|
|
|
new_query_layer_shape = key_layer.size()[:-1] + \
|
|
(self.num_attention_heads_per_partition,
|
|
self.hidden_size_per_attention_head)
|
|
key_layer = key_layer.view(*new_query_layer_shape)
|
|
|
|
new_query_layer_shape = value_layer.size()[:-1] + \
|
|
(self.num_attention_heads_per_partition,
|
|
self.hidden_size_per_attention_head)
|
|
value_layer = value_layer.view(*new_query_layer_shape)
|
|
|
|
# ==================================
|
|
# Adjust key and value for inference
|
|
# ==================================
|
|
|
|
if layer_past is not None:
|
|
past_key, past_value = layer_past
|
|
key_layer = torch.cat((past_key.type_as(key_layer),
|
|
key_layer), dim=0)
|
|
value_layer = torch.cat((past_value.type_as(value_layer),
|
|
value_layer), dim=0)
|
|
if get_key_value:
|
|
present = (key_layer, value_layer)
|
|
|
|
# ===================================
|
|
# Raw attention scores. [b, np, sq, sk]
|
|
# ===================================
|
|
|
|
# [b, np, sq, sk]
|
|
output_size = (query_layer.size(1),
|
|
query_layer.size(2),
|
|
query_layer.size(0),
|
|
key_layer.size(0))
|
|
|
|
# [s, b, np, hn] -> [s, b * np, hn]
|
|
query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1)
|
|
key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1)
|
|
|
|
# Raw attention scores. [b * np, sq, sk]
|
|
matmul_result = torch.matmul(query_layer.transpose(0, 1),
|
|
key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor
|
|
|
|
# change view to [b, np, s, s]
|
|
attention_scores = matmul_result.view(*output_size)
|
|
|
|
if self.attention_upweight is not None and layer_past is None:
|
|
log_attention_weights = torch.zeros(attention_scores.size(3), attention_scores.size(3),
|
|
device=torch.cuda.current_device(),
|
|
dtype=torch.half if self.fp16 else torch.float32)
|
|
if prompt_length is None:
|
|
log_attention_weights = self.attention_upweight
|
|
else:
|
|
log_attention_weights[:prompt_length, :prompt_length] = self.attention_upweight
|
|
attention_scores += log_attention_weights
|
|
|
|
# ==================================================
|
|
# Update attention mask for inference. [b, np, sq, sk]
|
|
# ==================================================
|
|
|
|
if get_key_value:
|
|
with torch.no_grad():
|
|
if layer_past is not None:
|
|
attention_mask = attention_mask[
|
|
...,
|
|
attention_scores.size(3) - 1,
|
|
:attention_scores.size(3)].unsqueeze(2)
|
|
else:
|
|
attention_mask = attention_mask[
|
|
...,
|
|
:attention_scores.size(3),
|
|
:attention_scores.size(3)]
|
|
|
|
# ===========================
|
|
# Attention probs and dropout
|
|
# ===========================
|
|
|
|
if context_length is not None:
|
|
attention_mask = torch.clone(attention_mask)
|
|
attention_mask[:, :, context_length:, :] = True
|
|
|
|
# attention scores and attention mask [b, np, sq, sk]
|
|
# attention_scores = attention_mask_func(attention_scores, attention_mask)
|
|
attention_scores = attention_scores - attention_mask * 10000.0
|
|
if self.attention_softmax_in_fp32:
|
|
attention_probs = self.softmax(attention_scores.float()).half()
|
|
else:
|
|
attention_probs = self.softmax(attention_scores)
|
|
|
|
# This is actually dropping out entire tokens to attend to, which might
|
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
with mpu.get_cuda_rng_tracker().fork():
|
|
attention_probs = self.attention_dropout(attention_probs)
|
|
|
|
# =========================
|
|
# Context layer. [sq, b, hp]
|
|
# =========================
|
|
|
|
# value_layer -> context layer.
|
|
# [sq, b, np, hn] --> [b, np, sq, hn]
|
|
|
|
# context layer shape: [b, np, sq, hn]
|
|
output_size = (value_layer.size(1),
|
|
value_layer.size(2),
|
|
query_layer.size(0),
|
|
value_layer.size(3))
|
|
|
|
# change view [sq, b * np, hn]
|
|
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
|
|
|
|
# change view [b * np, sq, sk]
|
|
attention_probs = attention_probs.view(output_size[0] * output_size[1],
|
|
output_size[2], -1)
|
|
|
|
# matmul: [b * np, sq, hn]
|
|
context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))
|
|
|
|
# change view [b, np, sq, hn]
|
|
context_layer = context_layer.view(*output_size)
|
|
|
|
# [b, np, sq, hn] --> [sq, b, np, hn]
|
|
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
|
|
|
# [sq, b, np, hn] --> [sq, b, hp]
|
|
new_context_layer_shape = context_layer.size()[:-2] + \
|
|
(self.hidden_size_per_partition,)
|
|
context_layer = context_layer.view(*new_context_layer_shape)
|
|
|
|
# =================
|
|
# Output. [sq, b, h]
|
|
# =================
|
|
|
|
output, bias = self.dense(context_layer)
|
|
|
|
if get_key_value:
|
|
output = [output, present]
|
|
|
|
return output, bias
|
|
|
|
|
|
def bias_dropout_add(x, bias, residual, prob, training):
|
|
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
|
|
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
|
|
out = residual + out
|
|
return out
|
|
|
|
|
|
def get_bias_dropout_add(training):
|
|
def _bias_dropout_add(x, bias, residual, prob):
|
|
return bias_dropout_add(x, bias, residual, prob, training)
|
|
|
|
return _bias_dropout_add
|
|
|
|
|
|
@torch.jit.script
|
|
def bias_dropout_add_fused_train(x, bias, residual, prob):
|
|
# type: (Tensor, Tensor, Tensor, float) -> Tensor
|
|
return bias_dropout_add(x, bias, residual, prob, True)
|
|
|
|
|
|
@torch.jit.script
|
|
def bias_dropout_add_fused_inference(x, bias, residual, prob):
|
|
# type: (Tensor, Tensor, Tensor, float) -> Tensor
|
|
return bias_dropout_add(x, bias, residual, prob, False)
|
|
|
|
|
|
class ParallelTransformerLayer(MegatronModule):
|
|
"""A single transformer layer.
|
|
|
|
Transformore layer takes input with size [b, s, h] and returns an
|
|
output of the same size.
|
|
"""
|
|
|
|
def __init__(self, init_method,
|
|
output_layer_init_method, layer_number):
|
|
args = get_args()
|
|
|
|
super(ParallelTransformerLayer, self).__init__()
|
|
self.layer_number = layer_number
|
|
|
|
self.apply_residual_connection_post_layernorm \
|
|
= args.apply_residual_connection_post_layernorm
|
|
|
|
# Layernorm on the input data.
|
|
self.input_layernorm = LayerNorm(
|
|
args.hidden_size,
|
|
eps=args.layernorm_epsilon)
|
|
|
|
# Self attention.
|
|
self.attention = ParallelSelfAttention(init_method,
|
|
output_layer_init_method,
|
|
layer_number)
|
|
self.hidden_dropout = args.hidden_dropout
|
|
self.bias_dropout_fusion = args.bias_dropout_fusion
|
|
|
|
# Layernorm on the input data.
|
|
self.post_attention_layernorm = LayerNorm(
|
|
args.hidden_size,
|
|
eps=args.layernorm_epsilon)
|
|
if hasattr(args, 'attention_upweight'):
|
|
self.attention_upweight = args.attention_upweight
|
|
else:
|
|
self.attention_upweight = None
|
|
if hasattr(args, 'ln_fp16'):
|
|
self.ln_fp16 = args.ln_fp16
|
|
else:
|
|
self.ln_fp16 = False
|
|
# MLP
|
|
self.mlp = ParallelMLP(init_method,
|
|
output_layer_init_method)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask,
|
|
layer_past=None,
|
|
get_key_value=False,
|
|
prompt_length=None,
|
|
context_length=None,
|
|
):
|
|
# hidden_states: [b, s, h]
|
|
if self.ln_fp16:
|
|
layernorm_output = self.input_layernorm(hidden_states)
|
|
else:
|
|
layernorm_output = self.input_layernorm(hidden_states.float()).half()
|
|
|
|
# Self attention.
|
|
attention_output, attention_bias = \
|
|
self.attention(layernorm_output,
|
|
attention_mask,
|
|
layer_past=layer_past,
|
|
get_key_value=get_key_value,
|
|
prompt_length=prompt_length,
|
|
context_length=context_length)
|
|
|
|
if get_key_value:
|
|
attention_output, presents = attention_output
|
|
|
|
# Residual connection.
|
|
if self.apply_residual_connection_post_layernorm:
|
|
residual = layernorm_output
|
|
else:
|
|
residual = hidden_states
|
|
|
|
# jit scripting for a nn.module (with dropout) is not
|
|
# trigerring the fusion kernel. For now, we use two
|
|
# different nn.functional routines to account for varying
|
|
# dropout semantics during training and inference phases.
|
|
if self.bias_dropout_fusion:
|
|
if self.training:
|
|
bias_dropout_add_func = bias_dropout_add_fused_train
|
|
else:
|
|
bias_dropout_add_func = bias_dropout_add_fused_inference
|
|
else:
|
|
bias_dropout_add_func = get_bias_dropout_add(self.training)
|
|
|
|
# re-enable torch grad to enable fused optimization.
|
|
with torch.enable_grad():
|
|
layernorm_input = bias_dropout_add_func(
|
|
attention_output,
|
|
attention_bias.expand_as(residual),
|
|
residual,
|
|
self.hidden_dropout)
|
|
|
|
# Layer norm post the self attention.
|
|
if self.ln_fp16:
|
|
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
|
else:
|
|
layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half()
|
|
|
|
mlp_output, _ = self.mlp(layernorm_output)
|
|
|
|
# MLP.
|
|
# Second residual connection.
|
|
if self.apply_residual_connection_post_layernorm:
|
|
residual = layernorm_output
|
|
else:
|
|
residual = layernorm_input
|
|
|
|
output = mlp_output + residual
|
|
|
|
if get_key_value:
|
|
output = [output, presents]
|
|
|
|
return output
|
|
|
|
|
|
class ParallelTopQueryLayer(MegatronModule):
|
|
"""A single top query layer.
|
|
|
|
Top query layer takes input with size [b, s, h] and returns an
|
|
output of the same size.
|
|
"""
|
|
|
|
def __init__(self, init_method,
|
|
output_layer_init_method, layer_number):
|
|
args = get_args()
|
|
|
|
super(ParallelTopQueryLayer, self).__init__()
|
|
self.layer_number = layer_number
|
|
|
|
self.apply_residual_connection_post_layernorm \
|
|
= args.apply_residual_connection_post_layernorm
|
|
|
|
# Layernorm on the input data.
|
|
self.input_layernorm = LayerNorm(
|
|
args.hidden_size,
|
|
eps=args.layernorm_epsilon)
|
|
|
|
# Self attention.
|
|
self.attention = ParallelTopQuerySelfAttention(init_method,
|
|
output_layer_init_method,
|
|
layer_number)
|
|
|
|
self.hidden_dropout = args.hidden_dropout
|
|
self.bias_dropout_fusion = args.bias_dropout_fusion
|
|
|
|
# Layernorm on the input data.
|
|
self.post_attention_layernorm = LayerNorm(
|
|
args.hidden_size,
|
|
eps=args.layernorm_epsilon)
|
|
|
|
if hasattr(args, 'ln_fp16'):
|
|
self.ln_fp16 = args.ln_fp16
|
|
else:
|
|
self.ln_fp16 = False
|
|
|
|
# MLP
|
|
self.mlp = ParallelMLP(init_method,
|
|
output_layer_init_method)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
query_hidden_state,
|
|
attention_mask,
|
|
layer_past=None,
|
|
get_key_value=False,
|
|
prompt_length=None,
|
|
context_length=None,
|
|
):
|
|
# hidden_states: [b, s, h]
|
|
assert query_hidden_state != None
|
|
|
|
# Layer norm at the begining of the transformer layer.
|
|
if self.ln_fp16:
|
|
layernorm_output = self.input_layernorm(hidden_states)
|
|
else:
|
|
layernorm_output = self.input_layernorm(hidden_states.float()).half()
|
|
|
|
# Self attention.
|
|
attention_output, attention_bias = \
|
|
self.attention(layernorm_output,
|
|
query_hidden_state,
|
|
attention_mask,
|
|
layer_past=layer_past,
|
|
get_key_value=get_key_value,
|
|
prompt_length=prompt_length,
|
|
context_length=context_length)
|
|
|
|
if get_key_value:
|
|
attention_output, presents = attention_output
|
|
|
|
# Residual connection.
|
|
if self.apply_residual_connection_post_layernorm:
|
|
residual = layernorm_output
|
|
else:
|
|
residual = hidden_states
|
|
|
|
# jit scripting for a nn.module (with dropout) is not
|
|
# trigerring the fusion kernel. For now, we use two
|
|
# different nn.functional routines to account for varying
|
|
# dropout semantics during training and inference phases.
|
|
if self.bias_dropout_fusion:
|
|
if self.training:
|
|
bias_dropout_add_func = bias_dropout_add_fused_train
|
|
else:
|
|
bias_dropout_add_func = bias_dropout_add_fused_inference
|
|
else:
|
|
bias_dropout_add_func = get_bias_dropout_add(self.training)
|
|
|
|
# re-enable torch grad to enable fused optimization.
|
|
with torch.enable_grad():
|
|
layernorm_input = bias_dropout_add_func(
|
|
attention_output,
|
|
attention_bias.expand_as(residual),
|
|
residual,
|
|
self.hidden_dropout)
|
|
|
|
# Layer norm post the self attention.
|
|
if self.ln_fp16:
|
|
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
|
else:
|
|
layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half()
|
|
|
|
# MLP.
|
|
mlp_output, _ = self.mlp(layernorm_output)
|
|
|
|
# Second residual connection.
|
|
if self.apply_residual_connection_post_layernorm:
|
|
residual = layernorm_output
|
|
else:
|
|
residual = layernorm_input
|
|
|
|
output = mlp_output + residual
|
|
|
|
if get_key_value:
|
|
output = [output, presents]
|
|
|
|
return output
|
|
|
|
|
|
class ParallelTransformer(MegatronModule):
|
|
"""Transformer class."""
|
|
|
|
def __init__(self, init_method, output_layer_init_method):
|
|
super(ParallelTransformer, self).__init__()
|
|
args = get_args()
|
|
|
|
# Store activation checkpoiting flag.
|
|
self.checkpoint_activations = args.checkpoint_activations
|
|
self.checkpoint_num_layers = args.checkpoint_num_layers
|
|
|
|
# Number of layers:
|
|
self.num_layers = args.num_layers
|
|
self.num_unique_layers = None
|
|
|
|
#################
|
|
assert self.num_unique_layers is None
|
|
#################
|
|
|
|
if self.num_unique_layers is None:
|
|
self.num_unique_layers = self.num_layers
|
|
assert self.num_layers % self.num_unique_layers == 0, \
|
|
'number of layers should be divisible by number of unique layers'
|
|
self.param_sharing_style = 'grouped'
|
|
|
|
# Transformer layers.
|
|
def build_layer(layer_number):
|
|
return ParallelTransformerLayer(
|
|
init_method,
|
|
output_layer_init_method, layer_number)
|
|
|
|
self.layers = torch.nn.ModuleList(
|
|
[build_layer(i + 1) for i in range(self.num_unique_layers)])
|
|
|
|
self.topQueryLayer = ParallelTopQueryLayer(
|
|
init_method,
|
|
output_layer_init_method, self.num_unique_layers)
|
|
|
|
# Final layer norm before output.
|
|
if hasattr(args, 'ln_fp16'):
|
|
self.ln_fp16 = args.ln_fp16
|
|
else:
|
|
self.ln_fp16 = False
|
|
|
|
self.final_layernorm = LayerNorm(
|
|
args.hidden_size,
|
|
eps=args.layernorm_epsilon)
|
|
|
|
def _get_layer_index(self, layer_number):
|
|
if self.param_sharing_style == 'grouped':
|
|
return layer_number % self.num_unique_layers
|
|
if self.param_sharing_style == 'spaced':
|
|
return layer_number // (self.num_layers // self.num_unique_layers)
|
|
assert False, 'should not be here'
|
|
|
|
def _get_layer(self, layer_number):
|
|
return self.layers[self._get_layer_index(layer_number)]
|
|
|
|
def _checkpointed_forward(self, hidden_states, attention_mask):
|
|
"""Forward method with activation checkpointing."""
|
|
|
|
def custom(start, end):
|
|
def custom_forward(*inputs):
|
|
x_ = inputs[0]
|
|
for index in range(start, end):
|
|
layer = self._get_layer(index)
|
|
x_ = layer(x_, inputs[1])
|
|
return x_
|
|
|
|
return custom_forward
|
|
|
|
# Make sure memory is freed.
|
|
mpu.reset_checkpointed_activations_memory_buffer()
|
|
l = 0
|
|
while l < self.num_layers:
|
|
hidden_states = mpu.checkpoint(
|
|
custom(l, l + self.checkpoint_num_layers),
|
|
hidden_states, attention_mask)
|
|
l += self.checkpoint_num_layers
|
|
|
|
return hidden_states
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
query_hidden_state,
|
|
attention_mask,
|
|
layer_past=None,
|
|
get_key_value=False,
|
|
prompt_length=None,
|
|
context_length=None,
|
|
):
|
|
|
|
# Checks
|
|
if layer_past is not None:
|
|
assert get_key_value, \
|
|
'for not None values in layer_past, ' \
|
|
'expected get_key_value to be set'
|
|
if get_key_value:
|
|
assert not self.checkpoint_activations, \
|
|
'get_key_value does not work with ' \
|
|
'activation checkpointing'
|
|
|
|
# data format change to avoid explicit tranposes : [b s h] --> [s b h]
|
|
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
|
query_hidden_state = query_hidden_state.transpose(0, 1).contiguous()
|
|
|
|
if self.checkpoint_activations:
|
|
hidden_states = self._checkpointed_forward(hidden_states,
|
|
attention_mask)
|
|
else:
|
|
if get_key_value:
|
|
presents = []
|
|
for index in range(self.num_layers):
|
|
layer = self._get_layer(index)
|
|
past = None
|
|
if layer_past is not None:
|
|
past = layer_past[index]
|
|
hidden_states = layer(hidden_states,
|
|
attention_mask,
|
|
layer_past=past,
|
|
get_key_value=get_key_value,
|
|
prompt_length=prompt_length,
|
|
context_length=context_length)
|
|
if get_key_value:
|
|
hidden_states, present = hidden_states
|
|
presents.append(present)
|
|
|
|
if self.ln_fp16:
|
|
hidden_states_ = self.final_layernorm(hidden_states)
|
|
else:
|
|
hidden_states_ = self.final_layernorm(hidden_states.float()).half()
|
|
|
|
#################################
|
|
# top query layer
|
|
#################################
|
|
past = None
|
|
if layer_past is not None:
|
|
past = layer_past[self.num_layers]
|
|
hidden_states = self.topQueryLayer(hidden_states_,
|
|
query_hidden_state,
|
|
attention_mask,
|
|
layer_past=past,
|
|
get_key_value=get_key_value,
|
|
prompt_length=prompt_length,
|
|
context_length=context_length)
|
|
|
|
if get_key_value:
|
|
hidden_states, present = hidden_states
|
|
presents.append(present)
|
|
|
|
# reverting data format change [s b h] --> [b s h]
|
|
output = hidden_states.transpose(0, 1).contiguous()
|
|
|
|
if get_key_value:
|
|
output = [output, presents]
|
|
|
|
return output
|