mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
110 lines
3.9 KiB
Python
110 lines
3.9 KiB
Python
# coding=utf-8
|
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import torch
|
|
from codegeex.megatron import get_args
|
|
from codegeex.megatron import mpu
|
|
from .module import MegatronModule
|
|
|
|
from .language_model import parallel_lm_logits
|
|
from .language_model import get_language_model
|
|
from .utils import init_method_normal
|
|
from .utils import scaled_init_method_normal
|
|
|
|
|
|
class CodeGeeXModel(MegatronModule):
|
|
"""Code Generative Model for Multilingual Program Synthesis."""
|
|
|
|
def __init__(self, num_tokentypes=0, parallel_output=False):
|
|
super(CodeGeeXModel, self).__init__()
|
|
args = get_args()
|
|
|
|
self.parallel_output = parallel_output
|
|
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
|
|
|
|
self.language_model, self._language_model_key = get_language_model(
|
|
num_tokentypes=num_tokentypes,
|
|
add_pooler=False,
|
|
init_method=init_method_normal(args.init_method_std),
|
|
scaled_init_method=scaled_init_method_normal(args.init_method_std,
|
|
args.num_layers))
|
|
|
|
def forward(
|
|
self,
|
|
input_ids,
|
|
position_ids,
|
|
attention_mask,
|
|
labels=None,
|
|
tokentype_ids=None,
|
|
layer_past=None,
|
|
get_key_value=False,
|
|
forward_method_parallel_output=None,
|
|
prompt_length=None,
|
|
context_length=None,
|
|
):
|
|
|
|
# Language model.
|
|
lm_output = self.language_model(input_ids,
|
|
position_ids,
|
|
attention_mask,
|
|
tokentype_ids=tokentype_ids,
|
|
layer_past=layer_past,
|
|
get_key_value=get_key_value,
|
|
prompt_length=prompt_length,
|
|
context_length=context_length)
|
|
|
|
if get_key_value:
|
|
lm_output, presents = lm_output
|
|
|
|
lm_output = torch.add(lm_output, 0)
|
|
# Output.
|
|
parallel_output = self.parallel_output
|
|
if forward_method_parallel_output is not None:
|
|
parallel_output = forward_method_parallel_output
|
|
output = parallel_lm_logits(
|
|
lm_output,
|
|
self.language_model.embedding.word_embeddings.weight,
|
|
parallel_output)
|
|
|
|
if get_key_value:
|
|
output = [output, presents]
|
|
|
|
if labels is None:
|
|
return output
|
|
else:
|
|
if self.fp16_lm_cross_entropy:
|
|
assert output.dtype == torch.half
|
|
loss = mpu.vocab_parallel_cross_entropy(output, labels)
|
|
else:
|
|
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
|
|
|
|
return loss
|
|
|
|
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
|
|
keep_vars=False):
|
|
|
|
state_dict_ = {}
|
|
state_dict_[self._language_model_key] \
|
|
= self.language_model.state_dict_for_save_checkpoint(
|
|
destination, prefix, keep_vars)
|
|
return state_dict_
|
|
|
|
def load_state_dict(self, state_dict, strict=True):
|
|
"""Customized load."""
|
|
|
|
if self._language_model_key in state_dict:
|
|
state_dict = state_dict[self._language_model_key]
|
|
self.language_model.load_state_dict(state_dict, strict=strict)
|