diff --git a/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc b/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc index b1f0fa25..2b409e36 100644 --- a/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc +++ b/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc @@ -144,6 +144,7 @@ ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContextInternal(La case TransitionType::BLANK_TO_LABEL: case TransitionType::LABEL_TO_LABEL: case TransitionType::INITIAL_LABEL: + case TransitionType::SENTENCE_END: pushToken = true; timeIncrement = not verticalLabelTransition_; break; diff --git a/src/Nn/LabelScorer/LabelScorer.hh b/src/Nn/LabelScorer/LabelScorer.hh index d23a5c76..2c681890 100644 --- a/src/Nn/LabelScorer/LabelScorer.hh +++ b/src/Nn/LabelScorer/LabelScorer.hh @@ -93,6 +93,7 @@ public: WORD_EXIT, NONWORD_EXIT, SILENCE_EXIT, + SENTENCE_END, numTypes, // must remain at the end }; @@ -173,6 +174,7 @@ protected: {"word-exit", WORD_EXIT}, {"nonword-exit", NONWORD_EXIT}, {"silence-exit", SILENCE_EXIT}, + {"sentence-end", SENTENCE_END}, }); static_assert(transitionTypeArray_.size() == TransitionType::numTypes, "transitionTypeArray size must match number of TransitionType values"); diff --git a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc index 324c5ae1..29babbdc 100644 --- a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc +++ b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc @@ -233,6 +233,7 @@ Core::Ref StatefulOnnxLabelScorer::extendedScoringContextI case LabelScorer::TransitionType::BLANK_TO_LABEL: case LabelScorer::TransitionType::LABEL_TO_LABEL: case LabelScorer::TransitionType::INITIAL_LABEL: + case LabelScorer::TransitionType::SENTENCE_END: updateState = true; break; default: diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc index 644532ef..5b3653df 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -38,7 +38,8 @@ LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis() : scoringContext(), currentToken(Nn::invalidLabelIndex), score(0.0), - trace(Core::ref(new LatticeTrace(0, {0, 0}, {}))) {} + trace(Core::ref(new LatticeTrace(0, {0, 0}, {}))), + reachedSentenceEnd(false) {} LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( LexiconfreeTimesyncBeamSearch::LabelHypothesis const& base, @@ -47,7 +48,8 @@ LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( : scoringContext(newScoringContext), currentToken(extension.nextToken), score(extension.score), - trace() { + trace(), + reachedSentenceEnd(base.reachedSentenceEnd or extension.transitionType == Nn::LabelScorer::SENTENCE_END) { Core::Ref predecessor; switch (extension.transitionType) { case Nn::LabelScorer::TransitionType::LABEL_LOOP: @@ -101,6 +103,21 @@ const Core::ParameterInt LexiconfreeTimesyncBeamSearch::paramBlankLabelIndex( "Index of the blank label in the lexicon. Can also be inferred from lexicon if it has a lemma with `special='blank'`. If not set, the search will not use blank.", Nn::invalidLabelIndex); +const Core::ParameterInt LexiconfreeTimesyncBeamSearch::paramSentenceEndLabelIndex( + "sentence-end-label-index", + "Index of the sentence end label in the lexicon. Can also be inferred from lexicon if it has a lemma with `special='sentence-end'` or `special='sentence-boundary'`. If not set, the search will not use sentence end.", + Nn::invalidLabelIndex); + +const Core::ParameterBool LexiconfreeTimesyncBeamSearch::paramAllowBlankAfterSentenceEnd( + "allow-blank-after-sentence-end", + "blanks can still be produced after the sentence-end has been reached", + true); + +const Core::ParameterBool LexiconfreeTimesyncBeamSearch::paramSentenceEndFallBack( + "sentence-end-fall-back", + "Allow for fallback solution if no active word-end hypothesis exists at the end of a segment.", + true); + const Core::ParameterBool LexiconfreeTimesyncBeamSearch::paramCollapseRepeatedLabels( "collapse-repeated-labels", "Collapse repeated emission of the same label into one output. If false, every emission is treated like a new output.", @@ -121,7 +138,11 @@ LexiconfreeTimesyncBeamSearch::LexiconfreeTimesyncBeamSearch(Core::Configuration SearchAlgorithmV2(config), maxBeamSize_(paramMaxBeamSize(config)), scoreThreshold_(paramScoreThreshold(config)), + sentenceEndFallback_(paramSentenceEndFallBack(config)), blankLabelIndex_(paramBlankLabelIndex(config)), + allowBlankAfterSentenceEnd_(paramAllowBlankAfterSentenceEnd(config)), + sentenceEndLemma_(), + sentenceEndLabelIndex_(paramSentenceEndLabelIndex(config)), collapseRepeatedLabels_(paramCollapseRepeatedLabels(config)), logStepwiseStatistics_(paramLogStepwiseStatistics(config)), cacheCleanupInterval_(paramCacheCleanupInterval(config)), @@ -148,6 +169,12 @@ LexiconfreeTimesyncBeamSearch::LexiconfreeTimesyncBeamSearch(Core::Configuration if (useBlank_) { log() << "Use blank label with index " << blankLabelIndex_; } + + useSentenceEnd_ = sentenceEndLabelIndex_ != Nn::invalidLabelIndex; + if (useSentenceEnd_) { + log() << "Use sentence end label with index " << sentenceEndLabelIndex_; + } + useScorePruning_ = scoreThreshold_ != Core::Type::max; } @@ -174,6 +201,21 @@ bool LexiconfreeTimesyncBeamSearch::setModelCombination(Speech::ModelCombination } } + sentenceEndLemma_ = lexicon_->specialLemma("sentence-end"); + if (!sentenceEndLemma_) { + sentenceEndLemma_ = lexicon_->specialLemma("sentence-boundary"); + } + if (sentenceEndLemma_) { + if (sentenceEndLabelIndex_ == Nn::invalidLabelIndex) { + sentenceEndLabelIndex_ = sentenceEndLemma_->id(); + useSentenceEnd_ = true; + log() << "Use sentence-end index " << sentenceEndLabelIndex_ << " inferred from lexicon"; + } + else if (sentenceEndLabelIndex_ != static_cast(sentenceEndLemma_->id())) { + warning() << "SentenceEnd lemma exists in lexicon with id " << sentenceEndLemma_->id() << " but is overwritten by config parameter with value " << sentenceEndLabelIndex_; + } + } + reset(); return true; } @@ -208,6 +250,7 @@ void LexiconfreeTimesyncBeamSearch::finishSegment() { labelScorer_->signalNoMoreFeatures(); featureProcessingTime_.stop(); decodeManySteps(); + finalizeHypotheses(); logStatistics(); finishedSegment_ = true; } @@ -283,6 +326,14 @@ bool LexiconfreeTimesyncBeamSearch::decodeStep() { const Bliss::Lemma* lemma(*lemmaIt); Nn::LabelIndex tokenIdx = lemma->id(); + // After first sentence-end token only allow looping that sentence-end or blanks afterwards + if (hyp.reachedSentenceEnd and + not( + (collapseRepeatedLabels_ and hyp.currentToken == sentenceEndLabelIndex_ and tokenIdx == sentenceEndLabelIndex_) // sentence-end-loop + or (allowBlankAfterSentenceEnd_ and tokenIdx == blankLabelIndex_))) { // blank + continue; + } + auto transitionType = inferTransitionType(hyp.currentToken, tokenIdx); extensions_.push_back( @@ -433,13 +484,17 @@ void LexiconfreeTimesyncBeamSearch::logStatistics() const { } Nn::LabelScorer::TransitionType LexiconfreeTimesyncBeamSearch::inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel) const { - bool prevIsBlank = (useBlank_ and prevLabel == blankLabelIndex_); - bool nextIsBlank = (useBlank_ and nextLabel == blankLabelIndex_); + bool prevIsBlank = (useBlank_ and prevLabel == blankLabelIndex_); + bool nextIsBlank = (useBlank_ and nextLabel == blankLabelIndex_); + bool nextIsSentenceEnd = (useSentenceEnd_ and nextLabel == sentenceEndLabelIndex_); if (prevLabel == Nn::invalidLabelIndex) { if (nextIsBlank) { return Nn::LabelScorer::TransitionType::INITIAL_BLANK; } + else if (nextIsSentenceEnd) { + return Nn::LabelScorer::TransitionType::SENTENCE_END; + } else { return Nn::LabelScorer::TransitionType::INITIAL_LABEL; } @@ -449,6 +504,9 @@ Nn::LabelScorer::TransitionType LexiconfreeTimesyncBeamSearch::inferTransitionTy if (nextIsBlank) { return Nn::LabelScorer::TransitionType::BLANK_LOOP; } + else if (nextIsSentenceEnd) { + return Nn::LabelScorer::TransitionType::SENTENCE_END; + } else { return Nn::LabelScorer::TransitionType::BLANK_TO_LABEL; } @@ -460,6 +518,9 @@ Nn::LabelScorer::TransitionType LexiconfreeTimesyncBeamSearch::inferTransitionTy else if (collapseRepeatedLabels_ and prevLabel == nextLabel) { return Nn::LabelScorer::TransitionType::LABEL_LOOP; } + else if (nextIsSentenceEnd) { + return Nn::LabelScorer::TransitionType::SENTENCE_END; + } else { return Nn::LabelScorer::TransitionType::LABEL_TO_LABEL; } @@ -549,4 +610,36 @@ void LexiconfreeTimesyncBeamSearch::recombination(std::vectortime = beam_.front().trace->time; // Retrieve the timeframe from any hyp in the old beam + newBeam_.front().trace->pronunciation = nullptr; + newBeam_.front().trace->predecessor = Core::ref(new LatticeTrace(0, {0, 0}, {})); + newBeam_.front().reachedSentenceEnd = true; + beam_.swap(newBeam_); + } + } + else { + newBeam_.swap(beam_); + } +} + } // namespace Search diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh index 9a46851e..9eb794fa 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh @@ -43,6 +43,9 @@ public: static const Core::ParameterInt paramMaxBeamSize; static const Core::ParameterFloat paramScoreThreshold; static const Core::ParameterInt paramBlankLabelIndex; + static const Core::ParameterInt paramSentenceEndLabelIndex; + static const Core::ParameterBool paramAllowBlankAfterSentenceEnd; + static const Core::ParameterBool paramSentenceEndFallBack; static const Core::ParameterBool paramCollapseRepeatedLabels; static const Core::ParameterBool paramCacheCleanupInterval; static const Core::ParameterBool paramLogStepwiseStatistics; @@ -87,10 +90,11 @@ protected: * Struct containing all information about a single hypothesis in the beam */ struct LabelHypothesis { - Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis - Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type) - Score score; // Full score of hypothesis - Core::Ref trace; // Associated trace for traceback or lattice building off of hypothesis + Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis + Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type) + Score score; // Full score of hypothesis + Core::Ref trace; // Associated trace for traceback or lattice building off of hypothesis + bool reachedSentenceEnd; // Flag whether hypothesis trace contains a sentence end emission LabelHypothesis(); LabelHypothesis(LabelHypothesis const& base, ExtensionCandidate const& extension, Nn::ScoringContextRef const& newScoringContext); @@ -111,8 +115,15 @@ private: bool useScorePruning_; Score scoreThreshold_; + bool sentenceEndFallback_; + bool useBlank_; Nn::LabelIndex blankLabelIndex_; + bool allowBlankAfterSentenceEnd_; + + bool useSentenceEnd_; + Bliss::Lemma const* sentenceEndLemma_; + Nn::LabelIndex sentenceEndLabelIndex_; bool collapseRepeatedLabels_; @@ -171,6 +182,12 @@ private: * Helper function for recombination of hypotheses with the same scoring context */ void recombination(std::vector& hypotheses); + + /* + * Prune away all hypotheses that have not reached sentence end. + * If no hypotheses would survive this, either construct an empty one or keep the beam intact if sentence-end fallback is enabled. + */ + void finalizeHypotheses(); }; } // namespace Search diff --git a/src/Search/PersistentStateTree.cc b/src/Search/PersistentStateTree.cc index 9948605c..698ddccb 100644 --- a/src/Search/PersistentStateTree.cc +++ b/src/Search/PersistentStateTree.cc @@ -36,7 +36,7 @@ static const Core::ParameterString paramCacheArchive( "cache archive in which the persistent state-network should be cached", "global-cache"); -static u32 formatVersion = 13; +static u32 formatVersion = 14; namespace Search { struct ConvertTree { @@ -273,7 +273,7 @@ void PersistentStateTree::write(Core::MappedArchiveWriter out) { out << coarticulatedRootStates << unpushedCoarticulatedRootStates; out << rootTransitDescriptions << pushedWordEndNodes << uncoarticulatedWordEndStates; - out << rootState << ciRootState << otherRootStates; + out << rootState << ciRootState << otherRootStates << finalStates; } bool PersistentStateTree::read(Core::MappedArchiveReader in) { @@ -282,8 +282,8 @@ bool PersistentStateTree::read(Core::MappedArchiveReader in) { /// @todo Eventually do memory-mapping - if (v != formatVersion) { - Core::Application::us()->log() << "Wrong compressed network format, need " << formatVersion << " got " << v; + if (v < 13) { + Core::Application::us()->log() << "Wrong compressed network format, need version >= 13 got " << v; return false; } @@ -308,6 +308,9 @@ bool PersistentStateTree::read(Core::MappedArchiveReader in) { in >> pushedWordEndNodes >> uncoarticulatedWordEndStates; in >> rootState >> ciRootState >> otherRootStates; + if (v >= 14) { + in >> finalStates; + } return in.good(); } diff --git a/src/Search/PersistentStateTree.hh b/src/Search/PersistentStateTree.hh index cda2c9aa..18508ca5 100644 --- a/src/Search/PersistentStateTree.hh +++ b/src/Search/PersistentStateTree.hh @@ -112,6 +112,9 @@ public: // Other root nodes (currently used for the wordBoundaryRoot in CtcTreeBuilder) std::set otherRootStates; + // Valid nodes that the search can end in + std::set finalStates; + // The word-end exits std::vector exits; diff --git a/src/Search/TreeBuilder.cc b/src/Search/TreeBuilder.cc index 05915122..bf145cf9 100644 --- a/src/Search/TreeBuilder.cc +++ b/src/Search/TreeBuilder.cc @@ -1284,11 +1284,17 @@ const Core::ParameterBool CtcTreeBuilder::paramForceBlank( "require a blank label between two identical labels (only works if label-loops are disabled)", true); +const Core::ParameterBool CtcTreeBuilder::paramAllowBlankAfterSentenceEnd( + "allow-blank-after-sentence-end", + "blanks can still be produced after the sentence-end has been reached", + true); + CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) : SharedBaseClassTreeBuilder(config, lexicon, acousticModel, network), labelLoop_(paramLabelLoop(config)), blankLoop_(paramBlankLoop(config)), - forceBlank_(paramForceBlank(config)) { + forceBlank_(paramForceBlank(config)), + allowBlankAfterSentenceEnd_(paramAllowBlankAfterSentenceEnd(config)) { auto iters = lexicon.phonemeInventory()->phonemes(); for (auto it = iters.first; it != iters.second; ++it) { require(not(*it)->isContextDependent()); // Context dependent labels are not supported @@ -1310,6 +1316,26 @@ CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& wordBoundaryRoot_ = createRoot(); network_.otherRootStates.insert(wordBoundaryRoot_); } + + // Create a special root for sentence-end + auto sentenceEndLemma = getSentenceEndLemma(); + if (sentenceEndLemma == nullptr or sentenceEndLemma->nPronunciations() == 0) { + if (sentenceEndLemma != nullptr) { + warning() << "Building tree without sentence-end which means it may also not be scored by the LM"; + } + + // If no sentence-end is present, any root state is a valid final state + network_.finalStates.insert(network_.rootState); + for (auto const& otherRootState : network_.otherRootStates) { + network_.finalStates.insert(otherRootState); + } + } + else { + // If sentence-end is present, the sink state is the only valid final state + sentenceEndSink_ = createRoot(); + network_.otherRootStates.insert(sentenceEndSink_); + network_.finalStates.insert(sentenceEndSink_); + } } } @@ -1323,15 +1349,19 @@ void CtcTreeBuilder::build() { addWordBoundaryStates(); } + auto sentenceEndLemma = getSentenceEndLemma(); + if (sentenceEndLemma != nullptr or sentenceEndLemma->nPronunciations() == 0) { + addSentenceEndStates(); + } + auto blankLemma = lexicon_.specialLemma("blank"); auto silenceLemma = lexicon_.specialLemma("silence"); auto iters = lexicon_.lemmaPronunciations(); // Iterate over the lemmata and add them to the tree for (auto it = iters.first; it != iters.second; ++it) { - if ((*it)->lemma() == wordBoundaryLemma) { - // The wordBoundaryLemma should be a successor of the wordBoundaryRoot_ - // This is handled separately in addWordBoundaryStates() + if ((*it)->lemma() == wordBoundaryLemma or (*it)->lemma() == sentenceEndLemma) { + // Word-boundary and sentence-end lemmas are handled separately by `addWordBoundaryStates` and `addSentenceEndStates` continue; } @@ -1448,6 +1478,36 @@ void CtcTreeBuilder::addWordBoundaryStates() { } } +void CtcTreeBuilder::addSentenceEndStates() { + auto sentenceEndLemma = getSentenceEndLemma(); + if (sentenceEndLemma == nullptr) { + return; + } + + // Add the sentence-end to the tree, starting from the root. + require(sentenceEndLemma->nPronunciations() == 1); // Sentence-end must have at least one pronunciation, even if it is empty. + auto const& sentenceEndPron = *sentenceEndLemma->pronunciations().first; + // It may be that sentenceEndLastState == root if the pronunciation has length 0. + StateId sentenceEndLastState = extendPronunciation(network_.rootState, sentenceEndPron.pronunciation()); + verify(sentenceEndLastState != 0); + + addExit(sentenceEndLastState, sentenceEndSink_, sentenceEndPron.id()); + + // Add optional blank after the sentence-end lemma + if (allowBlankAfterSentenceEnd_) { + StateId blankAfter = extendState(sentenceEndSink_, blankDesc_); + addExit(blankAfter, sentenceEndSink_, lexicon_.specialLemma("blank")->id()); + } +} + +Bliss::Lemma const* CtcTreeBuilder::getSentenceEndLemma() const { + auto sentenceEndLemma = lexicon_.specialLemma("sentence-end"); + if (sentenceEndLemma == nullptr) { + sentenceEndLemma = lexicon_.specialLemma("sentence-boundary"); + } + return sentenceEndLemma; +} + // -------------------- RnaTreeBuilder -------------------- const Core::ParameterBool RnaTreeBuilder::paramLabelLoop( diff --git a/src/Search/TreeBuilder.hh b/src/Search/TreeBuilder.hh index f310d4a6..9bc01c66 100644 --- a/src/Search/TreeBuilder.hh +++ b/src/Search/TreeBuilder.hh @@ -274,6 +274,7 @@ public: static const Core::ParameterBool paramLabelLoop; static const Core::ParameterBool paramBlankLoop; static const Core::ParameterBool paramForceBlank; + static const Core::ParameterBool paramAllowBlankAfterSentenceEnd; CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true); virtual ~CtcTreeBuilder() = default; @@ -287,8 +288,10 @@ protected: bool labelLoop_; bool blankLoop_; bool forceBlank_; + bool allowBlankAfterSentenceEnd_; StateId wordBoundaryRoot_; + StateId sentenceEndSink_; // Reached after emitting sentence-end with no more outgoing transitions except for blank-looping if `allowBlankAfterSentenceEnd_` is enabled Search::StateTree::StateDesc blankDesc_; Am::AllophoneStateIndex blankAllophoneStateIndex_; @@ -298,6 +301,11 @@ protected: // Build the sub-tree with the word-boundary lemma plus optional blank starting from `wordBoundaryRoot_`. void addWordBoundaryStates(); + + // Build the sub-tree with the sentence-end lemma plus optional blank starting from `sentenceEndRoot_`. + void addSentenceEndStates(); + + Bliss::Lemma const* getSentenceEndLemma() const; }; class RnaTreeBuilder : public CtcTreeBuilder { diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc index 372eb831..7c7b8bd2 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -147,6 +148,8 @@ TreeTimesyncBeamSearch::TreeTimesyncBeamSearch(Core::Configuration const& config maxWordEndBeamSize_(paramMaxWordEndBeamSize(config)), scoreThreshold_(paramScoreThreshold(config)), wordEndScoreThreshold_(paramWordEndScoreThreshold(config)), + blankLabelIndex_(Nn::invalidLabelIndex), + sentenceEndLabelIndex_(Nn::invalidLabelIndex), cacheCleanupInterval_(paramCacheCleanupInterval(config)), useBlank_(), collapseRepeatedLabels_(paramCollapseRepeatedLabels(config)), @@ -235,6 +238,23 @@ bool TreeTimesyncBeamSearch::setModelCombination(Speech::ModelCombination const& useBlank_ = false; } + auto const* sentenceEndLemma = lexicon_->specialLemma("sentence-end"); + if (not sentenceEndLemma) { + sentenceEndLemma = lexicon_->specialLemma("sentence-boundary"); + } + if (sentenceEndLemma and sentenceEndLemma->nPronunciations() != 0 and sentenceEndLemma->pronunciations().first->pronunciation()->length() > 0) { + auto const* pron = sentenceEndLemma->pronunciations().first->pronunciation(); + require(pron->length() == 1); + Am::Allophone allo(acousticModel_->phonology()->allophone(*pron, 0), + Am::Allophone::isInitialPhone | Am::Allophone::isFinalPhone); + Am::AllophoneStateIndex alloStateIdx = acousticModel_->allophoneStateAlphabet()->index(&allo, 0); + + sentenceEndLabelIndex_ = acousticModel_->emissionIndex(alloStateIdx); + } + else { + sentenceEndLabelIndex_ = Nn::invalidLabelIndex; + } + for (const auto& lemma : {"silence", "blank"}) { if (lexicon_->specialLemma(lemma) and (lexicon_->specialLemma(lemma)->syntacticTokenSequence()).size() != 0) { warning("Special lemma \"%s\" will be scored by the language model. To prevent the LM from scoring it, set an empty syntactic token sequence for it in the lexicon.", lemma); @@ -288,7 +308,7 @@ void TreeTimesyncBeamSearch::finishSegment() { decodeManySteps(); logStatistics(); finishedSegment_ = true; - finalizeLmScoring(); + finalizeHypotheses(); } void TreeTimesyncBeamSearch::putFeature(Nn::DataView const& feature) { @@ -366,13 +386,15 @@ bool TreeTimesyncBeamSearch::decodeStep() { } auto transitionType = inferTransitionType(hyp.currentToken, tokenIdx); withinWordExtensions_.push_back( - {tokenIdx, - successorState, - 0, - hyp.score, - transitionType, - hypIndex}); - requests_.push_back({beam_[hypIndex].scoringContext, tokenIdx, transitionType}); + {.nextToken = tokenIdx, + .nextState = successorState, + .timeframe = 0, + .score = hyp.score, + .transitionType = transitionType, + .baseHypIndex = hypIndex}); + requests_.push_back({.context = beam_[hypIndex].scoringContext, + .nextToken = tokenIdx, + .transitionType = transitionType}); } } @@ -413,9 +435,9 @@ bool TreeTimesyncBeamSearch::decodeStep() { contextExtensionTime_.start(); auto newScoringContext = labelScorer_->extendedScoringContext( - {baseHyp.scoringContext, - extension.nextToken, - extension.transitionType}); + {.context = baseHyp.scoringContext, + .nextToken = extension.nextToken, + .transitionType = extension.transitionType}); contextExtensionTime_.stop(); newBeam_.push_back({baseHyp, extension, newScoringContext}); @@ -470,7 +492,10 @@ bool TreeTimesyncBeamSearch::decodeStep() { penalty = result->score; } - wordEndExtensions_.push_back({lemmaPron, exit.transitState, hyp.score + lmScore + penalty, hypIndex}); + wordEndExtensions_.push_back({.pron = lemmaPron, + .rootState = exit.transitState, + .score = hyp.score + lmScore + penalty, + .baseHypIndex = hypIndex}); } } } @@ -513,6 +538,42 @@ bool TreeTimesyncBeamSearch::decodeStep() { clog() << Core::XmlFull("num-word-end-hyps-after-beam-pruning", wordEndHypotheses_.size()); } + /* + * Take having two exits back-to-back for the word-end hyps into account (usually for sentence-end with zero-length pronunciation after word-end) + */ + auto const origSize = wordEndHypotheses_.size(); + for (size_t hypIndex = 0ul; hypIndex < origSize; ++hypIndex) { + auto& hyp = wordEndHypotheses_[hypIndex]; + + auto exitList = exitLookup_[hyp.currentState]; + // Create one word-end hypothesis for each exit + for (const auto& exit : exitList) { + auto const* lemmaPron = lexicon_->lemmaPronunciation(exit.pronunciation); + auto const* lemma = lemmaPron->lemma(); + + WordEndExtensionCandidate wordEndExtension{.pron = lemmaPron, + .rootState = exit.transitState, // Start from the root node (the exit's transit state) in the next step + .score = hyp.score, + .baseHypIndex = hypIndex}; + + auto const sts = lemma->syntacticTokenSequence(); + auto newLmHistory = hyp.lmHistory; + if (sts.size() != 0) { + require(sts.size() == 1); + auto const* st = sts.front(); + + // Add the LM score + Lm::Score lmScore = languageModel_->score(hyp.lmHistory, st); + wordEndExtension.score += lmScore; + + // Extend the LM history + newLmHistory = languageModel_->extendedHistory(hyp.lmHistory, st); + } + wordEndHypotheses_.push_back({hyp, wordEndExtension, newLmHistory}); + } + } + recombination(wordEndHypotheses_, true); + beam_.swap(newBeam_); beam_.insert(beam_.end(), wordEndHypotheses_.begin(), wordEndHypotheses_.end()); @@ -605,13 +666,17 @@ void TreeTimesyncBeamSearch::logStatistics() const { } Nn::LabelScorer::TransitionType TreeTimesyncBeamSearch::inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel) const { - bool prevIsBlank = (useBlank_ and prevLabel == blankLabelIndex_); - bool nextIsBlank = (useBlank_ and nextLabel == blankLabelIndex_); + bool prevIsBlank = (useBlank_ and prevLabel == blankLabelIndex_); + bool nextIsBlank = (useBlank_ and nextLabel == blankLabelIndex_); + bool nextIsSentenceEnd = nextLabel == sentenceEndLabelIndex_; if (prevLabel == Nn::invalidLabelIndex) { if (nextIsBlank) { return Nn::LabelScorer::TransitionType::INITIAL_BLANK; } + else if (nextIsSentenceEnd) { + return Nn::LabelScorer::TransitionType::SENTENCE_END; + } else { return Nn::LabelScorer::TransitionType::INITIAL_LABEL; } @@ -621,6 +686,9 @@ Nn::LabelScorer::TransitionType TreeTimesyncBeamSearch::inferTransitionType(Nn:: if (nextIsBlank) { return Nn::LabelScorer::TransitionType::BLANK_LOOP; } + else if (nextIsSentenceEnd) { + return Nn::LabelScorer::TransitionType::SENTENCE_END; + } else { return Nn::LabelScorer::TransitionType::BLANK_TO_LABEL; } @@ -632,6 +700,9 @@ Nn::LabelScorer::TransitionType TreeTimesyncBeamSearch::inferTransitionType(Nn:: else if (collapseRepeatedLabels_ and prevLabel == nextLabel) { return Nn::LabelScorer::TransitionType::LABEL_LOOP; } + else if (nextIsSentenceEnd) { + return Nn::LabelScorer::TransitionType::SENTENCE_END; + } else { return Nn::LabelScorer::TransitionType::LABEL_TO_LABEL; } @@ -752,30 +823,24 @@ void TreeTimesyncBeamSearch::createSuccessorLookups() { } } -void TreeTimesyncBeamSearch::finalizeLmScoring() { +void TreeTimesyncBeamSearch::finalizeHypotheses() { newBeam_.clear(); - for (size_t hypIndex = 0ul; hypIndex < beam_.size(); ++hypIndex) { - auto& hyp = beam_[hypIndex]; - // Check if the hypotheses in the beam are at a root state and add the sentence-end LM score - if (hyp.currentState == network_->rootState or network_->otherRootStates.find(hyp.currentState) != network_->otherRootStates.end()) { - Lm::Score sentenceEndScore = languageModel_->sentenceEndScore(hyp.lmHistory); - hyp.score += sentenceEndScore; - hyp.trace->score.lm += sentenceEndScore; + for (auto const& hyp : beam_) { + if (network_->finalStates.contains(hyp.currentState)) { newBeam_.push_back(hyp); } } - if (newBeam_.empty()) { // There was no word-end hypothesis in the beam + if (newBeam_.empty()) { // There was no valid final hypothesis in the beam warning("No active word-end hypothesis at segment end."); if (sentenceEndFallback_) { log() << "Use sentence-end fallback"; // The trace of the unfinished word keeps an empty pronunciation, only the LM score is added - for (size_t hypIndex = 0ul; hypIndex < beam_.size(); ++hypIndex) { - auto& hyp = beam_[hypIndex]; - Lm::Score sentenceEndScore = languageModel_->sentenceEndScore(hyp.lmHistory); - hyp.score += sentenceEndScore; - hyp.trace->score.lm += sentenceEndScore; + for (auto const& hyp : beam_) { newBeam_.push_back(hyp); + Lm::Score sentenceEndScore = languageModel_->sentenceEndScore(hyp.lmHistory); + newBeam_.back().score += sentenceEndScore; + newBeam_.back().trace->score.lm += sentenceEndScore; } } else { diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh index b3a00918..a2c8eaf9 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh @@ -136,6 +136,7 @@ private: Score scoreThreshold_; Score wordEndScoreThreshold_; Nn::LabelIndex blankLabelIndex_; + Nn::LabelIndex sentenceEndLabelIndex_; size_t cacheCleanupInterval_; bool useBlank_; @@ -220,10 +221,10 @@ private: /* * After reaching the segment end, go through the active hypotheses, only keep those - * which are at a word end (in the root state) and add the sentence end LM score. - * If no word-end hypotheses exist, use sentence-end fallback or construct an empty hypothesis + * which are final states of the search tree. + * If no such hypotheses exist, use sentence-end fallback or construct an empty hypothesis. */ - void finalizeLmScoring(); + void finalizeHypotheses(); }; } // namespace Search