mirror of https://github.com/THUDM/CodeGeeX.git
Add megatron data processing
parent
3281f0bcbc
commit
02241e057c
@ -1,204 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# copied from fairseq/fairseq/data/indexed_dataset.py
|
||||
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
|
||||
# other slight modifications to remove fairseq dependencies
|
||||
# Added document index to index file and made it accessible.
|
||||
# An empty sentence no longer separates documents.
|
||||
|
||||
|
||||
# copied from Megatron
|
||||
# Used for building mmap datasets.
|
||||
from functools import lru_cache
|
||||
import struct
|
||||
import numpy as np
|
||||
import shutil
|
||||
|
||||
__all__ = ["make_mmap_builder"]
|
||||
|
||||
dtypes = {
|
||||
1: np.uint8,
|
||||
2: np.int8,
|
||||
3: np.int16,
|
||||
4: np.int32,
|
||||
5: np.int64,
|
||||
6: np.float,
|
||||
7: np.double,
|
||||
8: np.uint16,
|
||||
}
|
||||
|
||||
|
||||
def __best_fitting_dtype(vocab_size=None):
|
||||
if vocab_size is not None and vocab_size < 65500:
|
||||
return np.uint16
|
||||
else:
|
||||
return np.int32
|
||||
|
||||
|
||||
def make_mmap_builder(out_file, vocab_size=None):
|
||||
return MMapIndexedDatasetBuilder(
|
||||
out_file, dtype=__best_fitting_dtype(vocab_size)
|
||||
)
|
||||
|
||||
|
||||
def code(dtype):
|
||||
for k in dtypes.keys():
|
||||
if dtypes[k] == dtype:
|
||||
return k
|
||||
raise ValueError(dtype)
|
||||
|
||||
|
||||
def index_file_path(prefix_path):
|
||||
return prefix_path + ".idx"
|
||||
|
||||
|
||||
def data_file_path(prefix_path):
|
||||
return prefix_path + ".bin"
|
||||
|
||||
|
||||
class MMapIndexedDataset:
|
||||
class Index(object):
|
||||
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
||||
|
||||
@classmethod
|
||||
def writer(cls, path, dtype):
|
||||
class _Writer(object):
|
||||
def __enter__(self):
|
||||
self._file = open(path, "wb")
|
||||
|
||||
self._file.write(cls._HDR_MAGIC)
|
||||
self._file.write(struct.pack("<Q", 1))
|
||||
self._file.write(struct.pack("<B", code(dtype)))
|
||||
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _get_pointers(sizes):
|
||||
dtype_size = dtype().itemsize
|
||||
address = 0
|
||||
pointers = []
|
||||
|
||||
for size in sizes:
|
||||
pointers.append(address)
|
||||
address += size * dtype_size
|
||||
|
||||
return pointers
|
||||
|
||||
def write(self, sizes, doc_idx):
|
||||
pointers = self._get_pointers(sizes)
|
||||
|
||||
self._file.write(struct.pack("<Q", len(sizes)))
|
||||
self._file.write(struct.pack("<Q", len(doc_idx)))
|
||||
|
||||
sizes = np.array(sizes, dtype=np.int32)
|
||||
self._file.write(sizes.tobytes(order="C"))
|
||||
del sizes
|
||||
|
||||
pointers = np.array(pointers, dtype=np.int64)
|
||||
self._file.write(pointers.tobytes(order="C"))
|
||||
del pointers
|
||||
|
||||
doc_idx = np.array(doc_idx, dtype=np.int64)
|
||||
self._file.write(doc_idx.tobytes(order="C"))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._file.close()
|
||||
|
||||
return _Writer()
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
with open(path, "rb") as stream:
|
||||
magic_test = stream.read(9)
|
||||
assert self._HDR_MAGIC == magic_test, (
|
||||
"Index file doesn't match expected format. "
|
||||
"Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
version = struct.unpack("<Q", stream.read(8))
|
||||
assert (1,) == version
|
||||
|
||||
(dtype_code,) = struct.unpack("<B", stream.read(1))
|
||||
self._dtype = dtypes[dtype_code]
|
||||
self._dtype_size = self._dtype().itemsize
|
||||
|
||||
self._len = struct.unpack("<Q", stream.read(8))[0]
|
||||
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
|
||||
offset = stream.tell()
|
||||
|
||||
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
self._sizes = np.frombuffer(
|
||||
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
|
||||
)
|
||||
self._pointers = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._len,
|
||||
offset=offset + self._sizes.nbytes,
|
||||
)
|
||||
self._doc_idx = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._doc_count,
|
||||
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._doc_idx
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
return self._pointers[i], self._sizes[i]
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
|
||||
class MMapIndexedDatasetBuilder(object):
|
||||
def __init__(self, out_file, dtype=np.int64):
|
||||
self._data_file = open(out_file, "wb")
|
||||
self._dtype = dtype
|
||||
self._sizes = []
|
||||
self._doc_idx = [0]
|
||||
|
||||
def add_item(self, tensor):
|
||||
np_array = np.array(tensor.numpy(), dtype=self._dtype)
|
||||
self._data_file.write(np_array.tobytes(order="C"))
|
||||
self._sizes.append(np_array.size)
|
||||
|
||||
def end_document(self):
|
||||
self._doc_idx.append(len(self._sizes))
|
||||
|
||||
def merge_file_(self, another_file):
|
||||
# Concatenate index
|
||||
index = MMapIndexedDataset.Index(index_file_path(another_file))
|
||||
assert index.dtype == self._dtype
|
||||
|
||||
for size in index.sizes:
|
||||
self._sizes.append(size)
|
||||
|
||||
# Concatenate data
|
||||
with open(data_file_path(another_file), "rb") as f:
|
||||
shutil.copyfileobj(f, self._data_file)
|
||||
|
||||
def finalize(self, index_file):
|
||||
self._data_file.close()
|
||||
|
||||
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
|
||||
index.write(self._sizes, self._doc_idx)
|
@ -0,0 +1,69 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Blendable dataset."""
|
||||
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from codegeex.megatron import print_rank_0
|
||||
|
||||
|
||||
class BlendableDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, datasets, weights):
|
||||
|
||||
self.datasets = datasets
|
||||
num_datasets = len(datasets)
|
||||
assert num_datasets == len(weights)
|
||||
|
||||
self.size = 0
|
||||
for dataset in self.datasets:
|
||||
self.size += len(dataset)
|
||||
|
||||
# Normalize weights.
|
||||
weights = np.array(weights, dtype=np.float64)
|
||||
sum_weights = np.sum(weights)
|
||||
assert sum_weights > 0.0
|
||||
weights /= sum_weights
|
||||
|
||||
# Build indecies.
|
||||
start_time = time.time()
|
||||
assert num_datasets < 255
|
||||
self.dataset_index = np.zeros(self.size, dtype=np.uint8)
|
||||
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
|
||||
|
||||
from megatron.data import helpers
|
||||
|
||||
helpers.build_blending_indices(
|
||||
self.dataset_index,
|
||||
self.dataset_sample_index,
|
||||
weights,
|
||||
num_datasets,
|
||||
self.size,
|
||||
torch.distributed.get_rank() == 0,
|
||||
)
|
||||
print_rank_0(
|
||||
"> elapsed time for building blendable dataset indices: "
|
||||
"{:.2f} (sec)".format(time.time() - start_time)
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
def __getitem__(self, idx):
|
||||
dataset_idx = self.dataset_index[idx]
|
||||
sample_idx = self.dataset_sample_index[idx]
|
||||
return self.datasets[dataset_idx][sample_idx]
|
@ -0,0 +1,185 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Dataloaders."""
|
||||
|
||||
|
||||
import torch
|
||||
from codegeex.megatron import get_args
|
||||
from codegeex.megatron import mpu
|
||||
|
||||
|
||||
def build_pretraining_data_loader(dataset, consumed_samples):
|
||||
"""Buld dataloader given an input dataset."""
|
||||
|
||||
if dataset is None:
|
||||
return None
|
||||
args = get_args()
|
||||
|
||||
# Megatron sampler
|
||||
if args.dataloader_type == "single":
|
||||
batch_sampler = MegatronPretrainingSampler(
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=args.micro_batch_size,
|
||||
data_parallel_rank=mpu.get_data_parallel_rank(),
|
||||
data_parallel_size=mpu.get_data_parallel_world_size(),
|
||||
)
|
||||
elif args.dataloader_type == "cyclic":
|
||||
batch_sampler = MegatronPretrainingRandomSampler(
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=args.micro_batch_size,
|
||||
data_parallel_rank=mpu.get_data_parallel_rank(),
|
||||
data_parallel_size=mpu.get_data_parallel_world_size(),
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"{} dataloader type is not supported.".format(args.dataloader_type)
|
||||
)
|
||||
|
||||
# Torch dataloader.
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
|
||||
class MegatronPretrainingSampler:
|
||||
def __init__(
|
||||
self,
|
||||
total_samples,
|
||||
consumed_samples,
|
||||
micro_batch_size,
|
||||
data_parallel_rank,
|
||||
data_parallel_size,
|
||||
drop_last=True,
|
||||
):
|
||||
# Keep a copy of input params for later use.
|
||||
self.total_samples = total_samples
|
||||
self.consumed_samples = consumed_samples
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
self.micro_batch_times_data_parallel_size = (
|
||||
self.micro_batch_size * data_parallel_size
|
||||
)
|
||||
self.drop_last = drop_last
|
||||
|
||||
# Sanity checks.
|
||||
assert self.total_samples > 0, "no sample to consume: {}".format(
|
||||
self.total_samples
|
||||
)
|
||||
assert (
|
||||
self.consumed_samples < self.total_samples
|
||||
), "no samples left to consume: {}, {}".format(
|
||||
self.consumed_samples, self.total_samples
|
||||
)
|
||||
assert self.micro_batch_size > 0
|
||||
assert data_parallel_size > 0
|
||||
assert (
|
||||
self.data_parallel_rank < data_parallel_size
|
||||
), "data_parallel_rank should be smaller than data size: {}, " "{}".format(
|
||||
self.data_parallel_rank, data_parallel_size
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
|
||||
def get_start_end_idx(self):
|
||||
start_idx = self.data_parallel_rank * self.micro_batch_size
|
||||
end_idx = start_idx + self.micro_batch_size
|
||||
return start_idx, end_idx
|
||||
|
||||
def __iter__(self):
|
||||
batch = []
|
||||
# Last batch will be dropped if drop_last is not set False
|
||||
for idx in range(self.consumed_samples, self.total_samples):
|
||||
batch.append(idx)
|
||||
if len(batch) == self.micro_batch_times_data_parallel_size:
|
||||
start_idx, end_idx = self.get_start_end_idx()
|
||||
yield batch[start_idx:end_idx]
|
||||
batch = []
|
||||
|
||||
# Check the last partial batch and see drop_last is set
|
||||
if len(batch) > 0 and not self.drop_last:
|
||||
start_idx, end_idx = self.get_start_end_idx()
|
||||
yield batch[start_idx:end_idx]
|
||||
|
||||
|
||||
class MegatronPretrainingRandomSampler:
|
||||
def __init__(
|
||||
self,
|
||||
total_samples,
|
||||
consumed_samples,
|
||||
micro_batch_size,
|
||||
data_parallel_rank,
|
||||
data_parallel_size,
|
||||
):
|
||||
# Keep a copy of input params for later use.
|
||||
self.total_samples = total_samples
|
||||
self.consumed_samples = consumed_samples
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
self.data_parallel_size = data_parallel_size
|
||||
self.micro_batch_times_data_parallel_size = (
|
||||
self.micro_batch_size * data_parallel_size
|
||||
)
|
||||
self.last_batch_size = (
|
||||
self.total_samples % self.micro_batch_times_data_parallel_size
|
||||
)
|
||||
|
||||
# Sanity checks.
|
||||
assert self.total_samples > 0, "no sample to consume: {}".format(
|
||||
self.total_samples
|
||||
)
|
||||
assert self.micro_batch_size > 0
|
||||
assert data_parallel_size > 0
|
||||
assert (
|
||||
self.data_parallel_rank < data_parallel_size
|
||||
), "data_parallel_rank should be smaller than data size: {}, " "{}".format(
|
||||
self.data_parallel_rank, data_parallel_size
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
|
||||
def __iter__(self):
|
||||
active_total_samples = self.total_samples - self.last_batch_size
|
||||
self.epoch = self.consumed_samples // active_total_samples
|
||||
current_epoch_samples = self.consumed_samples % active_total_samples
|
||||
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
|
||||
|
||||
# data sharding and random sampling
|
||||
bucket_size = (
|
||||
self.total_samples // self.micro_batch_times_data_parallel_size
|
||||
) * self.micro_batch_size
|
||||
bucket_offset = current_epoch_samples // self.data_parallel_size
|
||||
start_idx = self.data_parallel_rank * bucket_size
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
random_idx = torch.randperm(bucket_size, generator=g).tolist()
|
||||
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
|
||||
|
||||
batch = []
|
||||
# Last batch if not complete will be dropped.
|
||||
for idx in idx_range:
|
||||
batch.append(idx)
|
||||
if len(batch) == self.micro_batch_size:
|
||||
self.consumed_samples += self.micro_batch_times_data_parallel_size
|
||||
yield batch
|
||||
batch = []
|
@ -0,0 +1,566 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors, and NVIDIA.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# Most of the code here has been copied from:
|
||||
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
|
||||
# with some modifications.
|
||||
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from codegeex.megatron import mpu, print_rank_0
|
||||
from codegeex.megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
|
||||
|
||||
|
||||
def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples):
|
||||
|
||||
# The data prefix should be in the format of:
|
||||
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
|
||||
assert len(data_prefix) % 2 == 0
|
||||
num_datasets = len(data_prefix) // 2
|
||||
weights = [0] * num_datasets
|
||||
prefixes = [0] * num_datasets
|
||||
for i in range(num_datasets):
|
||||
weights[i] = float(data_prefix[2 * i])
|
||||
prefixes[i] = (data_prefix[2 * i + 1]).strip()
|
||||
# Normalize weights
|
||||
weight_sum = 0.0
|
||||
for weight in weights:
|
||||
weight_sum += weight
|
||||
assert weight_sum > 0.0
|
||||
weights = [weight / weight_sum for weight in weights]
|
||||
|
||||
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
|
||||
# not uniformly distribute the number of samples, we still have
|
||||
# samples left to feed to the network.
|
||||
datasets_train_valid_test_num_samples = []
|
||||
for weight in weights:
|
||||
datasets_train_valid_test_num_samples.append(
|
||||
[
|
||||
int(math.ceil(val * weight * 1.005))
|
||||
for val in train_valid_test_num_samples
|
||||
]
|
||||
)
|
||||
|
||||
return prefixes, weights, datasets_train_valid_test_num_samples
|
||||
|
||||
|
||||
def compile_helper():
|
||||
"""Compile helper function ar runtime. Make sure this
|
||||
is invoked on a single process."""
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
path = os.path.abspath(os.path.dirname(__file__))
|
||||
ret = subprocess.run(["make", "-C", path])
|
||||
if ret.returncode != 0:
|
||||
print("Making C++ dataset helpers module failed, exiting.")
|
||||
import sys
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def get_a_and_b_segments(sample, np_rng):
|
||||
"""Divide sample into a and b segments."""
|
||||
|
||||
# Number of sentences in the sample.
|
||||
n_sentences = len(sample)
|
||||
# Make sure we always have two sentences.
|
||||
assert n_sentences > 1, "make sure each sample has at least two sentences."
|
||||
|
||||
# First part:
|
||||
# `a_end` is how many sentences go into the `A`.
|
||||
a_end = 1
|
||||
if n_sentences >= 3:
|
||||
# Note that randin in numpy is exclusive.
|
||||
a_end = np_rng.randint(1, n_sentences)
|
||||
tokens_a = []
|
||||
for j in range(a_end):
|
||||
tokens_a.extend(sample[j])
|
||||
|
||||
# Second part:
|
||||
tokens_b = []
|
||||
for j in range(a_end, n_sentences):
|
||||
tokens_b.extend(sample[j])
|
||||
|
||||
# Random next:
|
||||
is_next_random = False
|
||||
if np_rng.random() < 0.5:
|
||||
is_next_random = True
|
||||
tokens_a, tokens_b = tokens_b, tokens_a
|
||||
|
||||
return tokens_a, tokens_b, is_next_random
|
||||
|
||||
|
||||
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
|
||||
"""Truncates a pair of sequences to a maximum sequence length."""
|
||||
# print(len_a, len_b, max_num_tokens)
|
||||
assert len_a > 0
|
||||
if len_a + len_b <= max_num_tokens:
|
||||
return False
|
||||
while len_a + len_b > max_num_tokens:
|
||||
if len_a > len_b:
|
||||
len_a -= 1
|
||||
tokens = tokens_a
|
||||
else:
|
||||
len_b -= 1
|
||||
tokens = tokens_b
|
||||
if np_rng.random() < 0.5:
|
||||
del tokens[0]
|
||||
else:
|
||||
tokens.pop()
|
||||
return True
|
||||
|
||||
|
||||
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
|
||||
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
|
||||
|
||||
tokens = []
|
||||
tokentypes = []
|
||||
# [CLS].
|
||||
tokens.append(cls_id)
|
||||
tokentypes.append(0)
|
||||
# Segment A.
|
||||
for token in tokens_a:
|
||||
tokens.append(token)
|
||||
tokentypes.append(0)
|
||||
# [SEP].
|
||||
tokens.append(sep_id)
|
||||
tokentypes.append(0)
|
||||
# Segment B.
|
||||
for token in tokens_b:
|
||||
tokens.append(token)
|
||||
tokentypes.append(1)
|
||||
if tokens_b:
|
||||
# [SEP].
|
||||
tokens.append(sep_id)
|
||||
tokentypes.append(1)
|
||||
|
||||
return tokens, tokentypes
|
||||
|
||||
|
||||
MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"])
|
||||
|
||||
|
||||
def is_start_piece(piece):
|
||||
"""Check if the current word piece is the starting piece (BERT)."""
|
||||
# When a word has been split into
|
||||
# WordPieces, the first token does not have any marker and any subsequence
|
||||
# tokens are prefixed with ##. So whenever we see the ## token, we
|
||||
# append it to the previous set of word indexes.
|
||||
return not piece.startswith("##")
|
||||
|
||||
|
||||
def create_masked_lm_predictions(
|
||||
tokens,
|
||||
vocab_id_list,
|
||||
vocab_id_to_token_dict,
|
||||
masked_lm_prob,
|
||||
cls_id,
|
||||
sep_id,
|
||||
mask_id,
|
||||
max_predictions_per_seq,
|
||||
np_rng,
|
||||
max_ngrams=3,
|
||||
do_whole_word_mask=True,
|
||||
favor_longer_ngram=False,
|
||||
do_permutation=False,
|
||||
geometric_dist=False,
|
||||
masking_style="bert",
|
||||
):
|
||||
"""Creates the predictions for the masked LM objective.
|
||||
Note: Tokens here are vocab ids and not text tokens."""
|
||||
|
||||
cand_indexes = []
|
||||
# Note(mingdachen): We create a list for recording if the piece is
|
||||
# the starting piece of current token, where 1 means true, so that
|
||||
# on-the-fly whole word masking is possible.
|
||||
token_boundary = [0] * len(tokens)
|
||||
|
||||
for (i, token) in enumerate(tokens):
|
||||
if token == cls_id or token == sep_id:
|
||||
token_boundary[i] = 1
|
||||
continue
|
||||
# Whole Word Masking means that if we mask all of the wordpieces
|
||||
# corresponding to an original word.
|
||||
#
|
||||
# Note that Whole Word Masking does *not* change the training code
|
||||
# at all -- we still predict each WordPiece independently, softmaxed
|
||||
# over the entire vocabulary.
|
||||
if (
|
||||
do_whole_word_mask
|
||||
and len(cand_indexes) >= 1
|
||||
and not is_start_piece(vocab_id_to_token_dict[token])
|
||||
):
|
||||
cand_indexes[-1].append(i)
|
||||
else:
|
||||
cand_indexes.append([i])
|
||||
if is_start_piece(vocab_id_to_token_dict[token]):
|
||||
token_boundary[i] = 1
|
||||
|
||||
output_tokens = list(tokens)
|
||||
|
||||
masked_lm_positions = []
|
||||
masked_lm_labels = []
|
||||
|
||||
if masked_lm_prob == 0:
|
||||
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
|
||||
|
||||
num_to_predict = min(
|
||||
max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))
|
||||
)
|
||||
|
||||
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
|
||||
if not geometric_dist:
|
||||
# Note(mingdachen):
|
||||
# By default, we set the probilities to favor shorter ngram sequences.
|
||||
pvals = 1.0 / np.arange(1, max_ngrams + 1)
|
||||
pvals /= pvals.sum(keepdims=True)
|
||||
if favor_longer_ngram:
|
||||
pvals = pvals[::-1]
|
||||
|
||||
ngram_indexes = []
|
||||
for idx in range(len(cand_indexes)):
|
||||
ngram_index = []
|
||||
for n in ngrams:
|
||||
ngram_index.append(cand_indexes[idx : idx + n])
|
||||
ngram_indexes.append(ngram_index)
|
||||
|
||||
np_rng.shuffle(ngram_indexes)
|
||||
|
||||
(masked_lms, masked_spans) = ([], [])
|
||||
covered_indexes = set()
|
||||
for cand_index_set in ngram_indexes:
|
||||
if len(masked_lms) >= num_to_predict:
|
||||
break
|
||||
if not cand_index_set:
|
||||
continue
|
||||
# Note(mingdachen):
|
||||
# Skip current piece if they are covered in lm masking or previous ngrams.
|
||||
for index_set in cand_index_set[0]:
|
||||
for index in index_set:
|
||||
if index in covered_indexes:
|
||||
continue
|
||||
|
||||
if not geometric_dist:
|
||||
n = np_rng.choice(
|
||||
ngrams[: len(cand_index_set)],
|
||||
p=pvals[: len(cand_index_set)]
|
||||
/ pvals[: len(cand_index_set)].sum(keepdims=True),
|
||||
)
|
||||
else:
|
||||
# Sampling "n" from the geometric distribution and clipping it to
|
||||
# the max_ngrams. Using p=0.2 default from the SpanBERT paper
|
||||
# https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
|
||||
n = min(np_rng.geometric(0.2), max_ngrams)
|
||||
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
# Note(mingdachen):
|
||||
# Repeatedly looking for a candidate that does not exceed the
|
||||
# maximum number of predictions by trying shorter ngrams.
|
||||
while len(masked_lms) + len(index_set) > num_to_predict:
|
||||
if n == 0:
|
||||
break
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
# If adding a whole-word mask would exceed the maximum number of
|
||||
# predictions, then just skip this candidate.
|
||||
if len(masked_lms) + len(index_set) > num_to_predict:
|
||||
continue
|
||||
is_any_index_covered = False
|
||||
for index in index_set:
|
||||
if index in covered_indexes:
|
||||
is_any_index_covered = True
|
||||
break
|
||||
if is_any_index_covered:
|
||||
continue
|
||||
for index in index_set:
|
||||
covered_indexes.add(index)
|
||||
masked_token = None
|
||||
if masking_style == "bert":
|
||||
# 80% of the time, replace with [MASK]
|
||||
if np_rng.random() < 0.8:
|
||||
masked_token = mask_id
|
||||
else:
|
||||
# 10% of the time, keep original
|
||||
if np_rng.random() < 0.5:
|
||||
masked_token = tokens[index]
|
||||
# 10% of the time, replace with random word
|
||||
else:
|
||||
masked_token = vocab_id_list[
|
||||
np_rng.randint(0, len(vocab_id_list))
|
||||
]
|
||||
elif masking_style == "t5":
|
||||
masked_token = mask_id
|
||||
else:
|
||||
raise ValueError("invalid value of masking style")
|
||||
|
||||
output_tokens[index] = masked_token
|
||||
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
|
||||
|
||||
masked_spans.append(
|
||||
MaskedLmInstance(
|
||||
index=index_set, label=[tokens[index] for index in index_set]
|
||||
)
|
||||
)
|
||||
|
||||
assert len(masked_lms) <= num_to_predict
|
||||
np_rng.shuffle(ngram_indexes)
|
||||
|
||||
select_indexes = set()
|
||||
if do_permutation:
|
||||
for cand_index_set in ngram_indexes:
|
||||
if len(select_indexes) >= num_to_predict:
|
||||
break
|
||||
if not cand_index_set:
|
||||
continue
|
||||
# Note(mingdachen):
|
||||
# Skip current piece if they are covered in lm masking or previous ngrams.
|
||||
for index_set in cand_index_set[0]:
|
||||
for index in index_set:
|
||||
if index in covered_indexes or index in select_indexes:
|
||||
continue
|
||||
|
||||
n = np.random.choice(
|
||||
ngrams[: len(cand_index_set)],
|
||||
p=pvals[: len(cand_index_set)]
|
||||
/ pvals[: len(cand_index_set)].sum(keepdims=True),
|
||||
)
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
|
||||
while len(select_indexes) + len(index_set) > num_to_predict:
|
||||
if n == 0:
|
||||
break
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
# If adding a whole-word mask would exceed the maximum number of
|
||||
# predictions, then just skip this candidate.
|
||||
if len(select_indexes) + len(index_set) > num_to_predict:
|
||||
continue
|
||||
is_any_index_covered = False
|
||||
for index in index_set:
|
||||
if index in covered_indexes or index in select_indexes:
|
||||
is_any_index_covered = True
|
||||
break
|
||||
if is_any_index_covered:
|
||||
continue
|
||||
for index in index_set:
|
||||
select_indexes.add(index)
|
||||
assert len(select_indexes) <= num_to_predict
|
||||
|
||||
select_indexes = sorted(select_indexes)
|
||||
permute_indexes = list(select_indexes)
|
||||
np_rng.shuffle(permute_indexes)
|
||||
orig_token = list(output_tokens)
|
||||
|
||||
for src_i, tgt_i in zip(select_indexes, permute_indexes):
|
||||
output_tokens[src_i] = orig_token[tgt_i]
|
||||
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
|
||||
|
||||
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
||||
# Sort the spans by the index of the first span
|
||||
masked_spans = sorted(masked_spans, key=lambda x: x.index[0])
|
||||
|
||||
for p in masked_lms:
|
||||
masked_lm_positions.append(p.index)
|
||||
masked_lm_labels.append(p.label)
|
||||
return (
|
||||
output_tokens,
|
||||
masked_lm_positions,
|
||||
masked_lm_labels,
|
||||
token_boundary,
|
||||
masked_spans,
|
||||
)
|
||||
|
||||
|
||||
def pad_and_convert_to_numpy(
|
||||
tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length
|
||||
):
|
||||
"""Pad sequences and convert them to numpy."""
|
||||
|
||||
# Some checks.
|
||||
num_tokens = len(tokens)
|
||||
padding_length = max_seq_length - num_tokens
|
||||
assert padding_length >= 0
|
||||
assert len(tokentypes) == num_tokens
|
||||
assert len(masked_positions) == len(masked_labels)
|
||||
|
||||
# Tokens and token types.
|
||||
filler = [pad_id] * padding_length
|
||||
tokens_np = np.array(tokens + filler, dtype=np.int64)
|
||||
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
|
||||
|
||||
# Padding mask.
|
||||
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64)
|
||||
|
||||
# Lables and loss mask.
|
||||
labels = [-1] * max_seq_length
|
||||
loss_mask = [0] * max_seq_length
|
||||
for i in range(len(masked_positions)):
|
||||
assert masked_positions[i] < num_tokens
|
||||
labels[masked_positions[i]] = masked_labels[i]
|
||||
loss_mask[masked_positions[i]] = 1
|
||||
labels_np = np.array(labels, dtype=np.int64)
|
||||
loss_mask_np = np.array(loss_mask, dtype=np.int64)
|
||||
|
||||
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
|
||||
|
||||
|
||||
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
|
||||
|
||||
print_rank_0(" > building dataset index ...")
|
||||
|
||||
start_time = time.time()
|
||||
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
|
||||
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
|
||||
print_rank_0(
|
||||
" > finished creating indexed dataset in {:4f} "
|
||||
"seconds".format(time.time() - start_time)
|
||||
)
|
||||
|
||||
print_rank_0(" > indexed dataset stats:")
|
||||
print_rank_0(
|
||||
" number of documents: {}".format(indexed_dataset.doc_idx.shape[0] - 1)
|
||||
)
|
||||
print_rank_0(" number of sentences: {}".format(indexed_dataset.sizes.shape[0]))
|
||||
|
||||
return indexed_dataset
|
||||
|
||||
|
||||
def get_train_valid_test_split_(splits_string, size):
|
||||
"""Get dataset splits from comma or '/' separated string list."""
|
||||
|
||||
splits = []
|
||||
if splits_string.find(",") != -1:
|
||||
splits = [float(s) for s in splits_string.split(",")]
|
||||
elif splits_string.find("/") != -1:
|
||||
splits = [float(s) for s in splits_string.split("/")]
|
||||
else:
|
||||
splits = [float(splits_string)]
|
||||
while len(splits) < 3:
|
||||
splits.append(0.0)
|
||||
splits = splits[:3]
|
||||
splits_sum = sum(splits)
|
||||
assert splits_sum > 0.0
|
||||
splits = [split / splits_sum for split in splits]
|
||||
splits_index = [0]
|
||||
for index, split in enumerate(splits):
|
||||
splits_index.append(splits_index[index] + int(round(split * float(size))))
|
||||
diff = splits_index[-1] - size
|
||||
for index in range(1, len(splits_index)):
|
||||
splits_index[index] -= diff
|
||||
assert len(splits_index) == 4
|
||||
assert splits_index[-1] == size
|
||||
return splits_index
|
||||
|
||||
|
||||
def get_samples_mapping(
|
||||
indexed_dataset,
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
max_seq_length,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
name,
|
||||
binary_head,
|
||||
):
|
||||
"""Get a list that maps a sample index to a starting sentence index, end sentence index, and length"""
|
||||
|
||||
if not num_epochs:
|
||||
if not max_num_samples:
|
||||
raise ValueError("Need to specify either max_num_samples " "or num_epochs")
|
||||
num_epochs = np.iinfo(np.int32).max - 1
|
||||
if not max_num_samples:
|
||||
max_num_samples = np.iinfo(np.int64).max - 1
|
||||
|
||||
# Filename of the index mapping
|
||||
indexmap_filename = data_prefix
|
||||
indexmap_filename += "_{}_indexmap".format(name)
|
||||
if num_epochs != (np.iinfo(np.int32).max - 1):
|
||||
indexmap_filename += "_{}ep".format(num_epochs)
|
||||
if max_num_samples != (np.iinfo(np.int64).max - 1):
|
||||
indexmap_filename += "_{}mns".format(max_num_samples)
|
||||
indexmap_filename += "_{}msl".format(max_seq_length)
|
||||
indexmap_filename += "_{:0.2f}ssp".format(short_seq_prob)
|
||||
indexmap_filename += "_{}s".format(seed)
|
||||
indexmap_filename += ".npy"
|
||||
|
||||
# Build the indexed mapping if not exist.
|
||||
if torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename):
|
||||
print(
|
||||
" > WARNING: could not find index map file {}, building "
|
||||
"the indices on rank 0 ...".format(indexmap_filename)
|
||||
)
|
||||
|
||||
# Make sure the types match the helpers input types.
|
||||
assert indexed_dataset.doc_idx.dtype == np.int64
|
||||
assert indexed_dataset.sizes.dtype == np.int32
|
||||
|
||||
# Build samples mapping
|
||||
verbose = torch.distributed.get_rank() == 0
|
||||
start_time = time.time()
|
||||
print_rank_0(" > building sapmles index mapping for {} ...".format(name))
|
||||
# First compile and then import.
|
||||
from megatron.data import helpers
|
||||
|
||||
samples_mapping = helpers.build_mapping(
|
||||
indexed_dataset.doc_idx,
|
||||
indexed_dataset.sizes,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
max_seq_length,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
verbose,
|
||||
2 if binary_head else 1,
|
||||
)
|
||||
print_rank_0(" > done building sapmles index maping")
|
||||
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
|
||||
print_rank_0(" > saved the index mapping in {}".format(indexmap_filename))
|
||||
# Make sure all the ranks have built the mapping
|
||||
print_rank_0(
|
||||
" > elasped time to build and save samples mapping "
|
||||
"(seconds): {:4f}".format(time.time() - start_time)
|
||||
)
|
||||
# This should be a barrier but nccl barrier assumes
|
||||
# device_index=rank which is not the case for model
|
||||
# parallel case
|
||||
counts = torch.cuda.LongTensor([1])
|
||||
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
|
||||
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
|
||||
assert counts[0].item() == (
|
||||
torch.distributed.get_world_size()
|
||||
// torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())
|
||||
)
|
||||
|
||||
# Load indexed dataset.
|
||||
print_rank_0(" > loading indexed mapping from {}".format(indexmap_filename))
|
||||
start_time = time.time()
|
||||
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode="r")
|
||||
print_rank_0(
|
||||
" loaded indexed file in {:3.3f} seconds".format(time.time() - start_time)
|
||||
)
|
||||
print_rank_0(" total number of samples: {}".format(samples_mapping.shape[0]))
|
||||
|
||||
return samples_mapping
|
@ -0,0 +1,717 @@
|
||||
/*
|
||||
coding=utf-8
|
||||
Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
/* Helper methods for fast index mapping builds */
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <math.h>
|
||||
#include <stdexcept>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <random>
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace std;
|
||||
|
||||
const int32_t LONG_SENTENCE_LEN = 512;
|
||||
|
||||
|
||||
void build_blending_indices(py::array_t<uint8_t>& dataset_index,
|
||||
py::array_t<int64_t>& dataset_sample_index,
|
||||
const py::array_t<double>& weights,
|
||||
const int32_t num_datasets,
|
||||
const int64_t size, const bool verbose) {
|
||||
/* Given multiple datasets and a weighting array, build samples
|
||||
such that it follows those wieghts.*/
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "> building indices for blendable datasets ..." << std::endl;
|
||||
}
|
||||
|
||||
// Get the pointer access without the checks.
|
||||
auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
|
||||
auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
|
||||
auto weights_ptr = weights.unchecked<1>();
|
||||
|
||||
// Initialize buffer for number of samples used for each dataset.
|
||||
int64_t current_samples[num_datasets];
|
||||
for(int64_t i = 0; i < num_datasets; ++i) {
|
||||
current_samples[i] = 0;
|
||||
}
|
||||
|
||||
// For each sample:
|
||||
for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {
|
||||
|
||||
// Determine where the max error in sampling is happening.
|
||||
auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
|
||||
int64_t max_error_index = 0;
|
||||
double max_error = weights_ptr[0] * sample_idx_double -
|
||||
static_cast<double>(current_samples[0]);
|
||||
for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
|
||||
double error = weights_ptr[dataset_idx] * sample_idx_double -
|
||||
static_cast<double>(current_samples[dataset_idx]);
|
||||
if (error > max_error) {
|
||||
max_error = error;
|
||||
max_error_index = dataset_idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Populate the indices.
|
||||
dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
|
||||
dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];
|
||||
|
||||
// Update the total samples.
|
||||
current_samples[max_error_index] += 1;
|
||||
|
||||
}
|
||||
|
||||
// print info
|
||||
if (verbose) {
|
||||
std::cout << " > sample ratios:" << std::endl;
|
||||
for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
|
||||
auto ratio = static_cast<double>(current_samples[dataset_idx]) /
|
||||
static_cast<double>(size);
|
||||
std::cout << " dataset " << dataset_idx << ", input: " <<
|
||||
weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
|
||||
const py::array_t<int32_t>& doc_idx_,
|
||||
const int32_t seq_length,
|
||||
const int32_t num_epochs,
|
||||
const int64_t tokens_per_epoch) {
|
||||
/* Sample index (sample_idx) is used for gpt2 like dataset for which
|
||||
the documents are flattened and the samples are built based on this
|
||||
1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
|
||||
where [..., 0] contains the index into `doc_idx` and [..., 1] is the
|
||||
starting offset in that document.*/
|
||||
|
||||
// Consistency checks.
|
||||
assert(seq_length > 1);
|
||||
assert(num_epochs > 0);
|
||||
assert(tokens_per_epoch > 1);
|
||||
|
||||
// Remove bound checks.
|
||||
auto sizes = sizes_.unchecked<1>();
|
||||
auto doc_idx = doc_idx_.unchecked<1>();
|
||||
|
||||
// Mapping and it's length (1D).
|
||||
int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
|
||||
int32_t* sample_idx = new int32_t[2*(num_samples+1)];
|
||||
|
||||
cout << " using:" << endl << std::flush;
|
||||
cout << " number of documents: " <<
|
||||
doc_idx_.shape(0) / num_epochs << endl << std::flush;
|
||||
cout << " number of epochs: " << num_epochs <<
|
||||
endl << std::flush;
|
||||
cout << " sequence length: " << seq_length <<
|
||||
endl << std::flush;
|
||||
cout << " total number of samples: " << num_samples <<
|
||||
endl << std::flush;
|
||||
|
||||
// Index into sample_idx.
|
||||
int64_t sample_index = 0;
|
||||
// Index into doc_idx.
|
||||
int64_t doc_idx_index = 0;
|
||||
// Begining offset for each document.
|
||||
int32_t doc_offset = 0;
|
||||
// Start with first document and no offset.
|
||||
sample_idx[2 * sample_index] = doc_idx_index;
|
||||
sample_idx[2 * sample_index + 1] = doc_offset;
|
||||
++sample_index;
|
||||
|
||||
while (sample_index <= num_samples) {
|
||||
// Start with a fresh sequence.
|
||||
int32_t remaining_seq_length = seq_length + 1;
|
||||
while (remaining_seq_length != 0) {
|
||||
// Get the document length.
|
||||
auto doc_id = doc_idx[doc_idx_index];
|
||||
auto doc_length = sizes[doc_id] - doc_offset;
|
||||
// And add it to the current sequence.
|
||||
remaining_seq_length -= doc_length;
|
||||
// If we have more than a full sequence, adjust offset and set
|
||||
// remaining length to zero so we return from the while loop.
|
||||
// Note that -1 here is for the same reason we have -1 in
|
||||
// `_num_epochs` calculations.
|
||||
if (remaining_seq_length <= 0) {
|
||||
doc_offset += (remaining_seq_length + doc_length - 1);
|
||||
remaining_seq_length = 0;
|
||||
} else {
|
||||
// Otherwise, start from the begining of the next document.
|
||||
++doc_idx_index;
|
||||
doc_offset = 0;
|
||||
}
|
||||
}
|
||||
// Record the sequence.
|
||||
sample_idx[2 * sample_index] = doc_idx_index;
|
||||
sample_idx[2 * sample_index + 1] = doc_offset;
|
||||
++sample_index;
|
||||
}
|
||||
|
||||
// Method to deallocate memory.
|
||||
py::capsule free_when_done(sample_idx, [](void *mem_) {
|
||||
int32_t *mem = reinterpret_cast<int32_t*>(mem_);
|
||||
delete[] mem;
|
||||
});
|
||||
|
||||
// Return the numpy array.
|
||||
const auto byte_size = sizeof(int32_t);
|
||||
return py::array(std::vector<int64_t>{num_samples+1, 2}, // shape
|
||||
{2*byte_size, byte_size}, // C-style contiguous strides
|
||||
sample_idx, // the data pointer
|
||||
free_when_done); // numpy array references
|
||||
|
||||
}
|
||||
|
||||
|
||||
inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
|
||||
const int32_t max_length,
|
||||
std::mt19937& rand32_gen) {
|
||||
/* Training sample length. */
|
||||
if (short_seq_ratio == 0) {
|
||||
return max_length;
|
||||
}
|
||||
const auto random_number = rand32_gen();
|
||||
if ((random_number % short_seq_ratio) == 0) {
|
||||
return 2 + random_number % (max_length - 1);
|
||||
}
|
||||
return max_length;
|
||||
}
|
||||
|
||||
|
||||
template<typename DocIdx>
|
||||
py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
|
||||
const py::array_t<int32_t>& sizes_,
|
||||
const int32_t num_epochs,
|
||||
const uint64_t max_num_samples,
|
||||
const int32_t max_seq_length,
|
||||
const double short_seq_prob,
|
||||
const int32_t seed,
|
||||
const bool verbose,
|
||||
const int32_t min_num_sent) {
|
||||
/* Build a mapping of (start-index, end-index, sequence-length) where
|
||||
start and end index are the indices of the sentences in the sample
|
||||
and sequence-length is the target sequence length.
|
||||
*/
|
||||
|
||||
// Consistency checks.
|
||||
assert(num_epochs > 0);
|
||||
assert(max_seq_length > 1);
|
||||
assert(short_seq_prob >= 0.0);
|
||||
assert(short_seq_prob <= 1.0);
|
||||
assert(seed > 0);
|
||||
|
||||
// Remove bound checks.
|
||||
auto docs = docs_.unchecked<1>();
|
||||
auto sizes = sizes_.unchecked<1>();
|
||||
|
||||
// For efficiency, convert probability to ratio. Note: rand() generates int.
|
||||
int32_t short_seq_ratio = 0;
|
||||
if (short_seq_prob > 0) {
|
||||
short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
const auto sent_start_index = docs[0];
|
||||
const auto sent_end_index = docs[docs_.shape(0) - 1];
|
||||
const auto num_sentences = sent_end_index - sent_start_index;
|
||||
cout << " using:" << endl << std::flush;
|
||||
cout << " number of documents: " << docs_.shape(0) - 1 <<
|
||||
endl << std::flush;
|
||||
cout << " sentences range: [" << sent_start_index <<
|
||||
", " << sent_end_index << ")" << endl << std::flush;
|
||||
cout << " total number of sentences: " << num_sentences <<
|
||||
endl << std::flush;
|
||||
cout << " number of epochs: " << num_epochs <<
|
||||
endl << std::flush;
|
||||
cout << " maximum number of samples: " << max_num_samples <<
|
||||
endl << std::flush;
|
||||
cout << " maximum sequence length: " << max_seq_length <<
|
||||
endl << std::flush;
|
||||
cout << " short sequence probability: " << short_seq_prob <<
|
||||
endl << std::flush;
|
||||
cout << " short sequence ration (1/prob): " << short_seq_ratio <<
|
||||
endl << std::flush;
|
||||
cout << " seed: " << seed << endl <<
|
||||
std::flush;
|
||||
}
|
||||
|
||||
// Mapping and it's length (1D).
|
||||
int64_t num_samples = -1;
|
||||
DocIdx* maps = NULL;
|
||||
|
||||
// Perform two iterations, in the first iteration get the size
|
||||
// and allocate memory and in the second iteration populate the map.
|
||||
bool second = false;
|
||||
for (int32_t iteration=0; iteration<2; ++iteration) {
|
||||
|
||||
// Set the seed so both iterations produce the same results.
|
||||
std::mt19937 rand32_gen(seed);
|
||||
|
||||
// Set the flag on second iteration.
|
||||
second = (iteration == 1);
|
||||
|
||||
// Counters:
|
||||
uint64_t empty_docs = 0;
|
||||
uint64_t one_sent_docs = 0;
|
||||
uint64_t long_sent_docs = 0;
|
||||
|
||||
// Current map index.
|
||||
uint64_t map_index = 0;
|
||||
|
||||
// For each epoch:
|
||||
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
|
||||
if (map_index >= max_num_samples) {
|
||||
if (verbose && (!second)) {
|
||||
cout << " reached " << max_num_samples << " samples after "
|
||||
<< epoch << " epochs ..." << endl << std::flush;
|
||||
}
|
||||
break;
|
||||
}
|
||||
// For each document:
|
||||
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
|
||||
|
||||
// Document sentences are in [sent_index_first, sent_index_last)
|
||||
const auto sent_index_first = docs[doc];
|
||||
const auto sent_index_last = docs[doc + 1];
|
||||
|
||||
// At the begining of the document previous index is the
|
||||
// start index.
|
||||
auto prev_start_index = sent_index_first;
|
||||
|
||||
// Remaining documents.
|
||||
auto num_remain_sent = sent_index_last - sent_index_first;
|
||||
|
||||
// Some bookkeeping
|
||||
if ((epoch == 0) && (!second)) {
|
||||
if (num_remain_sent == 0) {
|
||||
++empty_docs;
|
||||
}
|
||||
if (num_remain_sent == 1) {
|
||||
++one_sent_docs;
|
||||
}
|
||||
}
|
||||
|
||||
// Detect documents with long sentences.
|
||||
bool contains_long_sentence = false;
|
||||
if (num_remain_sent > 1) {
|
||||
for (auto sent_index=sent_index_first;
|
||||
sent_index < sent_index_last; ++sent_index) {
|
||||
if (sizes[sent_index] > LONG_SENTENCE_LEN){
|
||||
if ((epoch == 0) && (!second)) {
|
||||
++long_sent_docs;
|
||||
}
|
||||
contains_long_sentence = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we have more than two sentences.
|
||||
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
|
||||
|
||||
// Set values.
|
||||
auto seq_len = int32_t{0};
|
||||
auto num_sent = int32_t{0};
|
||||
auto target_seq_len = get_target_sample_len(short_seq_ratio,
|
||||
max_seq_length,
|
||||
rand32_gen);
|
||||
|
||||
// Loop through sentences.
|
||||
for (auto sent_index=sent_index_first;
|
||||
sent_index < sent_index_last; ++sent_index) {
|
||||
|
||||
// Add the size and number of sentences.
|
||||
seq_len += sizes[sent_index];
|
||||
++num_sent;
|
||||
--num_remain_sent;
|
||||
|
||||
// If we have reached the target length.
|
||||
// and if not only one sentence is left in the document.
|
||||
// and if we have at least two sentneces.
|
||||
// and if we have reached end of the document.
|
||||
if (((seq_len >= target_seq_len) &&
|
||||
(num_remain_sent > 1) &&
|
||||
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
|
||||
|
||||
// Check for overflow.
|
||||
if ((3 * map_index + 2) >
|
||||
std::numeric_limits<int64_t>::max()) {
|
||||
cout << "number of samples exceeded maximum "
|
||||
<< "allowed by type int64: "
|
||||
<< std::numeric_limits<int64_t>::max()
|
||||
<< endl;
|
||||
throw std::overflow_error("Number of samples");
|
||||
}
|
||||
|
||||
// Populate the map.
|
||||
if (second) {
|
||||
const auto map_index_0 = 3 * map_index;
|
||||
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
|
||||
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
|
||||
maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
|
||||
}
|
||||
|
||||
// Update indices / counters.
|
||||
++map_index;
|
||||
prev_start_index = sent_index + 1;
|
||||
target_seq_len = get_target_sample_len(short_seq_ratio,
|
||||
max_seq_length,
|
||||
rand32_gen);
|
||||
seq_len = 0;
|
||||
num_sent = 0;
|
||||
}
|
||||
|
||||
} // for (auto sent_index=sent_index_first; ...
|
||||
} // if (num_remain_sent > 1) {
|
||||
} // for (int doc=0; doc < num_docs; ++doc) {
|
||||
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
||||
|
||||
if (!second) {
|
||||
if (verbose) {
|
||||
cout << " number of empty documents: " << empty_docs <<
|
||||
endl << std::flush;
|
||||
cout << " number of documents with one sentence: " <<
|
||||
one_sent_docs << endl << std::flush;
|
||||
cout << " number of documents with long sentences: " <<
|
||||
long_sent_docs << endl << std::flush;
|
||||
cout << " will create mapping for " << map_index <<
|
||||
" samples" << endl << std::flush;
|
||||
}
|
||||
assert(maps == NULL);
|
||||
assert(num_samples < 0);
|
||||
maps = new DocIdx[3*map_index];
|
||||
num_samples = static_cast<int64_t>(map_index);
|
||||
}
|
||||
|
||||
} // for (int iteration=0; iteration < 2; ++iteration) {
|
||||
|
||||
// Shuffle.
|
||||
// We need a 64 bit random number generator as we might have more
|
||||
// than 2 billion samples.
|
||||
std::mt19937_64 rand64_gen(seed + 1);
|
||||
for (auto i=(num_samples - 1); i > 0; --i) {
|
||||
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
|
||||
const auto i0 = 3 * i;
|
||||
const auto j0 = 3 * j;
|
||||
// Swap values.
|
||||
swap(maps[i0], maps[j0]);
|
||||
swap(maps[i0 + 1], maps[j0 + 1]);
|
||||
swap(maps[i0 + 2], maps[j0 + 2]);
|
||||
}
|
||||
|
||||
// Method to deallocate memory.
|
||||
py::capsule free_when_done(maps, [](void *mem_) {
|
||||
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
|
||||
delete[] mem;
|
||||
});
|
||||
|
||||
// Return the numpy array.
|
||||
const auto byte_size = sizeof(DocIdx);
|
||||
return py::array(std::vector<int64_t>{num_samples, 3}, // shape
|
||||
{3*byte_size, byte_size}, // C-style contiguous strides
|
||||
maps, // the data pointer
|
||||
free_when_done); // numpy array references
|
||||
|
||||
}
|
||||
|
||||
|
||||
py::array build_mapping(const py::array_t<int64_t>& docs_,
|
||||
const py::array_t<int>& sizes_,
|
||||
const int num_epochs,
|
||||
const uint64_t max_num_samples,
|
||||
const int max_seq_length,
|
||||
const double short_seq_prob,
|
||||
const int seed,
|
||||
const bool verbose,
|
||||
const int32_t min_num_sent) {
|
||||
|
||||
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
|
||||
if (verbose) {
|
||||
cout << " using uint64 for data mapping..." << endl << std::flush;
|
||||
}
|
||||
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
|
||||
max_num_samples, max_seq_length,
|
||||
short_seq_prob, seed, verbose,
|
||||
min_num_sent);
|
||||
} else {
|
||||
if (verbose) {
|
||||
cout << " using uint32 for data mapping..." << endl << std::flush;
|
||||
}
|
||||
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
|
||||
max_num_samples, max_seq_length,
|
||||
short_seq_prob, seed, verbose,
|
||||
min_num_sent);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename DocIdx>
|
||||
py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
|
||||
const py::array_t<int32_t>& sizes_,
|
||||
const py::array_t<int32_t>& titles_sizes_,
|
||||
const int32_t num_epochs,
|
||||
const uint64_t max_num_samples,
|
||||
const int32_t max_seq_length,
|
||||
const int32_t seed,
|
||||
const bool verbose,
|
||||
const bool use_one_sent_blocks) {
|
||||
/* Build a mapping of (start-index, end-index, sequence-length) where
|
||||
start and end index are the indices of the sentences in the sample
|
||||
and sequence-length is the target sequence length.
|
||||
*/
|
||||
|
||||
// Consistency checks.
|
||||
assert(num_epochs > 0);
|
||||
assert(max_seq_length > 1);
|
||||
assert(seed > 0);
|
||||
|
||||
// Remove bound checks.
|
||||
auto docs = docs_.unchecked<1>();
|
||||
auto sizes = sizes_.unchecked<1>();
|
||||
auto titles_sizes = titles_sizes_.unchecked<1>();
|
||||
|
||||
if (verbose) {
|
||||
const auto sent_start_index = docs[0];
|
||||
const auto sent_end_index = docs[docs_.shape(0) - 1];
|
||||
const auto num_sentences = sent_end_index - sent_start_index;
|
||||
cout << " using:" << endl << std::flush;
|
||||
cout << " number of documents: " << docs_.shape(0) - 1 <<
|
||||
endl << std::flush;
|
||||
cout << " sentences range: [" << sent_start_index <<
|
||||
", " << sent_end_index << ")" << endl << std::flush;
|
||||
cout << " total number of sentences: " << num_sentences <<
|
||||
endl << std::flush;
|
||||
cout << " number of epochs: " << num_epochs <<
|
||||
endl << std::flush;
|
||||
cout << " maximum number of samples: " << max_num_samples <<
|
||||
endl << std::flush;
|
||||
cout << " maximum sequence length: " << max_seq_length <<
|
||||
endl << std::flush;
|
||||
cout << " seed: " << seed << endl <<
|
||||
std::flush;
|
||||
}
|
||||
|
||||
// Mapping and its length (1D).
|
||||
int64_t num_samples = -1;
|
||||
DocIdx* maps = NULL;
|
||||
|
||||
// Acceptable number of sentences per block.
|
||||
int min_num_sent = 2;
|
||||
if (use_one_sent_blocks) {
|
||||
min_num_sent = 1;
|
||||
}
|
||||
|
||||
// Perform two iterations, in the first iteration get the size
|
||||
// and allocate memory and in the second iteration populate the map.
|
||||
bool second = false;
|
||||
for (int32_t iteration=0; iteration<2; ++iteration) {
|
||||
|
||||
// Set the flag on second iteration.
|
||||
second = (iteration == 1);
|
||||
|
||||
// Current map index.
|
||||
uint64_t map_index = 0;
|
||||
|
||||
uint64_t empty_docs = 0;
|
||||
uint64_t one_sent_docs = 0;
|
||||
uint64_t long_sent_docs = 0;
|
||||
// For each epoch:
|
||||
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
|
||||
// assign every block a unique id
|
||||
int32_t block_id = 0;
|
||||
|
||||
if (map_index >= max_num_samples) {
|
||||
if (verbose && (!second)) {
|
||||
cout << " reached " << max_num_samples << " samples after "
|
||||
<< epoch << " epochs ..." << endl << std::flush;
|
||||
}
|
||||
break;
|
||||
}
|
||||
// For each document:
|
||||
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
|
||||
|
||||
// Document sentences are in [sent_index_first, sent_index_last)
|
||||
const auto sent_index_first = docs[doc];
|
||||
const auto sent_index_last = docs[doc + 1];
|
||||
const auto target_seq_len = max_seq_length - titles_sizes[doc];
|
||||
|
||||
// At the begining of the document previous index is the
|
||||
// start index.
|
||||
auto prev_start_index = sent_index_first;
|
||||
|
||||
// Remaining documents.
|
||||
auto num_remain_sent = sent_index_last - sent_index_first;
|
||||
|
||||
// Some bookkeeping
|
||||
if ((epoch == 0) && (!second)) {
|
||||
if (num_remain_sent == 0) {
|
||||
++empty_docs;
|
||||
}
|
||||
if (num_remain_sent == 1) {
|
||||
++one_sent_docs;
|
||||
}
|
||||
}
|
||||
// Detect documents with long sentences.
|
||||
bool contains_long_sentence = false;
|
||||
if (num_remain_sent >= min_num_sent) {
|
||||
for (auto sent_index=sent_index_first;
|
||||
sent_index < sent_index_last; ++sent_index) {
|
||||
if (sizes[sent_index] > LONG_SENTENCE_LEN){
|
||||
if ((epoch == 0) && (!second)) {
|
||||
++long_sent_docs;
|
||||
}
|
||||
contains_long_sentence = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// If we have enough sentences and no long sentences.
|
||||
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
|
||||
|
||||
// Set values.
|
||||
auto seq_len = int32_t{0};
|
||||
auto num_sent = int32_t{0};
|
||||
|
||||
// Loop through sentences.
|
||||
for (auto sent_index=sent_index_first;
|
||||
sent_index < sent_index_last; ++sent_index) {
|
||||
|
||||
// Add the size and number of sentences.
|
||||
seq_len += sizes[sent_index];
|
||||
++num_sent;
|
||||
--num_remain_sent;
|
||||
|
||||
// If we have reached the target length.
|
||||
// and there are an acceptable number of sentences left
|
||||
// and if we have at least the minimum number of sentences.
|
||||
// or if we have reached end of the document.
|
||||
if (((seq_len >= target_seq_len) &&
|
||||
(num_remain_sent >= min_num_sent) &&
|
||||
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
|
||||
|
||||
// Populate the map.
|
||||
if (second) {
|
||||
const auto map_index_0 = 4 * map_index;
|
||||
// Each sample has 4 items: the starting sentence index, ending sentence index,
|
||||
// the index of the document from which the block comes (used for fetching titles)
|
||||
// and the unique id of the block (used for creating block indexes)
|
||||
|
||||
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
|
||||
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
|
||||
maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
|
||||
maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
|
||||
}
|
||||
|
||||
// Update indices / counters.
|
||||
++map_index;
|
||||
++block_id;
|
||||
prev_start_index = sent_index + 1;
|
||||
seq_len = 0;
|
||||
num_sent = 0;
|
||||
}
|
||||
} // for (auto sent_index=sent_index_first; ...
|
||||
} // if (num_remain_sent > 1) {
|
||||
} // for (int doc=0; doc < num_docs; ++doc) {
|
||||
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
||||
|
||||
if (!second) {
|
||||
if (verbose) {
|
||||
cout << " number of empty documents: " << empty_docs <<
|
||||
endl << std::flush;
|
||||
cout << " number of documents with one sentence: " <<
|
||||
one_sent_docs << endl << std::flush;
|
||||
cout << " number of documents with long sentences: " <<
|
||||
long_sent_docs << endl << std::flush;
|
||||
cout << " will create mapping for " << map_index <<
|
||||
" samples" << endl << std::flush;
|
||||
}
|
||||
assert(maps == NULL);
|
||||
assert(num_samples < 0);
|
||||
maps = new DocIdx[4*map_index];
|
||||
num_samples = static_cast<int64_t>(map_index);
|
||||
}
|
||||
|
||||
} // for (int iteration=0; iteration < 2; ++iteration) {
|
||||
|
||||
// Shuffle.
|
||||
// We need a 64 bit random number generator as we might have more
|
||||
// than 2 billion samples.
|
||||
std::mt19937_64 rand64_gen(seed + 1);
|
||||
for (auto i=(num_samples - 1); i > 0; --i) {
|
||||
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
|
||||
const auto i0 = 4 * i;
|
||||
const auto j0 = 4 * j;
|
||||
// Swap values.
|
||||
swap(maps[i0], maps[j0]);
|
||||
swap(maps[i0 + 1], maps[j0 + 1]);
|
||||
swap(maps[i0 + 2], maps[j0 + 2]);
|
||||
swap(maps[i0 + 3], maps[j0 + 3]);
|
||||
}
|
||||
|
||||
// Method to deallocate memory.
|
||||
py::capsule free_when_done(maps, [](void *mem_) {
|
||||
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
|
||||
delete[] mem;
|
||||
});
|
||||
|
||||
// Return the numpy array.
|
||||
const auto byte_size = sizeof(DocIdx);
|
||||
return py::array(std::vector<int64_t>{num_samples, 4}, // shape
|
||||
{4*byte_size, byte_size}, // C-style contiguous strides
|
||||
maps, // the data pointer
|
||||
free_when_done); // numpy array references
|
||||
|
||||
}
|
||||
|
||||
py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
|
||||
const py::array_t<int>& sizes_,
|
||||
const py::array_t<int>& titles_sizes_,
|
||||
const int num_epochs,
|
||||
const uint64_t max_num_samples,
|
||||
const int max_seq_length,
|
||||
const int seed,
|
||||
const bool verbose,
|
||||
const bool use_one_sent_blocks) {
|
||||
|
||||
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
|
||||
if (verbose) {
|
||||
cout << " using uint64 for data mapping..." << endl << std::flush;
|
||||
}
|
||||
return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
|
||||
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
|
||||
} else {
|
||||
if (verbose) {
|
||||
cout << " using uint32 for data mapping..." << endl << std::flush;
|
||||
}
|
||||
return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
|
||||
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(helpers, m) {
|
||||
m.def("build_mapping", &build_mapping);
|
||||
m.def("build_blocks_mapping", &build_blocks_mapping);
|
||||
m.def("build_sample_idx", &build_sample_idx);
|
||||
m.def("build_blending_indices", &build_blending_indices);
|
||||
}
|
@ -0,0 +1,595 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# copied from fairseq/fairseq/data/indexed_dataset.py
|
||||
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
|
||||
# other slight modifications to remove fairseq dependencies
|
||||
# Added document index to index file and made it accessible.
|
||||
# An empty sentence no longer separates documents.
|
||||
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import struct
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from functools import lru_cache
|
||||
from itertools import accumulate
|
||||
from codegeex.megatron import print_rank_0
|
||||
|
||||
|
||||
def __best_fitting_dtype(vocab_size=None):
|
||||
if vocab_size is not None and vocab_size < 65500:
|
||||
return np.uint16
|
||||
else:
|
||||
return np.int32
|
||||
|
||||
|
||||
def get_available_dataset_impl():
|
||||
return ["lazy", "cached", "mmap"]
|
||||
|
||||
|
||||
def infer_dataset_impl(path):
|
||||
if IndexedDataset.exists(path):
|
||||
with open(index_file_path(path), "rb") as f:
|
||||
magic = f.read(8)
|
||||
if magic == IndexedDataset._HDR_MAGIC:
|
||||
return "cached"
|
||||
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
|
||||
return "mmap"
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
print(f"Dataset does not exist: {path}")
|
||||
print(
|
||||
"Path should be a basename that both .idx and .bin can be appended to get full filenames."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def make_builder(out_file, impl, vocab_size=None):
|
||||
if impl == "mmap":
|
||||
return MMapIndexedDatasetBuilder(
|
||||
out_file, dtype=__best_fitting_dtype(vocab_size)
|
||||
)
|
||||
else:
|
||||
return IndexedDatasetBuilder(out_file)
|
||||
|
||||
|
||||
def make_dataset(path, impl, skip_warmup=False):
|
||||
if not IndexedDataset.exists(path):
|
||||
print(f"Dataset does not exist: {path}")
|
||||
print(
|
||||
"Path should be a basename that both .idx and .bin can be appended to get full filenames."
|
||||
)
|
||||
return None
|
||||
if impl == "infer":
|
||||
impl = infer_dataset_impl(path)
|
||||
if impl == "lazy" and IndexedDataset.exists(path):
|
||||
return IndexedDataset(path)
|
||||
elif impl == "cached" and IndexedDataset.exists(path):
|
||||
return IndexedCachedDataset(path)
|
||||
elif impl == "mmap" and MMapIndexedDataset.exists(path):
|
||||
return MMapIndexedDataset(path, skip_warmup)
|
||||
print(f"Unknown dataset implementation: {impl}")
|
||||
return None
|
||||
|
||||
|
||||
def dataset_exists(path, impl):
|
||||
if impl == "mmap":
|
||||
return MMapIndexedDataset.exists(path)
|
||||
else:
|
||||
return IndexedDataset.exists(path)
|
||||
|
||||
|
||||
def read_longs(f, n):
|
||||
a = np.empty(n, dtype=np.int64)
|
||||
f.readinto(a)
|
||||
return a
|
||||
|
||||
|
||||
def write_longs(f, a):
|
||||
f.write(np.array(a, dtype=np.int64))
|
||||
|
||||
|
||||
dtypes = {
|
||||
1: np.uint8,
|
||||
2: np.int8,
|
||||
3: np.int16,
|
||||
4: np.int32,
|
||||
5: np.int64,
|
||||
6: np.float,
|
||||
7: np.double,
|
||||
8: np.uint16,
|
||||
}
|
||||
|
||||
|
||||
def __best_fitting_dtype(vocab_size=None):
|
||||
if vocab_size is not None and vocab_size < 65500:
|
||||
return np.uint16
|
||||
else:
|
||||
return np.int32
|
||||
|
||||
|
||||
def make_mmap_builder(out_file, vocab_size=None):
|
||||
return MMapIndexedDatasetBuilder(
|
||||
out_file, dtype=__best_fitting_dtype(vocab_size)
|
||||
)
|
||||
|
||||
|
||||
def code(dtype):
|
||||
for k in dtypes.keys():
|
||||
if dtypes[k] == dtype:
|
||||
return k
|
||||
raise ValueError(dtype)
|
||||
|
||||
|
||||
def index_file_path(prefix_path):
|
||||
return prefix_path + ".idx"
|
||||
|
||||
|
||||
def data_file_path(prefix_path):
|
||||
return prefix_path + ".bin"
|
||||
|
||||
|
||||
def create_doc_idx(sizes):
|
||||
doc_idx = [0]
|
||||
for i, s in enumerate(sizes):
|
||||
if s == 0:
|
||||
doc_idx.append(i + 1)
|
||||
return doc_idx
|
||||
|
||||
|
||||
class IndexedDataset(torch.utils.data.Dataset):
|
||||
"""Loader for IndexedDataset"""
|
||||
|
||||
_HDR_MAGIC = b"TNTIDX\x00\x00"
|
||||
|
||||
def __init__(self, path):
|
||||
super().__init__()
|
||||
self.path = path
|
||||
self.data_file = None
|
||||
self.read_index(path)
|
||||
|
||||
def read_index(self, path):
|
||||
with open(index_file_path(path), "rb") as f:
|
||||
magic = f.read(8)
|
||||
assert magic == self._HDR_MAGIC, (
|
||||
"Index file doesn't match expected format. "
|
||||
"Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
version = f.read(8)
|
||||
assert struct.unpack("<Q", version) == (1,)
|
||||
code, self.element_size = struct.unpack("<QQ", f.read(16))
|
||||
self.dtype = dtypes[code]
|
||||
self._len, self.s = struct.unpack("<QQ", f.read(16))
|
||||
self.doc_count = struct.unpack("<Q", f.read(8))
|
||||
self.dim_offsets = read_longs(f, self._len + 1)
|
||||
self.data_offsets = read_longs(f, self._len + 1)
|
||||
self.sizes = read_longs(f, self.s)
|
||||
self.doc_idx = read_longs(f, self.doc_count)
|
||||
|
||||
def read_data(self, path):
|
||||
self.data_file = open(data_file_path(path), "rb", buffering=0)
|
||||
|
||||
def check_index(self, i):
|
||||
if i < 0 or i >= self._len:
|
||||
raise IndexError("index out of range")
|
||||
|
||||
def __del__(self):
|
||||
if self.data_file:
|
||||
self.data_file.close()
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if not self.data_file:
|
||||
self.read_data(self.path)
|
||||
if isinstance(idx, int):
|
||||
i = idx
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
return a
|
||||
elif isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||
sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]]
|
||||
size = sum(sizes)
|
||||
a = np.empty(size, dtype=self.dtype)
|
||||
self.data_file.seek(self.data_offsets[start] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
offsets = list(accumulate(sizes))
|
||||
sents = np.split(a, offsets[:-1])
|
||||
return sents
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
def size(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return os.path.exists(index_file_path(path)) and os.path.exists(
|
||||
data_file_path(path)
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False # avoid prefetching to save memory
|
||||
|
||||
|
||||
class IndexedCachedDataset(IndexedDataset):
|
||||
def __init__(self, path):
|
||||
super().__init__(path)
|
||||
self.cache = None
|
||||
self.cache_index = {}
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return True
|
||||
|
||||
def prefetch(self, indices):
|
||||
if all(i in self.cache_index for i in indices):
|
||||
return
|
||||
if not self.data_file:
|
||||
self.read_data(self.path)
|
||||
indices = sorted(set(indices))
|
||||
total_size = 0
|
||||
for i in indices:
|
||||
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
|
||||
self.cache = np.empty(total_size, dtype=self.dtype)
|
||||
ptx = 0
|
||||
self.cache_index.clear()
|
||||
for i in indices:
|
||||
self.cache_index[i] = ptx
|
||||
size = self.data_offsets[i + 1] - self.data_offsets[i]
|
||||
a = self.cache[ptx : ptx + size]
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
ptx += size
|
||||
if self.data_file:
|
||||
# close and delete data file after prefetch so we can pickle
|
||||
self.data_file.close()
|
||||
self.data_file = None
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
i = idx
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
ptx = self.cache_index[i]
|
||||
np.copyto(a, self.cache[ptx : ptx + a.size])
|
||||
return a
|
||||
elif isinstance(idx, slice):
|
||||
# Hack just to make this work, can optimizer later if necessary
|
||||
sents = []
|
||||
for i in range(*idx.indices(len(self))):
|
||||
sents.append(self[i])
|
||||
return sents
|
||||
|
||||
|
||||
class IndexedDatasetBuilder(object):
|
||||
element_sizes = {
|
||||
np.uint8: 1,
|
||||
np.int8: 1,
|
||||
np.int16: 2,
|
||||
np.int32: 4,
|
||||
np.int64: 8,
|
||||
np.float: 4,
|
||||
np.double: 8,
|
||||
}
|
||||
|
||||
def __init__(self, out_file, dtype=np.int32):
|
||||
self.out_file = open(out_file, "wb")
|
||||
self.dtype = dtype
|
||||
self.data_offsets = [0]
|
||||
self.dim_offsets = [0]
|
||||
self.sizes = []
|
||||
self.element_size = self.element_sizes[self.dtype]
|
||||
self.doc_idx = [0]
|
||||
|
||||
def add_item(self, tensor):
|
||||
bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
|
||||
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
|
||||
for s in tensor.size():
|
||||
self.sizes.append(s)
|
||||
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
|
||||
|
||||
def end_document(self):
|
||||
self.doc_idx.append(len(self.sizes))
|
||||
|
||||
def merge_file_(self, another_file):
|
||||
index = IndexedDataset(another_file)
|
||||
assert index.dtype == self.dtype
|
||||
|
||||
begin = self.data_offsets[-1]
|
||||
for offset in index.data_offsets[1:]:
|
||||
self.data_offsets.append(begin + offset)
|
||||
self.sizes.extend(index.sizes)
|
||||
begin = self.dim_offsets[-1]
|
||||
for dim_offset in index.dim_offsets[1:]:
|
||||
self.dim_offsets.append(begin + dim_offset)
|
||||
|
||||
with open(data_file_path(another_file), "rb") as f:
|
||||
while True:
|
||||
data = f.read(1024)
|
||||
if data:
|
||||
self.out_file.write(data)
|
||||
else:
|
||||
break
|
||||
|
||||
def finalize(self, index_file):
|
||||
self.out_file.close()
|
||||
index = open(index_file, "wb")
|
||||
index.write(b"TNTIDX\x00\x00")
|
||||
index.write(struct.pack("<Q", 1))
|
||||
index.write(struct.pack("<QQ", code(self.dtype), self.element_size))
|
||||
index.write(struct.pack("<QQ", len(self.data_offsets) - 1, len(self.sizes)))
|
||||
index.write(struct.pack("<Q", len(self.doc_idx)))
|
||||
write_longs(index, self.dim_offsets)
|
||||
write_longs(index, self.data_offsets)
|
||||
write_longs(index, self.sizes)
|
||||
write_longs(index, self.doc_idx)
|
||||
index.close()
|
||||
|
||||
|
||||
def _warmup_mmap_file(path):
|
||||
with open(path, "rb") as stream:
|
||||
while stream.read(100 * 1024 * 1024):
|
||||
pass
|
||||
|
||||
|
||||
class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
class Index(object):
|
||||
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
||||
|
||||
@classmethod
|
||||
def writer(cls, path, dtype):
|
||||
class _Writer(object):
|
||||
def __enter__(self):
|
||||
self._file = open(path, "wb")
|
||||
|
||||
self._file.write(cls._HDR_MAGIC)
|
||||
self._file.write(struct.pack("<Q", 1))
|
||||
self._file.write(struct.pack("<B", code(dtype)))
|
||||
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _get_pointers(sizes):
|
||||
dtype_size = dtype().itemsize
|
||||
address = 0
|
||||
pointers = []
|
||||
|
||||
for size in sizes:
|
||||
pointers.append(address)
|
||||
address += size * dtype_size
|
||||
|
||||
return pointers
|
||||
|
||||
def write(self, sizes, doc_idx):
|
||||
pointers = self._get_pointers(sizes)
|
||||
|
||||
self._file.write(struct.pack("<Q", len(sizes)))
|
||||
self._file.write(struct.pack("<Q", len(doc_idx)))
|
||||
|
||||
sizes = np.array(sizes, dtype=np.int32)
|
||||
self._file.write(sizes.tobytes(order="C"))
|
||||
del sizes
|
||||
|
||||
pointers = np.array(pointers, dtype=np.int64)
|
||||
self._file.write(pointers.tobytes(order="C"))
|
||||
del pointers
|
||||
|
||||
doc_idx = np.array(doc_idx, dtype=np.int64)
|
||||
self._file.write(doc_idx.tobytes(order="C"))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._file.close()
|
||||
|
||||
return _Writer()
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
with open(path, "rb") as stream:
|
||||
magic_test = stream.read(9)
|
||||
assert self._HDR_MAGIC == magic_test, (
|
||||
"Index file doesn't match expected format. "
|
||||
"Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
version = struct.unpack("<Q", stream.read(8))
|
||||
assert (1,) == version
|
||||
|
||||
(dtype_code,) = struct.unpack("<B", stream.read(1))
|
||||
self._dtype = dtypes[dtype_code]
|
||||
self._dtype_size = self._dtype().itemsize
|
||||
|
||||
self._len = struct.unpack("<Q", stream.read(8))[0]
|
||||
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
|
||||
offset = stream.tell()
|
||||
|
||||
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
self._sizes = np.frombuffer(
|
||||
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
|
||||
)
|
||||
self._pointers = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._len,
|
||||
offset=offset + self._sizes.nbytes,
|
||||
)
|
||||
self._doc_idx = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._doc_count,
|
||||
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._doc_idx
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
return self._pointers[i], self._sizes[i]
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
super().__init__()
|
||||
|
||||
self._path = None
|
||||
self._index = None
|
||||
self._bin_buffer = None
|
||||
|
||||
self._do_init(path, skip_warmup)
|
||||
|
||||
def __getstate__(self):
|
||||
return self._path
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._do_init(state)
|
||||
|
||||
def _do_init(self, path, skip_warmup):
|
||||
self._path = path
|
||||
self._index = self.Index(index_file_path(self._path), skip_warmup)
|
||||
|
||||
if not skip_warmup:
|
||||
print_rank_0(" warming up data mmap file...")
|
||||
_warmup_mmap_file(data_file_path(self._path))
|
||||
print_rank_0(" creating numpy buffer of mmap...")
|
||||
self._bin_buffer_mmap = np.memmap(
|
||||
data_file_path(self._path), mode="r", order="C"
|
||||
)
|
||||
print_rank_0(" creating memory view of numpy buffer...")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
del self._index
|
||||
|
||||
def __len__(self):
|
||||
return len(self._index)
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
ptr, size = self._index[idx]
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
||||
)
|
||||
return np_array
|
||||
elif isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||
ptr = self._index._pointers[start]
|
||||
sizes = self._index._sizes[idx]
|
||||
offsets = list(accumulate(sizes))
|
||||
total_size = sum(sizes)
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
|
||||
)
|
||||
sents = np.split(np_array, offsets[:-1])
|
||||
return sents
|
||||
|
||||
def get(self, idx, offset=0, length=None):
|
||||
"""Retrieves a single item from the dataset with the option to only
|
||||
return a portion of the item.
|
||||
|
||||
get(idx) is the same as [idx] but get() does not support slicing.
|
||||
"""
|
||||
ptr, size = self._index[idx]
|
||||
if length is None:
|
||||
length = size - offset
|
||||
ptr += offset * np.dtype(self._index.dtype).itemsize
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
|
||||
)
|
||||
return np_array
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._index.sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._index.doc_idx
|
||||
|
||||
def get_doc_idx(self):
|
||||
return self._index._doc_idx
|
||||
|
||||
def set_doc_idx(self, doc_idx_):
|
||||
self._index._doc_idx = doc_idx_
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return os.path.exists(index_file_path(path)) and os.path.exists(
|
||||
data_file_path(path)
|
||||
)
|
||||
|
||||
|
||||
class MMapIndexedDatasetBuilder(object):
|
||||
def __init__(self, out_file, dtype=np.int64):
|
||||
self._data_file = open(out_file, "wb")
|
||||
self._dtype = dtype
|
||||
self._sizes = []
|
||||
self._doc_idx = [0]
|
||||
|
||||
def add_item(self, tensor):
|
||||
np_array = np.array(tensor.numpy(), dtype=self._dtype)
|
||||
self._data_file.write(np_array.tobytes(order="C"))
|
||||
self._sizes.append(np_array.size)
|
||||
|
||||
def end_document(self):
|
||||
self._doc_idx.append(len(self._sizes))
|
||||
|
||||
def merge_file_(self, another_file):
|
||||
# Concatenate index
|
||||
index = MMapIndexedDataset.Index(index_file_path(another_file))
|
||||
assert index.dtype == self._dtype
|
||||
|
||||
for size in index.sizes:
|
||||
self._sizes.append(size)
|
||||
|
||||
# Concatenate data
|
||||
with open(data_file_path(another_file), "rb") as f:
|
||||
shutil.copyfileobj(f, self._data_file)
|
||||
|
||||
def finalize(self, index_file):
|
||||
self._data_file.close()
|
||||
|
||||
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
|
||||
index.write(self._sizes, self._doc_idx)
|
@ -0,0 +1,332 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""GPT prompting dataset."""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from codegeex.megatron import mpu, print_rank_0, get_tokenizer
|
||||
from codegeex.megatron.data.blendable_dataset import BlendableDataset
|
||||
from codegeex.megatron.data.dataset_utils import get_datasets_weights_and_num_samples
|
||||
from codegeex.megatron.data.dataset_utils import get_train_valid_test_split_
|
||||
from codegeex.megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
|
||||
|
||||
|
||||
def build_train_valid_test_datasets(
|
||||
data_prefix,
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
seq_length,
|
||||
seed,
|
||||
skip_warmup,
|
||||
):
|
||||
"""Build train, valid, and test datasets."""
|
||||
|
||||
# Single dataset.
|
||||
if len(data_prefix) == 1:
|
||||
return _build_train_valid_test_datasets(
|
||||
data_prefix[0],
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
seq_length,
|
||||
seed,
|
||||
skip_warmup,
|
||||
)
|
||||
|
||||
# Blending dataset.
|
||||
# Parse the values.
|
||||
output = get_datasets_weights_and_num_samples(
|
||||
data_prefix, train_valid_test_num_samples
|
||||
)
|
||||
prefixes, weights, datasets_train_valid_test_num_samples = output
|
||||
|
||||
# Build individual datasets.
|
||||
train_datasets = []
|
||||
valid_datasets = []
|
||||
test_datasets = []
|
||||
for i in range(len(prefixes)):
|
||||
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
|
||||
prefixes[i],
|
||||
data_impl,
|
||||
splits_string,
|
||||
datasets_train_valid_test_num_samples[i],
|
||||
seq_length,
|
||||
seed,
|
||||
skip_warmup,
|
||||
)
|
||||
if train_ds:
|
||||
train_datasets.append(train_ds)
|
||||
if valid_ds:
|
||||
valid_datasets.append(valid_ds)
|
||||
if test_ds:
|
||||
test_datasets.append(test_ds)
|
||||
|
||||
# Blend.
|
||||
blending_train_dataset = None
|
||||
if train_datasets:
|
||||
blending_train_dataset = BlendableDataset(train_datasets, weights)
|
||||
blending_valid_dataset = None
|
||||
if valid_datasets:
|
||||
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
|
||||
blending_test_dataset = None
|
||||
if test_datasets:
|
||||
blending_test_dataset = BlendableDataset(test_datasets, weights)
|
||||
|
||||
return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)
|
||||
|
||||
|
||||
def _build_train_valid_test_datasets(
|
||||
data_prefix,
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
seq_length,
|
||||
seed,
|
||||
skip_warmup,
|
||||
):
|
||||
"""Build train, valid, and test datasets."""
|
||||
|
||||
# Indexed dataset.
|
||||
assert os.path.exists(data_prefix + "_input_ids.bin"), f"Input tokens datafile not found: {data_prefix}_input_ids.bin"
|
||||
assert os.path.exists(data_prefix + "_attention_mask.bin"), f"Attention mask datafile not found: {data_prefix}_attention_mask.bin"
|
||||
assert os.path.exists(data_prefix + "_labels.bin"), f"Labels datafile not found: {data_prefix}_labels.bin"
|
||||
|
||||
input_ids_indexed_dataset = get_indexed_dataset_(data_prefix + "_input_ids", data_impl, skip_warmup)
|
||||
attention_mask_indexed_dataset = get_indexed_dataset_(data_prefix + "_attention_mask", data_impl, skip_warmup)
|
||||
labels_indexed_dataset = get_indexed_dataset_(data_prefix + "_labels", data_impl, skip_warmup)
|
||||
|
||||
total_num_of_documents = input_ids_indexed_dataset.sizes.shape[0]
|
||||
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
|
||||
|
||||
# Print stats about the splits.
|
||||
print_rank_0(" > dataset split:")
|
||||
|
||||
def print_split_stats(name, index):
|
||||
print_rank_0(" {}:".format(name))
|
||||
print_rank_0(
|
||||
" document indices in [{}, {}) total of {} "
|
||||
"documents".format(
|
||||
splits[index], splits[index + 1], splits[index + 1] - splits[index]
|
||||
)
|
||||
)
|
||||
|
||||
print_split_stats("train", 0)
|
||||
print_split_stats("validation", 1)
|
||||
print_split_stats("test", 2)
|
||||
|
||||
def build_dataset(index, name):
|
||||
dataset = None
|
||||
if splits[index + 1] > splits[index]:
|
||||
documents = np.arange(
|
||||
start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32
|
||||
)
|
||||
dataset = PromptDataset(
|
||||
name,
|
||||
data_prefix,
|
||||
documents,
|
||||
input_ids_indexed_dataset,
|
||||
attention_mask_indexed_dataset,
|
||||
labels_indexed_dataset,
|
||||
train_valid_test_num_samples[index],
|
||||
seq_length,
|
||||
seed,
|
||||
)
|
||||
return dataset
|
||||
|
||||
train_dataset = build_dataset(0, "train")
|
||||
valid_dataset = build_dataset(1, "valid")
|
||||
test_dataset = build_dataset(2, "test")
|
||||
print_rank_0(f"train_dataset:{type(train_dataset)}")
|
||||
print_rank_0(f"valid_dataset:{type(valid_dataset)}")
|
||||
print_rank_0(f"test_dataset:{type(test_dataset)}")
|
||||
|
||||
return (train_dataset, valid_dataset, test_dataset)
|
||||
|
||||
|
||||
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
|
||||
"""Build indexed dataset."""
|
||||
print_rank_0(" > building dataset index ...")
|
||||
|
||||
start_time = time.time()
|
||||
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
|
||||
print_rank_0(
|
||||
" > finished creating indexed dataset in {:4f} "
|
||||
"seconds".format(time.time() - start_time)
|
||||
)
|
||||
print_rank_0(" number of documents: {}".format(indexed_dataset.sizes.shape[0]))
|
||||
|
||||
return indexed_dataset
|
||||
|
||||
|
||||
class PromptDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
data_prefix,
|
||||
documents,
|
||||
input_ids_indexed_dataset,
|
||||
attention_mask_index_dataset,
|
||||
labels_indexed_dataset,
|
||||
num_samples,
|
||||
seq_length,
|
||||
seed,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
name: name of the dataset.
|
||||
data_prefix: prefix of the data.
|
||||
documents: list of document indices.
|
||||
input_ids_indexed_dataset: indexed dataset for prompts.
|
||||
attention_mask_index_dataset: indexed dataset for text.
|
||||
labels_indexed_dataset: indexed dataset for labels.
|
||||
num_samples: number of samples to draw from the indexed dataset.
|
||||
seq_length: sequence length.
|
||||
seed: seed for random number generator.
|
||||
"""
|
||||
|
||||
self.name = name
|
||||
self.input_ids_indexed_dataset = input_ids_indexed_dataset
|
||||
self.attention_mask_index_dataset = attention_mask_index_dataset
|
||||
self.labels_indexed_dataset = labels_indexed_dataset
|
||||
self.seq_length = seq_length
|
||||
self.eod_token = get_tokenizer().eod
|
||||
|
||||
# Checks
|
||||
assert np.min(documents) >= 0
|
||||
assert np.max(documents) < input_ids_indexed_dataset.sizes.shape[0]
|
||||
assert input_ids_indexed_dataset.sizes.shape[0] == attention_mask_index_dataset.sizes.shape[0]
|
||||
assert attention_mask_index_dataset.sizes.shape[0] == labels_indexed_dataset.sizes.shape[0]
|
||||
|
||||
# Build index mappings.
|
||||
self.doc_idx = _build_index_mappings(
|
||||
self.name,
|
||||
data_prefix,
|
||||
documents,
|
||||
self.input_ids_indexed_dataset.sizes,
|
||||
num_samples,
|
||||
seq_length,
|
||||
seed,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
# -1 is due to data structure used to retieve the index:
|
||||
# sample i --> [sample_idx[i], sample_idx[i+1])
|
||||
return self.doc_idx.shape[0]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# get the doc index
|
||||
doc_idx = self.doc_idx[idx]
|
||||
doc_idx = int(doc_idx) # NumPy int => Python int
|
||||
|
||||
input_ids = self.input_ids_indexed_dataset[doc_idx]
|
||||
# print_rank_0(f"input_ids={input_ids}")
|
||||
attention_mask = self.attention_mask_index_dataset[doc_idx]
|
||||
labels = self.labels_indexed_dataset[doc_idx]
|
||||
|
||||
res = {
|
||||
"input_ids": np.array(input_ids, dtype=np.int64),
|
||||
"attention_mask": np.array(attention_mask, dtype=np.int64),
|
||||
"labels": np.array(labels, dtype=np.int64),
|
||||
}
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _build_index_mappings(
|
||||
name, data_prefix, documents, sizes, num_samples, seq_length, seed,
|
||||
):
|
||||
"""Build index mappings.
|
||||
We only have to build doc-idx in prompt dataset.
|
||||
|
||||
Args:
|
||||
name: name of the dataset.
|
||||
data_prefix: prefix of the data.
|
||||
documents: list of document indices.
|
||||
sizes: sizes of the indexed dataset.
|
||||
num_samples: number of samples to draw from the indexed dataset.
|
||||
seq_length: sequence length.
|
||||
seed: seed for random number generator.
|
||||
"""
|
||||
num_epochs = _num_epochs(documents.shape[0], num_samples)
|
||||
np_rng = np.random.RandomState(seed=seed)
|
||||
|
||||
_filename = data_prefix
|
||||
_filename += "_{}_indexmap".format(name)
|
||||
_filename += "_{}ns".format(num_samples)
|
||||
_filename += "_{}sl".format(seq_length)
|
||||
_filename += "_{}s".format(seed)
|
||||
doc_idx_filename = _filename + "_doc_idx.npy"
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
if not os.path.isfile(doc_idx_filename):
|
||||
print_rank_0(
|
||||
" > WARNING: could not find index map files, building "
|
||||
"the indices on rank 0 ..."
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
doc_idx = _build_doc_idx(documents, num_epochs, np_rng, False)[:num_samples]
|
||||
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
|
||||
|
||||
print_rank_0(
|
||||
" > elasped time to build and save doc-idx mapping "
|
||||
"(seconds): {:4f}".format(time.time() - start_time)
|
||||
)
|
||||
|
||||
# This should be a barrier but nccl barrier assumes
|
||||
# device_index=rank which is not the case for model
|
||||
# parallel case
|
||||
counts = torch.cuda.LongTensor([1])
|
||||
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
|
||||
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
|
||||
assert counts[0].item() == (
|
||||
torch.distributed.get_world_size()
|
||||
// torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())
|
||||
)
|
||||
# Load mappings.
|
||||
start_time = time.time()
|
||||
print_rank_0(" > loading doc-idx mapping from {}".format(doc_idx_filename))
|
||||
doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode="r")
|
||||
print_rank_0(" total number of samples: {}".format(doc_idx.shape[0]))
|
||||
print_rank_0(" total number of epochs: {}".format(num_epochs))
|
||||
|
||||
return doc_idx
|
||||
|
||||
|
||||
def _num_epochs(samples_per_epoch, num_samples):
|
||||
"""Calculate the epoch needed for so many sample."""
|
||||
return int(np.ceil(num_samples / samples_per_epoch))
|
||||
|
||||
|
||||
def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch):
|
||||
"""Build an array with length = number-of-epochs * number-of-dcuments.
|
||||
Each index is mapped to a corresponding document."""
|
||||
if not separate_last_epoch or num_epochs == 1:
|
||||
doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1]
|
||||
doc_idx[:] = documents
|
||||
doc_idx = doc_idx.reshape(-1)
|
||||
doc_idx = doc_idx.astype(np.int32)
|
||||
np_rng.shuffle(doc_idx)
|
||||
return doc_idx
|
||||
|
||||
doc_idx_first = _build_doc_idx(documents, num_epochs - 1, np_rng, False)
|
||||
doc_idx_last = _build_doc_idx(documents, 1, np_rng, False)
|
||||
return np.concatenate((doc_idx_first, doc_idx_last))
|
Loading…
Reference in New Issue