from typing import * from time import perf_counter from codegeex.data.data_utils import sliding_window from codegeex.data.types import PromptSample, LabelSample class PromptDatasetProcessor(object): def __init__( self, tokenize: Callable, pad_token: int, keep_order: bool = False, max_seq_len: int = 2048, sliding_stride: int = 200, discard_overlong: bool = True, eod_token: int = None, preprocess: Callable = None, ): super(PromptDatasetProcessor, self).__init__() self._keep_order = keep_order self._max_seq_len = max_seq_len self._sliding_stride = sliding_stride self._tokenize = tokenize self._pad_token = pad_token self._discard_overlong = discard_overlong self._eod_token = eod_token self._preprocess = preprocess self.doc_processed = 0 self.doc_generated = 0 self.start_time = 0 def pad_seq(self, prompt_tokens: List[int], code_tokens: List[int], extra: dict = None) -> Dict[str, List[int]]: total_length = len(prompt_tokens) + len(code_tokens) assert total_length <= self._max_seq_len, f"padding sequence: {total_length} > {self._max_seq_len}" pad_len = self._max_seq_len - total_length input_ids = prompt_tokens + code_tokens + [self._pad_token] * pad_len attention_mask = [1] * len(prompt_tokens) + [1] * len(code_tokens) + [0] * pad_len labels = [-100] * len(prompt_tokens) + code_tokens + [-100] * pad_len return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } def process_sample(self, sample: PromptSample) -> Iterable[Dict[str, List[int]]]: """ Process a sample. """ prompt_tokens = self._tokenize(sample.prompt) code_tokens = self._tokenize(sample.code) if self._eod_token is not None: code_tokens.append(self._eod_token) if len(prompt_tokens) + len(code_tokens) > self._max_seq_len: if self._discard_overlong: return for p, t in sliding_window(prompt_tokens, code_tokens, self._max_seq_len, self._sliding_stride, self._sliding_stride): yield self.pad_seq(p, t) else: yield self.pad_seq(prompt_tokens, code_tokens, extra=sample.extra) def process_sample_strict(self, sample: PromptSample) -> List[Dict[str, List[int]]]: """ Instead of processing lazily, we turn the iterable into a list. """ if sample is None: return None return list(self.process_sample(sample)) def process_sample_(self, sample) -> List[Dict[str, List[int]]]: prompt_sample = self._preprocess(sample) return self.process_sample_strict(prompt_sample) def report(self): duration = perf_counter() - self.start_time process_speed = self.doc_processed * 1.0 / duration gen_speed = self.doc_generated * 1.0 / duration print(f">>> processed: {self.doc_processed} in {duration:.2f}s, speed: {process_speed:.2f} docs/s") print(f"... generated: {self.doc_generated} in {duration:.2f}s, speed: {gen_speed:.2f} docs/s") class LabelDatasetProcessor(object): def __init__( self, tokenize: Callable, pad_token: int, keep_order: bool = False, max_seq_len: int = 2048, sliding_stride: int = 200, discard_overlong: bool = True, eod_token: int = None, preprocess: Callable = None, ): super(LabelDatasetProcessor, self).__init__() self._keep_order = keep_order self._max_seq_len = max_seq_len self._sliding_stride = sliding_stride self._tokenize = tokenize self._pad_token = pad_token self._discard_overlong = discard_overlong self._eod_token = eod_token self._preprocess = preprocess self.doc_processed = 0 self.doc_generated = 0 self.start_time = 0 def pad_seq(self, prompt_tokens: List[int], label: int, extra: dict = None) -> Dict[str, List[int]]: total_length = len(prompt_tokens) assert total_length <= self._max_seq_len, f"padding sequence: {total_length} > {self._max_seq_len}" pad_len = self._max_seq_len - total_length input_ids = prompt_tokens + [self._pad_token] * pad_len attention_mask = [1] * len(prompt_tokens) + [0] * pad_len label = [label] return { "input_ids": input_ids, "attention_mask": attention_mask, "length": [len(prompt_tokens)], "labels": label } def process_sample(self, sample: LabelSample) -> Iterable[Dict[str, List[int]]]: """ Process a sample. """ prompt_tokens = self._tokenize(sample.prompt) label = sample.label if len(prompt_tokens) > self._max_seq_len: if self._discard_overlong: return prompt_tokens=prompt_tokens[-self._max_seq_len:] yield self.pad_seq(prompt_tokens, label, extra=sample.extra) def process_sample_strict(self, sample: LabelSample) -> List[Dict[str, List[int]]]: """ Instead of processing lazily, we turn the iterable into a list. """ if sample is None: return None return list(self.process_sample(sample)) def process_sample_(self, sample) -> List[Dict[str, List[int]]]: prompt_sample = self._preprocess(sample) return self.process_sample_strict(prompt_sample) def report(self): duration = perf_counter() - self.start_time process_speed = self.doc_processed * 1.0 / duration gen_speed = self.doc_generated * 1.0 / duration print(f">>> processed: {self.doc_processed} in {duration:.2f}s, speed: {process_speed:.2f} docs/s") print(f"... generated: {self.doc_generated} in {duration:.2f}s, speed: {gen_speed:.2f} docs/s")