99
1010from fastNLP .api .utils import load_url
1111from fastNLP .api .processor import ModelProcessor
12- from reproduction .chinese_word_segment .cws_io .cws_reader import ConllCWSReader
13- from reproduction .pos_tag_model .pos_reader import ZhConllPOSReader
14- from reproduction .Biaffine_parser .util import ConllxDataLoader , add_seg_tag
12+ from fastNLP .io .dataset_loader import ConllCWSReader , ConllxDataLoader
1513from fastNLP .core .instance import Instance
1614from fastNLP .api .pipeline import Pipeline
1715from fastNLP .core .metrics import SpanFPreRecMetric
1816from fastNLP .api .processor import IndexerProcessor
1917
2018# TODO add pretrain urls
2119model_urls = {
22- "cws" : "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899 .pkl" ,
23- "pos" : "http://123.206.98.91:8888/download/pos_tag_model_20190108-f3c60ee5 .pkl" ,
24- "parser" : "http://123.206.98.91:8888/download/biaffine_parser-3a2f052c .pkl"
20+ "cws" : "http://123.206.98.91:8888/download/cws_lstm_ctb9_1_20-09908656 .pkl" ,
21+ "pos" : "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435 .pkl" ,
22+ "parser" : "http://123.206.98.91:8888/download/parser_20190204-c72ca5c0 .pkl"
2523}
2624
2725
@@ -31,6 +29,16 @@ def __init__(self):
3129 self ._dict = None
3230
3331 def predict (self , * args , ** kwargs ):
32+ """Do prediction for the given input.
33+ """
34+ raise NotImplementedError
35+
36+ def test (self , file_path ):
37+ """Test performance over the given data set.
38+
39+ :param str file_path:
40+ :return: a dictionary of metric values
41+ """
3442 raise NotImplementedError
3543
3644 def load (self , path , device ):
@@ -69,12 +77,11 @@ def predict(self, content):
6977 if not hasattr (self , "pipeline" ):
7078 raise ValueError ("You have to load model first." )
7179
72- sentence_list = []
80+ sentence_list = content
7381 # 1. 检查sentence的类型
74- if isinstance (content , str ):
75- sentence_list .append (content )
76- elif isinstance (content , list ):
77- sentence_list = content
82+ for sentence in sentence_list :
83+ if not all ((type (obj ) == str for obj in sentence )):
84+ raise ValueError ("Input must be list of list of string." )
7885
7986 # 2. 组建dataset
8087 dataset = DataSet ()
@@ -83,36 +90,28 @@ def predict(self, content):
8390 # 3. 使用pipeline
8491 self .pipeline (dataset )
8592
86- def decode_tags (ins ):
87- pred_tags = ins ["tag" ]
88- chars = ins ["words" ]
89- words = []
90- start_idx = 0
91- for idx , tag in enumerate (pred_tags ):
92- if tag [0 ] == "S" :
93- words .append (chars [start_idx :idx + 1 ] + "/" + tag [2 :])
94- start_idx = idx + 1
95- elif tag [0 ] == "E" :
96- words .append ("" .join (chars [start_idx :idx + 1 ]) + "/" + tag [2 :])
97- start_idx = idx + 1
98- return words
99-
100- dataset .apply (decode_tags , new_field_name = "tag_output" )
101-
102- output = dataset .field_arrays ["tag_output" ].content
93+ def merge_tag (words_list , tags_list ):
94+ rtn = []
95+ for words , tags in zip (words_list , tags_list ):
96+ rtn .append ([w + "/" + t for w , t in zip (words , tags )])
97+ return rtn
98+
99+ output = dataset .field_arrays ["tag" ].content
103100 if isinstance (content , str ):
104101 return output [0 ]
105102 elif isinstance (content , list ):
106- return output
103+ return merge_tag ( content , output )
107104
108105 def test (self , file_path ):
109- test_data = ZhConllPOSReader ().load (file_path )
106+ test_data = ConllxDataLoader ().load (file_path )
110107
111- tag_vocab = self ._dict ["tag_vocab" ]
112- pipeline = self ._dict ["pipeline" ]
108+ save_dict = self ._dict
109+ tag_vocab = save_dict ["tag_vocab" ]
110+ pipeline = save_dict ["pipeline" ]
113111 index_tag = IndexerProcessor (vocab = tag_vocab , field_name = "tag" , new_added_field_name = "truth" , is_input = False )
114112 pipeline .pipeline = [index_tag ] + pipeline .pipeline
115113
114+ test_data .rename_field ("pos_tags" , "tag" )
116115 pipeline (test_data )
117116 test_data .set_target ("truth" )
118117 prediction = test_data .field_arrays ["predict" ].content
@@ -226,7 +225,7 @@ def test(self, filepath):
226225 rec = eval_res ['BMESF1PreRecMetric' ]['rec' ]
227226 # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))
228227
229- return f1 , pre , rec
228+ return { "F1" : f1 , "precision" : pre , "recall" : rec }
230229
231230
232231class Parser (API ):
@@ -251,6 +250,7 @@ def predict(self, content):
251250 dataset .add_field ('wp' , pos_out )
252251 dataset .apply (lambda x : ['<BOS>' ] + [w .split ('/' )[0 ] for w in x ['wp' ]], new_field_name = 'words' )
253252 dataset .apply (lambda x : ['<BOS>' ] + [w .split ('/' )[1 ] for w in x ['wp' ]], new_field_name = 'pos' )
253+ dataset .rename_field ("words" , "raw_words" )
254254
255255 # 3. 使用pipeline
256256 self .pipeline (dataset )
@@ -260,39 +260,82 @@ def predict(self, content):
260260 # output like: [['2/top', '0/root', '4/nn', '2/dep']]
261261 return dataset .field_arrays ['output' ].content
262262
263- def test (self , filepath ):
264- data = ConllxDataLoader ().load (filepath )
265- ds = DataSet ()
266- for ins1 , ins2 in zip (add_seg_tag (data ), data ):
267- ds .append (Instance (words = ins1 [0 ], tag = ins1 [1 ],
268- gold_words = ins2 [0 ], gold_pos = ins2 [1 ],
269- gold_heads = ins2 [2 ], gold_head_tags = ins2 [3 ]))
263+ def load_test_file (self , path ):
264+ def get_one (sample ):
265+ sample = list (map (list , zip (* sample )))
266+ if len (sample ) == 0 :
267+ return None
268+ for w in sample [7 ]:
269+ if w == '_' :
270+ print ('Error Sample {}' .format (sample ))
271+ return None
272+ # return word_seq, pos_seq, head_seq, head_tag_seq
273+ return sample [1 ], sample [3 ], list (map (int , sample [6 ])), sample [7 ]
274+
275+ datalist = []
276+ with open (path , 'r' , encoding = 'utf-8' ) as f :
277+ sample = []
278+ for line in f :
279+ if line .startswith ('\n ' ):
280+ datalist .append (sample )
281+ sample = []
282+ elif line .startswith ('#' ):
283+ continue
284+ else :
285+ sample .append (line .split ('\t ' ))
286+ if len (sample ) > 0 :
287+ datalist .append (sample )
288+
289+ data = [get_one (sample ) for sample in datalist ]
290+ data_list = list (filter (lambda x : x is not None , data ))
291+ return data_list
270292
293+ def test (self , filepath ):
294+ data = self .load_test_file (filepath )
295+
296+ def convert (data ):
297+ BOS = '<BOS>'
298+ dataset = DataSet ()
299+ for sample in data :
300+ word_seq = [BOS ] + sample [0 ]
301+ pos_seq = [BOS ] + sample [1 ]
302+ heads = [0 ] + sample [2 ]
303+ head_tags = [BOS ] + sample [3 ]
304+ dataset .append (Instance (raw_words = word_seq ,
305+ pos = pos_seq ,
306+ gold_heads = heads ,
307+ arc_true = heads ,
308+ tags = head_tags ))
309+ return dataset
310+
311+ ds = convert (data )
271312 pp = self .pipeline
272313 for p in pp :
273314 if p .field_name == 'word_list' :
274315 p .field_name = 'gold_words'
275316 elif p .field_name == 'pos_list' :
276317 p .field_name = 'gold_pos'
318+ # ds.rename_field("words", "raw_words")
319+ # ds.rename_field("tag", "pos")
277320 pp (ds )
278321 head_cor , label_cor , total = 0 , 0 , 0
279322 for ins in ds :
280323 head_gold = ins ['gold_heads' ]
281- head_pred = ins ['heads ' ]
324+ head_pred = ins ['arc_pred ' ]
282325 length = len (head_gold )
283326 total += length
284327 for i in range (length ):
285328 head_cor += 1 if head_pred [i ] == head_gold [i ] else 0
286329 uas = head_cor / total
287- print ('uas:{:.2f}' .format (uas ))
330+ # print('uas:{:.2f}'.format(uas))
288331
289332 for p in pp :
290333 if p .field_name == 'gold_words' :
291334 p .field_name = 'word_list'
292335 elif p .field_name == 'gold_pos' :
293336 p .field_name = 'pos_list'
294337
295- return uas
338+ return { "USA" : round ( uas , 5 )}
296339
297340
298341class Analyzer :
0 commit comments