Merge pull request #27 from THUDM/develop

Add quantization, parallelism, data processing, and other minor changes.
pull/32/head
Qinkai 2 years ago committed by GitHub
commit 593ef9e231
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,6 +2,7 @@ import os
import sys
import fire
import json
import gzip
import regex
import numpy as np
@ -27,7 +28,7 @@ def process_humaneval_test(sample, problems, example_test=False):
task_id = sample["task_id"]
language = task_id.split("/")[0].lower()
prompt = problems[task_id]["prompt"]
prompt = sample["prompt"]
if example_test and "example_test" in problems[task_id] and problems[task_id]["example_test"] != "":
test = problems[task_id]["example_test"]
else:
@ -43,17 +44,17 @@ def process_humaneval_test(sample, problems, example_test=False):
code_.append(line)
code = "\n".join(code_)
test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n"
test_string = test_setup + prompt + "\n" + code + "\n" + test + "\n"
test_string = test_setup + prompt + code + "\n" + test + "\n"
elif language == "cpp":
test_set_up = ""
for s in IMPORT_HELPER["cpp"]:
if s not in prompt:
test_set_up += s + "\n"
test_string = test_set_up + "\n" + prompt + "\n" + code + "\n" + test
test_string = test_set_up + "\n" + prompt + code + "\n" + test
elif language == "java":
test_string = prompt + "\n" + code + "\n" + test
test_string = prompt + code + "\n" + test
elif language == "js" or language == "javascript":
test_string = prompt + "\n" + code + "\n" + test
test_string = prompt + code + "\n" + test
elif language == "go":
import_string = problems[task_id]["import"]
prompt = prompt.replace(import_string, "")
@ -70,19 +71,23 @@ def process_humaneval_test(sample, problems, example_test=False):
other_pkgs.append(f"\"{pkg}\"")
if other_pkgs:
import_other_pkgs = "import (\n" + " ".join([p + "\n" for p in other_pkgs]) + ")"
test_string = test_setup + "\n" + import_other_pkgs + "\n" + prompt + "\n" + code + "\n" + test
test_string = test_setup + "\n" + import_other_pkgs + "\n" + prompt + code + "\n" + test
else:
test_string = test_setup + "\n" + prompt + "\n" + code + "\n" + test
test_string = test_setup + "\n" + prompt + code + "\n" + test
return test_string
def stream_jsonl_all(filename: str) -> Iterable[Dict]:
results = []
with open(filename, "r") as fp:
for line in fp:
if any(not x.isspace() for x in line):
results.append(json.loads(line))
if filename.endswith(".gz"):
fp = gzip.open(open(filename, "rb"), "rt")
else:
fp = open(filename, "r")
for line in fp:
if any(not x.isspace() for x in line):
results.append(json.loads(line))
fp.close()
return results
@ -116,7 +121,7 @@ def evaluate_functional_correctness(
else:
out_file = os.path.join(input_file.replace(".jsonl", suffix))
if "humaneval_" in input_file:
if "/codegeex/benchmark/humaneval-x/" in input_file:
test_groundtruth = True
if "-to-" in input_file:
@ -206,11 +211,18 @@ def evaluate_functional_correctness(
print("Total:", np.sum(total))
print("Correct:", np.sum(correct))
with open(out_file, 'w') as fp:
print("Writing to: ", out_file)
print("Writing to: ", out_file)
if out_file.endswith(".gz"):
fp = gzip.GzipFile(fileobj=open(out_file, "wb"), mode="wb")
for res in results.values():
for r in res:
fp.write((json.dumps(r[1]) + "\n").encode("utf-8"))
else:
fp = open(out_file, 'w')
for res in results.values():
for r in res:
fp.write(json.dumps(r[1]) + "\n")
fp.close()
print("Evaluation finished.")

@ -13,6 +13,7 @@ from codegeex.benchmark.utils import read_dataset, process_extra_prompt
from codegeex.megatron import get_args
from codegeex.megatron.inference import run_generation_distributed, model_provider
from codegeex.megatron.initialize import initialize_megatron
from codegeex.quantization import quantize
logging.getLogger("torch").setLevel(logging.WARNING)
@ -190,7 +191,11 @@ def add_code_generation_args(parser):
default=None,
help='Identify the type of programming language to generate',
)
group.add_argument(
"--quantize",
action="store_true",
)
return parser
@ -231,6 +236,8 @@ def main(node_rank: int, local_rank: int, master_port: int, num_devices: int):
model.eval()
if args.fp16 and args.ln_fp16:
model.half()
if args.quantize:
model = quantize(model, weight_bit_width=8, backend="megatron")
model.cuda()
# Generate samples.

@ -240,32 +240,36 @@ def inspect_result(
if incompleted:
print(f"Language not supported, aborted. {input_file}")
else:
total, correct = [], []
for k, res in result_stats.items():
total.append(res["n_sample"])
correct.append(res["accepted"])
df_res = pd.DataFrame(res, index=[int(k.split("/")[-1])])
df = pd.concat([df, df_res], axis=0)
try:
total, correct = [], []
for k, res in result_stats.items():
total.append(res["n_sample"])
correct.append(res["accepted"])
df_res = pd.DataFrame(res, index=[int(k.split("/")[-1])])
df = pd.concat([df, df_res], axis=0)
total = np.array(total)
correct = np.array(correct)
total = np.array(total)
correct = np.array(correct)
ks = [1, 10, 100, 1000]
pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean()
for k in ks if (total >= k).all()}
ks = [1, 10, 100, 1000]
pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean()
for k in ks if (total >= k).all()}
print(pass_at_k)
pass_at_k["file"] = input_file
pass_at_k["n"] = res["n_sample"]
pass_at_k_outs.append(pass_at_k)
print(pass_at_k)
pass_at_k["file"] = input_file
pass_at_k["n"] = res["n_sample"]
pass_at_k_outs.append(pass_at_k)
output_prefix = input_file.split("/")[-1].split(".jsonl")[0]
output_file = os.path.join(output_dir, output_prefix + "_stats.xlsx")
df = df.sort_index(ascending=True)
df.to_excel(output_file)
print(f"Stats saved in {output_file}")
output_prefix = input_file.split("/")[-1].split(".jsonl")[0]
output_file = os.path.join(output_dir, output_prefix + "_stats.xlsx")
df = df.sort_index(ascending=True)
df.to_excel(output_file)
print(f"Stats saved in {output_file}")
except Exception as e:
print(e)
print(f"Data incompleted, aborted. {input_file}")
if pass_at_k_outpath is not None:
jsonl_path = os.path.join(output_dir, pass_at_k_outpath)
with open(jsonl_path, "w") as f_out:

