diff --git a/tests/__init__.py b/tests/__init__.py index 0bbaffd..7d59866 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,14 @@ # conftest.py + +# solving the tcl issue +import os +os.environ.setdefault("MPLBACKEND", "Agg") +try: + import matplotlib + matplotlib.use("Agg", force=True) +except Exception: + pass + import numpy as np import pytest import pyvinecopulib as pvc @@ -27,7 +37,7 @@ def bicop_pair(request): """ Returns a tuple: - ( family, true_params, U_tensor, bicop_fastkde, bicop_tll ) + ( family, true_params, U_tensor, bicop_fastkde, bicop_tll, bicop_torchKDE ) notice the scope="module" so that the fixture is created only once and reused in all tests that use it. """ @@ -38,14 +48,17 @@ def bicop_pair(request): U = true_bc.simulate(n=N_SIM, seeds=SEEDS) # shape (N_SIM, 2) U_tensor = torch.tensor(U, device=DEVICE, dtype=torch.float64) - # 2) fit two torchvinecopulib instances (fast KDE and TLL) + # 2) fit two torchvinecopulib instances (torch KDE, fast KDE and TLL) bc_fast = tvc.BiCop(num_step_grid=512).to(DEVICE) bc_fast.fit(U_tensor, mtd_kde="fastKDE") + bc_torch = tvc.BiCop(num_step_grid=512).to(DEVICE) + bc_torch.fit(U_tensor, mtd_kde="torchKDE") + bc_tll = tvc.BiCop(num_step_grid=512).to(DEVICE) bc_tll.fit(U_tensor, mtd_kde="tll") - return family, true_params, rotation, U_tensor, bc_fast, bc_tll + return family, true_params, rotation, U_tensor, bc_fast, bc_tll, bc_torch @pytest.fixture(scope="module") diff --git a/tests/test_bicop.py b/tests/test_bicop.py index 9cf61fd..b2b7ce3 100644 --- a/tests/test_bicop.py +++ b/tests/test_bicop.py @@ -20,10 +20,10 @@ def test_device_and_dtype(): def test_monotonicity_and_range(bicop_pair): - family, params, rotation, U, bc_fast, bc_tll = bicop_pair + family, params, rotation, U, bc_fast, bc_tll, bc_torch = bicop_pair # pick one of the two implementations or loop both - for bicop in (bc_fast, bc_tll): + for bicop in (bc_fast, bc_tll, bc_torch): # * simple diagonal check grid = torch.linspace(EPS, 1.0 - EPS, 100, device=U.device, dtype=torch.float64).unsqueeze( 1 @@ -44,9 +44,9 @@ def test_monotonicity_and_range(bicop_pair): def test_inversion(bicop_pair): - family, params, rotation, U, bc_fast, bc_tll = bicop_pair + family, params, rotation, U, bc_fast, bc_tll, bc_torch = bicop_pair - for bicop in (bc_fast, bc_tll): + for bicop in (bc_fast, bc_tll, bc_torch): grid = torch.linspace(0.1, 0.9, 50, device=U.device, dtype=torch.float64).unsqueeze(1) for u in grid: pts = torch.hstack([grid, u.repeat(grid.size(0), 1)]) @@ -68,8 +68,8 @@ def test_inversion(bicop_pair): def test_pdf_integrates_to_one(bicop_pair): - family, params, rotation, U, bc_fast, bc_tll = bicop_pair - for cop in (bc_fast, bc_tll): + family, params, rotation, U, bc_fast, bc_tll, bc_torch = bicop_pair + for cop in (bc_fast, bc_tll, bc_torch): # our grid is uniform on [0,1]² with spacing Δ = 1/(N−1) Δ = 1.0 / (cop.num_step_grid - 1) # approximate ∫ pdf(u,v) du dv ≈ Σ_pdf_grid * Δ² @@ -80,8 +80,8 @@ def test_pdf_integrates_to_one(bicop_pair): def test_log_pdf_matches_log_of_pdf(bicop_pair): - family, params, rotation, U, bc_fast, bc_tll = bicop_pair - for cop in (bc_fast, bc_tll): + family, params, rotation, U, bc_fast, bc_tll, bc_torch = bicop_pair + for cop in (bc_fast, bc_tll, bc_torch): pts = torch.rand(500, 2, dtype=torch.float64, device=cop.device) pdf = cop.pdf(pts) logp = cop.log_pdf(pts) @@ -102,8 +102,8 @@ def test_log_pdf_handles_zero(): def test_sample_marginals(bicop_pair): - family, params, rotation, U, bc_fast, bc_tll = bicop_pair - for cop in (bc_fast, bc_tll): + family, params, rotation, U, bc_fast, bc_tll, bc_torch = bicop_pair + for cop in (bc_fast, bc_tll, bc_torch): for is_sobol in (False, True): samp = cop.sample(2000, seed=0, is_sobol=is_sobol) # samples lie in [0,1] @@ -117,8 +117,8 @@ def test_sample_marginals(bicop_pair): def test_internal_buffers_and_flags(bicop_pair): - _, _, _, U, bc_fast, bc_tll = bicop_pair - for cop, mtd_kde in [(bc_fast, "fastKDE"), (bc_tll, "tll")]: + _, _, _, U, bc_fast, bc_tll, bc_torch = bicop_pair + for cop, mtd_kde in [(bc_fast, "fastKDE"), (bc_tll, "tll"), (bc_torch, "torchKDE")]: print(cop) assert not cop.is_indep assert cop.mtd_kde == mtd_kde @@ -131,7 +131,7 @@ def test_internal_buffers_and_flags(bicop_pair): def test_tau_estimation(bicop_pair): - _, _, _, U, bc_fast, bc_mtd_kde = bicop_pair + _, _, _, U, bc_fast, bc_mtd_kde, bc_torch = bicop_pair # re‐fit with tau estimation bc = tvc.BiCop(num_step_grid=64) bc.fit(U, mtd_kde="tll", is_tau_est=True) @@ -141,8 +141,8 @@ def test_tau_estimation(bicop_pair): def test_sample_shape_and_dtype_on_tll(bicop_pair): - _, _, _, U, bc_fast, bc_tll = bicop_pair - for cop in (bc_fast, bc_tll): + _, _, _, U, bc_fast, bc_tll, bc_torch = bicop_pair + for cop in (bc_fast, bc_tll, bc_torch): s = cop.sample(123, seed=7, is_sobol=True) assert s.shape == (123, 2) assert s.dtype is cop.dtype @@ -150,8 +150,9 @@ def test_sample_shape_and_dtype_on_tll(bicop_pair): def test_imshow_and_plot_api(bicop_pair): - family, params, rotation, U, bc_fast, bc_tll = bicop_pair - cop = bc_fast + family, params, rotation, U, bc_fast, bc_tll, bc_torch = bicop_pair + # cop = bc_fast + cop = bc_torch # imshow fig, ax = cop.imshow(is_log_pdf=True) assert isinstance(fig, matplotlib.figure.Figure) @@ -186,16 +187,18 @@ def test_imshow_and_plot_api(bicop_pair): def test_plot_accepts_unused_kwargs(bicop_pair): - _, _, _, U, bc_fast, _ = bicop_pair + _, _, _, U, bc_fast, _ , bc_torch = bicop_pair # just ensure it doesn’t crash - bc_fast.plot(plot_type="contour", margin_type="norm", xylim=(0, 1), grid_size=50) - bc_fast.plot(plot_type="surface", margin_type="unif", xylim=(0, 1), grid_size=20) + # bc_fast.plot(plot_type="contour", margin_type="norm", xylim=(0, 1), grid_size=50) + # bc_fast.plot(plot_type="surface", margin_type="unif", xylim=(0, 1), grid_size=20) + bc_torch.plot(plot_type="contour", margin_type="norm", xylim=(0, 1), grid_size=50) + bc_torch.plot(plot_type="surface", margin_type="unif", xylim=(0, 1), grid_size=20) def test_reset_and_str(bicop_pair): # ! notice scope="module" so we put this test at the end - family, params, rotation, U, bc_fast, bc_tll = bicop_pair - for cop in (bc_fast, bc_tll): + family, params, rotation, U, bc_fast, bc_tll, bc_torch = bicop_pair + for cop in (bc_fast, bc_tll, bc_torch): cop.reset() # should go back to independent assert cop.is_indep @@ -261,7 +264,8 @@ def test_interp_on_trivial_grid(): def test_imshow_with_existing_axes(): cop = tvc.BiCop(num_step_grid=32) us = torch.rand(100, 2) - cop.fit(us, mtd_kde="fastKDE") + cop.fit(us, mtd_kde="torchKDE") + # cop.fit(us, mtd_kde="fastKDE") fig, outer_ax = plt.subplots() fig2, ax2 = cop.imshow(is_log_pdf=False, ax=outer_ax, cmap="viridis") # should have returned the same axes object diff --git a/tests/test_torchkde.py b/tests/test_torchkde.py new file mode 100644 index 0000000..a6d9cb3 --- /dev/null +++ b/tests/test_torchkde.py @@ -0,0 +1,145 @@ +import pytest +import torch +import torchvinecopulib as tvc +from scipy.stats import kstest, wasserstein_distance + +from . import EPS, bicop_pair + + +def test_1dkde_internal_grid_is_finite(): + """ + 1D KDE must build finite grids even if input contains NaN/±Inf. + This ensures later interpolation/extrapolation is safe. + """ + x_clean = torch.randn(2000, dtype=torch.float64) + x_bad = torch.tensor([float("nan"), float("inf"), -float("inf")], dtype=torch.float64) + x = torch.cat([x_clean, x_bad]) + + # Build 1D KDE (adjust path if your class lives elsewhere) + kde = tvc.kdeCDFPPF1D(x, bandwidth_method="auto") + + # All internal grids must be finite + for name in ("grid_x", "grid_pdf", "grid_cdf"): + grid = getattr(kde, name) + assert torch.isfinite(grid).all(), f"{name} contains non-finite values" + + # Basic validity: pdf ≥ 0; cdf in [0,1] and non-decreasing + assert (kde.grid_pdf >= -EPS).all() + assert kde.grid_cdf.min() >= -EPS and kde.grid_cdf.max() <= 1 + EPS + if kde.grid_cdf.numel() > 1: + assert kde.grid_cdf.diff().min() >= -EPS + + +def test_bicop_grids_and_eval_are_finite_with_nan_inf(bicop_pair): + """ + All BiCop modes (fastKDE, tll, torchKDE) must: + - build fully finite internal grids even if training data contain NaN/±Inf/OOB, + - evaluate safely on queries containing NaN/±Inf/OOB: + * finite rows → finite outputs (and proper ranges), + * non-finite rows → NaN (no crash). + """ + _, _, _, U, bc_fast, bc_tll, bc_torch = bicop_pair + + # Inject some bad rows into otherwise valid data and refit fresh models + bad = torch.tensor( + [ + [float("nan"), 0.3], + [0.4, float("inf")], + [-float("inf"), 0.9], + [1.1, -0.1], # OOB but finite + ], + dtype=torch.float64, + device=U.device, + ) + U_dirty = torch.vstack([U, bad]) + + modes = [("fastKDE", bc_fast), ("tll", bc_tll), ("torchKDE", bc_torch)] + for name, _ in modes: + cop = tvc.BiCop(num_step_grid=64).to(device=U.device) + cop.fit(U_dirty, mtd_kde=name) + + # 1) Internal grids must be finite + for gname in ("_pdf_grid", "_cdf_grid", "_hfunc_l_grid", "_hfunc_r_grid"): + G = getattr(cop, gname) + assert torch.isfinite(G).all(), f"{name}:{gname} contains non-finite values" + + # 2) Evaluation must be safe on queries with NaN/±Inf/OOB + Q = torch.tensor( + [ + [0.2, 0.7], # clean + [float("nan"), 0.3], # NaN + [0.4, float("inf")], # +Inf + [-float("inf"), 0.9], # -Inf + [0.0, 1.0], # edge + [1.1, -0.1], # OOB (finite) + ], + dtype=torch.float64, + device=cop.device, + ) + finite = torch.isfinite(Q).all(dim=1) + + # pdf/log_pdf: finite rows must be finite; pdf non-negative + pdf = cop.pdf(Q).squeeze(1) + logp = cop.log_pdf(Q).squeeze(1) + assert torch.isfinite(pdf[finite]).all() + assert torch.isfinite(logp[finite]).all() + assert (pdf[finite] >= -EPS).all() + + # Non-finite rows should come back as NaN (by design) and never crash + if (~finite).any(): + assert torch.isnan(pdf[~finite]).all() + assert torch.isnan(logp[~finite]).all() + + # cdf/hfuncs/hinvs: finite rows within [0,1] + for fn in (cop.cdf, cop.hfunc_r, cop.hfunc_l): + out = fn(Q).squeeze(1) + assert (out[finite] >= -EPS).all() and (out[finite] <= 1 + EPS).all() + if (~finite).any(): + assert torch.isnan(out[~finite]).all() + + # Inverses: finite rows → valid [0,1]; non-finite rows → NaN + out_r = cop.hinv_r(Q).squeeze(1) + out_l = cop.hinv_l(Q).squeeze(1) + assert (out_r[finite] >= -EPS).all() and (out_r[finite] <= 1 + EPS).all() + assert (out_l[finite] >= -EPS).all() and (out_l[finite] <= 1 + EPS).all() + if (~finite).any(): + assert torch.isnan(out_r[~finite]).all() + assert torch.isnan(out_l[~finite]).all() + + + +def test_pit_goodness_of_fit_train_test(bicop_pair): + """ + Split U into train/test. + Fit on train; compute PIT on test via hfunc_r. + PIT should be close to Uniform[0,1] (KS and Wasserstein). + """ + + _, _, _, U, bc_fast, bc_tll, bc_torch = bicop_pair + n = U.shape[0] + n_tr = int(0.6 * n) + Utr, Ute = U[:n_tr], U[n_tr:] + + cases = [ + ("fastKDE", bc_fast), + ("tll", bc_tll), + ("torchKDE", bc_torch), + ] + + for name, _ in cases: + cop = tvc.BiCop(num_step_grid=64) + cop.fit(Utr, mtd_kde=name) + + pit = cop.hfunc_r(Ute).squeeze(1).cpu().numpy() # PIT should be ~ U(0,1) + + # KS vs Uniform(0,1) + ks_stat, ks_p = kstest(pit, "uniform") + + # 1-Wasserstein vs iid Uniform sample of same size + uni = torch.rand_like(Ute[:, 0]).cpu().numpy() + wdist = wasserstein_distance(pit, uni) + + # Lenient but meaningful thresholds (tune if your grids/resolution differ) + assert ks_stat < 0.12 + assert wdist < 0.06 + assert ks_p > 1e-3 diff --git a/torchvinecopulib/__init__.py b/torchvinecopulib/__init__.py index 1c85a3c..d2a870d 100644 --- a/torchvinecopulib/__init__.py +++ b/torchvinecopulib/__init__.py @@ -1,11 +1,13 @@ from . import util from .bicop import BiCop from .vinecop import VineCop +from .util import kdeCDFPPF1D __all__ = [ "BiCop", "VineCop", "util", + "kdeCDFPPF1D", ] # dynamically grab the version you just built & installed try: diff --git a/torchvinecopulib/bicop/__init__.py b/torchvinecopulib/bicop/__init__.py index e1e9ebb..ff06521 100644 --- a/torchvinecopulib/bicop/__init__.py +++ b/torchvinecopulib/bicop/__init__.py @@ -40,10 +40,13 @@ import numpy as np import pyvinecopulib as pv import torch -from fastkde import pdf as fkpdf +import math +# from fastkde import pdf as fkpdf # can remove fastKDE dependency from matplotlib.colors import LinearSegmentedColormap from mpl_toolkits.mplot3d.axes3d import Axes3D from scipy.stats import kendalltau, norm +import torch.nn.functional as F +from ..util.bandwidth import * from ..util import _EPS, solve_ITP @@ -91,6 +94,7 @@ def __init__( ) # ! device agnostic self.register_buffer("_dd", torch.tensor([], dtype=torch.float64)) + self._LOG_EPS = 1e-6 # set this so that log(1e-6) = -13.815510557964274 @property def device(self) -> torch.device: @@ -110,6 +114,53 @@ def dtype(self) -> torch.dtype: """ return self._dd.dtype + # 2 helper functions to clean data + def _sanitize_train_obs(self, obs: torch.Tensor) -> torch.Tensor: + """ + Keep only fully finite rows and clamp to [0,1]^2 for fitting. + No change on self.num_obs (tests expect original count). + """ + obs = obs.to(device=self.device, dtype=self.dtype) + m = torch.isfinite(obs).all(dim=1) + if not m.any(): + return obs[:0] # empty [0,2] + return obs[m].clamp_(0.0, 1.0) + + def _eval_on_finite(self, obs: torch.Tensor, eval_fn, clamp_inputs: bool = True) -> torch.Tensor: + """ + Numerically-safe evaluation wrapper (tll/fastKDE/torchKDE). + + - Builds an (N,1) output tensor initialized to NaN. + - Runs `eval_fn` ONLY on rows where (u,v) are fully finite. + - Optionally clamps those finite rows to [EPS, 1-EPS] to prevent OOB indexing. + - Fills results back into the corresponding rows; non-finite rows stay NaN. + + Args: + obs (torch.Tensor): (N,2) query points. + eval_fn (callable): function mapping (k,2) -> (k,1). + clamp_inputs (bool): clamp finite (u,v) to [EPS, 1-EPS] before eval, default True. + + Returns: + torch.Tensor: (N,1) tensor with evaluated results (finite rows) and NaN elsewhere. + """ + # ! device agnostic + obs = obs.to(device=self.device, dtype=self.dtype) + # ensure a 2D shape first + if obs.ndim == 1: + obs = obs.view(-1, 2) + N = obs.shape[0] + out = torch.full((N, 1), float("nan"), dtype=self.dtype, device=self.device) + + mask = torch.isfinite(obs).all(dim=1) + if mask.any(): + z = obs[mask] + if clamp_inputs: + z = z.clamp(self._EPS, 1.0 - self._EPS) + val = eval_fn(z) + out[mask, 0] = val.view(-1) + + return out + @torch.no_grad() def reset(self) -> None: """Reinitialize state and zero all statistics and precomputed grids. @@ -148,7 +199,7 @@ def fit( Args: obs (torch.Tensor): shape (n, 2) bicop obs in [0, 1]². - mtd_kde (str, optional): Method for estimating the copula density. One of ("tll", "fastKDE"). Defaults to "tll". + mtd_kde (str, optional): Method for estimating the copula density. One of ("tll", "torchkde", "fastKDE"). Defaults to "tll". mtd_tll (str, optional): fit method for the transformation local-likelihood (TLL) nonparametric family, used only when ``mtd_kde="tll"``, one of ("constant", "linear", or "quadratic"). Defaults to "constant". num_iter_max (int, optional): num of Sinkhorn/IPF iters for grid normalization, used only when ``mtd_kde="fastKDE"``. Defaults to 17. is_tau_est (bool, optional): If True, compute and store Kendall’s τ. Defaults to ``False``. @@ -158,18 +209,21 @@ def fit( self.is_indep = False self.mtd_kde = mtd_kde self.num_obs.copy_(obs.shape[0]) - # * assuming already in [0, 1] - obs = obs.clamp(min=0.0, max=1.0) + + # clean training points once + obs_clean = self._sanitize_train_obs(obs) if is_tau_est: - self.tau.copy_( - torch.as_tensor( - kendalltau(obs[:, 0].cpu(), obs[:, 1].cpu()), - device=device, - dtype=dtype, - ) - ) - self._target = self.num_step_grid - 1.0 # * marginal target + if obs_clean.numel() == 0: + self.tau.zero_() + else: + self.tau.copy_(torch.as_tensor( + kendalltau(obs_clean[:, 0].cpu(), obs_clean[:, 1].cpu()), + device=self.device, dtype=self.dtype + )) + # grid geometry + self._target = self.num_step_grid - 1.0 self.step_grid = 1.0 / self._target + # ! pdf if mtd_kde == "tll": controls = pv.FitControlsBicop( @@ -177,7 +231,7 @@ def fit( num_threads=torch.get_num_threads(), nonparametric_method=mtd_tll, ) - cop = pv.Bicop.from_data(data=obs.cpu().numpy(), controls=controls) + cop = pv.Bicop.from_data(data=obs_clean.cpu().numpy(), controls=controls) axis = torch.linspace( _EPS, 1.0 - _EPS, @@ -191,13 +245,13 @@ def fit( .to(device=device, dtype=dtype) ) elif mtd_kde == "fastKDE": - pdf_grid = torch.from_numpy( - fkpdf( - obs[:, 0].cpu(), - obs[:, 1].cpu(), - num_points=self.num_step_grid * 2 + 1, - ).values - ).to(device=device, dtype=dtype) + if obs_clean.numel() == 0: + pdf_grid = torch.ones(self.num_step_grid, self.num_step_grid, dtype=dtype, device=device) + else: + pdf_grid = torch.from_numpy( + fkpdf(obs_clean[:, 0].cpu(), obs_clean[:, 1].cpu(), + num_points=self.num_step_grid * 2 + 1).values + ).to(device=device, dtype=dtype) # * padding/trimming after fastkde.pdf H, W = pdf_grid.shape if H < self.num_step_grid: @@ -224,6 +278,75 @@ def fit( pdf_grid *= self._target / pdf_grid.sum(dim=0, keepdim=True) pdf_grid *= self._target / pdf_grid.sum(dim=1, keepdim=True) pdf_grid /= pdf_grid.sum() * self.step_grid**2 + + elif mtd_kde == "torchKDE": + if obs_clean.numel() == 0: + # Independent fallback (finite, valid grids) + pdf_grid = torch.ones(self.num_step_grid, self.num_step_grid, dtype=dtype, device=device) + else: + U = obs_clean[:, 0].to(torch.float64) + V = obs_clean[:, 1].to(torch.float64) + n = U.numel() + G = int(self.num_step_grid) + tgt = float(self._target) + dx = float(self.step_grid) + + # per-axis bandwidths (fallback to Scott) + try: + hx = float(optimal_bandwidth(U, method="auto")) + hy = float(optimal_bandwidth(V, method="auto")) + except Exception: + sf = n ** (-1.0 / 6.0) + hx = float(sf * U.std(unbiased=True).clamp_min(_EPS)) + hy = float(sf * V.std(unbiased=True).clamp_min(_EPS)) + + # avoid kernels narrower than grid spacing + hx = max(hx, 0.5 * dx) + hy = max(hy, 0.5 * dx) + + # 2D histogram on the unit square grid + ix = (U / dx).floor().clamp_(0, G - 1).to(torch.int64) + iy = (V / dx).floor().clamp_(0, G - 1).to(torch.int64) + lin = ix * G + iy + counts = torch.zeros(G * G, dtype=torch.float64, device=device) + counts.scatter_add_(0, lin, torch.ones_like(lin, dtype=torch.float64, device=device)) + counts = counts.view(G, G) + + # separable Gaussian kernels (truncate at 4 * sigma) + rx = max(1, int(math.ceil(4.0 * (hx / dx)))) + ry = max(1, int(math.ceil(4.0 * (hy / dx)))) + ox = torch.arange(-rx, rx + 1, dtype=torch.float64, device=device) * dx + oy = torch.arange(-ry, ry + 1, dtype=torch.float64, device=device) * dx + kx = torch.exp(-0.5 * (ox / hx) ** 2) / (hx * math.sqrt(2.0 * math.pi)) + ky = torch.exp(-0.5 * (oy / hy) ** 2) / (hy * math.sqrt(2.0 * math.pi)) + + # normalize discrete kernels so their (Riemann) sums are 1 + kx = kx / (kx.sum().clamp_min(_EPS) * dx) + ky = ky / (ky.sum().clamp_min(_EPS) * dx) + + s = counts.unsqueeze(0).unsqueeze(0) + s = F.pad(s, (ry, ry, rx, rx), mode="reflect") # (left,right,top,bottom) + s = F.conv2d(s, kx.view(1, 1, -1, 1)) + s = F.conv2d(s, ky.view(1, 1, 1, -1)) + pdf_grid = (s.squeeze(0).squeeze(0) / max(n, 1)).clamp_min(0.0) + + # small positive floor before IPF to avoid stuck-zero rows/cols + pdf_grid = pdf_grid.clamp_min(self._EPS) + + # IPF to enforce discrete uniform marginals on the copula grid + for _ in range(int(num_iter_max)): + pdf_grid *= tgt / pdf_grid.sum(dim=0, keepdim=True).clamp_min(_EPS) + pdf_grid *= tgt / pdf_grid.sum(dim=1, keepdim=True).clamp_min(_EPS) + + # normalize to integrate to 1 on [0,1]^2 (discrete dx^2) + pdf_grid /= (pdf_grid.sum().clamp_min(_EPS) * (dx * dx)) + # keep strictly positive so log_pdf on finite rows is finite later + pdf_grid = pdf_grid.clamp_min(self._EPS) + pdf_grid /= (pdf_grid.sum().clamp_min(_EPS) * (dx * dx)) + + # guarantee grid is finite + pdf_grid = pdf_grid.nan_to_num(0.0, posinf=0.0, neginf=0.0).to(device=device, dtype=dtype) + else: raise NotImplementedError self._pdf_grid = pdf_grid @@ -275,11 +398,13 @@ def cdf(self, obs: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: CDF values at each observation, shape (n,1). """ - # ! device agnostic - obs = obs.to(device=self.device, dtype=self.dtype) + # Independent copula: C(u,v) = u*v; mask so NaN rows remain NaN if self.is_indep: - return obs.prod(dim=1, keepdim=True) - return self._interp(grid=self._cdf_grid, obs=obs).unsqueeze(dim=1) + return self._eval_on_finite(obs, lambda z: (z[:, [0]] * z[:, [1]]), clamp_inputs=False).clamp_(0.0, 1.0) + + # Grid-based evaluation for all KDE modes + val = self._eval_on_finite(obs, lambda z: self._interp(self._cdf_grid, z).unsqueeze(1)) + return val.clamp_(0.0, 1.0) def hfunc_l(self, obs: torch.Tensor) -> torch.Tensor: """Evaluate the left h-function at given points. Computes H(u₂ | u₁):= ∂/∂u₁ C(u₁,u₂) for @@ -291,11 +416,12 @@ def hfunc_l(self, obs: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Left h-function values at each observation, shape (n,1). """ - # ! device agnostic - obs = obs.to(device=self.device, dtype=self.dtype) + # Independent copula: h_l(v|u) = v if self.is_indep: - return obs[:, [1]] - return self._interp(grid=self._hfunc_l_grid, obs=obs).unsqueeze(dim=1) + return self._eval_on_finite(obs, lambda z: z[:, [1]], clamp_inputs=False) + + val = self._eval_on_finite(obs, lambda z: self._interp(self._hfunc_l_grid, z).unsqueeze(1)) + return val.clamp_(0.0, 1.0) def hfunc_r(self, obs: torch.Tensor) -> torch.Tensor: """Evaluate the right h-function at given points. Computes H(u₁ | u₂):= ∂/∂u₂ C(u₁,u₂) for @@ -307,87 +433,160 @@ def hfunc_r(self, obs: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Right h-function values at each observation, shape (n,1). """ - # ! device agnostic - obs = obs.to(device=self.device, dtype=self.dtype) + # Independent copula: h_r(u|v) = u if self.is_indep: - return obs[:, [0]] - return self._interp(grid=self._hfunc_r_grid, obs=obs).unsqueeze(dim=1) + return self._eval_on_finite(obs, lambda z: z[:, [0]], clamp_inputs=False) + + val = self._eval_on_finite(obs, lambda z: self._interp(self._hfunc_r_grid, z).unsqueeze(1)) + return val.clamp_(0.0, 1.0) @torch.no_grad() def hinv_l(self, obs: torch.Tensor) -> torch.Tensor: - """Invert the left h‐function via root‐finding: find u₂ given (u₁, p). Solves H(u₂ | u₁) = p - by ITP between 0 and 1. + """Invert the left h-function: solve u2 s.t. H_l(u2 | u1) = p. Args: - obs (torch.Tensor): Points in [0,1]² where to evaluate the left h-function (rows are (u₁,u₂)), shape (n,2). - + obs (torch.Tensor): Rows are (u1, p), nominally in [0,1]^2, shape (n,2). Returns: - torch.Tensor: Solutions u₂ ∈ [0,1], shape (n,1). + torch.Tensor: u2 ∈ [0,1], shape (n,1). Non-finite input rows yield NaN. """ - # ! device agnostic + # normalize device/dtype once obs = obs.to(device=self.device, dtype=self.dtype) + # Ensure a 2D shape + if obs.ndim == 1: + obs = obs.view(-1, 2) + + # independent copula: H_l(v|u) = v ⇒ u2 = p if self.is_indep: return obs[:, [1]] - # * via root-finding - u_l = obs[:, [0]] - p = obs[:, [1]] - return solve_ITP( - fun=lambda u_r: self.hfunc_l(obs=torch.hstack([u_l, u_r])) - p, - x_a=torch.zeros_like(p), - x_b=torch.ones_like(p), - ).clamp(min=0.0, max=1.0) + + n = obs.shape[0] + out = torch.full((n, 1), float("nan"), dtype=self.dtype, device=self.device) + + # keep only fully finite rows; others remain NaN + mask = torch.isfinite(obs).all(dim=1) + if not mask.any(): + return out + + z = obs[mask] + # known u1 in [ε, 1−ε] to stay on-grid; target prob p ∈ [0,1] + u1 = z[:, [0]].clamp(self._EPS, 1.0 - self._EPS) + p = z[:, [1]].clamp(0.0, 1.0) + + # fast paths for edge probabilities + u2 = torch.empty_like(p) + # make 1-D masks + edge0 = (p <= self._EPS).squeeze(1) # (K,) + edge1 = (p >= 1.0 - self._EPS).squeeze(1) # (K,) + u2[edge0, 0] = 0.0 + u2[edge1, 0] = 1.0 + + mid = ~(edge0 | edge1) # (K,) + if mid.any(): + u1m = u1[mid, :] # (Km,1) + pm = p[mid, :] # (Km,1) + sol = solve_ITP( + fun=lambda u2m: self.hfunc_l(obs=torch.hstack([u1m, u2m])) - pm, + x_a=torch.zeros_like(pm), + x_b=torch.ones_like(pm), + ) + u2[mid, 0] = sol.view(-1) # assign as flat column + + out[mask] = u2.clamp_(0.0, 1.0) + return out + @torch.no_grad() def hinv_r(self, obs: torch.Tensor) -> torch.Tensor: - """Invert the right h‐function via root‐finding: find u₁ given (u₂, p). Solves H(u₁ | u₂) = - p by ITP between 0 and 1. + """Invert the right h-function: solve u1 s.t. H_r(u1 | u2) = p. Args: - obs (torch.Tensor): Points in [0,1]² where to evaluate the right h-function (rows are (u₁,u₂)), shape (n,2). + obs (torch.Tensor): Rows are (p, u2) or (u1?, u2?)? (Your current API uses (p,u2): + in this implementation we follow your code: obs[:,0]=p, obs[:,1]=u2.) + Shape (n,2), values nominally in [0,1]. Returns: - torch.Tensor: Solutions u₁ ∈ [0,1], shape (n,1). + torch.Tensor: u1 ∈ [0,1], shape (n,1). Non-finite input rows yield NaN. """ - # ! device agnostic + # normalize device/dtype once obs = obs.to(device=self.device, dtype=self.dtype) + # Ensure a 2D shape + if obs.ndim == 1: + obs = obs.view(-1, 2) + + # independent copula: H_r(u|v) = u ⇒ u1 = p if self.is_indep: return obs[:, [0]] - # * via root-finding - u_r = obs[:, [1]] - p = obs[:, [0]] - return solve_ITP( - fun=lambda u_l: self.hfunc_r(obs=torch.hstack([u_l, u_r])) - p, - x_a=torch.zeros_like(p), - x_b=torch.ones_like(p), - ).clamp(min=0.0, max=1.0) - def pdf(self, obs: torch.Tensor) -> torch.Tensor: - """Evaluate the copula PDF at given points. For independent copula, returns 1. + n = obs.shape[0] + out = torch.full((n, 1), float("nan"), dtype=self.dtype, device=self.device) + + # keep only fully finite rows; others remain NaN + mask = torch.isfinite(obs).all(dim=1) + if not mask.any(): + return out + + z = obs[mask] + # known u2 in [ε, 1−ε] to stay on-grid; target prob p ∈ [0,1] + p = z[:, [0]].clamp(0.0, 1.0) + u2 = z[:, [1]].clamp(self._EPS, 1.0 - self._EPS) + + # fast paths for edge probabilities + u1 = torch.empty_like(p) + # make 1-D masks + edge0 = (p <= self._EPS).squeeze(1) # (K,) + edge1 = (p >= 1.0 - self._EPS).squeeze(1) # (K,) + u1[edge0, 0] = 0.0 + u1[edge1, 0] = 1.0 + + mid = ~(edge0 | edge1) # (K,) + if mid.any(): + pm = p[mid, :] # (Km,1) + u2m = u2[mid, :] # (Km,1) + sol = solve_ITP( + fun=lambda u1m: self.hfunc_r(obs=torch.hstack([u1m, u2m])) - pm, + x_a=torch.zeros_like(pm), + x_b=torch.ones_like(pm), + ) + u1[mid, 0] = sol.view(-1) - Args: - obs (torch.Tensor): Points in [0,1]² where to evaluate the PDF (rows are (u₁,u₂)), shape (n,2). - Returns: - torch.Tensor: PDF values at each observation, shape (n,1). + out[mask] = u1.clamp_(0.0, 1.0) + return out + + + def pdf(self, obs: torch.Tensor) -> torch.Tensor: + """ + Evaluate the copula PDF at given points. For independent copula, returns 1 on finite rows, + and leaves non-finite rows as NaN (via _eval_on_finite). """ - # ! device agnostic - obs = obs.to(device=self.device, dtype=self.dtype) if self.is_indep: - return torch.ones_like(obs[:, [0]]) - return self._interp(grid=self._pdf_grid, obs=obs).unsqueeze(dim=1) + # finite rows -> 1; non-finite rows stay NaN + return self._eval_on_finite( + obs, + lambda z: torch.ones((z.shape[0], 1), dtype=self.dtype, device=self.device), + clamp_inputs=False, # indep copula doesn't need clamping + ) - def log_pdf(self, obs: torch.Tensor) -> torch.Tensor: - """Evaluate the copula log-PDF at given points, with safe handling of inf/nan. For - independent copula, returns 0. + # nonparametric modes: interpolate and keep strictly positive on finite rows + return self._eval_on_finite( + obs, + # clamp_min only applies to finite rows because it's inside eval_fn + lambda z: self._interp(self._pdf_grid, z).unsqueeze(1).clamp_min(self._EPS) + ) - Args: - obs (torch.Tensor): Points in [0,1]² where to evaluate the log-PDF (rows are (u₁,u₂)), shape (n,2). - Returns: - torch.Tensor: log-PDF values at each observation, shape (n,1). + + def log_pdf(self, obs: torch.Tensor) -> torch.Tensor: """ - # ! device agnostic - obs = obs.to(device=self.device, dtype=self.dtype) - if self.is_indep: - return torch.zeros_like(obs[:, [0]]) - return self.pdf(obs=obs).log().nan_to_num(posinf=0.0, neginf=-13.815510557964274) + log-PDF: do not sanitize outputs; NaN rows must remain NaN per tests. + Finite rows are finite because pdf() >= EPS for them. + """ + p = self.pdf(obs) # shape (N,1), NaN on bad rows via _eval_on_finite + logp = torch.log(p) # NaN stays NaN, zeros -> -inf + # Only replace -inf (coming from zeros) with ln(LOG_EPS), leave everything else as-is + neg_inf = torch.isinf(logp) & (logp < 0) + if neg_inf.any(): + logp = logp.clone() + logp[neg_inf] = math.log(self._LOG_EPS) + return logp + @torch.no_grad() def sample( diff --git a/torchvinecopulib/util/__init__.py b/torchvinecopulib/util/__init__.py index b73d203..be28169 100644 --- a/torchvinecopulib/util/__init__.py +++ b/torchvinecopulib/util/__init__.py @@ -22,11 +22,14 @@ import enum from pprint import pformat -import fastkde +import math +# import fastkde import torch +import torch.nn.functional as F from scipy.stats import kendalltau -_EPS = 1e-10 +from .constants import _EPS +from .bandwidth import * @torch.no_grad() @@ -46,32 +49,156 @@ def kendall_tau(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: device=x.device, ) +def _auto_grid_size(range_len: float, h: float, cells_per_sigma: int, gmin: int, gmax: int) -> int: + if h <= 0.0 or not math.isfinite(h): # fallback + return max(gmin, 256) + est = int(math.ceil(cells_per_sigma * max(range_len, 1e-12) / h)) + return int(max(gmin, min(gmax, est))) @torch.no_grad() -def mutual_info(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Estimate mutual information using ``fastKDE``. Moves inputs to CPU and delegates to - ``fastKDE.pdf``. - - - O’Brien, T. A., Kashinath, K., Cavanaugh, N. R., Collins, W. D., & O’Brien, J. P. (2016). A fast and objective multidimensional kernel density estimation method: fastKDE. Computational Statistics & Data Analysis, 101, 148-160. - - O’Brien, T. A., Collins, W. D., Rauscher, S. A., & Ringler, T. D. (2014). Reducing the computational cost of the ECF using a nuFFT: A fast and objective probability density estimation method. Computational Statistics & Data Analysis, 79, 222-234. - - Purkayastha, S., & Song, P. X. K. (2024). fastMI: A fast and consistent copula-based nonparametric estimator of mutual information. Journal of Multivariate Analysis, 201, 105270. +@torch.no_grad() +def mutual_info( + x: torch.Tensor, y: torch.Tensor, + hxy: tuple[float, float] | None = None, + cells_per_sigma: int = 10, + grid_min: int = 128, + grid_max: int = 2048, + bandwidth_method: str = "auto", + bandwidth_kwargs: dict | None = None, +) -> torch.Tensor: + """ + Estimate mutual information I(X;Y) via Torch-only 2D KDE on an adaptive grid. Args: - x (torch.Tensor): shape (n, 1) - y (torch.Tensor): shape (n, 1) + x (torch.Tensor): 1-D samples for X. Non-finite pairs (with y) are dropped. + y (torch.Tensor): 1-D samples for Y. Non-finite pairs (with x) are dropped. + hxy (tuple[float, float] | None, optional): (hx, hy). If None, choose per-axis + via `optimal_bandwidth`; falls back to Scott’s rule if needed. + cells_per_sigma (int, optional): Target grid resolution per Gaussian σ. Default 10. + grid_min (int, optional): Minimum cells per axis. Default 128. + grid_max (int, optional): Maximum cells per axis. Default 2048. + bandwidth_method (str, optional): Bandwidth selector for each axis, e.g. "auto", + "isj", or "kfold". Default "auto". + bandwidth_kwargs (dict | None, optional): Extra kwargs passed to the selector. + Returns: - torch.Tensor: Estimated mutual information + torch.Tensor: Scalar 0-D tensor (same dtype/device as `x`) with the MI estimate. """ - x = x.clamp(_EPS, 1.0 - _EPS).view(-1).cpu() - y = y.clamp(_EPS, 1.0 - _EPS).view(-1).cpu() - joint = torch.as_tensor(fastkde.pdf(x, y).values, dtype=x.dtype, device=x.device) - margin_x = torch.as_tensor(fastkde.pdf(x).values, dtype=x.dtype, device=x.device) - margin_y = torch.as_tensor(fastkde.pdf(y).values, dtype=x.dtype, device=x.device) - return ( - joint[joint > 0.0].log().mean() - - margin_x[margin_x > 0.0].log().mean() - - margin_y[margin_y > 0.0].log().mean() - ) + # keep caller dtype/device for the return + out_dtype, out_device = x.dtype, x.device + + # 0) sanitize: keep only finite PAIRS + x = x.view(-1).to(torch.float64) + y = y.view(-1).to(torch.float64) + m = torch.isfinite(x) & torch.isfinite(y) + x, y = x[m], y[m] + n = x.numel() + if n < 2: + return torch.tensor(0.0, dtype=out_dtype, device=out_device) + + # 1) padded ranges (no NaNs now) + x_lo, x_hi = x.min().item(), x.max().item() + y_lo, y_hi = y.min().item(), y.max().item() + rx = max(x_hi - x_lo, 1e-12); ry = max(y_hi - y_lo, 1e-12) + pad_x, pad_y = 0.1 * rx, 0.1 * ry + x_min, x_max = x_lo - pad_x, x_hi + pad_x + y_min, y_max = y_lo - pad_y, y_hi + pad_y + rxp, ryp = (x_max - x_min), (y_max - y_min) + + # 2) bandwidths + if hxy is None: + try: + hx = float(optimal_bandwidth(x, method=bandwidth_method, **(bandwidth_kwargs or {}))) + hy = float(optimal_bandwidth(y, method=bandwidth_method, **(bandwidth_kwargs or {}))) + except Exception: + sf = n ** (-1.0 / 6.0) + hx = float(sf * x.std(unbiased=True).clamp_min(1e-12)) + hy = float(sf * y.std(unbiased=True).clamp_min(1e-12)) + else: + hx, hy = map(float, hxy) + if not (math.isfinite(hx) and math.isfinite(hy)) or hx <= 0 or hy <= 0: + # fallback if user passed junk + sf = n ** (-1.0 / 6.0) + hx = float(sf * x.std(unbiased=True).clamp_min(1e-12)) + hy = float(sf * y.std(unbiased=True).clamp_min(1e-12)) + + # 3) adaptive grid sizes + def _auto_grid_size(range_len: float, h: float) -> int: + if h <= 0.0 or not math.isfinite(h): + return max(grid_min, 256) + est = int(math.ceil(cells_per_sigma * max(range_len, 1e-12) / h)) + return int(max(grid_min, min(grid_max, est))) + + nx = _auto_grid_size(rxp, hx) + ny = _auto_grid_size(ryp, hy) + + gx = torch.linspace(x_min, x_max, nx, dtype=torch.float64) + gy = torch.linspace(y_min, y_max, ny, dtype=torch.float64) + dx = float(gx[1] - gx[0]) + dy = float(gy[1] - gy[0]) + + # avoid undersmoothing relative to grid + hx = max(hx, 0.5 * dx); hy = max(hy, 0.5 * dy) + + # 4) histogram with indices clamped in-range + ix = ((x - x_min) / dx).floor().clamp(0, nx - 1).to(torch.long) + iy = ((y - y_min) / dy).floor().clamp(0, ny - 1).to(torch.long) + lin = ix * ny + iy + counts = torch.zeros(nx * ny, dtype=torch.float64) + counts.scatter_add_(0, lin, torch.ones_like(lin, dtype=torch.float64)) + counts = counts.view(nx, ny) + + # 5) separable Gaussian smoothing (truncate at 4σ) + rxk = max(1, int(math.ceil(4.0 * (hx / dx)))) + ryk = max(1, int(math.ceil(4.0 * (hy / dy)))) + ox = torch.arange(-rxk, rxk + 1, dtype=torch.float64) * dx + oy = torch.arange(-ryk, ryk + 1, dtype=torch.float64) * dy + kx = torch.exp(-0.5 * (ox / hx) ** 2) / (hx * math.sqrt(2.0 * math.pi)) + ky = torch.exp(-0.5 * (oy / hy) ** 2) / (hy * math.sqrt(2.0 * math.pi)) + + s = counts.unsqueeze(0).unsqueeze(0) # [1,1,nx,ny] + s = F.conv2d(s, kx.view(1, 1, -1, 1), padding=(rxk, 0)) + s = F.conv2d(s, ky.view(1, 1, 1, -1), padding=(0, ryk)) + pdf = (s.squeeze(0).squeeze(0) / n).clamp_min(1e-12) + + # 6) MI from normalized grid probabilities + p = (pdf * dx * dy).clamp_min(1e-12) + p = p / p.sum() + px = p.sum(dim=1, keepdim=True).clamp_min(1e-12) + py = p.sum(dim=0, keepdim=True).clamp_min(1e-12) + mi = (p * (p.log() - px.log() - py.log())).sum() + + return mi.to(dtype=out_dtype, device=out_device) + + +############################################## +# previous fastKDE version +############################################## +# @torch.no_grad() +# def mutual_info(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: +# """Estimate mutual information using ``fastKDE``. Moves inputs to CPU and delegates to +# ``fastKDE.pdf``. + +# - O’Brien, T. A., Kashinath, K., Cavanaugh, N. R., Collins, W. D., & O’Brien, J. P. (2016). A fast and objective multidimensional kernel density estimation method: fastKDE. Computational Statistics & Data Analysis, 101, 148-160. +# - O’Brien, T. A., Collins, W. D., Rauscher, S. A., & Ringler, T. D. (2014). Reducing the computational cost of the ECF using a nuFFT: A fast and objective probability density estimation method. Computational Statistics & Data Analysis, 79, 222-234. +# - Purkayastha, S., & Song, P. X. K. (2024). fastMI: A fast and consistent copula-based nonparametric estimator of mutual information. Journal of Multivariate Analysis, 201, 105270. + +# Args: +# x (torch.Tensor): shape (n, 1) +# y (torch.Tensor): shape (n, 1) +# Returns: +# torch.Tensor: Estimated mutual information +# """ +# x = x.clamp(_EPS, 1.0 - _EPS).view(-1).cpu() +# y = y.clamp(_EPS, 1.0 - _EPS).view(-1).cpu() +# joint = torch.as_tensor(fastkde.pdf(x, y).values, dtype=x.dtype, device=x.device) +# margin_x = torch.as_tensor(fastkde.pdf(x).values, dtype=x.dtype, device=x.device) +# margin_y = torch.as_tensor(fastkde.pdf(y).values, dtype=x.dtype, device=x.device) +# return ( +# joint[joint > 0.0].log().mean() +# - margin_x[margin_x > 0.0].log().mean() +# - margin_y[margin_y > 0.0].log().mean() +# ) @torch.no_grad() @@ -153,7 +280,7 @@ def __call__(self, x: torch.Tensor, y: torch.Tensor, **kw): class kdeCDFPPF1D(torch.nn.Module): - _EPS = _EPS + _EPS = _EPS # keep your constant def __init__( self, @@ -162,52 +289,91 @@ def __init__( x_min: float = None, x_max: float = None, pad: float = 0.1, + h: torch.Tensor | float = None, # optional: allow passing a bandwidth + bandwidth_method: str = "auto", + bandwidth_kwargs: dict | None = None, ): - """1D KDE CDF/PPF using ``fastKDE`` + Simpson's rule. Given a sample ``x``, fits a kernel - density estimate via ``fastKDE`` on a grid of size ``num_step_grid`` (power of two plus - one). Precomputes PDF, CDF, and their finite‐difference slopes for fast interpolation. - - - O’Brien, T. A., Kashinath, K., Cavanaugh, N. R., Collins, W. D., & O’Brien, J. P. (2016). A fast and objective multidimensional kernel density estimation method: fastKDE. Computational Statistics & Data Analysis, 101, 148-160. - - O’Brien, T. A., Collins, W. D., Rauscher, S. A., & Ringler, T. D. (2014). Reducing the computational cost of the ECF using a nuFFT: A fast and objective probability density estimation method. Computational Statistics & Data Analysis, 79, 222-234. - - Args: - x (torch.Tensor): input sample to fit the KDE. - num_step_grid (int, optional): number of grid points for the KDE, should be power of 2 plus 1. Defaults to None. - x_min (float, optional): minimum value of the grid. Defaults to x.min() - pad. - x_max (float, optional): maximum value of the grid. Defaults to x.max() + pad. - pad (float, optional): padding to extend beyond the min/max when ``x_min``/``x_max`` is None. Defaults to 1.0. + """ + 1D KDE CDF/PPF using pure Torch: + - Bin samples on an equispaced grid + - Smooth counts with a Gaussian kernel via conv1d (zero-padded => no wraparound) + - Normalize to get PDF, integrate to get CDF """ super().__init__() - self.num_obs = x.shape[0] - self.x_min = x_min if x_min is not None else x.min().item() - pad - self.x_max = x_max if x_max is not None else x.max().item() + pad - # * power of 2 plus 1 + + # Work in float64 for stability; buffers carry device/dtype + x = x.view(-1).to(dtype=torch.float64) + x = x[torch.isfinite(x)] + self.num_obs = x.numel() + + # Domain & grid + x_lo = x.min().item() + x_hi = x.max().item() + self.x_min = float(x_min) if x_min is not None else (x_lo - pad * (x_hi - x_lo + 1e-12)) + self.x_max = float(x_max) if x_max is not None else (x_hi + pad * (x_hi - x_lo + 1e-12)) if num_step_grid is None: - num_step_grid = int(2 ** torch.log2(torch.tensor(x.numel())).ceil().item()) + 1 - self.num_step_grid = num_step_grid - # * fastkde - res = fastkde.pdf(x.view(-1).cpu().numpy(), num_points=num_step_grid) - xs = torch.from_numpy(res.var0.values).to(dtype=torch.float64) - pdfs = torch.from_numpy(res.values).to(dtype=torch.float64).clamp_min(self._EPS) - N = pdfs.shape[0] - ws = torch.ones(N, dtype=torch.float64) - ws[1:-1:2] = 4 - ws[2:-1:2] = 2 - h = xs[1] - xs[0] - cdf = torch.cumsum(pdfs * ws, dim=0) * (h / 3) - cdf = cdf / cdf[-1] - slope_fwd = (cdf[1:] - cdf[:-1]) / h - slope_inv = h / (cdf[1:] - cdf[:-1]) - slope_pdf = (pdfs[1:] - pdfs[:-1]) / h + # power-of-two-ish for nice conv/cache; "+1" keeps your original spirit + pow2 = 1 << (int(max(16, self.num_obs)).bit_length()) # >=16 + num_step_grid = int(pow2 + 1) + self.num_step_grid = int(num_step_grid) + + xs = torch.linspace(self.x_min, self.x_max, self.num_step_grid, dtype=torch.float64) + dx = xs[1] - xs[0] + + # Bandwidth (default: Silverman) + if h is None: + try: + h = optimal_bandwidth( + x, method=bandwidth_method, **(bandwidth_kwargs or {}) + ) + except Exception: + # fallback to Silverman if selector not available + std = x.std(unbiased=True).clamp_min(1e-12) + h = 1.06 * std * (self.num_obs ** (-1.0 / 5.0)) + + h = torch.as_tensor(h, dtype=torch.float64) + + # Bin to nearest grid index + idx_f = ((x - self.x_min) / dx).round() + idx = idx_f.clamp(0, self.num_step_grid - 1).to(torch.long) + counts = torch.zeros(self.num_step_grid, dtype=torch.float64).scatter_add_( + 0, idx, torch.ones_like(idx, dtype=torch.float64) + ) + + # Build discrete Gaussian kernel (truncate at 4 std devs) + rad = int(math.ceil(4.0 * (h / dx).item())) + offsets = torch.arange(-rad, rad + 1, dtype=torch.float64) * dx + kernel = torch.exp(-0.5 * (offsets / h).pow(2)) / (h * math.sqrt(2.0 * math.pi)) + kernel = kernel.unsqueeze(0).unsqueeze(0) # [1,1,kw] + signal = counts.unsqueeze(0).unsqueeze(0) # [1,1,N] + + # Linear conv with zero padding (no wrap) + pdf = F.conv1d(signal, kernel, padding=rad).squeeze() / self.num_obs # shape [N] + pdf = pdf.clamp_min(self._EPS) + + # CDF by trapezoid rule + renorm (monotone guard) + cdf = torch.empty_like(pdf) + cdf[0] = 0.0 + cdf[1:] = torch.cumsum(0.5 * (pdf[:-1] + pdf[1:]) * dx, dim=0) + cdf = (cdf / cdf[-1]).clamp(0.0, 1.0) + # enforce monotonicity in case of tiny numeric dips + cdf = torch.cummax(cdf, dim=0).values + + # Precompute slopes for fast interpolation + slope_fwd = (cdf[1:] - cdf[:-1]) / dx + slope_inv = dx / (cdf[1:] - cdf[:-1]).clamp_min(self._EPS) + slope_pdf = (pdf[1:] - pdf[:-1]) / dx + + # Register buffers for device/dtype agnosticism self.register_buffer("grid_x", xs) - self.register_buffer("grid_pdf", pdfs) + self.register_buffer("grid_pdf", pdf) self.register_buffer("grid_cdf", cdf) self.register_buffer("slope_fwd", slope_fwd) self.register_buffer("slope_inv", slope_inv) self.register_buffer("slope_pdf", slope_pdf) - self.h = h - # ! device agnostic - self.register_buffer("_dd", torch.tensor([], dtype=torch.float64)) + self.h = float(h.item()) + self.register_buffer("_dd", torch.tensor([], dtype=torch.float64)) # device anchor + self.negloglik = -self.log_pdf(x).mean() @property @@ -219,61 +385,72 @@ def dtype(self): return self._dd.dtype def cdf(self, x: torch.Tensor) -> torch.Tensor: - """Compute the CDF of the fitted KDE at ``x``. + """ + KDE CDF via piecewise-linear interpolation on the precomputed 1-D grid. Args: - x (torch.Tensor): Points at which to evaluate the CDF. + x (torch.Tensor): Query points. Non-finite entries return NaN in-place. + Returns: - torch.Tensor: CDF values at ``x``, clamped to [0, 1]. + torch.Tensor: CDF values in [0,1], same shape/dtype/device as `x`. """ - # ! device agnostic + x = x.to(device=self.device, dtype=self.dtype) - x_clamped = x.clamp(self.x_min, self.x_max) - idx = torch.searchsorted(self.grid_x, x_clamped, right=False) - idx = idx.clamp(1, self.grid_cdf.numel() - 1) - y = (self.grid_cdf[idx - 1]) + (self.slope_fwd[idx - 1]) * ( - x_clamped - self.grid_x[idx - 1] - ) - y = torch.where(x < self.x_min, torch.zeros_like(y), y) - y = torch.where(x > self.x_max, torch.ones_like(y), y) - return y.clamp(0.0, 1.0) + out = torch.full_like(x, float("nan")) + mask = torch.isfinite(x) + if mask.any(): + xc = x[mask].clamp(self.x_min, self.x_max) + idx = torch.searchsorted(self.grid_x, xc, right=False).clamp(1, self.grid_cdf.numel() - 1) + y = self.grid_cdf[idx - 1] + self.slope_fwd[idx - 1] * (xc - self.grid_x[idx - 1]) + y = torch.where(x[mask] < self.x_min, torch.zeros_like(y), y) + y = torch.where(x[mask] > self.x_max, torch.ones_like(y), y) + out[mask] = y + return out.clamp(0.0, 1.0) def ppf(self, q: torch.Tensor) -> torch.Tensor: - """Compute the PPF (quantile function) of the fitted KDE at ``q``. + """ + KDE percent-point function (quantile) from the precomputed monotone CDF grid. Args: - q (torch.Tensor): Quantiles at which to evaluate the PPF. + q (torch.Tensor): Probabilities. Non-finite entries return NaN in-place. + Returns: - torch.Tensor: PPF values at ``q``, clamped to [x_min, x_max]. + torch.Tensor: Quantiles in [x_min, x_max], same shape/dtype/device as `q`. """ - # ! device agnostic + q = q.to(device=self.device, dtype=self.dtype) - q_clamped = q.clamp(0.0, 1.0) - idx = torch.searchsorted(self.grid_cdf, q_clamped, right=False) - idx = idx.clamp(1, self.grid_cdf.numel() - 1) - x = (self.grid_x[idx - 1]) + (self.slope_inv[idx - 1]) * ( - q_clamped - self.grid_cdf[idx - 1] - ) - x = torch.where(q < 0.0, torch.full_like(x, self.x_min), x) - x = torch.where(q > 1.0, torch.full_like(x, self.x_max), x) - return x.clamp(self.x_min, self.x_max) + out = torch.full_like(q, float("nan")) + mask = torch.isfinite(q) + if mask.any(): + qc = q[mask].clamp(0.0, 1.0) + idx = torch.searchsorted(self.grid_cdf, qc, right=False).clamp(1, self.grid_cdf.numel() - 1) + x = self.grid_x[idx - 1] + self.slope_inv[idx - 1] * (qc - self.grid_cdf[idx - 1]) + x = torch.where(q[mask] < 0.0, torch.full_like(x, self.x_min), x) + x = torch.where(q[mask] > 1.0, torch.full_like(x, self.x_max), x) + out[mask] = x + return out.clamp(self.x_min, self.x_max) def pdf(self, x: torch.Tensor) -> torch.Tensor: - """Compute the PDF of the fitted KDE at ``x``. + """ + KDE PDF via piecewise-linear interpolation on the precomputed 1-D grid. Args: - x (torch.Tensor): Points at which to evaluate the PDF. + x (torch.Tensor): Query points. Non-finite entries return NaN in-place. + Returns: - torch.Tensor: PDF values at ``x``, clamped to [0, ∞). + torch.Tensor: PDF values (>=0), same shape/dtype/device as `x`. """ - # ! device agnostic + x = x.to(device=self.device, dtype=self.dtype) - x_clamped = x.clamp(self.x_min, self.x_max) - idx = torch.searchsorted(self.grid_x, x_clamped, right=False) - idx = idx.clamp(1, self.grid_pdf.numel() - 1) - f = self.grid_pdf[idx - 1] + (self.slope_pdf[idx - 1]) * (x_clamped - self.grid_x[idx - 1]) - f = torch.where((x < self.x_min) | (x > self.x_max), torch.zeros_like(f), f) - return f.clamp_min(0.0) + out = torch.full_like(x, float("nan")) + mask = torch.isfinite(x) + if mask.any(): + xc = x[mask].clamp(self.x_min, self.x_max) + idx = torch.searchsorted(self.grid_x, xc, right=False).clamp(1, self.grid_pdf.numel() - 1) + f = self.grid_pdf[idx - 1] + self.slope_pdf[idx - 1] * (xc - self.grid_x[idx - 1]) + f = torch.where((x[mask] < self.x_min) | (x[mask] > self.x_max), torch.zeros_like(f), f) + out[mask] = f + return out.clamp_min(0.0) def log_pdf(self, x: torch.Tensor) -> torch.Tensor: """Compute the log PDF of the fitted KDE at ``x``. @@ -295,6 +472,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ return -self.log_pdf(x).mean() + def __str__(self): """String representation of the ``kdeCDFPPF1D`` object. @@ -308,11 +486,177 @@ def __str__(self): "x_min": float(round(self.x_min, 4)), "x_max": float(round(self.x_max, 4)), "num_step_grid": int(self.num_step_grid), + "h": float(self.h), "dtype": self.dtype, "device": self.device, } - params_str = pformat(params, sort_dicts=False, underscore_numbers=True) - return f"{header}\n{params_str[1:-1]}\n\n" + from pprint import pformat + return f"{header}\n{pformat(params, sort_dicts=False, underscore_numbers=True)[1:-1]}\n\n" + +############################################## +# previous fastKDE version +############################################## +# class kdeCDFPPF1D(torch.nn.Module): +# _EPS = _EPS + +# def __init__( +# self, +# x: torch.Tensor, +# num_step_grid: int = None, +# x_min: float = None, +# x_max: float = None, +# pad: float = 0.1, +# ): +# """1D KDE CDF/PPF using ``fastKDE`` + Simpson's rule. Given a sample ``x``, fits a kernel +# density estimate via ``fastKDE`` on a grid of size ``num_step_grid`` (power of two plus +# one). Precomputes PDF, CDF, and their finite‐difference slopes for fast interpolation. + +# - O’Brien, T. A., Kashinath, K., Cavanaugh, N. R., Collins, W. D., & O’Brien, J. P. (2016). A fast and objective multidimensional kernel density estimation method: fastKDE. Computational Statistics & Data Analysis, 101, 148-160. +# - O’Brien, T. A., Collins, W. D., Rauscher, S. A., & Ringler, T. D. (2014). Reducing the computational cost of the ECF using a nuFFT: A fast and objective probability density estimation method. Computational Statistics & Data Analysis, 79, 222-234. + +# Args: +# x (torch.Tensor): input sample to fit the KDE. +# num_step_grid (int, optional): number of grid points for the KDE, should be power of 2 plus 1. Defaults to None. +# x_min (float, optional): minimum value of the grid. Defaults to x.min() - pad. +# x_max (float, optional): maximum value of the grid. Defaults to x.max() + pad. +# pad (float, optional): padding to extend beyond the min/max when ``x_min``/``x_max`` is None. Defaults to 1.0. +# """ +# super().__init__() +# self.num_obs = x.shape[0] +# self.x_min = x_min if x_min is not None else x.min().item() - pad +# self.x_max = x_max if x_max is not None else x.max().item() + pad +# # * power of 2 plus 1 +# if num_step_grid is None: +# num_step_grid = int(2 ** torch.log2(torch.tensor(x.numel())).ceil().item()) + 1 +# self.num_step_grid = num_step_grid +# # * fastkde +# res = fastkde.pdf(x.view(-1).cpu().numpy(), num_points=num_step_grid) +# xs = torch.from_numpy(res.var0.values).to(dtype=torch.float64) +# pdfs = torch.from_numpy(res.values).to(dtype=torch.float64).clamp_min(self._EPS) +# N = pdfs.shape[0] +# ws = torch.ones(N, dtype=torch.float64) +# ws[1:-1:2] = 4 +# ws[2:-1:2] = 2 +# h = xs[1] - xs[0] +# cdf = torch.cumsum(pdfs * ws, dim=0) * (h / 3) +# cdf = cdf / cdf[-1] +# slope_fwd = (cdf[1:] - cdf[:-1]) / h +# slope_inv = h / (cdf[1:] - cdf[:-1]) +# slope_pdf = (pdfs[1:] - pdfs[:-1]) / h +# self.register_buffer("grid_x", xs) +# self.register_buffer("grid_pdf", pdfs) +# self.register_buffer("grid_cdf", cdf) +# self.register_buffer("slope_fwd", slope_fwd) +# self.register_buffer("slope_inv", slope_inv) +# self.register_buffer("slope_pdf", slope_pdf) +# self.h = h +# # ! device agnostic +# self.register_buffer("_dd", torch.tensor([], dtype=torch.float64)) +# self.negloglik = -self.log_pdf(x).mean() + +# @property +# def device(self): +# return self._dd.device + +# @property +# def dtype(self): +# return self._dd.dtype + +# def cdf(self, x: torch.Tensor) -> torch.Tensor: +# """Compute the CDF of the fitted KDE at ``x``. + +# Args: +# x (torch.Tensor): Points at which to evaluate the CDF. +# Returns: +# torch.Tensor: CDF values at ``x``, clamped to [0, 1]. +# """ +# # ! device agnostic +# x = x.to(device=self.device, dtype=self.dtype) +# x_clamped = x.clamp(self.x_min, self.x_max) +# idx = torch.searchsorted(self.grid_x, x_clamped, right=False) +# idx = idx.clamp(1, self.grid_cdf.numel() - 1) +# y = (self.grid_cdf[idx - 1]) + (self.slope_fwd[idx - 1]) * ( +# x_clamped - self.grid_x[idx - 1] +# ) +# y = torch.where(x < self.x_min, torch.zeros_like(y), y) +# y = torch.where(x > self.x_max, torch.ones_like(y), y) +# return y.clamp(0.0, 1.0) + +# def ppf(self, q: torch.Tensor) -> torch.Tensor: +# """Compute the PPF (quantile function) of the fitted KDE at ``q``. + +# Args: +# q (torch.Tensor): Quantiles at which to evaluate the PPF. +# Returns: +# torch.Tensor: PPF values at ``q``, clamped to [x_min, x_max]. +# """ +# # ! device agnostic +# q = q.to(device=self.device, dtype=self.dtype) +# q_clamped = q.clamp(0.0, 1.0) +# idx = torch.searchsorted(self.grid_cdf, q_clamped, right=False) +# idx = idx.clamp(1, self.grid_cdf.numel() - 1) +# x = (self.grid_x[idx - 1]) + (self.slope_inv[idx - 1]) * ( +# q_clamped - self.grid_cdf[idx - 1] +# ) +# x = torch.where(q < 0.0, torch.full_like(x, self.x_min), x) +# x = torch.where(q > 1.0, torch.full_like(x, self.x_max), x) +# return x.clamp(self.x_min, self.x_max) + +# def pdf(self, x: torch.Tensor) -> torch.Tensor: +# """Compute the PDF of the fitted KDE at ``x``. + +# Args: +# x (torch.Tensor): Points at which to evaluate the PDF. +# Returns: +# torch.Tensor: PDF values at ``x``, clamped to [0, ∞). +# """ +# # ! device agnostic +# x = x.to(device=self.device, dtype=self.dtype) +# x_clamped = x.clamp(self.x_min, self.x_max) +# idx = torch.searchsorted(self.grid_x, x_clamped, right=False) +# idx = idx.clamp(1, self.grid_pdf.numel() - 1) +# f = self.grid_pdf[idx - 1] + (self.slope_pdf[idx - 1]) * (x_clamped - self.grid_x[idx - 1]) +# f = torch.where((x < self.x_min) | (x > self.x_max), torch.zeros_like(f), f) +# return f.clamp_min(0.0) + +# def log_pdf(self, x: torch.Tensor) -> torch.Tensor: +# """Compute the log PDF of the fitted KDE at ``x``. + +# Args: +# x (torch.Tensor): Points at which to evaluate the log PDF. +# Returns: +# torch.Tensor: Log PDF values at ``x``, guaranteed to be finite. +# """ +# return self.pdf(x).log().nan_to_num(posinf=0.0, neginf=-13.815510557964274) + +# def forward(self, x: torch.Tensor) -> torch.Tensor: +# """Average negative log-likelihood of the fitted KDE at ``x``. + +# Args: +# x (torch.Tensor): Points at which to evaluate the negative log-likelihood. +# Returns: +# torch.Tensor: Negative log-likelihood values at ``x``, averaged over the batch. +# """ +# return -self.log_pdf(x).mean() + +# def __str__(self): +# """String representation of the ``kdeCDFPPF1D`` object. + +# Returns: +# str: String representation of the ``kdeCDFPPF1D`` object. +# """ +# header = self.__class__.__name__ +# params = { +# "num_obs": int(self.num_obs), +# "negloglik": float(self.negloglik.round(decimals=4)), +# "x_min": float(round(self.x_min, 4)), +# "x_max": float(round(self.x_max, 4)), +# "num_step_grid": int(self.num_step_grid), +# "dtype": self.dtype, +# "device": self.device, +# } +# params_str = pformat(params, sort_dicts=False, underscore_numbers=True) +# return f"{header}\n{params_str[1:-1]}\n\n" # @torch.compile diff --git a/torchvinecopulib/util/bandwidth.py b/torchvinecopulib/util/bandwidth.py new file mode 100644 index 0000000..12f14ca --- /dev/null +++ b/torchvinecopulib/util/bandwidth.py @@ -0,0 +1,182 @@ +""" +bandwidth.py + +Pure-torch bandwidth selectors: +- isj_bandwidth(x): Improved Sheather–Jones plug-in +- icv_bandwidth(x, folds): Indirect cross-validation +""" + +import torch +import math +from .constants import _DEFAULT_CV_FOLDS, _ISJ_GRID_SIZE, _ISJ_MAX_ITER, _ISJ_TOL, _EPS, _KFOLD_LSCV_GRID_SIZE + + +# Improved Sheather–Jones (ISJ) + +def isj_bandwidth(x: torch.Tensor, grid_size: int = _ISJ_GRID_SIZE, max_iter: int = _ISJ_MAX_ITER, tol: float = _ISJ_TOL) -> torch.Tensor: + """ + ISJ plug-in selector using torch: + 1) Estimate variance, fourth derivative functional via small-bandwidth torch FFT. + 2) Solve the SJ fixed-point equation for h with torch root finding. + Returns h (scalar) + """ + n = x.numel() + # 1. Initial pilot h using normal-reference rule + # Silverman pilot + std = x.std(unbiased=True).clamp_min(1e-12) + silverman = 1.06 * std * (n ** (-1.0/5.0)) + h = silverman.clone() + + # Precompute grid + x_min, x_max = torch.min(x), torch.max(x) + L = (x_max - x_min).clamp_min(1e-12) + grid = torch.linspace(x_min - 0.5 * std, x_max + 0.5 * std, grid_size, device=x.device) + dx = grid[1] - grid[0] + + # Precompute k^2 frequencies for DCT/FFT + k = torch.arange(grid_size, device=x.device) + k_sq = (k * torch.pi / L) ** 2 + + R_K = 1.0 / (2.0 * torch.sqrt(torch.pi)) # for Gaussian + mu2_K = 1.0 + + for _ in range(max_iter): + # 2. Bin the data onto the grid + bins = torch.bucketize(x, grid) + counts = torch.zeros(grid_size, device=x.device).scatter_add_(0, bins, torch.ones_like(x, device=x.device)) + relfreq = counts / n + + # 3. DCT of the relative frequency + a_k = torch.fft.dct(relfreq, type=2, norm='ortho') # requires PyTorch >= 1.8 + + # 4. Estimate R(f'') via the DCT coefficients + t = h**2 / 2.0 + R2 = 0.5 * torch.pi ** 4 * torch.sum(k_sq ** 2 * a_k ** 2 * torch.exp(-k_sq * t)) # Equation from ArviZ (https://python.arviz.org/en/stable/_modules/arviz/stats/density_utils.html) + + # if R2 becomes tiny/NaN, bail out to Silverman + if not torch.isfinite(R2) or R2 <= 0: + h = silverman + break + + # 5. Fixed-point update + h_new = (R_K / (mu2_K ** 2 * R2 * n)) ** 0.2 + if torch.abs(h_new - h) < tol: + h = h_new + break + h = h_new + + # keep h positive and not ridiculously smaller than grid spacing + h = h.clamp_min(dx * 1e-3) + + return h + + +# K-fold cross-validation + +def _make_kfold_indices(n: int, folds: int = _DEFAULT_CV_FOLDS, seed: int = 0): + g = torch.Generator().manual_seed(seed) + perm = torch.randperm(n, generator=g) + parts = torch.chunk(perm, folds) + for i in range(folds): + val_idx = parts[i] + train_idx = torch.cat([parts[j] for j in range(folds) if j != i], dim=0) + yield train_idx, val_idx + +@torch.no_grad() +def kfold_lcv_bandwidth(x: torch.Tensor, h_grid: torch.Tensor | None = None, folds: int = _DEFAULT_CV_FOLDS) -> torch.Tensor: + """ + Bandwidth via k-fold likelihood cross-validation for 1-D KDE. + + Args: + x (torch.Tensor): 1-D samples. Non-finite values are ignored. + h_grid (torch.Tensor | None, optional): Candidate bandwidths; if None, + a log-spaced grid around a Silverman pilot is used. + folds (int, optional): Number of CV folds. Default `_DEFAULT_CV_FOLDS`. + + Returns: + torch.Tensor: Selected bandwidth (float64) from the grid; safe fallback + if data are insufficient. + """ + + x = x.view(-1).to(torch.float64) + x = x[torch.isfinite(x)] + n = x.numel() + if n < 2: + return torch.tensor(1.0, dtype=torch.float64, device=x.device) # safe fallback + + if h_grid is None: + std = x.std(unbiased=True).clamp_min(1e-12) + h0 = 1.06 * std * (n ** (-1.0 / 5.0)) + h_grid = torch.logspace(math.log10(0.5 * h0), math.log10(2.0 * h0), + steps=_KFOLD_LSCV_GRID_SIZE, dtype=torch.float64, device=x.device) + else: + h_grid = h_grid.to(dtype=torch.float64, device=x.device) + + TWO_PI = 2.0 * math.pi + best_loss = torch.tensor(float("inf"), dtype=torch.float64, device=x.device) + best_h = h_grid[0] + + for h in h_grid: + fold_losses = [] + for tr_idx, va_idx in _make_kfold_indices(n, folds=folds, seed=0): + x_tr = x[tr_idx] # [n_tr] + x_va = x[va_idx] # [n_va] + + # log f_h(x_va) = log( (1/n_tr) * sum_j phi((x_va - x_tr[j])/h) / h ) + # = logsumexp_j( -0.5*((x_va - x_tr[j])/h)^2 - log(h*sqrt(2π)) ) - log(n_tr) + dif = (x_va.unsqueeze(1) - x_tr.unsqueeze(0)) / h # [n_va, n_tr] + log_k = -0.5 * dif.pow(2) - 0.5 * math.log(TWO_PI) - math.log(h) + log_f = torch.logsumexp(log_k, dim=1) - math.log(x_tr.numel()) + fold_losses.append(-(log_f.mean())) + loss = torch.stack(fold_losses).mean() + if loss < best_loss: + best_loss = loss + best_h = h + + return best_h + +def optimal_bandwidth(x: torch.Tensor, method: str = 'isj', **kwargs) -> torch.Tensor: + """ + Robust dispatcher for 1-D KDE bandwidth selection. + + Args: + x (torch.Tensor): 1-D samples. Non-finite values are ignored. + method (str, optional): "isj", "kfold", or "auto" (ISJ → kfold → Silverman). + **kwargs: Extra parameters forwarded to the chosen method. + + Returns: + torch.Tensor: Positive scalar bandwidth (float64). + """ + + x = x.view(-1).to(torch.float64) + x = x[torch.isfinite(x)] + n = x.numel() + if n < 2: + return torch.tensor(1.0, dtype=torch.float64, device=x.device) + + std = x.std(unbiased=True).clamp_min(1e-12) + silverman = 1.06 * std * (n ** (-1.0 / 5.0)) + + if method == 'isj': + try: + return isj_bandwidth(x) + except Exception: + return silverman + + if method == 'kfold': + return kfold_lcv_bandwidth(x, kwargs.get('h_grid', None), kwargs.get('folds', _DEFAULT_CV_FOLDS)) + + if method == 'auto': + try: + h0 = isj_bandwidth(x) + except Exception: + h0 = silverman + # refine ±50% around h0 + h_grid = torch.logspace( + math.log10(0.5 * h0), math.log10(1.5 * h0), + steps=64, dtype=torch.float64, device=x.device + ) + return kfold_lcv_bandwidth(x, h_grid=h_grid, folds=kwargs.get('folds', _DEFAULT_CV_FOLDS)) + + raise ValueError(f"Unknown method {method!r}") + diff --git a/torchvinecopulib/util/constants.py b/torchvinecopulib/util/constants.py new file mode 100644 index 0000000..ef9f4bc --- /dev/null +++ b/torchvinecopulib/util/constants.py @@ -0,0 +1,6 @@ +_EPS = 1e-10 +_DEFAULT_CV_FOLDS = 5 +_ISJ_GRID_SIZE = 256 +_ISJ_MAX_ITER = 10 +_ISJ_TOL = 1e-6 +_KFOLD_LSCV_GRID_SIZE = 100 \ No newline at end of file