From 2a50fab0ff20ed156480b5826267b455918e8a30 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 4 Feb 2025 14:27:35 -0800 Subject: [PATCH] Avoid use of deprecated xla_bridge.get_backend().live_buffers() xla_bridge.get_backend is deprecated, and the public API for this is jax.live_arrays(). This is a drop-in replacement with no change of behavior. PiperOrigin-RevId: 723227225 --- trax/optimizers/trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/trax/optimizers/trainer.py b/trax/optimizers/trainer.py index d4761fc0a..4b1adbb54 100644 --- a/trax/optimizers/trainer.py +++ b/trax/optimizers/trainer.py @@ -440,8 +440,7 @@ def _collect_weights(self, layer): def _free_accelerators(self, exceptions=(), keep_constants=True): """Deletes all live buffers from accelerator with no safety guarantees.""" - backend = jax.lib.xla_bridge.get_backend() - live_buffers = backend.live_buffers() + live_buffers = jax.live_arrays() logging.info('Deleting %d live buffers.', len(live_buffers)) exceptions_buffers = [] for x in fastmath.tree_flatten(exceptions):