Skip to content

Commit 90cab09

Browse files
committed
add a check on jax version to solve the version error
1 parent 5641f6c commit 90cab09

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

axlearn/common/array_serialization.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,10 +306,12 @@ async def _async_serialize(
306306
and arr_inp.is_fully_addressable
307307
)
308308
# pylint: disable=protected-access
309-
spec_has_metadata = {
310-
"0.6.2": lambda: serialization.ts_impl._spec_has_metadata,
311-
"0.5.3": lambda: serialization._spec_has_metadata,
312-
}[jax.__version__]()
309+
if jax.__version__.startswith("0.8.0") or jax.__version__ == "0.6.2":
310+
spec_has_metadata = serialization.ts_impl._spec_has_metadata
311+
elif jax.__version__ == "0.5.3":
312+
spec_has_metadata = serialization._spec_has_metadata
313+
else:
314+
raise ValueError(f"Unsupported JAX version for spec_has_metadata: {jax.__version__}")
313315
if not spec_has_metadata(tensorstore_spec):
314316
# pylint: disable-next=protected-access
315317
tensorstore_spec["metadata"] = serialization._get_metadata(arr_inp)

0 commit comments

Comments
 (0)