mirror of https://github.com/THUDM/CodeGeeX.git
Merge pull request #27 from THUDM/develop
Add quantization, parallelism, data processing, and other minor changes.pull/32/head
commit
593ef9e231
@ -0,0 +1,100 @@
|
||||
import gzip
|
||||
import json
|
||||
|
||||
from typing import *
|
||||
|
||||
|
||||
LANGUAGE_TAG = {
|
||||
"c++" : "// language: C++",
|
||||
"cpp" : "// language: C++",
|
||||
"c" : "// language: C",
|
||||
"c#" : "// language: C#",
|
||||
"cuda" : "// language: Cuda",
|
||||
"objective-c" : "// language: Objective-C",
|
||||
"objective-c++": "// language: Objective-C++",
|
||||
"python" : "# language: Python",
|
||||
"java" : "// language: Java",
|
||||
"scala" : "// language: Scala",
|
||||
"tex" : f"% language: TeX",
|
||||
"html" : "<!--language: HTML-->",
|
||||
"php" : "// language: PHP",
|
||||
"js" : "// language: JavaScript",
|
||||
"javascript" : "// language: JavaScript",
|
||||
"typescript" : "// language: TypeScript",
|
||||
"go" : "// language: Go",
|
||||
"shell" : "# language: Shell",
|
||||
"rust" : "// language: Rust",
|
||||
"css" : "/* language: CSS */",
|
||||
"sql" : "-- language: SQL",
|
||||
"kotlin" : "// language: Kotlin",
|
||||
"pascal" : "// language: Pascal",
|
||||
"r" : "# language: R",
|
||||
"fortran" : "!language: Fortran",
|
||||
"lean" : "-- language: Lean",
|
||||
}
|
||||
|
||||
|
||||
def stream_jsonl(filename: str) -> Iterable[Dict]:
|
||||
"""
|
||||
Parses each jsonl line and yields it as a dictionary
|
||||
"""
|
||||
if filename.endswith(".gz"):
|
||||
with open(filename, "rb") as gzfp:
|
||||
with gzip.open(gzfp, "rt") as fp:
|
||||
for line in fp:
|
||||
if any(not x.isspace() for x in line):
|
||||
yield json.loads(line)
|
||||
else:
|
||||
with open(filename, "r") as fp:
|
||||
for line in fp:
|
||||
if any(not x.isspace() for x in line):
|
||||
yield json.loads(line)
|
||||
|
||||
|
||||
def write_jsonl(filename: str, data: Iterable[Dict], append: bool = False):
|
||||
"""
|
||||
Writes an iterable of dictionaries to jsonl
|
||||
"""
|
||||
if append:
|
||||
mode = "ab"
|
||||
else:
|
||||
mode = "wb"
|
||||
filename = os.path.expanduser(filename)
|
||||
if filename.endswith(".gz"):
|
||||
with open(filename, mode) as fp:
|
||||
with gzip.GzipFile(fileobj=fp, mode="wb") as gzfp:
|
||||
for x in data:
|
||||
gzfp.write((json.dumps(x) + "\n").encode("utf-8"))
|
||||
else:
|
||||
with open(filename, mode) as fp:
|
||||
for x in data:
|
||||
fp.write((json.dumps(x) + "\n").encode("utf-8"))
|
||||
|
||||
|
||||
def sliding_window(
|
||||
prompt_tokens: list,
|
||||
code_tokens: list,
|
||||
seq_len: int,
|
||||
sliding_stride: int,
|
||||
minimum_code_len: int = 1,
|
||||
) -> Iterable[Tuple[list, list]]:
|
||||
"""
|
||||
Generate a series of (prompt, code) pairs by sliding the window over the code.
|
||||
"""
|
||||
prompt_len = len(prompt_tokens)
|
||||
code_len = len(code_tokens)
|
||||
total_len = prompt_len + code_len
|
||||
|
||||
start_idx = max(0, prompt_len - seq_len + minimum_code_len) # at least `minimum_code_len` code token should be in the window
|
||||
end_idx = max(0, total_len - seq_len)
|
||||
start_idx = min(start_idx, end_idx)
|
||||
|
||||
for i in range(start_idx, end_idx + 1, sliding_stride):
|
||||
current_prompt = prompt_tokens[i:i + seq_len]
|
||||
current_code = code_tokens[max(i - prompt_len, 0):i - prompt_len + seq_len]
|
||||
yield current_prompt, current_code
|
||||
|
||||
if (end_idx - start_idx) % sliding_stride != 0:
|
||||
current_prompt = prompt_tokens[end_idx:end_idx + seq_len]
|
||||
current_code = code_tokens[max(end_idx - prompt_len, 0):end_idx - prompt_len + seq_len]
|
||||
yield current_prompt, current_code
|
@ -0,0 +1,144 @@
|
||||
import os
|
||||
import glob
|
||||
import fire
|
||||
import torch
|
||||
import multiprocessing
|
||||
|
||||
from typing import *
|
||||
from tqdm.auto import tqdm
|
||||
from time import perf_counter
|
||||
from black import format_str, FileMode
|
||||
|
||||
from codegeex.data.types import PromptDataset, PromptSample
|
||||
from codegeex.data.processor import PromptDatasetProcessor
|
||||
from codegeex.data.data_utils import stream_jsonl, LANGUAGE_TAG
|
||||
from codegeex.megatron.data.indexed_dataset import make_mmap_builder
|
||||
from codegeex.tokenizer import CodeGeeXTokenizer
|
||||
|
||||
|
||||
def try_format_code(code: str):
|
||||
# Auto-correct to PEP8 format (Change tab to 4-whitespaces;
|
||||
# add whitespace around some special symbols;
|
||||
# reformat line length < 100, etc.)
|
||||
try:
|
||||
res = format_str(code, mode=FileMode(line_length=200))
|
||||
except Exception as e:
|
||||
res = code
|
||||
print(e)
|
||||
print("Wrong python format: {}".format(code))
|
||||
return res
|
||||
|
||||
|
||||
def load_pretrain_dataset(dataset_path: Union[str, List[str]]) -> Dict:
|
||||
if type(dataset_path) is str:
|
||||
dataset_path = [dataset_path]
|
||||
|
||||
for p in dataset_path:
|
||||
if not os.path.isdir(p):
|
||||
if p.endswith(".gz") or p.endswith(".jsonl"):
|
||||
print(f"loading from {p}")
|
||||
yield from stream_jsonl(p)
|
||||
else:
|
||||
p_list = glob.glob(p + "/*")
|
||||
for p_ in p_list:
|
||||
if p_.endswith(".gz") or p_.endswith(".jsonl"):
|
||||
print(f"loading from {p_}")
|
||||
yield from stream_jsonl(p_)
|
||||
|
||||
|
||||
def process_sample(
|
||||
sample: Dict,
|
||||
language: str=None,
|
||||
mode: str="pretrain",
|
||||
) -> Iterable[PromptSample]:
|
||||
if mode == "pretrain":
|
||||
prompt = ""
|
||||
else:
|
||||
prompt = sample["prompt"]
|
||||
|
||||
try:
|
||||
if language is not None and language in LANGUAGE_TAG.keys():
|
||||
code = LANGUAGE_TAG[language] + sample["code"]
|
||||
else:
|
||||
code = sample["code"]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("The key 'code' is missing in data. Aborted")
|
||||
exit(0)
|
||||
|
||||
yield PromptSample(prompt, code)
|
||||
|
||||
|
||||
def generate_prompt_samples(
|
||||
dataset: Iterable[Dict],
|
||||
language: str = None,
|
||||
mode: str = "pretrain",
|
||||
) -> PromptDataset:
|
||||
for sample in dataset:
|
||||
yield from process_sample(sample, language, mode)
|
||||
|
||||
|
||||
def main(
|
||||
tokenizer_path: str,
|
||||
dataset_path: Union[str, List[str]],
|
||||
output_prefix: str,
|
||||
language: str = None,
|
||||
mode: str = "pretrain",
|
||||
discard_overlong: bool = False,
|
||||
sliding_stride: int = 200,
|
||||
num_workers: int = 32,
|
||||
seq_len: int = 2048,
|
||||
):
|
||||
DATA_KEYS = ["input_ids", "attention_mask", "labels"]
|
||||
|
||||
# create output dir
|
||||
os.makedirs(os.path.dirname(output_prefix), exist_ok=True)
|
||||
|
||||
tokenizer = CodeGeeXTokenizer(tokenizer_path=tokenizer_path)
|
||||
pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
dataset = load_pretrain_dataset(dataset_path)
|
||||
prompt_dataset = generate_prompt_samples(dataset, language=language, mode=mode)
|
||||
|
||||
if num_workers == 0:
|
||||
num_workers = multiprocessing.cpu_count()
|
||||
pool = multiprocessing.Pool(num_workers)
|
||||
output_bin_files = {}
|
||||
output_idx_files = {}
|
||||
builders = {}
|
||||
|
||||
for key in DATA_KEYS:
|
||||
output_bin_files[key] = "{}_{}.bin".format(output_prefix, key)
|
||||
output_idx_files[key] = "{}_{}.idx".format(output_prefix, key)
|
||||
builders[key] = make_mmap_builder(
|
||||
output_bin_files[key],
|
||||
vocab_size=None, # magic number, should change it
|
||||
)
|
||||
|
||||
# NOTE that we use seq_len + 1 instead of seq_len, since the input tokens will be shifted by one.
|
||||
processor = PromptDatasetProcessor(
|
||||
tokenize=tokenizer.encode_code,
|
||||
pad_token=pad_token_id,
|
||||
max_seq_len=seq_len + 1,
|
||||
discard_overlong=discard_overlong,
|
||||
sliding_stride=sliding_stride,
|
||||
eod_token=pad_token_id)
|
||||
|
||||
processor.start_time = perf_counter()
|
||||
doc_iter = pool.imap_unordered(processor.process_sample_strict,
|
||||
prompt_dataset,
|
||||
chunksize=20)
|
||||
|
||||
for doc_idx, docs in tqdm(enumerate(doc_iter, start=1)):
|
||||
processor.doc_processed += 1
|
||||
for doc in docs:
|
||||
processor.doc_generated += 1
|
||||
for key in DATA_KEYS:
|
||||
builders[key].add_item(torch.IntTensor(doc[key]))
|
||||
|
||||
for key in DATA_KEYS:
|
||||
builders[key].finalize(output_idx_files[key])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
@ -0,0 +1,155 @@
|
||||
from typing import *
|
||||
from time import perf_counter
|
||||
|
||||
from codegeex.data.data_utils import sliding_window
|
||||
from codegeex.data.types import PromptSample, LabelSample
|
||||
|
||||
|
||||
class PromptDatasetProcessor(object):
|
||||
def __init__(
|
||||
self,
|
||||
tokenize: Callable,
|
||||
pad_token: int,
|
||||
keep_order: bool = False,
|
||||
max_seq_len: int = 2048,
|
||||
sliding_stride: int = 200,
|
||||
discard_overlong: bool = True,
|
||||
eod_token: int = None,
|
||||
preprocess: Callable = None,
|
||||
):
|
||||
super(PromptDatasetProcessor, self).__init__()
|
||||
self._keep_order = keep_order
|
||||
self._max_seq_len = max_seq_len
|
||||
self._sliding_stride = sliding_stride
|
||||
self._tokenize = tokenize
|
||||
self._pad_token = pad_token
|
||||
self._discard_overlong = discard_overlong
|
||||
self._eod_token = eod_token
|
||||
self._preprocess = preprocess
|
||||
|
||||
self.doc_processed = 0
|
||||
self.doc_generated = 0
|
||||
self.start_time = 0
|
||||
|
||||
def pad_seq(self, prompt_tokens: List[int], code_tokens: List[int], extra: dict = None) -> Dict[str, List[int]]:
|
||||
total_length = len(prompt_tokens) + len(code_tokens)
|
||||
assert total_length <= self._max_seq_len, f"padding sequence: {total_length} > {self._max_seq_len}"
|
||||
pad_len = self._max_seq_len - total_length
|
||||
input_ids = prompt_tokens + code_tokens + [self._pad_token] * pad_len
|
||||
attention_mask = [1] * len(prompt_tokens) + [1] * len(code_tokens) + [0] * pad_len
|
||||
labels = [-100] * len(prompt_tokens) + code_tokens + [-100] * pad_len
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
def process_sample(self, sample: PromptSample) -> Iterable[Dict[str, List[int]]]:
|
||||
"""
|
||||
Process a sample.
|
||||
"""
|
||||
prompt_tokens = self._tokenize(sample.prompt)
|
||||
code_tokens = self._tokenize(sample.code)
|
||||
|
||||
if self._eod_token is not None:
|
||||
code_tokens.append(self._eod_token)
|
||||
|
||||
if len(prompt_tokens) + len(code_tokens) > self._max_seq_len:
|
||||
if self._discard_overlong:
|
||||
return
|
||||
for p, t in sliding_window(prompt_tokens, code_tokens, self._max_seq_len, self._sliding_stride, self._sliding_stride):
|
||||
yield self.pad_seq(p, t)
|
||||
else:
|
||||
yield self.pad_seq(prompt_tokens, code_tokens, extra=sample.extra)
|
||||
|
||||
def process_sample_strict(self, sample: PromptSample) -> List[Dict[str, List[int]]]:
|
||||
"""
|
||||
Instead of processing lazily, we turn the iterable into a list.
|
||||
"""
|
||||
return list(self.process_sample(sample))
|
||||
|
||||
def process_sample_(self, sample) -> List[Dict[str, List[int]]]:
|
||||
prompt_sample = self._preprocess(sample)
|
||||
return self.process_sample_strict(prompt_sample)
|
||||
|
||||
def report(self):
|
||||
duration = perf_counter() - self.start_time
|
||||
process_speed = self.doc_processed * 1.0 / duration
|
||||
gen_speed = self.doc_generated * 1.0 / duration
|
||||
print(f">>> processed: {self.doc_processed} in {duration:.2f}s, speed: {process_speed:.2f} docs/s")
|
||||
print(f"... generated: {self.doc_generated} in {duration:.2f}s, speed: {gen_speed:.2f} docs/s")
|
||||
|
||||
|
||||
|
||||
class LabelDatasetProcessor(object):
|
||||
def __init__(
|
||||
self,
|
||||
tokenize: Callable,
|
||||
pad_token: int,
|
||||
keep_order: bool = False,
|
||||
max_seq_len: int = 2048,
|
||||
sliding_stride: int = 200,
|
||||
discard_overlong: bool = True,
|
||||
eod_token: int = None,
|
||||
preprocess: Callable = None,
|
||||
):
|
||||
super(LabelDatasetProcessor, self).__init__()
|
||||
self._keep_order = keep_order
|
||||
self._max_seq_len = max_seq_len
|
||||
self._sliding_stride = sliding_stride
|
||||
self._tokenize = tokenize
|
||||
self._pad_token = pad_token
|
||||
self._discard_overlong = discard_overlong
|
||||
self._eod_token = eod_token
|
||||
self._preprocess = preprocess
|
||||
|
||||
self.doc_processed = 0
|
||||
self.doc_generated = 0
|
||||
self.start_time = 0
|
||||
|
||||
def pad_seq(self, prompt_tokens: List[int], label: int, extra: dict = None) -> Dict[str, List[int]]:
|
||||
total_length = len(prompt_tokens)
|
||||
assert total_length <= self._max_seq_len, f"padding sequence: {total_length} > {self._max_seq_len}"
|
||||
pad_len = self._max_seq_len - total_length
|
||||
input_ids = prompt_tokens + [self._pad_token] * pad_len
|
||||
attention_mask = [1] * len(prompt_tokens) + [0] * pad_len
|
||||
label = [label]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"length": [len(prompt_tokens)],
|
||||
"labels": label
|
||||
}
|
||||
def process_sample(self, sample: LabelSample) -> Iterable[Dict[str, List[int]]]:
|
||||
"""
|
||||
Process a sample.
|
||||
"""
|
||||
prompt_tokens = self._tokenize(sample.prompt)
|
||||
label = sample.label
|
||||
|
||||
|
||||
if len(prompt_tokens) > self._max_seq_len:
|
||||
if self._discard_overlong:
|
||||
return
|
||||
prompt_tokens=prompt_tokens[-self._max_seq_len:]
|
||||
|
||||
yield self.pad_seq(prompt_tokens, label, extra=sample.extra)
|
||||
|
||||
def process_sample_strict(self, sample: LabelSample) -> List[Dict[str, List[int]]]:
|
||||
"""
|
||||
Instead of processing lazily, we turn the iterable into a list.
|
||||
"""
|
||||
return list(self.process_sample(sample))
|
||||
|
||||
def process_sample_(self, sample) -> List[Dict[str, List[int]]]:
|
||||
prompt_sample = self._preprocess(sample)
|
||||
return self.process_sample_strict(prompt_sample)
|
||||
|
||||
def report(self):
|
||||
duration = perf_counter() - self.start_time
|
||||
process_speed = self.doc_processed * 1.0 / duration
|
||||
gen_speed = self.doc_generated * 1.0 / duration
|
||||
print(f">>> processed: {self.doc_processed} in {duration:.2f}s, speed: {process_speed:.2f} docs/s")
|
||||
print(f"... generated: {self.doc_generated} in {duration:.2f}s, speed: {gen_speed:.2f} docs/s")
|
@ -0,0 +1,20 @@
|
||||
from typing import *
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptSample:
|
||||
prompt: str
|
||||
code: str
|
||||
extra: dict = None
|
||||
|
||||
|
||||
PromptDataset = Iterable[PromptSample]
|
||||
|
||||
@dataclass
|
||||
class LabelSample:
|
||||
prompt: str
|
||||
label: int
|
||||
extra: dict = None
|
||||
|
||||
LabelDataset = Iterable[LabelSample]
|
@ -0,0 +1,99 @@
|
||||
import pkg_resources
|
||||
import torch
|
||||
import ctypes
|
||||
|
||||
from typing import List
|
||||
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
|
||||
|
||||
RESOURCE_PACKAGE_NAME = __name__
|
||||
|
||||
|
||||
class Kernel:
|
||||
def __init__(self, filename: str, function_names: List[str]):
|
||||
filename = filename + ".fatbin"
|
||||
if not pkg_resources.resource_exists(RESOURCE_PACKAGE_NAME, filename):
|
||||
raise RuntimeError("File `%s` not found in `%s`" % (filename, RESOURCE_PACKAGE_NAME))
|
||||
self.filename = filename
|
||||
self.code = pkg_resources.resource_string(RESOURCE_PACKAGE_NAME, filename)
|
||||
self._function_names = function_names
|
||||
self._cmodule = LazyKernelCModule(self.code)
|
||||
|
||||
for name in self._function_names:
|
||||
setattr(self, name, KernelFunction(self._cmodule, name))
|
||||
|
||||
|
||||
kernels = Kernel(
|
||||
"quantization",
|
||||
[
|
||||
"int4WeightCompression",
|
||||
"int4WeightExtractionFloat",
|
||||
"int4WeightExtractionHalf",
|
||||
"int8WeightExtractionFloat",
|
||||
"int8WeightExtractionHalf",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def compress_int4_weight(weight: torch.Tensor): # (n, m)
|
||||
with torch.cuda.device(weight.device):
|
||||
n, m = weight.size(0), weight.size(1)
|
||||
assert m % 2 == 0
|
||||
m = m // 2
|
||||
out = torch.empty(n, m, dtype=torch.int8, device="cuda")
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
gridDim = (n, 1, 1)
|
||||
blockDim = (min(round_up(m, 32), 1024), 1, 1)
|
||||
|
||||
kernels.int4WeightCompression(
|
||||
gridDim,
|
||||
blockDim,
|
||||
0,
|
||||
stream,
|
||||
[ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
|
||||
if source_bit_width == 8:
|
||||
func = kernels.int8WeightExtractionHalf
|
||||
elif source_bit_width == 4:
|
||||
func = kernels.int4WeightExtractionHalf
|
||||
else:
|
||||
assert False, "Unsupported bit-width"
|
||||
|
||||
with torch.cuda.device(weight.device):
|
||||
n, m = weight.size(0), weight.size(1)
|
||||
out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda")
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
gridDim = (n, 1, 1)
|
||||
blockDim = (min(round_up(m, 32), 1024), 1, 1)
|
||||
|
||||
func(
|
||||
gridDim,
|
||||
blockDim,
|
||||
0,
|
||||
stream,
|
||||
[
|
||||
ctypes.c_void_p(weight.data_ptr()),
|
||||
ctypes.c_void_p(scale_list.data_ptr()),
|
||||
ctypes.c_void_p(out.data_ptr()),
|
||||
ctypes.c_int32(n),
|
||||
ctypes.c_int32(m),
|
||||
],
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
weight = torch.randn(4, 32).to(torch.int8).cuda()
|
||||
scale = torch.ones(weight.size(0)).to(torch.half).cuda()
|
||||
|
||||
print(weight)
|
||||
b = compress_int4_weight(weight)
|
||||
print(b)
|
||||
|
||||
a = extract_weight_to_half(b, scale, source_bit_width=4)
|
||||
print(a)
|
Binary file not shown.
@ -0,0 +1,145 @@
|
||||
"""Get model parallel partitions."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
|
||||
os.path.pardir)))
|
||||
|
||||
from codegeex.megatron import get_args
|
||||
from codegeex.megatron.model import CodeGeeXModel
|
||||
from codegeex.megatron.initialize import initialize_megatron
|
||||
from codegeex.megatron.checkpointing import ensure_directory_exists
|
||||
|
||||
|
||||
def get_change_ckpt_args(parser):
|
||||
"""Provide extra arguments required for merging."""
|
||||
group = parser.add_argument_group(title='Mindspore to megatron')
|
||||
group.add_argument(
|
||||
'--load-ckpt-path',
|
||||
type=str,
|
||||
required=True,
|
||||
help='path to load ".pt" checkpoint.',
|
||||
)
|
||||
group.add_argument(
|
||||
'--save-ckpt-path',
|
||||
type=str,
|
||||
required=True,
|
||||
help='dir to save converted checkpoints.',
|
||||
)
|
||||
group.add_argument(
|
||||
'--target-tensor-model-parallel-size',
|
||||
type=int,
|
||||
default=2,
|
||||
help='target tensor model parallel size',
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_element_from_dict_by_path(d, path):
|
||||
"""
|
||||
Get element from dictionary by path. If element is not present, recursively add empty dictionaries.
|
||||
Args:
|
||||
d (dict): the dictionary to get the element from
|
||||
path (list): the path to the element which is delimited by "."
|
||||
"""
|
||||
path = path.split(".")
|
||||
for k in path:
|
||||
if k not in d:
|
||||
d[k] = {}
|
||||
d = d[k]
|
||||
return d
|
||||
|
||||
|
||||
def main():
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(random.randint(10000, 20000))
|
||||
|
||||
initialize_megatron(
|
||||
extra_args_provider=get_change_ckpt_args,
|
||||
args_defaults={
|
||||
"tokenizer_type": "GPT2BPETokenizer",
|
||||
"no_load_rng" : True,
|
||||
"no_load_optim" : True,
|
||||
},
|
||||
)
|
||||
|
||||
args = get_args()
|
||||
print(f"Load ckpt from {args.load_ckpt_path}...")
|
||||
state_dict = torch.load(args.load_ckpt_path, map_location="cpu")
|
||||
|
||||
print(f"Spliting ckpt into {args.target_tensor_model_parallel_size} parts...")
|
||||
output_state_dict = []
|
||||
for i in range(args.target_tensor_model_parallel_size):
|
||||
output_state_dict.append({})
|
||||
|
||||
print("Converting Embedding layers...")
|
||||
word_embeddings = state_dict['module']['language_model']['embedding']['word_embeddings']['weight']
|
||||
position_embeddings = state_dict['module']['language_model']['embedding']['position_embeddings']['weight']
|
||||
out_word_embeddings = torch.chunk(word_embeddings, args.target_tensor_model_parallel_size, dim=0)
|
||||
|
||||
for i in range(args.target_tensor_model_parallel_size):
|
||||
pos_emb_dict = get_element_from_dict_by_path(
|
||||
output_state_dict[i], "module.language_model.embedding.position_embeddings"
|
||||
)
|
||||
pos_emb_dict["weight"] = position_embeddings
|
||||
|
||||
word_emb_dict = get_element_from_dict_by_path(
|
||||
output_state_dict[i], "module.language_model.embedding.word_embeddings"
|
||||
)
|
||||
word_emb_dict["weight"] = out_word_embeddings[i]
|
||||
|
||||
print("Converting QueryEmbedding layers...")
|
||||
query_embeddings = state_dict['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight']
|
||||
out_query_embeddings = torch.chunk(query_embeddings, args.target_tensor_model_parallel_size, dim=0)
|
||||
|
||||
for i in range(args.target_tensor_model_parallel_size):
|
||||
query_emb_dict = get_element_from_dict_by_path(
|
||||
output_state_dict[i], "module.language_model.topQueryEmbedding.top_query_embeddings"
|
||||
)
|
||||
query_emb_dict["weight"] = out_query_embeddings[i]
|
||||
|
||||
print("Converting Transformer layers...")
|
||||
for layer_name in state_dict['module']['language_model']['transformer'].keys():
|
||||
params = state_dict['module']['language_model']['transformer'][layer_name]
|
||||
if "layernorm" in layer_name:
|
||||
pass
|
||||
elif "attention" in layer_name and "weight" in layer_name:
|
||||
if "dense" in layer_name:
|
||||
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=1)
|
||||
else:
|
||||
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=0)
|
||||
elif "weight" in layer_name and "dense" in layer_name:
|
||||
if "h_to_4h" in layer_name:
|
||||
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=0)
|
||||
else:
|
||||
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=1)
|
||||
elif "bias" in layer_name:
|
||||
if "dense" not in layer_name or "mlp" in layer_name:
|
||||
if "4h_to_h" in layer_name:
|
||||
pass
|
||||
else:
|
||||
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=0)
|
||||
|
||||
for i in range(args.target_tensor_model_parallel_size):
|
||||
params_dict = get_element_from_dict_by_path(output_state_dict[i], "module.language_model.transformer")
|
||||
if type(params) is tuple:
|
||||
params_dict[layer_name] = params[i]
|
||||
else:
|
||||
params_dict[layer_name] = params
|
||||
|
||||
os.makedirs(args.save_ckpt_path, exist_ok=True)
|
||||
for rank in range(args.target_tensor_model_parallel_size):
|
||||
save_ckpt_path = os.path.join(args.save_ckpt_path, f"mp_rank_{rank:02d}_model_states.pt")
|
||||
torch.save(output_state_dict[rank], save_ckpt_path)
|
||||
print(f"Converted checkpoint saved in {save_ckpt_path}.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,69 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Blendable dataset."""
|
||||
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from codegeex.megatron import print_rank_0
|
||||
|
||||
|
||||
class BlendableDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, datasets, weights):
|
||||
|
||||
self.datasets = datasets
|
||||
num_datasets = len(datasets)
|
||||
assert num_datasets == len(weights)
|
||||
|
||||
self.size = 0
|
||||
for dataset in self.datasets:
|
||||
self.size += len(dataset)
|
||||
|
||||
# Normalize weights.
|
||||
weights = np.array(weights, dtype=np.float64)
|
||||
sum_weights = np.sum(weights)
|
||||
assert sum_weights > 0.0
|
||||
weights /= sum_weights
|
||||
|
||||
# Build indecies.
|
||||
start_time = time.time()
|
||||
assert num_datasets < 255
|
||||
self.dataset_index = np.zeros(self.size, dtype=np.uint8)
|
||||
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
|
||||
|
||||
from megatron.data import helpers
|
||||
|
||||
helpers.build_blending_indices(
|
||||
self.dataset_index,
|
||||
self.dataset_sample_index,
|
||||
weights,
|
||||
num_datasets,
|
||||
self.size,
|
||||
torch.distributed.get_rank() == 0,
|
||||
)
|
||||
print_rank_0(
|
||||
"> elapsed time for building blendable dataset indices: "
|
||||
"{:.2f} (sec)".format(time.time() - start_time)
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
def __getitem__(self, idx):
|
||||
dataset_idx = self.dataset_index[idx]
|
||||
sample_idx = self.dataset_sample_index[idx]
|
||||
return self.datasets[dataset_idx][sample_idx]
|
@ -0,0 +1,185 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Dataloaders."""
|
||||
|
||||
|
||||
import torch
|
||||
from codegeex.megatron import get_args
|
||||
from codegeex.megatron import mpu
|
||||
|
||||
|
||||
def build_pretraining_data_loader(dataset, consumed_samples):
|
||||
"""Buld dataloader given an input dataset."""
|
||||
|
||||
if dataset is None:
|
||||
return None
|
||||
args = get_args()
|
||||
|
||||
# Megatron sampler
|
||||
if args.dataloader_type == "single":
|
||||
batch_sampler = MegatronPretrainingSampler(
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=args.micro_batch_size,
|
||||
data_parallel_rank=mpu.get_data_parallel_rank(),
|
||||
data_parallel_size=mpu.get_data_parallel_world_size(),
|
||||
)
|
||||
elif args.dataloader_type == "cyclic":
|
||||
batch_sampler = MegatronPretrainingRandomSampler(
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=args.micro_batch_size,
|
||||
data_parallel_rank=mpu.get_data_parallel_rank(),
|
||||
data_parallel_size=mpu.get_data_parallel_world_size(),
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"{} dataloader type is not supported.".format(args.dataloader_type)
|
||||
)
|
||||
|
||||
# Torch dataloader.
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
|
||||
class MegatronPretrainingSampler:
|
||||
def __init__(
|
||||
self,
|
||||
total_samples,
|
||||
consumed_samples,
|
||||
micro_batch_size,
|
||||
data_parallel_rank,
|
||||
data_parallel_size,
|
||||
drop_last=True,
|
||||
):
|
||||
# Keep a copy of input params for later use.
|
||||
self.total_samples = total_samples
|
||||
self.consumed_samples = consumed_samples
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
self.micro_batch_times_data_parallel_size = (
|
||||
self.micro_batch_size * data_parallel_size
|
||||
)
|
||||
self.drop_last = drop_last
|
||||
|
||||
# Sanity checks.
|
||||
assert self.total_samples > 0, "no sample to consume: {}".format(
|
||||
self.total_samples
|
||||
)
|
||||
assert (
|
||||
self.consumed_samples < self.total_samples
|
||||
), "no samples left to consume: {}, {}".format(
|
||||
self.consumed_samples, self.total_samples
|
||||
)
|
||||
assert self.micro_batch_size > 0
|
||||
assert data_parallel_size > 0
|
||||
assert (
|
||||
self.data_parallel_rank < data_parallel_size
|
||||
), "data_parallel_rank should be smaller than data size: {}, " "{}".format(
|
||||
self.data_parallel_rank, data_parallel_size
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
|
||||
def get_start_end_idx(self):
|
||||
start_idx = self.data_parallel_rank * self.micro_batch_size
|
||||
end_idx = start_idx + self.micro_batch_size
|
||||
return start_idx, end_idx
|
||||
|
||||
def __iter__(self):
|
||||
batch = []
|
||||
# Last batch will be dropped if drop_last is not set False
|
||||
for idx in range(self.consumed_samples, self.total_samples):
|
||||
batch.append(idx)
|
||||
if len(batch) == self.micro_batch_times_data_parallel_size:
|
||||
start_idx, end_idx = self.get_start_end_idx()
|
||||
yield batch[start_idx:end_idx]
|
||||
batch = []
|
||||
|
||||
# Check the last partial batch and see drop_last is set
|
||||
if len(batch) > 0 and not self.drop_last:
|
||||
start_idx, end_idx = self.get_start_end_idx()
|
||||
yield batch[start_idx:end_idx]
|
||||
|
||||
|
||||
class MegatronPretrainingRandomSampler:
|
||||
def __init__(
|
||||
self,
|
||||
total_samples,
|
||||
consumed_samples,
|
||||
micro_batch_size,
|
||||
data_parallel_rank,
|
||||
data_parallel_size,
|
||||
):
|
||||
# Keep a copy of input params for later use.
|
||||
self.total_samples = total_samples
|
||||
self.consumed_samples = consumed_samples
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
self.data_parallel_size = data_parallel_size
|
||||
self.micro_batch_times_data_parallel_size = (
|
||||
self.micro_batch_size * data_parallel_size
|
||||
)
|
||||
self.last_batch_size = (
|
||||
self.total_samples % self.micro_batch_times_data_parallel_size
|
||||
)
|
||||
|
||||
# Sanity checks.
|
||||
assert self.total_samples > 0, "no sample to consume: {}".format(
|
||||
self.total_samples
|
||||
)
|
||||
assert self.micro_batch_size > 0
|
||||
assert data_parallel_size > 0
|
||||
assert (
|
||||
self.data_parallel_rank < data_parallel_size
|
||||
), "data_parallel_rank should be smaller than data size: {}, " "{}".format(
|
||||
self.data_parallel_rank, data_parallel_size
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
|
||||
def __iter__(self):
|
||||
active_total_samples = self.total_samples - self.last_batch_size
|
||||
self.epoch = self.consumed_samples // active_total_samples
|
||||
current_epoch_samples = self.consumed_samples % active_total_samples
|
||||
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
|
||||
|
||||
# data sharding and random sampling
|
||||
bucket_size = (
|
||||
self.total_samples // self.micro_batch_times_data_parallel_size
|
||||
) * self.micro_batch_size
|
||||
bucket_offset = current_epoch_samples // self.data_parallel_size
|
||||
start_idx = self.data_parallel_rank * bucket_size
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
random_idx = torch.randperm(bucket_size, generator=g).tolist()
|
||||
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
|
||||
|
||||
batch = []
|
||||
# Last batch if not complete will be dropped.
|
||||
for idx in idx_range:
|
||||
batch.append(idx)
|
||||
if len(batch) == self.micro_batch_size:
|
||||
self.consumed_samples += self.micro_batch_times_data_parallel_size
|
||||
yield batch
|
||||
batch = []
|
@ -0,0 +1,566 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors, and NVIDIA.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# Most of the code here has been copied from:
|
||||
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
|
||||
# with some modifications.
|
||||
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from codegeex.megatron import mpu, print_rank_0
|
||||
from codegeex.megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
|
||||
|
||||
|
||||
def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples):
|
||||
|
||||
# The data prefix should be in the format of:
|
||||
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
|
||||
assert len(data_prefix) % 2 == 0
|
||||
num_datasets = len(data_prefix) // 2
|
||||
weights = [0] * num_datasets
|
||||
prefixes = [0] * num_datasets
|
||||
for i in range(num_datasets):
|
||||
weights[i] = float(data_prefix[2 * i])
|
||||
prefixes[i] = (data_prefix[2 * i + 1]).strip()
|
||||
# Normalize weights
|
||||
weight_sum = 0.0
|
||||
for weight in weights:
|
||||
weight_sum += weight
|
||||
assert weight_sum > 0.0
|
||||
weights = [weight / weight_sum for weight in weights]
|
||||
|
||||
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
|
||||
# not uniformly distribute the number of samples, we still have
|
||||
# samples left to feed to the network.
|
||||
datasets_train_valid_test_num_samples = []
|
||||
for weight in weights:
|
||||
datasets_train_valid_test_num_samples.append(
|
||||
[
|
||||
int(math.ceil(val * weight * 1.005))
|
||||
for val in train_valid_test_num_samples
|
||||
]
|
||||
)
|
||||
|
||||
return prefixes, weights, datasets_train_valid_test_num_samples
|
||||
|
||||
|
||||
def compile_helper():
|
||||
"""Compile helper function ar runtime. Make sure this
|
||||
is invoked on a single process."""
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
path = os.path.abspath(os.path.dirname(__file__))
|
||||
ret = subprocess.run(["make", "-C", path])
|
||||
if ret.returncode != 0:
|
||||
print("Making C++ dataset helpers module failed, exiting.")
|
||||
import sys
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def get_a_and_b_segments(sample, np_rng):
|
||||
"""Divide sample into a and b segments."""
|
||||
|
||||
# Number of sentences in the sample.
|
||||
n_sentences = len(sample)
|
||||
# Make sure we always have two sentences.
|
||||
assert n_sentences > 1, "make sure each sample has at least two sentences."
|
||||
|
||||
# First part:
|
||||
# `a_end` is how many sentences go into the `A`.
|
||||
a_end = 1
|
||||
if n_sentences >= 3:
|
||||
# Note that randin in numpy is exclusive.
|
||||
a_end = np_rng.randint(1, n_sentences)
|
||||
tokens_a = []
|
||||
for j in range(a_end):
|
||||
tokens_a.extend(sample[j])
|
||||
|
||||
# Second part:
|
||||
tokens_b = []
|
||||
for j in range(a_end, n_sentences):
|
||||
tokens_b.extend(sample[j])
|
||||
|
||||
# Random next:
|
||||
is_next_random = False
|
||||
if np_rng.random() < 0.5:
|
||||
is_next_random = True
|
||||
tokens_a, tokens_b = tokens_b, tokens_a
|
||||
|
||||
return tokens_a, tokens_b, is_next_random
|
||||
|
||||
|
||||
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
|
||||
"""Truncates a pair of sequences to a maximum sequence length."""
|
||||
# print(len_a, len_b, max_num_tokens)
|
||||
assert len_a > 0
|
||||
if len_a + len_b <= max_num_tokens:
|
||||
return False
|
||||
while len_a + len_b > max_num_tokens:
|
||||
if len_a > len_b:
|
||||
len_a -= 1
|
||||
tokens = tokens_a
|
||||
else:
|
||||
len_b -= 1
|
||||
tokens = tokens_b
|
||||
if np_rng.random() < 0.5:
|
||||
del tokens[0]
|
||||
else:
|
||||
tokens.pop()
|
||||
return True
|
||||
|
||||
|
||||
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
|
||||
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
|
||||
|
||||
tokens = []
|
||||
tokentypes = []
|
||||
# [CLS].
|
||||
tokens.append(cls_id)
|
||||
tokentypes.append(0)
|
||||
# Segment A.
|
||||
for token in tokens_a:
|
||||
tokens.append(token)
|
||||
tokentypes.append(0)
|
||||
# [SEP].
|
||||
tokens.append(sep_id)
|
||||
tokentypes.append(0)
|
||||
# Segment B.
|
||||
for token in tokens_b:
|
||||
tokens.append(token)
|
||||
tokentypes.append(1)
|
||||
if tokens_b:
|
||||
# [SEP].
|
||||
tokens.append(sep_id)
|
||||
tokentypes.append(1)
|
||||
|
||||
return tokens, tokentypes
|
||||
|
||||
|
||||
MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"])
|
||||
|
||||
|
||||
def is_start_piece(piece):
|
||||
"""Check if the current word piece is the starting piece (BERT)."""
|
||||
# When a word has been split into
|
||||
# WordPieces, the first token does not have any marker and any subsequence
|
||||
# tokens are prefixed with ##. So whenever we see the ## token, we
|
||||
# append it to the previous set of word indexes.
|
||||
return not piece.startswith("##")
|
||||
|
||||
|
||||
def create_masked_lm_predictions(
|
||||
tokens,
|
||||
vocab_id_list,
|
||||
vocab_id_to_token_dict,
|
||||
masked_lm_prob,
|
||||
cls_id,
|
||||
sep_id,
|
||||
mask_id,
|
||||
max_predictions_per_seq,
|
||||
np_rng,
|
||||
max_ngrams=3,
|
||||
do_whole_word_mask=True,
|
||||
favor_longer_ngram=False,
|
||||
do_permutation=False,
|
||||
geometric_dist=False,
|
||||
masking_style="bert",
|
||||
):
|
||||
"""Creates the predictions for the masked LM objective.
|
||||
Note: Tokens here are vocab ids and not text tokens."""
|
||||
|
||||
cand_indexes = []
|
||||
# Note(mingdachen): We create a list for recording if the piece is
|
||||
# the starting piece of current token, where 1 means true, so that
|
||||
# on-the-fly whole word masking is possible.
|
||||
token_boundary = [0] * len(tokens)
|
||||
|
||||
for (i, token) in enumerate(tokens):
|
||||
if token == cls_id or token == sep_id:
|
||||
token_boundary[i] = 1
|
||||
continue
|
||||
# Whole Word Masking means that if we mask all of the wordpieces
|
||||
# corresponding to an original word.
|
||||
#
|
||||
# Note that Whole Word Masking does *not* change the training code
|
||||
# at all -- we still predict each WordPiece independently, softmaxed
|
||||
# over the entire vocabulary.
|
||||
if (
|
||||
do_whole_word_mask
|
||||
and len(cand_indexes) >= 1
|
||||
and not is_start_piece(vocab_id_to_token_dict[token])
|
||||
):
|
||||
cand_indexes[-1].append(i)
|
||||
else:
|
||||
cand_indexes.append([i])
|
||||
if is_start_piece(vocab_id_to_token_dict[token]):
|
||||
token_boundary[i] = 1
|
||||
|
||||
output_tokens = list(tokens)
|
||||
|
||||
masked_lm_positions = []
|
||||
masked_lm_labels = []
|
||||
|
||||
if masked_lm_prob == 0:
|
||||
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
|
||||
|
||||
num_to_predict = min(
|
||||
max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))
|
||||
)
|
||||
|
||||
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
|
||||
if not geometric_dist:
|
||||
# Note(mingdachen):
|
||||
# By default, we set the probilities to favor shorter ngram sequences.
|
||||
pvals = 1.0 / np.arange(1, max_ngrams + 1)
|
||||
pvals /= pvals.sum(keepdims=True)
|
||||
if favor_longer_ngram:
|
||||
pvals = pvals[::-1]
|
||||
|
||||
ngram_indexes = []
|
||||
for idx in range(len(cand_indexes)):
|
||||
ngram_index = []
|
||||
for n in ngrams:
|
||||
ngram_index.append(cand_indexes[idx : idx + n])
|
||||
ngram_indexes.append(ngram_index)
|
||||
|
||||
np_rng.shuffle(ngram_indexes)
|
||||
|
||||
(masked_lms, masked_spans) = ([], [])
|
||||
covered_indexes = set()
|
||||
for cand_index_set in ngram_indexes:
|
||||
if len(masked_lms) >= num_to_predict:
|
||||
break
|
||||
if not cand_index_set:
|
||||
continue
|
||||
# Note(mingdachen):
|
||||
# Skip current piece if they are covered in lm masking or previous ngrams.
|
||||
for index_set in cand_index_set[0]:
|
||||
for index in index_set:
|
||||
if index in covered_indexes:
|
||||
continue
|
||||
|
||||
if not geometric_dist:
|
||||
n = np_rng.choice(
|
||||
ngrams[: len(cand_index_set)],
|
||||
p=pvals[: len(cand_index_set)]
|
||||
/ pvals[: len(cand_index_set)].sum(keepdims=True),
|
||||
)
|
||||
else:
|
||||
# Sampling "n" from the geometric distribution and clipping it to
|
||||
# the max_ngrams. Using p=0.2 default from the SpanBERT paper
|
||||
# https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
|
||||
n = min(np_rng.geometric(0.2), max_ngrams)
|
||||
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
# Note(mingdachen):
|
||||
# Repeatedly looking for a candidate that does not exceed the
|
||||
# maximum number of predictions by trying shorter ngrams.
|
||||
while len(masked_lms) + len(index_set) > num_to_predict:
|
||||
if n == 0:
|
||||
break
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
# If adding a whole-word mask would exceed the maximum number of
|
||||
# predictions, then just skip this candidate.
|
||||
if len(masked_lms) + len(index_set) > num_to_predict:
|
||||
continue
|
||||
is_any_index_covered = False
|
||||
for index in index_set:
|
||||
if index in covered_indexes:
|
||||
is_any_index_covered = True
|
||||
break
|
||||
if is_any_index_covered:
|
||||
continue
|
||||
for index in index_set:
|
||||
covered_indexes.add(index)
|
||||
masked_token = None
|
||||
if masking_style == "bert":
|
||||
# 80% of the time, replace with [MASK]
|
||||
if np_rng.random() < 0.8:
|
||||
masked_token = mask_id
|
||||
else:
|
||||
# 10% of the time, keep original
|
||||
if np_rng.random() < 0.5:
|
||||
masked_token = tokens[index]
|
||||
# 10% of the time, replace with random word
|
||||
else:
|
||||
masked_token = vocab_id_list[
|
||||
np_rng.randint(0, len(vocab_id_list))
|
||||
]
|
||||
elif masking_style == "t5":
|
||||
masked_token = mask_id
|
||||
else:
|
||||
raise ValueError("invalid value of masking style")
|
||||
|
||||
output_tokens[index] = masked_token
|
||||
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
|
||||
|
||||
masked_spans.append(
|
||||
MaskedLmInstance(
|
||||
index=index_set, label=[tokens[index] for index in index_set]
|
||||
)
|
||||
)
|
||||
|
||||
assert len(masked_lms) <= num_to_predict
|
||||
np_rng.shuffle(ngram_indexes)
|
||||
|
||||
select_indexes = set()
|
||||
if do_permutation:
|
||||
for cand_index_set in ngram_indexes:
|
||||
if len(select_indexes) >= num_to_predict:
|
||||
break
|
||||
if not cand_index_set:
|
||||
continue
|
||||
# Note(mingdachen):
|
||||
# Skip current piece if they are covered in lm masking or previous ngrams.
|
||||
for index_set in cand_index_set[0]:
|
||||
for index in index_set:
|
||||
if index in covered_indexes or index in select_indexes:
|
||||
continue
|
||||
|
||||
n = np.random.choice(
|
||||
ngrams[: len(cand_index_set)],
|
||||
p=pvals[: len(cand_index_set)]
|
||||
/ pvals[: len(cand_index_set)].sum(keepdims=True),
|
||||
)
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
|
||||
while len(select_indexes) + len(index_set) > num_to_predict:
|
||||
if n == 0:
|
||||
break
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
# If adding a whole-word mask would exceed the maximum number of
|
||||
# predictions, then just skip this candidate.
|
||||
if len(select_indexes) + len(index_set) > num_to_predict:
|
||||
continue
|
||||
is_any_index_covered = False
|
||||
for index in index_set:
|
||||
if index in covered_indexes or index in select_indexes:
|
||||
is_any_index_covered = True
|
||||
break
|
||||
if is_any_index_covered:
|
||||
continue
|
||||
for index in index_set:
|
||||
select_indexes.add(index)
|
||||
assert len(select_indexes) <= num_to_predict
|
||||
|
||||
select_indexes = sorted(select_indexes)
|
||||
permute_indexes = list(select_indexes)
|
||||
np_rng.shuffle(permute_indexes)
|
||||
orig_token = list(output_tokens)
|
||||
|
||||
for src_i, tgt_i in zip(select_indexes, permute_indexes):
|
||||
output_tokens[src_i] = orig_token[tgt_i]
|
||||
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
|
||||
|
||||
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
||||
# Sort the spans by the index of the first span
|
||||
masked_spans = sorted(masked_spans, key=lambda x: x.index[0])
|
||||
|
||||
for p in masked_lms:
|
||||
masked_lm_positions.append(p.index)
|
||||
masked_lm_labels.append(p.label)
|
||||
return (
|
||||
output_tokens,
|
||||
masked_lm_positions,
|
||||
masked_lm_labels,
|
||||
token_boundary,
|
||||
masked_spans,
|
||||
)
|
||||
|
||||
|
||||
def pad_and_convert_to_numpy(
|
||||
tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length
|
||||
):
|
||||
"""Pad sequences and convert them to numpy."""
|
||||
|
||||
# Some checks.
|
||||
num_tokens = len(tokens)
|
||||
padding_length = max_seq_length - num_tokens
|
||||
assert padding_length >= 0
|
||||
assert len(tokentypes) == num_tokens
|
||||
assert len(masked_positions) == len(masked_labels)
|
||||
|
||||
# Tokens and token types.
|
||||
filler = [pad_id] * padding_length
|
||||
tokens_np = np.array(tokens + filler, dtype=np.int64)
|
||||
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
|
||||
|
||||
# Padding mask.
|
||||
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64)
|
||||
|
||||
# Lables and loss mask.
|
||||
labels = [-1] * max_seq_length
|
||||
loss_mask = [0] * max_seq_length
|
||||
for i in range(len(masked_positions)):
|
||||
assert masked_positions[i] < num_tokens
|
||||
labels[masked_positions[i]] = masked_labels[i]
|
||||
loss_mask[masked_positions[i]] = 1
|
||||
labels_np = np.array(labels, dtype=np.int64)
|
||||
loss_mask_np = np.array(loss_mask, dtype=np.int64)
|
||||
|
||||
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
|
||||
|
||||
|
||||
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
|
||||
|
||||
print_rank_0(" > building dataset index ...")
|
||||
|
||||
start_time = time.time()
|
||||
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
|
||||
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
|
||||
print_rank_0(
|
||||
" > finished creating indexed dataset in {:4f} "
|
||||
"seconds".format(time.time() - start_time)
|
||||
)
|
||||
|
||||
print_rank_0(" > indexed dataset stats:")
|
||||
print_rank_0(
|
||||
" number of documents: {}".format(indexed_dataset.doc_idx.shape[0] - 1)
|
||||
)
|
||||
print_rank_0(" number of sentences: {}".format(indexed_dataset.sizes.shape[0]))
|
||||
|
||||
return indexed_dataset
|
||||
|
||||
|
||||
def get_train_valid_test_split_(splits_string, size):
|
||||
"""Get dataset splits from comma or '/' separated string list."""
|
||||
|
||||
splits = []
|
||||
if splits_string.find(",") != -1:
|
||||
splits = [float(s) for s in splits_string.split(",")]
|
||||
elif splits_string.find("/") != -1:
|
||||
splits = [float(s) for s in splits_string.split("/")]
|
||||
else:
|
||||
splits = [float(splits_string)]
|
||||
while len(splits) < 3:
|
||||
splits.append(0.0)
|
||||
splits = splits[:3]
|
||||
splits_sum = sum(splits)
|
||||
assert splits_sum > 0.0
|
||||
splits = [split / splits_sum for split in splits]
|
||||
splits_index = [0]
|
||||
for index, split in enumerate(splits):
|
||||
splits_index.append(splits_index[index] + int(round(split * float(size))))
|
||||
diff = splits_index[-1] - size
|
||||
for index in range(1, len(splits_index)):
|
||||
splits_index[index] -= diff
|
||||
assert len(splits_index) == 4
|
||||
assert splits_index[-1] == size
|
||||
return splits_index
|
||||
|
||||
|
||||
def get_samples_mapping(
|
||||
indexed_dataset,
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
max_seq_length,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
name,
|
||||
binary_head,
|
||||
):
|
||||
"""Get a list that maps a sample index to a starting sentence index, end sentence index, and length"""
|
||||
|
||||
if not num_epochs:
|
||||
if not max_num_samples:
|
||||
raise ValueError("Need to specify either max_num_samples " "or num_epochs")
|
||||
num_epochs = np.iinfo(np.int32).max - 1
|
||||
if not max_num_samples:
|
||||
max_num_samples = np.iinfo(np.int64).max - 1
|
||||
|
||||
# Filename of the index mapping
|
||||
indexmap_filename = data_prefix
|
||||
indexmap_filename += "_{}_indexmap".format(name)
|
||||
if num_epochs != (np.iinfo(np.int32).max - 1):
|
||||
indexmap_filename += "_{}ep".format(num_epochs)
|
||||
if max_num_samples != (np.iinfo(np.int64).max - 1):
|
||||
indexmap_filename += "_{}mns".format(max_num_samples)
|
||||
indexmap_filename += "_{}msl".format(max_seq_length)
|
||||
indexmap_filename += "_{:0.2f}ssp".format(short_seq_prob)
|
||||
indexmap_filename += "_{}s".format(seed)
|
||||
indexmap_filename += ".npy"
|
||||
|
||||
# Build the indexed mapping if not exist.
|
||||
if torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename):
|
||||
print(
|
||||
" > WARNING: could not find index map file {}, building "
|
||||
"the indices on rank 0 ...".format(indexmap_filename)
|
||||
)
|
||||
|
||||
# Make sure the types match the helpers input types.
|
||||
assert indexed_dataset.doc_idx.dtype == np.int64
|
||||
assert indexed_dataset.sizes.dtype == np.int32
|
||||
|
||||
# Build samples mapping
|
||||
verbose = torch.distributed.get_rank() == 0
|
||||
start_time = time.time()
|
||||
print_rank_0(" > building sapmles index mapping for {} ...".format(name))
|
||||
# First compile and then import.
|
||||
from megatron.data import helpers
|
||||
|
||||
samples_mapping = helpers.build_mapping(
|
||||
indexed_dataset.doc_idx,
|
||||
indexed_dataset.sizes,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
max_seq_length,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
verbose,
|
||||
2 if binary_head else 1,
|
||||
)
|
||||
print_rank_0(" > done building sapmles index maping")
|
||||
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
|
||||
print_rank_0(" > saved the index mapping in {}".format(indexmap_filename))
|
||||
# Make sure all the ranks have built the mapping
|
||||
print_rank_0(
|
||||
" > elasped time to build and save samples mapping "
|
||||
"(seconds): {:4f}".format(time.time() - start_time)
|
||||
)
|
||||
# This should be a barrier but nccl barrier assumes
|
||||
# device_index=rank which is not the case for model
|
||||
# parallel case
|
||||
counts = torch.cuda.LongTensor([1])
|
||||
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
|
||||
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
|
||||
assert counts[0].item() == (
|
||||
torch.distributed.get_world_size()
|
||||
// torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())
|
||||
)
|
||||
|
||||
# Load indexed dataset.
|
||||
print_rank_0(" > loading indexed mapping from {}".format(indexmap_filename))
|
||||
start_time = time.time()
|
||||
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode="r")
|
||||
print_rank_0(
|
||||
" loaded indexed file in {:3.3f} seconds".format(time.time() - start_time)
|
||||
)
|
||||
print_rank_0(" total number of samples: {}".format(samples_mapping.shape[0]))
|
||||
|
||||
return samples_mapping
|
@ -0,0 +1,717 @@
|
||||
/*
|
||||
coding=utf-8
|
||||
Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
/* Helper methods for fast index mapping builds */
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <math.h>
|
||||
#include <stdexcept>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <random>
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace std;
|
||||
|
||||
const int32_t LONG_SENTENCE_LEN = 512;
|
||||
|
||||
|
||||
void build_blending_indices(py::array_t<uint8_t>& dataset_index,
|
||||
py::array_t<int64_t>& dataset_sample_index,
|
||||
const py::array_t<double>& weights,
|
||||
const int32_t num_datasets,
|
||||
const int64_t size, const bool verbose) {
|
||||
/* Given multiple datasets and a weighting array, build samples
|
||||
such that it follows those wieghts.*/
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "> building indices for blendable datasets ..." << std::endl;
|
||||
}
|
||||
|
||||
// Get the pointer access without the checks.
|
||||
auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
|
||||
auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
|
||||
auto weights_ptr = weights.unchecked<1>();
|
||||
|
||||
// Initialize buffer for number of samples used for each dataset.
|
||||
int64_t current_samples[num_datasets];
|
||||
for(int64_t i = 0; i < num_datasets; ++i) {
|
||||
current_samples[i] = 0;
|
||||
}
|
||||
|
||||
// For each sample:
|
||||
for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {
|
||||
|
||||
// Determine where the max error in sampling is happening.
|
||||
auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
|
||||
int64_t max_error_index = 0;
|
||||
double max_error = weights_ptr[0] * sample_idx_double -
|
||||
static_cast<double>(current_samples[0]);
|
||||
for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
|
||||
double error = weights_ptr[dataset_idx] * sample_idx_double -
|
||||
static_cast<double>(current_samples[dataset_idx]);
|
||||
if (error > max_error) {
|
||||
max_error = error;
|
||||
max_error_index = dataset_idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Populate the indices.
|
||||
dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
|
||||
dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];
|
||||
|
||||
// Update the total samples.
|
||||
current_samples[max_error_index] += 1;
|
||||
|
||||
}
|
||||
|
||||
// print info
|
||||
if (verbose) {
|
||||
std::cout << " > sample ratios:" << std::endl;
|
||||
for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
|
||||
auto ratio = static_cast<double>(current_samples[dataset_idx]) /
|
||||
static_cast<double>(size);
|
||||
std::cout << " dataset " << dataset_idx << ", input: " <<
|
||||
weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
|
||||
const py::array_t<int32_t>& doc_idx_,
|
||||
const int32_t seq_length,
|
||||
const int32_t num_epochs,
|
||||
const int64_t tokens_per_epoch) {
|
||||
/* Sample index (sample_idx) is used for gpt2 like dataset for which
|
||||
the documents are flattened and the samples are built based on this
|
||||
1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
|
||||
where [..., 0] contains the index into `doc_idx` and [..., 1] is the
|
||||
starting offset in that document.*/
|
||||
|
||||
// Consistency checks.
|
||||
assert(seq_length > 1);
|
||||
assert(num_epochs > 0);
|
||||
assert(tokens_per_epoch > 1);
|
||||
|
||||
// Remove bound checks.
|
||||
auto sizes = sizes_.unchecked<1>();
|
||||
auto doc_idx = doc_idx_.unchecked<1>();
|
||||
|
||||
// Mapping and it's length (1D).
|
||||
int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
|
||||
int32_t* sample_idx = new int32_t[2*(num_samples+1)];
|
||||
|
||||
cout << " using:" << endl << std::flush;
|
||||
cout << " number of documents: " <<
|
||||
doc_idx_.shape(0) / num_epochs << endl << std::flush;
|
||||
cout << " number of epochs: " << num_epochs <<
|
||||
endl << std::flush;
|
||||
cout << " sequence length: " << seq_length <<
|
||||
endl << std::flush;
|
||||
cout << " total number of samples: " << num_samples <<
|
||||
endl << std::flush;
|
||||
|
||||
// Index into sample_idx.
|
||||
int64_t sample_index = 0;
|
||||
// Index into doc_idx.
|
||||
int64_t doc_idx_index = 0;
|
||||
// Begining offset for each document.
|
||||
int32_t doc_offset = 0;
|
||||
// Start with first document and no offset.
|
||||
sample_idx[2 * sample_index] = doc_idx_index;
|
||||
sample_idx[2 * sample_index + 1] = doc_offset;
|
||||
++sample_index;
|
||||
|
||||
while (sample_index <= num_samples) {
|
||||
// Start with a fresh sequence.
|
||||
int32_t remaining_seq_length = seq_length + 1;
|
||||
while (remaining_seq_length != 0) {
|
||||
// Get the document length.
|
||||
auto doc_id = doc_idx[doc_idx_index];
|
||||
auto doc_length = sizes[doc_id] - doc_offset;
|
||||
// And add it to the current sequence.
|
||||
remaining_seq_length -= doc_length;
|
||||
// If we have more than a full sequence, adjust offset and set
|
||||
// remaining length to zero so we return from the while loop.
|
||||
// Note that -1 here is for the same reason we have -1 in
|
||||
// `_num_epochs` calculations.
|
||||
if (remaining_seq_length <= 0) {
|
||||
doc_offset += (remaining_seq_length + doc_length - 1);
|
||||
remaining_seq_length = 0;
|
||||
} else {
|
||||
// Otherwise, start from the begining of the next document.
|
||||
++doc_idx_index;
|
||||
doc_offset = 0;
|
||||
}
|
||||
}
|
||||
// Record the sequence.
|
||||
sample_idx[2 * sample_index] = doc_idx_index;
|
||||
sample_idx[2 * sample_index + 1] = doc_offset;
|
||||
++sample_index;
|
||||
}
|
||||
|
||||
// Method to deallocate memory.
|
||||
py::capsule free_when_done(sample_idx, [](void *mem_) {
|
||||
int32_t *mem = reinterpret_cast<int32_t*>(mem_);
|
||||
delete[] mem;
|
||||
});
|
||||
|
||||
// Return the numpy array.
|
||||
const auto byte_size = sizeof(int32_t);
|
||||
return py::array(std::vector<int64_t>{num_samples+1, 2}, // shape
|
||||
{2*byte_size, byte_size}, // C-style contiguous strides
|
||||
sample_idx, // the data pointer
|
||||
free_when_done); // numpy array references
|
||||
|
||||
}
|
||||
|
||||
|
||||
inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
|
||||
const int32_t max_length,
|
||||
std::mt19937& rand32_gen) {
|
||||
/* Training sample length. */
|
||||
if (short_seq_ratio == 0) {
|
||||
return max_length;
|
||||
}
|
||||
const auto random_number = rand32_gen();
|
||||
if ((random_number % short_seq_ratio) == 0) {
|
||||
return 2 + random_number % (max_length - 1);
|
||||
}
|
||||
return max_length;
|
||||
}
|
||||
|
||||
|
||||
template<typename DocIdx>
|
||||
py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
|
||||
const py::array_t<int32_t>& sizes_,
|
||||
const int32_t num_epochs,
|
||||
const uint64_t max_num_samples,
|
||||
const int32_t max_seq_length,
|
||||
const double short_seq_prob,
|
||||
const int32_t seed,
|
||||
const bool verbose,
|
||||
const int32_t min_num_sent) {
|
||||
/* Build a mapping of (start-index, end-index, sequence-length) where
|
||||
start and end index are the indices of the sentences in the sample
|
||||
and sequence-length is the target sequence length.
|
||||
*/
|
||||
|
||||
// Consistency checks.
|
||||
assert(num_epochs > 0);
|
||||
assert(max_seq_length > 1);
|
||||
assert(short_seq_prob >= 0.0);
|
||||
assert(short_seq_prob <= 1.0);
|
||||
assert(seed > 0);
|
||||
|
||||
// Remove bound checks.
|
||||
auto docs = docs_.unchecked<1>();
|
||||
auto sizes = sizes_.unchecked<1>();
|
||||
|
||||
// For efficiency, convert probability to ratio. Note: rand() generates int.
|
||||
int32_t short_seq_ratio = 0;
|
||||
if (short_seq_prob > 0) {
|
||||
short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
const auto sent_start_index = docs[0];
|
||||
const auto sent_end_index = docs[docs_.shape(0) - 1];
|
||||
const auto num_sentences = sent_end_index - sent_start_index;
|
||||
cout << " using:" << endl << std::flush;
|
||||
cout << " number of documents: " << docs_.shape(0) - 1 <<
|
||||
endl << std::flush;
|
||||
cout << " sentences range: [" << sent_start_index <<
|
||||
", " << sent_end_index << ")" << endl << std::flush;
|
||||
cout << " total number of sentences: " << num_sentences <<
|
||||
endl << std::flush;
|
||||
cout << " number of epochs: " << num_epochs <<
|
||||
endl << std::flush;
|
||||
cout << " maximum number of samples: " << max_num_samples <<
|
||||
endl << std::flush;
|
||||
cout << " maximum sequence length: " << max_seq_length <<
|
||||
endl << std::flush;
|
||||
cout << " short sequence probability: " << short_seq_prob <<
|
||||
endl << std::flush;
|
||||
cout << " short sequence ration (1/prob): " << short_seq_ratio <<
|
||||
endl << std::flush;
|
||||
cout << " seed: " << seed << endl <<
|
||||
std::flush;
|
||||
}
|
||||
|
||||
// Mapping and it's length (1D).
|
||||
int64_t num_samples = -1;
|
||||
DocIdx* maps = NULL;
|
||||
|
||||
// Perform two iterations, in the first iteration get the size
|
||||
// and allocate memory and in the second iteration populate the map.
|
||||
bool second = false;
|
||||
for (int32_t iteration=0; iteration<2; ++iteration) {
|
||||
|
||||
// Set the seed so both iterations produce the same results.
|
||||
std::mt19937 rand32_gen(seed);
|
||||
|
||||
// Set the flag on second iteration.
|
||||
second = (iteration == 1);
|
||||
|
||||
// Counters:
|
||||
uint64_t empty_docs = 0;
|
||||
uint64_t one_sent_docs = 0;
|
||||
uint64_t long_sent_docs = 0;
|
||||
|
||||
// Current map index.
|
||||
uint64_t map_index = 0;
|
||||
|
||||
// For each epoch:
|
||||
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
|
||||
if (map_index >= max_num_samples) {
|
||||
if (verbose && (!second)) {
|
||||
cout << " reached " << max_num_samples << " samples after "
|
||||
<< epoch << " epochs ..." << endl << std::flush;
|
||||
}
|
||||
break;
|
||||
}
|
||||
// For each document:
|
||||
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
|
||||
|
||||
// Document sentences are in [sent_index_first, sent_index_last)
|
||||
const auto sent_index_first = docs[doc];
|
||||
const auto sent_index_last = docs[doc + 1];
|
||||
|
||||
// At the begining of the document previous index is the
|
||||
// start index.
|
||||
auto prev_start_index = sent_index_first;
|
||||
|
||||
// Remaining documents.
|
||||
auto num_remain_sent = sent_index_last - sent_index_first;
|
||||
|
||||
// Some bookkeeping
|
||||
if ((epoch == 0) && (!second)) {
|
||||
if (num_remain_sent == 0) {
|
||||
++empty_docs;
|
||||
}
|
||||
if (num_remain_sent == 1) {
|
||||
++one_sent_docs;
|
||||
}
|
||||
}
|
||||
|
||||
// Detect documents with long sentences.
|
||||
bool contains_long_sentence = false;
|
||||
if (num_remain_sent > 1) {
|
||||
for (auto sent_index=sent_index_first;
|
||||
sent_index < sent_index_last; ++sent_index) {
|
||||
if (sizes[sent_index] > LONG_SENTENCE_LEN){
|
||||
if ((epoch == 0) && (!second)) {
|
||||
++long_sent_docs;
|
||||
}
|
||||
contains_long_sentence = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we have more than two sentences.
|
||||
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
|
||||
|
||||
// Set values.
|
||||
auto seq_len = int32_t{0};
|
||||
auto num_sent = int32_t{0};
|
||||
auto target_seq_len = get_target_sample_len(short_seq_ratio,
|
||||
max_seq_length,
|
||||
rand32_gen);
|
||||
|
||||
// Loop through sentences.
|
||||
for (auto sent_index=sent_index_first;
|
||||
sent_index < sent_index_last; ++sent_index) {
|
||||
|
||||
// Add the size and number of sentences.
|
||||
seq_len += sizes[sent_index];
|
||||
++num_sent;
|
||||
--num_remain_sent;
|
||||
|
||||
// If we have reached the target length.
|
||||
// and if not only one sentence is left in the document.
|
||||
// and if we have at least two sentneces.
|
||||
// and if we have reached end of the document.
|
||||
if (((seq_len >= target_seq_len) &&
|
||||
(num_remain_sent > 1) &&
|
||||
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
|
||||
|
||||
// Check for overflow.
|
||||
if ((3 * map_index + 2) >
|
||||
std::numeric_limits<int64_t>::max()) {
|
||||
cout << "number of samples exceeded maximum "
|
||||
<< "allowed by type int64: "
|
||||
<< std::numeric_limits<int64_t>::max()
|
||||
<< endl;
|
||||
throw std::overflow_error("Number of samples");
|
||||
}
|
||||
|
||||
// Populate the map.
|
||||
if (second) {
|
||||
const auto map_index_0 = 3 * map_index;
|
||||
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
|
||||
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
|
||||
maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
|
||||
}
|
||||
|
||||
// Update indices / counters.
|
||||
++map_index;
|
||||
prev_start_index = sent_index + 1;
|
||||
target_seq_len = get_target_sample_len(short_seq_ratio,
|
||||
max_seq_length,
|
||||
rand32_gen);
|
||||
seq_len = 0;
|
||||
num_sent = 0;
|
||||
}
|
||||
|
||||
} // for (auto sent_index=sent_index_first; ...
|
||||
} // if (num_remain_sent > 1) {
|
||||
} // for (int doc=0; doc < num_docs; ++doc) {
|
||||
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
||||
|
||||
if (!second) {
|
||||
if (verbose) {
|
||||
cout << " number of empty documents: " << empty_docs <<
|
||||
endl << std::flush;
|
||||
cout << " number of documents with one sentence: " <<
|
||||
one_sent_docs << endl << std::flush;
|
||||
cout << " number of documents with long sentences: " <<
|
||||
long_sent_docs << endl << std::flush;
|
||||
cout << " will create mapping for " << map_index <<
|
||||
" samples" << endl << std::flush;
|
||||
}
|
||||
assert(maps == NULL);
|
||||
assert(num_samples < 0);
|
||||
maps = new DocIdx[3*map_index];
|
||||
num_samples = static_cast<int64_t>(map_index);
|
||||
}
|
||||
|
||||
} // for (int iteration=0; iteration < 2; ++iteration) {
|
||||
|
||||
// Shuffle.
|
||||
// We need a 64 bit random number generator as we might have more
|
||||
// than 2 billion samples.
|
||||
std::mt19937_64 rand64_gen(seed + 1);
|
||||
for (auto i=(num_samples - 1); i > 0; --i) {
|
||||
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
|
||||
const auto i0 = 3 * i;
|
||||
const auto j0 = 3 * j;
|
||||
// Swap values.
|
||||
swap(maps[i0], maps[j0]);
|
||||
swap(maps[i0 + 1], maps[j0 + 1]);
|
||||
swap(maps[i0 + 2], maps[j0 + 2]);
|
||||
}
|
||||
|
||||
// Method to deallocate memory.
|
||||
py::capsule free_when_done(maps, [](void *mem_) {
|
||||
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
|
||||
delete[] mem;
|
||||
});
|
||||
|
||||
// Return the numpy array.
|
||||
const auto byte_size = sizeof(DocIdx);
|
||||
return py::array(std::vector<int64_t>{num_samples, 3}, // shape
|
||||
{3*byte_size, byte_size}, // C-style contiguous strides
|
||||
maps, // the data pointer
|
||||
free_when_done); // numpy array references
|
||||
|
||||
}
|
||||
|
||||
|
||||
py::array build_mapping(const py::array_t<int64_t>& docs_,
|
||||
const py::array_t<int>& sizes_,
|
||||
const int num_epochs,
|
||||
const uint64_t max_num_samples,
|
||||
const int max_seq_length,
|
||||
const double short_seq_prob,
|
||||
const int seed,
|
||||
const bool verbose,
|
||||
const int32_t min_num_sent) {
|
||||
|
||||
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
|
||||
if (verbose) {
|
||||
cout << " using uint64 for data mapping..." << endl << std::flush;
|
||||
}
|
||||
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
|
||||
max_num_samples, max_seq_length,
|
||||
short_seq_prob, seed, verbose,
|
||||
min_num_sent);
|
||||
} else {
|
||||
if (verbose) {
|
||||
cout << " using uint32 for data mapping..." << endl << std::flush;
|
||||
}
|
||||
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
|
||||
max_num_samples, max_seq_length,
|
||||
short_seq_prob, seed, verbose,
|
||||
min_num_sent);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename DocIdx>
|
||||
py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
|
||||
const py::array_t<int32_t>& sizes_,
|
||||
const py::array_t<int32_t>& titles_sizes_,
|
||||
const int32_t num_epochs,
|
||||
const uint64_t max_num_samples,
|
||||
const int32_t max_seq_length,
|
||||
const int32_t seed,
|
||||
const bool verbose,
|
||||
const bool use_one_sent_blocks) {
|
||||
/* Build a mapping of (start-index, end-index, sequence-length) where
|
||||
start and end index are the indices of the sentences in the sample
|
||||
and sequence-length is the target sequence length.
|
||||
*/
|
||||
|
||||
// Consistency checks.
|
||||
assert(num_epochs > 0);
|
||||
assert(max_seq_length > 1);
|
||||
assert(seed > 0);
|
||||
|
||||
// Remove bound checks.
|
||||
auto docs = docs_.unchecked<1>();
|
||||
auto sizes = sizes_.unchecked<1>();
|
||||
auto titles_sizes = titles_sizes_.unchecked<1>();
|
||||
|
||||
if (verbose) {
|
||||
const auto sent_start_index = docs[0];
|
||||
const auto sent_end_index = docs[docs_.shape(0) - 1];
|
||||
const auto num_sentences = sent_end_index - sent_start_index;
|
||||
cout << " using:" << endl << std::flush;
|
||||
cout << " number of documents: " << docs_.shape(0) - 1 <<
|
||||
endl << std::flush;
|
||||
cout << " sentences range: [" << sent_start_index <<
|
||||
", " << sent_end_index << ")" << endl << std::flush;
|
||||
cout << " total number of sentences: " << num_sentences <<
|
||||
endl << std::flush;
|
||||
cout << " number of epochs: " << num_epochs <<
|
||||
endl << std::flush;
|
||||
cout << " maximum number of samples: " << max_num_samples <<
|
||||
endl << std::flush;
|
||||
cout << " maximum sequence length: " << max_seq_length <<
|
||||
endl << std::flush;
|
||||
cout << " seed: " << seed << endl <<
|
||||
std::flush;
|
||||
}
|
||||
|
||||
// Mapping and its length (1D).
|
||||
int64_t num_samples = -1;
|
||||
DocIdx* maps = NULL;
|
||||
|
||||
// Acceptable number of sentences per block.
|
||||
int min_num_sent = 2;
|
||||
if (use_one_sent_blocks) {
|
||||
min_num_sent = 1;
|
||||
}
|
||||
|
||||
// Perform two iterations, in the first iteration get the size
|
||||
// and allocate memory and in the second iteration populate the map.
|
||||
bool second = false;
|
||||
for (int32_t iteration=0; iteration<2; ++iteration) {
|
||||
|
||||
// Set the flag on second iteration.
|
||||
second = (iteration == 1);
|
||||
|
||||
// Current map index.
|
||||
uint64_t map_index = 0;
|
||||
|
||||
uint64_t empty_docs = 0;
|
||||
uint64_t one_sent_docs = 0;
|
||||
uint64_t long_sent_docs = 0;
|
||||
// For each epoch:
|
||||
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
|
||||
// assign every block a unique id
|
||||
int32_t block_id = 0;
|
||||
|
||||
if (map_index >= max_num_samples) {
|
||||
if (verbose && (!second)) {
|
||||
cout << " reached " << max_num_samples << " samples after "
|
||||
<< epoch << " epochs ..." << endl << std::flush;
|
||||
}
|
||||
break;
|
||||
}
|
||||
// For each document:
|
||||
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
|
||||
|
||||
// Document sentences are in [sent_index_first, sent_index_last)
|
||||
const auto sent_index_first = docs[doc];
|
||||
const auto sent_index_last = docs[doc + 1];
|
||||
const auto target_seq_len = max_seq_length - titles_sizes[doc];
|
||||
|
||||
// At the begining of the document previous index is the
|
||||
// start index.
|
||||
auto prev_start_index = sent_index_first;
|
||||
|
||||
// Remaining documents.
|
||||
auto num_remain_sent = sent_index_last - sent_index_first;
|
||||
|
||||
// Some bookkeeping
|
||||
if ((epoch == 0) && (!second)) {
|
||||
if (num_remain_sent == 0) {
|
||||
++empty_docs;
|
||||
}
|
||||
if (num_remain_sent == 1) {
|
||||
++one_sent_docs;
|
||||
}
|
||||
}
|
||||
// Detect documents with long sentences.
|
||||
bool contains_long_sentence = false;
|
||||
if (num_remain_sent >= min_num_sent) {
|
||||
for (auto sent_index=sent_index_first;
|
||||
sent_index < sent_index_last; ++sent_index) {
|
||||
if (sizes[sent_index] > LONG_SENTENCE_LEN){
|
||||
if ((epoch == 0) && (!second)) {
|
||||
++long_sent_docs;
|
||||
}
|
||||
contains_long_sentence = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// If we have enough sentences and no long sentences.
|
||||
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
|
||||
|
||||
// Set values.
|
||||
auto seq_len = int32_t{0};
|
||||
auto num_sent = int32_t{0};
|
||||
|
||||
// Loop through sentences.
|
||||
for (auto sent_index=sent_index_first;
|
||||
sent_index < sent_index_last; ++sent_index) {
|
||||
|
||||
// Add the size and number of sentences.
|
||||
seq_len += sizes[sent_index];
|
||||
++num_sent;
|
||||
--num_remain_sent;
|
||||
|
||||
// If we have reached the target length.
|
||||
// and there are an acceptable number of sentences left
|
||||
// and if we have at least the minimum number of sentences.
|
||||
// or if we have reached end of the document.
|
||||
if (((seq_len >= target_seq_len) &&
|
||||
(num_remain_sent >= min_num_sent) &&
|
||||
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
|
||||
|
||||
// Populate the map.
|
||||
if (second) {
|
||||
const auto map_index_0 = 4 * map_index;
|
||||
// Each sample has 4 items: the starting sentence index, ending sentence index,
|
||||
// the index of the document from which the block comes (used for fetching titles)
|
||||
// and the unique id of the block (used for creating block indexes)
|
||||
|
||||
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
|
||||
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
|
||||
maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
|
||||
maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
|
||||
}
|
||||
|
||||
// Update indices / counters.
|
||||
++map_index;
|
||||
++block_id;
|
||||
prev_start_index = sent_index + 1;
|
||||
seq_len = 0;
|
||||
num_sent = 0;
|
||||
}
|
||||
} // for (auto sent_index=sent_index_first; ...
|
||||
} // if (num_remain_sent > 1) {
|
||||
} // for (int doc=0; doc < num_docs; ++doc) {
|
||||
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
||||
|
||||
if (!second) {
|
||||
if (verbose) {
|
||||
cout << " number of empty documents: " << empty_docs <<
|
||||
endl << std::flush;
|
||||
cout << " number of documents with one sentence: " <<
|
||||
one_sent_docs << endl << std::flush;
|
||||
cout << " number of documents with long sentences: " <<
|
||||
long_sent_docs << endl << std::flush;
|
||||
cout << " will create mapping for " << map_index <<
|
||||
" samples" << endl << std::flush;
|
||||
}
|
||||
assert(maps == NULL);
|
||||
assert(num_samples < 0);
|
||||
maps = new DocIdx[4*map_index];
|
||||
num_samples = static_cast<int64_t>(map_index);
|
||||
}
|
||||
|
||||
} // for (int iteration=0; iteration < 2; ++iteration) {
|
||||
|
||||
// Shuffle.
|
||||
// We need a 64 bit random number generator as we might have more
|
||||
// than 2 billion samples.
|
||||
std::mt19937_64 rand64_gen(seed + 1);
|
||||
for (auto i=(num_samples - 1); i > 0; --i) {
|
||||
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
|
||||
const auto i0 = 4 * i;
|
||||
const auto j0 = 4 * j;
|
||||
// Swap values.
|
||||
swap(maps[i0], maps[j0]);
|
||||
swap(maps[i0 + 1], maps[j0 + 1]);
|
||||
swap(maps[i0 + 2], maps[j0 + 2]);
|
||||
swap(maps[i0 + 3], maps[j0 + 3]);
|
||||
}
|
||||
|
||||
// Method to deallocate memory.
|
||||
py::capsule free_when_done(maps, [](void *mem_) {
|
||||
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
|
||||
delete[] mem;
|
||||
});
|
||||
|
||||
// Return the numpy array.
|
||||
const auto byte_size = sizeof(DocIdx);
|
||||
return py::array(std::vector<int64_t>{num_samples, 4}, // shape
|
||||
{4*byte_size, byte_size}, // C-style contiguous strides
|
||||
maps, // the data pointer
|
||||
free_when_done); // numpy array references
|
||||
|
||||
}
|
||||
|
||||
py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
|
||||
const py::array_t<int>& sizes_,
|
||||
const py::array_t<int>& titles_sizes_,
|
||||
const int num_epochs,
|
||||
const uint64_t max_num_samples,
|
||||
const int max_seq_length,
|
||||
const int seed,
|
||||
const bool verbose,
|
||||
const bool use_one_sent_blocks) {
|
||||
|
||||
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
|
||||
if (verbose) {
|
||||
cout << " using uint64 for data mapping..." << endl << std::flush;
|
||||
}
|
||||
return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
|
||||
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
|
||||
} else {
|
||||
if (verbose) {
|
||||
cout << " using uint32 for data mapping..." << endl << std::flush;
|
||||
}
|
||||
return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
|
||||
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(helpers, m) {
|
||||
m.def("build_mapping", &build_mapping);
|
||||
m.def("build_blocks_mapping", &build_blocks_mapping);
|
||||
m.def("build_sample_idx", &build_sample_idx);
|
||||
m.def("build_blending_indices", &build_blending_indices);
|
||||
}
|
@ -0,0 +1,595 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# copied from fairseq/fairseq/data/indexed_dataset.py
|
||||
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
|
||||
# other slight modifications to remove fairseq dependencies
|
||||
# Added document index to index file and made it accessible.
|
||||
# An empty sentence no longer separates documents.
|
||||
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import struct
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from functools import lru_cache
|
||||
from itertools import accumulate
|
||||
from codegeex.megatron import print_rank_0
|
||||
|
||||
|
||||
def __best_fitting_dtype(vocab_size=None):
|
||||
if vocab_size is not None and vocab_size < 65500:
|
||||
return np.uint16
|
||||
else:
|
||||
return np.int32
|
||||
|
||||
|
||||
def get_available_dataset_impl():
|
||||
return ["lazy", "cached", "mmap"]
|
||||
|
||||
|
||||
def infer_dataset_impl(path):
|
||||
if IndexedDataset.exists(path):
|
||||
with open(index_file_path(path), "rb") as f:
|
||||
magic = f.read(8)
|
||||
if magic == IndexedDataset._HDR_MAGIC:
|
||||
return "cached"
|
||||
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
|
||||
return "mmap"
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
print(f"Dataset does not exist: {path}")
|
||||
print(
|
||||
"Path should be a basename that both .idx and .bin can be appended to get full filenames."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def make_builder(out_file, impl, vocab_size=None):
|
||||
if impl == "mmap":
|
||||
return MMapIndexedDatasetBuilder(
|
||||
out_file, dtype=__best_fitting_dtype(vocab_size)
|
||||
)
|
||||
else:
|
||||
return IndexedDatasetBuilder(out_file)
|
||||
|
||||
|
||||
def make_dataset(path, impl, skip_warmup=False):
|
||||
if not IndexedDataset.exists(path):
|
||||
print(f"Dataset does not exist: {path}")
|
||||
print(
|
||||
"Path should be a basename that both .idx and .bin can be appended to get full filenames."
|
||||
)
|
||||
return None
|
||||
if impl == "infer":
|
||||
impl = infer_dataset_impl(path)
|
||||
if impl == "lazy" and IndexedDataset.exists(path):
|
||||
return IndexedDataset(path)
|
||||
elif impl == "cached" and IndexedDataset.exists(path):
|
||||
return IndexedCachedDataset(path)
|
||||
elif impl == "mmap" and MMapIndexedDataset.exists(path):
|
||||
return MMapIndexedDataset(path, skip_warmup)
|
||||
print(f"Unknown dataset implementation: {impl}")
|
||||
return None
|
||||
|
||||
|
||||
def dataset_exists(path, impl):
|
||||
if impl == "mmap":
|
||||
return MMapIndexedDataset.exists(path)
|
||||
else:
|
||||
return IndexedDataset.exists(path)
|
||||
|
||||
|
||||
def read_longs(f, n):
|
||||
a = np.empty(n, dtype=np.int64)
|
||||
f.readinto(a)
|
||||
return a
|
||||
|
||||
|
||||
def write_longs(f, a):
|
||||
f.write(np.array(a, dtype=np.int64))
|
||||
|
||||
|
||||
dtypes = {
|
||||
1: np.uint8,
|
||||
2: np.int8,
|
||||
3: np.int16,
|
||||
4: np.int32,
|
||||
5: np.int64,
|
||||
6: np.float,
|
||||
7: np.double,
|
||||
8: np.uint16,
|
||||
}
|
||||
|
||||
|
||||
def __best_fitting_dtype(vocab_size=None):
|
||||
if vocab_size is not None and vocab_size < 65500:
|
||||
return np.uint16
|
||||
else:
|
||||
return np.int32
|
||||
|
||||
|
||||
def make_mmap_builder(out_file, vocab_size=None):
|
||||
return MMapIndexedDatasetBuilder(
|
||||
out_file, dtype=__best_fitting_dtype(vocab_size)
|
||||
)
|
||||
|
||||
|
||||
def code(dtype):
|
||||
for k in dtypes.keys():
|
||||
if dtypes[k] == dtype:
|
||||
return k
|
||||
raise ValueError(dtype)
|
||||
|
||||
|
||||
def index_file_path(prefix_path):
|
||||
return prefix_path + ".idx"
|
||||
|
||||
|
||||
def data_file_path(prefix_path):
|
||||
return prefix_path + ".bin"
|
||||
|
||||
|
||||
def create_doc_idx(sizes):
|
||||
doc_idx = [0]
|
||||
for i, s in enumerate(sizes):
|
||||
if s == 0:
|
||||
doc_idx.append(i + 1)
|
||||
return doc_idx
|
||||
|
||||
|
||||
class IndexedDataset(torch.utils.data.Dataset):
|
||||
"""Loader for IndexedDataset"""
|
||||
|
||||
_HDR_MAGIC = b"TNTIDX\x00\x00"
|
||||
|
||||
def __init__(self, path):
|
||||
super().__init__()
|
||||
self.path = path
|
||||
self.data_file = None
|
||||
self.read_index(path)
|
||||
|
||||
def read_index(self, path):
|
||||
with open(index_file_path(path), "rb") as f:
|
||||
magic = f.read(8)
|
||||
assert magic == self._HDR_MAGIC, (
|
||||
"Index file doesn't match expected format. "
|
||||
"Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
version = f.read(8)
|
||||
assert struct.unpack("<Q", version) == (1,)
|
||||
code, self.element_size = struct.unpack("<QQ", f.read(16))
|
||||
self.dtype = dtypes[code]
|
||||
self._len, self.s = struct.unpack("<QQ", f.read(16))
|
||||
self.doc_count = struct.unpack("<Q", f.read(8))
|
||||
self.dim_offsets = read_longs(f, self._len + 1)
|
||||
self.data_offsets = read_longs(f, self._len + 1)
|
||||
self.sizes = read_longs(f, self.s)
|
||||
self.doc_idx = read_longs(f, self.doc_count)
|
||||
|
||||
def read_data(self, path):
|
||||
self.data_file = open(data_file_path(path), "rb", buffering=0)
|
||||
|
||||
def check_index(self, i):
|
||||
if i < 0 or i >= self._len:
|
||||
raise IndexError("index out of range")
|
||||
|
||||
def __del__(self):
|
||||
if self.data_file:
|
||||
self.data_file.close()
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if not self.data_file:
|
||||
self.read_data(self.path)
|
||||
if isinstance(idx, int):
|
||||
i = idx
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
return a
|
||||
elif isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||
sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]]
|
||||
size = sum(sizes)
|
||||
a = np.empty(size, dtype=self.dtype)
|
||||
self.data_file.seek(self.data_offsets[start] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
offsets = list(accumulate(sizes))
|
||||
sents = np.split(a, offsets[:-1])
|
||||
return sents
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
def size(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return os.path.exists(index_file_path(path)) and os.path.exists(
|
||||
data_file_path(path)
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False # avoid prefetching to save memory
|
||||
|
||||
|
||||
class IndexedCachedDataset(IndexedDataset):
|
||||
def __init__(self, path):
|
||||
super().__init__(path)
|
||||
self.cache = None
|
||||
self.cache_index = {}
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return True
|
||||
|
||||
def prefetch(self, indices):
|
||||
if all(i in self.cache_index for i in indices):
|
||||
return
|
||||
if not self.data_file:
|
||||
self.read_data(self.path)
|
||||
indices = sorted(set(indices))
|
||||
total_size = 0
|
||||
for i in indices:
|
||||
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
|
||||
self.cache = np.empty(total_size, dtype=self.dtype)
|
||||
ptx = 0
|
||||
self.cache_index.clear()
|
||||
for i in indices:
|
||||
self.cache_index[i] = ptx
|
||||
size = self.data_offsets[i + 1] - self.data_offsets[i]
|
||||
a = self.cache[ptx : ptx + size]
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
ptx += size
|
||||
if self.data_file:
|
||||
# close and delete data file after prefetch so we can pickle
|
||||
self.data_file.close()
|
||||
self.data_file = None
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
i = idx
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
ptx = self.cache_index[i]
|
||||
np.copyto(a, self.cache[ptx : ptx + a.size])
|
||||
return a
|
||||
elif isinstance(idx, slice):
|
||||
# Hack just to make this work, can optimizer later if necessary
|
||||
sents = []
|
||||
for i in range(*idx.indices(len(self))):
|
||||
sents.append(self[i])
|
||||
return sents
|
||||
|
||||
|
||||
class IndexedDatasetBuilder(object):
|
||||
element_sizes = {
|
||||
np.uint8: 1,
|
||||
np.int8: 1,
|
||||
np.int16: 2,
|
||||
np.int32: 4,
|
||||
np.int64: 8,
|
||||
np.float: 4,
|
||||
np.double: 8,
|
||||
}
|
||||
|
||||
def __init__(self, out_file, dtype=np.int32):
|
||||
self.out_file = open(out_file, "wb")
|
||||
self.dtype = dtype
|
||||
self.data_offsets = [0]
|
||||
self.dim_offsets = [0]
|
||||
self.sizes = []
|
||||
self.element_size = self.element_sizes[self.dtype]
|
||||
self.doc_idx = [0]
|
||||
|
||||
def add_item(self, tensor):
|
||||
bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
|
||||
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
|
||||
for s in tensor.size():
|
||||
self.sizes.append(s)
|
||||
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
|
||||
|
||||
def end_document(self):
|
||||
self.doc_idx.append(len(self.sizes))
|
||||
|
||||
def merge_file_(self, another_file):
|
||||
index = IndexedDataset(another_file)
|
||||
assert index.dtype == self.dtype
|
||||
|
||||
begin = self.data_offsets[-1]
|
||||
for offset in index.data_offsets[1:]:
|
||||
self.data_offsets.append(begin + offset)
|
||||
self.sizes.extend(index.sizes)
|
||||
begin = self.dim_offsets[-1]
|
||||
for dim_offset in index.dim_offsets[1:]:
|
||||
self.dim_offsets.append(begin + dim_offset)
|
||||
|
||||
with open(data_file_path(another_file), "rb") as f:
|
||||
while True:
|
||||
data = f.read(1024)
|
||||
if data:
|
||||
self.out_file.write(data)
|
||||
else:
|
||||
break
|
||||
|
||||
def finalize(self, index_file):
|
||||
self.out_file.close()
|
||||
index = open(index_file, "wb")
|
||||
index.write(b"TNTIDX\x00\x00")
|
||||
index.write(struct.pack("<Q", 1))
|
||||
index.write(struct.pack("<QQ", code(self.dtype), self.element_size))
|
||||
index.write(struct.pack("<QQ", len(self.data_offsets) - 1, len(self.sizes)))
|
||||
index.write(struct.pack("<Q", len(self.doc_idx)))
|
||||
write_longs(index, self.dim_offsets)
|
||||
write_longs(index, self.data_offsets)
|
||||
write_longs(index, self.sizes)
|
||||
write_longs(index, self.doc_idx)
|
||||
index.close()
|
||||
|
||||
|
||||
def _warmup_mmap_file(path):
|
||||
with open(path, "rb") as stream:
|
||||
while stream.read(100 * 1024 * 1024):
|
||||
pass
|
||||
|
||||
|
||||
class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
class Index(object):
|
||||
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
||||
|
||||
@classmethod
|
||||
def writer(cls, path, dtype):
|
||||
class _Writer(object):
|
||||
def __enter__(self):
|
||||
self._file = open(path, "wb")
|
||||
|
||||
self._file.write(cls._HDR_MAGIC)
|
||||
self._file.write(struct.pack("<Q", 1))
|
||||
self._file.write(struct.pack("<B", code(dtype)))
|
||||
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _get_pointers(sizes):
|
||||
dtype_size = dtype().itemsize
|
||||
address = 0
|
||||
pointers = []
|
||||
|
||||
for size in sizes:
|
||||
pointers.append(address)
|
||||
address += size * dtype_size
|
||||
|
||||
return pointers
|
||||
|
||||
def write(self, sizes, doc_idx):
|
||||
pointers = self._get_pointers(sizes)
|
||||
|
||||
self._file.write(struct.pack("<Q", len(sizes)))
|
||||
self._file.write(struct.pack("<Q", len(doc_idx)))
|
||||
|
||||
sizes = np.array(sizes, dtype=np.int32)
|
||||
self._file.write(sizes.tobytes(order="C"))
|
||||
del sizes
|
||||
|
||||
pointers = np.array(pointers, dtype=np.int64)
|
||||
self._file.write(pointers.tobytes(order="C"))
|
||||
del pointers
|
||||
|
||||
doc_idx = np.array(doc_idx, dtype=np.int64)
|
||||
self._file.write(doc_idx.tobytes(order="C"))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._file.close()
|
||||
|
||||
return _Writer()
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
with open(path, "rb") as stream:
|
||||
magic_test = stream.read(9)
|
||||
assert self._HDR_MAGIC == magic_test, (
|
||||
"Index file doesn't match expected format. "
|
||||
"Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
version = struct.unpack("<Q", stream.read(8))
|
||||
assert (1,) == version
|
||||
|
||||
(dtype_code,) = struct.unpack("<B", stream.read(1))
|
||||
self._dtype = dtypes[dtype_code]
|
||||
self._dtype_size = self._dtype().itemsize
|
||||
|
||||
self._len = struct.unpack("<Q", stream.read(8))[0]
|
||||
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
|
||||
offset = stream.tell()
|
||||
|
||||
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
self._sizes = np.frombuffer(
|
||||
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
|
||||
)
|
||||
self._pointers = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._len,
|
||||
offset=offset + self._sizes.nbytes,
|
||||
)
|
||||
self._doc_idx = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._doc_count,
|
||||
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._doc_idx
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
return self._pointers[i], self._sizes[i]
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
super().__init__()
|
||||
|
||||
self._path = None
|
||||
self._index = None
|
||||
self._bin_buffer = None
|
||||
|
||||
self._do_init(path, skip_warmup)
|
||||
|
||||
def __getstate__(self):
|
||||
return self._path
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._do_init(state)
|
||||
|
||||
def _do_init(self, path, skip_warmup):
|
||||
self._path = path
|
||||
self._index = self.Index(index_file_path(self._path), skip_warmup)
|
||||
|
||||
if not skip_warmup:
|
||||
print_rank_0(" warming up data mmap file...")
|
||||
_warmup_mmap_file(data_file_path(self._path))
|
||||
print_rank_0(" creating numpy buffer of mmap...")
|
||||
self._bin_buffer_mmap = np.memmap(
|
||||
data_file_path(self._path), mode="r", order="C"
|
||||
)
|
||||
print_rank_0(" creating memory view of numpy buffer...")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
del self._index
|
||||
|
||||
def __len__(self):
|
||||
return len(self._index)
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
ptr, size = self._index[idx]
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
||||
)
|
||||
return np_array
|
||||
elif isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||
ptr = self._index._pointers[start]
|
||||
sizes = self._index._sizes[idx]
|
||||
offsets = list(accumulate(sizes))
|
||||
total_size = sum(sizes)
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
|
||||
)
|
||||
sents = np.split(np_array, offsets[:-1])
|
||||
return sents
|
||||
|
||||
def get(self, idx, offset=0, length=None):
|
||||
"""Retrieves a single item from the dataset with the option to only
|
||||
return a portion of the item.
|
||||
|
||||
get(idx) is the same as [idx] but get() does not support slicing.
|
||||
"""
|
||||
ptr, size = self._index[idx]
|
||||
if length is None:
|
||||
length = size - offset
|
||||
ptr += offset * np.dtype(self._index.dtype).itemsize
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
|
||||
)
|
||||
return np_array
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._index.sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._index.doc_idx
|
||||
|
||||
def get_doc_idx(self):
|
||||
return self._index._doc_idx
|
||||
|
||||
def set_doc_idx(self, doc_idx_):
|
||||
self._index._doc_idx = doc_idx_
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return os.path.exists(index_file_path(path)) and os.path.exists(
|
||||
data_file_path(path)
|
||||
)
|
||||
|
||||
|
||||
class MMapIndexedDatasetBuilder(object):
|
||||
def __init__(self, out_file, dtype=np.int64):
|
||||
self._data_file = open(out_file, "wb")
|
||||
self._dtype = dtype
|
||||
self._sizes = []
|
||||
self._doc_idx = [0]
|
||||
|
||||
def add_item(self, tensor):
|
||||
np_array = np.array(tensor.numpy(), dtype=self._dtype)
|
||||
self._data_file.write(np_array.tobytes(order="C"))
|
||||
self._sizes.append(np_array.size)
|
||||
|
||||
def end_document(self):
|
||||
self._doc_idx.append(len(self._sizes))
|
||||
|
||||
def merge_file_(self, another_file):
|
||||
# Concatenate index
|
||||
index = MMapIndexedDataset.Index(index_file_path(another_file))
|
||||
assert index.dtype == self._dtype
|
||||
|
||||
for size in index.sizes:
|
||||
self._sizes.append(size)
|
||||
|
||||
# Concatenate data
|
||||
with open(data_file_path(another_file), "rb") as f:
|
||||
shutil.copyfileobj(f, self._data_file)
|
||||
|
||||
def finalize(self, index_file):
|
||||
self._data_file.close()
|
||||
|
||||
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
|
||||
index.write(self._sizes, self._doc_idx)
|
@ -0,0 +1,332 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""GPT prompting dataset."""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from codegeex.megatron import mpu, print_rank_0, get_tokenizer
|
||||
from codegeex.megatron.data.blendable_dataset import BlendableDataset
|
||||
from codegeex.megatron.data.dataset_utils import get_datasets_weights_and_num_samples
|
||||
from codegeex.megatron.data.dataset_utils import get_train_valid_test_split_
|
||||
from codegeex.megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
|
||||
|
||||
|
||||
def build_train_valid_test_datasets(
|
||||
data_prefix,
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
seq_length,
|
||||
seed,
|
||||
skip_warmup,
|
||||
):
|
||||
"""Build train, valid, and test datasets."""
|
||||
|
||||
# Single dataset.
|
||||
if len(data_prefix) == 1:
|
||||
return _build_train_valid_test_datasets(
|
||||
data_prefix[0],
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
seq_length,
|
||||
seed,
|
||||
skip_warmup,
|
||||
)
|
||||
|
||||
# Blending dataset.
|
||||
# Parse the values.
|
||||
output = get_datasets_weights_and_num_samples(
|
||||
data_prefix, train_valid_test_num_samples
|
||||
)
|
||||
prefixes, weights, datasets_train_valid_test_num_samples = output
|
||||
|
||||
# Build individual datasets.
|
||||
train_datasets = []
|
||||
valid_datasets = []
|
||||
test_datasets = []
|
||||
for i in range(len(prefixes)):
|
||||
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
|
||||
prefixes[i],
|
||||
data_impl,
|
||||
splits_string,
|
||||
datasets_train_valid_test_num_samples[i],
|
||||
seq_length,
|
||||
seed,
|
||||
skip_warmup,
|
||||
)
|
||||
if train_ds:
|
||||
train_datasets.append(train_ds)
|
||||
if valid_ds:
|
||||
valid_datasets.append(valid_ds)
|
||||
if test_ds:
|
||||
test_datasets.append(test_ds)
|
||||
|
||||
# Blend.
|
||||
blending_train_dataset = None
|
||||
if train_datasets:
|
||||
blending_train_dataset = BlendableDataset(train_datasets, weights)
|
||||
blending_valid_dataset = None
|
||||
if valid_datasets:
|
||||
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
|
||||
blending_test_dataset = None
|
||||
if test_datasets:
|
||||
blending_test_dataset = BlendableDataset(test_datasets, weights)
|
||||
|
||||
return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)
|
||||
|
||||
|
||||
def _build_train_valid_test_datasets(
|
||||
data_prefix,
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
seq_length,
|
||||
seed,
|
||||
skip_warmup,
|
||||
):
|
||||
"""Build train, valid, and test datasets."""
|
||||
|
||||
# Indexed dataset.
|
||||
assert os.path.exists(data_prefix + "_input_ids.bin"), f"Input tokens datafile not found: {data_prefix}_input_ids.bin"
|
||||
assert os.path.exists(data_prefix + "_attention_mask.bin"), f"Attention mask datafile not found: {data_prefix}_attention_mask.bin"
|
||||
assert os.path.exists(data_prefix + "_labels.bin"), f"Labels datafile not found: {data_prefix}_labels.bin"
|
||||
|
||||
input_ids_indexed_dataset = get_indexed_dataset_(data_prefix + "_input_ids", data_impl, skip_warmup)
|
||||
attention_mask_indexed_dataset = get_indexed_dataset_(data_prefix + "_attention_mask", data_impl, skip_warmup)
|
||||
labels_indexed_dataset = get_indexed_dataset_(data_prefix + "_labels", data_impl, skip_warmup)
|
||||
|
||||
total_num_of_documents = input_ids_indexed_dataset.sizes.shape[0]
|
||||
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
|
||||
|
||||
# Print stats about the splits.
|
||||
print_rank_0(" > dataset split:")
|
||||
|
||||
def print_split_stats(name, index):
|
||||
print_rank_0(" {}:".format(name))
|
||||
print_rank_0(
|
||||
" document indices in [{}, {}) total of {} "
|
||||
"documents".format(
|
||||
splits[index], splits[index + 1], splits[index + 1] - splits[index]
|
||||
)
|
||||
)
|
||||
|
||||
print_split_stats("train", 0)
|
||||
print_split_stats("validation", 1)
|
||||
print_split_stats("test", 2)
|
||||
|
||||
def build_dataset(index, name):
|
||||
dataset = None
|
||||
if splits[index + 1] > splits[index]:
|
||||
documents = np.arange(
|
||||
start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32
|
||||
)
|
||||
dataset = PromptDataset(
|
||||
name,
|
||||
data_prefix,
|
||||
documents,
|
||||
input_ids_indexed_dataset,
|
||||
attention_mask_indexed_dataset,
|
||||
labels_indexed_dataset,
|
||||
train_valid_test_num_samples[index],
|
||||
seq_length,
|
||||
seed,
|
||||
)
|
||||
return dataset
|
||||
|
||||
train_dataset = build_dataset(0, "train")
|
||||
valid_dataset = build_dataset(1, "valid")
|
||||
test_dataset = build_dataset(2, "test")
|
||||
print_rank_0(f"train_dataset:{type(train_dataset)}")
|
||||
print_rank_0(f"valid_dataset:{type(valid_dataset)}")
|
||||
print_rank_0(f"test_dataset:{type(test_dataset)}")
|
||||
|
||||
return (train_dataset, valid_dataset, test_dataset)
|
||||
|
||||
|
||||
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
|
||||
"""Build indexed dataset."""
|
||||
print_rank_0(" > building dataset index ...")
|
||||
|
||||
start_time = time.time()
|
||||
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
|
||||
print_rank_0(
|
||||
" > finished creating indexed dataset in {:4f} "
|
||||
"seconds".format(time.time() - start_time)
|
||||
)
|
||||
print_rank_0(" number of documents: {}".format(indexed_dataset.sizes.shape[0]))
|
||||
|
||||
return indexed_dataset
|
||||
|
||||
|
||||
class PromptDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
data_prefix,
|
||||
documents,
|
||||
input_ids_indexed_dataset,
|
||||
attention_mask_index_dataset,
|
||||
labels_indexed_dataset,
|
||||
num_samples,
|
||||
seq_length,
|
||||
seed,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
name: name of the dataset.
|
||||
data_prefix: prefix of the data.
|
||||
documents: list of document indices.
|
||||
input_ids_indexed_dataset: indexed dataset for prompts.
|
||||
attention_mask_index_dataset: indexed dataset for text.
|
||||
labels_indexed_dataset: indexed dataset for labels.
|
||||
num_samples: number of samples to draw from the indexed dataset.
|
||||
seq_length: sequence length.
|
||||
seed: seed for random number generator.
|
||||
"""
|
||||
|
||||
self.name = name
|
||||
self.input_ids_indexed_dataset = input_ids_indexed_dataset
|
||||
self.attention_mask_index_dataset = attention_mask_index_dataset
|
||||
self.labels_indexed_dataset = labels_indexed_dataset
|
||||
self.seq_length = seq_length
|
||||
self.eod_token = get_tokenizer().eod
|
||||
|
||||
# Checks
|
||||
assert np.min(documents) >= 0
|
||||
assert np.max(documents) < input_ids_indexed_dataset.sizes.shape[0]
|
||||
assert input_ids_indexed_dataset.sizes.shape[0] == attention_mask_index_dataset.sizes.shape[0]
|
||||
assert attention_mask_index_dataset.sizes.shape[0] == labels_indexed_dataset.sizes.shape[0]
|
||||
|
||||
# Build index mappings.
|
||||
self.doc_idx = _build_index_mappings(
|
||||
self.name,
|
||||
data_prefix,
|
||||
documents,
|
||||
self.input_ids_indexed_dataset.sizes,
|
||||
num_samples,
|
||||
seq_length,
|
||||
seed,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
# -1 is due to data structure used to retieve the index:
|
||||
# sample i --> [sample_idx[i], sample_idx[i+1])
|
||||
return self.doc_idx.shape[0]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# get the doc index
|
||||
doc_idx = self.doc_idx[idx]
|
||||
doc_idx = int(doc_idx) # NumPy int => Python int
|
||||
|
||||
input_ids = self.input_ids_indexed_dataset[doc_idx]
|
||||
# print_rank_0(f"input_ids={input_ids}")
|
||||
attention_mask = self.attention_mask_index_dataset[doc_idx]
|
||||
labels = self.labels_indexed_dataset[doc_idx]
|
||||
|
||||
res = {
|
||||
"input_ids": np.array(input_ids, dtype=np.int64),
|
||||
"attention_mask": np.array(attention_mask, dtype=np.int64),
|
||||
"labels": np.array(labels, dtype=np.int64),
|
||||
}
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _build_index_mappings(
|
||||
name, data_prefix, documents, sizes, num_samples, seq_length, seed,
|
||||
):
|
||||
"""Build index mappings.
|
||||
We only have to build doc-idx in prompt dataset.
|
||||
|
||||
Args:
|
||||
name: name of the dataset.
|
||||
data_prefix: prefix of the data.
|
||||
documents: list of document indices.
|
||||
sizes: sizes of the indexed dataset.
|
||||
num_samples: number of samples to draw from the indexed dataset.
|
||||
seq_length: sequence length.
|
||||
seed: seed for random number generator.
|
||||
"""
|
||||
num_epochs = _num_epochs(documents.shape[0], num_samples)
|
||||
np_rng = np.random.RandomState(seed=seed)
|
||||
|
||||
_filename = data_prefix
|
||||
_filename += "_{}_indexmap".format(name)
|
||||
_filename += "_{}ns".format(num_samples)
|
||||
_filename += "_{}sl".format(seq_length)
|
||||
_filename += "_{}s".format(seed)
|
||||
doc_idx_filename = _filename + "_doc_idx.npy"
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
if not os.path.isfile(doc_idx_filename):
|
||||
print_rank_0(
|
||||
" > WARNING: could not find index map files, building "
|
||||
"the indices on rank 0 ..."
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
doc_idx = _build_doc_idx(documents, num_epochs, np_rng, False)[:num_samples]
|
||||
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
|
||||
|
||||
print_rank_0(
|
||||
" > elasped time to build and save doc-idx mapping "
|
||||
"(seconds): {:4f}".format(time.time() - start_time)
|
||||
)
|
||||
|
||||
# This should be a barrier but nccl barrier assumes
|
||||
# device_index=rank which is not the case for model
|
||||
# parallel case
|
||||
counts = torch.cuda.LongTensor([1])
|
||||
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
|
||||
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
|
||||
assert counts[0].item() == (
|
||||
torch.distributed.get_world_size()
|
||||
// torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())
|
||||
)
|
||||
# Load mappings.
|
||||
start_time = time.time()
|
||||
print_rank_0(" > loading doc-idx mapping from {}".format(doc_idx_filename))
|
||||
doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode="r")
|
||||
print_rank_0(" total number of samples: {}".format(doc_idx.shape[0]))
|
||||
print_rank_0(" total number of epochs: {}".format(num_epochs))
|
||||
|
||||
return doc_idx
|
||||
|
||||
|
||||
def _num_epochs(samples_per_epoch, num_samples):
|
||||
"""Calculate the epoch needed for so many sample."""
|
||||
return int(np.ceil(num_samples / samples_per_epoch))
|
||||
|
||||
|
||||
def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch):
|
||||
"""Build an array with length = number-of-epochs * number-of-dcuments.
|
||||
Each index is mapped to a corresponding document."""
|
||||
if not separate_last_epoch or num_epochs == 1:
|
||||
doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1]
|
||||
doc_idx[:] = documents
|
||||
doc_idx = doc_idx.reshape(-1)
|
||||
doc_idx = doc_idx.astype(np.int32)
|
||||
np_rng.shuffle(doc_idx)
|
||||
return doc_idx
|
||||
|
||||
doc_idx_first = _build_doc_idx(documents, num_epochs - 1, np_rng, False)
|
||||
doc_idx_last = _build_doc_idx(documents, 1, np_rng, False)
|
||||
return np.concatenate((doc_idx_first, doc_idx_last))
|
@ -0,0 +1,31 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
|
||||
|
||||
class LayerType(enum.Enum):
|
||||
encoder = 1
|
||||
decoder = 2
|
||||
|
||||
|
||||
class AttnType(enum.Enum):
|
||||
self_attn = 1
|
||||
cross_attn = 2
|
||||
|
||||
|
||||
class AttnMaskType(enum.Enum):
|
||||
padding = 1
|
||||
causal = 2
|
@ -0,0 +1,194 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Learning rate decay functions."""
|
||||
|
||||
import math
|
||||
|
||||
from codegeex.megatron import print_rank_0, get_args
|
||||
|
||||
|
||||
class AnnealingLR(object):
|
||||
"""Anneals the learning rate."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
max_lr,
|
||||
min_lr,
|
||||
warmup_steps,
|
||||
decay_steps,
|
||||
decay_style,
|
||||
use_checkpoint_lr_scheduler=True,
|
||||
override_lr_scheduler=False,
|
||||
):
|
||||
args = get_args()
|
||||
# Class values.
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.max_lr = float(max_lr)
|
||||
self.min_lr = min_lr
|
||||
assert self.min_lr >= 0.0
|
||||
assert self.max_lr >= self.min_lr
|
||||
|
||||
self.warmup_steps = warmup_steps
|
||||
self.num_steps = 0
|
||||
self.decay_steps = decay_steps
|
||||
assert self.decay_steps > 0
|
||||
assert self.warmup_steps < self.decay_steps
|
||||
|
||||
self.decay_tokens = args.lr_decay_tokens
|
||||
self.num_tokens = 0
|
||||
self.warmup_tokens = 0
|
||||
|
||||
self.decay_style = decay_style
|
||||
|
||||
self.override_lr_scheduler = override_lr_scheduler
|
||||
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
|
||||
if self.override_lr_scheduler:
|
||||
assert not self.use_checkpoint_lr_scheduler, (
|
||||
"both override and " "use-checkpoint are set."
|
||||
)
|
||||
|
||||
# Set the learning rate
|
||||
self.step(0)
|
||||
|
||||
print_rank_0("> learning rate decay style: {}".format(self.decay_style))
|
||||
|
||||
def get_lr(self):
|
||||
"""Learning rate decay functions from:
|
||||
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
|
||||
|
||||
# Use linear warmup for the initial part.
|
||||
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
|
||||
if self.num_steps == self.warmup_steps and self.decay_tokens is not None:
|
||||
self.warmup_tokens = self.num_tokens
|
||||
return self.max_lr * float(self.num_steps) / float(self.warmup_steps)
|
||||
|
||||
# If the learning rate is constant, just return the initial value.
|
||||
if self.decay_style == "constant":
|
||||
return self.max_lr
|
||||
|
||||
if self.decay_tokens is None:
|
||||
# step-based decay
|
||||
|
||||
# For any steps larger than `self.decay_steps`, use `self.min_lr`.
|
||||
if self.num_steps > self.decay_steps:
|
||||
return self.min_lr
|
||||
|
||||
# If we are done with the warmup period, use the decay style.
|
||||
num_steps_ = self.num_steps - self.warmup_steps
|
||||
decay_steps_ = self.decay_steps - self.warmup_steps
|
||||
decay_ratio = float(num_steps_) / float(decay_steps_)
|
||||
else:
|
||||
# token-based decay
|
||||
|
||||
if self.num_tokens > self.decay_tokens:
|
||||
return self.min_lr
|
||||
num_tokens_ = self.num_tokens - self.warmup_tokens
|
||||
decay_tokens_ = self.decay_tokens - self.warmup_tokens
|
||||
decay_ratio = float(num_tokens_) / float(decay_tokens_)
|
||||
assert decay_ratio >= 0.0
|
||||
assert decay_ratio <= 1.0
|
||||
delta_lr = self.max_lr - self.min_lr
|
||||
|
||||
if self.decay_style == "linear":
|
||||
coeff = 1.0 - decay_ratio
|
||||
elif self.decay_style == "cosine":
|
||||
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
|
||||
else:
|
||||
raise Exception("{} decay style is not supported.".format(self.decay_style))
|
||||
|
||||
return self.min_lr + coeff * delta_lr
|
||||
|
||||
def step(self, increment, token_num=None):
|
||||
"""Set lr for all parameters groups."""
|
||||
if token_num is None:
|
||||
args = get_args()
|
||||
token_num = args.consumed_train_tokens
|
||||
self.num_tokens = token_num
|
||||
self.num_steps += increment
|
||||
new_lr = self.get_lr()
|
||||
for group in self.optimizer.param_groups:
|
||||
group["lr"] = new_lr
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {
|
||||
"max_lr": self.max_lr,
|
||||
"warmup_steps": self.warmup_steps,
|
||||
"num_steps": self.num_steps,
|
||||
"warmup_tokens": self.warmup_tokens,
|
||||
"num_tokens": self.num_tokens,
|
||||
"decay_style": self.decay_style,
|
||||
"decay_steps": self.decay_steps,
|
||||
"min_lr": self.min_lr,
|
||||
}
|
||||
return state_dict
|
||||
|
||||
def _check_and_set(self, cls_value, sd_value, name):
|
||||
"""Auxiliary function for checking the values in the checkpoint and
|
||||
setting them."""
|
||||
if self.override_lr_scheduler:
|
||||
print_rank_0(" > overriding {} value to {}".format(name, cls_value))
|
||||
return cls_value
|
||||
|
||||
if not self.use_checkpoint_lr_scheduler:
|
||||
assert cls_value == sd_value, (
|
||||
f"AnnealingLR: class input value {cls_value} and checkpoint"
|
||||
f"value {sd_value} for {name} do not match"
|
||||
)
|
||||
print_rank_0(" > using checkpoint value {} for {}".format(sd_value, name))
|
||||
return sd_value
|
||||
|
||||
def load_state_dict(self, sd):
|
||||
|
||||
if "start_lr" in sd:
|
||||
max_lr_ = sd["start_lr"]
|
||||
else:
|
||||
max_lr_ = sd["max_lr"]
|
||||
self.max_lr = self._check_and_set(self.max_lr, max_lr_, "learning rate")
|
||||
|
||||
self.min_lr = self._check_and_set(
|
||||
self.min_lr, sd["min_lr"], "minimum learning rate"
|
||||
)
|
||||
|
||||
if "warmup_iter" in sd:
|
||||
warmup_steps_ = sd["warmup_iter"]
|
||||
else:
|
||||
warmup_steps_ = sd["warmup_steps"]
|
||||
self.warmup_steps = self._check_and_set(
|
||||
self.warmup_steps, warmup_steps_, "warmup iterations"
|
||||
)
|
||||
|
||||
if "end_iter" in sd:
|
||||
decay_steps_ = sd["end_iter"]
|
||||
else:
|
||||
decay_steps_ = sd["decay_steps"]
|
||||
self.decay_steps = self._check_and_set(
|
||||
self.decay_steps, decay_steps_, "total number of iterations"
|
||||
)
|
||||
self.decay_style = self._check_and_set(
|
||||
self.decay_style, sd["decay_style"], "decay style"
|
||||
)
|
||||
|
||||
if "num_iters" in sd:
|
||||
num_steps = sd["num_iters"]
|
||||
else:
|
||||
num_steps = sd["num_steps"]
|
||||
if "warmup_tokens" in sd:
|
||||
self.warmup_tokens = sd["warmup_tokens"]
|
||||
if "num_tokens" in sd:
|
||||
self.num_tokens = sd["num_tokens"]
|
||||
self.step(num_steps, self.num_tokens)
|
@ -0,0 +1,186 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Megatron number of micro-batches calculators."""
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
def build_num_microbatches_calculator(args):
|
||||
|
||||
# Constant num micro-batches.
|
||||
if args.rampup_batch_size is None:
|
||||
num_microbatches_calculator = ConstantNumMicroBatches(
|
||||
args.global_batch_size, args.micro_batch_size, args.data_parallel_size
|
||||
)
|
||||
if args.rank == 0:
|
||||
print(
|
||||
"setting number of micro-batches to constant {}".format(
|
||||
num_microbatches_calculator.get()
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
else:
|
||||
assert len(args.rampup_batch_size) == 3, (
|
||||
"expected the following "
|
||||
"format: --rampup-batch-size <start batch size> "
|
||||
"<batch size incerement> <ramp-up samples>"
|
||||
)
|
||||
start_batch_size = int(args.rampup_batch_size[0])
|
||||
batch_size_increment = int(args.rampup_batch_size[1])
|
||||
ramup_samples = int(args.rampup_batch_size[2])
|
||||
if args.rank == 0:
|
||||
print(
|
||||
"will use batch size rampup starting from global batch "
|
||||
"size {} to global batch size {} with batch size increments "
|
||||
"{} over {} samples.".format(
|
||||
start_batch_size,
|
||||
args.global_batch_size,
|
||||
batch_size_increment,
|
||||
ramup_samples,
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
num_microbatches_calculator = RampupBatchsizeNumMicroBatches(
|
||||
start_batch_size,
|
||||
batch_size_increment,
|
||||
ramup_samples,
|
||||
args.global_batch_size,
|
||||
args.micro_batch_size,
|
||||
args.data_parallel_size,
|
||||
)
|
||||
|
||||
return num_microbatches_calculator
|
||||
|
||||
|
||||
class NumMicroBatchesCalculator(ABC):
|
||||
def __init__(self):
|
||||
self.num_micro_batches = None
|
||||
self.current_global_batch_size = None
|
||||
|
||||
def get(self):
|
||||
return self.num_micro_batches
|
||||
|
||||
def get_current_global_batch_size(self):
|
||||
return self.current_global_batch_size
|
||||
|
||||
@abstractmethod
|
||||
def update(self, consumed_samples, consistency_check):
|
||||
pass
|
||||
|
||||
|
||||
class ConstantNumMicroBatches(NumMicroBatchesCalculator):
|
||||
def __init__(self, global_batch_size, micro_batch_size, data_parallel_size):
|
||||
micro_batch_times_data_parallel = micro_batch_size * data_parallel_size
|
||||
assert global_batch_size % micro_batch_times_data_parallel == 0, (
|
||||
"global batch size ({}) is not divisible by micro batch size ({})"
|
||||
" times data parallel size ({})".format(
|
||||
global_batch_size, micro_batch_size, data_parallel_size
|
||||
)
|
||||
)
|
||||
self.num_micro_batches = global_batch_size // micro_batch_times_data_parallel
|
||||
assert self.num_micro_batches >= 1
|
||||
self.current_global_batch_size = global_batch_size
|
||||
|
||||
def update(self, consumed_samples, consistency_check):
|
||||
pass
|
||||
|
||||
|
||||
class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
|
||||
def __init__(
|
||||
self,
|
||||
start_batch_size,
|
||||
batch_size_increment,
|
||||
ramup_samples,
|
||||
global_batch_size,
|
||||
micro_batch_size,
|
||||
data_parallel_size,
|
||||
):
|
||||
"""Batch size ramp up.
|
||||
Over
|
||||
steps = (global-batch-size - start-batch-size) / batch_size_increment
|
||||
increment batch size from start-batch-size to global-batch-size using
|
||||
rampup-samples / steps
|
||||
samples.
|
||||
Arguments:
|
||||
start_batch_size: global batch size to start with
|
||||
batch_size_increment: global batch size increments
|
||||
ramup_samples: number of samples to use ramp up global
|
||||
batch size from `start_batch_size` to `global_batch_size`
|
||||
global_batch_size: global batch size post rampup
|
||||
micro_batch_size: micro batch size
|
||||
data_parallel_size: data parallel size.
|
||||
"""
|
||||
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.data_parallel_size = data_parallel_size
|
||||
self.micro_batch_times_data_parallel_size = (
|
||||
self.micro_batch_size * self.data_parallel_size
|
||||
)
|
||||
assert self.micro_batch_times_data_parallel_size > 0
|
||||
|
||||
assert start_batch_size > 0
|
||||
self.start_batch_size = start_batch_size
|
||||
|
||||
assert global_batch_size > 0
|
||||
self.global_batch_size = global_batch_size
|
||||
diff_batch_size = self.global_batch_size - self.start_batch_size
|
||||
assert diff_batch_size >= 0
|
||||
assert batch_size_increment > 0
|
||||
self.batch_size_increment = batch_size_increment
|
||||
assert diff_batch_size % batch_size_increment == 0, (
|
||||
"expected "
|
||||
"global batch size interval ({}) to be divisible by global batch "
|
||||
"size increment ({})".format(diff_batch_size, batch_size_increment)
|
||||
)
|
||||
|
||||
num_increments = diff_batch_size // self.batch_size_increment
|
||||
self.ramup_samples = ramup_samples
|
||||
assert self.ramup_samples >= 0
|
||||
self.rampup_samples_per_increment = self.ramup_samples / num_increments
|
||||
|
||||
# Initialize number of microbatches.
|
||||
self.update(0, False)
|
||||
|
||||
def update(self, consumed_samples, consistency_check):
|
||||
|
||||
if consumed_samples > self.ramup_samples:
|
||||
self.current_global_batch_size = self.global_batch_size
|
||||
else:
|
||||
steps = int(consumed_samples / self.rampup_samples_per_increment)
|
||||
self.current_global_batch_size = (
|
||||
self.start_batch_size + steps * self.batch_size_increment
|
||||
)
|
||||
assert self.current_global_batch_size <= self.global_batch_size
|
||||
|
||||
if consistency_check:
|
||||
assert (
|
||||
self.current_global_batch_size
|
||||
% self.micro_batch_times_data_parallel_size
|
||||
== 0
|
||||
), (
|
||||
"current global "
|
||||
"batch size ({}) is not divisible by micro-batch-size ({}) times"
|
||||
"data parallel size ({})".format(
|
||||
self.current_global_batch_size,
|
||||
self.micro_batch_size,
|
||||
self.data_parallel_size,
|
||||
)
|
||||
)
|
||||
self.num_micro_batches = (
|
||||
self.current_global_batch_size // self.micro_batch_times_data_parallel_size
|
||||
)
|
@ -0,0 +1,129 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from apex.optimizers import FusedAdam as Adam
|
||||
from apex.optimizers import FusedSGD as SGD
|
||||
|
||||
from codegeex.megatron import get_args
|
||||
from codegeex.megatron.model import LayerNorm
|
||||
|
||||
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
|
||||
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
|
||||
|
||||
|
||||
def _get_params_for_weight_decay_optimization(modules):
|
||||
"""Divide params into with-weight-decay and without-weight-decay groups.
|
||||
Layernorms and baises will have no weight decay but the rest will.
|
||||
"""
|
||||
|
||||
weight_decay_params = {"params": []}
|
||||
no_weight_decay_params = {"params": [], "weight_decay": 0.0}
|
||||
for module in modules:
|
||||
for module_ in module.modules():
|
||||
if isinstance(module_, LayerNorm):
|
||||
no_weight_decay_params["params"].extend(
|
||||
[p for p in list(module_._parameters.values()) if p is not None]
|
||||
)
|
||||
else:
|
||||
weight_decay_params["params"].extend(
|
||||
[
|
||||
p
|
||||
for n, p in list(module_._parameters.items())
|
||||
if p is not None and n != "bias"
|
||||
]
|
||||
)
|
||||
no_weight_decay_params["params"].extend(
|
||||
[
|
||||
p
|
||||
for n, p in list(module_._parameters.items())
|
||||
if p is not None and n == "bias"
|
||||
]
|
||||
)
|
||||
|
||||
return weight_decay_params, no_weight_decay_params
|
||||
|
||||
|
||||
def get_megatron_optimizer(model):
|
||||
args = get_args()
|
||||
|
||||
if args.cpu_optimizer:
|
||||
raise NotImplementedError("need to add cpu adam")
|
||||
|
||||
param_groups = _get_params_for_weight_decay_optimization(model)
|
||||
|
||||
if args.optimizer == "adam":
|
||||
optimizer = Adam(
|
||||
param_groups,
|
||||
lr=args.lr,
|
||||
weight_decay=args.weight_decay,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
eps=args.adam_eps,
|
||||
)
|
||||
elif args.optimizer == "sgd":
|
||||
optimizer = SGD(
|
||||
param_groups,
|
||||
lr=args.lr,
|
||||
weight_decay=args.weight_decay,
|
||||
momentum=args.sgd_momentum,
|
||||
)
|
||||
else:
|
||||
raise Exception("{} optimizer is not supported.".format(args.optimizer))
|
||||
|
||||
if args.deepspeed:
|
||||
return optimizer
|
||||
|
||||
# Determine whether the params have main-grad field.
|
||||
params_have_main_grad = False
|
||||
if args.DDP_impl == "local":
|
||||
params_have_main_grad = True
|
||||
|
||||
if args.fp16 or args.bf16:
|
||||
|
||||
# Grad scaler:
|
||||
# if loss-scale is provided, instantiate the constant scaler.
|
||||
# if we are using fp16 and loss-scale is not present, use a
|
||||
# dynamic scaler.
|
||||
# otherwise we are running in bf16 with no loss-scale so
|
||||
# leave it as None.
|
||||
grad_scaler = None
|
||||
# Constant loss scale.
|
||||
if args.loss_scale:
|
||||
grad_scaler = ConstantGradScaler(args.loss_scale)
|
||||
# Dynamic loss scale.
|
||||
else:
|
||||
if args.fp16:
|
||||
grad_scaler = DynamicGradScaler(
|
||||
initial_scale=args.initial_loss_scale,
|
||||
min_scale=args.min_loss_scale,
|
||||
growth_factor=2.0,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=args.loss_scale_window,
|
||||
hysteresis=args.hysteresis,
|
||||
)
|
||||
|
||||
# Megatron optimizer.
|
||||
return Float16OptimizerWithFloat16Params(
|
||||
optimizer,
|
||||
args.clip_grad,
|
||||
args.log_num_zeros_in_grad,
|
||||
params_have_main_grad,
|
||||
args.bf16,
|
||||
grad_scaler,
|
||||
)
|
||||
|
||||
# FP32.
|
||||
return FP32Optimizer(
|
||||
optimizer, args.clip_grad, args.log_num_zeros_in_grad, params_have_main_grad
|
||||
)
|
@ -0,0 +1,153 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Gradient clipping."""
|
||||
|
||||
import torch
|
||||
from torch._six import inf
|
||||
|
||||
from apex.multi_tensor_apply import multi_tensor_applier
|
||||
import amp_C
|
||||
|
||||
from codegeex.megatron import mpu
|
||||
from codegeex.megatron.model.module import param_is_not_shared
|
||||
from codegeex.megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
|
||||
|
||||
|
||||
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
"""Clips gradient norm of an iterable of parameters whose gradients
|
||||
are in fp32.
|
||||
|
||||
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
|
||||
added functionality to handle model parallel parameters. Note that
|
||||
the gradients are modified in place.
|
||||
|
||||
Arguments:
|
||||
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||||
single Tensor that will have gradients normalized
|
||||
max_norm (float or int): max norm of the gradients
|
||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
|
||||
Returns:
|
||||
Total norm of the parameters (viewed as a single vector).
|
||||
"""
|
||||
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
|
||||
# Filter parameters based on:
|
||||
# - grad should not be none
|
||||
# - parameter should not be shared
|
||||
# - should not be a replica due to tensor model parallelism
|
||||
grads = []
|
||||
grads_for_norm = []
|
||||
for param in parameters:
|
||||
grad_not_none = param.grad is not None
|
||||
is_not_shared = param_is_not_shared(param)
|
||||
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
|
||||
grad = param.grad.detach()
|
||||
if grad_not_none:
|
||||
# Make sure the grads are in fp32
|
||||
assert param.grad.type() == "torch.cuda.FloatTensor"
|
||||
grads.append(grad)
|
||||
if grad_not_none and is_not_shared and is_not_tp_duplicate:
|
||||
grads_for_norm.append(grad)
|
||||
|
||||
# Norm parameters.
|
||||
max_norm = float(max_norm)
|
||||
norm_type = float(norm_type)
|
||||
total_norm = 0.0
|
||||
|
||||
# Calculate norm.
|
||||
if norm_type == inf:
|
||||
total_norm = max(grad.abs().max() for grad in grads_for_norm)
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
# Take max across all model-parallel GPUs.
|
||||
torch.distributed.all_reduce(
|
||||
total_norm_cuda,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=mpu.get_model_parallel_group(),
|
||||
)
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
|
||||
else:
|
||||
if norm_type == 2.0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
# Use apex's multi-tensor applier for efficiency reasons.
|
||||
# Multi-tensor applier takes a function and a list of list
|
||||
# and performs the operation on that list all in one kernel.
|
||||
grad_norm, _ = multi_tensor_applier(
|
||||
amp_C.multi_tensor_l2norm,
|
||||
dummy_overflow_buf,
|
||||
[grads_for_norm],
|
||||
False, # no per-parameter norm
|
||||
)
|
||||
# Since we will be summing across data parallel groups,
|
||||
# we need the pow(norm-type).
|
||||
total_norm = grad_norm ** norm_type
|
||||
|
||||
else:
|
||||
for grad in grads_for_norm:
|
||||
grad_norm = torch.norm(grad, norm_type)
|
||||
total_norm += grad_norm ** norm_type
|
||||
|
||||
# Sum across all model-parallel GPUs.
|
||||
torch.distributed.all_reduce(
|
||||
total_norm,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=mpu.get_model_parallel_group(),
|
||||
)
|
||||
total_norm = total_norm.item() ** (1.0 / norm_type)
|
||||
|
||||
# Scale.
|
||||
clip_coeff = max_norm / (total_norm + 1.0e-6)
|
||||
if clip_coeff < 1.0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
multi_tensor_applier(
|
||||
amp_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff
|
||||
)
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
def count_zeros_fp32(parameters):
|
||||
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
|
||||
# Filter parameters based on:
|
||||
# - grad should not be none
|
||||
# - parameter should not be shared
|
||||
# - should not be a replica due to tensor model parallelism
|
||||
total_num_zeros = 0.0
|
||||
for param in parameters:
|
||||
grad_not_none = param.grad is not None
|
||||
is_not_shared = param_is_not_shared(param)
|
||||
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
|
||||
if grad_not_none and is_not_shared and is_not_tp_duplicate:
|
||||
grad = param.grad.detach()
|
||||
num_zeros = grad.numel() - torch.count_nonzero(grad)
|
||||
total_num_zeros = num_zeros + total_num_zeros
|
||||
|
||||
# Sum across all model-parallel GPUs.
|
||||
torch.distributed.all_reduce(
|
||||
total_num_zeros,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=mpu.get_model_parallel_group(),
|
||||
)
|
||||
total_num_zeros = total_num_zeros.item()
|
||||
|
||||
return total_num_zeros
|
@ -0,0 +1,132 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Megatron grad scaler."""
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MegatronGradScaler(ABC):
|
||||
def __init__(self, initial_scale):
|
||||
"""Initialize scale value with the input initial scale."""
|
||||
assert initial_scale > 0.0
|
||||
self._scale = torch.cuda.FloatTensor([initial_scale])
|
||||
|
||||
@property
|
||||
def scale(self):
|
||||
return self._scale
|
||||
|
||||
@property
|
||||
def inv_scale(self):
|
||||
return self._scale.double().reciprocal().float()
|
||||
|
||||
@abstractmethod
|
||||
def update(self, found_inf):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def state_dict(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_state_dict(self, state_dict):
|
||||
pass
|
||||
|
||||
|
||||
class ConstantGradScaler(MegatronGradScaler):
|
||||
def update(self, found_inf):
|
||||
pass
|
||||
|
||||
def state_dict(self):
|
||||
return dict()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
pass
|
||||
|
||||
|
||||
class DynamicGradScaler(MegatronGradScaler):
|
||||
def __init__(
|
||||
self,
|
||||
initial_scale,
|
||||
min_scale,
|
||||
growth_factor,
|
||||
backoff_factor,
|
||||
growth_interval,
|
||||
hysteresis,
|
||||
):
|
||||
""" "Grad scaler with dynamic scale that gets adjusted
|
||||
during training."""
|
||||
super(DynamicGradScaler, self).__init__(initial_scale)
|
||||
|
||||
# Lower bound on the scale.
|
||||
assert min_scale > 0.0
|
||||
assert min_scale <= initial_scale
|
||||
self.min_scale = torch.cuda.FloatTensor([min_scale])
|
||||
# Growth and backoff factors for the scale.
|
||||
assert growth_factor > 1.0
|
||||
self.growth_factor = torch.cuda.FloatTensor([growth_factor])
|
||||
assert backoff_factor < 1.0
|
||||
assert backoff_factor > 0.0
|
||||
self.backoff_factor = torch.cuda.FloatTensor([backoff_factor])
|
||||
# Interval over which if we don't see any inf/nan,
|
||||
# we will scale the grad scale by the growth factor.
|
||||
assert growth_interval > 0
|
||||
self.growth_interval = growth_interval
|
||||
# Number of inf/nans we should see before scaling down
|
||||
# the grad scale by the backoff factor.
|
||||
assert hysteresis > 0
|
||||
self.hysteresis = hysteresis
|
||||
|
||||
# Trackers.
|
||||
self._growth_tracker = 0
|
||||
self._hysteresis_tracker = self.hysteresis
|
||||
|
||||
def update(self, found_inf):
|
||||
|
||||
# If we have an inf/nan, growth tracker is set to 0
|
||||
# and hysterisis tracker is reduced by 1.
|
||||
if found_inf:
|
||||
self._growth_tracker = 0
|
||||
self._hysteresis_tracker -= 1
|
||||
# Now if we are out of hysteresis count, scale down the loss.
|
||||
if self._hysteresis_tracker <= 0:
|
||||
self._scale = torch.max(
|
||||
self._scale * self.backoff_factor, self.min_scale
|
||||
)
|
||||
else:
|
||||
# If there is no nan/inf, increment the growth tracker.
|
||||
self._growth_tracker += 1
|
||||
# If we have had enough consequitive intervals with no nan/inf:
|
||||
if self._growth_tracker == self.growth_interval:
|
||||
# Reset the tracker and hysteresis trackers,
|
||||
self._growth_tracker = 0
|
||||
self._hysteresis_tracker = self.hysteresis
|
||||
# and scale up the loss scale.
|
||||
self._scale = self._scale * self.growth_factor
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {}
|
||||
state_dict["scale"] = self._scale
|
||||
state_dict["growth_tracker"] = self._growth_tracker
|
||||
state_dict["hysteresis_tracker"] = self._hysteresis_tracker
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._scale = state_dict["scale"].cuda(torch.cuda.current_device())
|
||||
self._growth_tracker = state_dict["growth_tracker"]
|
||||
self._hysteresis_tracker = state_dict["hysteresis_tracker"]
|
@ -0,0 +1,505 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Megatron optimizer."""
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from apex.multi_tensor_apply import multi_tensor_applier
|
||||
import amp_C
|
||||
|
||||
from codegeex.megatron import get_timers
|
||||
from codegeex.megatron import mpu
|
||||
from codegeex.megatron import print_rank_0
|
||||
|
||||
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
|
||||
|
||||
|
||||
def _zero_grad_group_helper(group, set_to_none):
|
||||
"""Zero out the gradient for a group of parameters.
|
||||
Note: copied from torch.optim.optimizer."""
|
||||
for param in group:
|
||||
if param.grad is not None:
|
||||
if set_to_none:
|
||||
param.grad = None
|
||||
else:
|
||||
if param.grad.grad_fn is not None:
|
||||
param.grad.detach_()
|
||||
else:
|
||||
param.grad.requires_grad_(False)
|
||||
param.grad.zero_()
|
||||
|
||||
|
||||
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
|
||||
"""Use multi-tensor-applier to copy values from one list to another.
|
||||
We don't have a blfoat16 implementation so for now if the overflow_buf
|
||||
is not provided, we default back to simple loop copy to be compatible
|
||||
with bfloat16."""
|
||||
if overflow_buf:
|
||||
overflow_buf.fill_(0)
|
||||
# Scaling with factor `1.0` is equivalent to copy.
|
||||
multi_tensor_applier(amp_C.multi_tensor_scale, overflow_buf, [this, that], 1.0)
|
||||
else:
|
||||
for this_, that_ in zip(this, that):
|
||||
that_.copy_(this_)
|
||||
|
||||
|
||||
class MegatronOptimizer(ABC):
|
||||
def __init__(
|
||||
self, optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad
|
||||
):
|
||||
"""Input optimizer is the base optimizer for example Adam."""
|
||||
self.optimizer = optimizer
|
||||
assert self.optimizer, "no optimizer is provided."
|
||||
# Set gradient clipping and logging params.
|
||||
self.clip_grad = clip_grad
|
||||
self.log_num_zeros_in_grad = log_num_zeros_in_grad
|
||||
self.params_have_main_grad = params_have_main_grad
|
||||
|
||||
def get_parameters(self):
|
||||
params = []
|
||||
for param_group in self.optimizer.param_groups:
|
||||
for param in param_group["params"]:
|
||||
params.append(param)
|
||||
return params
|
||||
|
||||
def clip_grad_norm(self, clip_grad):
|
||||
params = self.get_parameters()
|
||||
return clip_grad_norm_fp32(params, clip_grad)
|
||||
|
||||
def count_zeros(self):
|
||||
params = self.get_parameters()
|
||||
return count_zeros_fp32(params)
|
||||
|
||||
@abstractmethod
|
||||
def zero_grad(self, set_to_none=True):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_loss_scale(self):
|
||||
"""The output should be a cuda tensor of size 1."""
|
||||
pass
|
||||
|
||||
def scale_loss(self, loss):
|
||||
"""Simple scaling."""
|
||||
return self.get_loss_scale() * loss
|
||||
|
||||
@abstractmethod
|
||||
def step(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reload_model_params(self):
|
||||
"""Refreshes any internal state from the current model parameters.
|
||||
Call whenever the parameters are changed outside of the optimizer.
|
||||
For example, when we load a model from a checkpoint without loading
|
||||
the optimizer, the model parameters are updated but for fp16 optimizer
|
||||
with main parameters, the main parameters need to also be updated."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def state_dict(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_state_dict(self, state_dict):
|
||||
pass
|
||||
|
||||
# Promote state so it can be retrieved or set via
|
||||
# "optimizer_instance.state"
|
||||
def _get_state(self):
|
||||
return self.optimizer.state
|
||||
|
||||
def _set_state(self, value):
|
||||
self.optimizer.state = value
|
||||
|
||||
state = property(_get_state, _set_state)
|
||||
|
||||
# Promote param_groups so it can be retrieved or set via
|
||||
# "optimizer_instance.param_groups"
|
||||
# (for example, to adjust the learning rate)
|
||||
def _get_param_groups(self):
|
||||
return self.optimizer.param_groups
|
||||
|
||||
def _set_param_groups(self, value):
|
||||
self.optimizer.param_groups = value
|
||||
|
||||
param_groups = property(_get_param_groups, _set_param_groups)
|
||||
|
||||
|
||||
class Float16OptimizerWithFloat16Params(MegatronOptimizer):
|
||||
"""Float16 optimizer for fp16 and bf16 data types.
|
||||
|
||||
Arguments:
|
||||
optimizer: base optimizer such as Adam or SGD
|
||||
clip_grad: clip gradeints with this global L2 norm. Note
|
||||
that clipping is ignored if clip_grad == 0
|
||||
log_num_zeros_in_grad: return number of zeros in the gradients.
|
||||
params_have_main_grad: flag indicating if parameters have
|
||||
a `main_grad` field. If this is set, we are assuming
|
||||
that the model parameters are store in the `main_grad`
|
||||
field instead of the typical `grad` field. This happens
|
||||
for the DDP cases where there is a contihuous buffer
|
||||
holding the gradients. For example for bfloat16, we want
|
||||
to do gradient accumulation and all-reduces in float32
|
||||
and as a result we store those gradients in the main_grad.
|
||||
Note that main grad is not necessarily in float32.
|
||||
bf16: if true, the model is running in bfloat16.
|
||||
grad_scaler: used for scaling gradients. Note that this can be
|
||||
None. This case happens when `bf16 = True` and we don't
|
||||
use any loss scale. Note that for `bf16 = True`, we can have
|
||||
a constnat gradient scaler. Also for `bf16 = False`, we
|
||||
always require a grad scaler.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
clip_grad,
|
||||
log_num_zeros_in_grad,
|
||||
params_have_main_grad,
|
||||
bf16,
|
||||
grad_scaler,
|
||||
):
|
||||
|
||||
super(Float16OptimizerWithFloat16Params, self).__init__(
|
||||
optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad
|
||||
)
|
||||
|
||||
self.bf16 = bf16
|
||||
self.grad_scaler = grad_scaler
|
||||
# None grad scaler is only supported for bf16.
|
||||
if self.grad_scaler is None:
|
||||
assert self.bf16, "fp16 expects a grad scaler."
|
||||
|
||||
# Tensor used to determine if a nan/if has happend.
|
||||
# Any non-zero value indicates inf/nan.
|
||||
# Note that we keep this for the cases that grad scaler is none.
|
||||
# We still record nan/inf if we have a bfloat16 with a grad scaler.
|
||||
if self.grad_scaler:
|
||||
self.found_inf = torch.cuda.FloatTensor([0.0])
|
||||
|
||||
# Dummy tensor needed for apex multi-apply tensor.
|
||||
# For bfloat, we don't have multi-tensor apply and for now
|
||||
# we set it to none so the multi-tensor apply gets ignored.
|
||||
if bf16:
|
||||
self._dummy_overflow_buf = None
|
||||
else:
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
||||
# In case grad scaler is not passed, define the unity scale.
|
||||
if self.grad_scaler is None:
|
||||
self._scale_one = torch.cuda.FloatTensor([1.0])
|
||||
|
||||
# ======================
|
||||
# main parameter stuff
|
||||
# ======================
|
||||
|
||||
# Three groups of parameters:
|
||||
# float16_groups: original float16 parameters
|
||||
# fp32_from_float16_groups: fp32 copy of float16 parameters
|
||||
# fp32_from_fp32_groups: original fp32 parameters
|
||||
self.float16_groups = []
|
||||
self.fp32_from_float16_groups = []
|
||||
self.fp32_from_fp32_groups = []
|
||||
|
||||
# For all the groups in the original optimizer:
|
||||
for param_group in self.optimizer.param_groups:
|
||||
float16_params_this_group = []
|
||||
fp32_params_this_group = []
|
||||
fp32_from_float16_params_this_group = []
|
||||
# For all the parameters in this group:
|
||||
for i, param in enumerate(param_group["params"]):
|
||||
if param.requires_grad:
|
||||
|
||||
# float16 params:
|
||||
if param.type() in [
|
||||
"torch.cuda.HalfTensor",
|
||||
"torch.cuda.BFloat16Tensor",
|
||||
]:
|
||||
float16_params_this_group.append(param)
|
||||
# Create a copy
|
||||
main_param = param.detach().clone().float()
|
||||
# Copy tensor model parallel attributes.
|
||||
mpu.copy_tensor_model_parallel_attributes(main_param, param)
|
||||
if hasattr(param, "shared"):
|
||||
main_param.shared = param.shared
|
||||
# Replace the optimizer params with the new fp32 copy.
|
||||
param_group["params"][i] = main_param
|
||||
fp32_from_float16_params_this_group.append(main_param)
|
||||
# Reset existing state dict key to the new main param.
|
||||
if param in self.optimizer.state:
|
||||
self.optimizer.state[main_param] = self.optimizer.state.pop(
|
||||
param
|
||||
)
|
||||
|
||||
# fp32 params.
|
||||
elif param.type() == "torch.cuda.FloatTensor":
|
||||
fp32_params_this_group.append(param)
|
||||
param_group["params"][i] = param
|
||||
|
||||
else:
|
||||
raise TypeError(
|
||||
"Wrapped parameters must be one of "
|
||||
"torch.cuda.FloatTensor, "
|
||||
"torch.cuda.HalfTensor, or "
|
||||
"torch.cuda.BFloat16Tensor. "
|
||||
"Received {}".format(param.type())
|
||||
)
|
||||
|
||||
self.float16_groups.append(float16_params_this_group)
|
||||
self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group)
|
||||
self.fp32_from_fp32_groups.append(fp32_params_this_group)
|
||||
|
||||
# Leverage state_dict() and load_state_dict() to
|
||||
# recast preexisting per-param state tensors
|
||||
self.optimizer.load_state_dict(self.optimizer.state_dict())
|
||||
|
||||
def zero_grad(self, set_to_none=True):
|
||||
"""We only need to zero the model related parameters, i.e.,
|
||||
float16_groups & fp32_from_fp32_groups."""
|
||||
for group in self.float16_groups:
|
||||
_zero_grad_group_helper(group, set_to_none)
|
||||
for group in self.fp32_from_fp32_groups:
|
||||
_zero_grad_group_helper(group, set_to_none)
|
||||
|
||||
def get_loss_scale(self):
|
||||
if self.grad_scaler is None:
|
||||
return self._scale_one
|
||||
return self.grad_scaler.scale
|
||||
|
||||
def _copy_model_grads_to_main_grads(self):
|
||||
# This only needs to be done for the float16 group.
|
||||
for model_group, main_group in zip(
|
||||
self.float16_groups, self.fp32_from_float16_groups
|
||||
):
|
||||
for model_param, main_param in zip(model_group, main_group):
|
||||
if self.params_have_main_grad:
|
||||
main_param.grad = model_param.main_grad.float()
|
||||
else:
|
||||
if model_param.grad is not None:
|
||||
main_param.grad = model_param.grad.float()
|
||||
# For fp32 grads, we need to reset the grads to main grad.
|
||||
if self.params_have_main_grad:
|
||||
for model_group in self.fp32_from_fp32_groups:
|
||||
for model_param in model_group:
|
||||
model_param.grad = model_param.main_grad
|
||||
|
||||
def _unscale_main_grads_and_check_for_nan(self):
|
||||
main_grads = []
|
||||
# fp32 params fromm float16 ones.
|
||||
for main_group in self.fp32_from_float16_groups:
|
||||
for main_param in main_group:
|
||||
if main_param.grad is not None:
|
||||
main_grads.append(main_param.grad.data)
|
||||
# Append fp32 parameters.
|
||||
for main_group in self.fp32_from_fp32_groups:
|
||||
for main_param in main_group:
|
||||
if main_param.grad is not None:
|
||||
main_grads.append(main_param.grad.data)
|
||||
# Reset found inf.
|
||||
self.found_inf.fill_(0.0)
|
||||
# Unscale and set found inf/nan
|
||||
torch._amp_foreach_non_finite_check_and_unscale_(
|
||||
main_grads, self.found_inf, self.grad_scaler.inv_scale
|
||||
)
|
||||
# Update across all model parallel instances.
|
||||
torch.distributed.all_reduce(
|
||||
self.found_inf,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=mpu.get_model_parallel_group(),
|
||||
)
|
||||
|
||||
# Check for nan.
|
||||
found_inf_flag = self.found_inf.item() > 0
|
||||
return found_inf_flag
|
||||
|
||||
def _get_model_and_main_params_data_float16(self):
|
||||
model_data = []
|
||||
main_data = []
|
||||
for model_group, main_group in zip(
|
||||
self.float16_groups, self.fp32_from_float16_groups
|
||||
):
|
||||
for model_param, main_param in zip(model_group, main_group):
|
||||
model_data.append(model_param.data)
|
||||
main_data.append(main_param.data)
|
||||
return model_data, main_data
|
||||
|
||||
def _copy_main_params_to_model_params(self):
|
||||
# Only needed for the float16 params.
|
||||
model_data, main_data = self._get_model_and_main_params_data_float16()
|
||||
_multi_tensor_copy_this_to_that(
|
||||
this=main_data, that=model_data, overflow_buf=self._dummy_overflow_buf
|
||||
)
|
||||
|
||||
def _copy_model_params_to_main_params(self):
|
||||
# Only needed for the float16 params.
|
||||
model_data, main_data = self._get_model_and_main_params_data_float16()
|
||||
_multi_tensor_copy_this_to_that(
|
||||
this=model_data, that=main_data, overflow_buf=self._dummy_overflow_buf
|
||||
)
|
||||
|
||||
def reload_model_params(self):
|
||||
self._copy_model_params_to_main_params()
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
|
||||
timers = get_timers()
|
||||
|
||||
# Copy gradients from model params to main params.
|
||||
timers("optimizer-copy-to-main-grad").start()
|
||||
self._copy_model_grads_to_main_grads()
|
||||
timers("optimizer-copy-to-main-grad").stop()
|
||||
|
||||
# Do unscale, check for inf, and update grad scaler only for
|
||||
# the case that grad scaler is provided.
|
||||
if self.grad_scaler:
|
||||
|
||||
# Unscale and check for inf/nan.
|
||||
timers("optimizer-unscale-and-check-inf").start()
|
||||
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
|
||||
timers("optimizer-unscale-and-check-inf").stop()
|
||||
|
||||
# We are done with scaling gradients
|
||||
# so we can update the loss scale.
|
||||
self.grad_scaler.update(found_inf_flag)
|
||||
|
||||
# If we found inf/nan, skip the update.
|
||||
if found_inf_flag:
|
||||
return False, None, None
|
||||
|
||||
# Clip the main gradients.
|
||||
timers("optimizer-clip-main-grad").start()
|
||||
grad_norm = None
|
||||
if self.clip_grad > 0.0:
|
||||
grad_norm = self.clip_grad_norm(self.clip_grad)
|
||||
timers("optimizer-clip-main-grad").stop()
|
||||
|
||||
# count the zeros in the grads
|
||||
num_zeros_in_grad = self.count_zeros() if self.log_num_zeros_in_grad else None
|
||||
|
||||
# Step the optimizer.
|
||||
self.optimizer.step()
|
||||
|
||||
# Update params from main params.
|
||||
timers("optimizer-copy-main-to-model-params").start()
|
||||
self._copy_main_params_to_model_params()
|
||||
timers("optimizer-copy-main-to-model-params").stop()
|
||||
|
||||
# Successful update.
|
||||
return True, grad_norm, num_zeros_in_grad
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {}
|
||||
state_dict["optimizer"] = self.optimizer.state_dict()
|
||||
if self.grad_scaler:
|
||||
state_dict["grad_scaler"] = self.grad_scaler.state_dict()
|
||||
state_dict["fp32_from_fp16_params"] = self.fp32_from_float16_groups
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
# Optimizer.
|
||||
optimizer_key = "optimizer"
|
||||
if optimizer_key not in state_dict:
|
||||
optimizer_key = "optimizer_state_dict"
|
||||
print_rank_0(
|
||||
"***WARNING*** loading optimizer from " "an old checkpoint ..."
|
||||
)
|
||||
self.optimizer.load_state_dict(state_dict[optimizer_key])
|
||||
|
||||
# Grad scaler.
|
||||
if "grad_scaler" not in state_dict:
|
||||
print_rank_0(
|
||||
"***WARNING*** found an old checkpoint, will not "
|
||||
"load grad scaler ..."
|
||||
)
|
||||
else:
|
||||
if self.grad_scaler:
|
||||
self.grad_scaler.load_state_dict(state_dict["grad_scaler"])
|
||||
else:
|
||||
print_rank_0(
|
||||
"***WARNING*** fould the grad scaler in the "
|
||||
"checkpoint but it is None in the class. "
|
||||
"Skipping loading grad scaler ..."
|
||||
)
|
||||
|
||||
# Copy data for the main params.
|
||||
fp32_from_float16_params_key = "fp32_from_fp16_params"
|
||||
if fp32_from_float16_params_key not in state_dict:
|
||||
fp32_from_float16_params_key = "fp32_from_fp16"
|
||||
for current_group, saved_group in zip(
|
||||
self.fp32_from_float16_groups, state_dict[fp32_from_float16_params_key]
|
||||
):
|
||||
for current_param, saved_param in zip(current_group, saved_group):
|
||||
current_param.data.copy_(saved_param.data)
|
||||
|
||||
|
||||
class FP32Optimizer(MegatronOptimizer):
|
||||
def __init__(
|
||||
self, optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad
|
||||
):
|
||||
|
||||
super(FP32Optimizer, self).__init__(
|
||||
optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad
|
||||
)
|
||||
|
||||
self._scale = torch.cuda.FloatTensor([1.0])
|
||||
|
||||
def zero_grad(self, set_to_none=True):
|
||||
"""Copied from torch.optim.optimizer"""
|
||||
for group in self.optimizer.param_groups:
|
||||
_zero_grad_group_helper(group["params"], set_to_none)
|
||||
|
||||
def get_loss_scale(self):
|
||||
"""FP32 optimizer does not do any scaling."""
|
||||
return self._scale
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
"""Clip gradients (if needed) and step the base optimizer.
|
||||
Always return successful since there is no overflow."""
|
||||
|
||||
# Copy main_grads to grads.
|
||||
if self.params_have_main_grad:
|
||||
for param_group in self.optimizer.param_groups:
|
||||
for param in param_group["params"]:
|
||||
param.grad = param.main_grad
|
||||
|
||||
# Clip gradients.
|
||||
grad_norm = None
|
||||
if self.clip_grad > 0.0:
|
||||
grad_norm = self.clip_grad_norm(self.clip_grad)
|
||||
|
||||
# count the zeros in the grads
|
||||
num_zeros_in_grad = self.count_zeros() if self.log_num_zeros_in_grad else None
|
||||
|
||||
# Update parameters.
|
||||
self.optimizer.step()
|
||||
|
||||
# No overflow for FP32 optimizer.
|
||||
return True, grad_norm, num_zeros_in_grad
|
||||
|
||||
def reload_model_params(self):
|
||||
pass
|
||||
|
||||
def state_dict(self):
|
||||
return self.optimizer.state_dict()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.optimizer.load_state_dict(state_dict)
|
@ -0,0 +1,510 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from contextlib import contextmanager
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
|
||||
|
||||
from codegeex.megatron import get_args
|
||||
from codegeex.megatron import get_num_microbatches
|
||||
from codegeex.megatron import get_timers
|
||||
from codegeex.megatron import mpu
|
||||
from codegeex.megatron import p2p_communication
|
||||
from codegeex.megatron.utils import unwrap_model
|
||||
from codegeex.megatron.model import DistributedDataParallel as LocalDDP
|
||||
from codegeex.megatron.model import Float16Module
|
||||
|
||||
|
||||
def get_forward_backward_func():
|
||||
args = get_args()
|
||||
if mpu.get_pipeline_model_parallel_world_size() > 1:
|
||||
if args.virtual_pipeline_model_parallel_size is not None:
|
||||
forward_backward_func = forward_backward_pipelining_with_interleaving
|
||||
else:
|
||||
forward_backward_func = forward_backward_pipelining_without_interleaving
|
||||
else:
|
||||
forward_backward_func = forward_backward_no_pipelining
|
||||
return forward_backward_func
|
||||
|
||||
|
||||
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
|
||||
"""Forward step for passed-in model.
|
||||
|
||||
If first stage, input tensor is obtained from data_iterator, otherwise
|
||||
passed-in input_tensor is used.
|
||||
|
||||
Returns output tensor."""
|
||||
timers = get_timers()
|
||||
|
||||
args = get_args()
|
||||
|
||||
timers("forward-compute").start()
|
||||
unwrapped_model = unwrap_model(model, (torchDDP, LocalDDP, Float16Module))
|
||||
if not args.deepspeed:
|
||||
unwrapped_model.set_input_tensor(input_tensor)
|
||||
else:
|
||||
unwrapped_model.module.set_input_tensor(input_tensor)
|
||||
|
||||
output_tensor, loss_func = forward_step_func(data_iterator, model)
|
||||
if mpu.is_pipeline_last_stage():
|
||||
output_tensor = loss_func(output_tensor)
|
||||
loss, loss_reduced = output_tensor
|
||||
output_tensor = loss / get_num_microbatches()
|
||||
losses_reduced.append(loss_reduced)
|
||||
timers("forward-compute").stop()
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def backward_step(
|
||||
optimizer, input_tensor, output_tensor, output_tensor_grad, model=None
|
||||
):
|
||||
"""Backward step through passed-in output tensor.
|
||||
|
||||
If last stage, output_tensor_grad is None, otherwise gradient of loss
|
||||
with respect to stage's output tensor.
|
||||
|
||||
Returns gradient of loss with respect to input tensor (None if first
|
||||
stage)."""
|
||||
args = get_args()
|
||||
|
||||
if args.deepspeed:
|
||||
assert model is not None
|
||||
|
||||
timers = get_timers()
|
||||
timers("backward-compute").start()
|
||||
|
||||
# Retain the grad on the input_tensor.
|
||||
if input_tensor is not None:
|
||||
input_tensor.retain_grad()
|
||||
|
||||
if args.deepspeed:
|
||||
model.backward(output_tensor)
|
||||
else:
|
||||
# Backward pass.
|
||||
if output_tensor_grad is None:
|
||||
output_tensor = optimizer.scale_loss(output_tensor)
|
||||
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
|
||||
|
||||
# Collect the grad of the input_tensor.
|
||||
input_tensor_grad = None
|
||||
if input_tensor is not None:
|
||||
input_tensor_grad = input_tensor.grad
|
||||
|
||||
timers("backward-compute").stop()
|
||||
|
||||
return input_tensor_grad
|
||||
|
||||
|
||||
@contextmanager
|
||||
def dummy_handler():
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
def forward_backward_no_pipelining(
|
||||
forward_step_func, data_iterator, model, optimizer, timers, forward_only
|
||||
):
|
||||
"""Run forward and backward passes with no pipeline parallelism
|
||||
(no inter-stage communication).
|
||||
|
||||
Returns dictionary with losses."""
|
||||
assert len(model) == 1
|
||||
model = model[0]
|
||||
|
||||
args = get_args()
|
||||
|
||||
context_handler = dummy_handler
|
||||
if isinstance(model, torchDDP):
|
||||
context_handler = model.no_sync
|
||||
|
||||
if args.deepspeed:
|
||||
model.set_gradient_accumulation_boundary(False)
|
||||
|
||||
losses_reduced = []
|
||||
input_tensor, output_tensor_grad = None, None
|
||||
with context_handler():
|
||||
for i in range(get_num_microbatches() - 1):
|
||||
# print_rank_0("====> start of microstep {i}")
|
||||
# print_rank_0("====> forward")
|
||||
output_tensor = forward_step(
|
||||
forward_step_func, data_iterator, model, input_tensor, losses_reduced
|
||||
)
|
||||
# print_rank_0("====> backward")
|
||||
if not forward_only:
|
||||
backward_step(
|
||||
optimizer, input_tensor, output_tensor, output_tensor_grad, model
|
||||
)
|
||||
# print_rank_0("====> end of microstep {i}")
|
||||
|
||||
if args.deepspeed:
|
||||
model.set_gradient_accumulation_boundary(True)
|
||||
|
||||
# Run computation for last microbatch out of context handler (want to
|
||||
# synchronize gradients).
|
||||
# print_rank_0("====> start of the last microstep")
|
||||
# print_rank_0("====> forward")
|
||||
output_tensor = forward_step(
|
||||
forward_step_func, data_iterator, model, input_tensor, losses_reduced
|
||||
)
|
||||
# print_rank_0("====> backward")
|
||||
if not forward_only:
|
||||
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, model)
|
||||
# print_rank_0("====> end of the last microstep")
|
||||
|
||||
return losses_reduced
|
||||
|
||||
|
||||
def forward_backward_pipelining_with_interleaving(
|
||||
forward_step_func, data_iterator, model, optimizer, timers, forward_only
|
||||
):
|
||||
"""Run interleaved 1F1B schedule (model split into model chunks), with
|
||||
communication between pipeline stages as needed.
|
||||
|
||||
Returns dictionary with losses if the last stage, empty dict otherwise."""
|
||||
input_tensors = [[] for _ in range(len(model))]
|
||||
output_tensors = [[] for _ in range(len(model))]
|
||||
losses_reduced = []
|
||||
if not forward_only:
|
||||
output_tensor_grads = [[] for _ in range(len(model))]
|
||||
|
||||
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
|
||||
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
|
||||
|
||||
# Compute number of warmup and remaining microbatches.
|
||||
num_model_chunks = len(model)
|
||||
num_microbatches = get_num_microbatches() * num_model_chunks
|
||||
all_warmup_microbatches = False
|
||||
if forward_only:
|
||||
num_warmup_microbatches = num_microbatches
|
||||
else:
|
||||
# Run all forward passes and then all backward passes if number of
|
||||
# microbatches is just the number of pipeline stages.
|
||||
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
|
||||
# all workers, followed by more microbatches after depending on
|
||||
# stage ID (more forward passes for earlier stages, later stages can
|
||||
# immediately start with 1F1B).
|
||||
if get_num_microbatches() == pipeline_parallel_size:
|
||||
num_warmup_microbatches = num_microbatches
|
||||
all_warmup_microbatches = True
|
||||
else:
|
||||
num_warmup_microbatches = (
|
||||
pipeline_parallel_size - pipeline_parallel_rank - 1
|
||||
) * 2
|
||||
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
|
||||
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
|
||||
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
|
||||
|
||||
def get_model_chunk_id(microbatch_id, forward):
|
||||
"""Helper method to get the model chunk ID given the iteration number."""
|
||||
microbatch_id_in_group = microbatch_id % (
|
||||
pipeline_parallel_size * num_model_chunks
|
||||
)
|
||||
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
|
||||
if not forward:
|
||||
model_chunk_id = num_model_chunks - model_chunk_id - 1
|
||||
return model_chunk_id
|
||||
|
||||
def forward_step_helper(microbatch_id):
|
||||
"""Helper method to run forward step with model split into chunks
|
||||
(run set_virtual_pipeline_model_parallel_rank() before calling
|
||||
forward_step())."""
|
||||
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
|
||||
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
|
||||
|
||||
if mpu.is_pipeline_first_stage():
|
||||
if len(input_tensors[model_chunk_id]) == len(
|
||||
output_tensors[model_chunk_id]
|
||||
):
|
||||
input_tensors[model_chunk_id].append(None)
|
||||
input_tensor = input_tensors[model_chunk_id][-1]
|
||||
output_tensor = forward_step(
|
||||
forward_step_func,
|
||||
data_iterator[model_chunk_id],
|
||||
model[model_chunk_id],
|
||||
input_tensor,
|
||||
losses_reduced,
|
||||
)
|
||||
output_tensors[model_chunk_id].append(output_tensor)
|
||||
|
||||
return output_tensor
|
||||
|
||||
def backward_step_helper(microbatch_id):
|
||||
"""Helper method to run backward step with model split into chunks
|
||||
(run set_virtual_pipeline_model_parallel_rank() before calling
|
||||
backward_step())."""
|
||||
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
|
||||
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
|
||||
|
||||
if mpu.is_pipeline_last_stage():
|
||||
if len(output_tensor_grads[model_chunk_id]) == 0:
|
||||
output_tensor_grads[model_chunk_id].append(None)
|
||||
input_tensor = input_tensors[model_chunk_id].pop(0)
|
||||
output_tensor = output_tensors[model_chunk_id].pop(0)
|
||||
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
|
||||
input_tensor_grad = backward_step(
|
||||
optimizer, input_tensor, output_tensor, output_tensor_grad
|
||||
)
|
||||
|
||||
return input_tensor_grad
|
||||
|
||||
# Run warmup forward passes.
|
||||
mpu.set_virtual_pipeline_model_parallel_rank(0)
|
||||
input_tensors[0].append(p2p_communication.recv_forward(timers))
|
||||
for k in range(num_warmup_microbatches):
|
||||
output_tensor = forward_step_helper(k)
|
||||
|
||||
# Determine if tensor should be received from previous stage.
|
||||
next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
|
||||
recv_prev = True
|
||||
if mpu.is_pipeline_first_stage(ignore_virtual=True):
|
||||
if next_forward_model_chunk_id == 0:
|
||||
recv_prev = False
|
||||
if k == (num_microbatches - 1):
|
||||
recv_prev = False
|
||||
|
||||
# Don't send tensor downstream if on last stage.
|
||||
if mpu.is_pipeline_last_stage():
|
||||
output_tensor = None
|
||||
|
||||
# Send and receive tensors as appropriate (send tensors computed
|
||||
# in this iteration; receive tensors for next iteration).
|
||||
if (
|
||||
k == (num_warmup_microbatches - 1)
|
||||
and not forward_only
|
||||
and not all_warmup_microbatches
|
||||
):
|
||||
input_tensor_grad = None
|
||||
recv_next = True
|
||||
if mpu.is_pipeline_last_stage(ignore_virtual=True):
|
||||
recv_next = False
|
||||
(
|
||||
input_tensor,
|
||||
output_tensor_grad,
|
||||
) = p2p_communication.send_forward_backward_recv_forward_backward(
|
||||
output_tensor,
|
||||
input_tensor_grad,
|
||||
recv_prev=recv_prev,
|
||||
recv_next=recv_next,
|
||||
timers=timers,
|
||||
)
|
||||
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
|
||||
else:
|
||||
input_tensor = p2p_communication.send_forward_recv_forward(
|
||||
output_tensor, recv_prev, timers
|
||||
)
|
||||
input_tensors[next_forward_model_chunk_id].append(input_tensor)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for k in range(num_microbatches_remaining):
|
||||
# Forward pass.
|
||||
forward_k = k + num_warmup_microbatches
|
||||
output_tensor = forward_step_helper(forward_k)
|
||||
|
||||
# Backward pass.
|
||||
backward_k = k
|
||||
input_tensor_grad = backward_step_helper(backward_k)
|
||||
|
||||
# Send output_tensor and input_tensor_grad, receive input_tensor
|
||||
# and output_tensor_grad.
|
||||
|
||||
# Determine if current stage has anything to send in either direction,
|
||||
# otherwise set tensor to None.
|
||||
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
|
||||
mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
|
||||
if mpu.is_pipeline_last_stage():
|
||||
output_tensor = None
|
||||
|
||||
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
|
||||
mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
|
||||
if mpu.is_pipeline_first_stage():
|
||||
input_tensor_grad = None
|
||||
|
||||
# Determine if peers are sending, and where in data structure to put
|
||||
# received tensors.
|
||||
recv_prev = True
|
||||
if mpu.is_pipeline_first_stage(ignore_virtual=True):
|
||||
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
|
||||
next_forward_model_chunk_id = get_model_chunk_id(
|
||||
forward_k - (pipeline_parallel_size - 1), forward=True
|
||||
)
|
||||
if next_forward_model_chunk_id == (num_model_chunks - 1):
|
||||
recv_prev = False
|
||||
next_forward_model_chunk_id += 1
|
||||
else:
|
||||
next_forward_model_chunk_id = get_model_chunk_id(
|
||||
forward_k + 1, forward=True
|
||||
)
|
||||
|
||||
recv_next = True
|
||||
if mpu.is_pipeline_last_stage(ignore_virtual=True):
|
||||
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
|
||||
next_backward_model_chunk_id = get_model_chunk_id(
|
||||
backward_k - (pipeline_parallel_size - 1), forward=False
|
||||
)
|
||||
if next_backward_model_chunk_id == 0:
|
||||
recv_next = False
|
||||
next_backward_model_chunk_id -= 1
|
||||
else:
|
||||
next_backward_model_chunk_id = get_model_chunk_id(
|
||||
backward_k + 1, forward=False
|
||||
)
|
||||
|
||||
# If last iteration, don't receive; we already received one extra
|
||||
# before the start of the for loop.
|
||||
if k == (num_microbatches_remaining - 1):
|
||||
recv_prev = False
|
||||
|
||||
# Communicate tensors.
|
||||
(
|
||||
input_tensor,
|
||||
output_tensor_grad,
|
||||
) = p2p_communication.send_forward_backward_recv_forward_backward(
|
||||
output_tensor,
|
||||
input_tensor_grad,
|
||||
recv_prev=recv_prev,
|
||||
recv_next=recv_next,
|
||||
timers=timers,
|
||||
)
|
||||
|
||||
# Put input_tensor and output_tensor_grad in data structures in the
|
||||
# right location.
|
||||
if recv_prev:
|
||||
input_tensors[next_forward_model_chunk_id].append(input_tensor)
|
||||
if recv_next:
|
||||
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
|
||||
|
||||
# Run cooldown backward passes (flush out pipeline).
|
||||
if not forward_only:
|
||||
if all_warmup_microbatches:
|
||||
output_tensor_grads[num_model_chunks - 1].append(
|
||||
p2p_communication.recv_backward(timers)
|
||||
)
|
||||
for k in range(num_microbatches_remaining, num_microbatches):
|
||||
input_tensor_grad = backward_step_helper(k)
|
||||
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
|
||||
recv_next = True
|
||||
if mpu.is_pipeline_last_stage(ignore_virtual=True):
|
||||
if next_backward_model_chunk_id == (num_model_chunks - 1):
|
||||
recv_next = False
|
||||
if k == (num_microbatches - 1):
|
||||
recv_next = False
|
||||
output_tensor_grads[next_backward_model_chunk_id].append(
|
||||
p2p_communication.send_backward_recv_backward(
|
||||
input_tensor_grad, recv_next, timers
|
||||
)
|
||||
)
|
||||
|
||||
return losses_reduced
|
||||
|
||||
|
||||
def forward_backward_pipelining_without_interleaving(
|
||||
forward_step_func, data_iterator, model, optimizer, timers, forward_only
|
||||
):
|
||||
"""Run non-interleaved 1F1B schedule, with communication between pipeline
|
||||
stages.
|
||||
|
||||
Returns dictionary with losses if the last stage, empty dict otherwise."""
|
||||
timers = get_timers()
|
||||
|
||||
assert len(model) == 1
|
||||
model = model[0]
|
||||
|
||||
# Compute number of warmup microbatches.
|
||||
num_microbatches = get_num_microbatches()
|
||||
num_warmup_microbatches = (
|
||||
mpu.get_pipeline_model_parallel_world_size()
|
||||
- mpu.get_pipeline_model_parallel_rank()
|
||||
- 1
|
||||
)
|
||||
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
|
||||
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
|
||||
|
||||
input_tensors = []
|
||||
output_tensors = []
|
||||
losses_reduced = []
|
||||
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_tensor = p2p_communication.recv_forward(timers)
|
||||
output_tensor = forward_step(
|
||||
forward_step_func, data_iterator, model, input_tensor, losses_reduced
|
||||
)
|
||||
p2p_communication.send_forward(output_tensor, timers)
|
||||
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
# Before running 1F1B, need to receive first forward tensor.
|
||||
# If all microbatches are run in warmup / cooldown phase, then no need to
|
||||
# receive this tensor here.
|
||||
if num_microbatches_remaining > 0:
|
||||
input_tensor = p2p_communication.recv_forward(timers)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatches_remaining):
|
||||
last_iteration = i == (num_microbatches_remaining - 1)
|
||||
|
||||
output_tensor = forward_step(
|
||||
forward_step_func, data_iterator, model, input_tensor, losses_reduced
|
||||
)
|
||||
if forward_only:
|
||||
p2p_communication.send_forward(output_tensor, timers)
|
||||
else:
|
||||
output_tensor_grad = p2p_communication.send_forward_recv_backward(
|
||||
output_tensor, timers
|
||||
)
|
||||
|
||||
# Add input_tensor and output_tensor to end of list, then pop from the
|
||||
# start of the list for backward pass.
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
if forward_only:
|
||||
if not last_iteration:
|
||||
input_tensor = p2p_communication.recv_forward(timers)
|
||||
else:
|
||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||
|
||||
input_tensor_grad = backward_step(
|
||||
optimizer, input_tensor, output_tensor, output_tensor_grad, model
|
||||
)
|
||||
|
||||
if last_iteration:
|
||||
input_tensor = None
|
||||
p2p_communication.send_backward(input_tensor_grad, timers)
|
||||
else:
|
||||
input_tensor = p2p_communication.send_backward_recv_forward(
|
||||
input_tensor_grad, timers
|
||||
)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_tensor = input_tensors.pop(0)
|
||||
output_tensor = output_tensors.pop(0)
|
||||
|
||||
output_tensor_grad = p2p_communication.recv_backward(timers)
|
||||
|
||||
input_tensor_grad = backward_step(
|
||||
optimizer, input_tensor, output_tensor, output_tensor_grad, model
|
||||
)
|
||||
|
||||
p2p_communication.send_backward(input_tensor_grad, timers)
|
||||
|
||||
return losses_reduced
|
@ -0,0 +1,23 @@
|
||||
import os
|
||||
|
||||
|
||||
ENV_NAMES = ["CUDA_HOME", "LD_LIBRARY_PATH", "PATH", "TORCH_EXTENSIONS_DIR", "CUDA_LAUNCH_BLOCKING"]
|
||||
|
||||
|
||||
def main():
|
||||
s = ""
|
||||
for name in ENV_NAMES:
|
||||
if name in os.environ:
|
||||
value = os.environ[name]
|
||||
s += "{}={}\n".format(name, value)
|
||||
print(f"{name}={value}")
|
||||
else:
|
||||
print(f"{name} is not set")
|
||||
|
||||
# write env vars to .deepspeed_env
|
||||
with open(".deepspeed_env", "w") as f:
|
||||
f.write(s)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,208 @@
|
||||
import os
|
||||
import subprocess
|
||||
import torch
|
||||
import logging
|
||||
|
||||
logging.getLogger("torch").setLevel(logging.WARNING)
|
||||
|
||||
import deepspeed
|
||||
from deepspeed.runtime.utils import see_memory_usage
|
||||
from functools import partial
|
||||
|
||||
from codegeex.megatron import get_args, print_rank_0, get_timers,get_tokenizer, mpu
|
||||
from codegeex.megatron.data.prompt_dataset import build_train_valid_test_datasets
|
||||
from codegeex.megatron.model import CodeGeeXModel #, CodeGeeXModelPipe
|
||||
from codegeex.megatron.training import pretrain
|
||||
from codegeex.megatron.utils import get_ltor_masks_and_position_ids
|
||||
from codegeex.megatron.utils import average_losses_across_data_parallel_group
|
||||
|
||||
|
||||
def model_provider(pre_process=True, post_process=True):
|
||||
"""Build the model."""
|
||||
|
||||
print_rank_0("building GPT model ...")
|
||||
see_memory_usage(f"Before Building Model", force=True)
|
||||
|
||||
args = get_args()
|
||||
with deepspeed.zero.Init(
|
||||
data_parallel_group=mpu.get_data_parallel_group(),
|
||||
remote_device=None if args.remote_device == "none" else args.remote_device,
|
||||
config_dict_or_path=args.deepspeed_config,
|
||||
enabled=args.zero_stage == 3,
|
||||
mpu=mpu,
|
||||
):
|
||||
if args.deepspeed and not args.no_pipeline_parallel:
|
||||
model = CodeGeeXModelPipe(num_tokentypes=0, parallel_output=True)
|
||||
# This is a hack to give us a reference to get_batch_pipe from within training.py
|
||||
# We need to call model.set_batch_fn after deepspeed.initialize
|
||||
model._megatron_batch_fn = get_batch_pipe
|
||||
|
||||
# Predompute the attention mask and store it in args. This avoids having to
|
||||
# pipeline it as an activation during training. The mask is constant, and thus
|
||||
# we can reuse it.
|
||||
attention_mask = torch.tril(
|
||||
torch.ones(
|
||||
(1, args.seq_length, args.seq_length),
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
).view(1, 1, args.seq_length, args.seq_length)
|
||||
|
||||
# Convert attention mask to binary:
|
||||
attention_mask = attention_mask < 0.5
|
||||
if args.fp16:
|
||||
attention_mask = attention_mask.half()
|
||||
elif args.bf16:
|
||||
attention_mask = attention_mask.bfloat16()
|
||||
|
||||
# Attention mask must be bool.
|
||||
args.attn_mask = attention_mask.to(torch.bool)
|
||||
|
||||
else:
|
||||
model = CodeGeeXModel(
|
||||
num_tokentypes=0,
|
||||
parallel_output=True,
|
||||
)
|
||||
|
||||
if args.load_state is not None:
|
||||
timers = get_timers()
|
||||
print_rank_0("Loading warmstarting model states ...")
|
||||
timers("load-model-states").start()
|
||||
mp_rank = mpu.get_tensor_model_parallel_rank()
|
||||
if os.path.isdir(args.load_state):
|
||||
model_path = os.path.join(
|
||||
args.load_state, f"model_mp_rank_{mp_rank}.pt"
|
||||
)
|
||||
else:
|
||||
model_path = args.load_state
|
||||
print_rank_0(f"Loading model from {model_path} ...")
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
if "module" in state_dict:
|
||||
state_dict = state_dict["module"] # strip other client states
|
||||
model.load_state_dict(state_dict)
|
||||
timers("load-model-states").stop()
|
||||
timers.log(["load-model-states"])
|
||||
see_memory_usage(f"After Building Model", force=True)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_batch(data_iterator):
|
||||
"""Generate a batch"""
|
||||
args = get_args()
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# Items and their type.
|
||||
keys = ["input_ids"]
|
||||
datatype = torch.int64
|
||||
|
||||
# Broadcast data.
|
||||
if data_iterator is not None:
|
||||
data = next(data_iterator)
|
||||
else:
|
||||
data = None
|
||||
|
||||
data_b = mpu.broadcast_data(keys, data, datatype)
|
||||
|
||||
# Unpack.
|
||||
tokens_ = data_b["input_ids"].long()
|
||||
labels = tokens_[:, 1:].contiguous()
|
||||
tokens = tokens_[:, :-1].contiguous()
|
||||
|
||||
# Get the masks and postition ids.
|
||||
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
|
||||
tokens,
|
||||
tokenizer.eod,
|
||||
args.reset_position_ids,
|
||||
args.reset_attention_mask,
|
||||
args.eod_mask_loss,
|
||||
)
|
||||
|
||||
return tokens, labels, loss_mask, attention_mask, position_ids
|
||||
|
||||
|
||||
def get_batch_pipe(data):
|
||||
"""Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`"""
|
||||
args = get_args()
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# Items and their type.
|
||||
keys = ["input_ids"]
|
||||
datatype = torch.int64
|
||||
|
||||
# Broadcast data.
|
||||
data_b = mpu.broadcast_data(keys, data, datatype)
|
||||
|
||||
# Unpack.
|
||||
tokens_ = data_b["input_ids"].long()
|
||||
labels = tokens_[:, 1:].contiguous()
|
||||
tokens = tokens_[:, :-1].contiguous()
|
||||
|
||||
# Get the masks and postition ids.
|
||||
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
|
||||
tokens,
|
||||
tokenizer.eod,
|
||||
args.reset_position_ids,
|
||||
args.reset_attention_mask,
|
||||
args.eod_mask_loss,
|
||||
)
|
||||
|
||||
return (tokens, position_ids, attention_mask), (labels, loss_mask)
|
||||
|
||||
|
||||
def loss_func(loss_mask, output_tensor):
|
||||
losses = output_tensor.float()
|
||||
loss_mask = loss_mask.view(-1).float()
|
||||
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
|
||||
|
||||
# Reduce loss for logging.
|
||||
averaged_loss = average_losses_across_data_parallel_group([loss])
|
||||
|
||||
return loss, {"lm loss": averaged_loss[0]}
|
||||
|
||||
|
||||
def forward_step(data_iterator, model):
|
||||
"""Forward step."""
|
||||
args = get_args()
|
||||
timers = get_timers()
|
||||
|
||||
# Get the batch.
|
||||
timers("batch-generator").start()
|
||||
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
|
||||
timers("batch-generator").stop()
|
||||
|
||||
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
|
||||
|
||||
return output_tensor, partial(loss_func, loss_mask)
|
||||
|
||||
|
||||
def train_valid_test_datasets_provider(train_val_test_num_samples):
|
||||
"""Build train, valid, and test datasets."""
|
||||
args = get_args()
|
||||
|
||||
print_rank_0("> building train, validation, and test datasets " "for GPT ...")
|
||||
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
|
||||
data_prefix=args.data_path,
|
||||
data_impl=args.data_impl,
|
||||
splits_string=args.split,
|
||||
train_valid_test_num_samples=train_val_test_num_samples,
|
||||
seq_length=args.seq_length,
|
||||
seed=args.seed,
|
||||
skip_warmup=(not args.mmap_warmup),
|
||||
)
|
||||
print_rank_0("> finished creating GPT datasets ...")
|
||||
|
||||
return train_ds, valid_ds, test_ds
|
||||
|
||||
|
||||
def command_exists(cmd):
|
||||
result = subprocess.Popen(f"type {cmd}", stdout=subprocess.PIPE, shell=True)
|
||||
return result.wait() == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pretrain(
|
||||
train_valid_test_datasets_provider,
|
||||
model_provider,
|
||||
forward_step,
|
||||
args_defaults={"tokenizer_type": "GPT2BPETokenizer"},
|
||||
)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1 @@
|
||||
from .quantize import quantize
|
@ -0,0 +1,329 @@
|
||||
import torch
|
||||
|
||||
from torch.nn.parameter import Parameter
|
||||
from codegeex.kernels import extract_weight_to_half
|
||||
from codegeex.megatron.mpu.layers import RowParallelLinear, ColumnParallelLinear
|
||||
from codegeex.megatron.mpu.mappings import copy_to_tensor_model_parallel_region, gather_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region, scatter_to_tensor_model_parallel_region
|
||||
|
||||
|
||||
class W8A16Linear(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
|
||||
ctx.inp_shape = inp.size()
|
||||
ctx.weight_shape = quant_w.size()
|
||||
ctx.weight_bit_width = weight_bit_width
|
||||
out_features = quant_w.size(0)
|
||||
inp = inp.contiguous().view(-1, inp.size(-1))
|
||||
weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
|
||||
output = inp.mm(weight.t())
|
||||
ctx.save_for_backward(inp, quant_w, scale_w)
|
||||
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: torch.Tensor):
|
||||
inp, quant_w, scale_w = ctx.saved_tensors
|
||||
weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
|
||||
grad_output = grad_output.contiguous().view(-1, weight.size(0))
|
||||
grad_input = grad_output.mm(weight)
|
||||
grad_weight = grad_output.t().mm(inp)
|
||||
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None
|
||||
|
||||
|
||||
class QuantizedLinear(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
weight_bit_width: int,
|
||||
weight: torch.Tensor = None,
|
||||
bias: torch.Tensor = None,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super(QuantizedLinear, self).__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight_bit_width = weight_bit_width
|
||||
|
||||
if weight is None:
|
||||
self.weight = torch.empty(
|
||||
self.out_features, self.in_features * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
|
||||
)
|
||||
self.weight_scale = torch.empty(self.out_features, dtype=kwargs["params_dtype"], device=kwargs["device"])
|
||||
else:
|
||||
self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
|
||||
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
|
||||
if weight_bit_width == 4:
|
||||
self.weight = compress_int4_weight(self.weight)
|
||||
|
||||
if bias is None:
|
||||
self.register_parameter('bias', None)
|
||||
else:
|
||||
self.bias = bias
|
||||
|
||||
self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
|
||||
self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
|
||||
|
||||
def forward(self, input_):
|
||||
# Matrix multiply.
|
||||
output = W8A16Linear.apply(input_, self.weight, self.weight_scale, self.weight_bit_width)
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class QuantizedColumnParallelLinear(ColumnParallelLinear):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
weight_bit_width: int,
|
||||
weight: torch.Tensor = None,
|
||||
bias: torch.Tensor = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super(QuantizedColumnParallelLinear, self).__init__(input_size, output_size, *args, **kwargs)
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.weight_bit_width = weight_bit_width
|
||||
if "skip_bias_add" in kwargs:
|
||||
self.skip_bias_add = kwargs["skip_bias_add"]
|
||||
else:
|
||||
self.skip_bias_add = False
|
||||
del self.weight
|
||||
|
||||
if weight is None:
|
||||
self.weight = torch.empty(
|
||||
self.output_size, self.input_size * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
|
||||
)
|
||||
self.weight_scale = torch.empty(self.output_size, dtype=kwargs["params_dtype"], device=kwargs["device"])
|
||||
else:
|
||||
self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
|
||||
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
|
||||
if weight_bit_width == 4:
|
||||
self.weight = compress_int4_weight(self.weight)
|
||||
|
||||
if bias is None:
|
||||
self.register_parameter('bias', None)
|
||||
else:
|
||||
del self.bias
|
||||
self.bias = bias
|
||||
|
||||
self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
|
||||
self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
|
||||
|
||||
def forward(self, input_):
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = copy_to_tensor_model_parallel_region(input_)
|
||||
# Matrix multiply.
|
||||
output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width)
|
||||
if self.bias is not None and not self.skip_bias_add:
|
||||
output_parallel = output_parallel + self.bias
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_from_tensor_model_parallel_region(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class QuantizedRowParallelLinear(RowParallelLinear):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
weight_bit_width: int,
|
||||
weight: torch.Tensor = None,
|
||||
bias: torch.Tensor = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super(QuantizedRowParallelLinear, self).__init__(input_size, output_size, *args, **kwargs)
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.weight_bit_width = weight_bit_width
|
||||
if "skip_bias_add" in kwargs:
|
||||
self.skip_bias_add = kwargs["skip_bias_add"]
|
||||
else:
|
||||
self.skip_bias_add = False
|
||||
del self.weight
|
||||
|
||||
if weight is None:
|
||||
self.weight = torch.empty(
|
||||
self.output_size, self.input_size * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
|
||||
)
|
||||
self.weight_scale = torch.empty(self.output_size, dtype=kwargs["params_dtype"], device=kwargs["device"])
|
||||
else:
|
||||
self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
|
||||
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
|
||||
if weight_bit_width == 4:
|
||||
self.weight = compress_int4_weight(self.weight)
|
||||
|
||||
if bias is None:
|
||||
self.register_parameter('bias', None)
|
||||
else:
|
||||
del self.bias
|
||||
self.bias = bias
|
||||
|
||||
self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
|
||||
self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
|
||||
|
||||
def forward(self, input_):
|
||||
# Set up backprop all-reduce.
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
||||
# Matrix multiply.
|
||||
output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width)
|
||||
# All-reduce across all the partitions.
|
||||
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
||||
if self.bias is not None and not self.skip_bias_add:
|
||||
output = output_ + self.bias
|
||||
else:
|
||||
output = output_
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
return output, output_bias
|
||||
|
||||
|
||||
def quantize(model, weight_bit_width, backend="torch"):
|
||||
"""Replace fp16 linear with quantized linear"""
|
||||
|
||||
for i in range(len(model.language_model.transformer.layers) + 1):
|
||||
if i == len(model.language_model.transformer.layers):
|
||||
layer = model.language_model.transformer.topQueryLayer
|
||||
else:
|
||||
layer = model.language_model.transformer.layers[i]
|
||||
|
||||
if backend == "torch":
|
||||
layer.attention.query = QuantizedLinear(
|
||||
in_features=layer.attention.query.in_features,
|
||||
out_features=layer.attention.query.out_features,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.attention.query.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.attention.query.bias.to(torch.cuda.current_device()),
|
||||
params_dtype=torch.half,
|
||||
device=layer.attention.query.weight.device,
|
||||
)
|
||||
layer.attention.value = QuantizedLinear(
|
||||
in_features=layer.attention.value.in_features,
|
||||
out_features=layer.attention.value.out_features,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.attention.value.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.attention.value.bias.to(torch.cuda.current_device()),
|
||||
params_dtype=torch.half,
|
||||
device=layer.attention.value.weight.device,
|
||||
)
|
||||
layer.attention.key = QuantizedLinear(
|
||||
in_features=layer.attention.key.in_features,
|
||||
out_features=layer.attention.key.out_features,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.attention.key.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.attention.key.bias.to(torch.cuda.current_device()),
|
||||
params_dtype=torch.half,
|
||||
device=layer.attention.key.weight.device,
|
||||
)
|
||||
layer.attention.dense = QuantizedLinear(
|
||||
in_features=layer.attention.dense.in_features,
|
||||
out_features=layer.attention.dense.out_features,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.attention.dense.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.attention.dense.bias.to(torch.cuda.current_device()),
|
||||
params_dtype=torch.half,
|
||||
device=layer.attention.dense.weight.device,
|
||||
)
|
||||
layer.mlp.dense_h_to_4h = QuantizedLinear(
|
||||
in_features=layer.mlp.dense_h_to_4h.in_features,
|
||||
out_features=layer.mlp.dense_h_to_4h.out_features,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.mlp.dense_h_to_4h.bias.to(torch.cuda.current_device()),
|
||||
params_dtype=torch.half,
|
||||
device=layer.mlp.dense_h_to_4h.weight.device,
|
||||
)
|
||||
layer.mlp.dense_4h_to_h = QuantizedLinear(
|
||||
in_features=layer.mlp.dense_4h_to_h.in_features,
|
||||
out_features=layer.mlp.dense_4h_to_h.out_features,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.mlp.dense_4h_to_h.bias.to(torch.cuda.current_device()),
|
||||
params_dtype=torch.half,
|
||||
device=layer.mlp.dense_4h_to_h.weight.device,
|
||||
)
|
||||
elif backend == "megatron":
|
||||
layer.attention.query = QuantizedColumnParallelLinear(
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.attention.query.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.attention.query.bias.to(torch.cuda.current_device()),
|
||||
input_size=layer.attention.query.input_size,
|
||||
output_size=layer.attention.query.output_size,
|
||||
gather_output=False,
|
||||
skip_init=True,
|
||||
params_dtype=torch.half,
|
||||
device=layer.attention.query.weight.device,
|
||||
)
|
||||
layer.attention.value = QuantizedColumnParallelLinear(
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.attention.value.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.attention.value.bias.to(torch.cuda.current_device()),
|
||||
input_size=layer.attention.value.input_size,
|
||||
output_size=layer.attention.value.output_size,
|
||||
gather_output=False,
|
||||
skip_init=True,
|
||||
params_dtype=torch.half,
|
||||
device=layer.attention.value.weight.device,
|
||||
)
|
||||
layer.attention.key = QuantizedColumnParallelLinear(
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.attention.key.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.attention.key.bias.to(torch.cuda.current_device()),
|
||||
input_size=layer.attention.key.input_size,
|
||||
output_size=layer.attention.key.output_size,
|
||||
gather_output=False,
|
||||
skip_init=True,
|
||||
params_dtype=torch.half,
|
||||
device=layer.attention.key.weight.device,
|
||||
)
|
||||
layer.attention.dense = QuantizedRowParallelLinear(
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.attention.dense.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.attention.dense.bias.to(torch.cuda.current_device()),
|
||||
input_size=layer.attention.dense.input_size,
|
||||
output_size=layer.attention.dense.output_size,
|
||||
input_is_parallel=False,
|
||||
skip_init=True,
|
||||
skip_bias_add=True,
|
||||
params_dtype=torch.half,
|
||||
device=layer.attention.dense.weight.device,
|
||||
)
|
||||
layer.mlp.dense_h_to_4h = QuantizedColumnParallelLinear(
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.mlp.dense_h_to_4h.bias.to(torch.cuda.current_device()),
|
||||
input_size=layer.mlp.dense_h_to_4h.input_size,
|
||||
output_size=layer.mlp.dense_h_to_4h.output_size,
|
||||
gather_output=False,
|
||||
skip_init=True,
|
||||
params_dtype=torch.half,
|
||||
device=layer.mlp.dense_h_to_4h.weight.device,
|
||||
)
|
||||
layer.mlp.dense_4h_to_h = QuantizedRowParallelLinear(
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.mlp.dense_4h_to_h.bias.to(torch.cuda.current_device()),
|
||||
input_size=layer.mlp.dense_4h_to_h.input_size,
|
||||
output_size=layer.mlp.dense_4h_to_h.output_size,
|
||||
input_is_parallel=False,
|
||||
skip_init=True,
|
||||
params_dtype=torch.half,
|
||||
device=layer.mlp.dense_4h_to_h.weight.device,
|
||||
)
|
||||
|
||||
return model
|
@ -0,0 +1 @@
|
||||
from .tokenizer import CodeGeeXTokenizer
|
@ -0,0 +1,86 @@
|
||||
from typing import *
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
||||
|
||||
|
||||
def encode_whitespaces(text, start_extra_id: int, max_len: int):
|
||||
""" Encode whitespaces to extra tokens in GPT-J.
|
||||
|
||||
>>> encode_whitespaces('a\\n b\\n c', 10, 10)
|
||||
'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c'
|
||||
"""
|
||||
|
||||
def push_acc_space(acc_len: int, text: str):
|
||||
if acc_len == 0:
|
||||
return text
|
||||
if acc_len == 1:
|
||||
return text + ' '
|
||||
assert acc_len <= max_len, f'Max whitespace run length {max_len}, but found {acc_len}'
|
||||
extra_id = start_extra_id - 2 + acc_len
|
||||
extra_token = f'<|extratoken_{extra_id}|>'
|
||||
return text + extra_token
|
||||
|
||||
acc_len = 0
|
||||
res = ''
|
||||
for ch in text:
|
||||
if ch == ' ':
|
||||
acc_len += 1
|
||||
if acc_len == max_len:
|
||||
res = push_acc_space(acc_len, res)
|
||||
acc_len = 0
|
||||
else:
|
||||
res = push_acc_space(acc_len, res)
|
||||
acc_len = 0
|
||||
res = res + ch
|
||||
|
||||
res = push_acc_space(acc_len, res)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def decode_whitespaces(text: str, start_extra_id: int, max_len: int):
|
||||
""" Decode the whitespace-encoded strings produced by encode_whitespace.
|
||||
|
||||
>>> text = 'a\\n b\\n c'
|
||||
>>> s, l = 10, 10
|
||||
>>> text == decode_whitespaces(encode_whitespaces(text, s, l), s, l)
|
||||
True
|
||||
"""
|
||||
for l in range(2, max_len + 1):
|
||||
token_id = start_extra_id - 2 + l
|
||||
token = f'<|extratoken_{token_id}|>'
|
||||
text = text.replace(token, ' ' * l)
|
||||
return text
|
||||
|
||||
|
||||
class CodeGeeXTokenizer(object):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: GPT2TokenizerFast = None,
|
||||
tokenizer_path: str = "EleutherAI/gpt-j-6B",
|
||||
start_extra_id: int = 10,
|
||||
max_len : int = 10,
|
||||
mode='codegeex-13b',
|
||||
dict_file: str = None,
|
||||
):
|
||||
self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
if mode not in ['codegeex-13b']:
|
||||
raise ValueError(f"Invalid mode {mode}, choose from ['codegeex-13b']")
|
||||
self.start_extra_id = start_extra_id
|
||||
self.max_len = max_len
|
||||
self.mode = mode
|
||||
self.eos_token_id = self.tokenizer.eos_token_id
|
||||
|
||||
def encode_code(self, code: str):
|
||||
if self.mode == 'codegeex-13b':
|
||||
code = encode_whitespaces(code, self.start_extra_id, self.max_len)
|
||||
input_ids = self.tokenizer(code, is_split_into_words=False, verbose=False).input_ids
|
||||
|
||||
return input_ids
|
||||
|
||||
def decode_code(self, input_ids):
|
||||
if self.mode == 'codegeex-13b':
|
||||
text = self.tokenizer.decode(input_ids, skip_special_tokens=False, verbose=False)
|
||||
output_code = decode_whitespaces(text, self.start_extra_id, self.max_len)
|
||||
|
||||
return output_code
|
@ -0,0 +1 @@
|
||||
from .codegeex_model import CodeGeeXModel
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,326 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import *
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def get_ltor_masks_and_position_ids(
|
||||
data,
|
||||
eod_token,
|
||||
reset_position_ids,
|
||||
reset_attention_mask,
|
||||
):
|
||||
"""Build masks and position id for left to right model."""
|
||||
|
||||
# Extract batch size and sequence length.
|
||||
micro_batch_size, seq_length = data.size()
|
||||
|
||||
# Attention mask (lower triangular).
|
||||
if reset_attention_mask:
|
||||
att_mask_batch = micro_batch_size
|
||||
else:
|
||||
att_mask_batch = 1
|
||||
attention_mask = torch.tril(
|
||||
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
|
||||
).view(att_mask_batch, 1, seq_length, seq_length)
|
||||
|
||||
# Position ids.
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(data)
|
||||
# We need to clone as the ids will be modifed based on batch index.
|
||||
if reset_position_ids:
|
||||
position_ids = position_ids.clone()
|
||||
|
||||
if reset_position_ids or reset_attention_mask:
|
||||
# Loop through the batches:
|
||||
for b in range(micro_batch_size):
|
||||
|
||||
# Find indecies where EOD token is.
|
||||
eod_index = position_ids[b, data[b] == eod_token]
|
||||
# Detach indecies from positions if going to modify positions.
|
||||
if reset_position_ids:
|
||||
eod_index = eod_index.clone()
|
||||
|
||||
# Loop through EOD indecies:
|
||||
prev_index = 0
|
||||
for j in range(eod_index.size()[0]):
|
||||
i = eod_index[j]
|
||||
# Mask attention loss.
|
||||
if reset_attention_mask:
|
||||
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
|
||||
# Reset positions.
|
||||
if reset_position_ids:
|
||||
position_ids[b, (i + 1) :] -= i + 1 - prev_index
|
||||
prev_index = i + 1
|
||||
|
||||
# Convert attention mask to binary:
|
||||
attention_mask = attention_mask < 0.5
|
||||
|
||||
return attention_mask, position_ids
|
||||
|
||||
|
||||
def get_batch(
|
||||
context_tokens,
|
||||
micro_batch_size,
|
||||
eod_token,
|
||||
reset_position_ids=False,
|
||||
reset_attention_mask=False,
|
||||
):
|
||||
"""Generate batch from context tokens."""
|
||||
tokens = context_tokens.view(micro_batch_size, -1).contiguous().cuda()
|
||||
# Get the attention mask and postition ids.
|
||||
attention_mask, position_ids = get_ltor_masks_and_position_ids(
|
||||
tokens,
|
||||
eod_token,
|
||||
reset_position_ids,
|
||||
reset_attention_mask,
|
||||
)
|
||||
|
||||
return tokens, attention_mask, position_ids
|
||||
|
||||
|
||||
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
|
||||
"""This function has been mostly taken from huggingface conversational
|
||||
ai code at
|
||||
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
|
||||
conversational-ai-with-transfer-learning-2d818ac26313"""
|
||||
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the
|
||||
# last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p > 0.0:
|
||||
# Cconvert to 1D
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token
|
||||
# above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
for i in range(sorted_indices.size(0)):
|
||||
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
|
||||
logits[i][indices_to_remove] = filter_value
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def pad_batch(batch, pad_id, seq_length):
|
||||
context_lengths = []
|
||||
for tokens in batch:
|
||||
context_length = len(tokens)
|
||||
if context_length < seq_length:
|
||||
tokens.extend([pad_id] * (seq_length - context_length))
|
||||
context_lengths.append(context_length)
|
||||
return batch, context_lengths
|
||||
|
||||
|
||||
def forward_step(
|
||||
model,
|
||||
tokens,
|
||||
seq_length,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
layer_past=None,
|
||||
get_key_value=None,
|
||||
prompt_length=None,
|
||||
context_length=None,
|
||||
):
|
||||
# Forward pass through the model.
|
||||
output_tensor = model(
|
||||
tokens,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
layer_past=layer_past,
|
||||
get_key_value=get_key_value,
|
||||
prompt_length=prompt_length,
|
||||
context_length=context_length,
|
||||
)
|
||||
|
||||
if get_key_value:
|
||||
output_tensor, layer_past = output_tensor
|
||||
|
||||
if get_key_value:
|
||||
return output_tensor, layer_past
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def get_token_stream(
|
||||
model,
|
||||
tokenizer,
|
||||
seq_length,
|
||||
out_seq_length,
|
||||
context_tokens,
|
||||
return_scores: bool = False,
|
||||
prompt_length: int = None,
|
||||
micro_batch_size: int = None,
|
||||
bad_ids: List = None,
|
||||
temperature: float = 1.0,
|
||||
topp: float = 1.0,
|
||||
topk: int = 0.0,
|
||||
greedy: bool = False,
|
||||
recompute: bool = False,
|
||||
):
|
||||
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eos_token_id, seq_length)
|
||||
|
||||
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
|
||||
context_length_tensor = torch.cuda.LongTensor(context_lengths)
|
||||
context_length = context_length_tensor.min().item()
|
||||
tokens, attention_mask, position_ids = get_batch(
|
||||
context_tokens_tensor,
|
||||
micro_batch_size,
|
||||
tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
batch_token_iterator = sample_sequence_batch(
|
||||
model,
|
||||
tokenizer,
|
||||
context_tokens_tensor,
|
||||
context_length_tensor,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
seq_length=seq_length,
|
||||
out_seq_length=out_seq_length,
|
||||
return_scores=return_scores,
|
||||
prompt_length=prompt_length,
|
||||
bad_ids=bad_ids,
|
||||
temperature=temperature,
|
||||
topp=topp,
|
||||
topk=topk,
|
||||
greedy=greedy,
|
||||
recompute=recompute,
|
||||
)
|
||||
|
||||
for tokens, lengths in batch_token_iterator:
|
||||
context_length += 1
|
||||
if tokens is not None:
|
||||
yield tokens[:, :context_length], lengths
|
||||
else:
|
||||
yield None, None
|
||||
|
||||
|
||||
def switch(val1, val2, boolean):
|
||||
boolean = boolean.type_as(val1)
|
||||
return (1 - boolean) * val1 + boolean * val2
|
||||
|
||||
|
||||
def sample_sequence_batch(
|
||||
model,
|
||||
tokenizer,
|
||||
context_tokens,
|
||||
context_lengths,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
seq_length,
|
||||
out_seq_length,
|
||||
maxlen=None,
|
||||
return_scores: bool = False,
|
||||
prompt_length: int = None,
|
||||
bad_ids: List = None,
|
||||
temperature: float = 1.0,
|
||||
topp: float = 1.0,
|
||||
topk: int = 0.0,
|
||||
recompute: bool = False,
|
||||
greedy: bool = False,
|
||||
):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
context_length = context_lengths.min().item()
|
||||
eos_id = tokenizer.eos_token_id
|
||||
|
||||
counter = 0
|
||||
org_context_length = context_length
|
||||
|
||||
layer_past = None
|
||||
batch_size = context_tokens.size(0)
|
||||
is_done = torch.zeros([batch_size]).byte().cuda()
|
||||
tokens = context_tokens
|
||||
if maxlen is None:
|
||||
maxlen = seq_length - 1
|
||||
if maxlen > (org_context_length + out_seq_length):
|
||||
maxlen = org_context_length + out_seq_length
|
||||
|
||||
lengths = torch.ones([batch_size]).long().cuda() * maxlen
|
||||
if return_scores:
|
||||
scores = torch.zeros([batch_size]).float().cuda()
|
||||
|
||||
while context_length <= (maxlen):
|
||||
|
||||
if recompute:
|
||||
logits = model(tokens,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
prompt_length=prompt_length,
|
||||
context_length=context_length,
|
||||
)
|
||||
logits = logits[:, context_length - 1, :]
|
||||
else:
|
||||
if counter == 0:
|
||||
tokens2use = tokens[:, :context_length]
|
||||
positions2use = position_ids[:, :context_length]
|
||||
else:
|
||||
tokens2use = tokens[:, context_length - 1].view(
|
||||
batch_size, -1)
|
||||
positions2use = position_ids[:, context_length - 1].view(
|
||||
batch_size, -1)
|
||||
logits, layer_past = model(tokens2use,
|
||||
positions2use,
|
||||
attention_mask,
|
||||
layer_past=layer_past,
|
||||
get_key_value=True,
|
||||
prompt_length=prompt_length,
|
||||
context_length=context_length,
|
||||
)
|
||||
logits = logits[:, -1].view(batch_size, -1).contiguous()
|
||||
|
||||
if bad_ids is not None:
|
||||
for bad_id in bad_ids:
|
||||
logits[:, bad_id] = -10000
|
||||
if greedy:
|
||||
prev = torch.argmax(logits, dim=-1).view(-1)
|
||||
else:
|
||||
logits = logits.float()
|
||||
if return_scores:
|
||||
orig_log_probs = torch.log_softmax(logits, dim=-1)
|
||||
logits /= temperature
|
||||
logits = top_k_logits(logits, top_k=topk, top_p=topp)
|
||||
log_probs = F.softmax(logits, dim=-1)
|
||||
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
|
||||
|
||||
started = context_lengths <= context_length
|
||||
|
||||
new_tokens = switch(tokens[:, context_length].view(-1), prev, started)
|
||||
|
||||
if not greedy and return_scores:
|
||||
indices = prev.view(-1, 1)
|
||||
new_scores = orig_log_probs.gather(1, indices).view(-1)
|
||||
new_scores = new_scores * started
|
||||
new_scores = new_scores * is_done.bool().logical_not()
|
||||
scores += new_scores
|
||||
|
||||
tokens[:, context_length] = new_tokens
|
||||
done_token = (prev == eos_id).byte() & started.byte()
|
||||
just_finished = (done_token & ~is_done).bool()
|
||||
lengths[just_finished.view(-1)] = context_length
|
||||
is_done = is_done | done_token
|
||||
done = torch.all(is_done)
|
||||
|
||||
if return_scores:
|
||||
yield tokens, (lengths, scores)
|
||||
else:
|
||||
yield tokens, lengths
|
||||
|
||||
context_length += 1
|
||||
counter += 1
|
||||
if done:
|
||||
break
|
@ -0,0 +1,17 @@
|
||||
# CodeGeeX-13B parallel configuration
|
||||
# Parallel checkpoints are named under the format "mp_rank_0{i}_model_states.pt", where i is the rank, start from 0.
|
||||
|
||||
CHECKPOINT_PATH="<path where you put all parallel checkpoints (e.g., XXX/tp4/)>"
|
||||
|
||||
MODEL_ARGS="--num-layers 39 \
|
||||
--hidden-size 5120 \
|
||||
--num-attention-heads 40 \
|
||||
--max-position-embeddings 2048 \
|
||||
--attention-softmax-in-fp32 \
|
||||
--load "$CHECKPOINT_PATH" \
|
||||
--layernorm-epsilon 1e-5 \
|
||||
--fp16 \
|
||||
--ws-encoding-start-id 10 \
|
||||
--ws-encoding-length 10 \
|
||||
--make-vocab-size-divisible-by 52224 \
|
||||
--seq-length 2048"
|
@ -0,0 +1,37 @@
|
||||
# This script is used to convert mindspore checkpoint to the megatron format.
|
||||
|
||||
LOAD_CKPT_PATH=$1 # Path to weights in .pt format.
|
||||
SAVE_CKPT_PATH=$2 # Path to save the output MP checkpoints.
|
||||
MP_SIZE=$3 # Model parallel size
|
||||
|
||||
SCRIPT_PATH=$(realpath "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
MAIN_DIR=$(dirname "$SCRIPT_DIR")
|
||||
TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
|
||||
|
||||
if [ -z "$MP_SIZE" ]; then
|
||||
MP_SIZE=1
|
||||
fi
|
||||
|
||||
# export CUDA settings
|
||||
export CUDA_HOME=/usr/local/cuda-11.1/
|
||||
export CUDA_VISIBLE_DEVICES=0,1
|
||||
|
||||
|
||||
CMD="python $MAIN_DIR/codegeex/megatron/convert_ckpt_parallel.py \
|
||||
--load-ckpt-path $LOAD_CKPT_PATH \
|
||||
--save-ckpt-path $SAVE_CKPT_PATH \
|
||||
--tokenizer-path $TOKENIZER_PATH \
|
||||
--target-tensor-model-parallel-size $MP_SIZE \
|
||||
--num-layers 39 \
|
||||
--hidden-size 5120 \
|
||||
--num-attention-heads 40 \
|
||||
--max-position-embeddings 2048 \
|
||||
--attention-softmax-in-fp32 \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--make-vocab-size-divisible-by 52224 \
|
||||
--seq-length 2048"
|
||||
|
||||
echo "$CMD"
|
||||
eval "$CMD"
|
@ -0,0 +1,25 @@
|
||||
# Process dataset for CodeGeeX pretraining
|
||||
|
||||
DATASET_PATH=$1
|
||||
OUTPUT_PATH=$2
|
||||
LANGUAGE=$3
|
||||
|
||||
SCRIPT_PATH=$(realpath "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
MAIN_DIR=$(dirname "$SCRIPT_DIR")
|
||||
TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
|
||||
|
||||
if [ -z "$LANGUAGE" ]; then
|
||||
LANGUAGE=python
|
||||
fi
|
||||
|
||||
CMD="python $MAIN_DIR/codegeex/data/process_pretrain_dataset.py \
|
||||
--dataset_path $DATASET_PATH \
|
||||
--tokenizer_path $TOKENIZER_PATH \
|
||||
--output_prefix $OUTPUT_PATH \
|
||||
--language $LANGUAGE \
|
||||
--mode pretrain \
|
||||
--seq_len 2048"
|
||||
|
||||
echo "$CMD"
|
||||
eval "$CMD"
|
@ -0,0 +1,46 @@
|
||||
# This script is used to test the inference of CodeGeeX.
|
||||
|
||||
MP_SIZE=$1
|
||||
|
||||
SCRIPT_PATH=$(realpath "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
MAIN_DIR=$(dirname "$SCRIPT_DIR")
|
||||
TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
|
||||
|
||||
if [ -z "$MP_SIZE" ]; then
|
||||
MP_SIZE=1
|
||||
fi
|
||||
|
||||
if [ "$MP_SIZE" -eq 1 ]; then
|
||||
source "$MAIN_DIR/configs/codegeex_13b.sh"
|
||||
echo "Load config from $MAIN_DIR/configs/codegeex_13b.sh"
|
||||
else
|
||||
source "$MAIN_DIR/configs/codegeex_13b_parallel.sh"
|
||||
echo "Load config from $MAIN_DIR/configs/codegeex_13b_parallel.sh"
|
||||
fi
|
||||
|
||||
# export CUDA settings
|
||||
export CUDA_HOME=/usr/local/cuda-11.1/
|
||||
# export CUDA_VISIBLE_DEVICES=0,1
|
||||
|
||||
if [ -z "$PROMPT_FILE" ]; then
|
||||
PROMPT_FILE=$MAIN_DIR/tests/test_prompt.txt
|
||||
fi
|
||||
|
||||
# remove --greedy if using sampling
|
||||
CMD="torchrun --nproc_per_node $MP_SIZE $MAIN_DIR/tests/test_inference_megatron.py \
|
||||
--tensor-model-parallel-size $MP_SIZE \
|
||||
--prompt-file $PROMPT_FILE \
|
||||
--tokenizer-path $TOKENIZER_PATH \
|
||||
--micro-batch-size 1 \
|
||||
--out-seq-length 1024 \
|
||||
--temperature 0.8 \
|
||||
--top-p 0.95 \
|
||||
--top-k 0 \
|
||||
--greedy \
|
||||
--use-cpu-initialization \
|
||||
--ln-fp16 \
|
||||
$MODEL_ARGS"
|
||||
|
||||
echo "$CMD"
|
||||
eval "$CMD"
|
@ -0,0 +1,39 @@
|
||||
# This script is used to test the inference of CodeGeeX.
|
||||
|
||||
GPU=$1
|
||||
PROMPT_FILE=$2
|
||||
|
||||
SCRIPT_PATH=$(realpath "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
MAIN_DIR=$(dirname "$SCRIPT_DIR")
|
||||
TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
|
||||
|
||||
# import model configuration
|
||||
source "$MAIN_DIR/configs/codegeex_13b.sh"
|
||||
|
||||
# export CUDA settings
|
||||
if [ -z "$GPU" ]; then
|
||||
GPU=0
|
||||
fi
|
||||
|
||||
export CUDA_HOME=/usr/local/cuda-11.1/
|
||||
export CUDA_VISIBLE_DEVICES=$GPU
|
||||
|
||||
if [ -z "$PROMPT_FILE" ]; then
|
||||
PROMPT_FILE=$MAIN_DIR/tests/test_prompt.txt
|
||||
fi
|
||||
|
||||
# remove --greedy if using sampling
|
||||
CMD="python $MAIN_DIR/tests/test_inference.py \
|
||||
--prompt-file $PROMPT_FILE \
|
||||
--tokenizer-path $TOKENIZER_PATH \
|
||||
--micro-batch-size 1 \
|
||||
--out-seq-length 1024 \
|
||||
--temperature 0.2 \
|
||||
--top-p 0.95 \
|
||||
--top-k 0 \
|
||||
--quantize \
|
||||
$MODEL_ARGS"
|
||||
|
||||
echo "$CMD"
|
||||
eval "$CMD"
|
@ -0,0 +1,209 @@
|
||||
import copy
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from codegeex.megatron import get_tokenizer, get_args, print_rank_0
|
||||
from codegeex.megatron.initialize import initialize_megatron
|
||||
from codegeex.megatron.model import CodeGeeXModel
|
||||
from codegeex.megatron.code_generation_utils import get_token_stream
|
||||
from codegeex.quantization import quantize
|
||||
from codegeex.megatron.training import get_model
|
||||
from codegeex.megatron.checkpointing import load_checkpoint
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def model_provider(pre_process=True, post_process=True):
|
||||
"""Build the model."""
|
||||
|
||||
print_rank_0("Building CodeGeeX model ...")
|
||||
model = CodeGeeXModel(num_tokentypes=0,
|
||||
parallel_output=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def add_code_generation_args(parser):
|
||||
"""Code generation arguments."""
|
||||
group = parser.add_argument_group(title="code generation")
|
||||
|
||||
group.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Sampling temperature.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--greedy",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use greedy sampling.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--top-p",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Top p sampling.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Top k sampling.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--out-seq-length",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Size of the output generated text.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--recompute",
|
||||
action="store_true",
|
||||
help="During generation recompute all attention "
|
||||
"instead of using previously computed keys/values.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ws-encoding-start-id",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Start id for whitespace encoding",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ws-encoding-length",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Length of whitespace encoding",
|
||||
)
|
||||
group.add_argument(
|
||||
"--n-generation",
|
||||
type=int,
|
||||
default=10,
|
||||
)
|
||||
group.add_argument(
|
||||
"--eos-id",
|
||||
type=int,
|
||||
default=50256,
|
||||
)
|
||||
group.add_argument(
|
||||
"--prompt-file",
|
||||
type=str,
|
||||
default="./test_prompt.txt",
|
||||
)
|
||||
group.add_argument(
|
||||
"--perf-file",
|
||||
type=str,
|
||||
default="./perf_out.txt",
|
||||
)
|
||||
group.add_argument(
|
||||
"--perf-trace",
|
||||
type=str,
|
||||
default="./perf_out.txt",
|
||||
)
|
||||
group.add_argument(
|
||||
"--use-torch-profile",
|
||||
action="store_true",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ln-fp32",
|
||||
action="store_true",
|
||||
)
|
||||
group.add_argument(
|
||||
'--bad-ids',
|
||||
nargs="*",
|
||||
type=int,
|
||||
default=None,
|
||||
help='Identify the type of programming language to generate',
|
||||
)
|
||||
group.add_argument(
|
||||
"--quantize",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
initialize_megatron(
|
||||
extra_args_provider=add_code_generation_args,
|
||||
args_defaults={
|
||||
'no_load_rng': True,
|
||||
'no_load_optim': True,
|
||||
}
|
||||
)
|
||||
|
||||
args = get_args()
|
||||
set_random_seed(args.seed)
|
||||
|
||||
print_rank_0("Loading tokenizer ...")
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
print_rank_0("Loading state dict ...")
|
||||
|
||||
model = get_model(model_provider)
|
||||
if args.load is not None:
|
||||
_ = load_checkpoint(model, None, None)
|
||||
|
||||
assert len(model) == 1, "Above condition should have caught this"
|
||||
|
||||
model = model[0]
|
||||
model.eval()
|
||||
if args.fp16 and args.ln_fp16:
|
||||
model.half()
|
||||
if args.quantize:
|
||||
model = quantize(model, weight_bit_width=8, backend="megatron")
|
||||
|
||||
with open(args.prompt_file, "r") as f:
|
||||
prompt = f.readlines()
|
||||
prompt = "".join(prompt)
|
||||
|
||||
print_rank_0("Generating ...")
|
||||
t0 = time.perf_counter()
|
||||
for prompt in [prompt]:
|
||||
tokens = tokenizer.tokenize(prompt)
|
||||
print_rank_0(tokens)
|
||||
print_rank_0("Current prompt:")
|
||||
print_rank_0(prompt)
|
||||
n_token_prompt = len(tokens)
|
||||
print_rank_0(f"N_token_prompt: {n_token_prompt}")
|
||||
token_stream = get_token_stream(
|
||||
model,
|
||||
[copy.deepcopy(tokens) for _ in range(args.micro_batch_size)],
|
||||
micro_batch_size=args.micro_batch_size,
|
||||
bad_ids=args.bad_ids,
|
||||
)
|
||||
is_finished = [False for _ in range(args.micro_batch_size)]
|
||||
for i, generated in enumerate(token_stream):
|
||||
generated_tokens = generated[0]
|
||||
for j in range(args.micro_batch_size):
|
||||
if is_finished[j]:
|
||||
continue
|
||||
if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eod or len(
|
||||
generated_tokens[j]) >= args.out_seq_length:
|
||||
is_finished[j] = True
|
||||
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
|
||||
generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:])
|
||||
t1 = time.perf_counter()
|
||||
print_rank_0(f"Total generation time: {t1 - t0}, # Tokens: {len(generated_tokens_) - n_token_prompt}")
|
||||
print_rank_0(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
|
||||
print_rank_0("================================= Generated code:")
|
||||
print_rank_0(generated_code)
|
||||
t0 = time.perf_counter()
|
||||
if all(is_finished):
|
||||
break
|
||||
|
||||
print_rank_0("Generation finished.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue