From 053a0bc0fd52fe3804ecb7551a96c0b09129d352 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Fri, 3 Apr 2026 13:47:41 +0100 Subject: [PATCH 1/3] ENH: array_namespace: support `torch.compile` Co-authored-by: Evgeni Burovski --- array_api_compat/common/_helpers.py | 18 +++++++++++++++--- tests/test_no_dependencies.py | 2 +- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 8a307f9d..45d1978e 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -641,7 +641,7 @@ def your_function(x, y): is_pydata_sparse_array """ - namespaces: set[Namespace] = set() + namespaces: list[Namespace] = [] for x in xs: xp, info = _cls_to_namespace(cast(Hashable, type(x)), api_version, use_compat) if info is _ClsToXPInfo.SCALAR: @@ -663,10 +663,22 @@ def your_function(x, y): ) xp = get_ns(api_version=api_version) - namespaces.add(xp) + namespaces.append(xp) + + # Use a list of modules to avoid a graph break under torch.compile: + # torch._dynamo.exc.Unsupported: Dynamo cannot determine whether the underlying object is hashable + # Explanation: Dynamo does not know whether the underlying python object for + # PythonModuleVariable( Date: Fri, 3 Apr 2026 14:24:02 +0100 Subject: [PATCH 2/3] Apply suggestion from @lucascolley --- array_api_compat/common/_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 45d1978e..10134367 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -678,7 +678,7 @@ def your_function(x, y): namespaces = unique_namespaces try: - (xp,) = tuple(namespaces) + (xp,) = namespaces return xp except ValueError: if not namespaces: From 530e856bb40ad79f48a4da962cc8da55c0eef272 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 3 Apr 2026 20:08:01 +0200 Subject: [PATCH 3/3] TST: add a test of torch.compile-ing array_namespace (#414) --- tests/test_torch.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_torch.py b/tests/test_torch.py index 35ef5dda..b064a46d 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -161,3 +161,18 @@ def test_round(): r = xp.round(x, decimals=1, out=o) assert xp.all(r == o) assert r is o + + +def test_dynamo_array_namespace(): + """Check that torch.compiling array_namespace does not incur graph breaks.""" + from array_api_compat import array_namespace + + def foo(x): + xp = array_namespace(x) + return xp.multiply(x, x) + + bar = torch.compile(fullgraph=True)(foo) + + x = torch.arange(3) + y = bar(x) + assert xp.all(y == x**2)