-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathtest.py
More file actions
131 lines (93 loc) · 3.53 KB
/
test.py
File metadata and controls
131 lines (93 loc) · 3.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# from learner_impl import *
from stclassify.classifier import *
from stclassify.text_process import *
# from bert_serving.client import BertClient
custom_tokenize = None
train_svm_file = None
delimiter='\t'
name = 'test_model'
# train_src = [
# ('education', '名师指导托福语法技巧:名词的复数形式'),
# ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
# ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
# ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与')
# ]
dir = './data'
# train_src = os.path.join(dir, 'train_amber')
# test_src = os.path.join(dir, 'test_amber')
train_src = os.path.join(dir, 'insurance_train')
test_src = os.path.join(dir, 'insurance_test')
# train_src = os.path.join(dir, 'train_chs')
# test_src = os.path.join(dir, 'test_chs')
# train_src = os.path.join(dir, 'train_src_youdao')
# test_src = os.path.join(dir, 'test_src')
'''
data expansion
'''
# from data_expansion import data_expansion
# lines = []
# with open(train_src, mode='r') as file:
# lines_train = file.readlines()
#
# with open(test_src, mode='r') as file:
# lines_test = file.readlines()
#
# train_src = train_src + '_exp'
# test_src = test_src + '_exp'
#
# with open(train_src, mode='w') as file:
# for line in lines_train:
# line = line.split('\t')
# # file.write(line[0] + '\t' + line[1].strip() + '\n')
# for sent in data_expansion(line[1].strip(), num_aug=2):
# file.write(line[0] + '\t' + sent + '\n')
#
# with open(test_src, mode='w') as file:
# for line in lines_test:
# line = line.split('\t')
# # file.write(line[0] + '\t' + line[1].strip() + '\n')
# for sent in data_expansion(line[1].strip(), num_aug=2):
# file.write(line[0] + '\t' + sent + '\n')
text_converter = GroceryTextConverter(custom_tokenize=custom_tokenize)
train_svm_file = '%s_train.svm' % name
#
#
# text_converter.convert_text(train_src, output=train_svm_file, delimiter=' ')
text_converter.set_text_parameters(
# keywords_mode=True,
# POS_mode=True,
# extend_new_text=True,
# ngram_extend_mode=True
)
text_converter.convert_text(train_src, output=train_svm_file, delimiter='\t')
'''
-s 4 多分类 大数据量
accuracy recall
neg 76.76% 79.43%
pos 94.07% 95.86%
neu 59.28% 41.95%
0.9003548479907931
0.9005466577155462
-N 1 -T 0 , '-s 4 -c 1.' extend enabled
-s 5 -c 1.1 小数据量最好,加数据扩展最好
'''
model = train(train_svm_file, '-N 1 -T 0', '-s 4 -c 1.') #4, 5
model = GroceryTextModel(text_converter, model)
model.save('sentiment', force=True)
def load(name):
text_converter = GroceryTextConverter()
model = GroceryTextModel(text_converter)
model.load(name)
return model
model = load('sentiment')
# model = load('test')
single_text = '中国高考成绩海外认可 是“狼来了”吗'
# r = model.predict_text(single_text)
# print(r)
# test_src = preprocess_data('./text_feature_extract/data/eye_shadow/')
test_result = GroceryTest(model).test(text_src=test_src,delimiter='\t')
# test_result = GroceryTest(model).test(text_src=test_src,delimiter='\t')
print(test_result.accuracy_labels)
print(test_result.recall_labels)
test_result.show_result()
print(test_result)