diff --git a/README.md b/README.md index f154ec2..584b2b6 100644 --- a/README.md +++ b/README.md @@ -269,6 +269,16 @@ tot sacreBLEU docAsAWhole 32.786 avg sacreBLEU mwerSegmenter 25.850 ``` +If you want to calculate the COMET score as well, you need to include the ost file in the source language as src as shown below: +''' +MTeval -i sample-data/sample.en.cs.mt sample-data/sample.en.OSt sample-data/sample.cs.OSt -f mt src ref +''' +This would add an additional line in the output reporting the COMET score: +''' +tot COMET docAsWhole 0.770 +''' +This is optional. + #### Evaluating SLT Spoken language translation evaluates "machine translation in time". So a time-stamped MT output (``slt``) is compared with the reference translation (non-timed, ``ref``) and the timing of the golden transcript (``ostt``). @@ -293,6 +303,7 @@ tot Flicker count_changed_content 23 tot sacreBLEU docAsAWhole 32.786 ... ``` +Similar to MTeval, to calculate COMET score, you need to include the ost file in the source language. #### Evaluating ASR @@ -392,6 +403,9 @@ Usage: SLTIndexParser path_to_index_file path_to_dataset ``` +5. It must be noted that a stable internet connection is necessary in order to download the COMET model to the local +system to calculate the COMET score. + ## Terminology and Abbreviations * OSt ... original speech manually transcribed (i.e. golden transcript) diff --git a/SLTev/ASReval.py b/SLTev/ASReval.py index 502d04c..d4c6e8b 100644 --- a/SLTev/ASReval.py +++ b/SLTev/ASReval.py @@ -168,7 +168,8 @@ def main(input_files=[], file_formats=[], arguments={}): 'ostt': read_ostt_file(gold_files["ostt"][0]), 'references': read_references(gold_files["ost"]), 'SLTev_home': SLTev_home, - 'candidate': read_candidate_file(candidate_file[0]) + 'candidate': read_candidate_file(candidate_file[0]), + 'src': read_references(gold_files["ost"]) } _ = check_time_stamp_candiates_format(candidate_file[0], split_token) # submission checking diff --git a/SLTev/MTeval.py b/SLTev/MTeval.py index 6457c4c..cfe1597 100644 --- a/SLTev/MTeval.py +++ b/SLTev/MTeval.py @@ -82,6 +82,7 @@ def main(inputs=[], file_formats=[], arguments={}): evaluation_object = { 'references': read_references(gold_files["ref"]), 'mt': read_candidate_file(candidate_file[0]), + 'src': read_references(gold_files["src"]), 'SLTev_home': sltev_home, } diff --git a/SLTev/SLTev.py b/SLTev/SLTev.py index 31280c0..62b1bee 100644 --- a/SLTev/SLTev.py +++ b/SLTev/SLTev.py @@ -361,7 +361,7 @@ def slt_submission_evaluation(args, inputs_object): def build_input_fils_and_file_formats(submission_file, gold_input_files): - status, references, ostt, aligns = split_gold_inputs_submission_in_working_directory(submission_file, gold_input_files) + status, src, references, ostt, aligns = split_gold_inputs_submission_in_working_directory(submission_file, gold_input_files) input_files = [submission_file] file_formats = [remove_digits(status)] @@ -375,6 +375,9 @@ def build_input_fils_and_file_formats(submission_file, gold_input_files): if ostt != "": input_files.append(ostt) file_formats.append("ostt") + if src != "": + input_files.append(src) + file_formats.append("src") for align_file in aligns: input_files.append(align_file) diff --git a/SLTev/SLTeval.py b/SLTev/SLTeval.py index fd7170b..035765b 100644 --- a/SLTev/SLTeval.py +++ b/SLTev/SLTeval.py @@ -97,6 +97,7 @@ def main(input_files=[], file_formats=[], arguments={}): 'references': read_references(gold_files["ref"]), 'candidate': read_candidate_file(candidate_file[0]), 'align': gold_files["align"], + 'src': read_references(gold_files["src"]), 'SLTev_home': sltev_home, } diff --git a/SLTev/evaluator.py b/SLTev/evaluator.py index 67ba131..18879c4 100644 --- a/SLTev/evaluator.py +++ b/SLTev/evaluator.py @@ -9,7 +9,7 @@ from flicker_modules import calc_revise_count, calc_flicker_score from flicker_modules import calc_average_flickers_per_sentence, calc_average_flickers_per_tokens from quality_modules import calc_bleu_score_documentlevel, calc_bleu_score_segmenterlevel -from quality_modules import calc_bleu_score_timespanlevel +from quality_modules import calc_bleu_score_timespanlevel, calculate_comet_score from utilities import mwerSegmenter_error_message, eprint from files_modules import read_alignment_file @@ -268,6 +268,8 @@ def normal_evaluation_without_parity(inputs_object): references_statistical_info(references) # print statistical info average_refernces_token_count = get_average_references_token_count(references) candidate_sentences = inputs_object.get('candidate') + mt_sentences = inputs_object.get('mt') + src_file = inputs_object.get('src') evaluation_object = { 'candidate_sentences': candidate_sentences, @@ -283,6 +285,8 @@ def normal_evaluation_without_parity(inputs_object): # bleu score evaluation documantlevel_bleu_score_evaluation(references, candidate_sentences) wordbased_segmenter_bleu_score_evaluation(evaluation_object) + if src_file != '' and src_file != []: + comet_score_evaluation(src_file, mt_sentences, references) #flicker evaluation print("tot Flicker count_changed_Tokens ", int(calc_revise_count(candidate_sentences))) @@ -338,6 +342,7 @@ def normal_timestamp_evaluation(inputs_object): average_refernces_token_count = get_average_references_token_count(references) candidate_sentences = inputs_object.get('candidate') OStt_sentences = inputs_object.get('ostt') + src_file = inputs_object.get('src') print_ostt_duration(OStt_sentences) Ts = [] for reference in references: @@ -373,6 +378,8 @@ def normal_timestamp_evaluation(inputs_object): documantlevel_bleu_score_evaluation(references, candidate_sentences) wordbased_segmenter_bleu_score_evaluation(evaluation_object) time_span_bleu_score_evaluation(evaluation_object) + if src_file != '' and src_file != []: + comet_score_evaluation(src_file, candidate_sentences, references) #flicker evaluation print("tot Flicker count_changed_Tokens ", int(calc_revise_count(candidate_sentences))) print("tot Flicker count_changed_content ", int(calc_flicker_score(candidate_sentences))) @@ -385,6 +392,13 @@ def normal_timestamp_evaluation(inputs_object): str("{0:.3f}".format(round(calc_average_flickers_per_tokens(candidate_sentences), 3))), ) +def comet_score_evaluation(src_file, mt_sentences, references): + comet_score, success = calculate_comet_score(src_file, mt_sentences, references) + if success: + print( + "tot COMET docAsWhole ", + str("{0:.3f}".format(round(comet_score, 3))), + ) def simple_mt_evaluation(inputs_object): current_path = os.getcwd() @@ -413,6 +427,7 @@ def normal_mt_evaluation(inputs_object): references_statistical_info(references) # print statistical info average_refernces_token_count = get_average_references_token_count(references) mt_sentences = inputs_object.get('mt') + src_file = inputs_object.get('src') evaluation_object = { 'candidate_sentences': mt_sentences, @@ -427,6 +442,8 @@ def normal_mt_evaluation(inputs_object): # bleu score evaluation documantlevel_bleu_score_evaluation(references, mt_sentences) wordbased_segmenter_bleu_score_evaluation(evaluation_object) + if src_file != '' and src_file != []: + comet_score_evaluation(src_file, mt_sentences, references) diff --git a/SLTev/quality_modules.py b/SLTev/quality_modules.py index 6fbffb1..70367c4 100644 --- a/SLTev/quality_modules.py +++ b/SLTev/quality_modules.py @@ -1,7 +1,9 @@ #!/usr/bin/env python import sacrebleu +from comet import download_model, load_from_checkpoint from files_modules import quality_segmenter +from utilities import eprint def calc_bleu_score_documentlevel(references, candiate_sentences): @@ -162,3 +164,37 @@ def calc_bleu_score_timespanlevel(evaluation_object): ) return bleu_scores, avg_SacreBleu +def calculate_comet_score(sources, candidates, references=None): + try: + model_name = "Unbabel/wmt22-comet-da" + merge_mt_sentences = [] + for i in range(len(candidates)): + mt = candidates[i][-1][3:-1] + merge_mt_sentences += mt + + merge_references_sentences = [] + for ref in references: + l = [] + for sentence in ref: + l.append(" ".join(sentence[:-1])) + merge_references_sentences.append(l) + + merge_src_sentences = [] + for src in sources: + l = [] + for sentence in src: + l.append(" ".join(sentence[:-1])) + merge_src_sentences.append(l) + + ref = [" ".join(i) for i in merge_references_sentences] + mt = [" ".join(merge_mt_sentences[:])] + src = [" ".join(i) for i in merge_src_sentences] + data = [{'src': x[0], 'mt': x[1], 'ref': x[2]} for x in zip(src, mt, ref)] + + model_path = download_model(model_name) + model = load_from_checkpoint(model_path) + model_output = model.predict(data, batch_size=8, gpus=0) + return model_output['system_score'] * 100, True + except: + eprint("Unable to calculate COMET score since there is no internet connection to download the model.") + return 0, False diff --git a/SLTev/utilities.py b/SLTev/utilities.py index b83d8e6..669ab83 100644 --- a/SLTev/utilities.py +++ b/SLTev/utilities.py @@ -281,7 +281,7 @@ def split_gold_inputs_submission_in_working_directory(submission_file, gold_inpu :return tt, ostt, align: OSt, OStt, align files according to the submission file """ - status, ostt = "", "" + status, ostt, src = "", "", "" references, aligns = list(), list() submission_file_name = os.path.split(submission_file)[1] @@ -300,6 +300,11 @@ def split_gold_inputs_submission_in_working_directory(submission_file, gold_inpu == submission_file_name_without_prefix + "." + target_lang + ".OSt" ): references.append(file) + elif ( + ".".join(input_name[:-1]) + "." + remove_digits(input_name[-1]) + == submission_file_name_without_prefix + "." + source_lang + ".OSt" + ): + src = file elif ( ".".join(input_name[:-1]) + "." + remove_digits(input_name[-1]) == submission_file_name_without_prefix + "." + source_lang + ".OStt" @@ -310,7 +315,7 @@ def split_gold_inputs_submission_in_working_directory(submission_file, gold_inpu == submission_file_name_without_prefix + "." + source_lang + "." + target_lang + ".align" ): aligns.append(file) - return status, references, ostt, aligns + return status, src, references, ostt, aligns def mwerSegmenter_error_message(): @@ -459,6 +464,10 @@ def extract_mt_gold_files_for_candidate(candidate_file, gold_inputs): except: eprint( "evaluation failed, the reference file does not exist for ", candidate_file[0]) error = 1 + try: + gold_files["src"] = gold_inputs["src"] + except: + gold_files["src"] = "" return gold_files, error @@ -479,6 +488,10 @@ def extract_slt_gold_files_for_candidate(candidate_file, gold_inputs): gold_files["align"] = gold_inputs["align"] except: gold_files["align"] = [] + try: + gold_files["src"] = gold_inputs["src"] + except: + gold_files["src"] = "" return gold_files, error diff --git a/requirements.txt b/requirements.txt index 5f18788..e837181 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ gitdir jiwer filelock pytest +unbabel-comet==2.0.2 diff --git a/setup.py b/setup.py index be88404..15c9c6b 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ "jiwer", "filelock", "pytest", + "unbabel-comet" ], url="https://github.com/ELITR/SLTev.git", classifiers=[