forked from sleepinyourhat/vector-entailment
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTestModel.m
More file actions
61 lines (49 loc) · 2.5 KB
/
TestModel.m
File metadata and controls
61 lines (49 loc) · 2.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
% Want to distribute this code? Have other questions? -> sbowman@stanford.edu
function [combinedAcc, combinedMf1, aggConfusion, combinedConAcc] = TestModel(CostGradFunc, theta, thetaDecoder, testDatasets, separateWordFeatures, hyperParams)
% Test on a collection of test sets.
% NOTE: Currently results are reported in three accuracy figures and three MacroAvgF1 figures:
% 1. Performance on the #1 test set. Used, for instance, with SICK.
% 2. Aggregate performance on the first hyperParams.firstSplit test sets. This is not used much right now.
% 3. Aggregate performance across all of the test sets. Used, for instance, with and/or and quantification.
% TODO: Currently, I only aggregate average test statistics across test datasets that use the no. 1 set
% of labels.
if isfield(hyperParams, 'testLabelIndices')
targetLabelSet = hyperParams.testLabelIndices(1);
else
targetLabelSet = 1;
end
aggConfusion = zeros(hyperParams.numLabels(targetLabelSet));
targetConfusion = zeros(hyperParams.numLabels(targetLabelSet));
aggConAcc = [];
for i = 1:length(testDatasets{1})
if length(testDatasets{2}{i}) == 0
continue
end
[ ~, ~, ~, acc, conAcc, confusion ] = CostGradFunc(theta, thetaDecoder, testDatasets{2}{i}, separateWordFeatures, hyperParams, 0);
if conAcc(1) ~= -1
aggConAcc = [aggConAcc, conAcc];
end
if i == 1
targetConfusion = confusion;
targetConAcc = conAcc;
end
if hyperParams.showDetailedStats && acc > 0
Log(hyperParams.examplelog, ['For test data: ', testDatasets{1}{i}, ': ', num2str(acc), ' (', num2str(GetMacroF1(confusion)), ')']);
if conAcc ~= -1
Log(hyperParams.examplelog, ['Connection accuracy: ', num2str(conAcc(1, :)), ' std ', num2str(conAcc(2, :))]);
end
conf_msg = sprintf('\n%s', evalc('disp(confusion)'));
Log(hyperParams.examplelog, conf_msg);
end
if (~isfield(hyperParams, 'testLabelIndices') || hyperParams.testLabelIndices(i) == targetLabelSet)
aggConfusion = aggConfusion + confusion;
end
end
% Compute Accor rate from aggregate confusion matrix
targetAcc = sum(sum(eye(hyperParams.numLabels(targetLabelSet)) .* targetConfusion)) / sum(sum(targetConfusion));
aggAcc = sum(sum(eye(hyperParams.numLabels(targetLabelSet)) .* aggConfusion)) / sum(sum(aggConfusion));
aggConAcc = mean(aggConAcc, 2);
combinedMf1 = [GetMacroF1(targetConfusion), GetMacroF1(aggConfusion)];
combinedAcc = [targetAcc, aggAcc];
combinedConAcc = [targetConAcc, aggConAcc];
end