From 523362a754210f695f9ceddd42ff5e4be452d763 Mon Sep 17 00:00:00 2001 From: jiaruixu Date: Wed, 31 Dec 2025 23:40:05 -0800 Subject: [PATCH] Enhance .gitignore to exclude local issue drafts and add Hessian computation to LossGradHess class in LinearSplineLogisticRegression. Introduce progressive fitting with stratified sampling and early stopping features, along with updates to fitting methods and parameters for improved performance and usability. --- .gitignore | 3 + src/splinator/estimators.py | 382 ++++++++++++++++++++++++++++-- tests/test_progressive_fitting.py | 198 ++++++++++++++++ 3 files changed, 562 insertions(+), 21 deletions(-) create mode 100644 tests/test_progressive_fitting.py diff --git a/.gitignore b/.gitignore index ae979d6..49e6957 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,6 @@ cython_debug/ .DS_Store .idea/ + +# Local issue drafts +.github/ISSUES/ diff --git a/src/splinator/estimators.py b/src/splinator/estimators.py index fa34d8f..c04abd7 100644 --- a/src/splinator/estimators.py +++ b/src/splinator/estimators.py @@ -17,7 +17,7 @@ class MinimizationMethod(Enum): - """These two constrained-optimization methods are supported by scipy.optimize.minimize""" + """Optimization methods supported by scipy.optimize.minimize""" # they are case-sensitive slsqp = 'SLSQP' trust_constr = 'trust-constr' @@ -68,6 +68,36 @@ def grad(self, coefs): return grad + def hess(self, coefs): + # type: (np.ndarray) -> np.ndarray + """ + Compute the Hessian of the logistic loss. + + The Hessian of logistic regression is: H = X.T @ diag(w) @ X + alpha * I + where w = p * (1 - p) and p = sigmoid(X @ coefs). + + This is used by trust-constr for faster convergence (Newton-like steps). + """ + z = np.dot(self.X, coefs) + p = expit(z) + # Weights for the Hessian: p * (1 - p) + # This is always positive, making the Hessian positive semi-definite + weights = p * (1 - p) + + # H = X.T @ diag(weights) @ X + # Efficient computation: (X.T * weights) @ X + H = np.dot(self.X.T * weights, self.X) + + # Add regularization term + if self.intercept: + # Don't regularize the intercept + reg_indices = np.arange(1, H.shape[0]) + H[reg_indices, reg_indices] += self.alpha + else: + H += self.alpha * np.eye(H.shape[0]) + + return H + class LinearSplineLogisticRegression(RegressorMixin, TransformerMixin, BaseEstimator): """Piecewise Logistic Regression with Linear Splines. @@ -91,12 +121,30 @@ class LinearSplineLogisticRegression(RegressorMixin, TransformerMixin, BaseEstim Whether to include an intercept term in the model. method : str, default='SLSQP' Optimization method for scipy.optimize.minimize. Supports 'SLSQP' or 'trust-constr'. + SLSQP is recommended for problems with monotonicity constraints (fastest). + trust-constr with Hessian is faster for unconstrained problems. minimizer_options : dict or None, default=None Additional options passed to the scipy minimizer. C : float, default=100 Inverse of regularization strength (larger values = weaker regularization). two_stage_fitting_initial_size : int or None, default=None If provided, performs initial fit on a subsample of this size for faster convergence. + Deprecated: prefer using `progressive_fitting_fractions` for better performance. + progressive_fitting_fractions : tuple or None, default=None + Tuple of fractions for progressive fitting (e.g., (0.1, 0.3, 1.0)). + Each stage uses stratified sampling to maintain score distribution coverage. + If provided, overrides `two_stage_fitting_initial_size`. + stratified_sampling : bool, default=True + If True, uses quantile-based stratified sampling to ensure subsamples cover + the full score range. Only applies when using progressive/two-stage fitting. + early_stopping_tol : float or None, default=1e-4 + If provided, stops progressive fitting early when coefficient change + (relative L2 norm) falls below this threshold between stages. + use_hessian : bool or 'auto', default='auto' + Controls Hessian usage with trust-constr method. + - 'auto': Enable Hessian when monotonicity='none' (3-4x faster) + - True: Always use Hessian (can be slow with constraints) + - False: Never use Hessian random_state : int, default=31 Random seed for reproducibility. verbose : bool, default=False @@ -110,6 +158,8 @@ class LinearSplineLogisticRegression(RegressorMixin, TransformerMixin, BaseEstim The knot positions used in the model. n_features_in_ : int Number of features seen during fit. + fitting_history_ : list + History of fitting stages when using progressive fitting. Examples -------- @@ -119,6 +169,18 @@ class LinearSplineLogisticRegression(RegressorMixin, TransformerMixin, BaseEstim >>> y = np.zeros((100, )) >>> estimator = LinearSplineLogisticRegression() >>> estimator.fit(X, y) + + For monotonic calibration (recommended): + >>> estimator = LinearSplineLogisticRegression( + ... monotonicity='increasing', # SLSQP is fastest for constraints + ... n_knots=50, + ... ) + + For unconstrained fitting on large data (uses Hessian automatically): + >>> estimator = LinearSplineLogisticRegression( + ... monotonicity='none', + ... method='trust-constr', # Auto-enables Hessian for 3-4x speedup + ... ) """ def __init__( @@ -131,9 +193,13 @@ def __init__( method: str = MinimizationMethod.slsqp.value, minimizer_options: Optional[Dict[str, Any]] = None, C: int = 100, - two_stage_fitting_initial_size: int = None, + two_stage_fitting_initial_size: Optional[int] = None, + progressive_fitting_fractions: Optional[Tuple[float, ...]] = None, + stratified_sampling: bool = True, + early_stopping_tol: Optional[float] = 1e-4, + use_hessian: Union[bool, str] = 'auto', random_state: int = 31, - verbose=False, + verbose: bool = False, ): # type: (...) -> None """ @@ -164,6 +230,16 @@ def __init__( two_stage_fitting_initial_size : int, default=None subsample size of training data for first fitting. If two_stage_fitting is not used, this should be None. + Deprecated: prefer using `progressive_fitting_fractions`. + progressive_fitting_fractions : tuple, default=None + Fractions of data for progressive fitting stages (e.g., (0.1, 0.3, 1.0)). + Uses stratified sampling and warm-starts each stage with previous coefficients. + stratified_sampling : bool, default=True + Use quantile-based stratified sampling for progressive fitting. + early_stopping_tol : float, default=1e-4 + Stop early if coefficient change between stages is below this threshold. + use_hessian : bool or 'auto', default='auto' + Use analytical Hessian with trust-constr. 'auto' enables for unconstrained problems. random_state : int, default=31 random seed number, default is 31 """ @@ -176,6 +252,10 @@ def __init__( self.minimizer_options = minimizer_options self.C = C self.two_stage_fitting_initial_size = two_stage_fitting_initial_size + self.progressive_fitting_fractions = progressive_fitting_fractions + self.stratified_sampling = stratified_sampling + self.early_stopping_tol = early_stopping_tol + self.use_hessian = use_hessian self.random_state = random_state self.verbose = verbose @@ -220,10 +300,27 @@ def _fit(self, X, y, initial_guess=None): lgh = LossGradHess(design_X, y, 1 / self.C, self.intercept) + # Determine whether to use Hessian + # - 'auto': use Hessian only for trust-constr with no monotonicity constraints (3-4x faster) + # - True: always use Hessian with trust-constr + # - False: never use Hessian + if self.use_hessian == 'auto': + use_hess = ( + self.method == MinimizationMethod.trust_constr.value + and self.monotonicity == Monotonicity.none.value + ) + else: + use_hess = ( + self.use_hessian + and self.method == MinimizationMethod.trust_constr.value + ) + hess = lgh.hess if use_hess else None + result = minimize( fun=lgh.loss, x0=x0, jac=lgh.grad, + hess=hess, method=self.method, constraints=constraint, options=self.minimizer_options or {}, @@ -247,15 +344,120 @@ def get_additional_columns(self, X): additional_columns = np.delete(X, self.input_score_column_index, axis=1) return additional_columns + def _stratified_subsample(self, X, y, n_samples, n_strata=10): + # type: (np.ndarray, np.ndarray, int, int) -> Tuple[np.ndarray, np.ndarray, np.ndarray] + """ + Create a stratified subsample based on quantiles of the input scores. + + This ensures the subsample covers the full range of scores, which is + important for fitting splines accurately. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + y : array-like of shape (n_samples,) + n_samples : int + Target number of samples in the subsample + n_strata : int + Number of strata (quantile bins) to use + + Returns + ------- + X_sub : array-like + y_sub : array-like + indices : array-like + Indices of selected samples + """ + input_scores = self.get_input_scores(X) + n_total = len(X) + + # Calculate quantile boundaries for stratification + percentiles = np.linspace(0, 100, n_strata + 1) + boundaries = np.percentile(input_scores, percentiles) + + # Assign each sample to a stratum + strata = np.digitize(input_scores, boundaries[1:-1]) + + # Sample proportionally from each stratum + selected_indices = [] + samples_per_stratum = max(1, n_samples // n_strata) + + for s in range(n_strata): + stratum_indices = np.where(strata == s)[0] + if len(stratum_indices) == 0: + continue + + # Sample from this stratum + n_to_sample = min(samples_per_stratum, len(stratum_indices)) + sampled = self.random_state_.choice( + stratum_indices, n_to_sample, replace=False + ) + selected_indices.extend(sampled) + + # If we need more samples to reach target, sample randomly from remainder + selected_indices = np.array(selected_indices) + if len(selected_indices) < n_samples: + remaining = np.setdiff1d(np.arange(n_total), selected_indices) + n_extra = min(n_samples - len(selected_indices), len(remaining)) + if n_extra > 0: + extra = self.random_state_.choice(remaining, n_extra, replace=False) + selected_indices = np.concatenate([selected_indices, extra]) + + # Shuffle to avoid any ordering effects + self.random_state_.shuffle(selected_indices) + selected_indices = selected_indices[:n_samples] + + if isinstance(X, pd.DataFrame): + X_sub = X.iloc[selected_indices] + else: + X_sub = X[selected_indices, :] + y_sub = y[selected_indices] + + return X_sub, y_sub, selected_indices + + def _random_subsample(self, X, y, n_samples): + # type: (np.ndarray, np.ndarray, int) -> Tuple[np.ndarray, np.ndarray, np.ndarray] + """Simple random subsampling (original behavior).""" + indices = self.random_state_.choice( + np.arange(len(X)), n_samples, replace=False + ) + if isinstance(X, pd.DataFrame): + X_sub, y_sub = X.iloc[indices], y[indices] + else: + X_sub, y_sub = X[indices, :], y[indices] + return X_sub, y_sub, indices + + def _check_convergence(self, coefs_old, coefs_new): + # type: (np.ndarray, np.ndarray) -> bool + """Check if coefficients have converged based on relative change.""" + if coefs_old is None or self.early_stopping_tol is None: + return False + + # Relative L2 norm of change + norm_old = np.linalg.norm(coefs_old) + if norm_old < 1e-10: + # If old coefficients are near zero, use absolute change + change = np.linalg.norm(coefs_new - coefs_old) + else: + change = np.linalg.norm(coefs_new - coefs_old) / norm_old + + return change < self.early_stopping_tol + def fit(self, X, y): # type: (pd.DataFrame, Union[np.ndarray, pd.Series], Optional[np.ndarray]) -> None """ - When the dataset is too large, we choose to use a random subset of the data to do an initial fit; - Then we take the coefficients as initial guess to fit again using the entire dataset. This will speed - up training and avoid under-fitting. - We use two_stage_fitting_size as the sampling size. + Fit the linear spline logistic regression model. + + Supports three fitting modes: + 1. Direct fitting (default): Fit on full data + 2. Two-stage fitting (legacy): Uses `two_stage_fitting_initial_size` + 3. Progressive fitting (recommended): Uses `progressive_fitting_fractions` + + Progressive fitting with stratified sampling is recommended for large datasets + as it provides faster convergence while maintaining calibration quality. """ self.random_state_ = check_random_state(self.random_state) + self.fitting_history_ = [] # Validate X and y, this sets n_features_in_ automatically X, y = validate_data( @@ -289,24 +491,162 @@ def fit(self, X, y): if self.method not in ['SLSQP', 'trust-constr']: raise ValueError("optimization method can only be either 'SLSQP' or 'trust-constr'") - if self.two_stage_fitting_initial_size is None: - self._fit(X, y, initial_guess=None) + n_samples = X.shape[0] + + # Determine fitting mode + if self.progressive_fitting_fractions is not None: + # Mode 3: Progressive fitting (recommended) + self._progressive_fit(X, y) + elif self.two_stage_fitting_initial_size is not None: + # Mode 2: Legacy two-stage fitting + self._two_stage_fit(X, y) else: - if self.two_stage_fitting_initial_size > X.shape[0]: - raise ValueError("two_stage_fitting_initial_size should be smaller than data size") + # Mode 1: Direct fitting on full data + self._fit(X, y, initial_guess=None) + self.fitting_history_.append({ + 'stage': 0, + 'n_samples': n_samples, + 'fraction': 1.0, + 'iterations': getattr(self.result_, 'nit', None), + 'converged': self.result_.success, + }) - # initial fitting without guess - index = self.random_state_.choice(np.arange(len(X)), self.two_stage_fitting_initial_size, replace=False) - if isinstance(X, pd.DataFrame): - X_sub, y_sub = X.iloc[index], y[index] - else: - X_sub, y_sub = X[index, :], y[index] - self._fit(X_sub, y_sub, initial_guess=None) + return self - # final fitting with coefs from initial run as guess - self._fit(X, y, initial_guess=self.coefficients_) + def _two_stage_fit(self, X, y): + # type: (np.ndarray, np.ndarray) -> None + """Legacy two-stage fitting for backward compatibility.""" + n_samples = X.shape[0] + + if self.two_stage_fitting_initial_size > n_samples: + raise ValueError("two_stage_fitting_initial_size should be smaller than data size") + + # Warn if subsample is too small relative to knots + samples_per_knot = self.two_stage_fitting_initial_size / (self.knots_.shape[0] + 1) + if samples_per_knot < 50: + warn( + f"Subsample size ({self.two_stage_fitting_initial_size}) may be too small " + f"for {self.knots_.shape[0]} knots ({samples_per_knot:.0f} samples/knot). " + f"Consider increasing subsample size or reducing knots for better warm-start." + ) - return self + # Stage 1: Fit on subsample + if self.stratified_sampling: + X_sub, y_sub, _ = self._stratified_subsample( + X, y, self.two_stage_fitting_initial_size + ) + else: + X_sub, y_sub, _ = self._random_subsample( + X, y, self.two_stage_fitting_initial_size + ) + + self._fit(X_sub, y_sub, initial_guess=None) + self.fitting_history_.append({ + 'stage': 0, + 'n_samples': self.two_stage_fitting_initial_size, + 'fraction': self.two_stage_fitting_initial_size / n_samples, + 'iterations': getattr(self.result_, 'nit', None), + 'converged': self.result_.success, + }) + + if self.verbose: + print(f"Stage 1: {self.two_stage_fitting_initial_size} samples, " + f"{self.result_.nit} iterations") + + # Stage 2: Fit on full data with warm start + coefs_stage1 = self.coefficients_.copy() + self._fit(X, y, initial_guess=coefs_stage1) + self.fitting_history_.append({ + 'stage': 1, + 'n_samples': n_samples, + 'fraction': 1.0, + 'iterations': getattr(self.result_, 'nit', None), + 'converged': self.result_.success, + }) + + if self.verbose: + print(f"Stage 2: {n_samples} samples, {self.result_.nit} iterations") + + def _progressive_fit(self, X, y): + # type: (np.ndarray, np.ndarray) -> None + """ + Progressive fitting with gradual sample increase. + + Uses stratified sampling to ensure each stage covers the full score range. + Warm-starts each stage with coefficients from the previous stage. + Supports early stopping if coefficients converge. + """ + n_samples = X.shape[0] + fractions = self.progressive_fitting_fractions + + # Validate fractions + if not all(0 < f <= 1.0 for f in fractions): + raise ValueError("All fractions must be in (0, 1]") + if fractions[-1] != 1.0: + # Ensure we always end with full data + fractions = tuple(fractions) + (1.0,) + + prev_coefs = None + + for stage, frac in enumerate(fractions): + n_stage_samples = int(n_samples * frac) + n_stage_samples = max(n_stage_samples, self.knots_.shape[0] + 10) # Ensure enough samples + n_stage_samples = min(n_stage_samples, n_samples) + + if frac >= 1.0: + # Use full data for final stage + X_stage, y_stage = X, y + else: + # Warn if subsample too small for knots + samples_per_knot = n_stage_samples / (self.knots_.shape[0] + 1) + if samples_per_knot < 50 and self.verbose: + print(f" Warning: Stage {stage} has {samples_per_knot:.0f} samples/knot (recommend >= 50)") + + # Subsample for intermediate stages + if self.stratified_sampling: + X_stage, y_stage, _ = self._stratified_subsample( + X, y, n_stage_samples + ) + else: + X_stage, y_stage, _ = self._random_subsample( + X, y, n_stage_samples + ) + + self._fit(X_stage, y_stage, initial_guess=prev_coefs) + + self.fitting_history_.append({ + 'stage': stage, + 'n_samples': n_stage_samples, + 'fraction': frac, + 'iterations': getattr(self.result_, 'nit', None), + 'converged': self.result_.success, + }) + + if self.verbose: + print(f"Stage {stage + 1}/{len(fractions)}: " + f"{n_stage_samples} samples ({frac:.1%}), " + f"{self.result_.nit} iterations") + + # Check for early stopping (but always complete final stage) + if frac < 1.0 and self._check_convergence(prev_coefs, self.coefficients_): + if self.verbose: + print(f"Early stopping: coefficients converged at stage {stage + 1}") + # Still fit on full data for final refinement, but with fewer iterations expected + prev_coefs = self.coefficients_.copy() + self._fit(X, y, initial_guess=prev_coefs) + self.fitting_history_.append({ + 'stage': stage + 1, + 'n_samples': n_samples, + 'fraction': 1.0, + 'iterations': getattr(self.result_, 'nit', None), + 'converged': self.result_.success, + 'early_stopped': True, + }) + if self.verbose: + print(f"Final stage: {n_samples} samples, {self.result_.nit} iterations") + break + + prev_coefs = self.coefficients_.copy() def transform(self, X): if not self.is_fitted: diff --git a/tests/test_progressive_fitting.py b/tests/test_progressive_fitting.py new file mode 100644 index 0000000..df879c7 --- /dev/null +++ b/tests/test_progressive_fitting.py @@ -0,0 +1,198 @@ +"""Tests for the progressive fitting improvements (Issue #9).""" +from __future__ import absolute_import, division, print_function + +import unittest +import numpy as np +from scipy.special import expit + +from splinator.estimators import LinearSplineLogisticRegression, LossGradHess +from splinator.monotonic_spline import Monotonicity + + +class TestHessian(unittest.TestCase): + """Test the Hessian computation.""" + + def setUp(self): + np.random.seed(42) + self.n_samples = 50 + self.n_features = 5 + self.X = np.random.randn(self.n_samples, self.n_features) + self.y = np.random.randint(0, 2, self.n_samples) + self.alpha = 0.01 + + def test_hessian_shape(self): + """Test that Hessian has correct shape.""" + lgh = LossGradHess(self.X, self.y, self.alpha, intercept=True) + coefs = np.zeros(self.n_features) + H = lgh.hess(coefs) + self.assertEqual(H.shape, (self.n_features, self.n_features)) + + def test_hessian_symmetry(self): + """Test that Hessian is symmetric.""" + lgh = LossGradHess(self.X, self.y, self.alpha, intercept=True) + coefs = np.random.randn(self.n_features) + H = lgh.hess(coefs) + np.testing.assert_array_almost_equal(H, H.T) + + def test_hessian_positive_semidefinite(self): + """Test that Hessian is positive semi-definite.""" + lgh = LossGradHess(self.X, self.y, self.alpha, intercept=True) + coefs = np.random.randn(self.n_features) + H = lgh.hess(coefs) + eigenvalues = np.linalg.eigvalsh(H) + self.assertTrue(np.all(eigenvalues >= -1e-10)) + + def test_hessian_numerical_gradient(self): + """Test Hessian against numerical differentiation of gradient.""" + lgh = LossGradHess(self.X, self.y, self.alpha, intercept=True) + coefs = np.random.randn(self.n_features) * 0.1 + H_analytical = lgh.hess(coefs) + + # Numerical Hessian via finite differences on gradient + eps = 1e-5 + H_numerical = np.zeros((self.n_features, self.n_features)) + for i in range(self.n_features): + coefs_plus = coefs.copy() + coefs_plus[i] += eps + coefs_minus = coefs.copy() + coefs_minus[i] -= eps + H_numerical[:, i] = (lgh.grad(coefs_plus) - lgh.grad(coefs_minus)) / (2 * eps) + + np.testing.assert_array_almost_equal(H_analytical, H_numerical, decimal=4) + + +class TestProgressiveFitting(unittest.TestCase): + """Test progressive fitting functionality.""" + + def setUp(self): + np.random.seed(123) + self.n_samples = 500 + self.X = np.random.randn(self.n_samples, 1) * 2 + probs = expit(self.X[:, 0]) + self.y = np.random.binomial(1, probs) + + def test_progressive_fitting_runs(self): + """Test that progressive fitting runs without errors.""" + model = LinearSplineLogisticRegression( + n_knots=10, + progressive_fitting_fractions=(0.2, 1.0), + random_state=42, + ) + model.fit(self.X, self.y) + self.assertTrue(model.is_fitted) + self.assertEqual(len(model.fitting_history_), 2) + + def test_progressive_fitting_with_monotonicity(self): + """Test progressive fitting with monotonicity constraints.""" + model = LinearSplineLogisticRegression( + n_knots=10, + monotonicity='increasing', + progressive_fitting_fractions=(0.3, 1.0), + random_state=42, + ) + model.fit(self.X, self.y) + self.assertTrue(model.is_fitted) + + def test_progressive_fitting_history(self): + """Test that fitting history is recorded correctly.""" + model = LinearSplineLogisticRegression( + n_knots=10, + progressive_fitting_fractions=(0.2, 0.5, 1.0), + random_state=42, + ) + model.fit(self.X, self.y) + + self.assertEqual(len(model.fitting_history_), 3) + for i, history in enumerate(model.fitting_history_): + self.assertIn('stage', history) + self.assertIn('n_samples', history) + self.assertIn('fraction', history) + + +class TestStratifiedSampling(unittest.TestCase): + """Test stratified sampling functionality.""" + + def setUp(self): + np.random.seed(456) + self.n_samples = 500 + # Bimodal distribution + X1 = np.random.randn(self.n_samples // 2, 1) - 3 + X2 = np.random.randn(self.n_samples // 2, 1) + 3 + self.X = np.vstack([X1, X2]) + probs = expit(self.X[:, 0] * 0.5) + self.y = np.random.binomial(1, probs) + + def test_stratified_sampling_method(self): + """Test the _stratified_subsample method directly.""" + model = LinearSplineLogisticRegression(n_knots=10, random_state=42) + model.random_state_ = np.random.RandomState(42) + model.input_score_column_index = 0 + + X_sub, y_sub, indices = model._stratified_subsample( + self.X, self.y, n_samples=100, n_strata=10 + ) + + self.assertEqual(len(X_sub), 100) + self.assertEqual(len(y_sub), 100) + + # Check coverage + original_range = np.ptp(self.X[:, 0]) + subsample_range = np.ptp(X_sub[:, 0]) + coverage = subsample_range / original_range + self.assertGreater(coverage, 0.8) + + +class TestEarlyStopping(unittest.TestCase): + """Test early stopping functionality.""" + + def test_convergence_check(self): + """Test the _check_convergence method.""" + model = LinearSplineLogisticRegression( + n_knots=10, + early_stopping_tol=1e-4, + random_state=42 + ) + + coefs_old = np.array([1.0, 2.0, 3.0]) + coefs_similar = np.array([1.0001, 2.0001, 3.0001]) + coefs_different = np.array([1.5, 2.5, 3.5]) + + self.assertTrue(model._check_convergence(coefs_old, coefs_similar)) + self.assertFalse(model._check_convergence(coefs_old, coefs_different)) + self.assertFalse(model._check_convergence(None, coefs_similar)) + + +class TestBackwardCompatibility(unittest.TestCase): + """Test backward compatibility with existing API.""" + + def setUp(self): + np.random.seed(999) + self.n_samples = 300 + self.X = np.random.randn(self.n_samples, 1) + probs = expit(self.X[:, 0]) + self.y = np.random.binomial(1, probs) + + def test_legacy_two_stage_fitting(self): + """Test that legacy two_stage_fitting_initial_size still works.""" + model = LinearSplineLogisticRegression( + n_knots=10, + two_stage_fitting_initial_size=100, + random_state=42, + ) + model.fit(self.X, self.y) + self.assertTrue(model.is_fitted) + self.assertEqual(len(model.fitting_history_), 2) + + def test_direct_fitting(self): + """Test that direct fitting (no progressive) still works.""" + model = LinearSplineLogisticRegression( + n_knots=10, + random_state=42, + ) + model.fit(self.X, self.y) + self.assertTrue(model.is_fitted) + self.assertEqual(len(model.fitting_history_), 1) + + +if __name__ == '__main__': + unittest.main()