Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 84 additions & 107 deletions nemo/collections/asr/metrics/der.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from itertools import permutations
from typing import Dict, List, Optional, Tuple

import editdistance
import numpy as np
import pandas as pd
import torch
from pyannote.core import Segment, Timeline
from pyannote.metrics.diarization import DiarizationErrorRate
from scipy.optimize import linear_sum_assignment as scipy_linear_sum_assignment

from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.collections.asr.parts.utils.optimization_utils import linear_sum_assignment
from nemo.utils import logging

__all__ = [
Expand Down Expand Up @@ -119,7 +117,7 @@ def uem_timeline_from_file(uem_file, uniq_name=''):
UNIQ_SPEAKER_ID CHANNEL START_TIME END_TIME
"""
timeline = Timeline(uri=uniq_name)
with open(uem_file, 'r') as f:
with open(uem_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
for line in lines:
line = line.strip()
Expand Down Expand Up @@ -275,19 +273,19 @@ def evaluate_der(audio_rttm_map_dict, all_reference, all_hypothesis, diar_eval_m

def calculate_session_cpWER_bruteforce(spk_hypothesis: List[str], spk_reference: List[str]) -> Tuple[float, str, str]:
"""
Calculate cpWER with actual permutations in brute-force way when LSA algorithm cannot deliver the correct result.
Calculate cpWER with brute-force permutation search. Matches MeetEval's cpWER algorithm:
each (ref_speaker, hyp_speaker) pair is scored independently via edit distance, then
cpWER = sum(errors) / sum(ref_word_counts).

Args:
spk_hypothesis (list):
List containing the hypothesis transcript for each speaker. A list containing the sequence
of words is assigned for each speaker.
List containing the hypothesis transcript for each speaker.

Example:
>>> spk_hypothesis = ["hey how are you we that's nice", "i'm good yes hi is your sister"]

spk_reference (list):
List containing the reference transcript for each speaker. A list containing the sequence
of words is assigned for each speaker.
List containing the reference transcript for each speaker.

Example:
>>> spk_reference = ["hi how are you well that's nice", "i'm good yeah how is your sister"]
Expand All @@ -300,83 +298,59 @@ def calculate_session_cpWER_bruteforce(spk_hypothesis: List[str], spk_reference:
ref_trans (str):
Reference transcript in an arbitrary permutation. Words are separated by spaces.
"""
p_wer_list, permed_hyp_lists = [], []
ref_word_list = []

# Concatenate the hypothesis transcripts into a list
for spk_id, word_list in enumerate(spk_reference):
ref_word_list.append(word_list)
ref_trans = " ".join(ref_word_list)

# Calculate WER for every permutation
for hyp_word_list in permutations(spk_hypothesis):
hyp_trans = " ".join(hyp_word_list)
permed_hyp_lists.append(hyp_trans)

# Calculate a WER value of the permuted and concatenated transcripts
p_wer = word_error_rate(hypotheses=[hyp_trans], references=[ref_trans])
p_wer_list.append(p_wer)

# Find the lowest WER and its hypothesis transcript
argmin_idx = np.argmin(p_wer_list)
min_perm_hyp_trans = permed_hyp_lists[argmin_idx]
cpWER = p_wer_list[argmin_idx]
return cpWER, min_perm_hyp_trans, ref_trans


def calculate_session_cpWER(
spk_hypothesis: List[str], spk_reference: List[str], use_lsa_only: bool = False
) -> Tuple[float, str, str]:
num_hyp = len(spk_hypothesis)
num_ref = len(spk_reference)
Comment thread
tango4j marked this conversation as resolved.
num_speakers_padded = max(num_hyp, num_ref)

ref_word_lists = [
spk_reference[ref_idx].split() if ref_idx < num_ref else [] for ref_idx in range(num_speakers_padded)
]
hyp_word_lists = [
spk_hypothesis[hyp_idx].split() if hyp_idx < num_hyp else [] for hyp_idx in range(num_speakers_padded)
]

best_total_errors = float('inf')
best_hyp_trans = ""
total_ref_length = sum(len(word_list) for word_list in ref_word_lists)

for perm in permutations(range(num_speakers_padded)):
total_errors = 0
hyp_texts = []
for ref_idx, hyp_idx in enumerate(perm):
total_errors += editdistance.eval(ref_word_lists[ref_idx], hyp_word_lists[hyp_idx])
hyp_texts.append(spk_hypothesis[hyp_idx] if hyp_idx < num_hyp else "")
if total_errors < best_total_errors:
best_total_errors = total_errors
best_hyp_trans = " ".join(hyp_texts)

cpWER = best_total_errors / total_ref_length if total_ref_length > 0 else float('inf')
ref_trans = " ".join(spk_reference)
return cpWER, best_hyp_trans, ref_trans


def calculate_session_cpWER(spk_hypothesis: List[str], spk_reference: List[str]) -> Tuple[float, str, str]:
"""
Calculate a session-level concatenated minimum-permutation word error rate (cpWER) value. cpWER is
a scoring method that can evaluate speaker diarization and speech recognition performance at the same time.
cpWER is calculated by going through the following steps.

1. Concatenate all utterances of each speaker for both reference and hypothesis files.
2. Compute the WER between the reference and all possible speaker permutations of the hypothesis.
3. Pick the lowest WER among them (this is assumed to be the best permutation: `min_perm_hyp_trans`).

cpWER was proposed in the following article:
CHiME-6 Challenge: Tackling Multispeaker Speech Recognition for Unsegmented Recordings
https://arxiv.org/pdf/2004.09249.pdf

Implementation:
- Brute force permutation method for calculating cpWER has a time complexity of `O(n!)`.
- To reduce the computational burden, linear sum assignment (LSA) algorithm is applied
(also known as Hungarian algorithm) to find the permutation that leads to the lowest WER.
- In this implementation, instead of calculating all WER values for all permutation of hypotheses,
we only calculate WER values of (estimated number of speakers) x (reference number of speakers)
combinations with `O(n^2)`) time complexity and then select the permutation that yields the lowest
WER based on LSA algorithm.
- LSA algorithm has `O(n^3)` time complexity in the worst case.
- We cannot use LSA algorithm to find the best permutation when there are more hypothesis speakers
than reference speakers. In this case, we use the brute-force permutation method instead.

Example:
>>> transcript_A = ['a', 'b', 'c', 'd', 'e', 'f'] # 6 speakers
>>> transcript_B = ['a c b d', 'e f'] # 2 speakers

[case1] hypothesis is transcript_A, reference is transcript_B
[case2] hypothesis is transcript_B, reference is transcript_A

LSA algorithm based cpWER is:
[case1] 4/6 (4 deletion)
[case2] 2/6 (2 substitution)
brute force permutation based cpWER is:
[case1] 0
[case2] 2/6 (2 substitution)
Calculate a session-level concatenated minimum-permutation word error rate (cpWER) value,
matching MeetEval's cpWER algorithm (https://github.com/fgnt/meeteval).

Algorithm (identical to MeetEval):
1. Build a square cost matrix of size max(num_hyp, num_ref) using raw edit distance
counts between every (ref_speaker, hyp_speaker) pair. Missing speakers are padded
with empty word lists.
2. Use the Hungarian algorithm (scipy.optimize.linear_sum_assignment) to find the
speaker assignment that minimizes total edit distance.
3. Compute per-pair edit distance independently for the optimal assignment.
4. cpWER = sum(errors_per_pair) / sum(ref_word_counts_per_pair).

Args:
spk_hypothesis (list):
List containing the hypothesis transcript for each speaker. A list containing the sequence
of words is assigned for each speaker.
List containing the hypothesis transcript for each speaker.

Example:
>>> spk_hypothesis = ["hey how are you we that's nice", "i'm good yes hi is your sister"]

spk_reference (list):
List containing the reference transcript for each speaker. A list containing the sequence
of words is assigned for each speaker.
List containing the reference transcript for each speaker.

Example:
>>> spk_reference = ["hi how are you well that's nice", "i'm good yeah how is your sister"]
Expand All @@ -389,37 +363,40 @@ def calculate_session_cpWER(
ref_trans (str):
Reference transcript in an arbitrary permutation. Words are separated by spaces.
"""
# Get all pairs of (estimated num of spks) x (reference num of spks) combinations
hyp_ref_pair = [spk_hypothesis, spk_reference]
all_pairs = list(itertools.product(*hyp_ref_pair))
num_hyp = len(spk_hypothesis)
num_ref = len(spk_reference)

num_hyp_spks, num_ref_spks = len(spk_hypothesis), len(spk_reference)
if num_hyp == 0 and num_ref == 0:
return 0.0, "", ""

Comment thread
tango4j marked this conversation as resolved.
if not use_lsa_only and num_ref_spks < num_hyp_spks:
# Brute force algorithm when there are more speakers in the hypothesis
cpWER, min_perm_hyp_trans, ref_trans = calculate_session_cpWER_bruteforce(spk_hypothesis, spk_reference)
else:
# Calculate WER for each speaker in hypothesis with reference
# There are (number of hyp speakers) x (number of ref speakers) combinations
lsa_wer_list = []
for spk_hyp_trans, spk_ref_trans in all_pairs:
spk_wer = word_error_rate(hypotheses=[spk_hyp_trans], references=[spk_ref_trans])
lsa_wer_list.append(spk_wer)

# Make a cost matrix and calculate a linear sum assignment on the cost matrix.
# Row is hypothesis index and column is reference index
cost_wer = torch.tensor(lsa_wer_list).reshape([len(spk_hypothesis), len(spk_reference)])
row_hyp_ind, col_ref_ind = linear_sum_assignment(cost_wer)

# In case where hypothesis has more speakers, add words from residual speakers
hyp_permed = [spk_hypothesis[k] for k in np.argsort(col_ref_ind)]
min_perm_hyp_trans = " ".join(hyp_permed)

# Concatenate the reference transcripts into a string variable
ref_trans = " ".join(spk_reference)

# Calculate a WER value from the permutation that yields the lowest WER.
cpWER = word_error_rate(hypotheses=[min_perm_hyp_trans], references=[ref_trans])
num_speakers_padded = max(num_hyp, num_ref)

ref_word_lists = [
spk_reference[ref_idx].split() if ref_idx < num_ref else [] for ref_idx in range(num_speakers_padded)
]
hyp_word_lists = [
spk_hypothesis[hyp_idx].split() if hyp_idx < num_hyp else [] for hyp_idx in range(num_speakers_padded)
]

cost_matrix = np.zeros((num_speakers_padded, num_speakers_padded), dtype=np.float64)
for ref_idx in range(num_speakers_padded):
for hyp_idx in range(num_speakers_padded):
cost_matrix[ref_idx, hyp_idx] = editdistance.eval(ref_word_lists[ref_idx], hyp_word_lists[hyp_idx])

row_ind, col_ind = scipy_linear_sum_assignment(cost_matrix)

total_errors = 0
total_ref_length = 0
hyp_texts = []
for ref_idx, hyp_idx in zip(row_ind, col_ind):
total_errors += int(cost_matrix[ref_idx, hyp_idx])
total_ref_length += len(ref_word_lists[ref_idx])
hyp_texts.append(spk_hypothesis[hyp_idx] if hyp_idx < num_hyp else "")

cpWER = total_errors / total_ref_length if total_ref_length > 0 else float('inf')

min_perm_hyp_trans = " ".join(hyp_texts)
ref_trans = " ".join(spk_reference)

return cpWER, min_perm_hyp_trans, ref_trans

Expand Down
114 changes: 3 additions & 111 deletions tests/collections/speaker_tasks/test_diar_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,121 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from itertools import permutations

import pytest
import torch

from nemo.collections.asr.metrics.der import (
calculate_session_cpWER,
calculate_session_cpWER_bruteforce,
get_online_DER_stats,
get_partial_ref_labels,
)


def word_count(spk_transcript):
return sum([len(w.split()) for w in spk_transcript])


def calculate_wer_count(_ins, _del, _sub, ref_word_count):
return (_ins + _del + _sub) / ref_word_count


def permuted_input_test(hyp, ref, calculated):
"""
Randomly permute the input to see if evaluation result stays the same.
"""
for hyp_permed in permutations(hyp):
cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp_permed, spk_reference=ref)
diff = torch.abs(torch.tensor(calculated - cpWER))
assert diff <= 1e-6


class TestConcatMinPermWordErrorRate:
"""
Tests for cpWER calculation.
"""

@pytest.mark.unit
def test_cpwer_oneword(self):
hyp = ["oneword"]
ref = ["oneword"]
_ins, _del, _sub = 0, 0, 0
cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp, spk_reference=ref)
ref_word_count = word_count(ref)
calculated = calculate_wer_count(_ins, _del, _sub, ref_word_count)
diff = torch.abs(torch.tensor(calculated - cpWER))
assert diff <= 1e-6
permuted_input_test(hyp, ref, calculated)
cpWER_perm, hyp_min_perm, ref_str = calculate_session_cpWER_bruteforce(spk_hypothesis=hyp, spk_reference=ref)
diff = torch.abs(torch.tensor(cpWER_perm - cpWER))
assert diff <= 1e-6

# Test with a substitution
hyp = ["wrongword"]
_ins, _del, _sub = 0, 0, 1
cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp, spk_reference=ref)
calculated = calculate_wer_count(_ins, _del, _sub, ref_word_count)
diff = torch.abs(torch.tensor(calculated - cpWER))
assert diff <= 1e-6
permuted_input_test(hyp, ref, calculated)
cpWER_perm, hyp_min_perm, ref_str = calculate_session_cpWER_bruteforce(spk_hypothesis=hyp, spk_reference=ref)
diff = torch.abs(torch.tensor(cpWER_perm - cpWER))
assert diff <= 1e-6

@pytest.mark.unit
def test_cpwer_perfect(self):
hyp = ["ff", "aa bb cc", "dd ee"]
ref = ["aa bb cc", "dd ee", "ff"]
cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp, spk_reference=ref)
calculated = 0
diff = torch.abs(torch.tensor(calculated - cpWER))
assert diff <= 1e-6
permuted_input_test(hyp, ref, calculated)

@pytest.mark.unit
def test_cpwer_spk_counfusion_and_asr_error(self):
hyp = ["aa bb c ff", "dd e ii jj kk", "hi"]
ref = ["aa bb cc ff", "dd ee gg jj kk", "hh ii"]
_ins, _del, _sub = 0, 1, 4
cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp, spk_reference=ref)
ref_word_count = word_count(ref)
calculated = calculate_wer_count(_ins, _del, _sub, ref_word_count)
diff = torch.abs(torch.tensor(calculated - cpWER))
assert diff <= 1e-6
permuted_input_test(hyp, ref, calculated)
cpWER_perm, hyp_min_perm, ref_str = calculate_session_cpWER_bruteforce(spk_hypothesis=hyp, spk_reference=ref)
diff = torch.abs(torch.tensor(cpWER_perm - cpWER))
assert diff <= 1e-6
from nemo.collections.asr.metrics.der import get_online_DER_stats, get_partial_ref_labels

@pytest.mark.unit
def test_cpwer_undercount(self):
hyp = ["aa bb cc", "dd ee gg", "hh ii", "jj kk"]
ref = ["aa bb cc", "dd ee", "ff", "gg", "hh ii", "jj kk"]
_ins, _del, _sub = 0, 1, 0
cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp, spk_reference=ref)
ref_word_count = word_count(ref)
calculated = calculate_wer_count(_ins, _del, _sub, ref_word_count)
diff = torch.abs(torch.tensor(calculated - cpWER))
assert diff <= 1e-6
cpWER_perm, hyp_min_perm, ref_str = calculate_session_cpWER_bruteforce(spk_hypothesis=hyp, spk_reference=ref)
diff = torch.abs(torch.tensor(cpWER_perm - cpWER))
assert diff <= 1e-6

@pytest.mark.unit
def test_cpwer_overcount(self):
hyp = ["aa bb cc", "dd ee gg hh", "ii jj kk"]
ref = ["aa bb cc", "dd ee ff gg hh ii jj kk"]
_ins, _del, _sub = 0, 1, 0
cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp, spk_reference=ref)
ref_word_count = word_count(ref)
calculated = calculate_wer_count(_ins, _del, _sub, ref_word_count)
diff = torch.abs(torch.tensor(calculated - cpWER))
assert diff <= 1e-6
cpWER_perm, hyp_min_perm, ref_str = calculate_session_cpWER_bruteforce(spk_hypothesis=hyp, spk_reference=ref)
diff = torch.abs(torch.tensor(cpWER_perm - cpWER))
assert diff <= 1e-6
class TestDiarMetrics:
"""Tests for DER-related utility functions (cpWER tests are in test_cpwer.py)."""

@pytest.mark.parametrize(
"pred_labels, ref_labels, expected_output",
Expand Down
Loading
Loading