diff --git a/chex/_src/restrict_backends.py b/chex/_src/restrict_backends.py index 4487a5c4..c4ffe8ce 100644 --- a/chex/_src/restrict_backends.py +++ b/chex/_src/restrict_backends.py @@ -80,20 +80,20 @@ def is_allowed(backend_platform): return ((backend_platform in allowed) if allowed is not None else (backend_platform not in forbidden)) - inner_backend_compile = compiler.backend_compile + inner_backend_compile_and_load = compiler.backend_compile_and_load - @functools.wraps(inner_backend_compile) + @functools.wraps(inner_backend_compile_and_load) def wrapper(backend, *args, **kwargs): if not is_allowed(backend.platform): raise RestrictedBackendError( f'Compiling a JAX program for {backend.platform} is forbidden by ' f'restrict_backends().') - return inner_backend_compile(backend, *args, **kwargs) + return inner_backend_compile_and_load(backend, *args, **kwargs) try: - compiler.backend_compile = wrapper + compiler.backend_compile_and_load = wrapper yield finally: - backend_compile = compiler.backend_compile - assert backend_compile is wrapper, backend_compile - compiler.backend_compile = inner_backend_compile + backend_compile_and_load = compiler.backend_compile_and_load + assert backend_compile_and_load is wrapper, backend_compile_and_load + compiler.backend_compile_and_load = inner_backend_compile_and_load