import math import oneflow as torch import oneflow.nn.functional as F from oneflow.nn.parameter import Parameter from ..quantization import QuantizedLinear def fast_gelu(x): """Mindspore's fast gelu implementation.""" if hasattr(torch._C, 'quick_gelu'): return torch._C.quick_gelu(x) return x / (1 + torch.exp(-1.702 * torch.abs(x))) * torch.exp(0.851 * (x - torch.abs(x))) class MLP(torch.nn.Module): """MLP. MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. At the end, dropout is also applied. """ def __init__( self, hidden_size, ): super(MLP, self).__init__() self.hidden_size = hidden_size # Project to 4h. self.dense_h_to_4h = torch.nn.Linear( self.hidden_size, 4 * self.hidden_size, ) self.activation_func = fast_gelu # Project back to h. self.dense_4h_to_h = torch.nn.Linear( 4 * self.hidden_size, self.hidden_size, ) def forward(self, hidden_states): # [s, b, 4hp] intermediate_parallel = self.dense_h_to_4h(hidden_states) intermediate_parallel = self.activation_func(intermediate_parallel) # [s, b, h] output = self.dense_4h_to_h(intermediate_parallel) return output class SelfAttention(torch.nn.Module): """self-attention layer abstract class. Self-attention layer takes input with size [b, s, h] and returns output of the same size. """ def __init__( self, hidden_size, num_attention_heads, layer_number, fp16=True, attention_softmax_in_fp32=True, ): super(SelfAttention, self).__init__() self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.fp16 = fp16 self.attention_softmax_in_fp32 = attention_softmax_in_fp32 self.layer_number = max(1, layer_number) assert self.hidden_size % self.num_attention_heads == 0 self.hidden_size_per_attention_head = int(self.hidden_size // self.num_attention_heads) self.query = torch.nn.Linear(self.hidden_size, self.hidden_size) self.key = torch.nn.Linear(self.hidden_size, self.hidden_size) self.value = torch.nn.Linear(self.hidden_size, self.hidden_size) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.softmax = torch.nn.Softmax(dim=-1) self.dense = torch.nn.Linear(self.hidden_size, self.hidden_size) def forward( self, hidden_states, attention_mask, layer_past=None, get_key_value=False, prompt_length=None, context_length=None, layer_id=0, ): # hidden_states: [sq, b, h] # ===================== # Query, Key, and Value # ===================== if hasattr(torch._C, 'grouped_matmul_bias') and not isinstance(self.query, QuantizedLinear): query_layer, key_layer, value_layer = torch._C.grouped_matmul_bias([hidden_states, hidden_states, hidden_states], [self.query.weight, self.key.weight, self.value.weight], [self.query.bias, self.key.bias, self.value.bias]) else: query_layer = self.query(hidden_states) key_layer = self.key(hidden_states) value_layer = self.value(hidden_states) fallback = not hasattr(torch._C, 'fused_multi_head_attention_inference_v2') if fallback: if hasattr(torch._C, 'fused_codegeex_qkv_reshape'): query_layer, key_layer, value_layer = torch._C.fused_codegeex_qkv_reshape(query_layer, key_layer, value_layer, self.num_attention_heads) else: new_query_layer_shape = query_layer.size()[:-1] + \ (self.num_attention_heads, self.hidden_size_per_attention_head) query_layer = query_layer.view(*new_query_layer_shape) new_query_layer_shape = key_layer.size()[:-1] + \ (self.num_attention_heads, self.hidden_size_per_attention_head) key_layer = key_layer.view(*new_query_layer_shape) new_query_layer_shape = value_layer.size()[:-1] + \ (self.num_attention_heads, self.hidden_size_per_attention_head) value_layer = value_layer.view(*new_query_layer_shape) # ================================== # Adjust key and value for inference # ================================== if layer_past is not None: past_key, past_value = layer_past key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) if get_key_value: present = (key_layer, value_layer) # =================================== # Raw attention scores. [b, np, sq, sk] # =================================== # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1) key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1) # Raw attention scores. [b * np, sq, sk] matmul_result = torch.matmul(query_layer.transpose(0, 1), key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # ================================================== # Update attention mask for inference. [b, np, sq, sk] # ================================================== if layer_id == 0: if get_key_value: with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ ..., attention_scores.size(3) - 1, :attention_scores.size(3)].unsqueeze(2) else: attention_mask = attention_mask[ ..., :attention_scores.size(3), :attention_scores.size(3)] if context_length is not None: attention_mask = torch.clone(attention_mask) attention_mask[:, :, context_length:, :] = True attention_mask = ~attention_mask attention_mask = attention_mask.contiguous() # attention scores and attention mask [b, np, sq, sk] # attention_scores = attention_mask_func(attention_scores, attention_mask) if hasattr(torch._C, 'fused_scale_mask_softmax'): if self.attention_softmax_in_fp32: attention_probs = torch._C.fused_scale_mask_softmax(attention_scores.float(), attention_mask, fill_value=-10000.0, scale=1.0).half() else: attention_probs = torch._C.fused_scale_mask_softmax(attention_scores, attention_mask, fill_value=-10000.0, scale=1.0) else: attention_scores = attention_scores - attention_mask * 10000.0 if self.attention_softmax_in_fp32: attention_probs = self.softmax(attention_scores.float()).half() else: attention_probs = self.softmax(attention_scores) # ========================= # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [sq, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) # change view [sq, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0)) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) # # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ (self.hidden_size,) context_layer = context_layer.view(*new_context_layer_shape) else: if layer_past is not None: past_key, past_value = layer_past key_layer, value_layer = torch._C.fused_attention_concat_past_key_value( past_key=past_key, past_key_layout="MB(HK)", past_value=past_value, past_value_layout="MB(HK)", key=key_layer, key_layout="MB(HK)", value=value_layer, value_layout="MB(HK)", key_head_size=self.hidden_size_per_attention_head, ) if get_key_value: present = (key_layer, value_layer) context_layer = torch._C.fused_multi_head_attention_inference_v2( query=query_layer, key=key_layer, value=value_layer, query_head_size=self.hidden_size_per_attention_head, causal=True, causal_diagonal_offset=key_layer.shape[0]-query_layer.shape[0], query_layout="MB(HK)", key_layout="MB(HK)", value_layout="MB(HK)", output_layout="MB(HK)", ) # ================= # Output. [sq, b, h] # ================= output = self.dense(context_layer) if get_key_value: output = [output, present] return output, attention_mask class TopQuerySelfAttention(torch.nn.Module): """Top query self-attention layer abstract class. Self-attention layer takes input with size [b, s, h] and returns output of the same size. """ def __init__( self, hidden_size, num_attention_heads, layer_number, fp16=True, attention_softmax_in_fp32=True, ): super(TopQuerySelfAttention, self).__init__() self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.fp16 = fp16 self.attention_softmax_in_fp32 = attention_softmax_in_fp32 self.layer_number = max(1, layer_number) assert self.hidden_size % self.num_attention_heads == 0 self.hidden_size_per_attention_head = int(self.hidden_size // self.num_attention_heads) self.query = torch.nn.Linear(self.hidden_size, self.hidden_size) self.key = torch.nn.Linear(self.hidden_size, self.hidden_size) self.value = torch.nn.Linear(self.hidden_size, self.hidden_size) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.softmax = torch.nn.Softmax(dim=-1) self.dense = torch.nn.Linear(self.hidden_size, self.hidden_size) def forward( self, hidden_states, query_hidden_state, attention_mask, layer_past=None, get_key_value=False, prompt_length=None, context_length=None, ): # hidden_states: [sq, b, h] if hasattr(torch._C, 'grouped_matmul_bias') and not isinstance(self.query, QuantizedLinear): query_layer, key_layer, value_layer = torch._C.grouped_matmul_bias([query_hidden_state, hidden_states, hidden_states], [self.query.weight, self.key.weight, self.value.weight], [self.query.bias, self.key.bias, self.value.bias]) else: query_layer = self.query(query_hidden_state) key_layer = self.key(hidden_states) value_layer = self.value(hidden_states) fallback = not hasattr(torch._C, 'fused_multi_head_attention_inference_v2') if fallback: if hasattr(torch._C, 'fused_codegeex_qkv_reshape'): query_layer, key_layer, value_layer = torch._C.fused_codegeex_qkv_reshape(query_layer, key_layer, value_layer, self.num_attention_heads) else: new_query_layer_shape = query_layer.size()[:-1] + \ (self.num_attention_heads, self.hidden_size_per_attention_head) query_layer = query_layer.view(*new_query_layer_shape) new_query_layer_shape = key_layer.size()[:-1] + \ (self.num_attention_heads, self.hidden_size_per_attention_head) key_layer = key_layer.view(*new_query_layer_shape) new_query_layer_shape = value_layer.size()[:-1] + \ (self.num_attention_heads, self.hidden_size_per_attention_head) value_layer = value_layer.view(*new_query_layer_shape) # ================================== # Adjust key and value for inference # ================================== if layer_past is not None: past_key, past_value = layer_past key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) if get_key_value: present = (key_layer, value_layer) # =================================== # Raw attention scores. [b, np, sq, sk] # =================================== # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) # [s, b, np, hn] -> [s, b * np, hn] query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1) key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1) # Raw attention scores. [b * np, sq, sk] matmul_result = torch.matmul(query_layer.transpose(0, 1), key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor # change view to [b, np, s, s] attention_scores = matmul_result.view(*output_size) # ================================================== # Update attention mask for inference. [b, np, sq, sk] # ================================================== if get_key_value: with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ ..., attention_scores.size(3) - 1, :attention_scores.size(3)].unsqueeze(2) else: attention_mask = attention_mask[ ..., :attention_scores.size(3), :attention_scores.size(3)] if context_length is not None: attention_mask = torch.clone(attention_mask) attention_mask[:, :, context_length:, :] = True # attention scores and attention mask [b, np, sq, sk] # attention_scores = attention_mask_func(attention_scores, attention_mask) attention_scores = attention_scores - attention_mask * 10000.0 if self.attention_softmax_in_fp32: attention_probs = self.softmax(attention_scores.float()).half() else: attention_probs = self.softmax(attention_scores) # ========================= # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [sq, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) # change view [sq, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0)) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ (self.hidden_size,) context_layer = context_layer.view(*new_context_layer_shape) else: if layer_past is not None: past_key, past_value = layer_past key_layer, value_layer = torch._C.fused_attention_concat_past_key_value( past_key=past_key, past_key_layout="MB(HK)", past_value=past_value, past_value_layout="MB(HK)", key=key_layer, key_layout="MB(HK)", value=value_layer, value_layout="MB(HK)", key_head_size=self.hidden_size_per_attention_head, ) if get_key_value: present = (key_layer, value_layer) if hasattr(torch._C, 'fused_multi_head_attention_inference_v2'): context_layer = torch._C.fused_multi_head_attention_inference_v2( query=query_layer, key=key_layer, value=value_layer, query_head_size=self.hidden_size_per_attention_head, causal=True, causal_diagonal_offset=key_layer.shape[0]-query_layer.shape[0], query_layout="MB(HK)", key_layout="MB(HK)", value_layout="MB(HK)", output_layout="MB(HK)", ) # ================= # Output. [sq, b, h] # ================= output = self.dense(context_layer) if get_key_value: output = [output, present] return output class TransformerLayer(torch.nn.Module): """A single transformer layer. Transformore layer takes input with size [b, s, h] and returns an output of the same size. """ def __init__( self, hidden_size, num_attention_heads, layer_number, layernorm_epsilon=1e-5, fp16=True, attention_softmax_in_fp32=True, ): super(TransformerLayer, self).__init__() self.hidden_size = hidden_size self.layernorm_epsilon = layernorm_epsilon self.layer_number = layer_number # Layernorm on the input data. self.input_layernorm = torch.nn.LayerNorm(hidden_size, eps=self.layernorm_epsilon) # Self attention. self.attention = SelfAttention(hidden_size, num_attention_heads, layer_number, fp16, attention_softmax_in_fp32) # Layernorm on the input data. self.post_attention_layernorm = torch.nn.LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) self.mlp = MLP(self.hidden_size) def forward( self, hidden_states, attention_mask, layer_past=None, get_key_value=False, prompt_length=None, context_length=None, layer_id=0, ): # hidden_states: [b, s, h] # Use FP32 for Layernorm # layernorm_output = self.input_layernorm(hidden_states.float()).half() layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, attention_mask = self.attention(layernorm_output, attention_mask, layer_past=layer_past, get_key_value=get_key_value, prompt_length=prompt_length, context_length=context_length, layer_id=layer_id) if get_key_value: attention_output, presents = attention_output # Residual connection. residual = hidden_states layernorm_input = attention_output + residual # Use FP32 for Layernorm # layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half() layernorm_output = self.post_attention_layernorm(layernorm_input) mlp_output = self.mlp(layernorm_output) output = mlp_output + layernorm_input if get_key_value: output = [output, presents] return output, attention_mask class TopQueryLayer(torch.nn.Module): """A single top query layer. Top query layer takes input with size [b, s, h] and returns an output of the same size. """ def __init__( self, hidden_size, num_attention_heads, layer_number, layernorm_epsilon=1e-5, ): super(TopQueryLayer, self).__init__() self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.layernorm_epsilon = layernorm_epsilon self.layer_number = layer_number # Use FP32 for Layernorm self.input_layernorm = torch.nn.LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) # Self attention. self.attention = TopQuerySelfAttention(self.hidden_size, self.num_attention_heads, self.layer_number) # Layernorm on the input data. self.post_attention_layernorm = torch.nn.LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) # MLP self.mlp = MLP(self.hidden_size) def forward( self, hidden_states, query_hidden_state, attention_mask, layer_past=None, get_key_value=False, prompt_length=None, context_length=None, ): # hidden_states: [b, s, h] assert query_hidden_state != None # Use FP32 for Layernorm # layernorm_output = self.input_layernorm(hidden_states.float()).half() layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output = self.attention(layernorm_output, query_hidden_state, attention_mask, layer_past=layer_past, get_key_value=get_key_value, prompt_length=prompt_length, context_length=context_length) if get_key_value: attention_output, presents = attention_output # Residual connection. residual = hidden_states layernorm_input = attention_output + residual # Use FP32 for Layernorm # layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half() layernorm_output = self.post_attention_layernorm(layernorm_input) # MLP. mlp_output = self.mlp(layernorm_output) # Second residual connection. residual = layernorm_input output = mlp_output + residual if get_key_value: output = [output, presents] return output class Transformer(torch.nn.Module): """Transformer class.""" def __init__( self, hidden_size, num_attention_heads, num_layers, layernorm_epsilon=1e-5, ): super(Transformer, self).__init__() self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.layernorm_epsilon = layernorm_epsilon # Number of layers: self.num_layers = num_layers self.num_unique_layers = None ################# assert self.num_unique_layers is None ################# if self.num_unique_layers is None: self.num_unique_layers = self.num_layers assert self.num_layers % self.num_unique_layers == 0, \ 'number of layers should be divisible by number of unique layers' # Transformer layers. def build_layer(layer_number): return TransformerLayer(self.hidden_size, self.num_attention_heads, layer_number) self.layers = torch.nn.ModuleList( [build_layer(i + 1) for i in range(self.num_unique_layers)]) self.topQueryLayer = TopQueryLayer(self.hidden_size, self.num_attention_heads, self.num_unique_layers) self.final_layernorm = torch.nn.LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) def _get_layer_index(self, layer_number): return layer_number % self.num_unique_layers def _get_layer(self, layer_number): return self.layers[self._get_layer_index(layer_number)] def forward( self, hidden_states, query_hidden_state, attention_mask, layer_past=None, get_key_value=False, prompt_length=None, context_length=None, ): # data format change to avoid explicit tranposes : [b s h] --> [s b h] hidden_states = hidden_states.transpose(0, 1).contiguous() query_hidden_state = query_hidden_state.transpose(0, 1).contiguous() origin_attention_mask = attention_mask if get_key_value: presents = [] for index in range(self.num_layers): layer = self._get_layer(index) past = None if layer_past is not None: past = layer_past[index] hidden_states, attention_mask = layer(hidden_states, attention_mask, layer_past=past, get_key_value=get_key_value, prompt_length=prompt_length, context_length=context_length, layer_id=index) if get_key_value: hidden_states, present = hidden_states presents.append(present) # Use FP32 for Layernorm # hidden_states_ = self.final_layernorm(hidden_states.float()).half() hidden_states_ = self.final_layernorm(hidden_states) ################################# # top query layer ################################# past = None if layer_past is not None: past = layer_past[self.num_layers] hidden_states = self.topQueryLayer(hidden_states_, query_hidden_state, origin_attention_mask, layer_past=past, get_key_value=get_key_value, prompt_length=prompt_length, context_length=context_length) if get_key_value: hidden_states, present = hidden_states presents.append(present) # reverting data format change [s b h] --> [b s h] output = hidden_states.transpose(0, 1).contiguous() if get_key_value: output = [output, presents] return output def state_dict_for_save_checkpoint( self, destination=None, prefix="", keep_vars=False ): return self.state_dict(destination, prefix, keep_vars) class Embedding(torch.nn.Module): """Language model embeddings. Arguments: hidden_size: hidden size vocab_size: vocabulary size max_sequence_length: maximum size of sequence. This is used for positional embedding """ def __init__( self, hidden_size, vocab_size, max_sequence_length, ): super(Embedding, self).__init__() self.hidden_size = hidden_size self.vocab_size = vocab_size self.max_sequence_length = max_sequence_length # Word embeddings. self.word_embeddings = torch.nn.Embedding(self.vocab_size, self.hidden_size) self._word_embeddings_key = 'word_embeddings' # Position embedding. self.position_embeddings = torch.nn.Embedding(self.max_sequence_length, self.hidden_size) self.position_embeddings = self.position_embeddings.half() self._position_embeddings_key = 'position_embeddings' def forward(self, input_ids, position_ids): # Embeddings. words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) embeddings = words_embeddings + position_embeddings return embeddings def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): """For easy load.""" state_dict_ = {} state_dict_[self._word_embeddings_key] \ = self.word_embeddings.state_dict(destination, prefix, keep_vars) state_dict_[self._position_embeddings_key] \ = self.position_embeddings.state_dict( destination, prefix, keep_vars) return state_dict_ def load_state_dict(self, state_dict, strict=True): """Customized load.""" # Word embedding. if self._word_embeddings_key in state_dict: state_dict_ = state_dict[self._word_embeddings_key] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if 'word_embeddings' in key: state_dict_[key.split('word_embeddings.')[1]] \ = state_dict[key] state_dict_["weight"] = state_dict_["weight"][:self.vocab_size] self.word_embeddings.load_state_dict(state_dict_, strict=strict) # Position embedding. if self._position_embeddings_key in state_dict: state_dict_ = state_dict[self._position_embeddings_key] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if 'position_embeddings' in key: state_dict_[key.split('position_embeddings.')[1]] \ = state_dict[key] self.position_embeddings.load_state_dict(state_dict_, strict=strict) class QueryEmbedding(torch.nn.Module): """Language model embeddings. Arguments: hidden_size: hidden size vocab_size: vocabulary size max_sequence_length: maximum size of sequence. This is used for positional embedding """ def __init__( self, hidden_size, vocab_size, max_sequence_length, ): super(QueryEmbedding, self).__init__() self.hidden_size = hidden_size self.vocab_size = vocab_size self.max_sequence_length = max_sequence_length # Top query position embedding (serial). self.top_query_embeddings = torch.nn.Embedding(self.max_sequence_length, self.hidden_size) self.top_query_embeddings = self.top_query_embeddings.half() self._top_query_embeddings_key = 'top_query_embeddings' def forward(self, position_ids): # Embeddings. embeddings = self.top_query_embeddings(position_ids) return embeddings def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): """For easy load.""" state_dict_ = {} state_dict_[self._top_query_embeddings_key] \ = self.top_query_embeddings.state_dict( destination, prefix, keep_vars) return state_dict_ def load_state_dict(self, state_dict, strict=True): """Customized load.""" # Position embedding. if self._top_query_embeddings_key in state_dict: state_dict_ = state_dict[self._top_query_embeddings_key] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if 'top_query_embeddings' in key: state_dict_[key.split('top_query_embeddings.')[1]] \ = state_dict[key] self.top_query_embeddings.load_state_dict(state_dict_, strict=strict) class TransformerLanguageModel(torch.nn.Module): """Transformer language model. Arguments: transformer_hparams: transformer hyperparameters attention_mask_func: a function that takes `unmaksed-attention-scores` with size [b, np, s, s] and an `attention-mask` and will apply the masking. The function should return a masked score of the same size [b, np, s, s]. masked-attention-scores = attention_mask_func( unmaksed-attention-scores, attention-mask) vocab_size: vocabulary size max_sequence_length: maximum size of sequence. This is used for positional embedding """ def __init__( self, hidden_size, num_layers, num_attention_heads, padded_vocab_size, max_position_embeddings, ): super(TransformerLanguageModel, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.num_attention_heads = num_attention_heads self.padded_vocab_size = padded_vocab_size self.max_position_embeddings = max_position_embeddings # Embeddings self.embedding = Embedding(self.hidden_size, self.padded_vocab_size, self.max_position_embeddings) self._embedding_key = 'embedding' # Query embeddings self.topQueryEmbedding = QueryEmbedding(self.hidden_size, self.padded_vocab_size, self.max_position_embeddings) self._topQueryEmbedding_key = 'topQueryEmbedding' # Transformer self.transformer = Transformer(self.hidden_size, self.num_attention_heads, self.num_layers) self._transformer_key = 'transformer' def forward( self, input_ids, position_ids, attention_mask, layer_past=None, get_key_value=False, prompt_length=None, context_length=None, ): # Embeddings. embedding_output = self.embedding(input_ids, position_ids) query_position_ids = position_ids queryEmbedding_out = self.topQueryEmbedding(query_position_ids) # Transformer. transformer_output = self.transformer(embedding_output, queryEmbedding_out, attention_mask, layer_past=layer_past, get_key_value=get_key_value, prompt_length=prompt_length, context_length=context_length) return transformer_output def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): """For easy load.""" state_dict_ = {} state_dict_[self._embedding_key] \ = self.embedding.state_dict_for_save_checkpoint( destination, prefix, keep_vars) state_dict_[self._topQueryEmbedding_key] \ = self.topQueryEmbedding.state_dict_for_save_checkpoint( destination, prefix, keep_vars) state_dict_[self._transformer_key] \ = self.transformer.state_dict_for_save_checkpoint( destination, prefix, keep_vars) return state_dict_ def load_state_dict(self, state_dict, strict=True): """Customized load.""" # Embedding. if self._embedding_key in state_dict: state_dict_ = state_dict[self._embedding_key] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if '_embeddings' in key: state_dict_[key] = state_dict[key] self.embedding.load_state_dict(state_dict_, strict=strict) if self._topQueryEmbedding_key in state_dict: state_dict_ = state_dict[self._topQueryEmbedding_key] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if '_embeddings' in key: state_dict_[key] = state_dict[key] self.topQueryEmbedding.load_state_dict(state_dict_, strict=strict) # Transformer. if self._transformer_key in state_dict: state_dict_ = state_dict[self._transformer_key] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if 'transformer.' in key: state_dict_[key.split('transformer.')[1]] = state_dict[key] self.transformer.load_state_dict(state_dict_, strict=strict) class CodeGeeXModel(torch.nn.Module): """CodeGeeX: A Multilingual Code Generation Model.""" def __init__( self, hidden_size, num_layers, num_attention_heads, padded_vocab_size, max_position_embeddings, ): super(CodeGeeXModel, self).__init__() self.language_model = TransformerLanguageModel(hidden_size, num_layers, num_attention_heads, padded_vocab_size, max_position_embeddings) self._language_model_key = "language_model" def forward( self, input_ids, position_ids, attention_mask, layer_past=None, get_key_value=False, prompt_length=None, context_length=None, ): # Language model. lm_output = self.language_model(input_ids, position_ids, attention_mask, layer_past=layer_past, get_key_value=get_key_value, prompt_length=prompt_length, context_length=context_length) if get_key_value: lm_output, presents = lm_output output = F.linear(lm_output, self.language_model.embedding.word_embeddings.weight.half()) if get_key_value: output = [output, presents] return output def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): state_dict_ = {} state_dict_[self._language_model_key] \ = self.language_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars) return state_dict_ def load_state_dict(self, state_dict, strict=True): """Customized load.""" if self._language_model_key in state_dict: state_dict = state_dict[self._language_model_key] self.language_model.load_state_dict(state_dict, strict=strict)