From 02241e057c05513e1988cd99aa339654dcfd37e4 Mon Sep 17 00:00:00 2001 From: Stanislas0 Date: Sun, 4 Dec 2022 01:51:41 +0000 Subject: [PATCH] Add megatron data processing --- codegeex/benchmark/utils.py | 2 +- codegeex/data/{utils.py => data_utils.py} | 0 codegeex/data/indexed_dataset.py | 204 ------ codegeex/data/process_pretrain_dataset.py | 6 +- codegeex/data/processor.py | 2 +- codegeex/megatron/data/__init__.py | 0 codegeex/megatron/data/blendable_dataset.py | 69 ++ codegeex/megatron/data/data_samplers.py | 185 +++++ codegeex/megatron/data/dataset_utils.py | 566 +++++++++++++++ codegeex/megatron/data/helpers.cpp | 717 ++++++++++++++++++++ codegeex/megatron/data/indexed_dataset.py | 595 ++++++++++++++++ codegeex/megatron/data/prompt_dataset.py | 332 +++++++++ 12 files changed, 2469 insertions(+), 209 deletions(-) rename codegeex/data/{utils.py => data_utils.py} (100%) delete mode 100644 codegeex/data/indexed_dataset.py create mode 100644 codegeex/megatron/data/__init__.py create mode 100644 codegeex/megatron/data/blendable_dataset.py create mode 100644 codegeex/megatron/data/data_samplers.py create mode 100644 codegeex/megatron/data/dataset_utils.py create mode 100644 codegeex/megatron/data/helpers.cpp create mode 100644 codegeex/megatron/data/indexed_dataset.py create mode 100644 codegeex/megatron/data/prompt_dataset.py diff --git a/codegeex/benchmark/utils.py b/codegeex/benchmark/utils.py index 6083221..de47797 100644 --- a/codegeex/benchmark/utils.py +++ b/codegeex/benchmark/utils.py @@ -1,6 +1,6 @@ import os from typing import * -from codegeex.data.utils import stream_jsonl, LANGUAGE_TAG +from codegeex.data.data_utils import stream_jsonl, LANGUAGE_TAG IMPORT_HELPER = { diff --git a/codegeex/data/utils.py b/codegeex/data/data_utils.py similarity index 100% rename from codegeex/data/utils.py rename to codegeex/data/data_utils.py diff --git a/codegeex/data/indexed_dataset.py b/codegeex/data/indexed_dataset.py deleted file mode 100644 index 87cd6ba..0000000 --- a/codegeex/data/indexed_dataset.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - - -# copied from fairseq/fairseq/data/indexed_dataset.py -# Removed IndexedRawTextDataset since it relied on Fairseq dictionary -# other slight modifications to remove fairseq dependencies -# Added document index to index file and made it accessible. -# An empty sentence no longer separates documents. - - -# copied from Megatron -# Used for building mmap datasets. -from functools import lru_cache -import struct -import numpy as np -import shutil - -__all__ = ["make_mmap_builder"] - -dtypes = { - 1: np.uint8, - 2: np.int8, - 3: np.int16, - 4: np.int32, - 5: np.int64, - 6: np.float, - 7: np.double, - 8: np.uint16, -} - - -def __best_fitting_dtype(vocab_size=None): - if vocab_size is not None and vocab_size < 65500: - return np.uint16 - else: - return np.int32 - - -def make_mmap_builder(out_file, vocab_size=None): - return MMapIndexedDatasetBuilder( - out_file, dtype=__best_fitting_dtype(vocab_size) - ) - - -def code(dtype): - for k in dtypes.keys(): - if dtypes[k] == dtype: - return k - raise ValueError(dtype) - - -def index_file_path(prefix_path): - return prefix_path + ".idx" - - -def data_file_path(prefix_path): - return prefix_path + ".bin" - - -class MMapIndexedDataset: - class Index(object): - _HDR_MAGIC = b"MMIDIDX\x00\x00" - - @classmethod - def writer(cls, path, dtype): - class _Writer(object): - def __enter__(self): - self._file = open(path, "wb") - - self._file.write(cls._HDR_MAGIC) - self._file.write(struct.pack(" Dict: p_list = glob.glob(p + "/*") for p_ in p_list: if p_.endswith(".gz") or p_.endswith(".jsonl"): - print(f"loading from {p}") + print(f"loading from {p_}") yield from stream_jsonl(p_) diff --git a/codegeex/data/processor.py b/codegeex/data/processor.py index a650acd..4ea507f 100644 --- a/codegeex/data/processor.py +++ b/codegeex/data/processor.py @@ -1,7 +1,7 @@ from typing import * from time import perf_counter -from codegeex.data.utils import sliding_window +from codegeex.data.data_utils import sliding_window from codegeex.data.types import PromptSample, LabelSample diff --git a/codegeex/megatron/data/__init__.py b/codegeex/megatron/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/codegeex/megatron/data/blendable_dataset.py b/codegeex/megatron/data/blendable_dataset.py new file mode 100644 index 0000000..9b176f9 --- /dev/null +++ b/codegeex/megatron/data/blendable_dataset.py @@ -0,0 +1,69 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Blendable dataset.""" + +import time +import torch +import numpy as np + +from codegeex.megatron import print_rank_0 + + +class BlendableDataset(torch.utils.data.Dataset): + def __init__(self, datasets, weights): + + self.datasets = datasets + num_datasets = len(datasets) + assert num_datasets == len(weights) + + self.size = 0 + for dataset in self.datasets: + self.size += len(dataset) + + # Normalize weights. + weights = np.array(weights, dtype=np.float64) + sum_weights = np.sum(weights) + assert sum_weights > 0.0 + weights /= sum_weights + + # Build indecies. + start_time = time.time() + assert num_datasets < 255 + self.dataset_index = np.zeros(self.size, dtype=np.uint8) + self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) + + from megatron.data import helpers + + helpers.build_blending_indices( + self.dataset_index, + self.dataset_sample_index, + weights, + num_datasets, + self.size, + torch.distributed.get_rank() == 0, + ) + print_rank_0( + "> elapsed time for building blendable dataset indices: " + "{:.2f} (sec)".format(time.time() - start_time) + ) + + def __len__(self): + return self.size + + def __getitem__(self, idx): + dataset_idx = self.dataset_index[idx] + sample_idx = self.dataset_sample_index[idx] + return self.datasets[dataset_idx][sample_idx] diff --git a/codegeex/megatron/data/data_samplers.py b/codegeex/megatron/data/data_samplers.py new file mode 100644 index 0000000..2caffd1 --- /dev/null +++ b/codegeex/megatron/data/data_samplers.py @@ -0,0 +1,185 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataloaders.""" + + +import torch +from codegeex.megatron import get_args +from codegeex.megatron import mpu + + +def build_pretraining_data_loader(dataset, consumed_samples): + """Buld dataloader given an input dataset.""" + + if dataset is None: + return None + args = get_args() + + # Megatron sampler + if args.dataloader_type == "single": + batch_sampler = MegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=args.micro_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size(), + ) + elif args.dataloader_type == "cyclic": + batch_sampler = MegatronPretrainingRandomSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=args.micro_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size(), + ) + else: + raise Exception( + "{} dataloader type is not supported.".format(args.dataloader_type) + ) + + # Torch dataloader. + return torch.utils.data.DataLoader( + dataset, + batch_sampler=batch_sampler, + num_workers=args.num_workers, + pin_memory=True, + ) + + +class MegatronPretrainingSampler: + def __init__( + self, + total_samples, + consumed_samples, + micro_batch_size, + data_parallel_rank, + data_parallel_size, + drop_last=True, + ): + # Keep a copy of input params for later use. + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.micro_batch_times_data_parallel_size = ( + self.micro_batch_size * data_parallel_size + ) + self.drop_last = drop_last + + # Sanity checks. + assert self.total_samples > 0, "no sample to consume: {}".format( + self.total_samples + ) + assert ( + self.consumed_samples < self.total_samples + ), "no samples left to consume: {}, {}".format( + self.consumed_samples, self.total_samples + ) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert ( + self.data_parallel_rank < data_parallel_size + ), "data_parallel_rank should be smaller than data size: {}, " "{}".format( + self.data_parallel_rank, data_parallel_size + ) + + def __len__(self): + return self.total_samples + + def get_start_end_idx(self): + start_idx = self.data_parallel_rank * self.micro_batch_size + end_idx = start_idx + self.micro_batch_size + return start_idx, end_idx + + def __iter__(self): + batch = [] + # Last batch will be dropped if drop_last is not set False + for idx in range(self.consumed_samples, self.total_samples): + batch.append(idx) + if len(batch) == self.micro_batch_times_data_parallel_size: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + batch = [] + + # Check the last partial batch and see drop_last is set + if len(batch) > 0 and not self.drop_last: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + + +class MegatronPretrainingRandomSampler: + def __init__( + self, + total_samples, + consumed_samples, + micro_batch_size, + data_parallel_rank, + data_parallel_size, + ): + # Keep a copy of input params for later use. + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.data_parallel_size = data_parallel_size + self.micro_batch_times_data_parallel_size = ( + self.micro_batch_size * data_parallel_size + ) + self.last_batch_size = ( + self.total_samples % self.micro_batch_times_data_parallel_size + ) + + # Sanity checks. + assert self.total_samples > 0, "no sample to consume: {}".format( + self.total_samples + ) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert ( + self.data_parallel_rank < data_parallel_size + ), "data_parallel_rank should be smaller than data size: {}, " "{}".format( + self.data_parallel_rank, data_parallel_size + ) + + def __len__(self): + return self.total_samples + + def __iter__(self): + active_total_samples = self.total_samples - self.last_batch_size + self.epoch = self.consumed_samples // active_total_samples + current_epoch_samples = self.consumed_samples % active_total_samples + assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 + + # data sharding and random sampling + bucket_size = ( + self.total_samples // self.micro_batch_times_data_parallel_size + ) * self.micro_batch_size + bucket_offset = current_epoch_samples // self.data_parallel_size + start_idx = self.data_parallel_rank * bucket_size + + g = torch.Generator() + g.manual_seed(self.epoch) + random_idx = torch.randperm(bucket_size, generator=g).tolist() + idx_range = [start_idx + x for x in random_idx[bucket_offset:]] + + batch = [] + # Last batch if not complete will be dropped. + for idx in idx_range: + batch.append(idx) + if len(batch) == self.micro_batch_size: + self.consumed_samples += self.micro_batch_times_data_parallel_size + yield batch + batch = [] diff --git a/codegeex/megatron/data/dataset_utils.py b/codegeex/megatron/data/dataset_utils.py new file mode 100644 index 0000000..d7676ca --- /dev/null +++ b/codegeex/megatron/data/dataset_utils.py @@ -0,0 +1,566 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors, and NVIDIA. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Most of the code here has been copied from: +# https://github.com/google-research/albert/blob/master/create_pretraining_data.py +# with some modifications. + +import math +import os +import time +import collections + +import numpy as np +import torch + +from codegeex.megatron import mpu, print_rank_0 +from codegeex.megatron.data.indexed_dataset import make_dataset as make_indexed_dataset + + +def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples): + + # The data prefix should be in the format of: + # weight-1, data-prefix-1, weight-2, data-prefix-2, .. + assert len(data_prefix) % 2 == 0 + num_datasets = len(data_prefix) // 2 + weights = [0] * num_datasets + prefixes = [0] * num_datasets + for i in range(num_datasets): + weights[i] = float(data_prefix[2 * i]) + prefixes[i] = (data_prefix[2 * i + 1]).strip() + # Normalize weights + weight_sum = 0.0 + for weight in weights: + weight_sum += weight + assert weight_sum > 0.0 + weights = [weight / weight_sum for weight in weights] + + # Add 0.5% (the 1.005 factor) so in case the bleding dataset does + # not uniformly distribute the number of samples, we still have + # samples left to feed to the network. + datasets_train_valid_test_num_samples = [] + for weight in weights: + datasets_train_valid_test_num_samples.append( + [ + int(math.ceil(val * weight * 1.005)) + for val in train_valid_test_num_samples + ] + ) + + return prefixes, weights, datasets_train_valid_test_num_samples + + +def compile_helper(): + """Compile helper function ar runtime. Make sure this + is invoked on a single process.""" + import os + import subprocess + + path = os.path.abspath(os.path.dirname(__file__)) + ret = subprocess.run(["make", "-C", path]) + if ret.returncode != 0: + print("Making C++ dataset helpers module failed, exiting.") + import sys + + sys.exit(1) + + +def get_a_and_b_segments(sample, np_rng): + """Divide sample into a and b segments.""" + + # Number of sentences in the sample. + n_sentences = len(sample) + # Make sure we always have two sentences. + assert n_sentences > 1, "make sure each sample has at least two sentences." + + # First part: + # `a_end` is how many sentences go into the `A`. + a_end = 1 + if n_sentences >= 3: + # Note that randin in numpy is exclusive. + a_end = np_rng.randint(1, n_sentences) + tokens_a = [] + for j in range(a_end): + tokens_a.extend(sample[j]) + + # Second part: + tokens_b = [] + for j in range(a_end, n_sentences): + tokens_b.extend(sample[j]) + + # Random next: + is_next_random = False + if np_rng.random() < 0.5: + is_next_random = True + tokens_a, tokens_b = tokens_b, tokens_a + + return tokens_a, tokens_b, is_next_random + + +def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): + """Truncates a pair of sequences to a maximum sequence length.""" + # print(len_a, len_b, max_num_tokens) + assert len_a > 0 + if len_a + len_b <= max_num_tokens: + return False + while len_a + len_b > max_num_tokens: + if len_a > len_b: + len_a -= 1 + tokens = tokens_a + else: + len_b -= 1 + tokens = tokens_b + if np_rng.random() < 0.5: + del tokens[0] + else: + tokens.pop() + return True + + +def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): + """Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" + + tokens = [] + tokentypes = [] + # [CLS]. + tokens.append(cls_id) + tokentypes.append(0) + # Segment A. + for token in tokens_a: + tokens.append(token) + tokentypes.append(0) + # [SEP]. + tokens.append(sep_id) + tokentypes.append(0) + # Segment B. + for token in tokens_b: + tokens.append(token) + tokentypes.append(1) + if tokens_b: + # [SEP]. + tokens.append(sep_id) + tokentypes.append(1) + + return tokens, tokentypes + + +MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"]) + + +def is_start_piece(piece): + """Check if the current word piece is the starting piece (BERT).""" + # When a word has been split into + # WordPieces, the first token does not have any marker and any subsequence + # tokens are prefixed with ##. So whenever we see the ## token, we + # append it to the previous set of word indexes. + return not piece.startswith("##") + + +def create_masked_lm_predictions( + tokens, + vocab_id_list, + vocab_id_to_token_dict, + masked_lm_prob, + cls_id, + sep_id, + mask_id, + max_predictions_per_seq, + np_rng, + max_ngrams=3, + do_whole_word_mask=True, + favor_longer_ngram=False, + do_permutation=False, + geometric_dist=False, + masking_style="bert", +): + """Creates the predictions for the masked LM objective. + Note: Tokens here are vocab ids and not text tokens.""" + + cand_indexes = [] + # Note(mingdachen): We create a list for recording if the piece is + # the starting piece of current token, where 1 means true, so that + # on-the-fly whole word masking is possible. + token_boundary = [0] * len(tokens) + + for (i, token) in enumerate(tokens): + if token == cls_id or token == sep_id: + token_boundary[i] = 1 + continue + # Whole Word Masking means that if we mask all of the wordpieces + # corresponding to an original word. + # + # Note that Whole Word Masking does *not* change the training code + # at all -- we still predict each WordPiece independently, softmaxed + # over the entire vocabulary. + if ( + do_whole_word_mask + and len(cand_indexes) >= 1 + and not is_start_piece(vocab_id_to_token_dict[token]) + ): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + if is_start_piece(vocab_id_to_token_dict[token]): + token_boundary[i] = 1 + + output_tokens = list(tokens) + + masked_lm_positions = [] + masked_lm_labels = [] + + if masked_lm_prob == 0: + return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) + + num_to_predict = min( + max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob))) + ) + + ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) + if not geometric_dist: + # Note(mingdachen): + # By default, we set the probilities to favor shorter ngram sequences. + pvals = 1.0 / np.arange(1, max_ngrams + 1) + pvals /= pvals.sum(keepdims=True) + if favor_longer_ngram: + pvals = pvals[::-1] + + ngram_indexes = [] + for idx in range(len(cand_indexes)): + ngram_index = [] + for n in ngrams: + ngram_index.append(cand_indexes[idx : idx + n]) + ngram_indexes.append(ngram_index) + + np_rng.shuffle(ngram_indexes) + + (masked_lms, masked_spans) = ([], []) + covered_indexes = set() + for cand_index_set in ngram_indexes: + if len(masked_lms) >= num_to_predict: + break + if not cand_index_set: + continue + # Note(mingdachen): + # Skip current piece if they are covered in lm masking or previous ngrams. + for index_set in cand_index_set[0]: + for index in index_set: + if index in covered_indexes: + continue + + if not geometric_dist: + n = np_rng.choice( + ngrams[: len(cand_index_set)], + p=pvals[: len(cand_index_set)] + / pvals[: len(cand_index_set)].sum(keepdims=True), + ) + else: + # Sampling "n" from the geometric distribution and clipping it to + # the max_ngrams. Using p=0.2 default from the SpanBERT paper + # https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1) + n = min(np_rng.geometric(0.2), max_ngrams) + + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # Note(mingdachen): + # Repeatedly looking for a candidate that does not exceed the + # maximum number of predictions by trying shorter ngrams. + while len(masked_lms) + len(index_set) > num_to_predict: + if n == 0: + break + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(masked_lms) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + covered_indexes.add(index) + masked_token = None + if masking_style == "bert": + # 80% of the time, replace with [MASK] + if np_rng.random() < 0.8: + masked_token = mask_id + else: + # 10% of the time, keep original + if np_rng.random() < 0.5: + masked_token = tokens[index] + # 10% of the time, replace with random word + else: + masked_token = vocab_id_list[ + np_rng.randint(0, len(vocab_id_list)) + ] + elif masking_style == "t5": + masked_token = mask_id + else: + raise ValueError("invalid value of masking style") + + output_tokens[index] = masked_token + masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) + + masked_spans.append( + MaskedLmInstance( + index=index_set, label=[tokens[index] for index in index_set] + ) + ) + + assert len(masked_lms) <= num_to_predict + np_rng.shuffle(ngram_indexes) + + select_indexes = set() + if do_permutation: + for cand_index_set in ngram_indexes: + if len(select_indexes) >= num_to_predict: + break + if not cand_index_set: + continue + # Note(mingdachen): + # Skip current piece if they are covered in lm masking or previous ngrams. + for index_set in cand_index_set[0]: + for index in index_set: + if index in covered_indexes or index in select_indexes: + continue + + n = np.random.choice( + ngrams[: len(cand_index_set)], + p=pvals[: len(cand_index_set)] + / pvals[: len(cand_index_set)].sum(keepdims=True), + ) + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + + while len(select_indexes) + len(index_set) > num_to_predict: + if n == 0: + break + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(select_indexes) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes or index in select_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + select_indexes.add(index) + assert len(select_indexes) <= num_to_predict + + select_indexes = sorted(select_indexes) + permute_indexes = list(select_indexes) + np_rng.shuffle(permute_indexes) + orig_token = list(output_tokens) + + for src_i, tgt_i in zip(select_indexes, permute_indexes): + output_tokens[src_i] = orig_token[tgt_i] + masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i])) + + masked_lms = sorted(masked_lms, key=lambda x: x.index) + # Sort the spans by the index of the first span + masked_spans = sorted(masked_spans, key=lambda x: x.index[0]) + + for p in masked_lms: + masked_lm_positions.append(p.index) + masked_lm_labels.append(p.label) + return ( + output_tokens, + masked_lm_positions, + masked_lm_labels, + token_boundary, + masked_spans, + ) + + +def pad_and_convert_to_numpy( + tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length +): + """Pad sequences and convert them to numpy.""" + + # Some checks. + num_tokens = len(tokens) + padding_length = max_seq_length - num_tokens + assert padding_length >= 0 + assert len(tokentypes) == num_tokens + assert len(masked_positions) == len(masked_labels) + + # Tokens and token types. + filler = [pad_id] * padding_length + tokens_np = np.array(tokens + filler, dtype=np.int64) + tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) + + # Padding mask. + padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64) + + # Lables and loss mask. + labels = [-1] * max_seq_length + loss_mask = [0] * max_seq_length + for i in range(len(masked_positions)): + assert masked_positions[i] < num_tokens + labels[masked_positions[i]] = masked_labels[i] + loss_mask[masked_positions[i]] = 1 + labels_np = np.array(labels, dtype=np.int64) + loss_mask_np = np.array(loss_mask, dtype=np.int64) + + return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np + + +def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): + + print_rank_0(" > building dataset index ...") + + start_time = time.time() + indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) + assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] + print_rank_0( + " > finished creating indexed dataset in {:4f} " + "seconds".format(time.time() - start_time) + ) + + print_rank_0(" > indexed dataset stats:") + print_rank_0( + " number of documents: {}".format(indexed_dataset.doc_idx.shape[0] - 1) + ) + print_rank_0(" number of sentences: {}".format(indexed_dataset.sizes.shape[0])) + + return indexed_dataset + + +def get_train_valid_test_split_(splits_string, size): + """Get dataset splits from comma or '/' separated string list.""" + + splits = [] + if splits_string.find(",") != -1: + splits = [float(s) for s in splits_string.split(",")] + elif splits_string.find("/") != -1: + splits = [float(s) for s in splits_string.split("/")] + else: + splits = [float(splits_string)] + while len(splits) < 3: + splits.append(0.0) + splits = splits[:3] + splits_sum = sum(splits) + assert splits_sum > 0.0 + splits = [split / splits_sum for split in splits] + splits_index = [0] + for index, split in enumerate(splits): + splits_index.append(splits_index[index] + int(round(split * float(size)))) + diff = splits_index[-1] - size + for index in range(1, len(splits_index)): + splits_index[index] -= diff + assert len(splits_index) == 4 + assert splits_index[-1] == size + return splits_index + + +def get_samples_mapping( + indexed_dataset, + data_prefix, + num_epochs, + max_num_samples, + max_seq_length, + short_seq_prob, + seed, + name, + binary_head, +): + """Get a list that maps a sample index to a starting sentence index, end sentence index, and length""" + + if not num_epochs: + if not max_num_samples: + raise ValueError("Need to specify either max_num_samples " "or num_epochs") + num_epochs = np.iinfo(np.int32).max - 1 + if not max_num_samples: + max_num_samples = np.iinfo(np.int64).max - 1 + + # Filename of the index mapping + indexmap_filename = data_prefix + indexmap_filename += "_{}_indexmap".format(name) + if num_epochs != (np.iinfo(np.int32).max - 1): + indexmap_filename += "_{}ep".format(num_epochs) + if max_num_samples != (np.iinfo(np.int64).max - 1): + indexmap_filename += "_{}mns".format(max_num_samples) + indexmap_filename += "_{}msl".format(max_seq_length) + indexmap_filename += "_{:0.2f}ssp".format(short_seq_prob) + indexmap_filename += "_{}s".format(seed) + indexmap_filename += ".npy" + + # Build the indexed mapping if not exist. + if torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename): + print( + " > WARNING: could not find index map file {}, building " + "the indices on rank 0 ...".format(indexmap_filename) + ) + + # Make sure the types match the helpers input types. + assert indexed_dataset.doc_idx.dtype == np.int64 + assert indexed_dataset.sizes.dtype == np.int32 + + # Build samples mapping + verbose = torch.distributed.get_rank() == 0 + start_time = time.time() + print_rank_0(" > building sapmles index mapping for {} ...".format(name)) + # First compile and then import. + from megatron.data import helpers + + samples_mapping = helpers.build_mapping( + indexed_dataset.doc_idx, + indexed_dataset.sizes, + num_epochs, + max_num_samples, + max_seq_length, + short_seq_prob, + seed, + verbose, + 2 if binary_head else 1, + ) + print_rank_0(" > done building sapmles index maping") + np.save(indexmap_filename, samples_mapping, allow_pickle=True) + print_rank_0(" > saved the index mapping in {}".format(indexmap_filename)) + # Make sure all the ranks have built the mapping + print_rank_0( + " > elasped time to build and save samples mapping " + "(seconds): {:4f}".format(time.time() - start_time) + ) + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) + assert counts[0].item() == ( + torch.distributed.get_world_size() + // torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()) + ) + + # Load indexed dataset. + print_rank_0(" > loading indexed mapping from {}".format(indexmap_filename)) + start_time = time.time() + samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode="r") + print_rank_0( + " loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) + ) + print_rank_0(" total number of samples: {}".format(samples_mapping.shape[0])) + + return samples_mapping diff --git a/codegeex/megatron/data/helpers.cpp b/codegeex/megatron/data/helpers.cpp new file mode 100644 index 0000000..e45926a --- /dev/null +++ b/codegeex/megatron/data/helpers.cpp @@ -0,0 +1,717 @@ +/* + coding=utf-8 + Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + + +/* Helper methods for fast index mapping builds */ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; +using namespace std; + +const int32_t LONG_SENTENCE_LEN = 512; + + +void build_blending_indices(py::array_t& dataset_index, + py::array_t& dataset_sample_index, + const py::array_t& weights, + const int32_t num_datasets, + const int64_t size, const bool verbose) { + /* Given multiple datasets and a weighting array, build samples + such that it follows those wieghts.*/ + + if (verbose) { + std::cout << "> building indices for blendable datasets ..." << std::endl; + } + + // Get the pointer access without the checks. + auto dataset_index_ptr = dataset_index.mutable_unchecked<1>(); + auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>(); + auto weights_ptr = weights.unchecked<1>(); + + // Initialize buffer for number of samples used for each dataset. + int64_t current_samples[num_datasets]; + for(int64_t i = 0; i < num_datasets; ++i) { + current_samples[i] = 0; + } + + // For each sample: + for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { + + // Determine where the max error in sampling is happening. + auto sample_idx_double = std::max(static_cast(sample_idx), 1.0); + int64_t max_error_index = 0; + double max_error = weights_ptr[0] * sample_idx_double - + static_cast(current_samples[0]); + for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { + double error = weights_ptr[dataset_idx] * sample_idx_double - + static_cast(current_samples[dataset_idx]); + if (error > max_error) { + max_error = error; + max_error_index = dataset_idx; + } + } + + // Populate the indices. + dataset_index_ptr[sample_idx] = static_cast(max_error_index); + dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index]; + + // Update the total samples. + current_samples[max_error_index] += 1; + + } + + // print info + if (verbose) { + std::cout << " > sample ratios:" << std::endl; + for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { + auto ratio = static_cast(current_samples[dataset_idx]) / + static_cast(size); + std::cout << " dataset " << dataset_idx << ", input: " << + weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; + } + } + +} + + +py::array build_sample_idx(const py::array_t& sizes_, + const py::array_t& doc_idx_, + const int32_t seq_length, + const int32_t num_epochs, + const int64_t tokens_per_epoch) { + /* Sample index (sample_idx) is used for gpt2 like dataset for which + the documents are flattened and the samples are built based on this + 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] + where [..., 0] contains the index into `doc_idx` and [..., 1] is the + starting offset in that document.*/ + + // Consistency checks. + assert(seq_length > 1); + assert(num_epochs > 0); + assert(tokens_per_epoch > 1); + + // Remove bound checks. + auto sizes = sizes_.unchecked<1>(); + auto doc_idx = doc_idx_.unchecked<1>(); + + // Mapping and it's length (1D). + int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; + int32_t* sample_idx = new int32_t[2*(num_samples+1)]; + + cout << " using:" << endl << std::flush; + cout << " number of documents: " << + doc_idx_.shape(0) / num_epochs << endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " sequence length: " << seq_length << + endl << std::flush; + cout << " total number of samples: " << num_samples << + endl << std::flush; + + // Index into sample_idx. + int64_t sample_index = 0; + // Index into doc_idx. + int64_t doc_idx_index = 0; + // Begining offset for each document. + int32_t doc_offset = 0; + // Start with first document and no offset. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + + while (sample_index <= num_samples) { + // Start with a fresh sequence. + int32_t remaining_seq_length = seq_length + 1; + while (remaining_seq_length != 0) { + // Get the document length. + auto doc_id = doc_idx[doc_idx_index]; + auto doc_length = sizes[doc_id] - doc_offset; + // And add it to the current sequence. + remaining_seq_length -= doc_length; + // If we have more than a full sequence, adjust offset and set + // remaining length to zero so we return from the while loop. + // Note that -1 here is for the same reason we have -1 in + // `_num_epochs` calculations. + if (remaining_seq_length <= 0) { + doc_offset += (remaining_seq_length + doc_length - 1); + remaining_seq_length = 0; + } else { + // Otherwise, start from the begining of the next document. + ++doc_idx_index; + doc_offset = 0; + } + } + // Record the sequence. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + } + + // Method to deallocate memory. + py::capsule free_when_done(sample_idx, [](void *mem_) { + int32_t *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(int32_t); + return py::array(std::vector{num_samples+1, 2}, // shape + {2*byte_size, byte_size}, // C-style contiguous strides + sample_idx, // the data pointer + free_when_done); // numpy array references + +} + + +inline int32_t get_target_sample_len(const int32_t short_seq_ratio, + const int32_t max_length, + std::mt19937& rand32_gen) { + /* Training sample length. */ + if (short_seq_ratio == 0) { + return max_length; + } + const auto random_number = rand32_gen(); + if ((random_number % short_seq_ratio) == 0) { + return 2 + random_number % (max_length - 1); + } + return max_length; +} + + +template +py::array build_mapping_impl(const py::array_t& docs_, + const py::array_t& sizes_, + const int32_t num_epochs, + const uint64_t max_num_samples, + const int32_t max_seq_length, + const double short_seq_prob, + const int32_t seed, + const bool verbose, + const int32_t min_num_sent) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(short_seq_prob >= 0.0); + assert(short_seq_prob <= 1.0); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + + // For efficiency, convert probability to ratio. Note: rand() generates int. + int32_t short_seq_ratio = 0; + if (short_seq_prob > 0) { + short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); + } + + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << + endl << std::flush; + cout << " sentences range: [" << sent_start_index << + ", " << sent_end_index << ")" << endl << std::flush; + cout << " total number of sentences: " << num_sentences << + endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << + endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << + endl << std::flush; + cout << " short sequence probability: " << short_seq_prob << + endl << std::flush; + cout << " short sequence ration (1/prob): " << short_seq_ratio << + endl << std::flush; + cout << " seed: " << seed << endl << + std::flush; + } + + // Mapping and it's length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration=0; iteration<2; ++iteration) { + + // Set the seed so both iterations produce the same results. + std::mt19937 rand32_gen(seed); + + // Set the flag on second iteration. + second = (iteration == 1); + + // Counters: + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + + // Current map index. + uint64_t map_index = 0; + + // For each epoch: + for (int32_t epoch=0; epoch= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl << std::flush; + } + break; + } + // For each document: + for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { + + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent > 1) { + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN){ + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; + } + } + } + + // If we have more than two sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + auto target_seq_len = get_target_sample_len(short_seq_ratio, + max_seq_length, + rand32_gen); + + // Loop through sentences. + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and if not only one sentence is left in the document. + // and if we have at least two sentneces. + // and if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent > 1) && + (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { + + // Check for overflow. + if ((3 * map_index + 2) > + std::numeric_limits::max()) { + cout << "number of samples exceeded maximum " + << "allowed by type int64: " + << std::numeric_limits::max() + << endl; + throw std::overflow_error("Number of samples"); + } + + // Populate the map. + if (second) { + const auto map_index_0 = 3 * map_index; + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(target_seq_len); + } + + // Update indices / counters. + ++map_index; + prev_start_index = sent_index + 1; + target_seq_len = get_target_sample_len(short_seq_ratio, + max_seq_length, + rand32_gen); + seq_len = 0; + num_sent = 0; + } + + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << + endl << std::flush; + cout << " number of documents with one sentence: " << + one_sent_docs << endl << std::flush; + cout << " number of documents with long sentences: " << + long_sent_docs << endl << std::flush; + cout << " will create mapping for " << map_index << + " samples" << endl << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[3*map_index]; + num_samples = static_cast(map_index); + } + + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i=(num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 3 * i; + const auto j0 = 3 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + } + + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void *mem_) { + DocIdx *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 3}, // shape + {3*byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references + +} + + +py::array build_mapping(const py::array_t& docs_, + const py::array_t& sizes_, + const int num_epochs, + const uint64_t max_num_samples, + const int max_seq_length, + const double short_seq_prob, + const int seed, + const bool verbose, + const int32_t min_num_sent) { + + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_mapping_impl(docs_, sizes_, num_epochs, + max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, + min_num_sent); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; + } + return build_mapping_impl(docs_, sizes_, num_epochs, + max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, + min_num_sent); + } +} + +template +py::array build_blocks_mapping_impl(const py::array_t& docs_, + const py::array_t& sizes_, + const py::array_t& titles_sizes_, + const int32_t num_epochs, + const uint64_t max_num_samples, + const int32_t max_seq_length, + const int32_t seed, + const bool verbose, + const bool use_one_sent_blocks) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + auto titles_sizes = titles_sizes_.unchecked<1>(); + + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << + endl << std::flush; + cout << " sentences range: [" << sent_start_index << + ", " << sent_end_index << ")" << endl << std::flush; + cout << " total number of sentences: " << num_sentences << + endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << + endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << + endl << std::flush; + cout << " seed: " << seed << endl << + std::flush; + } + + // Mapping and its length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Acceptable number of sentences per block. + int min_num_sent = 2; + if (use_one_sent_blocks) { + min_num_sent = 1; + } + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration=0; iteration<2; ++iteration) { + + // Set the flag on second iteration. + second = (iteration == 1); + + // Current map index. + uint64_t map_index = 0; + + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + // For each epoch: + for (int32_t epoch=0; epoch= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl << std::flush; + } + break; + } + // For each document: + for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { + + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + const auto target_seq_len = max_seq_length - titles_sizes[doc]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent >= min_num_sent) { + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN){ + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; + } + } + } + // If we have enough sentences and no long sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + + // Loop through sentences. + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and there are an acceptable number of sentences left + // and if we have at least the minimum number of sentences. + // or if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent >= min_num_sent) && + (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { + + // Populate the map. + if (second) { + const auto map_index_0 = 4 * map_index; + // Each sample has 4 items: the starting sentence index, ending sentence index, + // the index of the document from which the block comes (used for fetching titles) + // and the unique id of the block (used for creating block indexes) + + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(doc); + maps[map_index_0 + 3] = static_cast(block_id); + } + + // Update indices / counters. + ++map_index; + ++block_id; + prev_start_index = sent_index + 1; + seq_len = 0; + num_sent = 0; + } + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << + endl << std::flush; + cout << " number of documents with one sentence: " << + one_sent_docs << endl << std::flush; + cout << " number of documents with long sentences: " << + long_sent_docs << endl << std::flush; + cout << " will create mapping for " << map_index << + " samples" << endl << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[4*map_index]; + num_samples = static_cast(map_index); + } + + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i=(num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 4 * i; + const auto j0 = 4 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + swap(maps[i0 + 3], maps[j0 + 3]); + } + + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void *mem_) { + DocIdx *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 4}, // shape + {4*byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references + +} + +py::array build_blocks_mapping(const py::array_t& docs_, + const py::array_t& sizes_, + const py::array_t& titles_sizes_, + const int num_epochs, + const uint64_t max_num_samples, + const int max_seq_length, + const int seed, + const bool verbose, + const bool use_one_sent_blocks) { + + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, + num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, + num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + } +} + +PYBIND11_MODULE(helpers, m) { + m.def("build_mapping", &build_mapping); + m.def("build_blocks_mapping", &build_blocks_mapping); + m.def("build_sample_idx", &build_sample_idx); + m.def("build_blending_indices", &build_blending_indices); +} diff --git a/codegeex/megatron/data/indexed_dataset.py b/codegeex/megatron/data/indexed_dataset.py new file mode 100644 index 0000000..fe8d2a9 --- /dev/null +++ b/codegeex/megatron/data/indexed_dataset.py @@ -0,0 +1,595 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# copied from fairseq/fairseq/data/indexed_dataset.py +# Removed IndexedRawTextDataset since it relied on Fairseq dictionary +# other slight modifications to remove fairseq dependencies +# Added document index to index file and made it accessible. +# An empty sentence no longer separates documents. + + +import os +import shutil +import struct + +import torch +import numpy as np + +from functools import lru_cache +from itertools import accumulate +from codegeex.megatron import print_rank_0 + + +def __best_fitting_dtype(vocab_size=None): + if vocab_size is not None and vocab_size < 65500: + return np.uint16 + else: + return np.int32 + + +def get_available_dataset_impl(): + return ["lazy", "cached", "mmap"] + + +def infer_dataset_impl(path): + if IndexedDataset.exists(path): + with open(index_file_path(path), "rb") as f: + magic = f.read(8) + if magic == IndexedDataset._HDR_MAGIC: + return "cached" + elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: + return "mmap" + else: + return None + else: + print(f"Dataset does not exist: {path}") + print( + "Path should be a basename that both .idx and .bin can be appended to get full filenames." + ) + return None + + +def make_builder(out_file, impl, vocab_size=None): + if impl == "mmap": + return MMapIndexedDatasetBuilder( + out_file, dtype=__best_fitting_dtype(vocab_size) + ) + else: + return IndexedDatasetBuilder(out_file) + + +def make_dataset(path, impl, skip_warmup=False): + if not IndexedDataset.exists(path): + print(f"Dataset does not exist: {path}") + print( + "Path should be a basename that both .idx and .bin can be appended to get full filenames." + ) + return None + if impl == "infer": + impl = infer_dataset_impl(path) + if impl == "lazy" and IndexedDataset.exists(path): + return IndexedDataset(path) + elif impl == "cached" and IndexedDataset.exists(path): + return IndexedCachedDataset(path) + elif impl == "mmap" and MMapIndexedDataset.exists(path): + return MMapIndexedDataset(path, skip_warmup) + print(f"Unknown dataset implementation: {impl}") + return None + + +def dataset_exists(path, impl): + if impl == "mmap": + return MMapIndexedDataset.exists(path) + else: + return IndexedDataset.exists(path) + + +def read_longs(f, n): + a = np.empty(n, dtype=np.int64) + f.readinto(a) + return a + + +def write_longs(f, a): + f.write(np.array(a, dtype=np.int64)) + + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float, + 7: np.double, + 8: np.uint16, +} + + +def __best_fitting_dtype(vocab_size=None): + if vocab_size is not None and vocab_size < 65500: + return np.uint16 + else: + return np.int32 + + +def make_mmap_builder(out_file, vocab_size=None): + return MMapIndexedDatasetBuilder( + out_file, dtype=__best_fitting_dtype(vocab_size) + ) + + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + ".idx" + + +def data_file_path(prefix_path): + return prefix_path + ".bin" + + +def create_doc_idx(sizes): + doc_idx = [0] + for i, s in enumerate(sizes): + if s == 0: + doc_idx.append(i + 1) + return doc_idx + + +class IndexedDataset(torch.utils.data.Dataset): + """Loader for IndexedDataset""" + + _HDR_MAGIC = b"TNTIDX\x00\x00" + + def __init__(self, path): + super().__init__() + self.path = path + self.data_file = None + self.read_index(path) + + def read_index(self, path): + with open(index_file_path(path), "rb") as f: + magic = f.read(8) + assert magic == self._HDR_MAGIC, ( + "Index file doesn't match expected format. " + "Make sure that --dataset-impl is configured properly." + ) + version = f.read(8) + assert struct.unpack("= self._len: + raise IndexError("index out of range") + + def __del__(self): + if self.data_file: + self.data_file.close() + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if not self.data_file: + self.read_data(self.path) + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + return a + elif isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + if step != 1: + raise ValueError("Slices into indexed_dataset must be contiguous") + sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]] + size = sum(sizes) + a = np.empty(size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[start] * self.element_size) + self.data_file.readinto(a) + offsets = list(accumulate(sizes)) + sents = np.split(a, offsets[:-1]) + return sents + + def __len__(self): + return self._len + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return os.path.exists(index_file_path(path)) and os.path.exists( + data_file_path(path) + ) + + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + + +class IndexedCachedDataset(IndexedDataset): + def __init__(self, path): + super().__init__(path) + self.cache = None + self.cache_index = {} + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + if all(i in self.cache_index for i in indices): + return + if not self.data_file: + self.read_data(self.path) + indices = sorted(set(indices)) + total_size = 0 + for i in indices: + total_size += self.data_offsets[i + 1] - self.data_offsets[i] + self.cache = np.empty(total_size, dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for i in indices: + self.cache_index[i] = ptx + size = self.data_offsets[i + 1] - self.data_offsets[i] + a = self.cache[ptx : ptx + size] + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + ptx += size + if self.data_file: + # close and delete data file after prefetch so we can pickle + self.data_file.close() + self.data_file = None + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + ptx = self.cache_index[i] + np.copyto(a, self.cache[ptx : ptx + a.size]) + return a + elif isinstance(idx, slice): + # Hack just to make this work, can optimizer later if necessary + sents = [] + for i in range(*idx.indices(len(self))): + sents.append(self[i]) + return sents + + +class IndexedDatasetBuilder(object): + element_sizes = { + np.uint8: 1, + np.int8: 1, + np.int16: 2, + np.int32: 4, + np.int64: 8, + np.float: 4, + np.double: 8, + } + + def __init__(self, out_file, dtype=np.int32): + self.out_file = open(out_file, "wb") + self.dtype = dtype + self.data_offsets = [0] + self.dim_offsets = [0] + self.sizes = [] + self.element_size = self.element_sizes[self.dtype] + self.doc_idx = [0] + + def add_item(self, tensor): + bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype)) + self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) + for s in tensor.size(): + self.sizes.append(s) + self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) + + def end_document(self): + self.doc_idx.append(len(self.sizes)) + + def merge_file_(self, another_file): + index = IndexedDataset(another_file) + assert index.dtype == self.dtype + + begin = self.data_offsets[-1] + for offset in index.data_offsets[1:]: + self.data_offsets.append(begin + offset) + self.sizes.extend(index.sizes) + begin = self.dim_offsets[-1] + for dim_offset in index.dim_offsets[1:]: + self.dim_offsets.append(begin + dim_offset) + + with open(data_file_path(another_file), "rb") as f: + while True: + data = f.read(1024) + if data: + self.out_file.write(data) + else: + break + + def finalize(self, index_file): + self.out_file.close() + index = open(index_file, "wb") + index.write(b"TNTIDX\x00\x00") + index.write(struct.pack(" dataset split:") + + def print_split_stats(name, index): + print_rank_0(" {}:".format(name)) + print_rank_0( + " document indices in [{}, {}) total of {} " + "documents".format( + splits[index], splits[index + 1], splits[index + 1] - splits[index] + ) + ) + + print_split_stats("train", 0) + print_split_stats("validation", 1) + print_split_stats("test", 2) + + def build_dataset(index, name): + dataset = None + if splits[index + 1] > splits[index]: + documents = np.arange( + start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32 + ) + dataset = PromptDataset( + name, + data_prefix, + documents, + input_ids_indexed_dataset, + attention_mask_indexed_dataset, + labels_indexed_dataset, + train_valid_test_num_samples[index], + seq_length, + seed, + ) + return dataset + + train_dataset = build_dataset(0, "train") + valid_dataset = build_dataset(1, "valid") + test_dataset = build_dataset(2, "test") + print_rank_0(f"train_dataset:{type(train_dataset)}") + print_rank_0(f"valid_dataset:{type(valid_dataset)}") + print_rank_0(f"test_dataset:{type(test_dataset)}") + + return (train_dataset, valid_dataset, test_dataset) + + +def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): + """Build indexed dataset.""" + print_rank_0(" > building dataset index ...") + + start_time = time.time() + indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) + print_rank_0( + " > finished creating indexed dataset in {:4f} " + "seconds".format(time.time() - start_time) + ) + print_rank_0(" number of documents: {}".format(indexed_dataset.sizes.shape[0])) + + return indexed_dataset + + +class PromptDataset(torch.utils.data.Dataset): + def __init__( + self, + name, + data_prefix, + documents, + input_ids_indexed_dataset, + attention_mask_index_dataset, + labels_indexed_dataset, + num_samples, + seq_length, + seed, + ): + """ + Args: + name: name of the dataset. + data_prefix: prefix of the data. + documents: list of document indices. + input_ids_indexed_dataset: indexed dataset for prompts. + attention_mask_index_dataset: indexed dataset for text. + labels_indexed_dataset: indexed dataset for labels. + num_samples: number of samples to draw from the indexed dataset. + seq_length: sequence length. + seed: seed for random number generator. + """ + + self.name = name + self.input_ids_indexed_dataset = input_ids_indexed_dataset + self.attention_mask_index_dataset = attention_mask_index_dataset + self.labels_indexed_dataset = labels_indexed_dataset + self.seq_length = seq_length + self.eod_token = get_tokenizer().eod + + # Checks + assert np.min(documents) >= 0 + assert np.max(documents) < input_ids_indexed_dataset.sizes.shape[0] + assert input_ids_indexed_dataset.sizes.shape[0] == attention_mask_index_dataset.sizes.shape[0] + assert attention_mask_index_dataset.sizes.shape[0] == labels_indexed_dataset.sizes.shape[0] + + # Build index mappings. + self.doc_idx = _build_index_mappings( + self.name, + data_prefix, + documents, + self.input_ids_indexed_dataset.sizes, + num_samples, + seq_length, + seed, + ) + + def __len__(self): + # -1 is due to data structure used to retieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + return self.doc_idx.shape[0] + + def __getitem__(self, idx): + # get the doc index + doc_idx = self.doc_idx[idx] + doc_idx = int(doc_idx) # NumPy int => Python int + + input_ids = self.input_ids_indexed_dataset[doc_idx] + # print_rank_0(f"input_ids={input_ids}") + attention_mask = self.attention_mask_index_dataset[doc_idx] + labels = self.labels_indexed_dataset[doc_idx] + + res = { + "input_ids": np.array(input_ids, dtype=np.int64), + "attention_mask": np.array(attention_mask, dtype=np.int64), + "labels": np.array(labels, dtype=np.int64), + } + + return res + + +def _build_index_mappings( + name, data_prefix, documents, sizes, num_samples, seq_length, seed, +): + """Build index mappings. + We only have to build doc-idx in prompt dataset. + + Args: + name: name of the dataset. + data_prefix: prefix of the data. + documents: list of document indices. + sizes: sizes of the indexed dataset. + num_samples: number of samples to draw from the indexed dataset. + seq_length: sequence length. + seed: seed for random number generator. + """ + num_epochs = _num_epochs(documents.shape[0], num_samples) + np_rng = np.random.RandomState(seed=seed) + + _filename = data_prefix + _filename += "_{}_indexmap".format(name) + _filename += "_{}ns".format(num_samples) + _filename += "_{}sl".format(seq_length) + _filename += "_{}s".format(seed) + doc_idx_filename = _filename + "_doc_idx.npy" + + if torch.distributed.get_rank() == 0: + if not os.path.isfile(doc_idx_filename): + print_rank_0( + " > WARNING: could not find index map files, building " + "the indices on rank 0 ..." + ) + + start_time = time.time() + doc_idx = _build_doc_idx(documents, num_epochs, np_rng, False)[:num_samples] + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + + print_rank_0( + " > elasped time to build and save doc-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time) + ) + + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) + assert counts[0].item() == ( + torch.distributed.get_world_size() + // torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()) + ) + # Load mappings. + start_time = time.time() + print_rank_0(" > loading doc-idx mapping from {}".format(doc_idx_filename)) + doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0(" total number of samples: {}".format(doc_idx.shape[0])) + print_rank_0(" total number of epochs: {}".format(num_epochs)) + + return doc_idx + + +def _num_epochs(samples_per_epoch, num_samples): + """Calculate the epoch needed for so many sample.""" + return int(np.ceil(num_samples / samples_per_epoch)) + + +def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch): + """Build an array with length = number-of-epochs * number-of-dcuments. + Each index is mapped to a corresponding document.""" + if not separate_last_epoch or num_epochs == 1: + doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1] + doc_idx[:] = documents + doc_idx = doc_idx.reshape(-1) + doc_idx = doc_idx.astype(np.int32) + np_rng.shuffle(doc_idx) + return doc_idx + + doc_idx_first = _build_doc_idx(documents, num_epochs - 1, np_rng, False) + doc_idx_last = _build_doc_idx(documents, 1, np_rng, False) + return np.concatenate((doc_idx_first, doc_idx_last)) \ No newline at end of file