Skip to content

Commit af1d38d

Browse files
authored
feat: add mutual information scoring and Shannon entropy (#68)
- config.py: add MutualInfoThresholds (low_mi_warning, max_categories, min_samples, entropy_bins) wired into HashPrepConfig - summaries/mutual_info.py: new module - summarize_mutual_information() computes sklearn MI scores (mutual_info_classif for categorical targets, mutual_info_regression for numeric targets) for all eligible features, with label-encoding for categoricals; scores sorted descending and stored in summaries["mutual_information"] when a target column is set - summaries/variables.py: add _shannon_entropy() helper; embed entropy (entropy_bits + normalized_entropy) in numeric summaries (discretised into bins) and categorical summaries (from value-count probabilities) - checks/mutual_info.py: new low_mutual_information check — flags features whose MI with the target is below the configured warning threshold - checks/__init__.py + core/analyzer.py: register low_mutual_information in CHECKS and ALL_CHECKS; inject MI summary into analyzer.summaries - summaries/__init__.py: export summarize_mutual_information - tests/test_mutual_info.py: 28 tests covering entropy in summaries, MI computation correctness, low_mi check unit, and end-to-end integration; threshold-sensitive tests use per-test seeded RNGs and n=2000 to avoid KNN estimator variance; all 208 tests pass (180 existing + 28 new)
1 parent 9f6f822 commit af1d38d

8 files changed

Lines changed: 483 additions & 0 deletions

File tree

hashprep/checks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
_check_high_missing_values,
1313
_check_missing_patterns,
1414
)
15+
from .mutual_info import _check_low_mutual_information
1516
from .outliers import (
1617
_check_constant_length,
1718
_check_datetime_skew,
@@ -60,6 +61,7 @@ def _check_dataset_drift(analyzer):
6061
"empty_dataset": _check_empty_dataset,
6162
"normality": _check_normality,
6263
"variance_homogeneity": _check_variance_homogeneity,
64+
"low_mutual_information": _check_low_mutual_information,
6365
}
6466

6567
CORRELATION_CHECKS = {"feature_correlation", "categorical_correlation", "mixed_correlation"}

