Skip to content

Commit c50acb7

Browse files
jpuigcervercopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 615750878
1 parent 0b961e5 commit c50acb7

File tree

2 files changed

+106
-36
lines changed

2 files changed

+106
-36
lines changed

clu/parameter_overview.py

Lines changed: 75 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ class _ParamRow:
3636
size: int
3737

3838

39+
@dataclasses.dataclass
40+
class _ParamRowWithSharding(_ParamRow):
41+
sharding: tuple[int | None, ...] | str
42+
43+
3944
@dataclasses.dataclass
4045
class _ParamRowWithStats(_ParamRow):
4146
mean: float
@@ -92,6 +97,47 @@ def count_parameters(params: _ParamsContainer) -> int:
9297
return _count_parameters(params)
9398

9499

100+
def _make_row(name, value) -> _ParamRow:
101+
return _ParamRow(
102+
name=name,
103+
shape=value.shape,
104+
dtype=str(value.dtype),
105+
size=int(np.prod(value.shape)),
106+
)
107+
108+
109+
def _make_row_with_sharding(name, value) -> _ParamRowWithSharding:
110+
row = _make_row(name, value)
111+
if hasattr(value, "sharding"):
112+
if hasattr(value.sharding, "spec"):
113+
sharding = tuple(value.sharding.spec)
114+
else:
115+
sharding = str(value.sharding)
116+
else:
117+
sharding = ()
118+
return _ParamRowWithSharding(**dataclasses.asdict(row), sharding=sharding)
119+
120+
121+
def _make_row_with_stats(name, value, mean, std) -> _ParamRowWithStats:
122+
row = _make_row(name, value)
123+
return _ParamRowWithStats(
124+
**dataclasses.asdict(row),
125+
mean=float(jax.device_get(mean)),
126+
std=float(jax.device_get(std)),
127+
)
128+
129+
130+
def _make_row_with_stats_and_sharding(
131+
name, value, mean, std
132+
) -> _ParamRowWithStatsAndSharding:
133+
row = _make_row_with_sharding(name, value)
134+
return _ParamRowWithStatsAndSharding(
135+
**dataclasses.asdict(row),
136+
mean=float(jax.device_get(mean)),
137+
std=float(jax.device_get(std)),
138+
)
139+
140+
95141
def _get_parameter_rows(
96142
params: _ParamsContainer,
97143
*,
@@ -104,8 +150,11 @@ def _get_parameter_rows(
104150
nested. Alternatively a `tf.Module` can be provided, in which case the
105151
`trainable_variables` of the module will be used.
106152
include_stats: If True, add columns with mean and std for each variable.
153+
If the string "sharding", add column a column with the sharding of the
154+
variable.
107155
If the string "global", params are sharded global arrays and this
108156
function assumes it is called on every host, i.e. can use collectives.
157+
The sharding of the variables is also added as a column.
109158
110159
Returns:
111160
A list of `ParamRow`, or `ParamRowWithStats`, depending on the passed value
@@ -122,37 +171,25 @@ def _get_parameter_rows(
122171
else:
123172
names, values = [], []
124173

125-
if include_stats:
126-
def make_row(name, value, mean, std):
127-
kw = dict(
128-
name=name,
129-
shape=value.shape,
130-
dtype=str(value.dtype),
131-
size=int(np.prod(value.shape)),
132-
mean=float(jax.device_get(mean)),
133-
std=float(jax.device_get(std)),
134-
)
135-
if include_stats == "global" and hasattr(value, "sharding"):
136-
if hasattr(value.sharding, "spec"):
137-
return _ParamRowWithStatsAndSharding(
138-
sharding=tuple(value.sharding.spec), **kw
139-
)
140-
else:
141-
return _ParamRowWithStatsAndSharding(
142-
sharding=str(value.sharding), **kw
143-
)
144-
return _ParamRowWithStats(**kw)
145-
mean_std_fn = _mean_std_jit if include_stats == "global" else _mean_std
146-
return jax.tree_util.tree_map(make_row, names, values, *mean_std_fn(values))
147-
else:
148-
def make_row(name, value):
149-
return _ParamRow(
150-
name=name,
151-
shape=value.shape,
152-
dtype=str(value.dtype),
153-
size=int(np.prod(value.shape)),
154-
)
155-
return jax.tree_util.tree_map(make_row, names, values)
174+
match include_stats:
175+
case False:
176+
return jax.tree_util.tree_map(_make_row, names, values)
177+
178+
case True:
179+
mean_and_std = _mean_std(values)
180+
return jax.tree_util.tree_map(
181+
_make_row_with_stats, names, values, *mean_and_std)
182+
183+
case "global":
184+
mean_and_std = _mean_std_jit(values)
185+
return jax.tree_util.tree_map(
186+
_make_row_with_stats_and_sharding, names, values, *mean_and_std)
187+
188+
case "sharding":
189+
return jax.tree_util.tree_map(_make_row_with_sharding, names, values)
190+
191+
case _:
192+
raise ValueError(f"Unknown `include_stats`: {include_stats}")
156193

157194

158195
def _default_table_value_formatter(value):
@@ -247,6 +284,7 @@ def _get_parameter_overview(
247284
False: _ParamRow,
248285
True: _ParamRowWithStats,
249286
"global": _ParamRowWithStatsAndSharding,
287+
"sharding": _ParamRowWithSharding,
250288
}[include_stats]
251289
# Pass in `column_names` to enable rendering empty tables.
252290
column_names = [field.name for field in dataclasses.fields(RowType)]
@@ -267,9 +305,12 @@ def get_parameter_overview(
267305
Args:
268306
params: Dictionary with parameters as NumPy arrays. The dictionary can be
269307
nested.
270-
include_stats: If True, add columns with mean and std for each variable. If
271-
the string "global", params are sharded global arrays and this function
272-
assumes it is called on every host, i.e. can use collectives.
308+
include_stats: If True, add columns with mean and std for each variable.
309+
If the string "sharding", add column a column with the sharding of the
310+
variable.
311+
If the string "global", params are sharded global arrays and this
312+
function assumes it is called on every host, i.e. can use collectives.
313+
The sharding of the variables is also added as a column.
273314
max_lines: If not `None`, the maximum number of variables to include.
274315
275316
Returns:

clu/parameter_overview_test.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from flax import linen as nn
2020
import jax
2121
import jax.numpy as jnp
22+
import numpy as np
2223

2324

2425
EMPTY_PARAMETER_OVERVIEW = """+------+-------+-------+------+------+-----+
@@ -35,6 +36,14 @@
3536
+-------------+--------------+---------+------+
3637
Total: 56 -- 224 bytes"""
3738

39+
FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_SHARDING = """+-------------+--------------+---------+------+----------+
40+
| Name | Shape | Dtype | Size | Sharding |
41+
+-------------+--------------+---------+------+----------+
42+
| conv/bias | (2,) | float32 | 2 | () |
43+
| conv/kernel | (3, 3, 3, 2) | float32 | 54 | () |
44+
+-------------+--------------+---------+------+----------+
45+
Total: 56 -- 224 bytes"""
46+
3847
FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS = """+-------------+--------------+---------+------+------+-----+
3948
| Name | Shape | Dtype | Size | Mean | Std |
4049
+-------------+--------------+---------+------+------+-----+
@@ -43,6 +52,14 @@
4352
+-------------+--------------+---------+------+------+-----+
4453
Total: 56 -- 224 bytes"""
4554

55+
FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS_AND_SHARDING = """+-------------+--------------+---------+------+------+-----+----------+
56+
| Name | Shape | Dtype | Size | Mean | Std | Sharding |
57+
+-------------+--------------+---------+------+------+-----+----------+
58+
| conv/bias | (2,) | float32 | 2 | 1.0 | 0.0 | () |
59+
| conv/kernel | (3, 3, 3, 2) | float32 | 54 | 1.0 | 0.0 | () |
60+
+-------------+--------------+---------+------+------+-----+----------+
61+
Total: 56 -- 224 bytes"""
62+
4663
FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS = """+--------------------+--------------+---------+------+------+-----+
4764
| Name | Shape | Dtype | Size | Mean | Std |
4865
+--------------------+--------------+---------+------+------+-----+
@@ -66,7 +83,7 @@ def test_count_parameters_empty(self):
6683

6784
def test_count_parameters(self):
6885
rng = jax.random.PRNGKey(42)
69-
# Weights of a 2D convolution with 2 filters..
86+
# Weights of a 2D convolution with 2 filters.
7087
variables = CNN().init(rng, jnp.zeros((2, 5, 5, 3)))
7188
# 3 * 3*3 * 2 + 2 (bias) = 56 parameters
7289
self.assertEqual(56,
@@ -78,7 +95,7 @@ def test_get_parameter_overview_empty(self):
7895

7996
def test_get_parameter_overview(self):
8097
rng = jax.random.PRNGKey(42)
81-
# Weights of a 2D convolution with 2 filters..
98+
# Weights of a 2D convolution with 2 filters.
8299
variables = CNN().init(rng, jnp.zeros((2, 5, 5, 3)))
83100
variables = jax.tree_map(jnp.ones_like, variables)
84101
self.assertEqual(
@@ -91,6 +108,18 @@ def test_get_parameter_overview(self):
91108
self.assertEqual(
92109
FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS,
93110
parameter_overview.get_parameter_overview(variables))
111+
# Add sharding with PartitionSpecs.
112+
mesh = jax.sharding.Mesh(np.asarray(jax.devices()), "d")
113+
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
114+
variables = jax.jit(lambda x: x, out_shardings=sharding)(variables)
115+
self.assertEqual(
116+
FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_SHARDING,
117+
parameter_overview.get_parameter_overview(
118+
variables["params"], include_stats="sharding"))
119+
self.assertEqual(
120+
FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS_AND_SHARDING,
121+
parameter_overview.get_parameter_overview(
122+
variables["params"], include_stats="global"))
94123

95124
def test_get_parameter_overview_shape_dtype_struct(self):
96125
variables_shape_dtype_struct = jax.eval_shape(

0 commit comments

Comments
 (0)