Skip to content

Commit d7762dd

Browse files
committed
support WMT-ST QA format
1 parent de14fc0 commit d7762dd

7 files changed

Lines changed: 949 additions & 146 deletions

File tree

leaderboard/models.py

Lines changed: 95 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from leaderboard.utils import analyze_jsonl_file, process_jsonl_to_text
3131
from leaderboard.utils import analyze_json_file, process_json_to_text
3232
from leaderboard.utils import detect_jsonl_format
33-
from leaderboard.utils import JSONL_WMT_ST_MT_FORMAT, JSONL_WMT_GENMT_FORMAT
33+
from leaderboard.utils import JSONL_WMT_ST_MT_FORMAT, JSONL_WMT_ST_QA_FORMAT, JSONL_WMT_GENMT_FORMAT
3434
from ocelot.settings import MEDIA_ROOT
3535

3636
MAX_CODE_LENGTH = 10 # ISO 639 codes need 3 chars, but better add buffer
@@ -261,6 +261,26 @@
261261
"additionalProperties": True
262262
}
263263

264+
# requires "dataset_id" to start with "wmtslavicllm2025_"
265+
JSONL_WMT25_ST_QA_SCHEMA = {
266+
"$schema": "http://json-schema.org/draft-07/schema#",
267+
"title": "WMT25-ST QA JSONL entry",
268+
"type": "object",
269+
"properties": {
270+
"dataset_id": { "type": "string" },
271+
"correct_answers": { "type": "array", "items": { "type": "string" } },
272+
"pred": { "type": "string" },
273+
},
274+
"required": [
275+
"dataset_id",
276+
],
277+
#"anyOf": [
278+
# { "required": ["correct_answers"] },
279+
# { "required": ["pred"] }
280+
#],
281+
"additionalProperties": True
282+
}
283+
264284
# requires "dataset_id" to start with "wmtslavicllm2025_"
265285
JSONL_WMT25_ST_MT_SCHEMA = {
266286
"$schema": "http://json-schema.org/draft-07/schema#",
@@ -274,12 +294,12 @@
274294
"pred": { "type": "string" },
275295
},
276296
"required": [
277-
"dataset_id",
278-
],
279-
"anyOf": [
280-
{ "required": ["source"] },
281-
{ "required": ["pred"] }
297+
"dataset_id", "sent_id"
282298
],
299+
#"anyOf": [
300+
# { "required": ["source"] },
301+
# { "required": ["pred"] }
302+
#],
283303
"additionalProperties": True
284304
}
285305

@@ -420,9 +440,13 @@ def validate_jsonl_schema(json_file):
420440
return
421441

422442
# Detect format and choose appropriate schema
423-
# todo: this could be defined globally as a map
424-
is_st_mt_format = detect_jsonl_format(json_file, JSONL_WMT_ST_MT_FORMAT)
425-
schema = JSONL_WMT25_ST_MT_SCHEMA if is_st_mt_format else JSONL_WMT25_SCHEMA
443+
jsonl_format = detect_jsonl_format(json_file)
444+
if jsonl_format == JSONL_WMT_ST_MT_FORMAT:
445+
schema = JSONL_WMT25_ST_MT_SCHEMA
446+
elif jsonl_format == JSONL_WMT_ST_QA_FORMAT:
447+
schema = JSONL_WMT25_ST_QA_SCHEMA
448+
else:
449+
schema = JSONL_WMT25_SCHEMA
426450

427451
try:
428452
# Ensure we start at the beginning of the file
@@ -667,9 +691,26 @@ def validate_jsonl_src_testset(json_file):
667691

