Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion spacecore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


from .backend import Context, BackendOps, JaxOps, NumpyOps, jax_pytree_class
from .linop import DenseLinOp, SparseLinOp, BlockDiagonalLinOp, SumToSingleLinOp, StackedLinOp
from .linop import DenseLinOp, SparseLinOp, BlockDiagonalLinOp, SumToSingleLinOp, StackedLinOp, LinOp
from .space import VectorSpace, HermitianSpace, Space, ProductSpace
from .types import DenseArray, SparseArray, ArrayLike

Expand All @@ -21,6 +21,7 @@
"jax_pytree_class",
"NumpyOps",

"LinOp",
"DenseLinOp",
"SparseLinOp",
"BlockDiagonalLinOp",
Expand Down
7 changes: 6 additions & 1 deletion spacecore/space/_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import abstractmethod
from typing import Any, Tuple
from typing import Any, Tuple, Callable

from ..backend import Context
from .._contextual import ContextBound
Expand Down Expand Up @@ -89,3 +89,8 @@ def unflatten(self, v: DenseArray) -> Any:

def _convert(self, new_ctx: Context) -> Space:
raise NotImplementedError()

def apply(self, x: Any, f: Callable) -> Any:
raise NotImplementedError(
f"{type(self).__name__} does not define functional application."
)
86 changes: 73 additions & 13 deletions spacecore/space/_herm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Tuple
from typing import Any, Tuple, Callable

from ._vector import VectorSpace
from ..types import DenseArray
Expand Down Expand Up @@ -53,15 +53,15 @@ def _check_member(self, x: Any) -> None:
if not self.is_hermitian(x):
raise TypeError("Matrix is not Hermitian (within the specified tolerances).")

def is_hermitian(self, X: DenseArray) -> bool:
def is_hermitian(self, x: DenseArray) -> bool:
ops = self.ctx.ops
Xh = ops.conj(X).T
diff = X - Xh
xh = ops.conj(x).T
diff = x - xh

# Validation is typically done outside jit, so it is OK to reduce via
# backend's .max when available (NumPy/JAX arrays have it).
adiff = ops.abs(diff)
aX = ops.abs(X)
aX = ops.abs(x)

# max(abs(diff)) <= atol + rtol*max(abs(X))
max_adiff = adiff.max()
Expand All @@ -80,22 +80,22 @@ def _as_float(v):
thresh = float(self.atol) + float(self.rtol) * max_aX_f
return max_adiff_f <= thresh

def symmetrize(self, X: DenseArray) -> DenseArray:
def symmetrize(self, x: DenseArray) -> DenseArray:
"""Project onto the Hermitian cone: (X + X^H)/2."""
return (X + X.T.conj()) * 0.5
return (x + x.T.conj()) * 0.5

def eigh(self, X: DenseArray, k: int = None) -> Tuple[DenseArray, DenseArray]:
self.check_member(X)
return self.ops.eigh(X)
def eigh(self, x: DenseArray, k: int = None) -> Tuple[DenseArray, DenseArray]:
self.check_member(x)
return self.ops.eigh(x)

def unflatten(self, v: DenseArray) -> DenseArray:
vv = self.ctx.assert_dense(v)
X = self.ops.reshape(vv, self.shape)
return self.symmetrize(X)

def psd_proj(self, X: DenseArray) -> DenseArray:
self.check_member(X)
evals, evecs = self.ops.eigh(X)
def psd_proj(self, x: DenseArray) -> DenseArray:
self.check_member(x)
evals, evecs = self.ops.eigh(x)
evals = self.ops.maximum(evals, 0.)
return self.eig_to_dense(evals, evecs)

Expand All @@ -108,3 +108,63 @@ def eig_to_dense(self, evals: DenseArray, evecs: DenseArray) -> DenseArray:

def _convert(self, new_ctx: Context) -> HermitianSpace:
return HermitianSpace(self.n, self.atol, self.rtol, self.enforce_herm, new_ctx)

