Skip to content

Commit aad35b5

Browse files
paul-gibbonspre-commit-ci[bot]pggPL
committed
[PyTorch Debug] Fix issue with start_end_list logging feature (#2252)
* fixes for start_end_list usage in TE debug Signed-off-by: Paul Gibbons <pgibbons@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Paul Gibbons <pgibbons@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com>
1 parent 56e04f1 commit aad35b5

File tree

3 files changed

+27
-7
lines changed

3 files changed

+27
-7
lines changed

transformer_engine/debug/features/log_fp8_tensor_stats.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,16 @@ def inspect_tensor(
290290
for stat in config["stats"]:
291291
self.check_if_stat_is_supported(stat, recipe_name)
292292

293+
start_step = config.get("start_step", None)
294+
end_step = config.get("end_step", None)
295+
start_end_list = config.get("start_end_list", None)
296+
if start_end_list is not None:
297+
start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list)
298+
293299
options = (
294-
config.get("start_step", None),
295-
config.get("end_step", None),
296-
config.get("start_end_list", None),
300+
start_step,
301+
end_step,
302+
start_end_list,
297303
"fp8",
298304
)
299305

transformer_engine/debug/features/log_tensor_stats.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,16 @@ def inspect_tensor(
130130
" log_tensor_stats. Use log_fp8_tensor_stats for FP8 tensors."
131131
)
132132

133+
start_step = config.get("start_step", None)
134+
end_step = config.get("end_step", None)
135+
start_end_list = config.get("start_end_list", None)
136+
if start_end_list is not None:
137+
start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list)
138+
133139
options = (
134-
config.get("start_step", None),
135-
config.get("end_step", None),
136-
config.get("start_end_list", None),
140+
start_step,
141+
end_step,
142+
start_end_list,
137143
)
138144

139145
skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params(

transformer_engine/debug/features/utils/stats_buffer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,19 @@ def _if_run_reduction(self) -> bool:
172172
if self.at_least_one_layer_fed:
173173
return True
174174
iteration = TEDebugState.get_iteration()
175-
for _, next_iter in self.layers_to_next_iter.items():
175+
layers_to_remove = []
176+
for layer_name, next_iter in self.layers_to_next_iter.items():
177+
# When next_iter is None the feature will no longer run.
178+
if next_iter is None:
179+
layers_to_remove.append(layer_name)
180+
continue
176181
# Note that layer can be not run for many iterations,
177182
# in this case we will synchronize until every step until we get any information from it.
178183
if iteration >= next_iter:
179184
return True
185+
186+
for layer_name in layers_to_remove:
187+
self.layers_to_next_iter.pop(layer_name, None)
180188
return False
181189

182190
def reset(self):

0 commit comments

Comments
 (0)