Skip to content
35 changes: 35 additions & 0 deletions src/core/coeff_policy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,38 @@ end
@inline _is_trivial_method(::ConstantInterp) = true
@inline _is_trivial_method(::AbstractInterpMethod) = false
@inline _all_trivial_methods(methods::Tuple) = all(_is_trivial_method, methods)

# ── ND oneshot resolution (separate from interpolant resolution) ──
# Centralized (method, query)-aware policy: resolved once at the user API entry
# point alongside the other resolves (extrap/search/deriv) and threaded down.
# All branches fold at compile time when `methods` is a concrete tuple type.
#
# Strategy matrix:
# methods\queries | scalar (Tuple{Vararg{Real}}) | batch
# -------------------------|------------------------------|---------
# all trivial (Lin/Const) | PreCompute (no slopes) | PreCompute
# has local Hermite | OnTheFly (per-fiber 1D) | OnTheFly (per-query loop)
# global only (Cubic/Quad) | OnTheFly (skip 2^N partials) | PreCompute (amortize solve)
#
# Rationale for the local-Hermite batch case: PreCompute backend
# (`_compute_nd_partials_hetero!`) does not implement `_extract_bc(::PchipInterp)`
# etc., so the only working path for batch + Hermite is a per-query loop calling
# `_interp_nd_oneshot_onthefly`, dispatched inside `_interp_nd_oneshot_batch_dispatch!`.
#
# ForwardDiff.Dual queries (scalar) flow through OnTheFly directly — `_collapse_dims`
# promotes its pool buffer via `_promote_query_eltype(Tv, q_eval)`.
#
# Separate function from `_resolve_coeffs` (1D) avoids dispatch ambiguity with
# the interpolant-construction overloads.
@inline _resolve_coeffs_nd_oneshot(c::PreCompute, _, _) = c
@inline _resolve_coeffs_nd_oneshot(c::OnTheFly, _, _) = c
@inline function _resolve_coeffs_nd_oneshot(::AutoCoeffs, ::Tuple{Vararg{Real}}, methods)
_all_trivial_methods(methods) && return PreCompute()
return OnTheFly()
end
# Batch (non-scalar) fallback: PreCompute amortizes the build, except when the
# method tuple contains a local Hermite method which has no PreCompute backend.
@inline function _resolve_coeffs_nd_oneshot(::AutoCoeffs, queries, methods)
_has_any_local_method(methods) && return OnTheFly()
return PreCompute()
end
19 changes: 19 additions & 0 deletions src/core/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,25 @@ chains like `promote_type(Tv, Tg, Tq)`.
return isconcretetype(Tr) ? Tr : Tv
end

"""
_promote_query_eltype(::Type{Tv}, q::Tuple) -> Type

Compute the promoted output element type by folding `promote_type` over `Tv`
and the element types of the tuple `q`. Recursive on `Base.tail` for compile-time
type specialization — each step sees concrete types and collapses to a constant
through Julia's normal inference (no @generated body needed, which would suffer
from world-age issues when promotion rules for `q`'s types are defined in an
extension module loaded after FastInterpolations).

Used by the OnTheFly ND `_collapse_dims` entry points where the pool buffer
type must include the query eltype (for ForwardDiff.Dual compatibility) but
the computation must remain zero-cost for plain-Float64 queries.
"""
@inline _promote_query_eltype(::Type{Tv}, ::Tuple{}) where {Tv} = Tv
@inline function _promote_query_eltype(::Type{Tv}, q::Tuple) where {Tv}
return _promote_query_eltype(promote_type(Tv, typeof(first(q))), Base.tail(q))
end