668692
json_file.seek(0)
669693
src_langs = set()
670-
671-
# Detect format
672-
is_st_mt_format = detect_jsonl_format(json_file, JSONL_WMT_ST_MT_FORMAT)
694+
695+
def _validate_jsonl_src(text, lineno, format):
696+
try:
697+
obj = json.loads(text)
698+
except json.JSONDecodeError as e:
699+
raise ValidationError(f'JSONL src test set invalid JSON at line {lineno}: {e}')
700+
701+
if format == JSONL_WMT_ST_MT_FORMAT:
702+
if not obj.get('source', ""):
703+
raise ValidationError(f'Missing "source" field at line {lineno} in JSONL src test set')
704+
elif format == JSONL_WMT_ST_QA_FORMAT:
705+
if not obj.get('correct_answers', []):
706+
raise ValidationError(f'Missing "correct_answers" field at line {lineno} in JSONL src test set')
707+
else:
708+
lang = obj.get('src_lang')
709+
if lang is None:
710+
raise ValidationError(f'Missing src_lang at line {lineno} in JSONL src test set')
711+
src_langs.add(lang)
712+
713+
jsonl_format = detect_jsonl_format(json_file)
673714

674715
# Handle compressed files
675716
if json_file.name.endswith('.jsonl.gz'):
@@ -688,62 +729,17 @@ def validate_jsonl_src_testset(json_file):
688729
text = line.strip()
689730
if not text:
690731
continue
691-
try:
692-
obj = json.loads(text)
693-
except json.JSONDecodeError as e:
694-
raise ValidationError(f'JSONL src test set invalid JSON at line {lineno}: {e}')
732+
_validate_jsonl_src(text, lineno, jsonl_format)
695733

696-
if is_st_mt_format:
697-
# For ST MT format, check for "source" field
698-
if not obj.get('source'):
699-
raise ValidationError(f'Missing "source" field at line {lineno} in JSONL src test set')
700-
# For ST MT format, we don't have explicit language fields
701-
# but we can derive from dataset_id
702-
dataset_id = obj.get('dataset_id', '')
703-
if dataset_id.startswith('wmtslavicllm2025_'):
704-
# Extract language pair from dataset_id (e.g., wmtslavicllm2025_de-dsb)
705-
lang_pair = dataset_id.replace('wmtslavicllm2025_', '')
706-
if '-' in lang_pair:
707-
src_lang = lang_pair.split('-')[0]
708-
src_langs.add(src_lang)
709-
else:
710-
# For standard WMT25 format
711-
lang = obj.get('src_lang')
712-
if lang is None:
713-
raise ValidationError(f'Missing src_lang at line {lineno} in JSONL src test set')
714-
src_langs.add(lang)
715734
else:
716735
# Handle uncompressed files
717736
for lineno, line in enumerate(json_file, start=1):
718737
text = line.decode('utf-8').strip() if isinstance(line, bytes) else line.strip()
719738
if not text:
720739
continue
721-
try:
722-
obj = json.loads(text)
723-
except json.JSONDecodeError as e:
724-
raise ValidationError(f'JSONL src test set invalid JSON at line {lineno}: {e}')
725-
726-
if is_st_mt_format:
727-
# For ST MT format, check for "source" field
728-
if not obj.get('source'):
729-
raise ValidationError(f'Missing "source" field at line {lineno} in JSONL src test set')
730-
# For ST MT format, we don't have explicit language fields
731-
# but we can derive from dataset_id
732-
dataset_id = obj.get('dataset_id', '')
733-
if dataset_id.startswith('wmtslavicllm2025_'):
734-
# Extract language pair from dataset_id (e.g., wmtslavicllm2025_de-dsb)
735-
lang_pair = dataset_id.replace('wmtslavicllm2025_', '')
736-
if '-' in lang_pair:
737-
src_lang = lang_pair.split('-')[0]
738-
src_langs.add(src_lang)
739-
else:
740-
# For standard WMT25 format
741-
lang = obj.get('src_lang')
742-
if lang is None:
743-
raise ValidationError(f'Missing src_lang at line {lineno} in JSONL src test set')
744-
src_langs.add(lang)
740+
_validate_jsonl_src(text, lineno, jsonl_format)
745741

746-
if not src_langs:
742+
if (jsonl_format not in [JSONL_WMT_ST_MT_FORMAT, JSONL_WMT_ST_QA_FORMAT] and not src_langs):
747743
raise ValidationError(f'No source language found in JSONL file {json_file.name}')
748744
json_file.seek(0)
749745

