diff --git a/skythought/evals/tasks/__init__.py b/skythought/evals/tasks/__init__.py index 3c60818a..39a115c7 100644 --- a/skythought/evals/tasks/__init__.py +++ b/skythought/evals/tasks/__init__.py @@ -7,6 +7,7 @@ from .base import ConversationType, TaskConfig, TaskHandler from .gpqa_diamond.gpqa_diamond_handler import GPQADiamondTaskHandler from .gsm8k.gsm8k_handler import GSM8KTaskHandler +from .liveaops.liveaops_handler import LiveAOPSTaskHandler from .livecodebench.livecodebench_handler import LiveCodeBenchTaskHandler from .math.math_handler import MathTaskHandler from .minervamath.minervamath_handler import MinervaMathTaskHandler @@ -33,6 +34,7 @@ "minervamath": MinervaMathTaskHandler, "olympiadbench_math": OlympiadBenchMathTaskHandler, "omni_math": OMNIMathTaskHandler, + "liveaops": LiveAOPSTaskHandler, } TASK_NAMES_TO_YAML = get_tasks(os.path.dirname(__file__)) diff --git a/skythought/evals/tasks/base.py b/skythought/evals/tasks/base.py index 9916e868..7a2162af 100644 --- a/skythought/evals/tasks/base.py +++ b/skythought/evals/tasks/base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional +from urllib.parse import urlparse import pandas as pd import yaml @@ -14,7 +15,7 @@ class TaskConfig(BaseModel): handler: str dataset_path: str dataset_subset: Optional[str] = None - dataset_split: str + dataset_split: Optional[str] = None dataset_kwargs: Dict[str, Any] = Field(default_factory=dict) question_key: str # Optional answer key for datasets with a single correct answer @@ -82,12 +83,28 @@ def make_conversations( return conversations def load_dataset(self, subset=None, split=None, **kwargs) -> HFDataset: - dataset = load_dataset( - path=self.task_config.dataset_path, - name=subset if subset else self.task_config.dataset_subset, - split=split if split else self.task_config.dataset_split, - **self.task_config.dataset_kwargs, - ) + # check if the path provided is a valid URL + parsed = urlparse(self.task_config.dataset_path) + if not parsed.scheme: + # HF dataset + dataset = load_dataset( + path=self.task_config.dataset_path, + name=subset if subset else self.task_config.dataset_subset, + split=split if split else self.task_config.dataset_split, + **self.task_config.dataset_kwargs, + ) + else: + # Try to load URL + # Only JSON supported for now + if split is not None or subset is not None: + raise ValueError( + "URL-based dataset does not support loading arguments like `split`, `subset`" + ) + # By default, Huggingface will create a DatasetDict object with "train" split + dataset = load_dataset("json", data_files=[self.task_config.dataset_path])[ + "train" + ] + # add an index column efficiently with map dataset = dataset.map(add_idx_map, with_indices=True) return dataset diff --git a/skythought/evals/tasks/liveaops/liveaops.yaml b/skythought/evals/tasks/liveaops/liveaops.yaml new file mode 100644 index 00000000..e2de36cd --- /dev/null +++ b/skythought/evals/tasks/liveaops/liveaops.yaml @@ -0,0 +1,8 @@ +handler: liveaops +dataset_path: https://livemathbench.github.io/data/LiveAoPSBench-2024.jsonl +dataset_subset: null # which subset on huggingface. Not applicable for a URL dataset +dataset_split: null # Rule based evaluation +question_key: question +answer_key: answer +templating_parameters: + template: "Return your final response within \\boxed{{}}. {question}" diff --git a/skythought/evals/tasks/liveaops/liveaops_handler.py b/skythought/evals/tasks/liveaops/liveaops_handler.py new file mode 100644 index 00000000..c5102e11 --- /dev/null +++ b/skythought/evals/tasks/liveaops/liveaops_handler.py @@ -0,0 +1,26 @@ +from skythought.evals.util.math_parsing_util import ( + extract_answer, + math_equal, + strip_answer_string, +) + +from ..math.math_handler import MathTaskHandler + + +class LiveAOPSTaskHandler(MathTaskHandler): + def generate_prompt(self, problem): + return self.task_config.templating_parameters["template"].format(**problem) + + def check_correctness(self, problem, generation): + # no preprocessing needed + answer = problem[self.task_config.answer_key] + pred = extract_answer(generation) + pred = strip_answer_string(pred) + return math_equal(pred, answer) + + def load_and_filter_dataset( + self, start, end, split=None, subset=None, difficulty=None + ): + assert difficulty is None, "LiveAOPS does not support `difficulty` argument" + dataset = self.load_dataset(subset=subset, split=split).to_pandas() + return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] diff --git a/skythought/evals/tasks/omni_math/omni_math.yaml b/skythought/evals/tasks/omni_math/omni_math.yaml index 2c80520c..6a9b2350 100644 --- a/skythought/evals/tasks/omni_math/omni_math.yaml +++ b/skythought/evals/tasks/omni_math/omni_math.yaml @@ -1,4 +1,4 @@ -handler: math +handler: omni_math dataset_path: "KbsdJames/Omni-MATH" # repo ID in huggingface dataset_subset: null # which subset on huggingface dataset_split: test_rule_based # Rule based evaluation