@ -1,36 +1,7 @@
import gzip
import json
import os
from typing import *
from codegeex.data.data_utils import stream_jsonl, LANGUAGE_TAG
LANGUAGE_TAG = {
"c++" : "// language: C++",
"cpp" : "// language: C++",
"c" : "// language: C",
"c#" : "// language: C#",
"cuda" : "// language: Cuda",
"objective-c" : "// language: Objective-C",
"objective-c++": "// language: Objective-C++",
"python" : "# language: Python",
"java" : "// language: Java",
"scala" : "// language: Scala",
"tex" : f"% language: TeX",
"html" : "<!--language: HTML-->",
"php" : "// language: PHP",
"js" : "// language: JavaScript",
"javascript" : "// language: JavaScript",
"typescript" : "// language: TypeScript",
"go" : "// language: Go",
"shell" : "# language: Shell",
"rust" : "// language: Rust",
"css" : "/* language: CSS */",
"sql" : "-- language: SQL",
"kotlin" : "// language: Kotlin",
"pascal" : "// language: Pascal",
"r" : "# language: R",
"fortran" : "!language: Fortran",
"lean" : "-- language: Lean",
}
IMPORT_HELPER = {
"python": [
@ -78,11 +49,9 @@ IMPORT_HELPER = {
def read_dataset(
data_file: str = None,
dataset_type: str = "humaneval",
split: str = "test",
args=None,
num_shot=None,
data_file: str = None,
dataset_type: str = "humaneval",
num_shot=None,
) -> Dict:
if num_shot is not None:
print(f"{num_shot}-shot setting...")
@ -98,11 +67,11 @@ def read_dataset(
def read_translation_dataset(
data_file_src: str = None,
data_file_tgt: str = None,
lang_src: str = None,
lang_tgt: str = None,
dataset_type: str = "humaneval",
data_file_src: str = None,
data_file_tgt: str = None,
lang_src: str = None,
lang_tgt: str = None,
dataset_type: str = "humaneval",
) -> Dict:
if "humaneval" in dataset_type.lower():
dataset_src = {task["task_id"]: task for task in stream_jsonl(data_file_src)}
@ -130,43 +99,6 @@ def read_translation_dataset(
return dataset_src
def stream_jsonl(filename: str) -> Iterable[Dict]:
"""
Parses each jsonl line and yields it as a dictionary
"""
if filename.endswith(".gz"):
with open(filename, "rb") as gzfp:
with gzip.open(gzfp, "rt") as fp:
for line in fp:
if any(not x.isspace() for x in line):
yield json.loads(line)
else:
with open(filename, "r") as fp:
for line in fp:
if any(not x.isspace() for x in line):
yield json.loads(line)
def write_jsonl(filename: str, data: Iterable[Dict], append: bool = False):
"""
Writes an iterable of dictionaries to jsonl
"""
if append:
mode = "ab"
else:
mode = "wb"
filename = os.path.expanduser(filename)
if filename.endswith(".gz"):
with open(filename, mode) as fp:
with gzip.GzipFile(fileobj=fp, mode="wb") as gzfp:
for x in data:
gzfp.write((json.dumps(x) + "\n").encode("utf-8"))
else:
with open(filename, mode) as fp:
for x in data:
fp.write((json.dumps(x) + "\n").encode("utf-8"))
def process_extra_prompt(prompt: str, language_type: str = None) -> str:
"""
Processes the extra prompt.
@ -181,9 +113,9 @@ def process_extra_prompt(prompt: str, language_type: str = None) -> str:
def is_code_generation_finished(
code: str,
language_type: str = None,
dataset: str = None,
code: str,
language_type: str = None,
dataset: str = None,
):
"""
Checks whether the generated code is finished.
@ -216,46 +148,10 @@ def is_code_generation_finished(
return False
def is_code_generation_finished_fix(
code: str,
language_type: str = None,
dataset: str = None,
):
"""
Checks whether the generated code is finished.
"""
if language_type is None or dataset is None:
return False
if "humaneval" in dataset.lower():
if language_type.lower() == "python":
for line in code.split("\n"):
if len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t':
return True
end_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"]
for w in end_words:
if w in code:
return True
elif language_type.lower() == "java":
if code.count("{") == code.count("}"):
return True
elif language_type.lower() == "go":
if code.count("{") == code.count("}"):
return True
elif language_type.lower() == "js":
if code.count("{") == code.count("}"):
return True
elif language_type.lower() == "cpp":
if code.count("{") == code.count("}"):
return True
return False
def cleanup_code(
code: str,
language_type: str = None,
dataset: str = None,
code: str,
language_type: str = None,
dataset: str = None,
):
"""
Cleans up the generated code.

@ -0,0 +1,100 @@
import gzip
import json
from typing import *
LANGUAGE_TAG = {
"c++" : "// language: C++",
"cpp" : "// language: C++",
"c" : "// language: C",
"c#" : "// language: C#",
"cuda" : "// language: Cuda",
"objective-c" : "// language: Objective-C",
"objective-c++": "// language: Objective-C++",
"python" : "# language: Python",
"java" : "// language: Java",
"scala" : "// language: Scala",
"tex" : f"% language: TeX",
"html" : "<!--language: HTML-->",
"php" : "// language: PHP",
"js" : "// language: JavaScript",
"javascript" : "// language: JavaScript",
"typescript" : "// language: TypeScript",
"go" : "// language: Go",
"shell" : "# language: Shell",
"rust" : "// language: Rust",
"css" : "/* language: CSS */",
"sql" : "-- language: SQL",
"kotlin" : "// language: Kotlin",
"pascal" : "// language: Pascal",
"r" : "# language: R",
"fortran" : "!language: Fortran",
"lean" : "-- language: Lean",
}
def stream_jsonl(filename: str) -> Iterable[Dict]:
"""
Parses each jsonl line and yields it as a dictionary
"""
if filename.endswith(".gz"):
with open(filename, "rb") as gzfp:
with gzip.open(gzfp, "rt") as fp:
for line in fp:
if any(not x.isspace() for x in line):
yield json.loads(line)
else:
with open(filename, "r") as fp:
for line in fp:
if any(not x.isspace() for x in line):
yield json.loads(line)
def write_jsonl(filename: str, data: Iterable[Dict], append: bool = False):
"""
Writes an iterable of dictionaries to jsonl
"""
if append:
mode = "ab"
else:
mode = "wb"
filename = os.path.expanduser(filename)
if filename.endswith(".gz"):
with open(filename, mode) as fp:
with gzip.GzipFile(fileobj=fp, mode="wb") as gzfp:
for x in data:
gzfp.write((json.dumps(x) + "\n").encode("utf-8"))
else:
with open(filename, mode) as fp:
for x in data:
fp.write((json.dumps(x) + "\n").encode("utf-8"))
def sliding_window(
prompt_tokens: list,
code_tokens: list,
seq_len: int,
sliding_stride: int,
minimum_code_len: int = 1,
) -> Iterable[Tuple[list, list]]:
"""
Generate a series of (prompt, code) pairs by sliding the window over the code.
"""
prompt_len = len(prompt_tokens)
code_len = len(code_tokens)
total_len = prompt_len + code_len
start_idx = max(0, prompt_len - seq_len + minimum_code_len) # at least `minimum_code_len` code token should be in the window
end_idx = max(0, total_len - seq_len)
start_idx = min(start_idx, end_idx)
for i in range(start_idx, end_idx + 1, sliding_stride):
current_prompt = prompt_tokens[i:i + seq_len]
current_code = code_tokens[max(i - prompt_len, 0):i - prompt_len + seq_len]
yield current_prompt, current_code
if (end_idx - start_idx) % sliding_stride != 0:
current_prompt = prompt_tokens[end_idx:end_idx + seq_len]
current_code = code_tokens[max(end_idx - prompt_len, 0):end_idx - prompt_len + seq_len]
yield current_prompt, current_code

@ -0,0 +1,144 @@
import os
import glob
import fire
import torch
import multiprocessing
from typing import *
from tqdm.auto import tqdm
from time import perf_counter
from black import format_str, FileMode
from codegeex.data.types import PromptDataset, PromptSample
from codegeex.data.processor import PromptDatasetProcessor
from codegeex.data.data_utils import stream_jsonl, LANGUAGE_TAG
from codegeex.megatron.data.indexed_dataset import make_mmap_builder
from codegeex.tokenizer import CodeGeeXTokenizer
def try_format_code(code: str):
# Auto-correct to PEP8 format (Change tab to 4-whitespaces;
# add whitespace around some special symbols;
# reformat line length < 100, etc.)
try:
res = format_str(code, mode=FileMode(line_length=200))
except Exception as e:
res = code
print(e)
print("Wrong python format: {}".format(code))
return res
def load_pretrain_dataset(dataset_path: Union[str, List[str]]) -> Dict:
if type(dataset_path) is str:
dataset_path = [dataset_path]
for p in dataset_path:
if not os.path.isdir(p):
if p.endswith(".gz") or p.endswith(".jsonl"):
print(f"loading from {p}")
yield from stream_jsonl(p)
else:
p_list = glob.glob(p + "/*")
for p_ in p_list:
if p_.endswith(".gz") or p_.endswith(".jsonl"):
print(f"loading from {p_}")
yield from stream_jsonl(p_)
def process_sample(
sample: Dict,
language: str=None,
mode: str="pretrain",
) -> Iterable[PromptSample]:
if mode == "pretrain":
prompt = ""
else:
prompt = sample["prompt"]
try:
if language is not None and language in LANGUAGE_TAG.keys():
code = LANGUAGE_TAG[language] + sample["code"]
else:
code = sample["code"]
except Exception as e:
print(e)
print("The key 'code' is missing in data. Aborted")
exit(0)
yield PromptSample(prompt, code)
def generate_prompt_samples(
dataset: Iterable[Dict],
language: str = None,
mode: str = "pretrain",
) -> PromptDataset:
for sample in dataset:
yield from process_sample(sample, language, mode)
def main(
tokenizer_path: str,
dataset_path: Union[str, List[str]],
output_prefix: str,
language: str = None,
mode: str = "pretrain",
discard_overlong: bool = False,
sliding_stride: int = 200,
num_workers: int = 32,
seq_len: int = 2048,
):
DATA_KEYS = ["input_ids", "attention_mask", "labels"]
# create output dir
os.makedirs(os.path.dirname(output_prefix), exist_ok=True)
tokenizer = CodeGeeXTokenizer(tokenizer_path=tokenizer_path)
pad_token_id = tokenizer.eos_token_id
dataset = load_pretrain_dataset(dataset_path)
prompt_dataset = generate_prompt_samples(dataset, language=language, mode=mode)
if num_workers == 0:
num_workers = multiprocessing.cpu_count()
pool = multiprocessing.Pool(num_workers)
output_bin_files = {}
output_idx_files = {}
builders = {}
for key in DATA_KEYS:
output_bin_files[key] = "{}_{}.bin".format(output_prefix, key)
output_idx_files[key] = "{}_{}.idx".format(output_prefix, key)
builders[key] = make_mmap_builder(
output_bin_files[key],
vocab_size=None, # magic number, should change it
)
# NOTE that we use seq_len + 1 instead of seq_len, since the input tokens will be shifted by one.
processor = PromptDatasetProcessor(
tokenize=tokenizer.encode_code,
pad_token=pad_token_id,
max_seq_len=seq_len + 1,
discard_overlong=discard_overlong,
sliding_stride=sliding_stride,
eod_token=pad_token_id)
processor.start_time = perf_counter()
doc_iter = pool.imap_unordered(processor.process_sample_strict,
prompt_dataset,
chunksize=20)
for doc_idx, docs in tqdm(enumerate(doc_iter, start=1)):
processor.doc_processed += 1
for doc in docs:
processor.doc_generated += 1
for key in DATA_KEYS:
builders[key].add_item(torch.IntTensor(doc[key]))
for key in DATA_KEYS:
builders[key].finalize(output_idx_files[key])
if __name__ == "__main__":
fire.Fire(main)

@ -0,0 +1,155 @@
from typing import *
from time import perf_counter
from codegeex.data.data_utils import sliding_window
from codegeex.data.types import PromptSample, LabelSample
class PromptDatasetProcessor(object):
def __init__(
self,
tokenize: Callable,
pad_token: int,
keep_order: bool = False,
max_seq_len: int = 2048,
sliding_stride: int = 200,
discard_overlong: bool = True,
eod_token: int = None,
preprocess: Callable = None,
):
super(PromptDatasetProcessor, self).__init__()
self._keep_order = keep_order
self._max_seq_len = max_seq_len
self._sliding_stride = sliding_stride
self._tokenize = tokenize
self._pad_token = pad_token
self._discard_overlong = discard_overlong
self._eod_token = eod_token
self._preprocess = preprocess
self.doc_processed = 0
self.doc_generated = 0
self.start_time = 0
def pad_seq(self, prompt_tokens: List[int], code_tokens: List[int], extra: dict = None) -> Dict[str, List[int]]:
total_length = len(prompt_tokens) + len(code_tokens)
assert total_length <= self._max_seq_len, f"padding sequence: {total_length} > {self._max_seq_len}"
pad_len = self._max_seq_len - total_length
input_ids = prompt_tokens + code_tokens + [self._pad_token] * pad_len
attention_mask = [1] * len(prompt_tokens) + [1] * len(code_tokens) + [0] * pad_len
labels = [-100] * len(prompt_tokens) + code_tokens + [-100] * pad_len
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
def process_sample(self, sample: PromptSample) -> Iterable[Dict[str, List[int]]]:
"""
Process a sample.
"""
prompt_tokens = self._tokenize(sample.prompt)
code_tokens = self._tokenize(sample.code)
if self._eod_token is not None:
code_tokens.append(self._eod_token)
if len(prompt_tokens) + len(code_tokens) > self._max_seq_len:
if self._discard_overlong:
return
for p, t in sliding_window(prompt_tokens, code_tokens, self._max_seq_len, self._sliding_stride, self._sliding_stride):
yield self.pad_seq(p, t)
else:
yield self.pad_seq(prompt_tokens, code_tokens, extra=sample.extra)
def process_sample_strict(self, sample: PromptSample) -> List[Dict[str, List[int]]]:
"""
Instead of processing lazily, we turn the iterable into a list.
"""
return list(self.process_sample(sample))
def process_sample_(self, sample) -> List[Dict[str, List[int]]]:
prompt_sample = self._preprocess(sample)
return self.process_sample_strict(prompt_sample)
def report(self):
duration = perf_counter() - self.start_time
process_speed = self.doc_processed * 1.0 / duration
gen_speed = self.doc_generated * 1.0 / duration
print(f">>> processed: {self.doc_processed} in {duration:.2f}s, speed: {process_speed:.2f} docs/s")
print(f"... generated: {self.doc_generated} in {duration:.2f}s, speed: {gen_speed:.2f} docs/s")
class LabelDatasetProcessor(object):
def __init__(
self,
tokenize: Callable,
pad_token: int,
keep_order: bool = False,
max_seq_len: int = 2048,
sliding_stride: int = 200,
discard_overlong: bool = True,
eod_token: int = None,
preprocess: Callable = None,
):
super(LabelDatasetProcessor, self).__init__()
self._keep_order = keep_order
self._max_seq_len = max_seq_len
self._sliding_stride = sliding_stride
self._tokenize = tokenize
self._pad_token = pad_token
self._discard_overlong = discard_overlong
self._eod_token = eod_token
self._preprocess = preprocess
self.doc_processed = 0
self.doc_generated = 0
self.start_time = 0
def pad_seq(self, prompt_tokens: List[int], label: int, extra: dict = None) -> Dict[str, List[int]]:
total_length = len(prompt_tokens)
assert total_length <= self._max_seq_len, f"padding sequence: {total_length} > {self._max_seq_len}"
pad_len = self._max_seq_len - total_length
input_ids = prompt_tokens + [self._pad_token] * pad_len
attention_mask = [1] * len(prompt_tokens) + [0] * pad_len
label = [label]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"length": [len(prompt_tokens)],
"labels": label
}
def process_sample(self, sample: LabelSample) -> Iterable[Dict[str, List[int]]]:
"""
Process a sample.
"""
prompt_tokens = self._tokenize(sample.prompt)
label = sample.label
if len(prompt_tokens) > self._max_seq_len:
if self._discard_overlong:
return
prompt_tokens=prompt_tokens[-self._max_seq_len:]
yield self.pad_seq(prompt_tokens, label, extra=sample.extra)
def process_sample_strict(self, sample: LabelSample) -> List[Dict[str, List[int]]]:
"""
Instead of processing lazily, we turn the iterable into a list.
"""
return list(self.process_sample(sample))
def process_sample_(self, sample) -> List[Dict[str, List[int]]]:
prompt_sample = self._preprocess(sample)
return self.process_sample_strict(prompt_sample)
def report(self):
duration = perf_counter() - self.start_time
process_speed = self.doc_processed * 1.0 / duration
gen_speed = self.doc_generated * 1.0 / duration
print(f">>> processed: {self.doc_processed} in {duration:.2f}s, speed: {process_speed:.2f} docs/s")
print(f"... generated: {self.doc_generated} in {duration:.2f}s, speed: {gen_speed:.2f} docs/s")

@ -0,0 +1,20 @@
from typing import *
from dataclasses import dataclass
@dataclass
class PromptSample:
prompt: str
code: str
extra: dict = None
PromptDataset = Iterable[PromptSample]
@dataclass
class LabelSample:
prompt: str
label: int
extra: dict = None
LabelDataset = Iterable[LabelSample]

@ -0,0 +1,99 @@
import pkg_resources
import torch
import ctypes
from typing import List
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
RESOURCE_PACKAGE_NAME = __name__
class Kernel:
def __init__(self, filename: str, function_names: List[str]):
filename = filename + ".fatbin"
if not pkg_resources.resource_exists(RESOURCE_PACKAGE_NAME, filename):
raise RuntimeError("File `%s` not found in `%s`" % (filename, RESOURCE_PACKAGE_NAME))
self.filename = filename
self.code = pkg_resources.resource_string(RESOURCE_PACKAGE_NAME, filename)
self._function_names = function_names
self._cmodule = LazyKernelCModule(self.code)
for name in self._function_names:
setattr(self, name, KernelFunction(self._cmodule, name))
kernels = Kernel(
"quantization",
[
"int4WeightCompression",
"int4WeightExtractionFloat",
"int4WeightExtractionHalf",
"int8WeightExtractionFloat",
"int8WeightExtractionHalf",
],
)
def compress_int4_weight(weight: torch.Tensor): # (n, m)
with torch.cuda.device(weight.device):
n, m = weight.size(0), weight.size(1)
assert m % 2 == 0
m = m // 2
out = torch.empty(n, m, dtype=torch.int8, device="cuda")
stream = torch.cuda.current_stream()
gridDim = (n, 1, 1)
blockDim = (min(round_up(m, 32), 1024), 1, 1)
kernels.int4WeightCompression(
gridDim,
blockDim,
0,
stream,
[ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
)
return out
def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
if source_bit_width == 8:
func = kernels.int8WeightExtractionHalf
elif source_bit_width == 4:
func = kernels.int4WeightExtractionHalf
else:
assert False, "Unsupported bit-width"
with torch.cuda.device(weight.device):
n, m = weight.size(0), weight.size(1)
out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda")
stream = torch.cuda.current_stream()
gridDim = (n, 1, 1)
blockDim = (min(round_up(m, 32), 1024), 1, 1)
func(
gridDim,
blockDim,
0,
stream,
[
ctypes.c_void_p(weight.data_ptr()),
ctypes.c_void_p(scale_list.data_ptr()),
ctypes.c_void_p(out.data_ptr()),
ctypes.c_int32(n),
ctypes.c_int32(m),
],
)
return out
if __name__ == "__main__":
weight = torch.randn(4, 32).to(torch.int8).cuda()
scale = torch.ones(weight.size(0)).to(torch.half).cuda()
print(weight)
b = compress_int4_weight(weight)
print(b)
a = extract_weight_to_half(b, scale, source_bit_width=4)
print(a)

@ -23,7 +23,7 @@ from glob import glob
import torch
from megatron import get_args, mpu, print_rank_0, update_num_microbatches, utils
from codegeex.megatron import get_args, mpu, print_rank_0, update_num_microbatches, utils
_CHECKPOINT_VERSION = None
@ -82,7 +82,7 @@ def ensure_directory_exists(filename):
def get_checkpoint_name(checkpoints_path, iteration, release=False):
"""A unified checkpoint name."""
if release:
directory = "release"
directory = ""
else:
directory = "iter_{:07d}".format(iteration)
# Use both the tensor and pipeline MP rank.
@ -90,16 +90,14 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False):
return os.path.join(
checkpoints_path,
directory,
"mp_rank_{:02d}".format(mpu.get_tensor_model_parallel_rank()),
"model_optim_rng.pt",
"mp_rank_{:02d}_model_states.pt".format(mpu.get_tensor_model_parallel_rank()),
)
return os.path.join(
checkpoints_path,
directory,
"mp_rank_{:02d}_{:03d}".format(
"mp_rank_{:02d}_{:03d}_model_states.pt".format(
mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank()
),
"model_optim_rng.pt",
)
@ -300,7 +298,13 @@ def load_deepspeed_state(model):
model[0].load_state_dict(state_dict, strict=True)
def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load", strict=True):
def load_checkpoint(
model,
optimizer,
lr_scheduler,
load_arg="load",
strict=True,
):
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
@ -323,44 +327,53 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load", strict=True
else:
model = utils.unwrap_model(model)
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir)
# If no tracker file, return iretation zero.
if not os.path.isfile(tracker_filename):
print_rank_0(
"WARNING: could not find the metadata file {} ".format(tracker_filename)
)
print_rank_0(
" will not load any checkpoints and will start from " "random"
)
return 0
if load_dir.endswith(".pt"):
checkpoint_name = load_dir
release = True
else:
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir)
# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration = 0
release = False
with open(tracker_filename, "r") as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == "release"
if not release:
# If no tracker file, return iretation zero.
if not os.path.isfile(tracker_filename):
print_rank_0(
"WARNING: could not find the metadata file {} ".format(tracker_filename)
)
iteration = 0
release = True
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if not os.path.isfile(checkpoint_name):
print_rank_0(
"ERROR: Invalid metadata file {}. Exiting".format(
tracker_filename
)
" will not load any checkpoints and will start from random"
)
sys.exit()
assert iteration > 0 or release, "error parsing metadata file {}".format(
tracker_filename
)
return 0
else:
# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration = 0
release = False
with open(tracker_filename, "r") as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == "release"
if not release:
print_rank_0(
"ERROR: Invalid metadata file {}. Exiting".format(
tracker_filename
)
)
sys.exit()
assert iteration > 0 or release, "error parsing metadata file {}".format(
tracker_filename
)
# Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
print_rank_0(f" loading checkpoint from {args.load} at iteration {iteration}")
# Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
print_rank_0(f" loading checkpoint from {args.load} at iteration {iteration}")
# Load the checkpoint.
try:
@ -423,12 +436,20 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load", strict=True
# Model.
if not args.deepspeed:
if len(model) == 1:
model[0].load_state_dict(state_dict["model"], strict=strict)
if release:
if len(model) == 1:
model[0].load_state_dict(state_dict["module"], strict=strict)
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(state_dict["model%d" % i], strict=strict)
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(state_dict["model%d" % i], strict=strict)
if len(model) == 1:
model[0].load_state_dict(state_dict["module"], strict=strict)
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(state_dict["model%d" % i], strict=strict)
# Fix up query/key/value matrix ordering if needed
checkpoint_version = get_checkpoint_version()

@ -0,0 +1,145 @@
"""Get model parallel partitions."""
import os
import re
import random
import sys
import numpy as np
import torch
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
from codegeex.megatron import get_args
from codegeex.megatron.model import CodeGeeXModel
from codegeex.megatron.initialize import initialize_megatron
from codegeex.megatron.checkpointing import ensure_directory_exists
def get_change_ckpt_args(parser):
"""Provide extra arguments required for merging."""
group = parser.add_argument_group(title='Mindspore to megatron')
group.add_argument(
'--load-ckpt-path',
type=str,
required=True,
help='path to load ".pt" checkpoint.',
)
group.add_argument(
'--save-ckpt-path',
type=str,
required=True,
help='dir to save converted checkpoints.',
)
group.add_argument(
'--target-tensor-model-parallel-size',
type=int,
default=2,
help='target tensor model parallel size',
)
return parser
def get_element_from_dict_by_path(d, path):
"""
Get element from dictionary by path. If element is not present, recursively add empty dictionaries.
Args:
d (dict): the dictionary to get the element from
path (list): the path to the element which is delimited by "."
"""
path = path.split(".")
for k in path:
if k not in d:
d[k] = {}
d = d[k]
return d
def main():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(random.randint(10000, 20000))
initialize_megatron(
extra_args_provider=get_change_ckpt_args,
args_defaults={
"tokenizer_type": "GPT2BPETokenizer",
"no_load_rng" : True,
"no_load_optim" : True,
},
)
args = get_args()
print(f"Load ckpt from {args.load_ckpt_path}...")
state_dict = torch.load(args.load_ckpt_path, map_location="cpu")
print(f"Spliting ckpt into {args.target_tensor_model_parallel_size} parts...")
output_state_dict = []
for i in range(args.target_tensor_model_parallel_size):
output_state_dict.append({})
print("Converting Embedding layers...")
word_embeddings = state_dict['module']['language_model']['embedding']['word_embeddings']['weight']
position_embeddings = state_dict['module']['language_model']['embedding']['position_embeddings']['weight']
out_word_embeddings = torch.chunk(word_embeddings, args.target_tensor_model_parallel_size, dim=0)
for i in range(args.target_tensor_model_parallel_size):
pos_emb_dict = get_element_from_dict_by_path(
output_state_dict[i], "module.language_model.embedding.position_embeddings"
)
pos_emb_dict["weight"] = position_embeddings
word_emb_dict = get_element_from_dict_by_path(
output_state_dict[i], "module.language_model.embedding.word_embeddings"
)
word_emb_dict["weight"] = out_word_embeddings[i]
print("Converting QueryEmbedding layers...")
query_embeddings = state_dict['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight']
out_query_embeddings = torch.chunk(query_embeddings, args.target_tensor_model_parallel_size, dim=0)
for i in range(args.target_tensor_model_parallel_size):
query_emb_dict = get_element_from_dict_by_path(
output_state_dict[i], "module.language_model.topQueryEmbedding.top_query_embeddings"
)
query_emb_dict["weight"] = out_query_embeddings[i]
print("Converting Transformer layers...")
for layer_name in state_dict['module']['language_model']['transformer'].keys():
params = state_dict['module']['language_model']['transformer'][layer_name]
if "layernorm" in layer_name:
pass
elif "attention" in layer_name and "weight" in layer_name:
if "dense" in layer_name:
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=1)
else:
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=0)
elif "weight" in layer_name and "dense" in layer_name:
if "h_to_4h" in layer_name:
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=0)
else:
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=1)
elif "bias" in layer_name:
if "dense" not in layer_name or "mlp" in layer_name:
if "4h_to_h" in layer_name:
pass
else:
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=0)
for i in range(args.target_tensor_model_parallel_size):
params_dict = get_element_from_dict_by_path(output_state_dict[i], "module.language_model.transformer")
if type(params) is tuple:
params_dict[layer_name] = params[i]
else:
params_dict[layer_name] = params
os.makedirs(args.save_ckpt_path, exist_ok=True)
for rank in range(args.target_tensor_model_parallel_size):
save_ckpt_path = os.path.join(args.save_ckpt_path, f"mp_rank_{rank:02d}_model_states.pt")
torch.save(output_state_dict[rank], save_ckpt_path)
print(f"Converted checkpoint saved in {save_ckpt_path}.")
if __name__ == '__main__':
main()

@ -0,0 +1,69 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Blendable dataset."""
import time
import torch
import numpy as np
from codegeex.megatron import print_rank_0
class BlendableDataset(torch.utils.data.Dataset):
def __init__(self, datasets, weights):
self.datasets = datasets
num_datasets = len(datasets)
assert num_datasets == len(weights)
self.size = 0
for dataset in self.datasets:
self.size += len(dataset)
# Normalize weights.
weights = np.array(weights, dtype=np.float64)
sum_weights = np.sum(weights)
assert sum_weights > 0.0
weights /= sum_weights
# Build indecies.
start_time = time.time()
assert num_datasets < 255
self.dataset_index = np.zeros(self.size, dtype=np.uint8)
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
from megatron.data import helpers
helpers.build_blending_indices(
self.dataset_index,
self.dataset_sample_index,
weights,
num_datasets,
self.size,
torch.distributed.get_rank() == 0,
)
print_rank_0(
"> elapsed time for building blendable dataset indices: "
"{:.2f} (sec)".format(time.time() - start_time)
)
def __len__(self):
return self.size
def __getitem__(self, idx):
dataset_idx = self.dataset_index[idx]
sample_idx = self.dataset_sample_index[idx]
return self.datasets[dataset_idx][sample_idx]

@ -0,0 +1,185 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataloaders."""
import torch
from codegeex.megatron import get_args
from codegeex.megatron import mpu
def build_pretraining_data_loader(dataset, consumed_samples):
"""Buld dataloader given an input dataset."""
if dataset is None:
return None
args = get_args()
# Megatron sampler
if args.dataloader_type == "single":
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size(),
)
elif args.dataloader_type == "cyclic":
batch_sampler = MegatronPretrainingRandomSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size(),
)
else:
raise Exception(
"{} dataloader type is not supported.".format(args.dataloader_type)
)
# Torch dataloader.
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True,
)
class MegatronPretrainingSampler:
def __init__(
self,
total_samples,
consumed_samples,
micro_batch_size,
data_parallel_rank,
data_parallel_size,
drop_last=True,
):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = (
self.micro_batch_size * data_parallel_size
)
self.drop_last = drop_last
# Sanity checks.
assert self.total_samples > 0, "no sample to consume: {}".format(
self.total_samples
)
assert (
self.consumed_samples < self.total_samples
), "no samples left to consume: {}, {}".format(
self.consumed_samples, self.total_samples
)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert (
self.data_parallel_rank < data_parallel_size
), "data_parallel_rank should be smaller than data size: {}, " "{}".format(
self.data_parallel_rank, data_parallel_size
)
def __len__(self):
return self.total_samples
def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size
return start_idx, end_idx
def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.micro_batch_times_data_parallel_size:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
class MegatronPretrainingRandomSampler:
def __init__(
self,
total_samples,
consumed_samples,
micro_batch_size,
data_parallel_rank,
data_parallel_size,
):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = (
self.micro_batch_size * data_parallel_size
)
self.last_batch_size = (
self.total_samples % self.micro_batch_times_data_parallel_size
)
# Sanity checks.
assert self.total_samples > 0, "no sample to consume: {}".format(
self.total_samples
)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert (
self.data_parallel_rank < data_parallel_size
), "data_parallel_rank should be smaller than data size: {}, " "{}".format(
self.data_parallel_rank, data_parallel_size
)
def __len__(self):
return self.total_samples
def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
# data sharding and random sampling
bucket_size = (
self.total_samples // self.micro_batch_times_data_parallel_size
) * self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []

@ -0,0 +1,566 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors, and NVIDIA.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Most of the code here has been copied from:
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
# with some modifications.
import math
import os
import time
import collections
import numpy as np
import torch
from codegeex.megatron import mpu, print_rank_0
from codegeex.megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples):
# The data prefix should be in the format of:
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
assert len(data_prefix) % 2 == 0
num_datasets = len(data_prefix) // 2
weights = [0] * num_datasets
prefixes = [0] * num_datasets
for i in range(num_datasets):
weights[i] = float(data_prefix[2 * i])
prefixes[i] = (data_prefix[2 * i + 1]).strip()
# Normalize weights
weight_sum = 0.0
for weight in weights:
weight_sum += weight
assert weight_sum > 0.0
weights = [weight / weight_sum for weight in weights]
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
datasets_train_valid_test_num_samples = []
for weight in weights:
datasets_train_valid_test_num_samples.append(
[
int(math.ceil(val * weight * 1.005))
for val in train_valid_test_num_samples
]
)
return prefixes, weights, datasets_train_valid_test_num_samples
def compile_helper():
"""Compile helper function ar runtime. Make sure this
is invoked on a single process."""
import os
import subprocess
path = os.path.abspath(os.path.dirname(__file__))
ret = subprocess.run(["make", "-C", path])
if ret.returncode != 0:
print("Making C++ dataset helpers module failed, exiting.")
import sys
sys.exit(1)
def get_a_and_b_segments(sample, np_rng):
"""Divide sample into a and b segments."""
# Number of sentences in the sample.
n_sentences = len(sample)
# Make sure we always have two sentences.
assert n_sentences > 1, "make sure each sample has at least two sentences."
# First part:
# `a_end` is how many sentences go into the `A`.
a_end = 1
if n_sentences >= 3:
# Note that randin in numpy is exclusive.
a_end = np_rng.randint(1, n_sentences)
tokens_a = []
for j in range(a_end):
tokens_a.extend(sample[j])
# Second part:
tokens_b = []
for j in range(a_end, n_sentences):
tokens_b.extend(sample[j])
# Random next:
is_next_random = False
if np_rng.random() < 0.5:
is_next_random = True
tokens_a, tokens_b = tokens_b, tokens_a
return tokens_a, tokens_b, is_next_random
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
"""Truncates a pair of sequences to a maximum sequence length."""
# print(len_a, len_b, max_num_tokens)
assert len_a > 0
if len_a + len_b <= max_num_tokens:
return False
while len_a + len_b > max_num_tokens:
if len_a > len_b:
len_a -= 1
tokens = tokens_a
else:
len_b -= 1
tokens = tokens_b
if np_rng.random() < 0.5:
del tokens[0]
else:
tokens.pop()
return True
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
tokens = []
tokentypes = []
# [CLS].
tokens.append(cls_id)
tokentypes.append(0)
# Segment A.
for token in tokens_a:
tokens.append(token)
tokentypes.append(0)
# [SEP].
tokens.append(sep_id)
tokentypes.append(0)
# Segment B.
for token in tokens_b:
tokens.append(token)
tokentypes.append(1)
if tokens_b:
# [SEP].
tokens.append(sep_id)
tokentypes.append(1)
return tokens, tokentypes
MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"])
def is_start_piece(piece):
"""Check if the current word piece is the starting piece (BERT)."""
# When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
return not piece.startswith("##")
def create_masked_lm_predictions(
tokens,
vocab_id_list,
vocab_id_to_token_dict,
masked_lm_prob,
cls_id,
sep_id,
mask_id,
max_predictions_per_seq,
np_rng,
max_ngrams=3,
do_whole_word_mask=True,
favor_longer_ngram=False,
do_permutation=False,
geometric_dist=False,
masking_style="bert",
):
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
cand_indexes = []
# Note(mingdachen): We create a list for recording if the piece is
# the starting piece of current token, where 1 means true, so that
# on-the-fly whole word masking is possible.
token_boundary = [0] * len(tokens)
for (i, token) in enumerate(tokens):
if token == cls_id or token == sep_id:
token_boundary[i] = 1
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (
do_whole_word_mask
and len(cand_indexes) >= 1
and not is_start_piece(vocab_id_to_token_dict[token])
):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
if is_start_piece(vocab_id_to_token_dict[token]):
token_boundary[i] = 1
output_tokens = list(tokens)
masked_lm_positions = []
masked_lm_labels = []
if masked_lm_prob == 0:
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
num_to_predict = min(
max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))
)
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
if not geometric_dist:
# Note(mingdachen):
# By default, we set the probilities to favor shorter ngram sequences.
pvals = 1.0 / np.arange(1, max_ngrams + 1)
pvals /= pvals.sum(keepdims=True)
if favor_longer_ngram:
pvals = pvals[::-1]
ngram_indexes = []
for idx in range(len(cand_indexes)):
ngram_index = []
for n in ngrams:
ngram_index.append(cand_indexes[idx : idx + n])
ngram_indexes.append(ngram_index)
np_rng.shuffle(ngram_indexes)
(masked_lms, masked_spans) = ([], [])
covered_indexes = set()
for cand_index_set in ngram_indexes:
if len(masked_lms) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes:
continue
if not geometric_dist:
n = np_rng.choice(
ngrams[: len(cand_index_set)],
p=pvals[: len(cand_index_set)]
/ pvals[: len(cand_index_set)].sum(keepdims=True),
)
else:
# Sampling "n" from the geometric distribution and clipping it to
# the max_ngrams. Using p=0.2 default from the SpanBERT paper
# https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
n = min(np_rng.geometric(0.2), max_ngrams)
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# Note(mingdachen):
# Repeatedly looking for a candidate that does not exceed the
# maximum number of predictions by trying shorter ngrams.
while len(masked_lms) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
if masking_style == "bert":
# 80% of the time, replace with [MASK]
if np_rng.random() < 0.8:
masked_token = mask_id
else:
# 10% of the time, keep original
if np_rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_id_list[
np_rng.randint(0, len(vocab_id_list))
]
elif masking_style == "t5":
masked_token = mask_id
else:
raise ValueError("invalid value of masking style")
output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
masked_spans.append(
MaskedLmInstance(
index=index_set, label=[tokens[index] for index in index_set]
)
)
assert len(masked_lms) <= num_to_predict
np_rng.shuffle(ngram_indexes)
select_indexes = set()
if do_permutation:
for cand_index_set in ngram_indexes:
if len(select_indexes) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes or index in select_indexes:
continue
n = np.random.choice(
ngrams[: len(cand_index_set)],
p=pvals[: len(cand_index_set)]
/ pvals[: len(cand_index_set)].sum(keepdims=True),
)
index_set = sum(cand_index_set[n - 1], [])
n -= 1
while len(select_indexes) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(select_indexes) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes or index in select_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
select_indexes.add(index)
assert len(select_indexes) <= num_to_predict
select_indexes = sorted(select_indexes)
permute_indexes = list(select_indexes)
np_rng.shuffle(permute_indexes)
orig_token = list(output_tokens)
for src_i, tgt_i in zip(select_indexes, permute_indexes):
output_tokens[src_i] = orig_token[tgt_i]
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
# Sort the spans by the index of the first span
masked_spans = sorted(masked_spans, key=lambda x: x.index[0])
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (
output_tokens,
masked_lm_positions,
masked_lm_labels,
token_boundary,
masked_spans,
)
def pad_and_convert_to_numpy(
tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length
):
"""Pad sequences and convert them to numpy."""
# Some checks.
num_tokens = len(tokens)
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels)
# Tokens and token types.
filler = [pad_id] * padding_length
tokens_np = np.array(tokens + filler, dtype=np.int64)
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
# Padding mask.
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64)
# Lables and loss mask.
labels = [-1] * max_seq_length
loss_mask = [0] * max_seq_length
for i in range(len(masked_positions)):
assert masked_positions[i] < num_tokens
labels[masked_positions[i]] = masked_labels[i]
loss_mask[masked_positions[i]] = 1
labels_np = np.array(labels, dtype=np.int64)
loss_mask_np = np.array(loss_mask, dtype=np.int64)
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
print_rank_0(" > building dataset index ...")
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
print_rank_0(
" > finished creating indexed dataset in {:4f} "
"seconds".format(time.time() - start_time)
)
print_rank_0(" > indexed dataset stats:")
print_rank_0(
" number of documents: {}".format(indexed_dataset.doc_idx.shape[0] - 1)
)
print_rank_0(" number of sentences: {}".format(indexed_dataset.sizes.shape[0]))
return indexed_dataset
def get_train_valid_test_split_(splits_string, size):
"""Get dataset splits from comma or '/' separated string list."""
splits = []
if splits_string.find(",") != -1:
splits = [float(s) for s in splits_string.split(",")]
elif splits_string.find("/") != -1:
splits = [float(s) for s in splits_string.split("/")]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.0)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] + int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index
def get_samples_mapping(
indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
max_seq_length,
short_seq_prob,
seed,
name,
binary_head,
):
"""Get a list that maps a sample index to a starting sentence index, end sentence index, and length"""
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples " "or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += "_{}_indexmap".format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += "_{}ep".format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += "_{}mns".format(max_num_samples)
indexmap_filename += "_{}msl".format(max_seq_length)
indexmap_filename += "_{:0.2f}ssp".format(short_seq_prob)
indexmap_filename += "_{}s".format(seed)
indexmap_filename += ".npy"
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename):
print(
" > WARNING: could not find index map file {}, building "
"the indices on rank 0 ...".format(indexmap_filename)
)
# Make sure the types match the helpers input types.
assert indexed_dataset.doc_idx.dtype == np.int64
assert indexed_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(" > building sapmles index mapping for {} ...".format(name))
# First compile and then import.
from megatron.data import helpers
samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx,
indexed_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length,
short_seq_prob,
seed,
verbose,
2 if binary_head else 1,
)
print_rank_0(" > done building sapmles index maping")
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(" > saved the index mapping in {}".format(indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(
" > elasped time to build and save samples mapping "
"(seconds): {:4f}".format(time.time() - start_time)
)
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size()
// torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())
)
# Load indexed dataset.
print_rank_0(" > loading indexed mapping from {}".format(indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode="r")
print_rank_0(
" loaded indexed file in {:3.3f} seconds".format(time.time() - start_time)
)
print_rank_0(" total number of samples: {}".format(samples_mapping.shape[0]))
return samples_mapping

@ -0,0 +1,717 @@
/*
coding=utf-8
Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
/* Helper methods for fast index mapping builds */
#include <algorithm>
#include <iostream>
#include <limits>
#include <math.h>
#include <stdexcept>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <random>
namespace py = pybind11;
using namespace std;
const int32_t LONG_SENTENCE_LEN = 512;
void build_blending_indices(py::array_t<uint8_t>& dataset_index,
py::array_t<int64_t>& dataset_sample_index,
const py::array_t<double>& weights,
const int32_t num_datasets,
const int64_t size, const bool verbose) {
/* Given multiple datasets and a weighting array, build samples
such that it follows those wieghts.*/
if (verbose) {
std::cout << "> building indices for blendable datasets ..." << std::endl;
}
// Get the pointer access without the checks.
auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
auto weights_ptr = weights.unchecked<1>();
// Initialize buffer for number of samples used for each dataset.
int64_t current_samples[num_datasets];
for(int64_t i = 0; i < num_datasets; ++i) {
current_samples[i] = 0;
}
// For each sample:
for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {
// Determine where the max error in sampling is happening.
auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
int64_t max_error_index = 0;
double max_error = weights_ptr[0] * sample_idx_double -
static_cast<double>(current_samples[0]);
for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
double error = weights_ptr[dataset_idx] * sample_idx_double -
static_cast<double>(current_samples[dataset_idx]);
if (error > max_error) {
max_error = error;
max_error_index = dataset_idx;
}
}
// Populate the indices.
dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];
// Update the total samples.
current_samples[max_error_index] += 1;
}
// print info
if (verbose) {
std::cout << " > sample ratios:" << std::endl;
for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
auto ratio = static_cast<double>(current_samples[dataset_idx]) /
static_cast<double>(size);
std::cout << " dataset " << dataset_idx << ", input: " <<
weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
}
}
}
py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
const py::array_t<int32_t>& doc_idx_,
const int32_t seq_length,
const int32_t num_epochs,
const int64_t tokens_per_epoch) {
/* Sample index (sample_idx) is used for gpt2 like dataset for which
the documents are flattened and the samples are built based on this
1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
where [..., 0] contains the index into `doc_idx` and [..., 1] is the
starting offset in that document.*/
// Consistency checks.
assert(seq_length > 1);
assert(num_epochs > 0);
assert(tokens_per_epoch > 1);
// Remove bound checks.
auto sizes = sizes_.unchecked<1>();
auto doc_idx = doc_idx_.unchecked<1>();
// Mapping and it's length (1D).
int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
int32_t* sample_idx = new int32_t[2*(num_samples+1)];
cout << " using:" << endl << std::flush;
cout << " number of documents: " <<
doc_idx_.shape(0) / num_epochs << endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " sequence length: " << seq_length <<
endl << std::flush;
cout << " total number of samples: " << num_samples <<
endl << std::flush;
// Index into sample_idx.
int64_t sample_index = 0;
// Index into doc_idx.
int64_t doc_idx_index = 0;
// Begining offset for each document.
int32_t doc_offset = 0;
// Start with first document and no offset.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
while (sample_index <= num_samples) {
// Start with a fresh sequence.
int32_t remaining_seq_length = seq_length + 1;
while (remaining_seq_length != 0) {
// Get the document length.
auto doc_id = doc_idx[doc_idx_index];
auto doc_length = sizes[doc_id] - doc_offset;
// And add it to the current sequence.
remaining_seq_length -= doc_length;
// If we have more than a full sequence, adjust offset and set
// remaining length to zero so we return from the while loop.
// Note that -1 here is for the same reason we have -1 in
// `_num_epochs` calculations.
if (remaining_seq_length <= 0) {
doc_offset += (remaining_seq_length + doc_length - 1);
remaining_seq_length = 0;
} else {
// Otherwise, start from the begining of the next document.
++doc_idx_index;
doc_offset = 0;
}
}
// Record the sequence.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
}
// Method to deallocate memory.
py::capsule free_when_done(sample_idx, [](void *mem_) {
int32_t *mem = reinterpret_cast<int32_t*>(mem_);
delete[] mem;
});
// Return the numpy array.
const auto byte_size = sizeof(int32_t);
return py::array(std::vector<int64_t>{num_samples+1, 2}, // shape
{2*byte_size, byte_size}, // C-style contiguous strides
sample_idx, // the data pointer
free_when_done); // numpy array references
}
inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
const int32_t max_length,
std::mt19937& rand32_gen) {
/* Training sample length. */
if (short_seq_ratio == 0) {
return max_length;
}
const auto random_number = rand32_gen();
if ((random_number % short_seq_ratio) == 0) {
return 2 + random_number % (max_length - 1);
}
return max_length;
}
template<typename DocIdx>
py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
const py::array_t<int32_t>& sizes_,
const int32_t num_epochs,
const uint64_t max_num_samples,
const int32_t max_seq_length,
const double short_seq_prob,
const int32_t seed,
const bool verbose,
const int32_t min_num_sent) {
/* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length.
*/
// Consistency checks.
assert(num_epochs > 0);
assert(max_seq_length > 1);
assert(short_seq_prob >= 0.0);
assert(short_seq_prob <= 1.0);
assert(seed > 0);
// Remove bound checks.
auto docs = docs_.unchecked<1>();
auto sizes = sizes_.unchecked<1>();
// For efficiency, convert probability to ratio. Note: rand() generates int.
int32_t short_seq_ratio = 0;
if (short_seq_prob > 0) {
short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
}
if (verbose) {
const auto sent_start_index = docs[0];
const auto sent_end_index = docs[docs_.shape(0) - 1];
const auto num_sentences = sent_end_index - sent_start_index;
cout << " using:" << endl << std::flush;
cout << " number of documents: " << docs_.shape(0) - 1 <<
endl << std::flush;
cout << " sentences range: [" << sent_start_index <<
", " << sent_end_index << ")" << endl << std::flush;
cout << " total number of sentences: " << num_sentences <<
endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " maximum number of samples: " << max_num_samples <<
endl << std::flush;
cout << " maximum sequence length: " << max_seq_length <<
endl << std::flush;
cout << " short sequence probability: " << short_seq_prob <<
endl << std::flush;
cout << " short sequence ration (1/prob): " << short_seq_ratio <<
endl << std::flush;
cout << " seed: " << seed << endl <<
std::flush;
}
// Mapping and it's length (1D).
int64_t num_samples = -1;
DocIdx* maps = NULL;
// Perform two iterations, in the first iteration get the size
// and allocate memory and in the second iteration populate the map.
bool second = false;
for (int32_t iteration=0; iteration<2; ++iteration) {
// Set the seed so both iterations produce the same results.
std::mt19937 rand32_gen(seed);
// Set the flag on second iteration.
second = (iteration == 1);
// Counters:
uint64_t empty_docs = 0;
uint64_t one_sent_docs = 0;
uint64_t long_sent_docs = 0;
// Current map index.
uint64_t map_index = 0;
// For each epoch:
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
if (map_index >= max_num_samples) {
if (verbose && (!second)) {
cout << " reached " << max_num_samples << " samples after "
<< epoch << " epochs ..." << endl << std::flush;
}
break;
}
// For each document:
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
// Document sentences are in [sent_index_first, sent_index_last)
const auto sent_index_first = docs[doc];
const auto sent_index_last = docs[doc + 1];
// At the begining of the document previous index is the
// start index.
auto prev_start_index = sent_index_first;
// Remaining documents.
auto num_remain_sent = sent_index_last - sent_index_first;
// Some bookkeeping
if ((epoch == 0) && (!second)) {
if (num_remain_sent == 0) {
++empty_docs;
}
if (num_remain_sent == 1) {
++one_sent_docs;
}
}
// Detect documents with long sentences.
bool contains_long_sentence = false;
if (num_remain_sent > 1) {
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
if (sizes[sent_index] > LONG_SENTENCE_LEN){
if ((epoch == 0) && (!second)) {
++long_sent_docs;
}
contains_long_sentence = true;
break;
}
}
}
// If we have more than two sentences.
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
// Set values.
auto seq_len = int32_t{0};
auto num_sent = int32_t{0};
auto target_seq_len = get_target_sample_len(short_seq_ratio,
max_seq_length,
rand32_gen);
// Loop through sentences.
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
// Add the size and number of sentences.
seq_len += sizes[sent_index];
++num_sent;
--num_remain_sent;
// If we have reached the target length.
// and if not only one sentence is left in the document.
// and if we have at least two sentneces.
// and if we have reached end of the document.
if (((seq_len >= target_seq_len) &&
(num_remain_sent > 1) &&
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
// Check for overflow.
if ((3 * map_index + 2) >
std::numeric_limits<int64_t>::max()) {
cout << "number of samples exceeded maximum "
<< "allowed by type int64: "
<< std::numeric_limits<int64_t>::max()
<< endl;
throw std::overflow_error("Number of samples");
}
// Populate the map.
if (second) {
const auto map_index_0 = 3 * map_index;
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
}
// Update indices / counters.
++map_index;
prev_start_index = sent_index + 1;
target_seq_len = get_target_sample_len(short_seq_ratio,
max_seq_length,
rand32_gen);
seq_len = 0;
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
if (verbose) {
cout << " number of empty documents: " << empty_docs <<
endl << std::flush;
cout << " number of documents with one sentence: " <<
one_sent_docs << endl << std::flush;
cout << " number of documents with long sentences: " <<
long_sent_docs << endl << std::flush;
cout << " will create mapping for " << map_index <<
" samples" << endl << std::flush;
}
assert(maps == NULL);
assert(num_samples < 0);
maps = new DocIdx[3*map_index];
num_samples = static_cast<int64_t>(map_index);
}
} // for (int iteration=0; iteration < 2; ++iteration) {
// Shuffle.
// We need a 64 bit random number generator as we might have more
// than 2 billion samples.
std::mt19937_64 rand64_gen(seed + 1);
for (auto i=(num_samples - 1); i > 0; --i) {
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
const auto i0 = 3 * i;
const auto j0 = 3 * j;
// Swap values.
swap(maps[i0], maps[j0]);
swap(maps[i0 + 1], maps[j0 + 1]);
swap(maps[i0 + 2], maps[j0 + 2]);
}
// Method to deallocate memory.
py::capsule free_when_done(maps, [](void *mem_) {
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
delete[] mem;
});
// Return the numpy array.
const auto byte_size = sizeof(DocIdx);
return py::array(std::vector<int64_t>{num_samples, 3}, // shape
{3*byte_size, byte_size}, // C-style contiguous strides
maps, // the data pointer
free_when_done); // numpy array references
}
py::array build_mapping(const py::array_t<int64_t>& docs_,
const py::array_t<int>& sizes_,
const int num_epochs,
const uint64_t max_num_samples,
const int max_seq_length,
const double short_seq_prob,
const int seed,
const bool verbose,
const int32_t min_num_sent) {
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) {
cout << " using uint64 for data mapping..." << endl << std::flush;
}
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length,
short_seq_prob, seed, verbose,
min_num_sent);
} else {
if (verbose) {
cout << " using uint32 for data mapping..." << endl << std::flush;
}
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length,
short_seq_prob, seed, verbose,
min_num_sent);
}
}
template<typename DocIdx>
py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
const py::array_t<int32_t>& sizes_,
const py::array_t<int32_t>& titles_sizes_,
const int32_t num_epochs,
const uint64_t max_num_samples,
const int32_t max_seq_length,
const int32_t seed,
const bool verbose,
const bool use_one_sent_blocks) {
/* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length.
*/
// Consistency checks.
assert(num_epochs > 0);
assert(max_seq_length > 1);
assert(seed > 0);
// Remove bound checks.
auto docs = docs_.unchecked<1>();
auto sizes = sizes_.unchecked<1>();
auto titles_sizes = titles_sizes_.unchecked<1>();
if (verbose) {
const auto sent_start_index = docs[0];
const auto sent_end_index = docs[docs_.shape(0) - 1];
const auto num_sentences = sent_end_index - sent_start_index;
cout << " using:" << endl << std::flush;
cout << " number of documents: " << docs_.shape(0) - 1 <<
endl << std::flush;
cout << " sentences range: [" << sent_start_index <<
", " << sent_end_index << ")" << endl << std::flush;
cout << " total number of sentences: " << num_sentences <<
endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " maximum number of samples: " << max_num_samples <<
endl << std::flush;
cout << " maximum sequence length: " << max_seq_length <<
endl << std::flush;
cout << " seed: " << seed << endl <<
std::flush;
}
// Mapping and its length (1D).
int64_t num_samples = -1;
DocIdx* maps = NULL;
// Acceptable number of sentences per block.
int min_num_sent = 2;
if (use_one_sent_blocks) {
min_num_sent = 1;
}
// Perform two iterations, in the first iteration get the size
// and allocate memory and in the second iteration populate the map.
bool second = false;
for (int32_t iteration=0; iteration<2; ++iteration) {
// Set the flag on second iteration.
second = (iteration == 1);
// Current map index.
uint64_t map_index = 0;
uint64_t empty_docs = 0;
uint64_t one_sent_docs = 0;
uint64_t long_sent_docs = 0;
// For each epoch:
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
// assign every block a unique id
int32_t block_id = 0;
if (map_index >= max_num_samples) {
if (verbose && (!second)) {
cout << " reached " << max_num_samples << " samples after "
<< epoch << " epochs ..." << endl << std::flush;
}
break;
}
// For each document:
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
// Document sentences are in [sent_index_first, sent_index_last)
const auto sent_index_first = docs[doc];
const auto sent_index_last = docs[doc + 1];
const auto target_seq_len = max_seq_length - titles_sizes[doc];
// At the begining of the document previous index is the
// start index.
auto prev_start_index = sent_index_first;
// Remaining documents.
auto num_remain_sent = sent_index_last - sent_index_first;
// Some bookkeeping
if ((epoch == 0) && (!second)) {
if (num_remain_sent == 0) {
++empty_docs;
}
if (num_remain_sent == 1) {
++one_sent_docs;
}
}
// Detect documents with long sentences.
bool contains_long_sentence = false;
if (num_remain_sent >= min_num_sent) {
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
if (sizes[sent_index] > LONG_SENTENCE_LEN){
if ((epoch == 0) && (!second)) {
++long_sent_docs;
}
contains_long_sentence = true;
break;
}
}
}
// If we have enough sentences and no long sentences.
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
// Set values.
auto seq_len = int32_t{0};
auto num_sent = int32_t{0};
// Loop through sentences.
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
// Add the size and number of sentences.
seq_len += sizes[sent_index];
++num_sent;
--num_remain_sent;
// If we have reached the target length.
// and there are an acceptable number of sentences left
// and if we have at least the minimum number of sentences.
// or if we have reached end of the document.
if (((seq_len >= target_seq_len) &&
(num_remain_sent >= min_num_sent) &&
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
// Populate the map.
if (second) {
const auto map_index_0 = 4 * map_index;
// Each sample has 4 items: the starting sentence index, ending sentence index,
// the index of the document from which the block comes (used for fetching titles)
// and the unique id of the block (used for creating block indexes)
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
}
// Update indices / counters.
++map_index;
++block_id;
prev_start_index = sent_index + 1;
seq_len = 0;
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
if (verbose) {
cout << " number of empty documents: " << empty_docs <<
endl << std::flush;
cout << " number of documents with one sentence: " <<
one_sent_docs << endl << std::flush;
cout << " number of documents with long sentences: " <<
long_sent_docs << endl << std::flush;
cout << " will create mapping for " << map_index <<
" samples" << endl << std::flush;
}
assert(maps == NULL);
assert(num_samples < 0);
maps = new DocIdx[4*map_index];
num_samples = static_cast<int64_t>(map_index);
}
} // for (int iteration=0; iteration < 2; ++iteration) {
// Shuffle.
// We need a 64 bit random number generator as we might have more
// than 2 billion samples.
std::mt19937_64 rand64_gen(seed + 1);
for (auto i=(num_samples - 1); i > 0; --i) {
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
const auto i0 = 4 * i;
const auto j0 = 4 * j;
// Swap values.
swap(maps[i0], maps[j0]);
swap(maps[i0 + 1], maps[j0 + 1]);
swap(maps[i0 + 2], maps[j0 + 2]);
swap(maps[i0 + 3], maps[j0 + 3]);
}
// Method to deallocate memory.
py::capsule free_when_done(maps, [](void *mem_) {
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
delete[] mem;
});
// Return the numpy array.
const auto byte_size = sizeof(DocIdx);
return py::array(std::vector<int64_t>{num_samples, 4}, // shape
{4*byte_size, byte_size}, // C-style contiguous strides
maps, // the data pointer
free_when_done); // numpy array references
}
py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
const py::array_t<int>& sizes_,
const py::array_t<int>& titles_sizes_,
const int num_epochs,
const uint64_t max_num_samples,
const int max_seq_length,
const int seed,
const bool verbose,
const bool use_one_sent_blocks) {
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) {
cout << " using uint64 for data mapping..." << endl << std::flush;
}
return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
} else {
if (verbose) {
cout << " using uint32 for data mapping..." << endl << std::flush;
}
return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
}
}
PYBIND11_MODULE(helpers, m) {
m.def("build_mapping", &build_mapping);
m.def("build_blocks_mapping", &build_blocks_mapping);
m.def("build_sample_idx", &build_sample_idx);
m.def("build_blending_indices", &build_blending_indices);
}

@ -0,0 +1,595 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# copied from fairseq/fairseq/data/indexed_dataset.py
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
# other slight modifications to remove fairseq dependencies
# Added document index to index file and made it accessible.
# An empty sentence no longer separates documents.
import os
import shutil
import struct
import torch
import numpy as np
from functools import lru_cache
from itertools import accumulate
from codegeex.megatron import print_rank_0
def __best_fitting_dtype(vocab_size=None):
if vocab_size is not None and vocab_size < 65500:
return np.uint16
else:
return np.int32
def get_available_dataset_impl():
return ["lazy", "cached", "mmap"]
def infer_dataset_impl(path):
if IndexedDataset.exists(path):
with open(index_file_path(path), "rb") as f:
magic = f.read(8)
if magic == IndexedDataset._HDR_MAGIC:
return "cached"
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
return "mmap"
else:
return None
else:
print(f"Dataset does not exist: {path}")
print(
"Path should be a basename that both .idx and .bin can be appended to get full filenames."
)
return None
def make_builder(out_file, impl, vocab_size=None):
if impl == "mmap":
return MMapIndexedDatasetBuilder(
out_file, dtype=__best_fitting_dtype(vocab_size)
)
else:
return IndexedDatasetBuilder(out_file)
def make_dataset(path, impl, skip_warmup=False):
if not IndexedDataset.exists(path):
print(f"Dataset does not exist: {path}")
print(
"Path should be a basename that both .idx and .bin can be appended to get full filenames."
)
return None
if impl == "infer":
impl = infer_dataset_impl(path)
if impl == "lazy" and IndexedDataset.exists(path):
return IndexedDataset(path)
elif impl == "cached" and IndexedDataset.exists(path):
return IndexedCachedDataset(path)
elif impl == "mmap" and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path, skip_warmup)
print(f"Unknown dataset implementation: {impl}")
return None
def dataset_exists(path, impl):
if impl == "mmap":
return MMapIndexedDataset.exists(path)
else:
return IndexedDataset.exists(path)
def read_longs(f, n):
a = np.empty(n, dtype=np.int64)
f.readinto(a)
return a
def write_longs(f, a):
f.write(np.array(a, dtype=np.int64))
dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float,
7: np.double,
8: np.uint16,
}
def __best_fitting_dtype(vocab_size=None):
if vocab_size is not None and vocab_size < 65500:
return np.uint16
else:
return np.int32
def make_mmap_builder(out_file, vocab_size=None):
return MMapIndexedDatasetBuilder(
out_file, dtype=__best_fitting_dtype(vocab_size)
)
def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
raise ValueError(dtype)
def index_file_path(prefix_path):
return prefix_path + ".idx"
def data_file_path(prefix_path):
return prefix_path + ".bin"
def create_doc_idx(sizes):
doc_idx = [0]
for i, s in enumerate(sizes):
if s == 0:
doc_idx.append(i + 1)
return doc_idx
class IndexedDataset(torch.utils.data.Dataset):
"""Loader for IndexedDataset"""
_HDR_MAGIC = b"TNTIDX\x00\x00"
def __init__(self, path):
super().__init__()
self.path = path
self.data_file = None
self.read_index(path)
def read_index(self, path):
with open(index_file_path(path), "rb") as f:
magic = f.read(8)
assert magic == self._HDR_MAGIC, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
version = f.read(8)
assert struct.unpack("<Q", version) == (1,)
code, self.element_size = struct.unpack("<QQ", f.read(16))
self.dtype = dtypes[code]
self._len, self.s = struct.unpack("<QQ", f.read(16))
self.doc_count = struct.unpack("<Q", f.read(8))
self.dim_offsets = read_longs(f, self._len + 1)
self.data_offsets = read_longs(f, self._len + 1)
self.sizes = read_longs(f, self.s)
self.doc_idx = read_longs(f, self.doc_count)
def read_data(self, path):
self.data_file = open(data_file_path(path), "rb", buffering=0)
def check_index(self, i):
if i < 0 or i >= self._len:
raise IndexError("index out of range")
def __del__(self):
if self.data_file:
self.data_file.close()
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if not self.data_file:
self.read_data(self.path)
if isinstance(idx, int):
i = idx
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
return a
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]]
size = sum(sizes)
a = np.empty(size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[start] * self.element_size)
self.data_file.readinto(a)
offsets = list(accumulate(sizes))
sents = np.split(a, offsets[:-1])
return sents
def __len__(self):
return self._len
def num_tokens(self, index):
return self.sizes[index]
def size(self, index):
return self.sizes[index]
@staticmethod
def exists(path):
return os.path.exists(index_file_path(path)) and os.path.exists(
data_file_path(path)
)
@property
def supports_prefetch(self):
return False # avoid prefetching to save memory
class IndexedCachedDataset(IndexedDataset):
def __init__(self, path):
super().__init__(path)
self.cache = None
self.cache_index = {}
@property
def supports_prefetch(self):
return True
def prefetch(self, indices):
if all(i in self.cache_index for i in indices):
return
if not self.data_file:
self.read_data(self.path)
indices = sorted(set(indices))
total_size = 0
for i in indices:
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
self.cache = np.empty(total_size, dtype=self.dtype)
ptx = 0
self.cache_index.clear()
for i in indices:
self.cache_index[i] = ptx
size = self.data_offsets[i + 1] - self.data_offsets[i]
a = self.cache[ptx : ptx + size]
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
ptx += size
if self.data_file:
# close and delete data file after prefetch so we can pickle
self.data_file.close()
self.data_file = None
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
i = idx
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
ptx = self.cache_index[i]
np.copyto(a, self.cache[ptx : ptx + a.size])
return a
elif isinstance(idx, slice):
# Hack just to make this work, can optimizer later if necessary
sents = []
for i in range(*idx.indices(len(self))):
sents.append(self[i])
return sents
class IndexedDatasetBuilder(object):
element_sizes = {
np.uint8: 1,
np.int8: 1,
np.int16: 2,
np.int32: 4,
np.int64: 8,
np.float: 4,
np.double: 8,
}
def __init__(self, out_file, dtype=np.int32):
self.out_file = open(out_file, "wb")
self.dtype = dtype
self.data_offsets = [0]
self.dim_offsets = [0]
self.sizes = []
self.element_size = self.element_sizes[self.dtype]
self.doc_idx = [0]
def add_item(self, tensor):
bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
for s in tensor.size():
self.sizes.append(s)
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def end_document(self):
self.doc_idx.append(len(self.sizes))
def merge_file_(self, another_file):
index = IndexedDataset(another_file)
assert index.dtype == self.dtype
begin = self.data_offsets[-1]
for offset in index.data_offsets[1:]:
self.data_offsets.append(begin + offset)
self.sizes.extend(index.sizes)
begin = self.dim_offsets[-1]
for dim_offset in index.dim_offsets[1:]:
self.dim_offsets.append(begin + dim_offset)
with open(data_file_path(another_file), "rb") as f:
while True:
data = f.read(1024)
if data:
self.out_file.write(data)
else:
break
def finalize(self, index_file):
self.out_file.close()
index = open(index_file, "wb")
index.write(b"TNTIDX\x00\x00")
index.write(struct.pack("<Q", 1))
index.write(struct.pack("<QQ", code(self.dtype), self.element_size))
index.write(struct.pack("<QQ", len(self.data_offsets) - 1, len(self.sizes)))
index.write(struct.pack("<Q", len(self.doc_idx)))
write_longs(index, self.dim_offsets)
write_longs(index, self.data_offsets)
write_longs(index, self.sizes)
write_longs(index, self.doc_idx)
index.close()
def _warmup_mmap_file(path):
with open(path, "rb") as stream:
while stream.read(100 * 1024 * 1024):
pass
class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b"MMIDIDX\x00\x00"
@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, "wb")
self._file.write(cls._HDR_MAGIC)
self._file.write(struct.pack("<Q", 1))
self._file.write(struct.pack("<B", code(dtype)))
return self
@staticmethod
def _get_pointers(sizes):
dtype_size = dtype().itemsize
address = 0
pointers = []
for size in sizes:
pointers.append(address)
address += size * dtype_size
return pointers
def write(self, sizes, doc_idx):
pointers = self._get_pointers(sizes)
self._file.write(struct.pack("<Q", len(sizes)))
self._file.write(struct.pack("<Q", len(doc_idx)))
sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order="C"))
del sizes
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order="C"))
del pointers
doc_idx = np.array(doc_idx, dtype=np.int64)
self._file.write(doc_idx.tobytes(order="C"))
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
return _Writer()
def __init__(self, path, skip_warmup=False):
with open(path, "rb") as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
version = struct.unpack("<Q", stream.read(8))
assert (1,) == version
(dtype_code,) = struct.unpack("<B", stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize
self._len = struct.unpack("<Q", stream.read(8))[0]
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
offset = stream.tell()
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
self._sizes = np.frombuffer(
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
)
self._pointers = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._len,
offset=offset + self._sizes.nbytes,
)
self._doc_idx = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
@property
def dtype(self):
return self._dtype
@property
def sizes(self):
return self._sizes
@property
def doc_idx(self):
return self._doc_idx
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def __init__(self, path, skip_warmup=False):
super().__init__()
self._path = None
self._index = None
self._bin_buffer = None
self._do_init(path, skip_warmup)
def __getstate__(self):
return self._path
def __setstate__(self, state):
self._do_init(state)
def _do_init(self, path, skip_warmup):
self._path = path
self._index = self.Index(index_file_path(self._path), skip_warmup)
if not skip_warmup:
print_rank_0(" warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path))
print_rank_0(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(
data_file_path(self._path), mode="r", order="C"
)
print_rank_0(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
del self._index
def __len__(self):
return len(self._index)
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
ptr, size = self._index[idx]
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
)
return np_array
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
ptr = self._index._pointers[start]
sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes))
total_size = sum(sizes)
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
)
sents = np.split(np_array, offsets[:-1])
return sents
def get(self, idx, offset=0, length=None):
"""Retrieves a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
"""
ptr, size = self._index[idx]
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
)
return np_array
@property
def sizes(self):
return self._index.sizes
@property
def doc_idx(self):
return self._index.doc_idx
def get_doc_idx(self):
return self._index._doc_idx
def set_doc_idx(self, doc_idx_):
self._index._doc_idx = doc_idx_
@property
def supports_prefetch(self):
return False
@staticmethod
def exists(path):
return os.path.exists(index_file_path(path)) and os.path.exists(
data_file_path(path)
)
class MMapIndexedDatasetBuilder(object):
def __init__(self, out_file, dtype=np.int64):
self._data_file = open(out_file, "wb")
self._dtype = dtype
self._sizes = []
self._doc_idx = [0]
def add_item(self, tensor):
np_array = np.array(tensor.numpy(), dtype=self._dtype)
self._data_file.write(np_array.tobytes(order="C"))
self._sizes.append(np_array.size)
def end_document(self):
self._doc_idx.append(len(self._sizes))
def merge_file_(self, another_file):
# Concatenate index
index = MMapIndexedDataset.Index(index_file_path(another_file))
assert index.dtype == self._dtype
for size in index.sizes:
self._sizes.append(size)
# Concatenate data
with open(data_file_path(another_file), "rb") as f:
shutil.copyfileobj(f, self._data_file)
def finalize(self, index_file):
self._data_file.close()
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
index.write(self._sizes, self._doc_idx)

@ -0,0 +1,332 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT prompting dataset."""
import os
import time
import numpy as np
import torch
from codegeex.megatron import mpu, print_rank_0, get_tokenizer
from codegeex.megatron.data.blendable_dataset import BlendableDataset
from codegeex.megatron.data.dataset_utils import get_datasets_weights_and_num_samples
from codegeex.megatron.data.dataset_utils import get_train_valid_test_split_
from codegeex.megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
def build_train_valid_test_datasets(
data_prefix,
data_impl,
splits_string,
train_valid_test_num_samples,
seq_length,
seed,
skip_warmup,
):
"""Build train, valid, and test datasets."""
# Single dataset.
if len(data_prefix) == 1:
return _build_train_valid_test_datasets(
data_prefix[0],
data_impl,
splits_string,
train_valid_test_num_samples,
seq_length,
seed,
skip_warmup,
)
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(
data_prefix, train_valid_test_num_samples
)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
train_datasets = []
valid_datasets = []
test_datasets = []
for i in range(len(prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i],
data_impl,
splits_string,
datasets_train_valid_test_num_samples[i],
seq_length,
seed,
skip_warmup,
)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)
# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)
def _build_train_valid_test_datasets(
data_prefix,
data_impl,
splits_string,
train_valid_test_num_samples,
seq_length,
seed,
skip_warmup,
):
"""Build train, valid, and test datasets."""
# Indexed dataset.
assert os.path.exists(data_prefix + "_input_ids.bin"), f"Input tokens datafile not found: {data_prefix}_input_ids.bin"
assert os.path.exists(data_prefix + "_attention_mask.bin"), f"Attention mask datafile not found: {data_prefix}_attention_mask.bin"
assert os.path.exists(data_prefix + "_labels.bin"), f"Labels datafile not found: {data_prefix}_labels.bin"
input_ids_indexed_dataset = get_indexed_dataset_(data_prefix + "_input_ids", data_impl, skip_warmup)
attention_mask_indexed_dataset = get_indexed_dataset_(data_prefix + "_attention_mask", data_impl, skip_warmup)
labels_indexed_dataset = get_indexed_dataset_(data_prefix + "_labels", data_impl, skip_warmup)
total_num_of_documents = input_ids_indexed_dataset.sizes.shape[0]
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(" > dataset split:")
def print_split_stats(name, index):
print_rank_0(" {}:".format(name))
print_rank_0(
" document indices in [{}, {}) total of {} "
"documents".format(
splits[index], splits[index + 1], splits[index + 1] - splits[index]
)
)
print_split_stats("train", 0)
print_split_stats("validation", 1)
print_split_stats("test", 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(
start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32
)
dataset = PromptDataset(
name,
data_prefix,
documents,
input_ids_indexed_dataset,
attention_mask_indexed_dataset,
labels_indexed_dataset,
train_valid_test_num_samples[index],
seq_length,
seed,
)
return dataset
train_dataset = build_dataset(0, "train")
valid_dataset = build_dataset(1, "valid")
test_dataset = build_dataset(2, "test")
print_rank_0(f"train_dataset:{type(train_dataset)}")
print_rank_0(f"valid_dataset:{type(valid_dataset)}")
print_rank_0(f"test_dataset:{type(test_dataset)}")
return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
"""Build indexed dataset."""
print_rank_0(" > building dataset index ...")
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
print_rank_0(
" > finished creating indexed dataset in {:4f} "
"seconds".format(time.time() - start_time)
)
print_rank_0(" number of documents: {}".format(indexed_dataset.sizes.shape[0]))
return indexed_dataset
class PromptDataset(torch.utils.data.Dataset):
def __init__(
self,
name,
data_prefix,
documents,
input_ids_indexed_dataset,
attention_mask_index_dataset,
labels_indexed_dataset,
num_samples,
seq_length,
seed,
):
"""
Args:
name: name of the dataset.
data_prefix: prefix of the data.
documents: list of document indices.
input_ids_indexed_dataset: indexed dataset for prompts.
attention_mask_index_dataset: indexed dataset for text.
labels_indexed_dataset: indexed dataset for labels.
num_samples: number of samples to draw from the indexed dataset.
seq_length: sequence length.
seed: seed for random number generator.
"""
self.name = name
self.input_ids_indexed_dataset = input_ids_indexed_dataset
self.attention_mask_index_dataset = attention_mask_index_dataset
self.labels_indexed_dataset = labels_indexed_dataset
self.seq_length = seq_length
self.eod_token = get_tokenizer().eod
# Checks
assert np.min(documents) >= 0
assert np.max(documents) < input_ids_indexed_dataset.sizes.shape[0]
assert input_ids_indexed_dataset.sizes.shape[0] == attention_mask_index_dataset.sizes.shape[0]
assert attention_mask_index_dataset.sizes.shape[0] == labels_indexed_dataset.sizes.shape[0]
# Build index mappings.
self.doc_idx = _build_index_mappings(
self.name,
data_prefix,
documents,
self.input_ids_indexed_dataset.sizes,
num_samples,
seq_length,
seed,
)
def __len__(self):
# -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
return self.doc_idx.shape[0]
def __getitem__(self, idx):
# get the doc index
doc_idx = self.doc_idx[idx]
doc_idx = int(doc_idx) # NumPy int => Python int
input_ids = self.input_ids_indexed_dataset[doc_idx]
# print_rank_0(f"input_ids={input_ids}")
attention_mask = self.attention_mask_index_dataset[doc_idx]
labels = self.labels_indexed_dataset[doc_idx]
res = {
"input_ids": np.array(input_ids, dtype=np.int64),
"attention_mask": np.array(attention_mask, dtype=np.int64),
"labels": np.array(labels, dtype=np.int64),
}
return res
def _build_index_mappings(
name, data_prefix, documents, sizes, num_samples, seq_length, seed,
):
"""Build index mappings.
We only have to build doc-idx in prompt dataset.
Args:
name: name of the dataset.
data_prefix: prefix of the data.
documents: list of document indices.
sizes: sizes of the indexed dataset.
num_samples: number of samples to draw from the indexed dataset.
seq_length: sequence length.
seed: seed for random number generator.
"""
num_epochs = _num_epochs(documents.shape[0], num_samples)
np_rng = np.random.RandomState(seed=seed)
_filename = data_prefix
_filename += "_{}_indexmap".format(name)
_filename += "_{}ns".format(num_samples)
_filename += "_{}sl".format(seq_length)
_filename += "_{}s".format(seed)
doc_idx_filename = _filename + "_doc_idx.npy"
if torch.distributed.get_rank() == 0:
if not os.path.isfile(doc_idx_filename):
print_rank_0(
" > WARNING: could not find index map files, building "
"the indices on rank 0 ..."
)
start_time = time.time()
doc_idx = _build_doc_idx(documents, num_epochs, np_rng, False)[:num_samples]
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
print_rank_0(
" > elasped time to build and save doc-idx mapping "
"(seconds): {:4f}".format(time.time() - start_time)
)
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size()
// torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())
)
# Load mappings.
start_time = time.time()
print_rank_0(" > loading doc-idx mapping from {}".format(doc_idx_filename))
doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode="r")
print_rank_0(" total number of samples: {}".format(doc_idx.shape[0]))
print_rank_0(" total number of epochs: {}".format(num_epochs))
return doc_idx
def _num_epochs(samples_per_epoch, num_samples):
"""Calculate the epoch needed for so many sample."""
return int(np.ceil(num_samples / samples_per_epoch))
def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch):
"""Build an array with length = number-of-epochs * number-of-dcuments.
Each index is mapped to a corresponding document."""
if not separate_last_epoch or num_epochs == 1:
doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1]
doc_idx[:] = documents
doc_idx = doc_idx.reshape(-1)
doc_idx = doc_idx.astype(np.int32)
np_rng.shuffle(doc_idx)
return doc_idx
doc_idx_first = _build_doc_idx(documents, num_epochs - 1, np_rng, False)
doc_idx_last = _build_doc_idx(documents, 1, np_rng, False)
return np.concatenate((doc_idx_first, doc_idx_last))

@ -0,0 +1,31 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
class LayerType(enum.Enum):
encoder = 1
decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2

@ -22,6 +22,8 @@ import torch
from codegeex.megatron.tokenizer import build_tokenizer
from codegeex.megatron.arguments import parse_args
from codegeex.megatron.microbatches import build_num_microbatches_calculator
_GLOBAL_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
@ -82,6 +84,7 @@ def set_global_variables(
defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args,
)
_build_num_microbatches_calculator(args)
if args.vocab_file or args.tokenizer_path:
_ = _build_tokenizer(args)
_set_tensorboard_writer(args)
@ -101,6 +104,16 @@ def _parse_args(extra_args_provider=None, defaults={}, ignore_unknown_args=False
return _GLOBAL_ARGS
def _build_num_microbatches_calculator(args):
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
_ensure_var_is_not_initialized(
_GLOBAL_NUM_MICROBATCHES_CALCULATOR, "num microbatches calculator"
)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(args)
def _build_tokenizer(args):
"""Initialize tokenizer."""
global _GLOBAL_TOKENIZER

@ -0,0 +1,194 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Learning rate decay functions."""
import math
from codegeex.megatron import print_rank_0, get_args
class AnnealingLR(object):
"""Anneals the learning rate."""
def __init__(
self,
optimizer,
max_lr,
min_lr,
warmup_steps,
decay_steps,
decay_style,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False,
):
args = get_args()
# Class values.
self.optimizer = optimizer
self.max_lr = float(max_lr)
self.min_lr = min_lr
assert self.min_lr >= 0.0
assert self.max_lr >= self.min_lr
self.warmup_steps = warmup_steps
self.num_steps = 0
self.decay_steps = decay_steps
assert self.decay_steps > 0
assert self.warmup_steps < self.decay_steps
self.decay_tokens = args.lr_decay_tokens
self.num_tokens = 0
self.warmup_tokens = 0
self.decay_style = decay_style
self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, (
"both override and " "use-checkpoint are set."
)
# Set the learning rate
self.step(0)
print_rank_0("> learning rate decay style: {}".format(self.decay_style))
def get_lr(self):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
# Use linear warmup for the initial part.
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
if self.num_steps == self.warmup_steps and self.decay_tokens is not None:
self.warmup_tokens = self.num_tokens
return self.max_lr * float(self.num_steps) / float(self.warmup_steps)
# If the learning rate is constant, just return the initial value.
if self.decay_style == "constant":
return self.max_lr
if self.decay_tokens is None:
# step-based decay
# For any steps larger than `self.decay_steps`, use `self.min_lr`.
if self.num_steps > self.decay_steps:
return self.min_lr
# If we are done with the warmup period, use the decay style.
num_steps_ = self.num_steps - self.warmup_steps
decay_steps_ = self.decay_steps - self.warmup_steps
decay_ratio = float(num_steps_) / float(decay_steps_)
else:
# token-based decay
if self.num_tokens > self.decay_tokens:
return self.min_lr
num_tokens_ = self.num_tokens - self.warmup_tokens
decay_tokens_ = self.decay_tokens - self.warmup_tokens
decay_ratio = float(num_tokens_) / float(decay_tokens_)
assert decay_ratio >= 0.0
assert decay_ratio <= 1.0
delta_lr = self.max_lr - self.min_lr
if self.decay_style == "linear":
coeff = 1.0 - decay_ratio
elif self.decay_style == "cosine":
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
else:
raise Exception("{} decay style is not supported.".format(self.decay_style))
return self.min_lr + coeff * delta_lr
def step(self, increment, token_num=None):
"""Set lr for all parameters groups."""
if token_num is None:
args = get_args()
token_num = args.consumed_train_tokens
self.num_tokens = token_num
self.num_steps += increment
new_lr = self.get_lr()
for group in self.optimizer.param_groups:
group["lr"] = new_lr
def state_dict(self):
state_dict = {
"max_lr": self.max_lr,
"warmup_steps": self.warmup_steps,
"num_steps": self.num_steps,
"warmup_tokens": self.warmup_tokens,
"num_tokens": self.num_tokens,
"decay_style": self.decay_style,
"decay_steps": self.decay_steps,
"min_lr": self.min_lr,
}
return state_dict
def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
if self.override_lr_scheduler:
print_rank_0(" > overriding {} value to {}".format(name, cls_value))
return cls_value
if not self.use_checkpoint_lr_scheduler:
assert cls_value == sd_value, (
f"AnnealingLR: class input value {cls_value} and checkpoint"
f"value {sd_value} for {name} do not match"
)
print_rank_0(" > using checkpoint value {} for {}".format(sd_value, name))
return sd_value
def load_state_dict(self, sd):
if "start_lr" in sd:
max_lr_ = sd["start_lr"]
else:
max_lr_ = sd["max_lr"]
self.max_lr = self._check_and_set(self.max_lr, max_lr_, "learning rate")
self.min_lr = self._check_and_set(
self.min_lr, sd["min_lr"], "minimum learning rate"
)
if "warmup_iter" in sd:
warmup_steps_ = sd["warmup_iter"]
else:
warmup_steps_ = sd["warmup_steps"]
self.warmup_steps = self._check_and_set(
self.warmup_steps, warmup_steps_, "warmup iterations"
)
if "end_iter" in sd:
decay_steps_ = sd["end_iter"]
else:
decay_steps_ = sd["decay_steps"]
self.decay_steps = self._check_and_set(
self.decay_steps, decay_steps_, "total number of iterations"
)
self.decay_style = self._check_and_set(
self.decay_style, sd["decay_style"], "decay style"
)
if "num_iters" in sd:
num_steps = sd["num_iters"]
else:
num_steps = sd["num_steps"]
if "warmup_tokens" in sd:
self.warmup_tokens = sd["warmup_tokens"]
if "num_tokens" in sd:
self.num_tokens = sd["num_tokens"]
self.step(num_steps, self.num_tokens)

@ -0,0 +1,186 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron number of micro-batches calculators."""
from abc import ABC
from abc import abstractmethod
def build_num_microbatches_calculator(args):
# Constant num micro-batches.
if args.rampup_batch_size is None:
num_microbatches_calculator = ConstantNumMicroBatches(
args.global_batch_size, args.micro_batch_size, args.data_parallel_size
)
if args.rank == 0:
print(
"setting number of micro-batches to constant {}".format(
num_microbatches_calculator.get()
),
flush=True,
)
else:
assert len(args.rampup_batch_size) == 3, (
"expected the following "
"format: --rampup-batch-size <start batch size> "
"<batch size incerement> <ramp-up samples>"
)
start_batch_size = int(args.rampup_batch_size[0])
batch_size_increment = int(args.rampup_batch_size[1])
ramup_samples = int(args.rampup_batch_size[2])
if args.rank == 0:
print(
"will use batch size rampup starting from global batch "
"size {} to global batch size {} with batch size increments "
"{} over {} samples.".format(
start_batch_size,
args.global_batch_size,
batch_size_increment,
ramup_samples,
),
flush=True,
)
num_microbatches_calculator = RampupBatchsizeNumMicroBatches(
start_batch_size,
batch_size_increment,
ramup_samples,
args.global_batch_size,
args.micro_batch_size,
args.data_parallel_size,
)
return num_microbatches_calculator
class NumMicroBatchesCalculator(ABC):
def __init__(self):
self.num_micro_batches = None
self.current_global_batch_size = None
def get(self):
return self.num_micro_batches
def get_current_global_batch_size(self):
return self.current_global_batch_size
@abstractmethod
def update(self, consumed_samples, consistency_check):
pass
class ConstantNumMicroBatches(NumMicroBatchesCalculator):
def __init__(self, global_batch_size, micro_batch_size, data_parallel_size):
micro_batch_times_data_parallel = micro_batch_size * data_parallel_size
assert global_batch_size % micro_batch_times_data_parallel == 0, (
"global batch size ({}) is not divisible by micro batch size ({})"
" times data parallel size ({})".format(
global_batch_size, micro_batch_size, data_parallel_size
)
)
self.num_micro_batches = global_batch_size // micro_batch_times_data_parallel
assert self.num_micro_batches >= 1
self.current_global_batch_size = global_batch_size
def update(self, consumed_samples, consistency_check):
pass
class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
def __init__(
self,
start_batch_size,
batch_size_increment,
ramup_samples,
global_batch_size,
micro_batch_size,
data_parallel_size,
):
"""Batch size ramp up.
Over
steps = (global-batch-size - start-batch-size) / batch_size_increment
increment batch size from start-batch-size to global-batch-size using
rampup-samples / steps
samples.
Arguments:
start_batch_size: global batch size to start with
batch_size_increment: global batch size increments
ramup_samples: number of samples to use ramp up global
batch size from `start_batch_size` to `global_batch_size`
global_batch_size: global batch size post rampup
micro_batch_size: micro batch size
data_parallel_size: data parallel size.
"""
self.micro_batch_size = micro_batch_size
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = (
self.micro_batch_size * self.data_parallel_size
)
assert self.micro_batch_times_data_parallel_size > 0
assert start_batch_size > 0
self.start_batch_size = start_batch_size
assert global_batch_size > 0
self.global_batch_size = global_batch_size
diff_batch_size = self.global_batch_size - self.start_batch_size
assert diff_batch_size >= 0
assert batch_size_increment > 0
self.batch_size_increment = batch_size_increment
assert diff_batch_size % batch_size_increment == 0, (
"expected "
"global batch size interval ({}) to be divisible by global batch "
"size increment ({})".format(diff_batch_size, batch_size_increment)
)
num_increments = diff_batch_size // self.batch_size_increment
self.ramup_samples = ramup_samples
assert self.ramup_samples >= 0
self.rampup_samples_per_increment = self.ramup_samples / num_increments
# Initialize number of microbatches.
self.update(0, False)
def update(self, consumed_samples, consistency_check):
if consumed_samples > self.ramup_samples:
self.current_global_batch_size = self.global_batch_size
else:
steps = int(consumed_samples / self.rampup_samples_per_increment)
self.current_global_batch_size = (
self.start_batch_size + steps * self.batch_size_increment
)
assert self.current_global_batch_size <= self.global_batch_size
if consistency_check:
assert (
self.current_global_batch_size
% self.micro_batch_times_data_parallel_size
== 0
), (
"current global "
"batch size ({}) is not divisible by micro-batch-size ({}) times"
"data parallel size ({})".format(
self.current_global_batch_size,
self.micro_batch_size,
self.data_parallel_size,
)
)
self.num_micro_batches = (
self.current_global_batch_size // self.micro_batch_times_data_parallel_size
)

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.nn import LayerNorm
from .distributed import DistributedDataParallel
from .codegeex_model import CodeGeeXModel
from .language_model import get_language_model

@ -14,18 +14,18 @@
# limitations under the License.
import torch
from codegeex.megatron import get_args
from codegeex.megatron import mpu
from .module import MegatronModule
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .utils import init_method_normal
from .utils import scaled_init_method_normal
from codegeex.megatron import get_args, mpu
from codegeex.megatron.model import LayerNorm
from codegeex.megatron.enums import AttnMaskType
from codegeex.megatron.model.module import MegatronModule
from codegeex.megatron.model.language_model import parallel_lm_logits, get_language_model, EmbeddingPipe, QueryEmbeddingPipe
from codegeex.megatron.model.utils import init_method_normal, scaled_init_method_normal
from codegeex.megatron.model.transformer import ParallelTransformerLayerPipe, ParallelTopQueryLayerPipe
from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec
class CodeGeeXModel(MegatronModule):
"""Code Generative Model for Multilingual Program Synthesis."""
"""Code Generation Model for Multilingual Program Synthesis."""
def __init__(self, num_tokentypes=0, parallel_output=False):
super(CodeGeeXModel, self).__init__()
@ -41,6 +41,10 @@ class CodeGeeXModel(MegatronModule):
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(
self,
input_ids,
@ -107,3 +111,106 @@ class CodeGeeXModel(MegatronModule):
if self._language_model_key in state_dict:
state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict)
def CrossEntropy(output, labels):
labels, loss_mask = labels[0], labels[1]
args = get_args()
losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels)
loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
return loss
class CodeGeeXModelPipe(PipelineModule, MegatronModule):
"""Pipeline version of CodeGeeX."""
def __init__(self, num_tokentypes=0, parallel_output=True):
args = get_args()
self.parallel_output = parallel_output
init_method = init_method_normal(args.init_method_std)
self.specs = []
# Embedding layer
self.specs.append(
TiedLayerSpec(
"embed",
EmbeddingPipe,
args.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
init_method=init_method,
num_tokentypes=num_tokentypes,
tied_weight_attr="word_embeddings_weight",
)
)
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
for layer_idx in range(args.num_layers):
self.specs.append(
LayerSpec(
ParallelTransformerLayerPipe,
init_method=init_method,
output_layer_init_method=scaled_init_method_normal(
args.init_method_std, args.num_layers
),
layer_number=layer_idx,
self_attn_mask_type=AttnMaskType.causal,
)
)
# Undo data format change
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
# Final layernorm after transformer layers
self.specs.append(
LayerSpec(LayerNorm, args.hidden_size, eps=args.layernorm_epsilon)
)
def _logits_helper(embedding, lm_output):
"""A wrapper to massage inputs/outputs from pipeline."""
return parallel_lm_logits(
lm_output, embedding.word_embeddings_weight, self.parallel_output
)
self.specs.append(
TiedLayerSpec(
"embed",
EmbeddingPipe,
args.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
init_method=init_method,
num_tokentypes=num_tokentypes,
forward_fn=_logits_helper,
tied_weight_attr="word_embeddings_weight",
)
)
if args.checkpoint_activations:
interval = args.checkpoint_num_layers
else:
interval = 0
from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
topo = PipeModelDataParallelTopology(
num_pp=mpu.get_pipeline_model_parallel_world_size(),
num_mp=mpu.get_tensor_model_parallel_world_size(),
num_dp=mpu.get_data_parallel_world_size(),
)
super().__init__(
layers=self.specs,
loss_fn=CrossEntropy,
topology=topo,
activation_checkpoint_interval=interval,
partition_method="type:transformer",
)

@ -82,13 +82,15 @@ class Embedding(MegatronModule):
will ignore this embedding
"""
def __init__(self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
init_method,
num_tokentypes=0):
def __init__(
self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
init_method,
num_tokentypes=0,
):
super(Embedding, self).__init__()
self.hidden_size = hidden_size
@ -157,8 +159,9 @@ class Embedding(MegatronModule):
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
def state_dict_for_save_checkpoint(
self, destination=None, prefix='', keep_vars=False,
):
"""For easy load."""
state_dict_ = {}
@ -221,6 +224,40 @@ class Embedding(MegatronModule):
'checkpoint but could not find it', flush=True)
class EmbeddingPipe(Embedding):
def forward(self, inputs, **kwargs):
if not hasattr(self, "_args"):
self._args = get_args()
input_ids = inputs[0]
position_ids = inputs[1]
if hasattr(self._args, "attn_mask"):
attention_mask = None
else:
attention_mask = inputs[2]
if len(inputs) == 4:
tokentype_ids = inputs[3]
else:
tokentype_ids = None
embeddings = super().forward(
input_ids, position_ids, tokentype_ids=tokentype_ids
)
# If cmd args has attn_mask, we don't forward it as an activation.
if hasattr(self._args, "attn_mask"):
return embeddings
else:
assert False
return embeddings, attention_mask
@property
def word_embeddings_weight(self):
"""Easy accessory for the DeepSpeed pipeline engine to tie embeddings across stages."""
return self.word_embeddings.weight
class QueryEmbedding(MegatronModule):
"""Language model embeddings.
@ -249,8 +286,8 @@ class QueryEmbedding(MegatronModule):
self.num_tokentypes = num_tokentypes
# Top query position embedding (serial).
self.top_query_embeddings = torch.nn.Embedding(
max_sequence_length, self.hidden_size)
self.top_query_embeddings = mpu.VocabParallelEmbedding(
max_sequence_length, self.hidden_size, init_method=self.init_method)
self.top_query_embeddings = self.top_query_embeddings.half()
self._top_query_embeddings_key = 'top_query_embeddings'
# Initialize the top query position embeddings.
@ -352,6 +389,39 @@ class QueryEmbedding(MegatronModule):
'checkpoint but could not find it', flush=True)
class QueryEmbeddingPipe(QueryEmbedding):
def forward(self, inputs, **kwargs):
if not hasattr(self, "_args"):
self._args = get_args()
position_ids = inputs[0]
if hasattr(self._args, "attn_mask"):
attention_mask = None
else:
attention_mask = inputs[1]
if len(inputs) == 3:
tokentype_ids = inputs[2]
else:
tokentype_ids = None
embeddings = super().forward(
position_ids, tokentype_ids=tokentype_ids,
)
# If cmd args has attn_mask, we don't forward it as an activation.
if hasattr(self._args, "attn_mask"):
return embeddings
else:
assert False
return embeddings, attention_mask
@property
def word_embeddings_weight(self):
"""Easy accessory for the DeepSpeed pipeline engine to tie embeddings across stages."""
return self.top_query_embeddings.weight
class TransformerLanguageModel(MegatronModule):
"""Transformer language model.
@ -408,6 +478,10 @@ class TransformerLanguageModel(MegatronModule):
output_layer_init_method)
self._transformer_key = 'transformer'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.transformer.set_input_tensor(input_tensor)
def forward(
self,
input_ids,

@ -80,7 +80,7 @@ class ParallelMLP(MegatronModule):
self.dense_4h_to_h = mpu.RowParallelLinear(
4 * args.hidden_size,
args.hidden_size,
input_is_parallel=False,
input_is_parallel=True if args.tensor_model_parallel_size > 1 else False,
init_method=output_layer_init_method,
# skip_bias_add=True,
)
@ -151,7 +151,7 @@ class ParallelSelfAttention(MegatronModule):
self.dense = mpu.RowParallelLinear(
args.hidden_size,
args.hidden_size,
input_is_parallel=False,
input_is_parallel=True if args.tensor_model_parallel_size > 1 else False,
init_method=output_layer_init_method,
skip_bias_add=True)
@ -374,7 +374,7 @@ class ParallelTopQuerySelfAttention(MegatronModule):
self.dense = mpu.RowParallelLinear(
args.hidden_size,
args.hidden_size,
input_is_parallel=False,
input_is_parallel=True if args.tensor_model_parallel_size > 1 else False,
init_method=output_layer_init_method,
skip_bias_add=True)
@ -685,6 +685,43 @@ class ParallelTransformerLayer(MegatronModule):
return output
class ParallelTransformerLayerPipe(ParallelTransformerLayer):
"""Extends ParallelTransformerLayer to forward attention_mask through the pipeline.
Forward has two usages that affect attention mask communication:
1) forward((input, attn_mask) , **kwargs) -> (output, mask)
When the attention mask is provided as the second positional
argument, typical pipeline behavior is used and both the output
*and* mask are returned in a tuple. This tuple is then forwarded
to the next stage in the pipeline.
This version is useful if masks are dynamic.
2) forward(input, **kwargs) -> output
When the mask is static over all samples, it is advantageous to
cache the mask and avoid communicating it.
If no mask is provided, the module will query `self._args.attn_mask`
for the mask and only return `super().forward(...)`
"""
def forward(self, inputs, **kwargs):
assert torch.is_tensor(inputs) or isinstance(inputs, tuple)
if torch.is_tensor(inputs) or len(inputs) == 1:
# No attention mask forwarded, search for args.attn_mask
if not hasattr(self, "_args"):
self._args = get_args()
hidden_states, attention_mask = inputs, self._args.attn_mask
return super().forward(hidden_states, attention_mask, **kwargs)
elif len(inputs) == 2:
# Attention mask is an activation.
hidden_states, attention_mask = inputs[0], inputs[1]
return super().forward(*inputs, **kwargs), attention_mask
else:
raise RuntimeError("Received more inputs than understood.")
class ParallelTopQueryLayer(MegatronModule):
"""A single top query layer.
@ -810,6 +847,44 @@ class ParallelTopQueryLayer(MegatronModule):
return output
class ParallelTopQueryLayerPipe(ParallelTopQueryLayer):
"""Extends ParallelTopQueryLayer to forward attention_mask through the pipeline.
Forward has two usages that affect attention mask communication:
1) forward((input, attn_mask) , **kwargs) -> (output, mask)
When the attention mask is provided as the second positional
argument, typical pipeline behavior is used and both the output
*and* mask are returned in a tuple. This tuple is then forwarded
to the next stage in the pipeline.
This version is useful if masks are dynamic.
2) forward(input, **kwargs) -> output
When the mask is static over all samples, it is advantageous to
cache the mask and avoid communicating it.
If no mask is provided, the module will query `self._args.attn_mask`
for the mask and only return `super().forward(...)`
"""
def forward(self, inputs, **kwargs):
assert torch.is_tensor(inputs) or isinstance(inputs, tuple)
if torch.is_tensor(inputs) or len(inputs) == 2:
# No attention mask forwarded, search for args.attn_mask
if not hasattr(self, "_args"):
self._args = get_args()
hidden_states, query_hidden_state = inputs
attention_mask = self._args.attn_mask
return super().forward(hidden_states, query_hidden_state, attention_mask, **kwargs)
elif len(inputs) == 3:
# Attention mask is an activation.
hidden_states, query_hidden_state, attention_mask = inputs[0], inputs[1]
return super().forward(*inputs, **kwargs), attention_mask
else:
raise RuntimeError("Received more inputs than understood.")
class ParallelTransformer(MegatronModule):
"""Transformer class."""
@ -892,6 +967,16 @@ class ParallelTransformer(MegatronModule):
return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(
self,
hidden_states,

@ -271,6 +271,9 @@ class ColumnParallelLinear(torch.nn.Module):
stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
params_dtype=None,
skip_init=False,
device=None,
):
super(ColumnParallelLinear, self).__init__()
@ -282,54 +285,60 @@ class ColumnParallelLinear(torch.nn.Module):
world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add
self.params_dtype = params_dtype
self.device = device
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
args = get_args()
if args.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.output_size_per_partition,
self.input_size,
dtype=args.params_dtype,
if not skip_init:
if args.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.output_size_per_partition,
self.input_size,
dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype,
)
)
)
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
self.output_size,
self.input_size,
self.output_size_per_partition,
0,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
)
else:
self.weight = Parameter(
torch.empty(
self.output_size_per_partition,
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
self.output_size,
self.input_size,
device=torch.cuda.current_device(),
dtype=args.params_dtype,
self.output_size_per_partition,
0,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
)
)
_initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=0, stride=stride
)
else:
self.weight = Parameter(
torch.empty(
self.output_size_per_partition,
self.input_size,
device=self.device if self.device is not None else torch.cuda.current_device(),
dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype,
)
)
_initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=0, stride=stride
)
else:
self.register_parameter("weight", None)
if bias:
if bias and not skip_init:
if args.use_cpu_initialization:
self.bias = Parameter(
torch.empty(self.output_size_per_partition, dtype=args.params_dtype)
torch.empty(self.output_size_per_partition,
dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype)
)
else:
self.bias = Parameter(
torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=args.params_dtype,
device=self.device if self.device is not None else torch.cuda.current_device(),
dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype,
)
)
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
@ -395,6 +404,9 @@ class RowParallelLinear(torch.nn.Module):
stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
params_dtype=None,
skip_init=False,
device=None,
):
super(RowParallelLinear, self).__init__()
@ -406,53 +418,60 @@ class RowParallelLinear(torch.nn.Module):
world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add
self.params_dtype = params_dtype
self.device = device
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
args = get_args()
if args.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.output_size,
self.input_size_per_partition,
dtype=args.params_dtype,
if not skip_init:
if args.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.output_size,
self.input_size_per_partition,
dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype,
)
)
)
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
self.output_size,
self.input_size,
self.input_size_per_partition,
1,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
)
else:
self.weight = Parameter(
torch.empty(
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
self.output_size,
self.input_size,
self.input_size_per_partition,
device=torch.cuda.current_device(),
dtype=args.params_dtype,
1,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
)
)
_initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=1, stride=stride
)
if bias:
else:
self.weight = Parameter(
torch.empty(
self.output_size,
self.input_size_per_partition,
device=self.device if self.device is not None else torch.cuda.current_device(),
dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype,
)
)
_initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=1, stride=stride
)
else:
self.register_parameter("weight", None)
if bias and not skip_init:
if args.use_cpu_initialization:
self.bias = Parameter(
torch.empty(self.output_size, dtype=args.params_dtype)
torch.empty(self.output_size,
dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype)
)
else:
self.bias = Parameter(
torch.empty(
self.output_size,
device=torch.cuda.current_device(),
dtype=args.params_dtype,
device=self.device if self.device is not None else torch.cuda.current_device(),
dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype,
)
)
# Always initialize bias to zero.

