From 4a867306102b54a8aed260deced442d12bd2672d Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 30 Jan 2026 16:11:21 -0800 Subject: [PATCH 1/4] [tx] Make sharding explicit in LoRA constructors --- skyrl-tx/tx/layers/lora.py | 24 ++++++++---------- skyrl-tx/tx/models/deepseekv3.py | 42 +++++++++++++++++++++----------- skyrl-tx/tx/models/llama3.py | 27 +++++++++++++------- skyrl-tx/tx/models/qwen3.py | 36 ++++++++++++++++++--------- 4 files changed, 80 insertions(+), 49 deletions(-) diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index 573b83adb..7aa6ee118 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -117,6 +117,7 @@ def __init__( num_embeddings: int, features: int, *, + sharding: tuple[str | None, ...], max_lora_adapters: int = 0, max_lora_rank: int = 8, dtype: jnp.dtype = jnp.float32, @@ -131,13 +132,9 @@ def __init__( features=features, dtype=dtype, param_dtype=param_dtype, - embedding_init=embedding_init, + embedding_init=nnx.with_partitioning(embedding_init, sharding), rngs=rngs, ) - assert ( - self.embedding[...].sharding is not None - ), "LoRAEmbed layer needs sharding, you can specify it by using nnx.with_partitioning on the embedding_init" - sharding = self.embedding[...].sharding.spec self.init_lora( max_lora_adapters=max_lora_adapters, @@ -181,6 +178,7 @@ def __init__( in_features: int, out_features: int, *, + sharding: tuple[str | None, ...], max_lora_adapters: int = 0, max_lora_rank: int = 8, dtype: jnp.dtype = jnp.float32, @@ -200,14 +198,11 @@ def __init__( use_bias=use_bias, dtype=dtype, param_dtype=param_dtype, - kernel_init=kernel_init, - bias_init=bias_init, + kernel_init=nnx.with_partitioning(kernel_init, sharding), + bias_init=nnx.with_partitioning(bias_init, (sharding[-1],)), rngs=rngs, ) - assert ( - self.kernel[...].sharding is not None - ), "LoRALinear layer needs sharding, you can specify it by using nnx.with_partitioning on the kernel_init" - sharding = self.kernel[...].sharding.spec + self.init_lora( max_lora_adapters=max_lora_adapters, max_lora_rank=max_lora_rank, @@ -233,6 +228,7 @@ def __init__( in_features: int, out_features: int, *, + sharding: tuple[str | None, ...], max_lora_adapters: int = 0, max_lora_rank: int = 8, dtype: jnp.dtype = jnp.float32, @@ -243,10 +239,10 @@ def __init__( self.in_features = in_features self.out_features = out_features - self.weight = Param(num_experts, in_features, out_features, dtype=dtype, kernel_init=kernel_init, rngs=rngs) + self.weight = Param( + num_experts, in_features, out_features, dtype=dtype, kernel_init=nnx.with_partitioning(kernel_init, sharding), rngs=rngs + ) - assert self.weight[...].sharding is not None, "LoRAExpert layer needs sharding" - sharding = self.weight[...].sharding.spec self.init_lora( max_lora_adapters=max_lora_adapters, max_lora_rank=max_lora_rank, diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index 07aea272b..9232832d1 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -37,12 +37,13 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs self.q_proj = LoRALinear( in_features=config.hidden_size, out_features=self.num_heads * self.qk_head_dim, + sharding=("fsdp", tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", tp_shard)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.q_a_proj = None @@ -53,36 +54,39 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs self.q_a_proj = LoRALinear( in_features=config.hidden_size, out_features=self.q_lora_rank, + sharding=("fsdp", None), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=config.attention_bias, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", None)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) self.q_b_proj = LoRALinear( in_features=self.q_lora_rank, out_features=self.num_heads * self.qk_head_dim, + sharding=(None, tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, tp_shard)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.kv_a_proj_with_mqa = LoRALinear( in_features=config.hidden_size, out_features=self.kv_lora_rank + self.qk_rope_head_dim, + sharding=("fsdp", None), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=config.attention_bias, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", None)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) @@ -90,24 +94,26 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs self.kv_b_proj = LoRALinear( in_features=self.kv_lora_rank, out_features=self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + sharding=(None, tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, tp_shard)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.o_proj = LoRALinear( in_features=self.num_heads * self.v_head_dim, out_features=config.hidden_size, + sharding=(tp_shard, "fsdp"), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=config.attention_bias, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (tp_shard, "fsdp")), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) @@ -189,10 +195,11 @@ def __init__( self.gate_proj = LoRALinear( config.hidden_size, intermediate_size, + sharding=("fsdp", "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", "tp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, @@ -200,10 +207,11 @@ def __init__( self.up_proj = LoRALinear( config.hidden_size, intermediate_size, + sharding=("fsdp", "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", "tp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, @@ -211,10 +219,11 @@ def __init__( self.down_proj = LoRALinear( intermediate_size, config.hidden_size, + sharding=("tp", "fsdp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("tp", "fsdp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, @@ -260,30 +269,33 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs config.n_routed_experts, config.hidden_size, config.moe_intermediate_size, + sharding=("ep", "fsdp", "tp"), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.up_proj = LoRAExpert( config.n_routed_experts, config.hidden_size, config.moe_intermediate_size, + sharding=("ep", "fsdp", "tp"), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.down_proj = LoRAExpert( config.n_routed_experts, config.moe_intermediate_size, config.hidden_size, + sharding=("ep", "tp", "fsdp"), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "tp", "fsdp")), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) @@ -452,11 +464,12 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs self.embed_tokens = LoRAEmbed( num_embeddings=config.vocab_size, features=config.hidden_size, + sharding=("tp", None), dtype=dtype, max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, param_dtype=dtype, - embedding_init=nnx.with_partitioning(nnx.initializers.normal(), ("tp", None)), + embedding_init=nnx.initializers.normal(), rngs=rngs, ) self.layers = nnx.List( @@ -520,10 +533,11 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs self.lm_head = LoRALinear( config.hidden_size, config.vocab_size, + sharding=(None, "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 0522f75be..711d48b8f 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -31,48 +31,52 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.q_proj = LoRALinear( in_features=config.hidden_size, out_features=self.num_heads * self.head_dim, + sharding=(None, tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, tp_shard)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.k_proj = LoRALinear( in_features=config.hidden_size, out_features=self.num_kv_heads * self.head_dim, + sharding=(None, tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, tp_shard)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.v_proj = LoRALinear( in_features=config.hidden_size, out_features=self.num_kv_heads * self.head_dim, + sharding=(None, tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, tp_shard)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.o_proj = LoRALinear( in_features=self.num_heads * self.head_dim, out_features=config.hidden_size, + sharding=(tp_shard, None), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (tp_shard, None)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) @@ -115,10 +119,11 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.gate_proj = LoRALinear( config.hidden_size, config.intermediate_size, + sharding=(None, "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, @@ -126,10 +131,11 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.up_proj = LoRALinear( config.hidden_size, config.intermediate_size, + sharding=(None, "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, @@ -137,10 +143,11 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.down_proj = LoRALinear( config.intermediate_size, config.hidden_size, + sharding=("tp", None), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("tp", None)), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, @@ -196,11 +203,12 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.embed_tokens = LoRAEmbed( num_embeddings=config.vocab_size, features=config.hidden_size, + sharding=("tp", None), dtype=dtype, max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, param_dtype=dtype, - embedding_init=nnx.with_partitioning(nnx.initializers.normal(), ("tp", None)), + embedding_init=nnx.initializers.normal(), rngs=rngs, ) self.layers = nnx.List( @@ -263,10 +271,11 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.lm_head = LoRALinear( config.hidden_size, config.vocab_size, + sharding=(None, "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index e35ce8069..1348cac09 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -32,45 +32,49 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.q_proj = LoRALinear( in_features=config.hidden_size, out_features=self.num_heads * self.head_dim, + sharding=("fsdp", tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", tp_shard)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.k_proj = LoRALinear( in_features=config.hidden_size, out_features=self.num_kv_heads * self.head_dim, + sharding=("fsdp", tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", tp_shard)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.v_proj = LoRALinear( in_features=config.hidden_size, out_features=self.num_kv_heads * self.head_dim, + sharding=("fsdp", tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", tp_shard)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.o_proj = LoRALinear( in_features=self.num_heads * self.head_dim, out_features=config.hidden_size, + sharding=(tp_shard, "fsdp"), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (tp_shard, "fsdp")), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) @@ -116,10 +120,11 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.gate_proj = LoRALinear( config.hidden_size, config.intermediate_size, + sharding=("fsdp", "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", "tp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, @@ -127,10 +132,11 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.up_proj = LoRALinear( config.hidden_size, config.intermediate_size, + sharding=("fsdp", "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", "tp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, @@ -138,10 +144,11 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.down_proj = LoRALinear( config.intermediate_size, config.hidden_size, + sharding=("tp", "fsdp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("tp", "fsdp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, @@ -161,30 +168,33 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> config.num_experts, config.hidden_size, config.moe_intermediate_size, + sharding=("ep", "fsdp", "tp"), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.up_proj = LoRAExpert( config.num_experts, config.hidden_size, config.moe_intermediate_size, + sharding=("ep", "fsdp", "tp"), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.down_proj = LoRAExpert( config.num_experts, config.moe_intermediate_size, config.hidden_size, + sharding=("ep", "tp", "fsdp"), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "tp", "fsdp")), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) @@ -311,11 +321,12 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.embed_tokens = LoRAEmbed( num_embeddings=config.vocab_size, features=config.hidden_size, + sharding=("tp", None), dtype=dtype, max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, param_dtype=dtype, - embedding_init=nnx.with_partitioning(nnx.initializers.normal(), ("tp", None)), + embedding_init=nnx.initializers.normal(), rngs=rngs, ) self.layers = nnx.List( @@ -378,10 +389,11 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.lm_head = LoRALinear( config.hidden_size, config.vocab_size, + sharding=(None, "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, From 1ec793e0a7f301774b60a46f68b07659eb9f9f89 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 30 Jan 2026 16:17:40 -0800 Subject: [PATCH 2/4] black --- skyrl-tx/tx/layers/lora.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index 7aa6ee118..776b7af59 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -240,7 +240,12 @@ def __init__( self.out_features = out_features self.weight = Param( - num_experts, in_features, out_features, dtype=dtype, kernel_init=nnx.with_partitioning(kernel_init, sharding), rngs=rngs + num_experts, + in_features, + out_features, + dtype=dtype, + kernel_init=nnx.with_partitioning(kernel_init, sharding), + rngs=rngs, ) self.init_lora( From 915571c8912e18a568df8db1e2ed52844e489943 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 30 Jan 2026 17:26:56 -0800 Subject: [PATCH 3/4] fix llama3 sharding --- skyrl-tx/tests/models/test_llama3.py | 2 +- skyrl-tx/tests/models/test_llama3_lora_training.py | 2 +- skyrl-tx/tx/models/llama3.py | 14 +++++++------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/skyrl-tx/tests/models/test_llama3.py b/skyrl-tx/tests/models/test_llama3.py index fa195567f..91b8575bc 100644 --- a/skyrl-tx/tests/models/test_llama3.py +++ b/skyrl-tx/tests/models/test_llama3.py @@ -39,7 +39,7 @@ def test_llama3(tp: int): base_config = AutoConfig.from_pretrained(model_name) config = Llama3Config(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True) - mesh = jax.make_mesh((1, tp), ("dp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) + mesh = jax.make_mesh((1, tp), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): model = Llama3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) load_safetensors(tmp, config, model) diff --git a/skyrl-tx/tests/models/test_llama3_lora_training.py b/skyrl-tx/tests/models/test_llama3_lora_training.py index af91d373e..61fa029c6 100644 --- a/skyrl-tx/tests/models/test_llama3_lora_training.py +++ b/skyrl-tx/tests/models/test_llama3_lora_training.py @@ -18,7 +18,7 @@ def test_lora_training(): config = Llama3Config(base_config, max_lora_adapters=5, max_lora_rank=32, shard_attention_heads=True) checkpoint_path = snapshot_download(base_model, allow_patterns=["*.safetensors"]) - mesh = jax.make_mesh((1, 1), ("dp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) + mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): model = Llama3ForCausalLM(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) load_safetensors(checkpoint_path, config, model) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 711d48b8f..b1ae1027b 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -31,7 +31,7 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.q_proj = LoRALinear( in_features=config.hidden_size, out_features=self.num_heads * self.head_dim, - sharding=(None, tp_shard), + sharding=("fsdp", tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, @@ -44,7 +44,7 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.k_proj = LoRALinear( in_features=config.hidden_size, out_features=self.num_kv_heads * self.head_dim, - sharding=(None, tp_shard), + sharding=("fsdp", tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, @@ -57,7 +57,7 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.v_proj = LoRALinear( in_features=config.hidden_size, out_features=self.num_kv_heads * self.head_dim, - sharding=(None, tp_shard), + sharding=("fsdp", tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, @@ -70,7 +70,7 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.o_proj = LoRALinear( in_features=self.num_heads * self.head_dim, out_features=config.hidden_size, - sharding=(tp_shard, None), + sharding=(tp_shard, "fsdp"), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, @@ -119,7 +119,7 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.gate_proj = LoRALinear( config.hidden_size, config.intermediate_size, - sharding=(None, "tp"), + sharding=("fsdp", "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, @@ -131,7 +131,7 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.up_proj = LoRALinear( config.hidden_size, config.intermediate_size, - sharding=(None, "tp"), + sharding=("fsdp", "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, @@ -143,7 +143,7 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.down_proj = LoRALinear( config.intermediate_size, config.hidden_size, - sharding=("tp", None), + sharding=("tp", "fsdp"), use_bias=False, dtype=dtype, param_dtype=dtype, From 2ebe8319b41ff338304736af21548b62a1f47166 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 30 Jan 2026 18:15:34 -0800 Subject: [PATCH 4/4] fix test --- skyrl-tx/tests/models/test_models_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index f0dc261bc..a90973371 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -13,7 +13,7 @@ from tx.utils.models import load_safetensors MODEL_PARAMS = [ - ("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("dp", "tp")), + ("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("fsdp", "tp")), ("Qwen/Qwen3-0.6B", Qwen3Config, Qwen3ForCausalLM, ("fsdp", "tp")), ] MODEL_IDS = ["llama3", "qwen3"]