Merge pull request #69 from THUDM/develop

Merge develop branch
pull/70/head
Qinkai 2 years ago committed by GitHub
commit d91ed519ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,70 @@
import copy
from typing import *
from codegeex.megatron.model import CodeGeeXModel
from codegeex.tokenizer import CodeGeeXTokenizer
from codegeex.torch.inference import get_token_stream
def get_model(
backend: str = "megatron",
quantized: bool = False,
):
pass
def generate(
model: CodeGeeXModel,
tokenizer: CodeGeeXTokenizer,
prompt: str,
out_seq_length: int,
seq_length: int = 2048,
top_k: int = 0,
top_p: float = 1.0,
temperature: float = 1.0,
micro_batch_size: int = 1,
backend: str = "megatron",
greedy: bool = False,
verbose: bool = False,
):
tokens = tokenizer.encode_code(prompt)
n_token_prompt = len(tokens)
if verbose:
print(f"Current prompt:\n{prompt}")
print("N_token_prompt:", n_token_prompt)
generated_codes = []
if backend == "megatron":
token_stream = get_token_stream(
model,
tokenizer,
seq_length,
out_seq_length,
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
micro_batch_size=micro_batch_size,
topk=top_k,
topp=top_p,
temperature=temperature,
greedy=greedy,
)
is_finished = [False for _ in range(micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]
for j in range(micro_batch_size):
if is_finished[j]:
continue
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len(generated_tokens[j]) >= out_seq_length:
is_finished[j] = True
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:])
generated_code = "".join(generated_code)
generated_codes.append(generated_code)
if verbose:
print(f"\nGenerated code {i}:\n{generated_code}")
if all(is_finished):
break
return generated_codes

