diff --git a/delta/backend/fx_lowering.py b/delta/backend/fx_lowering.py index 48d96af..c15c6b7 100644 --- a/delta/backend/fx_lowering.py +++ b/delta/backend/fx_lowering.py @@ -527,6 +527,8 @@ def _lower_layer(self, node: Layer, ctx: FXContext) -> Node: # Instantiate the actual torch module if node.kind == 'Linear': self.layers[layer_id] = torch.nn.Linear(*node.args, **node.kwargs) + elif node.kind == 'Conv1d': + self.layers[layer_id] = torch.nn.Conv1d(*node.args, **node.kwargs) elif node.kind == 'Conv2d': self.layers[layer_id] = torch.nn.Conv2d(*node.args, **node.kwargs) elif node.kind == 'Embedding': diff --git a/delta/ir/sir_builder.py b/delta/ir/sir_builder.py index 9258728..438df57 100644 --- a/delta/ir/sir_builder.py +++ b/delta/ir/sir_builder.py @@ -256,9 +256,7 @@ def _build_obs_decl(self, stmt: ObsDecl) -> ObsRef: def _build_constraint_stmt(self, stmt: ASTConstraintStmt) -> ConstraintOp: """Build a constraint statement.""" - print(f"DEBUG: Building constraint stmt: {stmt}") expr_node = self._build_expr(stmt.expr) - print(f"DEBUG: Constraint expr_node created: {type(expr_node)}") # Determine constraint kind if stmt.kind == ASTConstraintKind.REQUIRE: @@ -279,18 +277,16 @@ def _build_constraint_stmt(self, stmt: ASTConstraintStmt) -> ConstraintOp: if stmt.slack: slack = self._build_expr(stmt.slack) - print("DEBUG: Before SIRProperty") props = SIRProperty( - dtype=FloatType(), - requires_grad=expr_node.requires_grad, + dtype=ConstraintType(), + role=RoleInfo.constraint(), location=stmt.location ) - print("DEBUG: Before ConstraintOp") return ConstraintOp( kind=kind, lhs=expr_node, - rhs=None, + rhs=None, # All constraints normalized to expr checks weight=weight, slack=slack, _props=props @@ -358,7 +354,6 @@ def _build_if_stmt(self, stmt: IfStmt) -> Optional[SIRNode]: gate = GateOp(compare=TensorOpKind.GT, lhs=condition, rhs=Const(value=0.0), temperature=temperature) then_val = then_block.result or Const(value=0.0) else_val = (else_block.result if else_block else None) or Const(value=0.0) - print(f"DEBUG: IfStmt temperature={temperature}, then_result={then_block.result}, else_result={else_block.result if else_block else 'None'}") return MixOp(gate, then_val, else_val) return None @@ -447,9 +442,9 @@ def _build_block(self, block: Block) -> SIRBlock: if block.result: result = self._build_expr(block.result) nodes.append(result) - elif nodes and not isinstance(block.statements[-1], (ReturnStmt, ParamDecl, ObsDecl)): + elif nodes and not isinstance(block.statements[-1], (ParamDecl, ObsDecl)): # Fallback: if no explicit result, use the last node as result (like Rust/Ruby) - # but only if it's not a return or declaration + # Declared params/obs don't produce values result = nodes[-1] return SIRBlock(nodes=nodes, result=result) diff --git a/delta/run.py b/delta/run.py index 4f71f0c..d3256de 100644 --- a/delta/run.py +++ b/delta/run.py @@ -27,7 +27,7 @@ # Note: We import frontend AST types here for initializer evaluation. # Ideally this metadata would be extracted during compilation, but # for now we use the AST directly since CompileResult exposes it. -from delta.frontend.ast import ParamDecl, Call, Identifier, Literal +from delta.frontend.ast import ParamDecl, Call, Identifier, Literal, BinaryOp @dataclass @@ -48,7 +48,11 @@ class DeltaModel: def __init__(self, compile_result: CompileResult): self._compile_result = compile_result - self._module = list(compile_result.graph_modules.values())[0] + # Prefer 'forward' module if available (standard for training), otherwise take the first one + if compile_result.graph_modules and "forward" in compile_result.graph_modules: + self._module = compile_result.graph_modules["forward"] + else: + self._module = list(compile_result.graph_modules.values())[0] self._device = torch.device("cpu") self._init_params_from_sir() @@ -150,7 +154,15 @@ def _init_from_name(self, name: str, shape: tuple) -> torch.Tensor: def _extract_shape_from_initializer(self, initializer_expr: Any) -> Optional[tuple]: """Extract shape from an initializer expression like randn(3, 2).""" - from delta.frontend.ast import Call, Identifier, Literal, Tensor + from delta.frontend.ast import Call, Identifier, Literal, Tensor, BinaryOp + + # Handle scaling: randn(...) * 0.01 + if isinstance(initializer_expr, BinaryOp): + # Recursively check operands + shape = self._extract_shape_from_initializer(initializer_expr.left) + if shape: return shape + return self._extract_shape_from_initializer(initializer_expr.right) + if isinstance(initializer_expr, Call) and isinstance(initializer_expr.func, Identifier): func_name = initializer_expr.func.name args = initializer_expr.args