Skip to content

Commit da35b8b

Browse files
Implement neutral put
1 parent 6ef0830 commit da35b8b

File tree

4 files changed

+107
-9
lines changed

4 files changed

+107
-9
lines changed

effekt/jvm/src/test/scala/effekt/core/NewNormalizerTests.scala

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,57 @@ class NewNormalizerTests extends CoreTests {
460460

461461
assertAlphaEquivalentToplevels(actual, expected, List("run"))
462462
}
463+
464+
// This test shows a mutable reference captured by a block parameter.
465+
// During normalization, this block parameter gets lifted to a `def`.
466+
// One might hope for this mutable variable to be eliminated entirely,
467+
// but currently the normalizer does not inline definitions.
468+
test("Mutable reference can be lifted") {
469+
val input =
470+
"""
471+
|def run(): Int = {
472+
| def modifyProg { setter: Int => Unit }: Unit = {
473+
| setter(2)
474+
| ()
475+
| }
476+
| var x = 1
477+
| modifyProg { y => x = y }
478+
| x
479+
|}
480+
|
481+
|def main() = println(run())
482+
|""".stripMargin
483+
484+
val (mainId, actual) = normalize(input)
485+
486+
val expected =
487+
parse(
488+
"""
489+
|module input
490+
|
491+
|def run() = {
492+
| def modifyProg(){setter: (Int) => Unit} = {
493+
| let x = 2
494+
| val tmp: Unit = (setter: (Int) => Unit @ {setter})(x: Int);
495+
| let y = ()
496+
| return y: Unit
497+
| }
498+
| let y = 1
499+
| var x @ c = y: Int;
500+
| def f(y: Int) = {
501+
| put x @ c = y: Int;
502+
| let z = ()
503+
| return z: Unit
504+
| }
505+
| val tmp: Unit = (modifyProg: (){setter : (Int) => Unit} => Unit @ {})(){f: (Int) => Unit @ {c}};
506+
| get o: Int = !x @ c;
507+
| return o: Int
508+
|}
509+
|""".stripMargin
510+
)
511+
512+
assertAlphaEquivalentToplevels(actual, expected, List("run"))
513+
}
463514
}
464515