"""
_promote_value_type(y, ::Type{Tg}) -> (Tv, y_converted)

Expand Down
13 changes: 10 additions & 3 deletions src/cubic/nd/cubic_nd_oneshot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ function cubic_interp(
bc::Union{AbstractBC, NTuple{N, AbstractBC}} = CubicFit(),
extrap::Union{AbstractExtrap, NTuple{N, AbstractExtrap}} = NoExtrap(),
search::Union{AbstractSearchPolicy, NTuple{N, AbstractSearchPolicy}} = AutoSearch(),
coeffs::AbstractCoeffStrategy = PreCompute(),
coeffs::AbstractCoeffStrategy = AutoCoeffs(),
hint::Union{Nothing, NTuple{N, Base.RefValue{Int}}} = nothing
) where {Tv, N}
# Type promotion + validation (same as constructor path)
Expand All @@ -51,6 +51,13 @@ function cubic_interp(

extraps_val = _resolve_extrap_nd(extrap, bcs, Val(N), Tv)
ops = _resolve_deriv_nd(deriv, Val(N))

# OnTheFly: skip full partials build — use sequential 1D collapse (2^N× less work)
coeffs_resolved = _resolve_coeffs_nd_oneshot(coeffs, query, ntuple(_ -> CubicInterp(), Val(N)))
if coeffs_resolved isa OnTheFly
methods = map(CubicInterp, bcs)
return _interp_nd_oneshot_onthefly(grids_typed, data, query, methods, extraps_val, searches, ops, hint)::Tr
end
return _cubic_interp_nd_oneshot(grids_typed, data, query, bcs, extraps_val, searches, ops, hint)::Tr
end

Expand All @@ -70,7 +77,7 @@ function cubic_interp(
bc::Union{AbstractBC, NTuple{N, AbstractBC}} = CubicFit(),
extrap::Union{AbstractExtrap, NTuple{N, AbstractExtrap}} = NoExtrap(),
search::Union{AbstractSearchPolicy, NTuple{N, AbstractSearchPolicy}} = AutoSearch(),
coeffs::AbstractCoeffStrategy = PreCompute(),
coeffs::AbstractCoeffStrategy = AutoCoeffs(),
hint::Union{Nothing, NTuple{N, Base.RefValue{Int}}} = nothing
) where {Tv, N}
Tg = _promote_grid_eltype(grids)
Expand Down Expand Up @@ -215,7 +222,7 @@ function cubic_interp!(
bc::Union{AbstractBC, NTuple{N, AbstractBC}} = CubicFit(),
extrap::Union{AbstractExtrap, NTuple{N, AbstractExtrap}} = NoExtrap(),
search::Union{AbstractSearchPolicy, NTuple{N, AbstractSearchPolicy}} = AutoSearch(),
coeffs::AbstractCoeffStrategy = PreCompute(),
coeffs::AbstractCoeffStrategy = AutoCoeffs(),
hint::Union{Nothing, NTuple{N, Base.RefValue{Int}}} = nothing
) where {Tv, N}
_query_check_ndims(queries, Val(N))
Expand Down
39 changes: 27 additions & 12 deletions src/hetero/hetero_eval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,23 @@ end
end

# ========================================
# Sequential Dimension Collapse
# Sequential Dimension Collapse (Pool-Based)
# ========================================
# Recursive type-stable dispatch: each step removes dim 1 and recurses
# with Base.tail of all tuples. Julia infers concrete types at each level.

# Base case: 1D data → one-shot eval final dimension
# Intermediate arrays are pool-allocated via @with_pool (zero heap alloc after warmup).
#
# The first argument `::Type{Tr}` is the promoted result type, computed once at the
# entry point via `_promote_query_eltype(Tv, q_eval)`. This is the buffer type for
# intermediate results — for plain Float64 queries it typically equals the data
# element type (preserving zero-alloc), while for AD (ForwardDiff.Dual) queries it
# promotes to the Dual-compatible type so the query-dependent intermediate values
# fit into the pool buffer. `Tr` is plumbed unchanged through the recursion.

# Base case: 1D data → one-shot eval final dimension (Tr is carried but unused —
# the 1D scalar eval returns a scalar directly, no buffer needed).
@inline function _collapse_dims(
::Type,
data::AbstractVector,
grids::Tuple{AbstractVector},
methods::Tuple{AbstractInterpMethod},
Expand All @@ -80,19 +90,19 @@ end
end

# Recursive case: collapse dim 1 → (M-1)D array, then recurse
@inline function _collapse_dims(
data::AbstractArray{Tv, M},
@inline @with_pool pool function _collapse_dims(
::Type{Tr},
data::AbstractArray{<:Any, M},
grids::Tuple{AbstractVector, Vararg{AbstractVector}},
methods::Tuple{AbstractInterpMethod, Vararg{AbstractInterpMethod}},
extraps::Tuple{AbstractExtrap, Vararg{AbstractExtrap}},
q_eval::Tuple{Real, Vararg{Real}},
ops::Tuple{AbstractEvalOp, Vararg{AbstractEvalOp}},
searches::Tuple{AbstractSearchPolicy, Vararg{AbstractSearchPolicy}},
hints,
) where {Tv, M}
# Allocate intermediate array for collapsed result
) where {Tr, M}
remaining_size = Base.tail(size(data))
result = Array{Tv}(undef, remaining_size...)
result = acquire!(pool, Tr, remaining_size)
hint_1 = _first_hint(hints)

# Collapse first dimension: for each fiber along dim 1, one-shot eval
Expand All @@ -104,9 +114,9 @@ end
)
end

# Recurse with remaining dimensions
# Recurse with remaining dimensions (Tr unchanged — all levels share one buffer type)
return _collapse_dims(
result, Base.tail(grids), Base.tail(methods),
Tr, result, Base.tail(grids), Base.tail(methods),
Base.tail(extraps), Base.tail(q_eval), Base.tail(ops),
Base.tail(searches), _tail_hints(hints)
)
Expand All @@ -125,7 +135,10 @@ end
hints,
) where {Tg, Tv, N, G, S, M, E, P}
q_eval = _handle_all_extraps(query, itp.grids, itp.extraps)
return _collapse_dims(itp.data, itp.grids, itp.methods, itp.extraps, q_eval, ops, searches, hints)
# Tr promotes data eltype with query eltypes → Dual-safe pool buffers for AD.
# Recursive type fold specializes at compile time for each concrete query tuple.
Tr = _promote_query_eltype(Tv, q_eval)
return _collapse_dims(Tr, itp.data, itp.grids, itp.methods, itp.extraps, q_eval, ops, searches, hints)
end

# PreCompute path: precomputed partials + local kernel eval (O(1) per query)
Expand Down Expand Up @@ -201,7 +214,9 @@ end
ops::NTuple{N, AbstractEvalOp},
) where {Tg, Tv, N, G, S, M, E, P}
data, grids, methods, extraps, q_eval, searches, hints = cell
return _collapse_dims(data, grids, methods, extraps, q_eval, ops, searches, hints)
# Tr promotes data eltype with query eltypes → Dual-safe pool buffers for AD.
Tr = _promote_query_eltype(Tv, q_eval)
return _collapse_dims(Tr, data, grids, methods, extraps, q_eval, ops, searches, hints)
end

# PreCompute: cell stores precomputed cell location (locate-once optimization)
Expand Down
4 changes: 3 additions & 1 deletion src/hetero/hetero_nointerp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,9 @@ positions, filters all per-axis tuples to Real-only axes, delegates to existing
rsrc = ($(r_searches...),)
search_r = _resolve_search_nd(rsrc, Val($N_r), rq)
hint_r = hint === nothing ? nothing : ($([:(hint[$d]) for d in real_dims]...),)
return _collapse_dims(d_sliced, rg, rm, re, q_eval, ro, search_r, hint_r)
# Tr promotes sliced-data eltype with query eltypes → Dual-safe pool buffers for AD
Tr = _promote_query_eltype(eltype(d_sliced), q_eval)
return _collapse_dims(Tr, d_sliced, rg, rm, re, q_eval, ro, search_r, hint_r)
end
end
end
Expand Down
Loading
Loading