https://github.com/google/praxis/blob/43db2717e0d2e9ef09b566f7e7bbad049d63dceb/praxis/layers/gpu_fast_attention.py#L133 https://github.com/google/jax/blame/main/jax/experimental/pallas/ops/attention.py#L163 seems like jax has added `segment_ids` as required argument but praxis has not updated to add the argument