@ -0,0 +1,129 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD
from codegeex.megatron import get_args
from codegeex.megatron.model import LayerNorm
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(modules):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
weight_decay_params = {"params": []}
no_weight_decay_params = {"params": [], "weight_decay": 0.0}
for module in modules:
for module_ in module.modules():
if isinstance(module_, LayerNorm):
no_weight_decay_params["params"].extend(
[p for p in list(module_._parameters.values()) if p is not None]
)
else:
weight_decay_params["params"].extend(
[
p
for n, p in list(module_._parameters.items())
if p is not None and n != "bias"
]
)
no_weight_decay_params["params"].extend(
[
p
for n, p in list(module_._parameters.items())
if p is not None and n == "bias"
]
)
return weight_decay_params, no_weight_decay_params
def get_megatron_optimizer(model):
args = get_args()
if args.cpu_optimizer:
raise NotImplementedError("need to add cpu adam")
param_groups = _get_params_for_weight_decay_optimization(model)
if args.optimizer == "adam":
optimizer = Adam(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
)
elif args.optimizer == "sgd":
optimizer = SGD(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
momentum=args.sgd_momentum,
)
else:
raise Exception("{} optimizer is not supported.".format(args.optimizer))
if args.deepspeed:
return optimizer
# Determine whether the params have main-grad field.
params_have_main_grad = False
if args.DDP_impl == "local":
params_have_main_grad = True
if args.fp16 or args.bf16:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler = None
# Constant loss scale.
if args.loss_scale:
grad_scaler = ConstantGradScaler(args.loss_scale)
# Dynamic loss scale.
else:
if args.fp16:
grad_scaler = DynamicGradScaler(
initial_scale=args.initial_loss_scale,
min_scale=args.min_loss_scale,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=args.loss_scale_window,
hysteresis=args.hysteresis,
)
# Megatron optimizer.
return Float16OptimizerWithFloat16Params(
optimizer,
args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.bf16,
grad_scaler,
)
# FP32.
return FP32Optimizer(
optimizer, args.clip_grad, args.log_num_zeros_in_grad, params_have_main_grad
)

