From dc2b788eb325ba191a463a7d0a38708e93617bfc Mon Sep 17 00:00:00 2001 From: vissssa Date: Thu, 9 Jan 2020 19:23:11 +0800 Subject: [PATCH] refactor: improvement check_paddle_installed (#806) --- jieba/__init__.py | 39 +++++++++++++++----------------- jieba/_compat.py | 49 ++++++++++++++++------------------------ jieba/posseg/__init__.py | 22 ++++++++---------- 3 files changed, 48 insertions(+), 62 deletions(-) diff --git a/jieba/__init__.py b/jieba/__init__.py index 72af03b..992039e 100644 --- a/jieba/__init__.py +++ b/jieba/__init__.py @@ -1,26 +1,24 @@ from __future__ import absolute_import, unicode_literals + __version__ = '0.41' __license__ = 'MIT' -import re -import os -import sys -import time -import logging import marshal +import re import tempfile import threading -from math import log +import time from hashlib import md5 -from ._compat import * +from math import log + from . import finalseg +from ._compat import * if os.name == 'nt': from shutil import move as _replace_file else: _replace_file = os.rename - _get_abs_path = lambda path: os.path.normpath(os.path.join(os.getcwd(), path)) DEFAULT_DICT = None @@ -47,10 +45,11 @@ re_han_default = re.compile("([\u4E00-\u9FD5a-zA-Z0-9+#&\._%\-]+)", re.U) re_skip_default = re.compile("(\r\n|\s)", re.U) + def setLogLevel(log_level): - global logger default_logger.setLevel(log_level) + class Tokenizer(object): def __init__(self, dictionary=DEFAULT_DICT): @@ -69,7 +68,8 @@ class Tokenizer(object): def __repr__(self): return '' % self.dictionary - def gen_pfdict(self, f): + @staticmethod + def gen_pfdict(f): lfreq = {} ltotal = 0 f_name = resolve_filename(f) @@ -128,7 +128,7 @@ class Tokenizer(object): load_from_cache_fail = True if os.path.isfile(cache_file) and (abs_path == DEFAULT_DICT or - os.path.getmtime(cache_file) > os.path.getmtime(abs_path)): + os.path.getmtime(cache_file) > os.path.getmtime(abs_path)): default_logger.debug( "Loading model from cache %s" % cache_file) try: @@ -201,7 +201,7 @@ class Tokenizer(object): eng_scan = 0 eng_buf = u'' for k, L in iteritems(dag): - if eng_scan==1 and not re_eng.match(sentence[k]): + if eng_scan == 1 and not re_eng.match(sentence[k]): eng_scan = 0 yield eng_buf if len(L) == 1 and k > old_j: @@ -219,7 +219,7 @@ class Tokenizer(object): if j > k: yield sentence[k:j + 1] old_j = j - if eng_scan==1: + if eng_scan == 1: yield eng_buf def __cut_DAG_NO_HMM(self, sentence): @@ -285,8 +285,8 @@ class Tokenizer(object): for elem in buf: yield elem - def cut(self, sentence, cut_all = False, HMM = True,use_paddle = False): - ''' + def cut(self, sentence, cut_all=False, HMM=True, use_paddle=False): + """ The main function that segments an entire sentence that contains Chinese characters into separated words. @@ -294,15 +294,12 @@ class Tokenizer(object): - sentence: The str(unicode) to be segmented. - cut_all: Model type. True for full pattern, False for accurate pattern. - HMM: Whether to use the Hidden Markov Model. - ''' - is_paddle_installed = False - if use_paddle == True: - is_paddle_installed = check_paddle_install() + """ + is_paddle_installed = check_paddle_install['is_paddle_installed'] sentence = strdecode(sentence) - if use_paddle == True and is_paddle_installed == True: + if use_paddle and is_paddle_installed: if sentence is None or sentence == "" or sentence == u"": yield sentence - return import jieba.lac_small.predict as predict results = predict.get_sent(sentence) for sent in results: diff --git a/jieba/_compat.py b/jieba/_compat.py index 58137d4..4ea3f7a 100644 --- a/jieba/_compat.py +++ b/jieba/_compat.py @@ -1,49 +1,55 @@ # -*- coding: utf-8 -*- +import logging import os import sys -import logging log_console = logging.StreamHandler(sys.stderr) default_logger = logging.getLogger(__name__) default_logger.setLevel(logging.DEBUG) + def setLogLevel(log_level): - global logger default_logger.setLevel(log_level) + +check_paddle_install = {'is_paddle_installed': False} + try: import pkg_resources + get_module_res = lambda *res: pkg_resources.resource_stream(__name__, os.path.join(*res)) except ImportError: get_module_res = lambda *res: open(os.path.normpath(os.path.join( - os.getcwd(), os.path.dirname(__file__), *res)), 'rb') + os.getcwd(), os.path.dirname(__file__), *res)), 'rb') def enable_paddle(): - import_paddle_check = False try: import paddle except ImportError: default_logger.debug("Installing paddle-tiny, please wait a minute......") os.system("pip install paddlepaddle-tiny") - try: - import paddle - except ImportError: - default_logger.debug("Import paddle error, please use command to install: pip install paddlepaddle-tiny==1.6.1." - "Now, back to jieba basic cut......") + try: + import paddle + except ImportError: + default_logger.debug( + "Import paddle error, please use command to install: pip install paddlepaddle-tiny==1.6.1." + "Now, back to jieba basic cut......") if paddle.__version__ < '1.6.1': default_logger.debug("Find your own paddle version doesn't satisfy the minimum requirement (1.6.1), " "please install paddle tiny by 'pip install --upgrade paddlepaddle-tiny', " - "or upgrade paddle full version by 'pip install --upgrade paddlepaddle (-gpu for GPU version)' ") + "or upgrade paddle full version by " + "'pip install --upgrade paddlepaddle (-gpu for GPU version)' ") else: try: import jieba.lac_small.predict as predict - import_paddle_check = True default_logger.debug("Paddle enabled successfully......") + check_paddle_install['is_paddle_installed'] = True except ImportError: default_logger.debug("Import error, cannot find paddle.fluid and jieba.lac_small.predict module. " - "Now, back to jieba basic cut......") + "Now, back to jieba basic cut......") + PY2 = sys.version_info[0] == 2 @@ -66,6 +72,7 @@ else: itervalues = lambda d: iter(d.values()) iteritems = lambda d: iter(d.items()) + def strdecode(sentence): if not isinstance(sentence, text_type): try: @@ -74,25 +81,9 @@ def strdecode(sentence): sentence = sentence.decode('gbk', 'ignore') return sentence + def resolve_filename(f): try: return f.name except AttributeError: return repr(f) - - -def check_paddle_install(): - is_paddle_installed = False - try: - import paddle - if paddle.__version__ >= '1.6.1': - is_paddle_installed = True - else: - is_paddle_installed = False - default_logger.debug("Check the paddle version is not correct, the current version is "+ paddle.__version__+"," - "please use command to install paddle: pip uninstall paddlepaddle(-gpu), " - "pip install paddlepaddle-tiny==1.6.1. Now, back to jieba basic cut......") - except ImportError: - default_logger.debug("Import paddle error, back to jieba basic cut......") - is_paddle_installed = False - return is_paddle_installed diff --git a/jieba/posseg/__init__.py b/jieba/posseg/__init__.py index 248a9a8..05d7c01 100755 --- a/jieba/posseg/__init__.py +++ b/jieba/posseg/__init__.py @@ -1,11 +1,11 @@ from __future__ import absolute_import, unicode_literals -import os + +import pickle import re -import sys + import jieba -import pickle -from .._compat import * from .viterbi import viterbi +from .._compat import * PROB_START_P = "prob_start.p" PROB_TRANS_P = "prob_trans.p" @@ -252,6 +252,7 @@ class POSTokenizer(object): def lcut(self, *args, **kwargs): return list(self.cut(*args, **kwargs)) + # default Tokenizer instance dt = POSTokenizer(jieba.dt) @@ -276,19 +277,16 @@ def cut(sentence, HMM=True, use_paddle=False): Note that this only works using dt, custom POSTokenizer instances are not supported. """ - is_paddle_installed = False - if use_paddle == True: - is_paddle_installed = check_paddle_install() - if use_paddle==True and is_paddle_installed == True: + is_paddle_installed = check_paddle_install['is_paddle_installed'] + if use_paddle and is_paddle_installed: if sentence is None or sentence == "" or sentence == u"": yield pair(None, None) - return import jieba.lac_small.predict as predict - sents,tags = predict.get_result(strdecode(sentence)) - for i,sent in enumerate(sents): + sents, tags = predict.get_result(strdecode(sentence)) + for i, sent in enumerate(sents): if sent is None or tags[i] is None: continue - yield pair(sent,tags[i]) + yield pair(sent, tags[i]) return global dt if jieba.pool is None: