mirror of https://github.com/THUDM/CodeGeeX.git
commit
d91ed519ce
@ -0,0 +1,70 @@
|
||||
import copy
|
||||
|
||||
from typing import *
|
||||
from codegeex.megatron.model import CodeGeeXModel
|
||||
from codegeex.tokenizer import CodeGeeXTokenizer
|
||||
from codegeex.torch.inference import get_token_stream
|
||||
|
||||
|
||||
def get_model(
|
||||
backend: str = "megatron",
|
||||
quantized: bool = False,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def generate(
|
||||
model: CodeGeeXModel,
|
||||
tokenizer: CodeGeeXTokenizer,
|
||||
prompt: str,
|
||||
out_seq_length: int,
|
||||
seq_length: int = 2048,
|
||||
top_k: int = 0,
|
||||
top_p: float = 1.0,
|
||||
temperature: float = 1.0,
|
||||
micro_batch_size: int = 1,
|
||||
backend: str = "megatron",
|
||||
greedy: bool = False,
|
||||
verbose: bool = False,
|
||||
):
|
||||
tokens = tokenizer.encode_code(prompt)
|
||||
n_token_prompt = len(tokens)
|
||||
|
||||
if verbose:
|
||||
print(f"Current prompt:\n{prompt}")
|
||||
print("N_token_prompt:", n_token_prompt)
|
||||
|
||||
generated_codes = []
|
||||
if backend == "megatron":
|
||||
token_stream = get_token_stream(
|
||||
model,
|
||||
tokenizer,
|
||||
seq_length,
|
||||
out_seq_length,
|
||||
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
|
||||
micro_batch_size=micro_batch_size,
|
||||
topk=top_k,
|
||||
topp=top_p,
|
||||
temperature=temperature,
|
||||
greedy=greedy,
|
||||
)
|
||||
is_finished = [False for _ in range(micro_batch_size)]
|
||||
for i, generated in enumerate(token_stream):
|
||||
generated_tokens = generated[0]
|
||||
for j in range(micro_batch_size):
|
||||
if is_finished[j]:
|
||||
continue
|
||||
|
||||
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len(generated_tokens[j]) >= out_seq_length:
|
||||
is_finished[j] = True
|
||||
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
|
||||
generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:])
|
||||
generated_code = "".join(generated_code)
|
||||
generated_codes.append(generated_code)
|
||||
if verbose:
|
||||
print(f"\nGenerated code {i}:\n{generated_code}")
|
||||
|
||||
if all(is_finished):
|
||||
break
|
||||
|
||||
return generated_codes
|
@ -0,0 +1,133 @@
|
||||
"""Merge model parallel partitions into a single checkpoint."""
|
||||
|
||||
import os
|
||||
import torch
|
||||
import random
|
||||
|
||||
from codegeex.megatron import get_args
|
||||
from codegeex.megatron.model import CodeGeeXModel
|
||||
from codegeex.megatron.initialize import initialize_megatron
|
||||
from codegeex.megatron.checkpointing import ensure_directory_exists
|
||||
|
||||
|
||||
def get_change_ckpt_args(parser):
|
||||
"""Provide extra arguments required for merging."""
|
||||
group = parser.add_argument_group(title='Mindspore to megatron')
|
||||
group.add_argument(
|
||||
'--load-ckpt-path',
|
||||
type=str,
|
||||
required=True,
|
||||
help='dir to load model parallel partitions.',
|
||||
)
|
||||
group.add_argument(
|
||||
'--save-ckpt-path',
|
||||
type=str,
|
||||
required=True,
|
||||
help='path to save ".pt" checkpoint.',
|
||||
)
|
||||
group.add_argument(
|
||||
'--source-tensor-model-parallel-size',
|
||||
type=int,
|
||||
default=2,
|
||||
help='original tensor model parallel size',
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(random.randint(10000, 20000))
|
||||
|
||||
initialize_megatron(
|
||||
extra_args_provider=get_change_ckpt_args,
|
||||
args_defaults={
|
||||
"tokenizer_type": "GPT2BPETokenizer",
|
||||
"no_load_rng" : True,
|
||||
"no_load_optim" : True,
|
||||
},
|
||||
)
|
||||
|
||||
args = get_args()
|
||||
model = CodeGeeXModel()
|
||||
print(model.state_dict)
|
||||
|
||||
# Save the model.
|
||||
sd = {}
|
||||
sd['module'] = model.state_dict_for_save_checkpoint()
|
||||
ensure_directory_exists(args.save_ckpt_path)
|
||||
|
||||
print(f"Load ckpt from {args.load_ckpt_path}...")
|
||||
state_dict_list = []
|
||||
for i in range(args.source_tensor_model_parallel_size):
|
||||
try:
|
||||
state_dict_list.append(torch.load(os.path.join(args.load_ckpt_path, f"mp_rank_{i:02d}_model_states.pt"), map_location="cpu"))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
exit(0)
|
||||
|
||||
print(f"Merging {len(state_dict_list)} partitions into a single ckpt...")
|
||||
print("Merging Embedding layers...")
|
||||
vocab_parallel_size = args.make_vocab_size_divisible_by // args.source_tensor_model_parallel_size
|
||||
for i in range(args.source_tensor_model_parallel_size):
|
||||
sd['module']['language_model']['embedding']['word_embeddings']['weight'][i * vocab_parallel_size : (i + 1) * vocab_parallel_size, :] = state_dict_list[i]['module']['language_model']['embedding']['word_embeddings']['weight']
|
||||
|
||||
sd['module']['language_model']['embedding']['position_embeddings']['weight'] = state_dict_list[0]['module']['language_model']['embedding']['position_embeddings']['weight']
|
||||
|
||||
print("Merging QueryEmbedding layers...")
|
||||
query_parallel_size = args.max_position_embeddings // args.source_tensor_model_parallel_size
|
||||
for i in range(args.source_tensor_model_parallel_size):
|
||||
sd['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight'][i * query_parallel_size : (i + 1) * query_parallel_size, :] = state_dict_list[i]['module']['language_model']['topQueryEmbedding']['top_query_embeddings'].pop('weight', None)
|
||||
|
||||
print("Merging Transformer layers...")
|
||||
for layer_name in sd['module']['language_model']['transformer'].keys():
|
||||
if "layernorm" in layer_name:
|
||||
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None)
|
||||
elif "attention" in layer_name and "weight" in layer_name:
|
||||
if "dense" in layer_name:
|
||||
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[1] // args.source_tensor_model_parallel_size
|
||||
for i in range(args.source_tensor_model_parallel_size):
|
||||
sd['module']['language_model']['transformer'][layer_name][:, i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
|
||||
else:
|
||||
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size
|
||||
for i in range(args.source_tensor_model_parallel_size):
|
||||
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size, :] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
|
||||
elif "weight" in layer_name and "dense" in layer_name:
|
||||
if "h_to_4h" in layer_name:
|
||||
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size
|
||||
for i in range(args.source_tensor_model_parallel_size):
|
||||
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size, :] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
|
||||
else:
|
||||
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[1] // args.source_tensor_model_parallel_size
|
||||
for i in range(args.source_tensor_model_parallel_size):
|
||||
sd['module']['language_model']['transformer'][layer_name][:, i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
|
||||
elif "bias" in layer_name:
|
||||
if "mlp" in layer_name:
|
||||
if "4h_to_h" in layer_name:
|
||||
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None)
|
||||
else:
|
||||
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size
|
||||
for i in range(args.source_tensor_model_parallel_size):
|
||||
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
|
||||
elif "attention" in layer_name:
|
||||
if "dense" in layer_name:
|
||||
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None)
|
||||
else:
|
||||
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size
|
||||
for i in range(args.source_tensor_model_parallel_size):
|
||||
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
|
||||
else:
|
||||
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None)
|
||||
|
||||
if args.save_ckpt_path.endswith(".pt"):
|
||||
save_ckpt_path = args.save_ckpt_path
|
||||
else:
|
||||
os.makedirs(args.save_ckpt_path, exist_ok=True)
|
||||
save_ckpt_path = os.path.join(args.save_ckpt_path, "mp_rank_00_model_states.pt")
|
||||
|
||||
torch.save(sd, save_ckpt_path)
|
||||
print(f"Converted checkpoint saved in {save_ckpt_path}.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,326 @@
|
||||
import os
|
||||
import torch
|
||||
import logging
|
||||
|
||||
logging.getLogger("torch").setLevel(logging.WARNING)
|
||||
|
||||
import deepspeed
|
||||
from deepspeed.runtime.utils import see_memory_usage
|
||||
from functools import partial
|
||||
|
||||
from codegeex.megatron import get_args, print_rank_0, get_timers,get_tokenizer, mpu
|
||||
from codegeex.megatron.data.prompt_dataset import build_train_valid_test_datasets
|
||||
from codegeex.megatron.model import CodeGeeXModel
|
||||
from codegeex.megatron.training import pretrain
|
||||
from codegeex.megatron.utils import get_ltor_masks_and_position_ids
|
||||
from codegeex.megatron.utils import average_losses_across_data_parallel_group
|
||||
|
||||
|
||||
def model_provider(pre_process=True, post_process=True):
|
||||
"""Build the model."""
|
||||
|
||||
print_rank_0("building GPT model ...")
|
||||
see_memory_usage(f"Before Building Model", force=True)
|
||||
|
||||
args = get_args()
|
||||
with deepspeed.zero.Init(
|
||||
data_parallel_group=mpu.get_data_parallel_group(),
|
||||
remote_device=None if args.remote_device == "none" else args.remote_device,
|
||||
config_dict_or_path=args.deepspeed_config,
|
||||
enabled=args.zero_stage == 3,
|
||||
mpu=mpu,
|
||||
):
|
||||
if args.deepspeed and not args.no_pipeline_parallel:
|
||||
model = CodeGeeXModelPipe(num_tokentypes=0, parallel_output=True)
|
||||
# This is a hack to give us a reference to get_batch_pipe from within training.py
|
||||
# We need to call model.set_batch_fn after deepspeed.initialize
|
||||
model._megatron_batch_fn = get_batch_pipe
|
||||
|
||||
# Predompute the attention mask and store it in args. This avoids having to
|
||||
# pipeline it as an activation during training. The mask is constant, and thus
|
||||
# we can reuse it.
|
||||
attention_mask = torch.tril(
|
||||
torch.ones(
|
||||
(1, args.seq_length, args.seq_length),
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
).view(1, 1, args.seq_length, args.seq_length)
|
||||
|
||||
# Convert attention mask to binary:
|
||||
attention_mask = attention_mask < 0.5
|
||||
if args.fp16:
|
||||
attention_mask = attention_mask.half()
|
||||
elif args.bf16:
|
||||
attention_mask = attention_mask.bfloat16()
|
||||
|
||||
# Attention mask must be bool.
|
||||
args.attn_mask = attention_mask.to(torch.bool)
|
||||
|
||||
else:
|
||||
model = CodeGeeXModel(
|
||||
num_tokentypes=0,
|
||||
parallel_output=True,
|
||||
)
|
||||
|
||||
if args.load_state is not None:
|
||||
timers = get_timers()
|
||||
print_rank_0("Loading warmstarting model states ...")
|
||||
timers("load-model-states").start()
|
||||
mp_rank = mpu.get_tensor_model_parallel_rank()
|
||||
if os.path.isdir(args.load_state):
|
||||
model_path = os.path.join(
|
||||
args.load_state, "mp_rank_{:02d}_model_states.pt".format(mp_rank)
|
||||
)
|
||||
else:
|
||||
model_path = args.load_state
|
||||
print_rank_0(f"Loading model from {model_path} ...")
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
if "module" in state_dict:
|
||||
state_dict = state_dict["module"] # strip other client states
|
||||
model.load_state_dict(state_dict)
|
||||
timers("load-model-states").stop()
|
||||
timers.log(["load-model-states"])
|
||||
see_memory_usage(f"After Building Model", force=True)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_batch(data_iterator):
|
||||
"""Generate a batch"""
|
||||
args = get_args()
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# Items and their type.
|
||||
keys = ["input_ids", "attention_mask", "labels"]
|
||||
datatype = torch.int64
|
||||
|
||||
# Broadcast data.
|
||||
if data_iterator is not None:
|
||||
data = next(data_iterator)
|
||||
else:
|
||||
data = None
|
||||
data_b = mpu.broadcast_data(keys, data, datatype)
|
||||
|
||||
# Unpack.
|
||||
tokens_ = data_b["input_ids"].contiguous()
|
||||
# attn_mask_ = data_b["attention_mask"].contiguous()
|
||||
labels_ = data_b["labels"].contiguous()
|
||||
|
||||
tokens = tokens_[:, :-1]
|
||||
labels = labels_[:, 1:]
|
||||
|
||||
# Get the masks and postition ids.
|
||||
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
|
||||
tokens,
|
||||
tokenizer.eod,
|
||||
args.reset_position_ids,
|
||||
args.reset_attention_mask,
|
||||
args.eod_mask_loss,
|
||||
)
|
||||
|
||||
# mask loss to avoid predicting prompt and paddings
|
||||
prompt_loss_mask = labels >= 0
|
||||
loss_mask = prompt_loss_mask * loss_mask
|
||||
|
||||
return tokens, labels, loss_mask, attention_mask, position_ids
|
||||
|
||||
|
||||
def get_batch_pipe(data):
|
||||
"""Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`"""
|
||||
args = get_args()
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# Items and their type.
|
||||
keys = ["input_ids"]
|
||||
datatype = torch.int64
|
||||
|
||||
# Broadcast data.
|
||||
data_b = mpu.broadcast_data(keys, data, datatype)
|
||||
|
||||
# Unpack.
|
||||
tokens_ = data_b["input_ids"].long()
|
||||
labels = tokens_[:, 1:].contiguous()
|
||||
tokens = tokens_[:, :-1].contiguous()
|
||||
|
||||
# Get the masks and postition ids.
|
||||
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
|
||||
tokens,
|
||||
tokenizer.eod,
|
||||
args.reset_position_ids,
|
||||
args.reset_attention_mask,
|
||||
args.eod_mask_loss,
|
||||
)
|
||||
|
||||
return (tokens, position_ids, attention_mask), (labels, loss_mask)
|
||||
|
||||
|
||||
def loss_func(loss_mask, output_tensor):
|
||||
args = get_args()
|
||||
|
||||
def compute_lm_loss(losses: torch.Tensor, loss_mask: torch.Tensor):
|
||||
if args.gold:
|
||||
losses_ = losses.detach()
|
||||
prob = torch.exp(-losses_) # Pθ(s)
|
||||
torch.sqrt_(prob) # Pθ(s)ᵃ
|
||||
torch.clamp_min_(prob, args.gold_beta) # max(Pθ(s)ᵃ,β)
|
||||
losses = prob * losses
|
||||
|
||||
loss_mask = loss_mask.view(-1).float()
|
||||
loss = torch.sum(losses.view(-1) * loss_mask) / torch.clamp_min(loss_mask.sum(), 1e-8)
|
||||
|
||||
return loss
|
||||
|
||||
losses = output_tensor.float()
|
||||
loss = compute_lm_loss(losses, loss_mask)
|
||||
|
||||
# Reduce loss for logging.
|
||||
averaged_loss = average_losses_across_data_parallel_group([loss])
|
||||
|
||||
return loss, {"lm loss": averaged_loss[0]}
|
||||
|
||||
|
||||
def valid_loss_func(loss_mask, output_tensor):
|
||||
args = get_args()
|
||||
|
||||
def compute_lm_loss(losses: torch.Tensor, loss_mask: torch.Tensor):
|
||||
loss_mask = loss_mask.view(-1).float()
|
||||
loss = torch.sum(losses.view(-1) * loss_mask) / torch.clamp_min(loss_mask.sum(), 1e-8)
|
||||
|
||||
return loss
|
||||
|
||||
losses = output_tensor.float()
|
||||
loss = compute_lm_loss(losses, loss_mask)
|
||||
|
||||
# Reduce loss for logging.
|
||||
averaged_loss = average_losses_across_data_parallel_group([loss])
|
||||
|
||||
return loss, {"lm loss": averaged_loss[0]}
|
||||
|
||||
|
||||
def forward_step(data_iterator, model):
|
||||
"""Forward step."""
|
||||
args = get_args()
|
||||
timers = get_timers()
|
||||
|
||||
# Get the batch.
|
||||
timers("batch-generator").start()
|
||||
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
|
||||
timers("batch-generator").stop()
|
||||
|
||||
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
|
||||
|
||||
return output_tensor, partial(loss_func, loss_mask)
|
||||
|
||||
|
||||
def valid_forward_step(data_iterator, model):
|
||||
"""Forward step."""
|
||||
args = get_args()
|
||||
timers = get_timers()
|
||||
|
||||
# Get the batch.
|
||||
timers("batch-generator").start()
|
||||
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
|
||||
timers("batch-generator").stop()
|
||||
|
||||
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
|
||||
|
||||
return output_tensor, partial(valid_loss_func, loss_mask)
|
||||
|
||||
|
||||
def train_valid_test_datasets_provider(train_val_test_num_samples):
|
||||
"""Build train, valid, and test datasets."""
|
||||
args = get_args()
|
||||
|
||||
print_rank_0("> building train, validation, and test datasets " "for GPT ...")
|
||||
if args.co_evaluation:
|
||||
def dataset_partition_path_parsing(data_path):
|
||||
dataset_path = {}
|
||||
for index in range(len(data_path)):
|
||||
dataset_path[data_path[index]] = data_path[index]
|
||||
return dataset_path
|
||||
assert args.valid_data_path is not None, "Valid data path must be given when --co-evaluation is turned on."
|
||||
valid_data_path = dataset_partition_path_parsing(args.valid_data_path)
|
||||
if args.test_data_path is not None:
|
||||
test_data_path = dataset_partition_path_parsing(args.test_data_path)
|
||||
else:
|
||||
test_data_path = None
|
||||
train_ds, _, _ = build_train_valid_test_datasets(
|
||||
data_prefix=args.data_path,
|
||||
data_impl=args.data_impl,
|
||||
splits_string="1,0,0",
|
||||
train_valid_test_num_samples=train_val_test_num_samples,
|
||||
seq_length=args.seq_length,
|
||||
seed=args.seed,
|
||||
skip_warmup=(not args.mmap_warmup),
|
||||
)
|
||||
valid_ds = {}
|
||||
for key, value in valid_data_path.items():
|
||||
_, valid_ds_item, _ = build_train_valid_test_datasets(
|
||||
data_prefix=[value],
|
||||
data_impl=args.data_impl,
|
||||
splits_string="0,1,0",
|
||||
train_valid_test_num_samples=train_val_test_num_samples,
|
||||
seq_length=args.seq_length,
|
||||
seed=args.seed,
|
||||
skip_warmup=(not args.mmap_warmup),
|
||||
)
|
||||
valid_ds[key] = valid_ds_item
|
||||
if test_data_path is not None:
|
||||
test_ds = {}
|
||||
for key, value in test_data_path.items():
|
||||
_, _, test_ds_item = build_train_valid_test_datasets(
|
||||
data_prefix=[value],
|
||||
data_impl=args.data_impl,
|
||||
splits_string="0,0,1",
|
||||
train_valid_test_num_samples=train_val_test_num_samples,
|
||||
seq_length=args.seq_length,
|
||||
seed=args.seed,
|
||||
skip_warmup=(not args.mmap_warmup),
|
||||
)
|
||||
test_ds[key] = test_ds_item
|
||||
else:
|
||||
test_ds = None
|
||||
elif args.valid_data_path is None:
|
||||
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
|
||||
data_prefix=args.data_path,
|
||||
data_impl=args.data_impl,
|
||||
splits_string=args.split,
|
||||
train_valid_test_num_samples=train_val_test_num_samples,
|
||||
seq_length=args.seq_length,
|
||||
seed=args.seed,
|
||||
skip_warmup=(not args.mmap_warmup),
|
||||
)
|
||||
else:
|
||||
train_ds, _, test_ds = build_train_valid_test_datasets(
|
||||
data_prefix=args.data_path,
|
||||
data_impl=args.data_impl,
|
||||
splits_string="100,0,0",
|
||||
train_valid_test_num_samples=train_val_test_num_samples,
|
||||
seq_length=args.seq_length,
|
||||
seed=args.seed,
|
||||
skip_warmup=(not args.mmap_warmup),
|
||||
)
|
||||
|
||||
_, valid_ds, _ = build_train_valid_test_datasets(
|
||||
data_prefix=args.valid_data_path,
|
||||
data_impl=args.data_impl,
|
||||
splits_string="0,100,0",
|
||||
train_valid_test_num_samples=train_val_test_num_samples,
|
||||
seq_length=args.seq_length,
|
||||
seed=args.seed,
|
||||
skip_warmup=(not args.mmap_warmup),
|
||||
)
|
||||
|
||||
print_rank_0("> finished creating GPT datasets ...")
|
||||
|
||||
return train_ds, valid_ds, test_ds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pretrain(
|
||||
train_valid_test_datasets_provider,
|
||||
model_provider,
|
||||
forward_step,
|
||||
valid_forward_step,
|
||||
args_defaults={"tokenizer_type": "GPT2BPETokenizer"},
|
||||
)
|
@ -0,0 +1,50 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--load-path",
|
||||
type=str,
|
||||
default="/zhangpai24/workspace/ckpt_ms/ckpt_ms_213000_fp32_52224.pt")
|
||||
parser.add_argument("--save-path",
|
||||
type=str,
|
||||
default="/zhangpai24/workspace/ckpt_ms/ckpt_ms_213000_qkv.pt")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
state_dict_path = args.load_path
|
||||
print("Loading state dict ...")
|
||||
sd = torch.load(state_dict_path, map_location="cpu")
|
||||
|
||||
for i in range(40):
|
||||
if i < 39:
|
||||
query_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.query.weight', None)
|
||||
query_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.query.bias', None)
|
||||
key_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.key.weight', None)
|
||||
key_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.key.bias', None)
|
||||
value_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.value.weight', None)
|
||||
value_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.value.bias', None)
|
||||
qkv_weight = torch.cat([query_weight, key_weight, value_weight], dim=0)
|
||||
qkv_bias = torch.cat([query_bias, key_bias, value_bias])
|
||||
sd['module']['language_model']['transformer'][f'layers.{i}.attention.query_key_value.weight'] = qkv_weight
|
||||
sd['module']['language_model']['transformer'][f'layers.{i}.attention.query_key_value.bias'] = qkv_bias
|
||||
else:
|
||||
tq_key_weight = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.key.weight', None)
|
||||
tq_key_bias = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.key.bias', None)
|
||||
tq_value_weight = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.value.weight', None)
|
||||
tq_value_bias = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.value.bias', None)
|
||||
tq_kv_weight = torch.cat([tq_key_weight, tq_value_weight], dim=0)
|
||||
tq_kv_bias = torch.cat([tq_key_bias, tq_value_bias])
|
||||
sd['module']['language_model']['transformer']['topQueryLayer.attention.key_value.weight'] = tq_kv_weight
|
||||
sd['module']['language_model']['transformer']['topQueryLayer.attention.key_value.bias'] = tq_kv_bias
|
||||
|
||||
save_ckpt_path = args.save_path
|
||||
torch.save(sd, save_ckpt_path)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,198 @@
|
||||
import json
|
||||
import torch
|
||||
import argparse
|
||||
import gradio as gr
|
||||
|
||||
import codegeex
|
||||
from codegeex.torch import CodeGeeXModel
|
||||
from codegeex.tokenizer import CodeGeeXTokenizer
|
||||
from codegeex.quantization import quantize
|
||||
from codegeex.data.data_utils import LANGUAGE_TAG
|
||||
from codegeex.megatron.inference import set_random_seed
|
||||
|
||||
|
||||
def model_provider(args):
|
||||
"""Build the model."""
|
||||
|
||||
model = CodeGeeXModel(
|
||||
args.hidden_size,
|
||||
args.num_layers,
|
||||
args.num_attention_heads,
|
||||
args.padded_vocab_size,
|
||||
args.max_position_embeddings
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def add_code_generation_args(parser):
|
||||
group = parser.add_argument_group(title="code generation")
|
||||
group.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
default=39,
|
||||
)
|
||||
group.add_argument(
|
||||
"--hidden-size",
|
||||
type=int,
|
||||
default=5120,
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-attention-heads",
|
||||
type=int,
|
||||
default=40,
|
||||
)
|
||||
group.add_argument(
|
||||
"--padded-vocab-size",
|
||||
type=int,
|
||||
default=52224,
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-position-embeddings",
|
||||
type=int,
|
||||
default=2048,
|
||||
)
|
||||
group.add_argument(
|
||||
"--tokenizer-path",
|
||||
type=str,
|
||||
default="./tokenizer",
|
||||
)
|
||||
group.add_argument(
|
||||
"--example-path",
|
||||
type=str,
|
||||
default="./",
|
||||
)
|
||||
group.add_argument(
|
||||
"--load",
|
||||
type=str,
|
||||
)
|
||||
group.add_argument(
|
||||
"--state-dict-path",
|
||||
type=str,
|
||||
)
|
||||
group.add_argument(
|
||||
"--micro-batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
)
|
||||
group.add_argument(
|
||||
"--quantize",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = add_code_generation_args(parser)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
print("Loading tokenizer ...")
|
||||
tokenizer = CodeGeeXTokenizer(
|
||||
tokenizer_path=args.tokenizer_path,
|
||||
mode="codegeex-13b")
|
||||
|
||||
print("Loading state dict ...")
|
||||
state_dict = torch.load(args.load, map_location="cpu")
|
||||
state_dict = state_dict["module"]
|
||||
|
||||
print("Building CodeGeeX model ...")
|
||||
model = model_provider(args)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
model.half()
|
||||
if args.quantize:
|
||||
model = quantize(model, weight_bit_width=8, backend="torch")
|
||||
model.cuda()
|
||||
|
||||
def predict(
|
||||
prompt,
|
||||
lang,
|
||||
seed,
|
||||
out_seq_length,
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
):
|
||||
set_random_seed(seed)
|
||||
if lang.lower() in LANGUAGE_TAG:
|
||||
prompt = LANGUAGE_TAG[lang.lower()] + "\n" + prompt
|
||||
|
||||
generated_code = codegeex.generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
out_seq_length=out_seq_length,
|
||||
seq_length=args.max_position_embeddings,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
micro_batch_size=args.micro_batch_size,
|
||||
backend="megatron",
|
||||
verbose=True,
|
||||
)
|
||||
return prompt + generated_code
|
||||
|
||||
examples = []
|
||||
with open(args.example_path, "r") as f:
|
||||
for line in f:
|
||||
examples.append(list(json.loads(line).values()))
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown(
|
||||
"""
|
||||
<img src="https://raw.githubusercontent.com/THUDM/CodeGeeX/main/resources/logo/codegeex_logo.png">
|
||||
""")
|
||||
gr.Markdown(
|
||||
"""
|
||||
<p align="center">
|
||||
🏠 <a href="https://codegeex.cn" target="_blank">Homepage</a> | 📖 <a href="http://keg.cs.tsinghua.edu.cn/codegeex/" target="_blank">Blog</a> | 🪧 <a href="https://codegeex.cn/playground" target="_blank">DEMO</a> | 🛠 <a href="https://marketplace.visualstudio.com/items?itemName=aminer.codegeex" target="_blank">VS Code</a> or <a href="https://plugins.jetbrains.com/plugin/20587-codegeex" target="_blank">Jetbrains</a> Extensions | 💻 <a href="https://github.com/THUDM/CodeGeeX" target="_blank">Source code</a> | 🤖 <a href="https://models.aminer.cn/codegeex/download/request" target="_blank">Download Model</a>
|
||||
</p>
|
||||
""")
|
||||
gr.Markdown(
|
||||
"""
|
||||
We introduce CodeGeeX, a large-scale multilingual code generation model with 13 billion parameters, pre-trained on a large code corpus of more than 20 programming languages. CodeGeeX supports 15+ programming languages for both code generation and translation. CodeGeeX is open source, please refer to our [GitHub](https://github.com/THUDM/CodeGeeX) for more details. This is a minimal-functional DEMO, for other DEMOs like code translation, please visit our [Homepage](https://codegeex.cn). We also offer free [VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex) or [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex) extensions for full functionality.
|
||||
""")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
prompt = gr.Textbox(lines=13, placeholder='Please enter the description or select an example input below.',label='Input')
|
||||
with gr.Row():
|
||||
gen = gr.Button("Generate")
|
||||
clr = gr.Button("Clear")
|
||||
|
||||
outputs = gr.Textbox(lines=15, label='Output')
|
||||
|
||||
gr.Markdown(
|
||||
"""
|
||||
Generation Parameter
|
||||
""")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
lang = gr.Radio(
|
||||
choices=["C++", "C", "C#", "Python", "Java", "HTML", "PHP", "JavaScript", "TypeScript", "Go",
|
||||
"Rust",
|
||||
"SQL", "Kotlin", "R", "Fortran"], value='lang', label='Programming Language',
|
||||
default="Python")
|
||||
with gr.Column():
|
||||
seed = gr.Slider(maximum=10000, value=8888, step=1, label='Seed')
|
||||
with gr.Row():
|
||||
out_seq_length = gr.Slider(maximum=1024, value=128, minimum=1, step=1, label='Output Sequence Length')
|
||||
temperature = gr.Slider(maximum=1, value=0.2, minimum=0, label='Temperature')
|
||||
with gr.Row():
|
||||
top_k = gr.Slider(maximum=40, value=0, minimum=0, step=1, label='Top K')
|
||||
top_p = gr.Slider(maximum=1, value=1.0, minimum=0, label='Top P')
|
||||
|
||||
inputs = [prompt, lang, seed, out_seq_length, temperature, top_k, top_p]
|
||||
gen.click(fn=predict, inputs=inputs, outputs=outputs)
|
||||
clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=prompt)
|
||||
|
||||
gr_examples = gr.Examples(examples=examples, inputs=[prompt, lang],
|
||||
label="Example Inputs (Click to insert an examplet it into the input box)",
|
||||
examples_per_page=20)
|
||||
|
||||
demo.launch(server_port=6007)
|
||||
|
||||
if __name__ == '__main__':
|
||||
with torch.no_grad():
|
||||
main()
|
Loading…
Reference in New Issue