@ -0,0 +1,153 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gradient clipping."""
import torch
from torch._six import inf
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from codegeex.megatron import mpu
from codegeex.megatron.model.module import param_is_not_shared
from codegeex.megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
grads = []
grads_for_norm = []
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
grad = param.grad.detach()
if grad_not_none:
# Make sure the grads are in fp32
assert param.grad.type() == "torch.cuda.FloatTensor"
grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad)
# Norm parameters.
max_norm = float(max_norm)
norm_type = float(norm_type)
total_norm = 0.0
# Calculate norm.
if norm_type == inf:
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
torch.distributed.all_reduce(
total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group(),
)
total_norm = total_norm_cuda[0].item()
else:
if norm_type == 2.0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
# Use apex's multi-tensor applier for efficiency reasons.
# Multi-tensor applier takes a function and a list of list
# and performs the operation on that list all in one kernel.
grad_norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_for_norm],
False, # no per-parameter norm
)
# Since we will be summing across data parallel groups,
# we need the pow(norm-type).
total_norm = grad_norm ** norm_type
else:
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(
total_norm,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group(),
)
total_norm = total_norm.item() ** (1.0 / norm_type)
# Scale.
clip_coeff = max_norm / (total_norm + 1.0e-6)
if clip_coeff < 1.0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(
amp_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff
)
return total_norm
def count_zeros_fp32(parameters):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
total_num_zeros = 0.0
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grad = param.grad.detach()
num_zeros = grad.numel() - torch.count_nonzero(grad)
total_num_zeros = num_zeros + total_num_zeros
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(
total_num_zeros,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group(),
)
total_num_zeros = total_num_zeros.item()
return total_num_zeros

@ -0,0 +1,132 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron grad scaler."""
from abc import ABC
from abc import abstractmethod
import torch
class MegatronGradScaler(ABC):
def __init__(self, initial_scale):
"""Initialize scale value with the input initial scale."""
assert initial_scale > 0.0
self._scale = torch.cuda.FloatTensor([initial_scale])
@property
def scale(self):
return self._scale
@property
def inv_scale(self):
return self._scale.double().reciprocal().float()
@abstractmethod
def update(self, found_inf):
pass
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict):
pass
class ConstantGradScaler(MegatronGradScaler):
def update(self, found_inf):
pass
def state_dict(self):
return dict()
def load_state_dict(self, state_dict):
pass
class DynamicGradScaler(MegatronGradScaler):
def __init__(
self,
initial_scale,
min_scale,
growth_factor,
backoff_factor,
growth_interval,
hysteresis,
):
""" "Grad scaler with dynamic scale that gets adjusted
during training."""
super(DynamicGradScaler, self).__init__(initial_scale)
# Lower bound on the scale.
assert min_scale > 0.0
assert min_scale <= initial_scale
self.min_scale = torch.cuda.FloatTensor([min_scale])
# Growth and backoff factors for the scale.
assert growth_factor > 1.0
self.growth_factor = torch.cuda.FloatTensor([growth_factor])
assert backoff_factor < 1.0
assert backoff_factor > 0.0
self.backoff_factor = torch.cuda.FloatTensor([backoff_factor])
# Interval over which if we don't see any inf/nan,
# we will scale the grad scale by the growth factor.
assert growth_interval > 0
self.growth_interval = growth_interval
# Number of inf/nans we should see before scaling down
# the grad scale by the backoff factor.
assert hysteresis > 0
self.hysteresis = hysteresis
# Trackers.
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
def update(self, found_inf):
# If we have an inf/nan, growth tracker is set to 0
# and hysterisis tracker is reduced by 1.
if found_inf:
self._growth_tracker = 0
self._hysteresis_tracker -= 1
# Now if we are out of hysteresis count, scale down the loss.
if self._hysteresis_tracker <= 0:
self._scale = torch.max(
self._scale * self.backoff_factor, self.min_scale
)
else:
# If there is no nan/inf, increment the growth tracker.
self._growth_tracker += 1
# If we have had enough consequitive intervals with no nan/inf:
if self._growth_tracker == self.growth_interval:
# Reset the tracker and hysteresis trackers,
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
# and scale up the loss scale.
self._scale = self._scale * self.growth_factor
def state_dict(self):
state_dict = {}
state_dict["scale"] = self._scale
state_dict["growth_tracker"] = self._growth_tracker
state_dict["hysteresis_tracker"] = self._hysteresis_tracker
return state_dict
def load_state_dict(self, state_dict):
self._scale = state_dict["scale"].cuda(torch.cuda.current_device())
self._growth_tracker = state_dict["growth_tracker"]
self._hysteresis_tracker = state_dict["hysteresis_tracker"]

