@@ -36,6 +36,11 @@ class _ParamRow:
3636 size : int
3737
3838
39+ @dataclasses .dataclass
40+ class _ParamRowWithSharding (_ParamRow ):
41+ sharding : tuple [int | None , ...] | str
42+
43+
3944@dataclasses .dataclass
4045class _ParamRowWithStats (_ParamRow ):
4146 mean : float
@@ -92,6 +97,47 @@ def count_parameters(params: _ParamsContainer) -> int:
9297 return _count_parameters (params )
9398
9499
100+ def _make_row (name , value ) -> _ParamRow :
101+ return _ParamRow (
102+ name = name ,
103+ shape = value .shape ,
104+ dtype = str (value .dtype ),
105+ size = int (np .prod (value .shape )),
106+ )
107+
108+
109+ def _make_row_with_sharding (name , value ) -> _ParamRowWithSharding :
110+ row = _make_row (name , value )
111+ if hasattr (value , "sharding" ):
112+ if hasattr (value .sharding , "spec" ):
113+ sharding = tuple (value .sharding .spec )
114+ else :
115+ sharding = str (value .sharding )
116+ else :
117+ sharding = ()
118+ return _ParamRowWithSharding (** dataclasses .asdict (row ), sharding = sharding )
119+
120+
121+ def _make_row_with_stats (name , value , mean , std ) -> _ParamRowWithStats :
122+ row = _make_row (name , value )
123+ return _ParamRowWithStats (
124+ ** dataclasses .asdict (row ),
125+ mean = float (jax .device_get (mean )),
126+ std = float (jax .device_get (std )),
127+ )
128+
129+
130+ def _make_row_with_stats_and_sharding (
131+ name , value , mean , std
132+ ) -> _ParamRowWithStatsAndSharding :
133+ row = _make_row_with_sharding (name , value )
134+ return _ParamRowWithStatsAndSharding (
135+ ** dataclasses .asdict (row ),
136+ mean = float (jax .device_get (mean )),
137+ std = float (jax .device_get (std )),
138+ )
139+
140+
95141def _get_parameter_rows (
96142 params : _ParamsContainer ,
97143 * ,
@@ -104,8 +150,11 @@ def _get_parameter_rows(
104150 nested. Alternatively a `tf.Module` can be provided, in which case the
105151 `trainable_variables` of the module will be used.
106152 include_stats: If True, add columns with mean and std for each variable.
153+ If the string "sharding", add column a column with the sharding of the
154+ variable.
107155 If the string "global", params are sharded global arrays and this
108156 function assumes it is called on every host, i.e. can use collectives.
157+ The sharding of the variables is also added as a column.
109158
110159 Returns:
111160 A list of `ParamRow`, or `ParamRowWithStats`, depending on the passed value
@@ -122,37 +171,25 @@ def _get_parameter_rows(
122171 else :
123172 names , values = [], []
124173
125- if include_stats :
126- def make_row (name , value , mean , std ):
127- kw = dict (
128- name = name ,
129- shape = value .shape ,
130- dtype = str (value .dtype ),
131- size = int (np .prod (value .shape )),
132- mean = float (jax .device_get (mean )),
133- std = float (jax .device_get (std )),
134- )
135- if include_stats == "global" and hasattr (value , "sharding" ):
136- if hasattr (value .sharding , "spec" ):
137- return _ParamRowWithStatsAndSharding (
138- sharding = tuple (value .sharding .spec ), ** kw
139- )
140- else :
141- return _ParamRowWithStatsAndSharding (
142- sharding = str (value .sharding ), ** kw
143- )
144- return _ParamRowWithStats (** kw )
145- mean_std_fn = _mean_std_jit if include_stats == "global" else _mean_std
146- return jax .tree_util .tree_map (make_row , names , values , * mean_std_fn (values ))
147- else :
148- def make_row (name , value ):
149- return _ParamRow (
150- name = name ,
151- shape = value .shape ,
152- dtype = str (value .dtype ),
153- size = int (np .prod (value .shape )),
154- )
155- return jax .tree_util .tree_map (make_row , names , values )
174+ match include_stats :
175+ case False :
176+ return jax .tree_util .tree_map (_make_row , names , values )
177+
178+ case True :
179+ mean_and_std = _mean_std (values )
180+ return jax .tree_util .tree_map (
181+ _make_row_with_stats , names , values , * mean_and_std )
182+
183+ case "global" :
184+ mean_and_std = _mean_std_jit (values )
185+ return jax .tree_util .tree_map (
186+ _make_row_with_stats_and_sharding , names , values , * mean_and_std )
187+
188+ case "sharding" :
189+ return jax .tree_util .tree_map (_make_row_with_sharding , names , values )
190+
191+ case _:
192+ raise ValueError (f"Unknown `include_stats`: { include_stats } " )
156193
157194
158195def _default_table_value_formatter (value ):
@@ -247,6 +284,7 @@ def _get_parameter_overview(
247284 False : _ParamRow ,
248285 True : _ParamRowWithStats ,
249286 "global" : _ParamRowWithStatsAndSharding ,
287+ "sharding" : _ParamRowWithSharding ,
250288 }[include_stats ]
251289 # Pass in `column_names` to enable rendering empty tables.
252290 column_names = [field .name for field in dataclasses .fields (RowType )]
@@ -267,9 +305,12 @@ def get_parameter_overview(
267305 Args:
268306 params: Dictionary with parameters as NumPy arrays. The dictionary can be
269307 nested.
270- include_stats: If True, add columns with mean and std for each variable. If
271- the string "global", params are sharded global arrays and this function
272- assumes it is called on every host, i.e. can use collectives.
308+ include_stats: If True, add columns with mean and std for each variable.
309+ If the string "sharding", add column a column with the sharding of the
310+ variable.
311+ If the string "global", params are sharded global arrays and this
312+ function assumes it is called on every host, i.e. can use collectives.
313+ The sharding of the variables is also added as a column.
273314 max_lines: If not `None`, the maximum number of variables to include.
274315
275316 Returns:
0 commit comments