Skip to content

Commit f24fd6e

Browse files
author
The tunix Authors
committed
Remove explicit sharding after applying LoRA.
PiperOrigin-RevId: 825163138
1 parent b0e19d4 commit f24fd6e

File tree

6 files changed

+43
-54
lines changed

6 files changed

+43
-54
lines changed

examples/dpo_demo_gemma3.ipynb

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -293,12 +293,6 @@
293293
" base_model, lora_provider, **model_input\n",
294294
" )\n",
295295
"\n",
296-
" with mesh:\n",
297-
" state = nnx.state(lora_model)\n",
298-
" pspecs = nnx.get_partition_spec(state)\n",
299-
" sharded_state = jax.lax.with_sharding_constraint(state, pspecs)\n",
300-
" nnx.update(lora_model, sharded_state)\n",
301-
"\n",
302296
" return lora_model"
303297
]
304298
},
@@ -332,9 +326,9 @@
332326
},
333327
"outputs": [],
334328
"source": [
335-
"TEMPLATE = \"\"\"\u003cstart_of_turn\u003euser\n",
336-
"{question}\u003cend_of_turn\u003e\n",
337-
"\u003cstart_of_turn\u003emodel\"\"\"\n",
329+
"TEMPLATE = \"\"\"<start_of_turn>user\n",
330+
"{question}<end_of_turn>\n",
331+
"<start_of_turn>model\"\"\"\n",
338332
"\n",
339333
"\n",
340334
"def generate(\n",
@@ -426,10 +420,10 @@
426420
" except:\n",
427421
" print(\"SKIPPED accuracy check\")\n",
428422
"\n",
429-
" if corr_ctr_per_question \u003e 0:\n",
423+
" if corr_ctr_per_question > 0:\n",
430424
" break\n",
431425
"\n",
432-
" if corr_ctr_per_question \u003e 0:\n",
426+
" if corr_ctr_per_question > 0:\n",
433427
" corr += 1\n",
434428
" if corr_lst and make_lst:\n",
435429
" response_lst.append((question, answer, multiple_call_response))\n",
@@ -439,7 +433,7 @@
439433
"\n",
440434
" total += 1\n",
441435
" if total % 10 == 0:\n",
442-
" print(f\"===\u003e {corr=}, {total=}, {corr / total * 100=}\")\n",
436+
" print(f\"===> {corr=}, {total=}, {corr / total * 100=}\")\n",
443437
"\n",
444438
" to_return = (\n",
445439
" corr,\n",
@@ -459,13 +453,13 @@
459453
},
460454
"outputs": [],
461455
"source": [
462-
"def extract_hash_answer(text: str) -\u003e str | None:\n",
456+
"def extract_hash_answer(text: str) -> str | None:\n",
463457
" if \"####\" not in text:\n",
464458
" return None\n",
465459
" return text.split(\"####\")[1].strip()\n",
466460
"\n",
467461
"\n",
468-
"def get_dataset(data_dir, split=\"train\") -\u003e grain.MapDataset:\n",
462+
"def get_dataset(data_dir, split=\"train\") -> grain.MapDataset:\n",
469463
" # Download data\n",
470464
" if not os.path.exists(data_dir):\n",
471465
" os.makedirs(data_dir)\n",
@@ -548,7 +542,7 @@
548542
},
549543
"outputs": [],
550544
"source": [
551-
"def get_dataset() -\u003e grain.MapDataset:\n",
545+
"def get_dataset() -> grain.MapDataset:\n",
552546
" dpo_dataset = load_dataset(\n",
553547
" \"argilla/distilabel-intel-orca-dpo-pairs\", split=\"train\"\n",
554548
" )\n",
@@ -565,7 +559,7 @@
565559
" samples_to_add = total_samples_needed - num_gsm8k_train_samples\n",
566560
" print(f\"Number of additional random samples needed: {samples_to_add}\")\n",
567561
"\n",
568-
" if samples_to_add \u003e 0:\n",
562+
" if samples_to_add > 0:\n",
569563
" # Randomly select additional samples from the original dataset\n",
570564
" # Ensure we don't sample more than the total available in the original dataset\n",
571565
" random_samples = dpo_dataset.shuffle(seed=42).select(\n",
@@ -745,7 +739,7 @@
745739
},
746740
"outputs": [],
747741
"source": [
748-
"# The first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. \u003e10 minutes per step, please open a bug. Really appreciated!\n",
742+
"# The first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. >10 minutes per step, please open a bug. Really appreciated!\n",
749743
"\n",
750744
"if mesh is None:\n",
751745
" dpo_trainer.train(train_dataset)\n",

examples/qlora_demo.ipynb

Lines changed: 8 additions & 14 deletions
Large diffs are not rendered by default.

scripts/grpo_demo_llama3_qwen2.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -422,12 +422,6 @@ def get_lora_model(base_model, model_mesh=None):
422422
base_model, lora_provider, **model_input
423423
)
424424

425-
with model_mesh:
426-
state = nnx.state(lora_model)
427-
pspecs = nnx.get_partition_spec(state)
428-
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
429-
nnx.update(lora_model, sharded_state)
430-
431425
return lora_model
432426

433427

scripts/grpo_demo_sglang_jax_rollout.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -380,12 +380,6 @@ def get_lora_model(base_model, mesh):
380380
# base_model, lora_provider, **model_input
381381
# )
382382
lora_model = base_model
383-
with mesh:
384-
state = nnx.state(lora_model)
385-
pspecs = nnx.get_partition_spec(state)
386-
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
387-
nnx.update(lora_model, sharded_state)
388-
389383
return lora_model
390384

391385

tunix/cli/utils/model.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,6 @@ def apply_lora_to_model(base_model, mesh, lora_config):
253253
base_model, lora_provider, **model_input
254254
)
255255

256-
with mesh:
257-
state = nnx.state(lora_model)
258-
pspecs = nnx.get_partition_spec(state)
259-
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
260-
nnx.update(lora_model, sharded_state)
261-
262256
return lora_model
263257

264258

tunix/tests/test_common.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
import os
3131
import shutil
3232
import gc
33+
from tunix.rl import utils
34+
from tunix.rl import reshard
3335

3436
if hasattr(flax_config, 'flax_always_shard_variable'):
3537
flax_config.update('flax_always_shard_variable', False)
@@ -158,12 +160,29 @@ def get_lora_model(
158160
lora_model = qwix.apply_lora_to_model(
159161
model, lora_provider, **dummy_model_input
160162
)
161-
if mesh is not None:
163+
164+
# Reshard the model if the mesh is specified and the lora model mesh is not
165+
# the same as the input mesh.
166+
lora_model_mesh = utils.get_pytree_mesh_info(nnx.state(lora_model))
167+
if (
168+
lora_model_mesh is not None
169+
and mesh is not None
170+
and lora_model_mesh != mesh
171+
):
162172
with mesh:
163-
state = nnx.state(lora_model)
164-
pspecs = nnx.get_partition_spec(state)
165-
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
166-
nnx.update(lora_model, sharded_state)
173+
graph_def, state = nnx.split(lora_model)
174+
default_memory_kind = jax.devices()[0].default_memory().kind
175+
dst_shardings = jax.tree_util.tree_map(
176+
lambda x: jax.sharding.NamedSharding(
177+
mesh,
178+
x,
179+
memory_kind=default_memory_kind,
180+
),
181+
nnx.get_partition_spec(state),
182+
)
183+
lora_model = nnx.merge(
184+
graph_def, reshard.reshard_pytree(state, dst_shardings)
185+
)
167186
return lora_model
168187

169188

0 commit comments

Comments
 (0)