load default dictionary from pkg_resources and improve the loading method;

change the serialized models from marshal to pickle
pull/309/head
Dingyuan Wang 9 years ago
parent 70f019b669
commit 8814e08f9b

@ -20,11 +20,10 @@ if os.name == 'nt':
else:
_replace_file = os.rename
_get_module_path = lambda path: os.path.normpath(os.path.join(os.getcwd(),
os.path.dirname(__file__), path))
_get_abs_path = lambda path: os.path.normpath(os.path.join(os.getcwd(), path))
DEFAULT_DICT = _get_module_path("dict.txt")
DEFAULT_DICT = None
DEFAULT_DICT_NAME = "dict.txt"
log_console = logging.StreamHandler(sys.stderr)
default_logger = logging.getLogger(__name__)
@ -54,7 +53,10 @@ class Tokenizer(object):
def __init__(self, dictionary=DEFAULT_DICT):
self.lock = threading.RLock()
self.dictionary = _get_abs_path(dictionary)
if dictionary == DEFAULT_DICT:
self.dictionary = dictionary
else:
self.dictionary = _get_abs_path(dictionary)
self.FREQ = {}
self.total = 0
self.user_word_tag_tab = {}
@ -65,10 +67,11 @@ class Tokenizer(object):
def __repr__(self):
return '<Tokenizer dictionary=%r>' % self.dictionary
def gen_pfdict(self, f_name):
def gen_pfdict(self, f):
lfreq = {}
ltotal = 0
with open(f_name, 'rb') as f:
f_name = resolve_filename(f)
with f:
for lineno, line in enumerate(f, 1):
try:
line = line.strip().decode('utf-8')
@ -105,7 +108,7 @@ class Tokenizer(object):
if self.initialized:
return
default_logger.debug("Building prefix dict from %s ..." % abs_path)
default_logger.debug("Building prefix dict from %s ..." % (abs_path or 'the default dictionary'))
t1 = time.time()
if self.cache_file:
cache_file = self.cache_file
@ -122,7 +125,8 @@ class Tokenizer(object):
tmpdir = os.path.dirname(cache_file)
load_from_cache_fail = True
if os.path.isfile(cache_file) and os.path.getmtime(cache_file) > os.path.getmtime(abs_path):
if os.path.isfile(cache_file) and (abs_path == DEFAULT_DICT or
os.path.getmtime(cache_file) > os.path.getmtime(abs_path)):
default_logger.debug(
"Loading model from cache %s" % cache_file)
try:
@ -136,7 +140,7 @@ class Tokenizer(object):
wlock = DICT_WRITING.get(abs_path, threading.RLock())
DICT_WRITING[abs_path] = wlock
with wlock:
self.FREQ, self.total = self.gen_pfdict(abs_path)
self.FREQ, self.total = self.gen_pfdict(self.get_dict_file())
default_logger.debug(
"Dumping model to file cache %s" % cache_file)
try:
@ -343,8 +347,11 @@ class Tokenizer(object):
def _lcut_for_search_no_hmm(self, sentence):
return self.lcut_for_search(sentence, False)
def get_abs_path_dict(self):
return _get_abs_path(self.dictionary)
def get_dict_file(self):
if self.dictionary == DEFAULT_DICT:
return get_module_res(DEFAULT_DICT_NAME)
else:
return open(self.dictionary, 'rb')
def load_userdict(self, f):
'''
@ -363,14 +370,17 @@ class Tokenizer(object):
'''
self.check_initialized()
if isinstance(f, string_types):
f_name = f
f = open(f, 'rb')
else:
f_name = resolve_filename(f)
for lineno, ln in enumerate(f, 1):
line = ln.strip()
if not isinstance(line, text_type):
try:
line = line.decode('utf-8').lstrip('\ufeff')
except UnicodeDecodeError:
raise ValueError('dictionary file %s must be utf-8' % f.name)
raise ValueError('dictionary file %s must be utf-8' % f_name)
if not line:
continue
# match won't be None because there's at least one character
@ -494,7 +504,7 @@ cut_for_search = dt.cut_for_search
lcut_for_search = dt.lcut_for_search
del_word = dt.del_word
get_DAG = dt.get_DAG
get_abs_path_dict = dt.get_abs_path_dict
get_dict_file = dt.get_dict_file
initialize = dt.initialize
load_userdict = dt.load_userdict
set_dictionary = dt.set_dictionary