@ -0,0 +1,505 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron optimizer."""
from abc import ABC
from abc import abstractmethod
import torch
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from codegeex.megatron import get_timers
from codegeex.megatron import mpu
from codegeex.megatron import print_rank_0
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
def _zero_grad_group_helper(group, set_to_none):
"""Zero out the gradient for a group of parameters.
Note: copied from torch.optim.optimizer."""
for param in group:
if param.grad is not None:
if set_to_none:
param.grad = None
else:
if param.grad.grad_fn is not None:
param.grad.detach_()
else:
param.grad.requires_grad_(False)
param.grad.zero_()
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
"""Use multi-tensor-applier to copy values from one list to another.
We don't have a blfoat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible
with bfloat16."""
if overflow_buf:
overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale, overflow_buf, [this, that], 1.0)
else:
for this_, that_ in zip(this, that):
that_.copy_(this_)
class MegatronOptimizer(ABC):
def __init__(
self, optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad
):
"""Input optimizer is the base optimizer for example Adam."""
self.optimizer = optimizer
assert self.optimizer, "no optimizer is provided."
# Set gradient clipping and logging params.
self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad
self.params_have_main_grad = params_have_main_grad
def get_parameters(self):
params = []
for param_group in self.optimizer.param_groups:
for param in param_group["params"]:
params.append(param)
return params
def clip_grad_norm(self, clip_grad):
params = self.get_parameters()
return clip_grad_norm_fp32(params, clip_grad)
def count_zeros(self):
params = self.get_parameters()
return count_zeros_fp32(params)
@abstractmethod
def zero_grad(self, set_to_none=True):
pass
@abstractmethod
def get_loss_scale(self):
"""The output should be a cuda tensor of size 1."""
pass
def scale_loss(self, loss):
"""Simple scaling."""
return self.get_loss_scale() * loss
@abstractmethod
def step(self):
pass
@abstractmethod
def reload_model_params(self):
"""Refreshes any internal state from the current model parameters.
Call whenever the parameters are changed outside of the optimizer.
For example, when we load a model from a checkpoint without loading
the optimizer, the model parameters are updated but for fp16 optimizer
with main parameters, the main parameters need to also be updated."""
pass
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict):
pass
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
class Float16OptimizerWithFloat16Params(MegatronOptimizer):
"""Float16 optimizer for fp16 and bf16 data types.
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a contihuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
"""
def __init__(
self,
optimizer,
clip_grad,
log_num_zeros_in_grad,
params_have_main_grad,
bf16,
grad_scaler,
):
super(Float16OptimizerWithFloat16Params, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad
)
self.bf16 = bf16
self.grad_scaler = grad_scaler
# None grad scaler is only supported for bf16.
if self.grad_scaler is None:
assert self.bf16, "fp16 expects a grad scaler."
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
# Note that we keep this for the cases that grad scaler is none.
# We still record nan/inf if we have a bfloat16 with a grad scaler.
if self.grad_scaler:
self.found_inf = torch.cuda.FloatTensor([0.0])
# Dummy tensor needed for apex multi-apply tensor.
# For bfloat, we don't have multi-tensor apply and for now
# we set it to none so the multi-tensor apply gets ignored.
if bf16:
self._dummy_overflow_buf = None
else:
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
# In case grad scaler is not passed, define the unity scale.
if self.grad_scaler is None:
self._scale_one = torch.cuda.FloatTensor([1.0])
# ======================
# main parameter stuff
# ======================
# Three groups of parameters:
# float16_groups: original float16 parameters
# fp32_from_float16_groups: fp32 copy of float16 parameters
# fp32_from_fp32_groups: original fp32 parameters
self.float16_groups = []
self.fp32_from_float16_groups = []
self.fp32_from_fp32_groups = []
# For all the groups in the original optimizer:
for param_group in self.optimizer.param_groups:
float16_params_this_group = []
fp32_params_this_group = []
fp32_from_float16_params_this_group = []
# For all the parameters in this group:
for i, param in enumerate(param_group["params"]):
if param.requires_grad:
# float16 params:
if param.type() in [
"torch.cuda.HalfTensor",
"torch.cuda.BFloat16Tensor",
]:
float16_params_this_group.append(param)
# Create a copy
main_param = param.detach().clone().float()
# Copy tensor model parallel attributes.
mpu.copy_tensor_model_parallel_attributes(main_param, param)
if hasattr(param, "shared"):
main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy.
param_group["params"][i] = main_param
fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param.
if param in self.optimizer.state:
self.optimizer.state[main_param] = self.optimizer.state.pop(
param
)
# fp32 params.
elif param.type() == "torch.cuda.FloatTensor":
fp32_params_this_group.append(param)
param_group["params"][i] = param
else:
raise TypeError(
"Wrapped parameters must be one of "
"torch.cuda.FloatTensor, "
"torch.cuda.HalfTensor, or "
"torch.cuda.BFloat16Tensor. "
"Received {}".format(param.type())
)
self.float16_groups.append(float16_params_this_group)
self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict())
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups."""
for group in self.float16_groups:
_zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_fp32_groups:
_zero_grad_group_helper(group, set_to_none)
def get_loss_scale(self):
if self.grad_scaler is None:
return self._scale_one
return self.grad_scaler.scale
def _copy_model_grads_to_main_grads(self):
# This only needs to be done for the float16 group.
for model_group, main_group in zip(
self.float16_groups, self.fp32_from_float16_groups
):
for model_param, main_param in zip(model_group, main_group):
if self.params_have_main_grad:
main_param.grad = model_param.main_grad.float()
else:
if model_param.grad is not None:
main_param.grad = model_param.grad.float()
# For fp32 grads, we need to reset the grads to main grad.
if self.params_have_main_grad:
for model_group in self.fp32_from_fp32_groups:
for model_param in model_group:
model_param.grad = model_param.main_grad
def _unscale_main_grads_and_check_for_nan(self):
main_grads = []
# fp32 params fromm float16 ones.
for main_group in self.fp32_from_float16_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
# Append fp32 parameters.
for main_group in self.fp32_from_fp32_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
# Reset found inf.
self.found_inf.fill_(0.0)
# Unscale and set found inf/nan
torch._amp_foreach_non_finite_check_and_unscale_(
main_grads, self.found_inf, self.grad_scaler.inv_scale
)
# Update across all model parallel instances.
torch.distributed.all_reduce(
self.found_inf,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group(),
)
# Check for nan.
found_inf_flag = self.found_inf.item() > 0
return found_inf_flag
def _get_model_and_main_params_data_float16(self):
model_data = []
main_data = []
for model_group, main_group in zip(
self.float16_groups, self.fp32_from_float16_groups
):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
main_data.append(main_param.data)
return model_data, main_data
def _copy_main_params_to_model_params(self):
# Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(
this=main_data, that=model_data, overflow_buf=self._dummy_overflow_buf
)
def _copy_model_params_to_main_params(self):
# Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(
this=model_data, that=main_data, overflow_buf=self._dummy_overflow_buf
)
def reload_model_params(self):
self._copy_model_params_to_main_params()
@torch.no_grad()
def step(self):
timers = get_timers()
# Copy gradients from model params to main params.
timers("optimizer-copy-to-main-grad").start()
self._copy_model_grads_to_main_grads()
timers("optimizer-copy-to-main-grad").stop()
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if self.grad_scaler:
# Unscale and check for inf/nan.
timers("optimizer-unscale-and-check-inf").start()
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
timers("optimizer-unscale-and-check-inf").stop()
# We are done with scaling gradients
# so we can update the loss scale.
self.grad_scaler.update(found_inf_flag)
# If we found inf/nan, skip the update.
if found_inf_flag:
return False, None, None
# Clip the main gradients.
timers("optimizer-clip-main-grad").start()
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
timers("optimizer-clip-main-grad").stop()
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if self.log_num_zeros_in_grad else None
# Step the optimizer.
self.optimizer.step()
# Update params from main params.
timers("optimizer-copy-main-to-model-params").start()
self._copy_main_params_to_model_params()
timers("optimizer-copy-main-to-model-params").stop()
# Successful update.
return True, grad_norm, num_zeros_in_grad
def state_dict(self):
state_dict = {}
state_dict["optimizer"] = self.optimizer.state_dict()
if self.grad_scaler:
state_dict["grad_scaler"] = self.grad_scaler.state_dict()
state_dict["fp32_from_fp16_params"] = self.fp32_from_float16_groups
return state_dict
def load_state_dict(self, state_dict):
# Optimizer.
optimizer_key = "optimizer"
if optimizer_key not in state_dict:
optimizer_key = "optimizer_state_dict"
print_rank_0(
"***WARNING*** loading optimizer from " "an old checkpoint ..."
)
self.optimizer.load_state_dict(state_dict[optimizer_key])
# Grad scaler.
if "grad_scaler" not in state_dict:
print_rank_0(
"***WARNING*** found an old checkpoint, will not "
"load grad scaler ..."
)
else:
if self.grad_scaler:
self.grad_scaler.load_state_dict(state_dict["grad_scaler"])
else:
print_rank_0(
"***WARNING*** fould the grad scaler in the "
"checkpoint but it is None in the class. "
"Skipping loading grad scaler ..."
)
# Copy data for the main params.
fp32_from_float16_params_key = "fp32_from_fp16_params"
if fp32_from_float16_params_key not in state_dict:
fp32_from_float16_params_key = "fp32_from_fp16"
for current_group, saved_group in zip(
self.fp32_from_float16_groups, state_dict[fp32_from_float16_params_key]
):
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)
class FP32Optimizer(MegatronOptimizer):
def __init__(
self, optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad
):
super(FP32Optimizer, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad
)
self._scale = torch.cuda.FloatTensor([1.0])
def zero_grad(self, set_to_none=True):
"""Copied from torch.optim.optimizer"""
for group in self.optimizer.param_groups:
_zero_grad_group_helper(group["params"], set_to_none)
def get_loss_scale(self):
"""FP32 optimizer does not do any scaling."""
return self._scale
@torch.no_grad()
def step(self):
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
# Copy main_grads to grads.
if self.params_have_main_grad:
for param_group in self.optimizer.param_groups:
for param in param_group["params"]:
param.grad = param.main_grad
# Clip gradients.
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if self.log_num_zeros_in_grad else None
# Update parameters.
self.optimizer.step()
# No overflow for FP32 optimizer.
return True, grad_norm, num_zeros_in_grad
def reload_model_params(self):
pass
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict)

