Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion due_evaluator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,21 @@
import argparse
import sys
from typing import Optional, Set
from collections import UserList
import json

from due_evaluator.due_evaluator import DueEvaluator
from due_evaluator.utils import property_scores_to_string


class ListWrapper(UserList):
"""to allow grouping by key appended to the metric name(support for GROUP-ANLS only) """
def __contains__(self, key):
if key in self.data:
return True
return key.startswith('GROUP-ANLS')


def parse_args():
"""Parse CLI arguments.

Expand All @@ -29,7 +38,7 @@ def parse_args():
parser.add_argument(
'--reference', '-r', type=argparse.FileType('r', encoding='utf-8'), required=True, help='Reference file',
)
parser.add_argument('--metric', '-m', type=str, default='F1', choices=['F1', 'MEAN-F1', 'ANLS', 'WTQ', 'GROUP-ANLS'])
parser.add_argument('--metric', '-m', type=str, default='F1', choices=ListWrapper(['F1', 'MEAN-F1', 'ANLS', 'WTQ', 'GROUP-ANLS']))
parser.add_argument(
'--return-score',
default='F1',
Expand Down
9 changes: 7 additions & 2 deletions due_evaluator/due_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,13 @@ def create_scorer(self) -> BaseScorer:
scorer = MeanFScorer()
elif self.metric == 'WTQ':
scorer = WtqScorer()
elif self.metric == 'GROUP-ANLS':
scorer = GroupAnlsScorer()
elif self.metric.startswith('GROUP-ANLS'):
"""Expecting key to group by appended to the metric with @"""
if '@' in self.metric:
group_by_key = self.metric.split('@')[1]
else:
group_by_key = None
scorer = GroupAnlsScorer(group_by_key)
elif self.metric == 'GEVAL':
scorer = GevalScorer()
else:
Expand Down
10 changes: 7 additions & 3 deletions due_evaluator/scorers/fscorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ def __eq__(self, other):
class FScorer(BaseScorer):
"""Corpus level F1 Score evaluator."""

def __init__(self):
def __init__(self, group_by_key=None):
"""Initialize class."""
self.__precision = []
self.__recall = []
self.group_by_key = group_by_key

@classmethod
def from_scorers(cls, scorers: List['FScorer']) -> 'FScorer':
Expand All @@ -49,7 +50,7 @@ def from_scorers(cls, scorers: List['FScorer']) -> 'FScorer':
new_scorer.__precision.extend(scorer.__precision)
new_scorer.__recall.extend(scorer.__recall)
return new_scorer

def flatten_annotations(self, annotations: List[Dict[str, Any]]) -> List[Annotation]:
flatten_items = []
for annotation in annotations:
Expand All @@ -59,7 +60,6 @@ def flatten_annotations(self, annotations: List[Dict[str, Any]]) -> List[Annotat
value=value['value'],
value_variants=value['value_variants'] if 'value_variants' in value else []))
return flatten_items


def add(self, out_items: Dict[str, Any], ref_items: Dict[str, Any]):
"""Add more items for computing corpus level scores.
Expand All @@ -72,6 +72,10 @@ def add(self, out_items: Dict[str, Any], ref_items: Dict[str, Any]):
prediction_annotations = self.flatten_annotations(out_items['annotations'])
ref_annotations = self.flatten_annotations(ref_items['annotations'])

if self.group_by_key is not None:
prediction_annotations = [el for el in prediction_annotations if el.key == self.group_by_key]
ref_annotations = [el for el in ref_annotations if el.key == self.group_by_key]

ref_annotations_copy = ref_annotations.copy()
indicators = []
for prediction in prediction_annotations:
Expand Down
18 changes: 12 additions & 6 deletions due_evaluator/scorers/group_anls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class FuzzyAnnotation:
key: str
value: str
value_variants: List[str] = field(default_factory=list)
threshold: float = 0.5

def __eq__(self, other):
def _is_float(val):
Expand All @@ -29,9 +30,9 @@ def _comp(val, pos) -> float:
return float(val == pos)
return textdistance.levenshtein.normalized_similarity(val, pos)

def _is_acceptable(val, possible_vals, threshold=.5):
def _is_acceptable(val, possible_vals, threshold=self.threshold):
best_score = max([_comp(val, pos) for pos in possible_vals] + [0.])
return best_score >= threshold
return best_score >= threshold

if self.key == other.key:
if _is_acceptable(other.value, [self.value]):
Expand All @@ -44,20 +45,25 @@ def _is_acceptable(val, possible_vals, threshold=.5):


class FuzzyFScorer(FScorer):
def __init__(self, group_by_key=None):
super().__init__(group_by_key)

def flatten_annotations(self, annotations: List[Dict[str, Any]]) -> List[FuzzyAnnotation]:
flatten_items = []
for annotation in annotations:
for value in annotation['values']:
flatten_items.append(FuzzyAnnotation(
key=annotation['key'],
value=value['value'],
value_variants=value['value_variants'] if 'value_variants' in value else []))
value_variants=value['value_variants'] if 'value_variants' in value else [],
threshold=0.5 if self.group_by_key is None else 1.0))
return flatten_items


class GroupAnlsScorer(BaseScorer):
def __init__(self):
def __init__(self, group_by_key):
self.__inner_scorer = FuzzyFScorer()
self.group_by_key = group_by_key

def pseudo_documents(self, doc: dict) -> List[dict]:
docs = []
Expand All @@ -80,7 +86,7 @@ def best_permutation(self, out_items: List[dict], ref_items: List[dict]):
for o in out_items:
row = []
for ri, r in enumerate(ref_items):
fscorer = FuzzyFScorer()
fscorer = FuzzyFScorer(self.group_by_key)
fscorer.add(o, r)
row.append(1 - fscorer.f_score())
matrix.append(row)
Expand All @@ -89,7 +95,7 @@ def best_permutation(self, out_items: List[dict], ref_items: List[dict]):
best_out = [out_items[i] for i in row_ind]
best_ref = [ref_items[i] for i in col_ind]
return (best_out, best_ref)

def pad(self, items: List[dict], target_length: int):
for _ in range(target_length - len(items)):
items.append({'name': '', 'annotations': []})
Expand Down