diff --git a/src/kirin/interp/abc.py b/src/kirin/interp/abc.py index a58f8bda2..b01c2f2a0 100644 --- a/src/kirin/interp/abc.py +++ b/src/kirin/interp/abc.py @@ -51,6 +51,10 @@ class InterpreterABC(ABC, Generic[FrameType, ValueType]): """The interpreter state.""" __eval_lock: bool = field(default=False, init=False, repr=False) """Lock for the eval method.""" + _validation_errors: dict[ir.IRNode, set[ir.ValidationError]] = field( + default_factory=dict, init=False + ) + """The validation errors collected during interpretation.""" def __init_subclass__(cls) -> None: super().__init_subclass__() @@ -330,3 +334,25 @@ def lookup_registry( def build_signature(self, frame: FrameType, node: ir.Statement) -> Signature: return Signature(node.__class__, tuple(arg.type for arg in node.args)) + + def add_validation_error(self, node: ir.IRNode, error: ir.ValidationError) -> None: + """Add a ValidationError for a given IR node. + + If the node is not present in the _validation_errors dict, create a new set. + Otherwise append to the existing set of errors. + """ + self._validation_errors.setdefault(node, set()).add(error) + + def get_validation_errors( + self, keys: set[ir.IRNode] | None = None + ) -> list[ir.ValidationError]: + """Get the validation errors collected during interpretation. + + If keys is provided, only return errors for the given nodes. + Otherwise return all errors. + """ + if keys is None: + return [err for s in self._validation_errors.values() for err in s] + return [ + err for node in keys for err in self._validation_errors.get(node, set()) + ] diff --git a/src/kirin/ir/exception.py b/src/kirin/ir/exception.py index ecd203795..5044f84eb 100644 --- a/src/kirin/ir/exception.py +++ b/src/kirin/ir/exception.py @@ -68,3 +68,15 @@ class TypeCheckError(ValidationError): class CompilerError(Exception): pass + + +class PotentialValidationError(ValidationError): + """Indicates a potential violation that may occur at runtime.""" + + pass + + +class DefiniteValidationError(ValidationError): + """Indicates a definite violation that will occur at runtime.""" + + pass diff --git a/src/kirin/validation/__init__.py b/src/kirin/validation/__init__.py new file mode 100644 index 000000000..9f55f34ad --- /dev/null +++ b/src/kirin/validation/__init__.py @@ -0,0 +1,4 @@ +from .validationpass import ( + ValidationPass as ValidationPass, + ValidationSuite as ValidationSuite, +) diff --git a/src/kirin/validation/validationpass.py b/src/kirin/validation/validationpass.py new file mode 100644 index 000000000..41a0bd405 --- /dev/null +++ b/src/kirin/validation/validationpass.py @@ -0,0 +1,181 @@ +from abc import ABC, abstractmethod +from typing import Any, Generic, TypeVar +from dataclasses import field, dataclass + +from kirin import ir +from kirin.ir.exception import ValidationError + +T = TypeVar("T") + + +class ValidationPass(ABC, Generic[T]): + """Base class for a validation pass. + + Each pass analyzes an IR method and collects validation errors. + """ + + @abstractmethod + def name(self) -> str: + """Return the name of this validation pass.""" + ... + + @abstractmethod + def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]: + """Run validation and return (analysis_frame, errors). + + Returns: + - analysis_frame: The result frame from the analysis + - errors: List of validation errors (empty if valid) + """ + ... + + def get_required_analyses(self) -> list[type]: + """Return list of analysis classes this pass depends on. + + Override to declare dependencies (e.g., [AddressAnalysis, AnotherAnalysis]). + The suite will run these analyses once and cache results. + """ + return [] + + def set_analysis_cache(self, cache: dict[type, Any]) -> None: + """Receive cached analysis results from the suite. + + Override to store cached analysis frames/results. + Example: + self._address_frame = cache.get(AddressAnalysis) + """ + pass + + +@dataclass +class ValidationSuite: + """Compose multiple validation passes and run them together. + + Caches analysis results to avoid redundant computation when multiple + validation passes depend on the same underlying analysis. + + Example: + suite = ValidationSuite([ + NoCloningValidation, + AnotherValidation, + ]) + result = suite.validate(my_kernel) + print(result.format_errors()) + """ + + passes: list[type[ValidationPass]] = field(default_factory=list) + fail_fast: bool = False + _analysis_cache: dict[type, Any] = field(default_factory=dict, init=False) + + def add_pass(self, pass_cls: type[ValidationPass]) -> "ValidationSuite": + """Add a validation pass to the suite.""" + self.passes.append(pass_cls) + return self + + def validate(self, method: ir.Method) -> "ValidationResult": + """Run all validation passes and collect results.""" + all_errors: dict[str, list[ValidationError]] = {} + all_frames: dict[str, Any] = {} + self._analysis_cache.clear() + + for pass_cls in self.passes: + validator = pass_cls() + pass_name = validator.name() + + try: + required = validator.get_required_analyses() + for required_analysis in required: + if required_analysis not in self._analysis_cache: + analysis = required_analysis(method.dialects) + analysis.initialize() + frame, _ = analysis.run(method) + self._analysis_cache[required_analysis] = frame + + validator.set_analysis_cache(self._analysis_cache) + + frame, errors = validator.run(method) + all_frames[pass_name] = frame + + for err in errors: + if isinstance(err, ValidationError): + try: + err.attach(method) + except Exception: + pass + + if errors: + all_errors[pass_name] = errors + if self.fail_fast: + break + + except Exception as e: + import traceback + + tb = traceback.format_exc() + all_errors[pass_name] = [ + ValidationError( + method.code, f"Validation pass '{pass_name}' failed: {e}\n{tb}" + ) + ] + if self.fail_fast: + break + + return ValidationResult(all_errors, all_frames) + + +@dataclass +class ValidationResult: + """Result of running a validation suite.""" + + errors: dict[str, list[ValidationError]] + frames: dict[str, Any] = field(default_factory=dict) + is_valid: bool = field(default=True, init=False) + + def __post_init__(self): + for _, errors in self.errors.items(): + if errors: + self.is_valid = False + break + + def error_count(self) -> int: + """Total number of violations across all passes. + + Counts violations directly from frames using the same logic as test helpers. + """ + + total = 0 + for pass_name, errors in self.errors.items(): + if errors is None: + continue + total += len(errors) + return total + + def get_frame(self, pass_name: str) -> Any: + """Get the analysis frame for a specific pass.""" + return self.frames.get(pass_name) + + def format_errors(self) -> str: + """Format all errors with their pass names.""" + if self.is_valid: + return "\n\033[32mAll validation passes succeeded\033[0m" + + lines = [ + f"\n\033[31mValidation failed with {self.error_count()} violation(s):\033[0m" + ] + for pass_name, pass_errors in self.errors.items(): + lines.append(f"\n\033[31m{pass_name}:\033[0m") + for err in pass_errors: + err_msg = err.args[0] if err.args else str(err) + lines.append(f" - {err_msg}") + if hasattr(err, "hint"): + hint = err.hint() + if hint: + lines.append(f" {hint}") + + return "\n".join(lines) + + def raise_if_invalid(self): + """Raise an exception if validation failed.""" + if not self.is_valid: + first_errors = next(iter(self.errors.values())) + raise first_errors[0]