@ -0,0 +1,510 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from contextlib import contextmanager
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from codegeex.megatron import get_args
from codegeex.megatron import get_num_microbatches
from codegeex.megatron import get_timers
from codegeex.megatron import mpu
from codegeex.megatron import p2p_communication
from codegeex.megatron.utils import unwrap_model
from codegeex.megatron.model import DistributedDataParallel as LocalDDP
from codegeex.megatron.model import Float16Module
def get_forward_backward_func():
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
timers = get_timers()
args = get_args()
timers("forward-compute").start()
unwrapped_model = unwrap_model(model, (torchDDP, LocalDDP, Float16Module))
if not args.deepspeed:
unwrapped_model.set_input_tensor(input_tensor)
else:
unwrapped_model.module.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
timers("forward-compute").stop()
return output_tensor
def backward_step(
optimizer, input_tensor, output_tensor, output_tensor_grad, model=None
):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
args = get_args()
if args.deepspeed:
assert model is not None
timers = get_timers()
timers("backward-compute").start()
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
if args.deepspeed:
model.backward(output_tensor)
else:
# Backward pass.
if output_tensor_grad is None:
output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
timers("backward-compute").stop()
return input_tensor_grad
@contextmanager
def dummy_handler():
try:
yield
finally:
pass
def forward_backward_no_pipelining(
forward_step_func, data_iterator, model, optimizer, timers, forward_only
):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
Returns dictionary with losses."""
assert len(model) == 1
model = model[0]
args = get_args()
context_handler = dummy_handler
if isinstance(model, torchDDP):
context_handler = model.no_sync
if args.deepspeed:
model.set_gradient_accumulation_boundary(False)
losses_reduced = []
input_tensor, output_tensor_grad = None, None
with context_handler():
for i in range(get_num_microbatches() - 1):
# print_rank_0("====> start of microstep {i}")
# print_rank_0("====> forward")
output_tensor = forward_step(
forward_step_func, data_iterator, model, input_tensor, losses_reduced
)
# print_rank_0("====> backward")
if not forward_only:
backward_step(
optimizer, input_tensor, output_tensor, output_tensor_grad, model
)
# print_rank_0("====> end of microstep {i}")
if args.deepspeed:
model.set_gradient_accumulation_boundary(True)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
# print_rank_0("====> start of the last microstep")
# print_rank_0("====> forward")
output_tensor = forward_step(
forward_step_func, data_iterator, model, input_tensor, losses_reduced
)
# print_rank_0("====> backward")
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, model)
# print_rank_0("====> end of the last microstep")
return losses_reduced
def forward_backward_pipelining_with_interleaving(
forward_step_func, data_iterator, model, optimizer, timers, forward_only
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
losses_reduced = []
if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))]
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
# Compute number of warmup and remaining microbatches.
num_model_chunks = len(model)
num_microbatches = get_num_microbatches() * num_model_chunks
all_warmup_microbatches = False
if forward_only:
num_warmup_microbatches = num_microbatches
else:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if get_num_microbatches() == pipeline_parallel_size:
num_warmup_microbatches = num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = (
pipeline_parallel_size - pipeline_parallel_rank - 1
) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
def get_model_chunk_id(microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group = microbatch_id % (
pipeline_parallel_size * num_model_chunks
)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward:
model_chunk_id = num_model_chunks - model_chunk_id - 1
return model_chunk_id
def forward_step_helper(microbatch_id):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == len(
output_tensors[model_chunk_id]
):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = forward_step(
forward_step_func,
data_iterator[model_chunk_id],
model[model_chunk_id],
input_tensor,
losses_reduced,
)
output_tensors[model_chunk_id].append(output_tensor)
return output_tensor
def backward_step_helper(microbatch_id):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = backward_step(
optimizer, input_tensor, output_tensor, output_tensor_grad
)
return input_tensor_grad
# Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(timers))
for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True):
if next_forward_model_chunk_id == 0:
recv_prev = False
if k == (num_microbatches - 1):
recv_prev = False
# Don't send tensor downstream if on last stage.
if mpu.is_pipeline_last_stage():
output_tensor = None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if (
k == (num_warmup_microbatches - 1)
and not forward_only
and not all_warmup_microbatches
):
input_tensor_grad = None
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
(
input_tensor,
output_tensor_grad,
) = p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
timers=timers,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else:
input_tensor = p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev, timers
)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state.
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
output_tensor = forward_step_helper(forward_k)
# Backward pass.
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if mpu.is_pipeline_last_stage():
output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if mpu.is_pipeline_first_stage():
input_tensor_grad = None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True
)
if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(
forward_k + 1, forward=True
)
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False
)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(
backward_k + 1, forward=False
)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if k == (num_microbatches_remaining - 1):
recv_prev = False
# Communicate tensors.
(
input_tensor,
output_tensor_grad,
) = p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
timers=timers,
)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor)
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
# Run cooldown backward passes (flush out pipeline).
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append(
p2p_communication.recv_backward(timers)
)
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_model_chunk_id == (num_model_chunks - 1):
recv_next = False
if k == (num_microbatches - 1):
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next, timers
)
)
return losses_reduced
def forward_backward_pipelining_without_interleaving(
forward_step_func, data_iterator, model, optimizer, timers, forward_only
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
timers = get_timers()
assert len(model) == 1
model = model[0]
# Compute number of warmup microbatches.
num_microbatches = get_num_microbatches()
num_warmup_microbatches = (
mpu.get_pipeline_model_parallel_world_size()
- mpu.get_pipeline_model_parallel_rank()
- 1
)
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
input_tensors = []
output_tensors = []
losses_reduced = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = p2p_communication.recv_forward(timers)
output_tensor = forward_step(
forward_step_func, data_iterator, model, input_tensor, losses_reduced
)
p2p_communication.send_forward(output_tensor, timers)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_tensor = p2p_communication.recv_forward(timers)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = i == (num_microbatches_remaining - 1)
output_tensor = forward_step(
forward_step_func, data_iterator, model, input_tensor, losses_reduced
)
if forward_only:
p2p_communication.send_forward(output_tensor, timers)
else:
output_tensor_grad = p2p_communication.send_forward_recv_backward(
output_tensor, timers
)
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
if forward_only:
if not last_iteration:
input_tensor = p2p_communication.recv_forward(timers)
else:
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = backward_step(
optimizer, input_tensor, output_tensor, output_tensor_grad, model
)
if last_iteration:
input_tensor = None
p2p_communication.send_backward(input_tensor_grad, timers)
else:
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, timers
)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = p2p_communication.recv_backward(timers)
input_tensor_grad = backward_step(
optimizer, input_tensor, output_tensor, output_tensor_grad, model
)
p2p_communication.send_backward(input_tensor_grad, timers)
return losses_reduced

@ -125,7 +125,10 @@ def _vocab_size_with_padding(orig_vocab_size, args):
still having GPU friendly size."""
after = orig_vocab_size
multiple = args.make_vocab_size_divisible_by * args.tensor_model_parallel_size
if args.make_vocab_size_divisible_by > orig_vocab_size:
multiple = args.make_vocab_size_divisible_by
else:
multiple = args.make_vocab_size_divisible_by * args.tensor_model_parallel_size
while (after % multiple) != 0:
after += 1
if args.rank == 0:

