Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/kirin/interp/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class InterpreterABC(ABC, Generic[FrameType, ValueType]):
"""The interpreter state."""
__eval_lock: bool = field(default=False, init=False, repr=False)
"""Lock for the eval method."""
_validation_errors: dict[ir.IRNode, set[ir.ValidationError]] = field(
default_factory=dict, init=False
)
"""The validation errors collected during interpretation."""

def __init_subclass__(cls) -> None:
super().__init_subclass__()
Expand Down Expand Up @@ -330,3 +334,25 @@ def lookup_registry(

def build_signature(self, frame: FrameType, node: ir.Statement) -> Signature:
return Signature(node.__class__, tuple(arg.type for arg in node.args))

def add_validation_error(self, node: ir.IRNode, error: ir.ValidationError) -> None:
"""Add a ValidationError for a given IR node.

If the node is not present in the _validation_errors dict, create a new set.
Otherwise append to the existing set of errors.
"""
self._validation_errors.setdefault(node, set()).add(error)

def get_validation_errors(
self, keys: set[ir.IRNode] | None = None
) -> list[ir.ValidationError]:
"""Get the validation errors collected during interpretation.

If keys is provided, only return errors for the given nodes.
Otherwise return all errors.
"""
if keys is None:
return [err for s in self._validation_errors.values() for err in s]
return [
err for node in keys for err in self._validation_errors.get(node, set())
]
12 changes: 12 additions & 0 deletions src/kirin/ir/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,15 @@ class TypeCheckError(ValidationError):

class CompilerError(Exception):
pass


class PotentialValidationError(ValidationError):
"""Indicates a potential violation that may occur at runtime."""

pass


class DefiniteValidationError(ValidationError):
"""Indicates a definite violation that will occur at runtime."""

pass
4 changes: 4 additions & 0 deletions src/kirin/validation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .validationpass import (
ValidationPass as ValidationPass,
ValidationSuite as ValidationSuite,
)
181 changes: 181 additions & 0 deletions src/kirin/validation/validationpass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from abc import ABC, abstractmethod
from typing import Any, Generic, TypeVar
from dataclasses import field, dataclass

from kirin import ir
from kirin.ir.exception import ValidationError

T = TypeVar("T")


class ValidationPass(ABC, Generic[T]):
"""Base class for a validation pass.

Each pass analyzes an IR method and collects validation errors.
"""

@abstractmethod
def name(self) -> str:
"""Return the name of this validation pass."""
...

@abstractmethod
def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]:
"""Run validation and return (analysis_frame, errors).

Returns:
- analysis_frame: The result frame from the analysis
- errors: List of validation errors (empty if valid)
"""
...

def get_required_analyses(self) -> list[type]:
"""Return list of analysis classes this pass depends on.

Override to declare dependencies (e.g., [AddressAnalysis, AnotherAnalysis]).
The suite will run these analyses once and cache results.
"""
return []

def set_analysis_cache(self, cache: dict[type, Any]) -> None:
"""Receive cached analysis results from the suite.

Override to store cached analysis frames/results.
Example:
self._address_frame = cache.get(AddressAnalysis)
"""
pass


@dataclass
class ValidationSuite:
"""Compose multiple validation passes and run them together.

Caches analysis results to avoid redundant computation when multiple
validation passes depend on the same underlying analysis.

Example:
suite = ValidationSuite([
NoCloningValidation,
AnotherValidation,
])
result = suite.validate(my_kernel)
print(result.format_errors())
"""

passes: list[type[ValidationPass]] = field(default_factory=list)
fail_fast: bool = False
_analysis_cache: dict[type, Any] = field(default_factory=dict, init=False)

def add_pass(self, pass_cls: type[ValidationPass]) -> "ValidationSuite":
"""Add a validation pass to the suite."""
self.passes.append(pass_cls)
return self

def validate(self, method: ir.Method) -> "ValidationResult":
"""Run all validation passes and collect results."""
all_errors: dict[str, list[ValidationError]] = {}
all_frames: dict[str, Any] = {}
self._analysis_cache.clear()

for pass_cls in self.passes:
validator = pass_cls()
pass_name = validator.name()

try:
required = validator.get_required_analyses()
for required_analysis in required:
if required_analysis not in self._analysis_cache:
analysis = required_analysis(method.dialects)
analysis.initialize()
frame, _ = analysis.run(method)
self._analysis_cache[required_analysis] = frame

validator.set_analysis_cache(self._analysis_cache)

frame, errors = validator.run(method)
all_frames[pass_name] = frame

for err in errors:
if isinstance(err, ValidationError):
try:
err.attach(method)
except Exception:
pass

if errors:
all_errors[pass_name] = errors
if self.fail_fast:
break

except Exception as e:
import traceback

tb = traceback.format_exc()
all_errors[pass_name] = [
ValidationError(
method.code, f"Validation pass '{pass_name}' failed: {e}\n{tb}"
)
]
if self.fail_fast:
break

return ValidationResult(all_errors, all_frames)


@dataclass
class ValidationResult:
"""Result of running a validation suite."""

errors: dict[str, list[ValidationError]]
frames: dict[str, Any] = field(default_factory=dict)
is_valid: bool = field(default=True, init=False)

def __post_init__(self):
for _, errors in self.errors.items():
if errors:
self.is_valid = False
break

def error_count(self) -> int:
"""Total number of violations across all passes.

Counts violations directly from frames using the same logic as test helpers.
"""

total = 0
for pass_name, errors in self.errors.items():
if errors is None:
continue
total += len(errors)
return total

def get_frame(self, pass_name: str) -> Any:
"""Get the analysis frame for a specific pass."""
return self.frames.get(pass_name)

def format_errors(self) -> str:
"""Format all errors with their pass names."""
if self.is_valid:
return "\n\033[32mAll validation passes succeeded\033[0m"

lines = [
f"\n\033[31mValidation failed with {self.error_count()} violation(s):\033[0m"
]
for pass_name, pass_errors in self.errors.items():
lines.append(f"\n\033[31m{pass_name}:\033[0m")
for err in pass_errors:
err_msg = err.args[0] if err.args else str(err)
lines.append(f" - {err_msg}")
if hasattr(err, "hint"):
hint = err.hint()
if hint:
lines.append(f" {hint}")

return "\n".join(lines)

def raise_if_invalid(self):
"""Raise an exception if validation failed."""
if not self.is_valid:
first_errors = next(iter(self.errors.values()))
raise first_errors[0]