forked from thunlp/OpenPrompt
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_text_classification_dataset.py
More file actions
75 lines (63 loc) · 3.06 KB
/
test_text_classification_dataset.py
File metadata and controls
75 lines (63 loc) · 3.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os, sys
from os.path import dirname as d
from os.path import abspath, join
root_dir = d(d(d(abspath(__file__))))
sys.path.append(root_dir)
from openprompt.data_utils.text_classification_dataset import PROCESSORS
base_path = os.path.join(root_dir, "datasets/TextClassification")
def test_AgnewsProcessor():
dataset_name = "agnews"
dataset_path = os.path.join(base_path, dataset_name)
processor = PROCESSORS[dataset_name.lower()]()
trainvalid_dataset = processor.get_train_examples(dataset_path)
test_dataset = processor.get_test_examples(dataset_path)
assert processor.get_num_labels() == 4
assert processor.get_labels() == ["World", "Sports", "Business", "Tech"]
assert len(trainvalid_dataset) == 120000
assert len(test_dataset) == 7600
assert test_dataset[0].text_a == "Fears for T N pension after talks"
assert test_dataset[0].text_b == "Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul."
assert test_dataset[0].label == 2
def test_DBpediaProcessor():
dataset_name = "dbpedia"
dataset_path = os.path.join(base_path, dataset_name)
processor = PROCESSORS[dataset_name.lower()]()
trainvalid_dataset = processor.get_train_examples(dataset_path)
test_dataset = processor.get_test_examples(dataset_path)
assert processor.get_num_labels() == 14
assert len(trainvalid_dataset) == 560000
assert len(test_dataset) == 70000
def test_ImdbProcessor():
dataset_name = "imdb"
dataset_path = os.path.join(base_path, dataset_name)
processor = PROCESSORS[dataset_name.lower()]()
trainvalid_dataset = processor.get_train_examples(dataset_path)
test_dataset = processor.get_test_examples(dataset_path)
assert processor.get_num_labels() == 2
assert len(trainvalid_dataset) == 25000
assert len(test_dataset) == 25000
# def test_AmazonProcessor():
# dataset_name = "amazon"
# dataset_path = os.path.join(base_path, dataset_name)
# processor = PROCESSORS[dataset_name.lower()](dataset_path)
# trainvalid_dataset = processor.get_train_examples(dataset_path)
# test_dataset = processor.get_test_examples(dataset_path)
# assert processor.get_num_labels() == 2
# assert len(trainvalid_dataset) == 3600000
# assert len(test_dataset) == 400000
def test_SST2Processor():
dataset_name = "SST-2"
dataset_path = os.path.join(base_path, dataset_name)
processor = PROCESSORS[dataset_name.lower()]()
train_dataset = processor.get_train_examples(dataset_path)
dev_dataset = processor.get_dev_examples(dataset_path)
test_dataset = processor.get_test_examples(dataset_path)
assert processor.get_num_labels() == 2
assert processor.get_labels() == [0,1]
assert len(train_dataset) == 6920
assert len(dev_dataset) == 872
assert len(test_dataset) == 1821
assert train_dataset[0].text_a == 'a stirring , funny and finally transporting re-imagining of beauty and the beast and 1930s horror films'
assert train_dataset[0].label == 1
if __name__ == "__main__":
test_SST2Processor()