From 751ff35eb5faa6460038bb20a1ef6bfcf29f440a Mon Sep 17 00:00:00 2001 From: Dingyuan Wang Date: Fri, 31 Oct 2014 23:15:51 +0800 Subject: [PATCH] improve extract_tags; unify extract_tags and testrank --- README.md | 21 ++++++++++----------- jieba/analyse/__init__.py | 36 ++++++++++++++++++++---------------- jieba/analyse/textrank.py | 31 +++++++++++++++++++++++-------- jieba/finalseg/__init__.py | 11 ++++------- jieba/posseg/__init__.py | 11 ++++------- 5 files changed, 61 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 21ad8ec..d63bb53 100644 --- a/README.md +++ b/README.md @@ -153,17 +153,16 @@ jieba.analyse.textrank(raw_text) 来自`__main__`的示例结果: ``` -吉林 100.0 -欧亚 86.4592606421 -置业 55.3262889963 -实现 52.0353476663 -收入 37.9475518129 -增资 35.5042189944 -子公司 34.9286032861 -全资 30.8154823412 -城市 30.6031961172 -商业 30.4779050167 - +吉林 1.0 +欧亚 0.864834432786 +置业 0.553465925497 +实现 0.520660869531 +收入 0.379699688954 +增资 0.355086023683 +子公司 0.349758490263 +全资 0.308537396283 +城市 0.306103738053 +商业 0.304837414946 ``` 4) : 词性标注 diff --git a/jieba/analyse/__init__.py b/jieba/analyse/__init__.py index 535b6f6..95c5b14 100755 --- a/jieba/analyse/__init__.py +++ b/jieba/analyse/__init__.py @@ -1,6 +1,7 @@ #encoding=utf-8 import jieba import os +from operator import itemgetter try: from analyzer import ChineseAnalyzer except ImportError: @@ -26,13 +27,11 @@ class IDFLoader: if self.path != new_idf_path: content = open(new_idf_path, 'rb').read().decode('utf-8') idf_freq = {} - lines = content.split('\n') - if lines and not lines[-1]: - lines.pop(-1) + lines = content.rstrip('\n').split('\n') for line in lines: word, freq = line.split(' ') idf_freq[word] = float(freq) - median_idf = sorted(idf_freq.values())[len(idf_freq)/2] + median_idf = sorted(idf_freq.values())[len(idf_freq)//2] self.idf_freq = idf_freq self.median_idf = median_idf self.path = new_idf_path @@ -60,27 +59,32 @@ def set_stop_words(stop_words_path): STOP_WORDS.add(line) def extract_tags(sentence, topK=20, withWeight=False): - global STOP_WORDS + """ + Extract keywords from sentence using TF-IDF algorithm. + Parameter: + - topK: return how many top keywords. `None` for all possible words. + - withWeight: if True, return a list of (word, weight); + if False, return a list of words. + """ + global STOP_WORDS, idf_loader idf_freq, median_idf = idf_loader.get_idf() words = jieba.cut(sentence) freq = {} for w in words: - if len(w.strip()) < 2: - continue - if w.lower() in STOP_WORDS: + if len(w.strip()) < 2 or w.lower() in STOP_WORDS: continue freq[w] = freq.get(w, 0.0) + 1.0 total = sum(freq.values()) - freq = [(k,v/total) for k,v in freq.iteritems()] - - tf_idf_list = [(v*idf_freq.get(k,median_idf), k) for k,v in freq] - st_list = sorted(tf_idf_list, reverse=True) + for k in freq: + freq[k] *= idf_freq.get(k, median_idf) / total if withWeight: - tags = st_list[:topK] + tags = sorted(freq.items(), key=itemgetter(1), reverse=True) + else: + tags = sorted(freq, key=freq.__getitem__, reverse=True) + if topK: + return tags[:topK] else: - top_tuples = st_list[:topK] - tags = [a[1] for a in top_tuples] - return tags + return tags diff --git a/jieba/analyse/textrank.py b/jieba/analyse/textrank.py index 46a8527..9ac9ece 100644 --- a/jieba/analyse/textrank.py +++ b/jieba/analyse/textrank.py @@ -1,9 +1,10 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import jieba.posseg as pseg -import collections import sys +import collections +from operator import itemgetter +import jieba.posseg as pseg class UndirectWeightedGraph: d = 0.85 @@ -41,17 +42,25 @@ class UndirectWeightedGraph: max_rank = w for n, w in ws.items(): - ws[n] = (w - min_rank / 10.0) / (max_rank - min_rank / 10.0) * 100 + # to unify the weights, don't *100. + ws[n] = (w - min_rank / 10.0) / (max_rank - min_rank / 10.0) return ws -def textrank(raw, topk=10): +def textrank(sentence, topK=10, withWeight=False): + """ + Extract keywords from sentence using TextRank algorithm. + Parameter: + - topK: return how many top keywords. `None` for all possible words. + - withWeight: if True, return a list of (word, weight); + if False, return a list of words. + """ pos_filt = frozenset(('ns', 'n', 'vn', 'v')) g = UndirectWeightedGraph() cm = collections.defaultdict(int) span = 5 - words = [x for x in pseg.cut(raw)] + words = list(pseg.cut(sentence)) for i in xrange(len(words)): if words[i].flag in pos_filt: for j in xrange(i + 1, i + span): @@ -65,10 +74,16 @@ def textrank(raw, topk=10): g.addEdge(terms[0], terms[1], w) nodes_rank = g.rank() - nrs = sorted(nodes_rank.items(), key=lambda x: x[1], reverse=True) - return nrs[:topk] + if withWeight: + tags = sorted(nodes_rank.items(), key=itemgetter(1), reverse=True) + else: + tags = sorted(nodes_rank, key=nodes_rank.__getitem__, reverse=True) + if topK: + return tags[:topK] + else: + return tags if __name__ == '__main__': s = "此外,公司拟对全资子公司吉林欧亚置业有限公司增资4.3亿元,增资后,吉林欧亚置业注册资本由7000万元增加到5亿元。吉林欧亚置业主要经营范围为房地产开发及百货零售等业务。目前在建吉林欧亚城市商业综合体项目。2013年,实现营业收入0万元,实现净利润-139.13万元。" - for x, w in textrank(s): + for x, w in textrank(s, withWeight=True): print x, w diff --git a/jieba/finalseg/__init__.py b/jieba/finalseg/__init__.py index fa47268..3f24e51 100644 --- a/jieba/finalseg/__init__.py +++ b/jieba/finalseg/__init__.py @@ -19,25 +19,22 @@ PrevStatus = { } def load_model(): - _curpath=os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) + _curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) start_p = {} abs_path = os.path.join(_curpath, PROB_START_P) - with open(abs_path, mode='r') as f: + with open(abs_path, 'rb') as f: start_p = marshal.load(f) - f.closed trans_p = {} abs_path = os.path.join(_curpath, PROB_TRANS_P) - with open(abs_path, 'r') as f: + with open(abs_path, 'rb') as f: trans_p = marshal.load(f) - f.closed emit_p = {} abs_path = os.path.join(_curpath, PROB_EMIT_P) - with open(abs_path, 'r') as f: + with open(abs_path, 'rb') as f: emit_p = marshal.load(f) - f.closed return start_p, trans_p, emit_p diff --git a/jieba/posseg/__init__.py b/jieba/posseg/__init__.py index a048d22..30160d4 100644 --- a/jieba/posseg/__init__.py +++ b/jieba/posseg/__init__.py @@ -25,27 +25,24 @@ def load_model(f_name, isJython=True): continue word, _, tag = line.split(' ') result[word.decode('utf-8')] = tag - f.closed + if not isJython: return result start_p = {} abs_path = os.path.join(_curpath, PROB_START_P) - with open(abs_path, mode='r') as f: + with open(abs_path, 'rb') as f: start_p = marshal.load(f) - f.closed trans_p = {} abs_path = os.path.join(_curpath, PROB_TRANS_P) - with open(abs_path, 'r') as f: + with open(abs_path, 'rb') as f: trans_p = marshal.load(f) - f.closed emit_p = {} abs_path = os.path.join(_curpath, PROB_EMIT_P) - with open(abs_path, 'r') as f: + with open(abs_path, 'rb') as f: emit_p = marshal.load(f) - f.closed state = {} abs_path = os.path.join(_curpath, CHAR_STATE_TAB_P)