# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Transformer.""" import math import torch from torch.nn import LayerNorm from codegeex.megatron import get_args from codegeex.megatron import mpu from codegeex.megatron.model.module import MegatronModule from codegeex.megatron.model.utils import fast_gelu # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_override_can_fuse_on_gpu(True) """ We use the following notation throughout this file: h: hidden size n: number of attention heads p: number of model parallel partitions np: n/p hp: h/p hn: h/n b: batch size s: sequence length l: number of layers Transformer takes input of size [s, b, h] and returns a tensor of the same size. We use the following arguments: hyperparameters: transformer hyperparameters attention_mask_func: a function that takes `unmaksed-attention-scores` with size [b, np, s, s] and an `attention-mask` and will apply the masking. The function should return a masked score of the same size [b, np, s, s]. masked-attention-scores = attention_mask_func( unmaksed-attention-scores, attention-mask) """ class ParallelMLP(MegatronModule): """MLP. MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. At the end, dropout is also applied. """ def __init__(self, init_method, output_layer_init_method): super(ParallelMLP, self).__init__() args = get_args() # Project to 4h. self.dense_h_to_4h = mpu.ColumnParallelLinear( args.hidden_size, 4 * args.hidden_size, gather_output=False, init_method=init_method, # skip_bias_add=True, ) self.activation_func = fast_gelu # Project back to h. self.dense_4h_to_h = mpu.RowParallelLinear( 4 * args.hidden_size, args.hidden_size, input_is_parallel=False, init_method=output_layer_init_method, # skip_bias_add=True, ) def forward(self, hidden_states): # [s, b, 4hp] intermediate_parallel, _ = self.dense_h_to_4h(hidden_states) intermediate_parallel = self.activation_func(intermediate_parallel) # [s, b, h] output, output_bias = self.dense_4h_to_h(intermediate_parallel) return output, output_bias class ParallelSelfAttention(MegatronModule): """Parallel self-attention layer abstract class. Self-attention layer takes input with size [b, s, h] and returns output of the same size. """ def __init__(self, init_method, output_layer_init_method, layer_number): super(ParallelSelfAttention, self).__init__() args = get_args() self.fp16 = args.fp16 self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 self.layer_number = max(1, layer_number) # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( args.hidden_size, args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( args.num_attention_heads, world_size) if hasattr(args, 'attention_upweight'): self.attention_upweight = args.attention_upweight else: self.attention_upweight = None # Strided linear layer. self.query = mpu.ColumnParallelLinear( args.hidden_size, args.hidden_size, gather_output=False, init_method=init_method) self.key = mpu.ColumnParallelLinear( args.hidden_size, args.hidden_size, gather_output=False, init_method=init_method) self.value = mpu.ColumnParallelLinear( args.hidden_size, args.hidden_size, gather_output=False, init_method=init_method) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.softmax = torch.nn.Softmax(dim=-1) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(args.attention_dropout) # Output. self.dense = mpu.RowParallelLinear( args.hidden_size, args.hidden_size, input_is_parallel=False, init_method=output_layer_init_method, skip_bias_add=True) def forward( self, hidden_states, attention_mask, layer_past=None, get_key_value=False, prompt_length=None, context_length=None, ): # hidden_states: [sq, b, h] # ===================== # Query, Key, and Value # ===================== query_layer, _ = self.query(hidden_states) key_layer, _ = self.key(hidden_states) value_layer, _ = self.value(hidden_states) new_query_layer_shape = query_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) query_layer = query_layer.view(*new_query_layer_shape) new_query_layer_shape = key_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) key_layer = key_layer.view(*new_query_layer_shape) new_query_layer_shape = value_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) value_layer = value_layer.view(*new_query_layer_shape) # ================================== # Adjust key and value for inference # ================================== if layer_past is not None: past_key, past_value = layer_past key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) if get_key_value: present = (key_layer, value_layer) # =================================== # Raw attention scores. [b, np, sq, sk] # =================================== # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1) key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1) # Raw attention scores. [b * np, sq, sk] matmul_result = torch.matmul(query_layer.transpose(0, 1), key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) if self.attention_upweight is not None and layer_past is None: log_attention_weights = torch.zeros(attention_scores.size(3), attention_scores.size(3), device=torch.cuda.current_device(), dtype=torch.half if self.fp16 else torch.float32) if prompt_length is None: log_attention_weights = self.attention_upweight else: log_attention_weights[:prompt_length, :prompt_length] = self.attention_upweight attention_scores += log_attention_weights # ================================================== # Update attention mask for inference. [b, np, sq, sk] # ================================================== if get_key_value: with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ ..., attention_scores.size(3) - 1, :attention_scores.size(3)].unsqueeze(2) else: attention_mask = attention_mask[ ..., :attention_scores.size(3), :attention_scores.size(3)] # =========================== # Attention probs and dropout # =========================== if context_length is not None: attention_mask = torch.clone(attention_mask) attention_mask[:, :, context_length:, :] = True # attention scores and attention mask [b, np, sq, sk] # attention_scores = attention_mask_func(attention_scores, attention_mask) attention_scores = attention_scores - attention_mask * 10000.0 if self.attention_softmax_in_fp32: attention_probs = self.softmax(attention_scores.float()).half() else: attention_probs = self.softmax(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with mpu.get_cuda_rng_tracker().fork(): attention_probs = self.attention_dropout(attention_probs) # ========================= # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [sq, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) # change view [sq, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0)) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) # # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) # ================= # Output. [sq, b, h] # ================= output, bias = self.dense(context_layer) if get_key_value: output = [output, present] return output, bias class ParallelTopQuerySelfAttention(MegatronModule): """Parallel top query self-attention layer abstract class. Self-attention layer takes input with size [b, s, h] and returns output of the same size. """ def __init__(self, init_method, output_layer_init_method, layer_number): super(ParallelTopQuerySelfAttention, self).__init__() args = get_args() self.fp16 = args.fp16 self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 self.layer_number = max(1, layer_number) if hasattr(args, 'attention_upweight_top'): self.attention_upweight = args.attention_upweight_top else: self.attention_upweight = None # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( args.hidden_size, args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( args.num_attention_heads, world_size) self.query = mpu.ColumnParallelLinear( args.hidden_size, args.hidden_size, gather_output=False, init_method=init_method) self.key = mpu.ColumnParallelLinear( args.hidden_size, args.hidden_size, gather_output=False, init_method=init_method) self.value = mpu.ColumnParallelLinear( args.hidden_size, args.hidden_size, gather_output=False, init_method=init_method) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.softmax = torch.nn.Softmax(dim=-1) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(args.attention_dropout) # Output. self.dense = mpu.RowParallelLinear( args.hidden_size, args.hidden_size, input_is_parallel=False, init_method=output_layer_init_method, skip_bias_add=True) def forward( self, hidden_states, query_hidden_state, attention_mask, layer_past=None, get_key_value=False, prompt_length=None, context_length=None, ): # hidden_states: [sq, b, h] query_layer, _ = self.query(query_hidden_state) key_layer, _ = self.key(hidden_states) value_layer, _ = self.value(hidden_states) new_query_layer_shape = query_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) query_layer = query_layer.view(*new_query_layer_shape) new_query_layer_shape = key_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) key_layer = key_layer.view(*new_query_layer_shape) new_query_layer_shape = value_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) value_layer = value_layer.view(*new_query_layer_shape) # ================================== # Adjust key and value for inference # ================================== if layer_past is not None: past_key, past_value = layer_past key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) if get_key_value: present = (key_layer, value_layer) # =================================== # Raw attention scores. [b, np, sq, sk] # =================================== # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) # [s, b, np, hn] -> [s, b * np, hn] query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1) key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1) # Raw attention scores. [b * np, sq, sk] matmul_result = torch.matmul(query_layer.transpose(0, 1), key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor # change view to [b, np, s, s] attention_scores = matmul_result.view(*output_size) if self.attention_upweight is not None and layer_past is None: log_attention_weights = torch.zeros(attention_scores.size(3), attention_scores.size(3), device=torch.cuda.current_device(), dtype=torch.half if self.fp16 else torch.float32) if prompt_length is None: log_attention_weights = self.attention_upweight else: log_attention_weights[:prompt_length, :prompt_length] = self.attention_upweight attention_scores += log_attention_weights # ================================================== # Update attention mask for inference. [b, np, sq, sk] # ================================================== if get_key_value: with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ ..., attention_scores.size(3) - 1, :attention_scores.size(3)].unsqueeze(2) else: attention_mask = attention_mask[ ..., :attention_scores.size(3), :attention_scores.size(3)] # =========================== # Attention probs and dropout # =========================== if context_length is not None: attention_mask = torch.clone(attention_mask) attention_mask[:, :, context_length:, :] = True # attention scores and attention mask [b, np, sq, sk] # attention_scores = attention_mask_func(attention_scores, attention_mask) attention_scores = attention_scores - attention_mask * 10000.0 if self.attention_softmax_in_fp32: attention_probs = self.softmax(attention_scores.float()).half() else: attention_probs = self.softmax(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with mpu.get_cuda_rng_tracker().fork(): attention_probs = self.attention_dropout(attention_probs) # ========================= # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [sq, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) # change view [sq, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0)) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) # ================= # Output. [sq, b, h] # ================= output, bias = self.dense(context_layer) if get_key_value: output = [output, present] return output, bias def bias_dropout_add(x, bias, residual, prob, training): # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor out = torch.nn.functional.dropout(x + bias, p=prob, training=training) out = residual + out return out def get_bias_dropout_add(training): def _bias_dropout_add(x, bias, residual, prob): return bias_dropout_add(x, bias, residual, prob, training) return _bias_dropout_add @torch.jit.script def bias_dropout_add_fused_train(x, bias, residual, prob): # type: (Tensor, Tensor, Tensor, float) -> Tensor return bias_dropout_add(x, bias, residual, prob, True) @torch.jit.script def bias_dropout_add_fused_inference(x, bias, residual, prob): # type: (Tensor, Tensor, Tensor, float) -> Tensor return bias_dropout_add(x, bias, residual, prob, False) class ParallelTransformerLayer(MegatronModule): """A single transformer layer. Transformore layer takes input with size [b, s, h] and returns an output of the same size. """ def __init__(self, init_method, output_layer_init_method, layer_number): args = get_args() super(ParallelTransformerLayer, self).__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm \ = args.apply_residual_connection_post_layernorm # Layernorm on the input data. self.input_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) # Self attention. self.attention = ParallelSelfAttention(init_method, output_layer_init_method, layer_number) self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion # Layernorm on the input data. self.post_attention_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) if hasattr(args, 'attention_upweight'): self.attention_upweight = args.attention_upweight else: self.attention_upweight = None if hasattr(args, 'ln_fp16'): self.ln_fp16 = args.ln_fp16 else: self.ln_fp16 = False # MLP self.mlp = ParallelMLP(init_method, output_layer_init_method) def forward( self, hidden_states, attention_mask, layer_past=None, get_key_value=False, prompt_length=None, context_length=None, ): # hidden_states: [b, s, h] if self.ln_fp16: layernorm_output = self.input_layernorm(hidden_states) else: layernorm_output = self.input_layernorm(hidden_states.float()).half() # Self attention. attention_output, attention_bias = \ self.attention(layernorm_output, attention_mask, layer_past=layer_past, get_key_value=get_key_value, prompt_length=prompt_length, context_length=context_length) if get_key_value: attention_output, presents = attention_output # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states # jit scripting for a nn.module (with dropout) is not # trigerring the fusion kernel. For now, we use two # different nn.functional routines to account for varying # dropout semantics during training and inference phases. if self.bias_dropout_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train else: bias_dropout_add_func = bias_dropout_add_fused_inference else: bias_dropout_add_func = get_bias_dropout_add(self.training) # re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout) # Layer norm post the self attention. if self.ln_fp16: layernorm_output = self.post_attention_layernorm(layernorm_input) else: layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half() mlp_output, _ = self.mlp(layernorm_output) # MLP. # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input output = mlp_output + residual if get_key_value: output = [output, presents] return output class ParallelTopQueryLayer(MegatronModule): """A single top query layer. Top query layer takes input with size [b, s, h] and returns an output of the same size. """ def __init__(self, init_method, output_layer_init_method, layer_number): args = get_args() super(ParallelTopQueryLayer, self).__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm \ = args.apply_residual_connection_post_layernorm # Layernorm on the input data. self.input_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) # Self attention. self.attention = ParallelTopQuerySelfAttention(init_method, output_layer_init_method, layer_number) self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion # Layernorm on the input data. self.post_attention_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) if hasattr(args, 'ln_fp16'): self.ln_fp16 = args.ln_fp16 else: self.ln_fp16 = False # MLP self.mlp = ParallelMLP(init_method, output_layer_init_method) def forward( self, hidden_states, query_hidden_state, attention_mask, layer_past=None, get_key_value=False, prompt_length=None, context_length=None, ): # hidden_states: [b, s, h] assert query_hidden_state != None # Layer norm at the begining of the transformer layer. if self.ln_fp16: layernorm_output = self.input_layernorm(hidden_states) else: layernorm_output = self.input_layernorm(hidden_states.float()).half() # Self attention. attention_output, attention_bias = \ self.attention(layernorm_output, query_hidden_state, attention_mask, layer_past=layer_past, get_key_value=get_key_value, prompt_length=prompt_length, context_length=context_length) if get_key_value: attention_output, presents = attention_output # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states # jit scripting for a nn.module (with dropout) is not # trigerring the fusion kernel. For now, we use two # different nn.functional routines to account for varying # dropout semantics during training and inference phases. if self.bias_dropout_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train else: bias_dropout_add_func = bias_dropout_add_fused_inference else: bias_dropout_add_func = get_bias_dropout_add(self.training) # re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout) # Layer norm post the self attention. if self.ln_fp16: layernorm_output = self.post_attention_layernorm(layernorm_input) else: layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half() # MLP. mlp_output, _ = self.mlp(layernorm_output) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input output = mlp_output + residual if get_key_value: output = [output, presents] return output class ParallelTransformer(MegatronModule): """Transformer class.""" def __init__(self, init_method, output_layer_init_method): super(ParallelTransformer, self).__init__() args = get_args() # Store activation checkpoiting flag. self.checkpoint_activations = args.checkpoint_activations self.checkpoint_num_layers = args.checkpoint_num_layers # Number of layers: self.num_layers = args.num_layers self.num_unique_layers = None ################# assert self.num_unique_layers is None ################# if self.num_unique_layers is None: self.num_unique_layers = self.num_layers assert self.num_layers % self.num_unique_layers == 0, \ 'number of layers should be divisible by number of unique layers' self.param_sharing_style = 'grouped' # Transformer layers. def build_layer(layer_number): return ParallelTransformerLayer( init_method, output_layer_init_method, layer_number) self.layers = torch.nn.ModuleList( [build_layer(i + 1) for i in range(self.num_unique_layers)]) self.topQueryLayer = ParallelTopQueryLayer( init_method, output_layer_init_method, self.num_unique_layers) # Final layer norm before output. if hasattr(args, 'ln_fp16'): self.ln_fp16 = args.ln_fp16 else: self.ln_fp16 = False self.final_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) def _get_layer_index(self, layer_number): if self.param_sharing_style == 'grouped': return layer_number % self.num_unique_layers if self.param_sharing_style == 'spaced': return layer_number // (self.num_layers // self.num_unique_layers) assert False, 'should not be here' def _get_layer(self, layer_number): return self.layers[self._get_layer_index(layer_number)] def _checkpointed_forward(self, hidden_states, attention_mask): """Forward method with activation checkpointing.""" def custom(start, end): def custom_forward(*inputs): x_ = inputs[0] for index in range(start, end): layer = self._get_layer(index) x_ = layer(x_, inputs[1]) return x_ return custom_forward # Make sure memory is freed. mpu.reset_checkpointed_activations_memory_buffer() l = 0 while l < self.num_layers: hidden_states = mpu.checkpoint( custom(l, l + self.checkpoint_num_layers), hidden_states, attention_mask) l += self.checkpoint_num_layers return hidden_states def forward( self, hidden_states, query_hidden_state, attention_mask, layer_past=None, get_key_value=False, prompt_length=None, context_length=None, ): # Checks if layer_past is not None: assert get_key_value, \ 'for not None values in layer_past, ' \ 'expected get_key_value to be set' if get_key_value: assert not self.checkpoint_activations, \ 'get_key_value does not work with ' \ 'activation checkpointing' # data format change to avoid explicit tranposes : [b s h] --> [s b h] hidden_states = hidden_states.transpose(0, 1).contiguous() query_hidden_state = query_hidden_state.transpose(0, 1).contiguous() if self.checkpoint_activations: hidden_states = self._checkpointed_forward(hidden_states, attention_mask) else: if get_key_value: presents = [] for index in range(self.num_layers): layer = self._get_layer(index) past = None if layer_past is not None: past = layer_past[index] hidden_states = layer(hidden_states, attention_mask, layer_past=past, get_key_value=get_key_value, prompt_length=prompt_length, context_length=context_length) if get_key_value: hidden_states, present = hidden_states presents.append(present) if self.ln_fp16: hidden_states_ = self.final_layernorm(hidden_states) else: hidden_states_ = self.final_layernorm(hidden_states.float()).half() ################################# # top query layer ################################# past = None if layer_past is not None: past = layer_past[self.num_layers] hidden_states = self.topQueryLayer(hidden_states_, query_hidden_state, attention_mask, layer_past=past, get_key_value=get_key_value, prompt_length=prompt_length, context_length=context_length) if get_key_value: hidden_states, present = hidden_states presents.append(present) # reverting data format change [s b h] --> [b s h] output = hidden_states.transpose(0, 1).contiguous() if get_key_value: output = [output, presents] return output