diff --git a/src/ml_filter/data_models.py b/src/ml_filter/data_models.py new file mode 100644 index 00000000..6cfa8f05 --- /dev/null +++ b/src/ml_filter/data_models.py @@ -0,0 +1,86 @@ +from enum import Enum +from typing import Dict, Union + +from pydantic import BaseModel, Field + + +# Define DecodingStrategy Enum +class DecodingStrategy(str, Enum): + """Decoding strategies for text generation models""" + + GREEDY = "greedy" + BEAM_SEARCH = "beam_search" + TOP_K = "top_k" + TOP_P = "top_p" + + +# Base class for decoding strategy parameters +class DecodingParameters(BaseModel): + """Decoding strategy parameters""" + + strategy: DecodingStrategy + + +# Decoding strategy parameter classes +class GreedyParameters(DecodingParameters): + """Greedy decoding strategy parameters""" + + strategy: DecodingStrategy = Field(default=DecodingStrategy.GREEDY) + + +class BeamSearchParameters(DecodingParameters): + """Beam search decoding strategy parameters""" + + strategy: DecodingStrategy = Field(default=DecodingStrategy.BEAM_SEARCH) + num_beams: int = Field(..., gt=0, description="Number of beams must be greater than 0.") + early_stopping: bool + + +class TopKParameters(DecodingParameters): + """Top-K decoding strategy parameters""" + + strategy: DecodingStrategy = Field(default=DecodingStrategy.TOP_K) + top_k: int = Field(..., gt=0, description="Number of top candidates to consider. Must be greater than 0.") + temperature: float = Field(..., gt=0, description="Sampling temperature. Must be greater than 0.") + + +class TopPParameters(DecodingParameters): + """Top-P decoding strategy parameters""" + + strategy: DecodingStrategy = Field(default=DecodingStrategy.TOP_P) + top_p: float = Field( + ..., gt=0, le=1, description="Cumulative probability for nucleus sampling. Must be in the range (0, 1]." + ) + temperature: float = Field(..., gt=0, description="Sampling temperature. Must be greater than 0.") + + +# General Information about a document +class DocumentInfo(BaseModel): + """General information about a document""" + + document_id: str + prompt: str + prompt_lang: str + raw_data_path: str + model: str + decoding_parameters: Union[GreedyParameters, BeamSearchParameters, TopKParameters, TopPParameters] + + +class CorrelationMetrics(BaseModel): + """Correlation metrics for performance evaluation""" + + correlation: Dict[str, Dict[str, float]] # Correlation per ground truth approach + + +class TTestResults(BaseModel): + """T-Test results for performance evaluation""" + + t_test_p_values: Dict[str, float] # p-values for each ground truth approach + + +class StatisticReport(BaseModel): + """Complete statistical report combining various metrics""" + + document_info: DocumentInfo + correlation_metrics: CorrelationMetrics + t_test_results: TTestResults diff --git a/tests/test_data_models.py b/tests/test_data_models.py new file mode 100644 index 00000000..65a2ef45 --- /dev/null +++ b/tests/test_data_models.py @@ -0,0 +1,103 @@ +import pytest +from pydantic import ValidationError + +from ml_filter.data_models import ( + BeamSearchParameters, + CorrelationMetrics, + DecodingStrategy, + DocumentInfo, + GreedyParameters, + StatisticReport, + TopKParameters, + TopPParameters, + TTestResults, +) + + +def test_greedy_parameters(): + params = GreedyParameters() + assert params.strategy == DecodingStrategy.GREEDY + + +def test_beam_search_parameters(): + params = BeamSearchParameters(num_beams=10, early_stopping=False) + assert params.strategy == DecodingStrategy.BEAM_SEARCH + assert params.num_beams == 10 + assert not params.early_stopping + + +def test_top_k_parameters(): + params = TopKParameters(top_k=30, temperature=0.7) + assert params.strategy == DecodingStrategy.TOP_K + assert params.top_k == 30 + assert params.temperature == 0.7 + + +def test_top_p_parameters(): + params = TopPParameters(top_p=0.85, temperature=0.9) + assert params.strategy == DecodingStrategy.TOP_P + assert params.top_p == 0.85 + assert params.temperature == 0.9 + + +def test_invalid_decoding_parameters(): + with pytest.raises(ValidationError): + BeamSearchParameters(num_beams=-1, early_stopping=False) # Invalid num_beams + with pytest.raises(ValidationError): + TopKParameters(top_k=-5, temperature=0.7) # Invalid top_k + with pytest.raises(ValidationError): + TopPParameters(top_p=1.5, temperature=0.8) # Invalid top_p + + +def test_document_info_with_greedy(): + doc_info = DocumentInfo( + document_id="doc_001", + prompt="Asses the educational value of the text.", + prompt_lang="en", + raw_data_path="/path/to/raw_data.json", + model="test_model", + decoding_parameters=GreedyParameters(), + ) + assert doc_info.document_id == "doc_001" + assert doc_info.decoding_parameters.strategy == DecodingStrategy.GREEDY + + +def test_document_info_with_top_p(): + doc_info = DocumentInfo( + document_id="doc_002", + prompt="Asses, whether the text contains adult content.", + prompt_lang="en", + raw_data_path="/path/to/raw_data.json", + model="test_model", + decoding_parameters=TopPParameters(top_p=0.8, temperature=0.6), + ) + assert doc_info.document_id == "doc_002" + assert doc_info.decoding_parameters.top_p == 0.8 + assert doc_info.decoding_parameters.temperature == 0.6 + + +def test_statistic_report(): + doc_info = DocumentInfo( + document_id="doc_003", + prompt="Asses, whether the text contains chain of thoughts.", + prompt_lang="en", + raw_data_path="/path/to/raw_data.json", + model="test_model", + decoding_parameters=BeamSearchParameters(num_beams=5, early_stopping=True), + ) + correlation_metrics = CorrelationMetrics( + correlation={ + "average": {"pearson": 0.85, "spearman": 0.82}, + "min": {"pearson": 0.75, "spearman": 0.72}, + } + ) + t_test_results = TTestResults(t_test_p_values={"average": 0.03, "min": 0.05}) + report = StatisticReport( + document_info=doc_info, + correlation_metrics=correlation_metrics, + t_test_results=t_test_results, + ) + + assert report.document_info.document_id == "doc_003" + assert report.correlation_metrics.correlation["average"]["pearson"] == 0.85 + assert report.t_test_results.t_test_p_values["average"] == 0.03 diff --git a/tests/test_translate.py b/tests/test_translate.py index 096db068..135e4850 100644 --- a/tests/test_translate.py +++ b/tests/test_translate.py @@ -39,6 +39,8 @@ def test_translate_jsonl_to_multiple_languages( """Test the translate_jsonl_to_multiple_languages method.""" class MockTranslationClient: + name: str = "mock_translation_client" + def translate_text(self, text, source_language, target_language): return mock_translate_text(text, source_language, target_language) @@ -81,7 +83,7 @@ def supported_target_languages(self): # Verify output files for lang in target_languages: - output_file = output_folder / f"input_{lang}.jsonl" + output_file = output_folder / f"input_{lang}_{mock_client.name}.jsonl" assert output_file.exists() with open(output_file, "r", encoding="utf-8") as f: