Skip to content

Commit 6fe08f3

Browse files
hawkinspThe tunix Authors
authored andcommitted
[JAX] Replace reference to jax._src.lib.xla_client.SingleDeviceSharding with jax.sharding.SingleDeviceSharding, which is its public name.
PiperOrigin-RevId: 826203493
1 parent 1d94628 commit 6fe08f3

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

tests/distillation/distillation_trainer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def test_distributed_training(self):
242242

243243
self.assertIsInstance(
244244
unsharded_variables.layers[0].w1.kernel.value.sharding,
245-
jax._src.lib.xla_client.SingleDeviceSharding,
245+
jax.sharding.SingleDeviceSharding,
246246
)
247247
jax.tree.map_with_path(tc.assert_close, variables, unsharded_variables)
248248

tests/sft/checkpoint_manager_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,11 @@ def test_restore_different_sharding(self):
174174
# Check the model shardings are restored correctly.
175175
self.assertIsInstance(
176176
unsharded_variables.w1.kernel.value.sharding,
177-
jax._src.lib.xla_client.SingleDeviceSharding,
177+
jax.sharding.SingleDeviceSharding,
178178
)
179179
self.assertIsInstance(
180180
unsharded_variables.w2.kernel.value.sharding,
181-
jax._src.lib.xla_client.SingleDeviceSharding,
181+
jax.sharding.SingleDeviceSharding,
182182
)
183183

184184
# Restore the model with shardings.

tests/sft/peft_trainer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def test_dist_training(self):
309309
unsharded_variables = nnx.state(unsharded_model, nnx.Param)
310310
self.assertIsInstance(
311311
unsharded_variables.layers[0].w1.kernel.value.sharding,
312-
jax._src.lib.xla_client.SingleDeviceSharding,
312+
jax.sharding.SingleDeviceSharding,
313313
)
314314
jax.tree.map_with_path(tc.assert_close, variables, unsharded_variables)
315315

0 commit comments

Comments
 (0)