def apply(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseArray:
"""
Apply a scalar function to a Hermitian matrix via spectral calculus.

For a Hermitian matrix
$$
X \in \mathbb{H}^n,
$$
with eigendecomposition
$$
X = U \operatorname{diag}(\lambda) U^*,
$$
this method returns
$$
f(X) = U \operatorname{diag}(f(\lambda)) U^*,
$$
where ``f`` is applied entrywise to the eigenvalue vector
$$
\lambda \in \mathbb{R}^n.
$$

Parameters
----------
x:
Hermitian matrix in this space. Must have shape ``(n, n)`` and
satisfy the Hermitian membership conditions of the space.
f:
Callable applied to the eigenvalues of ``x``. It should accept a
dense backend array of eigenvalues and return an array of the same
shape.

Returns
-------
DenseArray
The Hermitian matrix obtained by spectral application of ``f`` to
``x``.

Raises
------
TypeError
If ``x`` is not a valid Hermitian element of this space.

Notes
-----
This is not an entrywise matrix transformation. The function is applied
to the spectrum of ``x``, not to its matrix entries.

In particular, if
$$
X = U \operatorname{diag}(\lambda) U^*,
$$
then the eigenvectors are preserved and only the eigenvalues are
transformed.
"""
self.check_member(x)
evals, evecs = self.eigh(x)
fevals = self._apply_entrywise(evals, f)

return self.eig_to_dense(fevals, evecs)
52 changes: 51 additions & 1 deletion spacecore/space/_product.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Tuple, List, Sequence
from typing import Any, Tuple, List, Sequence, Callable

from ._base import Space
from ..types import DenseArray
Expand Down Expand Up @@ -140,3 +140,53 @@ def unflatten(self, v: DenseArray) -> Tuple[Any, ...]:
xs.append(s.unflatten(vi))

return tuple(xs)

def apply(self, x: Tuple[Any, ...], f: Callable[[Any], Any]) -> Tuple[Any, ...]:
"""
Apply a function to each component of a product-space element.

For a product space
$$
X = X_1 \times \cdots \times X_m,
$$
and an element
$$
x = (x_1,\dots,x_m), \qquad x_i \in X_i,
$$
this method returns
$$
f(x) := \bigl(f_{X_1}(x_1), \dots, f_{X_m}(x_m)\bigr),
$$
where ``f_{X_i}`` denotes application according to the logic of the
corresponding component space ``X_i``.

Parameters
----------
x:
Tuple representing an element of this product space. Its length must
equal the arity of the product space, and each component must be a
valid member of the corresponding factor space.
f:
Callable to apply to each component. The meaning of application is
delegated to each component space via ``spaces[i].apply``.

Returns
-------
tuple[Any, ...]
Tuple of transformed components, one for each factor space.

Raises
------
TypeError
If ``x`` is not a valid product-space element.
ValueError
If ``x`` has the wrong tuple length.

Notes
-----
This method does not define a new joint functional calculus on the
product space. It applies the existing functional calculus of each
factor space independently, component by component.
"""
self.check_member(x)
return tuple(s.apply(xi, f) for s, xi in zip(self.spaces, x))
60 changes: 59 additions & 1 deletion spacecore/space/_vector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Tuple
from typing import Any, Tuple, Callable

from ._base import Space
from ..types import DenseArray
Expand Down Expand Up @@ -64,3 +64,61 @@ def unflatten(self, v: DenseArray) -> DenseArray:

def _convert(self, new_ctx: Context) -> VectorSpace:
return VectorSpace(self.shape, new_ctx)

def _apply_entrywise(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseArray:
try:
y = f(x)
except Exception:
# optional fallback if backend has vectorize/map
y = self.ops.vectorize(f)(x)
if self.ctx.enable_checks:
if y.shape != x.shape:
raise ValueError("Function application changed shape.")
return y

def apply(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseArray:
"""
Apply a scalar function to a vector-space element entrywise.

For a space element
$$
x \in \mathbb{K}^{n_1 \times \cdots \times n_k},
$$
this method returns the element
$$
y = f(x)
$$
obtained by applying ``f`` coordinatewise to the entries of ``x``.

Parameters
----------
x:
Element of this vector space. Must have shape ``self.shape`` and
dtype compatible with this space.
f:
Callable representing an entrywise transformation. It is expected
to act elementwise on backend arrays, or to be compatible with the
backend vectorization fallback.

Returns
-------
DenseArray
The transformed element, with the same shape as ``x``.

Raises
------
TypeError
If ``x`` is not a valid member of this space.
ValueError
If the result of the application does not preserve the shape of the
space element.

Notes
-----
This is the canonical functional calculus for ``VectorSpace``:
application is performed entrywise in the distinguished coordinate
representation.
"""
self.check_member(x)
y = self._apply_entrywise(x, f)
return y
Loading
Loading