Refactor encode_whitespaces

pull/53/head
Stanislas0
parent 5dcecd07bf
commit b7dde72832

@ -1,3 +1,4 @@
import numpy as np
from typing import *
from transformers import AutoTokenizer
from transformers.models.gpt2 import GPT2TokenizerFast
@ -9,33 +10,9 @@ def encode_whitespaces(text, start_extra_id: int, max_len: int):
>>> encode_whitespaces('a\\n b\\n c', 10, 10)
'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c'
"""
def push_acc_space(acc_len: int, text: str):
if acc_len == 0:
return text
if acc_len == 1:
return text + ' '
assert acc_len <= max_len, f'Max whitespace run length {max_len}, but found {acc_len}'
extra_id = start_extra_id - 2 + acc_len
extra_token = f'<|extratoken_{extra_id}|>'
return text + extra_token
acc_len = 0
res = ''
for ch in text:
if ch == ' ':
acc_len += 1
if acc_len == max_len:
res = push_acc_space(acc_len, res)
acc_len = 0
else:
res = push_acc_space(acc_len, res)
acc_len = 0
res = res + ch
res = push_acc_space(acc_len, res)
return res
for i in np.arange(max_len, 1, -1):
text = text.replace(" " * i, f"<|extratoken_{start_extra_id + i - 2}|>")
return text
def decode_whitespaces(text: str, start_extra_id: int, max_len: int):

Loading…
Cancel
Save