diff --git a/codegeex/paddle/__init__.py b/codegeex/paddle/__init__.py new file mode 100644 index 0000000..16975c0 --- /dev/null +++ b/codegeex/paddle/__init__.py @@ -0,0 +1 @@ +from .codegeex_model import CodeGeeXModel \ No newline at end of file diff --git a/codegeex/paddle/codegeex_model.py b/codegeex/paddle/codegeex_model.py new file mode 100644 index 0000000..4d946b0 --- /dev/null +++ b/codegeex/paddle/codegeex_model.py @@ -0,0 +1,1010 @@ +import math +import paddle +import paddle.nn.functional as F + + +def fast_gelu(x): + """Mindspore's fast gelu implementation.""" + return x / (1 + paddle.exp(-1.702 * paddle.abs(x))) * paddle.exp(0.851 * (x - paddle.abs(x))) + + +class MLP(paddle.nn.Layer): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. At the end, dropout is also + applied. + """ + + def __init__( + self, + hidden_size, + ): + super(MLP, self).__init__() + self.hidden_size = hidden_size + # Project to 4h. + self.dense_h_to_4h = paddle.nn.Linear( + self.hidden_size, + 4 * self.hidden_size, + ) + + self.activation_func = fast_gelu + + # Project back to h. + self.dense_4h_to_h = paddle.nn.Linear( + 4 * self.hidden_size, + self.hidden_size, + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + + return output + + +class SelfAttention(paddle.nn.Layer): + """self-attention layer abstract class. + + Self-attention layer takes input with size [b, s, h] + and returns output of the same size. + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + layer_number, + fp16=True, + attention_softmax_in_fp32=True, + ): + super(SelfAttention, self).__init__() + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.fp16 = fp16 + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.layer_number = max(1, layer_number) + + assert self.hidden_size % self.num_attention_heads == 0 + self.hidden_size_per_attention_head = int(self.hidden_size // self.num_attention_heads) + + self.query = paddle.nn.Linear(self.hidden_size, self.hidden_size) + self.key = paddle.nn.Linear(self.hidden_size, self.hidden_size) + self.value = paddle.nn.Linear(self.hidden_size, self.hidden_size) + + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + self.softmax = paddle.nn.Softmax(axis=-1) + + self.dense = paddle.nn.Linear(self.hidden_size, self.hidden_size) + + def forward( + self, + hidden_states, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, + ): + # hidden_states: [sq, b, h] + + # ===================== + # Query, Key, and Value + # ===================== + + query_layer = self.query(hidden_states) + key_layer = self.key(hidden_states) + value_layer = self.value(hidden_states) + + new_query_layer_shape = query_layer.shape[:-1] + \ + [self.num_attention_heads, + self.hidden_size_per_attention_head] + query_layer = query_layer.reshape(new_query_layer_shape) + + new_query_layer_shape = key_layer.shape[:-1] + \ + [self.num_attention_heads, + self.hidden_size_per_attention_head] + key_layer = key_layer.reshape(new_query_layer_shape) + + new_query_layer_shape = value_layer.shape[:-1] + \ + [self.num_attention_heads, + self.hidden_size_per_attention_head] + value_layer = value_layer.reshape(new_query_layer_shape) + + # ================================== + # Adjust key and value for inference + # ================================== + + if layer_past is not None: + past_key, past_value = layer_past + key_layer = paddle.concat((past_key.cast(key_layer.dtype), + key_layer), axis=0) + value_layer = paddle.concat((past_value.cast(value_layer.dtype), + value_layer), axis=0) + if get_key_value: + present = (key_layer, value_layer) + + # =================================== + # Raw attention scores. [b, np, sq, sk] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.shape[1], + query_layer.shape[2], + query_layer.shape[0], + key_layer.shape[0]) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.reshape([output_size[2], output_size[0] * output_size[1], -1]) + key_layer = key_layer.reshape([output_size[3], output_size[0] * output_size[1], -1]) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = paddle.matmul(query_layer.transpose([1, 0, 2]), + key_layer.transpose([1, 0, 2]).transpose([0, 2, 1])) / self.norm_factor + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.reshape(output_size) + + # ================================================== + # Update attention mask for inference. [b, np, sq, sk] + # ================================================== + + if get_key_value: + with paddle.no_grad(): + if layer_past is not None: + attention_mask = attention_mask[ + ..., + attention_scores.shape[3] - 1, + :attention_scores.shape[3]].unsqueeze(2) + else: + attention_mask = attention_mask[ + ..., + :attention_scores.shape[3], + :attention_scores.shape[3]] + + if context_length is not None: + attention_mask = paddle.clone(attention_mask) + attention_mask[:, :, context_length:, :] = True + + # attention scores and attention mask [b, np, sq, sk] + # attention_scores = attention_mask_func(attention_scores, attention_mask) + attention_scores = attention_scores - attention_mask * 10000.0 + if self.attention_softmax_in_fp32: + attention_probs = self.softmax(attention_scores.cast("float32")).cast("float16") + else: + attention_probs = self.softmax(attention_scores) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sq, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.shape[1], + value_layer.shape[2], + query_layer.shape[0], + value_layer.shape[3]) + + # change view [sq, b * np, hn] + value_layer = value_layer.reshape([value_layer.shape[0], output_size[0] * output_size[1], -1]) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.reshape([output_size[0] * output_size[1], + output_size[2], -1]) + + context_layer = paddle.bmm(attention_probs, value_layer.unsqueeze(0).transpose([0, 2, 1, 3]).squeeze(0)) + + # change view [b, np, sq, hn] + context_layer = context_layer.reshape(output_size) + + # # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.transpose([2, 0, 1, 3]) + + # # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.shape[:-2] + \ + [self.hidden_size,] + context_layer = context_layer.reshape(new_context_layer_shape) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + if get_key_value: + output = [output, present] + + return output + + +class TopQuerySelfAttention(paddle.nn.Layer): + """Top query self-attention layer abstract class. + + Self-attention layer takes input with size [b, s, h] + and returns output of the same size. + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + layer_number, + fp16=True, + attention_softmax_in_fp32=True, + ): + super(TopQuerySelfAttention, self).__init__() + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.fp16 = fp16 + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.layer_number = max(1, layer_number) + + assert self.hidden_size % self.num_attention_heads == 0 + self.hidden_size_per_attention_head = int(self.hidden_size // self.num_attention_heads) + + self.query = paddle.nn.Linear(self.hidden_size, self.hidden_size) + self.key = paddle.nn.Linear(self.hidden_size, self.hidden_size) + self.value = paddle.nn.Linear(self.hidden_size, self.hidden_size) + + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + self.softmax = paddle.nn.Softmax(axis=-1) + + self.dense = paddle.nn.Linear(self.hidden_size, self.hidden_size) + + def forward( + self, + hidden_states, + query_hidden_state, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, + ): + + # hidden_states: [sq, b, h] + query_layer = self.query(query_hidden_state) + key_layer = self.key(hidden_states) + value_layer = self.value(hidden_states) + + new_query_layer_shape = query_layer.shape[:-1] + \ + [self.num_attention_heads, + self.hidden_size_per_attention_head] + query_layer = query_layer.reshape(new_query_layer_shape) + + new_query_layer_shape = key_layer.shape[:-1] + \ + [self.num_attention_heads, + self.hidden_size_per_attention_head] + key_layer = key_layer.reshape(new_query_layer_shape) + + new_query_layer_shape = value_layer.shape[:-1] + \ + [self.num_attention_heads, + self.hidden_size_per_attention_head] + value_layer = value_layer.reshape(new_query_layer_shape) + + # ================================== + # Adjust key and value for inference + # ================================== + + if layer_past is not None: + past_key, past_value = layer_past + key_layer = paddle.concat((past_key.cast(key_layer.dtype), + key_layer), axis=0) + value_layer = paddle.concat((past_value.cast(value_layer.dtype), + value_layer), axis=0) + if get_key_value: + present = (key_layer, value_layer) + + # =================================== + # Raw attention scores. [b, np, sq, sk] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.shape[1], + query_layer.shape[2], + query_layer.shape[0], + key_layer.shape[0]) + + # [s, b, np, hn] -> [s, b * np, hn] + query_layer = query_layer.reshape([output_size[2], output_size[0] * output_size[1], -1]) + key_layer = key_layer.reshape([output_size[3], output_size[0] * output_size[1], -1]) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = paddle.matmul(query_layer.transpose([1, 0, 2]), + key_layer.transpose([1, 0, 2]).transpose([0, 2, 1])) / self.norm_factor + + # change view to [b, np, s, s] + attention_scores = matmul_result.reshape(output_size) + + # ================================================== + # Update attention mask for inference. [b, np, sq, sk] + # ================================================== + + if get_key_value: + with paddle.no_grad(): + if layer_past is not None: + attention_mask = attention_mask[ + ..., + attention_scores.shape[3] - 1, + :attention_scores.shape[3]].unsqueeze(2) + else: + attention_mask = attention_mask[ + ..., + :attention_scores.shape[3], + :attention_scores.shape[3]] + + if context_length is not None: + attention_mask = paddle.clone(attention_mask) + attention_mask[:, :, context_length:, :] = True + + # attention scores and attention mask [b, np, sq, sk] + # attention_scores = attention_mask_func(attention_scores, attention_mask) + attention_scores = attention_scores - attention_mask * 10000.0 + if self.attention_softmax_in_fp32: + attention_probs = self.softmax(attention_scores.cast("float32")).cast("float16") + else: + attention_probs = self.softmax(attention_scores) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sq, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.shape[1], + value_layer.shape[2], + query_layer.shape[0], + value_layer.shape[3]) + + # change view [sq, b * np, hn] + value_layer = value_layer.reshape([value_layer.shape[0], output_size[0] * output_size[1], -1]) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.reshape([output_size[0] * output_size[1], + output_size[2], -1]) + + # matmul: [b * np, sq, hn] + context_layer = paddle.bmm(attention_probs, value_layer.unsqueeze(0).transpose([0, 2, 1, 3]).squeeze(0)) + + # change view [b, np, sq, hn] + context_layer = context_layer.reshape(output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.transpose([2, 0, 1, 3]) + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.shape[:-2] + \ + [self.hidden_size,] + context_layer = context_layer.reshape(new_context_layer_shape) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + if get_key_value: + output = [output, present] + + return output + + +class TransformerLayer(paddle.nn.Layer): + """A single transformer layer. + + Transformore layer takes input with size [b, s, h] and returns an + output of the same size. + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + layer_number, + layernorm_epsilon=1e-5, + fp16=True, + attention_softmax_in_fp32=True, + ): + super(TransformerLayer, self).__init__() + self.hidden_size = hidden_size + self.layernorm_epsilon = layernorm_epsilon + self.layer_number = layer_number + + # Layernorm on the input data. + self.input_layernorm = paddle.nn.LayerNorm(hidden_size, + epsilon=self.layernorm_epsilon) + + # Self attention. + self.attention = SelfAttention(hidden_size, + num_attention_heads, + layer_number, + fp16, + attention_softmax_in_fp32) + + # Layernorm on the input data. + self.post_attention_layernorm = paddle.nn.LayerNorm(self.hidden_size, + epsilon=self.layernorm_epsilon) + self.mlp = MLP(self.hidden_size) + + def forward( + self, + hidden_states, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, + ): + # hidden_states: [b, s, h] + # Use FP32 for Layernorm + # layernorm_output = self.input_layernorm(hidden_states.cast("float32")).cast("float16") + layernorm_output = self.input_layernorm(hidden_states) + + # Self attention. + attention_output = self.attention(layernorm_output, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length) + + if get_key_value: + attention_output, presents = attention_output + + # Residual connection. + residual = hidden_states + layernorm_input = attention_output + residual + + # Use FP32 for Layernorm + # layernorm_output = self.post_attention_layernorm(layernorm_input.cast("float32")).cast("float16") + layernorm_output = self.post_attention_layernorm(layernorm_input) + mlp_output = self.mlp(layernorm_output) + output = mlp_output + layernorm_input + + if get_key_value: + output = [output, presents] + + return output + + +class TopQueryLayer(paddle.nn.Layer): + """A single top query layer. + + Top query layer takes input with size [b, s, h] and returns an + output of the same size. + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + layer_number, + layernorm_epsilon=1e-5, + ): + super(TopQueryLayer, self).__init__() + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.layernorm_epsilon = layernorm_epsilon + self.layer_number = layer_number + + # Use FP32 for Layernorm + self.input_layernorm = paddle.nn.LayerNorm(self.hidden_size, + epsilon=self.layernorm_epsilon) + + # Self attention. + self.attention = TopQuerySelfAttention(self.hidden_size, + self.num_attention_heads, + self.layer_number) + # Layernorm on the input data. + self.post_attention_layernorm = paddle.nn.LayerNorm(self.hidden_size, + epsilon=self.layernorm_epsilon) + + # MLP + self.mlp = MLP(self.hidden_size) + + def forward( + self, + hidden_states, + query_hidden_state, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, + ): + # hidden_states: [b, s, h] + # assert query_hidden_state != None + + # Use FP32 for Layernorm + # layernorm_output = self.input_layernorm(hidden_states.cast("float32")).cast("float16") + layernorm_output = self.input_layernorm(hidden_states) + + # Self attention. + attention_output = self.attention(layernorm_output, + query_hidden_state, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length) + + if get_key_value: + attention_output, presents = attention_output + + # Residual connection. + residual = hidden_states + layernorm_input = attention_output + residual + + # Use FP32 for Layernorm + # layernorm_output = self.post_attention_layernorm(layernorm_input.cast("float32")).cast("float16") + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + residual = layernorm_input + output = mlp_output + residual + + if get_key_value: + output = [output, presents] + + return output + + +class Transformer(paddle.nn.Layer): + """Transformer class.""" + + def __init__( + self, + hidden_size, + num_attention_heads, + num_layers, + layernorm_epsilon=1e-5, + ): + super(Transformer, self).__init__() + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.layernorm_epsilon = layernorm_epsilon + # Number of layers: + self.num_layers = num_layers + self.num_unique_layers = None + + ################# + assert self.num_unique_layers is None + ################# + + if self.num_unique_layers is None: + self.num_unique_layers = self.num_layers + assert self.num_layers % self.num_unique_layers == 0, \ + 'number of layers should be divisible by number of unique layers' + + # Transformer layers. + def build_layer(layer_number): + return TransformerLayer(self.hidden_size, self.num_attention_heads, layer_number) + + self.layers = paddle.nn.LayerList( + [build_layer(i + 1) for i in range(self.num_unique_layers)]) + + self.topQueryLayer = TopQueryLayer(self.hidden_size, + self.num_attention_heads, + self.num_unique_layers) + + self.final_layernorm = paddle.nn.LayerNorm(self.hidden_size, + epsilon=self.layernorm_epsilon) + + def _get_layer_index(self, layer_number): + return layer_number % self.num_unique_layers + + def _get_layer(self, layer_number): + return self.layers[self._get_layer_index(layer_number)] + + def forward( + self, + hidden_states, + query_hidden_state, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, + ): + # data format change to avoid explicit tranposes : [b s h] --> [s b h] + hidden_states = hidden_states.transpose([1, 0, 2]) + query_hidden_state = query_hidden_state.transpose([1, 0, 2]) + + + if get_key_value: + presents = [] + for index in range(self.num_layers): + layer = self._get_layer(index) + past = None + if layer_past is not None: + past = layer_past[index] + hidden_states = layer(hidden_states, + attention_mask, + layer_past=past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length) + if get_key_value: + hidden_states, present = hidden_states + presents.append(present) + + # Use FP32 for Layernorm + # hidden_states_ = self.final_layernorm(hidden_states.cast("float32")).cast("float16") + hidden_states_ = self.final_layernorm(hidden_states) + + ################################# + # top query layer + ################################# + past = None + if layer_past is not None: + past = layer_past[self.num_layers] + hidden_states = self.topQueryLayer(hidden_states_, + query_hidden_state, + attention_mask, + layer_past=past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length) + + if get_key_value: + hidden_states, present = hidden_states + presents.append(present) + + # reverting data format change [s b h] --> [b s h] + output = hidden_states.transpose([1, 0, 2]) + + if get_key_value: + output = [output, presents] + + return output + + def state_dict_for_save_checkpoint( + self, destination=None, prefix="", keep_vars=False + ): + return self.state_dict(destination, prefix, keep_vars) + + +class Embedding(paddle.nn.Layer): + """Language model embeddings. + + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + """ + + def __init__( + self, + hidden_size, + vocab_size, + max_sequence_length, + ): + super(Embedding, self).__init__() + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.max_sequence_length = max_sequence_length + + # Word embeddings. + self.word_embeddings = paddle.nn.Embedding(self.vocab_size, self.hidden_size) + self._word_embeddings_key = 'word_embeddings' + + # Position embedding. + self.position_embeddings = paddle.nn.Embedding(self.max_sequence_length, self.hidden_size) + self.position_embeddings = self.position_embeddings.to(dtype="float16") + self._position_embeddings_key = 'position_embeddings' + + def forward(self, input_ids, position_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + embeddings = words_embeddings + position_embeddings + + return embeddings + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load.""" + + state_dict_ = {} + state_dict_[self._word_embeddings_key] \ + = self.word_embeddings.state_dict(destination, prefix, keep_vars) + state_dict_[self._position_embeddings_key] \ + = self.position_embeddings.state_dict( + destination, prefix, keep_vars) + + return state_dict_ + + def set_state_dict(self, state_dict, use_structured_name=True): + """Customized load.""" + + # Word embedding. + if self._word_embeddings_key in state_dict: + state_dict_ = state_dict[self._word_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'word_embeddings' in key: + state_dict_[key.split('word_embeddings.')[1]] \ + = state_dict[key] + state_dict_["weight"] = state_dict_["weight"][:self.vocab_size] + self.word_embeddings.set_state_dict(state_dict_, use_structured_name=use_structured_name) + + # Position embedding. + if self._position_embeddings_key in state_dict: + state_dict_ = state_dict[self._position_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'position_embeddings' in key: + state_dict_[key.split('position_embeddings.')[1]] \ + = state_dict[key] + self.position_embeddings.set_state_dict(state_dict_, use_structured_name=use_structured_name) + + +class QueryEmbedding(paddle.nn.Layer): + """Language model embeddings. + + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + """ + + def __init__( + self, + hidden_size, + vocab_size, + max_sequence_length, + ): + super(QueryEmbedding, self).__init__() + + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.max_sequence_length = max_sequence_length + + # Top query position embedding (serial). + self.top_query_embeddings = paddle.nn.Embedding(self.max_sequence_length, self.hidden_size) + self.top_query_embeddings = self.top_query_embeddings.to(dtype="float16") + self._top_query_embeddings_key = 'top_query_embeddings' + + def forward(self, position_ids): + # Embeddings. + embeddings = self.top_query_embeddings(position_ids) + + return embeddings + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load.""" + + state_dict_ = {} + state_dict_[self._top_query_embeddings_key] \ + = self.top_query_embeddings.state_dict( + destination, prefix, keep_vars) + + return state_dict_ + + def set_state_dict(self, state_dict, use_structured_name=True): + """Customized load.""" + + # Position embedding. + if self._top_query_embeddings_key in state_dict: + state_dict_ = state_dict[self._top_query_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'top_query_embeddings' in key: + state_dict_[key.split('top_query_embeddings.')[1]] \ + = state_dict[key] + self.top_query_embeddings.set_state_dict(state_dict_, use_structured_name=use_structured_name) + + +class TransformerLanguageModel(paddle.nn.Layer): + """Transformer language model. + + Arguments: + transformer_hparams: transformer hyperparameters + attention_mask_func: a function that takes `unmaksed-attention-scores` + with size [b, np, s, s] and an `attention-mask` and will apply + the masking. The function should return a masked score of the + same size [b, np, s, s]. + masked-attention-scores = attention_mask_func( + unmaksed-attention-scores, attention-mask) + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + """ + + def __init__( + self, + hidden_size, + num_layers, + num_attention_heads, + padded_vocab_size, + max_position_embeddings, + ): + super(TransformerLanguageModel, self).__init__() + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.padded_vocab_size = padded_vocab_size + self.max_position_embeddings = max_position_embeddings + + # Embeddings + self.embedding = Embedding(self.hidden_size, + self.padded_vocab_size, + self.max_position_embeddings) + self._embedding_key = 'embedding' + + # Query embeddings + self.topQueryEmbedding = QueryEmbedding(self.hidden_size, + self.padded_vocab_size, + self.max_position_embeddings) + self._topQueryEmbedding_key = 'topQueryEmbedding' + + # Transformer + self.transformer = Transformer(self.hidden_size, + self.num_attention_heads, + self.num_layers) + self._transformer_key = 'transformer' + + def forward( + self, + input_ids, + position_ids, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, + ): + + # Embeddings. + embedding_output = self.embedding(input_ids, position_ids) + query_position_ids = position_ids + queryEmbedding_out = self.topQueryEmbedding(query_position_ids) + + # Transformer. + transformer_output = self.transformer(embedding_output, + queryEmbedding_out, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length) + + return transformer_output + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load.""" + + state_dict_ = {} + state_dict_[self._embedding_key] \ + = self.embedding.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + state_dict_[self._topQueryEmbedding_key] \ + = self.topQueryEmbedding.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + state_dict_[self._transformer_key] \ + = self.transformer.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + + return state_dict_ + + def set_state_dict(self, state_dict, use_structured_name=True): + """Customized load.""" + + # Embedding. + if self._embedding_key in state_dict: + state_dict_ = state_dict[self._embedding_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if '_embeddings' in key: + state_dict_[key] = state_dict[key] + self.embedding.set_state_dict(state_dict_, use_structured_name=use_structured_name) + + if self._topQueryEmbedding_key in state_dict: + state_dict_ = state_dict[self._topQueryEmbedding_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if '_embeddings' in key: + state_dict_[key] = state_dict[key] + self.topQueryEmbedding.set_state_dict(state_dict_, use_structured_name=use_structured_name) + + # Transformer. + if self._transformer_key in state_dict: + state_dict_ = state_dict[self._transformer_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'transformer.' in key: + state_dict_[key.split('transformer.')[1]] = state_dict[key] + self.transformer.set_state_dict(state_dict_, use_structured_name=use_structured_name) + + +class CodeGeeXModel(paddle.nn.Layer): + """CodeGeeX: A Multilingual Code Generation Model.""" + + def __init__( + self, + hidden_size, + num_layers, + num_attention_heads, + padded_vocab_size, + max_position_embeddings, + ): + super(CodeGeeXModel, self).__init__() + + self.language_model = TransformerLanguageModel(hidden_size, + num_layers, + num_attention_heads, + padded_vocab_size, + max_position_embeddings) + self._language_model_key = "language_model" + + def forward( + self, + input_ids, + position_ids, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, + ): + # Language model. + lm_output = self.language_model(input_ids, + position_ids, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length) + + if get_key_value: + lm_output, presents = lm_output + + output = F.linear(lm_output, self.language_model.embedding.word_embeddings.weight.cast("float16").transpose([1, 0])) + + if get_key_value: + output = [output, presents] + + return output + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + return state_dict_ + + def set_state_dict(self, state_dict, use_structured_name=True): + """Customized load.""" + + if self._language_model_key in state_dict: + state_dict = state_dict[self._language_model_key] + self.language_model.set_state_dict(state_dict, use_structured_name=use_structured_name) diff --git a/codegeex/paddle/inference.py b/codegeex/paddle/inference.py new file mode 100644 index 0000000..cbc8158 --- /dev/null +++ b/codegeex/paddle/inference.py @@ -0,0 +1,326 @@ +import copy +import json +import os +import time +from typing import * + +import paddle +import paddle.nn.functional as F +from dataclasses import dataclass + + +def get_ltor_masks_and_position_ids( + data, + eod_token, + reset_position_ids, + reset_attention_mask, +): + """Build masks and position id for left to right model.""" + + # Extract batch size and sequence length. + micro_batch_size, seq_length = data.shape + + # Attention mask (lower triangular). + if reset_attention_mask: + att_mask_batch = micro_batch_size + else: + att_mask_batch = 1 + attention_mask = paddle.tril( + paddle.ones((att_mask_batch, seq_length, seq_length)) + ).reshape([att_mask_batch, 1, seq_length, seq_length]) + + # Position ids. + position_ids = paddle.arange(seq_length, dtype="int64") + position_ids = position_ids.unsqueeze(0).expand_as(data) + # We need to clone as the ids will be modifed based on batch index. + if reset_position_ids: + position_ids = position_ids.clone() + + if reset_position_ids or reset_attention_mask: + # Loop through the batches: + for b in range(micro_batch_size): + + # Find indecies where EOD token is. + eod_index = position_ids[b, data[b] == eod_token] + # Detach indecies from positions if going to modify positions. + if reset_position_ids: + eod_index = eod_index.clone() + + # Loop through EOD indecies: + prev_index = 0 + for j in range(eod_index.shape[0]): + i = eod_index[j] + # Mask attention loss. + if reset_attention_mask: + attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0 + # Reset positions. + if reset_position_ids: + position_ids[b, (i + 1) :] -= i + 1 - prev_index + prev_index = i + 1 + + # Convert attention mask to binary: + attention_mask = attention_mask < 0.5 + + return attention_mask, position_ids + + +def get_batch( + context_tokens, + micro_batch_size, + eod_token, + reset_position_ids=False, + reset_attention_mask=False, +): + """Generate batch from context tokens.""" + tokens = context_tokens.reshape([micro_batch_size, -1]).cuda() + # Get the attention mask and postition ids. + attention_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + eod_token, + reset_position_ids, + reset_attention_mask, + ) + + return tokens, attention_mask, position_ids + + +def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): + """This function has been mostly taken from huggingface conversational + ai code at + https://medium.com/huggingface/how-to-build-a-state-of-the-art- + conversational-ai-with-transfer-learning-2d818ac26313""" + + if top_k > 0: + # Remove all tokens with a probability less than the + # last token of the top-k + indices_to_remove = logits < paddle.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + # Cconvert to 1D + sorted_logits, sorted_indices = paddle.sort(logits, descending=True, axis=-1) + cumulative_probs = paddle.cumsum(F.softmax(sorted_logits, axis=-1), axis=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token + # above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + for i in range(sorted_indices.shape[0]): + indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] + logits[i][indices_to_remove] = filter_value + + return logits + + +def pad_batch(batch, pad_id, seq_length): + context_lengths = [] + for tokens in batch: + context_length = len(tokens) + if context_length < seq_length: + tokens.extend([pad_id] * (seq_length - context_length)) + context_lengths.append(context_length) + return batch, context_lengths + + +def forward_step( + model, + tokens, + seq_length, + position_ids, + attention_mask, + layer_past=None, + get_key_value=None, + prompt_length=None, + context_length=None, +): + # Forward pass through the model. + output_tensor = model( + tokens, + position_ids, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) + + if get_key_value: + output_tensor, layer_past = output_tensor + + if get_key_value: + return output_tensor, layer_past + + return output_tensor + + +def get_token_stream( + model, + tokenizer, + seq_length, + out_seq_length, + context_tokens, + return_scores: bool = False, + prompt_length: int = None, + micro_batch_size: int = None, + bad_ids: List = None, + temperature: float = 1.0, + topp: float = 1.0, + topk: int = 0.0, + greedy: bool = False, + recompute: bool = False, +): + context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eos_token_id, seq_length) + + context_tokens_tensor = paddle.to_tensor(context_tokens, dtype="int64") + context_length_tensor = paddle.to_tensor(context_lengths, dtype="int64") + context_length = context_length_tensor.min().item() + tokens, attention_mask, position_ids = get_batch( + context_tokens_tensor, + micro_batch_size, + tokenizer.eos_token_id, + ) + + batch_token_iterator = sample_sequence_batch( + model, + tokenizer, + context_tokens_tensor, + context_length_tensor, + attention_mask, + position_ids, + seq_length=seq_length, + out_seq_length=out_seq_length, + return_scores=return_scores, + prompt_length=prompt_length, + bad_ids=bad_ids, + temperature=temperature, + topp=topp, + topk=topk, + greedy=greedy, + recompute=recompute, + ) + + for tokens, lengths in batch_token_iterator: + context_length += 1 + if tokens is not None: + yield tokens[:, :context_length], lengths + else: + yield None, None + + +def switch(val1, val2, boolean): + boolean = boolean.cast(val1.dtype) + return (1 - boolean) * val1 + boolean * val2 + + +def sample_sequence_batch( + model, + tokenizer, + context_tokens, + context_lengths, + attention_mask, + position_ids, + seq_length, + out_seq_length, + maxlen=None, + return_scores: bool = False, + prompt_length: int = None, + bad_ids: List = None, + temperature: float = 1.0, + topp: float = 1.0, + topk: int = 0.0, + recompute: bool = False, + greedy: bool = False, +): + model.eval() + with paddle.no_grad(): + context_length = context_lengths.min().item() + eos_id = tokenizer.eos_token_id + + counter = 0 + org_context_length = context_length + + layer_past = None + batch_size = context_tokens.shape[0] + is_done = paddle.zeros([batch_size]).cast("uint8").cuda() + tokens = context_tokens + if maxlen is None: + maxlen = seq_length - 1 + if maxlen > (org_context_length + out_seq_length): + maxlen = org_context_length + out_seq_length + + lengths = paddle.ones([batch_size]).cast("int64").cuda() * maxlen + if return_scores: + scores = paddle.zeros([batch_size]).cast("float32").cuda() + + while context_length <= (maxlen): + + if recompute: + logits = model(tokens, + position_ids, + attention_mask, + prompt_length=prompt_length, + context_length=context_length, + ) + logits = logits[:, context_length - 1, :] + else: + if counter == 0: + tokens2use = tokens[:, :context_length] + positions2use = position_ids[:, :context_length] + else: + tokens2use = tokens[:, context_length - 1].reshape([ + batch_size, -1]) + positions2use = position_ids[:, context_length - 1].reshape([ + batch_size, -1]) + logits, layer_past = model(tokens2use, + positions2use, + attention_mask, + layer_past=layer_past, + get_key_value=True, + prompt_length=prompt_length, + context_length=context_length, + ) + logits = logits[:, -1].reshape([batch_size, -1]) + + if bad_ids is not None: + for bad_id in bad_ids: + logits[:, bad_id] = -10000 + if greedy: + prev = paddle.argmax(logits, axis=-1).reshape([-1]) + else: + logits = logits.cast("float32") + if return_scores: + orig_log_probs = paddle.log_softmax(logits, axis=-1) + logits /= temperature + logits = top_k_logits(logits, top_k=topk, top_p=topp) + log_probs = F.softmax(logits, axis=-1) + prev = paddle.multinomial(log_probs, num_samples=1).reshape([-1]) + + started = context_lengths <= context_length + + new_tokens = switch(tokens[:, context_length].reshape([-1]), prev, started) + + if not greedy and return_scores: + indices = prev.reshape([-1, 1]) + new_scores = orig_log_probs.gather(1, indices).reshape([-1]) + new_scores = new_scores * started + new_scores = new_scores * is_done.cast("bool").logical_not() + scores += new_scores + + tokens[:, context_length] = new_tokens + done_token = (prev == eos_id).cast("uint8") & started.cast("uint8") + just_finished = (done_token & ~is_done).cast("bool") + lengths[just_finished.reshape([-1])] = context_length + is_done = is_done | done_token + done = paddle.all(is_done.cast("bool")) + + if return_scores: + yield tokens, (lengths, scores) + else: + yield tokens, lengths + + context_length += 1 + counter += 1 + if done: + break diff --git a/configs/codegeex_13b_paddle.sh b/configs/codegeex_13b_paddle.sh new file mode 100644 index 0000000..7c8005d --- /dev/null +++ b/configs/codegeex_13b_paddle.sh @@ -0,0 +1,16 @@ +# CodeGeeX-13B paddle configuration + +CHECKPOINT_PATH="" + +MODEL_ARGS="--num-layers 39 \ + --hidden-size 5120 \ + --num-attention-heads 40 \ + --max-position-embeddings 2048 \ + --attention-softmax-in-fp32 \ + --load "$CHECKPOINT_PATH" \ + --layernorm-epsilon 1e-5 \ + --fp16 \ + --ws-encoding-start-id 10 \ + --ws-encoding-length 10 \ + --make-vocab-size-divisible-by 52224 \ + --seq-length 2048" \ No newline at end of file diff --git a/scripts/test_inference_paddle.sh b/scripts/test_inference_paddle.sh new file mode 100644 index 0000000..ed4c472 --- /dev/null +++ b/scripts/test_inference_paddle.sh @@ -0,0 +1,39 @@ +# This script is used to test the inference of CodeGeeX. + +GPU=$1 +PROMPT_FILE=$2 + +SCRIPT_PATH=$(realpath "$0") +SCRIPT_DIR=$(dirname "$SCRIPT_PATH") +MAIN_DIR=$(dirname "$SCRIPT_DIR") +TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/" + +# import model configuration +source "$MAIN_DIR/configs/codegeex_13b_paddle.sh" + +# export CUDA settings +if [ -z "$GPU" ]; then + GPU=0 +fi + +export CUDA_HOME=/usr/local/cuda-11.1/ +export CUDA_VISIBLE_DEVICES=$GPU + +if [ -z "$PROMPT_FILE" ]; then + PROMPT_FILE=$MAIN_DIR/tests/test_prompt.txt +fi + +# remove --greedy if using sampling +CMD="python $MAIN_DIR/tests/test_inference_paddle.py \ + --prompt-file $PROMPT_FILE \ + --tokenizer-path $TOKENIZER_PATH \ + --micro-batch-size 1 \ + --out-seq-length 1024 \ + --temperature 0.8 \ + --top-p 0.95 \ + --top-k 0 \ + --greedy \ + $MODEL_ARGS" + +echo "$CMD" +eval "$CMD" diff --git a/tests/test_inference_paddle.py b/tests/test_inference_paddle.py new file mode 100644 index 0000000..4a6c21a --- /dev/null +++ b/tests/test_inference_paddle.py @@ -0,0 +1,213 @@ + +import os +import copy +import time +import paddle +import random +import argparse +import numpy as np + +from codegeex.paddle.inference import get_token_stream +from codegeex.paddle import CodeGeeXModel +from codegeex.tokenizer import CodeGeeXTokenizer + + +def model_provider(args): + """Build the model.""" + + old_dtype = paddle.get_default_dtype() + paddle.set_default_dtype("float16") + model = CodeGeeXModel( + args.hidden_size, + args.num_layers, + args.num_attention_heads, + args.padded_vocab_size, + args.max_position_embeddings + ) + model.language_model.embedding.word_embeddings.to(dtype="float32") + model.language_model.embedding.position_embeddings.to(dtype="float32") + model.language_model.topQueryEmbedding.top_query_embeddings.to(dtype="float32") + for i in model.language_model.transformer.layers: + i.input_layernorm.to(dtype="float32") + i.post_attention_layernorm.to(dtype="float32") + model.language_model.transformer.topQueryLayer.input_layernorm.to(dtype="float32") + model.language_model.transformer.topQueryLayer.post_attention_layernorm.to(dtype="float32") + model.language_model.transformer.final_layernorm.to(dtype="float32") + paddle.set_default_dtype(old_dtype) + + return model + + +def add_code_generation_args(parser): + group = parser.add_argument_group(title="code generation") + group.add_argument( + "--num-layers", + type=int, + default=39, + ) + group.add_argument( + "--hidden-size", + type=int, + default=5120, + ) + group.add_argument( + "--num-attention-heads", + type=int, + default=40, + ) + group.add_argument( + "--padded-vocab-size", + type=int, + default=52224, + ) + group.add_argument( + "--max-position-embeddings", + type=int, + default=2048, + ) + group.add_argument( + "--temperature", + type=float, + default=1.0, + help="Sampling temperature.", + ) + group.add_argument( + "--greedy", + action="store_true", + default=False, + help="Use greedy sampling.", + ) + group.add_argument( + "--top-p", + type=float, + default=0.0, + help="Top p sampling.", + ) + group.add_argument( + "--top-k", + type=int, + default=0, + help="Top k sampling.", + ) + group.add_argument( + "--out-seq-length", + type=int, + default=2048, + help="Size of the output generated text.", + ) + group.add_argument( + "--prompt-file", + type=str, + default="./test_prompt.txt", + ) + group.add_argument( + "--tokenizer-path", + type=str, + default="./tokenizer", + ) + group.add_argument( + "--load", + type=str, + ) + group.add_argument( + "--state-dict-path", + type=str, + ) + group.add_argument( + "--micro-batch-size", + type=int, + default=1, + ) + group.add_argument( + "--quantize", + action="store_true", + ) + + return parser + + +def main(): + parser = argparse.ArgumentParser() + parser = add_code_generation_args(parser) + args, _ = parser.parse_known_args() + + print("Loading tokenizer ...") + tokenizer = CodeGeeXTokenizer( + tokenizer_path=args.tokenizer_path, + mode="codegeex-13b") + + print("Loading state dict ...") + state_dict = paddle.load(args.load) + state_dict = state_dict["module"] + + print("Building CodeGeeX model ...") + model = model_provider(args) + model.set_state_dict(state_dict) + model.eval() + model.to(dtype="float16") + if args.quantize: + raise NotImplementedError("quantize") + + with open(args.prompt_file, "r") as f: + prompt = f.readlines() + prompt = "".join(prompt) + + times = {} + out_seq_lengths = [args.out_seq_length] + micro_batch_size = args.micro_batch_size + seq_length = args.max_position_embeddings + for out_seq_length in out_seq_lengths: + print(f"Generating with out_seq_len {out_seq_length}...") + + times[out_seq_length] = [] + for prompt in [prompt]: + t0 = time.perf_counter() + tokens = tokenizer.encode_code(prompt) + print(tokens) + print("Current prompt:") + print(prompt) + n_token_prompt = len(tokens) + print("N_token_prompt:", n_token_prompt) + token_stream = get_token_stream( + model, + tokenizer, + seq_length, + out_seq_length, + [copy.deepcopy(tokens) for _ in range(micro_batch_size)], + micro_batch_size=micro_batch_size, + topk=args.top_k, + topp=args.top_p, + temperature=args.temperature, + greedy=args.greedy, + ) + is_finished = [False for _ in range(micro_batch_size)] + for i, generated in enumerate(token_stream): + generated_tokens = generated[0] + for j in range(micro_batch_size): + if is_finished[j]: + continue + if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len( + generated_tokens[j]) >= out_seq_length: + is_finished[j] = True + generated_tokens_ = generated_tokens[j].cpu().numpy().tolist() + generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:]) + generated_code = "".join(generated_code) + t1 = time.perf_counter() + print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt) + print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token") + times[out_seq_length].append(t1 - t0) + print("================================= Generated code:") + print(generated_code) + + if all(is_finished): + break + + print(times) + for out_seq_length in times.keys(): + print(out_seq_length, np.mean(times[out_seq_length])) + + print("Generation finished.") + + +if __name__ == "__main__": + main()