-
Notifications
You must be signed in to change notification settings - Fork 19
Open
Description
Hi
Thank you for implementing the forward Lap. It significantly accelerates the Laplacian calculation. However, I am encountering an error when attempting to use sparsity. It seems that the sparsity is not being applied successfully. Here are the details of the issue:

Next is the minimal code to reproduce the error.
import folx
import jax
import time
def fwd(x):
x = x.reshape(-1, 3)
distances = jnp.sqrt(jnp.sum(jnp.square(x), axis=1))
sph = jnp.zeros((2, distances.shape[0]))
sph = sph.at[0].set(distances)
sph = sph.at[1].set(distances * 5.0)
return jnp.sum(x)
key = jax.random.PRNGKey(12)
x = jax.random.normal(key, (100,300))
lapl = jax.jit(jax.vmap(folx.ForwardLaplacianOperator(6)(fwd)))
jax.block_until_ready(lapl(x))
start_time = time.time()
jax.block_until_ready(lapl(x))
end_time = time.time()
print(end_time - start_time)
Metadata
Metadata
Assignees
Labels
No labels