-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLimitedBackoffStrategy.cpp
More file actions
96 lines (74 loc) · 2.21 KB
/
LimitedBackoffStrategy.cpp
File metadata and controls
96 lines (74 loc) · 2.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
/*
* LimitedBackoffStrategy.cpp
*
* Created on: Dec 16, 2016
* Author: onrust
*/
#include "LimitedBackoffStrategy.h"
#include "Logging.h"
#include "InterpolationStrategy.h"
namespace SLM {
LimitedBackoffStrategy::LimitedBackoffStrategy(SLM::LanguageModel& languageModel, const std::string& baseFileName, SLM::InterpolationStrategy* interpolationStrategy, bool ignoreCache)
: FullBackoffStrategy(languageModel, baseFileName, interpolationStrategy), ignoreCache(ignoreCache)
{
}
void LimitedBackoffStrategy::init(SLM::LanguageModel& languageModel, const std::string& baseFileName)
{
openFiles(languageModel, baseFileName);
std::string cacheFileName = baseFileName + "_" + name() + "." + cacheExtension;
L_V << "LimitedBackoffStrategy: (" << name() << ")" << std::setw(30) << "Cache output file:" << cacheFileName << "\n";
cacheOutputFile.open(cacheFileName);
}
LimitedBackoffStrategy::~LimitedBackoffStrategy() {
cacheOutputFile.flush();
cacheOutputFile.close();
}
std::string LimitedBackoffStrategy::name() const
{
return "limited" + interpolationStrategy->name();
}
double LimitedBackoffStrategy::prob(const Pattern& context, const Pattern& focus, bool isOOV)
{
L_S << "LimitedBackoffStrategy: Estimating prob for " << languageModel.toString(context)
<< " " << languageModel.toString(focus) << "\n";
double logProb = 0.0;
if(!isOOV)
{
// implement skipgrams from layer 3 on
double prob = languageModel.getProbLS4(focus, context, this, interpolationStrategy, normalisationCache);
logProb = log2(prob);
++sentCount;
sentLLH -= logProb;
} else
{
++sentOovs;
}
L_S << "LimitedBackoffStrategy: \t\t" << logProb << "\n";
writeProbToFile(focus, context, logProb, isOOV);
return logProb;
}
double LimitedBackoffStrategy::getNormalisationFactor(const Pattern& context)
{
// std::map<Pattern, double>::const_iterator it = weights.find(context);
// if(it == weights.end())
// {
// // for each pattern
// // if in crp: N++
// // if not in crp: B++
// // probs += prob
// unsigned long N = 0;
// unsigned long B = 0;
// double probs = 0.0;
//
// for(Pattern p: languageModel.getVocabulary())
// {
//
// }
// }
// else
// {
// return it.second;
// }
return 0.0;
}
} /* namespace SLM */