From e332ffd81feeeebc2cfb4b7ab10de26dce7190d9 Mon Sep 17 00:00:00 2001 From: Jeremy Vachier <89128100+jvachier@users.noreply.github.com> Date: Sun, 1 Mar 2026 12:12:48 +0100 Subject: [PATCH 1/3] Adding mypy. --- .github/workflows/test.yaml | 5 + .gitignore | 1 + .pre-commit-config.yaml | 8 + Makefile | 8 +- README.md | 22 +++ gradio_apps/README.md | 10 ++ pyproject.toml | 40 +++++ src/modules/data_preprocess_nltk.py | 4 +- src/modules/data_processor.py | 13 +- src/modules/load_data.py | 4 +- src/modules/model_bert_other.py | 4 +- src/modules/model_sentiment_analysis.py | 39 ++--- src/modules/optuna_transformer.py | 4 +- src/modules/sentiment_analysis_utils.py | 2 +- src/modules/speech_to_text.py | 2 +- src/modules/transformer_components.py | 6 +- src/modules/utils.py | 16 +- src/sentiment_analysis.py | 3 + src/translation_french_english.py | 13 +- tests/test_data_processor.py | 1 + tests/test_load_data.py | 201 ++++++++++++++++++++++++ tests/test_text_vectorizer.py | 147 +++++++++++++++++ tests/test_transformer_model.py | 17 +- tests/test_utils.py | 125 +++++++++++++++ uv.lock | 81 ++++++++++ 25 files changed, 723 insertions(+), 53 deletions(-) create mode 100644 tests/test_load_data.py create mode 100644 tests/test_text_vectorizer.py create mode 100644 tests/test_utils.py diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 65a3e01..d965d97 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -48,6 +48,11 @@ jobs: uv run ruff check ./src ./app ./tests uv run ruff format --check ./src ./app ./tests + # Run mypy type checking + - name: Run mypy type checking + run: | + uv run mypy src/ app/ tests/ --ignore-missing-imports + # Run pytest (excludes audio-dependent modules like speech_to_text) - name: Run tests with pytest run: | diff --git a/.gitignore b/.gitignore index 427d2c9..f8334b4 100644 --- a/.gitignore +++ b/.gitignore @@ -195,3 +195,4 @@ src/models/*.json app_simple.py +mypy_output.txt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1d281a5..e2a5e51 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,6 +25,14 @@ repos: - repo: local hooks: + - id: mypy + name: mypy + entry: uv run mypy + language: system + types: [python] + files: ^(src/|app/|tests/).*\.py$ + args: [--ignore-missing-imports] + - id: pytest name: pytest entry: uv run python -m pytest diff --git a/Makefile b/Makefile index 68ab0ed..7586d29 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Makefile for Sentiment Analysis -.PHONY: help install test lint format clean run +.PHONY: help install test lint format clean run type-check help: ## Show available commands @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-15s\033[0m %s\n", $$1, $$2}' @@ -20,6 +20,12 @@ lint: ## Check and fix code quality uv run ruff check --fix ./src ./app ./tests uv run ruff format ./src ./app ./tests +type-check: ## Run mypy type checking + uv run mypy src/ app/ tests/ + +type-check-strict: ## Run mypy with strict mode + uv run mypy --strict src/ app/ + format: ## Format code only uv run ruff format ./src ./app ./tests diff --git a/README.md b/README.md index dfa9257..6305419 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,21 @@ [![Linting: Ruff](https://img.shields.io/badge/linting-ruff-yellowgreen)](https://github.com/charliermarsh/ruff) +[![Type Checking: mypy](https://img.shields.io/badge/type%20checking-mypy-blue)](http://mypy-lang.org/) [![CI: Passed](https://img.shields.io/badge/CI-Passed-brightgreen)](https://github.com/jvachier/Sentiment_Analysis/actions/workflows/test.yaml) +[![Tests: pytest](https://img.shields.io/badge/tests-pytest-orange)](https://docs.pytest.org/) +[![Pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit)](https://github.com/pre-commit/pre-commit) [![Deep Learning](https://img.shields.io/badge/Deep%20Learning-TensorFlow-orange)](https://www.tensorflow.org/) [![Keras](https://img.shields.io/badge/Keras-red)](https://keras.io/) [![TensorFlow](https://img.shields.io/badge/TensorFlow-2.0%2B-orange)](https://www.tensorflow.org/) [![Python](https://img.shields.io/badge/Python-3.11%2B-blue)](https://www.python.org/) [![uv](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/uv/main/assets/badge/v0.json)](https://github.com/astral-sh/uv) +[![NLP](https://img.shields.io/badge/NLP-Natural%20Language%20Processing-green)](https://en.wikipedia.org/wiki/Natural_language_processing) +[![Transformers](https://img.shields.io/badge/Transformers-From%20Scratch-blueviolet)](https://arxiv.org/abs/1706.03762) +[![Neural Machine Translation](https://img.shields.io/badge/Neural-Machine%20Translation-purple)](https://en.wikipedia.org/wiki/Neural_machine_translation) +[![Sentiment Analysis](https://img.shields.io/badge/Sentiment-Analysis-pink)](https://en.wikipedia.org/wiki/Sentiment_analysis) +[![Speech Recognition](https://img.shields.io/badge/Speech-Recognition-cyan)](https://en.wikipedia.org/wiki/Speech_recognition) +[![Gradio](https://img.shields.io/badge/UI-Gradio-ff7c00)](https://gradio.app/) +[![Optuna](https://img.shields.io/badge/Hyperparameter-Optuna-lightblue)](https://optuna.org/) +[![Dash](https://img.shields.io/badge/Dashboard-Dash-blue)](https://dash.plotly.com/) [![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) # Sentiment Analysis and Translation @@ -292,6 +303,17 @@ Sentiment_Analysis/ This Kaggle notebook provides a detailed tutorial on the transformer architecture implemented in this repository. +--- + +## Live Demo + +** HuggingFace Space: English-to-French Translator** +- Try the enhanced Transformer model live in your browser +- Real-time translation with greedy and beam search decoding +- No installation required - instant access +- [Launch Demo on HuggingFace](https://huggingface.co/spaces/Jvachier/transformer-nmt-en-fr) + + --- ## Customization diff --git a/gradio_apps/README.md b/gradio_apps/README.md index 7dea952..8a02ef5 100644 --- a/gradio_apps/README.md +++ b/gradio_apps/README.md @@ -18,6 +18,16 @@ tags: - tensorflow - keras - from-scratch + - nlp + - seq2seq + - attention-mechanism + - encoder-decoder + - deep-learning + - machine-translation + - multilingual + - text-generation + - custom-model + - educational --- # English to French Enhanced Transformer diff --git a/pyproject.toml b/pyproject.toml index b02c3cc..4475e6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,8 @@ dev = [ "ruff>=0.2.1", "scikit-optimize>=0.9.0", "pre-commit>=4.0.0", + "mypy>=1.8.0", + "types-tensorflow>=2.16.0", ] macos = [ "tensorflow-io-gcs-filesystem<0.35.0", @@ -69,3 +71,41 @@ build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] packages = ["src", "app"] + +[tool.mypy] +python_version = "3.11" +warn_return_any = false +warn_unused_configs = true +disallow_untyped_defs = false +disallow_incomplete_defs = false +check_untyped_defs = true +disallow_untyped_calls = false +warn_redundant_casts = true +warn_unused_ignores = true +strict_optional = true +no_implicit_optional = true +ignore_missing_imports = true +show_error_codes = true +pretty = true + +[[tool.mypy.overrides]] +module = [ + "tensorflow.*", + "keras.*", + "nltk.*", + "vosk.*", + "dash.*", + "optuna.*", + "transformers.*", + "polars.*", +] +ignore_missing_imports = true +ignore_errors = true + +[[tool.mypy.overrides]] +module = [ + "src.translation_french_english", + "src.modules.optuna_transformer", + "tests.test_transformer_model", +] +disable_error_code = ["dict-item", "misc"] diff --git a/src/modules/data_preprocess_nltk.py b/src/modules/data_preprocess_nltk.py index f322826..9d1ae81 100644 --- a/src/modules/data_preprocess_nltk.py +++ b/src/modules/data_preprocess_nltk.py @@ -78,7 +78,7 @@ def encode(self, text_tensor, label): """ text = text_tensor.numpy().decode("utf-8") text = self.preprocess_text(text) - encoded_text = self.tokenizer.texts_to_sequences([text])[0] + encoded_text = self.tokenizer.texts_to_sequences([text])[0] # type: ignore[union-attr] return encoded_text, label def fit_tokenizer(self, ds_raw): @@ -163,7 +163,7 @@ def encode(self, text_tensor): """ text = text_tensor.numpy().decode("utf-8") text = self.preprocess_text(text) - encoded_text = self.tokenizer.texts_to_sequences([text])[0] + encoded_text = self.tokenizer.texts_to_sequences([text])[0] # type: ignore[union-attr] return encoded_text def fit_tokenizer(self, ds_raw): diff --git a/src/modules/data_processor.py b/src/modules/data_processor.py index 63dcd1a..2b698c3 100644 --- a/src/modules/data_processor.py +++ b/src/modules/data_processor.py @@ -2,7 +2,7 @@ import tensorflow as tf import string import re -from typing import Tuple, Dict +from typing import Tuple, Dict, Optional class DatasetProcessor: @@ -20,7 +20,8 @@ def __init__(self, file_path: str, delimiters: str = r"|"): """ self.file_path = file_path self.delimiters = delimiters - self.split_df = None + self.split_df: Optional[pl.DataFrame] = None + self.df: Optional[pl.DataFrame] = None def load_data(self) -> None: """Load the Parquet file using Polars.""" @@ -28,6 +29,11 @@ def load_data(self) -> None: def process_data(self) -> None: """Process the dataset by splitting, cleaning, and tokenizing.""" + if self.df is None: + raise ValueError( + "Data must be loaded before processing. Call load_data() first." + ) + # Split the 'en' column into rows based on delimiters if "en" in self.df.columns: en_split = self.df.select(pl.col("en").str.split(self.delimiters)).explode( @@ -76,6 +82,9 @@ def shuffle_and_split( """ # Calculate the number of samples for validation and test sets + if self.split_df is None: + raise ValueError("Data must be processed before splitting") + num_val_samples = int(val_split * len(self.split_df)) num_train_samples = len(self.split_df) - 2 * num_val_samples diff --git a/src/modules/load_data.py b/src/modules/load_data.py index 1893e26..66b2a81 100644 --- a/src/modules/load_data.py +++ b/src/modules/load_data.py @@ -1,6 +1,6 @@ import pandas as pd import tensorflow as tf -from pydantic import BaseModel, FilePath, Field, ValidationError +from pydantic import BaseModel, Field, ValidationError from src.modules.utils import DatasetPaths @@ -9,7 +9,7 @@ class DataLoaderConfig(BaseModel): Configuration for the DataLoader class. """ - data_path: FilePath = Field( + data_path: str = Field( default=DatasetPaths.RAW_DATA.value, description="Path to the CSV file containing the dataset.", ) diff --git a/src/modules/model_bert_other.py b/src/modules/model_bert_other.py index f72c9b3..b163fac 100644 --- a/src/modules/model_bert_other.py +++ b/src/modules/model_bert_other.py @@ -141,7 +141,9 @@ def build_model(self, num_classes): dropout = tf.keras.layers.Dropout(0.3)(cls_token) output = tf.keras.layers.Dense(1, activation="sigmoid")(dropout) - model = tf.keras.Model(inputs=[input_ids, attention_mask], outputs=output) + model: tf.keras.Model = tf.keras.Model( + inputs=[input_ids, attention_mask], outputs=output + ) model.compile( optimizer=tf.keras.optimizers.legacy.RMSprop( learning_rate=self.learning_rate diff --git a/src/modules/model_sentiment_analysis.py b/src/modules/model_sentiment_analysis.py index dfd169c..529cc71 100644 --- a/src/modules/model_sentiment_analysis.py +++ b/src/modules/model_sentiment_analysis.py @@ -62,7 +62,10 @@ def train_and_evaluate( callbacks=[early_stopping_callback, callbacks_model], ) test_results = model.evaluate(test_data) - logging.info("Test Accuracy: {:.2f}%".format(test_results[1] * 100)) + if isinstance(test_results, list): + logging.info("Test Accuracy: {:.2f}%".format(test_results[1] * 100)) + else: + logging.info("Test Accuracy: {:.2f}%".format(test_results * 100)) def inference_model( self, model: tf.keras.Model, text_vec: tf.keras.layers.TextVectorization @@ -77,10 +80,10 @@ def inference_model( Returns: tf.keras.Model: An inference model for sentiment prediction. """ - inputs = tf.keras.Input(shape=(1,), dtype=tf.string) + inputs: tf.Tensor = tf.keras.Input(shape=(1,), dtype=tf.string) process_inputs = text_vec(inputs) outputs = model(process_inputs) - inference_model = tf.keras.Model(inputs=inputs, outputs=outputs) + inference_model: tf.keras.Model = tf.keras.Model(inputs=inputs, outputs=outputs) return inference_model @@ -121,7 +124,7 @@ def optimize( test_data (tf.data.Dataset): Test dataset. """ - def _objective(trial): + def _objective(trial: optuna.trial.Trial) -> float: """ Objective function for Optuna to optimize the model's hyperparameters. @@ -141,7 +144,7 @@ def _objective(trial): ) ) n_layers_bidirectional = trial.suggest_int("n_units_bidirectional", 1, 3) - for i in range(n_layers_bidirectional): + for i in range(n_layers_bidirectional): # type: int num_hidden_bidirectional = trial.suggest_int( f"n_units_bidirectional_l{i}", 64, 128, log=True ) @@ -165,8 +168,8 @@ def _objective(trial): model.add(tf.keras.layers.Dropout(self.dropout_rate)) n_layers_nn = trial.suggest_int("n_layers_nn", 1, 2) - for i in range(n_layers_nn): - num_hidden_nn = trial.suggest_int(f"n_units_nn_l{i}", 64, 128, log=True) + for j in range(n_layers_nn): + num_hidden_nn = trial.suggest_int(f"n_units_nn_l{j}", 64, 128, log=True) model.add(tf.keras.layers.Dense(num_hidden_nn, activation="gelu")) model.add(tf.keras.layers.Dropout(self.dropout_rate)) @@ -199,7 +202,9 @@ def _objective(trial): ) # Evaluate the model accuracy on the validation set. score = model.evaluate(test_data, verbose=1) - return score[1] + if isinstance(score, list): + return float(score[1]) + return float(score) # Create an Optuna study study = optuna.create_study( @@ -267,7 +272,7 @@ def get_model_api(self) -> tf.keras.Model: dropout_layer2 = tf.keras.layers.Dropout(self.dropout_rate)(dense_layer2) dense_layer3 = tf.keras.layers.Dense(32, activation="gelu")(dropout_layer2) outputs = tf.keras.layers.Dense(1, activation="sigmoid")(dense_layer3) - model = tf.keras.Model(inputs=inputs, outputs=outputs) + model: tf.keras.Model = tf.keras.Model(inputs=inputs, outputs=outputs) model.compile( optimizer=tf.keras.optimizers.RMSprop(), loss=tf.keras.losses.BinaryCrossentropy(), @@ -282,13 +287,9 @@ def get_config(self) -> dict: Returns: dict: A dictionary containing the model's configuration. """ - config = super().get_config() - config.update( - { - "embedding_dim": self.embedding_dim, - "lstm_units": self.lstm_units, - "dropout_rate": self.dropout_rate, - "max_token": self.max_token, - } - ) - return config + return { + "embedding_dim": self.embedding_dim, + "lstm_units": self.lstm_units, + "dropout_rate": self.dropout_rate, + "max_token": self.max_token, + } diff --git a/src/modules/optuna_transformer.py b/src/modules/optuna_transformer.py index b3dd650..8183173 100644 --- a/src/modules/optuna_transformer.py +++ b/src/modules/optuna_transformer.py @@ -61,7 +61,9 @@ def build_transformer_model( dropout_outputs ) - transformer = tf.keras.Model([encoder_inputs, decoder_inputs], final_outputs) + transformer: tf.keras.Model = tf.keras.Model( + [encoder_inputs, decoder_inputs], final_outputs + ) # Compile the model transformer.compile( diff --git a/src/modules/sentiment_analysis_utils.py b/src/modules/sentiment_analysis_utils.py index c708a0a..c464925 100644 --- a/src/modules/sentiment_analysis_utils.py +++ b/src/modules/sentiment_analysis_utils.py @@ -58,7 +58,7 @@ def create_or_load_inference_model( return tf.keras.models.load_model(inference_model_path) logging.info("Creating and saving the inference model.") - trainer = ModelTrainer() + trainer = ModelTrainer(config_path=ModelPaths.MODEL_TRAINER_CONFIG.value) inference_model = trainer.inference_model(model, text_vec) inference_model.save(inference_model_path) return inference_model diff --git a/src/modules/speech_to_text.py b/src/modules/speech_to_text.py index 39c0724..c72992e 100644 --- a/src/modules/speech_to_text.py +++ b/src/modules/speech_to_text.py @@ -45,7 +45,7 @@ def __init__(self, model_path: str): ) self.stream.start_stream() self.rec = vosk.KaldiRecognizer(self.model, 16000) - self.recognized_text = [] + self.recognized_text: list[str] = [] self.recording = False def start_recording(self) -> None: diff --git a/src/modules/transformer_components.py b/src/modules/transformer_components.py index 2ea1867..556063d 100644 --- a/src/modules/transformer_components.py +++ b/src/modules/transformer_components.py @@ -118,7 +118,7 @@ def build(self, input_shape): self.layernorm_3 = tf.keras.layers.LayerNormalization() super().build(input_shape) - def call(self, inputs, encoder_outputs, mask=None): + def call(self, inputs, encoder_outputs, mask=None): # type: ignore[override] causal_mask = self.get_causal_attention_mask(inputs) if mask is not None: padding_mask = tf.cast(mask[:, tf.newaxis, tf.newaxis, :], dtype="float32") @@ -214,8 +214,8 @@ def evaluate_bleu( references.append([ref_sentence]) # Calculate BLEU score - bleu_score = corpus_bleu( - references, candidates, smoothing_function=smoothing_function + bleu_score = float( + corpus_bleu(references, candidates, smoothing_function=smoothing_function) ) logging.info(f"BLEU score evaluation completed: {bleu_score:.4f}") return bleu_score diff --git a/src/modules/utils.py b/src/modules/utils.py index caea7ea..a7bdf30 100644 --- a/src/modules/utils.py +++ b/src/modules/utils.py @@ -7,7 +7,7 @@ class DatasetPaths(str, Enum): Enum for dataset-related paths. """ - RAW_DATA = Path("src/data/tripadvisor_hotel_reviews.csv") + RAW_DATA = str(Path("src/data/tripadvisor_hotel_reviews.csv")) class ModelPaths(str, Enum): @@ -15,11 +15,11 @@ class ModelPaths(str, Enum): Enum for model-related paths. """ - MODEL_BUILDER_CONFIG = Path("src/configurations/model_builder_config.json") - MODEL_TRAINER_CONFIG = Path("src/configurations/model_trainer_config.json") - TRAINED_MODEL = Path("src/models/sentiment_keras_binary.keras") - INFERENCE_MODEL = Path("src/models/inference_model.keras") - TRANSFORMER_MODEL = Path("src/models/transformer_best_model.keras") + MODEL_BUILDER_CONFIG = str(Path("src/configurations/model_builder_config.json")) + MODEL_TRAINER_CONFIG = str(Path("src/configurations/model_trainer_config.json")) + TRAINED_MODEL = str(Path("src/models/sentiment_keras_binary.keras")) + INFERENCE_MODEL = str(Path("src/models/inference_model.keras")) + TRANSFORMER_MODEL = str(Path("src/models/transformer_best_model.keras")) class OptunaPaths(str, Enum): @@ -27,8 +27,8 @@ class OptunaPaths(str, Enum): Enum for Optuna-related paths. """ - OPTUNA_CONFIG = Path("src/configurations/optuna_config.json") - OPTUNA_MODEL = Path("src/models/optuna_model_binary.json") + OPTUNA_CONFIG = str(Path("src/configurations/optuna_config.json")) + OPTUNA_MODEL = str(Path("src/models/optuna_model_binary.json")) class TextVectorizerConfig(int, Enum): diff --git a/src/sentiment_analysis.py b/src/sentiment_analysis.py index 386f74e..022d6d8 100644 --- a/src/sentiment_analysis.py +++ b/src/sentiment_analysis.py @@ -60,6 +60,9 @@ def main() -> None: valid_data = vectorized_dataset.get("valid_data") test_data = vectorized_dataset.get("test_data") + if train_data is None or valid_data is None or test_data is None: + raise ValueError("Failed to vectorize datasets") + # Initialize the sentiment analysis model logging.info("Initializing the sentiment analysis model.") diff --git a/src/translation_french_english.py b/src/translation_french_english.py index 327deab..4f12033 100644 --- a/src/translation_french_english.py +++ b/src/translation_french_english.py @@ -84,7 +84,9 @@ def transformer_model( dropout_outputs ) - transformer = tf.keras.Model([encoder_inputs, decoder_inputs], final_outputs) + transformer: tf.keras.Model = tf.keras.Model( + [encoder_inputs, decoder_inputs], final_outputs + ) # Compile the model transformer.compile( @@ -161,7 +163,7 @@ def translation_test( """ # Get French vocabulary (target language) for decoding fr_vocab = preprocessor.target_vectorization.get_vocabulary() - fr_index_lookup = dict(zip(range(len(fr_vocab)), fr_vocab)) + fr_index_lookup: dict[int, str] = {i: word for i, word in enumerate(fr_vocab)} # Debug: print vocabulary info logging.info(f"French vocabulary size: {len(fr_vocab)}") @@ -177,7 +179,7 @@ def translation_test( )[:, :-1] predictions = transformer([tokenized_input_sentence, tokenized_target_sentence]) - sampled_token_index = np.argmax(predictions[0, i, :]) + sampled_token_index: int = int(np.argmax(predictions[0, i, :])) sampled_token = fr_index_lookup[sampled_token_index] decoded_sentence += " " + sampled_token @@ -228,7 +230,10 @@ def main() -> None: # Evaluate the model logging.info("Evaluating the model on the test dataset.") results = transformer.evaluate(test_ds) - logging.info(f"Test loss: {results[0]}, Test accuracy: {results[1]}") + if isinstance(results, list): + logging.info(f"Test loss: {results[0]}, Test accuracy: {results[1]}") + else: + logging.info(f"Test loss: {results}, Test accuracy: N/A") # Calculate BLEU score bleu_score = evaluate_bleu(transformer, test_ds, preprocessor) diff --git a/tests/test_data_processor.py b/tests/test_data_processor.py index ee2b83c..feef2ed 100644 --- a/tests/test_data_processor.py +++ b/tests/test_data_processor.py @@ -37,6 +37,7 @@ def test_dataset_processor(sample_data: pl.DataFrame) -> None: data_splits: Dict[str, pl.DataFrame] = processor.shuffle_and_split() # Check if the data is processed correctly + assert processor.split_df is not None, "Split dataframe is None!" assert len(processor.split_df) > 0, "Processed dataset is empty!" assert "fr" in processor.split_df.columns, "'fr' column is missing!" assert processor.split_df["fr"][0].startswith("[start]"), "Start token missing!" diff --git a/tests/test_load_data.py b/tests/test_load_data.py new file mode 100644 index 0000000..0743f62 --- /dev/null +++ b/tests/test_load_data.py @@ -0,0 +1,201 @@ +import pytest +import pandas as pd +import tensorflow as tf +import tempfile +import os +from src.modules.load_data import DataLoader, DataLoaderConfig + + +@pytest.fixture +def sample_csv(): + """Create a temporary CSV file with sample data.""" + data = { + "Review": [ + "Excellent hotel!", + "Terrible experience", + "Average place", + "Outstanding service", + "Very disappointing", + ], + "Rating": [5, 1, 3, 5, 2], + } + df = pd.DataFrame(data) + + # Create a temporary file + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: + df.to_csv(f.name, index=False) + temp_path = f.name + + yield temp_path + + # Cleanup + if os.path.exists(temp_path): + os.remove(temp_path) + + +def test_data_loader_config_valid(sample_csv): + """Test DataLoaderConfig with valid path.""" + config = DataLoaderConfig(data_path=sample_csv) + assert str(config.data_path) == sample_csv + + +def test_data_loader_config_with_string_path(): + """Test DataLoaderConfig accepts string paths (file existence checked at load time).""" + config = DataLoaderConfig(data_path="/any/path/to/file.csv") + assert config.data_path == "/any/path/to/file.csv" + + +def test_data_loader_initialization(sample_csv): + """Test DataLoader initialization.""" + loader = DataLoader(data_path=sample_csv) + assert loader.data_path == sample_csv + + +def test_data_loader_load_data(sample_csv): + """Test DataLoader load_data method.""" + loader = DataLoader(data_path=sample_csv) + result = loader.load_data() + + # Check that all expected keys are in the result + assert "raw" in result + assert "train" in result + assert "valid" in result + assert "test" in result + assert "target" in result + + # Check that datasets are TensorFlow datasets + assert isinstance(result["raw"], tf.data.Dataset) + assert isinstance(result["train"], tf.data.Dataset) + assert isinstance(result["valid"], tf.data.Dataset) + assert isinstance(result["test"], tf.data.Dataset) + """Test that sentiment labels are correctly assigned.""" + loader = DataLoader(data_path=sample_csv) + result = loader.load_data() + target = result["target"] + + # Reviews with Rating < 3 should have label 0 + # Reviews with Rating >= 3 should have label 1 + assert len(target) == 5 + assert set(target.unique()) == {0, 1} # Only binary labels + + +def test_data_loader_dataset_splits(sample_csv): + """Test that dataset is properly split.""" + loader = DataLoader(data_path=sample_csv) + result = loader.load_data() + + # Count elements in each split + train_count = sum(1 for _ in result["train"]) + _valid_count = sum(1 for _ in result["valid"]) + test_count = sum(1 for _ in result["test"]) + _total_count = sum(1 for _ in result["raw"]) + + # Verify split proportions + assert train_count >= test_count, "Training set should be larger than test set" + + +def test_data_loader_data_format(sample_csv): + """Test that loaded data has correct format.""" + loader = DataLoader(data_path=sample_csv) + result = loader.load_data() + + # Check one batch from training set + for text, label in result["train"].take(1): + # Text should be a string tensor + assert text.dtype == tf.string + # Label should be integer + assert label.dtype in [tf.int32, tf.int64] + # Label should be 0 or 1 + assert label.numpy() in [0, 1] + + +def test_data_loader_shuffling(sample_csv): + """Test that dataset is shuffled.""" + loader = DataLoader(data_path=sample_csv) + result = loader.load_data() + + # Get first few samples from training set + samples1 = [] + for text, label in result["train"].take(2): + samples1.append((text.numpy().decode(), label.numpy())) + + # Load again and check if order is different (with shuffling) + loader2 = DataLoader(data_path=sample_csv) + result2 = loader2.load_data() + + samples2 = [] + for text, label in result2["train"].take(2): + samples2.append((text.numpy().decode(), label.numpy())) + + # Due to shuffling, samples may be different + # (This test may occasionally fail if shuffle produces same order) + # We just verify that both have valid data + assert len(samples1) == len(samples2) + + +def test_data_loader_with_missing_column(): + """Test DataLoader with CSV missing required columns.""" + # Create CSV with wrong columns + data = {"WrongColumn": ["text1", "text2"], "AnotherWrong": [1, 2]} + df = pd.DataFrame(data) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: + df.to_csv(f.name, index=False) + temp_path = f.name + + try: + loader = DataLoader(data_path=temp_path) + # This should raise an error when trying to load (AttributeError for missing column) + with pytest.raises(AttributeError): + loader.load_data() + finally: + if os.path.exists(temp_path): + os.remove(temp_path) + + +def test_data_loader_empty_csv(): + """Test DataLoader with CSV with minimal data.""" + # Create CSV with at least one row to avoid pandas error + df = pd.DataFrame({"Review": ["test"], "Rating": [3]}) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: + df.to_csv(f.name, index=False) + temp_path = f.name + + try: + loader = DataLoader(data_path=temp_path) + result = loader.load_data() + + # Check that dataset is created + assert sum(1 for _ in result["raw"]) > 0 + finally: + if os.path.exists(temp_path): + os.remove(temp_path) + + +def test_data_loader_rating_boundaries(sample_csv): + """Test that rating boundaries are correctly classified.""" + # Create CSV with boundary ratings + data = { + "Review": ["Review 1", "Review 2", "Review 3", "Review 4"], + "Rating": [1, 2, 3, 4], # 1,2 -> negative (0); 3,4 -> positive (1) + } + df = pd.DataFrame(data) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: + df.to_csv(f.name, index=False) + temp_path = f.name + + try: + loader = DataLoader(data_path=temp_path) + result = loader.load_data() + target = result["target"] + + # Ratings 1, 2 should be 0; ratings 3, 4 should be 1 + assert target.iloc[0] == 0 # Rating 1 + assert target.iloc[1] == 0 # Rating 2 + assert target.iloc[2] == 1 # Rating 3 + assert target.iloc[3] == 1 # Rating 4 + finally: + if os.path.exists(temp_path): + os.remove(temp_path) diff --git a/tests/test_text_vectorizer.py b/tests/test_text_vectorizer.py new file mode 100644 index 0000000..6b858e8 --- /dev/null +++ b/tests/test_text_vectorizer.py @@ -0,0 +1,147 @@ +import pytest +import tensorflow as tf +from pydantic import ValidationError +from src.modules.text_vectorizer_sentiment_analysis import ( + TextVectorizer, + TextVectorizerConfig, +) + + +def test_text_vectorizer_config_valid(): + """Test TextVectorizerConfig with valid parameters.""" + config = TextVectorizerConfig(max_tokens=10000, output_sequence_length=200) + assert config.max_tokens == 10000 + assert config.output_sequence_length == 200 + + +def test_text_vectorizer_config_invalid_max_tokens(): + """Test TextVectorizerConfig with invalid max_tokens.""" + with pytest.raises(ValidationError): + TextVectorizerConfig(max_tokens=0, output_sequence_length=200) + + +def test_text_vectorizer_config_invalid_sequence_length(): + """Test TextVectorizerConfig with invalid output_sequence_length.""" + with pytest.raises(ValidationError): + TextVectorizerConfig(max_tokens=10000, output_sequence_length=-1) + + +def test_text_vectorizer_config_defaults(): + """Test TextVectorizerConfig with default values.""" + config = TextVectorizerConfig() + assert config.max_tokens == 20000 + assert config.output_sequence_length == 500 + + +def test_text_vectorizer_initialization(): + """Test TextVectorizer initialization.""" + vectorizer = TextVectorizer(max_tokens=5000, output_sequence_length=100) + assert vectorizer.max_tokens == 5000 + assert vectorizer.output_sequence_length == 100 + assert isinstance(vectorizer.text_vec, tf.keras.layers.TextVectorization) + + +def test_text_vectorizer_adapt(): + """Test TextVectorizer adapt method.""" + # Create a simple dataset + texts = ["hello world", "test data", "machine learning"] + labels = [0, 1, 0] + ds_train = tf.data.Dataset.from_tensor_slices((texts, labels)) + + # Initialize and adapt vectorizer + vectorizer = TextVectorizer(max_tokens=100, output_sequence_length=10) + vectorizer.adapt(ds_train) + + # Check that vocabulary was built + vocab = vectorizer.text_vec.get_vocabulary() + assert len(vocab) > 0 + assert "" in vocab # padding token + assert "[UNK]" in vocab # unknown token + + +def test_text_vectorizer_vectorize_datasets(): + """Test TextVectorizer vectorize_datasets method.""" + # Create datasets + texts_train = ["hello world", "test data"] + labels_train = [1, 0] + ds_train = tf.data.Dataset.from_tensor_slices((texts_train, labels_train)) + + texts_valid = ["validation text"] + labels_valid = [1] + ds_valid = tf.data.Dataset.from_tensor_slices((texts_valid, labels_valid)) + + texts_test = ["test text"] + labels_test = [0] + ds_test = tf.data.Dataset.from_tensor_slices((texts_test, labels_test)) + + # Initialize, adapt, and vectorize + vectorizer = TextVectorizer(max_tokens=100, output_sequence_length=10) + vectorizer.adapt(ds_train) + + result = vectorizer.vectorize_datasets(ds_train, ds_valid, ds_test) + ds_train_vec = result["train_data"] + _ds_valid_vec = result["valid_data"] + _ds_test_vec = result["test_data"] + + # Check that datasets are properly vectorized + for batch in ds_train_vec.take(1): + texts, labels = batch + assert texts.shape[1] == 10 # output_sequence_length + assert len(labels.shape) == 1 + break + + +def test_text_vectorizer_with_empty_texts(): + """Test TextVectorizer with empty texts.""" + texts = ["", "non-empty text"] + labels = [0, 1] + ds_train = tf.data.Dataset.from_tensor_slices((texts, labels)) + + vectorizer = TextVectorizer(max_tokens=100, output_sequence_length=10) + vectorizer.adapt(ds_train) + + # Vectorize and check it handles empty strings + result = vectorizer.vectorize_datasets(ds_train, ds_train, ds_train) + ds_train_vec = result["train_data"] + + for batch in ds_train_vec.take(1): + texts, labels = batch + assert texts.shape[0] > 0 # should have batches + break + + +def test_text_vectorizer_vocab_size(): + """Test that vocabulary size respects max_tokens.""" + texts = [f"word{i}" for i in range(200)] + labels = [i % 2 for i in range(200)] + ds_train = tf.data.Dataset.from_tensor_slices((texts, labels)) + + max_tokens = 50 + vectorizer = TextVectorizer(max_tokens=max_tokens, output_sequence_length=10) + vectorizer.adapt(ds_train) + + vocab = vectorizer.text_vec.get_vocabulary() + # Vocab includes padding and [UNK] tokens + assert len(vocab) <= max_tokens + + +def test_text_vectorizer_sequence_length(): + """Test that sequences are properly truncated/padded.""" + texts = [ + "short", + "this is a much longer text that should be truncated to the max length", + ] + labels = [0, 1] + ds_train = tf.data.Dataset.from_tensor_slices((texts, labels)) + + sequence_length = 5 + vectorizer = TextVectorizer(max_tokens=100, output_sequence_length=sequence_length) + vectorizer.adapt(ds_train) + + result = vectorizer.vectorize_datasets(ds_train, ds_train, ds_train) + ds_train_vec = result["train_data"] + + for batch in ds_train_vec.take(1): + texts_vec, _ = batch + assert texts_vec.shape[1] == sequence_length + break diff --git a/tests/test_transformer_model.py b/tests/test_transformer_model.py index 0eb62db..8ee2509 100644 --- a/tests/test_transformer_model.py +++ b/tests/test_transformer_model.py @@ -13,9 +13,9 @@ @pytest.fixture -def setup_data() -> Tuple[ - TextPreprocessor, tf.data.Dataset, tf.data.Dataset, tf.data.Dataset -]: +def setup_data() -> ( + Tuple[TextPreprocessor, tf.data.Dataset, tf.data.Dataset, tf.data.Dataset] +): """ Fixture to set up a mocked dataset and preprocessor for testing. @@ -130,9 +130,10 @@ def test_transformer_model_evaluation( ) # Evaluate the model - results: list[float] = model.evaluate(test_ds, verbose=0) + results = model.evaluate(test_ds, verbose=0) # Check if evaluation results are returned + assert isinstance(results, list), "Evaluation did not return a list." assert len(results) == 2, "Evaluation did not return loss and accuracy." assert results[0] >= 0, "Test loss is invalid." assert 0 <= results[1] <= 1, "Test accuracy is invalid." @@ -185,12 +186,12 @@ def test_transformer_model_loading( model.save(transformer_model_path) # Load the model - loaded_model: tf.keras.Model = tf.keras.models.load_model( + loaded_model = tf.keras.models.load_model( transformer_model_path, custom_objects={ - "PositionalEmbedding": PositionalEmbedding, - "TransformerEncoder": TransformerEncoder, - "TransformerDecoder": TransformerDecoder, + "PositionalEmbedding": PositionalEmbedding, # type: ignore[dict-item] + "TransformerEncoder": TransformerEncoder, # type: ignore[dict-item] + "TransformerDecoder": TransformerDecoder, # type: ignore[dict-item] }, ) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..fddea7d --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,125 @@ +from src.modules.utils import ( + DatasetPaths, + ModelPaths, + OptunaPaths, + TextVectorizerConfig, +) + + +def test_dataset_paths_enum(): + """Test DatasetPaths enum values.""" + assert hasattr(DatasetPaths, "RAW_DATA") + # Enum values are strings (str, Enum inheritance converts Path to str) + assert isinstance(DatasetPaths.RAW_DATA.value, str) + assert str(DatasetPaths.RAW_DATA.value).endswith(".csv") + + +def test_model_paths_enum(): + """Test ModelPaths enum values.""" + # Check that all required paths exist + assert hasattr(ModelPaths, "MODEL_BUILDER_CONFIG") + assert hasattr(ModelPaths, "MODEL_TRAINER_CONFIG") + assert hasattr(ModelPaths, "TRAINED_MODEL") + assert hasattr(ModelPaths, "INFERENCE_MODEL") + assert hasattr(ModelPaths, "TRANSFORMER_MODEL") + + # Check that paths are strings (str, Enum inheritance converts Path to str) + assert isinstance(ModelPaths.MODEL_BUILDER_CONFIG.value, str) + assert isinstance(ModelPaths.MODEL_TRAINER_CONFIG.value, str) + assert isinstance(ModelPaths.TRAINED_MODEL.value, str) + assert isinstance(ModelPaths.INFERENCE_MODEL.value, str) + assert isinstance(ModelPaths.TRANSFORMER_MODEL.value, str) + + # Check file extensions + assert str(ModelPaths.MODEL_BUILDER_CONFIG.value).endswith(".json") + assert str(ModelPaths.MODEL_TRAINER_CONFIG.value).endswith(".json") + assert str(ModelPaths.TRAINED_MODEL.value).endswith(".keras") + assert str(ModelPaths.INFERENCE_MODEL.value).endswith(".keras") + assert str(ModelPaths.TRANSFORMER_MODEL.value).endswith(".keras") + + +def test_optuna_paths_enum(): + """Test OptunaPaths enum values.""" + assert hasattr(OptunaPaths, "OPTUNA_CONFIG") + assert hasattr(OptunaPaths, "OPTUNA_MODEL") + + # Enum values are strings (str, Enum inheritance converts Path to str) + assert isinstance(OptunaPaths.OPTUNA_CONFIG.value, str) + assert isinstance(OptunaPaths.OPTUNA_MODEL.value, str) + + +def test_text_vectorizer_config_enum(): + """Test TextVectorizerConfig enum values.""" + assert hasattr(TextVectorizerConfig, "max_tokens") + assert hasattr(TextVectorizerConfig, "output_sequence_length") + + # Check values are positive integers + assert TextVectorizerConfig.max_tokens.value > 0 + assert TextVectorizerConfig.output_sequence_length.value > 0 + + # Check specific default values + assert TextVectorizerConfig.max_tokens.value == 20000 + assert TextVectorizerConfig.output_sequence_length.value == 500 + + +def test_paths_structure(): + """Test that all paths follow expected structure.""" + # All model paths should start with "src/" + assert str(ModelPaths.MODEL_BUILDER_CONFIG.value).startswith("src/") + assert str(ModelPaths.MODEL_TRAINER_CONFIG.value).startswith("src/") + assert str(ModelPaths.TRAINED_MODEL.value).startswith("src/") + + # Configuration paths should be in "src/configurations/" + assert "configurations" in str(ModelPaths.MODEL_BUILDER_CONFIG.value) + assert "configurations" in str(ModelPaths.MODEL_TRAINER_CONFIG.value) + assert "configurations" in str(OptunaPaths.OPTUNA_CONFIG.value) + + # Model paths should be in "src/models/" + assert "models" in str(ModelPaths.TRAINED_MODEL.value) + assert "models" in str(ModelPaths.INFERENCE_MODEL.value) + assert "models" in str(OptunaPaths.OPTUNA_MODEL.value) + + +def test_enum_uniqueness(): + """Test that enum values are unique.""" + # Check ModelPaths + model_values = [item.value for item in ModelPaths] + assert len(model_values) == len( + set(model_values) + ), "ModelPaths has duplicate values" + + # Check OptunaPaths + optuna_values = [item.value for item in OptunaPaths] + assert len(optuna_values) == len( + set(optuna_values) + ), "OptunaPaths has duplicate values" + + +def test_enum_string_conversion(): + """Test that enum values can be converted to strings.""" + # Test that Path objects can be converted to strings + assert isinstance(str(DatasetPaths.RAW_DATA.value), str) + assert isinstance(str(ModelPaths.TRAINED_MODEL.value), str) + assert isinstance(str(OptunaPaths.OPTUNA_CONFIG.value), str) + + +def test_enum_membership(): + """Test enum membership checks.""" + # Check that specific values are in the enums + assert "RAW_DATA" in DatasetPaths.__members__ + assert "TRAINED_MODEL" in ModelPaths.__members__ + assert "OPTUNA_CONFIG" in OptunaPaths.__members__ + assert "max_tokens" in TextVectorizerConfig.__members__ + + +def test_path_components(): + """Test that paths have expected components.""" + # Test that configurations are in the right directory + config_path = str(ModelPaths.MODEL_BUILDER_CONFIG.value) + assert config_path.startswith("src") + assert "configurations" in config_path + + # Test that models are in the right directory + model_path = str(ModelPaths.TRAINED_MODEL.value) + assert model_path.startswith("src") + assert "models" in model_path diff --git a/uv.lock b/uv.lock index 0327131..f7a6c9f 100644 --- a/uv.lock +++ b/uv.lock @@ -942,6 +942,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/71/cf/e01dc4cc79779cd82d77888a88ae2fa424d93b445ad4f6c02bfc18335b70/libclang-18.1.1-py2.py3-none-win_arm64.whl", hash = "sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8", size = 22361112, upload-time = "2024-03-17T16:42:59.565Z" }, ] +[[package]] +name = "librt" +version = "0.8.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/56/9c/b4b0c54d84da4a94b37bd44151e46d5e583c9534c7e02250b961b1b6d8a8/librt-0.8.1.tar.gz", hash = "sha256:be46a14693955b3bd96014ccbdb8339ee8c9346fbe11c1b78901b55125f14c73", size = 177471, upload-time = "2026-02-17T16:13:06.101Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/01/0e748af5e4fee180cf7cd12bd12b0513ad23b045dccb2a83191bde82d168/librt-0.8.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:681dc2451d6d846794a828c16c22dc452d924e9f700a485b7ecb887a30aad1fd", size = 65315, upload-time = "2026-02-17T16:11:25.152Z" }, + { url = "https://files.pythonhosted.org/packages/9d/4d/7184806efda571887c798d573ca4134c80ac8642dcdd32f12c31b939c595/librt-0.8.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a3b4350b13cc0e6f5bec8fa7caf29a8fb8cdc051a3bae45cfbfd7ce64f009965", size = 68021, upload-time = "2026-02-17T16:11:26.129Z" }, + { url = "https://files.pythonhosted.org/packages/ae/88/c3c52d2a5d5101f28d3dc89298444626e7874aa904eed498464c2af17627/librt-0.8.1-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:ac1e7817fd0ed3d14fd7c5df91daed84c48e4c2a11ee99c0547f9f62fdae13da", size = 194500, upload-time = "2026-02-17T16:11:27.177Z" }, + { url = "https://files.pythonhosted.org/packages/d6/5d/6fb0a25b6a8906e85b2c3b87bee1d6ed31510be7605b06772f9374ca5cb3/librt-0.8.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:747328be0c5b7075cde86a0e09d7a9196029800ba75a1689332348e998fb85c0", size = 205622, upload-time = "2026-02-17T16:11:28.242Z" }, + { url = "https://files.pythonhosted.org/packages/b2/a6/8006ae81227105476a45691f5831499e4d936b1c049b0c1feb17c11b02d1/librt-0.8.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f0af2bd2bc204fa27f3d6711d0f360e6b8c684a035206257a81673ab924aa11e", size = 218304, upload-time = "2026-02-17T16:11:29.344Z" }, + { url = "https://files.pythonhosted.org/packages/ee/19/60e07886ad16670aae57ef44dada41912c90906a6fe9f2b9abac21374748/librt-0.8.1-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:d480de377f5b687b6b1bc0c0407426da556e2a757633cc7e4d2e1a057aa688f3", size = 211493, upload-time = "2026-02-17T16:11:30.445Z" }, + { url = "https://files.pythonhosted.org/packages/9c/cf/f666c89d0e861d05600438213feeb818c7514d3315bae3648b1fc145d2b6/librt-0.8.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d0ee06b5b5291f609ddb37b9750985b27bc567791bc87c76a569b3feed8481ac", size = 219129, upload-time = "2026-02-17T16:11:32.021Z" }, + { url = "https://files.pythonhosted.org/packages/8f/ef/f1bea01e40b4a879364c031476c82a0dc69ce068daad67ab96302fed2d45/librt-0.8.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:9e2c6f77b9ad48ce5603b83b7da9ee3e36b3ab425353f695cba13200c5d96596", size = 213113, upload-time = "2026-02-17T16:11:33.192Z" }, + { url = "https://files.pythonhosted.org/packages/9b/80/cdab544370cc6bc1b72ea369525f547a59e6938ef6863a11ab3cd24759af/librt-0.8.1-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:439352ba9373f11cb8e1933da194dcc6206daf779ff8df0ed69c5e39113e6a99", size = 212269, upload-time = "2026-02-17T16:11:34.373Z" }, + { url = "https://files.pythonhosted.org/packages/9d/9c/48d6ed8dac595654f15eceab2035131c136d1ae9a1e3548e777bb6dbb95d/librt-0.8.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:82210adabbc331dbb65d7868b105185464ef13f56f7f76688565ad79f648b0fe", size = 234673, upload-time = "2026-02-17T16:11:36.063Z" }, + { url = "https://files.pythonhosted.org/packages/16/01/35b68b1db517f27a01be4467593292eb5315def8900afad29fabf56304ba/librt-0.8.1-cp311-cp311-win32.whl", hash = "sha256:52c224e14614b750c0a6d97368e16804a98c684657c7518752c356834fff83bb", size = 54597, upload-time = "2026-02-17T16:11:37.544Z" }, + { url = "https://files.pythonhosted.org/packages/71/02/796fe8f02822235966693f257bf2c79f40e11337337a657a8cfebba5febc/librt-0.8.1-cp311-cp311-win_amd64.whl", hash = "sha256:c00e5c884f528c9932d278d5c9cbbea38a6b81eb62c02e06ae53751a83a4d52b", size = 61733, upload-time = "2026-02-17T16:11:38.691Z" }, + { url = "https://files.pythonhosted.org/packages/28/ad/232e13d61f879a42a4e7117d65e4984bb28371a34bb6fb9ca54ec2c8f54e/librt-0.8.1-cp311-cp311-win_arm64.whl", hash = "sha256:f7cdf7f26c2286ffb02e46d7bac56c94655540b26347673bea15fa52a6af17e9", size = 52273, upload-time = "2026-02-17T16:11:40.308Z" }, +] + [[package]] name = "mako" version = "1.3.10" @@ -1055,6 +1076,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/db/1784b87285588788170f87e987bfb4bda218d62a70a81ebb66c94e7f9b95/ml_dtypes-0.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:6604877d567a29bfe7cc02969ae0f2425260e5335505cf5e7fefc3e5465f5655", size = 127681, upload-time = "2024-01-03T19:21:07.337Z" }, ] +[[package]] +name = "mypy" +version = "1.19.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "librt", marker = "platform_python_implementation != 'PyPy'" }, + { name = "mypy-extensions" }, + { name = "pathspec" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/db/4efed9504bc01309ab9c2da7e352cc223569f05478012b5d9ece38fd44d2/mypy-1.19.1.tar.gz", hash = "sha256:19d88bb05303fe63f71dd2c6270daca27cb9401c4ca8255fe50d1d920e0eb9ba", size = 3582404, upload-time = "2025-12-15T05:03:48.42Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/47/6b3ebabd5474d9cdc170d1342fbf9dddc1b0ec13ec90bf9004ee6f391c31/mypy-1.19.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d8dfc6ab58ca7dda47d9237349157500468e404b17213d44fc1cb77bce532288", size = 13028539, upload-time = "2025-12-15T05:03:44.129Z" }, + { url = "https://files.pythonhosted.org/packages/5c/a6/ac7c7a88a3c9c54334f53a941b765e6ec6c4ebd65d3fe8cdcfbe0d0fd7db/mypy-1.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e3f276d8493c3c97930e354b2595a44a21348b320d859fb4a2b9f66da9ed27ab", size = 12083163, upload-time = "2025-12-15T05:03:37.679Z" }, + { url = "https://files.pythonhosted.org/packages/67/af/3afa9cf880aa4a2c803798ac24f1d11ef72a0c8079689fac5cfd815e2830/mypy-1.19.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2abb24cf3f17864770d18d673c85235ba52456b36a06b6afc1e07c1fdcd3d0e6", size = 12687629, upload-time = "2025-12-15T05:02:31.526Z" }, + { url = "https://files.pythonhosted.org/packages/2d/46/20f8a7114a56484ab268b0ab372461cb3a8f7deed31ea96b83a4e4cfcfca/mypy-1.19.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a009ffa5a621762d0c926a078c2d639104becab69e79538a494bcccb62cc0331", size = 13436933, upload-time = "2025-12-15T05:03:15.606Z" }, + { url = "https://files.pythonhosted.org/packages/5b/f8/33b291ea85050a21f15da910002460f1f445f8007adb29230f0adea279cb/mypy-1.19.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f7cee03c9a2e2ee26ec07479f38ea9c884e301d42c6d43a19d20fb014e3ba925", size = 13661754, upload-time = "2025-12-15T05:02:26.731Z" }, + { url = "https://files.pythonhosted.org/packages/fd/a3/47cbd4e85bec4335a9cd80cf67dbc02be21b5d4c9c23ad6b95d6c5196bac/mypy-1.19.1-cp311-cp311-win_amd64.whl", hash = "sha256:4b84a7a18f41e167f7995200a1d07a4a6810e89d29859df936f1c3923d263042", size = 10055772, upload-time = "2025-12-15T05:03:26.179Z" }, + { url = "https://files.pythonhosted.org/packages/8d/f4/4ce9a05ce5ded1de3ec1c1d96cf9f9504a04e54ce0ed55cfa38619a32b8d/mypy-1.19.1-py3-none-any.whl", hash = "sha256:f1235f5ea01b7db5468d53ece6aaddf1ad0b88d9e7462b86ef96fe04995d7247", size = 2471239, upload-time = "2025-12-15T05:03:07.248Z" }, +] + [[package]] name = "mypy-extensions" version = "1.1.0" @@ -1849,12 +1891,14 @@ audio = [ ] dev = [ { name = "black" }, + { name = "mypy" }, { name = "pre-commit" }, { name = "pylint" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "ruff" }, { name = "scikit-optimize" }, + { name = "types-tensorflow" }, ] kaggle = [ { name = "kaggle" }, @@ -1876,6 +1920,7 @@ requires-dist = [ { name = "kagglehub", marker = "extra == 'kaggle'", specifier = ">=0.3.11" }, { name = "keras", specifier = ">=3.9.0" }, { name = "matplotlib", specifier = ">=3.8.2" }, + { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" }, { name = "nltk", specifier = ">=3.9.1" }, { name = "numpy", specifier = ">=1.26.2" }, { name = "optuna", specifier = ">=4.2.1" }, @@ -1899,6 +1944,7 @@ requires-dist = [ { name = "tensorflow-metal", marker = "extra == 'macos'", specifier = ">=1.2.0" }, { name = "tfds-nightly", specifier = ">=4.9.7.dev202503080044" }, { name = "transformers", specifier = ">=4.49.0" }, + { name = "types-tensorflow", marker = "extra == 'dev'", specifier = ">=2.16.0" }, { name = "vosk", specifier = ">=0.3.44" }, { name = "vosk", marker = "extra == 'audio'", specifier = "==0.3.44" }, { name = "wheel", specifier = ">=0.45.1" }, @@ -2316,6 +2362,41 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4a/91/48db081e7a63bb37284f9fbcefda7c44c277b18b0e13fbc36ea2335b71e6/typer-0.24.1-py3-none-any.whl", hash = "sha256:112c1f0ce578bfb4cab9ffdabc68f031416ebcc216536611ba21f04e9aa84c9e", size = 56085, upload-time = "2026-02-21T16:54:41.616Z" }, ] +[[package]] +name = "types-protobuf" +version = "6.32.1.20260221" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5f/e2/9aa4a3b2469508bd7b4e2ae11cbedaf419222a09a1b94daffcd5efca4023/types_protobuf-6.32.1.20260221.tar.gz", hash = "sha256:6d5fb060a616bfb076cbb61b4b3c3969f5fc8bec5810f9a2f7e648ee5cbcbf6e", size = 64408, upload-time = "2026-02-21T03:55:13.916Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/e8/1fd38926f9cf031188fbc5a96694203ea6f24b0e34bd64a225ec6f6291ba/types_protobuf-6.32.1.20260221-py3-none-any.whl", hash = "sha256:da7cdd947975964a93c30bfbcc2c6841ee646b318d3816b033adc2c4eb6448e4", size = 77956, upload-time = "2026-02-21T03:55:12.894Z" }, +] + +[[package]] +name = "types-requests" +version = "2.32.4.20260107" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0f/f3/a0663907082280664d745929205a89d41dffb29e89a50f753af7d57d0a96/types_requests-2.32.4.20260107.tar.gz", hash = "sha256:018a11ac158f801bfa84857ddec1650750e393df8a004a8a9ae2a9bec6fcb24f", size = 23165, upload-time = "2026-01-07T03:20:54.091Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1c/12/709ea261f2bf91ef0a26a9eed20f2623227a8ed85610c1e54c5805692ecb/types_requests-2.32.4.20260107-py3-none-any.whl", hash = "sha256:b703fe72f8ce5b31ef031264fe9395cac8f46a04661a79f7ed31a80fb308730d", size = 20676, upload-time = "2026-01-07T03:20:52.929Z" }, +] + +[[package]] +name = "types-tensorflow" +version = "2.18.0.20260224" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "types-protobuf" }, + { name = "types-requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/af/cb/4914c2fbc1cf8a8d1ef2a7c727bb6f694879be85edeee880a0c88e696af8/types_tensorflow-2.18.0.20260224.tar.gz", hash = "sha256:9b0ccc91c79c88791e43d3f80d6c879748fa0361409c5ff23c7ffe3709be00f2", size = 258786, upload-time = "2026-02-24T04:06:45.613Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d4/1d/a1c3c60f0eb1a204500dbdc66e3d18aafabc86ad07a8eca71ea05bc8c5a8/types_tensorflow-2.18.0.20260224-py3-none-any.whl", hash = "sha256:6a25f5f41f3e06f28c1f65c6e09f484d4ba0031d6d8df83a39df9d890245eefc", size = 329746, upload-time = "2026-02-24T04:06:44.4Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0" From e3e680fab7f3d7a6f70fef9b23675de6548b916b Mon Sep 17 00:00:00 2001 From: Jeremy Vachier <89128100+jvachier@users.noreply.github.com> Date: Sun, 1 Mar 2026 12:20:55 +0100 Subject: [PATCH 2/3] fixing ruff. --- tests/test_transformer_model.py | 6 +++--- tests/test_utils.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_transformer_model.py b/tests/test_transformer_model.py index 8ee2509..c63fd3d 100644 --- a/tests/test_transformer_model.py +++ b/tests/test_transformer_model.py @@ -13,9 +13,9 @@ @pytest.fixture -def setup_data() -> ( - Tuple[TextPreprocessor, tf.data.Dataset, tf.data.Dataset, tf.data.Dataset] -): +def setup_data() -> Tuple[ + TextPreprocessor, tf.data.Dataset, tf.data.Dataset, tf.data.Dataset +]: """ Fixture to set up a mocked dataset and preprocessor for testing. diff --git a/tests/test_utils.py b/tests/test_utils.py index fddea7d..0205e8e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -84,15 +84,15 @@ def test_enum_uniqueness(): """Test that enum values are unique.""" # Check ModelPaths model_values = [item.value for item in ModelPaths] - assert len(model_values) == len( - set(model_values) - ), "ModelPaths has duplicate values" + assert len(model_values) == len(set(model_values)), ( + "ModelPaths has duplicate values" + ) # Check OptunaPaths optuna_values = [item.value for item in OptunaPaths] - assert len(optuna_values) == len( - set(optuna_values) - ), "OptunaPaths has duplicate values" + assert len(optuna_values) == len(set(optuna_values)), ( + "OptunaPaths has duplicate values" + ) def test_enum_string_conversion(): From 560ad4ccacd01fb2717fd62fd65f68c6c713abc4 Mon Sep 17 00:00:00 2001 From: Jeremy Vachier <89128100+jvachier@users.noreply.github.com> Date: Sun, 1 Mar 2026 14:19:13 +0100 Subject: [PATCH 3/3] Updating requirement. --- gradio_apps/requirements.txt | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/gradio_apps/requirements.txt b/gradio_apps/requirements.txt index 6217096..d899871 100644 --- a/gradio_apps/requirements.txt +++ b/gradio_apps/requirements.txt @@ -1,3 +1,4 @@ -gradio==4.0.0 -tensorflow==2.19.0 -numpy==1.26.0 +huggingface-hub==0.25.1 +tensorflow==2.20.0 +numpy>=1.26.2 +audioop-lts