@@ -757,9 +753,28 @@ def validate_jsonl_ref_testset(json_file):
757753

758754
json_file.seek(0)
759755
ref_langs = set()
756+
757+
def _validate_jsonl_ref(text, lineno, format):
758+
try:
759+
obj = json.loads(text)
760+
except json.JSONDecodeError as e:
761+
raise ValidationError(f'JSONL ref test set invalid JSON at line {lineno}: {e}')
762+
763+
if format == JSONL_WMT_ST_MT_FORMAT:
764+
if not obj.get('target', ""):
765+
raise ValidationError(f'Missing "target" field at line {lineno} in JSONL ref test set')
766+
else:
767+
refs = obj.get('refs', [])
768+
if not refs:
769+
raise ValidationError(f'No refs array at line {lineno} in JSONL ref test set')
770+
for ref in refs:
771+
lang = ref.get('tgt_lang')
772+
if lang is None:
773+
raise ValidationError(f'Missing tgt_lang in refs at line {lineno}')
774+
ref_langs.add(lang)
760775

761776
# Detect format
762-
is_st_mt_format = detect_jsonl_format(json_file, JSONL_WMT_ST_MT_FORMAT)
777+
jsonl_format = detect_jsonl_format(json_file)
763778

764779
# Handle compressed files
765780
if json_file.name.endswith('.jsonl.gz'):
@@ -778,70 +793,17 @@ def validate_jsonl_ref_testset(json_file):
778793
text = line.strip()
779794
if not text:
780795
continue
781-
try:
782-
obj = json.loads(text)
783-
except json.JSONDecodeError as e:
784-
raise ValidationError(f'JSONL ref test set invalid JSON at line {lineno}: {e}')
796+
_validate_jsonl_ref(text, lineno, jsonl_format)
785797

786-
if is_st_mt_format:
787-
# For ST MT format, check for "target" field
788-
if not obj.get('target'):
789-
raise ValidationError(f'Missing "target" field at line {lineno} in JSONL ref test set')
790-
# For ST MT format, we don't have explicit language fields
791-
# but we can derive from dataset_id
792-
dataset_id = obj.get('dataset_id', '')
793-
if dataset_id.startswith('wmtslavicllm2025_'):
794-
# Extract language pair from dataset_id (e.g., wmtslavicllm2025_de-dsb)
795-
lang_pair = dataset_id.replace('wmtslavicllm2025_', '')
796-
if '-' in lang_pair:
797-
tgt_lang = lang_pair.split('-')[1]
798-
ref_langs.add(tgt_lang)
799-
else:
800-
# For standard WMT25 format
801-
refs = obj.get('refs')
802-
if not refs:
803-
raise ValidationError(f'No refs array at line {lineno} in JSONL ref test set')
804-
for ref in refs:
805-
lang = ref.get('tgt_lang')
806-
if lang is None:
807-
raise ValidationError(f'Missing tgt_lang in refs at line {lineno}')
808-
ref_langs.add(lang)
809798
else:
810799
# Handle uncompressed files
811800
for lineno, line in enumerate(json_file, start=1):
812801
text = line.decode('utf-8').strip() if isinstance(line, bytes) else line.strip()
813802
if not text:
814803
continue
815-
try:
816-
obj = json.loads(text)
817-
except json.JSONDecodeError as e:
818-
raise ValidationError(f'JSONL ref test set invalid JSON at line {lineno}: {e}')
819-
820-
if is_st_mt_format:
821-
# For ST MT format, check for "target" field
822-
if not obj.get('target'):
823-
raise ValidationError(f'Missing "target" field at line {lineno} in JSONL ref test set')
824-
# For ST MT format, we don't have explicit language fields
825-
# but we can derive from dataset_id
826-
dataset_id = obj.get('dataset_id', '')
827-
if dataset_id.startswith('wmtslavicllm2025_'):
828-
# Extract language pair from dataset_id (e.g., wmtslavicllm2025_de-dsb)
829-
lang_pair = dataset_id.replace('wmtslavicllm2025_', '')
830-
if '-' in lang_pair:
831-
tgt_lang = lang_pair.split('-')[1]
832-
ref_langs.add(tgt_lang)
833-
else:
834-
# For standard WMT25 format
835-
refs = obj.get('refs')
836-
if not refs:
837-
raise ValidationError(f'No refs array at line {lineno} in JSONL ref test set')
838-
for ref in refs:
839-
lang = ref.get('tgt_lang')
840-
if lang is None:
841-
raise ValidationError(f'Missing tgt_lang in refs at line {lineno}')
842-
ref_langs.add(lang)
843-
844-
if not ref_langs:
804+
_validate_jsonl_ref(text, lineno, jsonl_format)
805+
806+
if (jsonl_format not in [JSONL_WMT_ST_MT_FORMAT, JSONL_WMT_ST_QA_FORMAT] and not ref_langs):
845807
raise ValidationError(f'No reference languages found in JSONL file {json_file.name}')
846808
json_file.seek(0)
847809

