diff --git a/README.md b/README.md
index f154ec2..584b2b6 100644
--- a/README.md
+++ b/README.md
@@ -269,6 +269,16 @@ tot sacreBLEU docAsAWhole 32.786
avg sacreBLEU mwerSegmenter 25.850
```
+If you want to calculate the COMET score as well, you need to include the ost file in the source language as src as shown below:
+'''
+MTeval -i sample-data/sample.en.cs.mt sample-data/sample.en.OSt sample-data/sample.cs.OSt -f mt src ref
+'''
+This would add an additional line in the output reporting the COMET score:
+'''
+tot COMET docAsWhole 0.770
+'''
+This is optional.
+
#### Evaluating SLT
Spoken language translation evaluates "machine translation in time". So a time-stamped MT output (``slt``) is compared with the reference translation (non-timed, ``ref``) and the timing of the golden transcript (``ostt``).
@@ -293,6 +303,7 @@ tot Flicker count_changed_content 23
tot sacreBLEU docAsAWhole 32.786
...
```
+Similar to MTeval, to calculate COMET score, you need to include the ost file in the source language.
#### Evaluating ASR
@@ -392,6 +403,9 @@ Usage:
SLTIndexParser path_to_index_file path_to_dataset
```
+5. It must be noted that a stable internet connection is necessary in order to download the COMET model to the local
+system to calculate the COMET score.
+
## Terminology and Abbreviations
* OSt ... original speech manually transcribed (i.e. golden transcript)
diff --git a/SLTev/ASReval.py b/SLTev/ASReval.py
index 502d04c..d4c6e8b 100644
--- a/SLTev/ASReval.py
+++ b/SLTev/ASReval.py
@@ -168,7 +168,8 @@ def main(input_files=[], file_formats=[], arguments={}):
'ostt': read_ostt_file(gold_files["ostt"][0]),
'references': read_references(gold_files["ost"]),
'SLTev_home': SLTev_home,
- 'candidate': read_candidate_file(candidate_file[0])
+ 'candidate': read_candidate_file(candidate_file[0]),
+ 'src': read_references(gold_files["ost"])
}
_ = check_time_stamp_candiates_format(candidate_file[0], split_token) # submission checking
diff --git a/SLTev/MTeval.py b/SLTev/MTeval.py
index 6457c4c..cfe1597 100644
--- a/SLTev/MTeval.py
+++ b/SLTev/MTeval.py
@@ -82,6 +82,7 @@ def main(inputs=[], file_formats=[], arguments={}):
evaluation_object = {
'references': read_references(gold_files["ref"]),
'mt': read_candidate_file(candidate_file[0]),
+ 'src': read_references(gold_files["src"]),
'SLTev_home': sltev_home,
}
diff --git a/SLTev/SLTev.py b/SLTev/SLTev.py
index 31280c0..62b1bee 100644
--- a/SLTev/SLTev.py
+++ b/SLTev/SLTev.py
@@ -361,7 +361,7 @@ def slt_submission_evaluation(args, inputs_object):
def build_input_fils_and_file_formats(submission_file, gold_input_files):
- status, references, ostt, aligns = split_gold_inputs_submission_in_working_directory(submission_file, gold_input_files)
+ status, src, references, ostt, aligns = split_gold_inputs_submission_in_working_directory(submission_file, gold_input_files)
input_files = [submission_file]
file_formats = [remove_digits(status)]
@@ -375,6 +375,9 @@ def build_input_fils_and_file_formats(submission_file, gold_input_files):
if ostt != "":
input_files.append(ostt)
file_formats.append("ostt")
+ if src != "":
+ input_files.append(src)
+ file_formats.append("src")
for align_file in aligns:
input_files.append(align_file)
diff --git a/SLTev/SLTeval.py b/SLTev/SLTeval.py
index fd7170b..035765b 100644
--- a/SLTev/SLTeval.py
+++ b/SLTev/SLTeval.py
@@ -97,6 +97,7 @@ def main(input_files=[], file_formats=[], arguments={}):
'references': read_references(gold_files["ref"]),
'candidate': read_candidate_file(candidate_file[0]),
'align': gold_files["align"],
+ 'src': read_references(gold_files["src"]),
'SLTev_home': sltev_home,
}
diff --git a/SLTev/evaluator.py b/SLTev/evaluator.py
index 67ba131..18879c4 100644
--- a/SLTev/evaluator.py
+++ b/SLTev/evaluator.py
@@ -9,7 +9,7 @@
from flicker_modules import calc_revise_count, calc_flicker_score
from flicker_modules import calc_average_flickers_per_sentence, calc_average_flickers_per_tokens
from quality_modules import calc_bleu_score_documentlevel, calc_bleu_score_segmenterlevel
-from quality_modules import calc_bleu_score_timespanlevel
+from quality_modules import calc_bleu_score_timespanlevel, calculate_comet_score
from utilities import mwerSegmenter_error_message, eprint
from files_modules import read_alignment_file
@@ -268,6 +268,8 @@ def normal_evaluation_without_parity(inputs_object):
references_statistical_info(references) # print statistical info
average_refernces_token_count = get_average_references_token_count(references)
candidate_sentences = inputs_object.get('candidate')
+ mt_sentences = inputs_object.get('mt')
+ src_file = inputs_object.get('src')
evaluation_object = {
'candidate_sentences': candidate_sentences,
@@ -283,6 +285,8 @@ def normal_evaluation_without_parity(inputs_object):
# bleu score evaluation
documantlevel_bleu_score_evaluation(references, candidate_sentences)
wordbased_segmenter_bleu_score_evaluation(evaluation_object)
+ if src_file != '' and src_file != []:
+ comet_score_evaluation(src_file, mt_sentences, references)
#flicker evaluation
print("tot Flicker count_changed_Tokens ", int(calc_revise_count(candidate_sentences)))
@@ -338,6 +342,7 @@ def normal_timestamp_evaluation(inputs_object):
average_refernces_token_count = get_average_references_token_count(references)
candidate_sentences = inputs_object.get('candidate')
OStt_sentences = inputs_object.get('ostt')
+ src_file = inputs_object.get('src')
print_ostt_duration(OStt_sentences)
Ts = []
for reference in references:
@@ -373,6 +378,8 @@ def normal_timestamp_evaluation(inputs_object):
documantlevel_bleu_score_evaluation(references, candidate_sentences)
wordbased_segmenter_bleu_score_evaluation(evaluation_object)
time_span_bleu_score_evaluation(evaluation_object)
+ if src_file != '' and src_file != []:
+ comet_score_evaluation(src_file, candidate_sentences, references)
#flicker evaluation
print("tot Flicker count_changed_Tokens ", int(calc_revise_count(candidate_sentences)))
print("tot Flicker count_changed_content ", int(calc_flicker_score(candidate_sentences)))
@@ -385,6 +392,13 @@ def normal_timestamp_evaluation(inputs_object):
str("{0:.3f}".format(round(calc_average_flickers_per_tokens(candidate_sentences), 3))),
)
+def comet_score_evaluation(src_file, mt_sentences, references):
+ comet_score, success = calculate_comet_score(src_file, mt_sentences, references)
+ if success:
+ print(
+ "tot COMET docAsWhole ",
+ str("{0:.3f}".format(round(comet_score, 3))),
+ )
def simple_mt_evaluation(inputs_object):
current_path = os.getcwd()
@@ -413,6 +427,7 @@ def normal_mt_evaluation(inputs_object):
references_statistical_info(references) # print statistical info
average_refernces_token_count = get_average_references_token_count(references)
mt_sentences = inputs_object.get('mt')
+ src_file = inputs_object.get('src')
evaluation_object = {
'candidate_sentences': mt_sentences,
@@ -427,6 +442,8 @@ def normal_mt_evaluation(inputs_object):
# bleu score evaluation
documantlevel_bleu_score_evaluation(references, mt_sentences)
wordbased_segmenter_bleu_score_evaluation(evaluation_object)
+ if src_file != '' and src_file != []:
+ comet_score_evaluation(src_file, mt_sentences, references)
diff --git a/SLTev/quality_modules.py b/SLTev/quality_modules.py
index 6fbffb1..70367c4 100644
--- a/SLTev/quality_modules.py
+++ b/SLTev/quality_modules.py
@@ -1,7 +1,9 @@
#!/usr/bin/env python
import sacrebleu
+from comet import download_model, load_from_checkpoint
from files_modules import quality_segmenter
+from utilities import eprint
def calc_bleu_score_documentlevel(references, candiate_sentences):
@@ -162,3 +164,37 @@ def calc_bleu_score_timespanlevel(evaluation_object):
)
return bleu_scores, avg_SacreBleu
+def calculate_comet_score(sources, candidates, references=None):
+ try:
+ model_name = "Unbabel/wmt22-comet-da"
+ merge_mt_sentences = []
+ for i in range(len(candidates)):
+ mt = candidates[i][-1][3:-1]
+ merge_mt_sentences += mt
+
+ merge_references_sentences = []
+ for ref in references:
+ l = []
+ for sentence in ref:
+ l.append(" ".join(sentence[:-1]))
+ merge_references_sentences.append(l)
+
+ merge_src_sentences = []
+ for src in sources:
+ l = []
+ for sentence in src:
+ l.append(" ".join(sentence[:-1]))
+ merge_src_sentences.append(l)
+
+ ref = [" ".join(i) for i in merge_references_sentences]
+ mt = [" ".join(merge_mt_sentences[:])]
+ src = [" ".join(i) for i in merge_src_sentences]
+ data = [{'src': x[0], 'mt': x[1], 'ref': x[2]} for x in zip(src, mt, ref)]
+
+ model_path = download_model(model_name)
+ model = load_from_checkpoint(model_path)
+ model_output = model.predict(data, batch_size=8, gpus=0)
+ return model_output['system_score'] * 100, True
+ except:
+ eprint("Unable to calculate COMET score since there is no internet connection to download the model.")
+ return 0, False
diff --git a/SLTev/utilities.py b/SLTev/utilities.py
index b83d8e6..669ab83 100644
--- a/SLTev/utilities.py
+++ b/SLTev/utilities.py
@@ -281,7 +281,7 @@ def split_gold_inputs_submission_in_working_directory(submission_file, gold_inpu
:return tt, ostt, align: OSt, OStt, align files according to the submission file
"""
- status, ostt = "", ""
+ status, ostt, src = "", "", ""
references, aligns = list(), list()
submission_file_name = os.path.split(submission_file)[1]
@@ -300,6 +300,11 @@ def split_gold_inputs_submission_in_working_directory(submission_file, gold_inpu
== submission_file_name_without_prefix + "." + target_lang + ".OSt"
):
references.append(file)
+ elif (
+ ".".join(input_name[:-1]) + "." + remove_digits(input_name[-1])
+ == submission_file_name_without_prefix + "." + source_lang + ".OSt"
+ ):
+ src = file
elif (
".".join(input_name[:-1]) + "." + remove_digits(input_name[-1])
== submission_file_name_without_prefix + "." + source_lang + ".OStt"
@@ -310,7 +315,7 @@ def split_gold_inputs_submission_in_working_directory(submission_file, gold_inpu
== submission_file_name_without_prefix + "." + source_lang + "." + target_lang + ".align"
):
aligns.append(file)
- return status, references, ostt, aligns
+ return status, src, references, ostt, aligns
def mwerSegmenter_error_message():
@@ -459,6 +464,10 @@ def extract_mt_gold_files_for_candidate(candidate_file, gold_inputs):
except:
eprint( "evaluation failed, the reference file does not exist for ", candidate_file[0])
error = 1
+ try:
+ gold_files["src"] = gold_inputs["src"]
+ except:
+ gold_files["src"] = ""
return gold_files, error
@@ -479,6 +488,10 @@ def extract_slt_gold_files_for_candidate(candidate_file, gold_inputs):
gold_files["align"] = gold_inputs["align"]
except:
gold_files["align"] = []
+ try:
+ gold_files["src"] = gold_inputs["src"]
+ except:
+ gold_files["src"] = ""
return gold_files, error
diff --git a/requirements.txt b/requirements.txt
index 5f18788..e837181 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,3 +6,4 @@ gitdir
jiwer
filelock
pytest
+unbabel-comet==2.0.2
diff --git a/setup.py b/setup.py
index be88404..15c9c6b 100644
--- a/setup.py
+++ b/setup.py
@@ -32,6 +32,7 @@
"jiwer",
"filelock",
"pytest",
+ "unbabel-comet"
],
url="https://github.com/ELITR/SLTev.git",
classifiers=[