-
Notifications
You must be signed in to change notification settings - Fork 0
wip: N-D LinearOperator and cg
#35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
scipy/sparse/linalg/_interface.py
Outdated
|
|
||
| return np.hstack([self.matvec(col.reshape(-1,1)) for col in X.T]) | ||
| # X.mT here? | ||
| return np.hstack([self.matvec(col.reshape(-1, 1)) for col in X.T]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the new matvec can handle batched dims, do we need hstack or can it just be done all in one go?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@izaid to maintain compatibility with user-defined matvec functions which do not support batching, I think we do need to keep this hstack.
I think it could be done all in one go for batched matvec functions. But I think that would have to be opt-in in the API.
|
Yes! This looks good in general, basically what I was thinking. Will this error in all previous cases if someone passes an ND LinearOperator to an existing method. For instance, |
that is also TODO 👍 |
39c5712 to
fe03769
Compare
3898d37 to
e141baa
Compare
e141baa to
b1e89ed
Compare
b1e89ed to
5e49f58
Compare
7ceafca to
b6739f0
Compare
0569d5b to
fa54350
Compare
|
Prototype working with In [1]: import array_api_strict as xp; from scipy.sparse.linalg import cg, LinearOperator; import numpy as np
In [2]: def solve(N, batch, report_index=0, batched=False):
...: rng = np.random.default_rng(0)
...: M = rng.standard_normal((N, N))
...: M = xp.asarray(M)
...: reg = 1e-3
...:
...: if batched:
...: M = xp.broadcast_to(M[xp.newaxis, ...], (batch, *M.shape))
...:
...: def matvec(x):
...: return xp.squeeze(M.mT @ (M @ x[..., xp.newaxis]), axis=-1) + reg * x
...:
...: shape = (batch, N, N) if batched else (N, N)
...: A = LinearOperator(shape, matvec=matvec, dtype=xp.float64, xp=xp)
...:
...: b = rng.standard_normal(N)
...: b = xp.asarray(b)
...:
...: if batched:
...: b = xp.reshape(xp.arange(batch, dtype=xp.float64), (batch, 1)) * b
...: x, info = cg(A, b, atol=1e-8, maxiter=5000)
...: assert info == 0
...: print(f"{x[report_index, ...]}")
...: else:
...: for i in xp.arange(batch, dtype=xp.float64):
...: x, info = cg(A, i*b, atol=1e-8, maxiter=5000)
...: assert info == 0
...: if i == report_index:
...: print(x)
...:
In [3]: solve(5, 10, report_index=7)
Array([ 10.91985197, -5.53737923,
-6.96397906, -35.6473016 ,
13.48931722], dtype=array_api_strict.float64)
In [4]: solve(5, 10, report_index=7, batched=True)
Array([ 10.91985197, -5.53737923,
-6.96397906, -35.6473016 ,
13.48931722], dtype=array_api_strict.float64)
In [5]: import jax
...: jax.config.update("jax_enable_x64", True)
In [6]: import jax.numpy as xp
In [7]: solve(5, 10, report_index=7, batched=True)
[ 10.91985197 -5.53737923 -6.96397906 -35.6473016 13.48931722]For JIT I think we will need to do some fancy registration of the linear operator classes as PyTrees, along the lines of https://docs.jax.dev/en/latest/_autosummary/jax.tree_util.register_pytree_node.html#jax.tree_util.register_pytree_node. |
|
Oh that's great! Yes, might need to register LinearOperator as PyTrees, but not such a big deal I think? |
|
I think the JIT will probably require a separate backend that uses JAX-specific things. In particular, for: converged = xp_vector_norm(r, axis=-1) < atol
if xp.all(converged):
return x, 0I think we will need to use something from https://docs.jax.dev/en/latest/control-flow.html#structured-control-flow-primitives. |
ed2646c to
536cc49
Compare
6ca880b to
7f4313b
Compare
8d19eef to
c055795
Compare
| maxiter = b.shape[-1] * 10 | ||
|
|
||
| dotprod = np.vdot if np.iscomplexobj(x) else np.dot | ||
| dotprod = np.vdot if xp.isdtype(x.dtype, "complex floating") else functools.partial(xp.vecdot, axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may need axis=-1 for the complex case too
| raise ValueError(msg) | ||
|
|
||
| atol = max(float(atol), float(rtol) * float(b_norm)) | ||
| atol = xp.max(xp.stack((xp.asarray(float(atol)), float(rtol) * xp.min(b_norm)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we support different atols for different systems? If not, should this be xp.min or xp.max?
c055795 to
ad02511
Compare
6e3a68f to
bc57a7c
Compare
bc57a7c to
3afd4fa
Compare
Building on scipy#23836
https://github.com/kokkos/kokkos-kernels/blob/develop/batched/sparse/impl/KokkosBatched_CG_Team_Impl.hpp
https://ieeexplore.ieee.org/document/10054414 section VI