diff --git a/codegeex/oneflow/__init__.py b/codegeex/oneflow/__init__.py new file mode 100644 index 0000000..16975c0 --- /dev/null +++ b/codegeex/oneflow/__init__.py @@ -0,0 +1 @@ +from .codegeex_model import CodeGeeXModel \ No newline at end of file diff --git a/codegeex/oneflow/__pycache__/__init__.cpython-38.pyc b/codegeex/oneflow/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..632fead Binary files /dev/null and b/codegeex/oneflow/__pycache__/__init__.cpython-38.pyc differ diff --git a/codegeex/oneflow/__pycache__/codegeex_model.cpython-38.pyc b/codegeex/oneflow/__pycache__/codegeex_model.cpython-38.pyc new file mode 100644 index 0000000..1a4f515 Binary files /dev/null and b/codegeex/oneflow/__pycache__/codegeex_model.cpython-38.pyc differ diff --git a/codegeex/oneflow/__pycache__/inference.cpython-38.pyc b/codegeex/oneflow/__pycache__/inference.cpython-38.pyc new file mode 100644 index 0000000..dc8f57d Binary files /dev/null and b/codegeex/oneflow/__pycache__/inference.cpython-38.pyc differ diff --git a/codegeex/oneflow/codegeex_model.py b/codegeex/oneflow/codegeex_model.py new file mode 100644 index 0000000..9c1cfdb --- /dev/null +++ b/codegeex/oneflow/codegeex_model.py @@ -0,0 +1,1047 @@ +import math +import oneflow as torch +import oneflow.nn.functional as F +from oneflow.nn.parameter import Parameter + + +def fast_gelu(x): + """Mindspore's fast gelu implementation.""" + if hasattr(torch._C, 'quick_gelu'): + return torch._C.quick_gelu(x) + return x / (1 + torch.exp(-1.702 * torch.abs(x))) * torch.exp(0.851 * (x - torch.abs(x))) + + +class MLP(torch.nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. At the end, dropout is also + applied. + """ + + def __init__( + self, + hidden_size, + ): + super(MLP, self).__init__() + self.hidden_size = hidden_size + # Project to 4h. + self.dense_h_to_4h = torch.nn.Linear( + self.hidden_size, + 4 * self.hidden_size, + ) + + self.activation_func = fast_gelu + + # Project back to h. + self.dense_4h_to_h = torch.nn.Linear( + 4 * self.hidden_size, + self.hidden_size, + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + + return output + + +class SelfAttention(torch.nn.Module): + """self-attention layer abstract class. + + Self-attention layer takes input with size [b, s, h] + and returns output of the same size. + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + layer_number, + fp16=True, + attention_softmax_in_fp32=True, + ): + super(SelfAttention, self).__init__() + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.fp16 = fp16 + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.layer_number = max(1, layer_number) + + assert self.hidden_size % self.num_attention_heads == 0 + self.hidden_size_per_attention_head = int(self.hidden_size // self.num_attention_heads) + + self.query = torch.nn.Linear(self.hidden_size, self.hidden_size) + self.key = torch.nn.Linear(self.hidden_size, self.hidden_size) + self.value = torch.nn.Linear(self.hidden_size, self.hidden_size) + + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + self.softmax = torch.nn.Softmax(dim=-1) + + self.dense = torch.nn.Linear(self.hidden_size, self.hidden_size) + + def forward( + self, + hidden_states, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, + layer_id=0, + ): + # hidden_states: [sq, b, h] + + # ===================== + # Query, Key, and Value + # ===================== + + if hasattr(torch._C, 'grouped_matmul_bias'): + query_layer, key_layer, value_layer = torch._C.grouped_matmul_bias([hidden_states, hidden_states, hidden_states], + [self.query.weight, self.key.weight, self.value.weight], + [self.query.bias, self.key.bias, self.value.bias]) + else: + query_layer = self.query(hidden_states) + key_layer = self.key(hidden_states) + value_layer = self.value(hidden_states) + + new_query_layer_shape = query_layer.size()[:-1] + \ + (self.num_attention_heads, + self.hidden_size_per_attention_head) + query_layer = query_layer.view(*new_query_layer_shape) + + new_query_layer_shape = key_layer.size()[:-1] + \ + (self.num_attention_heads, + self.hidden_size_per_attention_head) + key_layer = key_layer.view(*new_query_layer_shape) + + new_query_layer_shape = value_layer.size()[:-1] + \ + (self.num_attention_heads, + self.hidden_size_per_attention_head) + value_layer = value_layer.view(*new_query_layer_shape) + + # ================================== + # Adjust key and value for inference + # ================================== + + if layer_past is not None: + past_key, past_value = layer_past + key_layer = torch.cat((past_key.type_as(key_layer), + key_layer), dim=0) + value_layer = torch.cat((past_value.type_as(value_layer), + value_layer), dim=0) + if get_key_value: + present = (key_layer, value_layer) + + origin_query_layer = query_layer + origin_key_layer = key_layer + origin_value_layer = value_layer + + if hasattr(torch._C, 'fused_multi_head_attention_inference'): + if layer_past is not None: + context_layer = torch._C.fused_multi_head_attention_inference( + origin_query_layer.view(query_layer.size()[1], query_layer.size()[0], -1), origin_key_layer.view(key_layer.size()[1], key_layer.size()[0], -1), origin_value_layer.view(value_layer.size()[1], value_layer.size()[0], -1), self.num_attention_heads, causal=False + ).transpose(0, 1) + else: + context_layer = torch._C.fused_multi_head_attention_inference( + origin_query_layer.view(query_layer.size()[1], query_layer.size()[0], -1), origin_key_layer.view(key_layer.size()[1], key_layer.size()[0], -1), origin_value_layer.view(value_layer.size()[1], value_layer.size()[0], -1), self.num_attention_heads, causal=True + ).transpose(0, 1) + else: + # =================================== + # Raw attention scores. [b, np, sq, sk] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1) + key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.matmul(query_layer.transpose(0, 1), + key_layer.permute(1, 2, 0)) / self.norm_factor + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # ================================================== + # Update attention mask for inference. [b, np, sq, sk] + # ================================================== + + if layer_id == 0: + if get_key_value: + with torch.no_grad(): + if layer_past is not None: + attention_mask = attention_mask[ + ..., + attention_scores.size(3) - 1, + :attention_scores.size(3)].unsqueeze(2) + else: + attention_mask = attention_mask[ + ..., + :attention_scores.size(3), + :attention_scores.size(3)] + + if context_length is not None: + attention_mask = torch.clone(attention_mask) + attention_mask[:, :, context_length:, :] = True + + attention_scores = attention_scores - attention_mask * 10000.0 + if self.attention_softmax_in_fp32: + attention_probs = self.softmax(attention_scores.float()).half() + else: + attention_probs = self.softmax(attention_scores) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sq, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3)) + + # change view [sq, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], + output_size[2], -1) + + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + if get_key_value: + output = [output, present] + + return output, attention_mask + + +class TopQuerySelfAttention(torch.nn.Module): + """Top query self-attention layer abstract class. + + Self-attention layer takes input with size [b, s, h] + and returns output of the same size. + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + layer_number, + fp16=True, + attention_softmax_in_fp32=True, + ): + super(TopQuerySelfAttention, self).__init__() + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.fp16 = fp16 + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.layer_number = max(1, layer_number) + + assert self.hidden_size % self.num_attention_heads == 0 + self.hidden_size_per_attention_head = int(self.hidden_size // self.num_attention_heads) + + self.query = torch.nn.Linear(self.hidden_size, self.hidden_size) + self.key = torch.nn.Linear(self.hidden_size, self.hidden_size) + self.value = torch.nn.Linear(self.hidden_size, self.hidden_size) + + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + self.softmax = torch.nn.Softmax(dim=-1) + + self.dense = torch.nn.Linear(self.hidden_size, self.hidden_size) + + def forward( + self, + hidden_states, + query_hidden_state, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, + ): + + # hidden_states: [sq, b, h] + if hasattr(torch._C, 'grouped_matmul_bias'): + query_layer, key_layer, value_layer = torch._C.grouped_matmul_bias([query_hidden_state, hidden_states, hidden_states], + [self.query.weight, self.key.weight, self.value.weight], + [self.query.bias, self.key.bias, self.value.bias]) + else: + query_layer = self.query(query_hidden_state) + key_layer = self.key(hidden_states) + value_layer = self.value(hidden_states) + + new_query_layer_shape = query_layer.size()[:-1] + \ + (self.num_attention_heads, + self.hidden_size_per_attention_head) + query_layer = query_layer.view(*new_query_layer_shape) + + new_query_layer_shape = key_layer.size()[:-1] + \ + (self.num_attention_heads, + self.hidden_size_per_attention_head) + key_layer = key_layer.view(*new_query_layer_shape) + + new_query_layer_shape = value_layer.size()[:-1] + \ + (self.num_attention_heads, + self.hidden_size_per_attention_head) + value_layer = value_layer.view(*new_query_layer_shape) + + # ================================== + # Adjust key and value for inference + # ================================== + + if layer_past is not None: + past_key, past_value = layer_past + key_layer = torch.cat((past_key.type_as(key_layer), + key_layer), dim=0) + value_layer = torch.cat((past_value.type_as(value_layer), + value_layer), dim=0) + if get_key_value: + present = (key_layer, value_layer) + + # =================================== + # Raw attention scores. [b, np, sq, sk] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0)) + + # [s, b, np, hn] -> [s, b * np, hn] + query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1) + key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.matmul(query_layer.transpose(0, 1), + key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor + + # change view to [b, np, s, s] + attention_scores = matmul_result.view(*output_size) + + # ================================================== + # Update attention mask for inference. [b, np, sq, sk] + # ================================================== + + if get_key_value: + with torch.no_grad(): + if layer_past is not None: + attention_mask = attention_mask[ + ..., + attention_scores.size(3) - 1, + :attention_scores.size(3)].unsqueeze(2) + else: + attention_mask = attention_mask[ + ..., + :attention_scores.size(3), + :attention_scores.size(3)] + + if context_length is not None: + attention_mask = torch.clone(attention_mask) + attention_mask[:, :, context_length:, :] = True + + # attention scores and attention mask [b, np, sq, sk] + # attention_scores = attention_mask_func(attention_scores, attention_mask) + if hasattr(torch._C, 'fused_scale_mask_softmax'): + attention_mask = ~attention_mask + if self.attention_softmax_in_fp32: + attention_probs = torch._C.fused_scale_mask_softmax(attention_scores.float(), attention_mask, fill_value=-10000.0, scale=1.0).half() + else: + attention_probs = torch._C.fused_scale_mask_softmax(attention_scores, attention_mask, fill_value=-10000.0, scale=1.0) + else: + attention_scores = attention_scores - attention_mask * 10000.0 + if self.attention_softmax_in_fp32: + attention_probs = self.softmax(attention_scores.float()).half() + else: + attention_probs = self.softmax(attention_scores) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sq, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3)) + + # change view [sq, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], + output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + if get_key_value: + output = [output, present] + + return output + + +class TransformerLayer(torch.nn.Module): + """A single transformer layer. + + Transformore layer takes input with size [b, s, h] and returns an + output of the same size. + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + layer_number, + layernorm_epsilon=1e-5, + fp16=True, + attention_softmax_in_fp32=True, + ): + super(TransformerLayer, self).__init__() + self.hidden_size = hidden_size + self.layernorm_epsilon = layernorm_epsilon + self.layer_number = layer_number + + # Layernorm on the input data. + self.input_layernorm = torch.nn.LayerNorm(hidden_size, + eps=self.layernorm_epsilon) + + # Self attention. + self.attention = SelfAttention(hidden_size, + num_attention_heads, + layer_number, + fp16, + attention_softmax_in_fp32) + + # Layernorm on the input data. + self.post_attention_layernorm = torch.nn.LayerNorm(self.hidden_size, + eps=self.layernorm_epsilon) + self.mlp = MLP(self.hidden_size) + + def forward( + self, + hidden_states, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, + layer_id=0, + ): + # hidden_states: [b, s, h] + # Use FP32 for Layernorm + # layernorm_output = self.input_layernorm(hidden_states.float()).half() + layernorm_output = self.input_layernorm(hidden_states) + + # Self attention. + attention_output, attention_mask = self.attention(layernorm_output, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + layer_id=layer_id) + + if get_key_value: + attention_output, presents = attention_output + + # Residual connection. + residual = hidden_states + layernorm_input = attention_output + residual + + # Use FP32 for Layernorm + # layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half() + layernorm_output = self.post_attention_layernorm(layernorm_input) + mlp_output = self.mlp(layernorm_output) + output = mlp_output + layernorm_input + + if get_key_value: + output = [output, presents] + + return output, attention_mask + + +class TopQueryLayer(torch.nn.Module): + """A single top query layer. + + Top query layer takes input with size [b, s, h] and returns an + output of the same size. + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + layer_number, + layernorm_epsilon=1e-5, + ): + super(TopQueryLayer, self).__init__() + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.layernorm_epsilon = layernorm_epsilon + self.layer_number = layer_number + + # Use FP32 for Layernorm + self.input_layernorm = torch.nn.LayerNorm(self.hidden_size, + eps=self.layernorm_epsilon) + + # Self attention. + self.attention = TopQuerySelfAttention(self.hidden_size, + self.num_attention_heads, + self.layer_number) + # Layernorm on the input data. + self.post_attention_layernorm = torch.nn.LayerNorm(self.hidden_size, + eps=self.layernorm_epsilon) + + # MLP + self.mlp = MLP(self.hidden_size) + + def forward( + self, + hidden_states, + query_hidden_state, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, + ): + # hidden_states: [b, s, h] + assert query_hidden_state != None + + # Use FP32 for Layernorm + # layernorm_output = self.input_layernorm(hidden_states.float()).half() + layernorm_output = self.input_layernorm(hidden_states) + + # Self attention. + attention_output = self.attention(layernorm_output, + query_hidden_state, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length) + + if get_key_value: + attention_output, presents = attention_output + + # Residual connection. + residual = hidden_states + layernorm_input = attention_output + residual + + # Use FP32 for Layernorm + # layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half() + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + residual = layernorm_input + output = mlp_output + residual + + if get_key_value: + output = [output, presents] + + return output + + +class Transformer(torch.nn.Module): + """Transformer class.""" + + def __init__( + self, + hidden_size, + num_attention_heads, + num_layers, + layernorm_epsilon=1e-5, + ): + super(Transformer, self).__init__() + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.layernorm_epsilon = layernorm_epsilon + # Number of layers: + self.num_layers = num_layers + self.num_unique_layers = None + + ################# + assert self.num_unique_layers is None + ################# + + if self.num_unique_layers is None: + self.num_unique_layers = self.num_layers + assert self.num_layers % self.num_unique_layers == 0, \ + 'number of layers should be divisible by number of unique layers' + + # Transformer layers. + def build_layer(layer_number): + return TransformerLayer(self.hidden_size, self.num_attention_heads, layer_number) + + self.layers = torch.nn.ModuleList( + [build_layer(i + 1) for i in range(self.num_unique_layers)]) + + self.topQueryLayer = TopQueryLayer(self.hidden_size, + self.num_attention_heads, + self.num_unique_layers) + + self.final_layernorm = torch.nn.LayerNorm(self.hidden_size, + eps=self.layernorm_epsilon) + + def _get_layer_index(self, layer_number): + return layer_number % self.num_unique_layers + + def _get_layer(self, layer_number): + return self.layers[self._get_layer_index(layer_number)] + + def forward( + self, + hidden_states, + query_hidden_state, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, + ): + # data format change to avoid explicit tranposes : [b s h] --> [s b h] + hidden_states = hidden_states.transpose(0, 1).contiguous() + query_hidden_state = query_hidden_state.transpose(0, 1).contiguous() + + origin_attention_mask = attention_mask + if get_key_value: + presents = [] + for index in range(self.num_layers): + layer = self._get_layer(index) + past = None + if layer_past is not None: + past = layer_past[index] + hidden_states, attention_mask = layer(hidden_states, + attention_mask, + layer_past=past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + layer_id=index) + if get_key_value: + hidden_states, present = hidden_states + presents.append(present) + + # Use FP32 for Layernorm + # hidden_states_ = self.final_layernorm(hidden_states.float()).half() + hidden_states_ = self.final_layernorm(hidden_states) + + ################################# + # top query layer + ################################# + past = None + if layer_past is not None: + past = layer_past[self.num_layers] + hidden_states = self.topQueryLayer(hidden_states_, + query_hidden_state, + origin_attention_mask, + layer_past=past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length) + + if get_key_value: + hidden_states, present = hidden_states + presents.append(present) + + # reverting data format change [s b h] --> [b s h] + output = hidden_states.transpose(0, 1).contiguous() + + if get_key_value: + output = [output, presents] + + return output + + def state_dict_for_save_checkpoint( + self, destination=None, prefix="", keep_vars=False + ): + return self.state_dict(destination, prefix, keep_vars) + + +class Embedding(torch.nn.Module): + """Language model embeddings. + + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + """ + + def __init__( + self, + hidden_size, + vocab_size, + max_sequence_length, + ): + super(Embedding, self).__init__() + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.max_sequence_length = max_sequence_length + + # Word embeddings. + self.word_embeddings = torch.nn.Embedding(self.vocab_size, self.hidden_size) + self._word_embeddings_key = 'word_embeddings' + + # Position embedding. + self.position_embeddings = torch.nn.Embedding(self.max_sequence_length, self.hidden_size) + self.position_embeddings = self.position_embeddings.half() + self._position_embeddings_key = 'position_embeddings' + + def forward(self, input_ids, position_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + embeddings = words_embeddings + position_embeddings + + return embeddings + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load.""" + + state_dict_ = {} + state_dict_[self._word_embeddings_key] \ + = self.word_embeddings.state_dict(destination, prefix, keep_vars) + state_dict_[self._position_embeddings_key] \ + = self.position_embeddings.state_dict( + destination, prefix, keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Word embedding. + if self._word_embeddings_key in state_dict: + state_dict_ = state_dict[self._word_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'word_embeddings' in key: + state_dict_[key.split('word_embeddings.')[1]] \ + = state_dict[key] + state_dict_["weight"] = state_dict_["weight"][:self.vocab_size] + self.word_embeddings.load_state_dict(state_dict_, strict=strict) + + # Position embedding. + if self._position_embeddings_key in state_dict: + state_dict_ = state_dict[self._position_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'position_embeddings' in key: + state_dict_[key.split('position_embeddings.')[1]] \ + = state_dict[key] + self.position_embeddings.load_state_dict(state_dict_, strict=strict) + + +class QueryEmbedding(torch.nn.Module): + """Language model embeddings. + + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + """ + + def __init__( + self, + hidden_size, + vocab_size, + max_sequence_length, + ): + super(QueryEmbedding, self).__init__() + + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.max_sequence_length = max_sequence_length + + # Top query position embedding (serial). + self.top_query_embeddings = torch.nn.Embedding(self.max_sequence_length, self.hidden_size) + self.top_query_embeddings = self.top_query_embeddings.half() + self._top_query_embeddings_key = 'top_query_embeddings' + + def forward(self, position_ids): + # Embeddings. + embeddings = self.top_query_embeddings(position_ids) + + return embeddings + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load.""" + + state_dict_ = {} + state_dict_[self._top_query_embeddings_key] \ + = self.top_query_embeddings.state_dict( + destination, prefix, keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Position embedding. + if self._top_query_embeddings_key in state_dict: + state_dict_ = state_dict[self._top_query_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'top_query_embeddings' in key: + state_dict_[key.split('top_query_embeddings.')[1]] \ + = state_dict[key] + self.top_query_embeddings.load_state_dict(state_dict_, strict=strict) + + +class TransformerLanguageModel(torch.nn.Module): + """Transformer language model. + + Arguments: + transformer_hparams: transformer hyperparameters + attention_mask_func: a function that takes `unmaksed-attention-scores` + with size [b, np, s, s] and an `attention-mask` and will apply + the masking. The function should return a masked score of the + same size [b, np, s, s]. + masked-attention-scores = attention_mask_func( + unmaksed-attention-scores, attention-mask) + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + """ + + def __init__( + self, + hidden_size, + num_layers, + num_attention_heads, + padded_vocab_size, + max_position_embeddings, + ): + super(TransformerLanguageModel, self).__init__() + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.padded_vocab_size = padded_vocab_size + self.max_position_embeddings = max_position_embeddings + + # Embeddings + self.embedding = Embedding(self.hidden_size, + self.padded_vocab_size, + self.max_position_embeddings) + self._embedding_key = 'embedding' + + # Query embeddings + self.topQueryEmbedding = QueryEmbedding(self.hidden_size, + self.padded_vocab_size, + self.max_position_embeddings) + self._topQueryEmbedding_key = 'topQueryEmbedding' + + # Transformer + self.transformer = Transformer(self.hidden_size, + self.num_attention_heads, + self.num_layers) + self._transformer_key = 'transformer' + + def forward( + self, + input_ids, + position_ids, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, + ): + + # Embeddings. + embedding_output = self.embedding(input_ids, position_ids) + query_position_ids = position_ids + queryEmbedding_out = self.topQueryEmbedding(query_position_ids) + + # Transformer. + transformer_output = self.transformer(embedding_output, + queryEmbedding_out, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length) + + return transformer_output + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load.""" + + state_dict_ = {} + state_dict_[self._embedding_key] \ + = self.embedding.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + state_dict_[self._topQueryEmbedding_key] \ + = self.topQueryEmbedding.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + state_dict_[self._transformer_key] \ + = self.transformer.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Embedding. + if self._embedding_key in state_dict: + state_dict_ = state_dict[self._embedding_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if '_embeddings' in key: + state_dict_[key] = state_dict[key] + self.embedding.load_state_dict(state_dict_, strict=strict) + + if self._topQueryEmbedding_key in state_dict: + state_dict_ = state_dict[self._topQueryEmbedding_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if '_embeddings' in key: + state_dict_[key] = state_dict[key] + self.topQueryEmbedding.load_state_dict(state_dict_, strict=strict) + + # Transformer. + if self._transformer_key in state_dict: + state_dict_ = state_dict[self._transformer_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'transformer.' in key: + state_dict_[key.split('transformer.')[1]] = state_dict[key] + self.transformer.load_state_dict(state_dict_, strict=strict) + + +class CodeGeeXModel(torch.nn.Module): + """CodeGeeX: A Multilingual Code Generation Model.""" + + def __init__( + self, + hidden_size, + num_layers, + num_attention_heads, + padded_vocab_size, + max_position_embeddings, + ): + super(CodeGeeXModel, self).__init__() + + self.language_model = TransformerLanguageModel(hidden_size, + num_layers, + num_attention_heads, + padded_vocab_size, + max_position_embeddings) + self._language_model_key = "language_model" + + def forward( + self, + input_ids, + position_ids, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, + ): + # Language model. + lm_output = self.language_model(input_ids, + position_ids, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length) + + if get_key_value: + lm_output, presents = lm_output + + output = F.linear(lm_output, self.language_model.embedding.word_embeddings.weight.half()) + + if get_key_value: + output = [output, presents] + + return output + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + if self._language_model_key in state_dict: + state_dict = state_dict[self._language_model_key] + self.language_model.load_state_dict(state_dict, strict=strict) diff --git a/codegeex/oneflow/inference.py b/codegeex/oneflow/inference.py new file mode 100644 index 0000000..23db106 --- /dev/null +++ b/codegeex/oneflow/inference.py @@ -0,0 +1,326 @@ +import copy +import json +import os +import time +from typing import * + +import oneflow as torch +import oneflow.nn.functional as F +from dataclasses import dataclass + + +def get_ltor_masks_and_position_ids( + data, + eod_token, + reset_position_ids, + reset_attention_mask, +): + """Build masks and position id for left to right model.""" + + # Extract batch size and sequence length. + micro_batch_size, seq_length = data.size() + + # Attention mask (lower triangular). + if reset_attention_mask: + att_mask_batch = micro_batch_size + else: + att_mask_batch = 1 + attention_mask = torch.tril( + torch.ones((att_mask_batch, seq_length, seq_length), device=data.device) + ).view(att_mask_batch, 1, seq_length, seq_length) + + # Position ids. + position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) + position_ids = position_ids.unsqueeze(0).expand_as(data) + # We need to clone as the ids will be modifed based on batch index. + if reset_position_ids: + position_ids = position_ids.clone() + + if reset_position_ids or reset_attention_mask: + # Loop through the batches: + for b in range(micro_batch_size): + + # Find indecies where EOD token is. + eod_index = position_ids[b, data[b] == eod_token] + # Detach indecies from positions if going to modify positions. + if reset_position_ids: + eod_index = eod_index.clone() + + # Loop through EOD indecies: + prev_index = 0 + for j in range(eod_index.size()[0]): + i = eod_index[j] + # Mask attention loss. + if reset_attention_mask: + attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0 + # Reset positions. + if reset_position_ids: + position_ids[b, (i + 1) :] -= i + 1 - prev_index + prev_index = i + 1 + + # Convert attention mask to binary: + attention_mask = attention_mask < 0.5 + + return attention_mask, position_ids + + +def get_batch( + context_tokens, + micro_batch_size, + eod_token, + reset_position_ids=False, + reset_attention_mask=False, +): + """Generate batch from context tokens.""" + tokens = context_tokens.view(micro_batch_size, -1).contiguous().cuda() + # Get the attention mask and postition ids. + attention_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + eod_token, + reset_position_ids, + reset_attention_mask, + ) + + return tokens, attention_mask, position_ids + + +def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): + """This function has been mostly taken from huggingface conversational + ai code at + https://medium.com/huggingface/how-to-build-a-state-of-the-art- + conversational-ai-with-transfer-learning-2d818ac26313""" + + if top_k > 0: + # Remove all tokens with a probability less than the + # last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + # Cconvert to 1D + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token + # above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + for i in range(sorted_indices.size(0)): + indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] + logits[i][indices_to_remove] = filter_value + + return logits + + +def pad_batch(batch, pad_id, seq_length): + context_lengths = [] + for tokens in batch: + context_length = len(tokens) + if context_length < seq_length: + tokens.extend([pad_id] * (seq_length - context_length)) + context_lengths.append(context_length) + return batch, context_lengths + + +def forward_step( + model, + tokens, + seq_length, + position_ids, + attention_mask, + layer_past=None, + get_key_value=None, + prompt_length=None, + context_length=None, +): + # Forward pass through the model. + output_tensor = model( + tokens, + position_ids, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) + + if get_key_value: + output_tensor, layer_past = output_tensor + + if get_key_value: + return output_tensor, layer_past + + return output_tensor + + +def get_token_stream( + model, + tokenizer, + seq_length, + out_seq_length, + context_tokens, + return_scores: bool = False, + prompt_length: int = None, + micro_batch_size: int = None, + bad_ids: List = None, + temperature: float = 1.0, + topp: float = 1.0, + topk: int = 0.0, + greedy: bool = False, + recompute: bool = False, +): + context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eos_token_id, seq_length) + + context_tokens_tensor = torch.cuda.LongTensor(context_tokens) + context_length_tensor = torch.cuda.LongTensor(context_lengths) + context_length = context_length_tensor.min().item() + tokens, attention_mask, position_ids = get_batch( + context_tokens_tensor, + micro_batch_size, + tokenizer.eos_token_id, + ) + + batch_token_iterator = sample_sequence_batch( + model, + tokenizer, + context_tokens_tensor, + context_length_tensor, + attention_mask, + position_ids, + seq_length=seq_length, + out_seq_length=out_seq_length, + return_scores=return_scores, + prompt_length=prompt_length, + bad_ids=bad_ids, + temperature=temperature, + topp=topp, + topk=topk, + greedy=greedy, + recompute=recompute, + ) + + for tokens, lengths in batch_token_iterator: + context_length += 1 + if tokens is not None: + yield tokens[:, :context_length], lengths + else: + yield None, None + + +def switch(val1, val2, boolean): + boolean = boolean.type_as(val1) + return (1 - boolean) * val1 + boolean * val2 + + +def sample_sequence_batch( + model, + tokenizer, + context_tokens, + context_lengths, + attention_mask, + position_ids, + seq_length, + out_seq_length, + maxlen=None, + return_scores: bool = False, + prompt_length: int = None, + bad_ids: List = None, + temperature: float = 1.0, + topp: float = 1.0, + topk: int = 0.0, + recompute: bool = False, + greedy: bool = False, +): + model.eval() + with torch.no_grad(): + context_length = context_lengths.min().item() + eos_id = tokenizer.eos_token_id + + counter = 0 + org_context_length = context_length + + layer_past = None + batch_size = context_tokens.size(0) + is_done = torch.zeros([batch_size]).byte().cuda() + tokens = context_tokens + if maxlen is None: + maxlen = seq_length - 1 + if maxlen > (org_context_length + out_seq_length): + maxlen = org_context_length + out_seq_length + + lengths = torch.ones([batch_size]).long().cuda() * maxlen + if return_scores: + scores = torch.zeros([batch_size]).float().cuda() + + while context_length <= (maxlen): + + if recompute: + logits = model(tokens, + position_ids, + attention_mask, + prompt_length=prompt_length, + context_length=context_length, + ) + logits = logits[:, context_length - 1, :] + else: + if counter == 0: + tokens2use = tokens[:, :context_length] + positions2use = position_ids[:, :context_length] + else: + tokens2use = tokens[:, context_length - 1].view( + batch_size, -1) + positions2use = position_ids[:, context_length - 1].view( + batch_size, -1) + logits, layer_past = model(tokens2use, + positions2use, + attention_mask, + layer_past=layer_past, + get_key_value=True, + prompt_length=prompt_length, + context_length=context_length, + ) + logits = logits[:, -1].view(batch_size, -1).contiguous() + + if bad_ids is not None: + for bad_id in bad_ids: + logits[:, bad_id] = -10000 + if greedy: + prev = torch.argmax(logits, dim=-1).view(-1) + else: + logits = logits.float() + if return_scores: + orig_log_probs = torch.log_softmax(logits, dim=-1) + logits /= temperature + logits = top_k_logits(logits, top_k=topk, top_p=topp) + log_probs = F.softmax(logits, dim=-1) + prev = torch.multinomial(log_probs, num_samples=1).view(-1) + + started = context_lengths <= context_length + + new_tokens = switch(tokens[:, context_length].view(-1), prev, started) + + if not greedy and return_scores: + indices = prev.view(-1, 1) + new_scores = orig_log_probs.gather(1, indices).view(-1) + new_scores = new_scores * started + new_scores = new_scores * is_done.bool().logical_not() + scores += new_scores + + tokens[:, context_length] = new_tokens + done_token = (prev == eos_id).byte() & started.byte() + just_finished = (done_token & ~is_done).bool() + lengths[just_finished.view(-1)] = context_length + is_done = is_done | done_token + done = torch.all(is_done) + + if return_scores: + yield tokens, (lengths, scores) + else: + yield tokens, lengths + + context_length += 1 + counter += 1 + if done: + break diff --git a/scripts/test_inference_oneflow.sh b/scripts/test_inference_oneflow.sh new file mode 100644 index 0000000..939fb67 --- /dev/null +++ b/scripts/test_inference_oneflow.sh @@ -0,0 +1,39 @@ +# This script is used to test the inference of CodeGeeX. + +GPU=$1 +PROMPT_FILE=$2 + +SCRIPT_PATH=$(realpath "$0") +SCRIPT_DIR=$(dirname "$SCRIPT_PATH") +MAIN_DIR=$(dirname "$SCRIPT_DIR") +TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/" + +# import model configuration +source "$MAIN_DIR/configs/codegeex_13b.sh" + +# export CUDA settings +if [ -z "$GPU" ]; then + GPU=0 +fi + +export CUDA_HOME=/usr/local/cuda-11.1/ +export CUDA_VISIBLE_DEVICES=$GPU + +if [ -z "$PROMPT_FILE" ]; then + PROMPT_FILE=$MAIN_DIR/tests/test_prompt.txt +fi + +# remove --greedy if using sampling +CMD="python $MAIN_DIR/tests/test_inference_oneflow.py \ + --prompt-file $PROMPT_FILE \ + --tokenizer-path $TOKENIZER_PATH \ + --micro-batch-size 1 \ + --out-seq-length 1024 \ + --temperature 0.8 \ + --top-p 0.95 \ + --top-k 0 \ + --greedy \ + $MODEL_ARGS" + +echo "$CMD" +eval "$CMD"