@ -189,7 +189,7 @@ def add_code_generation_args(parser):
nargs="*",
type=int,
default=None,
help='Identify the type of programming language to generate',
help='Specify bad ids that will not be used',
)
group.add_argument(
"--quantize",

@ -1,23 +1,32 @@
import os
import gzip
import json
from typing import *
LANGUAGE_TAG = {
"c" : "// language: C",
"c++" : "// language: C++",
"cpp" : "// language: C++",
"c" : "// language: C",
"c#" : "// language: C#",
"csharp" : "// language: C#",
"css" : "/* language: CSS */",
"cuda" : "// language: Cuda",
"dart" : "// language: Dart",
"lua" : "// language: Lua",
"objectivec" : "// language: Objective-C",
"objective-c" : "// language: Objective-C",
"objective-c++": "// language: Objective-C++",
"python" : "# language: Python",
"perl" : "# language: Perl",
"prolog" : f"% language: Prolog",
"swift" : "// language: swift",
"lisp" : "; language: Lisp",
"java" : "// language: Java",
"scala" : "// language: Scala",
"tex" : f"% language: TeX",
"vue" : "<!--language: Vue-->",
"markdown" : "<!--language: Markdown-->",
"html" : "<!--language: HTML-->",
"php" : "// language: PHP",
"js" : "// language: JavaScript",
@ -26,13 +35,32 @@ LANGUAGE_TAG = {
"go" : "// language: Go",
"shell" : "# language: Shell",
"rust" : "// language: Rust",
"css" : "/* language: CSS */",
"sql" : "-- language: SQL",
"kotlin" : "// language: Kotlin",
"vb" : "' language: Visual Basic",
"ruby" : "# language: Ruby",
"pascal" : "// language: Pascal",
"r" : "# language: R",
"fortran" : "!language: Fortran",
"lean" : "-- language: Lean",
"matlab" : f"% language: Matlab",
"delphi" : "{language: Delphi}",
"scheme" : "; language: Scheme",
"basic" : "' language: Basic",
"assembly" : "; language: Assembly",
"groovy" : "// language: Groovy",
"abap" : "* language: Abap",
"gdscript" : "# language: GDScript",
"haskell" : "-- language: Haskell",
"julia" : "# language: Julia",
"elixir" : "# language: Elixir",
"excel" : "' language: Excel",
"clojure" : "; language: Clojure",
"actionscript" : "// language: ActionScript",
"solidity" : "// language: Solidity",
"powershell" : "# language: PowerShell",
"erlang" : f"% language: Erlang",
"cobol" : "// language: Cobol",
}

@ -58,7 +58,7 @@ def process_sample(
try:
if language is not None and language in LANGUAGE_TAG.keys():
code = LANGUAGE_TAG[language] + sample["code"]
code = LANGUAGE_TAG[language] + "\n" + sample["code"]
else:
code = sample["code"]
except Exception as e:

@ -67,6 +67,9 @@ class PromptDatasetProcessor(object):
"""
Instead of processing lazily, we turn the iterable into a list.
"""
if sample is None:
return None
return list(self.process_sample(sample))
def process_sample_(self, sample) -> List[Dict[str, List[int]]]:
@ -141,6 +144,9 @@ class LabelDatasetProcessor(object):
"""
Instead of processing lazily, we turn the iterable into a list.
"""
if sample is None:
return None
return list(self.process_sample(sample))
def process_sample_(self, sample) -> List[Dict[str, List[int]]]:

@ -415,6 +415,10 @@ def _add_network_size_args(parser):
help="Disable BERT binary head.",
dest="bert_binary_head",
)
group.add_argument(
"--compress",
action="store_true",
)
return parser
@ -560,6 +564,24 @@ def _add_regularization_args(parser):
group.add_argument(
"--sgd-momentum", type=float, default=0.9, help="Momentum factor for sgd"
)
group.add_argument(
"--shrink-logit-embedding-gradient",
action="store_true",
)
group.add_argument(
"--shrink-embedding-gradient-alpha",
type=float,
default=1.0,
help='Shrink embedding gradient for alpha',
)
group.add_argument(
"--shrink-embedding-gradient-steps",
nargs='*',
default=None,
help='--shrink-embedding-gradient-steps <x1> <x2>'
'Shrink embedding gradient alpha for x1 steps,'
'then warm it up to 1.0 with x2 steps',
)
return parser
@ -751,6 +773,10 @@ def _add_initialization_args(parser):
def _add_inference_args(parser):
group = parser.add_argument_group(title="initialization")
group.add_argument(
'--evaluation',
action="store_true",
)
group.add_argument(
'--beam-warmup',
action="store_true",

@ -84,7 +84,7 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False):
if release:
directory = ""
else:
directory = "iter_{:07d}".format(iteration)
directory = f"global_step{iteration}"
# Use both the tensor and pipeline MP rank.
if mpu.get_pipeline_model_parallel_world_size() == 1:
return os.path.join(
@ -174,7 +174,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Saving is a collective communication
checkpoint_name = get_checkpoint_name(args.save, iteration)
# Trim off the filename and mp_rank_* directory.
for _ in range(3):
for _ in range(2):
checkpoint_name = os.path.dirname(checkpoint_name)
model[0].save_checkpoint(checkpoint_name, client_state=state_dict)

@ -25,10 +25,11 @@ import torch
import torch.nn.functional as F
from dataclasses import dataclass
from codegeex.megatron import get_args
from codegeex.megatron import get_args, print_rank_0
from codegeex.megatron import get_tokenizer
from codegeex.megatron import mpu
from codegeex.megatron.utils import get_ltor_masks_and_position_ids
from codegeex.benchmark.utils import is_code_generation_finished
def get_batch(context_tokens, micro_batch_size=None):
@ -682,12 +683,17 @@ def beam_search(model, context_tokens, num_beams: int):
expanded_beams = expand_beams(beams, num_beams, model)
next_beams = []
for beam in expanded_beams:
if args.beam_warmup_length > 0:
if args.beam_warmup:
if len(beam.tokens) >= org_context_len + args.beam_warmup_length or beam.tokens[-1] == tokenizer.eod:
finished_beams.append(beam)
else:
next_beams.append(beam)
else:
if args.evaluation:
generated_code = tokenizer.detokenize(beam.tokens[org_context_len:])
if is_code_generation_finished(generated_code):
finished_beams.append(beam)
continue
if beam.tokens[-1] == tokenizer.eod:
finished_beams.append(beam)
else:
@ -842,7 +848,6 @@ def get_token_stream(
temperature: float = None,
topp: float = None,
topk: int = None,
beam_warmup: bool = False,
):
args = get_args()
tokenizer = get_tokenizer()
@ -866,42 +871,30 @@ def get_token_stream(
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, micro_batch_size)
if beam_warmup:
batch_token_iterator = sample_sequence_batch_beam(
model,
context_tokens_tensor,
context_length_tensor,
attention_mask,
position_ids,
return_scores=return_scores,
prompt_length=prompt_length,
bad_ids=bad_ids,
temperature=temperature,
topp=topp,
topk=topk,
beam_warmup=True,
)
else:
batch_token_iterator = sample_sequence_batch(
model,
context_tokens_tensor,
context_length_tensor,
attention_mask,
position_ids,
return_scores=return_scores,
prompt_length=prompt_length,
bad_ids=bad_ids,
temperature=temperature,
topp=topp,
topk=topk,
)
batch_token_iterator = sample_sequence_batch(
model,
context_tokens_tensor,
context_length_tensor,
attention_mask,
position_ids,
return_scores=return_scores,
prompt_length=prompt_length,
bad_ids=bad_ids,
temperature=temperature,
topp=topp,
topk=topk,
)
for tokens, lengths in batch_token_iterator:
context_length += 1
if tokens is not None:
yield tokens[:, :context_length], lengths
else:
yield None, None
if args.beam_search:
for beams in batch_token_iterator:
yield beams
else:
for tokens, lengths in batch_token_iterator:
context_length += 1
if tokens is not None:
yield tokens[:, :context_length], lengths
else:
yield None, None
def switch(val1, val2, boolean):
@ -957,284 +950,131 @@ def sample_sequence_batch(
if return_scores:
scores = torch.zeros([batch_size]).float().cuda()
while context_length <= (maxlen):
if args.recompute:
logits = model(tokens,
position_ids,
attention_mask,
tokentype_ids=type_ids,
forward_method_parallel_output=False,
prompt_length=prompt_length,
context_length=context_length,
)
logits = logits[:, context_length - 1, :]
if args.beam_search:
beams = beam_search(model, context_tokens=tokens.cpu().numpy().tolist()[0][:context_length],
num_beams=args.num_beams)
if args.beam_warmup:
beam = beams[0]
tokens_ = beam.tokens
tokens_ = (tokens_ if tokens_[-1] != tokenizer.eod else tokens_[:-1])
tokens_warmup = []
for i in range(batch_size):
tokens_warmup.append(tokens_.copy())
tokens, context_lengths = pad_batch(tokens_warmup, tokenizer.eod, args)
tokens = torch.cuda.LongTensor(tokens)
context_lengths = torch.cuda.LongTensor(context_lengths)
context_length = len(tokens_)
org_context_length = context_length
if maxlen is None:
maxlen = args.seq_length - 1
if maxlen > (org_context_length + args.out_seq_length):
maxlen = org_context_length + args.out_seq_length
lengths = torch.ones([batch_size]).long().cuda() * maxlen
tokens, attention_mask, position_ids = get_batch(tokens, batch_size)
else:
types2use = None
if counter == 0:
tokens2use = tokens[:, :context_length]
positions2use = position_ids[:, :context_length]
if type_ids is not None:
types2use = type_ids[:, :context_length]
yield beams
else:
while context_length <= (maxlen):
if args.recompute:
logits = model(tokens,
position_ids,
attention_mask,
tokentype_ids=type_ids,
forward_method_parallel_output=False,
prompt_length=prompt_length,
context_length=context_length,
)
logits = logits[:, context_length - 1, :]
else:
tokens2use = tokens[:, context_length - 1].view(
batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(
batch_size, -1)
if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(
types2use = None
if counter == 0:
tokens2use = tokens[:, :context_length]
positions2use = position_ids[:, :context_length]
if type_ids is not None:
types2use = type_ids[:, :context_length]
else:
tokens2use = tokens[:, context_length - 1].view(
batch_size, -1)
logits, layer_past = model(tokens2use,
positions2use,
attention_mask,
layer_past=layer_past,
get_key_value=True,
tokentype_ids=types2use,
forward_method_parallel_output=False,
prompt_length=prompt_length,
context_length=context_length,
)
logits = logits[:, -1].view(batch_size, -1).contiguous()
if mpu.is_pipeline_last_stage():
if bad_ids is not None:
for bad_id in bad_ids:
logits[:, bad_id] = -10000
if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits = logits.float()
if return_scores:
orig_log_probs = torch.log_softmax(logits, dim=-1)
logits /= temperature
logits = top_k_logits(logits, top_k=topk, top_p=topp)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
started = context_lengths <= context_length
new_tokens = switch(tokens[:, context_length].view(-1), prev, started)
if not args.greedy and return_scores:
indices = prev.view(-1, 1)
new_scores = orig_log_probs.gather(1, indices).view(-1)
new_scores = new_scores * started
new_scores = new_scores * is_done.bool().logical_not()
scores += new_scores
tokens[:, context_length] = new_tokens
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(new_tokens, src, group)
done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
is_done = is_done | done_token
done = torch.all(is_done)
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
if return_scores:
yield tokens, (lengths, scores)
else:
yield tokens, lengths
positions2use = position_ids[:, context_length - 1].view(
batch_size, -1)
if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(
batch_size, -1)
logits, layer_past = model(tokens2use,
positions2use,
attention_mask,
layer_past=layer_past,
get_key_value=True,
tokentype_ids=types2use,
forward_method_parallel_output=False,
prompt_length=prompt_length,
context_length=context_length,
)
logits = logits[:, -1].view(batch_size, -1).contiguous()
if mpu.is_pipeline_last_stage():
if bad_ids is not None:
for bad_id in bad_ids:
logits[:, bad_id] = -10000
if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits = logits.float()
if return_scores:
orig_log_probs = torch.log_softmax(logits, dim=-1)
logits /= temperature
logits = top_k_logits(logits, top_k=topk, top_p=topp)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
started = context_lengths <= context_length
new_tokens = switch(tokens[:, context_length].view(-1), prev, started)
if not args.greedy and return_scores:
indices = prev.view(-1, 1)
new_scores = orig_log_probs.gather(1, indices).view(-1)
new_scores = new_scores * started
new_scores = new_scores * is_done.bool().logical_not()
scores += new_scores
else:
if mpu.is_pipeline_first_stage():
tokens[:, context_length] = new_tokens
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
new_tokens = torch.empty_like(tokens[:, context_length])
torch.distributed.broadcast(new_tokens, src, group)
tokens[:, context_length] = new_tokens
yield tokens, None
else:
yield None, None
done = torch.cuda.ByteTensor([0])
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
context_length += 1
counter += 1
if done:
break
def sample_sequence_batch_beam(
model,
context_tokens,
context_lengths,
attention_mask,
position_ids,
maxlen=None,
type_ids=None,
return_scores: bool = False,
prompt_length: int = None,
bad_ids: List = None,
temperature: float = None,
topp: float = None,
topk: int = None,
beam_warmup: bool = False,
):
args = get_args()
tokenizer = get_tokenizer()
temperature = temperature if temperature is not None else args.temperature
topp = topp if topp is not None else args.top_p
topk = topk if topk is not None else args.top_k
model.eval()
with torch.no_grad():
context_length = context_lengths.min().item()
# added eos_id to support the function generate_samples_eval that passes
# eos_id as an argument and needs termination when that id id found.
if hasattr(args, "eos_id"):
eos_id = args.eos_id
else:
eos_id = tokenizer.eod
counter = 0
org_context_length = context_length
layer_past = None
batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens
if maxlen is None:
maxlen = args.seq_length - 1
if maxlen > (org_context_length + args.out_seq_length):
maxlen = org_context_length + args.out_seq_length
done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
is_done = is_done | done_token
lengths = torch.ones([batch_size]).long().cuda() * maxlen
if return_scores:
scores = torch.zeros([batch_size]).float().cuda()
done = torch.all(is_done)
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
if beam_warmup:
beams = beam_search(model, context_tokens=tokens.cpu().numpy().tolist()[0][:context_length],
num_beams=args.num_beams)
beam = beams[0]
tokens_ = beam.tokens
tokens_ = (tokens_ if tokens_[-1] != tokenizer.eod else tokens_[:-1])
tokens_warmup = []
for i in range(batch_size):
tokens_warmup.append(tokens_.copy())
tokens, context_lengths = pad_batch(tokens_warmup, tokenizer.eod, args)
tokens = torch.cuda.LongTensor(tokens)
context_lengths = torch.cuda.LongTensor(context_lengths)
context_length = len(tokens_)
org_context_length = context_length
if maxlen is None:
maxlen = args.seq_length - 1
if maxlen > (org_context_length + args.out_seq_length):
maxlen = org_context_length + args.out_seq_length
lengths = torch.ones([batch_size]).long().cuda() * maxlen
tokens, attention_mask, position_ids = get_batch(tokens, batch_size)
while context_length <= (maxlen):
if args.recompute:
logits = model(tokens,
position_ids,
attention_mask,
tokentype_ids=type_ids,
forward_method_parallel_output=False,
prompt_length=prompt_length,
context_length=context_length,
)
logits = logits[:, context_length - 1, :]
else:
types2use = None
if counter == 0:
tokens2use = tokens[:, :context_length]
positions2use = position_ids[:, :context_length]
if type_ids is not None:
types2use = type_ids[:, :context_length]
else:
tokens2use = tokens[:, context_length - 1].view(
batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(
batch_size, -1)
if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(
batch_size, -1)
logits, layer_past = model(tokens2use,
positions2use,
attention_mask,
layer_past=layer_past,
get_key_value=True,
tokentype_ids=types2use,
forward_method_parallel_output=False,
prompt_length=prompt_length,
context_length=context_length,
)
logits = logits[:, -1].view(batch_size, -1).contiguous()
if mpu.is_pipeline_last_stage():
if bad_ids is not None:
for bad_id in bad_ids:
logits[:, bad_id] = -10000
if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits = logits.float()
if return_scores:
orig_log_probs = torch.log_softmax(logits, dim=-1)
logits /= temperature
logits = top_k_logits(logits, top_k=topk, top_p=topp)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
started = context_lengths <= context_length
new_tokens = switch(tokens[:, context_length].view(-1), prev, started)
if not args.greedy and return_scores:
indices = prev.view(-1, 1)
new_scores = orig_log_probs.gather(1, indices).view(-1)
new_scores = new_scores * started
new_scores = new_scores * is_done.bool().logical_not()
scores += new_scores
tokens[:, context_length] = new_tokens
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(new_tokens, src, group)
done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
is_done = is_done | done_token
done = torch.all(is_done)
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
if return_scores:
yield tokens, (lengths, scores)
else:
yield tokens, lengths
yield tokens, (lengths, scores)
else:
yield tokens, lengths
else:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
new_tokens = torch.empty_like(tokens[:, context_length])
torch.distributed.broadcast(new_tokens, src, group)
tokens[:, context_length] = new_tokens
yield tokens, None
else:
yield None, None
done = torch.cuda.ByteTensor([0])
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
new_tokens = torch.empty_like(tokens[:, context_length])
torch.distributed.broadcast(new_tokens, src, group)
tokens[:, context_length] = new_tokens
yield tokens, None
else:
yield None, None
done = torch.cuda.ByteTensor([0])
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
context_length += 1
counter += 1
if done:
break
context_length += 1
counter += 1
if done:
break

@ -1,20 +1,8 @@
"""Get model parallel partitions."""
import os
import re
import random
import sys
import numpy as np
import torch
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
from codegeex.megatron import get_args
from codegeex.megatron.model import CodeGeeXModel
from codegeex.megatron.initialize import initialize_megatron
from codegeex.megatron.checkpointing import ensure_directory_exists
import argparse
def get_change_ckpt_args(parser):
@ -58,19 +46,10 @@ def get_element_from_dict_by_path(d, path):
def main():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(random.randint(10000, 20000))
initialize_megatron(
extra_args_provider=get_change_ckpt_args,
args_defaults={
"tokenizer_type": "GPT2BPETokenizer",
"no_load_rng" : True,
"no_load_optim" : True,
},
)
args = get_args()
parser = argparse.ArgumentParser()
parser = get_change_ckpt_args(parser)
args, _ = parser.parse_known_args()
print(f"Load ckpt from {args.load_ckpt_path}...")
state_dict = torch.load(args.load_ckpt_path, map_location="cpu")

@ -0,0 +1,133 @@
"""Merge model parallel partitions into a single checkpoint."""
import os
import torch
import random
from codegeex.megatron import get_args
from codegeex.megatron.model import CodeGeeXModel
from codegeex.megatron.initialize import initialize_megatron
from codegeex.megatron.checkpointing import ensure_directory_exists
def get_change_ckpt_args(parser):
"""Provide extra arguments required for merging."""
group = parser.add_argument_group(title='Mindspore to megatron')
group.add_argument(
'--load-ckpt-path',
type=str,
required=True,
help='dir to load model parallel partitions.',
)
group.add_argument(
'--save-ckpt-path',
type=str,
required=True,
help='path to save ".pt" checkpoint.',
)
group.add_argument(
'--source-tensor-model-parallel-size',
type=int,
default=2,
help='original tensor model parallel size',
)
return parser
def main():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(random.randint(10000, 20000))
initialize_megatron(
extra_args_provider=get_change_ckpt_args,
args_defaults={
"tokenizer_type": "GPT2BPETokenizer",
"no_load_rng" : True,
"no_load_optim" : True,
},
)
args = get_args()
model = CodeGeeXModel()
print(model.state_dict)
# Save the model.
sd = {}
sd['module'] = model.state_dict_for_save_checkpoint()
ensure_directory_exists(args.save_ckpt_path)
print(f"Load ckpt from {args.load_ckpt_path}...")
state_dict_list = []
for i in range(args.source_tensor_model_parallel_size):
try:
state_dict_list.append(torch.load(os.path.join(args.load_ckpt_path, f"mp_rank_{i:02d}_model_states.pt"), map_location="cpu"))
except Exception as e:
print(e)
exit(0)
print(f"Merging {len(state_dict_list)} partitions into a single ckpt...")
print("Merging Embedding layers...")
vocab_parallel_size = args.make_vocab_size_divisible_by // args.source_tensor_model_parallel_size
for i in range(args.source_tensor_model_parallel_size):
sd['module']['language_model']['embedding']['word_embeddings']['weight'][i * vocab_parallel_size : (i + 1) * vocab_parallel_size, :] = state_dict_list[i]['module']['language_model']['embedding']['word_embeddings']['weight']
sd['module']['language_model']['embedding']['position_embeddings']['weight'] = state_dict_list[0]['module']['language_model']['embedding']['position_embeddings']['weight']
print("Merging QueryEmbedding layers...")
query_parallel_size = args.max_position_embeddings // args.source_tensor_model_parallel_size
for i in range(args.source_tensor_model_parallel_size):
sd['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight'][i * query_parallel_size : (i + 1) * query_parallel_size, :] = state_dict_list[i]['module']['language_model']['topQueryEmbedding']['top_query_embeddings'].pop('weight', None)
print("Merging Transformer layers...")
for layer_name in sd['module']['language_model']['transformer'].keys():
if "layernorm" in layer_name:
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None)
elif "attention" in layer_name and "weight" in layer_name:
if "dense" in layer_name:
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[1] // args.source_tensor_model_parallel_size
for i in range(args.source_tensor_model_parallel_size):
sd['module']['language_model']['transformer'][layer_name][:, i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
else:
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size
for i in range(args.source_tensor_model_parallel_size):
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size, :] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
elif "weight" in layer_name and "dense" in layer_name:
if "h_to_4h" in layer_name:
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size
for i in range(args.source_tensor_model_parallel_size):
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size, :] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
else:
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[1] // args.source_tensor_model_parallel_size
for i in range(args.source_tensor_model_parallel_size):
sd['module']['language_model']['transformer'][layer_name][:, i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
elif "bias" in layer_name:
if "mlp" in layer_name:
if "4h_to_h" in layer_name:
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None)
else:
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size
for i in range(args.source_tensor_model_parallel_size):
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
elif "attention" in layer_name:
if "dense" in layer_name:
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None)
else:
hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size
for i in range(args.source_tensor_model_parallel_size):
sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None)
else:
sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None)
if args.save_ckpt_path.endswith(".pt"):
save_ckpt_path = args.save_ckpt_path
else:
os.makedirs(args.save_ckpt_path, exist_ok=True)
save_ckpt_path = os.path.join(args.save_ckpt_path, "mp_rank_00_model_states.pt")
torch.save(sd, save_ckpt_path)
print(f"Converted checkpoint saved in {save_ckpt_path}.")
if __name__ == '__main__':
main()

@ -19,17 +19,46 @@ import torch
import torch.nn.functional as F
from codegeex.megatron import get_args
from codegeex.megatron import mpu
from codegeex.megatron import mpu, print_rank_0
from codegeex.megatron.model.module import MegatronModule
from codegeex.megatron.model.transformer import ParallelTransformer
from codegeex.megatron.model.utils import init_method_normal, scaled_init_method_normal
from codegeex.megatron.mpu.initialize import get_tensor_model_parallel_world_size
def get_shrink_embedding_gradient_alpha(iteration):
args = get_args()
alpha = args.shrink_embedding_gradient_alpha
if args.shrink_embedding_gradient_steps is None:
return alpha
else:
x1 = int(args.shrink_embedding_gradient_steps[0])
x2 = int(args.shrink_embedding_gradient_steps[1])
if iteration <= x1:
return alpha
elif iteration >= x1 + x2:
return 1.0
else:
return alpha + (1 - alpha) * (args.iteration - x1) / x2
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
"""LM logits using word embedding weights."""
# Parallel logits.
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
args = get_args()
if args.shrink_logit_embedding_gradient:
if hasattr(args, 'iteration'):
alpha = get_shrink_embedding_gradient_alpha(args.iteration + 1)
else:
alpha = args.shrink_embedding_gradient_alpha
word_embeddings_weight = word_embeddings_weight if alpha == 1.0 \
else (
word_embeddings_weight * alpha +
word_embeddings_weight.detach() * (1 - alpha)
)
if bias is None:
logits_parallel = F.linear(input_parallel, word_embeddings_weight.half())
else:
@ -92,15 +121,19 @@ class Embedding(MegatronModule):
num_tokentypes=0,
):
super(Embedding, self).__init__()
args = get_args()
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes
self.max_sequence_length = max_sequence_length
# Word embeddings (parallel).
self.word_embeddings = mpu.VocabParallelEmbedding(
vocab_size, self.hidden_size, init_method=self.init_method)
self._word_embeddings_key = 'word_embeddings'
self.vocab_size = vocab_size
# Position embedding (serial).
@ -108,6 +141,7 @@ class Embedding(MegatronModule):
max_sequence_length, self.hidden_size)
self.position_embeddings = self.position_embeddings.half()
self._position_embeddings_key = 'position_embeddings'
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
@ -190,7 +224,8 @@ class Embedding(MegatronModule):
if 'word_embeddings' in key:
state_dict_[key.split('word_embeddings.')[1]] \
= state_dict[key]
state_dict_["weight"] = state_dict_["weight"][:self.vocab_size]
vocab_len = state_dict_['weight'].shape[0]
state_dict_["weight"] = state_dict_["weight"][:self.vocab_size // get_tensor_model_parallel_world_size()]
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
# Position embedding.
@ -203,6 +238,17 @@ class Embedding(MegatronModule):
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] \
= state_dict[key]
pos_len = state_dict_['weight'].shape[0]
max_seq_len = self.max_sequence_length
if pos_len < max_seq_len:
print_rank_0(f"Position embedding padded {pos_len} -> {max_seq_len}.")
position_embeddings_padded = torch.nn.Embedding(
max_seq_len - pos_len, self.hidden_size).half()
self.init_method(position_embeddings_padded.weight)
state_dict_['weight'] = torch.cat([state_dict_['weight'], position_embeddings_padded.weight], dim=0)
# self.position_embeddings = self.position_embeddings.half()
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.
@ -284,12 +330,14 @@ class QueryEmbedding(MegatronModule):
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes
self.max_sequence_length = max_sequence_length
# Top query position embedding (serial).
self.top_query_embeddings = mpu.VocabParallelEmbedding(
max_sequence_length, self.hidden_size, init_method=self.init_method)
self.top_query_embeddings = self.top_query_embeddings.half()
self._top_query_embeddings_key = 'top_query_embeddings'
# Initialize the top query position embeddings.
self.init_method(self.top_query_embeddings.weight)
@ -368,6 +416,14 @@ class QueryEmbedding(MegatronModule):
if 'top_query_embeddings' in key:
state_dict_[key.split('top_query_embeddings.')[1]] \
= state_dict[key]
pos_len = state_dict_['weight'].shape[0]
max_seq_len = self.max_sequence_length // get_tensor_model_parallel_world_size()
if pos_len < max_seq_len:
print_rank_0(f"Top query embedding padded {pos_len} -> {max_seq_len}.")
top_query_embeddings_padded = torch.nn.Embedding(
max_seq_len - pos_len, self.hidden_size).half()
self.init_method(top_query_embeddings_padded.weight)
state_dict_['weight'] = torch.cat([state_dict_['weight'], top_query_embeddings_padded.weight], dim=0)
self.top_query_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.

