|
|
|
@ -1,3 +1,4 @@
|
|
|
|
|
import numpy as np
|
|
|
|
|
from typing import *
|
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
from transformers.models.gpt2 import GPT2TokenizerFast
|
|
|
|
@ -9,33 +10,9 @@ def encode_whitespaces(text, start_extra_id: int, max_len: int):
|
|
|
|
|
>>> encode_whitespaces('a\\n b\\n c', 10, 10)
|
|
|
|
|
'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c'
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def push_acc_space(acc_len: int, text: str):
|
|
|
|
|
if acc_len == 0:
|
|
|
|
|
return text
|
|
|
|
|
if acc_len == 1:
|
|
|
|
|
return text + ' '
|
|
|
|
|
assert acc_len <= max_len, f'Max whitespace run length {max_len}, but found {acc_len}'
|
|
|
|
|
extra_id = start_extra_id - 2 + acc_len
|
|
|
|
|
extra_token = f'<|extratoken_{extra_id}|>'
|
|
|
|
|
return text + extra_token
|
|
|
|
|
|
|
|
|
|
acc_len = 0
|
|
|
|
|
res = ''
|
|
|
|
|
for ch in text:
|
|
|
|
|
if ch == ' ':
|
|
|
|
|
acc_len += 1
|
|
|
|
|
if acc_len == max_len:
|
|
|
|
|
res = push_acc_space(acc_len, res)
|
|
|
|
|
acc_len = 0
|
|
|
|
|
else:
|
|
|
|
|
res = push_acc_space(acc_len, res)
|
|
|
|
|
acc_len = 0
|
|
|
|
|
res = res + ch
|
|
|
|
|
|
|
|
|
|
res = push_acc_space(acc_len, res)
|
|
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
for i in np.arange(max_len, 1, -1):
|
|
|
|
|
text = text.replace(" " * i, f"<|extratoken_{start_extra_id + i - 2}|>")
|
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decode_whitespaces(text: str, start_extra_id: int, max_len: int):
|
|
|
|
|