-
Notifications
You must be signed in to change notification settings - Fork 131
Open
Description
I use this code to identify leaf node.
class ModelInfo:
def __init__(self, model: TorchModelReference):
self.layer_list = torchinfo.summary(model._reference_model, model.batch_input_shape, verbose=0).summary_list
def get_info_for_layer(self, layer_key) -> LayerInfo:
return self._get_info(layer_key, self.layer_list, layer_key, self.layer_list[0])
def _get_info(self, layer_key: str, layer_list: list[LayerInfo], full_key, parent_info) -> LayerInfo:
key_elements = layer_key.split(".")
if len(key_elements) > 1:
parents = {info.var_name: info for info in layer_list if not info.is_leaf_layer}
if key_elements[0] in parents:
current_info = parents[key_elements[0]]
return self._get_info(".".join(key_elements[1:]), current_info.children, full_key, current_info)
leafs = {info.var_name: info for info in layer_list if
info.is_leaf_layer and info.parent_info.var_name == parent_info.var_name}
if key_elements[0] in leafs:
return leafs[key_elements[0]]
raise Exception(f"Could not resolve layer info for {full_key} - step failed for part {'.'.join(key_elements)}")
but it always run into Exception: Could not resolve layer info for model.2.cv0.conv - step failed for part cv0.conv
I think this is because it mix different '2' modules , this '2' should under model module, but it identify the one under '22' varname module, and lead to the model can not find 'cv0'.
Metadata
Metadata
Assignees
Labels
No labels