diff --git a/datafast/datasets.py b/datafast/datasets.py index 712b090..7ade965 100644 --- a/datafast/datasets.py +++ b/datafast/datasets.py @@ -3,18 +3,20 @@ from pydantic import BaseModel, Field from pathlib import Path from typing import Any, Optional -from datasets import Dataset +from datasets import Dataset, load_dataset from huggingface_hub import HfApi from datafast.llms import LLMProvider from datafast.prompts import ( classification_prompts, question_generation_prompts, + mcq_prompts, text_prompts, ) from datafast.schema.config import ( ClassificationConfig, TextDatasetConfig, UltraChatDatasetConfig, + MCQDatasetConfig, ) from datafast.schema.data_rows import ( ChatRow, @@ -22,6 +24,8 @@ LabelSource, TextRow, TextSource, + MCQRow, + MCQSource, ) from datafast.expanders import expand_prompts import os @@ -31,6 +35,14 @@ class TextEntries(BaseModel): entries: list[str] = Field(..., description="List of generated texts") +class QAEntry(BaseModel): + question: str = Field(..., description="Question") + answer: str = Field(..., description="Answer") + +class QAEntries(BaseModel): + entries: list[QAEntry] = Field(..., description="List of generated QAs") + + class UserQuestions(BaseModel): questions: list[str] = Field(..., description="List of user questions") @@ -526,3 +538,145 @@ def _get_default_simulated_assistant_prompt(self) -> str: def _get_default_user_followup_prompt(self) -> str: return question_generation_prompts.USER_FOLLOWUP_PROMPT_TEMPLATE + + +class MCQDataset(DatasetBase): + def __init__(self, config: MCQDatasetConfig): + super().__init__(config) + self.config = config + + def generate(self, llms: list[LLMProvider]) -> "MCQDataset": + """ + Generate multiple choice questions by calling providers for questions and then for incorrect answers. + + Args: + llms: List of LLM providers to use for generation. Must not be empty. + + Raises: + ValueError: If no LLM providers are supplied or if required configuration is missing. + """ + if not llms: + raise ValueError("At least one LLM provider must be supplied") + + # Load the dataset from Hugging Face + try: + hf_dataset = load_dataset(self.config.hf_dataset_name) + # Most datasets have a 'train' split, but fallback to first available split + split_names = list(hf_dataset.keys()) + if not split_names: + raise ValueError(f"No splits found in dataset {self.config.hf_dataset_name}") + + main_split = "train" if "train" in split_names else split_names[0] + dataset = hf_dataset[main_split] + + # Limit the number of samples if specified + if self.config.sample_count is not None: + dataset = dataset.select(range(min(self.config.sample_count, len(dataset)))) + + except Exception as e: + raise ValueError(f"Error loading dataset {self.config.hf_dataset_name}: {e}") + + # Get languages from config, default to English if not specified + languages = self.config.languages or {"en": "English"} + + # For each document, generate questions and answers + for sample in dataset: + if self.config.text_column not in sample: + print(f"Warning: Column {self.config.text_column} not found in sample, skipping") + continue + + document = sample[self.config.text_column] + if not document or len(document.strip()) < self.config.min_document_length: # Skip very short documents + continue + if len(document.strip()) > self.config.max_document_length: # Skip very long documents + continue + + for lang_code, language_name in languages.items(): + # 1. First call: Generate questions and correct answers + question_prompts = self.config.prompts or self._get_default_prompts() + question_prompts = [ + prompt.format( + num_samples=self.config.num_samples_per_prompt, + language_name=language_name, + document=document, + ) + for prompt in question_prompts + ] + + # Expand prompts with configured variations + question_expansions = expand_prompts( + prompt_templates=question_prompts, + **self.config.expansion.model_dump(), + ) + + # Process each expanded prompt + for expanded_prompt, meta in question_expansions: + for llm in llms: + # Use the first LLM provider to generate questions and correct answers + try: + # Generate questions with the correct answers + response = llm.generate(expanded_prompt, response_format=QAEntries) + + for qa_entry in response.entries: + # Extract question and correct answer from the QAEntry + try: + # QAEntry already has question and answer fields + question_part = qa_entry.question + correct_answer = qa_entry.answer + + # 2. Second call: Generate incorrect answers + distractor_prompt = self.config.distractor_prompt or self._get_distractor_prompt().format( + question=question_part, + correct_answer=correct_answer, + language_name=language_name, + ) + + try: + # Use TextEntries for the distractor response since we need a list of incorrect answers + distractor_response = llm.generate( + distractor_prompt, response_format=TextEntries + ) + + # Parse the incorrect answers + incorrect_answers = [] + for entry in distractor_response.entries: + incorrect_answers.append(entry.strip()) + + if len(incorrect_answers) >= 3: + # Create MCQ row with the question, correct answer, and incorrect answers + row = MCQRow( + source_document=document, + question=question_part, + correct_answer=correct_answer, + incorrect_answer_1=incorrect_answers[0], + incorrect_answer_2=incorrect_answers[1], + incorrect_answer_3=incorrect_answers[2], + model_id=llm.model_id, + mcq_source=MCQSource.SYNTHETIC, + metadata={ + "language": lang_code, + "source_dataset": self.config.hf_dataset_name, + }, + ) + self.data_rows.append(row) + else: + print(f"Warning: Not enough incorrect answers generated (got {len(incorrect_answers)}, need 3)") + except Exception as e: + print(f"Error generating distractors: {e}") + except Exception as e: + print(f"Error processing entry: {e}") + print(f" Generated total of {len(self.data_rows)} MCQs") + except Exception as e: + print(f"Error with llm provider {llm.name}: {e}") + + # Final save at the end + self.to_jsonl(self.config.output_file) + return self + + def _get_default_prompts(self) -> list[str]: + """Return the default prompt templates for MCQ generation.""" + return mcq_prompts.DEFAULT_TEMPLATES + + def _get_distractor_prompt(self) -> str: + """Return the prompt template for generating incorrect answers.""" + return mcq_prompts.DISTRACTOR_TEMPLATE diff --git a/datafast/examples/test_mcq.py b/datafast/examples/test_mcq.py new file mode 100644 index 0000000..a798616 --- /dev/null +++ b/datafast/examples/test_mcq.py @@ -0,0 +1,54 @@ +""" +Example script for generating a Multiple Choice Question dataset. +""" + +import os +from datafast.schema.config import MCQDatasetConfig +from datafast.datasets import MCQDataset +from datafast.llms import OpenAIProvider, AnthropicProvider, GoogleProvider, HuggingFaceProvider + + +def main(): + # 1. Define the configuration + config = MCQDatasetConfig( + hf_dataset_name="patrickfleith/space_engineering_environment_effects_texts", # Stanford Question Answering Dataset + text_column="text", # Column containing the text to generate questions from + sample_count=3, # Process only 5 samples for testing + num_samples_per_prompt=2,# Generate 2 questions per document + min_document_length=100, # Skip documents shorter than 100 chars + max_document_length=20000,# Skip documents longer than 20000 chars + output_file="mcq_test_dataset.jsonl", + ) + + # 2. Initialize LLM providers + providers = [ + # OpenAIProvider(model_id="gpt-4o-mini"), + # AnthropicProvider(model_id="claude-3-5-haiku-latest"), + GoogleProvider(model_id="gemini-2.0-flash"), + ] + + # 3. Generate the dataset + dataset = MCQDataset(config) + dataset.generate(providers) + + # 4. Print results summary + print(f"\nGenerated {len(dataset.data_rows)} MCQs") + print(f"Results saved to {config.output_file}") + + # 5. Optional: Push to HF hub + USERNAME = "patrickfleith" # <--- Your hugging face username + DATASET_NAME = "mcq_test_dataset" # <--- Your hugging face dataset name + url = dataset.push_to_hub( + repo_id=f"{USERNAME}/{DATASET_NAME}", + train_size=0.7, + seed=20250125, + shuffle=True, + ) + print(f"\nDataset pushed to Hugging Face Hub: {url}") + + +if __name__ == "__main__": + from dotenv import load_dotenv + + load_dotenv("secrets.env") + main() diff --git a/datafast/prompts/mcq_prompts.py b/datafast/prompts/mcq_prompts.py new file mode 100644 index 0000000..4338396 --- /dev/null +++ b/datafast/prompts/mcq_prompts.py @@ -0,0 +1,29 @@ +DEFAULT_TEMPLATES = [ + """You are an expert at creating exam questions. Your task is to come up with {num_samples} \ +difficult multiple choice questions written in {language_name} in relation to the following document along with the correct answer. +The question should be self-contained, short and answerable. +It is very important to have unique questions. No questions should be like 'what is X and what about Y?' or 'what is X and when did Y happen?'. +The answer must be short. +It must relate to the details of the document, but the question should contain enough context to be answerable for a person without the document. + +### Document +{document} + +Now come up with {num_samples} questions in relation to the document. +Make sure the questions are difficult, but answerable with a short answer. +Provide the correct answer for each question.""" +] + +DISTRACTOR_TEMPLATE = """ +You are an expert in creating plausible but incorrect answers for multiple choice questions. + +For the following question, and correct answer, generate 3 short, plausible but incorrect answers in {language_name}. +The incorrect answers should be wrong but not obviously wrong - they should be tempting distractors. +Do not provide explanations, just list the three incorrect answers one per line (without prefixing with A, B, C, or numbers etc.) + +Question: {question} + +Correct Answer: {correct_answer} + +Now provide exactly 3 short, plausible but incorrect answers: +""" diff --git a/datafast/schema/config.py b/datafast/schema/config.py index 660c141..44e1cf1 100644 --- a/datafast/schema/config.py +++ b/datafast/schema/config.py @@ -185,3 +185,80 @@ class UltraChatDatasetConfig(BaseModel): default={"en": "English"}, description="Language ISO codes and their corresponding names", ) + + +class MCQDatasetConfig(BaseModel): + """ + Configuration for generating multiple choice questions from text in a Hugging Face dataset. + Each question has one correct answer and three plausible but incorrect answers. + """ + dataset_type: str = Field(default="mcq_dataset") + + # Hugging Face dataset information + hf_dataset_name: str = Field( + ..., # required field + description="Name of the Hugging Face dataset to use" + ) + + text_column: str = Field( + ..., # required field + description="Column name containing the text to generate questions from" + ) + + # MCQ Generation parameters + num_samples_per_prompt: int = Field( + default=3, + description="Number of questions to generate for each text" + ) + + sample_count: Optional[int] = Field( + default=None, + description="Optional number of samples to process from the dataset" + ) + + min_document_length: int = Field( + default=100, + description="Minimum number of characters below which documents will be skipped" + ) + + max_document_length: int = Field( + default=10000, + description="Maximum number of characters above which documents will be skipped" + ) + + # Where to save the output + output_file: str = Field( + default="mcq_dataset.jsonl", + description="Path to save MCQ dataset results" + ) + + # Optional custom prompts + prompts: Optional[list[str]] = Field( + default=None, + description="Optional custom prompt templates" + ) + + distractor_prompt: Optional[str] = Field( + default=None, + description="Optional custom distractor prompt template" + ) + + # Standard config options + expansion: PromptExpansionConfig = PromptExpansionConfig() + + languages: dict[str, str] = Field( + default={"en": "English"}, + description="Language ISO codes and their corresponding names" + ) + + @field_validator("hf_dataset_name") + def validate_dataset_name(cls, v): + if not v: + raise ValueError("hf_dataset_name is required") + return v + + @field_validator("text_column") + def validate_text_column(cls, v): + if not v: + raise ValueError("text_column is required") + return v diff --git a/datafast/schema/data_rows.py b/datafast/schema/data_rows.py index f5dee0b..76b1494 100644 --- a/datafast/schema/data_rows.py +++ b/datafast/schema/data_rows.py @@ -18,6 +18,13 @@ class TextSource(str, Enum): CONSENSUS = "consensus" +class MCQSource(str, Enum): + SYNTHETIC = "synthetic" + VERIFIED = "verified" + HUMAN = "human" + CONSENSUS = "consensus" + + LabelType = Union[str, list[str], list[int]] @@ -52,3 +59,17 @@ class TextClassificationRow(BaseModel): # System and metadata fields uuid: UUID = Field(default_factory=uuid4) metadata: dict[str, str] = Field(default_factory=dict) + + +class MCQRow(BaseModel): + source_document: str + question: str + correct_answer: str + incorrect_answer_1: str + incorrect_answer_2: str + incorrect_answer_3: str + model_id: Optional[str] = None + mcq_source: MCQSource = MCQSource.SYNTHETIC + uuid: UUID = Field(default_factory=uuid4) + metadata: dict[str, str] = Field(default_factory=dict) + \ No newline at end of file