diff --git a/python/module_logging/perf/analysis.py b/python/module_logging/perf/analysis.py index 8504140..5552f00 100644 --- a/python/module_logging/perf/analysis.py +++ b/python/module_logging/perf/analysis.py @@ -14,6 +14,18 @@ def demangle(mangled_name): return result.stdout.decode("utf-8").strip() +def fix_op_signature(op_sig: str) -> str: + op_name = op_sig.split("(", maxsplit=1)[0] + has_cast = "cast" in op_name + has_type = "bf16" in op_name or "bfloat16" in op_name + if has_cast and has_type: + fixed_op_name = op_name.replace("cast", "opt").replace("bf16", "hf").replace("bfloat16", "half") + fixed_sig = op_sig.replace(op_name, fixed_op_name) + else: + fixed_sig = op_sig + return fixed_sig + + class STATE(Enum): BEGIN = auto() MODULE = auto() @@ -436,7 +448,8 @@ def identify_op_time(self, line: str): Logger.debug(line) extention_op_time = float(line.split(" ")[-2]) / 1000000 extention_op_name = demangle(line.split(" ")[1]) - self.current_op = AtenOp(extention_op_name, self.current_m_name) + fixed_op_name = fix_op_signature(extention_op_name) + self.current_op = AtenOp(fixed_op_name, self.current_m_name) self.current_op.set_time(extention_op_time) self.op_or_module.append(self.current_op) self.total += extention_op_time