Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 155 additions & 1 deletion datafast/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,29 @@
from pydantic import BaseModel, Field
from pathlib import Path
from typing import Any, Optional
from datasets import Dataset
from datasets import Dataset, load_dataset
from huggingface_hub import HfApi
from datafast.llms import LLMProvider
from datafast.prompts import (
classification_prompts,
question_generation_prompts,
mcq_prompts,
text_prompts,
)
from datafast.schema.config import (
ClassificationConfig,
TextDatasetConfig,
UltraChatDatasetConfig,
MCQDatasetConfig,
)
from datafast.schema.data_rows import (
ChatRow,
TextClassificationRow,
LabelSource,
TextRow,
TextSource,
MCQRow,
MCQSource,
)
from datafast.expanders import expand_prompts
import os
Expand All @@ -31,6 +35,14 @@ class TextEntries(BaseModel):
entries: list[str] = Field(..., description="List of generated texts")


class QAEntry(BaseModel):
question: str = Field(..., description="Question")
answer: str = Field(..., description="Answer")

class QAEntries(BaseModel):
entries: list[QAEntry] = Field(..., description="List of generated QAs")


class UserQuestions(BaseModel):
questions: list[str] = Field(..., description="List of user questions")

Expand Down Expand Up @@ -526,3 +538,145 @@ def _get_default_simulated_assistant_prompt(self) -> str:

def _get_default_user_followup_prompt(self) -> str:
return question_generation_prompts.USER_FOLLOWUP_PROMPT_TEMPLATE


class MCQDataset(DatasetBase):
def __init__(self, config: MCQDatasetConfig):
super().__init__(config)
self.config = config

def generate(self, llms: list[LLMProvider]) -> "MCQDataset":
"""
Generate multiple choice questions by calling providers for questions and then for incorrect answers.

Args:
llms: List of LLM providers to use for generation. Must not be empty.

Raises:
ValueError: If no LLM providers are supplied or if required configuration is missing.
"""
if not llms:
raise ValueError("At least one LLM provider must be supplied")

# Load the dataset from Hugging Face
try:
hf_dataset = load_dataset(self.config.hf_dataset_name)
# Most datasets have a 'train' split, but fallback to first available split
split_names = list(hf_dataset.keys())
if not split_names:
raise ValueError(f"No splits found in dataset {self.config.hf_dataset_name}")

main_split = "train" if "train" in split_names else split_names[0]
dataset = hf_dataset[main_split]

# Limit the number of samples if specified
if self.config.sample_count is not None:
dataset = dataset.select(range(min(self.config.sample_count, len(dataset))))

except Exception as e:
raise ValueError(f"Error loading dataset {self.config.hf_dataset_name}: {e}")

# Get languages from config, default to English if not specified
languages = self.config.languages or {"en": "English"}

# For each document, generate questions and answers
for sample in dataset:
if self.config.text_column not in sample:
print(f"Warning: Column {self.config.text_column} not found in sample, skipping")
continue

document = sample[self.config.text_column]
if not document or len(document.strip()) < self.config.min_document_length: # Skip very short documents
continue
if len(document.strip()) > self.config.max_document_length: # Skip very long documents
continue

for lang_code, language_name in languages.items():
# 1. First call: Generate questions and correct answers
question_prompts = self.config.prompts or self._get_default_prompts()
question_prompts = [
prompt.format(
num_samples=self.config.num_samples_per_prompt,
language_name=language_name,
document=document,
)
for prompt in question_prompts
]

# Expand prompts with configured variations
question_expansions = expand_prompts(
prompt_templates=question_prompts,
**self.config.expansion.model_dump(),
)

# Process each expanded prompt
for expanded_prompt, meta in question_expansions:
for llm in llms:
# Use the first LLM provider to generate questions and correct answers
try:
# Generate questions with the correct answers
response = llm.generate(expanded_prompt, response_format=QAEntries)

for qa_entry in response.entries:
# Extract question and correct answer from the QAEntry
try:
# QAEntry already has question and answer fields
question_part = qa_entry.question
correct_answer = qa_entry.answer

