From 447469b2491546bc7c44ab84b9759adeffcce86c Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Mon, 6 Apr 2026 18:49:02 -0300 Subject: [PATCH 1/5] Add apply method to Space --- spacecore/__init__.py | 2 +- spacecore/space/_base.py | 7 ++++++- spacecore/space/_herm.py | 33 ++++++++++++++++++++------------- spacecore/space/_product.py | 6 +++++- spacecore/space/_vector.py | 18 +++++++++++++++++- 5 files changed, 49 insertions(+), 17 deletions(-) diff --git a/spacecore/__init__.py b/spacecore/__init__.py index 7a91fa0..6449822 100644 --- a/spacecore/__init__.py +++ b/spacecore/__init__.py @@ -2,7 +2,7 @@ from .backend import Context, BackendOps, JaxOps, NumpyOps, jax_pytree_class -from .linop import DenseLinOp, SparseLinOp, BlockDiagonalLinOp, SumToSingleLinOp, StackedLinOp +from .linop import DenseLinOp, SparseLinOp, BlockDiagonalLinOp, SumToSingleLinOp, StackedLinOp, LinOp from .space import VectorSpace, HermitianSpace, Space, ProductSpace from .types import DenseArray, SparseArray, ArrayLike diff --git a/spacecore/space/_base.py b/spacecore/space/_base.py index 586bea0..44bf5b7 100644 --- a/spacecore/space/_base.py +++ b/spacecore/space/_base.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Tuple +from typing import Any, Tuple, Callable from ..backend import Context from .._contextual import ContextBound @@ -89,3 +89,8 @@ def unflatten(self, v: DenseArray) -> Any: def _convert(self, new_ctx: Context) -> Space: raise NotImplementedError() + + def apply(self, x: Any, f: Callable) -> Any: + raise NotImplementedError( + f"{type(self).__name__} does not define functional application." + ) diff --git a/spacecore/space/_herm.py b/spacecore/space/_herm.py index 5aeba18..5461b5f 100644 --- a/spacecore/space/_herm.py +++ b/spacecore/space/_herm.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Tuple +from typing import Any, Tuple, Callable from ._vector import VectorSpace from ..types import DenseArray @@ -53,15 +53,15 @@ def _check_member(self, x: Any) -> None: if not self.is_hermitian(x): raise TypeError("Matrix is not Hermitian (within the specified tolerances).") - def is_hermitian(self, X: DenseArray) -> bool: + def is_hermitian(self, x: DenseArray) -> bool: ops = self.ctx.ops - Xh = ops.conj(X).T - diff = X - Xh + xh = ops.conj(x).T + diff = x - xh # Validation is typically done outside jit, so it is OK to reduce via # backend's .max when available (NumPy/JAX arrays have it). adiff = ops.abs(diff) - aX = ops.abs(X) + aX = ops.abs(x) # max(abs(diff)) <= atol + rtol*max(abs(X)) max_adiff = adiff.max() @@ -80,22 +80,22 @@ def _as_float(v): thresh = float(self.atol) + float(self.rtol) * max_aX_f return max_adiff_f <= thresh - def symmetrize(self, X: DenseArray) -> DenseArray: + def symmetrize(self, x: DenseArray) -> DenseArray: """Project onto the Hermitian cone: (X + X^H)/2.""" - return (X + X.T.conj()) * 0.5 + return (x + x.T.conj()) * 0.5 - def eigh(self, X: DenseArray, k: int = None) -> Tuple[DenseArray, DenseArray]: - self.check_member(X) - return self.ops.eigh(X) + def eigh(self, x: DenseArray, k: int = None) -> Tuple[DenseArray, DenseArray]: + self.check_member(x) + return self.ops.eigh(x) def unflatten(self, v: DenseArray) -> DenseArray: vv = self.ctx.assert_dense(v) X = self.ops.reshape(vv, self.shape) return self.symmetrize(X) - def psd_proj(self, X: DenseArray) -> DenseArray: - self.check_member(X) - evals, evecs = self.ops.eigh(X) + def psd_proj(self, x: DenseArray) -> DenseArray: + self.check_member(x) + evals, evecs = self.ops.eigh(x) evals = self.ops.maximum(evals, 0.) return self.eig_to_dense(evals, evecs) @@ -108,3 +108,10 @@ def eig_to_dense(self, evals: DenseArray, evecs: DenseArray) -> DenseArray: def _convert(self, new_ctx: Context) -> HermitianSpace: return HermitianSpace(self.n, self.atol, self.rtol, self.enforce_herm, new_ctx) + + def apply(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseArray: + self.check_member(x) + evals, evecs = self.eigh(x) + fevals = self._apply_entrywise(evals, f) + + return self.eig_to_dense(fevals, evecs) diff --git a/spacecore/space/_product.py b/spacecore/space/_product.py index f49628d..3457b4a 100644 --- a/spacecore/space/_product.py +++ b/spacecore/space/_product.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Tuple, List, Sequence +from typing import Any, Tuple, List, Sequence, Callable from ._base import Space from ..types import DenseArray @@ -140,3 +140,7 @@ def unflatten(self, v: DenseArray) -> Tuple[Any, ...]: xs.append(s.unflatten(vi)) return tuple(xs) + + def apply(self, x: Tuple[Any, ...], f: Callable[[Any], Any]) -> Tuple[Any, ...]: + self.check_member(x) + return tuple(s.apply(xi, f) for s, xi in zip(self.spaces, x)) diff --git a/spacecore/space/_vector.py b/spacecore/space/_vector.py index d0adcb1..cd7e322 100644 --- a/spacecore/space/_vector.py +++ b/spacecore/space/_vector.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Tuple +from typing import Any, Tuple, Callable from ._base import Space from ..types import DenseArray @@ -64,3 +64,19 @@ def unflatten(self, v: DenseArray) -> DenseArray: def _convert(self, new_ctx: Context) -> VectorSpace: return VectorSpace(self.shape, new_ctx) + + def _apply_entrywise(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseArray: + try: + y = f(x) + except Exception: + # optional fallback if backend has vectorize/map + y = self.ops.vectorize(f)(x) + return y + + def apply(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseArray: + self.check_member(x) + y = self._apply_entrywise(x, f) + if self.ctx.enable_checks: + if y.shape != self.shape: + raise ValueError("Function application changed shape.") + return y From f8ec1068e562161aec1fe4eb5493888524c9b95f Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Mon, 6 Apr 2026 18:49:14 -0300 Subject: [PATCH 2/5] Add tests to apply method in Space --- tests/test_space_apply.py | 162 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 tests/test_space_apply.py diff --git a/tests/test_space_apply.py b/tests/test_space_apply.py new file mode 100644 index 0000000..1458aaf --- /dev/null +++ b/tests/test_space_apply.py @@ -0,0 +1,162 @@ +import numpy as np +import pytest + +from spacecore import Context, NumpyOps, JaxOps +from spacecore.space import VectorSpace, HermitianSpace, ProductSpace + + +def _np_ctx(): + # Adjust if your backend alias is different. + return Context(ops=NumpyOps(), enable_checks=True) + + +def _jax_ctx(): + # Adjust if your backend alias is different. + return Context(ops=JaxOps(), enable_checks=False) + + +def test_vector_apply_numpy_entrywise(): + sp = VectorSpace((3,), ctx=_np_ctx()) + x = np.asarray([1.0, 2.0, 3.0], dtype=sp.dtype) + + y = sp.apply(x, np.square) + + expected = np.asarray([1.0, 4.0, 9.0], dtype=sp.dtype) + np.testing.assert_allclose(y, expected) + + +def test_vector_apply_numpy_shape_change_raises(): + sp = VectorSpace((3,), ctx=_np_ctx()) + x = np.asarray([1.0, 2.0, 3.0], dtype=sp.dtype) + + def bad_f(z): + return z[:2] + + with pytest.raises(ValueError, match="changed shape"): + sp.apply(x, bad_f) + + +def test_product_apply_numpy_componentwise(): + sp1 = VectorSpace((2,), ctx=_np_ctx()) + sp2 = VectorSpace((3,), ctx=_np_ctx()) + psp = ProductSpace((sp1, sp2), ctx=_np_ctx()) + + x = ( + np.asarray([1.0, 2.0], dtype=psp.spaces[0].dtype), + np.asarray([3.0, 4.0, 5.0], dtype=psp.spaces[1].dtype), + ) + + y = psp.apply(x, np.square) + + np.testing.assert_allclose( + y[0], np.asarray([1.0, 4.0], dtype=psp.spaces[0].dtype) + ) + np.testing.assert_allclose( + y[1], np.asarray([9.0, 16.0, 25.0], dtype=psp.spaces[1].dtype) + ) + + +def test_hermitian_apply_numpy_spectral_on_diagonal(): + sp = HermitianSpace(3, ctx=_np_ctx()) + x = np.diag(np.asarray([1.0, 2.0, 3.0], dtype=sp.dtype)) + + y = sp.apply(x, np.exp) + + expected = np.diag(np.exp(np.asarray([1.0, 2.0, 3.0], dtype=sp.dtype))) + np.testing.assert_allclose(y, expected, rtol=1e-12, atol=1e-12) + + +def test_hermitian_apply_numpy_preserves_hermitian_structure(): + sp = HermitianSpace(2, ctx=_np_ctx()) + x = np.asarray([[2.0, 1.0], [1.0, 3.0]], dtype=sp.dtype) + + y = sp.apply(x, np.exp) + + np.testing.assert_allclose(y, y.T.conj(), rtol=1e-12, atol=1e-12) + + +@pytest.mark.parametrize( + "factory, expected", + [ + ( + lambda sp: np.asarray([1.0, 2.0, 3.0], dtype=sp.dtype), + np.asarray([1.0, 4.0, 9.0]), + ), + ], +) +def test_vector_apply_numpy_basic_regression(factory, expected): + sp = VectorSpace((3,), ctx=_np_ctx()) + x = factory(sp) + y = sp.apply(x, np.square) + np.testing.assert_allclose(y, expected.astype(sp.dtype)) + + +def test_vector_apply_jax_matches_eager_and_compiles(): + jax = pytest.importorskip("jax") + jnp = pytest.importorskip("jax.numpy") + + sp = VectorSpace((3,), ctx=_jax_ctx()) + x = jnp.asarray([1.0, 2.0, 3.0], dtype=sp.dtype) + + f = jnp.square + + y_eager = sp.apply(x, f) + + @jax.jit + def compiled_apply(z): + return sp.apply(z, f) + + y_jit = compiled_apply(x) + + np.testing.assert_allclose(np.asarray(y_eager), np.asarray([1.0, 4.0, 9.0])) + np.testing.assert_allclose(np.asarray(y_jit), np.asarray([1.0, 4.0, 9.0])) + + +def test_product_apply_jax_matches_eager_and_compiles(): + jax = pytest.importorskip("jax") + jnp = pytest.importorskip("jax.numpy") + + sp1 = VectorSpace((2,), ctx=_jax_ctx()) + sp2 = VectorSpace((3,), ctx=_jax_ctx()) + psp = ProductSpace((sp1, sp2), ctx=_jax_ctx()) + + x = ( + jnp.asarray([1.0, 2.0], dtype=psp.spaces[0].dtype), + jnp.asarray([3.0, 4.0, 5.0], dtype=psp.spaces[1].dtype), + ) + + f = jnp.square + + y_eager = psp.apply(x, f) + + @jax.jit + def compiled_apply(a, b): + return psp.apply((a, b), f) + + y_jit = compiled_apply(*x) + + np.testing.assert_allclose(np.asarray(y_eager[0]), np.asarray([1.0, 4.0])) + np.testing.assert_allclose(np.asarray(y_eager[1]), np.asarray([9.0, 16.0, 25.0])) + np.testing.assert_allclose(np.asarray(y_jit[0]), np.asarray([1.0, 4.0])) + np.testing.assert_allclose(np.asarray(y_jit[1]), np.asarray([9.0, 16.0, 25.0])) + + +def test_hermitian_apply_jax_diagonal_matches_eager_and_compiles(): + jax = pytest.importorskip("jax") + jnp = pytest.importorskip("jax.numpy") + + sp = HermitianSpace(3, ctx=_jax_ctx()) + x = jnp.diag(jnp.asarray([1.0, 2.0, 3.0], dtype=sp.dtype)) + f = jnp.exp + + y_eager = sp.apply(x, f) + + @jax.jit + def compiled_apply(z): + return sp.apply(z, f) + + y_jit = compiled_apply(x) + + expected = np.diag(np.exp(np.asarray([1.0, 2.0, 3.0]))) + np.testing.assert_allclose(np.asarray(y_eager), expected, rtol=1e-6, atol=1e-6) + np.testing.assert_allclose(np.asarray(y_jit), expected, rtol=1e-6, atol=1e-6) \ No newline at end of file From 32ec767bdcdfc0bc21c028d7ae947187b10ebbe2 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Mon, 6 Apr 2026 18:57:51 -0300 Subject: [PATCH 3/5] Move shape change check to _apply_entrywise --- spacecore/space/_vector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/spacecore/space/_vector.py b/spacecore/space/_vector.py index cd7e322..ace82b8 100644 --- a/spacecore/space/_vector.py +++ b/spacecore/space/_vector.py @@ -71,12 +71,12 @@ def _apply_entrywise(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) except Exception: # optional fallback if backend has vectorize/map y = self.ops.vectorize(f)(x) + if self.ctx.enable_checks: + if y.shape != x.shape: + raise ValueError("Function application changed shape.") return y def apply(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseArray: self.check_member(x) y = self._apply_entrywise(x, f) - if self.ctx.enable_checks: - if y.shape != self.shape: - raise ValueError("Function application changed shape.") return y From 6476fada902412c0d2f6717275106b89513623fd Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Mon, 6 Apr 2026 19:13:26 -0300 Subject: [PATCH 4/5] Add docstrings --- spacecore/space/_herm.py | 53 +++++++++++++++++++++++++++++++++++++ spacecore/space/_product.py | 46 ++++++++++++++++++++++++++++++++ spacecore/space/_vector.py | 42 +++++++++++++++++++++++++++++ 3 files changed, 141 insertions(+) diff --git a/spacecore/space/_herm.py b/spacecore/space/_herm.py index 5461b5f..d4bbef7 100644 --- a/spacecore/space/_herm.py +++ b/spacecore/space/_herm.py @@ -110,6 +110,59 @@ def _convert(self, new_ctx: Context) -> HermitianSpace: return HermitianSpace(self.n, self.atol, self.rtol, self.enforce_herm, new_ctx) def apply(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseArray: + """ + Apply a scalar function to a Hermitian matrix via spectral calculus. + + For a Hermitian matrix + $$ + X \in \mathbb{H}^n, + $$ + with eigendecomposition + $$ + X = U \operatorname{diag}(\lambda) U^*, + $$ + this method returns + $$ + f(X) = U \operatorname{diag}(f(\lambda)) U^*, + $$ + where ``f`` is applied entrywise to the eigenvalue vector + $$ + \lambda \in \mathbb{R}^n. + $$ + + Parameters + ---------- + x: + Hermitian matrix in this space. Must have shape ``(n, n)`` and + satisfy the Hermitian membership conditions of the space. + f: + Callable applied to the eigenvalues of ``x``. It should accept a + dense backend array of eigenvalues and return an array of the same + shape. + + Returns + ------- + DenseArray + The Hermitian matrix obtained by spectral application of ``f`` to + ``x``. + + Raises + ------ + TypeError + If ``x`` is not a valid Hermitian element of this space. + + Notes + ----- + This is not an entrywise matrix transformation. The function is applied + to the spectrum of ``x``, not to its matrix entries. + + In particular, if + $$ + X = U \operatorname{diag}(\lambda) U^*, + $$ + then the eigenvectors are preserved and only the eigenvalues are + transformed. + """ self.check_member(x) evals, evecs = self.eigh(x) fevals = self._apply_entrywise(evals, f) diff --git a/spacecore/space/_product.py b/spacecore/space/_product.py index 3457b4a..ee95434 100644 --- a/spacecore/space/_product.py +++ b/spacecore/space/_product.py @@ -142,5 +142,51 @@ def unflatten(self, v: DenseArray) -> Tuple[Any, ...]: return tuple(xs) def apply(self, x: Tuple[Any, ...], f: Callable[[Any], Any]) -> Tuple[Any, ...]: + """ + Apply a function to each component of a product-space element. + + For a product space + $$ + X = X_1 \times \cdots \times X_m, + $$ + and an element + $$ + x = (x_1,\dots,x_m), \qquad x_i \in X_i, + $$ + this method returns + $$ + f(x) := \bigl(f_{X_1}(x_1), \dots, f_{X_m}(x_m)\bigr), + $$ + where ``f_{X_i}`` denotes application according to the logic of the + corresponding component space ``X_i``. + + Parameters + ---------- + x: + Tuple representing an element of this product space. Its length must + equal the arity of the product space, and each component must be a + valid member of the corresponding factor space. + f: + Callable to apply to each component. The meaning of application is + delegated to each component space via ``spaces[i].apply``. + + Returns + ------- + tuple[Any, ...] + Tuple of transformed components, one for each factor space. + + Raises + ------ + TypeError + If ``x`` is not a valid product-space element. + ValueError + If ``x`` has the wrong tuple length. + + Notes + ----- + This method does not define a new joint functional calculus on the + product space. It applies the existing functional calculus of each + factor space independently, component by component. + """ self.check_member(x) return tuple(s.apply(xi, f) for s, xi in zip(self.spaces, x)) diff --git a/spacecore/space/_vector.py b/spacecore/space/_vector.py index ace82b8..97c7e6c 100644 --- a/spacecore/space/_vector.py +++ b/spacecore/space/_vector.py @@ -77,6 +77,48 @@ def _apply_entrywise(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) return y def apply(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseArray: + """ + Apply a scalar function to a vector-space element entrywise. + + For a space element + $$ + x \in \mathbb{K}^{n_1 \times \cdots \times n_k}, + $$ + this method returns the element + $$ + y = f(x) + $$ + obtained by applying ``f`` coordinatewise to the entries of ``x``. + + Parameters + ---------- + x: + Element of this vector space. Must have shape ``self.shape`` and + dtype compatible with this space. + f: + Callable representing an entrywise transformation. It is expected + to act elementwise on backend arrays, or to be compatible with the + backend vectorization fallback. + + Returns + ------- + DenseArray + The transformed element, with the same shape as ``x``. + + Raises + ------ + TypeError + If ``x`` is not a valid member of this space. + ValueError + If the result of the application does not preserve the shape of the + space element. + + Notes + ----- + This is the canonical functional calculus for ``VectorSpace``: + application is performed entrywise in the distinguished coordinate + representation. + """ self.check_member(x) y = self._apply_entrywise(x, f) return y From d5b74b69559a51be572713dfa963cf237e997bd7 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Mon, 6 Apr 2026 19:15:49 -0300 Subject: [PATCH 5/5] Add LinOp to __all__ --- spacecore/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/spacecore/__init__.py b/spacecore/__init__.py index 6449822..7bedc31 100644 --- a/spacecore/__init__.py +++ b/spacecore/__init__.py @@ -21,6 +21,7 @@ "jax_pytree_class", "NumpyOps", + "LinOp", "DenseLinOp", "SparseLinOp", "BlockDiagonalLinOp",