From 90cd4b30148d1d68123d7715a8e20906bfed651b Mon Sep 17 00:00:00 2001 From: fxsjy Date: Tue, 6 Nov 2012 07:17:26 +0800 Subject: [PATCH] improve POS tagging --- jieba/__init__.py | 8 +++- jieba/posseg/__init__.py | 79 ++++++++++++++++++++++++++++++++++++---- jieba/posseg/viterbi.py | 13 +++++-- setup.py | 2 +- test/test.txt | 2 + test/test_file.py | 20 ++++++++++ test/test_pos_file.py | 20 ++++++++++ 7 files changed, 130 insertions(+), 14 deletions(-) create mode 100644 test/test.txt create mode 100644 test/test_file.py create mode 100644 test/test_pos_file.py diff --git a/jieba/__init__.py b/jieba/__init__.py index c994e47..9131ac2 100644 --- a/jieba/__init__.py +++ b/jieba/__init__.py @@ -82,7 +82,7 @@ def calc(sentence,DAG,idx,route): candidates = [ ( FREQ.get(sentence[idx:x+1],min_freq) * route[x+1][0],x ) for x in DAG[idx] ] route[idx] = max(candidates) -def __cut_DAG(sentence): +def get_DAG(sentence): N = len(sentence) i,j=0,0 p = trie @@ -107,11 +107,15 @@ def __cut_DAG(sentence): for i in xrange(len(sentence)): if not i in DAG: DAG[i] =[i] - #pprint.pprint(DAG) + return DAG + +def __cut_DAG(sentence): + DAG = get_DAG(sentence) route ={} calc(sentence,DAG,0,route=route) x = 0 buf =u'' + N = len(sentence) while x0: + if len(buf)==1: + yield list(__cut(buf))[0] + buf=u'' + else: + regognized = __cut(buf) + for t in regognized: + yield t + buf=u'' + for w in __cut(l_word,tags_limited=True): + yield w + x =y + + if len(buf)>0: + if len(buf)==1: + yield list(__cut(buf))[0] + else: + regognized = __cut(buf) + for t in regognized: + yield t + def cut(sentence): if not ( type(sentence) is unicode): @@ -48,10 +108,15 @@ def cut(sentence): for blk in blocks: if re_han.match(blk): - for word in __cut(blk): + for word in __cut_DAG(blk): yield word else: tmp = re_skip.split(blk) for x in tmp: if x!="": - yield x + if re.match(ur"[0-9]+",x): + yield pair(x,'m') + elif re.match(ur"[a-zA-Z+#]+",x): + yield pair(x,'eng') + else: + yield pair(x,'x') diff --git a/jieba/posseg/viterbi.py b/jieba/posseg/viterbi.py index 91698cc..3943af3 100644 --- a/jieba/posseg/viterbi.py +++ b/jieba/posseg/viterbi.py @@ -5,7 +5,7 @@ def get_top_states(t_state_v,K=4): topK= sorted(items,key=operator.itemgetter(1),reverse=True)[:K] return [x[0] for x in topK] -def viterbi(obs, states, start_p, trans_p, emit_p): +def viterbi(obs, states, start_p, trans_p, emit_p,limit_tags): V = [{}] #tabular mem_path = [{}] all_states = trans_p.keys() @@ -15,19 +15,24 @@ def viterbi(obs, states, start_p, trans_p, emit_p): for t in range(1,len(obs)): V.append({}) mem_path.append({}) - prev_states =[ x for x in mem_path[t-1].keys() if len(trans_p[x])>0 ] - #print get_top_states(V[t-1]) prev_states = get_top_states(V[t-1]) + prev_states =[ x for x in mem_path[t-1].keys() if len(trans_p[x])>0 ] + tmp = prev_states + if limit_tags: + prev_states = [x for x in prev_states if x[0]==limit_tags[t-1]] + if len(prev_states)==0: + prev_states = tmp prev_states_expect_next = set( (y for x in prev_states for y in trans_p[x].keys() ) ) obs_states = states.get(obs[t],all_states) obs_states = set(obs_states) & set(prev_states_expect_next) + if limit_tags: + obs_states = [x for x in obs_states if x[0]==limit_tags[t]] if len(obs_states)==0: obs_states = all_states for y in obs_states: (prob,state ) = max([(V[t-1][y0] * trans_p[y0].get(y,0) * emit_p[y].get(obs[t],0) ,y0) for y0 in prev_states]) V[t][y] =prob mem_path[t][y] = state - last = [(V[-1][y], y) for y in mem_path[-1].keys() ] #if len(last)==0: #print obs diff --git a/setup.py b/setup.py index 1d96397..fa3d0dc 100644 --- a/setup.py +++ b/setup.py @@ -7,5 +7,5 @@ setup(name='jieba', url='http://github.com/fxsjy', packages=['jieba'], package_dir={'jieba':'jieba'}, - package_data={'jieba':['*.*','finalseg/*']} + package_data={'jieba':['*.*','finalseg/*','analyse/*','posseg/*']} ) diff --git a/test/test.txt b/test/test.txt new file mode 100644 index 0000000..cbffc45 --- /dev/null +++ b/test/test.txt @@ -0,0 +1,2 @@ +西三旗硅谷先锋小区半地下室出租,便宜可合租硅谷 +工信处女干事每月经过下属科室都要亲口交代24口交换机等技术性器件的安装工作 \ No newline at end of file diff --git a/test/test_file.py b/test/test_file.py new file mode 100644 index 0000000..2107c36 --- /dev/null +++ b/test/test_file.py @@ -0,0 +1,20 @@ +import urllib2 +import sys,time +import sys +sys.path.append("../") +import jieba + +url = sys.argv[1] +content = open(url,"rb").read() +t1 = time.time() +words = list(jieba.cut(content)) + +t2 = time.time() +tm_cost = t2-t1 + +log_f = open("1.log","wb") +for w in words: + print >> log_f, w.encode("gbk"), "/" , + +print 'speed' , len(content)/tm_cost, " bytes/second" + diff --git a/test/test_pos_file.py b/test/test_pos_file.py new file mode 100644 index 0000000..fd14a2d --- /dev/null +++ b/test/test_pos_file.py @@ -0,0 +1,20 @@ +import urllib2 +import sys,time +import sys +sys.path.append("../") +import jieba.posseg as pseg + +url = sys.argv[1] +content = open(url,"rb").read() +t1 = time.time() +words = list(pseg.cut(content)) + +t2 = time.time() +tm_cost = t2-t1 + +log_f = open("1.log","wb") +for w in words: + print >> log_f, w.encode("gbk"), "/" , + +print 'speed' , len(content)/tm_cost, " bytes/second" +