# 2. Second call: Generate incorrect answers
distractor_prompt = self.config.distractor_prompt or self._get_distractor_prompt().format(
question=question_part,
correct_answer=correct_answer,
language_name=language_name,
)

try:
# Use TextEntries for the distractor response since we need a list of incorrect answers
distractor_response = llm.generate(
distractor_prompt, response_format=TextEntries
)

# Parse the incorrect answers
incorrect_answers = []
for entry in distractor_response.entries:
incorrect_answers.append(entry.strip())

if len(incorrect_answers) >= 3:
# Create MCQ row with the question, correct answer, and incorrect answers
row = MCQRow(
source_document=document,
question=question_part,
correct_answer=correct_answer,
incorrect_answer_1=incorrect_answers[0],
incorrect_answer_2=incorrect_answers[1],
incorrect_answer_3=incorrect_answers[2],
model_id=llm.model_id,
mcq_source=MCQSource.SYNTHETIC,
metadata={
"language": lang_code,
"source_dataset": self.config.hf_dataset_name,
},
)
self.data_rows.append(row)
else:
print(f"Warning: Not enough incorrect answers generated (got {len(incorrect_answers)}, need 3)")
except Exception as e:
print(f"Error generating distractors: {e}")
except Exception as e:
print(f"Error processing entry: {e}")
print(f" Generated total of {len(self.data_rows)} MCQs")
except Exception as e:
print(f"Error with llm provider {llm.name}: {e}")

# Final save at the end
self.to_jsonl(self.config.output_file)
return self

def _get_default_prompts(self) -> list[str]:
"""Return the default prompt templates for MCQ generation."""
return mcq_prompts.DEFAULT_TEMPLATES

def _get_distractor_prompt(self) -> str:
"""Return the prompt template for generating incorrect answers."""
return mcq_prompts.DISTRACTOR_TEMPLATE
54 changes: 54 additions & 0 deletions datafast/examples/test_mcq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Example script for generating a Multiple Choice Question dataset.
"""

import os
from datafast.schema.config import MCQDatasetConfig
from datafast.datasets import MCQDataset
from datafast.llms import OpenAIProvider, AnthropicProvider, GoogleProvider, HuggingFaceProvider


def main():
# 1. Define the configuration
config = MCQDatasetConfig(
hf_dataset_name="patrickfleith/space_engineering_environment_effects_texts", # Stanford Question Answering Dataset
text_column="text", # Column containing the text to generate questions from
sample_count=3, # Process only 5 samples for testing
num_samples_per_prompt=2,# Generate 2 questions per document
min_document_length=100, # Skip documents shorter than 100 chars
max_document_length=20000,# Skip documents longer than 20000 chars
output_file="mcq_test_dataset.jsonl",
)

# 2. Initialize LLM providers
providers = [
# OpenAIProvider(model_id="gpt-4o-mini"),
# AnthropicProvider(model_id="claude-3-5-haiku-latest"),
GoogleProvider(model_id="gemini-2.0-flash"),
]

# 3. Generate the dataset
dataset = MCQDataset(config)
dataset.generate(providers)

# 4. Print results summary
print(f"\nGenerated {len(dataset.data_rows)} MCQs")
print(f"Results saved to {config.output_file}")

# 5. Optional: Push to HF hub
USERNAME = "patrickfleith" # <--- Your hugging face username
DATASET_NAME = "mcq_test_dataset" # <--- Your hugging face dataset name
url = dataset.push_to_hub(
repo_id=f"{USERNAME}/{DATASET_NAME}",
train_size=0.7,
seed=20250125,
shuffle=True,
)
print(f"\nDataset pushed to Hugging Face Hub: {url}")


if __name__ == "__main__":
from dotenv import load_dotenv

load_dotenv("secrets.env")
main()
29 changes: 29 additions & 0 deletions datafast/prompts/mcq_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
DEFAULT_TEMPLATES = [
"""You are an expert at creating exam questions. Your task is to come up with {num_samples} \
difficult multiple choice questions written in {language_name} in relation to the following document along with the correct answer.
The question should be self-contained, short and answerable.
It is very important to have unique questions. No questions should be like 'what is X and what about Y?' or 'what is X and when did Y happen?'.
The answer must be short.
It must relate to the details of the document, but the question should contain enough context to be answerable for a person without the document.