@ -61,14 +61,19 @@ class ParallelMLP(MegatronModule):
applied.
"""
def __init__(self, init_method, output_layer_init_method):
def __init__(
self,
init_method,
output_layer_init_method,
scale: int = 4,
):
super(ParallelMLP, self).__init__()
args = get_args()
# Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear(
args.hidden_size,
4 * args.hidden_size,
scale * args.hidden_size,
gather_output=False,
init_method=init_method,
# skip_bias_add=True,
@ -78,7 +83,7 @@ class ParallelMLP(MegatronModule):
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
4 * args.hidden_size,
scale * args.hidden_size,
args.hidden_size,
input_is_parallel=True if args.tensor_model_parallel_size > 1 else False,
init_method=output_layer_init_method,
@ -112,10 +117,11 @@ class ParallelSelfAttention(MegatronModule):
# Per attention head and per partition values.
world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(args.hidden_size,
world_size)
self.hidden_size_per_partition = mpu.divide(
args.hidden_size // 2 if args.compress else args.hidden_size,
world_size)
self.hidden_size_per_attention_head = mpu.divide(
args.hidden_size, args.num_attention_heads)
args.hidden_size // 2 if args.compress else args.hidden_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size)
if hasattr(args, 'attention_upweight'):
@ -125,17 +131,17 @@ class ParallelSelfAttention(MegatronModule):
# Strided linear layer.
self.query = mpu.ColumnParallelLinear(
args.hidden_size,
args.hidden_size,
args.hidden_size // 2 if args.compress else args.hidden_size,
gather_output=False,
init_method=init_method)
self.key = mpu.ColumnParallelLinear(
args.hidden_size,
args.hidden_size,
args.hidden_size // 2 if args.compress else args.hidden_size,
gather_output=False,
init_method=init_method)
self.value = mpu.ColumnParallelLinear(
args.hidden_size,
args.hidden_size,
args.hidden_size // 2 if args.compress else args.hidden_size,
gather_output=False,
init_method=init_method)
@ -149,7 +155,7 @@ class ParallelSelfAttention(MegatronModule):
# Output.
self.dense = mpu.RowParallelLinear(
args.hidden_size,
args.hidden_size // 2 if args.compress else args.hidden_size,
args.hidden_size,
input_is_parallel=True if args.tensor_model_parallel_size > 1 else False,
init_method=output_layer_init_method,
@ -264,7 +270,7 @@ class ParallelSelfAttention(MegatronModule):
if self.attention_softmax_in_fp32:
attention_probs = self.softmax(attention_scores.float()).half()
else:
attention_probs = self.softmax(attention_scores)
attention_probs = self.softmax(attention_scores.half())
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
@ -485,7 +491,7 @@ class ParallelTopQuerySelfAttention(MegatronModule):
if self.attention_softmax_in_fp32:
attention_probs = self.softmax(attention_scores.float()).half()
else:
attention_probs = self.softmax(attention_scores)
attention_probs = self.softmax(attention_scores.half())
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
@ -607,7 +613,8 @@ class ParallelTransformerLayer(MegatronModule):
self.ln_fp16 = False
# MLP
self.mlp = ParallelMLP(init_method,
output_layer_init_method)
output_layer_init_method,
scale=2 if args.compress else 4)
def forward(
self,

@ -0,0 +1,326 @@
import os
import torch
import logging
logging.getLogger("torch").setLevel(logging.WARNING)
import deepspeed
from deepspeed.runtime.utils import see_memory_usage
from functools import partial
from codegeex.megatron import get_args, print_rank_0, get_timers,get_tokenizer, mpu
from codegeex.megatron.data.prompt_dataset import build_train_valid_test_datasets
from codegeex.megatron.model import CodeGeeXModel
from codegeex.megatron.training import pretrain
from codegeex.megatron.utils import get_ltor_masks_and_position_ids
from codegeex.megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0("building GPT model ...")
see_memory_usage(f"Before Building Model", force=True)
args = get_args()
with deepspeed.zero.Init(
data_parallel_group=mpu.get_data_parallel_group(),
remote_device=None if args.remote_device == "none" else args.remote_device,
config_dict_or_path=args.deepspeed_config,
enabled=args.zero_stage == 3,
mpu=mpu,
):
if args.deepspeed and not args.no_pipeline_parallel:
model = CodeGeeXModelPipe(num_tokentypes=0, parallel_output=True)
# This is a hack to give us a reference to get_batch_pipe from within training.py
# We need to call model.set_batch_fn after deepspeed.initialize
model._megatron_batch_fn = get_batch_pipe
# Predompute the attention mask and store it in args. This avoids having to
# pipeline it as an activation during training. The mask is constant, and thus
# we can reuse it.
attention_mask = torch.tril(
torch.ones(
(1, args.seq_length, args.seq_length),
device=torch.cuda.current_device(),
)
).view(1, 1, args.seq_length, args.seq_length)
# Convert attention mask to binary:
attention_mask = attention_mask < 0.5
if args.fp16:
attention_mask = attention_mask.half()
elif args.bf16:
attention_mask = attention_mask.bfloat16()
# Attention mask must be bool.
args.attn_mask = attention_mask.to(torch.bool)
else:
model = CodeGeeXModel(
num_tokentypes=0,
parallel_output=True,
)
if args.load_state is not None:
timers = get_timers()
print_rank_0("Loading warmstarting model states ...")
timers("load-model-states").start()
mp_rank = mpu.get_tensor_model_parallel_rank()
if os.path.isdir(args.load_state):
model_path = os.path.join(
args.load_state, "mp_rank_{:02d}_model_states.pt".format(mp_rank)
)
else:
model_path = args.load_state
print_rank_0(f"Loading model from {model_path} ...")
state_dict = torch.load(model_path, map_location="cpu")
if "module" in state_dict:
state_dict = state_dict["module"] # strip other client states
model.load_state_dict(state_dict)
timers("load-model-states").stop()
timers.log(["load-model-states"])
see_memory_usage(f"After Building Model", force=True)
return model
def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ["input_ids", "attention_mask", "labels"]
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b["input_ids"].contiguous()
# attn_mask_ = data_b["attention_mask"].contiguous()
labels_ = data_b["labels"].contiguous()
tokens = tokens_[:, :-1]
labels = labels_[:, 1:]
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
)
# mask loss to avoid predicting prompt and paddings
prompt_loss_mask = labels >= 0
loss_mask = prompt_loss_mask * loss_mask
return tokens, labels, loss_mask, attention_mask, position_ids
def get_batch_pipe(data):
"""Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ["input_ids"]
datatype = torch.int64
# Broadcast data.
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b["input_ids"].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
)
return (tokens, position_ids, attention_mask), (labels, loss_mask)
def loss_func(loss_mask, output_tensor):
args = get_args()
def compute_lm_loss(losses: torch.Tensor, loss_mask: torch.Tensor):
if args.gold:
losses_ = losses.detach()
prob = torch.exp(-losses_) # Pθ(s)
torch.sqrt_(prob) # Pθ(s)ᵃ
torch.clamp_min_(prob, args.gold_beta) # max(Pθ(s)ᵃ,β)
losses = prob * losses
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / torch.clamp_min(loss_mask.sum(), 1e-8)
return loss
losses = output_tensor.float()
loss = compute_lm_loss(losses, loss_mask)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"lm loss": averaged_loss[0]}
def valid_loss_func(loss_mask, output_tensor):
args = get_args()
def compute_lm_loss(losses: torch.Tensor, loss_mask: torch.Tensor):
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / torch.clamp_min(loss_mask.sum(), 1e-8)
return loss
losses = output_tensor.float()
loss = compute_lm_loss(losses, loss_mask)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"lm loss": averaged_loss[0]}
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers("batch-generator").start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
timers("batch-generator").stop()
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def valid_forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers("batch-generator").start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
timers("batch-generator").stop()
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
return output_tensor, partial(valid_loss_func, loss_mask)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0("> building train, validation, and test datasets " "for GPT ...")
if args.co_evaluation:
def dataset_partition_path_parsing(data_path):
dataset_path = {}
for index in range(len(data_path)):
dataset_path[data_path[index]] = data_path[index]
return dataset_path
assert args.valid_data_path is not None, "Valid data path must be given when --co-evaluation is turned on."
valid_data_path = dataset_partition_path_parsing(args.valid_data_path)
if args.test_data_path is not None:
test_data_path = dataset_partition_path_parsing(args.test_data_path)
else:
test_data_path = None
train_ds, _, _ = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string="1,0,0",
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
)
valid_ds = {}
for key, value in valid_data_path.items():
_, valid_ds_item, _ = build_train_valid_test_datasets(
data_prefix=[value],
data_impl=args.data_impl,
splits_string="0,1,0",
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
)
valid_ds[key] = valid_ds_item
if test_data_path is not None:
test_ds = {}
for key, value in test_data_path.items():
_, _, test_ds_item = build_train_valid_test_datasets(
data_prefix=[value],
data_impl=args.data_impl,
splits_string="0,0,1",
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
)
test_ds[key] = test_ds_item
else:
test_ds = None
elif args.valid_data_path is None:
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
)
else:
train_ds, _, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string="100,0,0",
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
)
_, valid_ds, _ = build_train_valid_test_datasets(
data_prefix=args.valid_data_path,
data_impl=args.data_impl,
splits_string="0,100,0",
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
)
print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
forward_step,
valid_forward_step,
args_defaults={"tokenizer_type": "GPT2BPETokenizer"},
)