@ -0,0 +1,23 @@
import os
ENV_NAMES = ["CUDA_HOME", "LD_LIBRARY_PATH", "PATH", "TORCH_EXTENSIONS_DIR", "CUDA_LAUNCH_BLOCKING"]
def main():
s = ""
for name in ENV_NAMES:
if name in os.environ:
value = os.environ[name]
s += "{}={}\n".format(name, value)
print(f"{name}={value}")
else:
print(f"{name} is not set")
# write env vars to .deepspeed_env
with open(".deepspeed_env", "w") as f:
f.write(s)
if __name__ == "__main__":
main()

@ -0,0 +1,208 @@
import os
import subprocess
import torch
import logging
logging.getLogger("torch").setLevel(logging.WARNING)
import deepspeed
from deepspeed.runtime.utils import see_memory_usage
from functools import partial
from codegeex.megatron import get_args, print_rank_0, get_timers,get_tokenizer, mpu
from codegeex.megatron.data.prompt_dataset import build_train_valid_test_datasets
from codegeex.megatron.model import CodeGeeXModel #, CodeGeeXModelPipe
from codegeex.megatron.training import pretrain
from codegeex.megatron.utils import get_ltor_masks_and_position_ids
from codegeex.megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0("building GPT model ...")
see_memory_usage(f"Before Building Model", force=True)
args = get_args()
with deepspeed.zero.Init(
data_parallel_group=mpu.get_data_parallel_group(),
remote_device=None if args.remote_device == "none" else args.remote_device,
config_dict_or_path=args.deepspeed_config,
enabled=args.zero_stage == 3,
mpu=mpu,
):
if args.deepspeed and not args.no_pipeline_parallel:
model = CodeGeeXModelPipe(num_tokentypes=0, parallel_output=True)
# This is a hack to give us a reference to get_batch_pipe from within training.py
# We need to call model.set_batch_fn after deepspeed.initialize
model._megatron_batch_fn = get_batch_pipe
# Predompute the attention mask and store it in args. This avoids having to
# pipeline it as an activation during training. The mask is constant, and thus
# we can reuse it.
attention_mask = torch.tril(
torch.ones(
(1, args.seq_length, args.seq_length),
device=torch.cuda.current_device(),
)
).view(1, 1, args.seq_length, args.seq_length)
# Convert attention mask to binary:
attention_mask = attention_mask < 0.5
if args.fp16:
attention_mask = attention_mask.half()
elif args.bf16:
attention_mask = attention_mask.bfloat16()
# Attention mask must be bool.
args.attn_mask = attention_mask.to(torch.bool)
else:
model = CodeGeeXModel(
num_tokentypes=0,
parallel_output=True,
)
if args.load_state is not None:
timers = get_timers()
print_rank_0("Loading warmstarting model states ...")
timers("load-model-states").start()
mp_rank = mpu.get_tensor_model_parallel_rank()
if os.path.isdir(args.load_state):
model_path = os.path.join(
args.load_state, f"model_mp_rank_{mp_rank}.pt"
)
else:
model_path = args.load_state
print_rank_0(f"Loading model from {model_path} ...")
state_dict = torch.load(model_path, map_location="cpu")
if "module" in state_dict:
state_dict = state_dict["module"] # strip other client states
model.load_state_dict(state_dict)
timers("load-model-states").stop()
timers.log(["load-model-states"])
see_memory_usage(f"After Building Model", force=True)
return model
def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ["input_ids"]
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b["input_ids"].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
)
return tokens, labels, loss_mask, attention_mask, position_ids
def get_batch_pipe(data):
"""Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ["input_ids"]
datatype = torch.int64
# Broadcast data.
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b["input_ids"].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
)
return (tokens, position_ids, attention_mask), (labels, loss_mask)
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"lm loss": averaged_loss[0]}
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers("batch-generator").start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
timers("batch-generator").stop()
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0("> building train, validation, and test datasets " "for GPT ...")
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
)
print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds
def command_exists(cmd):
result = subprocess.Popen(f"type {cmd}", stdout=subprocess.PIPE, shell=True)
return result.wait() == 0
if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
forward_step,
args_defaults={"tokenizer_type": "GPT2BPETokenizer"},
)

File diff suppressed because it is too large Load Diff

@ -217,3 +217,26 @@ def get_parameters_in_billions(model):
)
return approx_parameters_in_billions * gpus_per_model / (1e9)
def flops_calculator(model, args, iteration_time):
return # currently broken
gpus_per_model = torch.distributed.get_world_size(
group=mpu.get_model_parallel_group()
)
approx_parameters_in_billions = get_parameters_in_billions(model)
batch_size = args.micro_batch_size * get_num_microbatches()
giga_flops_per_model_per_train_step = (
approx_parameters_in_billions * batch_size * args.seq_length * 2.0 * 4.0
)
effective_tera_flops_per_gpu = giga_flops_per_model_per_train_step / (
iteration_time * 1000.0 * gpus_per_model
)
print_rank_0(
f"Effective Tera Flops per GPU: {round(effective_tera_flops_per_gpu, 2)} and total parameters {round(approx_parameters_in_billions, 3)} B"
)

@ -0,0 +1 @@
from .quantize import quantize

@ -0,0 +1,329 @@
import torch
from torch.nn.parameter import Parameter
from codegeex.kernels import extract_weight_to_half
from codegeex.megatron.mpu.layers import RowParallelLinear, ColumnParallelLinear
from codegeex.megatron.mpu.mappings import copy_to_tensor_model_parallel_region, gather_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region, scatter_to_tensor_model_parallel_region
class W8A16Linear(torch.autograd.Function):
@staticmethod
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
ctx.inp_shape = inp.size()
ctx.weight_shape = quant_w.size()
ctx.weight_bit_width = weight_bit_width
out_features = quant_w.size(0)
inp = inp.contiguous().view(-1, inp.size(-1))
weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
output = inp.mm(weight.t())
ctx.save_for_backward(inp, quant_w, scale_w)
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
inp, quant_w, scale_w = ctx.saved_tensors
weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
grad_output = grad_output.contiguous().view(-1, weight.size(0))
grad_input = grad_output.mm(weight)
grad_weight = grad_output.t().mm(inp)
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None
class QuantizedLinear(torch.nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
weight_bit_width: int,
weight: torch.Tensor = None,
bias: torch.Tensor = None,
*args,
**kwargs
):
super(QuantizedLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight_bit_width = weight_bit_width
if weight is None:
self.weight = torch.empty(
self.out_features, self.in_features * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
)
self.weight_scale = torch.empty(self.out_features, dtype=kwargs["params_dtype"], device=kwargs["device"])
else:
self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
if weight_bit_width == 4:
self.weight = compress_int4_weight(self.weight)
if bias is None:
self.register_parameter('bias', None)
else:
self.bias = bias
self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
def forward(self, input_):
# Matrix multiply.
output = W8A16Linear.apply(input_, self.weight, self.weight_scale, self.weight_bit_width)
if self.bias is not None:
output = output + self.bias
return output
class QuantizedColumnParallelLinear(ColumnParallelLinear):
def __init__(
self,
input_size: int,
output_size: int,
weight_bit_width: int,
weight: torch.Tensor = None,
bias: torch.Tensor = None,
*args,
**kwargs,
):
super(QuantizedColumnParallelLinear, self).__init__(input_size, output_size, *args, **kwargs)
self.input_size = input_size
self.output_size = output_size
self.weight_bit_width = weight_bit_width
if "skip_bias_add" in kwargs:
self.skip_bias_add = kwargs["skip_bias_add"]
else:
self.skip_bias_add = False
del self.weight
if weight is None:
self.weight = torch.empty(
self.output_size, self.input_size * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
)
self.weight_scale = torch.empty(self.output_size, dtype=kwargs["params_dtype"], device=kwargs["device"])
else:
self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
if weight_bit_width == 4:
self.weight = compress_int4_weight(self.weight)
if bias is None:
self.register_parameter('bias', None)
else:
del self.bias
self.bias = bias
self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
def forward(self, input_):
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width)
if self.bias is not None and not self.skip_bias_add:
output_parallel = output_parallel + self.bias
if self.gather_output:
# All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class QuantizedRowParallelLinear(RowParallelLinear):
def __init__(
self,
input_size: int,
output_size: int,
weight_bit_width: int,
weight: torch.Tensor = None,
bias: torch.Tensor = None,
*args,
**kwargs,
):
super(QuantizedRowParallelLinear, self).__init__(input_size, output_size, *args, **kwargs)
self.input_size = input_size
self.output_size = output_size
self.weight_bit_width = weight_bit_width
if "skip_bias_add" in kwargs:
self.skip_bias_add = kwargs["skip_bias_add"]
else:
self.skip_bias_add = False
del self.weight
if weight is None:
self.weight = torch.empty(
self.output_size, self.input_size * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
)
self.weight_scale = torch.empty(self.output_size, dtype=kwargs["params_dtype"], device=kwargs["device"])
else:
self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
if weight_bit_width == 4:
self.weight = compress_int4_weight(self.weight)
if bias is None:
self.register_parameter('bias', None)
else:
del self.bias
self.bias = bias
self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
def forward(self, input_):
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width)
# All-reduce across all the partitions.
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if self.bias is not None and not self.skip_bias_add:
output = output_ + self.bias
else:
output = output_
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def quantize(model, weight_bit_width, backend="torch"):
"""Replace fp16 linear with quantized linear"""
for i in range(len(model.language_model.transformer.layers) + 1):
if i == len(model.language_model.transformer.layers):
layer = model.language_model.transformer.topQueryLayer
else:
layer = model.language_model.transformer.layers[i]
if backend == "torch":
layer.attention.query = QuantizedLinear(
in_features=layer.attention.query.in_features,
out_features=layer.attention.query.out_features,
weight_bit_width=weight_bit_width,
weight=layer.attention.query.weight.to(torch.cuda.current_device()),
bias=layer.attention.query.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.query.weight.device,
)
layer.attention.value = QuantizedLinear(
in_features=layer.attention.value.in_features,
out_features=layer.attention.value.out_features,
weight_bit_width=weight_bit_width,
weight=layer.attention.value.weight.to(torch.cuda.current_device()),
bias=layer.attention.value.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.value.weight.device,
)
layer.attention.key = QuantizedLinear(
in_features=layer.attention.key.in_features,
out_features=layer.attention.key.out_features,
weight_bit_width=weight_bit_width,
weight=layer.attention.key.weight.to(torch.cuda.current_device()),
bias=layer.attention.key.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.key.weight.device,
)
layer.attention.dense = QuantizedLinear(
in_features=layer.attention.dense.in_features,
out_features=layer.attention.dense.out_features,
weight_bit_width=weight_bit_width,
weight=layer.attention.dense.weight.to(torch.cuda.current_device()),
bias=layer.attention.dense.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.dense.weight.device,
)
layer.mlp.dense_h_to_4h = QuantizedLinear(
in_features=layer.mlp.dense_h_to_4h.in_features,
out_features=layer.mlp.dense_h_to_4h.out_features,
weight_bit_width=weight_bit_width,
weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
bias=layer.mlp.dense_h_to_4h.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.mlp.dense_h_to_4h.weight.device,
)
layer.mlp.dense_4h_to_h = QuantizedLinear(
in_features=layer.mlp.dense_4h_to_h.in_features,
out_features=layer.mlp.dense_4h_to_h.out_features,
weight_bit_width=weight_bit_width,
weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
bias=layer.mlp.dense_4h_to_h.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.mlp.dense_4h_to_h.weight.device,
)
elif backend == "megatron":
layer.attention.query = QuantizedColumnParallelLinear(
weight_bit_width=weight_bit_width,
weight=layer.attention.query.weight.to(torch.cuda.current_device()),
bias=layer.attention.query.bias.to(torch.cuda.current_device()),
input_size=layer.attention.query.input_size,
output_size=layer.attention.query.output_size,
gather_output=False,
skip_init=True,
params_dtype=torch.half,
device=layer.attention.query.weight.device,
)
layer.attention.value = QuantizedColumnParallelLinear(
weight_bit_width=weight_bit_width,
weight=layer.attention.value.weight.to(torch.cuda.current_device()),
bias=layer.attention.value.bias.to(torch.cuda.current_device()),
input_size=layer.attention.value.input_size,
output_size=layer.attention.value.output_size,
gather_output=False,
skip_init=True,
params_dtype=torch.half,
device=layer.attention.value.weight.device,
)
layer.attention.key = QuantizedColumnParallelLinear(
weight_bit_width=weight_bit_width,
weight=layer.attention.key.weight.to(torch.cuda.current_device()),
bias=layer.attention.key.bias.to(torch.cuda.current_device()),
input_size=layer.attention.key.input_size,
output_size=layer.attention.key.output_size,
gather_output=False,
skip_init=True,
params_dtype=torch.half,
device=layer.attention.key.weight.device,
)
layer.attention.dense = QuantizedRowParallelLinear(
weight_bit_width=weight_bit_width,
weight=layer.attention.dense.weight.to(torch.cuda.current_device()),
bias=layer.attention.dense.bias.to(torch.cuda.current_device()),
input_size=layer.attention.dense.input_size,
output_size=layer.attention.dense.output_size,
input_is_parallel=False,
skip_init=True,
skip_bias_add=True,
params_dtype=torch.half,
device=layer.attention.dense.weight.device,
)
layer.mlp.dense_h_to_4h = QuantizedColumnParallelLinear(
weight_bit_width=weight_bit_width,
weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
bias=layer.mlp.dense_h_to_4h.bias.to(torch.cuda.current_device()),
input_size=layer.mlp.dense_h_to_4h.input_size,
output_size=layer.mlp.dense_h_to_4h.output_size,
gather_output=False,
skip_init=True,
params_dtype=torch.half,
device=layer.mlp.dense_h_to_4h.weight.device,
)
layer.mlp.dense_4h_to_h = QuantizedRowParallelLinear(
weight_bit_width=weight_bit_width,
weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
bias=layer.mlp.dense_4h_to_h.bias.to(torch.cuda.current_device()),
input_size=layer.mlp.dense_4h_to_h.input_size,
output_size=layer.mlp.dense_4h_to_h.output_size,
input_is_parallel=False,
skip_init=True,
params_dtype=torch.half,
device=layer.mlp.dense_4h_to_h.weight.device,
)
return model

@ -0,0 +1 @@
from .tokenizer import CodeGeeXTokenizer

@ -0,0 +1,86 @@
from typing import *
from transformers import AutoTokenizer
from transformers.models.gpt2 import GPT2TokenizerFast
def encode_whitespaces(text, start_extra_id: int, max_len: int):
""" Encode whitespaces to extra tokens in GPT-J.
>>> encode_whitespaces('a\\n b\\n c', 10, 10)
'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c'
"""
def push_acc_space(acc_len: int, text: str):
if acc_len == 0:
return text
if acc_len == 1:
return text + ' '
assert acc_len <= max_len, f'Max whitespace run length {max_len}, but found {acc_len}'
extra_id = start_extra_id - 2 + acc_len
extra_token = f'<|extratoken_{extra_id}|>'
return text + extra_token
acc_len = 0
res = ''
for ch in text:
if ch == ' ':
acc_len += 1
if acc_len == max_len:
res = push_acc_space(acc_len, res)
acc_len = 0
else:
res = push_acc_space(acc_len, res)
acc_len = 0
res = res + ch
res = push_acc_space(acc_len, res)
return res
def decode_whitespaces(text: str, start_extra_id: int, max_len: int):
""" Decode the whitespace-encoded strings produced by encode_whitespace.
>>> text = 'a\\n b\\n c'
>>> s, l = 10, 10
>>> text == decode_whitespaces(encode_whitespaces(text, s, l), s, l)
True
"""
for l in range(2, max_len + 1):
token_id = start_extra_id - 2 + l
token = f'<|extratoken_{token_id}|>'
text = text.replace(token, ' ' * l)
return text
class CodeGeeXTokenizer(object):
def __init__(
self,
tokenizer: GPT2TokenizerFast = None,
tokenizer_path: str = "EleutherAI/gpt-j-6B",
start_extra_id: int = 10,
max_len : int = 10,
mode='codegeex-13b',
dict_file: str = None,
):
self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(tokenizer_path)
if mode not in ['codegeex-13b']:
raise ValueError(f"Invalid mode {mode}, choose from ['codegeex-13b']")
self.start_extra_id = start_extra_id
self.max_len = max_len
self.mode = mode
self.eos_token_id = self.tokenizer.eos_token_id
def encode_code(self, code: str):
if self.mode == 'codegeex-13b':
code = encode_whitespaces(code, self.start_extra_id, self.max_len)
input_ids = self.tokenizer(code, is_split_into_words=False, verbose=False).input_ids
return input_ids
def decode_code(self, input_ids):
if self.mode == 'codegeex-13b':
text = self.tokenizer.decode(input_ids, skip_special_tokens=False, verbose=False)
output_code = decode_whitespaces(text, self.start_extra_id, self.max_len)
return output_code

@ -0,0 +1 @@
from .codegeex_model import CodeGeeXModel

File diff suppressed because it is too large Load Diff

@ -0,0 +1,326 @@
import copy
import json
import os
import time
from typing import *
import torch
import torch.nn.functional as F
from dataclasses import dataclass
def get_ltor_masks_and_position_ids(
data,
eod_token,
reset_position_ids,
reset_attention_mask,
):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
).view(att_mask_batch, 1, seq_length, seq_length)
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1) :] -= i + 1 - prev_index
prev_index = i + 1
# Convert attention mask to binary:
attention_mask = attention_mask < 0.5
return attention_mask, position_ids
def get_batch(
context_tokens,
micro_batch_size,
eod_token,
reset_position_ids=False,
reset_attention_mask=False,
):
"""Generate batch from context tokens."""
tokens = context_tokens.view(micro_batch_size, -1).contiguous().cuda()
# Get the attention mask and postition ids.
attention_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
eod_token,
reset_position_ids,
reset_attention_mask,
)
return tokens, attention_mask, position_ids
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
"""This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313"""
if top_k > 0:
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
# Cconvert to 1D
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
return logits
def pad_batch(batch, pad_id, seq_length):
context_lengths = []
for tokens in batch:
context_length = len(tokens)
if context_length < seq_length:
tokens.extend([pad_id] * (seq_length - context_length))
context_lengths.append(context_length)
return batch, context_lengths
def forward_step(
model,
tokens,
seq_length,
position_ids,
attention_mask,
layer_past=None,
get_key_value=None,
prompt_length=None,
context_length=None,
):
# Forward pass through the model.
output_tensor = model(
tokens,
position_ids,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
prompt_length=prompt_length,
context_length=context_length,
)
if get_key_value:
output_tensor, layer_past = output_tensor
if get_key_value:
return output_tensor, layer_past
return output_tensor
def get_token_stream(
model,
tokenizer,
seq_length,
out_seq_length,
context_tokens,
return_scores: bool = False,
prompt_length: int = None,
micro_batch_size: int = None,
bad_ids: List = None,
temperature: float = 1.0,
topp: float = 1.0,
topk: int = 0.0,
greedy: bool = False,
recompute: bool = False,
):
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eos_token_id, seq_length)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths)
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(
context_tokens_tensor,
micro_batch_size,
tokenizer.eos_token_id,
)
batch_token_iterator = sample_sequence_batch(
model,
tokenizer,
context_tokens_tensor,
context_length_tensor,
attention_mask,
position_ids,
seq_length=seq_length,
out_seq_length=out_seq_length,
return_scores=return_scores,
prompt_length=prompt_length,
bad_ids=bad_ids,
temperature=temperature,
topp=topp,
topk=topk,
greedy=greedy,
recompute=recompute,
)
for tokens, lengths in batch_token_iterator:
context_length += 1
if tokens is not None:
yield tokens[:, :context_length], lengths
else:
yield None, None
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2
def sample_sequence_batch(
model,
tokenizer,
context_tokens,
context_lengths,
attention_mask,
position_ids,
seq_length,
out_seq_length,
maxlen=None,
return_scores: bool = False,
prompt_length: int = None,
bad_ids: List = None,
temperature: float = 1.0,
topp: float = 1.0,
topk: int = 0.0,
recompute: bool = False,
greedy: bool = False,
):
model.eval()
with torch.no_grad():
context_length = context_lengths.min().item()
eos_id = tokenizer.eos_token_id
counter = 0
org_context_length = context_length
layer_past = None
batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens
if maxlen is None:
maxlen = seq_length - 1
if maxlen > (org_context_length + out_seq_length):
maxlen = org_context_length + out_seq_length
lengths = torch.ones([batch_size]).long().cuda() * maxlen
if return_scores:
scores = torch.zeros([batch_size]).float().cuda()
while context_length <= (maxlen):
if recompute:
logits = model(tokens,
position_ids,
attention_mask,
prompt_length=prompt_length,
context_length=context_length,
)
logits = logits[:, context_length - 1, :]
else:
if counter == 0:
tokens2use = tokens[:, :context_length]
positions2use = position_ids[:, :context_length]
else:
tokens2use = tokens[:, context_length - 1].view(
batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(
batch_size, -1)
logits, layer_past = model(tokens2use,
positions2use,
attention_mask,
layer_past=layer_past,
get_key_value=True,
prompt_length=prompt_length,
context_length=context_length,
)
logits = logits[:, -1].view(batch_size, -1).contiguous()
if bad_ids is not None:
for bad_id in bad_ids:
logits[:, bad_id] = -10000
if greedy:
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits = logits.float()
if return_scores:
orig_log_probs = torch.log_softmax(logits, dim=-1)
logits /= temperature
logits = top_k_logits(logits, top_k=topk, top_p=topp)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
started = context_lengths <= context_length
new_tokens = switch(tokens[:, context_length].view(-1), prev, started)
if not greedy and return_scores:
indices = prev.view(-1, 1)
new_scores = orig_log_probs.gather(1, indices).view(-1)
new_scores = new_scores * started
new_scores = new_scores * is_done.bool().logical_not()
scores += new_scores
tokens[:, context_length] = new_tokens
done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
is_done = is_done | done_token
done = torch.all(is_done)
if return_scores:
yield tokens, (lengths, scores)
else:
yield tokens, lengths
context_length += 1
counter += 1
if done:
break

@ -0,0 +1,17 @@
# CodeGeeX-13B parallel configuration
# Parallel checkpoints are named under the format "mp_rank_0{i}_model_states.pt", where i is the rank, start from 0.
CHECKPOINT_PATH="<path where you put all parallel checkpoints (e.g., XXX/tp4/)>"
MODEL_ARGS="--num-layers 39 \
--hidden-size 5120 \
--num-attention-heads 40 \
--max-position-embeddings 2048 \
--attention-softmax-in-fp32 \
--load "$CHECKPOINT_PATH" \
--layernorm-epsilon 1e-5 \
--fp16 \
--ws-encoding-start-id 10 \
--ws-encoding-length 10 \
--make-vocab-size-divisible-by 52224 \
--seq-length 2048"

@ -0,0 +1,37 @@
# This script is used to convert mindspore checkpoint to the megatron format.
LOAD_CKPT_PATH=$1 # Path to weights in .pt format.
SAVE_CKPT_PATH=$2 # Path to save the output MP checkpoints.
MP_SIZE=$3 # Model parallel size
SCRIPT_PATH=$(realpath "$0")
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
MAIN_DIR=$(dirname "$SCRIPT_DIR")
TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
if [ -z "$MP_SIZE" ]; then
MP_SIZE=1
fi
# export CUDA settings
export CUDA_HOME=/usr/local/cuda-11.1/
export CUDA_VISIBLE_DEVICES=0,1
CMD="python $MAIN_DIR/codegeex/megatron/convert_ckpt_parallel.py \
--load-ckpt-path $LOAD_CKPT_PATH \
--save-ckpt-path $SAVE_CKPT_PATH \
--tokenizer-path $TOKENIZER_PATH \
--target-tensor-model-parallel-size $MP_SIZE \
--num-layers 39 \
--hidden-size 5120 \
--num-attention-heads 40 \
--max-position-embeddings 2048 \
--attention-softmax-in-fp32 \
--fp16 \
--micro-batch-size 1 \
--make-vocab-size-divisible-by 52224 \
--seq-length 2048"
echo "$CMD"
eval "$CMD"

@ -0,0 +1,25 @@
# Process dataset for CodeGeeX pretraining
DATASET_PATH=$1
OUTPUT_PATH=$2
LANGUAGE=$3
SCRIPT_PATH=$(realpath "$0")
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
MAIN_DIR=$(dirname "$SCRIPT_DIR")
TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
if [ -z "$LANGUAGE" ]; then
LANGUAGE=python
fi
CMD="python $MAIN_DIR/codegeex/data/process_pretrain_dataset.py \
--dataset_path $DATASET_PATH \
--tokenizer_path $TOKENIZER_PATH \
--output_prefix $OUTPUT_PATH \
--language $LANGUAGE \
--mode pretrain \
--seq_len 2048"
echo "$CMD"
eval "$CMD"

@ -31,7 +31,7 @@ CMD="python $MAIN_DIR/tests/test_inference.py \
--out-seq-length 1024 \
--temperature 0.8 \
--top-p 0.95 \
--top-k 100 \
--top-k 0 \
--greedy \
$MODEL_ARGS"

@ -0,0 +1,46 @@
# This script is used to test the inference of CodeGeeX.
MP_SIZE=$1
SCRIPT_PATH=$(realpath "$0")
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
MAIN_DIR=$(dirname "$SCRIPT_DIR")
TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
if [ -z "$MP_SIZE" ]; then
MP_SIZE=1
fi
if [ "$MP_SIZE" -eq 1 ]; then
source "$MAIN_DIR/configs/codegeex_13b.sh"
echo "Load config from $MAIN_DIR/configs/codegeex_13b.sh"
else
source "$MAIN_DIR/configs/codegeex_13b_parallel.sh"
echo "Load config from $MAIN_DIR/configs/codegeex_13b_parallel.sh"
fi
# export CUDA settings
export CUDA_HOME=/usr/local/cuda-11.1/
# export CUDA_VISIBLE_DEVICES=0,1
if [ -z "$PROMPT_FILE" ]; then
PROMPT_FILE=$MAIN_DIR/tests/test_prompt.txt
fi
# remove --greedy if using sampling
CMD="torchrun --nproc_per_node $MP_SIZE $MAIN_DIR/tests/test_inference_megatron.py \
--tensor-model-parallel-size $MP_SIZE \
--prompt-file $PROMPT_FILE \
--tokenizer-path $TOKENIZER_PATH \
--micro-batch-size 1 \
--out-seq-length 1024 \
--temperature 0.8 \
--top-p 0.95 \
--top-k 0 \
--greedy \
--use-cpu-initialization \
--ln-fp16 \
$MODEL_ARGS"
echo "$CMD"
eval "$CMD"

@ -0,0 +1,39 @@
# This script is used to test the inference of CodeGeeX.
GPU=$1
PROMPT_FILE=$2
SCRIPT_PATH=$(realpath "$0")
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
MAIN_DIR=$(dirname "$SCRIPT_DIR")
TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
# import model configuration
source "$MAIN_DIR/configs/codegeex_13b.sh"
# export CUDA settings
if [ -z "$GPU" ]; then
GPU=0
fi
export CUDA_HOME=/usr/local/cuda-11.1/
export CUDA_VISIBLE_DEVICES=$GPU
if [ -z "$PROMPT_FILE" ]; then
PROMPT_FILE=$MAIN_DIR/tests/test_prompt.txt
fi
# remove --greedy if using sampling
CMD="python $MAIN_DIR/tests/test_inference.py \
--prompt-file $PROMPT_FILE \
--tokenizer-path $TOKENIZER_PATH \
--micro-batch-size 1 \
--out-seq-length 1024 \
--temperature 0.2 \
--top-p 0.95 \
--top-k 0 \
--quantize \
$MODEL_ARGS"
echo "$CMD"
eval "$CMD"

@ -1,40 +1,59 @@
import os
import copy
import time
import torch
import random
import argparse
import numpy as np
from codegeex.megatron import get_tokenizer, get_args
from codegeex.megatron.initialize import initialize_megatron
from codegeex.megatron.model import CodeGeeXModel
from codegeex.megatron.code_generation_utils import get_token_stream
torch.set_printoptions(precision=8)
def set_random_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
from codegeex.torch.inference import get_token_stream
from codegeex.torch import CodeGeeXModel
from codegeex.tokenizer import CodeGeeXTokenizer
from codegeex.quantization import quantize
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def model_provider():
def model_provider(args):
"""Build the model."""
model = CodeGeeXModel(num_tokentypes=0,
parallel_output=False)
model = CodeGeeXModel(
args.hidden_size,
args.num_layers,
args.num_attention_heads,
args.padded_vocab_size,
args.max_position_embeddings
)
return model
def add_code_generation_args(parser):
"""Code generation arguments."""
group = parser.add_argument_group(title="code generation")
group.add_argument(
"--num-layers",
type=int,
default=39,
)
group.add_argument(
"--hidden-size",
type=int,
default=5120,
)
group.add_argument(
"--num-attention-heads",
type=int,
default=40,
)
group.add_argument(
"--padded-vocab-size",
type=int,
default=52224,
)
group.add_argument(
"--max-position-embeddings",
type=int,
default=2048,
)
group.add_argument(
"--temperature",
type=float,
@ -65,133 +84,118 @@ def add_code_generation_args(parser):
default=2048,
help="Size of the output generated text.",
)
group.add_argument(
"--recompute",
action="store_true",
help="During generation recompute all attention "
"instead of using previously computed keys/values.",
)
group.add_argument(
"--ws-encoding-start-id",
type=int,
default=10,
help="Start id for whitespace encoding",
)
group.add_argument(
"--ws-encoding-length",
type=int,
default=80,
help="Length of whitespace encoding",
)
group.add_argument(
"--n-generation",
type=int,
default=10,
)
group.add_argument(
"--eos-id",
type=int,
default=50256,
)
group.add_argument(
"--prompt-file",
type=str,
default="./test_prompt.txt",
)
group.add_argument(
"--perf-file",
"--tokenizer-path",
type=str,
default="./perf_out.txt",
default="./tokenizer",
)
group.add_argument(
"--perf-trace",
"--load",
type=str,
default="./perf_out.txt",
)
group.add_argument(
"--use-torch-profile",
action="store_true",
"--state-dict-path",
type=str,
)
group.add_argument(
"--ln-fp32",
action="store_true",
"--micro-batch-size",
type=int,
default=1,
)
group.add_argument(
'--bad-ids',
nargs="*",
type=int,
default=None,
help='Identify the type of programming language to generate',
"--quantize",
action="store_true",
)
return parser
def main():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(random.randint(10000, 20000))
initialize_megatron(
extra_args_provider=add_code_generation_args,
)
args = get_args()
set_random_seed(args.seed)
parser = argparse.ArgumentParser()
parser = add_code_generation_args(parser)
args, _ = parser.parse_known_args()
print("Loading tokenizer ...")
tokenizer = get_tokenizer()
tokenizer = CodeGeeXTokenizer(
tokenizer_path=args.tokenizer_path,
mode="codegeex-13b")
print("Loading state dict ...")
state_dict = torch.load(args.load, map_location="cpu")
state_dict = state_dict["module"]
print("Building CodeGeeX model ...")
model = model_provider()
model = model_provider(args)
model.load_state_dict(state_dict)
model.eval()
if args.fp16 and args.ln_fp16:
model.half()
model.half()
if args.quantize:
model = quantize(model, weight_bit_width=8, backend="torch")
model.cuda()
with open(args.prompt_file, "r") as f:
prompt = f.readlines()
prompt = "".join(prompt)
print("Generating ...")
t0 = time.perf_counter()
for prompt in [prompt]:
tokens = tokenizer.tokenize(prompt)
print(tokens)
print("Current prompt:")
print(prompt)
n_token_prompt = len(tokens)
print("N_token_prompt:", n_token_prompt)
token_stream = get_token_stream(
model,
[copy.deepcopy(tokens) for _ in range(args.micro_batch_size)],
micro_batch_size=args.micro_batch_size,
bad_ids=args.bad_ids,
)
is_finished = [False for _ in range(args.micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]
for j in range(args.micro_batch_size):
if is_finished[j]:
continue
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eod or len(
generated_tokens[j]) >= args.out_seq_length:
is_finished[j] = True
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
t1 = time.perf_counter()
print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt)
print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
print("================================= Generated code:")
print(generated_code)
t0 = time.perf_counter()
if all(is_finished):
break
times = {}
out_seq_lengths = [args.out_seq_length]
micro_batch_size = args.micro_batch_size
seq_length = args.max_position_embeddings
for out_seq_length in out_seq_lengths:
print(f"Generating with out_seq_len {out_seq_length}...")
times[out_seq_length] = []
for prompt in [prompt]:
t0 = time.perf_counter()
tokens = tokenizer.encode_code(prompt)
print(tokens)
print("Current prompt:")
print(prompt)
n_token_prompt = len(tokens)
print("N_token_prompt:", n_token_prompt)
token_stream = get_token_stream(
model,
tokenizer,
seq_length,
out_seq_length,
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
micro_batch_size=micro_batch_size,
topk=args.top_k,
topp=args.top_p,
temperature=args.temperature,
greedy=args.greedy,
)
is_finished = [False for _ in range(micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]
for j in range(micro_batch_size):
if is_finished[j]:
continue
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len(
generated_tokens[j]) >= out_seq_length:
is_finished[j] = True
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:])
generated_code = "".join(generated_code)
t1 = time.perf_counter()
print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt)
print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
times[out_seq_length].append(t1 - t0)
print("================================= Generated code:")
print(generated_code)
if all(is_finished):
break
print(times)
for out_seq_length in times.keys():
print(out_seq_length, np.mean(times[out_seq_length]))
print("Generation finished.")

@ -0,0 +1,209 @@
import copy
import time
import torch
import numpy as np
from codegeex.megatron import get_tokenizer, get_args, print_rank_0
from codegeex.megatron.initialize import initialize_megatron
from codegeex.megatron.model import CodeGeeXModel
from codegeex.megatron.code_generation_utils import get_token_stream
from codegeex.quantization import quantize
from codegeex.megatron.training import get_model
from codegeex.megatron.checkpointing import load_checkpoint
torch.set_printoptions(precision=8)
def set_random_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0("Building CodeGeeX model ...")
model = CodeGeeXModel(num_tokentypes=0,
parallel_output=False)
return model
def add_code_generation_args(parser):
"""Code generation arguments."""
group = parser.add_argument_group(title="code generation")
group.add_argument(
"--temperature",
type=float,
default=1.0,
help="Sampling temperature.",
)
group.add_argument(
"--greedy",
action="store_true",
default=False,
help="Use greedy sampling.",
)
group.add_argument(
"--top-p",
type=float,
default=0.0,
help="Top p sampling.",
)
group.add_argument(
"--top-k",
type=int,
default=0,
help="Top k sampling.",
)
group.add_argument(
"--out-seq-length",
type=int,
default=2048,
help="Size of the output generated text.",
)
group.add_argument(
"--recompute",
action="store_true",
help="During generation recompute all attention "
"instead of using previously computed keys/values.",
)
group.add_argument(
"--ws-encoding-start-id",
type=int,
default=10,
help="Start id for whitespace encoding",
)
group.add_argument(
"--ws-encoding-length",
type=int,
default=10,
help="Length of whitespace encoding",
)
group.add_argument(
"--n-generation",
type=int,
default=10,
)
group.add_argument(
"--eos-id",
type=int,
default=50256,
)
group.add_argument(
"--prompt-file",
type=str,
default="./test_prompt.txt",
)
group.add_argument(
"--perf-file",
type=str,
default="./perf_out.txt",
)
group.add_argument(
"--perf-trace",
type=str,
default="./perf_out.txt",
)
group.add_argument(
"--use-torch-profile",
action="store_true",
)
group.add_argument(
"--ln-fp32",
action="store_true",
)
group.add_argument(
'--bad-ids',
nargs="*",
type=int,
default=None,
help='Identify the type of programming language to generate',
)
group.add_argument(
"--quantize",
action="store_true",
)
return parser
def main():
initialize_megatron(
extra_args_provider=add_code_generation_args,
args_defaults={
'no_load_rng': True,
'no_load_optim': True,
}
)
args = get_args()
set_random_seed(args.seed)
print_rank_0("Loading tokenizer ...")
tokenizer = get_tokenizer()
print_rank_0("Loading state dict ...")
model = get_model(model_provider)
if args.load is not None:
_ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
model.eval()
if args.fp16 and args.ln_fp16:
model.half()
if args.quantize:
model = quantize(model, weight_bit_width=8, backend="megatron")
with open(args.prompt_file, "r") as f:
prompt = f.readlines()
prompt = "".join(prompt)
print_rank_0("Generating ...")
t0 = time.perf_counter()
for prompt in [prompt]:
tokens = tokenizer.tokenize(prompt)
print_rank_0(tokens)
print_rank_0("Current prompt:")
print_rank_0(prompt)
n_token_prompt = len(tokens)
print_rank_0(f"N_token_prompt: {n_token_prompt}")
token_stream = get_token_stream(
model,
[copy.deepcopy(tokens) for _ in range(args.micro_batch_size)],
micro_batch_size=args.micro_batch_size,
bad_ids=args.bad_ids,
)
is_finished = [False for _ in range(args.micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]
for j in range(args.micro_batch_size):
if is_finished[j]:
continue
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eod or len(
generated_tokens[j]) >= args.out_seq_length:
is_finished[j] = True
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
t1 = time.perf_counter()
print_rank_0(f"Total generation time: {t1 - t0}, # Tokens: {len(generated_tokens_) - n_token_prompt}")
print_rank_0(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
print_rank_0("================================= Generated code:")
print_rank_0(generated_code)
t0 = time.perf_counter()
if all(is_finished):
break
print_rank_0("Generation finished.")
if __name__ == "__main__":
main()
Loading…
Cancel
Save