mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
217 lines
7.7 KiB
Python
217 lines
7.7 KiB
Python
# coding=utf-8
|
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import torch
|
|
from codegeex.megatron import get_args, mpu
|
|
from codegeex.megatron.model import LayerNorm
|
|
from codegeex.megatron.enums import AttnMaskType
|
|
from codegeex.megatron.model.module import MegatronModule
|
|
from codegeex.megatron.model.language_model import parallel_lm_logits, get_language_model, EmbeddingPipe, QueryEmbeddingPipe
|
|
from codegeex.megatron.model.utils import init_method_normal, scaled_init_method_normal
|
|
from codegeex.megatron.model.transformer import ParallelTransformerLayerPipe, ParallelTopQueryLayerPipe
|
|
from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec
|
|
|
|
|
|
class CodeGeeXModel(MegatronModule):
|
|
"""Code Generation Model for Multilingual Program Synthesis."""
|
|
|
|
def __init__(self, num_tokentypes=0, parallel_output=False):
|
|
super(CodeGeeXModel, self).__init__()
|
|
args = get_args()
|
|
|
|
self.parallel_output = parallel_output
|
|
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
|
|
|
|
self.language_model, self._language_model_key = get_language_model(
|
|
num_tokentypes=num_tokentypes,
|
|
add_pooler=False,
|
|
init_method=init_method_normal(args.init_method_std),
|
|
scaled_init_method=scaled_init_method_normal(args.init_method_std,
|
|
args.num_layers))
|
|
|
|
def set_input_tensor(self, input_tensor):
|
|
"""See megatron.model.transformer.set_input_tensor()"""
|
|
self.language_model.set_input_tensor(input_tensor)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids,
|
|
position_ids,
|
|
attention_mask,
|
|
labels=None,
|
|
tokentype_ids=None,
|
|
layer_past=None,
|
|
get_key_value=False,
|
|
forward_method_parallel_output=None,
|
|
prompt_length=None,
|
|
context_length=None,
|
|
):
|
|
|
|
# Language model.
|
|
lm_output = self.language_model(input_ids,
|
|
position_ids,
|
|
attention_mask,
|
|
tokentype_ids=tokentype_ids,
|
|
layer_past=layer_past,
|
|
get_key_value=get_key_value,
|
|
prompt_length=prompt_length,
|
|
context_length=context_length)
|
|
|
|
if get_key_value:
|
|
lm_output, presents = lm_output
|
|
|
|
lm_output = torch.add(lm_output, 0)
|
|
# Output.
|
|
parallel_output = self.parallel_output
|
|
if forward_method_parallel_output is not None:
|
|
parallel_output = forward_method_parallel_output
|
|
output = parallel_lm_logits(
|
|
lm_output,
|
|
self.language_model.embedding.word_embeddings.weight,
|
|
parallel_output)
|
|
|
|
if get_key_value:
|
|
output = [output, presents]
|
|
|
|
if labels is None:
|
|
return output
|
|
else:
|
|
if self.fp16_lm_cross_entropy:
|
|
assert output.dtype == torch.half
|
|
loss = mpu.vocab_parallel_cross_entropy(output, labels)
|
|
else:
|
|
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
|
|
|
|
return loss
|
|
|
|
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
|
|
keep_vars=False):
|
|
|
|
state_dict_ = {}
|
|
state_dict_[self._language_model_key] \
|
|
= self.language_model.state_dict_for_save_checkpoint(
|
|
destination, prefix, keep_vars)
|
|
return state_dict_
|
|
|
|
def load_state_dict(self, state_dict, strict=True):
|
|
"""Customized load."""
|
|
|
|
if self._language_model_key in state_dict:
|
|
state_dict = state_dict[self._language_model_key]
|
|
self.language_model.load_state_dict(state_dict, strict=strict)
|
|
|
|
|
|
def CrossEntropy(output, labels):
|
|
labels, loss_mask = labels[0], labels[1]
|
|
|
|
args = get_args()
|
|
|
|
losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels)
|
|
loss_mask = loss_mask.view(-1)
|
|
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
|
|
return loss
|
|
|
|
|
|
class CodeGeeXModelPipe(PipelineModule, MegatronModule):
|
|
"""Pipeline version of CodeGeeX."""
|
|
|
|
def __init__(self, num_tokentypes=0, parallel_output=True):
|
|
args = get_args()
|
|
self.parallel_output = parallel_output
|
|
|
|
init_method = init_method_normal(args.init_method_std)
|
|
|
|
self.specs = []
|
|
|
|
# Embedding layer
|
|
self.specs.append(
|
|
TiedLayerSpec(
|
|
"embed",
|
|
EmbeddingPipe,
|
|
args.hidden_size,
|
|
args.padded_vocab_size,
|
|
args.max_position_embeddings,
|
|
args.hidden_dropout,
|
|
init_method=init_method,
|
|
num_tokentypes=num_tokentypes,
|
|
tied_weight_attr="word_embeddings_weight",
|
|
)
|
|
)
|
|
|
|
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
|
|
|
|
for layer_idx in range(args.num_layers):
|
|
self.specs.append(
|
|
LayerSpec(
|
|
ParallelTransformerLayerPipe,
|
|
init_method=init_method,
|
|
output_layer_init_method=scaled_init_method_normal(
|
|
args.init_method_std, args.num_layers
|
|
),
|
|
layer_number=layer_idx,
|
|
self_attn_mask_type=AttnMaskType.causal,
|
|
)
|
|
)
|
|
|
|
# Undo data format change
|
|
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
|
|
|
|
# Final layernorm after transformer layers
|
|
self.specs.append(
|
|
LayerSpec(LayerNorm, args.hidden_size, eps=args.layernorm_epsilon)
|
|
)
|
|
|
|
def _logits_helper(embedding, lm_output):
|
|
"""A wrapper to massage inputs/outputs from pipeline."""
|
|
return parallel_lm_logits(
|
|
lm_output, embedding.word_embeddings_weight, self.parallel_output
|
|
)
|
|
|
|
self.specs.append(
|
|
TiedLayerSpec(
|
|
"embed",
|
|
EmbeddingPipe,
|
|
args.hidden_size,
|
|
args.padded_vocab_size,
|
|
args.max_position_embeddings,
|
|
args.hidden_dropout,
|
|
init_method=init_method,
|
|
num_tokentypes=num_tokentypes,
|
|
forward_fn=_logits_helper,
|
|
tied_weight_attr="word_embeddings_weight",
|
|
)
|
|
)
|
|
|
|
if args.checkpoint_activations:
|
|
interval = args.checkpoint_num_layers
|
|
else:
|
|
interval = 0
|
|
|
|
from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
|
|
|
|
topo = PipeModelDataParallelTopology(
|
|
num_pp=mpu.get_pipeline_model_parallel_world_size(),
|
|
num_mp=mpu.get_tensor_model_parallel_world_size(),
|
|
num_dp=mpu.get_data_parallel_world_size(),
|
|
)
|
|
|
|
super().__init__(
|
|
layers=self.specs,
|
|
loss_fn=CrossEntropy,
|
|
topology=topo,
|
|
activation_checkpoint_interval=interval,
|
|
partition_method="type:transformer",
|
|
)
|