-
Notifications
You must be signed in to change notification settings - Fork 0
Fixes for scaled initializers, function inlining, model selection, and added Conv1d support. #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -256,9 +256,7 @@ def _build_obs_decl(self, stmt: ObsDecl) -> ObsRef: | |
|
|
||
| def _build_constraint_stmt(self, stmt: ASTConstraintStmt) -> ConstraintOp: | ||
| """Build a constraint statement.""" | ||
| print(f"DEBUG: Building constraint stmt: {stmt}") | ||
| expr_node = self._build_expr(stmt.expr) | ||
| print(f"DEBUG: Constraint expr_node created: {type(expr_node)}") | ||
|
|
||
| # Determine constraint kind | ||
| if stmt.kind == ASTConstraintKind.REQUIRE: | ||
|
|
@@ -279,18 +277,16 @@ def _build_constraint_stmt(self, stmt: ASTConstraintStmt) -> ConstraintOp: | |
| if stmt.slack: | ||
| slack = self._build_expr(stmt.slack) | ||
|
|
||
| print("DEBUG: Before SIRProperty") | ||
| props = SIRProperty( | ||
| dtype=FloatType(), | ||
| requires_grad=expr_node.requires_grad, | ||
| dtype=ConstraintType(), | ||
| role=RoleInfo.constraint(), | ||
| location=stmt.location | ||
| ) | ||
|
|
||
| print("DEBUG: Before ConstraintOp") | ||
| return ConstraintOp( | ||
| kind=kind, | ||
| lhs=expr_node, | ||
| rhs=None, | ||
| rhs=None, # All constraints normalized to expr checks | ||
| weight=weight, | ||
| slack=slack, | ||
| _props=props | ||
|
|
@@ -358,7 +354,6 @@ def _build_if_stmt(self, stmt: IfStmt) -> Optional[SIRNode]: | |
| gate = GateOp(compare=TensorOpKind.GT, lhs=condition, rhs=Const(value=0.0), temperature=temperature) | ||
| then_val = then_block.result or Const(value=0.0) | ||
| else_val = (else_block.result if else_block else None) or Const(value=0.0) | ||
| print(f"DEBUG: IfStmt temperature={temperature}, then_result={then_block.result}, else_result={else_block.result if else_block else 'None'}") | ||
| return MixOp(gate, then_val, else_val) | ||
|
|
||
| return None | ||
|
|
@@ -447,9 +442,9 @@ def _build_block(self, block: Block) -> SIRBlock: | |
| if block.result: | ||
| result = self._build_expr(block.result) | ||
| nodes.append(result) | ||
| elif nodes and not isinstance(block.statements[-1], (ReturnStmt, ParamDecl, ObsDecl)): | ||
| elif nodes and not isinstance(block.statements[-1], (ParamDecl, ObsDecl)): | ||
|
||
| # Fallback: if no explicit result, use the last node as result (like Rust/Ruby) | ||
| # but only if it's not a return or declaration | ||
| # Declared params/obs don't produce values | ||
| result = nodes[-1] | ||
|
|
||
| return SIRBlock(nodes=nodes, result=result) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -27,7 +27,7 @@ | |||||
| # Note: We import frontend AST types here for initializer evaluation. | ||||||
| # Ideally this metadata would be extracted during compilation, but | ||||||
| # for now we use the AST directly since CompileResult exposes it. | ||||||
| from delta.frontend.ast import ParamDecl, Call, Identifier, Literal | ||||||
| from delta.frontend.ast import ParamDecl, Call, Identifier, Literal, BinaryOp | ||||||
|
||||||
| from delta.frontend.ast import ParamDecl, Call, Identifier, Literal, BinaryOp | |
| from delta.frontend.ast import ParamDecl, Call, Identifier, Literal |
Copilot
AI
Jan 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The BinaryOp import was added to the top-level imports but is also imported again within the _extract_shape_from_initializer method at line 157. Consider removing the redundant local import since BinaryOp is now available from the module-level import at line 30.
| from delta.frontend.ast import Call, Identifier, Literal, Tensor, BinaryOp | |
| from delta.frontend.ast import Tensor |
Copilot
AI
Jan 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The handling of BinaryOp expressions in _extract_shape_from_initializer is incomplete. While it recursively extracts shapes from binary operations like randn(3, 2) * 0.01, the corresponding _eval_initializer method (lines 90-137) doesn't handle BinaryOp expressions. This means scaled initializers will have their shapes extracted correctly, but the actual initialization won't apply the scaling factor. Consider also updating _eval_initializer to handle BinaryOp expressions to fully support scaled initializers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type
ConstraintType()does not exist in the codebase. This will cause a NameError at runtime when building constraint statements. The type should beFloatType()instead, as constraints evaluate to scalar penalty terms that are added to the objective. The original code usedFloatType()for the dtype of constraints.