You say
It is noteworthy that the parallel version will appear to be much slower due to a slow compilation in JAX. This could be improved by using a different implementation of the associative scan or by fixing the number of levels the way it is done in TensorFlow Probability.
What do you mean by 'fixing the number of levels'?
