From 4a6a35689a0176631cb6420fcd79dcbd3232c135 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Wed, 12 Nov 2025 12:16:06 -0500 Subject: [PATCH 1/5] Add validation error handling to InterpreterABC class --- src/kirin/interp/abc.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/kirin/interp/abc.py b/src/kirin/interp/abc.py index a58f8bda2..b01c2f2a0 100644 --- a/src/kirin/interp/abc.py +++ b/src/kirin/interp/abc.py @@ -51,6 +51,10 @@ class InterpreterABC(ABC, Generic[FrameType, ValueType]): """The interpreter state.""" __eval_lock: bool = field(default=False, init=False, repr=False) """Lock for the eval method.""" + _validation_errors: dict[ir.IRNode, set[ir.ValidationError]] = field( + default_factory=dict, init=False + ) + """The validation errors collected during interpretation.""" def __init_subclass__(cls) -> None: super().__init_subclass__() @@ -330,3 +334,25 @@ def lookup_registry( def build_signature(self, frame: FrameType, node: ir.Statement) -> Signature: return Signature(node.__class__, tuple(arg.type for arg in node.args)) + + def add_validation_error(self, node: ir.IRNode, error: ir.ValidationError) -> None: + """Add a ValidationError for a given IR node. + + If the node is not present in the _validation_errors dict, create a new set. + Otherwise append to the existing set of errors. + """ + self._validation_errors.setdefault(node, set()).add(error) + + def get_validation_errors( + self, keys: set[ir.IRNode] | None = None + ) -> list[ir.ValidationError]: + """Get the validation errors collected during interpretation. + + If keys is provided, only return errors for the given nodes. + Otherwise return all errors. + """ + if keys is None: + return [err for s in self._validation_errors.values() for err in s] + return [ + err for node in keys for err in self._validation_errors.get(node, set()) + ] From cde9d2e4dd394f853ab8045c42586103eb56de08 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Wed, 12 Nov 2025 13:02:05 -0500 Subject: [PATCH 2/5] Moved ValidationPass to Kirin --- src/kirin/validation/validationpass.py | 168 +++++++++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 src/kirin/validation/validationpass.py diff --git a/src/kirin/validation/validationpass.py b/src/kirin/validation/validationpass.py new file mode 100644 index 000000000..0755762cf --- /dev/null +++ b/src/kirin/validation/validationpass.py @@ -0,0 +1,168 @@ +from abc import ABC, abstractmethod +from typing import Any, Generic, TypeVar +from dataclasses import field, dataclass + +from kirin import ir +from kirin.ir.exception import ValidationError + +T = TypeVar("T") + + +class ValidationPass(ABC, Generic[T]): + """Base class for a validation pass. + + Each pass analyzes an IR method and collects validation errors. + """ + + @abstractmethod + def name(self) -> str: + """Return the name of this validation pass.""" + ... + + @abstractmethod + def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]: + """Run validation and return (analysis_frame, errors). + + Returns: + - analysis_frame: The result frame from the analysis + - errors: List of validation errors (empty if valid) + """ + ... + + def get_required_analyses(self) -> list[type]: + """Return list of analysis classes this pass depends on. + + Override to declare dependencies (e.g., [AddressAnalysis, AnotherAnalysis]). + The suite will run these analyses once and cache results. + """ + return [] + + def set_analysis_cache(self, cache: dict[type, Any]) -> None: + """Receive cached analysis results from the suite. + + Override to store cached analysis frames/results. + Example: + self._address_frame = cache.get(AddressAnalysis) + """ + pass + + +@dataclass +class ValidationSuite: + """Compose multiple validation passes and run them together. + + Caches analysis results to avoid redundant computation when multiple + validation passes depend on the same underlying analysis. + + Example: + suite = ValidationSuite([ + NoCloningValidation, + AnotherValidation, + ]) + result = suite.validate(my_kernel) + print(result.format_errors()) + """ + + passes: list[type[ValidationPass]] = field(default_factory=list) + fail_fast: bool = False + _analysis_cache: dict[type, Any] = field(default_factory=dict, init=False) + + def add_pass(self, pass_cls: type[ValidationPass]) -> "ValidationSuite": + """Add a validation pass to the suite.""" + self.passes.append(pass_cls) + return self + + def validate(self, method: ir.Method) -> "ValidationResult": + """Run all validation passes and collect results.""" + all_errors: dict[str, list[ValidationError]] = {} + all_frames: dict[str, Any] = {} + self._analysis_cache.clear() + for pass_cls in self.passes: + validator = pass_cls() + pass_name = validator.name() + + try: + required = validator.get_required_analyses() + for required_analysis in required: + if required_analysis not in self._analysis_cache: + analysis = required_analysis(method.dialects) + analysis.initialize() + frame, _ = analysis.run(method) + self._analysis_cache[required_analysis] = frame + + validator.set_analysis_cache(self._analysis_cache) + + frame, errors = validator.run(method) + all_frames[pass_name] = frame + + for err in errors: + if isinstance(err, ValidationError): + try: + err.attach(method) + except Exception: + pass + + if errors: + all_errors[pass_name] = errors + if self.fail_fast: + break + except Exception as e: + import traceback + + tb = traceback.format_exc() + all_errors[pass_name] = [ + ValidationError( + method.code, f"Validation pass '{pass_name}' failed: {e}\n{tb}" + ) + ] + if self.fail_fast: + break + + return ValidationResult(all_errors, all_frames) + + +@dataclass +class ValidationResult: + """Result of running a validation suite.""" + + errors: dict[str, list[ValidationError]] + frames: dict[str, Any] = field(default_factory=dict) + + def is_valid(self) -> bool: + """Check if validation passed (no errors).""" + return len(self.errors) == 0 + + def error_count(self) -> int: + """Total number of errors across all passes.""" + return sum(len(errs) for errs in self.errors.values()) + + def get_frame(self, pass_name: str) -> Any: + """Get the analysis frame for a specific pass.""" + return self.frames.get(pass_name) + + def format_errors(self) -> str: + """Format all errors with their pass names.""" + if not self.errors: + return "\n\033[32mAll validation passes succeeded\033[0m" + + lines = [ + f"\n\033[31mValidation failed with {self.error_count()} error(s):\033[0m" + ] + + for pass_name, pass_errors in self.errors.items(): + lines.append(f"\n\033[31m{pass_name}:\033[0m") + for err in pass_errors: + err_msg = err.args[0] if err.args else str(err) + lines.append(f" - {err_msg}") + if hasattr(err, "hint"): + hint = err.hint() + if hint: + lines.append(f" {hint}") + + return "\n".join(lines) + + def raise_if_invalid(self): + """Raise an exception if validation failed.""" + if not self.is_valid(): + first_errors = next(iter(self.errors.values())) + raise first_errors[0] From f6bbde74f016265a1d36cfc85a6f90bdf9d92467 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Wed, 12 Nov 2025 14:09:23 -0500 Subject: [PATCH 3/5] fix documentation error --- src/kirin/validation/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 src/kirin/validation/__init__.py diff --git a/src/kirin/validation/__init__.py b/src/kirin/validation/__init__.py new file mode 100644 index 000000000..9f55f34ad --- /dev/null +++ b/src/kirin/validation/__init__.py @@ -0,0 +1,4 @@ +from .validationpass import ( + ValidationPass as ValidationPass, + ValidationSuite as ValidationSuite, +) From ed4ddec2e087f1fd8c38ea4877758f69628cc26a Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Mon, 17 Nov 2025 10:59:51 -0500 Subject: [PATCH 4/5] Fix ValidaitionPass to track validity and count violations accurately --- src/kirin/validation/validationpass.py | 42 ++++++++++++++++++++------ 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/src/kirin/validation/validationpass.py b/src/kirin/validation/validationpass.py index 0755762cf..54d7b4eae 100644 --- a/src/kirin/validation/validationpass.py +++ b/src/kirin/validation/validationpass.py @@ -77,6 +77,7 @@ def validate(self, method: ir.Method) -> "ValidationResult": all_errors: dict[str, list[ValidationError]] = {} all_frames: dict[str, Any] = {} self._analysis_cache.clear() + for pass_cls in self.passes: validator = pass_cls() pass_name = validator.name() @@ -106,6 +107,7 @@ def validate(self, method: ir.Method) -> "ValidationResult": all_errors[pass_name] = errors if self.fail_fast: break + except Exception as e: import traceback @@ -127,14 +129,37 @@ class ValidationResult: errors: dict[str, list[ValidationError]] frames: dict[str, Any] = field(default_factory=dict) + is_valid: bool = field(default=True, init=False) + + def __post_init__(self): + from bloqade.analysis.validation.nocloning.lattice import May, Must - def is_valid(self) -> bool: - """Check if validation passed (no errors).""" - return len(self.errors) == 0 + for _, frame in self.frames.items(): + if frame is None: + continue + for node, value in frame.entries.items(): + if isinstance(value, (Must, May)): + self.is_valid = False def error_count(self) -> int: - """Total number of errors across all passes.""" - return sum(len(errs) for errs in self.errors.values()) + """Total number of violations across all passes. + + Counts violations directly from frames using the same logic as test helpers. + """ + from bloqade.analysis.validation.nocloning.lattice import May, Must + + total = 0 + for pass_name, frame in self.frames.items(): + if frame is None: + continue + + for node, value in frame.entries.items(): + if isinstance(value, Must): + total += len(value.violations) + elif isinstance(value, May): + total += len(value.violations) + + return total def get_frame(self, pass_name: str) -> Any: """Get the analysis frame for a specific pass.""" @@ -142,13 +167,12 @@ def get_frame(self, pass_name: str) -> Any: def format_errors(self) -> str: """Format all errors with their pass names.""" - if not self.errors: + if self.is_valid: return "\n\033[32mAll validation passes succeeded\033[0m" lines = [ - f"\n\033[31mValidation failed with {self.error_count()} error(s):\033[0m" + f"\n\033[31mValidation failed with {self.error_count()} violation(s):\033[0m" ] - for pass_name, pass_errors in self.errors.items(): lines.append(f"\n\033[31m{pass_name}:\033[0m") for err in pass_errors: @@ -163,6 +187,6 @@ def format_errors(self) -> str: def raise_if_invalid(self): """Raise an exception if validation failed.""" - if not self.is_valid(): + if not self.is_valid: first_errors = next(iter(self.errors.values())) raise first_errors[0] From dcf9c3d53e49dfd48dc26a233b13eae16f658085 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Mon, 17 Nov 2025 11:23:30 -0500 Subject: [PATCH 5/5] Added Potential and Definite ValidationErrors --- src/kirin/ir/exception.py | 12 ++++++++++++ src/kirin/validation/validationpass.py | 25 +++++++------------------ 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/src/kirin/ir/exception.py b/src/kirin/ir/exception.py index ecd203795..5044f84eb 100644 --- a/src/kirin/ir/exception.py +++ b/src/kirin/ir/exception.py @@ -68,3 +68,15 @@ class TypeCheckError(ValidationError): class CompilerError(Exception): pass + + +class PotentialValidationError(ValidationError): + """Indicates a potential violation that may occur at runtime.""" + + pass + + +class DefiniteValidationError(ValidationError): + """Indicates a definite violation that will occur at runtime.""" + + pass diff --git a/src/kirin/validation/validationpass.py b/src/kirin/validation/validationpass.py index 54d7b4eae..41a0bd405 100644 --- a/src/kirin/validation/validationpass.py +++ b/src/kirin/validation/validationpass.py @@ -132,33 +132,22 @@ class ValidationResult: is_valid: bool = field(default=True, init=False) def __post_init__(self): - from bloqade.analysis.validation.nocloning.lattice import May, Must - - for _, frame in self.frames.items(): - if frame is None: - continue - for node, value in frame.entries.items(): - if isinstance(value, (Must, May)): - self.is_valid = False + for _, errors in self.errors.items(): + if errors: + self.is_valid = False + break def error_count(self) -> int: """Total number of violations across all passes. Counts violations directly from frames using the same logic as test helpers. """ - from bloqade.analysis.validation.nocloning.lattice import May, Must total = 0 - for pass_name, frame in self.frames.items(): - if frame is None: + for pass_name, errors in self.errors.items(): + if errors is None: continue - - for node, value in frame.entries.items(): - if isinstance(value, Must): - total += len(value.violations) - elif isinstance(value, May): - total += len(value.violations) - + total += len(errors) return total def get_frame(self, pass_name: str) -> Any: