mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
693 lines
31 KiB
Python
693 lines
31 KiB
Python
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""PanguAlpha model"""
|
|
import copy
|
|
import os
|
|
|
|
import mindspore.common.dtype as mstype
|
|
import mindspore.nn as nn
|
|
import numpy as np
|
|
from mindspore import Tensor, Parameter
|
|
from mindspore.common.initializer import initializer
|
|
from mindspore.nn import Cell
|
|
# from mindspore.parallel.nn.layers import _LayerNorm
|
|
from mindspore.nn.transformer.layers import _Dropout, _LayerNorm
|
|
from mindspore.ops import functional as F
|
|
from mindspore.ops import operations as P
|
|
from mindspore.parallel.nn import MoEConfig
|
|
from mindspore.parallel.nn.transformer import (
|
|
VocabEmbedding,
|
|
TransformerEncoder,
|
|
TransformerEncoderLayer,
|
|
AttentionMask,
|
|
)
|
|
|
|
|
|
class EmbeddingLayer(nn.Cell):
|
|
r"""Embedding layer of the PanGUAlpha Model"""
|
|
|
|
def __init__(self, config):
|
|
super(EmbeddingLayer, self).__init__()
|
|
# Only for the pipeline mode, the embedding needs to be row sliced.
|
|
dp = config.parallel_config.embedding_dp_mp_config.data_parallel
|
|
mp = config.parallel_config.embedding_dp_mp_config.model_parallel
|
|
self.word_embedding = VocabEmbedding(
|
|
vocab_size=(config.vocab_size // 1024 + 1) * 1024,
|
|
embedding_size=config.hidden_size,
|
|
param_init=initializer(
|
|
"normal",
|
|
[(config.vocab_size // 1024 + 1) * 1024, config.hidden_size],
|
|
# dtype=config.param_init_type,
|
|
dtype=mstype.float32,
|
|
),
|
|
parallel_config=config.parallel_config.embedding_dp_mp_config,
|
|
)
|
|
self.word_embedding.gather.shard(((mp, 1), (dp, 1)))
|
|
# self.word_embedding.embedding_table.parallel_optimizer = True
|
|
copied_parallel_config = copy.deepcopy(config.parallel_config)
|
|
copied_parallel_config.vocab_emb_dp = True
|
|
self.position_embedding = VocabEmbedding(
|
|
vocab_size=config.seq_length,
|
|
embedding_size=config.hidden_size,
|
|
param_init=initializer(
|
|
"normal",
|
|
[config.seq_length, config.hidden_size],
|
|
# dtype=config.param_init_type,
|
|
dtype=mstype.float32,
|
|
),
|
|
parallel_config=copied_parallel_config.embedding_dp_mp_config,
|
|
)
|
|
self.split = P.Split(1, 2).shard(
|
|
((config.parallel_config.data_parallel, 1, 1),)
|
|
)
|
|
self.print = P.Print()
|
|
self.add = P.Add().shard(
|
|
(
|
|
(config.parallel_config.data_parallel, 1, 1),
|
|
(config.parallel_config.data_parallel, 1, 1),
|
|
)
|
|
)
|
|
# self.dropout = nn.Dropout(1 - config.dropout_rate)
|
|
self.dropout = _Dropout(1 - config.dropout_rate)
|
|
self.dropout.shard(((config.parallel_config.data_parallel, 1, 1),))
|
|
# self.dropout.dropout.shard(
|
|
# ((config.parallel_config.data_parallel, 1, 1),)
|
|
# )
|
|
self.is_first_iteration = True
|
|
self.use_past = config.use_past
|
|
self.batch_size = config.batch_size
|
|
|
|
def construct(
|
|
self, input_ids, input_position, init_reset, batch_valid_length
|
|
):
|
|
word_embedding, word_table = self.word_embedding(input_ids)
|
|
if self.use_past and not self.is_first_iteration:
|
|
_, seq_length = F.shape(input_ids)
|
|
# self.print("==batch_valid_length is: ", batch_valid_length, ", transform into: ", self.batch_size, "*", seq_length)
|
|
input_position = batch_valid_length.view(self.batch_size, seq_length)
|
|
position_embedding, _ = self.position_embedding(input_position)
|
|
embed = self.add(word_embedding, position_embedding)
|
|
embed = self.dropout(embed)
|
|
return embed, word_table
|
|
|
|
def get_word_embedding_weight(self):
|
|
return self.word_embedding.embedding_table
|
|
|
|
|
|
class QueryLayer(TransformerEncoderLayer):
|
|
r"""Query Layer at the final layer."""
|
|
|
|
def __init__(
|
|
self,
|
|
batch_size,
|
|
hidden_size,
|
|
ffn_hidden_size,
|
|
num_heads,
|
|
seq_length,
|
|
attention_dropout_rate=0.1,
|
|
hidden_dropout_rate=0.1,
|
|
post_layernorm_residual=False,
|
|
param_init_type=mstype.float32,
|
|
hidden_act="fast_gelu",
|
|
use_past=False,
|
|
parallel_config=None,
|
|
softmax_compute_type=mstype.float32,
|
|
):
|
|
super(QueryLayer, self).__init__(
|
|
batch_size=batch_size,
|
|
hidden_size=hidden_size,
|
|
ffn_hidden_size=ffn_hidden_size,
|
|
num_heads=num_heads,
|
|
seq_length=seq_length,
|
|
attention_dropout_rate=attention_dropout_rate,
|
|
hidden_dropout_rate=hidden_dropout_rate,
|
|
post_layernorm_residual=post_layernorm_residual,
|
|
param_init_type=param_init_type,
|
|
hidden_act=hidden_act,
|
|
use_past=use_past,
|
|
parallel_config=parallel_config.dp_mp_config,
|
|
softmax_compute_type=softmax_compute_type,
|
|
)
|
|
|
|
def construct(
|
|
self,
|
|
x,
|
|
query_vector,
|
|
input_mask,
|
|
init_reset=True,
|
|
batch_valid_length=None,
|
|
):
|
|
r"""
|
|
The forward process of the block.
|
|
"""
|
|
# [bs * seq_length, embedding_size]
|
|
input_x = self.layernorm1(x)
|
|
input_x = F.cast(input_x, self.dtype)
|
|
|
|
# indicate whether reset saved states
|
|
key_reset = None
|
|
value_reset = None
|
|
|
|
if self.use_past:
|
|
# reset states, init_reset True for reuse and False for reset
|
|
key_reset = self.assign(
|
|
self.key_past,
|
|
self.mul(self.key_past, F.cast(init_reset, self.dtype)),
|
|
)
|
|
value_reset = self.assign(
|
|
self.value_past,
|
|
self.mul(self.value_past, F.cast(init_reset, self.dtype)),
|
|
)
|
|
# add dependency for desired execution order
|
|
input_x = F.depend(input_x, key_reset)
|
|
input_x = F.depend(input_x, value_reset)
|
|
|
|
attention, layer_present = self.attention(
|
|
query_vector,
|
|
input_x,
|
|
input_x,
|
|
input_mask,
|
|
self.key_past,
|
|
self.value_past,
|
|
batch_valid_length,
|
|
)
|
|
# For post-layernorm the inputs for residual path are output of self-attention and output of layernorm
|
|
if self.post_layernorm_residual:
|
|
x = self.add(input_x, attention)
|
|
# For pre-layernorm the inputs for residual path are output of self-attention and input of this layer
|
|
else:
|
|
x = self.add(x, attention)
|
|
|
|
output_x = self.layernorm2(x)
|
|
output_x = F.cast(output_x, self.dtype)
|
|
mlp_logit = self.output(output_x)
|
|
|
|
value_update = None
|
|
key_update = None
|
|
if self.use_past:
|
|
# current key and value
|
|
key_present, value_present = layer_present
|
|
# update key and value calculated this step
|
|
key_update = self.assign(self.key_past, key_present)
|
|
value_update = self.assign(self.value_past, value_present)
|
|
# add dependency for desired execution order
|
|
key_update = F.depend(key_update, key_reset)
|
|
value_update = F.depend(value_update, value_reset)
|
|
|
|
# add dependency for desired execution order
|
|
mlp_logit = F.depend(mlp_logit, value_update)
|
|
mlp_logit = F.depend(mlp_logit, key_update)
|
|
|
|
if self.post_layernorm_residual:
|
|
output = self.add(output_x, mlp_logit)
|
|
else:
|
|
output = self.add(x, mlp_logit)
|
|
return output, layer_present
|
|
|
|
|
|
class PanGuHead(Cell):
|
|
"""
|
|
Head to get the logits of each token in the vocab
|
|
Args:
|
|
config(): the config of network
|
|
Inputs:
|
|
state: the output of the backbone
|
|
embedding_table: the embedding table of the vocabulary
|
|
Returns:
|
|
logits: Tensor, the logits of the corresponding inputs
|
|
"""
|
|
|
|
def __init__(
|
|
self, hidden_size, compute_type=mstype.float16, parallel_config=None
|
|
):
|
|
super(PanGuHead, self).__init__()
|
|
if parallel_config.vocab_emb_dp:
|
|
self.matmul = P.MatMul(transpose_b=True).shard(
|
|
((parallel_config.data_parallel, 1), (1, 1))
|
|
)
|
|
else:
|
|
print("====vocab_emb_dp is false", flush=True)
|
|
self.matmul = P.MatMul(transpose_b=True).shard(
|
|
(
|
|
(parallel_config.data_parallel, 1),
|
|
(parallel_config.model_parallel, 1),
|
|
)
|
|
)
|
|
self.hidden_size = hidden_size
|
|
self.dtype = compute_type
|
|
self.cast = P.Cast()
|
|
|
|
def construct(self, state, embed):
|
|
state = P.Reshape()(state, (-1, self.hidden_size))
|
|
# output logits over vocabulary [bs*seq_length, vocab_size]
|
|
logits = self.matmul(
|
|
self.cast(state, self.dtype), self.cast(embed, self.dtype)
|
|
)
|
|
return logits
|
|
|
|
|
|
def set_parallel_configure_for_layer(
|
|
network, layer_id, offset, parallel_config, layers
|
|
):
|
|
r"""
|
|
Default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.
|
|
|
|
|
|
Args:
|
|
network(Cell) - Represents the transformer block
|
|
layer_id(int) - Means the layer index for the current module, counts from zero.
|
|
offset(int) - Means the layer_index needs a offset, if there are other modules in the net.
|
|
layers(int) - The total layers used for the model.
|
|
"""
|
|
# Used for the pipeline's stages setting
|
|
# As the final layer is not included here, so we need to manually add here.
|
|
# original: if set two stages, layers on two stages will be [15, 16+1]
|
|
# with 1 added, the layers on two stages will be [16, 15 +1]
|
|
pp_dis = max(int((layers + 1) / parallel_config.pipeline_stage), 1)
|
|
# the pipeline stage must be in [0, parallel_config.pipeline_stage - 1]
|
|
pp_id = min((layer_id + offset) // pp_dis, parallel_config.pipeline_stage - 1)
|
|
network.pipeline_stage = pp_id
|
|
print(f"pipeline stage id is {pp_id}", flush=True)
|
|
|
|
# Used for optimizer's fusion tag
|
|
dis = max(int((layers + 1) / parallel_config.gradient_aggregation_group), 1)
|
|
if parallel_config.pipeline_stage > 1:
|
|
# we give the fusion in pipeline mode a fixed value, otherwise the performance may become worse.
|
|
network.set_comm_fusion(2)
|
|
else:
|
|
network.set_comm_fusion(int((layer_id + offset) / dis) + 1)
|
|
# Used for enabling recomputation of the block
|
|
if parallel_config.recompute:
|
|
network.recompute(recompute_slice_activation=True)
|
|
|
|
|
|
class PanguAlpha_Model(Cell):
|
|
r"""The base backbone of the PanGuAlpha model"""
|
|
|
|
def __init__(self, config):
|
|
super(PanguAlpha_Model, self).__init__()
|
|
self.is_pipeline = config.parallel_config.pipeline_stage > 1
|
|
self.embedding = EmbeddingLayer(config)
|
|
self.config = config
|
|
self.layernorm = _LayerNorm((config.hidden_size,)).to_float(
|
|
mstype.float32
|
|
)
|
|
if config.parallel_config.pipeline_stage > 1:
|
|
self.layernorm.set_comm_fusion(2)
|
|
else:
|
|
self.layernorm.set_comm_fusion(
|
|
config.parallel_config.gradient_aggregation_group
|
|
)
|
|
self.layernorm.shard(((config.parallel_config.data_parallel, 1),))
|
|
self.layernorm.pipeline_stage = (
|
|
config.parallel_config.pipeline_stage - 1
|
|
)
|
|
# Configure the shard configure of the Embedding layer
|
|
self.embedding.pipeline_stage = 0
|
|
self.num_layers = config.num_layers
|
|
if config.use_moe:
|
|
moe_config = MoEConfig(
|
|
expert_num=config.parallel_config.data_parallel
|
|
* config.per_dp_dim_expert_num
|
|
)
|
|
else:
|
|
moe_config = MoEConfig(expert_num=1)
|
|
# The shard setting of Transformer is set within the class StackedTransformer
|
|
self.blocks = TransformerEncoder(num_layers=config.num_layers - 1,
|
|
batch_size=config.batch_size,
|
|
hidden_size=config.hidden_size,
|
|
ffn_hidden_size=config.ffn_hidden_size,
|
|
num_heads=config.num_heads,
|
|
seq_length=config.seq_length,
|
|
attention_dropout_rate=config.dropout_rate,
|
|
hidden_dropout_rate=config.dropout_rate,
|
|
lambda_func=set_parallel_configure_for_layer,
|
|
hidden_act="fast_gelu",
|
|
param_init_type=config.param_init_type,
|
|
use_past=config.use_past,
|
|
parallel_config=config.parallel_config,
|
|
moe_config=moe_config,
|
|
softmax_compute_type=config.softmax_compute_type).blocks
|
|
for block in self.blocks:
|
|
block.attention.dense1.bias.parallel_optimizer = False
|
|
block.attention.dense2.bias.parallel_optimizer = False
|
|
block.attention.dense3.bias.parallel_optimizer = False
|
|
block.output.mapping.bias.parallel_optimizer = False
|
|
copied_parallel_config = copy.deepcopy(config.parallel_config)
|
|
copied_parallel_config.vocab_emb_dp = True
|
|
self.top_query_embedding = VocabEmbedding(vocab_size=config.seq_length,
|
|
embedding_size=config.hidden_size,
|
|
param_init=initializer("normal",
|
|
[config.seq_length, config.hidden_size],
|
|
dtype=mstype.float32),
|
|
# dtype=config.param_init_type),
|
|
parallel_config=copied_parallel_config.embedding_dp_mp_config)
|
|
self.top_query_embedding.pipeline_stage = config.parallel_config.pipeline_stage - 1
|
|
if config.parallel_config.pipeline_stage > 1:
|
|
self.top_query_embedding.set_comm_fusion(2)
|
|
else:
|
|
self.top_query_embedding.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
|
|
|
|
self.top_query_layer = QueryLayer(batch_size=config.batch_size,
|
|
hidden_size=config.hidden_size,
|
|
ffn_hidden_size=config.ffn_hidden_size,
|
|
num_heads=config.num_heads,
|
|
seq_length=config.seq_length,
|
|
attention_dropout_rate=config.dropout_rate,
|
|
hidden_dropout_rate=config.dropout_rate,
|
|
hidden_act=config.hidden_act,
|
|
param_init_type=config.param_init_type,
|
|
use_past=config.use_past,
|
|
parallel_config=config.parallel_config)
|
|
self.top_query_layer.attention.dense1.bias.parallel_optimizer = False
|
|
self.top_query_layer.attention.dense2.bias.parallel_optimizer = False
|
|
self.top_query_layer.attention.dense3.bias.parallel_optimizer = False
|
|
self.top_query_layer.output.mapping.bias.parallel_optimizer = False
|
|
if config.parallel_config.recompute:
|
|
self.top_query_layer.recompute()
|
|
self.top_query_layer.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
|
|
self.top_query_layer.pipeline_stage = config.parallel_config.pipeline_stage - 1
|
|
|
|
self.dtype = mstype.float16
|
|
|
|
# if config.load_ckpt_path:
|
|
# self.load_embedding_from_ckpt(config.load_ckpt_path)
|
|
|
|
def construct(self, input_ids,
|
|
input_position,
|
|
encoder_masks,
|
|
init_reset=True,
|
|
batch_valid_length=None):
|
|
r"""forward pass of the model"""
|
|
embed, word_table = self.embedding(input_ids, input_position, init_reset, batch_valid_length)
|
|
hidden_state = P.Cast()(embed, self.dtype)
|
|
if init_reset is False:
|
|
hidden_state = self.reshape_to_2d(hidden_state)
|
|
# encoder_mask = self.create_encoder_mask(encoder_masks)
|
|
if self.blocks is not None:
|
|
for i in range(self.num_layers - 1):
|
|
hidden_state, _ = self.blocks[i](hidden_state, encoder_masks, init_reset, batch_valid_length)
|
|
if self.is_pipeline:
|
|
top_query_hidden_states, _ = self.top_query_embedding(input_position)
|
|
top_query_hidden_states = self.reshape_to_2d(top_query_hidden_states)
|
|
encoder_output, _ = self.top_query_layer(hidden_state, top_query_hidden_states,
|
|
encoder_masks, init_reset, batch_valid_length)
|
|
encoder_output = self.layernorm(encoder_output)
|
|
else:
|
|
hidden_state = self.reshape_to_2d(hidden_state)
|
|
encoder_output = self.layernorm(hidden_state)
|
|
encoder_output = P.Cast()(encoder_output, self.dtype)
|
|
top_query_hidden_states, _ = self.top_query_embedding(input_position)
|
|
top_query_hidden_states = self.reshape_to_2d(top_query_hidden_states)
|
|
encoder_output, _ = self.top_query_layer(encoder_output, top_query_hidden_states,
|
|
encoder_masks, init_reset, batch_valid_length)
|
|
|
|
return encoder_output, word_table
|
|
|
|
def reshape_to_2d(self, x):
|
|
r"""reshape nd tensor to 2d, if n <= 2, keep original shape."""
|
|
shape = F.shape(x)
|
|
if len(shape) <= 2:
|
|
return x
|
|
x = F.reshape(x, (-1, shape[-1]))
|
|
return x
|
|
|
|
def load_embedding_from_ckpt(self, load_ckpt_path):
|
|
r"""load the weights from the checkpoint"""
|
|
|
|
def load_param(path):
|
|
if os.path.exists(path):
|
|
p_table = np.load(path)
|
|
table_param = Tensor(p_table, mstype.float32)
|
|
else:
|
|
raise ValueError(
|
|
f"{path} file not exits, "
|
|
f"please check whether embedding file exit."
|
|
)
|
|
return table_param
|
|
|
|
# three embedding needed to be loaded
|
|
# Loading the embedding table from the ckpt path:
|
|
position_embedding_path = os.path.join(load_ckpt_path, 'position_embedding.npy')
|
|
word_embedding_path = os.path.join(load_ckpt_path, 'word_embedding.npy')
|
|
top_query_embedding_path = os.path.join(load_ckpt_path, 'top_query_embedding.npy')
|
|
self.embedding.word_embedding.embedding_table = Parameter(initializer(load_param(word_embedding_path),
|
|
[self.config.vocab_size,
|
|
self.config.hidden_size]),
|
|
name='word_embedding_table', parallel_optimizer=False)
|
|
self.embedding.position_embedding.embedding_table = Parameter(initializer(load_param(position_embedding_path),
|
|
[self.config.seq_length,
|
|
self.config.hidden_size]),
|
|
name='position_embedding_table',
|
|
parallel_optimizer=False)
|
|
self.top_query_embedding.embedding_table = Parameter(initializer(load_param(top_query_embedding_path),
|
|
[self.config.seq_length,
|
|
self.config.hidden_size]),
|
|
name='query_embedding_table', parallel_optimizer=False)
|
|
|
|
|
|
class PanguAlphaModel(nn.Cell):
|
|
"""
|
|
The PanguAlpha network consisting of two parts the backbone and the head
|
|
Args:
|
|
config(PanguAlphaConfig): the config of network
|
|
Inputs:
|
|
input_ids: the tokenized inputs
|
|
input_mask: the mask indicating whether each position is a valid input
|
|
past: the previous feature map
|
|
Returns:
|
|
logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size)
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super(PanguAlphaModel, self).__init__()
|
|
# Network head to get logits over vocabulary
|
|
copied_parallel_config = copy.deepcopy(config.parallel_config)
|
|
if copied_parallel_config.pipeline_stage > 1:
|
|
copied_parallel_config.vocab_emb_dp = False
|
|
self.head = PanGuHead(
|
|
hidden_size=config.hidden_size,
|
|
parallel_config=copied_parallel_config,
|
|
)
|
|
self.head.pipeline_stage = config.parallel_config.pipeline_stage - 1
|
|
self.backbone = PanguAlpha_Model(config)
|
|
self.backbone.embedding.word_embedding.embedding_table.add_pipeline_stage(self.head.pipeline_stage)
|
|
|
|
def construct(self, input_ids, input_position, attention_mask,
|
|
init_reset=True, batch_valid_length=None):
|
|
output_states, word_table = self.backbone(input_ids, input_position, attention_mask,
|
|
init_reset, batch_valid_length)
|
|
logits = self.head(output_states, word_table)
|
|
return logits
|
|
|
|
|
|
class PanGUAlphaWithLoss(Cell):
|
|
"""
|
|
PanguAlpha training loss for generation.
|
|
Args:
|
|
config(PanGUConfig)
|
|
Inputs:
|
|
input_ids: the tokenized inputs
|
|
past: the previous feature map
|
|
Returns:
|
|
output: Tensor, the loss of the network
|
|
"""
|
|
|
|
def __init__(self, config, network, loss):
|
|
super(PanGUAlphaWithLoss, self).__init__(auto_prefix=False)
|
|
self.batch_size = config.batch_size
|
|
self.seq_length = config.seq_length
|
|
dp = config.parallel_config.data_parallel
|
|
self.network = network
|
|
self.eod_token = config.eod_token
|
|
self.loss = loss
|
|
|
|
self.slice = P.StridedSlice().shard(((dp, 1),))
|
|
self.not_equal = P.NotEqual().shard(((dp, 1), ()))
|
|
# self.batch_size = config.batch_size
|
|
self.len = config.seq_length
|
|
self.slice2 = P.StridedSlice().shard(((dp, 1, 1),))
|
|
self.micro_batch_step = 1
|
|
self.print = P.Print()
|
|
if config.parallel_config.pipeline_stage > 1:
|
|
self.micro_batch_step = config.parallel_config.micro_batch_num
|
|
|
|
def construct(self, input_ids, input_position=None, attention_mask=None):
|
|
r"""Forward process of the pangu alpha model"""
|
|
tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1))
|
|
# P.Print()("==net tokens is:", tokens)
|
|
input_position = self.slice(input_position, (0, 0), (self.batch_size, self.len), (1, 1))
|
|
decoder_attention_masks = self.slice2(attention_mask, (0, 0, 0), (self.batch_size, self.len, self.len),
|
|
(1, 1, 1))
|
|
input_mask = F.cast(self.not_equal(tokens, self.eod_token), mstype.float32)
|
|
logits = self.network(tokens, input_position, decoder_attention_masks)
|
|
# P.Print()("==logits_is:", logits, ",shape is:", logits.shape)
|
|
# Get label corresponding to input tokens
|
|
labels = self.slice(input_ids, (0, 1), (self.batch_size, self.len + 1), (1, 1))
|
|
labels = P.Reshape()(labels, (-1,))
|
|
input_mask = P.Reshape()(input_mask, (-1,))
|
|
# P.Print()("==input_mask is:", input_mask)
|
|
output = self.loss(logits, labels, input_mask)
|
|
# P.Print()("==net output is:", output)
|
|
return output
|
|
|
|
|
|
class EvalNet(nn.Cell):
|
|
"""
|
|
PanguAlpha evaluation net
|
|
Args:
|
|
backbone: backbone network of PanguAlpha
|
|
generate: enable generate mode
|
|
Inputs:
|
|
input_ids: the tokenized inpus
|
|
current_index: the index of current token
|
|
init_reset: whether reset saved states
|
|
Returns:
|
|
outputs: Tensor, corresponding output for different tasks
|
|
"""
|
|
|
|
def __init__(self, backbone, generate=False, pad_token=6, seq_length=2048):
|
|
super(EvalNet, self).__init__(auto_prefix=False)
|
|
self.backbone = backbone
|
|
self.pad_token = pad_token
|
|
self.argmax = P.Argmax()
|
|
self.generate = generate
|
|
self.topk = P.TopK(sorted=True).shard(((1, 1),))
|
|
self.gather = P.Gather().shard(((1, 1), (1,)))
|
|
self.log_softmax = P.LogSoftmax().shard(((1, 1, 1),))
|
|
self.get_attention_mask = AttentionMask(seq_length)
|
|
self.expand = P.ExpandDims().shard(((1, 1, 1),))
|
|
self.all_ones_attention_mask = Tensor(np.ones((1, 1, seq_length)), mstype.float32)
|
|
self.not_equal = P.NotEqual().shard(((1, 1), ()))
|
|
|
|
def construct(self, input_ids, current_index, init_reset=True, batch_valid_length=None):
|
|
"""evaluation net"""
|
|
# input_mask = F.cast(F.not_equal(input_ids, self.pad_token), mstype.float32)
|
|
input_mask = F.cast(self.not_equal(input_ids, self.pad_token), mstype.float32)
|
|
bs, seq_length = F.shape(input_ids)
|
|
if self.is_first_iteration is False:
|
|
attention_mask = P.Tile()(self.all_ones_attention_mask, (bs, 1, 1))
|
|
else:
|
|
attention_mask = self.get_attention_mask(input_mask)
|
|
input_position = F.tuple_to_array(F.make_range(seq_length))
|
|
input_position = P.Tile()(input_position, (bs, 1))
|
|
logits = self.backbone(input_ids, input_position, attention_mask,
|
|
init_reset, batch_valid_length)
|
|
index = current_index.view(-1, )
|
|
# P.Print()("==logits_is:", logits, ",shape is:", logits.shape)
|
|
# P.Print()("==index_is:", index, ",shape is:", index.shape)
|
|
logits = self.gather(logits, index, 0)
|
|
logits = logits.view(bs, 1, -1)
|
|
log_probs = self.log_softmax(logits)
|
|
return log_probs
|
|
|
|
|
|
class LogitsNet(nn.Cell):
|
|
"""
|
|
PanguAlpha evaluation net
|
|
Args:
|
|
backbone: backbone network of PanguAlpha
|
|
generate: enable generate mode
|
|
Inputs:
|
|
input_ids: the tokenized inpus
|
|
init_reset: whether reset saved states
|
|
Returns:
|
|
outputs: Tensor, corresponding output for different tasks
|
|
"""
|
|
|
|
def __init__(self, backbone, generate=False, pad_token=6, seq_length=2048):
|
|
super(LogitsNet, self).__init__(auto_prefix=False)
|
|
self.backbone = backbone
|
|
self.pad_token = pad_token
|
|
self.argmax = P.Argmax()
|
|
self.generate = generate
|
|
self.topk = P.TopK(sorted=True).shard(((1, 1),))
|
|
self.gather = P.Gather().shard(((1, 1), (1,)))
|
|
self.log_softmax = P.LogSoftmax().shard(((1, 1, 1),))
|
|
self.get_attention_mask = AttentionMask(seq_length)
|
|
self.expand = P.ExpandDims().shard(((1, 1, 1),))
|
|
self.all_ones_attention_mask = Tensor(np.ones((1, 1, seq_length)), mstype.float32)
|
|
self.not_equal = P.NotEqual().shard(((1, 1), ()))
|
|
|
|
def construct(self, input_ids, init_reset=True, batch_valid_length=None, attention_mask=None):
|
|
"""evaluation net"""
|
|
# input_mask = F.cast(F.not_equal(input_ids, self.pad_token), mstype.float32)
|
|
input_mask = F.cast(self.not_equal(input_ids, self.pad_token), mstype.float32)
|
|
bs, seq_length = F.shape(input_ids)
|
|
if attention_mask is None:
|
|
if self.is_first_iteration is False:
|
|
attention_mask = P.Tile()(self.all_ones_attention_mask, (bs, 1, 1))
|
|
else:
|
|
attention_mask = self.get_attention_mask(input_mask)
|
|
input_position = F.tuple_to_array(F.make_range(seq_length))
|
|
input_position = P.Tile()(input_position, (bs, 1))
|
|
logits = self.backbone(input_ids, input_position, attention_mask,
|
|
init_reset, batch_valid_length)
|
|
|
|
return logits
|
|
|
|
|
|
class PanGUAlphaWithFinetuneLoss(Cell):
|
|
"""
|
|
PanguAlpha training loss for generation.
|
|
Args:
|
|
config(PanGUConfig)
|
|
Inputs:
|
|
input_ids: the tokenized inputs
|
|
past: the previous feature map
|
|
Returns:
|
|
output: Tensor, the loss of the network
|
|
"""
|
|
|
|
def __init__(self, config, network, loss):
|
|
super(PanGUAlphaWithFinetuneLoss, self).__init__(auto_prefix=False)
|
|
self.batch_size = config.batch_size
|
|
self.seq_length = config.seq_length
|
|
dp = config.parallel_config.data_parallel
|
|
self.network = network
|
|
self.eod_token = config.eod_token
|
|
self.loss = loss
|
|
|
|
self.slice = P.StridedSlice().shard(((dp, 1),))
|
|
self.not_equal = P.NotEqual().shard(((dp, 1), ()))
|
|
# self.batch_size = config.batch_size
|
|
self.len = config.seq_length
|
|
self.slice2 = P.StridedSlice().shard(((dp, 1, 1),))
|
|
self.micro_batch_step = 1
|
|
self.print = P.Print()
|
|
if config.parallel_config.pipeline_stage > 1:
|
|
self.micro_batch_step = config.parallel_config.micro_batch_num
|
|
|
|
def construct(self, input_ids, loss_mask, input_position, attention_mask):
|
|
r"""Forward process of the pangu alpha model"""
|
|
tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1))
|
|
# P.Print()("==net tokens is:", tokens)
|
|
input_position = self.slice(input_position, (0, 0), (self.batch_size, self.len), (1, 1))
|
|
decoder_attention_masks = self.slice2(attention_mask, (0, 0, 0), (self.batch_size, self.len, self.len),
|
|
(1, 1, 1))
|
|
input_mask = F.cast(self.not_equal(tokens, self.eod_token), mstype.float32)
|
|
logits = self.network(tokens, input_position, decoder_attention_masks)
|
|
# P.Print()("===logits: ", logits, ", shape: ", logits.shape)
|
|
# Get label corresponding to input tokens
|
|
labels = self.slice(input_ids, (0, 1), (self.batch_size, self.len + 1), (1, 1))
|
|
labels = P.Reshape()(labels, (-1,))
|
|
input_mask = P.Reshape()(loss_mask, (-1,))
|
|
# P.Print()("===labels: ", labels, ", shape: ", labels.shape)
|
|
# input_mask = P.Reshape()(input_mask, (-1,))
|
|
# P.Print()("===input_mask: ", input_mask, ", shape: ", input_mask.shape)
|
|
output = self.loss(logits, labels, input_mask)
|
|
# P.Print()("==net output is:", output)
|
|
return output
|