diff --git a/codegeex/__init__.py b/codegeex/__init__.py index e69de29..954e66b 100644 --- a/codegeex/__init__.py +++ b/codegeex/__init__.py @@ -0,0 +1,70 @@ +import copy + +from typing import * +from codegeex.megatron.model import CodeGeeXModel +from codegeex.tokenizer import CodeGeeXTokenizer +from codegeex.torch.inference import get_token_stream + + +def get_model( + backend: str = "megatron", + quantized: bool = False, +): + pass + + +def generate( + model: CodeGeeXModel, + tokenizer: CodeGeeXTokenizer, + prompt: str, + out_seq_length: int, + seq_length: int = 2048, + top_k: int = 0, + top_p: float = 1.0, + temperature: float = 1.0, + micro_batch_size: int = 1, + backend: str = "megatron", + greedy: bool = False, + verbose: bool = False, +): + tokens = tokenizer.encode_code(prompt) + n_token_prompt = len(tokens) + + if verbose: + print(f"Current prompt:\n{prompt}") + print("N_token_prompt:", n_token_prompt) + + generated_codes = [] + if backend == "megatron": + token_stream = get_token_stream( + model, + tokenizer, + seq_length, + out_seq_length, + [copy.deepcopy(tokens) for _ in range(micro_batch_size)], + micro_batch_size=micro_batch_size, + topk=top_k, + topp=top_p, + temperature=temperature, + greedy=greedy, + ) + is_finished = [False for _ in range(micro_batch_size)] + for i, generated in enumerate(token_stream): + generated_tokens = generated[0] + for j in range(micro_batch_size): + if is_finished[j]: + continue + + if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len(generated_tokens[j]) >= out_seq_length: + is_finished[j] = True + generated_tokens_ = generated_tokens[j].cpu().numpy().tolist() + generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:]) + generated_code = "".join(generated_code) + generated_codes.append(generated_code) + if verbose: + print(f"\nGenerated code {i}:\n{generated_code}") + + if all(is_finished): + break + + return generated_codes \ No newline at end of file diff --git a/codegeex/benchmark/humaneval-x/generate_humaneval_x.py b/codegeex/benchmark/humaneval-x/generate_humaneval_x.py index 82923f9..75eee09 100644 --- a/codegeex/benchmark/humaneval-x/generate_humaneval_x.py +++ b/codegeex/benchmark/humaneval-x/generate_humaneval_x.py @@ -189,7 +189,7 @@ def add_code_generation_args(parser): nargs="*", type=int, default=None, - help='Identify the type of programming language to generate', + help='Specify bad ids that will not be used', ) group.add_argument( "--quantize", diff --git a/codegeex/data/data_utils.py b/codegeex/data/data_utils.py index 90e7e4c..9d5e631 100644 --- a/codegeex/data/data_utils.py +++ b/codegeex/data/data_utils.py @@ -1,23 +1,32 @@ +import os import gzip import json from typing import * - LANGUAGE_TAG = { + "c" : "// language: C", "c++" : "// language: C++", "cpp" : "// language: C++", - "c" : "// language: C", "c#" : "// language: C#", "csharp" : "// language: C#", + "css" : "/* language: CSS */", "cuda" : "// language: Cuda", + "dart" : "// language: Dart", + "lua" : "// language: Lua", + "objectivec" : "// language: Objective-C", "objective-c" : "// language: Objective-C", "objective-c++": "// language: Objective-C++", "python" : "# language: Python", "perl" : "# language: Perl", + "prolog" : f"% language: Prolog", + "swift" : "// language: swift", + "lisp" : "; language: Lisp", "java" : "// language: Java", "scala" : "// language: Scala", "tex" : f"% language: TeX", + "vue" : "", + "markdown" : "", "html" : "", "php" : "// language: PHP", "js" : "// language: JavaScript", @@ -26,13 +35,32 @@ LANGUAGE_TAG = { "go" : "// language: Go", "shell" : "# language: Shell", "rust" : "// language: Rust", - "css" : "/* language: CSS */", "sql" : "-- language: SQL", "kotlin" : "// language: Kotlin", + "vb" : "' language: Visual Basic", + "ruby" : "# language: Ruby", "pascal" : "// language: Pascal", "r" : "# language: R", "fortran" : "!language: Fortran", "lean" : "-- language: Lean", + "matlab" : f"% language: Matlab", + "delphi" : "{language: Delphi}", + "scheme" : "; language: Scheme", + "basic" : "' language: Basic", + "assembly" : "; language: Assembly", + "groovy" : "// language: Groovy", + "abap" : "* language: Abap", + "gdscript" : "# language: GDScript", + "haskell" : "-- language: Haskell", + "julia" : "# language: Julia", + "elixir" : "# language: Elixir", + "excel" : "' language: Excel", + "clojure" : "; language: Clojure", + "actionscript" : "// language: ActionScript", + "solidity" : "// language: Solidity", + "powershell" : "# language: PowerShell", + "erlang" : f"% language: Erlang", + "cobol" : "// language: Cobol", } diff --git a/codegeex/data/process_pretrain_dataset.py b/codegeex/data/process_pretrain_dataset.py index 4942266..8d9f03b 100644 --- a/codegeex/data/process_pretrain_dataset.py +++ b/codegeex/data/process_pretrain_dataset.py @@ -58,7 +58,7 @@ def process_sample( try: if language is not None and language in LANGUAGE_TAG.keys(): - code = LANGUAGE_TAG[language] + sample["code"] + code = LANGUAGE_TAG[language] + "\n" + sample["code"] else: code = sample["code"] except Exception as e: diff --git a/codegeex/data/processor.py b/codegeex/data/processor.py index 4ea507f..25775da 100644 --- a/codegeex/data/processor.py +++ b/codegeex/data/processor.py @@ -67,6 +67,9 @@ class PromptDatasetProcessor(object): """ Instead of processing lazily, we turn the iterable into a list. """ + if sample is None: + return None + return list(self.process_sample(sample)) def process_sample_(self, sample) -> List[Dict[str, List[int]]]: @@ -141,6 +144,9 @@ class LabelDatasetProcessor(object): """ Instead of processing lazily, we turn the iterable into a list. """ + if sample is None: + return None + return list(self.process_sample(sample)) def process_sample_(self, sample) -> List[Dict[str, List[int]]]: diff --git a/codegeex/megatron/arguments.py b/codegeex/megatron/arguments.py index e63b7be..315f4b9 100644 --- a/codegeex/megatron/arguments.py +++ b/codegeex/megatron/arguments.py @@ -415,6 +415,10 @@ def _add_network_size_args(parser): help="Disable BERT binary head.", dest="bert_binary_head", ) + group.add_argument( + "--compress", + action="store_true", + ) return parser @@ -560,6 +564,24 @@ def _add_regularization_args(parser): group.add_argument( "--sgd-momentum", type=float, default=0.9, help="Momentum factor for sgd" ) + group.add_argument( + "--shrink-logit-embedding-gradient", + action="store_true", + ) + group.add_argument( + "--shrink-embedding-gradient-alpha", + type=float, + default=1.0, + help='Shrink embedding gradient for alpha', + ) + group.add_argument( + "--shrink-embedding-gradient-steps", + nargs='*', + default=None, + help='--shrink-embedding-gradient-steps ' + 'Shrink embedding gradient alpha for x1 steps,' + 'then warm it up to 1.0 with x2 steps', + ) return parser @@ -751,6 +773,10 @@ def _add_initialization_args(parser): def _add_inference_args(parser): group = parser.add_argument_group(title="initialization") + group.add_argument( + '--evaluation', + action="store_true", + ) group.add_argument( '--beam-warmup', action="store_true", diff --git a/codegeex/megatron/checkpointing.py b/codegeex/megatron/checkpointing.py index 738107c..53f71e2 100644 --- a/codegeex/megatron/checkpointing.py +++ b/codegeex/megatron/checkpointing.py @@ -84,7 +84,7 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False): if release: directory = "" else: - directory = "iter_{:07d}".format(iteration) + directory = f"global_step{iteration}" # Use both the tensor and pipeline MP rank. if mpu.get_pipeline_model_parallel_world_size() == 1: return os.path.join( @@ -174,7 +174,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): # Saving is a collective communication checkpoint_name = get_checkpoint_name(args.save, iteration) # Trim off the filename and mp_rank_* directory. - for _ in range(3): + for _ in range(2): checkpoint_name = os.path.dirname(checkpoint_name) model[0].save_checkpoint(checkpoint_name, client_state=state_dict) diff --git a/codegeex/megatron/code_generation_utils.py b/codegeex/megatron/code_generation_utils.py index 31df158..8ec082d 100644 --- a/codegeex/megatron/code_generation_utils.py +++ b/codegeex/megatron/code_generation_utils.py @@ -25,10 +25,11 @@ import torch import torch.nn.functional as F from dataclasses import dataclass -from codegeex.megatron import get_args +from codegeex.megatron import get_args, print_rank_0 from codegeex.megatron import get_tokenizer from codegeex.megatron import mpu from codegeex.megatron.utils import get_ltor_masks_and_position_ids +from codegeex.benchmark.utils import is_code_generation_finished def get_batch(context_tokens, micro_batch_size=None): @@ -682,12 +683,17 @@ def beam_search(model, context_tokens, num_beams: int): expanded_beams = expand_beams(beams, num_beams, model) next_beams = [] for beam in expanded_beams: - if args.beam_warmup_length > 0: + if args.beam_warmup: if len(beam.tokens) >= org_context_len + args.beam_warmup_length or beam.tokens[-1] == tokenizer.eod: finished_beams.append(beam) else: next_beams.append(beam) else: + if args.evaluation: + generated_code = tokenizer.detokenize(beam.tokens[org_context_len:]) + if is_code_generation_finished(generated_code): + finished_beams.append(beam) + continue if beam.tokens[-1] == tokenizer.eod: finished_beams.append(beam) else: @@ -842,7 +848,6 @@ def get_token_stream( temperature: float = None, topp: float = None, topk: int = None, - beam_warmup: bool = False, ): args = get_args() tokenizer = get_tokenizer() @@ -866,42 +871,30 @@ def get_token_stream( context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, micro_batch_size) - if beam_warmup: - batch_token_iterator = sample_sequence_batch_beam( - model, - context_tokens_tensor, - context_length_tensor, - attention_mask, - position_ids, - return_scores=return_scores, - prompt_length=prompt_length, - bad_ids=bad_ids, - temperature=temperature, - topp=topp, - topk=topk, - beam_warmup=True, - ) - else: - batch_token_iterator = sample_sequence_batch( - model, - context_tokens_tensor, - context_length_tensor, - attention_mask, - position_ids, - return_scores=return_scores, - prompt_length=prompt_length, - bad_ids=bad_ids, - temperature=temperature, - topp=topp, - topk=topk, - ) + batch_token_iterator = sample_sequence_batch( + model, + context_tokens_tensor, + context_length_tensor, + attention_mask, + position_ids, + return_scores=return_scores, + prompt_length=prompt_length, + bad_ids=bad_ids, + temperature=temperature, + topp=topp, + topk=topk, + ) - for tokens, lengths in batch_token_iterator: - context_length += 1 - if tokens is not None: - yield tokens[:, :context_length], lengths - else: - yield None, None + if args.beam_search: + for beams in batch_token_iterator: + yield beams + else: + for tokens, lengths in batch_token_iterator: + context_length += 1 + if tokens is not None: + yield tokens[:, :context_length], lengths + else: + yield None, None def switch(val1, val2, boolean): @@ -957,284 +950,131 @@ def sample_sequence_batch( if return_scores: scores = torch.zeros([batch_size]).float().cuda() - while context_length <= (maxlen): - - if args.recompute: - logits = model(tokens, - position_ids, - attention_mask, - tokentype_ids=type_ids, - forward_method_parallel_output=False, - prompt_length=prompt_length, - context_length=context_length, - ) - logits = logits[:, context_length - 1, :] + if args.beam_search: + beams = beam_search(model, context_tokens=tokens.cpu().numpy().tolist()[0][:context_length], + num_beams=args.num_beams) + if args.beam_warmup: + beam = beams[0] + tokens_ = beam.tokens + tokens_ = (tokens_ if tokens_[-1] != tokenizer.eod else tokens_[:-1]) + tokens_warmup = [] + for i in range(batch_size): + tokens_warmup.append(tokens_.copy()) + tokens, context_lengths = pad_batch(tokens_warmup, tokenizer.eod, args) + tokens = torch.cuda.LongTensor(tokens) + context_lengths = torch.cuda.LongTensor(context_lengths) + context_length = len(tokens_) + org_context_length = context_length + if maxlen is None: + maxlen = args.seq_length - 1 + if maxlen > (org_context_length + args.out_seq_length): + maxlen = org_context_length + args.out_seq_length + lengths = torch.ones([batch_size]).long().cuda() * maxlen + tokens, attention_mask, position_ids = get_batch(tokens, batch_size) else: - types2use = None - if counter == 0: - tokens2use = tokens[:, :context_length] - positions2use = position_ids[:, :context_length] - if type_ids is not None: - types2use = type_ids[:, :context_length] + yield beams + else: + while context_length <= (maxlen): + if args.recompute: + logits = model(tokens, + position_ids, + attention_mask, + tokentype_ids=type_ids, + forward_method_parallel_output=False, + prompt_length=prompt_length, + context_length=context_length, + ) + logits = logits[:, context_length - 1, :] else: - tokens2use = tokens[:, context_length - 1].view( - batch_size, -1) - positions2use = position_ids[:, context_length - 1].view( - batch_size, -1) - if type_ids is not None: - types2use = type_ids[:, context_length - 1].view( + types2use = None + if counter == 0: + tokens2use = tokens[:, :context_length] + positions2use = position_ids[:, :context_length] + if type_ids is not None: + types2use = type_ids[:, :context_length] + else: + tokens2use = tokens[:, context_length - 1].view( batch_size, -1) - logits, layer_past = model(tokens2use, - positions2use, - attention_mask, - layer_past=layer_past, - get_key_value=True, - tokentype_ids=types2use, - forward_method_parallel_output=False, - prompt_length=prompt_length, - context_length=context_length, - ) - logits = logits[:, -1].view(batch_size, -1).contiguous() - - if mpu.is_pipeline_last_stage(): - if bad_ids is not None: - for bad_id in bad_ids: - logits[:, bad_id] = -10000 - if args.greedy: - prev = torch.argmax(logits, dim=-1).view(-1) - else: - logits = logits.float() - if return_scores: - orig_log_probs = torch.log_softmax(logits, dim=-1) - logits /= temperature - logits = top_k_logits(logits, top_k=topk, top_p=topp) - log_probs = F.softmax(logits, dim=-1) - prev = torch.multinomial(log_probs, num_samples=1).view(-1) - - started = context_lengths <= context_length - - new_tokens = switch(tokens[:, context_length].view(-1), prev, started) - - if not args.greedy and return_scores: - indices = prev.view(-1, 1) - new_scores = orig_log_probs.gather(1, indices).view(-1) - new_scores = new_scores * started - new_scores = new_scores * is_done.bool().logical_not() - scores += new_scores - - tokens[:, context_length] = new_tokens - src = mpu.get_pipeline_model_parallel_last_rank() - group = mpu.get_embedding_group() - torch.distributed.broadcast(new_tokens, src, group) - - done_token = (prev == eos_id).byte() & started.byte() - just_finished = (done_token & ~is_done).bool() - lengths[just_finished.view(-1)] = context_length - is_done = is_done | done_token - - done = torch.all(is_done) - src = mpu.get_pipeline_model_parallel_last_rank() - group = mpu.get_pipeline_model_parallel_group() - torch.distributed.broadcast(done, src, group) - - if return_scores: - yield tokens, (lengths, scores) - else: - yield tokens, lengths + positions2use = position_ids[:, context_length - 1].view( + batch_size, -1) + if type_ids is not None: + types2use = type_ids[:, context_length - 1].view( + batch_size, -1) + logits, layer_past = model(tokens2use, + positions2use, + attention_mask, + layer_past=layer_past, + get_key_value=True, + tokentype_ids=types2use, + forward_method_parallel_output=False, + prompt_length=prompt_length, + context_length=context_length, + ) + logits = logits[:, -1].view(batch_size, -1).contiguous() + + if mpu.is_pipeline_last_stage(): + if bad_ids is not None: + for bad_id in bad_ids: + logits[:, bad_id] = -10000 + if args.greedy: + prev = torch.argmax(logits, dim=-1).view(-1) + else: + logits = logits.float() + if return_scores: + orig_log_probs = torch.log_softmax(logits, dim=-1) + logits /= temperature + logits = top_k_logits(logits, top_k=topk, top_p=topp) + log_probs = F.softmax(logits, dim=-1) + prev = torch.multinomial(log_probs, num_samples=1).view(-1) + + started = context_lengths <= context_length + + new_tokens = switch(tokens[:, context_length].view(-1), prev, started) + + if not args.greedy and return_scores: + indices = prev.view(-1, 1) + new_scores = orig_log_probs.gather(1, indices).view(-1) + new_scores = new_scores * started + new_scores = new_scores * is_done.bool().logical_not() + scores += new_scores - else: - if mpu.is_pipeline_first_stage(): + tokens[:, context_length] = new_tokens src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() - new_tokens = torch.empty_like(tokens[:, context_length]) torch.distributed.broadcast(new_tokens, src, group) - tokens[:, context_length] = new_tokens - yield tokens, None - else: - yield None, None - - done = torch.cuda.ByteTensor([0]) - src = mpu.get_pipeline_model_parallel_last_rank() - group = mpu.get_pipeline_model_parallel_group() - torch.distributed.broadcast(done, src, group) - - context_length += 1 - counter += 1 - if done: - break - - -def sample_sequence_batch_beam( - model, - context_tokens, - context_lengths, - attention_mask, - position_ids, - maxlen=None, - type_ids=None, - return_scores: bool = False, - prompt_length: int = None, - bad_ids: List = None, - temperature: float = None, - topp: float = None, - topk: int = None, - beam_warmup: bool = False, -): - args = get_args() - tokenizer = get_tokenizer() - temperature = temperature if temperature is not None else args.temperature - topp = topp if topp is not None else args.top_p - topk = topk if topk is not None else args.top_k - - model.eval() - with torch.no_grad(): - context_length = context_lengths.min().item() - - # added eos_id to support the function generate_samples_eval that passes - # eos_id as an argument and needs termination when that id id found. - if hasattr(args, "eos_id"): - eos_id = args.eos_id - else: - eos_id = tokenizer.eod - - counter = 0 - org_context_length = context_length - layer_past = None - batch_size = context_tokens.size(0) - is_done = torch.zeros([batch_size]).byte().cuda() - tokens = context_tokens - if maxlen is None: - maxlen = args.seq_length - 1 - if maxlen > (org_context_length + args.out_seq_length): - maxlen = org_context_length + args.out_seq_length + done_token = (prev == eos_id).byte() & started.byte() + just_finished = (done_token & ~is_done).bool() + lengths[just_finished.view(-1)] = context_length + is_done = is_done | done_token - lengths = torch.ones([batch_size]).long().cuda() * maxlen - if return_scores: - scores = torch.zeros([batch_size]).float().cuda() + done = torch.all(is_done) + src = mpu.get_pipeline_model_parallel_last_rank() + group = mpu.get_pipeline_model_parallel_group() + torch.distributed.broadcast(done, src, group) - if beam_warmup: - beams = beam_search(model, context_tokens=tokens.cpu().numpy().tolist()[0][:context_length], - num_beams=args.num_beams) - beam = beams[0] - tokens_ = beam.tokens - tokens_ = (tokens_ if tokens_[-1] != tokenizer.eod else tokens_[:-1]) - tokens_warmup = [] - for i in range(batch_size): - tokens_warmup.append(tokens_.copy()) - tokens, context_lengths = pad_batch(tokens_warmup, tokenizer.eod, args) - tokens = torch.cuda.LongTensor(tokens) - context_lengths = torch.cuda.LongTensor(context_lengths) - context_length = len(tokens_) - org_context_length = context_length - if maxlen is None: - maxlen = args.seq_length - 1 - if maxlen > (org_context_length + args.out_seq_length): - maxlen = org_context_length + args.out_seq_length - lengths = torch.ones([batch_size]).long().cuda() * maxlen - tokens, attention_mask, position_ids = get_batch(tokens, batch_size) - - while context_length <= (maxlen): - if args.recompute: - logits = model(tokens, - position_ids, - attention_mask, - tokentype_ids=type_ids, - forward_method_parallel_output=False, - prompt_length=prompt_length, - context_length=context_length, - ) - logits = logits[:, context_length - 1, :] - else: - types2use = None - if counter == 0: - tokens2use = tokens[:, :context_length] - positions2use = position_ids[:, :context_length] - if type_ids is not None: - types2use = type_ids[:, :context_length] - else: - tokens2use = tokens[:, context_length - 1].view( - batch_size, -1) - positions2use = position_ids[:, context_length - 1].view( - batch_size, -1) - if type_ids is not None: - types2use = type_ids[:, context_length - 1].view( - batch_size, -1) - logits, layer_past = model(tokens2use, - positions2use, - attention_mask, - layer_past=layer_past, - get_key_value=True, - tokentype_ids=types2use, - forward_method_parallel_output=False, - prompt_length=prompt_length, - context_length=context_length, - ) - logits = logits[:, -1].view(batch_size, -1).contiguous() - - if mpu.is_pipeline_last_stage(): - if bad_ids is not None: - for bad_id in bad_ids: - logits[:, bad_id] = -10000 - if args.greedy: - prev = torch.argmax(logits, dim=-1).view(-1) - else: - logits = logits.float() if return_scores: - orig_log_probs = torch.log_softmax(logits, dim=-1) - logits /= temperature - logits = top_k_logits(logits, top_k=topk, top_p=topp) - log_probs = F.softmax(logits, dim=-1) - prev = torch.multinomial(log_probs, num_samples=1).view(-1) - - started = context_lengths <= context_length - - new_tokens = switch(tokens[:, context_length].view(-1), prev, started) - - if not args.greedy and return_scores: - indices = prev.view(-1, 1) - new_scores = orig_log_probs.gather(1, indices).view(-1) - new_scores = new_scores * started - new_scores = new_scores * is_done.bool().logical_not() - scores += new_scores - - tokens[:, context_length] = new_tokens - src = mpu.get_pipeline_model_parallel_last_rank() - group = mpu.get_embedding_group() - torch.distributed.broadcast(new_tokens, src, group) - - done_token = (prev == eos_id).byte() & started.byte() - just_finished = (done_token & ~is_done).bool() - lengths[just_finished.view(-1)] = context_length - is_done = is_done | done_token - - done = torch.all(is_done) - src = mpu.get_pipeline_model_parallel_last_rank() - group = mpu.get_pipeline_model_parallel_group() - torch.distributed.broadcast(done, src, group) - - if return_scores: - yield tokens, (lengths, scores) - else: - yield tokens, lengths + yield tokens, (lengths, scores) + else: + yield tokens, lengths - else: - if mpu.is_pipeline_first_stage(): - src = mpu.get_pipeline_model_parallel_last_rank() - group = mpu.get_embedding_group() - new_tokens = torch.empty_like(tokens[:, context_length]) - torch.distributed.broadcast(new_tokens, src, group) - tokens[:, context_length] = new_tokens - yield tokens, None else: - yield None, None - - done = torch.cuda.ByteTensor([0]) - src = mpu.get_pipeline_model_parallel_last_rank() - group = mpu.get_pipeline_model_parallel_group() - torch.distributed.broadcast(done, src, group) + if mpu.is_pipeline_first_stage(): + src = mpu.get_pipeline_model_parallel_last_rank() + group = mpu.get_embedding_group() + new_tokens = torch.empty_like(tokens[:, context_length]) + torch.distributed.broadcast(new_tokens, src, group) + tokens[:, context_length] = new_tokens + yield tokens, None + else: + yield None, None + + done = torch.cuda.ByteTensor([0]) + src = mpu.get_pipeline_model_parallel_last_rank() + group = mpu.get_pipeline_model_parallel_group() + torch.distributed.broadcast(done, src, group) - context_length += 1 - counter += 1 - if done: - break + context_length += 1 + counter += 1 + if done: + break diff --git a/codegeex/megatron/convert_ckpt_parallel.py b/codegeex/megatron/convert_ckpt_parallel.py index c6adcdc..7b625c1 100644 --- a/codegeex/megatron/convert_ckpt_parallel.py +++ b/codegeex/megatron/convert_ckpt_parallel.py @@ -1,20 +1,8 @@ """Get model parallel partitions.""" import os -import re -import random -import sys - -import numpy as np import torch - -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), - os.path.pardir))) - -from codegeex.megatron import get_args -from codegeex.megatron.model import CodeGeeXModel -from codegeex.megatron.initialize import initialize_megatron -from codegeex.megatron.checkpointing import ensure_directory_exists +import argparse def get_change_ckpt_args(parser): @@ -58,19 +46,10 @@ def get_element_from_dict_by_path(d, path): def main(): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = str(random.randint(10000, 20000)) - - initialize_megatron( - extra_args_provider=get_change_ckpt_args, - args_defaults={ - "tokenizer_type": "GPT2BPETokenizer", - "no_load_rng" : True, - "no_load_optim" : True, - }, - ) - - args = get_args() + parser = argparse.ArgumentParser() + parser = get_change_ckpt_args(parser) + args, _ = parser.parse_known_args() + print(f"Load ckpt from {args.load_ckpt_path}...") state_dict = torch.load(args.load_ckpt_path, map_location="cpu") diff --git a/codegeex/megatron/merge_ckpt_parallel.py b/codegeex/megatron/merge_ckpt_parallel.py new file mode 100644 index 0000000..f032bef --- /dev/null +++ b/codegeex/megatron/merge_ckpt_parallel.py @@ -0,0 +1,133 @@ +"""Merge model parallel partitions into a single checkpoint.""" + +import os +import torch +import random + +from codegeex.megatron import get_args +from codegeex.megatron.model import CodeGeeXModel +from codegeex.megatron.initialize import initialize_megatron +from codegeex.megatron.checkpointing import ensure_directory_exists + + +def get_change_ckpt_args(parser): + """Provide extra arguments required for merging.""" + group = parser.add_argument_group(title='Mindspore to megatron') + group.add_argument( + '--load-ckpt-path', + type=str, + required=True, + help='dir to load model parallel partitions.', + ) + group.add_argument( + '--save-ckpt-path', + type=str, + required=True, + help='path to save ".pt" checkpoint.', + ) + group.add_argument( + '--source-tensor-model-parallel-size', + type=int, + default=2, + help='original tensor model parallel size', + ) + + return parser + + +def main(): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(random.randint(10000, 20000)) + + initialize_megatron( + extra_args_provider=get_change_ckpt_args, + args_defaults={ + "tokenizer_type": "GPT2BPETokenizer", + "no_load_rng" : True, + "no_load_optim" : True, + }, + ) + + args = get_args() + model = CodeGeeXModel() + print(model.state_dict) + + # Save the model. + sd = {} + sd['module'] = model.state_dict_for_save_checkpoint() + ensure_directory_exists(args.save_ckpt_path) + + print(f"Load ckpt from {args.load_ckpt_path}...") + state_dict_list = [] + for i in range(args.source_tensor_model_parallel_size): + try: + state_dict_list.append(torch.load(os.path.join(args.load_ckpt_path, f"mp_rank_{i:02d}_model_states.pt"), map_location="cpu")) + except Exception as e: + print(e) + exit(0) + + print(f"Merging {len(state_dict_list)} partitions into a single ckpt...") + print("Merging Embedding layers...") + vocab_parallel_size = args.make_vocab_size_divisible_by // args.source_tensor_model_parallel_size + for i in range(args.source_tensor_model_parallel_size): + sd['module']['language_model']['embedding']['word_embeddings']['weight'][i * vocab_parallel_size : (i + 1) * vocab_parallel_size, :] = state_dict_list[i]['module']['language_model']['embedding']['word_embeddings']['weight'] + + sd['module']['language_model']['embedding']['position_embeddings']['weight'] = state_dict_list[0]['module']['language_model']['embedding']['position_embeddings']['weight'] + + print("Merging QueryEmbedding layers...") + query_parallel_size = args.max_position_embeddings // args.source_tensor_model_parallel_size + for i in range(args.source_tensor_model_parallel_size): + sd['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight'][i * query_parallel_size : (i + 1) * query_parallel_size, :] = state_dict_list[i]['module']['language_model']['topQueryEmbedding']['top_query_embeddings'].pop('weight', None) + + print("Merging Transformer layers...") + for layer_name in sd['module']['language_model']['transformer'].keys(): + if "layernorm" in layer_name: + sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None) + elif "attention" in layer_name and "weight" in layer_name: + if "dense" in layer_name: + hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[1] // args.source_tensor_model_parallel_size + for i in range(args.source_tensor_model_parallel_size): + sd['module']['language_model']['transformer'][layer_name][:, i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None) + else: + hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size + for i in range(args.source_tensor_model_parallel_size): + sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size, :] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None) + elif "weight" in layer_name and "dense" in layer_name: + if "h_to_4h" in layer_name: + hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size + for i in range(args.source_tensor_model_parallel_size): + sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size, :] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None) + else: + hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[1] // args.source_tensor_model_parallel_size + for i in range(args.source_tensor_model_parallel_size): + sd['module']['language_model']['transformer'][layer_name][:, i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None) + elif "bias" in layer_name: + if "mlp" in layer_name: + if "4h_to_h" in layer_name: + sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None) + else: + hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size + for i in range(args.source_tensor_model_parallel_size): + sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None) + elif "attention" in layer_name: + if "dense" in layer_name: + sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None) + else: + hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size + for i in range(args.source_tensor_model_parallel_size): + sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None) + else: + sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None) + + if args.save_ckpt_path.endswith(".pt"): + save_ckpt_path = args.save_ckpt_path + else: + os.makedirs(args.save_ckpt_path, exist_ok=True) + save_ckpt_path = os.path.join(args.save_ckpt_path, "mp_rank_00_model_states.pt") + + torch.save(sd, save_ckpt_path) + print(f"Converted checkpoint saved in {save_ckpt_path}.") + + +if __name__ == '__main__': + main() diff --git a/codegeex/megatron/model/language_model.py b/codegeex/megatron/model/language_model.py index d9798d2..e0b5a9c 100644 --- a/codegeex/megatron/model/language_model.py +++ b/codegeex/megatron/model/language_model.py @@ -19,17 +19,46 @@ import torch import torch.nn.functional as F from codegeex.megatron import get_args -from codegeex.megatron import mpu +from codegeex.megatron import mpu, print_rank_0 from codegeex.megatron.model.module import MegatronModule from codegeex.megatron.model.transformer import ParallelTransformer from codegeex.megatron.model.utils import init_method_normal, scaled_init_method_normal +from codegeex.megatron.mpu.initialize import get_tensor_model_parallel_world_size +def get_shrink_embedding_gradient_alpha(iteration): + args = get_args() + + alpha = args.shrink_embedding_gradient_alpha + if args.shrink_embedding_gradient_steps is None: + return alpha + else: + x1 = int(args.shrink_embedding_gradient_steps[0]) + x2 = int(args.shrink_embedding_gradient_steps[1]) + if iteration <= x1: + return alpha + elif iteration >= x1 + x2: + return 1.0 + else: + return alpha + (1 - alpha) * (args.iteration - x1) / x2 + + def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): """LM logits using word embedding weights.""" # Parallel logits. input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) # Matrix multiply. + args = get_args() + if args.shrink_logit_embedding_gradient: + if hasattr(args, 'iteration'): + alpha = get_shrink_embedding_gradient_alpha(args.iteration + 1) + else: + alpha = args.shrink_embedding_gradient_alpha + word_embeddings_weight = word_embeddings_weight if alpha == 1.0 \ + else ( + word_embeddings_weight * alpha + + word_embeddings_weight.detach() * (1 - alpha) + ) if bias is None: logits_parallel = F.linear(input_parallel, word_embeddings_weight.half()) else: @@ -92,15 +121,19 @@ class Embedding(MegatronModule): num_tokentypes=0, ): super(Embedding, self).__init__() - + + args = get_args() + self.hidden_size = hidden_size self.init_method = init_method self.num_tokentypes = num_tokentypes - + self.max_sequence_length = max_sequence_length + # Word embeddings (parallel). self.word_embeddings = mpu.VocabParallelEmbedding( vocab_size, self.hidden_size, init_method=self.init_method) self._word_embeddings_key = 'word_embeddings' + self.vocab_size = vocab_size # Position embedding (serial). @@ -108,6 +141,7 @@ class Embedding(MegatronModule): max_sequence_length, self.hidden_size) self.position_embeddings = self.position_embeddings.half() self._position_embeddings_key = 'position_embeddings' + # Initialize the position embeddings. self.init_method(self.position_embeddings.weight) @@ -190,7 +224,8 @@ class Embedding(MegatronModule): if 'word_embeddings' in key: state_dict_[key.split('word_embeddings.')[1]] \ = state_dict[key] - state_dict_["weight"] = state_dict_["weight"][:self.vocab_size] + vocab_len = state_dict_['weight'].shape[0] + state_dict_["weight"] = state_dict_["weight"][:self.vocab_size // get_tensor_model_parallel_world_size()] self.word_embeddings.load_state_dict(state_dict_, strict=strict) # Position embedding. @@ -203,6 +238,17 @@ class Embedding(MegatronModule): if 'position_embeddings' in key: state_dict_[key.split('position_embeddings.')[1]] \ = state_dict[key] + + pos_len = state_dict_['weight'].shape[0] + max_seq_len = self.max_sequence_length + if pos_len < max_seq_len: + print_rank_0(f"Position embedding padded {pos_len} -> {max_seq_len}.") + position_embeddings_padded = torch.nn.Embedding( + max_seq_len - pos_len, self.hidden_size).half() + self.init_method(position_embeddings_padded.weight) + state_dict_['weight'] = torch.cat([state_dict_['weight'], position_embeddings_padded.weight], dim=0) + + # self.position_embeddings = self.position_embeddings.half() self.position_embeddings.load_state_dict(state_dict_, strict=strict) # Tokentype embedding. @@ -284,12 +330,14 @@ class QueryEmbedding(MegatronModule): self.hidden_size = hidden_size self.init_method = init_method self.num_tokentypes = num_tokentypes - + self.max_sequence_length = max_sequence_length + # Top query position embedding (serial). self.top_query_embeddings = mpu.VocabParallelEmbedding( max_sequence_length, self.hidden_size, init_method=self.init_method) self.top_query_embeddings = self.top_query_embeddings.half() self._top_query_embeddings_key = 'top_query_embeddings' + # Initialize the top query position embeddings. self.init_method(self.top_query_embeddings.weight) @@ -368,6 +416,14 @@ class QueryEmbedding(MegatronModule): if 'top_query_embeddings' in key: state_dict_[key.split('top_query_embeddings.')[1]] \ = state_dict[key] + pos_len = state_dict_['weight'].shape[0] + max_seq_len = self.max_sequence_length // get_tensor_model_parallel_world_size() + if pos_len < max_seq_len: + print_rank_0(f"Top query embedding padded {pos_len} -> {max_seq_len}.") + top_query_embeddings_padded = torch.nn.Embedding( + max_seq_len - pos_len, self.hidden_size).half() + self.init_method(top_query_embeddings_padded.weight) + state_dict_['weight'] = torch.cat([state_dict_['weight'], top_query_embeddings_padded.weight], dim=0) self.top_query_embeddings.load_state_dict(state_dict_, strict=strict) # Tokentype embedding. diff --git a/codegeex/megatron/model/transformer.py b/codegeex/megatron/model/transformer.py index a66580e..02e66aa 100644 --- a/codegeex/megatron/model/transformer.py +++ b/codegeex/megatron/model/transformer.py @@ -61,14 +61,19 @@ class ParallelMLP(MegatronModule): applied. """ - def __init__(self, init_method, output_layer_init_method): + def __init__( + self, + init_method, + output_layer_init_method, + scale: int = 4, + ): super(ParallelMLP, self).__init__() args = get_args() # Project to 4h. self.dense_h_to_4h = mpu.ColumnParallelLinear( args.hidden_size, - 4 * args.hidden_size, + scale * args.hidden_size, gather_output=False, init_method=init_method, # skip_bias_add=True, @@ -78,7 +83,7 @@ class ParallelMLP(MegatronModule): # Project back to h. self.dense_4h_to_h = mpu.RowParallelLinear( - 4 * args.hidden_size, + scale * args.hidden_size, args.hidden_size, input_is_parallel=True if args.tensor_model_parallel_size > 1 else False, init_method=output_layer_init_method, @@ -112,10 +117,11 @@ class ParallelSelfAttention(MegatronModule): # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() - self.hidden_size_per_partition = mpu.divide(args.hidden_size, - world_size) + self.hidden_size_per_partition = mpu.divide( + args.hidden_size // 2 if args.compress else args.hidden_size, + world_size) self.hidden_size_per_attention_head = mpu.divide( - args.hidden_size, args.num_attention_heads) + args.hidden_size // 2 if args.compress else args.hidden_size, args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( args.num_attention_heads, world_size) if hasattr(args, 'attention_upweight'): @@ -125,17 +131,17 @@ class ParallelSelfAttention(MegatronModule): # Strided linear layer. self.query = mpu.ColumnParallelLinear( args.hidden_size, - args.hidden_size, + args.hidden_size // 2 if args.compress else args.hidden_size, gather_output=False, init_method=init_method) self.key = mpu.ColumnParallelLinear( args.hidden_size, - args.hidden_size, + args.hidden_size // 2 if args.compress else args.hidden_size, gather_output=False, init_method=init_method) self.value = mpu.ColumnParallelLinear( args.hidden_size, - args.hidden_size, + args.hidden_size // 2 if args.compress else args.hidden_size, gather_output=False, init_method=init_method) @@ -149,7 +155,7 @@ class ParallelSelfAttention(MegatronModule): # Output. self.dense = mpu.RowParallelLinear( - args.hidden_size, + args.hidden_size // 2 if args.compress else args.hidden_size, args.hidden_size, input_is_parallel=True if args.tensor_model_parallel_size > 1 else False, init_method=output_layer_init_method, @@ -264,7 +270,7 @@ class ParallelSelfAttention(MegatronModule): if self.attention_softmax_in_fp32: attention_probs = self.softmax(attention_scores.float()).half() else: - attention_probs = self.softmax(attention_scores) + attention_probs = self.softmax(attention_scores.half()) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. @@ -485,7 +491,7 @@ class ParallelTopQuerySelfAttention(MegatronModule): if self.attention_softmax_in_fp32: attention_probs = self.softmax(attention_scores.float()).half() else: - attention_probs = self.softmax(attention_scores) + attention_probs = self.softmax(attention_scores.half()) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. @@ -607,7 +613,8 @@ class ParallelTransformerLayer(MegatronModule): self.ln_fp16 = False # MLP self.mlp = ParallelMLP(init_method, - output_layer_init_method) + output_layer_init_method, + scale=2 if args.compress else 4) def forward( self, diff --git a/codegeex/megatron/tools/finetune_codegeex.py b/codegeex/megatron/tools/finetune_codegeex.py new file mode 100644 index 0000000..4a56602 --- /dev/null +++ b/codegeex/megatron/tools/finetune_codegeex.py @@ -0,0 +1,326 @@ +import os +import torch +import logging + +logging.getLogger("torch").setLevel(logging.WARNING) + +import deepspeed +from deepspeed.runtime.utils import see_memory_usage +from functools import partial + +from codegeex.megatron import get_args, print_rank_0, get_timers,get_tokenizer, mpu +from codegeex.megatron.data.prompt_dataset import build_train_valid_test_datasets +from codegeex.megatron.model import CodeGeeXModel +from codegeex.megatron.training import pretrain +from codegeex.megatron.utils import get_ltor_masks_and_position_ids +from codegeex.megatron.utils import average_losses_across_data_parallel_group + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0("building GPT model ...") + see_memory_usage(f"Before Building Model", force=True) + + args = get_args() + with deepspeed.zero.Init( + data_parallel_group=mpu.get_data_parallel_group(), + remote_device=None if args.remote_device == "none" else args.remote_device, + config_dict_or_path=args.deepspeed_config, + enabled=args.zero_stage == 3, + mpu=mpu, + ): + if args.deepspeed and not args.no_pipeline_parallel: + model = CodeGeeXModelPipe(num_tokentypes=0, parallel_output=True) + # This is a hack to give us a reference to get_batch_pipe from within training.py + # We need to call model.set_batch_fn after deepspeed.initialize + model._megatron_batch_fn = get_batch_pipe + + # Predompute the attention mask and store it in args. This avoids having to + # pipeline it as an activation during training. The mask is constant, and thus + # we can reuse it. + attention_mask = torch.tril( + torch.ones( + (1, args.seq_length, args.seq_length), + device=torch.cuda.current_device(), + ) + ).view(1, 1, args.seq_length, args.seq_length) + + # Convert attention mask to binary: + attention_mask = attention_mask < 0.5 + if args.fp16: + attention_mask = attention_mask.half() + elif args.bf16: + attention_mask = attention_mask.bfloat16() + + # Attention mask must be bool. + args.attn_mask = attention_mask.to(torch.bool) + + else: + model = CodeGeeXModel( + num_tokentypes=0, + parallel_output=True, + ) + + if args.load_state is not None: + timers = get_timers() + print_rank_0("Loading warmstarting model states ...") + timers("load-model-states").start() + mp_rank = mpu.get_tensor_model_parallel_rank() + if os.path.isdir(args.load_state): + model_path = os.path.join( + args.load_state, "mp_rank_{:02d}_model_states.pt".format(mp_rank) + ) + else: + model_path = args.load_state + print_rank_0(f"Loading model from {model_path} ...") + state_dict = torch.load(model_path, map_location="cpu") + if "module" in state_dict: + state_dict = state_dict["module"] # strip other client states + model.load_state_dict(state_dict) + timers("load-model-states").stop() + timers.log(["load-model-states"]) + see_memory_usage(f"After Building Model", force=True) + + return model + + +def get_batch(data_iterator): + """Generate a batch""" + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = ["input_ids", "attention_mask", "labels"] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b["input_ids"].contiguous() + # attn_mask_ = data_b["attention_mask"].contiguous() + labels_ = data_b["labels"].contiguous() + + tokens = tokens_[:, :-1] + labels = labels_[:, 1:] + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss, + ) + + # mask loss to avoid predicting prompt and paddings + prompt_loss_mask = labels >= 0 + loss_mask = prompt_loss_mask * loss_mask + + return tokens, labels, loss_mask, attention_mask, position_ids + + +def get_batch_pipe(data): + """Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`""" + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = ["input_ids"] + datatype = torch.int64 + + # Broadcast data. + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b["input_ids"].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss, + ) + + return (tokens, position_ids, attention_mask), (labels, loss_mask) + + +def loss_func(loss_mask, output_tensor): + args = get_args() + + def compute_lm_loss(losses: torch.Tensor, loss_mask: torch.Tensor): + if args.gold: + losses_ = losses.detach() + prob = torch.exp(-losses_) # Pθ(s) + torch.sqrt_(prob) # Pθ(s)ᵃ + torch.clamp_min_(prob, args.gold_beta) # max(Pθ(s)ᵃ,β) + losses = prob * losses + + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / torch.clamp_min(loss_mask.sum(), 1e-8) + + return loss + + losses = output_tensor.float() + loss = compute_lm_loss(losses, loss_mask) + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + return loss, {"lm loss": averaged_loss[0]} + + +def valid_loss_func(loss_mask, output_tensor): + args = get_args() + + def compute_lm_loss(losses: torch.Tensor, loss_mask: torch.Tensor): + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / torch.clamp_min(loss_mask.sum(), 1e-8) + + return loss + + losses = output_tensor.float() + loss = compute_lm_loss(losses, loss_mask) + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + return loss, {"lm loss": averaged_loss[0]} + + +def forward_step(data_iterator, model): + """Forward step.""" + args = get_args() + timers = get_timers() + + # Get the batch. + timers("batch-generator").start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator) + timers("batch-generator").stop() + + output_tensor = model(tokens, position_ids, attention_mask, labels=labels) + + return output_tensor, partial(loss_func, loss_mask) + + +def valid_forward_step(data_iterator, model): + """Forward step.""" + args = get_args() + timers = get_timers() + + # Get the batch. + timers("batch-generator").start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator) + timers("batch-generator").stop() + + output_tensor = model(tokens, position_ids, attention_mask, labels=labels) + + return output_tensor, partial(valid_loss_func, loss_mask) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + + print_rank_0("> building train, validation, and test datasets " "for GPT ...") + if args.co_evaluation: + def dataset_partition_path_parsing(data_path): + dataset_path = {} + for index in range(len(data_path)): + dataset_path[data_path[index]] = data_path[index] + return dataset_path + assert args.valid_data_path is not None, "Valid data path must be given when --co-evaluation is turned on." + valid_data_path = dataset_partition_path_parsing(args.valid_data_path) + if args.test_data_path is not None: + test_data_path = dataset_partition_path_parsing(args.test_data_path) + else: + test_data_path = None + train_ds, _, _ = build_train_valid_test_datasets( + data_prefix=args.data_path, + data_impl=args.data_impl, + splits_string="1,0,0", + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + ) + valid_ds = {} + for key, value in valid_data_path.items(): + _, valid_ds_item, _ = build_train_valid_test_datasets( + data_prefix=[value], + data_impl=args.data_impl, + splits_string="0,1,0", + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + ) + valid_ds[key] = valid_ds_item + if test_data_path is not None: + test_ds = {} + for key, value in test_data_path.items(): + _, _, test_ds_item = build_train_valid_test_datasets( + data_prefix=[value], + data_impl=args.data_impl, + splits_string="0,0,1", + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + ) + test_ds[key] = test_ds_item + else: + test_ds = None + elif args.valid_data_path is None: + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=args.data_path, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + ) + else: + train_ds, _, test_ds = build_train_valid_test_datasets( + data_prefix=args.data_path, + data_impl=args.data_impl, + splits_string="100,0,0", + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + ) + + _, valid_ds, _ = build_train_valid_test_datasets( + data_prefix=args.valid_data_path, + data_impl=args.data_impl, + splits_string="0,100,0", + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + ) + + print_rank_0("> finished creating GPT datasets ...") + + return train_ds, valid_ds, test_ds + + +if __name__ == "__main__": + pretrain( + train_valid_test_datasets_provider, + model_provider, + forward_step, + valid_forward_step, + args_defaults={"tokenizer_type": "GPT2BPETokenizer"}, + ) \ No newline at end of file diff --git a/codegeex/megatron/training.py b/codegeex/megatron/training.py index 322ce69..593649d 100644 --- a/codegeex/megatron/training.py +++ b/codegeex/megatron/training.py @@ -65,12 +65,6 @@ except ImportError: from filelock import FileLock import pathlib -try: - import bmcook - from bmcook import Config -except ImportError: - print("bmcook not imported.") - bmcook = None def print_datetime(string): @@ -80,15 +74,11 @@ def print_datetime(string): print_rank_0("[" + string + "] datetime: {} ".format(time_str)) -def compress_setup(args, model, optimizer): - teacher = get_model(args) - cook_config = ConfigParser(args.cook_config) - CPMAntTrainer.set_compression(cook_config, model, optimizer, teacher=teacher, remove_ckptblock=False, target_linear=Linear) - def pretrain( train_valid_test_dataset_provider, model_provider, forward_step_func, + valid_forward_step_func=None, extra_args_provider=None, args_defaults={}, ): @@ -187,6 +177,7 @@ def pretrain( if args.do_train and args.train_iters > 0: iteration = train( forward_step_func, + valid_forward_step_func, model, optimizer, lr_scheduler, @@ -200,11 +191,11 @@ def pretrain( if args.co_evaluation: for key, value in valid_data_iterator.items(): evaluate_and_print_results( - prefix, forward_step_func, value, model, iteration, False, tag=key + prefix, valid_forward_step_func, value, model, iteration, False, tag=key ) else: evaluate_and_print_results( - prefix, forward_step_func, valid_data_iterator, model, iteration, False + prefix, valid_forward_step_func, valid_data_iterator, model, iteration, False ) if args.save and iteration != 0: @@ -890,6 +881,7 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler): def train( forward_step_func, + valid_forward_step_func, model, optimizer, lr_scheduler, @@ -987,11 +979,15 @@ def train( if args.co_evaluation: for key, value in valid_data_iterator.items(): evaluate_and_print_results( - prefix, forward_step_func, value, model, iteration, False, tag=key + prefix, valid_forward_step_func, value, model, iteration, False, tag=key ) else: + if args.gold: + evaluate_and_print_results_gold( + prefix, forward_step_func, valid_data_iterator, model, iteration, False + ) evaluate_and_print_results( - prefix, forward_step_func, valid_data_iterator, model, iteration, False + prefix, valid_forward_step_func, valid_data_iterator, model, iteration, False ) # Checkpointing @@ -1194,16 +1190,6 @@ def evaluate_and_print_results_gold( total_loss_dict[key].item(), iteration, ) - # writer.add_scalar( - # f"lm-loss-validation/{display_key} validation vs samples", - # total_loss_dict[key].item(), - # args.consumed_train_samples, - # ) - # writer.add_scalar( - # f"lm-loss-validation/{display_key} validation vs tokens", - # total_loss_dict[key].item(), - # args.consumed_train_tokens, - # ) if args.log_validation_ppl_to_tensorboard: writer.add_scalar( f"lm-loss-validation/{display_key} validation ppl", ppl, iteration diff --git a/codegeex/torch/get_ckpt_qkv.py b/codegeex/torch/get_ckpt_qkv.py new file mode 100644 index 0000000..78693a4 --- /dev/null +++ b/codegeex/torch/get_ckpt_qkv.py @@ -0,0 +1,50 @@ +import os +import sys +import torch +import random +import argparse +import numpy as np + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--load-path", + type=str, + default="/zhangpai24/workspace/ckpt_ms/ckpt_ms_213000_fp32_52224.pt") + parser.add_argument("--save-path", + type=str, + default="/zhangpai24/workspace/ckpt_ms/ckpt_ms_213000_qkv.pt") + + args, _ = parser.parse_known_args() + + state_dict_path = args.load_path + print("Loading state dict ...") + sd = torch.load(state_dict_path, map_location="cpu") + + for i in range(40): + if i < 39: + query_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.query.weight', None) + query_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.query.bias', None) + key_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.key.weight', None) + key_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.key.bias', None) + value_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.value.weight', None) + value_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.value.bias', None) + qkv_weight = torch.cat([query_weight, key_weight, value_weight], dim=0) + qkv_bias = torch.cat([query_bias, key_bias, value_bias]) + sd['module']['language_model']['transformer'][f'layers.{i}.attention.query_key_value.weight'] = qkv_weight + sd['module']['language_model']['transformer'][f'layers.{i}.attention.query_key_value.bias'] = qkv_bias + else: + tq_key_weight = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.key.weight', None) + tq_key_bias = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.key.bias', None) + tq_value_weight = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.value.weight', None) + tq_value_bias = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.value.bias', None) + tq_kv_weight = torch.cat([tq_key_weight, tq_value_weight], dim=0) + tq_kv_bias = torch.cat([tq_key_bias, tq_value_bias]) + sd['module']['language_model']['transformer']['topQueryLayer.attention.key_value.weight'] = tq_kv_weight + sd['module']['language_model']['transformer']['topQueryLayer.attention.key_value.bias'] = tq_kv_bias + + save_ckpt_path = args.save_path + torch.save(sd, save_ckpt_path) + +if __name__ == '__main__': + main() diff --git a/deployment/server_gradio.py b/deployment/server_gradio.py new file mode 100644 index 0000000..96eb770 --- /dev/null +++ b/deployment/server_gradio.py @@ -0,0 +1,198 @@ +import json +import torch +import argparse +import gradio as gr + +import codegeex +from codegeex.torch import CodeGeeXModel +from codegeex.tokenizer import CodeGeeXTokenizer +from codegeex.quantization import quantize +from codegeex.data.data_utils import LANGUAGE_TAG +from codegeex.megatron.inference import set_random_seed + + +def model_provider(args): + """Build the model.""" + + model = CodeGeeXModel( + args.hidden_size, + args.num_layers, + args.num_attention_heads, + args.padded_vocab_size, + args.max_position_embeddings + ) + return model + + +def add_code_generation_args(parser): + group = parser.add_argument_group(title="code generation") + group.add_argument( + "--num-layers", + type=int, + default=39, + ) + group.add_argument( + "--hidden-size", + type=int, + default=5120, + ) + group.add_argument( + "--num-attention-heads", + type=int, + default=40, + ) + group.add_argument( + "--padded-vocab-size", + type=int, + default=52224, + ) + group.add_argument( + "--max-position-embeddings", + type=int, + default=2048, + ) + group.add_argument( + "--tokenizer-path", + type=str, + default="./tokenizer", + ) + group.add_argument( + "--example-path", + type=str, + default="./", + ) + group.add_argument( + "--load", + type=str, + ) + group.add_argument( + "--state-dict-path", + type=str, + ) + group.add_argument( + "--micro-batch-size", + type=int, + default=1, + ) + group.add_argument( + "--quantize", + action="store_true", + ) + + return parser + + +def main(): + parser = argparse.ArgumentParser() + parser = add_code_generation_args(parser) + args, _ = parser.parse_known_args() + + print("Loading tokenizer ...") + tokenizer = CodeGeeXTokenizer( + tokenizer_path=args.tokenizer_path, + mode="codegeex-13b") + + print("Loading state dict ...") + state_dict = torch.load(args.load, map_location="cpu") + state_dict = state_dict["module"] + + print("Building CodeGeeX model ...") + model = model_provider(args) + model.load_state_dict(state_dict) + model.eval() + model.half() + if args.quantize: + model = quantize(model, weight_bit_width=8, backend="torch") + model.cuda() + + def predict( + prompt, + lang, + seed, + out_seq_length, + temperature, + top_k, + top_p, + ): + set_random_seed(seed) + if lang.lower() in LANGUAGE_TAG: + prompt = LANGUAGE_TAG[lang.lower()] + "\n" + prompt + + generated_code = codegeex.generate( + model, + tokenizer, + prompt, + out_seq_length=out_seq_length, + seq_length=args.max_position_embeddings, + top_k=top_k, + top_p=top_p, + temperature=temperature, + micro_batch_size=args.micro_batch_size, + backend="megatron", + verbose=True, + ) + return prompt + generated_code + + examples = [] + with open(args.example_path, "r") as f: + for line in f: + examples.append(list(json.loads(line).values())) + + with gr.Blocks() as demo: + gr.Markdown( + """ + + """) + gr.Markdown( + """ +

+ 🏠 Homepage | 📖 Blog | 🪧 DEMO | 🛠 VS Code or Jetbrains Extensions | 💻 Source code | 🤖 Download Model +

+ """) + gr.Markdown( + """ + We introduce CodeGeeX, a large-scale multilingual code generation model with 13 billion parameters, pre-trained on a large code corpus of more than 20 programming languages. CodeGeeX supports 15+ programming languages for both code generation and translation. CodeGeeX is open source, please refer to our [GitHub](https://github.com/THUDM/CodeGeeX) for more details. This is a minimal-functional DEMO, for other DEMOs like code translation, please visit our [Homepage](https://codegeex.cn). We also offer free [VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex) or [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex) extensions for full functionality. + """) + + with gr.Row(): + with gr.Column(): + prompt = gr.Textbox(lines=13, placeholder='Please enter the description or select an example input below.',label='Input') + with gr.Row(): + gen = gr.Button("Generate") + clr = gr.Button("Clear") + + outputs = gr.Textbox(lines=15, label='Output') + + gr.Markdown( + """ + Generation Parameter + """) + with gr.Row(): + with gr.Column(): + lang = gr.Radio( + choices=["C++", "C", "C#", "Python", "Java", "HTML", "PHP", "JavaScript", "TypeScript", "Go", + "Rust", + "SQL", "Kotlin", "R", "Fortran"], value='lang', label='Programming Language', + default="Python") + with gr.Column(): + seed = gr.Slider(maximum=10000, value=8888, step=1, label='Seed') + with gr.Row(): + out_seq_length = gr.Slider(maximum=1024, value=128, minimum=1, step=1, label='Output Sequence Length') + temperature = gr.Slider(maximum=1, value=0.2, minimum=0, label='Temperature') + with gr.Row(): + top_k = gr.Slider(maximum=40, value=0, minimum=0, step=1, label='Top K') + top_p = gr.Slider(maximum=1, value=1.0, minimum=0, label='Top P') + + inputs = [prompt, lang, seed, out_seq_length, temperature, top_k, top_p] + gen.click(fn=predict, inputs=inputs, outputs=outputs) + clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=prompt) + + gr_examples = gr.Examples(examples=examples, inputs=[prompt, lang], + label="Example Inputs (Click to insert an examplet it into the input box)", + examples_per_page=20) + + demo.launch(server_port=6007) + +if __name__ == '__main__': + with torch.no_grad(): + main() \ No newline at end of file diff --git a/tests/test_inference.py b/tests/test_inference.py index be6ed40..c4267ab 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1,13 +1,9 @@ - -import os -import copy import time import torch -import random import argparse import numpy as np -from codegeex.torch.inference import get_token_stream +import codegeex from codegeex.torch import CodeGeeXModel from codegeex.tokenizer import CodeGeeXTokenizer from codegeex.quantization import quantize @@ -111,6 +107,10 @@ def add_code_generation_args(parser): "--quantize", action="store_true", ) + group.add_argument( + "--interative", + action="store_true", + ) return parser @@ -143,62 +143,48 @@ def main(): prompt = f.readlines() prompt = "".join(prompt) - times = {} out_seq_lengths = [args.out_seq_length] - micro_batch_size = args.micro_batch_size - seq_length = args.max_position_embeddings for out_seq_length in out_seq_lengths: print(f"Generating with out_seq_len {out_seq_length}...") - - times[out_seq_length] = [] - for prompt in [prompt]: - t0 = time.perf_counter() - tokens = tokenizer.encode_code(prompt) - print(tokens) - print("Current prompt:") - print(prompt) - n_token_prompt = len(tokens) - print("N_token_prompt:", n_token_prompt) - token_stream = get_token_stream( - model, - tokenizer, - seq_length, - out_seq_length, - [copy.deepcopy(tokens) for _ in range(micro_batch_size)], - micro_batch_size=micro_batch_size, - topk=args.top_k, - topp=args.top_p, - temperature=args.temperature, - greedy=args.greedy, - ) - is_finished = [False for _ in range(micro_batch_size)] - for i, generated in enumerate(token_stream): - generated_tokens = generated[0] - for j in range(micro_batch_size): - if is_finished[j]: - continue - if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len( - generated_tokens[j]) >= out_seq_length: - is_finished[j] = True - generated_tokens_ = generated_tokens[j].cpu().numpy().tolist() - generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:]) - generated_code = "".join(generated_code) - t1 = time.perf_counter() - print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt) - print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token") - times[out_seq_length].append(t1 - t0) - print("================================= Generated code:") - print(generated_code) - - if all(is_finished): - break - - print(times) - for out_seq_length in times.keys(): - print(out_seq_length, np.mean(times[out_seq_length])) - + while True: + print("\nPlease Input Query (Ctrl-D to save multiple lines, 'stop' to exit) >>> ") + prompts = [] + while True: + try: + line = input() + except EOFError: + break + prompts.append(line) + prompt = "\n".join(prompts) + prompt = prompt.strip() + if not prompt: + print('Query should not be empty!') + continue + if prompt == "stop": + return + try: + t0 = time.perf_counter() + generated_code = codegeex.generate( + model, + tokenizer, + prompt, + out_seq_length=out_seq_length, + seq_length=args.max_position_embeddings, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + micro_batch_size=args.micro_batch_size, + backend="megatron", + verbose=True, + ) + t1 = time.perf_counter() + print("Total generation time:", t1 - t0) + except (ValueError, FileNotFoundError) as e: + print(e) + continue + print("Generation finished.") if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/tests/test_inference_megatron.py b/tests/test_inference_megatron.py index ac9c600..66bda37 100644 --- a/tests/test_inference_megatron.py +++ b/tests/test_inference_megatron.py @@ -166,44 +166,57 @@ def main(): with open(args.prompt_file, "r") as f: prompt = f.readlines() prompt = "".join(prompt) - - print_rank_0("Generating ...") - t0 = time.perf_counter() - for prompt in [prompt]: - tokens = tokenizer.tokenize(prompt) - print_rank_0(tokens) - print_rank_0("Current prompt:") - print_rank_0(prompt) - n_token_prompt = len(tokens) - print_rank_0(f"N_token_prompt: {n_token_prompt}") - token_stream = get_token_stream( - model, - [copy.deepcopy(tokens) for _ in range(args.micro_batch_size)], - micro_batch_size=args.micro_batch_size, - bad_ids=args.bad_ids, - ) - is_finished = [False for _ in range(args.micro_batch_size)] - for i, generated in enumerate(token_stream): - generated_tokens = generated[0] - for j in range(args.micro_batch_size): - if is_finished[j]: - continue - if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eod or len( - generated_tokens[j]) >= args.out_seq_length: - is_finished[j] = True - generated_tokens_ = generated_tokens[j].cpu().numpy().tolist() - generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:]) - t1 = time.perf_counter() - print_rank_0(f"Total generation time: {t1 - t0}, # Tokens: {len(generated_tokens_) - n_token_prompt}") - print_rank_0(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token") - print_rank_0("================================= Generated code:") - print_rank_0(generated_code) - t0 = time.perf_counter() - if all(is_finished): - break - + + times = {} + out_seq_lengths = [args.out_seq_length] + micro_batch_size = args.micro_batch_size + for out_seq_length in out_seq_lengths: + print_rank_0(f"Generating with out_seq_len {out_seq_length}...") + + times[out_seq_length] = [] + for prompt in [prompt] * args.n_generation: + t0 = time.perf_counter() + tokens = tokenizer.tokenize(prompt) + print_rank_0(tokens) + print_rank_0("Current prompt:") + print_rank_0(prompt) + n_token_prompt = len(tokens) + print_rank_0(f"N_token_prompt:{n_token_prompt}") + token_stream = get_token_stream( + model, + [copy.deepcopy(tokens) for _ in range(micro_batch_size)], + micro_batch_size=micro_batch_size, + topk=args.top_k, + topp=args.top_p, + temperature=args.temperature, + ) + is_finished = [False for _ in range(micro_batch_size)] + for i, generated in enumerate(token_stream): + generated_tokens = generated[0] + for j in range(micro_batch_size): + if is_finished[j]: + continue + if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eod or len( + generated_tokens[j]) >= out_seq_length: + is_finished[j] = True + generated_tokens_ = generated_tokens[j].cpu().numpy().tolist() + generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:]) + t1 = time.perf_counter() + print_rank_0(f"Total generation time: {t1 - t0}, # Tokens: {len(generated_tokens_) - n_token_prompt}") + print_rank_0(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token") + times[out_seq_length].append(t1 - t0) + print_rank_0("================================= Generated code:") + print_rank_0(generated_code) + t0 = time.perf_counter() + + if all(is_finished): + break + + print_rank_0(times) + for out_seq_length in times.keys(): + print_rank_0(f"{out_seq_length}, {np.mean(times[out_seq_length])}") + print_rank_0("Generation finished.") - if __name__ == "__main__": main()