Skip to content

Commit 8301b81

Browse files
author
The tunix Authors
committed
Remove explicit sharding after applying LoRA.
PiperOrigin-RevId: 826211980
1 parent 6fe08f3 commit 8301b81

File tree

7 files changed

+31
-42
lines changed

7 files changed

+31
-42
lines changed

examples/dpo_demo_gemma3.ipynb

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -293,12 +293,6 @@
293293
" base_model, lora_provider, **model_input\n",
294294
" )\n",
295295
"\n",
296-
" with mesh:\n",
297-
" state = nnx.state(lora_model)\n",
298-
" pspecs = nnx.get_partition_spec(state)\n",
299-
" sharded_state = jax.lax.with_sharding_constraint(state, pspecs)\n",
300-
" nnx.update(lora_model, sharded_state)\n",
301-
"\n",
302296
" return lora_model"
303297
]
304298
},

examples/qlora_demo.ipynb

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -490,12 +490,6 @@
490490
" base_model, lora_provider, **model_input\n",
491491
" )\n",
492492
"\n",
493-
" with mesh:\n",
494-
" state = nnx.state(lora_model)\n",
495-
" pspecs = nnx.get_partition_spec(state)\n",
496-
" sharded_state = jax.lax.with_sharding_constraint(state, pspecs)\n",
497-
" nnx.update(lora_model, sharded_state)\n",
498-
"\n",
499493
" return lora_model"
500494
]
501495
},

scripts/grpo_demo_llama3_qwen2.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -422,12 +422,6 @@ def get_lora_model(base_model, model_mesh=None):
422422
base_model, lora_provider, **model_input
423423
)
424424

425-
with model_mesh:
426-
state = nnx.state(lora_model)
427-
pspecs = nnx.get_partition_spec(state)
428-
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
429-
nnx.update(lora_model, sharded_state)
430-
431425
return lora_model
432426

433427

scripts/grpo_demo_sglang_jax_rollout.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -380,12 +380,6 @@ def get_lora_model(base_model, mesh):
380380
# base_model, lora_provider, **model_input
381381
# )
382382
lora_model = base_model
383-
with mesh:
384-
state = nnx.state(lora_model)
385-
pspecs = nnx.get_partition_spec(state)
386-
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
387-
nnx.update(lora_model, sharded_state)
388-
389383
return lora_model
390384

391385

tunix/cli/utils/model.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from tunix.models.llama3 import model as llama3_lib
3434
from tunix.models.qwen2 import model as qwen2_lib
3535
from tunix.models.qwen3 import model as qwen3_lib
36+
from tunix.rl import reshard
3637

3738

3839
# Map prefixes to the target object containing the methods.
@@ -252,13 +253,8 @@ def apply_lora_to_model(base_model, mesh, lora_config):
252253
lora_model = qwix.apply_lora_to_model(
253254
base_model, lora_provider, **model_input
254255
)
255-
256-
with mesh:
257-
state = nnx.state(lora_model)
258-
pspecs = nnx.get_partition_spec(state)
259-
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
260-
nnx.update(lora_model, sharded_state)
261-
256+
if mesh is not None:
257+
lora_model = reshard.reshard_model_to_mesh(lora_model, mesh)
262258
return lora_model
263259

264260

tunix/rl/reshard.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from absl import logging
2727
import jax
2828
import jaxtyping
29-
29+
from flax import nnx
30+
from tunix.rl import utils
3031

3132
# TODO(tsbao): move this to util
3233
def callback_on_ready(
@@ -483,3 +484,22 @@ def _get_dst_sharding(x):
483484
),
484485
)
485486
return resharded_array
487+
488+
489+
def reshard_model_to_mesh(model: nnx.Module, mesh: jax.sharding.Mesh):
490+
"""Reshard the lora model if the mesh is specified and the lora model mesh is not the same as the input mesh."""
491+
model_mesh = utils.get_pytree_mesh_info(nnx.state(model))
492+
if mesh is not None and model_mesh != mesh:
493+
with mesh:
494+
graph_def, state = nnx.split(model)
495+
default_memory_kind = jax.devices()[0].default_memory().kind
496+
dst_shardings = jax.tree_util.tree_map(
497+
lambda x: jax.sharding.NamedSharding(
498+
mesh,
499+
x,
500+
memory_kind=default_memory_kind,
501+
),
502+
nnx.get_partition_spec(state),
503+
)
504+
model = nnx.merge(graph_def, reshard_pytree(state, dst_shardings))
505+
return model

tunix/tests/test_common.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,23 @@
1414

1515
"""Common test utilities."""
1616

17-
from typing import List, Tuple, Any
1817
from collections.abc import Iterable
1918
import dataclasses
19+
import gc
20+
import os
21+
import shutil
22+
from typing import Any, List, Tuple
2023

2124
from flax import config as flax_config
2225
from flax import nnx
26+
import huggingface_hub
2327
import jax
2428
import jax.numpy as jnp
2529
import numpy as np
2630
import qwix
31+
from tunix.rl import reshard
2732

2833
import sentencepiece as spm
29-
import huggingface_hub
30-
import os
31-
import shutil
32-
import gc
3334

3435
if hasattr(flax_config, 'flax_always_shard_variable'):
3536
flax_config.update('flax_always_shard_variable', False)
@@ -159,11 +160,7 @@ def get_lora_model(
159160
model, lora_provider, **dummy_model_input
160161
)
161162
if mesh is not None:
162-
with mesh:
163-
state = nnx.state(lora_model)
164-
pspecs = nnx.get_partition_spec(state)
165-
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
166-
nnx.update(lora_model, sharded_state)
163+
lora_model = reshard.reshard_model_to_mesh(lora_model, mesh)
167164
return lora_model
168165

169166

0 commit comments

Comments
 (0)