-
Notifications
You must be signed in to change notification settings - Fork 19
Description
Hello, thank you for the great library — it makes the Laplacian calculations much faster! However, I'm having trouble with sparsity when trying to use folx with a LapNet-like architecture that includes sparse derivative attention blocks. I believe the Jacobian sparsity is not being recognized correctly by folx. I've condensed my issue to a minimal example shown bellow.
x is an input array of shape (N,D) and S is an intermediate matrix of shape (N,N) with sparse derivative. Therefore, the speedup resulting from the sparse treatment of the Jacobian should be most notable for high N. However, even for N = 400, the speedup is only marginal. I noticed that, by rewriting the last matrix multiplication column by column for x (x has just 2 columns in this example), I can achieve the desired sparsity speedup of more than 15x.
Could you help me understand, why the two codes with the same outputs perform so differently and how to make the first example equally fast? Thanks for your help!
import folx
import jax
import jax.numpy as jnp
def attn_minimal(x):
# S = exp(x @ x^T)
S = jnp.exp(jnp.matmul(x, jnp.swapaxes(x, -2, -1))) # shape (N, N)
h = jnp.matmul(S, x)
return h.mean()
def attn_minimal_manual_matmul(x):
S = jnp.exp(jnp.matmul(x, jnp.swapaxes(x, -2, -1))) # shape (N, N)
h0 = jnp.matmul(S, x[...,[0]])
h1 = jnp.matmul(S, x[...,[1]])
h = jnp.concatenate([h0, h1], axis=-1)
return h.mean()
N = 400
D = 2
x = jax.random.normal(jax.random.PRNGKey(1), (1,N,D))
# Forward laplacian without sparsity, full matmul
lapl = jax.jit(jax.vmap(folx.ForwardLaplacianOperator(0)(attn_minimal)))
jax.block_until_ready(lapl(x))
%timeit jax.block_until_ready(lapl(x)) # 9.57 ms ± 60.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# Forward laplacian with sparsity, full matmul
lapl = jax.jit(jax.vmap(folx.ForwardLaplacianOperator(0.6)(attn_minimal)))
jax.block_until_ready(lapl(x))
%timeit jax.block_until_ready(lapl(x)) # 7.57 ms ± 90.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# Forward laplacian without sparsity, column by column matmul
lapl = jax.jit(jax.vmap(folx.ForwardLaplacianOperator(0)(attn_minimal_manual_matmul)))
jax.block_until_ready(lapl(x))
%timeit jax.block_until_ready(lapl(x)) # 8.7 ms ± 166 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# Forward laplacian with sparsity, column by column matmul
lapl = jax.jit(jax.vmap(folx.ForwardLaplacianOperator(0.6)(attn_minimal_manual_matmul)))
jax.block_until_ready(lapl(x))
%timeit jax.block_until_ready(lapl(x)) # 559 μs ± 14.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)