hashprep/checks/mutual_info.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""
2+
Check for features with near-zero mutual information with the target column.
3+
Near-zero MI means the feature carries almost no information about the target
4+
and is likely useless (or worse — noise) for a predictive model.
5+
"""
6+
7+
from ..config import DEFAULT_CONFIG
8+
from ..summaries.mutual_info import summarize_mutual_information
9+
from .core import Issue
10+
11+
_MI = DEFAULT_CONFIG.mutual_info
12+
13+
14+
def _check_low_mutual_information(analyzer) -> list[Issue]:
15+
"""
16+
Flag features whose mutual information with the target column is below
17+
the configured warning threshold. Requires target_col to be set.
18+
"""
19+
if analyzer.target_col is None:
20+
return []
21+
22+
mi_result = summarize_mutual_information(analyzer.df, analyzer.target_col, analyzer.column_types)
23+
if not mi_result or not mi_result.get("scores"):
24+
return []
25+
26+
issues = []
27+
scores = mi_result["scores"]
28+
task = mi_result["task"]
29+
30+
for col, score in scores.items():
31+
if score < _MI.low_mi_warning:
32+
issues.append(
33+
Issue(
34+
category="low_mutual_information",
35+
severity="warning",
36+
column=col,
37+
description=(
38+
f"Column '{col}' has near-zero mutual information with target "
39+
f"'{analyzer.target_col}' (MI={score:.4f} nats, task={task})"
40+
),
41+
impact_score="medium",
42+
quick_fix=(
43+
"Options:\n"
44+
"- Drop feature: Near-zero MI suggests no predictive signal for the target.\n"
45+
"- Investigate interactions: Feature may be useful combined with others.\n"
46+
"- Check encoding: Categorical features may need different encoding.\n"
47+
"- Retain for now: MI is marginal; feature interactions may matter."
48+
),
49+
)
50+
)
51+
52+
return issues

hashprep/config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,20 @@ class ImbalanceThresholds:
126126
majority_class_ratio: float = 0.9
127127

128128

129+
@dataclass(frozen=True)
130+
class MutualInfoThresholds:
131+
"""Thresholds for mutual information and entropy checks."""
132+
133+
# MI score below this value (nats) flags a feature as potentially uninformative
134+
low_mi_warning: float = 0.01
135+
# Maximum number of categories to include a column in MI computation
136+
max_categories_for_mi: int = 50
137+
# Minimum number of samples required to compute MI
138+
min_samples_for_mi: int = 20
139+
# Number of bins used to discretize numeric columns when computing entropy
140+
entropy_bins: int = 10
141+
142+
129143
@dataclass(frozen=True)
130144
class StatisticalTestThresholds:
131145
"""Thresholds for normality and variance homogeneity tests."""
@@ -206,6 +220,7 @@ class HashPrepConfig:
206220
drift: DriftThresholds = field(default_factory=DriftThresholds)
207221
distribution: DistributionThresholds = field(default_factory=DistributionThresholds)
208222
imbalance: ImbalanceThresholds = field(default_factory=ImbalanceThresholds)
223+
mutual_info: MutualInfoThresholds = field(default_factory=MutualInfoThresholds)
209224
statistical_tests: StatisticalTestThresholds = field(default_factory=StatisticalTestThresholds)
210225
datetime: DateTimeThresholds = field(default_factory=DateTimeThresholds)
211226
type_inference: TypeInferenceConfig = field(default_factory=TypeInferenceConfig)

hashprep/core/analyzer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
summarize_variable_types,
1818
summarize_variables,
1919
)
20+
from ..summaries.mutual_info import summarize_mutual_information
2021
from ..utils.sampling import DatasetSampler, SamplingConfig
2122
from ..utils.type_inference import infer_types
2223
from .visualizations import (
@@ -61,6 +62,7 @@ class DatasetAnalyzer:
6162
"constant_length",
6263
"normality",
6364
"variance_homogeneity",
65+
"low_mutual_information",
6466
]
6567

6668
def __init__(
@@ -125,6 +127,11 @@ def analyze(self) -> dict:
125127
self.summaries.update(summarize_interactions(self.df))
126128
self.summaries.update(summarize_missing_values(self.df))
127129

130+
if self.target_col is not None:
131+
mi_result = summarize_mutual_information(self.df, self.target_col, self.column_types)
132+
if mi_result:
133+
self.summaries["mutual_information"] = mi_result
134+
128135
if self.sampler:
129136
self.summaries["sampling_info"] = self.sampler.get_sampling_info()
130137

hashprep/summaries/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@
1818
)
1919
from .interactions import summarize_interactions as summarize_interactions
2020
from .missing import summarize_missing_values as summarize_missing_values
21+
from .mutual_info import summarize_mutual_information as summarize_mutual_information
2122
from .variables import summarize_variables as summarize_variables

hashprep/summaries/mutual_info.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""
2+
Mutual information between each feature and the target column.
3+
4+
Uses sklearn's mutual_info_classif (categorical target) or
5+
mutual_info_regression (numeric target). Categorical features are
6+
label-encoded before scoring.
7+
"""
8+
9+
import pandas as pd
10+
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
11+
from sklearn.preprocessing import LabelEncoder
12+
13+
from ..config import DEFAULT_CONFIG
14+
from ..utils.logging import get_logger
15+
16+
_log = get_logger("summaries.mutual_info")
17+
_MI = DEFAULT_CONFIG.mutual_info
18+
19+
20+
def summarize_mutual_information(
21+
df: pd.DataFrame,
22+
target_col: str,
23+
column_types: dict[str, str],
24+
) -> dict:
25+
"""
26+
Compute mutual information between every feature and the target column.
27+
28+
Returns a dict:
29+
{
30+
"target": target_col,
31+
"task": "classification" | "regression",
32+
"scores": {col: mi_score, ...}, # nats, sorted descending
33+
}
34+
or an empty dict when MI cannot be computed (too few samples, bad target, etc.).
35+
"""
36+
if target_col not in df.columns:
37+
return {}
38+
39+
target_type = column_types.get(target_col, "Unsupported")
40+
n = len(df.dropna(subset=[target_col]))
41+
if n < _MI.min_samples_for_mi:
42+
return {}
43+
44+
# Determine task type
45+
if target_type in ("Numeric",):
46+
task = "regression"
47+
mi_fn = mutual_info_regression
48+
else:
49+
task = "classification"
50+
mi_fn = mutual_info_classif
51+
52+
# Build feature matrix — include Numeric and low-cardinality Categorical cols
53+
feature_cols = []
54+
discrete_mask = []
55+
56+
for col in df.columns:
57+
if col == target_col:
58+
continue
59+
typ = column_types.get(col, "Unsupported")
60+
if typ == "Numeric":
61+
feature_cols.append(col)
62+
discrete_mask.append(False)
63+
elif typ == "Categorical" and df[col].nunique() <= _MI.max_categories_for_mi:
64+
feature_cols.append(col)
65+
discrete_mask.append(True)
66+
67+
if not feature_cols:
68+
return {}
69+
70+
# Build X: label-encode categoricals, drop rows missing target
71+
sub = df[feature_cols + [target_col]].dropna(subset=[target_col])
72+
X = sub[feature_cols].copy()
73+
74+
for col, is_discrete in zip(feature_cols, discrete_mask):
75+
if is_discrete:
76+
le = LabelEncoder()
77+
filled = X[col].fillna("__missing__").astype(str)
78+
X[col] = le.fit_transform(filled)
79+
else:
80+
X[col] = X[col].fillna(X[col].median())
81+
82+
y_raw = sub[target_col]
83+
if task == "classification":
84+
le_y = LabelEncoder()
85+
y = le_y.fit_transform(y_raw.fillna("__missing__").astype(str))
86+
else:
87+
y = y_raw.values
88+
89+
try:
90+
mi_scores = mi_fn(X.values, y, discrete_features=discrete_mask, random_state=0)
91+
except Exception as e:
92+
_log.debug("Mutual information computation failed: %s", e)
93+
return {}
94+
95+
scores = {col: float(score) for col, score in zip(feature_cols, mi_scores)}
96+
# Sort descending by MI score
97+
scores = dict(sorted(scores.items(), key=lambda kv: kv[1], reverse=True))
98+
99+
return {
100+
"target": target_col,
101+
"task": task,
102+
"scores": scores,
103+
}

