diff --git a/due_evaluator/__main__.py b/due_evaluator/__main__.py index b25470b..e690426 100755 --- a/due_evaluator/__main__.py +++ b/due_evaluator/__main__.py @@ -4,12 +4,21 @@ import argparse import sys from typing import Optional, Set +from collections import UserList import json from due_evaluator.due_evaluator import DueEvaluator from due_evaluator.utils import property_scores_to_string +class ListWrapper(UserList): + """to allow grouping by key appended to the metric name(support for GROUP-ANLS only) """ + def __contains__(self, key): + if key in self.data: + return True + return key.startswith('GROUP-ANLS') + + def parse_args(): """Parse CLI arguments. @@ -29,7 +38,7 @@ def parse_args(): parser.add_argument( '--reference', '-r', type=argparse.FileType('r', encoding='utf-8'), required=True, help='Reference file', ) - parser.add_argument('--metric', '-m', type=str, default='F1', choices=['F1', 'MEAN-F1', 'ANLS', 'WTQ', 'GROUP-ANLS']) + parser.add_argument('--metric', '-m', type=str, default='F1', choices=ListWrapper(['F1', 'MEAN-F1', 'ANLS', 'WTQ', 'GROUP-ANLS'])) parser.add_argument( '--return-score', default='F1', diff --git a/due_evaluator/due_evaluator.py b/due_evaluator/due_evaluator.py index ba0cc39..c2fa603 100644 --- a/due_evaluator/due_evaluator.py +++ b/due_evaluator/due_evaluator.py @@ -79,8 +79,13 @@ def create_scorer(self) -> BaseScorer: scorer = MeanFScorer() elif self.metric == 'WTQ': scorer = WtqScorer() - elif self.metric == 'GROUP-ANLS': - scorer = GroupAnlsScorer() + elif self.metric.startswith('GROUP-ANLS'): + """Expecting key to group by appended to the metric with @""" + if '@' in self.metric: + group_by_key = self.metric.split('@')[1] + else: + group_by_key = None + scorer = GroupAnlsScorer(group_by_key) elif self.metric == 'GEVAL': scorer = GevalScorer() else: diff --git a/due_evaluator/scorers/fscorer.py b/due_evaluator/scorers/fscorer.py index 40e9cc3..3d4b6d0 100644 --- a/due_evaluator/scorers/fscorer.py +++ b/due_evaluator/scorers/fscorer.py @@ -28,10 +28,11 @@ def __eq__(self, other): class FScorer(BaseScorer): """Corpus level F1 Score evaluator.""" - def __init__(self): + def __init__(self, group_by_key=None): """Initialize class.""" self.__precision = [] self.__recall = [] + self.group_by_key = group_by_key @classmethod def from_scorers(cls, scorers: List['FScorer']) -> 'FScorer': @@ -49,7 +50,7 @@ def from_scorers(cls, scorers: List['FScorer']) -> 'FScorer': new_scorer.__precision.extend(scorer.__precision) new_scorer.__recall.extend(scorer.__recall) return new_scorer - + def flatten_annotations(self, annotations: List[Dict[str, Any]]) -> List[Annotation]: flatten_items = [] for annotation in annotations: @@ -59,7 +60,6 @@ def flatten_annotations(self, annotations: List[Dict[str, Any]]) -> List[Annotat value=value['value'], value_variants=value['value_variants'] if 'value_variants' in value else [])) return flatten_items - def add(self, out_items: Dict[str, Any], ref_items: Dict[str, Any]): """Add more items for computing corpus level scores. @@ -72,6 +72,10 @@ def add(self, out_items: Dict[str, Any], ref_items: Dict[str, Any]): prediction_annotations = self.flatten_annotations(out_items['annotations']) ref_annotations = self.flatten_annotations(ref_items['annotations']) + if self.group_by_key is not None: + prediction_annotations = [el for el in prediction_annotations if el.key == self.group_by_key] + ref_annotations = [el for el in ref_annotations if el.key == self.group_by_key] + ref_annotations_copy = ref_annotations.copy() indicators = [] for prediction in prediction_annotations: diff --git a/due_evaluator/scorers/group_anls.py b/due_evaluator/scorers/group_anls.py index 01e1386..4b71395 100644 --- a/due_evaluator/scorers/group_anls.py +++ b/due_evaluator/scorers/group_anls.py @@ -15,6 +15,7 @@ class FuzzyAnnotation: key: str value: str value_variants: List[str] = field(default_factory=list) + threshold: float = 0.5 def __eq__(self, other): def _is_float(val): @@ -29,9 +30,9 @@ def _comp(val, pos) -> float: return float(val == pos) return textdistance.levenshtein.normalized_similarity(val, pos) - def _is_acceptable(val, possible_vals, threshold=.5): + def _is_acceptable(val, possible_vals, threshold=self.threshold): best_score = max([_comp(val, pos) for pos in possible_vals] + [0.]) - return best_score >= threshold + return best_score >= threshold if self.key == other.key: if _is_acceptable(other.value, [self.value]): @@ -44,6 +45,9 @@ def _is_acceptable(val, possible_vals, threshold=.5): class FuzzyFScorer(FScorer): + def __init__(self, group_by_key=None): + super().__init__(group_by_key) + def flatten_annotations(self, annotations: List[Dict[str, Any]]) -> List[FuzzyAnnotation]: flatten_items = [] for annotation in annotations: @@ -51,13 +55,15 @@ def flatten_annotations(self, annotations: List[Dict[str, Any]]) -> List[FuzzyAn flatten_items.append(FuzzyAnnotation( key=annotation['key'], value=value['value'], - value_variants=value['value_variants'] if 'value_variants' in value else [])) + value_variants=value['value_variants'] if 'value_variants' in value else [], + threshold=0.5 if self.group_by_key is None else 1.0)) return flatten_items class GroupAnlsScorer(BaseScorer): - def __init__(self): + def __init__(self, group_by_key): self.__inner_scorer = FuzzyFScorer() + self.group_by_key = group_by_key def pseudo_documents(self, doc: dict) -> List[dict]: docs = [] @@ -80,7 +86,7 @@ def best_permutation(self, out_items: List[dict], ref_items: List[dict]): for o in out_items: row = [] for ri, r in enumerate(ref_items): - fscorer = FuzzyFScorer() + fscorer = FuzzyFScorer(self.group_by_key) fscorer.add(o, r) row.append(1 - fscorer.f_score()) matrix.append(row) @@ -89,7 +95,7 @@ def best_permutation(self, out_items: List[dict], ref_items: List[dict]): best_out = [out_items[i] for i in row_ind] best_ref = [ref_items[i] for i in col_ind] return (best_out, best_ref) - + def pad(self, items: List[dict], target_length: int): for _ in range(target_length - len(items)): items.append({'name': '', 'annotations': []})