@ -1,6 +1,15 @@
# -*- coding: utf-8 -*-
import os
import sys
try:
import pkg_resources
get_module_res = lambda *res: pkg_resources.resource_stream(__name__,
os.path.join(*res))
except ImportError:
get_module_res = lambda *res: open(os.path.normpath(os.path.join(
os.getcwd(), os.path.dirname(__file__), *res)), 'rb')
PY2 = sys.version_info[0] == 2
default_encoding = sys.getfilesystemencoding()
@ -29,3 +38,9 @@ def strdecode(sentence):
except UnicodeDecodeError:
sentence = sentence.decode('gbk', 'ignore')
return sentence
def resolve_filename(f):
try:
return f.name
except AttributeError:
return repr(f)

@ -1,8 +1,8 @@
from __future__ import absolute_import, unicode_literals
import re
import os
import marshal
import sys
import pickle
from .._compat import *
MIN_FLOAT = -3.14e100
@ -21,24 +21,9 @@ PrevStatus = {
def load_model():
_curpath = os.path.normpath(
os.path.join(os.getcwd(), os.path.dirname(__file__)))
start_p = {}
abs_path = os.path.join(_curpath, PROB_START_P)
with open(abs_path, 'rb') as f:
start_p = marshal.load(f)
trans_p = {}
abs_path = os.path.join(_curpath, PROB_TRANS_P)
with open(abs_path, 'rb') as f:
trans_p = marshal.load(f)
emit_p = {}
abs_path = os.path.join(_curpath, PROB_EMIT_P)
with open(abs_path, 'rb') as f:
emit_p = marshal.load(f)
start_p = pickle.load(get_module_res("finalseg", PROB_START_P))
trans_p = pickle.load(get_module_res("finalseg", PROB_TRANS_P))
emit_p = pickle.load(get_module_res("finalseg", PROB_EMIT_P))
return start_p, trans_p, emit_p
if sys.platform.startswith("java"):

Binary file not shown.

Binary file not shown.

Binary file not shown.

@ -3,7 +3,7 @@ import os
import re
import sys
import jieba
import marshal
import pickle
from .._compat import *
from .viterbi import viterbi
@ -23,36 +23,17 @@ re_num = re.compile("[\.0-9]+")
re_eng1 = re.compile('^[a-zA-Z0-9]$', re.U)
def load_model(f_name):
_curpath = os.path.normpath(
os.path.join(os.getcwd(), os.path.dirname(__file__)))
def load_model():
# For Jython
start_p = {}
abs_path = os.path.join(_curpath, PROB_START_P)
with open(abs_path, 'rb') as f:
start_p = marshal.load(f)
trans_p = {}
abs_path = os.path.join(_curpath, PROB_TRANS_P)
with open(abs_path, 'rb') as f:
trans_p = marshal.load(f)
emit_p = {}
abs_path = os.path.join(_curpath, PROB_EMIT_P)
with open(abs_path, 'rb') as f:
emit_p = marshal.load(f)
state = {}
abs_path = os.path.join(_curpath, CHAR_STATE_TAB_P)
with open(abs_path, 'rb') as f:
state = marshal.load(f)
f.closed
return state, start_p, trans_p, emit_p, result
start_p = pickle.load(get_module_res("posseg", PROB_START_P))
trans_p = pickle.load(get_module_res("posseg", PROB_TRANS_P))
emit_p = pickle.load(get_module_res("posseg", PROB_EMIT_P))
state = pickle.load(get_module_res("posseg", CHAR_STATE_TAB_P))
return state, start_p, trans_p, emit_p
if sys.platform.startswith("java"):
char_state_tab_P, start_P, trans_P, emit_P, word_tag_tab = load_model()
char_state_tab_P, start_P, trans_P, emit_P = load_model()
else:
from .char_state_tab import P as char_state_tab_P
from .prob_start import P as start_P
@ -89,7 +70,7 @@ class POSTokenizer(object):
def __init__(self, tokenizer=None):
self.tokenizer = tokenizer or jieba.Tokenizer()
self.load_word_tag(self.tokenizer.get_abs_path_dict())
self.load_word_tag(self.tokenizer.get_dict_file())
def __repr__(self):
return '<POSTokenizer tokenizer=%r>' % self.tokenizer
@ -102,11 +83,12 @@ class POSTokenizer(object):
def initialize(self, dictionary=None):
self.tokenizer.initialize(dictionary)
self.load_word_tag(self.tokenizer.get_abs_path_dict())
self.load_word_tag(self.tokenizer.get_dict_file())
def load_word_tag(self, f_name):
def load_word_tag(self, f):
self.word_tag_tab = {}
with open(f_name, "rb") as f:
f_name = resolve_filename(f)
with f:
for lineno, line in enumerate(f, 1):
try:
line = line.strip().decode("utf-8")

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.
Loading…
Cancel
Save