diff --git a/benchmarks/benchmarks/sparse_linalg_solve.py b/benchmarks/benchmarks/sparse_linalg_solve.py index d1931a01ed58..c09fe0151521 100644 --- a/benchmarks/benchmarks/sparse_linalg_solve.py +++ b/benchmarks/benchmarks/sparse_linalg_solve.py @@ -1,6 +1,8 @@ """ Check the speed of the conjugate gradient solver. """ +import inspect + import numpy as np from numpy.testing import assert_equal @@ -8,7 +10,7 @@ with safe_import(): from scipy import linalg, sparse - from scipy.sparse.linalg import cg, minres, gmres, tfqmr, spsolve + from scipy.sparse.linalg import cg, minres, gmres, tfqmr, spsolve, LinearOperator with safe_import(): from scipy.sparse.linalg import lgmres with safe_import(): @@ -18,7 +20,11 @@ def _create_sparse_poisson1d(n): # Make Gilbert Strang's favorite matrix # http://www-math.mit.edu/~gs/PIX/cupcakematrix.jpg - P1d = sparse.diags([[-1]*(n-1), [2]*n, [-1]*(n-1)], [-1, 0, 1]) + P1d = sparse.diags_array( + [[-1]*(n-1), [2]*n, [-1]*(n-1)], + offsets=[-1, 0, 1], + dtype=np.float64 + ) assert_equal(P1d.shape, (n, n)) return P1d @@ -27,12 +33,19 @@ def _create_sparse_poisson2d(n): P1d = _create_sparse_poisson1d(n) P2d = sparse.kronsum(P1d, P1d) assert_equal(P2d.shape, (n*n, n*n)) - return P2d.tocsr() + return sparse.csr_array(P2d) + + +def _create_sparse_poisson2d_coo(n): + P1d = _create_sparse_poisson1d(n) + P2d = sparse.kronsum(P1d, P1d) + assert_equal(P2d.shape, (n*n, n*n)) + return sparse.coo_array(P2d) class Bench(Benchmark): params = [ - [4, 6, 10, 16, 25, 40, 64, 100], + [4, 8, 16, 32, 64, 128, 256, 512], ['dense', 'spsolve', 'cg', 'minres', 'gmres', 'lgmres', 'gcrotmk', 'tfqmr'] ] @@ -57,6 +70,90 @@ def time_solve(self, n, solver): self.mapping[solver](self.P_sparse, self.b) +class BatchedCG(Benchmark): + params = [ + [2, 4, 6, 8, 16, 32, 64], + [1, 10, 100, 500, 1000, 5000, 10000] + ] + param_names = ['(n,n)', 'batch_size'] + + def setup(self, n, batch_size): + if n >= 32 and batch_size >= 500: + raise NotImplementedError() + if n >= 16 and batch_size > 5000: + raise NotImplementedError() + rng = np.random.default_rng(42) + + self.batched = "xp" in inspect.signature(LinearOperator.__init__).parameters + if self.batched: + P_sparse = _create_sparse_poisson2d_coo(n) + if batch_size > 1: + self.P_sparse = sparse.vstack( + [P_sparse] * batch_size, format="coo" + ).reshape(batch_size, n*n, n*n) + self.b = rng.standard_normal((batch_size, n*n)) + else: + self.P_sparse = P_sparse + self.b = rng.standard_normal(n*n) + else: + self.P_sparse = _create_sparse_poisson2d(n) + self.b = [rng.standard_normal(n*n) for _ in range(batch_size)] + + def time_solve(self, n, batch_size): + if self.batched: + cg(self.P_sparse, self.b) + else: + for i in range(batch_size): + cg(self.P_sparse, self.b[i]) + + +def _create_dense_random(n, batch_shape=None): + rng = np.random.default_rng(42) + M = rng.standard_normal((n*n, n*n)) + reg = 1e-3 + if batch_shape: + M = np.broadcast_to(M[np.newaxis, ...], (*batch_shape, n*n, n*n)) + + def matvec(x): + return np.squeeze(M.mT @ (M @ x[..., np.newaxis]), axis=-1) + reg * x + + return LinearOperator(shape=M.shape, matvec=matvec, dtype=np.float64) + + +class BatchedCGDense(Benchmark): + params = [ + [2, 4, 8, 16, 24], + [1, 10, 100, 500, 1000, 5000, 10000] + ] + param_names = ['(n,n)', 'batch_size'] + + def setup(self, n, batch_size): + if n >= 24 and batch_size > 100: + raise NotImplementedError() + if n >= 16 and batch_size > 500: + raise NotImplementedError() + rng = np.random.default_rng(42) + + self.batched = "xp" in inspect.signature(LinearOperator.__init__).parameters + if self.batched: + if batch_size > 1: + self.A = _create_dense_random(n, batch_shape=(batch_size,)) + self.b = rng.standard_normal((batch_size, n*n)) + else: + self.A = _create_dense_random(n) + self.b = rng.standard_normal(n*n) + else: + self.A = _create_dense_random(n) + self.b = [rng.standard_normal(n*n) for _ in range(batch_size)] + + def time_solve(self, n, batch_size): + if self.batched: + cg(self.A, self.b) + else: + for i in range(batch_size): + cg(self.A, self.b[i]) + + class Lgmres(Benchmark): params = [ [10, 50, 100, 1000, 10000], diff --git a/doc/source/_templates/autosummary/class.rst b/doc/source/_templates/autosummary/class.rst index e7637016a652..be0fb6e298e7 100644 --- a/doc/source/_templates/autosummary/class.rst +++ b/doc/source/_templates/autosummary/class.rst @@ -18,12 +18,12 @@ .. autosummary:: :toctree: {% for item in all_methods %} - {%- if not item.startswith('_') or item in ['__call__', '__mul__', '__getitem__', '__len__', '__pow__'] %} + {%- if not item.startswith('_') or item in ['__call__', '__mul__', '__getitem__', '__len__', '__pow__', '__matmul__', '__truediv__', '__add__', '__rmul__', '__rmatmul__'] %} {{ name }}.{{ item }} {%- endif -%} {%- endfor %} {% for item in inherited_members %} - {%- if item in ['__call__', '__mul__', '__getitem__', '__len__', '__pow__'] %} + {%- if item in ['__call__', '__mul__', '__getitem__', '__len__', '__pow__', '__matmul__', '__truediv__', '__add__', '__rmul__', '__rmatmul__'] %} {{ name }}.{{ item }} {%- endif -%} {%- endfor %} diff --git a/doc/source/conf.py b/doc/source/conf.py index a26fe3bdab6d..fa9f9cbe6b05 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -296,7 +296,8 @@ # Generate plots for example sections numpydoc_use_plots = True np_docscrape.ClassDoc.extra_public_methods = [ # should match class.rst - '__call__', '__mul__', '__getitem__', '__len__', + '__call__', '__mul__', '__getitem__', '__len__', '__pow__', '__matmul__', + '__truediv__', '__add__', '__rmul__', '__rmatmul__' ] # ----------------------------------------------------------------------------- diff --git a/pixi.lock b/pixi.lock index aaa2c02968a7..7556dbdf7bcd 100644 --- a/pixi.lock +++ b/pixi.lock @@ -161,6 +161,7 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-7_kmp_llvm.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.4.1-pyhe01879c_0.conda + - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/attrs-25.4.0-pyhcf101f3_1.conda - conda: https://prefix.dev/conda-forge/linux-64/backports.zstd-1.2.0-py312h90b7ffd_0.conda - conda: https://prefix.dev/conda-forge/linux-64/blas-devel-3.11.0-4_hcf00494_mkl.conda @@ -176,8 +177,10 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/coverage-7.13.0-py312h8a5da7c_0.conda - conda: https://prefix.dev/conda-forge/noarch/cpython-3.12.12-py312hd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/dask-core-2025.11.0-pyhcf101f3_0.conda + - conda: https://prefix.dev/conda-forge/noarch/decorator-5.2.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/exceptiongroup-1.3.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/execnet-2.1.2-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/executing-2.2.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/filelock-3.20.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/fmt-12.0.0-h2b0788b_0.conda - conda: https://prefix.dev/conda-forge/noarch/fsspec-2025.12.0-pyhd8ed1ab_0.conda @@ -190,8 +193,11 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/idna-3.11-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/importlib-metadata-8.7.0-pyhe01879c_1.conda - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.3.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/ipython-9.8.0-pyh53cf698_0.conda + - conda: https://prefix.dev/conda-forge/noarch/ipython_pygments_lexers-1.1.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/jax-0.7.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/jaxlib-0.7.2-cpu_py312h9a1a051_2.conda + - conda: https://prefix.dev/conda-forge/noarch/jedi-0.19.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/jinja2-3.1.6-pyhcf101f3_1.conda - conda: https://prefix.dev/conda-forge/linux-64/ld_impl_linux-64-2.45-default_hbd61a6d_104.conda - conda: https://prefix.dev/conda-forge/linux-64/libabseil-20250512.1-cxx17_hba17884_0.conda @@ -226,6 +232,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/markupsafe-3.0.3-py312h8a5da7c_0.conda - conda: https://prefix.dev/conda-forge/noarch/marray-python-0.0.12-pyh332efcf_0.conda + - conda: https://prefix.dev/conda-forge/noarch/matplotlib-inline-0.2.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/mkl-2025.3.0-h0e700b2_462.conda - conda: https://prefix.dev/conda-forge/linux-64/mkl-devel-2025.3.0-ha770c72_462.conda - conda: https://prefix.dev/conda-forge/linux-64/mkl-include-2025.3.0-hf2ce2f3_462.conda @@ -240,10 +247,15 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/opt_einsum-3.4.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/optree-0.18.0-py312hd9148b4_0.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-25.0-pyh29332c3_1.conda + - conda: https://prefix.dev/conda-forge/noarch/parso-0.8.5-pyhcf101f3_0.conda - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/pexpect-4.9.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/platformdirs-4.5.1-pyhcf101f3_0.conda - conda: https://prefix.dev/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda - conda: https://prefix.dev/conda-forge/noarch/pooch-1.8.2-pyhd8ed1ab_3.conda + - conda: https://prefix.dev/conda-forge/noarch/prompt-toolkit-3.0.52-pyha770c72_0.conda + - conda: https://prefix.dev/conda-forge/noarch/ptyprocess-0.7.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/pure_eval-0.2.3-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/pybind11-3.0.1-pyh7a1b43c_0.conda - conda: https://prefix.dev/conda-forge/noarch/pybind11-abi-11-hc364b38_1.conda - conda: https://prefix.dev/conda-forge/noarch/pybind11-global-3.0.1-pyhc7ab6ef_0.conda @@ -266,22 +278,26 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/sleef-3.9.0-ha0421bc_0.conda - conda: https://prefix.dev/conda-forge/noarch/sortedcontainers-2.4.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/spin-0.15-pyh8f84b5b_0.conda + - conda: https://prefix.dev/conda-forge/noarch/stack_data-0.6.3-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sympy-1.14.0-pyh2585a3b_105.conda - conda: https://prefix.dev/conda-forge/linux-64/tbb-2022.3.0-h8d10470_1.conda - conda: https://prefix.dev/conda-forge/noarch/threadpoolctl-3.6.0-pyhecae5ae_0.conda - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_ha0e22de_103.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.3.0-pyhcf101f3_0.conda - conda: https://prefix.dev/conda-forge/noarch/toolz-1.1.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/traitlets-5.14.3-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/typing-extensions-4.15.0-h396c80c_0.conda - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.15.0-pyhcf101f3_0.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda - conda: https://prefix.dev/conda-forge/noarch/urllib3-2.6.1-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/wcwidth-0.2.14-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/yaml-0.2.5-h280c20c_3.conda - conda: https://prefix.dev/conda-forge/noarch/zipp-3.23.0-pyhcf101f3_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda osx-arm64: - conda: https://prefix.dev/conda-forge/osx-arm64/_openmp_mutex-4.5-7_kmp_llvm.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.4.1-pyhe01879c_0.conda + - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/attrs-25.4.0-pyhcf101f3_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/backports.zstd-1.2.0-py313hf42fe89_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/blas-devel-3.11.0-4_h11c0a38_openblas.conda @@ -297,8 +313,10 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/coverage-7.13.0-py313h7d74516_0.conda - conda: https://prefix.dev/conda-forge/noarch/cpython-3.13.11-py313hd8ed1ab_100.conda - conda: https://prefix.dev/conda-forge/noarch/dask-core-2025.11.0-pyhcf101f3_0.conda + - conda: https://prefix.dev/conda-forge/noarch/decorator-5.2.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/exceptiongroup-1.3.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/execnet-2.1.2-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/executing-2.2.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/filelock-3.20.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/fmt-12.0.0-h669d743_0.conda - conda: https://prefix.dev/conda-forge/noarch/fsspec-2025.12.0-pyhd8ed1ab_0.conda @@ -311,8 +329,11 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/idna-3.11-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/importlib-metadata-8.7.0-pyhe01879c_1.conda - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.3.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/ipython-9.8.0-pyh53cf698_0.conda + - conda: https://prefix.dev/conda-forge/noarch/ipython_pygments_lexers-1.1.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/jax-0.7.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/jaxlib-0.7.2-cpu_py313hf0aba26_2.conda + - conda: https://prefix.dev/conda-forge/noarch/jedi-0.19.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/jinja2-3.1.6-pyhcf101f3_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libabseil-20250512.1-cxx17_hd41c47c_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libblas-3.11.0-4_h51639a9_openblas.conda @@ -339,6 +360,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/osx-arm64/markupsafe-3.0.3-py313h7d74516_0.conda - conda: https://prefix.dev/conda-forge/noarch/marray-python-0.0.12-pyh332efcf_0.conda + - conda: https://prefix.dev/conda-forge/noarch/matplotlib-inline-0.2.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ml_dtypes-0.5.4-np2py313h9ce8dcc_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/mpc-1.3.1-h8f1351a_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/mpfr-4.2.1-hb693164_3.conda @@ -352,10 +374,15 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/opt_einsum-3.4.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/optree-0.18.0-py313ha61f8ec_0.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-25.0-pyh29332c3_1.conda + - conda: https://prefix.dev/conda-forge/noarch/parso-0.8.5-pyhcf101f3_0.conda - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/pexpect-4.9.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/platformdirs-4.5.1-pyhcf101f3_0.conda - conda: https://prefix.dev/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda - conda: https://prefix.dev/conda-forge/noarch/pooch-1.8.2-pyhd8ed1ab_3.conda + - conda: https://prefix.dev/conda-forge/noarch/prompt-toolkit-3.0.52-pyha770c72_0.conda + - conda: https://prefix.dev/conda-forge/noarch/ptyprocess-0.7.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/pure_eval-0.2.3-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/pybind11-3.0.1-pyh7a1b43c_0.conda - conda: https://prefix.dev/conda-forge/noarch/pybind11-abi-11-hc364b38_1.conda - conda: https://prefix.dev/conda-forge/noarch/pybind11-global-3.0.1-pyhc7ab6ef_0.conda @@ -378,21 +405,25 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/sleef-3.9.0-hb028509_0.conda - conda: https://prefix.dev/conda-forge/noarch/sortedcontainers-2.4.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/spin-0.15-pyh8f84b5b_0.conda + - conda: https://prefix.dev/conda-forge/noarch/stack_data-0.6.3-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sympy-1.14.0-pyh2585a3b_105.conda - conda: https://prefix.dev/conda-forge/noarch/threadpoolctl-3.6.0-pyhecae5ae_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/tk-8.6.13-h892fb3f_3.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.3.0-pyhcf101f3_0.conda - conda: https://prefix.dev/conda-forge/noarch/toolz-1.1.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/traitlets-5.14.3-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/typing-extensions-4.15.0-h396c80c_0.conda - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.15.0-pyhcf101f3_0.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda - conda: https://prefix.dev/conda-forge/noarch/urllib3-2.6.1-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/wcwidth-0.2.14-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/yaml-0.2.5-h925e9cb_3.conda - conda: https://prefix.dev/conda-forge/noarch/zipp-3.23.0-pyhcf101f3_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda win-64: - conda: https://prefix.dev/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.4.1-pyhe01879c_0.conda + - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/attrs-25.4.0-pyhcf101f3_1.conda - conda: https://prefix.dev/conda-forge/win-64/backports.zstd-1.2.0-py313h2a31948_0.conda - conda: https://prefix.dev/conda-forge/win-64/blas-devel-3.11.0-4_h85df5b5_mkl.conda @@ -406,8 +437,10 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/coverage-7.13.0-py313hd650c13_0.conda - conda: https://prefix.dev/conda-forge/noarch/dask-core-2025.11.0-pyhcf101f3_0.conda + - conda: https://prefix.dev/conda-forge/noarch/decorator-5.2.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/exceptiongroup-1.3.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/execnet-2.1.2-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/executing-2.2.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/filelock-3.20.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/fmt-12.0.0-h29169d4_0.conda - conda: https://prefix.dev/conda-forge/noarch/fsspec-2025.12.0-pyhd8ed1ab_0.conda @@ -420,6 +453,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/idna-3.11-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/importlib-metadata-8.7.0-pyhe01879c_1.conda - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.3.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/ipython-9.8.0-pyhe2676ad_0.conda + - conda: https://prefix.dev/conda-forge/noarch/ipython_pygments_lexers-1.1.1-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/jedi-0.19.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/jinja2-3.1.6-pyhcf101f3_1.conda - conda: https://prefix.dev/conda-forge/win-64/libabseil-20250512.1-cxx17_habfad5f_0.conda - conda: https://prefix.dev/conda-forge/win-64/libblas-3.11.0-4_hf2e6a31_mkl.conda @@ -446,6 +482,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/win-64/markupsafe-3.0.3-py313hd650c13_0.conda - conda: https://prefix.dev/conda-forge/noarch/marray-python-0.0.12-pyh332efcf_0.conda + - conda: https://prefix.dev/conda-forge/noarch/matplotlib-inline-0.2.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/mkl-2025.3.0-hac47afa_454.conda - conda: https://prefix.dev/conda-forge/win-64/mkl-devel-2025.3.0-h57928b3_454.conda - conda: https://prefix.dev/conda-forge/win-64/mkl-include-2025.3.0-h57928b3_454.conda @@ -457,10 +494,13 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/openssl-3.6.0-h725018a_0.conda - conda: https://prefix.dev/conda-forge/win-64/optree-0.18.0-py313hf069bd2_0.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-25.0-pyh29332c3_1.conda + - conda: https://prefix.dev/conda-forge/noarch/parso-0.8.5-pyhcf101f3_0.conda - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/platformdirs-4.5.1-pyhcf101f3_0.conda - conda: https://prefix.dev/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda - conda: https://prefix.dev/conda-forge/noarch/pooch-1.8.2-pyhd8ed1ab_3.conda + - conda: https://prefix.dev/conda-forge/noarch/prompt-toolkit-3.0.52-pyha770c72_0.conda + - conda: https://prefix.dev/conda-forge/noarch/pure_eval-0.2.3-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/pybind11-3.0.1-pyh7a1b43c_0.conda - conda: https://prefix.dev/conda-forge/noarch/pybind11-abi-11-hc364b38_1.conda - conda: https://prefix.dev/conda-forge/noarch/pybind11-global-3.0.1-pyh5e4992e_0.conda @@ -480,12 +520,14 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/sleef-3.9.0-h67fd636_0.conda - conda: https://prefix.dev/conda-forge/noarch/sortedcontainers-2.4.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/spin-0.15-pyha7b4d00_0.conda + - conda: https://prefix.dev/conda-forge/noarch/stack_data-0.6.3-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sympy-1.14.0-pyh04b8f61_5.conda - conda: https://prefix.dev/conda-forge/win-64/tbb-2022.3.0-hd094cb3_1.conda - conda: https://prefix.dev/conda-forge/noarch/threadpoolctl-3.6.0-pyhecae5ae_0.conda - conda: https://prefix.dev/conda-forge/win-64/tk-8.6.13-h2c6b04d_3.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.3.0-pyhcf101f3_0.conda - conda: https://prefix.dev/conda-forge/noarch/toolz-1.1.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/traitlets-5.14.3-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/typing-extensions-4.15.0-h396c80c_0.conda - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.15.0-pyhcf101f3_0.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda @@ -494,6 +536,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-h2b53caa_33.conda - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.44.35208-h818238b_33.conda - conda: https://prefix.dev/conda-forge/win-64/vcomp14-14.44.35208-h818238b_33.conda + - conda: https://prefix.dev/conda-forge/noarch/wcwidth-0.2.14-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/win_inet_pton-1.1.0-pyh7428d3b_8.conda - conda: https://prefix.dev/conda-forge/win-64/yaml-0.2.5-h6a83c73_3.conda - conda: https://prefix.dev/conda-forge/noarch/zipp-3.23.0-pyhcf101f3_1.conda diff --git a/pixi.toml b/pixi.toml index c0cea14c1718..36ed1bcf853e 100644 --- a/pixi.toml +++ b/pixi.toml @@ -26,7 +26,7 @@ solve-group = "default" [environments.ipython] # tasks: ipython -features = ["run-deps", "test-deps", "ipython"] +features = ["run-deps", "test-deps", "ipython-dep", "ipython-task"] solve-group = "default" [environments.bench] @@ -82,8 +82,8 @@ features = ["run-deps", "test-deps", "mkl", "torch-cpu", "torch-cpu-tasks"] solve-group = "array-api-cpu" [environments.array-api-cpu] -# tasks: test-cpu -features = ["run-deps", "test-deps", "test-cpu", "mkl", "array_api_strict", "dask", "jax-cpu", "marray", "torch-cpu"] +# tasks: test-cpu, ipython-cpu +features = ["run-deps", "test-deps", "test-cpu", "mkl", "array_api_strict", "dask", "jax-cpu", "marray", "torch-cpu", "ipython-dep", "ipython-cpu-task"] solve-group = "array-api-cpu" [environments.build-cuda] @@ -245,14 +245,20 @@ description = "Build the documentation" ### IPython ### -[feature.ipython.dependencies] +[feature.ipython-dep.dependencies] ipython = "*" -[feature.ipython.tasks.ipython] +[feature.ipython-task.tasks.ipython] cmd = "spin ipython --no-build" depends-on = "build" description = "Launch IPython" +[feature.ipython-cpu-task.tasks.ipython-cpu] +cmd = "spin ipython --build-dir=build-cpu --no-build" +depends-on = "build-cpu" +env.SCIPY_ARRAY_API = "1" +description = "Launch IPython" + ### Benchmarking ### diff --git a/scipy/_lib/_array_api.py b/scipy/_lib/_array_api.py index d051bb285e40..bdb229e49838 100644 --- a/scipy/_lib/_array_api.py +++ b/scipy/_lib/_array_api.py @@ -1034,3 +1034,7 @@ def xp_device_type(a: Array) -> Literal["cpu", "cuda", None]: return xp_device_type(a._meta) # array-api-strict is a stand-in for unknown libraries; don't special-case it return None + + +def xp_isscalar(x): + return np.isscalar(x) or (is_array_api_obj(x) and x.ndim == 0) diff --git a/scipy/sparse/_sputils.py b/scipy/sparse/_sputils.py index 326c68bab785..a1b9cb0a637e 100644 --- a/scipy/sparse/_sputils.py +++ b/scipy/sparse/_sputils.py @@ -369,14 +369,14 @@ def isintlike(x) -> bool: return True -def isshape(x, nonneg=False, *, allow_nd=(2,)) -> bool: +def isshape(x, nonneg=False, *, allow_nd=(2,), check_nd=True) -> bool: """Is x a valid tuple of dimensions? If nonneg, also checks that the dimensions are non-negative. Shapes of length in the tuple allow_nd are allowed. """ ndim = len(x) - if ndim not in allow_nd: + if check_nd and ndim not in allow_nd: return False for d in x: diff --git a/scipy/sparse/linalg/_eigen/_svds.py b/scipy/sparse/linalg/_eigen/_svds.py index f0591f8fe252..36dde26ab77f 100644 --- a/scipy/sparse/linalg/_eigen/_svds.py +++ b/scipy/sparse/linalg/_eigen/_svds.py @@ -34,6 +34,8 @@ def _iv(A, k, ncv, tol, which, v0, maxiter, if math.prod(A.shape) == 0: message = "`A` must not be empty." raise ValueError(message) + if len(A.shape) != 2: + raise ValueError("Only 2-D input is supported for `A` (a single matrix)") # input validation/standardization for `k` kmax = min(A.shape) if solver == 'propack' else min(A.shape) - 1 diff --git a/scipy/sparse/linalg/_eigen/arpack/arpack.py b/scipy/sparse/linalg/_eigen/arpack/arpack.py index a5927fd8473d..9e3f2aa3d3ac 100644 --- a/scipy/sparse/linalg/_eigen/arpack/arpack.py +++ b/scipy/sparse/linalg/_eigen/arpack/arpack.py @@ -1327,6 +1327,8 @@ def eigs(A, k=6, M=None, sigma=None, which='LM', v0=None, """ A = convert_pydata_sparse_to_scipy(A) + if (A_ndim := len(A.shape)) > 2: + raise ValueError(f"{A_ndim}-dimensional `A` is unsupported, expected 2-D.") M = convert_pydata_sparse_to_scipy(M) if A.shape[0] != A.shape[1]: raise ValueError(f'expected square matrix (shape={A.shape})') @@ -1664,6 +1666,8 @@ def eigsh(A, k=6, M=None, sigma=None, which='LM', v0=None, else: return ret.real + if (A_ndim := len(A.shape)) > 2: + raise ValueError(f"{A_ndim}-dimensional `A` is unsupported, expected 2-D.") if A.shape[0] != A.shape[1]: raise ValueError(f'expected square matrix (shape={A.shape})') if M is not None: diff --git a/scipy/sparse/linalg/_eigen/arpack/tests/test_arpack.py b/scipy/sparse/linalg/_eigen/arpack/tests/test_arpack.py index 560df435f2a9..8bdc5aa69590 100644 --- a/scipy/sparse/linalg/_eigen/arpack/tests/test_arpack.py +++ b/scipy/sparse/linalg/_eigen/arpack/tests/test_arpack.py @@ -688,3 +688,12 @@ def test_real_eigs_real_k_subset(): assert_allclose(dist, 0, atol=np.sqrt(eps)) prev_w = w + +@pytest.mark.parametrize("func", [eigs, eigsh]) +def test_nD(func): + """Check that >2-D operators are rejected cleanly.""" + def id(x): + return x + A = LinearOperator(shape=(2, 2, 2), matvec=id, dtype=np.float64) + with pytest.raises(ValueError, match="expected 2-D"): + func(A) diff --git a/scipy/sparse/linalg/_eigen/lobpcg/lobpcg.py b/scipy/sparse/linalg/_eigen/lobpcg/lobpcg.py index 63e8287c971c..69f9f86ec9b1 100644 --- a/scipy/sparse/linalg/_eigen/lobpcg/lobpcg.py +++ b/scipy/sparse/linalg/_eigen/lobpcg/lobpcg.py @@ -547,6 +547,7 @@ def lobpcg( raise ValueError( f"The shape {A.shape} of the primary matrix\n" f"defined by a callable object is wrong.\n" + f"Expected {(n, n)}." ) elif issparse(A): A = A.toarray() diff --git a/scipy/sparse/linalg/_eigen/tests/test_svds.py b/scipy/sparse/linalg/_eigen/tests/test_svds.py index c2b85d3dd5a3..7e7e98fc3d9e 100644 --- a/scipy/sparse/linalg/_eigen/tests/test_svds.py +++ b/scipy/sparse/linalg/_eigen/tests/test_svds.py @@ -133,7 +133,7 @@ class SVDSCommonTests: _A_empty_msg = "`A` must not be empty." _A_dtype_msg = "`A` must be of numeric data type" _A_type_msg = "type not understood" - _A_ndim_msg = "array must have ndim <= 2" + _A_ndim_msg = "Only 2-D input" _A_validation_inputs = [ (np.asarray([[]]), ValueError, _A_empty_msg), (np.array([['a', 'b'], ['c', 'd']], dtype='object'), ValueError, _A_dtype_msg), diff --git a/scipy/sparse/linalg/_interface.py b/scipy/sparse/linalg/_interface.py index 5f9cfdff5564..cb0e3d4f6365 100644 --- a/scipy/sparse/linalg/_interface.py +++ b/scipy/sparse/linalg/_interface.py @@ -47,19 +47,22 @@ import numpy as np +from scipy import sparse from scipy.sparse import issparse from scipy.sparse._sputils import isshape, isintlike, asmatrix, is_pydata_spmatrix +from scipy._lib._array_api import array_namespace, _asarray, is_lazy_array, np_compat, xp_copy, xp_isscalar +from scipy._lib import array_api_extra as xpx __all__ = ['LinearOperator', 'aslinearoperator'] class LinearOperator: - """Common interface for performing matrix vector products + """Common interface for performing matrix vector products. Many iterative methods (e.g. `cg`, `gmres`) do not need to know the individual entries of a matrix to solve a linear system ``A@x = b``. Such solvers only require the computation of matrix vector - products, ``A@v`` where ``v`` is a dense vector. This class serves as + products, ``A@v``, where ``v`` is a dense vector. This class serves as an abstract interface between iterative solvers and matrix-like objects. @@ -68,7 +71,8 @@ class LinearOperator: A subclass must implement either one of the methods ``_matvec`` and ``_matmat``, and the attributes/properties ``shape`` (pair of - integers) and ``dtype`` (may be None). It may call the ``__init__`` + integers, optionally with additional batch dimensions at the front) + and ``dtype`` (may be None). It may call the ``__init__`` on this class to have these attributes validated. Implementing ``_matvec`` automatically implements ``_matmat`` (using a naive algorithm) and vice-versa. @@ -80,20 +84,34 @@ class LinearOperator: ``_adjoint`` is preferable; ``_rmatvec`` is mostly there for backwards compatibility. + The defined operator may have additional "batch" dimensions + prepended to the core shape, to represent a batch of 2-D operators; + see :ref:`linalg_batch` for details. + TODO: check whether we need to add any caveats for broadcasting. + Parameters ---------- shape : tuple - Matrix dimensions ``(M, N)``. + Matrix dimensions ``(..., M, N)``, + where ``...`` represents any additional batch dimensions. matvec : callable f(v) - Returns returns ``A @ v``. + Returns ``A @ v``, where ``v`` is a dense vector + with shape ``(..., N)`` or ``(..., N, 1)``. + If `shape` contains batch dimensions, this must handle batched input. rmatvec : callable f(v) - Returns ``A^H @ v``, where ``A^H`` is the conjugate transpose of ``A``. + Returns ``A^H @ v``, where ``A^H`` is the conjugate transpose of ``A``, + and ``v`` is a dense vector of shape ``(..., M)`` or ``(..., M, 1)``. + If `shape` contains batch dimensions, this must handle batched input. matmat : callable f(V) - Returns ``A @ V``, where ``V`` is a dense matrix with dimensions ``(N, K)``. - dtype : dtype - Data type of the matrix. + Returns ``A @ V``, where ``V`` is a dense matrix + with dimensions ``(..., N, K)``. + If `shape` contains batch dimensions, this must handle batched input. rmatmat : callable f(V) - Returns ``A^H @ V``, where ``V`` is a dense matrix with dimensions ``(M, K)``. + Returns ``A^H @ V``, where ``V`` is a dense matrix + with dimensions ``(..., M, K)``. + If `shape` contains batch dimensions, this must handle batched input. + dtype : dtype + Data type of the matrix or matrices. Attributes ---------- @@ -101,36 +119,55 @@ class LinearOperator: For linear operators describing products etc. of other linear operators, the operands of the binary operation. ndim : int - Number of dimensions (this is always 2) + Number of dimensions (greater than 2 in the case of batch dimensions). + T : LinearOperator + Transpose. + H : LinearOperator + Hermitian adjoint. + + Methods + ------- + matvec + matmat + adjoint + transpose + rmatvec + rmatmat + dot + rdot + __mul__ + __matmul__ + __call__ + __add__ + __truediv__ + __rmul__ + __rmatmul__ See Also -------- - aslinearoperator : Construct LinearOperators + aslinearoperator : Construct a `LinearOperator`. Notes ----- The user-defined `matvec` function must properly handle the case - where ``v`` has shape ``(N,)`` as well as the ``(N,1)`` case. The shape of - the return type is handled internally by `LinearOperator`. + where ``v`` has shape ``(..., N)`` as well as the ``(..., N, 1)`` case. + The shape of the return type is handled internally by `LinearOperator`. It is highly recommended to explicitly specify the `dtype`, otherwise it is determined automatically at the cost of a single matvec application on ``int8`` zero vector using the promoted `dtype` of the output. - Python ``int`` could be difficult to automatically cast to numpy integers - in the definition of the `matvec` so the determination may be inaccurate. It is assumed that `matmat`, `rmatvec`, and `rmatmat` would result in the same dtype of the output given an ``int8`` input as `matvec`. - LinearOperator instances can also be multiplied, added with each - other and exponentiated, all lazily: the result of these operations - is always a new, composite LinearOperator, that defers linear + `LinearOperator` instances can also be multiplied, added with each + other, and exponentiated, all lazily: the result of these operations + is always a new, composite `LinearOperator`, that defers linear operations to the original operators and combines the results. - More details regarding how to subclass a LinearOperator and several - examples of concrete LinearOperator instances can be found in the + More details regarding how to subclass a `LinearOperator` and several + examples of concrete `LinearOperator` instances can be found in the external project `PyLops `_. - Examples -------- >>> import numpy as np @@ -148,13 +185,14 @@ class LinearOperator: """ - ndim = 2 # Necessary for right matmul with numpy arrays. __array_ufunc__ = None # generic type compatibility with scipy-stubs __class_getitem__ = classmethod(types.GenericAlias) + ndim: int + def __new__(cls, *args, **kwargs): if cls is LinearOperator: # Operate as _CustomLinearOperator factory. @@ -170,21 +208,26 @@ def __new__(cls, *args, **kwargs): return obj - def __init__(self, dtype, shape): + def __init__(self, dtype, shape, xp=None): """Initialize this LinearOperator. To be called by subclasses. ``dtype`` may be None; ``shape`` should - be convertible to a length-2 tuple. + be convertible to a length >=2 tuple. """ + xp = np_compat if xp is None else xp if dtype is not None: - dtype = np.dtype(dtype) + dtype = xp.empty(0, dtype=dtype).dtype shape = tuple(shape) - if not isshape(shape): - raise ValueError(f"invalid shape {shape!r} (must be 2-d)") + if len(shape) < 2: + raise ValueError(f"invalid shape {shape!r} (must be at least 2-d)") + if not is_lazy_array(xp.empty(0)) and not isshape(shape, check_nd=False): + raise ValueError(f"invalid shape {shape!r}") self.dtype = dtype self.shape = shape + self.ndim = len(shape) + self._xp = xp def _init_dtype(self): """Determine the dtype by executing `matvec` on an `int8` test vector. @@ -198,138 +241,203 @@ def _init_dtype(self): Called from subclasses at the end of the __init__ routine. """ if self.dtype is None: - v = np.zeros(self.shape[-1], dtype=np.int8) + batch_shape = self.shape[:-2] + N = self.shape[-1] + v = self._xp.zeros((*batch_shape, N), dtype=self._xp.int8) try: - matvec_v = np.asarray(self.matvec(v)) + matvec_v = self._xp.asarray(self.matvec(v)) except OverflowError: - # Python large `int` promoted to `np.int64`or `np.int32` - self.dtype = np.dtype(int) + self.dtype = xpx.default_dtype("integral") else: self.dtype = matvec_v.dtype def _matmat(self, X): """Default matrix-matrix multiplication handler. + + If self is a linear operator of shape ``(..., M, N)``, + then this method will be called on a shape ``(..., N, K)`` array, + and should return a shape ``(..., M, K)`` array. - Falls back on the user-defined _matvec method, so defining that will - define matrix multiplication (though in a very suboptimal way). + Falls back to `matvec`, so defining that will + define matrix multiplication too (though in a very suboptimal way). """ - - return np.hstack([self.matvec(col.reshape(-1,1)) for col in X.T]) + xp = self._xp + # Maintain backwards-compatibility for 1-D input + if X.ndim == 1: + X = X[xp.newaxis] + + # NOTE: we can't use `matvec` directly as we can't assume that user-defined + # `matvec` functions support batching. + return xp.concat( + [self.matvec(X[..., :, i, xp.newaxis]) for i in range(X.shape[-1])], + axis=-1 + ) def _matvec(self, x): """Default matrix-vector multiplication handler. - If self is a linear operator of shape (M, N), then this method will - be called on a shape (N,) or (N, 1) ndarray, and should return a - shape (M,) or (M, 1) ndarray. + If self is a linear operator of shape ``(..., M, N)``, + then this method will be called on a shape + ``(..., N)`` or ``(..., N, 1)`` array, + and should return a shape ``(..., M)`` or ``(..., M, 1)`` array. - This default implementation falls back on _matmat, so defining that + Falls back to `matmat`, so defining that will define matrix-vector multiplication as well. """ - return self.matmat(x.reshape(-1, 1)) + N = self.shape[-1] + # matmat expects 2-D core + if x.shape[-1] == N: + x = x[..., self._xp.newaxis] + return self.matmat(x) def matvec(self, x): """Matrix-vector multiplication. - Performs the operation y=A@x where A is an MxN linear - operator and x is a column vector or 1-d array. + Performs the operation ``A @ x`` where ``A`` is an ``M`` x ``N`` + linear operator (or batch of linear operators) + and `x` is a row or column vector (or batch of such vectors). Parameters ---------- x : {matrix, ndarray} - An array with shape (N,) or (N,1). + An array with shape ``(..., N)`` representing a row vector + (or batch of row vectors), + or an array with shape ``(..., N, 1)`` representing a column vector + (or batch of column vectors). Returns ------- y : {matrix, ndarray} - A matrix or ndarray with shape (M,) or (M,1) depending - on the type and shape of the x argument. + An array with shape ``(..., M)`` or ``(..., M, 1)`` depending + on the type and shape of `x`. Notes ----- - This matvec wraps the user-specified matvec routine or overridden - _matvec method to ensure that y has the correct shape and type. + This method wraps the user-specified ``matvec`` routine or overridden + ``_matvec`` method to ensure that `y` has the correct shape and type. + + If ``x.shape[-1] == N``, `x` is interpreted as a row vector + (or batch of row vectors if there are leading batch dimensions). + Otherwise, if ``x.shape[-2:] == (N, 1)``, + `x` is interpreted as a column vector + (or batch of column vectors if there are leading batch dimensions). """ - - x = np.asanyarray(x) - - M,N = self.shape - - if x.shape != (N,) and x.shape != (N,1): - raise ValueError('dimension mismatch') + xp = self._xp + + x = _asarray(x, subok=True, xp=xp) + + *self_broadcast_dims, M, N = self.shape + + x_broadcast_dims: tuple[int, ...] = () + row_vector: bool = False + if x.ndim >= 1 and (row_vector := x.shape[-1] == N): + x_broadcast_dims = x.shape[:-1] + if column_vector := x.shape[-2:] == (N, 1): + x_broadcast_dims = x.shape[:-2] + if not (row_vector or column_vector): + msg = ( + f"Dimension mismatch: `x` must have a shape ending in " + f"`({N},)` or `({N}, 1)`. Given shape: {x.shape}" + ) + raise ValueError(msg) y = self._matvec(x) if isinstance(x, np.matrix): y = asmatrix(y) else: - y = np.asarray(y) + y = xp.asarray(y) - if x.ndim == 1: - y = y.reshape(M) - elif x.ndim == 2: - y = y.reshape(M,1) - else: - raise ValueError('invalid shape returned by user-defined matvec()') + broadcasted_dims = xpx.broadcast_shapes(self_broadcast_dims, x_broadcast_dims) + if row_vector: + y = xp.reshape(y, (*broadcasted_dims, M)) + elif column_vector: + y = y.reshape(*broadcasted_dims, M, 1) return y def rmatvec(self, x): """Adjoint matrix-vector multiplication. - Performs the operation y = A^H @ x where A is an MxN linear - operator and x is a column vector or 1-d array. + Performs the operation ``A^H @ x`` where ``A`` is an + ``M`` x ``N`` linear operator (or batch of linear operators) + and `x` is a row or column vector (or batch of such vectors). Parameters ---------- x : {matrix, ndarray} - An array with shape (M,) or (M,1). + An array with shape ``(..., M)`` representing a row vector + (or batch of row vectors), + or an array with shape ``(..., M, 1)`` representing a column vector + (or batch of column vectors). Returns ------- y : {matrix, ndarray} - A matrix or ndarray with shape (N,) or (N,1) depending - on the type and shape of the x argument. + An array with shape ``(..., N)`` or ``(..., N, 1)`` depending + on the type and shape of `x`. Notes ----- - This rmatvec wraps the user-specified rmatvec routine or overridden - _rmatvec method to ensure that y has the correct shape and type. + This method wraps the user-specified ``rmatvec`` routine or overridden + ``_rmatvec`` method to ensure that `y` has the correct shape and type. + + If ``x.shape[-1] == M``, `x` is interpreted as a row vector + (or batch of row vectors if there are leading batch dimensions). + Otherwise, if ``x.shape[-2:] == (M, 1)``, + `x` is interpreted as a column vector + (or batch of column vectors if there are leading batch dimensions). """ - - x = np.asanyarray(x) - - M,N = self.shape - - if x.shape != (M,) and x.shape != (M,1): - raise ValueError('dimension mismatch') + xp = self._xp + x = _asarray(x, subok=True, xp=xp) + + *self_broadcast_dims, M, N = self.shape + + x_broadcast_dims: tuple[int, ...] = () + row_vector: bool = False + if x.ndim >= 1 and (row_vector := x.shape[-1] == M): + x_broadcast_dims = x.shape[:-1] + if column_vector := x.shape[-2:] == (M, 1): + x_broadcast_dims = x.shape[:-2] + if not (row_vector or column_vector): + msg = ( + f"Dimension mismatch: `x` must have a shape ending in " + f"`({M},)` or `({M}, 1)`. Given shape: {x.shape}" + ) + raise ValueError(msg) y = self._rmatvec(x) if isinstance(x, np.matrix): y = asmatrix(y) else: - y = np.asarray(y) + y = xp.asarray(y) - if x.ndim == 1: - y = y.reshape(N) - elif x.ndim == 2: - y = y.reshape(N,1) - else: - raise ValueError('invalid shape returned by user-defined rmatvec()') + broadcasted_dims = xpx.broadcast_shapes(self_broadcast_dims, x_broadcast_dims) + if row_vector: + y = xp.reshape(y, (*broadcasted_dims, N)) + elif column_vector: + y = xp.reshape(y, (*broadcasted_dims, N, 1)) return y def _rmatvec(self, x): - """Default implementation of _rmatvec; defers to adjoint.""" + """Default implementation of `_rmatvec`. + Defers to `_rmatmat` or `adjoint`.""" if type(self)._adjoint == LinearOperator._adjoint: # _adjoint not overridden, prevent infinite recursion if (hasattr(self, "_rmatmat") and type(self)._rmatmat != LinearOperator._rmatmat): + xp = self._xp # Try to use _rmatmat as a fallback - return self._rmatmat(x.reshape(-1, 1)).reshape(-1) + M = self.shape[-2] + # _rmatmat expects 2-D core + if x.shape[-1] == M: + return xp.reshape(self._rmatmat(x[..., xp.newaxis]), *x.shape) + else: + return self._rmatmat(x) raise NotImplementedError else: return self.H.matvec(x) @@ -337,33 +445,36 @@ def _rmatvec(self, x): def matmat(self, X): """Matrix-matrix multiplication. - Performs the operation y=A@X where A is an MxN linear - operator and X dense N*K matrix or ndarray. + Performs the operation ``A @ X`` where ``A`` is an ``M`` x ``N`` + linear operator (or batch of linear operators) + and ``X`` is a dense ``N`` x ``K`` matrix + (or batch of dense matrices). Parameters ---------- X : {matrix, ndarray} - An array with shape (N,K). + An array with shape ``(..., N, K)`` representing the dense matrix + (or batch of dense matrices). Returns ------- Y : {matrix, ndarray} - A matrix or ndarray with shape (M,K) depending on - the type of the X argument. + An array with shape ``(..., M, K)`` depending on + the type of `X`. Notes ----- - This matmat wraps any user-specified matmat routine or overridden - _matmat method to ensure that y has the correct type. + This method wraps any user-specified ``matmat`` routine or overridden + ``_matmat`` method to ensure that `Y` has the correct type. """ if not (issparse(X) or is_pydata_spmatrix(X)): - X = np.asanyarray(X) + X = _asarray(X, subok=True, xp=self._xp) - if X.ndim != 2: - raise ValueError(f'expected 2-d ndarray or matrix, not {X.ndim}-d') + if X.ndim < 2: + raise ValueError(f'expected at least 2-d ndarray or matrix, not {X.ndim}-d') - if X.shape[0] != self.shape[1]: + if X.shape[-2] != self.shape[-1]: raise ValueError(f'dimension mismatch: {self.shape}, {X.shape}') try: @@ -384,32 +495,36 @@ def matmat(self, X): def rmatmat(self, X): """Adjoint matrix-matrix multiplication. - Performs the operation y = A^H @ x where A is an MxN linear - operator and x is a column vector or 1-d array, or 2-d array. + Performs the operation ``A^H @ X`` where ``A`` is an ``M`` x ``N`` + linear operator (or batch of linear operators) + and `X` is a dense ``M`` x ``K`` matrix + (or batch of dense matrices). The default implementation defers to the adjoint. Parameters ---------- X : {matrix, ndarray} - A matrix or 2D array. + An array with shape ``(..., M, K)`` representing the dense matrix + (or batch of dense matrices). Returns ------- Y : {matrix, ndarray} - A matrix or 2D array depending on the type of the input. + An array with shape ``(..., N, K)`` depending on the type of `X`. Notes ----- - This rmatmat wraps the user-specified rmatmat routine. + This method wraps any user-specified ``rmatmat`` routine or overridden + ``_rmatmat`` method to ensure that `Y` has the correct type. """ if not (issparse(X) or is_pydata_spmatrix(X)): - X = np.asanyarray(X) + X = _asarray(X, subok=True, xp=self._xp) - if X.ndim != 2: - raise ValueError(f'expected 2-d ndarray or matrix, not {X.ndim}-d') + if X.ndim < 2: + raise ValueError(f'expected at least 2-d ndarray or matrix, not {X.ndim}-d') - if X.shape[0] != self.shape[0]: + if X.shape[-2] != self.shape[-2]: raise ValueError(f'dimension mismatch: {self.shape}, {X.shape}') try: @@ -427,179 +542,398 @@ def rmatmat(self, X): return Y def _rmatmat(self, X): - """Default implementation of _rmatmat defers to rmatvec or adjoint.""" + """Default implementation of `_rmatmat`; defers to `rmatvec` or `adjoint`.""" if type(self)._adjoint == LinearOperator._adjoint: - return np.hstack([self.rmatvec(col.reshape(-1, 1)) for col in X.T]) + xp = self._xp + # Maintain backwards-compatibility for 1-D input + if X.ndim == 1: + X = X[xp.newaxis] + + # NOTE: we can't use `rmatvec` directly as we can't assume that user-defined + # `rmatvec` functions support batching. + return xp.concat( + [self.rmatvec(X[..., :, i, xp.newaxis]) for i in range(X.shape[-1])], + axis=-1 + ) else: return self.H.matmat(X) def __call__(self, x): - return self@x + """Apply this linear operator. + + Equivalent to `__matmul__`. + """ + return self @ x def __mul__(self, x): + """Multiplication. + + Used by the ``*`` operator. Equivalent to `dot`. + """ return self.dot(x) def __truediv__(self, other): - if not np.isscalar(other): + """Scalar Division. + + Returns a lazily scaled linear operator. + """ + if not xp_isscalar(other): raise ValueError("Can only divide a linear operator by a scalar.") - return _ScaledLinearOperator(self, 1.0/other) + return _ScaledLinearOperator(self, 1.0/other, xp=self._xp) def dot(self, x): - """Matrix-matrix or matrix-vector multiplication. + """Multi-purpose multiplication method. Parameters ---------- - x : array_like - 1-d or 2-d array, representing a vector or matrix. + x : array_like or `LinearOperator` or scalar + Array-like input will be interpreted as a dense column vector, + matrix, or row vector (or batch of such vectors or matrices) + depending on its shape. See the Returns section for details. Returns ------- - Ax : array - 1-d or 2-d array (depending on the shape of x) that represents - the result of applying this linear operator on x. + Ax : array or `LinearOperator` + - For `LinearOperator` input, operator composition is performed. + + - For scalar input, a lazily scaled operator is returned. + + - Otherwise, the input is expected to take the form of a dense + vector or matrix (or batch of such vectors or matrices), + interpreted as follows + (where ``self`` is an ``M`` by ``N`` linear operator): + + - If `x` has shape ``(N, 1)`` + it is interpreted as a column vector + and `matvec` is called. + - Otherwise, if `x` has shape ``(..., N, K)`` for some + integer ``K``, it is interpreted as a matrix + (or batch of matrices if there are batch dimensions) + and `matmat` is called. + - Otherwise, if `x` has shape ``(..., N)``, + it is interpreted as a row vector + (or batch of row vectors if there are batch dimensions) + and `matvec` is called. + + .. warning :: + + `x` of shape ``(..., N, N)`` will be interpreted as + a batch of matrices of shape ``(N, N)``. + If `x` is intended to be a batch of row vectors of + where the batch shape happens to end in ``N``, please + use the `.matvec` method instead. + + Notes + ----- + For clarity, it is recommended to use the `matvec` or + `matmat` methods directly instead of this method + when interacting with dense vectors and matrices. + + See Also + -------- + __mul__ : Equivalent method used by the ``*`` operator. + __matmul__ : + Method used by the ``@`` operator which rejects scalar + input before calling this method. """ if isinstance(x, LinearOperator): - return _ProductLinearOperator(self, x) - elif np.isscalar(x): - return _ScaledLinearOperator(self, x) + if (xp_x := getattr(x, "_xp", np)) != self._xp: + msg = ( + f"Mismatched array namespaces." + f"Namespace for self is {self._xp}, namespace for x is {xp_x}" + ) + raise TypeError(msg) + return _ProductLinearOperator(self, x, self._xp) + elif xp_isscalar(x): + if (xp_x := array_namespace(x, self._xp.empty(0))) != self._xp: + msg = ( + f"Mismatched array namespaces." + f"Namespace for self is {self._xp}, namespace for x is {xp_x}" + ) + raise TypeError(msg) + return _ScaledLinearOperator(self, x, self._xp) else: if not issparse(x) and not is_pydata_spmatrix(x): - # Sparse matrices shouldn't be converted to numpy arrays. - x = np.asarray(x) - - if x.ndim == 1 or x.ndim == 2 and x.shape[1] == 1: + x = self._xp.asarray(x) + + N = self.shape[-1] + + # maintain column vector backwards-compatibility in 2-D case + column_vector = x.shape == (N, 1) + # otherwise, treat input as a matrix if the shape fits + matrix = x.ndim >= 2 and x.shape[-2] == N + # otherwise, treat as a row-vector + row_vector = x.shape[-1] == N + + # NOTE: for `x.ndim > 2`, `np.dot(a, b)` implements different semantics: + # sum product over the last axis of `a` and the second-to-last axis of `b`. + + if not (row_vector or column_vector or matrix): + msg = ( + f"Dimension mismatch: `x` must have a shape ending in " + f"`({N},)` or `({N}, 1)` or `({N}, K)` for some integer `K`. " + f"Given shape: {x.shape}" + ) + raise ValueError(msg) + + if column_vector: return self.matvec(x) - elif x.ndim == 2: + elif matrix: return self.matmat(x) - else: - raise ValueError(f'expected 1-d or 2-d array or matrix, got {x!r}') + elif row_vector: + return self.matvec(x) def __matmul__(self, other): - if np.isscalar(other): + """Matrix Multiplication. + + Used by the ``@`` operator. + Rejects scalar input. + Otherwise, equivalent to `dot`. + """ + if xp_isscalar(other): raise ValueError("Scalar operands are not allowed, " "use '*' instead") return self.__mul__(other) def __rmatmul__(self, other): - if np.isscalar(other): + """Matrix Multiplication from the right. + + Used by the ``@`` operator from the right. + Rejects scalar input. + Otherwise, equivalent to `rdot`. + """ + if xp_isscalar(other): raise ValueError("Scalar operands are not allowed, " "use '*' instead") return self.__rmul__(other) def __rmul__(self, x): - if np.isscalar(x): - return _ScaledLinearOperator(self, x) - else: - return self._rdot(x) - - def _rdot(self, x): - """Matrix-matrix or matrix-vector multiplication from the right. - + """Multi-purpose multiplication method from the right. + + Used by the ``*`` operator from the right. Equivalent to `rdot`. + """ + return self.rdot(x) + + def rdot(self, x): + """Multi-purpose multiplication method from the right. + + .. note :: + + For complex data, this does not perform conjugation, + returning ``xA`` rather than ``x A^H``. + To calculate adjoint multiplication instead, use one of + `rmatvec` or `rmatmat`, or take the adjoint first, + like ``self.H.rdot(x)`` or ``x * self.H``. + + .. note :: + + Array-like input to this function is unsupported for linear operators + with batch shapes. + It is recommended to transpose data separately + and then use forward operations like `matvec` and `matmat` directly. + Parameters ---------- - x : array_like - 1-d or 2-d array, representing a vector or matrix. + x : array_like or `LinearOperator` or scalar + Array-like input will be interpreted as a dense row vector, + matrix, or column vector, depending on its shape. + See the Returns section for details. Returns ------- - xA : array - 1-d or 2-d array (depending on the shape of x) that represents - the result of applying this linear operator on x from the right. - - Notes - ----- - This is copied from dot to implement right multiplication. + xA : array or `LinearOperator` + - For `LinearOperator` input, operator composition is performed. + + - For scalar input, a lazily scaled operator is returned. + + - Otherwise, the input is expected to take the form of a dense + vector or matrix, interpreted as follows + (where ``self`` is an ``M`` by ``N`` linear operator): + + - If `x` has shape ``(M,)`` + it is interpreted as a row vector. + - Otherwise, if `x` has shape ``(1, M)`` + it is interpreted as a column vector. + - Otherwise, if `x` has shape ``(K, M)`` for some + integer ``K``, it is interpreted as a matrix. + + See Also + -------- + dot : Multi-purpose multiplication method from the left. + __rmul__ : + Equivalent method, used by the ``*`` operator from the right. + __rmatmul__ : + Method used by the ``@`` operator from the right + which rejects scalar input before calling this method. """ if isinstance(x, LinearOperator): - return _ProductLinearOperator(x, self) - elif np.isscalar(x): - return _ScaledLinearOperator(self, x) + if (xp_x := getattr(x, "_xp", np)) != self._xp: + msg = ( + f"Mismatched array namespaces." + f"Namespace for self is {self._xp}, namespace for x is {xp_x}" + ) + raise TypeError(msg) + return _ProductLinearOperator(x, self, self._xp) + elif xp_isscalar(x): + if (xp_x := array_namespace(x, self._xp.empty(0))) != self._xp: + msg = ( + f"Mismatched array namespaces." + f"Namespace for self is {self._xp}, namespace for x is {xp_x}" + ) + raise TypeError(msg) + return _ScaledLinearOperator(self, x, self._xp) else: + if self.ndim > 2: + msg = ( + "Array-like input is unsupported in `rdot` for batched" + "operators (with `ndim > 2`).\n" + "It is recommended to transpose data separately and then" + "use forward operations like `matvec` and `matmat` directly." + ) + raise ValueError(msg) if not issparse(x) and not is_pydata_spmatrix(x): - # Sparse matrices shouldn't be converted to numpy arrays. - x = np.asarray(x) + x = self._xp.asarray(x) + + M = self.shape[-2] + + # treat 1-D input as a row-vector + row_vector = x.shape == (M,) + # maintain column vector backwards-compatibility in 2-D case + column_vector = x.shape == (1, M) + # otherwise, treat input as a matrix if the shape fits + matrix = x.ndim == 2 and x.shape[-1] == M + + if not (row_vector or column_vector or matrix): + msg = ( + f"Dimension mismatch: `x` must have shape `({M},)`, " + f"`(1, {M})`, or `(K, {M})` for some integer `K`. " + f"Given shape: {x.shape}" + ) + raise ValueError(msg) # We use transpose instead of rmatvec/rmatmat to avoid # unnecessary complex conjugation if possible. - if x.ndim == 1 or x.ndim == 2 and x.shape[0] == 1: + if row_vector or column_vector: return self.T.matvec(x.T).T - elif x.ndim == 2: + elif matrix: return self.T.matmat(x.T).T - else: - raise ValueError(f'expected 1-d or 2-d array or matrix, got {x!r}') def __pow__(self, p): - if np.isscalar(p): - return _PowerLinearOperator(self, p) + if xp_isscalar(p): + if (xp_p := array_namespace(p, self._xp.empty(0))) != self._xp: + msg = ( + f"Mismatched array namespaces." + f"Namespace for self is {self._xp}, namespace for p is {xp_p}" + ) + raise TypeError(msg) + return _PowerLinearOperator(self, p, self._xp) else: return NotImplemented def __add__(self, x): + """Linear operator addition. + + The input must be a `LinearOperator`. + A lazily summed linear operator is returned. + """ if isinstance(x, LinearOperator): - return _SumLinearOperator(self, x) + if (xp_x := getattr(x, "_xp", np)) != self._xp: + msg = ( + f"Mismatched array namespaces." + f"Namespace for self is {self._xp}, namespace for x is {xp_x}" + ) + raise TypeError(msg) + return _SumLinearOperator(self, x, xp=self._xp) else: return NotImplemented def __neg__(self): - return _ScaledLinearOperator(self, -1) + return _ScaledLinearOperator(self, -1, xp=self._xp) def __sub__(self, x): return self.__add__(-x) def __repr__(self): - M,N = self.shape if self.dtype is None: dt = 'unspecified dtype' else: dt = 'dtype=' + str(self.dtype) - return f'<{M}x{N} {self.__class__.__name__} with {dt}>' + shape = 'x'.join(str(dim) for dim in self.shape) + return f'<{shape} {self.__class__.__name__} with {dt}>' def adjoint(self): """Hermitian adjoint. - Returns the Hermitian adjoint of self, aka the Hermitian + Returns the Hermitian adjoint of this linear operator, + also known as the Hermitian conjugate or Hermitian transpose. For a complex matrix, the Hermitian adjoint is equal to the conjugate transpose. - Can be abbreviated self.H instead of self.adjoint(). - Returns ------- - A_H : LinearOperator + `LinearOperator` Hermitian adjoint of self. + + See Also + -------- + :attr:`~scipy.sparse.linalg.LinearOperator.H` : Equivalent attribute. """ return self._adjoint() - H = property(adjoint) + @property + def H(self): + """Hermitian adjoint. + + See Also + -------- + scipy.sparse.linalg.LinearOperator.adjoint : Equivalent method. + """ + return self.adjoint() def transpose(self): - """Transpose this linear operator. + """Transpose. - Returns a LinearOperator that represents the transpose of this one. - Can be abbreviated self.T instead of self.transpose(). + Returns + ------- + `LinearOperator` + Transposition of this linear operator. + + See Also + -------- + :attr:`~scipy.sparse.linalg.LinearOperator.T` : Equivalent attribute. """ return self._transpose() - T = property(transpose) + @property + def T(self): + """Transpose. + + See Also + -------- + scipy.sparse.linalg.LinearOperator.transpose : Equivalent method. + """ + return self.transpose() def _adjoint(self): - """Default implementation of _adjoint; defers to rmatvec.""" - return _AdjointLinearOperator(self) + """Default implementation of `_adjoint`. + Defers to adjoint functions, e.g. `_rmatvec` for `_matvec`.""" + return _AdjointLinearOperator(self, self._xp) def _transpose(self): - """ Default implementation of _transpose; defers to rmatvec + conj""" - return _TransposedLinearOperator(self) + """Default implementation of `_transpose`. + For `_matvec`, defers to `_rmatvec` + `np.conj`.""" + return _TransposedLinearOperator(self, self._xp) class _CustomLinearOperator(LinearOperator): """Linear operator defined in terms of user-specified operations.""" def __init__(self, shape, matvec, rmatvec=None, matmat=None, - dtype=None, rmatmat=None): - super().__init__(dtype, shape) + dtype=None, rmatmat=None, xp=None): + super().__init__(dtype, shape, xp) self.args = () @@ -632,20 +966,23 @@ def _rmatmat(self, X): return super()._rmatmat(X) def _adjoint(self): - return _CustomLinearOperator(shape=(self.shape[1], self.shape[0]), - matvec=self.__rmatvec_impl, - rmatvec=self.__matvec_impl, - matmat=self.__rmatmat_impl, - rmatmat=self.__matmat_impl, - dtype=self.dtype) + return _CustomLinearOperator( + shape=(*self.shape[:-2], self.shape[-1], self.shape[-2]), + matvec=self.__rmatvec_impl, + rmatvec=self.__matvec_impl, + matmat=self.__rmatmat_impl, + rmatmat=self.__matmat_impl, + dtype=self.dtype, + xp=self._xp + ) class _AdjointLinearOperator(LinearOperator): """Adjoint of arbitrary Linear Operator""" - def __init__(self, A): - shape = (A.shape[1], A.shape[0]) - super().__init__(dtype=A.dtype, shape=shape) + def __init__(self, A, xp=None): + shape = (*A.shape[:-2], A.shape[-1], A.shape[-2]) + super().__init__(A.dtype, shape, xp) self.A = A self.args = (A,) @@ -664,44 +1001,49 @@ def _rmatmat(self, x): class _TransposedLinearOperator(LinearOperator): """Transposition of arbitrary Linear Operator""" - def __init__(self, A): - shape = (A.shape[1], A.shape[0]) - super().__init__(dtype=A.dtype, shape=shape) + def __init__(self, A, xp=None): + shape = (*A.shape[:-2], A.shape[-1], A.shape[-2]) + super().__init__(A.dtype, shape, xp) self.A = A self.args = (A,) + self._xp = A._xp def _matvec(self, x): - # NB. np.conj works also on sparse matrices - return np.conj(self.A._rmatvec(np.conj(x))) + return self._xp.conj(self.A._rmatvec(self._xp.conj(x))) def _rmatvec(self, x): - return np.conj(self.A._matvec(np.conj(x))) + return self._xp.conj(self.A._matvec(self._xp.conj(x))) def _matmat(self, x): - # NB. np.conj works also on sparse matrices - return np.conj(self.A._rmatmat(np.conj(x))) + return self._xp.conj(self.A._rmatmat(self._xp.conj(x))) def _rmatmat(self, x): - return np.conj(self.A._matmat(np.conj(x))) + return self._xp.conj(self.A._matmat(self._xp.conj(x))) + -def _get_dtype(operators, dtypes=None): +def _get_dtype(operators, dtypes=None, xp=None): + """Returns the promoted dtype from input dtypes and operators.""" + xp = np_compat if xp is None else xp if dtypes is None: dtypes = [] for obj in operators: if obj is not None and hasattr(obj, 'dtype'): dtypes.append(obj.dtype) - return np.result_type(*dtypes) + return xp.result_type(*dtypes) class _SumLinearOperator(LinearOperator): - def __init__(self, A, B): - if not isinstance(A, LinearOperator) or \ - not isinstance(B, LinearOperator): + """Representing ``A + B``""" + def __init__(self, A, B, xp=None): + if not isinstance(A, LinearOperator) or not isinstance(B, LinearOperator): raise ValueError('both operands have to be a LinearOperator') - if A.shape != B.shape: + *A_broadcast_dims, A_M, A_N = A.shape + *B_broadcast_dims, B_M, B_N = B.shape + if (A_M, A_N) != (B_M, B_N): raise ValueError(f'cannot add {A} and {B}: shape mismatch') + broadcasted_dims = xp.broadcast_shapes(A_broadcast_dims, B_broadcast_dims) self.args = (A, B) - super().__init__(_get_dtype([A, B]), A.shape) + super().__init__(_get_dtype([A, B]), (*broadcasted_dims, A_M, A_N), xp) def _matvec(self, x): return self.args[0].matvec(x) + self.args[1].matvec(x) @@ -721,14 +1063,16 @@ def _adjoint(self): class _ProductLinearOperator(LinearOperator): - def __init__(self, A, B): - if not isinstance(A, LinearOperator) or \ - not isinstance(B, LinearOperator): + """Representing ``A @ B``""" + def __init__(self, A, B, xp=None): + if not isinstance(A, LinearOperator) or not isinstance(B, LinearOperator): raise ValueError('both operands have to be a LinearOperator') - if A.shape[1] != B.shape[0]: + *A_broadcast_dims, A_M, A_N = A.shape + *B_broadcast_dims, B_M, B_N = B.shape + if A_N != B_M: raise ValueError(f'cannot multiply {A} and {B}: shape mismatch') - super().__init__(_get_dtype([A, B]), - (A.shape[0], B.shape[1])) + broadcasted_dims = np.broadcast_shapes(A_broadcast_dims, B_broadcast_dims) + super().__init__(_get_dtype([A, B]), (*broadcasted_dims, A_M, B_N), xp) self.args = (A, B) def _matvec(self, x): @@ -749,7 +1093,8 @@ def _adjoint(self): class _ScaledLinearOperator(LinearOperator): - def __init__(self, A, alpha): + """Representing ``alpha * A``""" + def __init__(self, A, alpha, xp=None): if not isinstance(A, LinearOperator): raise ValueError('LinearOperator expected as A') if not np.isscalar(alpha): @@ -761,7 +1106,7 @@ def __init__(self, A, alpha): alpha = alpha * alpha_original dtype = _get_dtype([A], [type(alpha)]) - super().__init__(dtype, A.shape) + super().__init__(dtype, A.shape, xp) self.args = (A, alpha) # Note: args[1] is alpha (a scalar), so use `*` below, not `@` @@ -769,33 +1114,35 @@ def _matvec(self, x): return self.args[1] * self.args[0].matvec(x) def _rmatvec(self, x): - return np.conj(self.args[1]) * self.args[0].rmatvec(x) + return self._xp.conj(self.args[1]) * self.args[0].rmatvec(x) def _rmatmat(self, x): - return np.conj(self.args[1]) * self.args[0].rmatmat(x) + return self._xp.conj(self.args[1]) * self.args[0].rmatmat(x) def _matmat(self, x): return self.args[1] * self.args[0].matmat(x) def _adjoint(self): A, alpha = self.args - return A.H * np.conj(alpha) + return A.H * self._xp.conj(alpha) class _PowerLinearOperator(LinearOperator): - def __init__(self, A, p): + """Representing ``A ** p``""" + def __init__(self, A, p, xp=None): if not isinstance(A, LinearOperator): raise ValueError('LinearOperator expected as A') - if A.shape[0] != A.shape[1]: - raise ValueError(f'square LinearOperator expected, got {A!r}') + if A.shape[-2] != A.shape[-1]: + msg = f'square core-dimensions of LinearOperator expected, got {A!r}' + raise ValueError(msg) if not isintlike(p) or p < 0: raise ValueError('non-negative integer expected as p') - super().__init__(_get_dtype([A]), A.shape) + super().__init__(_get_dtype([A]), A.shape, xp) self.args = (A, p) def _power(self, fun, x): - res = np.array(x, copy=True) + res = xp_copy(x) for i in range(self.args[1]): res = fun(res) return res @@ -818,38 +1165,52 @@ def _adjoint(self): class MatrixLinearOperator(LinearOperator): - def __init__(self, A): - super().__init__(A.dtype, A.shape) + """Operator defined by a matrix `A` which implements ``@``.""" + def __init__(self, A, xp=None): + super().__init__(A.dtype, A.shape, xp) self.A = A self.__adj = None self.args = (A,) def _matmat(self, X): - return self.A.dot(X) + return self.A @ X def _adjoint(self): if self.__adj is None: - self.__adj = _AdjointMatrixOperator(self.A) + self.__adj = _AdjointMatrixOperator(self.A, self._xp) return self.__adj class _AdjointMatrixOperator(MatrixLinearOperator): - def __init__(self, adjoint_array): - self.A = adjoint_array.T.conj() - self.args = (adjoint_array,) - self.shape = adjoint_array.shape[1], adjoint_array.shape[0] + """Representing ``A.H``, for `MatrixLinearOperator` `A`.""" + def __init__(self, A, xp=None): + xp = np_compat if xp is None else xp + if A.ndim > 2: + if issparse(A): + A_T = sparse.swapaxes(A, -1, -2) + else: + A_T = A.mT + else: + A_T = A.T + self.A = xp.conj(A_T) + self.args = (A,) + self.shape = ( + *A.shape[:-2], A.shape[-1], A.shape[-2] + ) + self.ndim = A.ndim + self._xp = xp @property def dtype(self): return self.args[0].dtype def _adjoint(self): - return MatrixLinearOperator(self.args[0]) + return MatrixLinearOperator(self.args[0], self._xp) class IdentityOperator(LinearOperator): - def __init__(self, shape, dtype=None): - super().__init__(dtype, shape) + def __init__(self, shape, dtype=None, xp=None): + super().__init__(dtype, shape, xp) def _matvec(self, x): return x @@ -868,21 +1229,22 @@ def _adjoint(self): def aslinearoperator(A): - """Return A as a LinearOperator. + """Return `A` as a `LinearOperator`. - 'A' may be any of the following types: - - ndarray - - matrix - - sparse array (e.g. csr_array, lil_array, etc.) - - LinearOperator - - An object with .shape and .matvec attributes + `A` may be any of the following types: + - `numpy.ndarray` + - `numpy.matrix` + - `scipy.sparse` array + (e.g. `~scipy.sparse.csr_array`, `~scipy.sparse.lil_array`, etc.) + - `LinearOperator` + - An object with ``.shape`` and ``.matvec`` attributes - See the LinearOperator documentation for additional information. + See the `LinearOperator` documentation for additional information. Notes ----- - If 'A' has no .dtype attribute, the data type is determined by calling - :func:`LinearOperator.matvec()` - set the .dtype attribute to prevent this + If `A` has no ``.dtype`` attribute, the data type is determined by calling + :func:`LinearOperator.matvec()` - set the ``.dtype`` attribute to prevent this call upon the linear operator creation. Examples @@ -893,19 +1255,32 @@ def aslinearoperator(A): >>> aslinearoperator(M) <2x3 MatrixLinearOperator with dtype=int32> """ + A, _ = _xp_aslinearoperator(A) + return A + +def _xp_aslinearoperator(A): + """ + Return `A` as a linear operator, + as well as a compatible array namespace `xp` for `A`. + Fallback to NumPy for unknown types. + """ if isinstance(A, LinearOperator): - return A + return A, getattr(A, "_xp", np_compat) - elif isinstance(A, np.ndarray) or isinstance(A, np.matrix): - if A.ndim > 2: - raise ValueError('array must have ndim <= 2') + elif issparse(A): + return MatrixLinearOperator(A), np_compat + + elif isinstance(A, np.matrix): A = np.atleast_2d(np.asarray(A)) - return MatrixLinearOperator(A) - - elif issparse(A) or is_pydata_spmatrix(A): - return MatrixLinearOperator(A) - - else: + return MatrixLinearOperator(A), np_compat + + try: + xp = array_namespace(A) + xp = np_compat if xp is sparse else xp + A = xpx.atleast_nd(A, ndim=2, xp=xp) + return MatrixLinearOperator(A, xp=xp), xp + + except: if hasattr(A, 'shape') and hasattr(A, 'matvec'): rmatvec = None rmatmat = None @@ -917,8 +1292,10 @@ def aslinearoperator(A): rmatmat = A.rmatmat if hasattr(A, 'dtype'): dtype = A.dtype + xp = array_namespace(A) or np_compat + xp = np_compat if xp is sparse else xp return LinearOperator(A.shape, A.matvec, rmatvec=rmatvec, - rmatmat=rmatmat, dtype=dtype) + rmatmat=rmatmat, dtype=dtype, xp=xp), xp else: raise TypeError('type not understood') diff --git a/scipy/sparse/linalg/_isolve/_gcrotmk.py b/scipy/sparse/linalg/_isolve/_gcrotmk.py index 1b053919c101..ccd0cf1cb2ec 100644 --- a/scipy/sparse/linalg/_isolve/_gcrotmk.py +++ b/scipy/sparse/linalg/_isolve/_gcrotmk.py @@ -273,7 +273,7 @@ def gcrotmk(A, b, x0=None, *, rtol=1e-5, atol=0., maxiter=1000, M=None, callback True """ - A,M,x,b = make_system(A,M,x0,b) + A,M,x,b,_ = make_system(A,M,x0,b) if not np.isfinite(b).all(): raise ValueError("RHS must contain only finite numbers") diff --git a/scipy/sparse/linalg/_isolve/iterative.py b/scipy/sparse/linalg/_isolve/iterative.py index 637f2d021f62..910f63b2cd02 100644 --- a/scipy/sparse/linalg/_isolve/iterative.py +++ b/scipy/sparse/linalg/_isolve/iterative.py @@ -1,13 +1,18 @@ import warnings +import functools + import numpy as np from scipy.sparse.linalg._interface import LinearOperator from .utils import make_system from scipy.linalg import get_lapack_funcs +from scipy._lib import array_api_extra as xpx +from scipy._lib._array_api import is_lazy_array, xp_copy, xp_vector_norm + __all__ = ['bicg', 'bicgstab', 'cg', 'cgs', 'gmres', 'qmr'] -def _get_atol_rtol(name, b_norm, atol=0., rtol=1e-5): +def _get_atol_rtol(name, b_norm, atol=0., rtol=1e-5, xp=np): """ A helper function to handle tolerance normalization """ @@ -16,7 +21,7 @@ def _get_atol_rtol(name, b_norm, atol=0., rtol=1e-5): "if set, `atol` must be a real, non-negative number.") raise ValueError(msg) - atol = max(float(atol), float(rtol) * float(b_norm)) + atol = xp.max(xp.stack((xp.asarray(float(atol)), float(rtol) * xp.min(b_norm)))) return atol, rtol @@ -87,7 +92,7 @@ def bicg(A, b, x0=None, *, rtol=1e-5, atol=0., maxiter=None, M=None, callback=No >>> np.allclose(A.dot(x), b) True """ - A, M, x, b = make_system(A, M, x0, b) + A, M, x, b, _ = make_system(A, M, x0, b) bnrm2 = np.linalg.norm(b) atol, _ = _get_atol_rtol('bicg', bnrm2, atol, rtol) @@ -226,7 +231,7 @@ def bicgstab(A, b, x0=None, *, rtol=1e-5, atol=0., maxiter=None, M=None, >>> np.allclose(A.dot(x), b) True """ - A, M, x, b = make_system(A, M, x0, b) + A, M, x, b, _ = make_system(A, M, x0, b) bnrm2 = np.linalg.norm(b) atol, _ = _get_atol_rtol('bicgstab', bnrm2, atol, rtol) @@ -376,44 +381,65 @@ def cg(A, b, x0=None, *, rtol=1e-5, atol=0., maxiter=None, M=None, callback=None >>> np.allclose(A.dot(x), b) True """ - A, M, x, b = make_system(A, M, x0, b) - bnrm2 = np.linalg.norm(b) + A, M, x, b, xp = make_system(A, M, x0, b, reject_nD=False) + batched = A.ndim > 2 + bnrm2 = xp_vector_norm(b, axis=-1) - atol, _ = _get_atol_rtol('cg', bnrm2, atol, rtol) + atol, _ = _get_atol_rtol('cg', bnrm2, atol, rtol, xp=xp) - if bnrm2 == 0: + if not xp.any(bnrm2): return b, 0 - n = len(b) - if maxiter is None: - maxiter = n*10 + maxiter = b.shape[-1] * 10 - dotprod = np.vdot if np.iscomplexobj(x) else np.dot + dotprod = np.vdot if xp.isdtype(x.dtype, "complex floating") else functools.partial(xp.vecdot, axis=-1) matvec = A.matvec psolve = M.matvec - r = b - matvec(x) if x.any() else b.copy() + r = b - matvec(x) if xp.any(x) else xp_copy(b) # Dummy value to initialize var, silences warnings rho_prev, p = None, None for iteration in range(maxiter): - if np.linalg.norm(r) < atol: # Are we done? + converged = xp_vector_norm(r, axis=-1) < atol + if xp.all(converged): return x, 0 z = psolve(r) rho_cur = dotprod(r, z) + if iteration > 0: - beta = rho_cur / rho_prev + if is_lazy_array(converged): + beta = xp.where(~converged, rho_cur / rho_prev, 0.0) + elif not batched: + beta = rho_cur / rho_prev + else: + beta = xp.zeros_like(rho_cur, dtype=float) + mask = ~converged + beta[mask] = rho_cur[mask] / rho_prev[mask] + beta = beta[..., xp.newaxis] + p *= beta p += z else: # First spin - p = np.empty_like(r) - p[:] = z[:] + p = xp.empty_like(r) + p = xpx.at(p)[:, ...].set(z[:, ...]) q = matvec(p) - alpha = rho_cur / dotprod(p, q) + c = dotprod(p, q) + + if is_lazy_array(converged): + alpha = xp.where(~converged, rho_cur / c, 0.0) + elif not batched: + alpha = rho_cur / c + else: + alpha = xp.zeros_like(rho_cur, dtype=float) + mask = ~converged + alpha[mask] = rho_cur[mask] / c[mask] + alpha = alpha[..., xp.newaxis] + x += alpha*p r -= alpha*q rho_prev = rho_cur @@ -496,7 +522,7 @@ def cgs(A, b, x0=None, *, rtol=1e-5, atol=0., maxiter=None, M=None, callback=Non >>> np.allclose(A.dot(x), b) True """ - A, M, x, b = make_system(A, M, x0, b) + A, M, x, b, _ = make_system(A, M, x0, b) bnrm2 = np.linalg.norm(b) atol, _ = _get_atol_rtol('cgs', bnrm2, atol, rtol) @@ -694,7 +720,7 @@ def gmres(A, b, x0=None, *, rtol=1e-5, atol=0., restart=None, maxiter=None, M=No if callback is None: callback_type = None - A, M, x, b = make_system(A, M, x0, b) + A, M, x, b, _ = make_system(A, M, x0, b) matvec = A.matvec psolve = M.matvec n = len(b) @@ -907,7 +933,7 @@ def qmr(A, b, x0=None, *, rtol=1e-5, atol=0., maxiter=None, M1=None, M2=None, True """ A_ = A - A, M, x, b = make_system(A, None, x0, b) + A, M, x, b, _ = make_system(A, None, x0, b) bnrm2 = np.linalg.norm(b) atol, _ = _get_atol_rtol('qmr', bnrm2, atol, rtol) diff --git a/scipy/sparse/linalg/_isolve/lgmres.py b/scipy/sparse/linalg/_isolve/lgmres.py index e44613c94691..41721f2b3903 100644 --- a/scipy/sparse/linalg/_isolve/lgmres.py +++ b/scipy/sparse/linalg/_isolve/lgmres.py @@ -119,7 +119,7 @@ def lgmres(A, b, x0=None, *, rtol=1e-5, atol=0., maxiter=1000, M=None, callback= >>> np.allclose(A.dot(x), b) True """ - A,M,x,b = make_system(A,M,x0,b) + A,M,x,b,_ = make_system(A,M,x0,b) if not np.isfinite(b).all(): raise ValueError("RHS must contain only finite numbers") diff --git a/scipy/sparse/linalg/_isolve/lsmr.py b/scipy/sparse/linalg/_isolve/lsmr.py index 9fd0a30e6d53..5081ebee1315 100644 --- a/scipy/sparse/linalg/_isolve/lsmr.py +++ b/scipy/sparse/linalg/_isolve/lsmr.py @@ -195,6 +195,8 @@ def lsmr(A, b, damp=0.0, atol=1e-6, btol=1e-6, conlim=1e8, """ A = aslinearoperator(A) + if A.ndim > 2: + raise ValueError(f"{A.ndim}-dimensional `A` is unsupported, expected 2-D.") b = atleast_1d(b) if b.ndim > 1: b = b.squeeze() diff --git a/scipy/sparse/linalg/_isolve/lsqr.py b/scipy/sparse/linalg/_isolve/lsqr.py index 3e490a0769e6..abedfb855ffe 100644 --- a/scipy/sparse/linalg/_isolve/lsqr.py +++ b/scipy/sparse/linalg/_isolve/lsqr.py @@ -322,6 +322,8 @@ def lsqr(A, b, damp=0.0, atol=1e-6, btol=1e-6, conlim=1e8, """ A = convert_pydata_sparse_to_scipy(A) A = aslinearoperator(A) + if A.ndim > 2: + raise ValueError(f"{A.ndim}-dimensional `A` is unsupported, expected 2-D.") b = np.atleast_1d(b) if b.ndim > 1: b = b.squeeze() diff --git a/scipy/sparse/linalg/_isolve/minres.py b/scipy/sparse/linalg/_isolve/minres.py index b5933175e589..6991fcd43f30 100644 --- a/scipy/sparse/linalg/_isolve/minres.py +++ b/scipy/sparse/linalg/_isolve/minres.py @@ -89,7 +89,7 @@ def minres(A, b, x0=None, *, rtol=1e-5, shift=0.0, maxiter=None, https://web.stanford.edu/group/SOL/software/minres/minres-matlab.zip """ - A, M, x, b = make_system(A, M, x0, b) + A, M, x, b, _ = make_system(A, M, x0, b) matvec = A.matvec psolve = M.matvec diff --git a/scipy/sparse/linalg/_isolve/tests/test_iterative.py b/scipy/sparse/linalg/_isolve/tests/test_iterative.py index cd15445a8443..55e89a9cc4d0 100644 --- a/scipy/sparse/linalg/_isolve/tests/test_iterative.py +++ b/scipy/sparse/linalg/_isolve/tests/test_iterative.py @@ -809,3 +809,13 @@ def x_cb(x): restart=10, callback_type='x') assert info == 20 assert count[0] == 20 + + +def test_nD(solver): + """Check that >2-D operators are rejected cleanly.""" + def id(x): + return x + A = LinearOperator(shape=(2, 2, 2), matvec=id, dtype=np.float64) + b = np.ones((2, 2)) + with pytest.raises(ValueError, match="expected 2-D"): + solver(A, b) diff --git a/scipy/sparse/linalg/_isolve/tests/test_lsmr.py b/scipy/sparse/linalg/_isolve/tests/test_lsmr.py index 2a468305cba5..a9e32daa3726 100644 --- a/scipy/sparse/linalg/_isolve/tests/test_lsmr.py +++ b/scipy/sparse/linalg/_isolve/tests/test_lsmr.py @@ -16,13 +16,13 @@ """ +import numpy as np from numpy import array, arange, eye, zeros, ones, transpose, hstack from numpy.linalg import norm from numpy.testing import assert_allclose import pytest from scipy.sparse import coo_array -from scipy.sparse.linalg._interface import aslinearoperator -from scipy.sparse.linalg import lsmr +from scipy.sparse.linalg import lsmr, LinearOperator, aslinearoperator from .test_lsqr import G, b @@ -194,3 +194,12 @@ def test_lsmr_maxiter_zero_with_show(): assert result[0] is not None # x should exist assert result[1] == 0 # istop should be 0 assert result[2] == 0 # itn should be 0 + +def test_nD(): + """Check that >2-D operators are rejected cleanly.""" + def id(x): + return x + A = LinearOperator(shape=(2, 2, 2), matvec=id, dtype=np.float64) + b = np.ones((2, 2)) + with pytest.raises(ValueError, match="expected 2-D"): + lsmr(A, b) diff --git a/scipy/sparse/linalg/_isolve/tests/test_lsqr.py b/scipy/sparse/linalg/_isolve/tests/test_lsqr.py index d77048af48a6..339704b0a339 100644 --- a/scipy/sparse/linalg/_isolve/tests/test_lsqr.py +++ b/scipy/sparse/linalg/_isolve/tests/test_lsqr.py @@ -2,8 +2,7 @@ from numpy.testing import assert_allclose, assert_array_equal, assert_equal import pytest import scipy.sparse -import scipy.sparse.linalg -from scipy.sparse.linalg import lsqr +from scipy.sparse.linalg import lsqr, LinearOperator # Set up a test problem n = 35 @@ -118,3 +117,12 @@ def test_initialization(): x = lsqr(G, b, show=show, atol=tol, btol=tol, iter_lim=maxit, x0=x0) assert_allclose(x_ref[0], x[0]) assert_array_equal(b_copy, b) + +def test_nD(): + """Check that >2-D operators are rejected cleanly.""" + def id(x): + return x + A = LinearOperator(shape=(2, 2, 2), matvec=id, dtype=np.float64) + b = np.ones((2, 2)) + with pytest.raises(ValueError, match="expected 2-D"): + lsqr(A, b) diff --git a/scipy/sparse/linalg/_isolve/tfqmr.py b/scipy/sparse/linalg/_isolve/tfqmr.py index 3904c2766341..9b5d52bbf9dc 100644 --- a/scipy/sparse/linalg/_isolve/tfqmr.py +++ b/scipy/sparse/linalg/_isolve/tfqmr.py @@ -97,7 +97,7 @@ def tfqmr(A, b, x0=None, *, rtol=1e-5, atol=0., maxiter=None, M=None, if np.issubdtype(b.dtype, np.int64): b = b.astype(dtype) - A, M, x, b = make_system(A, M, x0, b) + A, M, x, b, _ = make_system(A, M, x0, b) # Check if the R.H.S is a zero vector if np.linalg.norm(b) == 0.: diff --git a/scipy/sparse/linalg/_isolve/utils.py b/scipy/sparse/linalg/_isolve/utils.py index 28925f48014b..75f34cd6e24e 100644 --- a/scipy/sparse/linalg/_isolve/utils.py +++ b/scipy/sparse/linalg/_isolve/utils.py @@ -3,39 +3,25 @@ __all__ = [] -from numpy import asanyarray, asarray, array, zeros - -from scipy.sparse.linalg._interface import aslinearoperator, LinearOperator, \ - IdentityOperator - -_coerce_rules = {('f','f'):'f', ('f','d'):'d', ('f','F'):'F', - ('f','D'):'D', ('d','f'):'d', ('d','d'):'d', - ('d','F'):'D', ('d','D'):'D', ('F','f'):'F', - ('F','d'):'D', ('F','F'):'F', ('F','D'):'D', - ('D','f'):'D', ('D','d'):'D', ('D','F'):'D', - ('D','D'):'D'} - - -def coerce(x,y): - if x not in 'fdFD': - x = 'd' - if y not in 'fdFD': - y = 'd' - return _coerce_rules[x,y] +import numpy as np +import scipy +from scipy.sparse.linalg._interface import ( + _xp_aslinearoperator, LinearOperator, IdentityOperator +) +from scipy._lib._array_api import array_namespace, is_lazy_array, xp_copy, xp_ravel, xp_result_type, _asarray def id(x): return x - -def make_system(A, M, x0, b): +def make_system(A, M, x0, b, reject_nD=True): """Make a linear system Ax=b Parameters ---------- A : LinearOperator sparse or dense matrix (or any valid input to aslinearoperator) - M : {LinearOperator, Nones} + M : {LinearOperator, None} preconditioner sparse or dense matrix (or any valid input to aslinearoperator) x0 : {array_like, str, None} @@ -47,7 +33,7 @@ def make_system(A, M, x0, b): Returns ------- - (A, M, x, b) + (A, M, x, b, xp) A : LinearOperator matrix of the linear system M : LinearOperator @@ -56,33 +42,47 @@ def make_system(A, M, x0, b): initial guess b : rank 1 ndarray right hand side + xp : compatible array namespace """ A_ = A - A = aslinearoperator(A) - - if A.shape[0] != A.shape[1]: - raise ValueError(f'expected square matrix, but got shape={(A.shape,)}') - - N = A.shape[0] - - b = asanyarray(b) - - if not (b.shape == (N,1) or b.shape == (N,)): + A, xp = _xp_aslinearoperator(A) + if reject_nD and A.ndim > 2: + raise ValueError(f"{A.ndim}-dimensional `A` is unsupported, expected 2-D.") + # lazy = is_lazy_array(xp.empty(0)) + lazy = False + + N = A.shape[-2] + if not lazy and N != A.shape[-1]: + raise ValueError(f'expected square matrix or stack of square matrices, but got shape={(A.shape,)}') + + xp_b = array_namespace(b) + if xp_b != xp: + msg = f"Mismatched array namespaces. Namespace for A is {xp}, namespace for b is {xp_b}" + raise TypeError(msg) + + b = _asarray(b, subok=True, xp=xp) + + column_vector = not lazy and b.ndim == 2 and b.shape[-2:] == (N, 1) # maintain column vector backwards-compatibility in 2-D case + row_vector = b.shape[-1] == N # otherwise treat as a row-vector + + if not lazy and not (column_vector or row_vector): raise ValueError(f'shapes of A {A.shape} and b {b.shape} are ' 'incompatible') - if b.dtype.char not in 'fdFD': - b = b.astype('d') # upcast non-FP types to double + if not xp.isdtype(b.dtype, ("real floating", "complex floating")): + b = xp.astype(b, xp.float64) # upcast non-FP types to float64 - if hasattr(A,'dtype'): - xtype = A.dtype.char + if hasattr(A, 'dtype'): + x_dtype = A.dtype else: - xtype = A.matvec(b).dtype.char - xtype = coerce(xtype, b.dtype.char) + x_dtype = A.matvec(b).dtype + # XXX: does this match the previous coercion? + x_dtype = xp_result_type(x_dtype, b.dtype, force_floating=True, xp=xp) - b = asarray(b,dtype=xtype) # make b the same type as x - b = b.ravel() + b = xp.astype(b, x_dtype) # make b the same type as x + if column_vector: + b = xp_ravel(b) # process preconditioner if M is None: @@ -95,27 +95,36 @@ def make_system(A, M, x0, b): else: rpsolve = id if psolve is id and rpsolve is id: - M = IdentityOperator(shape=A.shape, dtype=A.dtype) + M = IdentityOperator(shape=A.shape, dtype=A.dtype, xp=xp) else: M = LinearOperator(A.shape, matvec=psolve, rmatvec=rpsolve, - dtype=A.dtype) + dtype=A.dtype, xp=xp) else: - M = aslinearoperator(M) + M, xp_M = _xp_aslinearoperator(M) + if xp_M != xp: + msg = f"Mismatched array namespaces. Namespace for A is {xp}, namespace for M is {xp_M}" + raise TypeError(msg) if A.shape != M.shape: raise ValueError('matrix and preconditioner have different shapes') # set initial guess if x0 is None: - x = zeros(N, dtype=xtype) + x = xp.zeros((*M.shape[:-2], N), dtype=x_dtype) + # XXX: proper error handling for `x0` of type `str` but not equal to `'Mb'`? elif isinstance(x0, str): if x0 == 'Mb': # use nonzero initial guess ``M @ b`` - bCopy = b.copy() + bCopy = xp_copy(b) x = M.matvec(bCopy) else: - x = array(x0, dtype=xtype) - if not (x.shape == (N, 1) or x.shape == (N,)): + x = xp.asarray(x0, dtype=x_dtype, copy=True) + + column_vector = x.ndim == 2 and x.shape[-2:] == (N, 1) # maintain column vector backwards-compatibility in 2-D case + row_vector = x.shape[-1] == N # otherwise treat as a row-vector + + if not (row_vector or column_vector): raise ValueError(f'shapes of A {A.shape} and ' f'x0 {x.shape} are incompatible') - x = x.ravel() + if column_vector: + x = xp_ravel(x) - return A, M, x, b + return A, M, x, b, xp diff --git a/scipy/sparse/linalg/tests/test_interface.py b/scipy/sparse/linalg/tests/test_interface.py index 2dab9c07283e..758f509f28ff 100644 --- a/scipy/sparse/linalg/tests/test_interface.py +++ b/scipy/sparse/linalg/tests/test_interface.py @@ -4,7 +4,7 @@ from functools import partial from itertools import product import operator -from typing import NamedTuple +from typing import NamedTuple, Literal import pytest from pytest import raises as assert_raises, warns from numpy.testing import assert_, assert_equal, assert_allclose @@ -224,33 +224,36 @@ class TestDotTests: """ class OperatorArgs(NamedTuple): """ - shape: shape of the operator + shape: (core) shape of the operator op_dtype: dtype of the operator data_dtype: real dtype corresponding to op_dtype for data generation complex: the operator has a complex dtype + batch_shape: batch shape of the operator """ shape: tuple[int, ...] op_dtype: str data_dtype: str complex: bool + batch_shape: tuple[int, ...] real_square_args: OperatorArgs = OperatorArgs( - (12, 12), "float64", "float64", False + (12, 12), "float64", "float64", False, (4,) ) + # TODO: batch shape (0,) integer_square_args: OperatorArgs = OperatorArgs( - (9, 9), "int32", "float32", False + (9, 9), "int32", "float32", False, (3, 4, 5) ) complex_square_args: OperatorArgs = OperatorArgs( - (13, 13), "complex64", "float32", True + (13, 13), "complex64", "float32", True, () ) real_overdetermined_args: OperatorArgs = OperatorArgs( - (17, 11), "float64", "float64", False + (17, 11), "float64", "float64", False, (3,) ) complex_overdetermined_args: OperatorArgs = OperatorArgs( - (17, 11), "complex128", "float64", True + (17, 11), "complex128", "float64", True, (3,) ) real_underdetermined_args: OperatorArgs = OperatorArgs( - (5, 9), "float64", "float64", False + (5, 9), "float64", "float64", False, (3,) ) square_args_list: list[OperatorArgs] = [ @@ -285,11 +288,20 @@ def check_matvec( """ rng = np.random.default_rng(42) - u = rng.standard_normal(op.shape[-1], dtype=data_dtype) - v = rng.standard_normal(op.shape[-2], dtype=data_dtype) + dtype = np.dtype(data_dtype) + u = rng.standard_normal(op.shape[-1], dtype=dtype) + v = rng.standard_normal(op.shape[-2], dtype=dtype) if complex_data: - u = u + (1j * rng.standard_normal(op.shape[-1], dtype=data_dtype)) - v = v + (1j * rng.standard_normal(op.shape[-2], dtype=data_dtype)) + u = u + (1j * rng.standard_normal(op.shape[-1], dtype=dtype)) + v = v + (1j * rng.standard_normal(op.shape[-2], dtype=dtype)) + + # TODO: handle empty batches + # TODO: test vectors with different but broadcastable batch shapes? + # Test `u` and `v` with the same batch shape as `op` + if batch_shape := op.shape[:-2]: + batch_scale = rng.standard_normal((*batch_shape, 1), dtype=dtype) + u = batch_scale * u + v = batch_scale * v op_u = op.matvec(u) opH_v = op.rmatvec(v) @@ -336,11 +348,18 @@ def check_matmat( rng = np.random.default_rng(42) k = rng.integers(2, 100) - U = rng.standard_normal(size=(op.shape[-1], k), dtype=data_dtype) - V = rng.standard_normal(size=(op.shape[-2], k), dtype=data_dtype) + dtype = np.dtype(data_dtype) + U = rng.standard_normal(size=(op.shape[-1], k), dtype=dtype) + V = rng.standard_normal(size=(op.shape[-2], k), dtype=dtype) if complex_data: - U = U + (1j * rng.standard_normal(size=(op.shape[-1], k), dtype=data_dtype)) - V = V + (1j * rng.standard_normal(size=(op.shape[-2], k), dtype=data_dtype)) + U = U + (1j * rng.standard_normal(size=(op.shape[-1], k), dtype=dtype)) + V = V + (1j * rng.standard_normal(size=(op.shape[-2], k), dtype=dtype)) + + # TODO: handle empty batches + if batch_shape := op.shape[:-2]: + batch_scale = rng.standard_normal((*batch_shape, 1, 1), dtype=dtype) + U = batch_scale * U + V = batch_scale * V op_U = op.matmat(U) opH_V = op.rmatmat(V) @@ -355,8 +374,8 @@ def check_matmat( assert_allclose(op_U, op.dot(U)) assert_allclose(opH_V, op.H.dot(V)) - op_U_H = np.conj(op_U).T - UH = np.conj(U).T + op_U_H = np.conj(op_U).mT + UH = np.conj(U).mT op_U_H_V = np.matmul(op_U_H, V) UH_opH_V = np.matmul(UH, opH_V) @@ -365,14 +384,18 @@ def check_matmat( assert_allclose(op_U_H_V, UH_opH_V, rtol=rtol) @pytest.mark.parametrize("args", square_args_list) - def test_identity_square(self, args): - """Simple identity operator on square matrices""" + def test_identity_square(self, args: OperatorArgs): + """ + Simple identity operator on square matrices. + Tests batches of RHS via `args.batch_shape`. + """ def identity(x): return x + shape = args.batch_shape + args.shape op = interface.LinearOperator( - shape=args.shape, dtype=args.op_dtype, - matvec=identity, rmatvec=identity + shape=shape, dtype=args.op_dtype, + matvec=identity, rmatvec=identity, ) self.check_matvec(op, data_dtype=args.data_dtype, complex_data=args.complex) @@ -380,37 +403,57 @@ def identity(x): @pytest.mark.parametrize("args", all_args_list) def test_identity_nonsquare(self, args): - """Identity operator with zero-padding on non-square matrices""" + """ + Identity operator with zero-padding on non-square matrices. + Tests batches of RHS via `args.batch_shape`. + """ + M, N = args.shape + def mv(x): # handle column vectors too # (`LinearOperator` handles reshape in post-processing) - x = x.flatten() + if x.shape[-2:] == (N, 1): + x = np.reshape(x, (*x.shape[:-2], -1)) + + x_broadcast_dims = x.shape[:-1] - match np.sign(x.shape[0] - args.shape[-2]): + match np.sign(x.shape[-1] - M): case 0: # square return x case 1: # crop x to size - return x[:args.shape[-2]] + return x[..., :M] case -1: # pad with zeros - pad_width = (0, args.shape[-2] - x.shape[0]) - return np.pad(x, pad_width, mode='constant', constant_values=0) + no_padding = [(0, 0)] * len(x_broadcast_dims) + pad_width = (0, M - x.shape[-1]) + return np.pad( + x, (*no_padding, pad_width), + mode='constant', constant_values=0 + ) def rmv(x): # handle column vectors too # (`LinearOperator` handles reshape in post-processing) - x = x.flatten() + if x.shape[-2:] == (M, 1): + x = np.reshape(x, (*x.shape[:-2], -1)) + + x_broadcast_dims = x.shape[:-1] - match np.sign(args.shape[-1] - x.shape[0]): + match np.sign(N - x.shape[-1]): case 0: # square return x case 1: # pad with zeros - pad_width = (0, args.shape[-1] - x.shape[0]) - return np.pad(x, pad_width, mode='constant', constant_values=0) + no_padding = [(0, 0)] * len(x_broadcast_dims) + pad_width = (0, N - x.shape[-1]) + return np.pad( + x, (*no_padding, pad_width), + mode='constant', constant_values=0 + ) case -1: # crop x to size - return x[:args.shape[-1]] + return x[..., :N] + shape = args.batch_shape + args.shape op = interface.LinearOperator( - shape=args.shape, dtype=args.op_dtype, matvec=mv, rmatvec=rmv + shape=shape, dtype=args.op_dtype, matvec=mv, rmatvec=rmv ) self.check_matvec(op, data_dtype=args.data_dtype, complex_data=args.complex) @@ -418,15 +461,19 @@ def rmv(x): @pytest.mark.parametrize("args", square_args_list) def test_scaling_square(self, args): - """Simple (complex) scaling operator on square matrices""" + """ + Simple (complex) scaling operator on square matrices. + Tests batches of RHS via `args.batch_shape`. + """ def scale(x): return (3 + 2j) * x def r_scale(x): return (3 - 2j) * x + shape = args.batch_shape + args.shape op = interface.LinearOperator( - shape=args.shape, dtype=args.op_dtype, matvec=scale, rmatvec=r_scale + shape=shape, dtype=args.op_dtype, matvec=scale, rmatvec=r_scale ) self.check_matvec( op, data_dtype=args.data_dtype, complex_data=args.complex, @@ -437,10 +484,13 @@ def r_scale(x): check_operators=True, check_dot=True ) - def test_subclass_matmat(self): + # TODO: test empty batches + @pytest.mark.parametrize("batch_shape", [(), (3,), (2, 3, 4)]) + def test_subclass_matmat(self, batch_shape: tuple[int, ...]): """ Simple rotation operator defined by `matmat` and `adjoint`, subclassing `LinearOperator`. + Tests batches of RHS via `batch_shape`. """ def rmatmat(X): theta = np.pi / 2 @@ -470,7 +520,7 @@ def _adjoint(self): theta = np.pi / 2 dtype = "float64" - op = RotOp(shape=(2, 2), dtype=dtype, theta=theta) + op = RotOp(shape=(*batch_shape, 2, 2), dtype=dtype, theta=theta) self.check_matvec( op, data_dtype=dtype, complex_data=False, @@ -481,13 +531,30 @@ def _adjoint(self): check_operators=True, check_dot=True ) + # TODO: test empty batches + @pytest.mark.parametrize("batch_shape", [(), (3,), (2, 3, 4)]) @pytest.mark.parametrize( - "matrix", [ - np.asarray([[1, 2j, 3j], [4j, 5j, 6]]), - sparse.random_array((5, 5)) - ] + "format", ["dense", "sparse"] ) - def test_aslinearop(self, matrix): + def test_aslinearop( + self, format: Literal["dense", "sparse"], batch_shape: tuple[int, ...] + ): + """ + Test operators coming from `aslinearoperator`, + *including batched LHS*. + """ + rng = np.random.default_rng(42) + constructor = sparse.random_array if format == "sparse" else rng.standard_normal + if batch_shape: + batch_size = np.prod(batch_shape) + core_matrices = [constructor((4, 4)) for _ in range(batch_size)] + if format == "sparse": + matrix = sparse.vstack(core_matrices).reshape(batch_shape + (4, 4)) + else: + stacked = np.stack(core_matrices, axis=0) + matrix = np.reshape(stacked, batch_shape + (4, 4)) + else: + matrix = constructor((4, 4)) op = interface.aslinearoperator(matrix) data_dtype = "float64" self.check_matvec(op, data_dtype=data_dtype, complex_data=True)