@ -65,12 +65,6 @@ except ImportError:
from filelock import FileLock
import pathlib
try:
import bmcook
from bmcook import Config
except ImportError:
print("bmcook not imported.")
bmcook = None
def print_datetime(string):
@ -80,15 +74,11 @@ def print_datetime(string):
print_rank_0("[" + string + "] datetime: {} ".format(time_str))
def compress_setup(args, model, optimizer):
teacher = get_model(args)
cook_config = ConfigParser(args.cook_config)
CPMAntTrainer.set_compression(cook_config, model, optimizer, teacher=teacher, remove_ckptblock=False, target_linear=Linear)
def pretrain(
train_valid_test_dataset_provider,
model_provider,
forward_step_func,
valid_forward_step_func=None,
extra_args_provider=None,
args_defaults={},
):
@ -187,6 +177,7 @@ def pretrain(
if args.do_train and args.train_iters > 0:
iteration = train(
forward_step_func,
valid_forward_step_func,
model,
optimizer,
lr_scheduler,
@ -200,11 +191,11 @@ def pretrain(
if args.co_evaluation:
for key, value in valid_data_iterator.items():
evaluate_and_print_results(
prefix, forward_step_func, value, model, iteration, False, tag=key
prefix, valid_forward_step_func, value, model, iteration, False, tag=key
)
else:
evaluate_and_print_results(
prefix, forward_step_func, valid_data_iterator, model, iteration, False
prefix, valid_forward_step_func, valid_data_iterator, model, iteration, False
)
if args.save and iteration != 0:
@ -890,6 +881,7 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
def train(
forward_step_func,
valid_forward_step_func,
model,
optimizer,
lr_scheduler,
@ -987,11 +979,15 @@ def train(
if args.co_evaluation:
for key, value in valid_data_iterator.items():
evaluate_and_print_results(
prefix, forward_step_func, value, model, iteration, False, tag=key
prefix, valid_forward_step_func, value, model, iteration, False, tag=key
)
else:
if args.gold:
evaluate_and_print_results_gold(
prefix, forward_step_func, valid_data_iterator, model, iteration, False
)
evaluate_and_print_results(
prefix, forward_step_func, valid_data_iterator, model, iteration, False
prefix, valid_forward_step_func, valid_data_iterator, model, iteration, False
)
# Checkpointing
@ -1194,16 +1190,6 @@ def evaluate_and_print_results_gold(
total_loss_dict[key].item(),
iteration,
)
# writer.add_scalar(
# f"lm-loss-validation/{display_key} validation vs samples",
# total_loss_dict[key].item(),
# args.consumed_train_samples,
# )
# writer.add_scalar(
# f"lm-loss-validation/{display_key} validation vs tokens",
# total_loss_dict[key].item(),
# args.consumed_train_tokens,
# )
if args.log_validation_ppl_to_tensorboard:
writer.add_scalar(
f"lm-loss-validation/{display_key} validation ppl", ppl, iteration

@ -0,0 +1,50 @@
import os
import sys
import torch
import random
import argparse
import numpy as np
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--load-path",
type=str,
default="/zhangpai24/workspace/ckpt_ms/ckpt_ms_213000_fp32_52224.pt")
parser.add_argument("--save-path",
type=str,
default="/zhangpai24/workspace/ckpt_ms/ckpt_ms_213000_qkv.pt")
args, _ = parser.parse_known_args()
state_dict_path = args.load_path
print("Loading state dict ...")
sd = torch.load(state_dict_path, map_location="cpu")
for i in range(40):
if i < 39:
query_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.query.weight', None)
query_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.query.bias', None)
key_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.key.weight', None)
key_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.key.bias', None)
value_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.value.weight', None)
value_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.value.bias', None)
qkv_weight = torch.cat([query_weight, key_weight, value_weight], dim=0)
qkv_bias = torch.cat([query_bias, key_bias, value_bias])
sd['module']['language_model']['transformer'][f'layers.{i}.attention.query_key_value.weight'] = qkv_weight
sd['module']['language_model']['transformer'][f'layers.{i}.attention.query_key_value.bias'] = qkv_bias
else:
tq_key_weight = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.key.weight', None)
tq_key_bias = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.key.bias', None)
tq_value_weight = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.value.weight', None)
tq_value_bias = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.value.bias', None)
tq_kv_weight = torch.cat([tq_key_weight, tq_value_weight], dim=0)
tq_kv_bias = torch.cat([tq_key_bias, tq_value_bias])
sd['module']['language_model']['transformer']['topQueryLayer.attention.key_value.weight'] = tq_kv_weight
sd['module']['language_model']['transformer']['topQueryLayer.attention.key_value.bias'] = tq_kv_bias
save_ckpt_path = args.save_path
torch.save(sd, save_ckpt_path)
if __name__ == '__main__':
main()

@ -0,0 +1,198 @@
import json
import torch
import argparse
import gradio as gr
import codegeex
from codegeex.torch import CodeGeeXModel
from codegeex.tokenizer import CodeGeeXTokenizer
from codegeex.quantization import quantize
from codegeex.data.data_utils import LANGUAGE_TAG
from codegeex.megatron.inference import set_random_seed
def model_provider(args):
"""Build the model."""
model = CodeGeeXModel(
args.hidden_size,
args.num_layers,
args.num_attention_heads,
args.padded_vocab_size,
args.max_position_embeddings
)
return model
def add_code_generation_args(parser):
group = parser.add_argument_group(title="code generation")
group.add_argument(
"--num-layers",
type=int,
default=39,
)
group.add_argument(
"--hidden-size",
type=int,
default=5120,
)
group.add_argument(
"--num-attention-heads",
type=int,
default=40,
)
group.add_argument(
"--padded-vocab-size",
type=int,
default=52224,
)
group.add_argument(
"--max-position-embeddings",
type=int,
default=2048,
)
group.add_argument(
"--tokenizer-path",
type=str,
default="./tokenizer",
)
group.add_argument(
"--example-path",
type=str,
default="./",
)
group.add_argument(
"--load",
type=str,
)
group.add_argument(
"--state-dict-path",
type=str,
)
group.add_argument(
"--micro-batch-size",
type=int,
default=1,
)
group.add_argument(
"--quantize",
action="store_true",
)
return parser
def main():
parser = argparse.ArgumentParser()
parser = add_code_generation_args(parser)
args, _ = parser.parse_known_args()
print("Loading tokenizer ...")
tokenizer = CodeGeeXTokenizer(
tokenizer_path=args.tokenizer_path,
mode="codegeex-13b")
print("Loading state dict ...")
state_dict = torch.load(args.load, map_location="cpu")
state_dict = state_dict["module"]
print("Building CodeGeeX model ...")
model = model_provider(args)
model.load_state_dict(state_dict)
model.eval()
model.half()
if args.quantize:
model = quantize(model, weight_bit_width=8, backend="torch")
model.cuda()
def predict(
prompt,
lang,
seed,
out_seq_length,
temperature,
top_k,
top_p,
):
set_random_seed(seed)
if lang.lower() in LANGUAGE_TAG:
prompt = LANGUAGE_TAG[lang.lower()] + "\n" + prompt
generated_code = codegeex.generate(
model,
tokenizer,
prompt,
out_seq_length=out_seq_length,
seq_length=args.max_position_embeddings,
top_k=top_k,
top_p=top_p,
temperature=temperature,
micro_batch_size=args.micro_batch_size,
backend="megatron",
verbose=True,
)
return prompt + generated_code
examples = []
with open(args.example_path, "r") as f:
for line in f:
examples.append(list(json.loads(line).values()))
with gr.Blocks() as demo:
gr.Markdown(
"""
<img src="https://raw.githubusercontent.com/THUDM/CodeGeeX/main/resources/logo/codegeex_logo.png">
""")
gr.Markdown(
"""
<p align="center">
🏠 <a href="https://codegeex.cn" target="_blank">Homepage</a> | 📖 <a href="http://keg.cs.tsinghua.edu.cn/codegeex/" target="_blank">Blog</a> | 🪧 <a href="https://codegeex.cn/playground" target="_blank">DEMO</a> | 🛠 <a href="https://marketplace.visualstudio.com/items?itemName=aminer.codegeex" target="_blank">VS Code</a> or <a href="https://plugins.jetbrains.com/plugin/20587-codegeex" target="_blank">Jetbrains</a> Extensions | 💻 <a href="https://github.com/THUDM/CodeGeeX" target="_blank">Source code</a> | 🤖 <a href="https://models.aminer.cn/codegeex/download/request" target="_blank">Download Model</a>
</p>
""")
gr.Markdown(
"""
We introduce CodeGeeX, a large-scale multilingual code generation model with 13 billion parameters, pre-trained on a large code corpus of more than 20 programming languages. CodeGeeX supports 15+ programming languages for both code generation and translation. CodeGeeX is open source, please refer to our [GitHub](https://github.com/THUDM/CodeGeeX) for more details. This is a minimal-functional DEMO, for other DEMOs like code translation, please visit our [Homepage](https://codegeex.cn). We also offer free [VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex) or [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex) extensions for full functionality.
""")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(lines=13, placeholder='Please enter the description or select an example input below.',label='Input')
with gr.Row():
gen = gr.Button("Generate")
clr = gr.Button("Clear")
outputs = gr.Textbox(lines=15, label='Output')
gr.Markdown(
"""
Generation Parameter
""")
with gr.Row():
with gr.Column():
lang = gr.Radio(
choices=["C++", "C", "C#", "Python", "Java", "HTML", "PHP", "JavaScript", "TypeScript", "Go",
"Rust",
"SQL", "Kotlin", "R", "Fortran"], value='lang', label='Programming Language',
default="Python")
with gr.Column():
seed = gr.Slider(maximum=10000, value=8888, step=1, label='Seed')
with gr.Row():
out_seq_length = gr.Slider(maximum=1024, value=128, minimum=1, step=1, label='Output Sequence Length')
temperature = gr.Slider(maximum=1, value=0.2, minimum=0, label='Temperature')
with gr.Row():
top_k = gr.Slider(maximum=40, value=0, minimum=0, step=1, label='Top K')
top_p = gr.Slider(maximum=1, value=1.0, minimum=0, label='Top P')
inputs = [prompt, lang, seed, out_seq_length, temperature, top_k, top_p]
gen.click(fn=predict, inputs=inputs, outputs=outputs)
clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=prompt)
gr_examples = gr.Examples(examples=examples, inputs=[prompt, lang],
label="Example Inputs (Click to insert an examplet it into the input box)",
examples_per_page=20)
demo.launch(server_port=6007)
if __name__ == '__main__':
with torch.no_grad():
main()

@ -1,13 +1,9 @@
import os
import copy
import time
import torch
import random
import argparse
import numpy as np
from codegeex.torch.inference import get_token_stream
import codegeex
from codegeex.torch import CodeGeeXModel
from codegeex.tokenizer import CodeGeeXTokenizer
from codegeex.quantization import quantize
@ -111,6 +107,10 @@ def add_code_generation_args(parser):
"--quantize",
action="store_true",
)
group.add_argument(
"--interative",
action="store_true",
)
return parser
@ -143,62 +143,48 @@ def main():
prompt = f.readlines()
prompt = "".join(prompt)
times = {}
out_seq_lengths = [args.out_seq_length]
micro_batch_size = args.micro_batch_size
seq_length = args.max_position_embeddings
for out_seq_length in out_seq_lengths:
print(f"Generating with out_seq_len {out_seq_length}...")
times[out_seq_length] = []
for prompt in [prompt]:
t0 = time.perf_counter()
tokens = tokenizer.encode_code(prompt)
print(tokens)
print("Current prompt:")
print(prompt)
n_token_prompt = len(tokens)
print("N_token_prompt:", n_token_prompt)
token_stream = get_token_stream(
model,
tokenizer,
seq_length,
out_seq_length,
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
micro_batch_size=micro_batch_size,
topk=args.top_k,
topp=args.top_p,
temperature=args.temperature,
greedy=args.greedy,
)
is_finished = [False for _ in range(micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]
for j in range(micro_batch_size):
if is_finished[j]:
continue
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len(
generated_tokens[j]) >= out_seq_length:
is_finished[j] = True
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:])
generated_code = "".join(generated_code)
t1 = time.perf_counter()
print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt)
print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
times[out_seq_length].append(t1 - t0)
print("================================= Generated code:")
print(generated_code)
if all(is_finished):
break
print(times)
for out_seq_length in times.keys():
print(out_seq_length, np.mean(times[out_seq_length]))
while True:
print("\nPlease Input Query (Ctrl-D to save multiple lines, 'stop' to exit) >>> ")
prompts = []
while True:
try:
line = input()
except EOFError:
break
prompts.append(line)
prompt = "\n".join(prompts)
prompt = prompt.strip()
if not prompt:
print('Query should not be empty!')
continue
if prompt == "stop":
return
try:
t0 = time.perf_counter()
generated_code = codegeex.generate(
model,
tokenizer,
prompt,
out_seq_length=out_seq_length,
seq_length=args.max_position_embeddings,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
micro_batch_size=args.micro_batch_size,
backend="megatron",
verbose=True,
)
t1 = time.perf_counter()
print("Total generation time:", t1 - t0)
except (ValueError, FileNotFoundError) as e:
print(e)
continue
print("Generation finished.")
if __name__ == "__main__":
main()
main()

@ -166,44 +166,57 @@ def main():
with open(args.prompt_file, "r") as f:
prompt = f.readlines()
prompt = "".join(prompt)
print_rank_0("Generating ...")
t0 = time.perf_counter()
for prompt in [prompt]:
tokens = tokenizer.tokenize(prompt)
print_rank_0(tokens)
print_rank_0("Current prompt:")
print_rank_0(prompt)
n_token_prompt = len(tokens)
print_rank_0(f"N_token_prompt: {n_token_prompt}")
token_stream = get_token_stream(
model,
[copy.deepcopy(tokens) for _ in range(args.micro_batch_size)],
micro_batch_size=args.micro_batch_size,
bad_ids=args.bad_ids,
)
is_finished = [False for _ in range(args.micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]
for j in range(args.micro_batch_size):
if is_finished[j]:
continue
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eod or len(
generated_tokens[j]) >= args.out_seq_length:
is_finished[j] = True
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
t1 = time.perf_counter()
print_rank_0(f"Total generation time: {t1 - t0}, # Tokens: {len(generated_tokens_) - n_token_prompt}")
print_rank_0(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
print_rank_0("================================= Generated code:")
print_rank_0(generated_code)
t0 = time.perf_counter()
if all(is_finished):
break
times = {}
out_seq_lengths = [args.out_seq_length]
micro_batch_size = args.micro_batch_size
for out_seq_length in out_seq_lengths:
print_rank_0(f"Generating with out_seq_len {out_seq_length}...")
times[out_seq_length] = []
for prompt in [prompt] * args.n_generation:
t0 = time.perf_counter()
tokens = tokenizer.tokenize(prompt)
print_rank_0(tokens)
print_rank_0("Current prompt:")
print_rank_0(prompt)
n_token_prompt = len(tokens)
print_rank_0(f"N_token_prompt:{n_token_prompt}")
token_stream = get_token_stream(
model,
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
micro_batch_size=micro_batch_size,
topk=args.top_k,
topp=args.top_p,
temperature=args.temperature,
)
is_finished = [False for _ in range(micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]
for j in range(micro_batch_size):
if is_finished[j]:
continue
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eod or len(
generated_tokens[j]) >= out_seq_length:
is_finished[j] = True
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
t1 = time.perf_counter()
print_rank_0(f"Total generation time: {t1 - t0}, # Tokens: {len(generated_tokens_) - n_token_prompt}")
print_rank_0(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
times[out_seq_length].append(t1 - t0)
print_rank_0("================================= Generated code:")
print_rank_0(generated_code)
t0 = time.perf_counter()
if all(is_finished):
break
print_rank_0(times)
for out_seq_length in times.keys():
print_rank_0(f"{out_seq_length}, {np.mean(times[out_seq_length])}")
print_rank_0("Generation finished.")
if __name__ == "__main__":
main()

Loading…
Cancel
Save