hashprep/summaries/variables.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,32 @@
1010

1111
_SUMMARY = DEFAULT_CONFIG.summaries
1212
_ST = DEFAULT_CONFIG.statistical_tests
13+
_MI = DEFAULT_CONFIG.mutual_info
14+
15+
16+
def _shannon_entropy(series: pd.Series, bins: int | None = None) -> dict | None:
17+
"""
18+
Compute Shannon entropy (bits) for a series.
19+
- Categorical/text: uses value-count probabilities directly.
20+
- Numeric: discretises into `bins` equal-width bins first.
21+
Returns a dict with 'entropy_bits' and 'normalized_entropy' (0–1),
22+
or None when there are fewer than 2 distinct values.
23+
"""
24+
if series.empty:
25+
return None
26+
if bins is not None:
27+
# Discretise numeric series into bins
28+
try:
29+
series = pd.cut(series, bins=bins, labels=False, duplicates="drop")
30+
except Exception:
31+
return None
32+
probs = series.dropna().value_counts(normalize=True)
33+
if len(probs) < 2:
34+
return None
35+
entropy_bits = float(-np.sum(probs * np.log2(probs)))
36+
max_entropy = float(np.log2(len(probs)))
37+
normalized = entropy_bits / max_entropy if max_entropy > 0 else 0.0
38+
return {"entropy_bits": entropy_bits, "normalized_entropy": normalized}
1339

1440

1541
def get_monotonicity(series: pd.Series) -> str:
@@ -159,6 +185,7 @@ def _summarize_numeric(df, col):
159185
"common_values": common_values,
160186
"extreme_values": extremes,
161187
"normality": normality,
188+
"entropy": _shannon_entropy(finite, bins=_MI.entropy_bins),
162189
}
163190
return stats
164191

@@ -341,6 +368,7 @@ def _summarize_categorical(df, col):
341368
},
342369
"words": text_summary["words"],
343370
"characters": text_summary["characters"],
371+
"entropy": _shannon_entropy(series),
344372
}
345373
return stats
346374

0 commit comments

Comments
 (0)