You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

162 lines
6.0 KiB
Python

from typing import *
from time import perf_counter
from codegeex.data.data_utils import sliding_window
from codegeex.data.types import PromptSample, LabelSample
class PromptDatasetProcessor(object):
def __init__(
self,
tokenize: Callable,
pad_token: int,
keep_order: bool = False,
max_seq_len: int = 2048,
sliding_stride: int = 200,
discard_overlong: bool = True,
eod_token: int = None,
preprocess: Callable = None,
):
super(PromptDatasetProcessor, self).__init__()
self._keep_order = keep_order
self._max_seq_len = max_seq_len
self._sliding_stride = sliding_stride
self._tokenize = tokenize
self._pad_token = pad_token
self._discard_overlong = discard_overlong
self._eod_token = eod_token
self._preprocess = preprocess
self.doc_processed = 0
self.doc_generated = 0
self.start_time = 0
def pad_seq(self, prompt_tokens: List[int], code_tokens: List[int], extra: dict = None) -> Dict[str, List[int]]:
total_length = len(prompt_tokens) + len(code_tokens)
assert total_length <= self._max_seq_len, f"padding sequence: {total_length} > {self._max_seq_len}"
pad_len = self._max_seq_len - total_length
input_ids = prompt_tokens + code_tokens + [self._pad_token] * pad_len
attention_mask = [1] * len(prompt_tokens) + [1] * len(code_tokens) + [0] * pad_len
labels = [-100] * len(prompt_tokens) + code_tokens + [-100] * pad_len
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
def process_sample(self, sample: PromptSample) -> Iterable[Dict[str, List[int]]]:
"""
Process a sample.
"""
prompt_tokens = self._tokenize(sample.prompt)
code_tokens = self._tokenize(sample.code)
if self._eod_token is not None:
code_tokens.append(self._eod_token)
if len(prompt_tokens) + len(code_tokens) > self._max_seq_len:
if self._discard_overlong:
return
for p, t in sliding_window(prompt_tokens, code_tokens, self._max_seq_len, self._sliding_stride, self._sliding_stride):
yield self.pad_seq(p, t)
else:
yield self.pad_seq(prompt_tokens, code_tokens, extra=sample.extra)
def process_sample_strict(self, sample: PromptSample) -> List[Dict[str, List[int]]]:
"""
Instead of processing lazily, we turn the iterable into a list.
"""
if sample is None:
return None
return list(self.process_sample(sample))
def process_sample_(self, sample) -> List[Dict[str, List[int]]]:
prompt_sample = self._preprocess(sample)
return self.process_sample_strict(prompt_sample)
def report(self):
duration = perf_counter() - self.start_time
process_speed = self.doc_processed * 1.0 / duration
gen_speed = self.doc_generated * 1.0 / duration
print(f">>> processed: {self.doc_processed} in {duration:.2f}s, speed: {process_speed:.2f} docs/s")
print(f"... generated: {self.doc_generated} in {duration:.2f}s, speed: {gen_speed:.2f} docs/s")
class LabelDatasetProcessor(object):
def __init__(
self,
tokenize: Callable,
pad_token: int,
keep_order: bool = False,
max_seq_len: int = 2048,
sliding_stride: int = 200,
discard_overlong: bool = True,
eod_token: int = None,
preprocess: Callable = None,
):
super(LabelDatasetProcessor, self).__init__()
self._keep_order = keep_order
self._max_seq_len = max_seq_len
self._sliding_stride = sliding_stride
self._tokenize = tokenize
self._pad_token = pad_token
self._discard_overlong = discard_overlong
self._eod_token = eod_token
self._preprocess = preprocess
self.doc_processed = 0
self.doc_generated = 0
self.start_time = 0
def pad_seq(self, prompt_tokens: List[int], label: int, extra: dict = None) -> Dict[str, List[int]]:
total_length = len(prompt_tokens)
assert total_length <= self._max_seq_len, f"padding sequence: {total_length} > {self._max_seq_len}"
pad_len = self._max_seq_len - total_length
input_ids = prompt_tokens + [self._pad_token] * pad_len
attention_mask = [1] * len(prompt_tokens) + [0] * pad_len
label = [label]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"length": [len(prompt_tokens)],
"labels": label
}
def process_sample(self, sample: LabelSample) -> Iterable[Dict[str, List[int]]]:
"""
Process a sample.
"""
prompt_tokens = self._tokenize(sample.prompt)
label = sample.label
if len(prompt_tokens) > self._max_seq_len:
if self._discard_overlong:
return
prompt_tokens=prompt_tokens[-self._max_seq_len:]
yield self.pad_seq(prompt_tokens, label, extra=sample.extra)
def process_sample_strict(self, sample: LabelSample) -> List[Dict[str, List[int]]]:
"""
Instead of processing lazily, we turn the iterable into a list.
"""
if sample is None:
return None
return list(self.process_sample(sample))
def process_sample_(self, sample) -> List[Dict[str, List[int]]]:
prompt_sample = self._preprocess(sample)
return self.process_sample_strict(prompt_sample)
def report(self):
duration = perf_counter() - self.start_time
process_speed = self.doc_processed * 1.0 / duration
gen_speed = self.doc_generated * 1.0 / duration
print(f">>> processed: {self.doc_processed} in {duration:.2f}s, speed: {process_speed:.2f} docs/s")
print(f"... generated: {self.doc_generated} in {duration:.2f}s, speed: {gen_speed:.2f} docs/s")