mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
145 lines
4.4 KiB
Python
145 lines
4.4 KiB
Python
import os
|
|
import glob
|
|
import fire
|
|
import torch
|
|
import multiprocessing
|
|
|
|
from typing import *
|
|
from tqdm.auto import tqdm
|
|
from time import perf_counter
|
|
from black import format_str, FileMode
|
|
|
|
from codegeex.data.types import PromptDataset, PromptSample
|
|
from codegeex.data.processor import PromptDatasetProcessor
|
|
from codegeex.data.data_utils import stream_jsonl, LANGUAGE_TAG
|
|
from codegeex.megatron.data.indexed_dataset import make_mmap_builder
|
|
from codegeex.tokenizer import CodeGeeXTokenizer
|
|
|
|
|
|
def try_format_code(code: str):
|
|
# Auto-correct to PEP8 format (Change tab to 4-whitespaces;
|
|
# add whitespace around some special symbols;
|
|
# reformat line length < 100, etc.)
|
|
try:
|
|
res = format_str(code, mode=FileMode(line_length=200))
|
|
except Exception as e:
|
|
res = code
|
|
print(e)
|
|
print("Wrong python format: {}".format(code))
|
|
return res
|
|
|
|
|
|
def load_pretrain_dataset(dataset_path: Union[str, List[str]]) -> Dict:
|
|
if type(dataset_path) is str:
|
|
dataset_path = [dataset_path]
|
|
|
|
for p in dataset_path:
|
|
if not os.path.isdir(p):
|
|
if p.endswith(".gz") or p.endswith(".jsonl"):
|
|
print(f"loading from {p}")
|
|
yield from stream_jsonl(p)
|
|
else:
|
|
p_list = glob.glob(p + "/*")
|
|
for p_ in p_list:
|
|
if p_.endswith(".gz") or p_.endswith(".jsonl"):
|
|
print(f"loading from {p_}")
|
|
yield from stream_jsonl(p_)
|
|
|
|
|
|
def process_sample(
|
|
sample: Dict,
|
|
language: str=None,
|
|
mode: str="pretrain",
|
|
) -> Iterable[PromptSample]:
|
|
if mode == "pretrain":
|
|
prompt = ""
|
|
else:
|
|
prompt = sample["prompt"]
|
|
|
|
try:
|
|
if language is not None and language in LANGUAGE_TAG.keys():
|
|
code = LANGUAGE_TAG[language] + "\n" + sample["code"]
|
|
else:
|
|
code = sample["code"]
|
|
except Exception as e:
|
|
print(e)
|
|
print("The key 'code' is missing in data. Aborted")
|
|
exit(0)
|
|
|
|
yield PromptSample(prompt, code)
|
|
|
|
|
|
def generate_prompt_samples(
|
|
dataset: Iterable[Dict],
|
|
language: str = None,
|
|
mode: str = "pretrain",
|
|
) -> PromptDataset:
|
|
for sample in dataset:
|
|
yield from process_sample(sample, language, mode)
|
|
|
|
|
|
def main(
|
|
tokenizer_path: str,
|
|
dataset_path: Union[str, List[str]],
|
|
output_prefix: str,
|
|
language: str = None,
|
|
mode: str = "pretrain",
|
|
discard_overlong: bool = False,
|
|
sliding_stride: int = 200,
|
|
num_workers: int = 32,
|
|
seq_len: int = 2048,
|
|
):
|
|
DATA_KEYS = ["input_ids", "attention_mask", "labels"]
|
|
|
|
# create output dir
|
|
os.makedirs(os.path.dirname(output_prefix), exist_ok=True)
|
|
|
|
tokenizer = CodeGeeXTokenizer(tokenizer_path=tokenizer_path)
|
|
pad_token_id = tokenizer.eos_token_id
|
|
|
|
dataset = load_pretrain_dataset(dataset_path)
|
|
prompt_dataset = generate_prompt_samples(dataset, language=language, mode=mode)
|
|
|
|
if num_workers == 0:
|
|
num_workers = multiprocessing.cpu_count()
|
|
pool = multiprocessing.Pool(num_workers)
|
|
output_bin_files = {}
|
|
output_idx_files = {}
|
|
builders = {}
|
|
|
|
for key in DATA_KEYS:
|
|
output_bin_files[key] = "{}_{}.bin".format(output_prefix, key)
|
|
output_idx_files[key] = "{}_{}.idx".format(output_prefix, key)
|
|
builders[key] = make_mmap_builder(
|
|
output_bin_files[key],
|
|
vocab_size=None, # magic number, should change it
|
|
)
|
|
|
|
# NOTE that we use seq_len + 1 instead of seq_len, since the input tokens will be shifted by one.
|
|
processor = PromptDatasetProcessor(
|
|
tokenize=tokenizer.encode_code,
|
|
pad_token=pad_token_id,
|
|
max_seq_len=seq_len + 1,
|
|
discard_overlong=discard_overlong,
|
|
sliding_stride=sliding_stride,
|
|
eod_token=pad_token_id)
|
|
|
|
processor.start_time = perf_counter()
|
|
doc_iter = pool.imap_unordered(processor.process_sample_strict,
|
|
prompt_dataset,
|
|
chunksize=20)
|
|
|
|
for doc_idx, docs in tqdm(enumerate(doc_iter, start=1)):
|
|
processor.doc_processed += 1
|
|
for doc in docs:
|
|
processor.doc_generated += 1
|
|
for key in DATA_KEYS:
|
|
builders[key].add_item(torch.IntTensor(doc[key]))
|
|
|
|
for key in DATA_KEYS:
|
|
builders[key].finalize(output_idx_files[key])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(main)
|