diff --git a/README.md b/README.md index 961bf75..cc31457 100644 --- a/README.md +++ b/README.md @@ -17,15 +17,9 @@ Production-ready machine learning pipeline for personality classification using ## Dashboard Preview -
- -
- Watch a live demo of the Personality Classification Dashboard in action -
+![Dashboard Demo](docs/images/personality_classification_app.mp4) + +*Watch a live demo of the Personality Classification Dashboard in action* ## Quick Start diff --git a/dash_app/dashboard/callbacks.py b/dash_app/dashboard/callbacks.py index 717d6da..e83898f 100644 --- a/dash_app/dashboard/callbacks.py +++ b/dash_app/dashboard/callbacks.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +from dataclasses import dataclass from datetime import datetime from dash import html @@ -13,6 +14,64 @@ ) +# Default values for prediction inputs +class PredictionDefaults: + """Default values for personality prediction features.""" + + TIME_ALONE = 2.0 + SOCIAL_EVENTS = 4.0 + GOING_OUTSIDE = 3.0 + FRIENDS_SIZE = 8.0 + POST_FREQUENCY = 3.0 + + +@dataclass +class PredictionInputs: + """Data class for prediction input parameters.""" + + time_alone: float | None = None + social_events: float | None = None + going_outside: float | None = None + friends_size: float | None = None + post_freq: float | None = None + stage_fear: str | None = None + drained_social: str | None = None + + def to_feature_dict(self) -> dict[str, int | float]: + """Convert inputs to feature dictionary for model prediction.""" + return { + "Time_spent_Alone": self.time_alone + if self.time_alone is not None + else PredictionDefaults.TIME_ALONE, + "Social_event_attendance": self.social_events + if self.social_events is not None + else PredictionDefaults.SOCIAL_EVENTS, + "Going_outside": self.going_outside + if self.going_outside is not None + else PredictionDefaults.GOING_OUTSIDE, + "Friends_circle_size": self.friends_size + if self.friends_size is not None + else PredictionDefaults.FRIENDS_SIZE, + "Post_frequency": self.post_freq + if self.post_freq is not None + else PredictionDefaults.POST_FREQUENCY, + # One-hot encode Stage_fear + "Stage_fear_No": 1 if self.stage_fear == "No" else 0, + "Stage_fear_Unknown": 1 if self.stage_fear == "Unknown" else 0, + "Stage_fear_Yes": 1 if self.stage_fear == "Yes" else 0, + # One-hot encode Drained_after_socializing + "Drained_after_socializing_No": 1 if self.drained_social == "No" else 0, + "Drained_after_socializing_Unknown": 1 + if self.drained_social == "Unknown" + else 0, + "Drained_after_socializing_Yes": 1 if self.drained_social == "Yes" else 0, + # Set external match features to Unknown (default) + "match_p_Extrovert": 0, + "match_p_Introvert": 0, + "match_p_Unknown": 1, + } + + def register_callbacks(app, model_loader, prediction_history: list) -> None: """Register all callbacks for the Dash application. @@ -37,47 +96,25 @@ def register_callbacks(app, model_loader, prediction_history: list) -> None: State("drained-after-socializing", "value"), prevent_initial_call=True, ) - def make_prediction( - n_clicks, - time_alone, - social_events, - going_outside, - friends_size, - post_freq, - stage_fear, - drained_social, - ): + def make_prediction(n_clicks, *input_values): """Handle prediction requests.""" if not n_clicks: return "" try: - # Build the feature dictionary with proper encoding - data = { - "Time_spent_Alone": time_alone if time_alone is not None else 2.0, - "Social_event_attendance": social_events - if social_events is not None - else 4.0, - "Going_outside": going_outside if going_outside is not None else 3.0, - "Friends_circle_size": friends_size - if friends_size is not None - else 8.0, - "Post_frequency": post_freq if post_freq is not None else 3.0, - # One-hot encode Stage_fear - "Stage_fear_No": 1 if stage_fear == "No" else 0, - "Stage_fear_Unknown": 1 if stage_fear == "Unknown" else 0, - "Stage_fear_Yes": 1 if stage_fear == "Yes" else 0, - # One-hot encode Drained_after_socializing - "Drained_after_socializing_No": 1 if drained_social == "No" else 0, - "Drained_after_socializing_Unknown": 1 - if drained_social == "Unknown" - else 0, - "Drained_after_socializing_Yes": 1 if drained_social == "Yes" else 0, - # Set external match features to Unknown (default) - "match_p_Extrovert": 0, - "match_p_Introvert": 0, - "match_p_Unknown": 1, - } + # Create prediction inputs from callback arguments + inputs = PredictionInputs( + time_alone=input_values[0], + social_events=input_values[1], + going_outside=input_values[2], + friends_size=input_values[3], + post_freq=input_values[4], + stage_fear=input_values[5], + drained_social=input_values[6], + ) + + # Convert to feature dictionary + data = inputs.to_feature_dict() # Make prediction result = model_loader.predict(data) diff --git a/dash_app/dashboard/layout.py b/dash_app/dashboard/layout.py index df5644c..648d0c7 100644 --- a/dash_app/dashboard/layout.py +++ b/dash_app/dashboard/layout.py @@ -2,6 +2,7 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Any import dash_bootstrap_components as dbc @@ -9,6 +10,20 @@ from dash import dcc, html +@dataclass +class SliderConfig: + """Configuration for enhanced sliders.""" + + slider_id: str + label: str + min_val: int + max_val: int + default: int + intro_text: str + extro_text: str + css_class: str + + def create_layout(model_name: str, model_metadata: dict[str, Any]) -> html.Div: """Create the main layout for the Dash application. @@ -207,24 +222,28 @@ def create_input_panel() -> dbc.Card: className="section-title mb-4", ), create_enhanced_slider( - "time-spent-alone", - "Time Spent Alone (hours/day)", - 0, - 24, - 8, - "Less alone time", - "More alone time", - "slider-social", + SliderConfig( + slider_id="time-spent-alone", + label="Time Spent Alone (hours/day)", + min_val=0, + max_val=24, + default=8, + intro_text="Less alone time", + extro_text="More alone time", + css_class="slider-social", + ) ), create_enhanced_slider( - "social-event-attendance", - "Social Event Attendance (events/month)", - 0, - 20, - 4, - "Fewer events", - "More events", - "slider-social", + SliderConfig( + slider_id="social-event-attendance", + label="Social Event Attendance (events/month)", + min_val=0, + max_val=20, + default=4, + intro_text="Fewer events", + extro_text="More events", + css_class="slider-social", + ) ), # Lifestyle Section html.H5( @@ -238,24 +257,28 @@ def create_input_panel() -> dbc.Card: className="section-title mt-5 mb-4", ), create_enhanced_slider( - "going-outside", - "Going Outside Frequency (times/week)", - 0, - 15, - 5, - "Stay indoors", - "Go out frequently", - "slider-lifestyle", + SliderConfig( + slider_id="going-outside", + label="Going Outside Frequency (times/week)", + min_val=0, + max_val=15, + default=5, + intro_text="Stay indoors", + extro_text="Go out frequently", + css_class="slider-lifestyle", + ) ), create_enhanced_slider( - "friends-circle-size", - "Friends Circle Size", - 0, - 50, - 12, - "Small circle", - "Large network", - "slider-lifestyle", + SliderConfig( + slider_id="friends-circle-size", + label="Friends Circle Size", + min_val=0, + max_val=50, + default=12, + intro_text="Small circle", + extro_text="Large network", + css_class="slider-lifestyle", + ) ), # Digital Behavior Section html.H5( @@ -269,14 +292,16 @@ def create_input_panel() -> dbc.Card: className="section-title mt-5 mb-4", ), create_enhanced_slider( - "post-frequency", - "Social Media Posts (per week)", - 0, - 20, - 3, - "Rarely post", - "Frequently post", - "slider-digital", + SliderConfig( + slider_id="post-frequency", + label="Social Media Posts (per week)", + min_val=0, + max_val=20, + default=3, + intro_text="Rarely post", + extro_text="Frequently post", + css_class="slider-digital", + ) ), # Psychological Assessment Section html.H5( @@ -353,38 +378,36 @@ def create_input_panel() -> dbc.Card: ) -def create_enhanced_slider( - slider_id: str, - label: str, - min_val: int, - max_val: int, - default: int, - intro_text: str, - extro_text: str, - css_class: str, -) -> html.Div: - """Create an enhanced slider with personality hints.""" +def create_enhanced_slider(config: SliderConfig) -> html.Div: + """Create an enhanced slider with personality hints. + + Args: + config: Slider configuration containing all parameters + + Returns: + HTML div containing the slider component + """ return html.Div( [ - html.Label(label, className="slider-label fw-bold"), + html.Label(config.label, className="slider-label fw-bold"), dcc.Slider( - id=slider_id, - min=min_val, - max=max_val, + id=config.slider_id, + min=config.min_val, + max=config.max_val, step=1, - value=default, + value=config.default, marks={ - min_val: { - "label": intro_text, + config.min_val: { + "label": config.intro_text, "style": {"color": "#3498db", "fontSize": "0.8rem"}, }, - max_val: { - "label": extro_text, + config.max_val: { + "label": config.extro_text, "style": {"color": "#e74c3c", "fontSize": "0.8rem"}, }, }, tooltip={"placement": "bottom", "always_visible": True}, - className=f"personality-slider {css_class}", + className=f"personality-slider {config.css_class}", ), ], className="slider-container mb-3", diff --git a/docs/README.md b/docs/README.md index ff5884e..ac7fdb7 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,7 +1,5 @@ # Documentation Index -# Documentation Index - Welcome! This documentation covers all aspects of the Six-Stack Personality Classification Pipeline. ## Main Guides @@ -42,7 +40,7 @@ docker build -t personality-classifier . docker run -p 8080:8080 personality-classifier ``` -## ️ Resources +## 📚 Resources - Code: `src/main_modular.py`, `examples/` - Config templates: [Configuration Guide](configuration.md) diff --git a/docs/api-reference.md b/docs/api-reference.md index ee82db2..f0faf32 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -1,48 +1,49 @@ # API Reference - Six-Stack Personality Classification Pipeline -## Modules & Functions - -**config.py** -- RND: int = 42 -- N_SPLITS: int = 5 -- N_TRIALS_STACK: int = 15 -- N_TRIALS_BLEND: int = 200 -- LOG_LEVEL: str = "INFO" -- ENABLE_DATA_AUGMENTATION: bool = True -- AUGMENTATION_METHOD: str = "sdv_copula" -- AUGMENTATION_RATIO: float = 0.05 -- DIVERSITY_THRESHOLD: float = 0.95 -- QUALITY_THRESHOLD: float = 0.7 -- class ThreadConfig(Enum): N_JOBS, THREAD_COUNT -- setup_logging(), get_logger(name) - -**data_loader.py** -- load_data_with_external_merge() -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame] - -**preprocessing.py** -- preprocess_data(df) -> pd.DataFrame - -**data_augmentation.py** -- augment_data(X, y, method, ratio) -> pd.DataFrame - -**model_builders.py** -- build_stack(stack_id, X, y) -> model - -**ensemble.py** -- blend_predictions(preds_list) -> np.ndarray - -**optimization.py** -- optimize_hyperparameters(model, X, y) -> dict - -**utils.py** -- Utility functions for metrics, logging, etc. - Returns: - tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: - (train_df, test_df, submission_template) - """ -def merge_external_features(df_main: pd.DataFrame, - df_external: pd.DataFrame, - is_training: bool = True) -> pd.DataFrame: +## Overview + +This document provides comprehensive API documentation for all modules and functions in the personality classification pipeline. + +## Configuration Module (`config.py`) + +### Constants + +- `RND: int = 42` - Random seed for reproducibility +- `N_SPLITS: int = 5` - Number of cross-validation splits +- `N_TRIALS_STACK: int = 15` - Optimization trials per stack +- `N_TRIALS_BLEND: int = 200` - Optimization trials for blending +- `LOG_LEVEL: str = "INFO"` - Logging level +- `ENABLE_DATA_AUGMENTATION: bool = True` - Data augmentation toggle +- `AUGMENTATION_METHOD: str = "sdv_copula"` - Augmentation method +- `AUGMENTATION_RATIO: float = 0.05` - Augmentation ratio +- `DIVERSITY_THRESHOLD: float = 0.95` - Diversity threshold +- `QUALITY_THRESHOLD: float = 0.7` - Quality threshold + +### Classes + +```python +class ThreadConfig(Enum): + """Thread configuration enumeration.""" + N_JOBS = "n_jobs" + THREAD_COUNT = "thread_count" +``` + +### Functions + +```python +def setup_logging() -> None: + """Setup logging configuration.""" + +def get_logger(name: str) -> logging.Logger: + """Get logger instance with specified name.""" +``` + +## Data Loading Module (`data_loader.py`) + +### Main Functions + +```python +def load_data_with_external_merge() -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ Merge external dataset features using strategic matching. diff --git a/docs/architecture.md b/docs/architecture.md index a530c85..f024b3f 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -1,21 +1,43 @@ # Architecture Documentation -## Architecture Overview - -- Modular pipeline: 8 core modules in `src/modules/` -- Main pipeline: `src/main_modular.py` -- Dashboard: `dash_app/` (Dash, Docker) -- Model stacks: 6 specialized ensembles (A-F) -- Data flow: Load → Preprocess → Augment → Train → Ensemble → Predict - -## Stacks -- A: Traditional ML (narrow) -- B: Traditional ML (wide) -- C: XGBoost/CatBoost -- D: Sklearn ensemble -- E: Neural networks -- F: Noise-robust +## System Overview + +The Six-Stack Personality Classification Pipeline is built with a modular, scalable architecture designed for machine learning competitions and production deployment. + +### Core Architecture + +- **Modular pipeline**: 8 core modules in `src/modules/` +- **Main pipeline**: `src/main_modular.py` +- **Dashboard**: `dash_app/` (Dash, Docker) +- **Model stacks**: 6 specialized ensembles (A-F) +- **Data flow**: Load → Preprocess → Augment → Train → Ensemble → Predict + +## Model Stack Architecture + +### Stack Specializations + +| Stack | Type | Description | +|-------|------|-------------| +| **A** | Traditional ML (narrow) | Focused feature selection with classic algorithms | +| **B** | Traditional ML (wide) | Comprehensive feature engineering with traditional models | +| **C** | XGBoost/CatBoost | Gradient boosting specialists | +| **D** | Sklearn ensemble | Ensemble of sklearn algorithms | +| **E** | Neural networks | Deep learning approaches | +| **F** | Noise-robust | Robust methods for noisy data | ## Key Features -- Efficient, reproducible, and testable -- Full logging and error handling + +### Design Principles + +- **Efficient**: Optimized for both speed and accuracy +- **Reproducible**: Consistent results with random seed control +- **Testable**: Comprehensive test coverage +- **Modular**: Easy to extend and maintain + +### Core Capabilities + +- **Full logging**: Comprehensive error handling and progress tracking +- **Data augmentation**: Advanced synthetic data generation +- **Hyperparameter optimization**: Automated tuning for each stack +- **Cross-validation**: Robust evaluation methodology +- **Ensemble learning**: Meta-learning for optimal predictions diff --git a/docs/data-augmentation.md b/docs/data-augmentation.md index df8b39f..b437964 100644 --- a/docs/data-augmentation.md +++ b/docs/data-augmentation.md @@ -1,21 +1,33 @@ # Data Augmentation Guide -## Data Augmentation Guide - -### Strategy +## Strategy - Adaptive selection based on dataset size, balance, and feature types -### Decision Matrix +## Decision Matrix | Data Type | Method | |-------------------|---------------| | Small/Imbalanced | SMOTE/ADASYN | | High Categorical | Basic | | Complex Numeric | SDV Copula | -### Main Method -**SDV Copula** (recommended): -- Preserves feature distributions and correlations -- Fast mode for development +## Main Method + +### 1. SDV Copula (Recommended) + +**Best for**: Complex datasets with mixed feature types + +#### Features + +- **Preserves feature distributions and correlations** +- **Fast mode for development** +- **Handles mixed data types effectively** +- **Statistical validation built-in** + +#### Implementation + +```python +def sdv_copula_augmentation(X, y, n_samples): + """SDV Copula-based synthetic data generation.""" synthesizer = GaussianCopula( enforce_rounding=True, enforce_min_max_values=True diff --git a/docs/performance-tuning.md b/docs/performance-tuning.md index 2558a98..c6dd85d 100644 --- a/docs/performance-tuning.md +++ b/docs/performance-tuning.md @@ -1,8 +1,10 @@ # Performance Tuning Guide -## Performance Tuning Guide +## Overview -### Key Levers +This guide covers optimization strategies for speed, memory usage, and accuracy in the personality classification pipeline. + +## Key Performance Levers - Training speed: TESTING_MODE, N_TRIALS_STACK, N_TRIALS_BLEND - Memory: TESTING_SAMPLE_SIZE, ENABLE_DATA_AUGMENTATION - Accuracy: Ensemble optimization, feature engineering diff --git a/pyproject.toml b/pyproject.toml index acbfb3a..1f0f4ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "personality-classification" -version = "0.1.0" +version = "2.1.1" description = "Six-Stack Personality Classification Pipeline with Advanced ML Ensemble Methods" authors = [ {name = "Jeremy Vachier"} @@ -21,7 +21,7 @@ classifiers = [ dependencies = [ # Core ML libraries - "numpy>=1.24.0,<2.0.0", + "numpy>=1.24.0,<3.0.0", "pandas>=2.0.0,<3.0.0", "scikit-learn>=1.3.0,<1.6.0", # Advanced ML models (gradient boosting) @@ -41,6 +41,10 @@ dependencies = [ "dash>=2.14.0,<3.0.0", "dash-bootstrap-components>=1.7.1", "plotly>=5.24.1", + # Utility libraries + "pyyaml>=6.0.0,<7.0.0", + "tqdm>=4.65.0,<5.0.0", + "colorama>=0.4.6,<1.0.0", ] [project.optional-dependencies] @@ -78,7 +82,6 @@ Documentation = "https://github.com/jeremyvachier/personality-classification/blo [project.scripts] personality-classify = "src.main_modular:main" -personality-demo = "src.main_demo:main" [build-system] requires = ["hatchling"] @@ -140,12 +143,11 @@ select = [ "RUF", # Ruff-specific rules ] ignore = [ - "E501", # line too long (handled by black) + "E501", # line too long (handled by formatter) "B905", # `zip()` without an explicit `strict=` parameter - "B904", # raise exceptions with `raise ... from err` (false positives in pre-commit) - "PLR0912", # too many branches - "PLR0913", # too many arguments - "PLR0915", # too many statements + "PLR0912", # too many branches (complex ML functions) + "PLR0913", # too many arguments (complex ML functions) + "PLR0915", # too many statements (complex ML functions) "PLR2004", # magic value used in comparison ] diff --git a/src/modules/ensemble.py b/src/modules/ensemble.py index fbcf434..02589e2 100644 --- a/src/modules/ensemble.py +++ b/src/modules/ensemble.py @@ -2,6 +2,9 @@ Ensemble functions for out-of-fold predictions and blending optimization. """ +from collections.abc import Callable +from dataclasses import dataclass + import numpy as np import pandas as pd from sklearn.metrics import accuracy_score @@ -14,6 +17,18 @@ logger = get_logger(__name__) +@dataclass +class NoisyEnsembleConfig: + """Configuration for noisy label ensemble training.""" + + model_builder: Callable + X: pd.DataFrame + y: pd.Series + X_test: pd.DataFrame + noise_rate: float + sample_weights: np.ndarray | None = None + + def oof_probs( model_builder, X: pd.DataFrame, @@ -46,69 +61,67 @@ def oof_probs( return oof_preds, test_preds -def oof_probs_noisy( - model_builder, - X: pd.DataFrame, - y: pd.Series, - X_test: pd.DataFrame, - noise_rate: float, - sample_weights=None, -): - """Generate out-of-fold predictions for noisy label ensemble.""" - oof_preds = np.zeros(len(X)) - test_preds = np.zeros(len(X_test)) +def oof_probs_noisy(config: NoisyEnsembleConfig) -> tuple[np.ndarray, np.ndarray]: + """Generate out-of-fold predictions for noisy label ensemble. + + Args: + config: Configuration containing model builder, data, and noise parameters + + Returns: + Tuple of (oof_predictions, test_predictions) + """ + oof_preds = np.zeros(len(config.X)) + test_preds = np.zeros(len(config.X_test)) cv = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=RND) - for fold, (tr_idx, val_idx) in enumerate(cv.split(X, y)): + for fold, (tr_idx, val_idx) in enumerate(cv.split(config.X, config.y)): logger.info( - f" Fold {fold + 1}/{N_SPLITS} (with {noise_rate:.1%} label noise)" + f" Fold {fold + 1}/{N_SPLITS} (with {config.noise_rate:.1%} label noise)" ) - X_train, X_val = X.iloc[tr_idx], X.iloc[val_idx] - y_train, _y_val = y.iloc[tr_idx], y.iloc[val_idx] + X_train, X_val = config.X.iloc[tr_idx], config.X.iloc[val_idx] + y_train, _y_val = config.y.iloc[tr_idx], config.y.iloc[val_idx] # Add label noise to training data y_train_noisy = add_label_noise( - y_train, noise_rate=noise_rate, random_state=RND + fold + y_train, noise_rate=config.noise_rate, random_state=RND + fold ) # Build and fit model - model = model_builder() + model = config.model_builder() model.fit(X_train, y_train_noisy) # Out-of-fold predictions (on clean validation data) oof_preds[val_idx] = model.predict_proba(X_val)[:, 1] # Test predictions (averaged across folds) - test_preds += model.predict_proba(X_test)[:, 1] / N_SPLITS + test_preds += model.predict_proba(config.X_test)[:, 1] / N_SPLITS return oof_preds, test_preds -def improved_blend_obj(trial, oof_A, oof_B, oof_C, oof_D, oof_E, oof_F, y_true): - """Improved blending objective with constraints and regularization.""" - # Sample blend weights - w1 = trial.suggest_float("w1", 0.0, 1.0) - w2 = trial.suggest_float("w2", 0.0, 1.0) - w3 = trial.suggest_float("w3", 0.0, 1.0) - w4 = trial.suggest_float("w4", 0.0, 1.0) - w5 = trial.suggest_float("w5", 0.0, 1.0) - w6 = trial.suggest_float("w6", 0.0, 1.0) +def improved_blend_obj(trial, oof_predictions: dict[str, np.ndarray], y_true): + """Improved blending objective with constraints and regularization. + + Args: + trial: Optuna trial object + oof_predictions: Dictionary mapping stack names to their OOF predictions + y_true: True labels + """ + # Sample blend weights for each stack + weights = {} + for stack_name in oof_predictions: + weights[stack_name] = trial.suggest_float(f"w_{stack_name}", 0.0, 1.0) # Normalize weights - weights = np.array([w1, w2, w3, w4, w5, w6]) - weights = weights / np.sum(weights) + weight_values = np.array(list(weights.values())) + normalized_weights = weight_values / np.sum(weight_values) # Calculate blended predictions - blended = ( - weights[0] * oof_A - + weights[1] * oof_B - + weights[2] * oof_C - + weights[3] * oof_D - + weights[4] * oof_E - + weights[5] * oof_F - ) + blended = np.zeros_like(y_true, dtype=float) + for i, (_stack_name, oof_pred) in enumerate(oof_predictions.items()): + blended += normalized_weights[i] * oof_pred # Convert to binary predictions y_pred = (blended >= 0.5).astype(int) @@ -117,6 +130,7 @@ def improved_blend_obj(trial, oof_A, oof_B, oof_C, oof_D, oof_E, oof_F, y_true): score = accuracy_score(y_true, y_pred) # Store normalized weights in trial attributes - trial.set_user_attr("weights", weights.tolist()) + trial.set_user_attr("weights", normalized_weights.tolist()) + trial.set_user_attr("stack_names", list(oof_predictions.keys())) return score diff --git a/tests/dash_app/test_callbacks.py b/tests/dash_app/test_callbacks.py index eaf2974..cc031e2 100644 --- a/tests/dash_app/test_callbacks.py +++ b/tests/dash_app/test_callbacks.py @@ -137,36 +137,41 @@ class TestCallbackInputValidation: def callback_function_mock(self): """Mock the actual callback function for testing.""" with patch("dash_app.dashboard.callbacks.register_callbacks") as mock_register: - # Create a mock prediction function - def mock_prediction_callback( - n_clicks, - time_alone, - social_events, - going_outside, - friends_size, - post_freq, - stage_fear, - drained_social, - ): + # Create a mock prediction function that matches our refactored signature + def mock_prediction_callback(n_clicks, *input_values): # Simulate input validation if n_clicks is None or n_clicks == 0: return "No prediction made" - # Validate input ranges + # Validate input ranges - unpack the input values + if len(input_values) < 7: + return "Invalid input: Not enough values" + + time_alone, social_events, going_outside = ( + input_values[0], + input_values[1], + input_values[2], + ) + friends_size, post_freq, _stage_fear, _drained_social = ( + input_values[3], + input_values[4], + input_values[5], + input_values[6], + ) + inputs = [ time_alone, social_events, going_outside, friends_size, post_freq, - stage_fear, - drained_social, ] if any(x is None for x in inputs): return "Invalid input: None values" - if any(not isinstance(x, int | float) for x in inputs): + # Check numeric inputs + if any(not isinstance(x, int | float) for x in inputs if x is not None): return "Invalid input: Non-numeric values" return "Valid prediction" diff --git a/uv.lock b/uv.lock index 45a3513..f2f8459 100644 --- a/uv.lock +++ b/uv.lock @@ -2481,10 +2481,11 @@ wheels = [ [[package]] name = "personality-classification" -version = "0.1.0" +version = "2.1.1" source = { editable = "." } dependencies = [ { name = "catboost" }, + { name = "colorama" }, { name = "dash" }, { name = "dash-bootstrap-components" }, { name = "imbalanced-learn" }, @@ -2494,9 +2495,11 @@ dependencies = [ { name = "optuna" }, { name = "pandas" }, { name = "plotly" }, + { name = "pyyaml" }, { name = "scikit-learn" }, { name = "scipy" }, { name = "sdv" }, + { name = "tqdm" }, { name = "xgboost" }, ] @@ -2534,6 +2537,7 @@ requires-dist = [ { name = "autogluon", marker = "extra == 'automl'", specifier = ">=1.1.1,<2.0.0" }, { name = "bandit", marker = "extra == 'dev'", specifier = ">=1.7.0,<2.0.0" }, { name = "catboost", specifier = ">=1.2.0,<2.0.0" }, + { name = "colorama", specifier = ">=0.4.6,<1.0.0" }, { name = "dash", specifier = ">=2.14.0,<3.0.0" }, { name = "dash-bootstrap-components", specifier = ">=1.7.1" }, { name = "h2o", marker = "extra == 'automl'", specifier = ">=3.44.0,<4.0.0" }, @@ -2541,7 +2545,7 @@ requires-dist = [ { name = "joblib", specifier = ">=1.3.0,<2.0.0" }, { name = "lightgbm", specifier = ">=4.0.0,<5.0.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.5.0,<2.0.0" }, - { name = "numpy", specifier = ">=1.24.0,<2.0.0" }, + { name = "numpy", specifier = ">=1.24.0,<3.0.0" }, { name = "optuna", specifier = ">=3.4.0,<4.0.0" }, { name = "pandas", specifier = ">=2.0.0,<3.0.0" }, { name = "plotly", specifier = ">=5.24.1" }, @@ -2550,10 +2554,12 @@ requires-dist = [ { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0,<8.0.0" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.1.0,<5.0.0" }, { name = "pytest-xdist", marker = "extra == 'dev'", specifier = ">=3.3.0,<4.0.0" }, + { name = "pyyaml", specifier = ">=6.0.0,<7.0.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.4.0,<1.0.0" }, { name = "scikit-learn", specifier = ">=1.3.0,<1.6.0" }, { name = "scipy", specifier = ">=1.11.0,<2.0.0" }, { name = "sdv", specifier = ">=1.24.0,<2.0.0" }, + { name = "tqdm", specifier = ">=4.65.0,<5.0.0" }, { name = "types-pyyaml", marker = "extra == 'dev'", specifier = ">=6.0.0,<7.0.0" }, { name = "types-requests", marker = "extra == 'dev'", specifier = ">=2.31.0,<3.0.0" }, { name = "xgboost", specifier = ">=2.0.0,<3.0.0" },