### Document
{document}

Now come up with {num_samples} questions in relation to the document.
Make sure the questions are difficult, but answerable with a short answer.
Provide the correct answer for each question."""
]

DISTRACTOR_TEMPLATE = """
You are an expert in creating plausible but incorrect answers for multiple choice questions.

For the following question, and correct answer, generate 3 short, plausible but incorrect answers in {language_name}.
The incorrect answers should be wrong but not obviously wrong - they should be tempting distractors.
Do not provide explanations, just list the three incorrect answers one per line (without prefixing with A, B, C, or numbers etc.)

Question: {question}

Correct Answer: {correct_answer}

Now provide exactly 3 short, plausible but incorrect answers:
"""
77 changes: 77 additions & 0 deletions datafast/schema/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,80 @@ class UltraChatDatasetConfig(BaseModel):
default={"en": "English"},
description="Language ISO codes and their corresponding names",
)


class MCQDatasetConfig(BaseModel):
"""
Configuration for generating multiple choice questions from text in a Hugging Face dataset.
Each question has one correct answer and three plausible but incorrect answers.
"""
dataset_type: str = Field(default="mcq_dataset")

# Hugging Face dataset information
hf_dataset_name: str = Field(
..., # required field
description="Name of the Hugging Face dataset to use"
)

text_column: str = Field(
..., # required field
description="Column name containing the text to generate questions from"
)

# MCQ Generation parameters
num_samples_per_prompt: int = Field(
default=3,
description="Number of questions to generate for each text"
)

sample_count: Optional[int] = Field(
default=None,
description="Optional number of samples to process from the dataset"
)

min_document_length: int = Field(
default=100,
description="Minimum number of characters below which documents will be skipped"
)

max_document_length: int = Field(
default=10000,
description="Maximum number of characters above which documents will be skipped"
)

# Where to save the output
output_file: str = Field(
default="mcq_dataset.jsonl",
description="Path to save MCQ dataset results"
)

# Optional custom prompts
prompts: Optional[list[str]] = Field(
default=None,
description="Optional custom prompt templates"
)

distractor_prompt: Optional[str] = Field(
default=None,
description="Optional custom distractor prompt template"
)

# Standard config options
expansion: PromptExpansionConfig = PromptExpansionConfig()

languages: dict[str, str] = Field(
default={"en": "English"},
description="Language ISO codes and their corresponding names"
)

@field_validator("hf_dataset_name")
def validate_dataset_name(cls, v):
if not v:
raise ValueError("hf_dataset_name is required")
return v

@field_validator("text_column")
def validate_text_column(cls, v):
if not v:
raise ValueError("text_column is required")
return v
21 changes: 21 additions & 0 deletions datafast/schema/data_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ class TextSource(str, Enum):
CONSENSUS = "consensus"


class MCQSource(str, Enum):
SYNTHETIC = "synthetic"
VERIFIED = "verified"
HUMAN = "human"
CONSENSUS = "consensus"


LabelType = Union[str, list[str], list[int]]


Expand Down Expand Up @@ -52,3 +59,17 @@ class TextClassificationRow(BaseModel):
# System and metadata fields
uuid: UUID = Field(default_factory=uuid4)
metadata: dict[str, str] = Field(default_factory=dict)


class MCQRow(BaseModel):
source_document: str
question: str
correct_answer: str
incorrect_answer_1: str
incorrect_answer_2: str
incorrect_answer_3: str
model_id: Optional[str] = None
mcq_source: MCQSource = MCQSource.SYNTHETIC
uuid: UUID = Field(default_factory=uuid4)
metadata: dict[str, str] = Field(default_factory=dict)

Loading