465516
/**

effekt/jvm/src/test/scala/effekt/core/TestRenamer.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ class TestRenamer(names: Names = Names(Map.empty), prefix: String = "", defaultS
8282
val resolvedCapt = rewrite(capt)
8383
withBinding(id) { core.Get(rewrite(id), rewrite(tpe), resolvedRef, resolvedCapt, rewrite(body)) }
8484

85+
case core.App(callee, targs, vargs, bargs) =>
86+
val resolvedCallee = rewrite(callee)
87+
val resolvedTargs = targs map rewrite
88+
val resolvedVargs = vargs map rewrite
89+
val resolvedBargs = bargs map rewrite
90+
core.App(resolvedCallee, resolvedTargs, resolvedVargs, resolvedBargs)
8591
}
8692

8793
override def block: PartialFunction[Block, Block] = {
@@ -90,6 +96,32 @@ class TestRenamer(names: Names = Names(Map.empty), prefix: String = "", defaultS
9096
Block.BlockLit(tparams map rewrite, cparams map rewrite, vparams map rewrite, bparams map rewrite,
9197
rewrite(body))
9298
}
99+
case Block.BlockVar(id: Id, annotatedTpe: BlockType, annotatedCapt: Captures) => {
100+
withBinding(id) {
101+
val idOut = rewrite(id)
102+
val annotatedTpeOut = rewrite(annotatedTpe)
103+
val annotatedCaptOut = rewrite(annotatedCapt)
104+
Block.BlockVar(rewrite(id), rewrite(annotatedTpe), rewrite(annotatedCapt))
105+
}
106+
}
107+
}
108+
109+
override def rewrite(t: BlockType): BlockType = t match {
110+
case BlockType.Function(tparams, cparams, vparams, bparams, result: ValueType) =>
111+
// TODO: is this how we want to treat captures here?
112+
val resolvedCapt = cparams.map(id => Map(id -> freshIdFor(id))).reduceOption(_ ++ _).getOrElse(Map())
113+
withBindings(tparams) {
114+
withMapping(resolvedCapt) {
115+
BlockType.Function(
116+
tparams.map(rewrite),
117+
resolvedCapt.values.toList.map(rewrite),
118+
vparams.map(rewrite),
119+
bparams.map(rewrite),
120+
rewrite(result)
121+
)
122+
}}
123+
case BlockType.Interface(name, targs) =>
124+
BlockType.Interface(name, targs map rewrite)
93125
}
94126

95127
override def rewrite(o: Operation): Operation = o match {

effekt/shared/src/main/scala/effekt/core/Parser.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class EffektLexers extends Parsers {
6161
lazy val `|` = literal("|")
6262

6363
lazy val `get` = keyword("get")
64+
lazy val `put` = keyword("put")
6465
lazy val `let` = keyword("let")
6566
lazy val `true` = keyword("true")
6667
lazy val `false` = keyword("false")
@@ -323,6 +324,9 @@ class CoreParsers(names: Names) extends EffektLexers {
323324
| `get` ~> id ~ (`:` ~> valueType) ~ (`=` ~> `!` ~> id) ~ (`@` ~> id) ~ (`;` ~> stmts) ^^ {
324325
case name ~ tpe ~ ref ~ cap ~ body => Get(name, tpe, ref, Set(cap), body)
325326
}
327+
| `put` ~> id ~ (`@` ~> id) ~ (`=` ~> expr) ~ (`;` ~> stmts) ^^ {
328+
case ref ~ capt ~ value ~ body => Put(ref, Set(capt), value, body)
329+
}
326330
| `def` ~> id ~ (`=` ~/> block) ~ stmts ^^ Stmt.Def.apply
327331
| `def` ~> id ~ parameters ~ (`=` ~/> stmt) ~ stmts ^^ {
328332
case name ~ (tparams, cparams, vparams, bparams) ~ body ~ rest =>

effekt/shared/src/main/scala/effekt/core/optimizer/NewNormalizer.scala

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ object semantics {
305305
case Resume(k: Id, body: BasicBlock)
306306

307307
case Var(id: BlockParam, init: Addr, body: BasicBlock)
308-
// case Put
308+
case Put(ref: Id, tpe: ValueType, cap: Captures, value: Addr, body: BasicBlock)
309309

310310
// aborts at runtime
311311
case Hole(span: Span)
@@ -322,6 +322,7 @@ object semantics {
322322
case NeutralStmt.Resume(k, body) => Set(k) ++ body.free
323323
case NeutralStmt.Var(id, init, body) => Set(init) ++ body.free - id.id
324324
case NeutralStmt.Hole(span) => Set.empty
325+
case NeutralStmt.Put(ref, tpe, cap, value, body) => Set(ref, value) ++ body.free
325326
}
326327
}
327328

@@ -396,13 +397,14 @@ object semantics {
396397
case Stack.Var(id1, curr, init, frame, next) if ref == id1.id => Some(curr)
397398
case Stack.Var(id1, curr, init, frame, next) => get(ref, next)
398399
}
399-
400-
def put(id: Id, value: Addr, ks: Stack): Stack = ks match {
400+
401+
def put(id: Id, value: Addr, ks: Stack): Option[Stack] = ks match {
401402
case Stack.Empty => sys error s"Should not happen: trying to put ${util.show(id)} in empty stack"
402-
case Stack.Unknown => sys error s"Cannot put ${util.show(id)} in unknown stack"
403-
case Stack.Reset(prompt, frame, next) => Stack.Reset(prompt, frame, put(id, value, next))
404-
case Stack.Var(id1, curr, init, frame, next) if id == id1.id => Stack.Var(id1, value, init, frame, next)
405-
case Stack.Var(id1, curr, init, frame, next) => Stack.Var(id1, curr, init, frame, put(id, value, next))
403+
// We have reached the end of the known stack, so the variable must be in the unknown part.
404+
case Stack.Unknown => None
405+
case Stack.Reset(prompt, frame, next) => put(id, value, next).map(Stack.Reset(prompt, frame, _))
406+
case Stack.Var(id1, curr, init, frame, next) if id == id1.id => Some(Stack.Var(id1, value, init, frame, next))
407+
case Stack.Var(id1, curr, init, frame, next) => put(id, value, next).map(Stack.Var(id1, curr, init, frame, _))
406408
}
407409

408410
enum Cont {
@@ -544,6 +546,9 @@ object semantics {
544546
"var" <+> toDoc(id) <+> "=" <+> toDoc(init) <> line <> toDoc(body)
545547

546548
case NeutralStmt.Hole(span) => "hole()"
549+
550+
case NeutralStmt.Put(ref, tpe, cap, value, body) =>
551+
"put" <+> toDoc(ref) <+> "=" <+> toDoc(value) <> line <> toDoc(body)
547552
}
548553

549554
def toDoc(id: Id): Doc = id.show
@@ -598,7 +603,7 @@ object semantics {
598603
(if (targs.isEmpty) emptyDoc else brackets(hsep(targs.map(toDoc), comma))) <>
599604
parens(hsep(vargs.map(toDoc), comma)) <> hcat(bargs.map(b => braces { toDoc(b) })) <> line
600605
case (addr, Binding.Unbox(innerAddr, tpe, capt)) => "def" <+> toDoc(addr) <+> "=" <+> "unbox" <+> toDoc(innerAddr) <> line
601-
case (addr, Binding.Get(ref, tpe, cap)) => "let" <+> toDoc(addr) <+> "=" <+> "!" <> toDoc(ref) <> line
606+
case (addr, Binding.Get(ref, tpe, cap)) => "get" <+> toDoc(addr) <+> "=" <+> "!" <> toDoc(ref) <> line
602607
})
603608

604609
def toDoc(block: BasicBlock): Doc =
@@ -872,7 +877,11 @@ class NewNormalizer(shouldInline: (Id, BlockLit) => Boolean) {
872877
case None => bind(id, scope.allocateGet(ref, annotatedTpe, annotatedCapt)) { evaluate(body, k, ks) }
873878
}
874879
case Stmt.Put(ref, annotatedCapt, value, body) =>
875-
evaluate(body, k, put(ref, evaluate(value), ks))
880+
put(ref, evaluate(value), ks) match {
881+
case Some(stack) => evaluate(body, k, stack)
882+
case None =>
883+
NeutralStmt.Put(ref, value.tpe, annotatedCapt, evaluate(value), nested { evaluate(body, k, ks) })
884+
}
876885

877886
// Control Effects
878887
case Stmt.Shift(prompt, core.Block.BlockLit(Nil, cparam :: Nil, Nil, k2 :: Nil, body)) =>
@@ -1017,6 +1026,8 @@ class NewNormalizer(shouldInline: (Id, BlockLit) => Boolean) {
10171026
Stmt.Var(blockParam.id, embedExpr(init), capt, embedStmt(body)(using G.bind(blockParam.id, blockParam.tpe, blockParam.capt)))
10181027
case NeutralStmt.Hole(span) =>
10191028
Stmt.Hole(span)
1029+
case NeutralStmt.Put(ref, annotatedTpe, annotatedCapt, value, body) =>
1030+
Stmt.Put(ref, annotatedCapt, embedExpr(value), embedStmt(body))
10201031
}
10211032

10221033
def embedStmt(basicBlock: BasicBlock)(using G: TypingContext): core.Stmt = basicBlock match {

0 commit comments

Comments
 (0)