@@ -854,9 +816,19 @@ def validate_jsonl_submission(json_file):
854816
validate_jsonl_schema(json_file)
855817
json_file.seek(0)
856818
has_hyps = False
857-
819+
820+
def _validate_jsonl_hyps(text, lineno, format):
821+
obj = json.loads(text)
822+
if format == JSONL_WMT_ST_MT_FORMAT or format == JSONL_WMT_ST_QA_FORMAT:
823+
hyps = obj.get('pred', "")
824+
if not hyps:
825+
raise ValidationError(f'Missing "pred" field at line {lineno} in JSONL submission')
826+
else:
827+
hyps = obj.get('hypothesis') or obj.get('hyps') or ""
828+
return bool(hyps)
829+
858830
# Detect format
859-
is_st_mt_format = detect_jsonl_format(json_file, JSONL_WMT_ST_MT_FORMAT)
831+
jsonl_format = detect_jsonl_format(json_file)
860832

861833
# Handle compressed files
862834
if json_file.name.endswith('.jsonl.gz'):
@@ -875,16 +847,7 @@ def validate_jsonl_submission(json_file):
875847
text = line.strip()
876848
if not text:
877849
continue
878-
obj = json.loads(text)
879-
880-
if is_st_mt_format:
881-
# For ST MT format, check for "pred" field
882-
hyps = obj.get('pred', "")
883-
else:
884-
# For standard WMT25 format
885-
hyps = obj.get('hypothesis') or obj.get('hyps') or ""
886-
887-
if hyps:
850+
if _validate_jsonl_hyps(text, lineno, jsonl_format):
888851
has_hyps = True
889852
break
890853
else:
@@ -893,26 +856,16 @@ def validate_jsonl_submission(json_file):
893856
text = line.decode('utf-8').strip() if isinstance(line, bytes) else line.strip()
894857
if not text:
895858
continue
896-
obj = json.loads(text)
897-
898-
if is_st_mt_format:
899-
# For ST MT format, check for "pred" field
900-
hyps = obj.get('pred', "")
901-
else:
902-
# For standard WMT25 format
903-
hyps = obj.get('hypothesis') or obj.get('hyps') or ""
904-
905-
if hyps:
859+
if _validate_jsonl_hyps(text, lineno, jsonl_format):
906860
has_hyps = True
907861
break
908862

909863
if not has_hyps:
910-
field_name = "pred" if is_st_mt_format else "hypothesis"
864+
field_name = "hypothesis" if jsonl_format is JSONL_WMT_GENMT_FORMAT else "pred"
911865
raise ValidationError(f'Could not find "{field_name}" node anywhere in the JSONL submission')
912866
json_file.seek(0)
913867

914868

915-
916869
def validate_team_name(value):
917870
"""Validates team name matches r'^[a-zA-Z0-9_\\- ]{2,32}$'."""
918871
valid_name = re.compile(r'^[a-zA-Z0-9_\- ]{2,32}$')

0 commit comments

Comments
 (0)