diff --git a/README.md b/README.md
index 961bf75..cc31457 100644
--- a/README.md
+++ b/README.md
@@ -17,15 +17,9 @@ Production-ready machine learning pipeline for personality classification using
## Dashboard Preview
-
-
-
-
Watch a live demo of the Personality Classification Dashboard in action
-
+
+
+*Watch a live demo of the Personality Classification Dashboard in action*
## Quick Start
diff --git a/dash_app/dashboard/callbacks.py b/dash_app/dashboard/callbacks.py
index 717d6da..e83898f 100644
--- a/dash_app/dashboard/callbacks.py
+++ b/dash_app/dashboard/callbacks.py
@@ -3,6 +3,7 @@
from __future__ import annotations
import logging
+from dataclasses import dataclass
from datetime import datetime
from dash import html
@@ -13,6 +14,64 @@
)
+# Default values for prediction inputs
+class PredictionDefaults:
+ """Default values for personality prediction features."""
+
+ TIME_ALONE = 2.0
+ SOCIAL_EVENTS = 4.0
+ GOING_OUTSIDE = 3.0
+ FRIENDS_SIZE = 8.0
+ POST_FREQUENCY = 3.0
+
+
+@dataclass
+class PredictionInputs:
+ """Data class for prediction input parameters."""
+
+ time_alone: float | None = None
+ social_events: float | None = None
+ going_outside: float | None = None
+ friends_size: float | None = None
+ post_freq: float | None = None
+ stage_fear: str | None = None
+ drained_social: str | None = None
+
+ def to_feature_dict(self) -> dict[str, int | float]:
+ """Convert inputs to feature dictionary for model prediction."""
+ return {
+ "Time_spent_Alone": self.time_alone
+ if self.time_alone is not None
+ else PredictionDefaults.TIME_ALONE,
+ "Social_event_attendance": self.social_events
+ if self.social_events is not None
+ else PredictionDefaults.SOCIAL_EVENTS,
+ "Going_outside": self.going_outside
+ if self.going_outside is not None
+ else PredictionDefaults.GOING_OUTSIDE,
+ "Friends_circle_size": self.friends_size
+ if self.friends_size is not None
+ else PredictionDefaults.FRIENDS_SIZE,
+ "Post_frequency": self.post_freq
+ if self.post_freq is not None
+ else PredictionDefaults.POST_FREQUENCY,
+ # One-hot encode Stage_fear
+ "Stage_fear_No": 1 if self.stage_fear == "No" else 0,
+ "Stage_fear_Unknown": 1 if self.stage_fear == "Unknown" else 0,
+ "Stage_fear_Yes": 1 if self.stage_fear == "Yes" else 0,
+ # One-hot encode Drained_after_socializing
+ "Drained_after_socializing_No": 1 if self.drained_social == "No" else 0,
+ "Drained_after_socializing_Unknown": 1
+ if self.drained_social == "Unknown"
+ else 0,
+ "Drained_after_socializing_Yes": 1 if self.drained_social == "Yes" else 0,
+ # Set external match features to Unknown (default)
+ "match_p_Extrovert": 0,
+ "match_p_Introvert": 0,
+ "match_p_Unknown": 1,
+ }
+
+
def register_callbacks(app, model_loader, prediction_history: list) -> None:
"""Register all callbacks for the Dash application.
@@ -37,47 +96,25 @@ def register_callbacks(app, model_loader, prediction_history: list) -> None:
State("drained-after-socializing", "value"),
prevent_initial_call=True,
)
- def make_prediction(
- n_clicks,
- time_alone,
- social_events,
- going_outside,
- friends_size,
- post_freq,
- stage_fear,
- drained_social,
- ):
+ def make_prediction(n_clicks, *input_values):
"""Handle prediction requests."""
if not n_clicks:
return ""
try:
- # Build the feature dictionary with proper encoding
- data = {
- "Time_spent_Alone": time_alone if time_alone is not None else 2.0,
- "Social_event_attendance": social_events
- if social_events is not None
- else 4.0,
- "Going_outside": going_outside if going_outside is not None else 3.0,
- "Friends_circle_size": friends_size
- if friends_size is not None
- else 8.0,
- "Post_frequency": post_freq if post_freq is not None else 3.0,
- # One-hot encode Stage_fear
- "Stage_fear_No": 1 if stage_fear == "No" else 0,
- "Stage_fear_Unknown": 1 if stage_fear == "Unknown" else 0,
- "Stage_fear_Yes": 1 if stage_fear == "Yes" else 0,
- # One-hot encode Drained_after_socializing
- "Drained_after_socializing_No": 1 if drained_social == "No" else 0,
- "Drained_after_socializing_Unknown": 1
- if drained_social == "Unknown"
- else 0,
- "Drained_after_socializing_Yes": 1 if drained_social == "Yes" else 0,
- # Set external match features to Unknown (default)
- "match_p_Extrovert": 0,
- "match_p_Introvert": 0,
- "match_p_Unknown": 1,
- }
+ # Create prediction inputs from callback arguments
+ inputs = PredictionInputs(
+ time_alone=input_values[0],
+ social_events=input_values[1],
+ going_outside=input_values[2],
+ friends_size=input_values[3],
+ post_freq=input_values[4],
+ stage_fear=input_values[5],
+ drained_social=input_values[6],
+ )
+
+ # Convert to feature dictionary
+ data = inputs.to_feature_dict()
# Make prediction
result = model_loader.predict(data)
diff --git a/dash_app/dashboard/layout.py b/dash_app/dashboard/layout.py
index df5644c..648d0c7 100644
--- a/dash_app/dashboard/layout.py
+++ b/dash_app/dashboard/layout.py
@@ -2,6 +2,7 @@
from __future__ import annotations
+from dataclasses import dataclass
from typing import Any
import dash_bootstrap_components as dbc
@@ -9,6 +10,20 @@
from dash import dcc, html
+@dataclass
+class SliderConfig:
+ """Configuration for enhanced sliders."""
+
+ slider_id: str
+ label: str
+ min_val: int
+ max_val: int
+ default: int
+ intro_text: str
+ extro_text: str
+ css_class: str
+
+
def create_layout(model_name: str, model_metadata: dict[str, Any]) -> html.Div:
"""Create the main layout for the Dash application.
@@ -207,24 +222,28 @@ def create_input_panel() -> dbc.Card:
className="section-title mb-4",
),
create_enhanced_slider(
- "time-spent-alone",
- "Time Spent Alone (hours/day)",
- 0,
- 24,
- 8,
- "Less alone time",
- "More alone time",
- "slider-social",
+ SliderConfig(
+ slider_id="time-spent-alone",
+ label="Time Spent Alone (hours/day)",
+ min_val=0,
+ max_val=24,
+ default=8,
+ intro_text="Less alone time",
+ extro_text="More alone time",
+ css_class="slider-social",
+ )
),
create_enhanced_slider(
- "social-event-attendance",
- "Social Event Attendance (events/month)",
- 0,
- 20,
- 4,
- "Fewer events",
- "More events",
- "slider-social",
+ SliderConfig(
+ slider_id="social-event-attendance",
+ label="Social Event Attendance (events/month)",
+ min_val=0,
+ max_val=20,
+ default=4,
+ intro_text="Fewer events",
+ extro_text="More events",
+ css_class="slider-social",
+ )
),
# Lifestyle Section
html.H5(
@@ -238,24 +257,28 @@ def create_input_panel() -> dbc.Card:
className="section-title mt-5 mb-4",
),
create_enhanced_slider(
- "going-outside",
- "Going Outside Frequency (times/week)",
- 0,
- 15,
- 5,
- "Stay indoors",
- "Go out frequently",
- "slider-lifestyle",
+ SliderConfig(
+ slider_id="going-outside",
+ label="Going Outside Frequency (times/week)",
+ min_val=0,
+ max_val=15,
+ default=5,
+ intro_text="Stay indoors",
+ extro_text="Go out frequently",
+ css_class="slider-lifestyle",
+ )
),
create_enhanced_slider(
- "friends-circle-size",
- "Friends Circle Size",
- 0,
- 50,
- 12,
- "Small circle",
- "Large network",
- "slider-lifestyle",
+ SliderConfig(
+ slider_id="friends-circle-size",
+ label="Friends Circle Size",
+ min_val=0,
+ max_val=50,
+ default=12,
+ intro_text="Small circle",
+ extro_text="Large network",
+ css_class="slider-lifestyle",
+ )
),
# Digital Behavior Section
html.H5(
@@ -269,14 +292,16 @@ def create_input_panel() -> dbc.Card:
className="section-title mt-5 mb-4",
),
create_enhanced_slider(
- "post-frequency",
- "Social Media Posts (per week)",
- 0,
- 20,
- 3,
- "Rarely post",
- "Frequently post",
- "slider-digital",
+ SliderConfig(
+ slider_id="post-frequency",
+ label="Social Media Posts (per week)",
+ min_val=0,
+ max_val=20,
+ default=3,
+ intro_text="Rarely post",
+ extro_text="Frequently post",
+ css_class="slider-digital",
+ )
),
# Psychological Assessment Section
html.H5(
@@ -353,38 +378,36 @@ def create_input_panel() -> dbc.Card:
)
-def create_enhanced_slider(
- slider_id: str,
- label: str,
- min_val: int,
- max_val: int,
- default: int,
- intro_text: str,
- extro_text: str,
- css_class: str,
-) -> html.Div:
- """Create an enhanced slider with personality hints."""
+def create_enhanced_slider(config: SliderConfig) -> html.Div:
+ """Create an enhanced slider with personality hints.
+
+ Args:
+ config: Slider configuration containing all parameters
+
+ Returns:
+ HTML div containing the slider component
+ """
return html.Div(
[
- html.Label(label, className="slider-label fw-bold"),
+ html.Label(config.label, className="slider-label fw-bold"),
dcc.Slider(
- id=slider_id,
- min=min_val,
- max=max_val,
+ id=config.slider_id,
+ min=config.min_val,
+ max=config.max_val,
step=1,
- value=default,
+ value=config.default,
marks={
- min_val: {
- "label": intro_text,
+ config.min_val: {
+ "label": config.intro_text,
"style": {"color": "#3498db", "fontSize": "0.8rem"},
},
- max_val: {
- "label": extro_text,
+ config.max_val: {
+ "label": config.extro_text,
"style": {"color": "#e74c3c", "fontSize": "0.8rem"},
},
},
tooltip={"placement": "bottom", "always_visible": True},
- className=f"personality-slider {css_class}",
+ className=f"personality-slider {config.css_class}",
),
],
className="slider-container mb-3",
diff --git a/docs/README.md b/docs/README.md
index ff5884e..ac7fdb7 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -1,7 +1,5 @@
# Documentation Index
-# Documentation Index
-
Welcome! This documentation covers all aspects of the Six-Stack Personality Classification Pipeline.
## Main Guides
@@ -42,7 +40,7 @@ docker build -t personality-classifier .
docker run -p 8080:8080 personality-classifier
```
-## ️ Resources
+## 📚 Resources
- Code: `src/main_modular.py`, `examples/`
- Config templates: [Configuration Guide](configuration.md)
diff --git a/docs/api-reference.md b/docs/api-reference.md
index ee82db2..f0faf32 100644
--- a/docs/api-reference.md
+++ b/docs/api-reference.md
@@ -1,48 +1,49 @@
# API Reference - Six-Stack Personality Classification Pipeline
-## Modules & Functions
-
-**config.py**
-- RND: int = 42
-- N_SPLITS: int = 5
-- N_TRIALS_STACK: int = 15
-- N_TRIALS_BLEND: int = 200
-- LOG_LEVEL: str = "INFO"
-- ENABLE_DATA_AUGMENTATION: bool = True
-- AUGMENTATION_METHOD: str = "sdv_copula"
-- AUGMENTATION_RATIO: float = 0.05
-- DIVERSITY_THRESHOLD: float = 0.95
-- QUALITY_THRESHOLD: float = 0.7
-- class ThreadConfig(Enum): N_JOBS, THREAD_COUNT
-- setup_logging(), get_logger(name)
-
-**data_loader.py**
-- load_data_with_external_merge() -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]
-
-**preprocessing.py**
-- preprocess_data(df) -> pd.DataFrame
-
-**data_augmentation.py**
-- augment_data(X, y, method, ratio) -> pd.DataFrame
-
-**model_builders.py**
-- build_stack(stack_id, X, y) -> model
-
-**ensemble.py**
-- blend_predictions(preds_list) -> np.ndarray
-
-**optimization.py**
-- optimize_hyperparameters(model, X, y) -> dict
-
-**utils.py**
-- Utility functions for metrics, logging, etc.
- Returns:
- tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
- (train_df, test_df, submission_template)
- """
-def merge_external_features(df_main: pd.DataFrame,
- df_external: pd.DataFrame,
- is_training: bool = True) -> pd.DataFrame:
+## Overview
+
+This document provides comprehensive API documentation for all modules and functions in the personality classification pipeline.
+
+## Configuration Module (`config.py`)
+
+### Constants
+
+- `RND: int = 42` - Random seed for reproducibility
+- `N_SPLITS: int = 5` - Number of cross-validation splits
+- `N_TRIALS_STACK: int = 15` - Optimization trials per stack
+- `N_TRIALS_BLEND: int = 200` - Optimization trials for blending
+- `LOG_LEVEL: str = "INFO"` - Logging level
+- `ENABLE_DATA_AUGMENTATION: bool = True` - Data augmentation toggle
+- `AUGMENTATION_METHOD: str = "sdv_copula"` - Augmentation method
+- `AUGMENTATION_RATIO: float = 0.05` - Augmentation ratio
+- `DIVERSITY_THRESHOLD: float = 0.95` - Diversity threshold
+- `QUALITY_THRESHOLD: float = 0.7` - Quality threshold
+
+### Classes
+
+```python
+class ThreadConfig(Enum):
+ """Thread configuration enumeration."""
+ N_JOBS = "n_jobs"
+ THREAD_COUNT = "thread_count"
+```
+
+### Functions
+
+```python
+def setup_logging() -> None:
+ """Setup logging configuration."""
+
+def get_logger(name: str) -> logging.Logger:
+ """Get logger instance with specified name."""
+```
+
+## Data Loading Module (`data_loader.py`)
+
+### Main Functions
+
+```python
+def load_data_with_external_merge() -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Merge external dataset features using strategic matching.
diff --git a/docs/architecture.md b/docs/architecture.md
index a530c85..f024b3f 100644
--- a/docs/architecture.md
+++ b/docs/architecture.md
@@ -1,21 +1,43 @@
# Architecture Documentation
-## Architecture Overview
-
-- Modular pipeline: 8 core modules in `src/modules/`
-- Main pipeline: `src/main_modular.py`
-- Dashboard: `dash_app/` (Dash, Docker)
-- Model stacks: 6 specialized ensembles (A-F)
-- Data flow: Load → Preprocess → Augment → Train → Ensemble → Predict
-
-## Stacks
-- A: Traditional ML (narrow)
-- B: Traditional ML (wide)
-- C: XGBoost/CatBoost
-- D: Sklearn ensemble
-- E: Neural networks
-- F: Noise-robust
+## System Overview
+
+The Six-Stack Personality Classification Pipeline is built with a modular, scalable architecture designed for machine learning competitions and production deployment.
+
+### Core Architecture
+
+- **Modular pipeline**: 8 core modules in `src/modules/`
+- **Main pipeline**: `src/main_modular.py`
+- **Dashboard**: `dash_app/` (Dash, Docker)
+- **Model stacks**: 6 specialized ensembles (A-F)
+- **Data flow**: Load → Preprocess → Augment → Train → Ensemble → Predict
+
+## Model Stack Architecture
+
+### Stack Specializations
+
+| Stack | Type | Description |
+|-------|------|-------------|
+| **A** | Traditional ML (narrow) | Focused feature selection with classic algorithms |
+| **B** | Traditional ML (wide) | Comprehensive feature engineering with traditional models |
+| **C** | XGBoost/CatBoost | Gradient boosting specialists |
+| **D** | Sklearn ensemble | Ensemble of sklearn algorithms |
+| **E** | Neural networks | Deep learning approaches |
+| **F** | Noise-robust | Robust methods for noisy data |
## Key Features
-- Efficient, reproducible, and testable
-- Full logging and error handling
+
+### Design Principles
+
+- **Efficient**: Optimized for both speed and accuracy
+- **Reproducible**: Consistent results with random seed control
+- **Testable**: Comprehensive test coverage
+- **Modular**: Easy to extend and maintain
+
+### Core Capabilities
+
+- **Full logging**: Comprehensive error handling and progress tracking
+- **Data augmentation**: Advanced synthetic data generation
+- **Hyperparameter optimization**: Automated tuning for each stack
+- **Cross-validation**: Robust evaluation methodology
+- **Ensemble learning**: Meta-learning for optimal predictions
diff --git a/docs/data-augmentation.md b/docs/data-augmentation.md
index df8b39f..b437964 100644
--- a/docs/data-augmentation.md
+++ b/docs/data-augmentation.md
@@ -1,21 +1,33 @@
# Data Augmentation Guide
-## Data Augmentation Guide
-
-### Strategy
+## Strategy
- Adaptive selection based on dataset size, balance, and feature types
-### Decision Matrix
+## Decision Matrix
| Data Type | Method |
|-------------------|---------------|
| Small/Imbalanced | SMOTE/ADASYN |
| High Categorical | Basic |
| Complex Numeric | SDV Copula |
-### Main Method
-**SDV Copula** (recommended):
-- Preserves feature distributions and correlations
-- Fast mode for development
+## Main Method
+
+### 1. SDV Copula (Recommended)
+
+**Best for**: Complex datasets with mixed feature types
+
+#### Features
+
+- **Preserves feature distributions and correlations**
+- **Fast mode for development**
+- **Handles mixed data types effectively**
+- **Statistical validation built-in**
+
+#### Implementation
+
+```python
+def sdv_copula_augmentation(X, y, n_samples):
+ """SDV Copula-based synthetic data generation."""
synthesizer = GaussianCopula(
enforce_rounding=True,
enforce_min_max_values=True
diff --git a/docs/performance-tuning.md b/docs/performance-tuning.md
index 2558a98..c6dd85d 100644
--- a/docs/performance-tuning.md
+++ b/docs/performance-tuning.md
@@ -1,8 +1,10 @@
# Performance Tuning Guide
-## Performance Tuning Guide
+## Overview
-### Key Levers
+This guide covers optimization strategies for speed, memory usage, and accuracy in the personality classification pipeline.
+
+## Key Performance Levers
- Training speed: TESTING_MODE, N_TRIALS_STACK, N_TRIALS_BLEND
- Memory: TESTING_SAMPLE_SIZE, ENABLE_DATA_AUGMENTATION
- Accuracy: Ensemble optimization, feature engineering
diff --git a/pyproject.toml b/pyproject.toml
index acbfb3a..1f0f4ac 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "personality-classification"
-version = "0.1.0"
+version = "2.1.1"
description = "Six-Stack Personality Classification Pipeline with Advanced ML Ensemble Methods"
authors = [
{name = "Jeremy Vachier"}
@@ -21,7 +21,7 @@ classifiers = [
dependencies = [
# Core ML libraries
- "numpy>=1.24.0,<2.0.0",
+ "numpy>=1.24.0,<3.0.0",
"pandas>=2.0.0,<3.0.0",
"scikit-learn>=1.3.0,<1.6.0",
# Advanced ML models (gradient boosting)
@@ -41,6 +41,10 @@ dependencies = [
"dash>=2.14.0,<3.0.0",
"dash-bootstrap-components>=1.7.1",
"plotly>=5.24.1",
+ # Utility libraries
+ "pyyaml>=6.0.0,<7.0.0",
+ "tqdm>=4.65.0,<5.0.0",
+ "colorama>=0.4.6,<1.0.0",
]
[project.optional-dependencies]
@@ -78,7 +82,6 @@ Documentation = "https://github.com/jeremyvachier/personality-classification/blo
[project.scripts]
personality-classify = "src.main_modular:main"
-personality-demo = "src.main_demo:main"
[build-system]
requires = ["hatchling"]
@@ -140,12 +143,11 @@ select = [
"RUF", # Ruff-specific rules
]
ignore = [
- "E501", # line too long (handled by black)
+ "E501", # line too long (handled by formatter)
"B905", # `zip()` without an explicit `strict=` parameter
- "B904", # raise exceptions with `raise ... from err` (false positives in pre-commit)
- "PLR0912", # too many branches
- "PLR0913", # too many arguments
- "PLR0915", # too many statements
+ "PLR0912", # too many branches (complex ML functions)
+ "PLR0913", # too many arguments (complex ML functions)
+ "PLR0915", # too many statements (complex ML functions)
"PLR2004", # magic value used in comparison
]
diff --git a/src/modules/ensemble.py b/src/modules/ensemble.py
index fbcf434..02589e2 100644
--- a/src/modules/ensemble.py
+++ b/src/modules/ensemble.py
@@ -2,6 +2,9 @@
Ensemble functions for out-of-fold predictions and blending optimization.
"""
+from collections.abc import Callable
+from dataclasses import dataclass
+
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
@@ -14,6 +17,18 @@
logger = get_logger(__name__)
+@dataclass
+class NoisyEnsembleConfig:
+ """Configuration for noisy label ensemble training."""
+
+ model_builder: Callable
+ X: pd.DataFrame
+ y: pd.Series
+ X_test: pd.DataFrame
+ noise_rate: float
+ sample_weights: np.ndarray | None = None
+
+
def oof_probs(
model_builder,
X: pd.DataFrame,
@@ -46,69 +61,67 @@ def oof_probs(
return oof_preds, test_preds
-def oof_probs_noisy(
- model_builder,
- X: pd.DataFrame,
- y: pd.Series,
- X_test: pd.DataFrame,
- noise_rate: float,
- sample_weights=None,
-):
- """Generate out-of-fold predictions for noisy label ensemble."""
- oof_preds = np.zeros(len(X))
- test_preds = np.zeros(len(X_test))
+def oof_probs_noisy(config: NoisyEnsembleConfig) -> tuple[np.ndarray, np.ndarray]:
+ """Generate out-of-fold predictions for noisy label ensemble.
+
+ Args:
+ config: Configuration containing model builder, data, and noise parameters
+
+ Returns:
+ Tuple of (oof_predictions, test_predictions)
+ """
+ oof_preds = np.zeros(len(config.X))
+ test_preds = np.zeros(len(config.X_test))
cv = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=RND)
- for fold, (tr_idx, val_idx) in enumerate(cv.split(X, y)):
+ for fold, (tr_idx, val_idx) in enumerate(cv.split(config.X, config.y)):
logger.info(
- f" Fold {fold + 1}/{N_SPLITS} (with {noise_rate:.1%} label noise)"
+ f" Fold {fold + 1}/{N_SPLITS} (with {config.noise_rate:.1%} label noise)"
)
- X_train, X_val = X.iloc[tr_idx], X.iloc[val_idx]
- y_train, _y_val = y.iloc[tr_idx], y.iloc[val_idx]
+ X_train, X_val = config.X.iloc[tr_idx], config.X.iloc[val_idx]
+ y_train, _y_val = config.y.iloc[tr_idx], config.y.iloc[val_idx]
# Add label noise to training data
y_train_noisy = add_label_noise(
- y_train, noise_rate=noise_rate, random_state=RND + fold
+ y_train, noise_rate=config.noise_rate, random_state=RND + fold
)
# Build and fit model
- model = model_builder()
+ model = config.model_builder()
model.fit(X_train, y_train_noisy)
# Out-of-fold predictions (on clean validation data)
oof_preds[val_idx] = model.predict_proba(X_val)[:, 1]
# Test predictions (averaged across folds)
- test_preds += model.predict_proba(X_test)[:, 1] / N_SPLITS
+ test_preds += model.predict_proba(config.X_test)[:, 1] / N_SPLITS
return oof_preds, test_preds
-def improved_blend_obj(trial, oof_A, oof_B, oof_C, oof_D, oof_E, oof_F, y_true):
- """Improved blending objective with constraints and regularization."""
- # Sample blend weights
- w1 = trial.suggest_float("w1", 0.0, 1.0)
- w2 = trial.suggest_float("w2", 0.0, 1.0)
- w3 = trial.suggest_float("w3", 0.0, 1.0)
- w4 = trial.suggest_float("w4", 0.0, 1.0)
- w5 = trial.suggest_float("w5", 0.0, 1.0)
- w6 = trial.suggest_float("w6", 0.0, 1.0)
+def improved_blend_obj(trial, oof_predictions: dict[str, np.ndarray], y_true):
+ """Improved blending objective with constraints and regularization.
+
+ Args:
+ trial: Optuna trial object
+ oof_predictions: Dictionary mapping stack names to their OOF predictions
+ y_true: True labels
+ """
+ # Sample blend weights for each stack
+ weights = {}
+ for stack_name in oof_predictions:
+ weights[stack_name] = trial.suggest_float(f"w_{stack_name}", 0.0, 1.0)
# Normalize weights
- weights = np.array([w1, w2, w3, w4, w5, w6])
- weights = weights / np.sum(weights)
+ weight_values = np.array(list(weights.values()))
+ normalized_weights = weight_values / np.sum(weight_values)
# Calculate blended predictions
- blended = (
- weights[0] * oof_A
- + weights[1] * oof_B
- + weights[2] * oof_C
- + weights[3] * oof_D
- + weights[4] * oof_E
- + weights[5] * oof_F
- )
+ blended = np.zeros_like(y_true, dtype=float)
+ for i, (_stack_name, oof_pred) in enumerate(oof_predictions.items()):
+ blended += normalized_weights[i] * oof_pred
# Convert to binary predictions
y_pred = (blended >= 0.5).astype(int)
@@ -117,6 +130,7 @@ def improved_blend_obj(trial, oof_A, oof_B, oof_C, oof_D, oof_E, oof_F, y_true):
score = accuracy_score(y_true, y_pred)
# Store normalized weights in trial attributes
- trial.set_user_attr("weights", weights.tolist())
+ trial.set_user_attr("weights", normalized_weights.tolist())
+ trial.set_user_attr("stack_names", list(oof_predictions.keys()))
return score
diff --git a/tests/dash_app/test_callbacks.py b/tests/dash_app/test_callbacks.py
index eaf2974..cc031e2 100644
--- a/tests/dash_app/test_callbacks.py
+++ b/tests/dash_app/test_callbacks.py
@@ -137,36 +137,41 @@ class TestCallbackInputValidation:
def callback_function_mock(self):
"""Mock the actual callback function for testing."""
with patch("dash_app.dashboard.callbacks.register_callbacks") as mock_register:
- # Create a mock prediction function
- def mock_prediction_callback(
- n_clicks,
- time_alone,
- social_events,
- going_outside,
- friends_size,
- post_freq,
- stage_fear,
- drained_social,
- ):
+ # Create a mock prediction function that matches our refactored signature
+ def mock_prediction_callback(n_clicks, *input_values):
# Simulate input validation
if n_clicks is None or n_clicks == 0:
return "No prediction made"
- # Validate input ranges
+ # Validate input ranges - unpack the input values
+ if len(input_values) < 7:
+ return "Invalid input: Not enough values"
+
+ time_alone, social_events, going_outside = (
+ input_values[0],
+ input_values[1],
+ input_values[2],
+ )
+ friends_size, post_freq, _stage_fear, _drained_social = (
+ input_values[3],
+ input_values[4],
+ input_values[5],
+ input_values[6],
+ )
+
inputs = [
time_alone,
social_events,
going_outside,
friends_size,
post_freq,
- stage_fear,
- drained_social,
]
if any(x is None for x in inputs):
return "Invalid input: None values"
- if any(not isinstance(x, int | float) for x in inputs):
+ # Check numeric inputs
+ if any(not isinstance(x, int | float) for x in inputs if x is not None):
return "Invalid input: Non-numeric values"
return "Valid prediction"
diff --git a/uv.lock b/uv.lock
index 45a3513..f2f8459 100644
--- a/uv.lock
+++ b/uv.lock
@@ -2481,10 +2481,11 @@ wheels = [
[[package]]
name = "personality-classification"
-version = "0.1.0"
+version = "2.1.1"
source = { editable = "." }
dependencies = [
{ name = "catboost" },
+ { name = "colorama" },
{ name = "dash" },
{ name = "dash-bootstrap-components" },
{ name = "imbalanced-learn" },
@@ -2494,9 +2495,11 @@ dependencies = [
{ name = "optuna" },
{ name = "pandas" },
{ name = "plotly" },
+ { name = "pyyaml" },
{ name = "scikit-learn" },
{ name = "scipy" },
{ name = "sdv" },
+ { name = "tqdm" },
{ name = "xgboost" },
]
@@ -2534,6 +2537,7 @@ requires-dist = [
{ name = "autogluon", marker = "extra == 'automl'", specifier = ">=1.1.1,<2.0.0" },
{ name = "bandit", marker = "extra == 'dev'", specifier = ">=1.7.0,<2.0.0" },
{ name = "catboost", specifier = ">=1.2.0,<2.0.0" },
+ { name = "colorama", specifier = ">=0.4.6,<1.0.0" },
{ name = "dash", specifier = ">=2.14.0,<3.0.0" },
{ name = "dash-bootstrap-components", specifier = ">=1.7.1" },
{ name = "h2o", marker = "extra == 'automl'", specifier = ">=3.44.0,<4.0.0" },
@@ -2541,7 +2545,7 @@ requires-dist = [
{ name = "joblib", specifier = ">=1.3.0,<2.0.0" },
{ name = "lightgbm", specifier = ">=4.0.0,<5.0.0" },
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.5.0,<2.0.0" },
- { name = "numpy", specifier = ">=1.24.0,<2.0.0" },
+ { name = "numpy", specifier = ">=1.24.0,<3.0.0" },
{ name = "optuna", specifier = ">=3.4.0,<4.0.0" },
{ name = "pandas", specifier = ">=2.0.0,<3.0.0" },
{ name = "plotly", specifier = ">=5.24.1" },
@@ -2550,10 +2554,12 @@ requires-dist = [
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0,<8.0.0" },
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.1.0,<5.0.0" },
{ name = "pytest-xdist", marker = "extra == 'dev'", specifier = ">=3.3.0,<4.0.0" },
+ { name = "pyyaml", specifier = ">=6.0.0,<7.0.0" },
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.4.0,<1.0.0" },
{ name = "scikit-learn", specifier = ">=1.3.0,<1.6.0" },
{ name = "scipy", specifier = ">=1.11.0,<2.0.0" },
{ name = "sdv", specifier = ">=1.24.0,<2.0.0" },
+ { name = "tqdm", specifier = ">=4.65.0,<5.0.0" },
{ name = "types-pyyaml", marker = "extra == 'dev'", specifier = ">=6.0.0,<7.0.0" },
{ name = "types-requests", marker = "extra == 'dev'", specifier = ">=2.31.0,<3.0.0" },
{ name = "xgboost", specifier = ">=2.0.0,<3.0.0" },