From d9e53aec50f15d3d90f761d2fb14f4a4cf2d3d3c Mon Sep 17 00:00:00 2001 From: LPTK Date: Tue, 10 Oct 2017 14:00:43 +0200 Subject: [PATCH 01/66] Some tests for higher-order pattern variables --- .../scala/squid/quasi/QuasiEmbedder.scala | 6 +- .../feature/HigherOrderPatternVariables.scala | 59 +++++++++++++++++++ 2 files changed, 61 insertions(+), 4 deletions(-) create mode 100644 src/test/scala/squid/feature/HigherOrderPatternVariables.scala diff --git a/core/src/main/scala/squid/quasi/QuasiEmbedder.scala b/core/src/main/scala/squid/quasi/QuasiEmbedder.scala index 7bcf78ba..327c910d 100644 --- a/core/src/main/scala/squid/quasi/QuasiEmbedder.scala +++ b/core/src/main/scala/squid/quasi/QuasiEmbedder.scala @@ -42,9 +42,7 @@ class QuasiEmbedder[C <: whitebox.Context](val c: C) { /** holes: Seq(either term or type); splicedHoles: which holes are spliced term holes (eg: List($xs*)) * holes in ction mode are interpreted as free variables (and are never spliced) - * unquotedTypes contains the types unquoted in (in ction mode with $ and in xtion mode with $$) - * TODO: actually use `unquotedTypes`! - * TODO maybe this should go in QuasiMacros... */ + * unquotedTypes contains the types that are inserted (in ction mode with $ and in xtion mode with $$) */ def applyQQ(Base: Tree, tree: c.Tree, holes: Seq[Either[TermName,TypeName]], splicedHoles: collection.Set[TermName], hopvHoles: collection.Map[TermName,List[List[TermName]]], typeBounds: Map[TypeName,EitherOrBoth[Tree,Tree]], @@ -497,7 +495,7 @@ class QuasiEmbedder[C <: whitebox.Context](val c: C) { hopvHoles get TermName(name) match { case Some(idents) => - // TODO make sure HOPV holes withb the same name are not repeated? or at least have the same param types + // TODO make sure HOPV holes with the same name are not repeated? or at least have the same param types // TODO handle HOPV holes in spliced position? val scp = ctx map { case k -> v => k.name -> (k -> v) } diff --git a/src/test/scala/squid/feature/HigherOrderPatternVariables.scala b/src/test/scala/squid/feature/HigherOrderPatternVariables.scala new file mode 100644 index 00000000..980e93f7 --- /dev/null +++ b/src/test/scala/squid/feature/HigherOrderPatternVariables.scala @@ -0,0 +1,59 @@ +package squid +package feature + +import utils._ + +class HigherOrderPatternVariables extends MyFunSuite { + import TestDSL.Predef._ + + test("Matching lambda bodies") { + + val id = ir"(z:Int) => z" + + ir"(a: Int) => a + 1" matches { + case ir"(x: Int) => $body(x):Int" => + body eqt ir"(_:Int)+1" + } and { + case ir"(x: Int) => $body(x):$t" => + body eqt ir"(_:Int)+1" + eqt(t, irTypeOf[Int]) + } and { + case ir"(x: Int) => ($exp(x):Int)+1" => + exp eqt id + } + + ir"(a: Int, b: Int) => a + 1" matches { + case ir"(x: Int, y: Int) => $body(y):Int" => fail + case ir"(x: Int, y: Int) => $body(x):Int" => + } and { + case ir"(x: Int, y: Int) => $body(x,y):Int" => + } + + ir"(a: Int, b: Int) => a + b" matches { + case ir"(x: Int, y: Int) => $body(x):Int" => fail + case ir"(x: Int, y: Int) => $body(x,y):Int" => + body eqt ir"(_:Int)+(_:Int)" + } and { + case ir"(x: Int, y: Int) => ($lhs(y):Int)+($rhs(y):Int)" => fail + case ir"(x: Int, y: Int) => ($lhs(x):Int)+($rhs(y):Int)" => + lhs eqt id + rhs eqt id + } + + } + + test("Matching let-binding bodies") { + + ir"val a = 0; val b = 1; a + b" matches { + case ir"val x: Int = $v; $body(x):Int" => + v eqt ir"0" + body matches { + case ir"(y:Int) => { val x: Int = $w; $body(x,y):Int }" => + w eqt ir"1" + body eqt ir"(u:Int,v:Int) => (v:Int)+(u:Int)" + } + } + + } + +} From e0b60d386669e9996c1746dcda90cdf6fe5be60b Mon Sep 17 00:00:00 2001 From: LPTK Date: Sat, 28 Oct 2017 17:21:48 -0700 Subject: [PATCH 02/66] Generalize `hopHole` to handle arbitrary patterns: introduce `hopHole2` `hopHole2` should eventually replace `hopHole` The implementation should also be completed with correct handling of FVs --- .../scala/squid/lang/InspectableBase.scala | 2 +- .../main/scala/squid/quasi/MetaBases.scala | 2 + .../main/scala/squid/quasi/QuasiBase.scala | 13 +- .../scala/squid/quasi/QuasiEmbedder.scala | 124 +++++++++--------- .../main/scala/squid/quasi/QuasiMacros.scala | 23 ++-- src/main/scala/squid/ir/AST.scala | 6 + 6 files changed, 94 insertions(+), 76 deletions(-) diff --git a/core/src/main/scala/squid/lang/InspectableBase.scala b/core/src/main/scala/squid/lang/InspectableBase.scala index df6cee4c..29cfa83e 100644 --- a/core/src/main/scala/squid/lang/InspectableBase.scala +++ b/core/src/main/scala/squid/lang/InspectableBase.scala @@ -95,7 +95,7 @@ trait InspectableBase extends IntermediateBase with quasi.QuasiBase with TraceDe } protected implicit class ProtectedInspectableRepOps(private val self: Rep) { - def extract (that: Rep) = baseSelf.extract(self, that) + def extract (that: Rep) = baseSelf.extractRep(self, that) } implicit class InspectableRepOps(private val self: Rep) { /** Note: this is only a to-level call to `base.extractRep`; not supposed to be called in implementations of `extract` itself */ diff --git a/core/src/main/scala/squid/quasi/MetaBases.scala b/core/src/main/scala/squid/quasi/MetaBases.scala index f423f4c4..cc7f4b03 100644 --- a/core/src/main/scala/squid/quasi/MetaBases.scala +++ b/core/src/main/scala/squid/quasi/MetaBases.scala @@ -170,6 +170,8 @@ trait MetaBases { def hole(name: String, typ: TypeRep): Rep = q"$Base.hole($name, $typ)" def hopHole(name: String, typ: TypeRep, yes: List[List[BoundVal]], no: List[BoundVal]): Rep = q"$Base.hopHole($name, $typ, ${yes map (_ map (_._2))}, ${no map (_._2)})" + override def hopHole2(name: String, typ: TypeRep, argss: List[List[Rep]], visible: List[BoundVal]): Rep = + q"$Base.hopHole2($name, $typ, ${argss}, ${visible map (_._2)})" def splicedHole(name: String, typ: TypeRep): Rep = q"$Base.splicedHole($name, $typ)" def substitute(r: => Rep, defs: Map[String, Rep]): Rep = diff --git a/core/src/main/scala/squid/quasi/QuasiBase.scala b/core/src/main/scala/squid/quasi/QuasiBase.scala index 92120b53..a894c89f 100644 --- a/core/src/main/scala/squid/quasi/QuasiBase.scala +++ b/core/src/main/scala/squid/quasi/QuasiBase.scala @@ -46,6 +46,14 @@ self: Base => * this should check that the extracted term does not contain any reference to a bound value contained in the `no` * parameter, and it should extract a function term with the arity of the `yes` parameter. */ def hopHole(name: String, typ: TypeRep, yes: List[List[BoundVal]], no: List[BoundVal]): Rep + def hopHole2(name: String, typ: TypeRep, args: List[List[Rep]], visible: List[BoundVal]): Rep = { + // TODO remove `hopHole` and implement this correctly everywhere + //val vars = args map (_ map (_.asInstanceOf[BoundVal]))\ + // ^ Note: this will succeed even if there are some non-BoundVal, because BoundVal is eliminated by erasure + // it's actually wrong (in AST, readVal(_:BoundVal) adds a Rep wrapper! + //hopHole(name, typ, vars, visible.toSet -- vars.flatten toList) + throw new UnsupportedOperationException("Higher-order patterns") + } /** Pattern hole in type position */ def typeHole(name: String): TypeRep @@ -320,11 +328,14 @@ self: Base => def $Code[T](q: Code[T]): T = ??? // TODO B/E -- also, rename to 'unquote'? def $Code[A,B](q: Code[A] => Code[B]): A => B = ??? - /* To support hole syntax `xs?` (old syntax `$$xs`) (in ction) or `$xs` (in xtion) */ + /* To support hole syntax `?xs` (old syntax `$$xs`) (in ction) or `$xs` (in xtion) */ def $$[T](name: Symbol): T = ??? /* To support hole syntax `$$xs: _*` (in ction) or `$xs: _*` (in xtion) */ def $$_*[T](name: Symbol): Seq[T] = ??? + /** Higher-order pattern */ + def $$_hop[T](name: Symbol)(args: Any*): T = ??? + implicit def liftFun[A:IRType,B,C](qf: IR[A,C] => IR[B,C]): IR[A => B,C] = { val bv = bindVal("lifted", typeRepOf[A], Nil) // add TODO annotation recording the lifting? diff --git a/core/src/main/scala/squid/quasi/QuasiEmbedder.scala b/core/src/main/scala/squid/quasi/QuasiEmbedder.scala index 327c910d..c9cc12a6 100644 --- a/core/src/main/scala/squid/quasi/QuasiEmbedder.scala +++ b/core/src/main/scala/squid/quasi/QuasiEmbedder.scala @@ -43,10 +43,10 @@ class QuasiEmbedder[C <: whitebox.Context](val c: C) { /** holes: Seq(either term or type); splicedHoles: which holes are spliced term holes (eg: List($xs*)) * holes in ction mode are interpreted as free variables (and are never spliced) * unquotedTypes contains the types that are inserted (in ction mode with $ and in xtion mode with $$) */ - def applyQQ(Base: Tree, tree: c.Tree, holes: Seq[Either[TermName,TypeName]], - splicedHoles: collection.Set[TermName], hopvHoles: collection.Map[TermName,List[List[TermName]]], - typeBounds: Map[TypeName,EitherOrBoth[Tree,Tree]], - unquotedTypes: Seq[(TypeName, Type, Tree)], unapply: Option[c.Tree], config: QuasiConfig): c.Tree = { + def applyQQ(Base: Tree, tree: c.Tree, holes: Seq[Either[TermName,TypeName]], + splicedHoles: collection.Set[TermName], hopHoles: collection.Set[TermName], + typeBounds: Map[TypeName,EitherOrBoth[Tree,Tree]], + unquotedTypes: Seq[(TypeName, Type, Tree)], unapply: Option[c.Tree], config: QuasiConfig): c.Tree = { //debug("HOLES:",holes) @@ -224,7 +224,7 @@ class QuasiEmbedder[C <: whitebox.Context](val c: C) { apply(Base, finalTree, termScope, config, unapply, - typeSymbols, holeSymbols, holes, splicedHoles, hopvHoles, termHoles, typeHoles, typedTree, typedTreeType, stmts, convNames, unquotedTypes) + typeSymbols, holeSymbols, holes, splicedHoles, hopHoles, termHoles, typeHoles, typedTree, typedTreeType, stmts, convNames, unquotedTypes) } @@ -240,7 +240,7 @@ class QuasiEmbedder[C <: whitebox.Context](val c: C) { def apply( baseTree: Tree, rawTree: c.Tree, termScopeParam: List[Type], config: QuasiConfig, unapply: Option[c.Tree], typeSymbols: Map[TypeName, Symbol], holeSymbols: Set[Symbol], - holes: Seq[Either[TermName,TypeName]], splicedHoles: collection.Set[TermName], hopvHoles: collection.Map[TermName,List[List[TermName]]], + holes: Seq[Either[TermName,TypeName]], splicedHoles: collection.Set[TermName], hopHoles: collection.Set[TermName], termHoles: Set[TermName], typeHoles: Set[TypeName], typedTree: Tree, typedTreeType: Type, stmts: List[Tree], convNames: Set[TermName], unquotedTypes: Seq[(TypeName, Type, Tree)] ): c.Tree = { @@ -320,6 +320,23 @@ class QuasiEmbedder[C <: whitebox.Context](val c: C) { override def liftTerm(x: Tree, parent: Tree, expectedType: Option[Type], inVarargsPos: Boolean)(implicit ctx: Map[TermSymbol, b.BoundVal]): b.Rep = { object HoleName { def unapply(tr: Tree) = Some(holeName(tr,x)) } + def mkHoleType(name: String, tpt: Tree) = expectedType match { + case None if tpt.tpe <:< Nothing => throw EmbeddingException(s"No type info for hole '$name'" + ( + if (debug.debugOptionEnabled) s", in: $parent" else "" )) // TODO: check only after we have explored all repetitions of the hole? (not sure if possible) + case Some(tp) => + assert(!SCALA_REPEATED(tp.typeSymbol.fullName.toString)) + + if (tp <:< Nothing) parent match { + case q"$_: $_" => // this hole is ascribed explicitly; the Nothing type must be desired + case _ => macroContext.warning(macroContext.enclosingPosition, + s"Type inferred for hole '$name' was Nothing. Ascribe the hole explicitly to remove this warning.") // TODO show real hole in its original form + } + + val mergedTp = if (tp <:< tpt.tpe || tpt.tpe <:< Nothing) tp else tpt.tpe + debug(s"[Hole '$name'] Expected",tp," required",tpt.tpe," merged",mergedTp) + + mergedTp + } x match { @@ -463,74 +480,57 @@ class QuasiEmbedder[C <: whitebox.Context](val c: C) { throw QuasiException("Unknown use of free variable syntax operator `?`.") + /** Handling of higher-order patterns: */ + case q"$baseTree.$$$$_hop[$tpt](${nameTree @ HoleName(name)})(..$args)" => // TODO check baseTree + // TODO remove _extracted_ binders (as in `val $x = ...`) from 'visible' bindings?; add them to context requirements... + debug("HOP",tpt,nameTree,args) + + val holeType = mkHoleType(name, tpt) + + // TODO handle contexts and FVs... + // TODO make sure HOPV holes with the same name are not repeated? or at least have the same param types + // Q: handle HOPV holes in spliced position? + + val termName = TermName(name) + + val hopvType = FunctionType(args map (_.tpe) : _*)(holeType) + termHoleInfo(termName) = Map() -> hopvType + + val largs = args map (liftTerm(_,x,None,false)) + + b.hopHole2(name, liftType(holeType), (largs)::Nil, ctx.values.toList) + /** Replaces calls to $$(name) with actual holes */ case q"$baseTree.$$$$[$tpt]($nameTree)" => // TODO check baseTree val name = holeName(nameTree, x) - val holeType = expectedType match { - case None if tpt.tpe <:< Nothing => throw EmbeddingException(s"No type info for hole '$name'" + ( - if (debug.debugOptionEnabled) s", in: $parent" else "" )) // TODO: check only after we have explored all repetitions of the hole? (not sure if possible) - case Some(tp) => - assert(!SCALA_REPEATED(tp.typeSymbol.fullName.toString)) - - if (tp <:< Nothing) parent match { - case q"$_: $_" => // this hole is ascribed explicitly; the Nothing type must be desired - case _ => macroContext.warning(macroContext.enclosingPosition, - s"Type inferred for hole '$name' was Nothing. Ascribe the hole explicitly to remove this warning.") // TODO show real hole in its original form - } - - val mergedTp = if (tp <:< tpt.tpe || tpt.tpe <:< Nothing) tp else tpt.tpe - debug(s"[Hole '$name'] Expected",tp," required",tpt.tpe," merged",mergedTp) - - mergedTp - } - + val holeType = mkHoleType(name, tpt) //if (splicedHoles(TermName(name))) throw EmbeddingException(s"Misplaced spliced hole: '$name'") // used to be: q"$Base.splicedHole[${_expectedType}](${name.toString})" val termName = TermName(name) - hopvHoles get TermName(name) match { - case Some(idents) => - - // TODO make sure HOPV holes with the same name are not repeated? or at least have the same param types - // TODO handle HOPV holes in spliced position? - - val scp = ctx map { case k -> v => k.name -> (k -> v) } - val identVals = idents map (_ map scp) - val yes = identVals map (_ map (_._2)) - val keys = idents.flatten.toSet - val no = scp.filterKeys(!keys(_)).values.map(_._2).toList - debug(s"HOPV: yes=$yes; no=$no") - val List(identVals2) = identVals // FIXME generalize - val hopvType = FunctionType(identVals2 map (_._1.typeSignature) : _*)(holeType) - termHoleInfo(termName) = Map() -> hopvType - b.hopHole(name, liftType(holeType), yes, no) - - case _ => - if (unapply isEmpty) { - debug(s"Free variable: $name: $tpt") - freeVariableInstances ::= name -> holeType - // TODO maybe aggregate glb type beforehand so we can pass the precise type here... could even pass the _same_ hole! - } else { - //val termName = TermName(name) - val scp = ctx.keys map (k => k.name -> k.typeSignature) toMap; - termHoleInfo get termName map { case (scp0, holeType0) => - val newScp = scp0 ++ scp.map { case (n,t) => lub(t :: scp0.getOrElse(n, Any) :: Nil) } - newScp -> lub(holeType :: holeType0 :: Nil) - } getOrElse { - termHoleInfo(termName) = scp -> holeType - } - } - - if (splicedHoles(TermName(name))) - b.splicedHole(name, liftType(holeType)) - else b.hole(name, liftType(holeType)) - + if (unapply isEmpty) { + debug(s"Free variable: $name: $tpt") + freeVariableInstances ::= name -> holeType + // TODO maybe aggregate glb type beforehand so we can pass the precise type here... could even pass the _same_ hole! + } else { + //val termName = TermName(name) + val scp = ctx.keys map (k => k.name -> k.typeSignature) toMap; + termHoleInfo get termName map { case (scp0, holeType0) => + val newScp = scp0 ++ scp.map { case (n,t) => lub(t :: scp0.getOrElse(n, Any) :: Nil) } + newScp -> lub(holeType :: holeType0 :: Nil) + } getOrElse { + termHoleInfo(termName) = scp -> holeType + } } + if (splicedHoles(TermName(name))) + b.splicedHole(name, liftType(holeType)) + else b.hole(name, liftType(holeType)) + /** Removes implicit conversions that were generated in order to apply the subtyping knowledge extracted from pattern matching */ case q"${Ident(name:TermName)}.apply($x)" if convNames(name) => liftTerm(x, parent, typeIfNotNothing(x.tpe)) @@ -640,7 +640,7 @@ class QuasiEmbedder[C <: whitebox.Context](val c: C) { val typeTypesToExtract = typeSymbols mapValues { sym => tq"$baseTree.IRType[$sym]" } val extrTyps = holes.map { - case Left(vname) => termTypesToExtract(vname) + case Left(vname) => termTypesToExtract(vname) // TODO B/E case Right(tname) => typeTypesToExtract(tname) // TODO B/E } debug("Extracted Types: "+extrTyps.map(showCode(_)).mkString(", ")) diff --git a/core/src/main/scala/squid/quasi/QuasiMacros.scala b/core/src/main/scala/squid/quasi/QuasiMacros.scala index 3919a05b..b163e107 100644 --- a/core/src/main/scala/squid/quasi/QuasiMacros.scala +++ b/core/src/main/scala/squid/quasi/QuasiMacros.scala @@ -106,7 +106,7 @@ class QuasiMacros(val c: whitebox.Context) { holeSymbols = Set(), holes = Seq(), splicedHoles = Set(), - hopvHoles = Map(), + hopHoles = Set(), termHoles = Set(), typeHoles = Set(), typedTree = code, @@ -200,7 +200,7 @@ class QuasiMacros(val c: whitebox.Context) { var holes: List[(Either[TermName,TypeName], Tree)] = Nil // (Left(value-hole) | Right(type-hole), original-hole-tree) val splicedHoles = mutable.Set[TermName]() - val hopvHoles = mutable.Map[TermName,List[List[TermName]]]() + val hopvHoles = mutable.Set[TermName]() var typeBounds: List[(TypeName, EitherOrBoth[Tree,Tree])] = Nil // Keeps track of which holes still have not been found in the source code @@ -302,18 +302,17 @@ class QuasiMacros(val c: whitebox.Context) { case Ident(name: TermName) if builder.holes.contains(name) => mkTermHole(name, false) - // Identify and treat Higher-Order Pattern Variables (HOPV) - case q"${Ident(name: TermName)}(...$argss)" if isUnapply && builder.holes.contains(name) => - val idents = argss map (_ map { - case Ident(name:TermName) => name - case e => throw EmbeddingException(s"Unexpected expression in higher-order pattern variable argument: ${showCode(e)}") - }) - val hole = builder.holes(name) - val n = hole.name filter (_.toString != "_") getOrElse ( + // Identify and treat Higher-Order Patterns (HOP) + case q"${Ident(nme: TermName)}(...$argss)" if isUnapply && builder.holes.contains(nme) => + val List(args) = argss // TODO handle ...$argss + val hole = builder.holes(nme) + val name = hole.name filter (_.toString != "_") getOrElse ( throw QuasiException("All higher-order holes should be named.") // Q: necessary restriction? ) toTermName; - hopvHoles += n -> idents - mkTermHole(name, false) + hopvHoles += name + remainingHoles -= nme + holes ::= Left(name) -> hole.tree + q"$base.$$$$_hop(${Symbol(name toString)})(..$args)" // Interprets bounds on extracted types, like in: `case List[$t where (Null <:< t <:< AnyRef)]`: case tq"${Ident(name: TypeName)} where $bounds" if isUnapply && builder.holes.contains(name.toTermName) => diff --git a/src/main/scala/squid/ir/AST.scala b/src/main/scala/squid/ir/AST.scala index 9a47fceb..773f2982 100644 --- a/src/main/scala/squid/ir/AST.scala +++ b/src/main/scala/squid/ir/AST.scala @@ -298,6 +298,12 @@ trait AST extends InspectableBase with ScalaTyping with ASTReinterpreter with Ru override def hopHole(name: String, typ: TypeRep, yes: List[List[Val]], no: List[Val]) = rep(new HOPHole(name, typ, yes, no)) class HOPHole(name: String, typ: TypeRep, val yes: List[List[Val]], val no: List[Val]) extends HoleClass(name, typ)() + override def hopHole2(name: String, typ: TypeRep, args: List[List[Rep]], visible: List[BoundVal]): Rep = { + // TODO replace `hopHole` + val vars = args map (_ map (r => dfn(r).asInstanceOf[BoundVal])) + hopHole(name, typ, vars, visible.toSet -- vars.flatten toList) + } + case class Constant(value: Any) extends Def { lazy val typ = value match { case () => TypeRep(ruh.Unit) From 3ff5419c36b15a63ae0d087654fbe4c4036f0434 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Sun, 29 Oct 2017 14:32:50 +0100 Subject: [PATCH 03/66] FastANF implementation --- .../main/scala/squid/quasi/QuasiBase.scala | 2 + src/main/scala/squid/ir/fastanf/FastANF.scala | 639 ++++++++++++++++-- src/main/scala/squid/ir/fastanf/Rep.scala | 64 +- src/main/scala/squid/ir/fastanf/Symbols.scala | 4 +- src/main/scala/squid/ir/fastanf/TypeRep.scala | 2 +- src/main/scala/squid/test/Test.scala | 45 ++ .../scala/squid/ir/fastir/BasicTests.scala | 94 +++ .../squid/ir/fastir/RewritingTests.scala | 68 ++ 8 files changed, 857 insertions(+), 61 deletions(-) create mode 100644 src/main/scala/squid/test/Test.scala create mode 100644 src/test/scala/squid/ir/fastir/RewritingTests.scala diff --git a/core/src/main/scala/squid/quasi/QuasiBase.scala b/core/src/main/scala/squid/quasi/QuasiBase.scala index a894c89f..d14af403 100644 --- a/core/src/main/scala/squid/quasi/QuasiBase.scala +++ b/core/src/main/scala/squid/quasi/QuasiBase.scala @@ -299,6 +299,8 @@ self: Base => res } + protected def mergeAll(as: Extract*): Option[Extract] = mergeAll(as.map(Some(_))) + def mergeableReps(a: Rep, b: Rep): Boolean = a =~= b def mergeTypes(a: TypeRep, b: TypeRep): Option[TypeRep] = diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 693e053e..f6b751c6 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -2,9 +2,10 @@ package squid package ir.fastanf import utils._ -import lang.{InspectableBase, Base} -import squid.ir.CurryEncoding -import squid.lang.ScalaCore +import lang.{Base, InspectableBase, ScalaCore} +import squid.ir.{Covariant, CurryEncoding, IRException, Invariant} + +import scala.collection.immutable.ListMap class FastANF extends InspectableBase with CurryEncoding with ScalaCore { private[this] implicit val base = this @@ -36,12 +37,45 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { override final def wrapExtract(r: => Rep): Rep = wrap(super.wrapExtract(r), true) @inline final def currentScope = scopes.head - - // TODO make squid.lang.Base and QQs agnostic in the argument list impl to avoid this conversion! + def toArgumentLists(argss: List[ArgList]): ArgumentLists = { + def toArgumentList(args: Seq[Rep]): ArgumentList = + args.foldRight(NoArguments: ArgumentList)(_ ~: _) + def toArgumentListWithSpliced(args: Seq[Rep])(splicedArg: Rep) = + args.foldRight(SplicedArgument(splicedArg): ArgumentList)(_ ~: _) + + argss.foldRight(NoArgumentLists: ArgumentLists) { + (args, acc) => args match { + case Args(as @ _*) => toArgumentList(as) ~~: acc + case ArgsVarargs(Args(as @ _*), Args(bs @ _*)) => toArgumentList(as ++ bs) ~~: acc // ArgVararg ~converted as Args! + case ArgsVarargSpliced(Args(as @ _*), s) => toArgumentListWithSpliced(as)(s) ~~: acc + } + } + } + + def toListOfArgList(argss: ArgumentLists): List[ArgList] = { + def toArgList(args: ArgumentList): List[Rep] -> Option[Rep] = args match { + case NoArguments => Nil -> None + case SplicedArgument(a) => Nil -> Some(a) // Everything after spliced argument is ignored. + case r : Rep => List(r) -> None + case ArgumentCons(h, t) => + val (rest, spliced) = toArgList(t) + (h :: rest) -> spliced + } + argss match { - case Nil => NoArgumentLists - case Args(a) :: Nil => a + case ArgumentListCons(h, t) => + val (args, spliced) = toArgList(h) + val _args = Args(args: _*) + spliced.fold(_args: ArgList)(s => ArgsVarargSpliced(_args, s)) :: toListOfArgList(t) + case NoArgumentLists => Nil + case SplicedArgument(spliced) => List(ArgsVarargSpliced(Args(), spliced)) // Not sure + case ac : ArgumentCons => + val (args, spliced) = toArgList(ac) + val _args = Args(args: _*) + spliced.fold(_args: ArgList)(s => ArgsVarargSpliced(_args, s)) :: Nil + case NoArguments => Nil + case r : Rep => List(Args(r)) } } @@ -61,8 +95,8 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { def lambdaType(paramTyps: List[TypeRep], ret: TypeRep): TypeRep = DummyTypeRep def staticModule(fullName: String): Rep = StaticModule(fullName) - def module(prefix: Rep, name: String, typ: TypeRep): Rep = unsupported - def newObject(tp: TypeRep): Rep = unsupported + def module(prefix: Rep, name: String, typ: TypeRep): Rep = Module(prefix, name, typ) + def newObject(typ: TypeRep): Rep = NewObject(typ) def methodApp(self: Rep, mtd: MtdSymbol, targs: List[TypeRep], argss: List[ArgList], tp: TypeRep): Rep = { MethodApp(self, mtd, targs, argss |> toArgumentLists, tp) |> letbind } @@ -74,62 +108,575 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { case Ascribe(trueSelf, _) => Ascribe(trueSelf, typ) // Hopefully Scala's subtyping is transitive! case _ => Ascribe(self, typ) } - - def loadMtdSymbol(typ: TypSymbol, symName: String, index: Option[Int] = None, static: Boolean = false): MtdSymbol = new MethodSymbol(symName) // TODO - + + def loadMtdSymbol(typ: TypSymbol, symName: String, index: Option[Int] = None, static: Boolean = false): MtdSymbol = new MethodSymbol(typ, symName) // TODO + object Const extends ConstAPI { def unapply[T: IRType](ir: IR[T,_]): Option[T] = ir.rep match { case cst @ Constant(v) if typLeq(cst.typ, irTypeOf[T].rep) => Some(v.asInstanceOf[T]) case _ => None } } - - def repEq(a: Rep, b: Rep): Boolean = a === b // FIXME impl term equivalence - - + + // /** Artifact of a term extraction: map from hole name to terms, types and spliced term lists */ + + def repEq(a: Rep, b: Rep): Boolean = { + (a extractRep b) === (b extractRep a) && (a extractRep b) === Some(EmptyExtract) && (b extractRep a) === Some(EmptyExtract) + + //val aExtractB = a extractRep b + // + //if (aExtractB.isEmpty) false + //else { + // (aExtractB, b extractRep a) match { + // case (Some((xs, xts, fxs)), Some((ys, yts, fys))) => + // val extractsHole: (String -> Rep) => Boolean = { + // case (k: String, Hole(n, _)) if k == n => true + // case _ => false + // } + // + // val extractsTypeHole: (String -> TypeRep) => Boolean = { + // case (k: String, DummyTypeRep) => true + // case _ => false + // } + // + // fxs.isEmpty && fys.isEmpty && + // (xs forall extractsHole) && (ys forall extractsHole) && + // (xts forall extractsTypeHole) && (yts forall extractsTypeHole) + // case _ => false + // } + //} + } + + // * --- * --- * --- * Implementations of `IntermediateBase` methods * --- * --- * --- * - - def nullValue[T: IRType]: IR[T,{}] = unsupported - def reinterpret(r: Rep, newBase: Base)(extrudedHandle: BoundVal => newBase.Rep): newBase.Rep = unsupported + + def nullValue[T: IRType]: IR[T,{}] = IR[T, {}](const(DummyTypeRep)) + def reinterpret(r: Rep, newBase: Base)(extrudedHandle: BoundVal => newBase.Rep): newBase.Rep = { + def go: Rep => newBase.Rep = r => reinterpret(r, newBase)(extrudedHandle) + def reinterpretType: TypeRep => newBase.TypeRep = t => newBase.staticTypeApp(newBase.loadTypSymbol("scala.Any"), Nil) + def reinterpretBV:BoundVal => newBase.BoundVal = bv => newBase.bindVal(bv.name, reinterpretType(bv.typ), Nil) + def reinterpretTypSym(t: TypeSymbol): newBase.TypSymbol = newBase.loadTypSymbol(t.name) + def reinterpretMtdSym(s: MtdSymbol): newBase.MtdSymbol = newBase.loadMtdSymbol(reinterpretTypSym(s.typ), s.name) + def reinterpretArgList(argss: ArgumentLists): List[newBase.ArgList] = toListOfArgList(argss) map { + case ArgsVarargSpliced(args, varargs) => newBase.ArgsVarargSpliced(args.map(newBase)(go), go(varargs)) + case ArgsVarargs(args, varargs) => newBase.ArgsVarargs(args.map(newBase)(go), varargs.map(newBase)(go)) + case args : Args => args.map(newBase)(go) + } + def defToRep(d: Def): newBase.Rep = d match { + case app @ App(f, a) => newBase.app(go(f), go(a))(reinterpretType(app.typ)) + case ma : MethodApp => newBase.methodApp( + go(ma.self), + reinterpretMtdSym(ma.mtd), + ma.targs.map(reinterpretType), + reinterpretArgList(ma.argss), + reinterpretType(ma.typ)) + case l: Lambda => newBase.lambda(List(reinterpretBV(l.bound)), go(l.body)) + } + + r match { + case Constant(v) => newBase.const(v) + case StaticModule(fn) => newBase.staticModule(fn) + case NewObject(t) => newBase.newObject(reinterpretType(t)) + case Hole(n, t) => newBase.hole(n, reinterpretType(t)) + case SplicedHole(n, t) => newBase.splicedHole(n, reinterpretType(t)) + case Ascribe(s, t) => newBase.ascribe(go(s), reinterpretType(t)) + case HOPHole(n, t, yes, no) => newBase.hopHole( + n, + reinterpretType(t), + yes.map(_.map(reinterpretBV)), + no.map(reinterpretBV)) + case Module(p, n, t) => newBase.module(go(p), n, reinterpretType(t)) + case lb: LetBinding => newBase.letin( + reinterpretBV(lb.bound), + defToRep(lb.value), + go(lb.body), + reinterpretType(lb.typ)) + case s: Symbol => extrudedHandle(s) + } + + } def repType(r: Rep): TypeRep = r.typ def boundValType(bv: BoundVal) = bv.typ // * --- * --- * --- * Implementations of `InspectableBase` methods * --- * --- * --- * + + def extractType(xtor: TypeRep, xtee: TypeRep, va: squid.ir.Variance): Option[Extract] = Some(EmptyExtract) //unsupported + def bottomUp(r: Rep)(f: Rep => Rep): Rep = transformRepAndDef(r)(identity, f)(identity) + def topDown(r: Rep)(f: Rep => Rep): Rep = transformRepAndDef(r)(f)(identity) + def transformRepAndDef(r: Rep)(pre: Rep => Rep, post: Rep => Rep = identity)(preDef: Def => Def, postDef: Def => Def = identity): Rep = { + def _transformRepAndDef(r: Rep) = transformRepAndDef(r)(pre, post)(preDef, postDef) + + def transformDef(d: Def): Def = (d map preDef match { + case App(f, a) => App(_transformRepAndDef(f), _transformRepAndDef(a)) + case ma: MethodApp => MethodApp(_transformRepAndDef(ma.self), ma.mtd, ma.targs, ma.argss map (_transformRepAndDef(_)), ma.typ) + case l: Lambda => new Lambda(l.name, l.bound, l.boundType, _transformRepAndDef(l.body)) + }) map postDef + + (r map pre match { + case lb: LetBinding => + new LetBinding( + lb.name, + lb.bound, + transformDef(lb.value), + _transformRepAndDef(lb.body) + ) + case Ascribe(s, t) => + Ascribe(_transformRepAndDef(s), t) + case Module(p, n, t) => + Module(_transformRepAndDef(p), n, t) + case r @ ((_:Constant) | (_:Hole) | (_:Symbol) | (_:SplicedHole) | (_:HOPHole) | (_:NewObject) | (_:StaticModule)) => r + }) map post + } - def extractType(xtor: TypeRep, xtee: TypeRep,va: squid.ir.Variance): Option[Extract] = unsupported - def bottomUp(r: Rep)(f: Rep => Rep): Rep = transformRep(r)(identity, f) - def topDown(r: Rep)(f: Rep => Rep): Rep = transformRep(r)(f) - def transformRep(r: Rep)(pre: Rep => Rep, post: Rep => Rep = identity): Rep = unsupported - protected def extract(xtor: Rep, xtee: Rep): Option[Extract] = unsupported - protected def spliceExtract(xtor: Rep, args: Args): Option[Extract] = unsupported + def transformRep(r: Rep)(pre: Rep => Rep, post: Rep => Rep = identity): Rep = + transformRepAndDef(r)(pre, post)(identity, identity) + + protected def extract(xtor: Rep, xtee: Rep): Option[Extract] = extractWithCtx(xtor, xtee)(ListMap.empty) + + def extractWithCtx(xtor: Rep, xtee: Rep)(implicit ctx: ListMap[BoundVal, BoundVal]): Option[Extract] = xtor -> xtee match { + case (lb1: LetBinding, lb2: LetBinding) => + val normal = for { + //e1 <- extractWithCtx(lb1.bound, lb2.bound) + e1 <- lb1.boundType extract (lb1.boundType, Covariant) + e2 <- extractValue(lb1.value, lb2.value) + e3 <- extractWithCtx(lb1.body, lb2.body)(ctx + (lb1.bound -> lb2.bound)) + m <- mergeAll(e1, e2, e3) + } yield m + + /* + * For instance when: + * xtor: val x0 = List(Hole(...): _*) + * xtee: val x0 = Seq(1, 2, 3); val x1 = List(x0) + */ + // TODO CHECK CTX + lazy val lookFurtherInXtee = extractWithCtx(lb1, lb2.body)(ctx + (lb1.bound -> lb2.bound)) + lazy val lookFurtherInXtor = extractWithCtx(lb1.body, lb2)(ctx + (lb1.bound -> lb2.bound)) + + normal //orElse lookFurtherInXtee orElse lookFurtherInXtor + + // Matches 42 and 42: Any, is it safe to ignore the typ? + case (_, Ascribe(s, _)) => extractWithCtx(xtor, s) + + case (Ascribe(s, t) , _) => + for { + e1 <- t extract (xtee.typ, Covariant) // t <:< a.typ, which one to use? + e2 <- extractWithCtx(s, xtee) + m <- merge(e1, e2) + } yield m + + case (Hole(n, t), _) => + for { + e <- t extract (xtee.typ, Covariant) + m <- merge(e, repExtract(n -> xtee)) + } yield m + + case (HOPHole(name, typ, args, visible), _) => unsupported + + case (bv1: BoundVal, bv2: BoundVal) => + if (bv1 == bv2) Some(EmptyExtract) + else for { + candidate <- ctx.get(bv1) + if candidate == bv2 + } yield EmptyExtract + + case (Constant(v1), Constant(v2)) if v1 == v2 => + xtor.typ extract (xtee.typ, Covariant) + + // Assuming if they have the same name the type is the same + case (StaticModule(fn1), StaticModule(fn2)) if fn1 == fn2 => Some(EmptyExtract) + + // Assuming if they have the same name and prefix the type is the same + case (Module(p1, n1, _), Module(p2, n2, _)) if n1 == n2 => extractWithCtx(p1, p2) + + case (NewObject(t1), NewObject(t2)) => t1 extract (t2, Covariant) + + case _ => None + } + protected def spliceExtract(xtor: Rep, args: Args): Option[Extract] = xtor match { + // Should check that type matches, but don't see how to access it for Args + case SplicedHole(n, _) => Some(Map(), Map(), Map(n -> args.reps)) + + case Hole(n, t) => + val rep = methodApp( + staticModule("scala.collection.Seq"), + loadMtdSymbol( + loadTypSymbol("scala.collection.generic.GenericCompanion"), + "apply", + None), + List(t), + List(Args()(args.reps: _*)), + staticTypeApp(loadTypSymbol("scala.collection.Seq"), List(t))) + Some(repExtract(n -> rep)) + + case _ => throw IRException(s"Trying to splice-extract with invalid extractor $xtor") + } + + override def rewriteRep(xtor: Rep, xtee: Rep, code: Extract => Option[Rep]): Option[Rep] = { + def go(ex: Extract, matchedBVs: Set[BoundVal])(_xtor: Rep, _xtee: Rep)(implicit ctx: ListMap[BoundVal, BoundVal]): Option[Rep] = { + println(s"XTEE ->> ${_xtee}") + + def checkRefs(r: Rep): Option[Rep] = { + def refs(r: Rep): Set[BoundVal] = { + def bvsUsed(value: Def): Set[BoundVal] = value match { + case ma: MethodApp => + def bvsInArgss(argss: ArgumentLists): Set[BoundVal] = { + def go(argss: ArgumentLists, acc: Set[BoundVal]): Set[BoundVal] = argss match { + case ArgumentListCons(h, t) => go(t, go(h, acc)) + case ArgumentCons(h, t) => go(t, go(h, acc)) + case SplicedArgument(bv: BoundVal) => acc + bv + case bv: BoundVal => acc + bv + case _ => acc + } + + go(argss, Set.empty) + } + + val selfBV = ma.self match { + case bv: BoundVal => Set(bv) + case _ => Set.empty + } + + selfBV ++ bvsInArgss(ma.argss) + + case l: Lambda => + val bodyBV: Set[BoundVal] = l.body match { + case bv: BoundVal => Set(bv) + case _ => Set.empty + } + + bodyBV + l.bound + } + + + r match { + case lb: LetBinding => bvsUsed(lb.value) ++ refs(lb.body) + case bv: BoundVal => Set(bv) + case _ => Set.empty + } + } + + if ((refs(r) & matchedBVs).isEmpty) Some(r) else None + } + + def mkCode(e: Extract): Option[Rep] = { + for { + c <- code(e) + c <- checkRefs(c) + } yield c + } + + def letIn(body: Rep)(xBV: BoundVal, newX: Rep): Rep = { + println(s"----- In $body replacing $xBV with $newX") + def findAndReplace(argss: ArgumentLists, r: BoundVal, newR: BoundVal): ArgumentLists = argss.map { + case a if a == r => println(s"### $a -> $newR"); newR + case a => println(s"+++ $a"); a + } + + // In `body` replace every occurrence of `x` with `outerLB` + def _letIn(body: Rep)(xBV: BoundVal, outerLB: LetBinding): Rep = { + def replaceInValue(v: Def): Def = { + v match { + case ma: MethodApp => + println(s"@@@@ $ma") + MethodApp( + if (ma.self == xBV) outerLB.bound else ma.self, + ma.mtd, + ma.targs, + findAndReplace(ma.argss, xBV, outerLB.bound), + ma.typ) + case l: Lambda => + new Lambda(l.name, + l.bound, + l.boundType, + _letIn(l.body)(xBV, outerLB)) + } + } + + body match { + case lb: LetBinding if lb.bound == xBV => new LetBinding(outerLB.name, outerLB.bound, outerLB.value, _letIn(lb.body)(xBV, outerLB)) + + case lb: LetBinding => + new LetBinding( + lb.name, + lb.bound, + replaceInValue(lb.value), + _letIn(lb.body)(xBV, outerLB)) + + case `xBV` => outerLB.bound + case _ => body + } + } + + newX match { + case outerLB: LetBinding => + //xBV rebind outerLB.bound + //body + _letIn(removeBVs(body, matchedBVs - xBV))(xBV, outerLB) alsoApply println + + case bv: BoundVal => + xBV rebind bv + body + + case _ => ??? //_letIn(removeBVs(body, matchedBVs - xBV))(xBV, newX) + } + } + + def withBV(r: Rep)(f: LetBinding => Option[Extract]): Option[Extract -> BoundVal] = { +// def extractArgss(f: A => Option[Extract])(argss: ArgumentLists): Option[Extract -> BoundVal] = argss match { +// case ArgumentListCons(h, t) => extractArgss(f)(h) orElse extractArgss(f)(t) +// case ArgumentCons(h, t) => extractArgss(f)(h) orElse extractArgss(f)(t) +// case SplicedArgument(a) => withBV(f)(a) +// case r: Rep => withBV(f)(r) +// case NoArguments | NoArgumentLists => None +// } + + r match { + case lb: LetBinding => + f(lb) map { _ -> lb.bound} orElse { + lb.value match { + case l: Lambda => withBV(l.body)(f) + case _ => None + } + } orElse withBV(lb.body)(f) + + case _ => None + } + } + + def extractValueWithBV(v: Def, r: Rep): Option[Extract -> BoundVal] = withBV(r) { lb => extractValue(v, lb.value) } + def extractBVWithBV(bv: BoundVal, r: Rep): Option[Extract -> BoundVal] = withBV(r) { lb => extractWithCtx(bv, lb.bound) } + + def removeBVs(r: Rep, bvs: Set[BoundVal]): Rep = { + def lambdaRemove(l: Lambda): Def = { + // TODO check l.bound? + new Lambda(l.name, l.bound, l.boundType, removeBVs(l.body, bvs)) + } + + r match { + case lb: LetBinding if bvs contains lb.bound => lb.body + case lb: LetBinding => lb.value match { + case l: Lambda => new LetBinding(lb.name, lb.bound, lambdaRemove(l), removeBVs(lb.body, bvs)) + case _ => new LetBinding(lb.name, lb.bound, lb.value, removeBVs(lb.body, bvs)) + } + case _ => r + } + } + + (_xtor, _xtee) match { + case (lb1: LetBinding, lb2: LetBinding) => + for { + (e, matchedBV) <- extractValueWithBV(lb1.value, lb2) + _ = println(s"----- Matched $matchedBV") + if !(matchedBVs contains matchedBV) // Cannot match twice the same code line + m <- merge(e, ex) + r <- go(m, matchedBVs + matchedBV)(lb1.body, lb2)(ctx + (lb1.bound -> matchedBV)) + } yield r + + // Match return of `xtor` with something in `xtee` + case (bv: BoundVal, lb: LetBinding) => + println(s"Knowing: $ctx") + println(s"Matching... $bv in $lb") + for { + (e, matchedBV) <- extractBVWithBV(bv, lb) + _ = println(s"----->> Matched $matchedBV") + m <- merge(e, ex) + newX <- mkCode(m) + _ = println(s"Code => $newX") + newC = letIn(lb)(matchedBV, newX) alsoApply (c => println(s"New code => $c")) + } yield newC + + case (bv: BoundVal, xtee: Rep) => + for { + (e, matchedBV) <- extractBVWithBV(bv, xtee) + m <- merge(e, ex) + newX <- mkCode(m) + newC = letIn(xtee)(matchedBV, newX) + } yield newC + + // Match Constant(42) with `value` of the `LetBinding` + case (xtor: Rep, lb: LetBinding) => + type LazyInlined[R] = R -> List[LetBinding] + implicit def lazyInlined[R](r: R): LazyInlined[R] = r -> List.empty + def run(lr: LazyInlined[Rep]): Rep = lr match { + case r -> lbs => lbs.foldLeft(r) { (acc, outerLB) => new LetBinding(outerLB.name, outerLB.bound, outerLB.value, acc) } + } + + // Puts `acc` inside `outer`. + def surroundWithLB(acc: LazyInlined[Rep], outer: LazyInlined[Rep]): LazyInlined[Rep] = outer match { + case (outerLB: LetBinding, lbs) => + new LetBinding(outerLB.name, outerLB.bound, outerLB.value, run(acc)) -> lbs + case _ => throw new IllegalArgumentException + } + + + def rewrite(lb: LetBinding): Option[Rep] = { + def rewriteValue(lb: LetBinding): Option[LazyInlined[Rep]] = { + val rewrittenValue = { + def rewriteArgssReps(argss: ArgumentLists): Option[LazyInlined[ArgumentLists]] = { + def rewriteArgsReps(args: ArgumentList): Option[LazyInlined[ArgumentList]] = args match { + case ArgumentCons(h, t) => for { + (h, outerLBs) <- rewriteRep(h) + (t, innerLBs) <- rewriteArgsReps(t) + } yield ArgumentCons(h, t) -> (innerLBs ::: outerLBs) + + case SplicedArgument(a) => rewriteRep(a) map { case arg -> lbs => SplicedArgument(arg) -> lbs } + case r: Rep => rewriteRep(r) + case NoArguments => Some(NoArguments) + } + + argss match { + case ArgumentListCons(h, t) => for { + (h, outerLBs) <- rewriteArgsReps(h) + (t, innerLBs) <- rewriteArgssReps(t) + } yield ArgumentListCons(h, t) -> (innerLBs ::: outerLBs) + + case args: ArgumentList => rewriteArgsReps(args) + case NoArgumentLists => Some(NoArgumentLists) + } + } + + def rewriteRep(xtee: Rep) = { + def codeWithOuter(e: Extract) = { + def toLazy(r: Rep): LazyInlined[Rep] = { + def go(r: Rep, acc: List[LetBinding]): LazyInlined[Rep] = r match { + case lb: LetBinding => go(lb.body, lb :: acc) + case _ => r -> acc + } + + go(r, List.empty) + } + + mkCode(e) map toLazy + } + + extractWithCtx(xtor, xtee).fold(Option(lazyInlined(xtee)))(codeWithOuter) + } + + lb.value match { + case ma: MethodApp => for { + (self, outerLBs) <- rewriteRep(ma.self) + (argss, innerLBs) <- rewriteArgssReps(ma.argss) + } yield MethodApp(self, ma.mtd, ma.targs, argss, ma.typ) -> (innerLBs ::: outerLBs) + + case l: Lambda => rewriteRep(l.body) map { + case body -> lbs => new Lambda(l.name, l.bound, l.boundType, body) -> lbs + } + } + } + + // Puts the value back into its LB + rewrittenValue map { case value -> lbs => new LetBinding(lb.name, lb.bound, value, lb.body) -> lbs} + } + + def go(lb: LetBinding, acc: Option[List[LazyInlined[Rep]]]): Option[List[LazyInlined[Rep]]] = lb.body match { + case innerLB: LetBinding => + go(innerLB, + for { + acc <- acc + lb <- rewriteValue(lb) + } yield lb :: acc) + + case ret => + for { + acc <- acc + lb <- rewriteValue(lb) + } yield lazyInlined(ret) :: lb :: acc + } + + go(lb, Some(List.empty)) map { + case ret :: outerLBs => run(outerLBs.foldLeft(ret)(surroundWithLB)) + } + } + + rewrite(lb) + + // Matching return values + case (r1: Rep, r2: Rep) => + for { + e <- extractWithCtx(r1, r2) + m <- merge(e, ex) + c <- mkCode(m) + } yield c + } + } + + go(EmptyExtract, Set.empty)(xtor, xtee)(ListMap.empty) + } + + def extractValue(v1: Def, v2: Def)(implicit ctx: ListMap[BoundVal, BoundVal]) = (v1, v2) match { + case (l1: Lambda, l2: Lambda) => + for { + e1 <- l1.boundType extract (l2.boundType, Covariant) + e2 <- extractWithCtx(l1.body, l2.body)(ctx + (l1.bound -> l2.bound)) + m <- merge(e1, e2) + } yield m + + case (ma1: MethodApp, ma2: MethodApp) if ma1.mtd == ma2.mtd => + lazy val targExtract = mergeAll(for { + (e1, e2) <- ma1.targs zip ma2.targs + } yield e1 extract (e2, Invariant)) // TODO Invariant? Depends on its positions... + + def extractArgss(argss1: ArgumentLists, argss2: ArgumentLists): Option[Extract] = { + def go(argss1: ArgumentLists, argss2: ArgumentLists, acc: List[Rep]): Option[Extract] = (argss1, argss2) match { + case (ArgumentListCons(h1, t1), ArgumentListCons(h2, t2)) => mergeOpt(go(h1, h2, acc), go(t1, t2, acc)) + case (ArgumentCons(h1, t1), ArgumentCons(h2, t2)) => mergeOpt(go(h1, h2, acc), go(t1, t2, acc)) + case (SplicedArgument(arg1), SplicedArgument(arg2)) => extractWithCtx(arg1, arg2) + case (sa: SplicedArgument, ArgumentCons(h, t)) => go(sa, t, h :: acc) + case (sa: SplicedArgument, r: Rep) => go(sa, NoArguments, r :: acc) + case (SplicedArgument(arg), NoArguments) => spliceExtract(arg, Args(acc.reverse: _*)) // Reverses list... + case (r1: Rep, r2: Rep) => extractWithCtx(r1, r2) + case (NoArguments, NoArguments) => Some(EmptyExtract) + case (NoArgumentLists, NoArgumentLists) => Some(EmptyExtract) + case _ => None + } + + go(argss1, argss2, Nil) + } + + for { + e1 <- extractWithCtx(ma1.self, ma2.self) + e2 <- targExtract + e3 <- extractArgss(ma1.argss, ma2.argss) + e4 <- ma1.typ extract (ma2.typ, Covariant) + m <- mergeAll(e1, e2, e3, e4) + } yield m + + case _ => None + } // * --- * --- * --- * Implementations of `QuasiBase` methods * --- * --- * --- * - - def hole(name: String, typ: TypeRep) = unsupported - def splicedHole(name: String, typ: TypeRep): Rep = unsupported - def typeHole(name: String): TypeRep = unsupported - def hopHole(name: String, typ: TypeRep, yes: List[List[BoundVal]], no: List[BoundVal]) = unsupported - - def substitute(r: => Rep, defs: Map[String, Rep]): Rep = unsupported - - + + def hole(name: String, typ: TypeRep) = Hole(name, typ) + def splicedHole(name: String, typ: TypeRep): Rep = SplicedHole(name, typ) + def typeHole(name: String): TypeRep = DummyTypeRep + def hopHole(name: String, typ: TypeRep, yes: List[List[BoundVal]], no: List[BoundVal]) = HOPHole(name, typ, yes, no) + + def substitute(r: => Rep, defs: Map[String, Rep]): Rep = + if (defs isEmpty) r + else bottomUp(r) { + case h @ Hole(n, _) => defs getOrElse(n, h) + case h @ SplicedHole(n, _) => defs getOrElse(n, h) + case h => h + } + + // * --- * --- * --- * Implementations of `TypingBase` methods * --- * --- * --- * import scala.reflect.runtime.universe.TypeTag // TODO do without this - - def uninterpretedType[A: TypeTag]: TypeRep = unsupported - def typeApp(self: TypeRep, typ: TypSymbol, targs: List[TypeRep]): TypeRep = unsupported - def staticTypeApp(typ: TypSymbol, targs: List[TypeRep]): TypeRep = DummyTypeRep //unsupported - def recordType(fields: List[(String, TypeRep)]): TypeRep = unsupported - def constType(value: Any, underlying: TypeRep): TypeRep = unsupported - - def typLeq(a: TypeRep, b: TypeRep): Boolean = unsupported - - def loadTypSymbol(fullName: String): TypSymbol = new TypeSymbol // TODO - - + + def uninterpretedType[A: TypeTag]: TypeRep = DummyTypeRep + def typeApp(self: TypeRep, typ: TypSymbol, targs: List[TypeRep]): TypeRep = DummyTypeRep + def staticTypeApp(typ: TypSymbol, targs: List[TypeRep]): TypeRep = DummyTypeRep //unsupported + def recordType(fields: List[(String, TypeRep)]): TypeRep = DummyTypeRep + def constType(value: Any, underlying: TypeRep): TypeRep = DummyTypeRep + + def typLeq(a: TypeRep, b: TypeRep): Boolean = true + + def loadTypSymbol(fullName: String): TypSymbol = new TypeSymbol(fullName) // TODO + + // * --- * --- * --- * Misc * --- * --- * --- * def unsupported = lastWords("This part of the IR is not yet implemented/supported") diff --git a/src/main/scala/squid/ir/fastanf/Rep.scala b/src/main/scala/squid/ir/fastanf/Rep.scala index 4606314c..ecf4635d 100644 --- a/src/main/scala/squid/ir/fastanf/Rep.scala +++ b/src/main/scala/squid/ir/fastanf/Rep.scala @@ -37,25 +37,39 @@ object Som { /** Trivial expression, that can be used as arguments */ sealed abstract class Def extends DefOption with DefOrTypeRep with FlatSom[Def] { def fold[R](df: Def => R, typeRep: TypeRep => R): R = df(this) + def map(f: Def => Def): Def = f(this) } /** Expression that can be used as an argument or result; this includes let bindings. */ sealed abstract class Rep extends RepOption with ArgumentList with FlatSom[Rep] { def typ: TypeRep def * = SplicedArgument(this) + def map(f: Rep => Rep) = f(this) + def isPure: Bool = true } final case class Constant(value: Any) extends Rep with CachedHashCode { - lazy val typ = value match { // TODO impl and rm lazy - case _ => lastWords(s"Not a valid constant: $value") + // TODO impl and rm lazy + lazy val typ = value match { + case _ => DummyTypeRep } } +final case class Hole(name: String, typ: TypeRep) extends Rep with CachedHashCode + +final case class SplicedHole(name: String, typ: TypeRep) extends Rep with CachedHashCode + +final case class HOPHole(name: String, typ: TypeRep, args: List[List[Symbol]], visible: List[Symbol]) extends Rep with CachedHashCode + // TODO intern objects final case class StaticModule(fullName: String) extends Rep with CachedHashCode { - lazy val typ = ??? + val typ = DummyTypeRep } +final case class Module(prefix: Rep, name: String, typ: TypeRep) extends Rep with CachedHashCode + +final case class NewObject(typ: TypeRep) extends Rep with CachedHashCode + // TODO make sure this does not generate a field for `base` // if it does, perhaps use `private[this] implicit val`? final case class Ascribe(self: Rep, typ: TypeRep)(implicit base: FastANF) extends Rep with CachedHashCode { @@ -112,26 +126,52 @@ sealed trait ArgumentLists extends CachedHashCode { def asSingleArg: Rep = this.asInstanceOf[Rep] def ~~: (as: ArgumentList): ArgumentLists = ArgumentListCons(as, this) - - def toArgssString = this match { - case NoArgumentLists => "" - case NoArguments => s"()" - case r: Rep => s"($r)" - case _ => ??? // TODO + + def toArgssString: String = { + def withoutParen(args: ArgumentList): String = args match { + case NoArguments => "" + case r: Rep => s"$r" + case SplicedArgument(r) => s"$r: _*" + case ArgumentCons(h, t: Rep) => s"$h, ${withoutParen(t)}" + case ArgumentCons(h, NoArguments) => s"$h" + case ArgumentCons(h, t) => s"$h, ${withoutParen(t)}" + } + + this match { + case NoArgumentLists => "" + case NoArguments => s"()" + case r: Rep => s"(${withoutParen(r)})" + case s: SplicedArgument => s"(${withoutParen(s)})" + case args: ArgumentCons => s"(${withoutParen(args)})" + case ArgumentListCons(h, t) => s"(${withoutParen(h)})${t.toArgssString}" + } } + + def map(f: Rep => Rep): ArgumentLists } + final case object NoArgumentLists extends ArgumentLists { override def ~~: (as: ArgumentList): ArgumentLists = as + def map(f: Rep => Rep) = this } sealed trait ArgumentList extends ArgumentLists { def ~: (a: Rep): ArgumentList = ArgumentCons(a, this) + def map(f: Rep => Rep): ArgumentList } final case object NoArguments extends ArgumentList { override def ~: (a: Rep): ArgumentList = a + def map(f: Rep => Rep) = this +} +// Q: can make extend AnyVal? requires making all upper traits universal) +final case class SplicedArgument(arg: Rep) extends ArgumentList { + def map(f: Rep => Rep) = SplicedArgument(f(arg)) +} +final case class ArgumentCons(head: Rep, tail: ArgumentList) extends ArgumentList { + def map(f: Rep => Rep) = ArgumentCons(f(head), tail map f) +} +final case class ArgumentListCons(head: ArgumentList, tail: ArgumentLists) extends ArgumentLists { + def map(f: Rep => Rep) = ArgumentListCons(head map f, tail map f) } -final case class SplicedArgument(arg: Rep) extends ArgumentList // Q: can make extend AnyVal? requires making all upper traits universal) -final case class ArgumentCons(head: Rep, tail: ArgumentLists) extends ArgumentList -final case class ArgumentListCons(head: ArgumentList, tail: ArgumentLists) extends ArgumentLists trait Binding extends SymbolParent { diff --git a/src/main/scala/squid/ir/fastanf/Symbols.scala b/src/main/scala/squid/ir/fastanf/Symbols.scala index fd2c8b58..9f4a5426 100644 --- a/src/main/scala/squid/ir/fastanf/Symbols.scala +++ b/src/main/scala/squid/ir/fastanf/Symbols.scala @@ -3,10 +3,10 @@ package ir.fastanf import utils._ -class TypeSymbol { +case class TypeSymbol(name: String) { } -case class MethodSymbol(name: String) { +case class MethodSymbol(typ: TypeSymbol, name: String) { } diff --git a/src/main/scala/squid/ir/fastanf/TypeRep.scala b/src/main/scala/squid/ir/fastanf/TypeRep.scala index 225e6f68..5ed853d3 100644 --- a/src/main/scala/squid/ir/fastanf/TypeRep.scala +++ b/src/main/scala/squid/ir/fastanf/TypeRep.scala @@ -4,7 +4,7 @@ package ir.fastanf import utils._ // cannot be sealed unless we put everything into one file... -private[fastanf] trait DefOrTypeRep { +/*private[fastanf]*/ trait DefOrTypeRep { def typ: TypeRep def fold[R](df: Def => R, typeRep: TypeRep => R): R } diff --git a/src/main/scala/squid/test/Test.scala b/src/main/scala/squid/test/Test.scala new file mode 100644 index 00000000..2a49e446 --- /dev/null +++ b/src/main/scala/squid/test/Test.scala @@ -0,0 +1,45 @@ +package squid +package test + +import squid.ir.FastANF +import squid.ir.fastanf.{DummyTypeRep, SplicedHole} + +object Test extends App { + object Embedding extends FastANF + import Embedding.Predef._ + import Embedding.Quasicodes._ + + //def odd(x: Int, y: Int)(z: Int)(r: Double*): Int = x + y + z + //def foo = 42 + // + //val bla = ir"foo" + // + //val program = ir"(a: Int, b: Int, c: Int, d: Double, e: Double) => odd(a, b)($bla)(d, e)" + + //val program = dbg_ir"(y: Int) => y + 5" + // + //println(s"$program") + + //val p1 = code"(x: Int) => x + 4" + + //case class Point(x: Int, y: Int) + //val p = Point(0, 1) + // + //code{p.x} match { + // case code"p.x" => println(code"p.y") + //} + + def f(args: Int*)(moreArgs: Int*) = 0 + + val p = code"f(1, 2, 3)(Seq(5, 6): _*)" + println(p) + + p match { + case code"f($n*)($fivesix: _*)" => println((n, fivesix)) + } + + val p2 = code"(x: Int) => x + 5" alsoApply println + p2 match { + case code"(y: Int) => y + 5" => + } +} diff --git a/src/test/scala/squid/ir/fastir/BasicTests.scala b/src/test/scala/squid/ir/fastir/BasicTests.scala index 9dcbe68d..0796cbe1 100644 --- a/src/test/scala/squid/ir/fastir/BasicTests.scala +++ b/src/test/scala/squid/ir/fastir/BasicTests.scala @@ -3,6 +3,7 @@ package ir.fastir import utils._ import squid.ir.FastANF +import squid.ir.{SimpleRuleBasedTransformer,TopDownTransformer} class BasicTests extends MyFunSuiteBase(BasicTests.Embedding) { import BasicTests.Embedding.Predef._ @@ -28,6 +29,99 @@ class BasicTests extends MyFunSuiteBase(BasicTests.Embedding) { } + test("ArgumentLists Pretty Print") { + import squid.ir.fastanf._ + + val c0 = Constant(0) + val c1 = Constant(1) + val c2 = Constant(2) + val c3 = Constant(3) + + assert((c0 ~: c1).toArgssString == s"($c0, $c1)") + assert((c0 ~~: (c1 ~: c2)).toArgssString == s"($c0)($c1, $c2)") + assert((c0 ~~: (c1 ~: c2 ~: NoArguments)).toArgssString == s"($c0)($c1, $c2)") + assert((c0 ~~: (c1 ~: c2 ~: SplicedArgument(c3))).toArgssString == s"($c0)($c1, $c2, $c3: _*)") + } + + test("Transformations") { + import squid.ir.fastanf._ + object Embedding extends FastANF + import Embedding.Predef._ + import Embedding.Quasicodes._ + + assert(Embedding.bottomUpPartial(code"42".rep){case Constant(n: Int) => Constant(n * 2)} == code"84".rep) + //assert(Embedding.bottomUpPartial(code"(x: Int) => x + 5".rep){case Constant(5) => Constant(42)} == code"(x: Int) => x + 42".rep) + //assert(Embedding.bottomUpPartial(code{ Some(5) }.rep){ case Constant(5) => Constant(42) } == code{ Some(42) }.rep) + //assert(Embedding.bottomUpPartial(code"foo(4, 5)".rep){ case Constant(4) => Constant(42) } == code"foo(42, 5)".rep) + + //case class Point(x: Int, y: Int) + //val p = Point(0, 1) + // + //assert(Embedding.bottomUpPartial(code{p.x + p.y}.rep){ case Constant(0) => Constant(42) } == code"p.y + p.y") + // + //code{p.x} match { + // case code"p.x" => assert(true) + // case _ => assert(false) + //} + // + //code{p.x + p.y} match { + // case code"p.x" => assert(true) + // case _ => assert(false) + //} + + assert(Embedding.bottomUpPartial(code"println((1, 23))".rep) { case Constant(1) => Constant(42) } == code"println((42,23))".rep) + } + + test("Transformers") { + object Tr extends BasicTests.Embedding.SelfTransformer with SimpleRuleBasedTransformer with TopDownTransformer { + rewrite { + case ir"123" => ir"111" + case ir"readInt" => ir"42" + } + } + + val p = + BasicTests.Embedding.debugFor { + ir"readInt+123" alsoApply println transformWith Tr alsoApply println + } + + assert(p =~= ir"${ir"42"}+123") + + } + + test("Extract") { + import squid.ir.fastanf._ + object Embedding extends FastANF + import Embedding.Predef._ + import Embedding.Quasicodes._ + + code"42: Int" match { + case code"${Const(x: Int)}" => assert(x == 42) + } + + code"(x: Int) => x" match { + case code"(x: Int) => x" => + } + + code"(x: Int, y: Int) => x + 1 + y" match { + case code"(y: Int, z: Int) => y + 1 + z" => assert(true) + case _ => assert(false) + } + + //case class Point(x: Int, y: Int) + //val p = Point(0, 1) + // + //code{p.x} match { + // case code"p.x" => assert(true) + // case _ => assert(false) + //} + // + //code{p.x + p.y} match { + // case code"p.x" => assert(true) + // case _ => assert(false) + //} + } + } object BasicTests { object Embedding extends FastANF diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala new file mode 100644 index 00000000..d512dbad --- /dev/null +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -0,0 +1,68 @@ +package squid +package ir +package fastir + +class RewritingTests extends MyFunSuiteBase(BasicTests.Embedding) { + import RewritingTests.Embedding.Predef._ + + test("Worksheet") { + val a = ir"readDouble.toInt" alsoApply println + + val b = a rewrite { + case ir"readDouble.toInt" => + ir"readInt" + } alsoApply println + + + } + + test("Rewriting simple expressions only once") { + val a = ir"println((50,60))" + val b = a rewrite { + case ir"($x:Int,$y:Int)" => + ir"($y:Int,$x:Int)" + case ir"(${Const(n)}:Int)" => Const(n+1) + } + + assert(b =~= ir"println((61,51))") + } + + test("Rewriting Sequences of Bindings") { + + val a = ir"val aaa = readInt; val bbb = readDouble.toInt; aaa+bbb" alsoApply println // FIXMElater: why isn't name 'bbb' conserved? + val b = a rewrite { + case ir"readDouble.toInt" => + ir"readInt" + } alsoApply println + + ir"val aa = readInt; val bb = readInt; aa+bb" alsoApply println + + assert(b =~= ir"val aa = readInt; val bb = readInt; aa+bb") + +// { +// val a = ir"val a = 11.toDouble; val b = 22.toDouble; val c = 33.toDouble; (a,b,c)" +// val c = a rewrite { +// case ir"val a = ($x:Int).toDouble; val b = ($y:Int).toDouble; $body: $bt" => +// ir"val a = ($x+$y).toDouble/2; val b = a; $body" +// } +// assert(c =~= ir"val a = (11+${ir"22"}).toDouble/2; val c = 33.toDouble; (a,a,c)") +// +// /* +// // TODO make this work like the one above +// val d = a rewrite { +// case ir"val a = ($x:Int).toDouble; ($y:Int).toDouble" => +// ir"($x+$y).toDouble/2" +// } +// println(d) +// //d eqt ir"val a = (11+${ir"22"}).toDouble/2; val c = 33.toDouble; (a,a,c)" +// */ +// } +// +// + } + + +} +object RewritingTests { + object Embedding extends FastANF +} From e4c647944d4b4a447faf2918302a26ef9e430031 Mon Sep 17 00:00:00 2001 From: LPTK Date: Thu, 19 Oct 2017 23:38:24 -0700 Subject: [PATCH 04/66] Make ANF not retarded: proper letin implementation and arguments binding --- .../scala/squid/lang/IntermediateBase.scala | 1 + src/main/scala/squid/ir/fastanf/FastANF.scala | 81 ++++++++++++++----- src/main/scala/squid/ir/fastanf/Rep.scala | 59 ++++++++++---- 3 files changed, 109 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/squid/lang/IntermediateBase.scala b/core/src/main/scala/squid/lang/IntermediateBase.scala index 14439c66..8212e429 100644 --- a/core/src/main/scala/squid/lang/IntermediateBase.scala +++ b/core/src/main/scala/squid/lang/IntermediateBase.scala @@ -21,6 +21,7 @@ trait IntermediateBase extends Base { ibase: IntermediateBase => def reinterpret(r: Rep, newBase: Base)(extrudedHandle: (BoundVal => newBase.Rep) = DefaultExtrudedHandler): newBase.Rep + // TODO rename to `defaultValue` def nullValue[T: IRType]: IR[T,{}] diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index f6b751c6..774ccb21 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -38,7 +38,9 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { @inline final def currentScope = scopes.head - def toArgumentLists(argss: List[ArgList]): ArgumentLists = { + def toArgumentLists(argss0: List[ArgList]): ArgumentLists = { + val argss = argss0.map(_.map(this)(inlineBlock)) // TODO optimize: avoid reconstruction of the ArgList's + def toArgumentList(args: Seq[Rep]): ArgumentList = args.foldRight(NoArguments: ArgumentList)(_ ~: _) def toArgumentListWithSpliced(args: Seq[Rep])(splicedArg: Rep) = @@ -83,7 +85,7 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { // * --- * --- * --- * Implementations of `Base` methods * --- * --- * --- * def bindVal(name: String, typ: TypeRep, annots: List[Annot]): BoundVal = new UnboundSymbol(name,typ) - def readVal(bv: BoundVal): Rep = bv + def readVal(bv: BoundVal): Rep = curSub getOrElse (bv, bv) def const(value: Any): Rep = Constant(value) // Note: method `lambda(params: List[BoundVal], body: => Rep): Rep` is implemented by CurryEncoding @@ -98,12 +100,49 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { def module(prefix: Rep, name: String, typ: TypeRep): Rep = Module(prefix, name, typ) def newObject(typ: TypeRep): Rep = NewObject(typ) def methodApp(self: Rep, mtd: MtdSymbol, targs: List[TypeRep], argss: List[ArgList], tp: TypeRep): Rep = { - MethodApp(self, mtd, targs, argss |> toArgumentLists, tp) |> letbind + MethodApp(self |> inlineBlock, mtd, targs, argss |> toArgumentLists, tp) |> letbind } def byName(mkArg: => Rep): Rep = wrapNest(mkArg) def letbind(d: Def): Rep = currentScope += d - + def inlineBlock(r: Rep): Rep = r |>=? { + case lb: LetBinding => + currentScope += lb + inlineBlock(lb.body) + } + + override def letin(bound: BoundVal, value: Rep, body: => Rep, bodyType: TypeRep): Rep = { + value match { + case s: Symbol => + s.owner |>? { + case lb: RebindableBinding => + lb.name = bound.name + } + bound rebind s + body + case lb: LetBinding => + // conceptually, does like `inlineBlock`, but additionally rewrites `bound` and renames `lb`'s last binding + val last = lb.last + bound rebind last.bound + last.body = body + last.name = bound.name // TODO make sure we're only renaming an automatically-named binding? + lb + case (_:HOPHole) | (_:Hole) | (_:SplicedHole) => + ??? // TODO holes should probably be Def's; note that it's not safe to do a substitution for holes + case _ => + withSubs(bound -> value)(body) + // ^ executing `body` will reify some statements into the reification scope, and likely return a symbol + // during this reification, we need all references to `bound` to be replaced by the actual `value` + } + } + + var curSub: Map[Symbol,Rep] = Map.empty + def withSubs[R](subs: Symbol -> Rep)(k: => R): R = { + val oldSub = curSub + curSub += subs + try k finally curSub = oldSub + } + override def ascribe(self: Rep, typ: TypeRep): Rep = if (self.typ =:= typ) self else self match { case Ascribe(trueSelf, _) => Ascribe(trueSelf, typ) // Hopefully Scala's subtyping is transitive! case _ => Ascribe(self, typ) @@ -150,7 +189,7 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { // * --- * --- * --- * Implementations of `IntermediateBase` methods * --- * --- * --- * - def nullValue[T: IRType]: IR[T,{}] = IR[T, {}](const(DummyTypeRep)) + def nullValue[T: IRType]: IR[T,{}] = IR[T, {}](const(null)) // FIXME: should implement proper semantics; e.g. nullValue[Int] == ir"0", not ir"null" def reinterpret(r: Rep, newBase: Base)(extrudedHandle: BoundVal => newBase.Rep): newBase.Rep = { def go: Rep => newBase.Rep = r => reinterpret(r, newBase)(extrudedHandle) def reinterpretType: TypeRep => newBase.TypeRep = t => newBase.staticTypeApp(newBase.loadTypSymbol("scala.Any"), Nil) @@ -208,25 +247,31 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { def _transformRepAndDef(r: Rep) = transformRepAndDef(r)(pre, post)(preDef, postDef) def transformDef(d: Def): Def = (d map preDef match { - case App(f, a) => App(_transformRepAndDef(f), _transformRepAndDef(a)) - case ma: MethodApp => MethodApp(_transformRepAndDef(ma.self), ma.mtd, ma.targs, ma.argss map (_transformRepAndDef(_)), ma.typ) - case l: Lambda => new Lambda(l.name, l.bound, l.boundType, _transformRepAndDef(l.body)) + case App(f, a) => App(_transformRepAndDef(f), _transformRepAndDef(a)) // Note: App is a MethodApp, but we can transform it more efficiently this way + case ma: MethodApp => MethodApp(_transformRepAndDef(ma.self), ma.mtd, ma.targs, ma.argss argssMap (_transformRepAndDef(_)), ma.typ) + case l: Lambda => // Note: destructive modification of the lambda binding! + //new Lambda(l.name, l.bound, l.boundType, _transformRepAndDef(l.body)) + l.body = l.body |> _transformRepAndDef + l }) map postDef - (r map pre match { - case lb: LetBinding => - new LetBinding( - lb.name, - lb.bound, - transformDef(lb.value), - _transformRepAndDef(lb.body) - ) + post(pre(r) match { + case lb: LetBinding => // Note: destructive modification of the let-binding! + //new LetBinding( + // lb.name, + // lb.bound, + // transformDef(lb.value), + // _transformRepAndDef(lb.body) + //) + lb.value = lb.value |> transformDef + lb.body = lb.body |> _transformRepAndDef + lb case Ascribe(s, t) => Ascribe(_transformRepAndDef(s), t) case Module(p, n, t) => Module(_transformRepAndDef(p), n, t) case r @ ((_:Constant) | (_:Hole) | (_:Symbol) | (_:SplicedHole) | (_:HOPHole) | (_:NewObject) | (_:StaticModule)) => r - }) map post + }) } def transformRep(r: Rep)(pre: Rep => Rep, post: Rep => Rep = identity): Rep = @@ -369,7 +414,7 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { def letIn(body: Rep)(xBV: BoundVal, newX: Rep): Rep = { println(s"----- In $body replacing $xBV with $newX") - def findAndReplace(argss: ArgumentLists, r: BoundVal, newR: BoundVal): ArgumentLists = argss.map { + def findAndReplace(argss: ArgumentLists, r: BoundVal, newR: BoundVal): ArgumentLists = argss.argssMap { case a if a == r => println(s"### $a -> $newR"); newR case a => println(s"+++ $a"); a } diff --git a/src/main/scala/squid/ir/fastanf/Rep.scala b/src/main/scala/squid/ir/fastanf/Rep.scala index ecf4635d..ecefb186 100644 --- a/src/main/scala/squid/ir/fastanf/Rep.scala +++ b/src/main/scala/squid/ir/fastanf/Rep.scala @@ -44,8 +44,9 @@ sealed abstract class Def extends DefOption with DefOrTypeRep with FlatSom[Def] sealed abstract class Rep extends RepOption with ArgumentList with FlatSom[Rep] { def typ: TypeRep def * = SplicedArgument(this) - def map(f: Rep => Rep) = f(this) def isPure: Bool = true + def argssMap(f: Rep => Rep) = f(this) + def argssList: List[Rep] = this :: Nil } final case class Constant(value: Any) extends Rep with CachedHashCode { @@ -64,6 +65,7 @@ final case class HOPHole(name: String, typ: TypeRep, args: List[List[Symbol]], v // TODO intern objects final case class StaticModule(fullName: String) extends Rep with CachedHashCode { val typ = DummyTypeRep + override def toString = fullName } final case class Module(prefix: Rep, name: String, typ: TypeRep) extends Rep with CachedHashCode @@ -84,6 +86,8 @@ trait MethodApp extends Def { def targs: List[TypeRep] def argss: ArgumentLists def typ: TypeRep + protected def doChecks = // we can't execute the checks here right away, because of initialization order + (self :: argss.argssList).foreach(r => assert(!r.isInstanceOf[LetBinding], s"Illegal ANF argument/self: $r")) override def toString = s"$self.${mtd.name}${argss.toArgssString}" } object MethodApp { @@ -97,7 +101,7 @@ object MethodApp { } } final case class SimpleMethodApp protected(self: Rep, mtd: MethodSymbol, targs: List[TypeRep], argss: ArgumentLists)(val typ: TypeRep) - extends MethodApp with CachedHashCode + extends MethodApp with CachedHashCode { doChecks } final case class App(fun: Rep, arg: Rep)(implicit base: FastANF) extends Def with MethodApp { val self = fun @@ -105,6 +109,7 @@ final case class App(fun: Rep, arg: Rep)(implicit base: FastANF) extends Def wit def targs = Nil def argss = arg lazy val typ = fun.typ.asFunType.map(_._2).getOrElse(lastWords(s"Application on a non-function type `${fun.typ}`")) + doChecks } /** To avoid useless wrappers/boxing in the most common cases, we have this scheme for argument lists: @@ -147,46 +152,62 @@ sealed trait ArgumentLists extends CachedHashCode { } } - def map(f: Rep => Rep): ArgumentLists + def argssMap(f: Rep => Rep): ArgumentLists + def argssList: List[Rep] // TODO use more efficient structure to accumulate args (with constant-time ++ and +:); also, make these lazy vals in non-trivial argument lists? } final case object NoArgumentLists extends ArgumentLists { override def ~~: (as: ArgumentList): ArgumentLists = as - def map(f: Rep => Rep) = this + def argssMap(f: Rep => Rep) = this + def argssList: List[Rep] = Nil } sealed trait ArgumentList extends ArgumentLists { def ~: (a: Rep): ArgumentList = ArgumentCons(a, this) - def map(f: Rep => Rep): ArgumentList + def argssMap(f: Rep => Rep): ArgumentList } final case object NoArguments extends ArgumentList { override def ~: (a: Rep): ArgumentList = a - def map(f: Rep => Rep) = this + def argssMap(f: Rep => Rep) = this + def argssList: List[Rep] = Nil } // Q: can make extend AnyVal? requires making all upper traits universal) final case class SplicedArgument(arg: Rep) extends ArgumentList { - def map(f: Rep => Rep) = SplicedArgument(f(arg)) + def argssMap(f: Rep => Rep) = SplicedArgument(f(arg)) + def argssList: List[Rep] = arg :: Nil } final case class ArgumentCons(head: Rep, tail: ArgumentList) extends ArgumentList { - def map(f: Rep => Rep) = ArgumentCons(f(head), tail map f) + def argssMap(f: Rep => Rep) = ArgumentCons(f(head), tail argssMap f) + def argssList: List[Rep] = head :: tail.argssList } final case class ArgumentListCons(head: ArgumentList, tail: ArgumentLists) extends ArgumentLists { - def map(f: Rep => Rep) = ArgumentListCons(head map f, tail map f) + def argssMap(f: Rep => Rep) = ArgumentListCons(head argssMap f, tail argssMap f) + def argssList: List[Rep] = head.argssList ++ tail.argssList } trait Binding extends SymbolParent { - val name: String + def name: String def bound: Symbol def boundType: TypeRep } -class LetBinding(val name: String, val bound: Symbol, val value: Def, private var _body: Rep) extends Rep with Binding { +trait RebindableBinding extends Binding { + def bound_= (newBound: Symbol): Unit + def name_= (newName: String): Unit +} +class LetBinding(var name: String, var bound: Symbol, var value: Def, private var _body: Rep) extends Rep with RebindableBinding { def body = _body def body_= (newBody: Rep) = _body = newBody def boundType = value.typ def typ = body.typ + /** Returns the last let-bindings of this conceptual block of code. + * Caution: this has linear complexity */ + def last: LetBinding = body match { + case lb: LetBinding => lb.last + case _ => this + } override def toString: String = s"val $bound = $value; $body" } -class Lambda(val name: String, val bound: Symbol, val boundType: TypeRep, val body: Rep)(implicit base: FastANF) extends Def with Binding { +class Lambda(var name: String, var bound: Symbol, val boundType: TypeRep, var body: Rep)(implicit base: FastANF) extends Def with RebindableBinding { val typ: TypeRep = base.funType(boundType, body.typ) override def toString: String = s"($bound: $boundType) => $body" } @@ -211,13 +232,18 @@ abstract class Symbol extends Rep with SymbolParent { def representative: Symbol = owner.bound def owner: Binding = _parent match { case bnd: Binding => - assert(bnd.bound eq this) + //assert(bnd.bound eq this) // FIXME? still seems to crash; is the assertion correct? bnd case parent: Symbol => val bnd = parent.owner _parent = bnd bnd } + def owner_=(bnd: RebindableBinding) = { + // TODO add appropriate assertions...? + bnd.bound = this + _parent = bnd + } def typ = owner.boundType def dfn: DefOption = owner match { case bnd: LetBinding => bnd.value @@ -229,6 +255,11 @@ abstract class Symbol extends Rep with SymbolParent { case s: Symbol => s.representative eq representative case _ => false } - override def toString: String = s"${owner.name}#${System.identityHashCode(representative)}" + //override def toString: String = s"${owner.name}#${System.identityHashCode(representative)}" + /* below is for easier debugging -- revert to the above for better perf */ + val id = Symbol.curId alsoDo (Symbol.curId += 1); override def toString: String = s"${owner.name}_${representative.id}" +} +object Symbol { + private var curId = 0 } From f7ab7b587c102527c9f26e287abc566083756c76 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Sun, 29 Oct 2017 14:39:19 +0100 Subject: [PATCH 05/66] HOPHole implementation --- src/main/scala/squid/ir/fastanf/FastANF.scala | 477 +++++++++--------- src/main/scala/squid/ir/fastanf/Rep.scala | 2 +- .../scala/squid/ir/fastir/BasicTests.scala | 32 +- .../fastir/HigherOrderPatternVariables.scala | 73 +++ .../squid/ir/fastir/RewritingTests.scala | 85 ++-- 5 files changed, 359 insertions(+), 310 deletions(-) create mode 100644 src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 774ccb21..e1853983 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -123,10 +123,12 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { case lb: LetBinding => // conceptually, does like `inlineBlock`, but additionally rewrites `bound` and renames `lb`'s last binding val last = lb.last + val boundName = bound.name bound rebind last.bound last.body = body - last.name = bound.name // TODO make sure we're only renaming an automatically-named binding? + last.name = boundName // TODO make sure we're only renaming an automatically-named binding? lb + // case c: Constant => bottomUpPartial(body) { case `bound` => c } case (_:HOPHole) | (_:Hole) | (_:SplicedHole) => ??? // TODO holes should probably be Def's; note that it's not safe to do a substitution for holes case _ => @@ -143,6 +145,16 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { try k finally curSub = oldSub } + override def tryInline(fun: Rep, arg: Rep)(retTp: TypeRep): Rep = { + fun match { + case lb: LetBinding => lb.value match { + case l: Lambda => letin(l.bound, arg, l.body, l.body.typ) + case _ => super.tryInline(fun, arg)(retTp) + } + case _ => super.tryInline(fun, arg)(retTp) + } + } + override def ascribe(self: Rep, typ: TypeRep): Rep = if (self.typ =:= typ) self else self match { case Ascribe(trueSelf, _) => Ascribe(trueSelf, typ) // Hopefully Scala's subtyping is transitive! case _ => Ascribe(self, typ) @@ -159,32 +171,8 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { // /** Artifact of a term extraction: map from hole name to terms, types and spliced term lists */ - def repEq(a: Rep, b: Rep): Boolean = { - (a extractRep b) === (b extractRep a) && (a extractRep b) === Some(EmptyExtract) && (b extractRep a) === Some(EmptyExtract) - - //val aExtractB = a extractRep b - // - //if (aExtractB.isEmpty) false - //else { - // (aExtractB, b extractRep a) match { - // case (Some((xs, xts, fxs)), Some((ys, yts, fys))) => - // val extractsHole: (String -> Rep) => Boolean = { - // case (k: String, Hole(n, _)) if k == n => true - // case _ => false - // } - // - // val extractsTypeHole: (String -> TypeRep) => Boolean = { - // case (k: String, DummyTypeRep) => true - // case _ => false - // } - // - // fxs.isEmpty && fys.isEmpty && - // (xs forall extractsHole) && (ys forall extractsHole) && - // (xts forall extractsTypeHole) && (yts forall extractsTypeHole) - // case _ => false - // } - //} - } + def repEq(a: Rep, b: Rep): Boolean = + (a extractRep b) === Some(EmptyExtract) && (b extractRep a) === Some(EmptyExtract) // * --- * --- * --- * Implementations of `IntermediateBase` methods * --- * --- * --- * @@ -219,11 +207,11 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { case Hole(n, t) => newBase.hole(n, reinterpretType(t)) case SplicedHole(n, t) => newBase.splicedHole(n, reinterpretType(t)) case Ascribe(s, t) => newBase.ascribe(go(s), reinterpretType(t)) - case HOPHole(n, t, yes, no) => newBase.hopHole( + case HOPHole(n, t, args, visible) => newBase.hopHole( n, reinterpretType(t), - yes.map(_.map(reinterpretBV)), - no.map(reinterpretBV)) + args.map(_.map(reinterpretBV)), + visible.map(reinterpretBV)) case Module(p, n, t) => newBase.module(go(p), n, reinterpretType(t)) case lb: LetBinding => newBase.letin( reinterpretBV(lb.bound), @@ -244,14 +232,14 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { def bottomUp(r: Rep)(f: Rep => Rep): Rep = transformRepAndDef(r)(identity, f)(identity) def topDown(r: Rep)(f: Rep => Rep): Rep = transformRepAndDef(r)(f)(identity) def transformRepAndDef(r: Rep)(pre: Rep => Rep, post: Rep => Rep = identity)(preDef: Def => Def, postDef: Def => Def = identity): Rep = { - def _transformRepAndDef(r: Rep) = transformRepAndDef(r)(pre, post)(preDef, postDef) + def transformRepAndDef0(r: Rep) = transformRepAndDef(r)(pre, post)(preDef, postDef) def transformDef(d: Def): Def = (d map preDef match { - case App(f, a) => App(_transformRepAndDef(f), _transformRepAndDef(a)) // Note: App is a MethodApp, but we can transform it more efficiently this way - case ma: MethodApp => MethodApp(_transformRepAndDef(ma.self), ma.mtd, ma.targs, ma.argss argssMap (_transformRepAndDef(_)), ma.typ) + case App(f, a) => App(transformRepAndDef0(f), transformRepAndDef0(a)) // Note: App is a MethodApp, but we can transform it more efficiently this way + case ma: MethodApp => MethodApp(transformRepAndDef0(ma.self), ma.mtd, ma.targs, ma.argss argssMap (transformRepAndDef0(_)), ma.typ) case l: Lambda => // Note: destructive modification of the lambda binding! - //new Lambda(l.name, l.bound, l.boundType, _transformRepAndDef(l.body)) - l.body = l.body |> _transformRepAndDef + //new Lambda(l.name, l.bound, l.boundType, transformRepAndDef0(l.body)) + l.body = l.body |> transformRepAndDef0 l }) map postDef @@ -261,15 +249,15 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { // lb.name, // lb.bound, // transformDef(lb.value), - // _transformRepAndDef(lb.body) + // transformRepAndDef0(lb.body) //) lb.value = lb.value |> transformDef - lb.body = lb.body |> _transformRepAndDef + lb.body = lb.body |> transformRepAndDef0 lb case Ascribe(s, t) => - Ascribe(_transformRepAndDef(s), t) + Ascribe(transformRepAndDef0(s), t) case Module(p, n, t) => - Module(_transformRepAndDef(p), n, t) + Module(transformRepAndDef0(p), n, t) case r @ ((_:Constant) | (_:Hole) | (_:Symbol) | (_:SplicedHole) | (_:HOPHole) | (_:NewObject) | (_:StaticModule)) => r }) } @@ -279,64 +267,107 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { protected def extract(xtor: Rep, xtee: Rep): Option[Extract] = extractWithCtx(xtor, xtee)(ListMap.empty) - def extractWithCtx(xtor: Rep, xtee: Rep)(implicit ctx: ListMap[BoundVal, BoundVal]): Option[Extract] = xtor -> xtee match { - case (lb1: LetBinding, lb2: LetBinding) => - val normal = for { - //e1 <- extractWithCtx(lb1.bound, lb2.bound) - e1 <- lb1.boundType extract (lb1.boundType, Covariant) - e2 <- extractValue(lb1.value, lb2.value) - e3 <- extractWithCtx(lb1.body, lb2.body)(ctx + (lb1.bound -> lb2.bound)) - m <- mergeAll(e1, e2, e3) - } yield m - - /* - * For instance when: - * xtor: val x0 = List(Hole(...): _*) - * xtee: val x0 = Seq(1, 2, 3); val x1 = List(x0) - */ - // TODO CHECK CTX - lazy val lookFurtherInXtee = extractWithCtx(lb1, lb2.body)(ctx + (lb1.bound -> lb2.bound)) - lazy val lookFurtherInXtor = extractWithCtx(lb1.body, lb2)(ctx + (lb1.bound -> lb2.bound)) - - normal //orElse lookFurtherInXtee orElse lookFurtherInXtor - - // Matches 42 and 42: Any, is it safe to ignore the typ? - case (_, Ascribe(s, _)) => extractWithCtx(xtor, s) - - case (Ascribe(s, t) , _) => - for { - e1 <- t extract (xtee.typ, Covariant) // t <:< a.typ, which one to use? - e2 <- extractWithCtx(s, xtee) - m <- merge(e1, e2) - } yield m - - case (Hole(n, t), _) => - for { - e <- t extract (xtee.typ, Covariant) - m <- merge(e, repExtract(n -> xtee)) - } yield m - - case (HOPHole(name, typ, args, visible), _) => unsupported - - case (bv1: BoundVal, bv2: BoundVal) => - if (bv1 == bv2) Some(EmptyExtract) - else for { - candidate <- ctx.get(bv1) - if candidate == bv2 - } yield EmptyExtract - - case (Constant(v1), Constant(v2)) if v1 == v2 => - xtor.typ extract (xtee.typ, Covariant) - - // Assuming if they have the same name the type is the same - case (StaticModule(fn1), StaticModule(fn2)) if fn1 == fn2 => Some(EmptyExtract) - - // Assuming if they have the same name and prefix the type is the same - case (Module(p1, n1, _), Module(p2, n2, _)) if n1 == n2 => extractWithCtx(p1, p2) - - case (NewObject(t1), NewObject(t2)) => t1 extract (t2, Covariant) - - case _ => None + def extractWithCtx(xtor: Rep, xtee: Rep)(implicit ctx: ListMap[BoundVal, Set[BoundVal]]): Option[Extract] = { + def reverse[A, B](m: Map[A, Set[B]]): Map[B, A] = for { + (a, bs) <- m + b <- bs + } yield b -> a + + xtor -> xtee match { + case (lb1: LetBinding, lb2: LetBinding) => + val normal = for { + //e1 <- extractWithCtx(lb1.bound, lb2.bound) + e1 <- lb1.boundType extract (lb1.boundType, Covariant) + e2 <- extractValue(lb1.value, lb2.value) + e3 <- extractWithCtx(lb1.body, lb2.body)(ctx + (lb1.bound -> Set(lb2.bound))) + m <- mergeAll(e1, e2, e3) + } yield m + + /* + * For instance when: + * xtor: val x0 = List(Hole(...): _*) + * xtee: val x0 = Seq(1, 2, 3); val x1 = List(x0) + */ + lazy val lookFurtherInXtee = extractWithCtx(lb1, lb2.body) + // lazy val lookFurtherInXtor = extractWithCtx(lb1.body, lb2) + + normal orElse lookFurtherInXtee //orElse lookFurtherInXtor + + // Matches 42 and 42: Any, is it safe to ignore the typ? + case (_, Ascribe(s, _)) => extractWithCtx(xtor, s) + + case (Ascribe(s, t) , _) => + for { + e1 <- t extract (xtee.typ, Covariant) // t <:< a.typ, which one to use? + e2 <- extractWithCtx(s, xtee) + m <- merge(e1, e2) + } yield m + + case (Hole(n, t), bv: BoundVal) => + val r = bv.owner match { + case lb: LetBinding => + new LetBinding(lb.name, lb.bound, lb.value, lb.bound) + case _ => bv + } + + for { + e <- t extract (xtee.typ, Covariant) + m <- merge(e, repExtract(n -> r)) + } yield m + + case (Hole(n, t), _) => + for { + e <- t extract (xtee.typ, Covariant) + m <- merge(e, repExtract(n -> xtee)) + } yield m + + case (h @ HOPHole(name, typ, argss, visible), _) => + type Func = List[List[BoundVal]] -> Rep + def emptyFunc(r: Rep) = List.empty[List[BoundVal]] -> r + def fargss(f: Func) = f._1 + def fbody(f: Func) = f._2 + + val ctx0 = ctx.mapValues(_.head) + val invCtx = reverse(ctx) + + bottomUpPartial(xtee) { case bv: BoundVal if visible contains invCtx(bv) => return None } + + def extendFunc(args: List[Rep], f: Func): Func = { + val args0 = args.map(bottomUpPartial(_) { case bv: BoundVal => ctx0.getOrElse(bv, bv) }) + val xs = args.map(arg => bindVal("hopArg", arg.typ, Nil)) + val transformation = (args0 zip xs).toMap + val body0 = bottomUp(fbody(f)) { case r => transformation.getOrElse(r, r) } + (xs :: fargss(f)) -> body0 + } + + for { + e1 <- typ extract (xtee.typ, Covariant) + f = argss.foldRight(emptyFunc(xtee))(extendFunc) + l = fargss(f).foldRight(fbody(f)) { case (args, body) => wrapConstruct(lambda(args, body)) } + e2 = repExtract(name -> l) + m <- merge(e1, e2) + } yield m + + case (bv1: BoundVal, bv2: BoundVal) => + if (bv1 == bv2) Some(EmptyExtract) + else for { + candidates <- ctx.get(bv1) + if candidates contains bv2 + } yield EmptyExtract + + case (Constant(v1), Constant(v2)) if v1 == v2 => + xtor.typ extract (xtee.typ, Covariant) + + // Assuming if they have the same name the type is the same + case (StaticModule(fn1), StaticModule(fn2)) if fn1 == fn2 => Some(EmptyExtract) + + // Assuming if they have the same name and prefix the type is the same + case (Module(p1, n1, _), Module(p2, n2, _)) if n1 == n2 => extractWithCtx(p1, p2) + + case (NewObject(t1), NewObject(t2)) => t1 extract (t2, Covariant) + + case _ => None + } } protected def spliceExtract(xtor: Rep, args: Args): Option[Extract] = xtor match { @@ -359,23 +390,21 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { } override def rewriteRep(xtor: Rep, xtee: Rep, code: Extract => Option[Rep]): Option[Rep] = { - def go(ex: Extract, matchedBVs: Set[BoundVal])(_xtor: Rep, _xtee: Rep)(implicit ctx: ListMap[BoundVal, BoundVal]): Option[Rep] = { - println(s"XTEE ->> ${_xtee}") - + def rewriteRep0(ex: Extract, matchedBVs: Set[BoundVal])(xtor: Rep, xtee: Rep)(implicit ctx: ListMap[BoundVal, Set[BoundVal]]): Option[Rep] = { def checkRefs(r: Rep): Option[Rep] = { def refs(r: Rep): Set[BoundVal] = { def bvsUsed(value: Def): Set[BoundVal] = value match { case ma: MethodApp => def bvsInArgss(argss: ArgumentLists): Set[BoundVal] = { - def go(argss: ArgumentLists, acc: Set[BoundVal]): Set[BoundVal] = argss match { - case ArgumentListCons(h, t) => go(t, go(h, acc)) - case ArgumentCons(h, t) => go(t, go(h, acc)) + def bvsInArgss0(argss: ArgumentLists, acc: Set[BoundVal]): Set[BoundVal] = argss match { + case ArgumentListCons(h, t) => bvsInArgss0(t, bvsInArgss0(h, acc)) + case ArgumentCons(h, t) => bvsInArgss0(t, bvsInArgss0(h, acc)) case SplicedArgument(bv: BoundVal) => acc + bv case bv: BoundVal => acc + bv case _ => acc } - go(argss, Set.empty) + bvsInArgss0(argss, Set.empty) } val selfBV = ma.self match { @@ -408,137 +437,80 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { def mkCode(e: Extract): Option[Rep] = { for { c <- code(e) - c <- checkRefs(c) +// c <- checkRefs(c) } yield c } - def letIn(body: Rep)(xBV: BoundVal, newX: Rep): Rep = { - println(s"----- In $body replacing $xBV with $newX") - def findAndReplace(argss: ArgumentLists, r: BoundVal, newR: BoundVal): ArgumentLists = argss.argssMap { - case a if a == r => println(s"### $a -> $newR"); newR - case a => println(s"+++ $a"); a - } + // TODO function name? + def traverseO[A](r: Rep)(f: LetBinding => Option[A]): Option[A] = r match { + case lb: LetBinding => f(lb) orElse traverseO(lb.body)(f) + case _ => None + } - // In `body` replace every occurrence of `x` with `outerLB` - def _letIn(body: Rep)(xBV: BoundVal, outerLB: LetBinding): Rep = { - def replaceInValue(v: Def): Def = { - v match { - case ma: MethodApp => - println(s"@@@@ $ma") - MethodApp( - if (ma.self == xBV) outerLB.bound else ma.self, - ma.mtd, - ma.targs, - findAndReplace(ma.argss, xBV, outerLB.bound), - ma.typ) - case l: Lambda => - new Lambda(l.name, - l.bound, - l.boundType, - _letIn(l.body)(xBV, outerLB)) + def extractValueWithBV(v: Def, r: Rep): Option[Extract -> Set[BoundVal]] = traverseO(r) { lb => + def getBVs(e: Extract): Set[BoundVal] = { + val a = e._1.values.foldLeft(Set.empty[BoundVal]) { case (acc, r) => + r match { + case lb: LetBinding => acc + lb.bound + case _ => acc } } - body match { - case lb: LetBinding if lb.bound == xBV => new LetBinding(outerLB.name, outerLB.bound, outerLB.value, _letIn(lb.body)(xBV, outerLB)) - - case lb: LetBinding => - new LetBinding( - lb.name, - lb.bound, - replaceInValue(lb.value), - _letIn(lb.body)(xBV, outerLB)) - - case `xBV` => outerLB.bound - case _ => body + val b = e._3.values.flatten.foldLeft(Set.empty[BoundVal]) { case (acc, r) => + r match { + case lb: LetBinding => acc + lb.bound + case _ => acc + } } - } - - newX match { - case outerLB: LetBinding => - //xBV rebind outerLB.bound - //body - _letIn(removeBVs(body, matchedBVs - xBV))(xBV, outerLB) alsoApply println - - case bv: BoundVal => - xBV rebind bv - body - case _ => ??? //_letIn(removeBVs(body, matchedBVs - xBV))(xBV, newX) + a ++ b } - } - - def withBV(r: Rep)(f: LetBinding => Option[Extract]): Option[Extract -> BoundVal] = { -// def extractArgss(f: A => Option[Extract])(argss: ArgumentLists): Option[Extract -> BoundVal] = argss match { -// case ArgumentListCons(h, t) => extractArgss(f)(h) orElse extractArgss(f)(t) -// case ArgumentCons(h, t) => extractArgss(f)(h) orElse extractArgss(f)(t) -// case SplicedArgument(a) => withBV(f)(a) -// case r: Rep => withBV(f)(r) -// case NoArguments | NoArgumentLists => None -// } - r match { - case lb: LetBinding => - f(lb) map { _ -> lb.bound} orElse { - lb.value match { - case l: Lambda => withBV(l.body)(f) - case _ => None - } - } orElse withBV(lb.body)(f) - - case _ => None - } + extractValue(v, lb.value) map (e => e -> (getBVs(e) + lb.bound)) } - def extractValueWithBV(v: Def, r: Rep): Option[Extract -> BoundVal] = withBV(r) { lb => extractValue(v, lb.value) } - def extractBVWithBV(bv: BoundVal, r: Rep): Option[Extract -> BoundVal] = withBV(r) { lb => extractWithCtx(bv, lb.bound) } - - def removeBVs(r: Rep, bvs: Set[BoundVal]): Rep = { - def lambdaRemove(l: Lambda): Def = { - // TODO check l.bound? - new Lambda(l.name, l.bound, l.boundType, removeBVs(l.body, bvs)) - } - - r match { - case lb: LetBinding if bvs contains lb.bound => lb.body - case lb: LetBinding => lb.value match { - case l: Lambda => new LetBinding(lb.name, lb.bound, lambdaRemove(l), removeBVs(lb.body, bvs)) - case _ => new LetBinding(lb.name, lb.bound, lb.value, removeBVs(lb.body, bvs)) - } - case _ => r - } + def removeBVs(r: Rep, bvs: Set[BoundVal]): Rep = r match { + case lb: LetBinding if bvs contains lb.bound => removeBVs(lb.body, bvs) + case lb: LetBinding => new LetBinding(lb.name, lb.bound, lb.value, removeBVs(lb.body, bvs)) + case _ => r } - (_xtor, _xtee) match { + (xtor, xtee) match { case (lb1: LetBinding, lb2: LetBinding) => for { - (e, matchedBV) <- extractValueWithBV(lb1.value, lb2) - _ = println(s"----- Matched $matchedBV") - if !(matchedBVs contains matchedBV) // Cannot match twice the same code line + (e, newBVs) <- extractValueWithBV(lb1.value, lb2) m <- merge(e, ex) - r <- go(m, matchedBVs + matchedBV)(lb1.body, lb2)(ctx + (lb1.bound -> matchedBV)) + removed = removeBVs(lb2, newBVs) + r <- rewriteRep0(m, matchedBVs ++ newBVs)(lb1.body, removed)(ctx + (lb1.bound -> newBVs)) } yield r // Match return of `xtor` with something in `xtee` - case (bv: BoundVal, lb: LetBinding) => - println(s"Knowing: $ctx") - println(s"Matching... $bv in $lb") + case (bv: BoundVal, _: LetBinding) => for { - (e, matchedBV) <- extractBVWithBV(bv, lb) - _ = println(s"----->> Matched $matchedBV") + x <- ctx.get(bv) + newX <- mkCode(ex) + _ = assert(x.size == 1) + xHead = x.head + newLB = newX match { + case _: LetBinding => letin(xHead, newX, xtee, xtee.typ) + case _ => bottomUpPartial(xtee) { case `xHead` => newX } + } + } yield newLB + + case (h: HOPHole, lb: LetBinding) => + for { + e <- extractWithCtx(h, lb) m <- merge(e, ex) newX <- mkCode(m) - _ = println(s"Code => $newX") - newC = letIn(lb)(matchedBV, newX) alsoApply (c => println(s"New code => $c")) - } yield newC + } yield newX - case (bv: BoundVal, xtee: Rep) => + case (h: Hole, lb: LetBinding) => for { - (e, matchedBV) <- extractBVWithBV(bv, xtee) + e <- extractWithCtx(h, lb) m <- merge(e, ex) newX <- mkCode(m) - newC = letIn(xtee)(matchedBV, newX) - } yield newC + // _ = lb.body = newX + } yield newX // Match Constant(42) with `value` of the `LetBinding` case (xtor: Rep, lb: LetBinding) => @@ -615,9 +587,9 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { rewrittenValue map { case value -> lbs => new LetBinding(lb.name, lb.bound, value, lb.body) -> lbs} } - def go(lb: LetBinding, acc: Option[List[LazyInlined[Rep]]]): Option[List[LazyInlined[Rep]]] = lb.body match { + def rewriteValue0(lb: LetBinding, acc: Option[List[LazyInlined[Rep]]]): Option[List[LazyInlined[Rep]]] = lb.body match { case innerLB: LetBinding => - go(innerLB, + rewriteValue0(innerLB, for { acc <- acc lb <- rewriteValue(lb) @@ -630,65 +602,70 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { } yield lazyInlined(ret) :: lb :: acc } - go(lb, Some(List.empty)) map { + rewriteValue0(lb, Some(List.empty)) map { case ret :: outerLBs => run(outerLBs.foldLeft(ret)(surroundWithLB)) + case _ => lb } } rewrite(lb) - // Matching return values - case (r1: Rep, r2: Rep) => + case (_: Rep, _: Rep) => for { - e <- extractWithCtx(r1, r2) + e <- extractWithCtx(xtor, xtee) m <- merge(e, ex) c <- mkCode(m) } yield c + + // Matching return values + case _ => None } } - go(EmptyExtract, Set.empty)(xtor, xtee)(ListMap.empty) + rewriteRep0(EmptyExtract, Set.empty)(xtor, xtee)(ListMap.empty) } - def extractValue(v1: Def, v2: Def)(implicit ctx: ListMap[BoundVal, BoundVal]) = (v1, v2) match { - case (l1: Lambda, l2: Lambda) => - for { - e1 <- l1.boundType extract (l2.boundType, Covariant) - e2 <- extractWithCtx(l1.body, l2.body)(ctx + (l1.bound -> l2.bound)) - m <- merge(e1, e2) - } yield m - - case (ma1: MethodApp, ma2: MethodApp) if ma1.mtd == ma2.mtd => - lazy val targExtract = mergeAll(for { - (e1, e2) <- ma1.targs zip ma2.targs - } yield e1 extract (e2, Invariant)) // TODO Invariant? Depends on its positions... - - def extractArgss(argss1: ArgumentLists, argss2: ArgumentLists): Option[Extract] = { - def go(argss1: ArgumentLists, argss2: ArgumentLists, acc: List[Rep]): Option[Extract] = (argss1, argss2) match { - case (ArgumentListCons(h1, t1), ArgumentListCons(h2, t2)) => mergeOpt(go(h1, h2, acc), go(t1, t2, acc)) - case (ArgumentCons(h1, t1), ArgumentCons(h2, t2)) => mergeOpt(go(h1, h2, acc), go(t1, t2, acc)) - case (SplicedArgument(arg1), SplicedArgument(arg2)) => extractWithCtx(arg1, arg2) - case (sa: SplicedArgument, ArgumentCons(h, t)) => go(sa, t, h :: acc) - case (sa: SplicedArgument, r: Rep) => go(sa, NoArguments, r :: acc) - case (SplicedArgument(arg), NoArguments) => spliceExtract(arg, Args(acc.reverse: _*)) // Reverses list... - case (r1: Rep, r2: Rep) => extractWithCtx(r1, r2) - case (NoArguments, NoArguments) => Some(EmptyExtract) - case (NoArgumentLists, NoArgumentLists) => Some(EmptyExtract) - case _ => None - } + def extractValue(v1: Def, v2: Def)(implicit ctx: ListMap[BoundVal, Set[BoundVal]]) = { + (v1, v2) match { + case (l1: Lambda, l2: Lambda) => + for { + e1 <- l1.boundType extract (l2.boundType, Covariant) + e2 <- extractWithCtx(l1.body, l2.body)(ctx + (l1.bound -> Set(l2.bound))) + m <- merge(e1, e2) + } yield m + + case (ma1: MethodApp, ma2: MethodApp) if ma1.mtd == ma2.mtd => + lazy val targExtract = mergeAll(for { + (e1, e2) <- ma1.targs zip ma2.targs + } yield e1 extract (e2, Invariant)) // TODO Invariant? Depends on its positions... + + def extractArgss(argss1: ArgumentLists, argss2: ArgumentLists): Option[Extract] = { + def extractArgss0(argss1: ArgumentLists, argss2: ArgumentLists, acc: List[Rep]): Option[Extract] = (argss1, argss2) match { + case (ArgumentListCons(h1, t1), ArgumentListCons(h2, t2)) => mergeOpt(extractArgss0(h1, h2, acc), extractArgss0(t1, t2, acc)) + case (ArgumentCons(h1, t1), ArgumentCons(h2, t2)) => mergeOpt(extractArgss0(h1, h2, acc), extractArgss0(t1, t2, acc)) + case (SplicedArgument(arg1), SplicedArgument(arg2)) => extractWithCtx(arg1, arg2) + case (sa: SplicedArgument, ArgumentCons(h, t)) => extractArgss0(sa, t, h :: acc) + case (sa: SplicedArgument, r: Rep) => extractArgss0(sa, NoArguments, r :: acc) + case (SplicedArgument(arg), NoArguments) => spliceExtract(arg, Args(acc.reverse: _*)) // Reverses list... + case (r1: Rep, r2: Rep) => extractWithCtx(r1, r2) + case (NoArguments, NoArguments) => Some(EmptyExtract) + case (NoArgumentLists, NoArgumentLists) => Some(EmptyExtract) + case _ => None + } - go(argss1, argss2, Nil) - } + extractArgss0(argss1, argss2, Nil) + } - for { - e1 <- extractWithCtx(ma1.self, ma2.self) - e2 <- targExtract - e3 <- extractArgss(ma1.argss, ma2.argss) - e4 <- ma1.typ extract (ma2.typ, Covariant) - m <- mergeAll(e1, e2, e3, e4) - } yield m + for { + e1 <- extractWithCtx(ma1.self, ma2.self) + e2 <- targExtract + e3 <- extractArgss(ma1.argss, ma2.argss) + e4 <- ma1.typ extract (ma2.typ, Covariant) + m <- mergeAll(e1, e2, e3, e4) + } yield m - case _ => None + case _ => None + } } // * --- * --- * --- * Implementations of `QuasiBase` methods * --- * --- * --- * diff --git a/src/main/scala/squid/ir/fastanf/Rep.scala b/src/main/scala/squid/ir/fastanf/Rep.scala index ecefb186..0b136a66 100644 --- a/src/main/scala/squid/ir/fastanf/Rep.scala +++ b/src/main/scala/squid/ir/fastanf/Rep.scala @@ -209,7 +209,7 @@ class LetBinding(var name: String, var bound: Symbol, var value: Def, private va } class Lambda(var name: String, var bound: Symbol, val boundType: TypeRep, var body: Rep)(implicit base: FastANF) extends Def with RebindableBinding { val typ: TypeRep = base.funType(boundType, body.typ) - override def toString: String = s"($bound: $boundType) => $body" + override def toString: String = s"($bound: $boundType) => $body --" } /** Currently used mainly for reification. */ // Note: could intern these objects diff --git a/src/test/scala/squid/ir/fastir/BasicTests.scala b/src/test/scala/squid/ir/fastir/BasicTests.scala index 0796cbe1..83001da2 100644 --- a/src/test/scala/squid/ir/fastir/BasicTests.scala +++ b/src/test/scala/squid/ir/fastir/BasicTests.scala @@ -73,20 +73,28 @@ class BasicTests extends MyFunSuiteBase(BasicTests.Embedding) { } test("Transformers") { - object Tr extends BasicTests.Embedding.SelfTransformer with SimpleRuleBasedTransformer with TopDownTransformer { - rewrite { - case ir"123" => ir"111" - case ir"readInt" => ir"42" - } - } - - val p = - BasicTests.Embedding.debugFor { - ir"readInt+123" alsoApply println transformWith Tr alsoApply println - } + //object Tr extends BasicTests.Embedding.SelfTransformer with SimpleRuleBasedTransformer with TopDownTransformer { + // rewrite { + // case ir"123" => ir"readInt + 5" + // case ir"readDouble" => ir"10.0" + // } + //} + // + //val p = + // BasicTests.Embedding.debugFor { + // ir"123+readDouble+123" alsoApply println transformWith Tr alsoApply println + // } + } - assert(p =~= ir"${ir"42"}+123") + test("Rewriting simple expressions only once") { + val a = ir"println((50,60))" + val b = a rewrite { + case ir"($x:Int,$y:Int)" => + ir"($y:Int,$x:Int)" + case ir"(${Const(n)}:Int)" => Const(n+1) + } + assert(b =~= ir"println((61,51))") } test("Extract") { diff --git a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala new file mode 100644 index 00000000..80123d5b --- /dev/null +++ b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala @@ -0,0 +1,73 @@ +package squid +package ir +package fastir + +import utils._ + +class HigherOrderPatternVariables extends MyFunSuite { + import TestDSL.Predef._ + + test("Matching lambda bodies") { + + val id = ir"(z:Int) => z" + + ir"(a: Int) => a + 1" matches { + case ir"(x: Int) => $body(x):Int" => + body eqt ir"(_:Int)+1" + } and { + case ir"(x: Int) => $body(x):$t" => + body eqt ir"(_:Int)+1" + eqt(t, irTypeOf[Int]) + } and { + case ir"(x: Int) => ($exp(x):Int)+1" => + exp eqt id + } + + ir"(a: Int, b: Int) => a + 1" matches { + case ir"(x: Int, y: Int) => $body(y):Int" => fail + case ir"(x: Int, y: Int) => $body(x):Int" => + } and { + case ir"(x: Int, y: Int) => $body(x,y):Int" => + } + + ir"(a: Int, b: Int) => a + b" matches { + case ir"(x: Int, y: Int) => $body(x):Int" => fail + case ir"(x: Int, y: Int) => $body(x,y):Int" => + body eqt ir"(_:Int)+(_:Int)" + } and { + case ir"(x: Int, y: Int) => ($lhs(y):Int)+($rhs(y):Int)" => fail + case ir"(x: Int, y: Int) => ($lhs(x):Int)+($rhs(y):Int)" => + lhs eqt id + rhs eqt id + } + + } + + test("Matching let-binding bodies") { + + ir"val a = 0; val b = 1; a + b" matches { + case ir"val x: Int = $v; $body(x):Int" => + v eqt ir"0" + body matches { + case ir"(y:Int) => { val x: Int = $w; $body(x,y):Int }" => + w eqt ir"1" + body eqt ir"(u:Int,v:Int) => (v:Int)+(u:Int)" + } + } + } + + test("Find a good name") { + val a = ir"val a = 10.toDouble; val b = a + 1; b" rewrite { + case ir"val x = 10.toDouble; $body(x):Double" => body(ir"readDouble") + } + assert(a == ir"val r = readDouble + 1; r") + + // val b = ir"val a = 10; val b = readInt; val c = a + b; c" rewrite { + // case ir"val x = 10; val y = readInt; $body(x): Int" => ir"$body(42)" + // // case ir"val x = 10.toDouble; $body(x):Double" => body(ir"readDouble") + // } alsoApply println + // assert(b == ir"readDouble.toInt + 42") + // assert(b == ir"val r = readDouble + 1; r") + } + +} diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index d512dbad..bcb4e92a 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -5,64 +5,55 @@ package fastir class RewritingTests extends MyFunSuiteBase(BasicTests.Embedding) { import RewritingTests.Embedding.Predef._ - test("Worksheet") { - val a = ir"readDouble.toInt" alsoApply println - - val b = a rewrite { - case ir"readDouble.toInt" => - ir"readInt" - } alsoApply println - - - } + test("Simple rewrites") { + val a = ir"123" rewrite { + case ir"123" => ir"666" + } + assert(a =~= ir"666") - test("Rewriting simple expressions only once") { - val a = ir"println((50,60))" - val b = a rewrite { - case ir"($x:Int,$y:Int)" => - ir"($y:Int,$x:Int)" - case ir"(${Const(n)}:Int)" => Const(n+1) + val b = ir"42.toFloat" rewrite { + case ir"42.toFloat" => ir"42f" } + assert(b =~= ir"42f") - assert(b =~= ir"println((61,51))") + val c = ir"42.toDouble" rewrite { + case ir"(${Const(n)}: Int).toDouble" => ir"${Const(n.toDouble)}" + } + assert(c =~= ir"42.0") } - test("Rewriting Sequences of Bindings") { - - val a = ir"val aaa = readInt; val bbb = readDouble.toInt; aaa+bbb" alsoApply println // FIXMElater: why isn't name 'bbb' conserved? - val b = a rewrite { - case ir"readDouble.toInt" => - ir"readInt" - } alsoApply println + test("Rewriting subpatterns") { + val a = ir"(readInt + 111) * .5" rewrite { + case ir"(($n: Int) + 111) * .5" => ir"$n * .25" + } + assert(a =~= ir"readInt * .25") - ir"val aa = readInt; val bb = readInt; aa+bb" alsoApply println + val b = ir"(x: Int) => (x-5) * 32" rewrite { + case ir"($b: Int) * 32" => ir"$b" + } + assert(b =~= ir"(x: Int) => x - 5") - assert(b =~= ir"val aa = readInt; val bb = readInt; aa+bb") + val c = ir"Option(42).get" rewrite { + case ir"Option(($n: Int)).get" => n + } + assert(c =~= ir"42") -// { -// val a = ir"val a = 11.toDouble; val b = 22.toDouble; val c = 33.toDouble; (a,b,c)" -// val c = a rewrite { -// case ir"val a = ($x:Int).toDouble; val b = ($y:Int).toDouble; $body: $bt" => -// ir"val a = ($x+$y).toDouble/2; val b = a; $body" -// } -// assert(c =~= ir"val a = (11+${ir"22"}).toDouble/2; val c = 33.toDouble; (a,a,c)") -// -// /* -// // TODO make this work like the one above -// val d = a rewrite { -// case ir"val a = ($x:Int).toDouble; ($y:Int).toDouble" => -// ir"($x+$y).toDouble/2" -// } -// println(d) -// //d eqt ir"val a = (11+${ir"22"}).toDouble/2; val c = 33.toDouble; (a,a,c)" -// */ -// } -// -// + val d = ir"val a = Option(42).get; a * 5" rewrite { + case ir"Option(($n: Int)).get" => n + case ir"($a: Int) * 5" => ir"$a * 2" + } + assert(d =~= ir"val a = 42; a * 2") } - + test("Rewriting simple expressions only once") { + val a = ir"println((50, 60))" rewrite { + case ir"($x:Int,$y:Int)" => ir"($y:Int,$x:Int)" + case ir"(${Const(n)}:Int)" => Const(n+1) + } alsoApply println + assert(a =~= ir"println((61,51))") + } } + object RewritingTests { object Embedding extends FastANF } From d5d31610b2d57edc51344425977aecc51a733629 Mon Sep 17 00:00:00 2001 From: LPTK Date: Tue, 31 Oct 2017 17:39:16 +0100 Subject: [PATCH 06/66] =?UTF-8?q?Don=E2=80=99t=20inline=20all=20arguments?= =?UTF-8?q?=20(cf.=20some=20are=20by-name);=20do=20it=20in=20`substitute`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/scala/squid/ir/fastanf/FastANF.scala | 16 ++++++++-------- src/main/scala/squid/ir/fastanf/Rep.scala | 3 ++- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index e1853983..d46fdbf3 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -37,15 +37,15 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { override final def wrapExtract(r: => Rep): Rep = wrap(super.wrapExtract(r), true) @inline final def currentScope = scopes.head - - def toArgumentLists(argss0: List[ArgList]): ArgumentLists = { - val argss = argss0.map(_.map(this)(inlineBlock)) // TODO optimize: avoid reconstruction of the ArgList's - + + def toArgumentLists(argss: List[ArgList]): ArgumentLists = { + // Note: some arguments may be let-bindings (ie: blocks), which is only possible if they are by-name arguments + def toArgumentList(args: Seq[Rep]): ArgumentList = args.foldRight(NoArguments: ArgumentList)(_ ~: _) def toArgumentListWithSpliced(args: Seq[Rep])(splicedArg: Rep) = args.foldRight(SplicedArgument(splicedArg): ArgumentList)(_ ~: _) - + argss.foldRight(NoArgumentLists: ArgumentLists) { (args, acc) => args match { case Args(as @ _*) => toArgumentList(as) ~~: acc @@ -54,7 +54,7 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { } } } - + def toListOfArgList(argss: ArgumentLists): List[ArgList] = { def toArgList(args: ArgumentList): List[Rep] -> Option[Rep] = args match { case NoArguments => Nil -> None @@ -676,12 +676,12 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { def hopHole(name: String, typ: TypeRep, yes: List[List[BoundVal]], no: List[BoundVal]) = HOPHole(name, typ, yes, no) def substitute(r: => Rep, defs: Map[String, Rep]): Rep = - if (defs isEmpty) r + if (defs isEmpty) r |> inlineBlock else bottomUp(r) { case h @ Hole(n, _) => defs getOrElse(n, h) case h @ SplicedHole(n, _) => defs getOrElse(n, h) case h => h - } + } |> inlineBlock // * --- * --- * --- * Implementations of `TypingBase` methods * --- * --- * --- * diff --git a/src/main/scala/squid/ir/fastanf/Rep.scala b/src/main/scala/squid/ir/fastanf/Rep.scala index 0b136a66..e46d9990 100644 --- a/src/main/scala/squid/ir/fastanf/Rep.scala +++ b/src/main/scala/squid/ir/fastanf/Rep.scala @@ -87,7 +87,8 @@ trait MethodApp extends Def { def argss: ArgumentLists def typ: TypeRep protected def doChecks = // we can't execute the checks here right away, because of initialization order - (self :: argss.argssList).foreach(r => assert(!r.isInstanceOf[LetBinding], s"Illegal ANF argument/self: $r")) + //(self :: argss.argssList).foreach(r => assert(!r.isInstanceOf[LetBinding], s"Illegal ANF argument/self: $r")) + assert(!self.isInstanceOf[LetBinding], s"Illegal ANF self argument: $self") // ^ some arguments may be by-name! override def toString = s"$self.${mtd.name}${argss.toArgssString}" } object MethodApp { From f4104178f9c5fac521442920df3621d158a7cdb8 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Sun, 29 Oct 2017 09:37:37 +0100 Subject: [PATCH 07/66] Add HOPHole2 support --- src/main/scala/squid/ir/fastanf/FastANF.scala | 130 +++++++++++------- src/main/scala/squid/ir/fastanf/Rep.scala | 2 + src/main/scala/squid/test/Test.scala | 30 ++-- .../scala/squid/ir/fastir/BasicTests.scala | 2 +- .../fastir/HigherOrderPatternVariables.scala | 107 +++++++++----- .../squid/ir/fastir/RewritingTests.scala | 42 ++++-- 6 files changed, 209 insertions(+), 104 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index d46fdbf3..effa0c55 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -129,7 +129,7 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { last.name = boundName // TODO make sure we're only renaming an automatically-named binding? lb // case c: Constant => bottomUpPartial(body) { case `bound` => c } - case (_:HOPHole) | (_:Hole) | (_:SplicedHole) => + case (_:HOPHole) | (_:HOPHole2) | (_:Hole) | (_:SplicedHole) => ??? // TODO holes should probably be Def's; note that it's not safe to do a substitution for holes case _ => withSubs(bound -> value)(body) @@ -179,25 +179,25 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { def nullValue[T: IRType]: IR[T,{}] = IR[T, {}](const(null)) // FIXME: should implement proper semantics; e.g. nullValue[Int] == ir"0", not ir"null" def reinterpret(r: Rep, newBase: Base)(extrudedHandle: BoundVal => newBase.Rep): newBase.Rep = { - def go: Rep => newBase.Rep = r => reinterpret(r, newBase)(extrudedHandle) + def reinterpret0: Rep => newBase.Rep = r => reinterpret(r, newBase)(extrudedHandle) def reinterpretType: TypeRep => newBase.TypeRep = t => newBase.staticTypeApp(newBase.loadTypSymbol("scala.Any"), Nil) def reinterpretBV:BoundVal => newBase.BoundVal = bv => newBase.bindVal(bv.name, reinterpretType(bv.typ), Nil) def reinterpretTypSym(t: TypeSymbol): newBase.TypSymbol = newBase.loadTypSymbol(t.name) def reinterpretMtdSym(s: MtdSymbol): newBase.MtdSymbol = newBase.loadMtdSymbol(reinterpretTypSym(s.typ), s.name) def reinterpretArgList(argss: ArgumentLists): List[newBase.ArgList] = toListOfArgList(argss) map { - case ArgsVarargSpliced(args, varargs) => newBase.ArgsVarargSpliced(args.map(newBase)(go), go(varargs)) - case ArgsVarargs(args, varargs) => newBase.ArgsVarargs(args.map(newBase)(go), varargs.map(newBase)(go)) - case args : Args => args.map(newBase)(go) + case ArgsVarargSpliced(args, varargs) => newBase.ArgsVarargSpliced(args.map(newBase)(reinterpret0), reinterpret0(varargs)) + case ArgsVarargs(args, varargs) => newBase.ArgsVarargs(args.map(newBase)(reinterpret0), varargs.map(newBase)(reinterpret0)) + case args : Args => args.map(newBase)(reinterpret0) } def defToRep(d: Def): newBase.Rep = d match { - case app @ App(f, a) => newBase.app(go(f), go(a))(reinterpretType(app.typ)) + case app @ App(f, a) => newBase.app(reinterpret0(f), reinterpret0(a))(reinterpretType(app.typ)) case ma : MethodApp => newBase.methodApp( - go(ma.self), + reinterpret0(ma.self), reinterpretMtdSym(ma.mtd), ma.targs.map(reinterpretType), reinterpretArgList(ma.argss), reinterpretType(ma.typ)) - case l: Lambda => newBase.lambda(List(reinterpretBV(l.bound)), go(l.body)) + case l: Lambda => newBase.lambda(List(reinterpretBV(l.bound)), reinterpret0(l.body)) } r match { @@ -206,17 +206,23 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { case NewObject(t) => newBase.newObject(reinterpretType(t)) case Hole(n, t) => newBase.hole(n, reinterpretType(t)) case SplicedHole(n, t) => newBase.splicedHole(n, reinterpretType(t)) - case Ascribe(s, t) => newBase.ascribe(go(s), reinterpretType(t)) + case Ascribe(s, t) => newBase.ascribe(reinterpret0(s), reinterpretType(t)) case HOPHole(n, t, args, visible) => newBase.hopHole( n, reinterpretType(t), args.map(_.map(reinterpretBV)), visible.map(reinterpretBV)) - case Module(p, n, t) => newBase.module(go(p), n, reinterpretType(t)) + case HOPHole2(n, t, args, visible) => newBase.hopHole2( + n, + reinterpretType(t), + args.map(_.map(reinterpret0)), + visible.map(reinterpretBV) + ) + case Module(p, n, t) => newBase.module(reinterpret0(p), n, reinterpretType(t)) case lb: LetBinding => newBase.letin( reinterpretBV(lb.bound), defToRep(lb.value), - go(lb.body), + reinterpret0(lb.body), reinterpretType(lb.typ)) case s: Symbol => extrudedHandle(s) } @@ -258,7 +264,7 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { Ascribe(transformRepAndDef0(s), t) case Module(p, n, t) => Module(transformRepAndDef0(p), n, t) - case r @ ((_:Constant) | (_:Hole) | (_:Symbol) | (_:SplicedHole) | (_:HOPHole) | (_:NewObject) | (_:StaticModule)) => r + case r @ ((_:Constant) | (_:Hole) | (_:Symbol) | (_:SplicedHole) | (_:HOPHole) | (_:HOPHole2) | (_:NewObject) | (_:StaticModule)) => r }) } @@ -268,11 +274,40 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { protected def extract(xtor: Rep, xtee: Rep): Option[Extract] = extractWithCtx(xtor, xtee)(ListMap.empty) def extractWithCtx(xtor: Rep, xtee: Rep)(implicit ctx: ListMap[BoundVal, Set[BoundVal]]): Option[Extract] = { + println(s"$xtor\n$xtee with $ctx\n\n") def reverse[A, B](m: Map[A, Set[B]]): Map[B, A] = for { (a, bs) <- m b <- bs } yield b -> a + def extractHOPHole(name: String, typ: TypeRep, argss: List[List[Rep]], visible: List[BoundVal]): Option[Extract] = { + type Func = List[List[BoundVal]] -> Rep + def emptyFunc(r: Rep) = List.empty[List[BoundVal]] -> r + def fargss(f: Func) = f._1 + def fbody(f: Func) = f._2 + + val ctx0 = ctx.mapValues(_.head) + val invCtx = reverse(ctx) + + bottomUpPartial(xtee) { case bv: BoundVal if visible contains invCtx.getOrElse(bv, bv) => return None } + + def extendFunc(args: List[Rep], f: Func): Func = { + val args0 = args.map(bottomUpPartial(_) { case bv: BoundVal => ctx0.getOrElse(bv, bv) }) + val xs = args.map(arg => bindVal("hopArg", arg.typ, Nil)) + val transformation = (args0 zip xs).toMap + val body0 = bottomUp(fbody(f)) { case r => transformation.getOrElse(r, r) } + (xs :: fargss(f)) -> body0 + } + + for { + e1 <- typ extract (xtee.typ, Covariant) + f = argss.foldRight(emptyFunc(xtee))(extendFunc) + l = fargss(f).foldRight(fbody(f)) { case (args, body) => wrapConstruct(lambda(args, body)) } + e2 = repExtract(name -> l) + m <- merge(e1, e2) + } yield m + } + xtor -> xtee match { case (lb1: LetBinding, lb2: LetBinding) => val normal = for { @@ -288,6 +323,13 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { * xtor: val x0 = List(Hole(...): _*) * xtee: val x0 = Seq(1, 2, 3); val x1 = List(x0) */ + + /* TODO problematic since it will ignore `val aX = 42` + * XTEE: `val a = readInt; val b = 42` + * XTOR: `val aX = 42; val bX = readInt` + * Could swap `a` and `b` if they are pure, and not linked (b references a)? + * Would be nice to have `x + y + z` match `y + x + z` + */ lazy val lookFurtherInXtee = extractWithCtx(lb1, lb2.body) // lazy val lookFurtherInXtor = extractWithCtx(lb1.body, lb2) @@ -305,8 +347,7 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { case (Hole(n, t), bv: BoundVal) => val r = bv.owner match { - case lb: LetBinding => - new LetBinding(lb.name, lb.bound, lb.value, lb.bound) + case lb: LetBinding => new LetBinding(lb.name, lb.bound, lb.value, lb.bound) // Why doesn't `lb.body = lb.bound; lb` work? case _ => bv } @@ -321,32 +362,10 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { m <- merge(e, repExtract(n -> xtee)) } yield m - case (h @ HOPHole(name, typ, argss, visible), _) => - type Func = List[List[BoundVal]] -> Rep - def emptyFunc(r: Rep) = List.empty[List[BoundVal]] -> r - def fargss(f: Func) = f._1 - def fbody(f: Func) = f._2 - - val ctx0 = ctx.mapValues(_.head) - val invCtx = reverse(ctx) - - bottomUpPartial(xtee) { case bv: BoundVal if visible contains invCtx(bv) => return None } - - def extendFunc(args: List[Rep], f: Func): Func = { - val args0 = args.map(bottomUpPartial(_) { case bv: BoundVal => ctx0.getOrElse(bv, bv) }) - val xs = args.map(arg => bindVal("hopArg", arg.typ, Nil)) - val transformation = (args0 zip xs).toMap - val body0 = bottomUp(fbody(f)) { case r => transformation.getOrElse(r, r) } - (xs :: fargss(f)) -> body0 - } - - for { - e1 <- typ extract (xtee.typ, Covariant) - f = argss.foldRight(emptyFunc(xtee))(extendFunc) - l = fargss(f).foldRight(fbody(f)) { case (args, body) => wrapConstruct(lambda(args, body)) } - e2 = repExtract(name -> l) - m <- merge(e1, e2) - } yield m + case (HOPHole(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) + case (h@HOPHole2(name, typ, argss, visible), _) => + //println(s"Hole $h -> $xtee\n\n") + extractHOPHole(name, typ, argss, visible) case (bv1: BoundVal, bv2: BoundVal) => if (bv1 == bv2) Some(EmptyExtract) @@ -471,7 +490,9 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { def removeBVs(r: Rep, bvs: Set[BoundVal]): Rep = r match { case lb: LetBinding if bvs contains lb.bound => removeBVs(lb.body, bvs) - case lb: LetBinding => new LetBinding(lb.name, lb.bound, lb.value, removeBVs(lb.body, bvs)) + case lb: LetBinding => + lb.body = removeBVs(lb.body, bvs) + lb case _ => r } @@ -504,6 +525,13 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { newX <- mkCode(m) } yield newX + case (h2: HOPHole2, lb: LetBinding) => + for { + e <- extractWithCtx(h2, lb) + m <- merge(e, ex) + newX <- mkCode(m) + } yield newX + case (h: Hole, lb: LetBinding) => for { e <- extractWithCtx(h, lb) @@ -517,13 +545,14 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { type LazyInlined[R] = R -> List[LetBinding] implicit def lazyInlined[R](r: R): LazyInlined[R] = r -> List.empty def run(lr: LazyInlined[Rep]): Rep = lr match { - case r -> lbs => lbs.foldLeft(r) { (acc, outerLB) => new LetBinding(outerLB.name, outerLB.bound, outerLB.value, acc) } + case r -> lbs => lbs.foldLeft(r) { (acc, outerLB) => outerLB.body = acc; outerLB } } // Puts `acc` inside `outer`. def surroundWithLB(acc: LazyInlined[Rep], outer: LazyInlined[Rep]): LazyInlined[Rep] = outer match { case (outerLB: LetBinding, lbs) => - new LetBinding(outerLB.name, outerLB.bound, outerLB.value, run(acc)) -> lbs + outerLB.body = run(acc) + outerLB -> lbs case _ => throw new IllegalArgumentException } @@ -578,13 +607,19 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { } yield MethodApp(self, ma.mtd, ma.targs, argss, ma.typ) -> (innerLBs ::: outerLBs) case l: Lambda => rewriteRep(l.body) map { - case body -> lbs => new Lambda(l.name, l.bound, l.boundType, body) -> lbs + case body -> lbs => + l.body = body + l -> lbs } } } // Puts the value back into its LB - rewrittenValue map { case value -> lbs => new LetBinding(lb.name, lb.bound, value, lb.body) -> lbs} + rewrittenValue map { + case value -> lbs => + lb.value = value + lb -> lbs + } } def rewriteValue0(lb: LetBinding, acc: Option[List[LazyInlined[Rep]]]): Option[List[LazyInlined[Rep]]] = lb.body match { @@ -626,6 +661,7 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { } def extractValue(v1: Def, v2: Def)(implicit ctx: ListMap[BoundVal, Set[BoundVal]]) = { + //println(s"$v1\n$v2 with $ctx \n\n") (v1, v2) match { case (l1: Lambda, l2: Lambda) => for { @@ -674,12 +710,14 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { def splicedHole(name: String, typ: TypeRep): Rep = SplicedHole(name, typ) def typeHole(name: String): TypeRep = DummyTypeRep def hopHole(name: String, typ: TypeRep, yes: List[List[BoundVal]], no: List[BoundVal]) = HOPHole(name, typ, yes, no) - + override def hopHole2(name: String, typ: TypeRep, args: List[List[Rep]], visible: List[BoundVal]) = + HOPHole2(name, typ, args, visible filterNot (args.flatten contains _)) def substitute(r: => Rep, defs: Map[String, Rep]): Rep = if (defs isEmpty) r |> inlineBlock else bottomUp(r) { case h @ Hole(n, _) => defs getOrElse(n, h) case h @ SplicedHole(n, _) => defs getOrElse(n, h) + case h: BoundVal => defs getOrElse(h.name, h) // TODO FVs in lambda become BVs too early, this should be changed!! case h => h } |> inlineBlock diff --git a/src/main/scala/squid/ir/fastanf/Rep.scala b/src/main/scala/squid/ir/fastanf/Rep.scala index e46d9990..9bac0b7e 100644 --- a/src/main/scala/squid/ir/fastanf/Rep.scala +++ b/src/main/scala/squid/ir/fastanf/Rep.scala @@ -62,6 +62,8 @@ final case class SplicedHole(name: String, typ: TypeRep) extends Rep with Cached final case class HOPHole(name: String, typ: TypeRep, args: List[List[Symbol]], visible: List[Symbol]) extends Rep with CachedHashCode +final case class HOPHole2(name: String, typ: TypeRep, args: List[List[Rep]], visible: List[Symbol]) extends Rep with CachedHashCode + // TODO intern objects final case class StaticModule(fullName: String) extends Rep with CachedHashCode { val typ = DummyTypeRep diff --git a/src/main/scala/squid/test/Test.scala b/src/main/scala/squid/test/Test.scala index 2a49e446..344e8744 100644 --- a/src/main/scala/squid/test/Test.scala +++ b/src/main/scala/squid/test/Test.scala @@ -2,7 +2,6 @@ package squid package test import squid.ir.FastANF -import squid.ir.fastanf.{DummyTypeRep, SplicedHole} object Test extends App { object Embedding extends FastANF @@ -28,18 +27,19 @@ object Test extends App { //code{p.x} match { // case code"p.x" => println(code"p.y") //} - - def f(args: Int*)(moreArgs: Int*) = 0 - - val p = code"f(1, 2, 3)(Seq(5, 6): _*)" - println(p) - - p match { - case code"f($n*)($fivesix: _*)" => println((n, fivesix)) - } - - val p2 = code"(x: Int) => x + 5" alsoApply println - p2 match { - case code"(y: Int) => y + 5" => - } + // val a = ir"val a = 10.toDouble; val b = readDouble; (a + b) * 5" + // val b = a rewrite { + // case dbg_ir"val x = 10.toDouble; val y = readDouble; $body(x + y)" => ir"$body(222)" + // } alsoApply println + + // ir"(a: Int, b: Int) => a + 1" match { + //// case ir"(x: Int, y: Int) => $body(y):Int" => + // case db_ir"(x: Int, y: Int) => $body(x):Int" => + // } + + val a = ir"(readInt + 111) * .5" rewrite { + case ir"(($n: Int) + 111) * .5" => ir"$n * .25" + } alsoApply println + println(ir"readInt * .25") + assert(a =~= ir"readInt * .25") } diff --git a/src/test/scala/squid/ir/fastir/BasicTests.scala b/src/test/scala/squid/ir/fastir/BasicTests.scala index 83001da2..4f98c60f 100644 --- a/src/test/scala/squid/ir/fastir/BasicTests.scala +++ b/src/test/scala/squid/ir/fastir/BasicTests.scala @@ -69,7 +69,7 @@ class BasicTests extends MyFunSuiteBase(BasicTests.Embedding) { // case _ => assert(false) //} - assert(Embedding.bottomUpPartial(code"println((1, 23))".rep) { case Constant(1) => Constant(42) } == code"println((42,23))".rep) + //assert(Embedding.bottomUpPartial(code"println((1, 23))".rep) { case Constant(1) => Constant(42) } == code"println((42,23))".rep) } test("Transformers") { diff --git a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala index 80123d5b..f9298dbc 100644 --- a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala +++ b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala @@ -2,25 +2,17 @@ package squid package ir package fastir -import utils._ - -class HigherOrderPatternVariables extends MyFunSuite { - import TestDSL.Predef._ +class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVariables.Embedding) { + import HigherOrderPatternVariables.Embedding.Predef._ test("Matching lambda bodies") { val id = ir"(z:Int) => z" ir"(a: Int) => a + 1" matches { - case ir"(x: Int) => $body(x):Int" => - body eqt ir"(_:Int)+1" - } and { - case ir"(x: Int) => $body(x):$t" => - body eqt ir"(_:Int)+1" - eqt(t, irTypeOf[Int]) - } and { - case ir"(x: Int) => ($exp(x):Int)+1" => - exp eqt id + case ir"(x: Int) => $body(x): Int" => assert(body == ir"(_:Int) + 1") + case ir"(x: Int) => $body(x):$t" => assert(body == ir"(_:Int)+1") + case ir"(x: Int) => ($exp(x):Int)+1" => assert(exp == id) } ir"(a: Int, b: Int) => a + 1" matches { @@ -32,35 +24,83 @@ class HigherOrderPatternVariables extends MyFunSuite { ir"(a: Int, b: Int) => a + b" matches { case ir"(x: Int, y: Int) => $body(x):Int" => fail - case ir"(x: Int, y: Int) => $body(x,y):Int" => - body eqt ir"(_:Int)+(_:Int)" - } and { + case ir"(x: Int, y: Int) => $body(x,y):Int" => assert(body == ir"(_:Int)+(_:Int)") + } + + ir"(a: Int, b: Int) => a + b" matches { case ir"(x: Int, y: Int) => ($lhs(y):Int)+($rhs(y):Int)" => fail case ir"(x: Int, y: Int) => ($lhs(x):Int)+($rhs(y):Int)" => - lhs eqt id - rhs eqt id + assert(lhs == id) + assert(rhs == id) } - } test("Matching let-binding bodies") { + // Not implemented error in `letin` + //ir"val a = 0; val b = 1; a + b" matches { + // case ir"val x: Int = $v; $body(x):Int" => + // assert(v == ir"0") + // body matches { + // case ir"(y:Int) => { val x: Int = $w; $body(x,y):Int }" => + // assert(w == ir"1") + // assert(body == ir"(u:Int,v:Int) => (v:Int)+(u:Int)") + // } + //} + } + + test("Non-trivial arguments") { + val id = ir"(z: Int) => z" - ir"val a = 0; val b = 1; a + b" matches { - case ir"val x: Int = $v; $body(x):Int" => - v eqt ir"0" - body matches { - case ir"(y:Int) => { val x: Int = $w; $body(x,y):Int }" => - w eqt ir"1" - body eqt ir"(u:Int,v:Int) => (v:Int)+(u:Int)" - } + ir"(a: Int, b: Int) => a + b" matches { + case ir"(x: Int, y: Int) => $body(x + y): Int" => assert(body == id) + case ir"(x: Int, y: Int) => $body(x): Int" => fail + case ir"(x: Int, y: Int) => $body(y): Int" => fail + } + + ir"(a: Int, b: Int, c: Int) => a + b + c" matches { + case ir"(x: Int, y: Int, z: Int) => $body(x + y, z): Int" => assert(body == ir"(r: Int, s: Int) => r + s") + } + + ir"(a: Int, b: Int, c: Int) => a + b + c" matches { + case ir"(x: Int, y: Int, z: Int) => $body(x + y + z): Int" => assert(body == id) + } + + // TODO `extract` should see the different combinations of `x + y + z` + // case ir"(x: Int, y: Int, z: Int) => $body(x + z, y)" => println(body); assert(body == ir"(r: Int, s: Int) => r + s") + + // TODO doesn't "align", `extract` is too structural + // case ir"(x: Int, y: Int, z: Int) => $body(x, y + z)" => assert(body == ir"(r: Int, s: Int) => r + s") + + ir"(a: Int) => readInt + a" matches { + case ir"(x: Int) => $body(readInt, x): Int" => assert(body == ir"(r: Int, s: Int) => r + s") } - } - test("Find a good name") { - val a = ir"val a = 10.toDouble; val b = a + 1; b" rewrite { - case ir"val x = 10.toDouble; $body(x):Double" => body(ir"readDouble") + ir"(a: Int) => readInt + a" matches { + case ir"(x: Int) => $body(x, readInt): Int" => assert(body == ir"(r: Int, s: Int) => s + r") } - assert(a == ir"val r = readDouble + 1; r") + + ir"(a: Int, b: Int) => readInt + (a + b)" matches { + case ir"(x: Int, y: Int) => $body(readInt, x + y): Int" => assert(body == ir"(r: Int, s: Int) => r + s") + } + + // TODO doesn't "align", `extract` is too structural + //ir"(a: Int, b: Int) => readInt + (a + b)" matches { + // case ir"(x: Int, y: Int) => $body(x + y, readInt): Int" => assert(body == ir"(r: Int, s: Int) => s + r") + //} + } + + //test("Currying") { + // ir"(a: Int, b: Int) => a + b" match { + // case ir"(x: Int, y: Int) => $body(x)(y)" => + // } + //} + + test("Match letbindinds") { + // TODO `apply` should call inline + //val a = ir"val a = 10.toDouble; val b = a + 1; val c = b + 2; c" matches { + // case ir"val x = 10.toDouble; $body(x):Double" => + // assert(ir"$body(42)" == ir"(val a = (x: Int) => (val b = x + 1; val c = b + 2; c)); val tmp = a.apply(42.0); tmp") + //} // val b = ir"val a = 10; val b = readInt; val c = a + b; c" rewrite { // case ir"val x = 10; val y = readInt; $body(x): Int" => ir"$body(42)" @@ -69,5 +109,8 @@ class HigherOrderPatternVariables extends MyFunSuite { // assert(b == ir"readDouble.toInt + 42") // assert(b == ir"val r = readDouble + 1; r") } +} +object HigherOrderPatternVariables { + object Embedding extends FastANF } diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index bcb4e92a..a04e4f82 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -5,6 +5,10 @@ package fastir class RewritingTests extends MyFunSuiteBase(BasicTests.Embedding) { import RewritingTests.Embedding.Predef._ + //object T extends SimpleRuleBasedTransformer with TopDownTransformer { + // val base: DSL.type = DSL + //} + test("Simple rewrites") { val a = ir"123" rewrite { case ir"123" => ir"666" @@ -20,6 +24,14 @@ class RewritingTests extends MyFunSuiteBase(BasicTests.Embedding) { case ir"(${Const(n)}: Int).toDouble" => ir"${Const(n.toDouble)}" } assert(c =~= ir"42.0") + + //assertDoesNotCompile(""" + // T.rewrite { case ir"0.5" => ir"42" } + //""") + + //assertDoesNotCompile(""" + // T.rewrite { case ir"123" => ir"($$n:Int)" } + //""") } test("Rewriting subpatterns") { @@ -28,30 +40,40 @@ class RewritingTests extends MyFunSuiteBase(BasicTests.Embedding) { } assert(a =~= ir"readInt * .25") - val b = ir"(x: Int) => (x-5) * 32" rewrite { - case ir"($b: Int) * 32" => ir"$b" - } - assert(b =~= ir"(x: Int) => x - 5") - - val c = ir"Option(42).get" rewrite { + val b = ir"Option(42).get" rewrite { case ir"Option(($n: Int)).get" => n } - assert(c =~= ir"42") + assert(b =~= ir"42") - val d = ir"val a = Option(42).get; a * 5" rewrite { + val c = ir"val a = Option(42).get; a * 5" rewrite { case ir"Option(($n: Int)).get" => n case ir"($a: Int) * 5" => ir"$a * 2" } - assert(d =~= ir"val a = 42; a * 2") + assert(c =~= ir"val a = 42; a * 2") } test("Rewriting simple expressions only once") { val a = ir"println((50, 60))" rewrite { case ir"($x:Int,$y:Int)" => ir"($y:Int,$x:Int)" case ir"(${Const(n)}:Int)" => Const(n+1) - } alsoApply println + } assert(a =~= ir"println((61,51))") } + + test("Function Rewritings") { + val a = ir"(x: Int) => (x-5) * 32" rewrite { + case ir"($b: Int) * 32" => ir"$b" + } + assert(a =~= ir"(x: Int) => x - 5") + + val b = ir"(x: Int) => (x-5) * 32" rewrite { + case ir"(x: Int) => ($b: Int) * 32" => dbg_ir"val x = 42; (p: Int) => $b + p" + } alsoApply println + + println(ir"val u = 42; (v: Int) => (u - 5) + v") + + assert(b =~= ir"val u = 42; (v: Int) => (u - 5) + v") + } } object RewritingTests { From 65dffc8038cac1000ff7feed3c9225d08eb3e907 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Tue, 31 Oct 2017 18:02:38 +0100 Subject: [PATCH 08/66] Handle effects --- src/main/scala/squid/ir/fastanf/Effects.scala | 75 ++ src/main/scala/squid/ir/fastanf/FastANF.scala | 757 +++++++++--------- src/main/scala/squid/ir/fastanf/Rep.scala | 9 +- .../squid/ir/fastir/RewritingTests.scala | 56 +- 4 files changed, 516 insertions(+), 381 deletions(-) create mode 100644 src/main/scala/squid/ir/fastanf/Effects.scala diff --git a/src/main/scala/squid/ir/fastanf/Effects.scala b/src/main/scala/squid/ir/fastanf/Effects.scala new file mode 100644 index 00000000..61502827 --- /dev/null +++ b/src/main/scala/squid/ir/fastanf/Effects.scala @@ -0,0 +1,75 @@ +package squid.ir.fastanf + +import squid.utils.Bool + +import scala.collection.mutable + +trait Effects { + protected val pureMtds = mutable.Set[MethodSymbol]() + //protected val pureTyps = mutable.Set[TypeSymbol]() + + def addPureMtd(m: MethodSymbol): Unit = pureMtds += m + //def addPureTyp(t: TypeSymbol): Unit = pureTyps += t + + def isPure(r: Rep): Boolean = effect(r) == Pure + def isPure(d: Def): Boolean = defEffect(d) == Pure + + def effect(d: Def): Effect = defEffect(d) + + def effect(r: Rep): Effect = r match { + case lb: LetBinding => defEffect(lb.value) |+| effect(lb.body) + + //case s: Symbol => s.owner match { + // case lb: LetBinding => effect(lb) + // case _ => Pure + //} + case _: Symbol => Pure + + case Ascribe(r, _) => effect(r) + + case Module(r, _, _) => effect(r) + + case Constant(_) | _: Symbol | + StaticModule(_) | NewObject(_) | + Hole(_, _) | SplicedHole(_, _) | + HOPHole(_, _, _, _) | HOPHole2(_, _, _, _) => Pure + } + + def mtdEffect(m: MethodSymbol): Effect = { + //println(m) + if (pureMtds contains m) Pure else Impure + } + + def defEffect(d: Def): Effect = d match { + case l: Lambda => effect(l.body) + case ma: MethodApp => + val selfEff = effect(ma.self) + val argssEff = ma.argss.argssList.map(effect).fold(Pure)(_ |+| _) + val mtdEff = mtdEffect(ma.mtd) + selfEff |+| argssEff |+| mtdEff + case DefHole(_) => Pure + case Unreachable => Pure + } + + sealed trait Effect { + def |+|(e: Effect): Effect + } + + case object Pure extends Effect { + def |+|(e: Effect): Effect = e + } + + case object Impure extends Effect { + def |+|(e: Effect): Effect = Impure + } +} + +trait StandardEffects extends Effects { + addPureMtd(MethodSymbol(TypeSymbol("scala.Int"),"$plus")) + addPureMtd(MethodSymbol(TypeSymbol("scala.Int"),"$times")) + addPureMtd(MethodSymbol(TypeSymbol("scala.Int"), "toDouble")) + addPureMtd(MethodSymbol(TypeSymbol("scala.Int"), "toFloat")) + addPureMtd(MethodSymbol(TypeSymbol("scala.Option$"), "apply")) + addPureMtd(MethodSymbol(TypeSymbol("scala.Option"), "get")) + addPureMtd(MethodSymbol(TypeSymbol("scala.Tuple2$"), "apply")) +} diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index effa0c55..e1dd781d 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -3,11 +3,11 @@ package ir.fastanf import utils._ import lang.{Base, InspectableBase, ScalaCore} -import squid.ir.{Covariant, CurryEncoding, IRException, Invariant} +import squid.ir._ -import scala.collection.immutable.ListMap +import scala.collection.immutable.{ListMap, ListSet} -class FastANF extends InspectableBase with CurryEncoding with ScalaCore { +class FastANF extends InspectableBase with CurryEncoding with StandardEffects with ScalaCore { private[this] implicit val base = this @@ -99,43 +99,81 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { def staticModule(fullName: String): Rep = StaticModule(fullName) def module(prefix: Rep, name: String, typ: TypeRep): Rep = Module(prefix, name, typ) def newObject(typ: TypeRep): Rep = NewObject(typ) - def methodApp(self: Rep, mtd: MtdSymbol, targs: List[TypeRep], argss: List[ArgList], tp: TypeRep): Rep = { - MethodApp(self |> inlineBlock, mtd, targs, argss |> toArgumentLists, tp) |> letbind + def methodApp(self: Rep, mtd: MtdSymbol, targs: List[TypeRep], argss: List[ArgList], tp: TypeRep): Rep = mtd match { + case MethodSymbol(TypeSymbol("squid.lib.package$"), "Imperative") => argss match { + case List(h, t) => + val firstArgss = h.reps + val holes = h.reps.filter { + case Hole(_, _) => true + case _ => false + } + + val lastArgss = t.reps + assert(lastArgss.size == 1) + holes.foldRight(lastArgss.head) { case (h, acc) => + letin(bindVal("tmp", h.typ, Nil), h, acc, acc.typ) + } + } + + + case _ => MethodApp(self |> inlineBlock, mtd, targs, argss |> toArgumentLists, tp) |> letbind } def byName(mkArg: => Rep): Rep = wrapNest(mkArg) def letbind(d: Def): Rep = currentScope += d def inlineBlock(r: Rep): Rep = r |>=? { case lb: LetBinding => + println(s"INLINE: $lb --> $scopes") currentScope += lb + println(s"$scopes") inlineBlock(lb.body) } - override def letin(bound: BoundVal, value: Rep, body: => Rep, bodyType: TypeRep): Rep = { - value match { - case s: Symbol => - s.owner |>? { - case lb: RebindableBinding => - lb.name = bound.name - } - bound rebind s - body - case lb: LetBinding => - // conceptually, does like `inlineBlock`, but additionally rewrites `bound` and renames `lb`'s last binding - val last = lb.last - val boundName = bound.name - bound rebind last.bound - last.body = body - last.name = boundName // TODO make sure we're only renaming an automatically-named binding? - lb - // case c: Constant => bottomUpPartial(body) { case `bound` => c } - case (_:HOPHole) | (_:HOPHole2) | (_:Hole) | (_:SplicedHole) => - ??? // TODO holes should probably be Def's; note that it's not safe to do a substitution for holes - case _ => - withSubs(bound -> value)(body) - // ^ executing `body` will reify some statements into the reification scope, and likely return a symbol - // during this reification, we need all references to `bound` to be replaced by the actual `value` - } + override def letin(bound: BoundVal, value: Rep, body: => Rep, bodyType: TypeRep): Rep = value match { + case s: Symbol => + withSubs(bound, value)(body) + + //s.owner |>? { + // case lb: RebindableBinding => + // lb.name = bound.name + //} + //bound rebind s + //body + case lb: LetBinding => + // conceptually, does like `inlineBlock`, but additionally rewrites `bound` and renames `lb`'s last binding + val last = lb.last + val boundName = bound.name + bound rebind last.bound + last.body = body + last.name = boundName // TODO make sure we're only renaming an automatically-named binding? + lb + // case c: Constant => bottomUpPartial(body) { case `bound` => c } + case h: Hole => + //Wrap construct? How? + //new LetBinding(bound.name, bound, DefHole(h), body) + + // letin(x, Hole, Constant(20)) => `val tmp = defHole; 20;` + + val dh = DefHole(h) |> letbind // flag + + //(dh |>? { + // case bv: BoundVal => bv.owner |>? { + // case lb: LetBinding => + // lb.body = body + // lb + // } + //}).flatten.getOrElse(body) + + //new LetBinding(bound.name, bound, dh, body) alsoApply (currentScope += _) alsoApply (bound.rebind) + withSubs(bound -> dh)(body) + + + case (_:HOPHole) | (_:HOPHole2) | (_:SplicedHole) => + ??? // TODO holes should probably be Def's; note that it's not safe to do a substitution for holes + case _ => + withSubs(bound -> value)(body) + // ^ executing `body` will reify some statements into the reification scope, and likely return a symbol + // during this reification, we need all references to `bound` to be replaced by the actual `value` } var curSub: Map[Symbol,Rep] = Map.empty @@ -146,9 +184,10 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { } override def tryInline(fun: Rep, arg: Rep)(retTp: TypeRep): Rep = { + println(s"tryInline $fun -- $arg") fun match { case lb: LetBinding => lb.value match { - case l: Lambda => letin(l.bound, arg, l.body, l.body.typ) + case l: Lambda => letin(l.bound, arg, l.body, l.body.typ) // flag case _ => super.tryInline(fun, arg)(retTp) } case _ => super.tryInline(fun, arg)(retTp) @@ -160,7 +199,7 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { case _ => Ascribe(self, typ) } - def loadMtdSymbol(typ: TypSymbol, symName: String, index: Option[Int] = None, static: Boolean = false): MtdSymbol = new MethodSymbol(typ, symName) // TODO + def loadMtdSymbol(typ: TypSymbol, symName: String, index: Option[Int] = None, static: Boolean = false): MtdSymbol = MethodSymbol(typ, symName) // TODO object Const extends ConstAPI { def unapply[T: IRType](ir: IR[T,_]): Option[T] = ir.rep match { @@ -198,6 +237,8 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { reinterpretArgList(ma.argss), reinterpretType(ma.typ)) case l: Lambda => newBase.lambda(List(reinterpretBV(l.bound)), reinterpret0(l.body)) + case DefHole(h) => newBase.hole(h.name, reinterpretType(h.typ)) + case Unreachable => unsupported } r match { @@ -229,7 +270,7 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { } def repType(r: Rep): TypeRep = r.typ - def boundValType(bv: BoundVal) = bv.typ + def boundValType(bv: BoundVal): TypeRep = bv.typ // * --- * --- * --- * Implementations of `InspectableBase` methods * --- * --- * --- * @@ -240,14 +281,16 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { def transformRepAndDef(r: Rep)(pre: Rep => Rep, post: Rep => Rep = identity)(preDef: Def => Def, postDef: Def => Def = identity): Rep = { def transformRepAndDef0(r: Rep) = transformRepAndDef(r)(pre, post)(preDef, postDef) - def transformDef(d: Def): Def = (d map preDef match { + def transformDef(d: Def): Def = postDef(preDef(d) match { case App(f, a) => App(transformRepAndDef0(f), transformRepAndDef0(a)) // Note: App is a MethodApp, but we can transform it more efficiently this way case ma: MethodApp => MethodApp(transformRepAndDef0(ma.self), ma.mtd, ma.targs, ma.argss argssMap (transformRepAndDef0(_)), ma.typ) case l: Lambda => // Note: destructive modification of the lambda binding! //new Lambda(l.name, l.bound, l.boundType, transformRepAndDef0(l.body)) l.body = l.body |> transformRepAndDef0 l - }) map postDef + case dh: DefHole => dh + case Unreachable => Unreachable + }) post(pre(r) match { case lb: LetBinding => // Note: destructive modification of the let-binding! @@ -271,23 +314,146 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { def transformRep(r: Rep)(pre: Rep => Rep, post: Rep => Rep = identity): Rep = transformRepAndDef(r)(pre, post)(identity, identity) - protected def extract(xtor: Rep, xtee: Rep): Option[Extract] = extractWithCtx(xtor, xtee)(ListMap.empty) + def firstHole(d: Def)(implicit es: State): Option[Hole] = (for { + DefHole(h @ Hole(name, _)) <- holes(d) + if !(es.ex._1 contains name) && !(es.ex._3 contains name) + } yield h).headOption + + def holes(d: Def): List[DefHole] = d match { + case l: Lambda => Nil // TODO handle + case ma: MethodApp => (ma.self :: ma.argss.argssList).foldLeft(List.empty[DefHole]) { + case (acc, h: Hole) => DefHole(h) :: acc + case (acc, _) => acc + }.reverse + case dh: DefHole => List(dh) + case Unreachable => Nil + } - def extractWithCtx(xtor: Rep, xtee: Rep)(implicit ctx: ListMap[BoundVal, Set[BoundVal]]): Option[Extract] = { - println(s"$xtor\n$xtee with $ctx\n\n") + protected def extract(xtor: Rep, xtee: Rep): Option[Extract] = { + println(s"Extract($xtor, $xtee)") + for { + es <- extractWithState(xtor, xtee)(State.forExtraction(xtor, xtee)) + if es.mks.xtor.isEmpty && es.mks.xtee.isEmpty + } yield es.ex + } + + type Ctx = Map[BoundVal, Set[BoundVal]] + def updateWith(ctx: Ctx)(u: (BoundVal, BoundVal)): Ctx = u match { + case (k, v) => ctx + (k -> (ctx(k) + v)) + } + + case class State(ex: Extract, ctx: Ctx, mks: Markers, matchedBVs: Set[BoundVal], makeUnreachable: Boolean) { + def withExtract(newEx: Extract): State = copy(ex = newEx) + def withCtx(newCtx: Ctx): State = copy(ctx = newCtx) + def withCtx(p: (BoundVal, BoundVal)): State = copy(ctx = updateWith(ctx)(p)) + def updateMarkers(newMks: Markers): State = copy(mks = newMks) + def withoutMarkers(xtorMk: BoundVal, xteeMk: BoundVal): State = copy(mks = Markers(mks.xtor - xtorMk, mks.xtee - xteeMk)) + def withMatchedBV(bv: BoundVal): State = copy(matchedBVs = matchedBVs + bv) + } + object State { + def forRewriting(xtor: Rep, xtee: Rep): State = State(xtor, xtee, true) + def forExtraction(xtor: Rep, xtee: Rep): State = State(xtor, xtee, false) + + private def apply(xtor: Rep, xtee: Rep, makeUnreachable: Bool): State = + State(EmptyExtract, ListMap.empty.withDefaultValue(Set.empty), Markers(xtor, xtee), Set.empty, makeUnreachable) + } + + sealed trait Marker + case object EndPoint extends Marker + case object NonEndPoint extends Marker + + case class Markers(xtor: ListSet[BoundVal], xtee: ListSet[BoundVal]) { + private def marker(ls: ListSet[BoundVal])(bv: BoundVal) = if (ls contains bv) EndPoint else NonEndPoint + def xtorMarker(bv: BoundVal): Marker = marker(xtor)(bv) + def xteeMarker(bv: BoundVal): Marker = marker(xtee)(bv) + } + object Markers { + def apply(xtor: Rep, xtee: Rep): Markers = Markers(extractionStarts(xtor), extractionStarts(xtee)) + + private def extractionStarts(r: Rep): ListSet[BoundVal] = { + def bvs(d: Def, acc: ListSet[BoundVal]): ListSet[BoundVal] = d match { + case _: Lambda => ListSet.empty // TODO The lambda may never be applied. + case ma: MethodApp => (ma.self :: ma.argss.argssList).foldLeft(acc) { + case (acc, bv: BoundVal) => acc - bv + case (acc, _) => acc // Assuming no LBs in self or argument positions + } + case _: DefHole => ListSet.empty + case Unreachable => ListSet.empty + } + + def extractionStarts0(r: Rep, acc: ListSet[BoundVal]): ListSet[BoundVal] = r match { + case lb: LetBinding => effect(lb) match { + case Pure => extractionStarts0(lb.body, bvs(lb.value, acc + lb.bound)) + case Impure => extractionStarts0(lb.body, bvs(lb.value, acc)) + } + case bv: BoundVal => acc - bv + case _ => acc + } + + extractionStarts0(r, ListSet.empty) + } + + //type Signature = List[Int] + //case class EndPoints(endPoints: ListSet[BoundVal], signatures: Map[BoundVal, Signature]) + //def findEndPoints(r: Rep): EndPoints = { + // case class Signature0(impurePos: Int, signature: List[Int]) + // case class EndPoints0(endPoints: ListSet[BoundVal], nodes: Map[BoundVal, Signature0], currImpurePos: Int) + // + // def update(acc: EndPoints0, lb: LetBinding): EndPoints0 = lb.value match { + // case ma: MethodApp => + // val (newAcc, nodes) = (ma.self :: ma.argss.argssList).foldRight((acc, List.empty[Signature0])) { + // case (bv: BoundVal, (acc, nodes)) => (acc.copy(endPoints = acc.endPoints - bv), acc.nodes(bv) :: nodes) + // case (_, acc) => acc + // } + // + // val impurePos = acc.currImpurePos + // + // val sig = for { + // Signature0(childImpurePos, childSig) <- nodes + // childSigComponent <- childSig + // sigComponent = childSigComponent + (childImpurePos - impurePos) + // } yield sigComponent + // + // val sig0 = defEffect(ma) match { + // case Pure => sig + // case Impure => 0 :: sig + // } + // + // newAcc.copy( + // endPoints = newAcc.endPoints + lb.bound, + // nodes = acc.nodes + (lb.bound -> Signature0(impurePos, sig0)), + // currImpurePos = if (isPure(lb)) newAcc.currImpurePos else newAcc.currImpurePos + 1) + // case l: Lambda => ??? + // case DefHole(_) => acc + // case Unreachable => acc + // } + // + // def endPoints0(acc: EndPoints0, r: Rep): EndPoints0 = r match { + // case lb: LetBinding => endPoints0(update(acc, lb), lb.body) + // case _ => acc + // } + // + // val res = endPoints0(EndPoints0(ListSet.empty, Map.empty, 0), r) + // EndPoints(res.endPoints, res.nodes mapValues (_.signature)) + //} + } + + def extractWithState(xtor: Rep, xtee: Rep)(implicit es: State): Option[State] = { + + //println(s"$xtor\n$xtee with $ctx\n\n") def reverse[A, B](m: Map[A, Set[B]]): Map[B, A] = for { (a, bs) <- m b <- bs } yield b -> a - def extractHOPHole(name: String, typ: TypeRep, argss: List[List[Rep]], visible: List[BoundVal]): Option[Extract] = { + def extractHOPHole(name: String, typ: TypeRep, argss: List[List[Rep]], visible: List[BoundVal])(implicit es: State): Option[State] = { type Func = List[List[BoundVal]] -> Rep def emptyFunc(r: Rep) = List.empty[List[BoundVal]] -> r def fargss(f: Func) = f._1 def fbody(f: Func) = f._2 - val ctx0 = ctx.mapValues(_.head) - val invCtx = reverse(ctx) + val ctx0 = es.ctx.mapValues(_.head) + val invCtx = reverse(es.ctx) bottomUpPartial(xtee) { case bv: BoundVal if visible contains invCtx.getOrElse(bv, bv) => return None } @@ -295,7 +461,7 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { val args0 = args.map(bottomUpPartial(_) { case bv: BoundVal => ctx0.getOrElse(bv, bv) }) val xs = args.map(arg => bindVal("hopArg", arg.typ, Nil)) val transformation = (args0 zip xs).toMap - val body0 = bottomUp(fbody(f)) { case r => transformation.getOrElse(r, r) } + val body0 = bottomUp(fbody(f))(r => transformation.getOrElse(r, r)) (xs :: fargss(f)) -> body0 } @@ -305,89 +471,154 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { l = fargss(f).foldRight(fbody(f)) { case (args, body) => wrapConstruct(lambda(args, body)) } e2 = repExtract(name -> l) m <- merge(e1, e2) - } yield m + } yield es withExtract m } - xtor -> xtee match { - case (lb1: LetBinding, lb2: LetBinding) => - val normal = for { - //e1 <- extractWithCtx(lb1.bound, lb2.bound) - e1 <- lb1.boundType extract (lb1.boundType, Covariant) - e2 <- extractValue(lb1.value, lb2.value) - e3 <- extractWithCtx(lb1.body, lb2.body)(ctx + (lb1.bound -> Set(lb2.bound))) - m <- mergeAll(e1, e2, e3) - } yield m - - /* - * For instance when: - * xtor: val x0 = List(Hole(...): _*) - * xtee: val x0 = Seq(1, 2, 3); val x1 = List(x0) - */ - - /* TODO problematic since it will ignore `val aX = 42` - * XTEE: `val a = readInt; val b = 42` - * XTOR: `val aX = 42; val bX = readInt` - * Could swap `a` and `b` if they are pure, and not linked (b references a)? - * Would be nice to have `x + y + z` match `y + x + z` - */ - lazy val lookFurtherInXtee = extractWithCtx(lb1, lb2.body) - // lazy val lookFurtherInXtor = extractWithCtx(lb1.body, lb2) - - normal orElse lookFurtherInXtee //orElse lookFurtherInXtor - - // Matches 42 and 42: Any, is it safe to ignore the typ? - case (_, Ascribe(s, _)) => extractWithCtx(xtor, s) + def extractLBs(lb1: LetBinding, lb2: LetBinding)(implicit es: State): Option[State] = (effect(lb1.value), effect(lb2.value)) match { + case (Impure, Impure) => for { + es1 <- extractDefs(lb1.value, lb2.value) + _ = if (es1.makeUnreachable) lb2.value = Unreachable + c <- extractWithState(lb1.body, lb2.body)(es1 withCtx (lb1.bound, lb2.bound) withMatchedBV lb2.bound) + } yield c + + case (Impure, Pure) => extractWithState(lb1, lb2.body) + + case (Pure, Impure) => + firstHole(lb1.value).fold(extractWithState(lb1.body, lb2)) { + case Hole(name, typ) => for { + e <- typ extract(lb2.value.typ, Covariant) + lb = new LetBinding(lb2.name, lb2.bound, lb2.value, lb2.bound)// alsoApply (currentScope += _) //alsoApply (lb2.bound.rebind) + m <- mergeAll(es.ex, e, repExtract(name ->lb)) + //_ = if (es.makeUnreachable) lb2.value = Unreachable + es2 <- extractWithState(lb1, lb2.body)(es withExtract m withMatchedBV lb2.bound) + } yield es2 + } - case (Ascribe(s, t) , _) => - for { - e1 <- t extract (xtee.typ, Covariant) // t <:< a.typ, which one to use? - e2 <- extractWithCtx(s, xtee) - m <- merge(e1, e2) - } yield m + case (Pure, Pure) => (es.mks.xtorMarker(lb1.bound), es.mks.xteeMarker(lb2.bound)) match { + case (EndPoint, EndPoint) => + extractWithState(lb1.bound, lb2.bound) match { + case Some(es0) => + if (es0.makeUnreachable) lb2.value = Unreachable + extractWithState(lb1.body, lb2.body)(es0 withoutMarkers(lb1.bound, lb2.bound) withMatchedBV lb2.bound) + case None => extractWithState(lb1, lb2.body) + } + case (EndPoint, NonEndPoint) => extractWithState(lb1, lb2.body) + case (NonEndPoint, EndPoint) => extractWithState(lb1.body, lb2) + case (NonEndPoint, NonEndPoint) => extractWithState(lb1.body, lb2.body) + } + } + def extractHole(h: Hole, r: Rep)(implicit es: State): Option[State] = (h, r) match { case (Hole(n, t), bv: BoundVal) => val r = bv.owner match { - case lb: LetBinding => new LetBinding(lb.name, lb.bound, lb.value, lb.bound) // Why doesn't `lb.body = lb.bound; lb` work? + case lb: LetBinding => new LetBinding(lb.name, lb.bound, lb.value, lb.bound) // flag case _ => bv } for { e <- t extract (xtee.typ, Covariant) - m <- merge(e, repExtract(n -> r)) - } yield m + m <- mergeAll(e, es.ex, repExtract(n -> r)) + } yield es withExtract m withMatchedBV bv + + case (Hole(n, t), lb: LetBinding) => for { + e <- t extract (lb.typ, Covariant) + newLB = new LetBinding(lb.name, lb.bound, lb.value, lb.bound) // flag + _ = if (es.makeUnreachable) lb.value = Unreachable + m <- mergeAll(e, es.ex, repExtract(n -> newLB)) + } yield es withExtract m withMatchedBV lb.bound + + case (Hole(n, t), _) => for { + e <- t extract (xtee.typ, Covariant) + m <- mergeAll(e, es.ex, repExtract(n -> xtee)) + } yield es withExtract m + } + + def extractInside(bv: BoundVal, d: Def)(implicit es: State): Option[State] = { + def bvs(d: Def): List[BoundVal] = d match { + case ma: MethodApp => (ma.self :: ma.argss.argssList).foldRight(List.empty[BoundVal]) { + case (bv: BoundVal, acc) => bv :: acc + case (_, acc) => acc + } + case _ => Nil + } + + bvs(d).foldLeft(Option.empty[State]) { case (acc, bv2) => + acc orElse extractWithState(bv, bv2)(es) + } alsoApply(s => println(s"FOO: $s")) + } + + def filledWith(h: Hole)(implicit es: State): Option[Rep] = es.ex._1 get h.name // TODO check in ex._3 - case (Hole(n, t), _) => + println(s"extractWithState: $xtor\n$xtee\n") + xtor -> xtee match { + case (lb1: LetBinding, lb2: LetBinding) => extractLBs(lb1, lb2) + + // Stop at markers? + case (lb: LetBinding, _: Rep) => extractWithState(lb.body, xtee) + + // TODO really need the pure? + case (bv: BoundVal, lb: LetBinding) if isPure(xtor) => + extractWithState(bv, lb.bound) orElse extractInside(bv, lb.value) orElse extractWithState(bv, lb.body) + + case (_: Rep, lb: LetBinding) if lb.value == Unreachable => extractWithState(xtor, lb.body) + + case (_, Ascribe(s, _)) => extractWithState(xtor, s) + + case (Ascribe(s, t) , _) => for { - e <- t extract (xtee.typ, Covariant) - m <- merge(e, repExtract(n -> xtee)) - } yield m + e1 <- t extract (xtee.typ, Covariant) + m <- merge(e1, es.ex) + es2 <- extractWithState(s, xtee)(es withExtract m) + } yield es2 + + case (h: Hole, _) => + filledWith(h) match { + case Some(_) => Some(es) // TODO check if the hole contains what we are trying to extract + //case Some(lb: LetBinding) if xtee == lb.bound => Some(es) + //case Some(r) if xtee == r => Some(es) + //case Some(_) => None + case None => extractHole(h, xtee) + } case (HOPHole(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) + case (h@HOPHole2(name, typ, argss, visible), _) => //println(s"Hole $h -> $xtee\n\n") extractHOPHole(name, typ, argss, visible) case (bv1: BoundVal, bv2: BoundVal) => - if (bv1 == bv2) Some(EmptyExtract) - else for { - candidates <- ctx.get(bv1) - if candidates contains bv2 - } yield EmptyExtract + println(s"EXTRACTIONSTATE IN BV: $es") + println(s"OWNERS: ${bv1.owner} -- ${bv2.owner}") + if (bv1 == bv2 || (es.ctx(bv1) contains bv2)) Some(es) + else (bv1.owner, bv2.owner) match { + case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value) map { es => + if (es.makeUnreachable) lb2.value = Unreachable + es withCtx (lb1.bound, lb2.bound) withMatchedBV lb2.bound + } + case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2) map (_.withCtx(l1.bound, l2.bound)) + case (_: UnboundSymbol, _: UnboundSymbol) => None + } + + case (Constant(v1), Constant(v2)) if v1 == v2 => for { + eTyp <- xtor.typ extract (xtee.typ, Covariant) + m <- merge(eTyp, es.ex) + } yield es withExtract m - case (Constant(v1), Constant(v2)) if v1 == v2 => - xtor.typ extract (xtee.typ, Covariant) // Assuming if they have the same name the type is the same - case (StaticModule(fn1), StaticModule(fn2)) if fn1 == fn2 => Some(EmptyExtract) + case (StaticModule(fn1), StaticModule(fn2)) if fn1 == fn2 => Some(es) // Assuming if they have the same name and prefix the type is the same - case (Module(p1, n1, _), Module(p2, n2, _)) if n1 == n2 => extractWithCtx(p1, p2) + case (Module(p1, n1, _), Module(p2, n2, _)) if n1 == n2 => extractWithState(p1, p2) - case (NewObject(t1), NewObject(t2)) => t1 extract (t2, Covariant) + case (NewObject(t1), NewObject(t2)) => for { + eTyp <- t1 extract (t2, Covariant) + m <- merge(eTyp, es.ex) + } yield es withExtract m case _ => None } - } + } alsoApply (res => println(s"Extract: $res")) protected def spliceExtract(xtor: Rep, args: Args): Option[Extract] = xtor match { // Should check that type matches, but don't see how to access it for Args @@ -409,283 +640,75 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { } override def rewriteRep(xtor: Rep, xtee: Rep, code: Extract => Option[Rep]): Option[Rep] = { - def rewriteRep0(ex: Extract, matchedBVs: Set[BoundVal])(xtor: Rep, xtee: Rep)(implicit ctx: ListMap[BoundVal, Set[BoundVal]]): Option[Rep] = { - def checkRefs(r: Rep): Option[Rep] = { - def refs(r: Rep): Set[BoundVal] = { - def bvsUsed(value: Def): Set[BoundVal] = value match { - case ma: MethodApp => - def bvsInArgss(argss: ArgumentLists): Set[BoundVal] = { - def bvsInArgss0(argss: ArgumentLists, acc: Set[BoundVal]): Set[BoundVal] = argss match { - case ArgumentListCons(h, t) => bvsInArgss0(t, bvsInArgss0(h, acc)) - case ArgumentCons(h, t) => bvsInArgss0(t, bvsInArgss0(h, acc)) - case SplicedArgument(bv: BoundVal) => acc + bv - case bv: BoundVal => acc + bv - case _ => acc - } - - bvsInArgss0(argss, Set.empty) - } - - val selfBV = ma.self match { - case bv: BoundVal => Set(bv) - case _ => Set.empty - } - - selfBV ++ bvsInArgss(ma.argss) - - case l: Lambda => - val bodyBV: Set[BoundVal] = l.body match { - case bv: BoundVal => Set(bv) - case _ => Set.empty - } - - bodyBV + l.bound - } - + def rewriteRepWithState(xtor: Rep, xtee: Rep)(implicit es: State): Option[State] = { + println(s"rewriteRepWithState(\n\t$xtor\n\t$xtee)($es)") - r match { - case lb: LetBinding => bvsUsed(lb.value) ++ refs(lb.body) - case bv: BoundVal => Set(bv) - case _ => Set.empty - } + (xtor, xtee) match { + case (lb1: LetBinding, lb2: LetBinding) => ((effect(lb1), es.mks.xtorMarker(lb1.bound)), (effect(lb2), es.mks.xteeMarker(lb2.bound))) match { + case ((Pure, NonEndPoint), (Pure, NonEndPoint)) => None + case _ => extractWithState(lb1, lb2) // TODO With unreachable handling } - - if ((refs(r) & matchedBVs).isEmpty) Some(r) else None + case _ => extractWithState(xtor, xtee) } - - def mkCode(e: Extract): Option[Rep] = { - for { - c <- code(e) -// c <- checkRefs(c) - } yield c - } - - // TODO function name? - def traverseO[A](r: Rep)(f: LetBinding => Option[A]): Option[A] = r match { - case lb: LetBinding => f(lb) orElse traverseO(lb.body)(f) - case _ => None + } + + def genCode(es: State): Option[Rep] = { + def check(matchedBVs: Set[BoundVal])(r: Rep): Boolean = r match { + case lb: LetBinding => checkDef(matchedBVs)(lb.value) + case bv: BoundVal => !(matchedBVs contains bv) + case _ => true } - - def extractValueWithBV(v: Def, r: Rep): Option[Extract -> Set[BoundVal]] = traverseO(r) { lb => - def getBVs(e: Extract): Set[BoundVal] = { - val a = e._1.values.foldLeft(Set.empty[BoundVal]) { case (acc, r) => - r match { - case lb: LetBinding => acc + lb.bound - case _ => acc - } - } - - val b = e._3.values.flatten.foldLeft(Set.empty[BoundVal]) { case (acc, r) => - r match { - case lb: LetBinding => acc + lb.bound - case _ => acc - } - } - - a ++ b + + def checkDef(matchedBVs: Set[BoundVal])(d: Def): Boolean = d match { + case ma: MethodApp => (ma.self :: ma.argss.argssList).foldLeft(true) { + case (checks, bv: BoundVal) => checks && !(matchedBVs contains bv) + case (checks, lb: LetBinding) => checks && check(matchedBVs)(lb) + case (checks, _) => true } - - extractValue(v, lb.value) map (e => e -> (getBVs(e) + lb.bound)) - } - - def removeBVs(r: Rep, bvs: Set[BoundVal]): Rep = r match { - case lb: LetBinding if bvs contains lb.bound => removeBVs(lb.body, bvs) - case lb: LetBinding => - lb.body = removeBVs(lb.body, bvs) - lb - case _ => r - } - - (xtor, xtee) match { - case (lb1: LetBinding, lb2: LetBinding) => - for { - (e, newBVs) <- extractValueWithBV(lb1.value, lb2) - m <- merge(e, ex) - removed = removeBVs(lb2, newBVs) - r <- rewriteRep0(m, matchedBVs ++ newBVs)(lb1.body, removed)(ctx + (lb1.bound -> newBVs)) - } yield r - - // Match return of `xtor` with something in `xtee` - case (bv: BoundVal, _: LetBinding) => - for { - x <- ctx.get(bv) - newX <- mkCode(ex) - _ = assert(x.size == 1) - xHead = x.head - newLB = newX match { - case _: LetBinding => letin(xHead, newX, xtee, xtee.typ) - case _ => bottomUpPartial(xtee) { case `xHead` => newX } - } - } yield newLB - - case (h: HOPHole, lb: LetBinding) => - for { - e <- extractWithCtx(h, lb) - m <- merge(e, ex) - newX <- mkCode(m) - } yield newX - - case (h2: HOPHole2, lb: LetBinding) => - for { - e <- extractWithCtx(h2, lb) - m <- merge(e, ex) - newX <- mkCode(m) - } yield newX - - case (h: Hole, lb: LetBinding) => - for { - e <- extractWithCtx(h, lb) - m <- merge(e, ex) - newX <- mkCode(m) - // _ = lb.body = newX - } yield newX - - // Match Constant(42) with `value` of the `LetBinding` - case (xtor: Rep, lb: LetBinding) => - type LazyInlined[R] = R -> List[LetBinding] - implicit def lazyInlined[R](r: R): LazyInlined[R] = r -> List.empty - def run(lr: LazyInlined[Rep]): Rep = lr match { - case r -> lbs => lbs.foldLeft(r) { (acc, outerLB) => outerLB.body = acc; outerLB } - } - - // Puts `acc` inside `outer`. - def surroundWithLB(acc: LazyInlined[Rep], outer: LazyInlined[Rep]): LazyInlined[Rep] = outer match { - case (outerLB: LetBinding, lbs) => - outerLB.body = run(acc) - outerLB -> lbs - case _ => throw new IllegalArgumentException - } - - - def rewrite(lb: LetBinding): Option[Rep] = { - def rewriteValue(lb: LetBinding): Option[LazyInlined[Rep]] = { - val rewrittenValue = { - def rewriteArgssReps(argss: ArgumentLists): Option[LazyInlined[ArgumentLists]] = { - def rewriteArgsReps(args: ArgumentList): Option[LazyInlined[ArgumentList]] = args match { - case ArgumentCons(h, t) => for { - (h, outerLBs) <- rewriteRep(h) - (t, innerLBs) <- rewriteArgsReps(t) - } yield ArgumentCons(h, t) -> (innerLBs ::: outerLBs) - - case SplicedArgument(a) => rewriteRep(a) map { case arg -> lbs => SplicedArgument(arg) -> lbs } - case r: Rep => rewriteRep(r) - case NoArguments => Some(NoArguments) - } - - argss match { - case ArgumentListCons(h, t) => for { - (h, outerLBs) <- rewriteArgsReps(h) - (t, innerLBs) <- rewriteArgssReps(t) - } yield ArgumentListCons(h, t) -> (innerLBs ::: outerLBs) - - case args: ArgumentList => rewriteArgsReps(args) - case NoArgumentLists => Some(NoArgumentLists) - } - } - - def rewriteRep(xtee: Rep) = { - def codeWithOuter(e: Extract) = { - def toLazy(r: Rep): LazyInlined[Rep] = { - def go(r: Rep, acc: List[LetBinding]): LazyInlined[Rep] = r match { - case lb: LetBinding => go(lb.body, lb :: acc) - case _ => r -> acc - } - - go(r, List.empty) - } - - mkCode(e) map toLazy - } - - extractWithCtx(xtor, xtee).fold(Option(lazyInlined(xtee)))(codeWithOuter) - } - - lb.value match { - case ma: MethodApp => for { - (self, outerLBs) <- rewriteRep(ma.self) - (argss, innerLBs) <- rewriteArgssReps(ma.argss) - } yield MethodApp(self, ma.mtd, ma.targs, argss, ma.typ) -> (innerLBs ::: outerLBs) - - case l: Lambda => rewriteRep(l.body) map { - case body -> lbs => - l.body = body - l -> lbs - } - } - } - - // Puts the value back into its LB - rewrittenValue map { - case value -> lbs => - lb.value = value - lb -> lbs - } - } - - def rewriteValue0(lb: LetBinding, acc: Option[List[LazyInlined[Rep]]]): Option[List[LazyInlined[Rep]]] = lb.body match { - case innerLB: LetBinding => - rewriteValue0(innerLB, - for { - acc <- acc - lb <- rewriteValue(lb) - } yield lb :: acc) - - case ret => - for { - acc <- acc - lb <- rewriteValue(lb) - } yield lazyInlined(ret) :: lb :: acc - } - - rewriteValue0(lb, Some(List.empty)) map { - case ret :: outerLBs => run(outerLBs.foldLeft(ret)(surroundWithLB)) - case _ => lb - } - } - - rewrite(lb) - - case (_: Rep, _: Rep) => - for { - e <- extractWithCtx(xtor, xtee) - m <- merge(e, ex) - c <- mkCode(m) - } yield c - - // Matching return values - case _ => None + case l: Lambda => !(matchedBVs contains l.bound) && check(matchedBVs)(l.body) + case _ => true } + + for { + code <- code(es.ex) + if check(es.matchedBVs)(code) + } yield code } - - rewriteRep0(EmptyExtract, Set.empty)(xtor, xtee)(ListMap.empty) + + rewriteRepWithState(xtor, xtee)(State.forRewriting(xtor, xtee)) flatMap genCode alsoApply (c => println(s"Code: $c")) } - - def extractValue(v1: Def, v2: Def)(implicit ctx: ListMap[BoundVal, Set[BoundVal]]) = { - //println(s"$v1\n$v2 with $ctx \n\n") + + def extractDefs(v1: Def, v2: Def)(implicit es: State): Option[State] = { + println(s"VALUES: \n\t$v1\n\t$v2 with $es \n\n") (v1, v2) match { + // Has already been matched... + case (_, Unreachable) => None + //case (Unreachable, _) => Some(es) + case (l1: Lambda, l2: Lambda) => for { e1 <- l1.boundType extract (l2.boundType, Covariant) - e2 <- extractWithCtx(l1.body, l2.body)(ctx + (l1.bound -> Set(l2.bound))) - m <- merge(e1, e2) - } yield m + m1 <- merge(e1, es.ex) + es2 <- extractWithState(l1.body, l2.body)(es.withExtract(m1).withCtx(l1.bound -> l2.bound)) + m2 <- merge(es2.ex, m1) + } yield es2 withExtract m2 case (ma1: MethodApp, ma2: MethodApp) if ma1.mtd == ma2.mtd => lazy val targExtract = mergeAll(for { (e1, e2) <- ma1.targs zip ma2.targs } yield e1 extract (e2, Invariant)) // TODO Invariant? Depends on its positions... - def extractArgss(argss1: ArgumentLists, argss2: ArgumentLists): Option[Extract] = { - def extractArgss0(argss1: ArgumentLists, argss2: ArgumentLists, acc: List[Rep]): Option[Extract] = (argss1, argss2) match { + def extractArgss(argss1: ArgumentLists, argss2: ArgumentLists)(implicit es: State): Option[Extract] = { + def extractArgss0(argss1: ArgumentLists, argss2: ArgumentLists, acc: List[Rep])(implicit es: State): Option[Extract] = (argss1, argss2) match { case (ArgumentListCons(h1, t1), ArgumentListCons(h2, t2)) => mergeOpt(extractArgss0(h1, h2, acc), extractArgss0(t1, t2, acc)) case (ArgumentCons(h1, t1), ArgumentCons(h2, t2)) => mergeOpt(extractArgss0(h1, h2, acc), extractArgss0(t1, t2, acc)) - case (SplicedArgument(arg1), SplicedArgument(arg2)) => extractWithCtx(arg1, arg2) + case (SplicedArgument(arg1), SplicedArgument(arg2)) => extractWithState(arg1, arg2)(es) map (_.ex) case (sa: SplicedArgument, ArgumentCons(h, t)) => extractArgss0(sa, t, h :: acc) case (sa: SplicedArgument, r: Rep) => extractArgss0(sa, NoArguments, r :: acc) case (SplicedArgument(arg), NoArguments) => spliceExtract(arg, Args(acc.reverse: _*)) // Reverses list... - case (r1: Rep, r2: Rep) => extractWithCtx(r1, r2) - case (NoArguments, NoArguments) => Some(EmptyExtract) - case (NoArgumentLists, NoArgumentLists) => Some(EmptyExtract) + case (r1: Rep, r2: Rep) => extractWithState(r1, r2)(es) map (_.ex) + case (NoArguments, NoArguments) => Some(es.ex) + case (NoArgumentLists, NoArgumentLists) => Some(es.ex) case _ => None } @@ -693,17 +716,31 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { } for { - e1 <- extractWithCtx(ma1.self, ma2.self) + es1 <- extractWithState(ma1.self, ma2.self)(es) e2 <- targExtract - e3 <- extractArgss(ma1.argss, ma2.argss) + e3 <- extractArgss(ma1.argss, ma2.argss)(es1) e4 <- ma1.typ extract (ma2.typ, Covariant) - m <- mergeAll(e1, e2, e3, e4) - } yield m + m <- mergeAll(e2, e3, e4) + } yield es withExtract m + + case (DefHole(Hole(name, typ)), _) => + for { + e <- typ extract (v2.typ, Covariant) + m <- merge(e, repExtract(name -> wrapConstruct(letbind(v2)))) //wrapconstr + } yield es withExtract m case _ => None } } + def cleanup(r: Rep): Rep = r match { + case lb: LetBinding if lb.value == Unreachable => cleanup(lb.body) + case lb: LetBinding => + lb.body = cleanup(lb.body) + lb + case _ => r + } + // * --- * --- * --- * Implementations of `QuasiBase` methods * --- * --- * --- * def hole(name: String, typ: TypeRep) = Hole(name, typ) @@ -712,14 +749,16 @@ class FastANF extends InspectableBase with CurryEncoding with ScalaCore { def hopHole(name: String, typ: TypeRep, yes: List[List[BoundVal]], no: List[BoundVal]) = HOPHole(name, typ, yes, no) override def hopHole2(name: String, typ: TypeRep, args: List[List[Rep]], visible: List[BoundVal]) = HOPHole2(name, typ, args, visible filterNot (args.flatten contains _)) - def substitute(r: => Rep, defs: Map[String, Rep]): Rep = - if (defs isEmpty) r |> inlineBlock + def substitute(r: => Rep, defs: Map[String, Rep]): Rep = { + println(s"Subs: $r with $defs") + if (defs isEmpty) r //|> inlineBlock // TODO works if I remove this... else bottomUp(r) { - case h @ Hole(n, _) => defs getOrElse(n, h) - case h @ SplicedHole(n, _) => defs getOrElse(n, h) - case h: BoundVal => defs getOrElse(h.name, h) // TODO FVs in lambda become BVs too early, this should be changed!! + case h@Hole(n, _) => defs getOrElse(n, h) + case h@SplicedHole(n, _) => defs getOrElse(n, h) + //case h: BoundVal => defs getOrElse(h.name, h) // TODO FVs in lambda become BVs too early, this should be changed!! case h => h } |> inlineBlock + } // * --- * --- * --- * Implementations of `TypingBase` methods * --- * --- * --- * diff --git a/src/main/scala/squid/ir/fastanf/Rep.scala b/src/main/scala/squid/ir/fastanf/Rep.scala index 9bac0b7e..160957fd 100644 --- a/src/main/scala/squid/ir/fastanf/Rep.scala +++ b/src/main/scala/squid/ir/fastanf/Rep.scala @@ -40,11 +40,14 @@ sealed abstract class Def extends DefOption with DefOrTypeRep with FlatSom[Def] def map(f: Def => Def): Def = f(this) } +case object Unreachable extends Def { + override def typ = ??? +} + /** Expression that can be used as an argument or result; this includes let bindings. */ sealed abstract class Rep extends RepOption with ArgumentList with FlatSom[Rep] { def typ: TypeRep def * = SplicedArgument(this) - def isPure: Bool = true def argssMap(f: Rep => Rep) = f(this) def argssList: List[Rep] = this :: Nil } @@ -115,6 +118,10 @@ final case class App(fun: Rep, arg: Rep)(implicit base: FastANF) extends Def wit doChecks } +final case class DefHole(hole: Hole) extends Def { + def typ: TypeRep = hole.typ +} + /** To avoid useless wrappers/boxing in the most common cases, we have this scheme for argument lists: * Scala example | syntax | representation * ------------------------------------------------------------ diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index a04e4f82..ab7123d7 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -51,29 +51,43 @@ class RewritingTests extends MyFunSuiteBase(BasicTests.Embedding) { } assert(c =~= ir"val a = 42; a * 2") } - - test("Rewriting simple expressions only once") { - val a = ir"println((50, 60))" rewrite { - case ir"($x:Int,$y:Int)" => ir"($y:Int,$x:Int)" - case ir"(${Const(n)}:Int)" => Const(n+1) - } - assert(a =~= ir"println((61,51))") - } - - test("Function Rewritings") { - val a = ir"(x: Int) => (x-5) * 32" rewrite { - case ir"($b: Int) * 32" => ir"$b" + + test("Rewriting with dead-ends") { + val b = ir"Option(42).get; 20" rewrite { + case ir"Option(($n: Int)).get; 20" => n } - assert(a =~= ir"(x: Int) => x - 5") - - val b = ir"(x: Int) => (x-5) * 32" rewrite { - case ir"(x: Int) => ($b: Int) * 32" => dbg_ir"val x = 42; (p: Int) => $b + p" - } alsoApply println - - println(ir"val u = 42; (v: Int) => (u - 5) + v") + assert(b =~= ir"42") + } + + //test("Rewriting with impures") { + // val a = ir"val a = readInt; val b = readInt; (a + b) * 0.5" rewrite { + // case ir"(($h1: Int) + ($h2: Int)) * 0.5" => dbg_ir"($h1 * $h2) + 42.0" + // } + // assert(a =~= ir"(readInt + readInt) + 42.0") + //} - assert(b =~= ir"val u = 42; (v: Int) => (u - 5) + v") - } + //test("Rewriting simple expressions only once") { + // val a = ir"println((50, 60))" rewrite { + // case ir"($x:Int,$y:Int)" => ir"($y:Int,$x:Int)" + // case ir"(${Const(n)}:Int)" => Const(n+1) + // } + // assert(a =~= ir"println((61,51))") + //} + // + //test("Function Rewritings") { + // val a = ir"(x: Int) => (x-5) * 32" rewrite { + // case ir"($b: Int) * 32" => ir"$b" + // } + // assert(a =~= ir"(x: Int) => x - 5") + // + // val b = ir"(x: Int) => (x-5) * 32" rewrite { + // case ir"(x: Int) => ($b: Int) * 32" => dbg_ir"val x = 42; (p: Int) => $b + p" + // } alsoApply println + // + // println(ir"val u = 42; (v: Int) => (u - 5) + v") + // + // assert(b =~= ir"val u = 42; (v: Int) => (u - 5) + v") + //} } object RewritingTests { From fa8929a699d34c30049dc61d7743ee1abff2295f Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Mon, 20 Nov 2017 17:45:55 +0100 Subject: [PATCH 09/66] Fix HOPHole, defhole, and some cleaning --- src/main/scala/squid/ir/fastanf/Effects.scala | 3 +- src/main/scala/squid/ir/fastanf/FastANF.scala | 337 ++++++++++-------- 2 files changed, 196 insertions(+), 144 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/Effects.scala b/src/main/scala/squid/ir/fastanf/Effects.scala index 61502827..4256f858 100644 --- a/src/main/scala/squid/ir/fastanf/Effects.scala +++ b/src/main/scala/squid/ir/fastanf/Effects.scala @@ -47,7 +47,7 @@ trait Effects { val argssEff = ma.argss.argssList.map(effect).fold(Pure)(_ |+| _) val mtdEff = mtdEffect(ma.mtd) selfEff |+| argssEff |+| mtdEff - case DefHole(_) => Pure + case DefHole(_) => Impure case Unreachable => Pure } @@ -72,4 +72,5 @@ trait StandardEffects extends Effects { addPureMtd(MethodSymbol(TypeSymbol("scala.Option$"), "apply")) addPureMtd(MethodSymbol(TypeSymbol("scala.Option"), "get")) addPureMtd(MethodSymbol(TypeSymbol("scala.Tuple2$"), "apply")) + addPureMtd(MethodSymbol(TypeSymbol("squid.lib.package$"),"uncurried2")) } diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index e1dd781d..380898c2 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -129,9 +129,11 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi inlineBlock(lb.body) } - override def letin(bound: BoundVal, value: Rep, body: => Rep, bodyType: TypeRep): Rep = value match { - case s: Symbol => - withSubs(bound, value)(body) + override def letin(bound: BoundVal, value: Rep, body: => Rep, bodyType: TypeRep): Rep = { + println(s"letin: $value --> $bound") + value match { + case s: Symbol => + withSubs(bound, value)(body) //s.owner |>? { // case lb: RebindableBinding => @@ -139,41 +141,41 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi //} //bound rebind s //body - case lb: LetBinding => - // conceptually, does like `inlineBlock`, but additionally rewrites `bound` and renames `lb`'s last binding - val last = lb.last - val boundName = bound.name - bound rebind last.bound - last.body = body - last.name = boundName // TODO make sure we're only renaming an automatically-named binding? - lb - // case c: Constant => bottomUpPartial(body) { case `bound` => c } - case h: Hole => - //Wrap construct? How? - //new LetBinding(bound.name, bound, DefHole(h), body) - - // letin(x, Hole, Constant(20)) => `val tmp = defHole; 20;` - - val dh = DefHole(h) |> letbind // flag - - //(dh |>? { - // case bv: BoundVal => bv.owner |>? { - // case lb: LetBinding => - // lb.body = body - // lb - // } - //}).flatten.getOrElse(body) - - //new LetBinding(bound.name, bound, dh, body) alsoApply (currentScope += _) alsoApply (bound.rebind) - withSubs(bound -> dh)(body) - - - case (_:HOPHole) | (_:HOPHole2) | (_:SplicedHole) => - ??? // TODO holes should probably be Def's; note that it's not safe to do a substitution for holes - case _ => - withSubs(bound -> value)(body) + case lb: LetBinding => + // conceptually, does like `inlineBlock`, but additionally rewrites `bound` and renames `lb`'s last binding + val last = lb.last + val boundName = bound.name + bound rebind last.bound + last.body = body + last.name = boundName // TODO make sure we're only renaming an automatically-named binding? + lb + // case c: Constant => bottomUpPartial(body) { case `bound` => c } + case h: Hole => + //Wrap construct? How? + + // letin(x, Hole, Constant(20)) => `val tmp = defHole; 20;` + + val dh = wrapConstruct(new LetBinding(bound.name, bound, DefHole(h), bound) alsoApply bound.rebind) // flag wrapConstruct? + + //(dh |>? { + // case bv: BoundVal => bv.owner |>? { + // case lb: LetBinding => + // lb.body = body + // lb + // } + //}).flatten.getOrElse(body) + + //new LetBinding(bound.name, bound, dh, body) alsoApply (currentScope += _) alsoApply (bound.rebind) + withSubs(bound -> dh)(body) + + + case (_:HOPHole) | (_:HOPHole2) | (_:SplicedHole) => + ??? // TODO holes should probably be Def's; note that it's not safe to do a substitution for holes + case _ => + withSubs(bound -> value)(body) // ^ executing `body` will reify some statements into the reification scope, and likely return a symbol // during this reification, we need all references to `bound` to be replaced by the actual `value` + } } var curSub: Map[Symbol,Rep] = Map.empty @@ -187,7 +189,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi println(s"tryInline $fun -- $arg") fun match { case lb: LetBinding => lb.value match { - case l: Lambda => letin(l.bound, arg, l.body, l.body.typ) // flag + case l: Lambda => letin(l.bound, arg, l.body, l.body.typ) case _ => super.tryInline(fun, arg)(retTp) } case _ => super.tryInline(fun, arg)(retTp) @@ -314,21 +316,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def transformRep(r: Rep)(pre: Rep => Rep, post: Rep => Rep = identity): Rep = transformRepAndDef(r)(pre, post)(identity, identity) - def firstHole(d: Def)(implicit es: State): Option[Hole] = (for { - DefHole(h @ Hole(name, _)) <- holes(d) - if !(es.ex._1 contains name) && !(es.ex._3 contains name) - } yield h).headOption - - def holes(d: Def): List[DefHole] = d match { - case l: Lambda => Nil // TODO handle - case ma: MethodApp => (ma.self :: ma.argss.argssList).foldLeft(List.empty[DefHole]) { - case (acc, h: Hole) => DefHole(h) :: acc - case (acc, _) => acc - }.reverse - case dh: DefHole => List(dh) - case Unreachable => Nil - } - protected def extract(xtor: Rep, xtee: Rep): Option[Extract] = { println(s"Extract($xtor, $xtee)") for { @@ -343,12 +330,12 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } case class State(ex: Extract, ctx: Ctx, mks: Markers, matchedBVs: Set[BoundVal], makeUnreachable: Boolean) { - def withExtract(newEx: Extract): State = copy(ex = newEx) + def withNewExtract(newEx: Extract): State = copy(ex = newEx) def withCtx(newCtx: Ctx): State = copy(ctx = newCtx) def withCtx(p: (BoundVal, BoundVal)): State = copy(ctx = updateWith(ctx)(p)) def updateMarkers(newMks: Markers): State = copy(mks = newMks) def withoutMarkers(xtorMk: BoundVal, xteeMk: BoundVal): State = copy(mks = Markers(mks.xtor - xtorMk, mks.xtee - xteeMk)) - def withMatchedBV(bv: BoundVal): State = copy(matchedBVs = matchedBVs + bv) + def withMatched(bv: BoundVal): State = copy(matchedBVs = matchedBVs + bv) } object State { def forRewriting(xtor: Rep, xtee: Rep): State = State(xtor, xtee, true) @@ -447,6 +434,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } yield b -> a def extractHOPHole(name: String, typ: TypeRep, argss: List[List[Rep]], visible: List[BoundVal])(implicit es: State): Option[State] = { + println("EXTRACTINGHOPHOLE") type Func = List[List[BoundVal]] -> Rep def emptyFunc(r: Rep) = List.empty[List[BoundVal]] -> r def fargss(f: Func) = f._1 @@ -455,82 +443,141 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi val ctx0 = es.ctx.mapValues(_.head) val invCtx = reverse(es.ctx) - bottomUpPartial(xtee) { case bv: BoundVal if visible contains invCtx.getOrElse(bv, bv) => return None } + println(s"INVCTX: $invCtx") + //bottomUpPartial(xtee) { case bv: BoundVal if visible contains invCtx.getOrElse(bv, bv) => return None } + + def changeRepBVs(r: Rep)(f: BoundVal => Option[BoundVal]): Rep = r match { + case bv: BoundVal => + f(bv) match { + case Some(bv0) => bv0 + case None => bv.owner match { + case lb: LetBinding => + lb.value = changeDefBVs(lb.value)(f) + lb + case _ => r + } + } + case lb: LetBinding => + lb.value = changeDefBVs(lb.value)(f) + lb + + case _ => r + } + def changeDefBVs(d: Def)(f: BoundVal => Option[BoundVal]): Def = d match { + case ma: MethodApp => + MethodApp( + changeRepBVs(ma.self)(f), + ma.mtd, + ma.targs, + ma.argss.argssMap(r => changeRepBVs(r)(f)), + ma.typ + ) + + case l: Lambda => + new Lambda( + l.name, + l.bound, + l.boundType, + changeRepBVs(l.body)(f) + ) + + case _ => d + } - def extendFunc(args: List[Rep], f: Func): Func = { - val args0 = args.map(bottomUpPartial(_) { case bv: BoundVal => ctx0.getOrElse(bv, bv) }) + def extendFunc(args: List[Rep], f: Option[Func]): Option[Func] = { + println(s"ARGS: $args") + val args0 = args.map { + case bv: BoundVal => bv.owner match { + case lb: LetBinding => + lb.body = lb.bound // Cut out the HOPHole + changeRepBVs(bv)(ctx0.get) + } + case arg => changeRepBVs(arg)(ctx0.get) + } // Traverse bottom up (parent, etc) + + println(s"ARGS0: $args0") + val xs = args.map(arg => bindVal("hopArg", arg.typ, Nil)) - val transformation = (args0 zip xs).toMap - val body0 = bottomUp(fbody(f))(r => transformation.getOrElse(r, r)) - (xs :: fargss(f)) -> body0 + + val transformations = (args0 zip xs).toMap + + println(s"Transformation: $transformations") + + val after = for { + f <- f + _ = println(s"BEFORE $f") + body0 <- transformations.foldLeft(Option(fbody(f))) { + case (Some(body), (xtor, res)) => rewriteRep0(xtor, body, _ => Some(res))(State.forRewriting(xtor, body), true) + case _ => None + } + } yield (xs :: fargss(f)) -> body0 + + println(s"AFTER: $after") + + after } for { e1 <- typ extract (xtee.typ, Covariant) - f = argss.foldRight(emptyFunc(xtee))(extendFunc) + f <- argss.foldRight(Option(emptyFunc(xtee)))(extendFunc) + _ = println(s"F: $f") l = fargss(f).foldRight(fbody(f)) { case (args, body) => wrapConstruct(lambda(args, body)) } e2 = repExtract(name -> l) - m <- merge(e1, e2) - } yield es withExtract m + m <- mergeAll(e1, e2, es.ex) + } yield es withNewExtract m } def extractLBs(lb1: LetBinding, lb2: LetBinding)(implicit es: State): Option[State] = (effect(lb1.value), effect(lb2.value)) match { case (Impure, Impure) => for { es1 <- extractDefs(lb1.value, lb2.value) _ = if (es1.makeUnreachable) lb2.value = Unreachable - c <- extractWithState(lb1.body, lb2.body)(es1 withCtx (lb1.bound, lb2.bound) withMatchedBV lb2.bound) + c <- extractWithState(lb1.body, lb2.body)(es1 withCtx (lb1.bound, lb2.bound) withMatched lb2.bound) } yield c - - case (Impure, Pure) => extractWithState(lb1, lb2.body) - - case (Pure, Impure) => - firstHole(lb1.value).fold(extractWithState(lb1.body, lb2)) { - case Hole(name, typ) => for { - e <- typ extract(lb2.value.typ, Covariant) - lb = new LetBinding(lb2.name, lb2.bound, lb2.value, lb2.bound)// alsoApply (currentScope += _) //alsoApply (lb2.bound.rebind) - m <- mergeAll(es.ex, e, repExtract(name ->lb)) - //_ = if (es.makeUnreachable) lb2.value = Unreachable - es2 <- extractWithState(lb1, lb2.body)(es withExtract m withMatchedBV lb2.bound) - } yield es2 - } - + case (Pure, Pure) => (es.mks.xtorMarker(lb1.bound), es.mks.xteeMarker(lb2.bound)) match { case (EndPoint, EndPoint) => extractWithState(lb1.bound, lb2.bound) match { case Some(es0) => if (es0.makeUnreachable) lb2.value = Unreachable - extractWithState(lb1.body, lb2.body)(es0 withoutMarkers(lb1.bound, lb2.bound) withMatchedBV lb2.bound) + extractWithState(lb1.body, lb2.body)(es0 withoutMarkers(lb1.bound, lb2.bound) withMatched lb2.bound) case None => extractWithState(lb1, lb2.body) } case (EndPoint, NonEndPoint) => extractWithState(lb1, lb2.body) case (NonEndPoint, EndPoint) => extractWithState(lb1.body, lb2) - case (NonEndPoint, NonEndPoint) => extractWithState(lb1.body, lb2.body) + case (NonEndPoint, NonEndPoint) => extractWithState(lb1.body, lb2) } - } - def extractHole(h: Hole, r: Rep)(implicit es: State): Option[State] = (h, r) match { - case (Hole(n, t), bv: BoundVal) => - val r = bv.owner match { - case lb: LetBinding => new LetBinding(lb.name, lb.bound, lb.value, lb.bound) // flag - case _ => bv - } + case (Impure, Pure) => extractWithState(lb1, lb2.body) + case (Pure, Impure) => extractWithState(lb1.body, lb2) + } - for { + def extractHole(h: Hole, r: Rep)(implicit es: State): Option[State] = { + println(s"ExtractHole: $h --> $r") + + (h, r) match { + case (Hole(n, t), bv: BoundVal) => + //val r = bv.owner match { + // case lb: LetBinding => new LetBinding(lb.name, lb.bound, lb.value, lb.bound) // flag + // case _ => bv + //} + + for { + e <- t extract (xtee.typ, Covariant) + m <- mergeAll(e, es.ex, repExtract(n -> bv)) + } yield es withNewExtract m withMatched bv + + case (Hole(n, t), lb: LetBinding) => for { + e <- t extract (lb.typ, Covariant) + newLB = wrapConstruct(letbind(lb.value))// flag + _ = if (es.makeUnreachable) lb.value = Unreachable + m <- mergeAll(e, es.ex, repExtract(n -> newLB)) + } yield es withNewExtract m withMatched lb.bound + + case (Hole(n, t), _) => for { e <- t extract (xtee.typ, Covariant) - m <- mergeAll(e, es.ex, repExtract(n -> r)) - } yield es withExtract m withMatchedBV bv - - case (Hole(n, t), lb: LetBinding) => for { - e <- t extract (lb.typ, Covariant) - newLB = new LetBinding(lb.name, lb.bound, lb.value, lb.bound) // flag - _ = if (es.makeUnreachable) lb.value = Unreachable - m <- mergeAll(e, es.ex, repExtract(n -> newLB)) - } yield es withExtract m withMatchedBV lb.bound - - case (Hole(n, t), _) => for { - e <- t extract (xtee.typ, Covariant) - m <- mergeAll(e, es.ex, repExtract(n -> xtee)) - } yield es withExtract m + m <- mergeAll(e, es.ex, repExtract(n -> xtee)) + } yield es withNewExtract m + } } def extractInside(bv: BoundVal, d: Def)(implicit es: State): Option[State] = { @@ -547,18 +594,33 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } alsoApply(s => println(s"FOO: $s")) } - def filledWith(h: Hole)(implicit es: State): Option[Rep] = es.ex._1 get h.name // TODO check in ex._3 + def contentsOf(h: Hole)(implicit es: State): Option[Rep] = es.ex._1 get h.name // TODO check in ex._3, return Option[List[Rep]] println(s"extractWithState: $xtor\n$xtee\n") xtor -> xtee match { + case (h: Hole, lb: LetBinding) => contentsOf(h) match { + case Some(lb1: LetBinding) if lb1.value == lb.value => Some(es) + case Some(_) => None + case None => extractHole(h, xtee) + } + + case (h: Hole, _) => contentsOf(h) match { + case Some(`xtee`) => Some(es) + case Some(_) => None + case None => extractHole(h, xtee) + } + + case (h@HOPHole2(name, typ, argss, visible), _) => + //println(s"Hole $h -> $xtee\n\n") + extractHOPHole(name, typ, argss, visible) + case (lb1: LetBinding, lb2: LetBinding) => extractLBs(lb1, lb2) // Stop at markers? case (lb: LetBinding, _: Rep) => extractWithState(lb.body, xtee) - // TODO really need the pure? - case (bv: BoundVal, lb: LetBinding) if isPure(xtor) => - extractWithState(bv, lb.bound) orElse extractInside(bv, lb.value) orElse extractWithState(bv, lb.body) + case (bv: BoundVal, lb: LetBinding) => + extractWithState(bv, lb.bound) orElse extractInside(bv, lb.value) orElse extractWithState(bv, lb.body) case (_: Rep, lb: LetBinding) if lb.value == Unreachable => extractWithState(xtor, lb.body) @@ -568,23 +630,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi for { e1 <- t extract (xtee.typ, Covariant) m <- merge(e1, es.ex) - es2 <- extractWithState(s, xtee)(es withExtract m) + es2 <- extractWithState(s, xtee)(es withNewExtract m) } yield es2 - case (h: Hole, _) => - filledWith(h) match { - case Some(_) => Some(es) // TODO check if the hole contains what we are trying to extract - //case Some(lb: LetBinding) if xtee == lb.bound => Some(es) - //case Some(r) if xtee == r => Some(es) - //case Some(_) => None - case None => extractHole(h, xtee) - } - case (HOPHole(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) - - case (h@HOPHole2(name, typ, argss, visible), _) => - //println(s"Hole $h -> $xtee\n\n") - extractHOPHole(name, typ, argss, visible) case (bv1: BoundVal, bv2: BoundVal) => println(s"EXTRACTIONSTATE IN BV: $es") @@ -593,7 +642,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi else (bv1.owner, bv2.owner) match { case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value) map { es => if (es.makeUnreachable) lb2.value = Unreachable - es withCtx (lb1.bound, lb2.bound) withMatchedBV lb2.bound + es withCtx (lb1.bound, lb2.bound) withMatched lb2.bound } case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2) map (_.withCtx(l1.bound, l2.bound)) case (_: UnboundSymbol, _: UnboundSymbol) => None @@ -602,7 +651,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (Constant(v1), Constant(v2)) if v1 == v2 => for { eTyp <- xtor.typ extract (xtee.typ, Covariant) m <- merge(eTyp, es.ex) - } yield es withExtract m + } yield es withNewExtract m // Assuming if they have the same name the type is the same @@ -614,7 +663,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (NewObject(t1), NewObject(t2)) => for { eTyp <- t1 extract (t2, Covariant) m <- merge(eTyp, es.ex) - } yield es withExtract m + } yield es withNewExtract m case _ => None } @@ -639,14 +688,17 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => throw IRException(s"Trying to splice-extract with invalid extractor $xtor") } - override def rewriteRep(xtor: Rep, xtee: Rep, code: Extract => Option[Rep]): Option[Rep] = { + override def rewriteRep(xtor: Rep, xtee: Rep, code: Extract => Option[Rep]): Option[Rep] = + rewriteRep0(xtor, xtee, code)(State.forRewriting(xtor, xtee), false) + + def rewriteRep0(xtor: Rep, xtee: Rep, code: Extract => Option[Rep])(es: State, internalRec: Boolean): Option[Rep] = { def rewriteRepWithState(xtor: Rep, xtee: Rep)(implicit es: State): Option[State] = { println(s"rewriteRepWithState(\n\t$xtor\n\t$xtee)($es)") (xtor, xtee) match { - case (lb1: LetBinding, lb2: LetBinding) => ((effect(lb1), es.mks.xtorMarker(lb1.bound)), (effect(lb2), es.mks.xteeMarker(lb2.bound))) match { + case (lb1: LetBinding, lb2: LetBinding) if !internalRec => ((effect(lb1), es.mks.xtorMarker(lb1.bound)), (effect(lb2), es.mks.xteeMarker(lb2.bound))) match { case ((Pure, NonEndPoint), (Pure, NonEndPoint)) => None - case _ => extractWithState(lb1, lb2) // TODO With unreachable handling + case _ => extractWithState(lb1, lb2) // TODO With unreachable handling TODO why did I mean? } case _ => extractWithState(xtor, xtee) } @@ -671,11 +723,11 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi for { code <- code(es.ex) - if check(es.matchedBVs)(code) + if check(es.matchedBVs)(cleanup(code)) } yield code } - - rewriteRepWithState(xtor, xtee)(State.forRewriting(xtor, xtee)) flatMap genCode alsoApply (c => println(s"Code: $c")) + + rewriteRepWithState(xtor, xtee)(es) flatMap genCode alsoApply (c => println(s"Code: $c")) } def extractDefs(v1: Def, v2: Def)(implicit es: State): Option[State] = { @@ -689,9 +741,9 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi for { e1 <- l1.boundType extract (l2.boundType, Covariant) m1 <- merge(e1, es.ex) - es2 <- extractWithState(l1.body, l2.body)(es.withExtract(m1).withCtx(l1.bound -> l2.bound)) + es2 <- extractWithState(l1.body, l2.body)(es withNewExtract m1 withCtx l1.bound -> l2.bound) m2 <- merge(es2.ex, m1) - } yield es2 withExtract m2 + } yield es2 withNewExtract m2 case (ma1: MethodApp, ma2: MethodApp) if ma1.mtd == ma2.mtd => lazy val targExtract = mergeAll(for { @@ -721,24 +773,23 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi e3 <- extractArgss(ma1.argss, ma2.argss)(es1) e4 <- ma1.typ extract (ma2.typ, Covariant) m <- mergeAll(e2, e3, e4) - } yield es withExtract m + } yield es withNewExtract m - case (DefHole(Hole(name, typ)), _) => - for { - e <- typ extract (v2.typ, Covariant) - m <- merge(e, repExtract(name -> wrapConstruct(letbind(v2)))) //wrapconstr - } yield es withExtract m + case (DefHole(h), _) => extractWithState(h, wrapConstruct(letbind(v2))) case _ => None } } - def cleanup(r: Rep): Rep = r match { - case lb: LetBinding if lb.value == Unreachable => cleanup(lb.body) - case lb: LetBinding => - lb.body = cleanup(lb.body) - lb - case _ => r + def cleanup(r: Rep): Rep = { + println("CLEANING!!") + r match { + case lb: LetBinding if lb.value == Unreachable => cleanup(lb.body) + case lb: LetBinding => + lb.body = cleanup(lb.body) + lb + case _ => r + } } // * --- * --- * --- * Implementations of `QuasiBase` methods * --- * --- * --- * From 7b7dea574cd1d5c7edcbe0db8965f1024c32c8d5 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Mon, 20 Nov 2017 18:25:09 +0100 Subject: [PATCH 10/66] Handle when HOPHole arg is simply a BV --- src/main/scala/squid/ir/fastanf/FastANF.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 380898c2..139045c3 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -491,6 +491,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case lb: LetBinding => lb.body = lb.bound // Cut out the HOPHole changeRepBVs(bv)(ctx0.get) + case _ => ctx0.getOrElse(bv, bv) } case arg => changeRepBVs(arg)(ctx0.get) } // Traverse bottom up (parent, etc) @@ -507,6 +508,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi f <- f _ = println(s"BEFORE $f") body0 <- transformations.foldLeft(Option(fbody(f))) { + case (Some(body), (bv: BoundVal, res)) => Some(changeRepBVs(body)(bv => Some(res))) case (Some(body), (xtor, res)) => rewriteRep0(xtor, body, _ => Some(res))(State.forRewriting(xtor, body), true) case _ => None } @@ -645,7 +647,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi es withCtx (lb1.bound, lb2.bound) withMatched lb2.bound } case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2) map (_.withCtx(l1.bound, l2.bound)) - case (_: UnboundSymbol, _: UnboundSymbol) => None + //case (_: UnboundSymbol, _: UnboundSymbol) => None + case _ => None } case (Constant(v1), Constant(v2)) if v1 == v2 => for { From aff2a998a98598b2f24eeeb5b782fe9addaa5ea3 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Tue, 21 Nov 2017 14:35:36 +0100 Subject: [PATCH 11/66] Don't try to extract again failed extractions --- src/main/scala/squid/ir/fastanf/FastANF.scala | 307 +++++++++--------- 1 file changed, 159 insertions(+), 148 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 139045c3..ec7739bc 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -319,30 +319,32 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi protected def extract(xtor: Rep, xtee: Rep): Option[Extract] = { println(s"Extract($xtor, $xtee)") for { - es <- extractWithState(xtor, xtee)(State.forExtraction(xtor, xtee)) + es <- extractWithState(xtor, xtee)(State.forExtraction(xtor, xtee)).fold(_ => None, Some(_)) if es.mks.xtor.isEmpty && es.mks.xtee.isEmpty } yield es.ex } type Ctx = Map[BoundVal, Set[BoundVal]] - def updateWith(ctx: Ctx)(u: (BoundVal, BoundVal)): Ctx = u match { - case (k, v) => ctx + (k -> (ctx(k) + v)) + def updateWith(m: Map[BoundVal, Set[BoundVal]])(u: (BoundVal, BoundVal)): Map[BoundVal, Set[BoundVal]] = u match { + case (k, v) => m + (k -> (m(k) + v)) } - case class State(ex: Extract, ctx: Ctx, mks: Markers, matchedBVs: Set[BoundVal], makeUnreachable: Boolean) { + type ExtractState = Either[State, State] + case class State(ex: Extract, ctx: Ctx, mks: Markers, matchedBVs: Set[BoundVal], failedMatches: Map[BoundVal, Set[BoundVal]], makeUnreachable: Boolean) { def withNewExtract(newEx: Extract): State = copy(ex = newEx) def withCtx(newCtx: Ctx): State = copy(ctx = newCtx) def withCtx(p: (BoundVal, BoundVal)): State = copy(ctx = updateWith(ctx)(p)) def updateMarkers(newMks: Markers): State = copy(mks = newMks) def withoutMarkers(xtorMk: BoundVal, xteeMk: BoundVal): State = copy(mks = Markers(mks.xtor - xtorMk, mks.xtee - xteeMk)) def withMatched(bv: BoundVal): State = copy(matchedBVs = matchedBVs + bv) + def withFailed(p: (BoundVal, BoundVal)): State = copy(failedMatches = updateWith(failedMatches)(p)) } object State { def forRewriting(xtor: Rep, xtee: Rep): State = State(xtor, xtee, true) def forExtraction(xtor: Rep, xtee: Rep): State = State(xtor, xtee, false) private def apply(xtor: Rep, xtee: Rep, makeUnreachable: Bool): State = - State(EmptyExtract, ListMap.empty.withDefaultValue(Set.empty), Markers(xtor, xtee), Set.empty, makeUnreachable) + State(EmptyExtract, ListMap.empty.withDefaultValue(Set.empty), Markers(xtor, xtee), Set.empty, Map.empty.withDefaultValue(Set.empty), makeUnreachable) } sealed trait Marker @@ -379,53 +381,9 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi extractionStarts0(r, ListSet.empty) } - - //type Signature = List[Int] - //case class EndPoints(endPoints: ListSet[BoundVal], signatures: Map[BoundVal, Signature]) - //def findEndPoints(r: Rep): EndPoints = { - // case class Signature0(impurePos: Int, signature: List[Int]) - // case class EndPoints0(endPoints: ListSet[BoundVal], nodes: Map[BoundVal, Signature0], currImpurePos: Int) - // - // def update(acc: EndPoints0, lb: LetBinding): EndPoints0 = lb.value match { - // case ma: MethodApp => - // val (newAcc, nodes) = (ma.self :: ma.argss.argssList).foldRight((acc, List.empty[Signature0])) { - // case (bv: BoundVal, (acc, nodes)) => (acc.copy(endPoints = acc.endPoints - bv), acc.nodes(bv) :: nodes) - // case (_, acc) => acc - // } - // - // val impurePos = acc.currImpurePos - // - // val sig = for { - // Signature0(childImpurePos, childSig) <- nodes - // childSigComponent <- childSig - // sigComponent = childSigComponent + (childImpurePos - impurePos) - // } yield sigComponent - // - // val sig0 = defEffect(ma) match { - // case Pure => sig - // case Impure => 0 :: sig - // } - // - // newAcc.copy( - // endPoints = newAcc.endPoints + lb.bound, - // nodes = acc.nodes + (lb.bound -> Signature0(impurePos, sig0)), - // currImpurePos = if (isPure(lb)) newAcc.currImpurePos else newAcc.currImpurePos + 1) - // case l: Lambda => ??? - // case DefHole(_) => acc - // case Unreachable => acc - // } - // - // def endPoints0(acc: EndPoints0, r: Rep): EndPoints0 = r match { - // case lb: LetBinding => endPoints0(update(acc, lb), lb.body) - // case _ => acc - // } - // - // val res = endPoints0(EndPoints0(ListSet.empty, Map.empty, 0), r) - // EndPoints(res.endPoints, res.nodes mapValues (_.signature)) - //} } - def extractWithState(xtor: Rep, xtee: Rep)(implicit es: State): Option[State] = { + def extractWithState(xtor: Rep, xtee: Rep)(implicit es: State): ExtractState = { //println(s"$xtor\n$xtee with $ctx\n\n") def reverse[A, B](m: Map[A, Set[B]]): Map[B, A] = for { @@ -433,7 +391,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi b <- bs } yield b -> a - def extractHOPHole(name: String, typ: TypeRep, argss: List[List[Rep]], visible: List[BoundVal])(implicit es: State): Option[State] = { + def extractHOPHole(name: String, typ: TypeRep, argss: List[List[Rep]], visible: List[BoundVal])(implicit es: State): ExtractState = { println("EXTRACTINGHOPHOLE") type Func = List[List[BoundVal]] -> Rep def emptyFunc(r: Rep) = List.empty[List[BoundVal]] -> r @@ -444,7 +402,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi val invCtx = reverse(es.ctx) println(s"INVCTX: $invCtx") - //bottomUpPartial(xtee) { case bv: BoundVal if visible contains invCtx.getOrElse(bv, bv) => return None } + println(s"VISIBLE: $visible") + bottomUpPartial(xtee) { case bv: BoundVal if visible contains invCtx.getOrElse(bv, bv) => return Left(es) } def changeRepBVs(r: Rep)(f: BoundVal => Option[BoundVal]): Rep = r match { case bv: BoundVal => @@ -518,31 +477,38 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi after } - - for { + + (for { e1 <- typ extract (xtee.typ, Covariant) f <- argss.foldRight(Option(emptyFunc(xtee)))(extendFunc) _ = println(s"F: $f") l = fargss(f).foldRight(fbody(f)) { case (args, body) => wrapConstruct(lambda(args, body)) } e2 = repExtract(name -> l) m <- mergeAll(e1, e2, es.ex) - } yield es withNewExtract m + } yield m) + .fold[ExtractState](Left(es))(ex => Right(es withNewExtract ex)) } - def extractLBs(lb1: LetBinding, lb2: LetBinding)(implicit es: State): Option[State] = (effect(lb1.value), effect(lb2.value)) match { - case (Impure, Impure) => for { - es1 <- extractDefs(lb1.value, lb2.value) - _ = if (es1.makeUnreachable) lb2.value = Unreachable - c <- extractWithState(lb1.body, lb2.body)(es1 withCtx (lb1.bound, lb2.bound) withMatched lb2.bound) - } yield c + def extractLBs(lb1: LetBinding, lb2: LetBinding)(implicit es: State): ExtractState = (effect(lb1.value), effect(lb2.value)) match { + case (Impure, Impure) => + extractDefs(lb1.value, lb2.value).right flatMap { es => + if (es.makeUnreachable) lb2.value = Unreachable + extractWithState(lb1.body, lb2.body)(es withCtx (lb1.bound, lb2.bound) withMatched lb2.bound) + } + + //for { + // es1 <- extractDefs(lb1.value, lb2.value).right + // _ = if (es1.makeUnreachable) lb2.value = Unreachable + // es2 <- extractWithState(lb1.body, lb2.body)(es1 withCtx (lb1.bound, lb2.bound) withMatched lb2.bound).right + //} yield es2 case (Pure, Pure) => (es.mks.xtorMarker(lb1.bound), es.mks.xteeMarker(lb2.bound)) match { case (EndPoint, EndPoint) => extractWithState(lb1.bound, lb2.bound) match { - case Some(es0) => - if (es0.makeUnreachable) lb2.value = Unreachable + case Right(es0) => + if (es.makeUnreachable) lb2.value = Unreachable extractWithState(lb1.body, lb2.body)(es0 withoutMarkers(lb1.bound, lb2.bound) withMatched lb2.bound) - case None => extractWithState(lb1, lb2.body) + case Left(es0) => extractWithState(lb1, lb2.body)(es0) } case (EndPoint, NonEndPoint) => extractWithState(lb1, lb2.body) case (NonEndPoint, EndPoint) => extractWithState(lb1.body, lb2) @@ -553,7 +519,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (Pure, Impure) => extractWithState(lb1.body, lb2) } - def extractHole(h: Hole, r: Rep)(implicit es: State): Option[State] = { + def extractHole(h: Hole, r: Rep)(implicit es: State): ExtractState = { println(s"ExtractHole: $h --> $r") (h, r) match { @@ -563,26 +529,44 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi // case _ => bv //} - for { - e <- t extract (xtee.typ, Covariant) - m <- mergeAll(e, es.ex, repExtract(n -> bv)) - } yield es withNewExtract m withMatched bv - - case (Hole(n, t), lb: LetBinding) => for { - e <- t extract (lb.typ, Covariant) - newLB = wrapConstruct(letbind(lb.value))// flag - _ = if (es.makeUnreachable) lb.value = Unreachable - m <- mergeAll(e, es.ex, repExtract(n -> newLB)) - } yield es withNewExtract m withMatched lb.bound - - case (Hole(n, t), _) => for { - e <- t extract (xtee.typ, Covariant) - m <- mergeAll(e, es.ex, repExtract(n -> xtee)) - } yield es withNewExtract m + (t extract (xtee.typ, Covariant)) + .flatMap(mergeAll(_, es.ex, repExtract(n -> bv))) + .fold[ExtractState](Left(es))(m => Right(es withNewExtract m withMatched bv)) + + //for { + // e <- t extract (xtee.typ, Covariant) + // m <- mergeAll(e, es.ex, repExtract(n -> bv)) + //} yield es withNewExtract m withMatched bv + + case (Hole(n, t), lb: LetBinding) => + (t extract (lb.typ, Covariant)) + .flatMap { e => + val newLB = wrapConstruct(letbind(lb.value)) + if (es.makeUnreachable) lb.value = Unreachable + mergeAll(e, es.ex, repExtract(n -> newLB)) + } + .fold[ExtractState](Left(es))(m => Right(es withNewExtract m withMatched lb.bound)) + + // for { + // e <- t extract (lb.typ, Covariant) + // newLB = wrapConstruct(letbind(lb.value))// flag + // _ = if (es.makeUnreachable) lb.value = Unreachable + // m <- mergeAll(e, es.ex, repExtract(n -> newLB)) + //} yield es withNewExtract m withMatched lb.bound + + case (Hole(n, t), _) => + (t extract (xtee.typ, Covariant)) + .flatMap(mergeAll(_, es.ex, repExtract(n -> xtee))) + .fold[ExtractState](Left(es))(m => Right(es withNewExtract m)) + + // for { + // e <- t extract (xtee.typ, Covariant) + // m <- mergeAll(e, es.ex, repExtract(n -> xtee)) + //} yield es withNewExtract m } } - def extractInside(bv: BoundVal, d: Def)(implicit es: State): Option[State] = { + def extractInside(bv: BoundVal, d: Def)(implicit es: State): ExtractState = { def bvs(d: Def): List[BoundVal] = d match { case ma: MethodApp => (ma.self :: ma.argss.argssList).foldRight(List.empty[BoundVal]) { case (bv: BoundVal, acc) => bv :: acc @@ -591,8 +575,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => Nil } - bvs(d).foldLeft(Option.empty[State]) { case (acc, bv2) => - acc orElse extractWithState(bv, bv2)(es) + bvs(d).foldLeft[ExtractState](Left(es)) { case (acc, bv2) => for { + _ <- acc.left + es1 <- extractWithState(bv, bv2)(es).left + } yield es1 } alsoApply(s => println(s"FOO: $s")) } @@ -601,74 +587,80 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi println(s"extractWithState: $xtor\n$xtee\n") xtor -> xtee match { case (h: Hole, lb: LetBinding) => contentsOf(h) match { - case Some(lb1: LetBinding) if lb1.value == lb.value => Some(es) - case Some(_) => None + case Some(lb1: LetBinding) if lb1.value == lb.value => Right(es) + case Some(_) => Left(es) case None => extractHole(h, xtee) } case (h: Hole, _) => contentsOf(h) match { - case Some(`xtee`) => Some(es) - case Some(_) => None + case Some(`xtee`) => Right(es) + case Some(_) => Left(es) case None => extractHole(h, xtee) } - + case (h@HOPHole2(name, typ, argss, visible), _) => //println(s"Hole $h -> $xtee\n\n") extractHOPHole(name, typ, argss, visible) - + case (lb1: LetBinding, lb2: LetBinding) => extractLBs(lb1, lb2) - + // Stop at markers? case (lb: LetBinding, _: Rep) => extractWithState(lb.body, xtee) - - case (bv: BoundVal, lb: LetBinding) => - extractWithState(bv, lb.bound) orElse extractInside(bv, lb.value) orElse extractWithState(bv, lb.body) - + + case (bv: BoundVal, lb: LetBinding) => for { + es1 <- extractWithState(bv, lb.bound).left + es2 <- extractInside(bv, lb.value)(es1).left + es3 <- extractWithState(bv, lb.body)(es2).left + } yield es3 + case (_: Rep, lb: LetBinding) if lb.value == Unreachable => extractWithState(xtor, lb.body) case (_, Ascribe(s, _)) => extractWithState(xtor, s) - - case (Ascribe(s, t) , _) => - for { - e1 <- t extract (xtee.typ, Covariant) - m <- merge(e1, es.ex) - es2 <- extractWithState(s, xtee)(es withNewExtract m) - } yield es2 + + case (Ascribe(s, t), _) => for { + es1 <- (t extract(xtee.typ, Covariant)) + .flatMap(merge(_, es.ex)) + .fold[ExtractState](Left(es))(ex => Right(es withNewExtract ex)).right + es2 <- extractWithState(s, xtee)(es1).right + } yield es2 case (HOPHole(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) case (bv1: BoundVal, bv2: BoundVal) => println(s"EXTRACTIONSTATE IN BV: $es") println(s"OWNERS: ${bv1.owner} -- ${bv2.owner}") - if (bv1 == bv2 || (es.ctx(bv1) contains bv2)) Some(es) + if (bv1 == bv2 || (es.ctx(bv1) contains bv2)) Right(es) + else if (es.failedMatches(bv1) contains bv2) Left(es) else (bv1.owner, bv2.owner) match { - case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value) map { es => - if (es.makeUnreachable) lb2.value = Unreachable - es withCtx (lb1.bound, lb2.bound) withMatched lb2.bound + case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value) match { + case Right(es0) => + if (es0.makeUnreachable) lb2.value = Unreachable + Right(es0 withCtx (lb1.bound, lb2.bound) withMatched lb2.bound) + case Left(es0) => Left(es0 withFailed lb1.bound -> lb2.bound) } - case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2) map (_.withCtx(l1.bound, l2.bound)) - //case (_: UnboundSymbol, _: UnboundSymbol) => None - case _ => None + case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2).right map (_ withCtx l1.bound -> l2.bound) + case _ => Left(es) } - case (Constant(v1), Constant(v2)) if v1 == v2 => for { + case (Constant(v1), Constant(v2)) if v1 == v2 => (for { eTyp <- xtor.typ extract (xtee.typ, Covariant) m <- merge(eTyp, es.ex) - } yield es withNewExtract m - + } yield es withNewExtract m) + .fold[ExtractState](Left(es))(Right(_)) // Assuming if they have the same name the type is the same - case (StaticModule(fn1), StaticModule(fn2)) if fn1 == fn2 => Some(es) + case (StaticModule(fn1), StaticModule(fn2)) if fn1 == fn2 => Right(es) // Assuming if they have the same name and prefix the type is the same case (Module(p1, n1, _), Module(p2, n2, _)) if n1 == n2 => extractWithState(p1, p2) - case (NewObject(t1), NewObject(t2)) => for { + case (NewObject(t1), NewObject(t2)) => (for { eTyp <- t1 extract (t2, Covariant) m <- merge(eTyp, es.ex) - } yield es withNewExtract m + } yield es withNewExtract m) + .fold[ExtractState](Left(es))(Right(_)) - case _ => None + case _ => Left(es) } } alsoApply (res => println(s"Extract: $res")) @@ -695,12 +687,12 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi rewriteRep0(xtor, xtee, code)(State.forRewriting(xtor, xtee), false) def rewriteRep0(xtor: Rep, xtee: Rep, code: Extract => Option[Rep])(es: State, internalRec: Boolean): Option[Rep] = { - def rewriteRepWithState(xtor: Rep, xtee: Rep)(implicit es: State): Option[State] = { + def rewriteRepWithState(xtor: Rep, xtee: Rep)(implicit es: State): ExtractState = { println(s"rewriteRepWithState(\n\t$xtor\n\t$xtee)($es)") (xtor, xtee) match { case (lb1: LetBinding, lb2: LetBinding) if !internalRec => ((effect(lb1), es.mks.xtorMarker(lb1.bound)), (effect(lb2), es.mks.xteeMarker(lb2.bound))) match { - case ((Pure, NonEndPoint), (Pure, NonEndPoint)) => None + case ((Pure, NonEndPoint), (Pure, NonEndPoint)) => Left(es) case _ => extractWithState(lb1, lb2) // TODO With unreachable handling TODO why did I mean? } case _ => extractWithState(xtor, xtee) @@ -729,58 +721,77 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi if check(es.matchedBVs)(cleanup(code)) } yield code } - - rewriteRepWithState(xtor, xtee)(es) flatMap genCode alsoApply (c => println(s"Code: $c")) + + rewriteRepWithState(xtor, xtee)(es) match { + case Right(es) => genCode(es) + case Left(_) => None + } + + // flatMap genCode alsoApply (c => println(s"Code: $c")) } - def extractDefs(v1: Def, v2: Def)(implicit es: State): Option[State] = { + def extractDefs(v1: Def, v2: Def)(implicit es: State): ExtractState = { println(s"VALUES: \n\t$v1\n\t$v2 with $es \n\n") (v1, v2) match { // Has already been matched... - case (_, Unreachable) => None + case (_, Unreachable) => Left(es) //case (Unreachable, _) => Some(es) case (l1: Lambda, l2: Lambda) => for { - e1 <- l1.boundType extract (l2.boundType, Covariant) - m1 <- merge(e1, es.ex) - es2 <- extractWithState(l1.body, l2.body)(es withNewExtract m1 withCtx l1.bound -> l2.bound) - m2 <- merge(es2.ex, m1) - } yield es2 withNewExtract m2 + es1 <- (l1.boundType extract (l2.boundType, Covariant)) + .flatMap(merge(_, es.ex)) + .fold[ExtractState](Left(es))(ex => Right(es withNewExtract ex)).right + es2 <- extractWithState(l1.body, l2.body)(es1 withCtx l1.bound -> l2.bound).right + } yield es2 case (ma1: MethodApp, ma2: MethodApp) if ma1.mtd == ma2.mtd => - lazy val targExtract = mergeAll(for { - (e1, e2) <- ma1.targs zip ma2.targs - } yield e1 extract (e2, Invariant)) // TODO Invariant? Depends on its positions... - - def extractArgss(argss1: ArgumentLists, argss2: ArgumentLists)(implicit es: State): Option[Extract] = { - def extractArgss0(argss1: ArgumentLists, argss2: ArgumentLists, acc: List[Rep])(implicit es: State): Option[Extract] = (argss1, argss2) match { - case (ArgumentListCons(h1, t1), ArgumentListCons(h2, t2)) => mergeOpt(extractArgss0(h1, h2, acc), extractArgss0(t1, t2, acc)) - case (ArgumentCons(h1, t1), ArgumentCons(h2, t2)) => mergeOpt(extractArgss0(h1, h2, acc), extractArgss0(t1, t2, acc)) - case (SplicedArgument(arg1), SplicedArgument(arg2)) => extractWithState(arg1, arg2)(es) map (_.ex) + def targExtract(es: State): ExtractState = mergeAll(for { + (e1, e2) <- ma1.targs zip ma2.targs + } yield e1 extract(e2, Invariant)) + .flatMap(merge(_, es.ex)) + .fold[ExtractState](Left(es))(ex => Right(es withNewExtract ex)) + + def extractArgss(argss1: ArgumentLists, argss2: ArgumentLists)(implicit es: State): ExtractState = { + def extractArgss0(argss1: ArgumentLists, argss2: ArgumentLists, acc: List[Rep])(implicit es: State): ExtractState = (argss1, argss2) match { + case (ArgumentListCons(h1, t1), ArgumentListCons(h2, t2)) => for { + es0 <- extractArgss0(h1, h2, acc).right + es1 <- extractArgss0(t1, t2, acc)(es0).right + } yield es1 + + case (ArgumentCons(h1, t1), ArgumentCons(h2, t2)) => for { + es0 <- extractArgss0(h1, h2, acc).right + es1 <- extractArgss0(t1, t2, acc)(es0).right + } yield es1 + + case (SplicedArgument(arg1), SplicedArgument(arg2)) => extractWithState(arg1, arg2) case (sa: SplicedArgument, ArgumentCons(h, t)) => extractArgss0(sa, t, h :: acc) case (sa: SplicedArgument, r: Rep) => extractArgss0(sa, NoArguments, r :: acc) - case (SplicedArgument(arg), NoArguments) => spliceExtract(arg, Args(acc.reverse: _*)) // Reverses list... - case (r1: Rep, r2: Rep) => extractWithState(r1, r2)(es) map (_.ex) - case (NoArguments, NoArguments) => Some(es.ex) - case (NoArgumentLists, NoArgumentLists) => Some(es.ex) - case _ => None + case (SplicedArgument(arg), NoArguments) => spliceExtract(arg, Args(acc.reverse: _*)) match { + case Some(ex) => merge(ex, es.ex).fold[ExtractState](Left(es))(ex => Right(es withNewExtract ex)) + case None => Left(es) + } + case (r1: Rep, r2: Rep) => extractWithState(r1, r2) + case (NoArguments, NoArguments) => Right(es) + case (NoArgumentLists, NoArgumentLists) => Right(es) + case _ => Left(es) } extractArgss0(argss1, argss2, Nil) } for { - es1 <- extractWithState(ma1.self, ma2.self)(es) - e2 <- targExtract - e3 <- extractArgss(ma1.argss, ma2.argss)(es1) - e4 <- ma1.typ extract (ma2.typ, Covariant) - m <- mergeAll(e2, e3, e4) - } yield es withNewExtract m + es1 <- extractWithState(ma1.self, ma2.self)(es).right + es2 <- targExtract(es1).right + es3 <- extractArgss(ma1.argss, ma2.argss)(es2).right + es4 <- (ma1.typ extract (ma2.typ, Covariant)) + .flatMap(merge(_, es3.ex)) + .fold[ExtractState](Left(es))(ex => Right(es3 withNewExtract ex)).right + } yield es4 case (DefHole(h), _) => extractWithState(h, wrapConstruct(letbind(v2))) - case _ => None + case _ => Left(es) } } From b43ef099e3d208d0ee744c74b0717186f171b448 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Tue, 21 Nov 2017 15:53:48 +0100 Subject: [PATCH 12/66] Factor out extract updates --- src/main/scala/squid/ir/fastanf/FastANF.scala | 95 +++++++------------ 1 file changed, 33 insertions(+), 62 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index ec7739bc..c56b4969 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -478,15 +478,14 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi after } - (for { + updateExtract(for { e1 <- typ extract (xtee.typ, Covariant) f <- argss.foldRight(Option(emptyFunc(xtee)))(extendFunc) _ = println(s"F: $f") l = fargss(f).foldRight(fbody(f)) { case (args, body) => wrapConstruct(lambda(args, body)) } e2 = repExtract(name -> l) m <- mergeAll(e1, e2, es.ex) - } yield m) - .fold[ExtractState](Left(es))(ex => Right(es withNewExtract ex)) + } yield m)(es) } def extractLBs(lb1: LetBinding, lb2: LetBinding)(implicit es: State): ExtractState = (effect(lb1.value), effect(lb2.value)) match { @@ -528,41 +527,24 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi // case lb: LetBinding => new LetBinding(lb.name, lb.bound, lb.value, lb.bound) // flag // case _ => bv //} - - (t extract (xtee.typ, Covariant)) - .flatMap(mergeAll(_, es.ex, repExtract(n -> bv))) - .fold[ExtractState](Left(es))(m => Right(es withNewExtract m withMatched bv)) - - //for { - // e <- t extract (xtee.typ, Covariant) - // m <- mergeAll(e, es.ex, repExtract(n -> bv)) - //} yield es withNewExtract m withMatched bv - - case (Hole(n, t), lb: LetBinding) => - (t extract (lb.typ, Covariant)) - .flatMap { e => - val newLB = wrapConstruct(letbind(lb.value)) - if (es.makeUnreachable) lb.value = Unreachable - mergeAll(e, es.ex, repExtract(n -> newLB)) - } - .fold[ExtractState](Left(es))(m => Right(es withNewExtract m withMatched lb.bound)) - - // for { - // e <- t extract (lb.typ, Covariant) - // newLB = wrapConstruct(letbind(lb.value))// flag - // _ = if (es.makeUnreachable) lb.value = Unreachable - // m <- mergeAll(e, es.ex, repExtract(n -> newLB)) - //} yield es withNewExtract m withMatched lb.bound - - case (Hole(n, t), _) => - (t extract (xtee.typ, Covariant)) - .flatMap(mergeAll(_, es.ex, repExtract(n -> xtee))) - .fold[ExtractState](Left(es))(m => Right(es withNewExtract m)) - - // for { - // e <- t extract (xtee.typ, Covariant) - // m <- mergeAll(e, es.ex, repExtract(n -> xtee)) - //} yield es withNewExtract m + updateExtract( + t extract (xtee.typ, Covariant), + Some(repExtract(n -> bv)) + )(es) + .right.map(_ withMatched bv) + + case (Hole(n, t), lb: LetBinding) => + updateExtract( + t extract(lb.typ, Covariant), + Some(repExtract(n -> wrapConstruct(letbind(lb.value)))) + )(es) + .right.map(_ withMatched lb.bound) + + case (Hole(n, t), _) => + updateExtract( + t extract(xtee.typ, Covariant), + Some(repExtract(n -> xtee)) + )(es) } } @@ -618,9 +600,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (_, Ascribe(s, _)) => extractWithState(xtor, s) case (Ascribe(s, t), _) => for { - es1 <- (t extract(xtee.typ, Covariant)) - .flatMap(merge(_, es.ex)) - .fold[ExtractState](Left(es))(ex => Right(es withNewExtract ex)).right + es1 <- updateExtract(t extract(xtee.typ, Covariant))(es).right es2 <- extractWithState(s, xtee)(es1).right } yield es2 @@ -642,11 +622,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => Left(es) } - case (Constant(v1), Constant(v2)) if v1 == v2 => (for { - eTyp <- xtor.typ extract (xtee.typ, Covariant) - m <- merge(eTyp, es.ex) - } yield es withNewExtract m) - .fold[ExtractState](Left(es))(Right(_)) + case (Constant(v1), Constant(v2)) if v1 == v2 => + updateExtract(xtor.typ extract (xtee.typ, Covariant))(es) // Assuming if they have the same name the type is the same case (StaticModule(fn1), StaticModule(fn2)) if fn1 == fn2 => Right(es) @@ -654,11 +631,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi // Assuming if they have the same name and prefix the type is the same case (Module(p1, n1, _), Module(p2, n2, _)) if n1 == n2 => extractWithState(p1, p2) - case (NewObject(t1), NewObject(t2)) => (for { - eTyp <- t1 extract (t2, Covariant) - m <- merge(eTyp, es.ex) - } yield es withNewExtract m) - .fold[ExtractState](Left(es))(Right(_)) + case (NewObject(t1), NewObject(t2)) => updateExtract(t1 extract (t2, Covariant))(es) case _ => Left(es) } @@ -739,18 +712,14 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (l1: Lambda, l2: Lambda) => for { - es1 <- (l1.boundType extract (l2.boundType, Covariant)) - .flatMap(merge(_, es.ex)) - .fold[ExtractState](Left(es))(ex => Right(es withNewExtract ex)).right + es1 <- updateExtract(l1.boundType extract (l2.boundType, Covariant))(es).right es2 <- extractWithState(l1.body, l2.body)(es1 withCtx l1.bound -> l2.bound).right } yield es2 case (ma1: MethodApp, ma2: MethodApp) if ma1.mtd == ma2.mtd => - def targExtract(es: State): ExtractState = mergeAll(for { + def targExtract(es0: State): ExtractState = updateExtract(mergeAll(for { (e1, e2) <- ma1.targs zip ma2.targs - } yield e1 extract(e2, Invariant)) - .flatMap(merge(_, es.ex)) - .fold[ExtractState](Left(es))(ex => Right(es withNewExtract ex)) + } yield e1 extract(e2, Invariant)))(es0)(es) def extractArgss(argss1: ArgumentLists, argss2: ArgumentLists)(implicit es: State): ExtractState = { def extractArgss0(argss1: ArgumentLists, argss2: ArgumentLists, acc: List[Rep])(implicit es: State): ExtractState = (argss1, argss2) match { @@ -784,9 +753,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi es1 <- extractWithState(ma1.self, ma2.self)(es).right es2 <- targExtract(es1).right es3 <- extractArgss(ma1.argss, ma2.argss)(es2).right - es4 <- (ma1.typ extract (ma2.typ, Covariant)) - .flatMap(merge(_, es3.ex)) - .fold[ExtractState](Left(es))(ex => Right(es3 withNewExtract ex)).right + es4 <- updateExtract(ma1.typ extract (ma2.typ, Covariant))(es3).right } yield es4 case (DefHole(h), _) => extractWithState(h, wrapConstruct(letbind(v2))) @@ -794,6 +761,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => Left(es) } } + + def updateExtract(e: Option[Extract]*)(es: State)(implicit default: State): ExtractState = { + mergeAll(Some(es.ex) +: e).fold[ExtractState](Left(default))(ex => Right(es withNewExtract ex)) + } def cleanup(r: Rep): Rep = { println("CLEANING!!") @@ -816,7 +787,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi HOPHole2(name, typ, args, visible filterNot (args.flatten contains _)) def substitute(r: => Rep, defs: Map[String, Rep]): Rep = { println(s"Subs: $r with $defs") - if (defs isEmpty) r //|> inlineBlock // TODO works if I remove this... + if (defs isEmpty) r |> inlineBlock // TODO works if I remove this... else bottomUp(r) { case h@Hole(n, _) => defs getOrElse(n, h) case h@SplicedHole(n, _) => defs getOrElse(n, h) From 045ebe1254a0210db93adb6456ffb27f91e02432 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Tue, 21 Nov 2017 15:59:19 +0100 Subject: [PATCH 13/66] Make updateExtract infix --- src/main/scala/squid/ir/fastanf/FastANF.scala | 40 +++++++++---------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index c56b4969..e78c66c4 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -338,6 +338,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def withoutMarkers(xtorMk: BoundVal, xteeMk: BoundVal): State = copy(mks = Markers(mks.xtor - xtorMk, mks.xtee - xteeMk)) def withMatched(bv: BoundVal): State = copy(matchedBVs = matchedBVs + bv) def withFailed(p: (BoundVal, BoundVal)): State = copy(failedMatches = updateWith(failedMatches)(p)) + + def updateExtract(e: Option[Extract]*)(implicit default: State): ExtractState = { + mergeAll(Some(ex) +: e).fold[ExtractState](Left(default))(ex => Right(this withNewExtract ex)) + } } object State { def forRewriting(xtor: Rep, xtee: Rep): State = State(xtor, xtee, true) @@ -478,14 +482,14 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi after } - updateExtract(for { + es.updateExtract(for { e1 <- typ extract (xtee.typ, Covariant) f <- argss.foldRight(Option(emptyFunc(xtee)))(extendFunc) _ = println(s"F: $f") l = fargss(f).foldRight(fbody(f)) { case (args, body) => wrapConstruct(lambda(args, body)) } e2 = repExtract(name -> l) m <- mergeAll(e1, e2, es.ex) - } yield m)(es) + } yield m) } def extractLBs(lb1: LetBinding, lb2: LetBinding)(implicit es: State): ExtractState = (effect(lb1.value), effect(lb2.value)) match { @@ -527,24 +531,22 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi // case lb: LetBinding => new LetBinding(lb.name, lb.bound, lb.value, lb.bound) // flag // case _ => bv //} - updateExtract( + es.updateExtract( t extract (xtee.typ, Covariant), Some(repExtract(n -> bv)) - )(es) - .right.map(_ withMatched bv) + ).right.map(_ withMatched bv) case (Hole(n, t), lb: LetBinding) => - updateExtract( + es.updateExtract( t extract(lb.typ, Covariant), Some(repExtract(n -> wrapConstruct(letbind(lb.value)))) - )(es) - .right.map(_ withMatched lb.bound) + ).right.map(_ withMatched lb.bound) case (Hole(n, t), _) => - updateExtract( + es.updateExtract( t extract(xtee.typ, Covariant), Some(repExtract(n -> xtee)) - )(es) + ) } } @@ -600,7 +602,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (_, Ascribe(s, _)) => extractWithState(xtor, s) case (Ascribe(s, t), _) => for { - es1 <- updateExtract(t extract(xtee.typ, Covariant))(es).right + es1 <- es.updateExtract(t extract(xtee.typ, Covariant)).right es2 <- extractWithState(s, xtee)(es1).right } yield es2 @@ -623,7 +625,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } case (Constant(v1), Constant(v2)) if v1 == v2 => - updateExtract(xtor.typ extract (xtee.typ, Covariant))(es) + es.updateExtract(xtor.typ extract (xtee.typ, Covariant)) // Assuming if they have the same name the type is the same case (StaticModule(fn1), StaticModule(fn2)) if fn1 == fn2 => Right(es) @@ -631,7 +633,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi // Assuming if they have the same name and prefix the type is the same case (Module(p1, n1, _), Module(p2, n2, _)) if n1 == n2 => extractWithState(p1, p2) - case (NewObject(t1), NewObject(t2)) => updateExtract(t1 extract (t2, Covariant))(es) + case (NewObject(t1), NewObject(t2)) => es.updateExtract(t1 extract (t2, Covariant)) case _ => Left(es) } @@ -712,14 +714,14 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (l1: Lambda, l2: Lambda) => for { - es1 <- updateExtract(l1.boundType extract (l2.boundType, Covariant))(es).right + es1 <- es.updateExtract(l1.boundType extract (l2.boundType, Covariant)).right es2 <- extractWithState(l1.body, l2.body)(es1 withCtx l1.bound -> l2.bound).right } yield es2 case (ma1: MethodApp, ma2: MethodApp) if ma1.mtd == ma2.mtd => - def targExtract(es0: State): ExtractState = updateExtract(mergeAll(for { + def targExtract(es0: State): ExtractState = es.updateExtract(mergeAll(for { (e1, e2) <- ma1.targs zip ma2.targs - } yield e1 extract(e2, Invariant)))(es0)(es) + } yield e1 extract(e2, Invariant)))(es) def extractArgss(argss1: ArgumentLists, argss2: ArgumentLists)(implicit es: State): ExtractState = { def extractArgss0(argss1: ArgumentLists, argss2: ArgumentLists, acc: List[Rep])(implicit es: State): ExtractState = (argss1, argss2) match { @@ -753,7 +755,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi es1 <- extractWithState(ma1.self, ma2.self)(es).right es2 <- targExtract(es1).right es3 <- extractArgss(ma1.argss, ma2.argss)(es2).right - es4 <- updateExtract(ma1.typ extract (ma2.typ, Covariant))(es3).right + es4 <- es3.updateExtract(ma1.typ extract (ma2.typ, Covariant)).right } yield es4 case (DefHole(h), _) => extractWithState(h, wrapConstruct(letbind(v2))) @@ -761,10 +763,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => Left(es) } } - - def updateExtract(e: Option[Extract]*)(es: State)(implicit default: State): ExtractState = { - mergeAll(Some(es.ex) +: e).fold[ExtractState](Left(default))(ex => Right(es withNewExtract ex)) - } def cleanup(r: Rep): Rep = { println("CLEANING!!") From f2c934c2c10012acac6079a58d401e9378e9deb6 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Wed, 22 Nov 2017 14:59:31 +0100 Subject: [PATCH 14/66] Right-bias either and fix some errors --- src/main/scala/squid/ir/fastanf/FastANF.scala | 135 +++++++++--------- src/main/scala/squid/ir/fastanf/Rep.scala | 1 + 2 files changed, 72 insertions(+), 64 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index e78c66c4..1187e02a 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -133,6 +133,15 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi println(s"letin: $value --> $bound") value match { case s: Symbol => + s.owner |>? { + case lb: RebindableBinding => + //println(s"LETIN $lb ") + lb.name = bound.name + } + s.owner |>? { + case lb: LetBinding => + lb.userDefined = true + } withSubs(bound, value)(body) //s.owner |>? { @@ -330,6 +339,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } type ExtractState = Either[State, State] + implicit def rightBias[A, B](e: Either[A, B]): Either.RightProjection[A,B] = e.right + case class State(ex: Extract, ctx: Ctx, mks: Markers, matchedBVs: Set[BoundVal], failedMatches: Map[BoundVal, Set[BoundVal]], makeUnreachable: Boolean) { def withNewExtract(newEx: Extract): State = copy(ex = newEx) def withCtx(newCtx: Ctx): State = copy(ctx = newCtx) @@ -339,7 +350,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def withMatched(bv: BoundVal): State = copy(matchedBVs = matchedBVs + bv) def withFailed(p: (BoundVal, BoundVal)): State = copy(failedMatches = updateWith(failedMatches)(p)) - def updateExtract(e: Option[Extract]*)(implicit default: State): ExtractState = { + def updateExtractWith(e: Option[Extract]*)(implicit default: State): ExtractState = { mergeAll(Some(ex) +: e).fold[ExtractState](Left(default))(ex => Right(this withNewExtract ex)) } } @@ -426,6 +437,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => r } + def changeDefBVs(d: Def)(f: BoundVal => Option[BoundVal]): Def = d match { case ma: MethodApp => MethodApp( @@ -471,7 +483,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi f <- f _ = println(s"BEFORE $f") body0 <- transformations.foldLeft(Option(fbody(f))) { - case (Some(body), (bv: BoundVal, res)) => Some(changeRepBVs(body)(bv => Some(res))) + case (Some(body), (bv: BoundVal, res)) => Some(changeRepBVs(body)(bv0 => Some(if (bv == bv0) res else bv0))) case (Some(body), (xtor, res)) => rewriteRep0(xtor, body, _ => Some(res))(State.forRewriting(xtor, body), true) case _ => None } @@ -482,40 +494,37 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi after } - es.updateExtract(for { - e1 <- typ extract (xtee.typ, Covariant) - f <- argss.foldRight(Option(emptyFunc(xtee)))(extendFunc) - _ = println(s"F: $f") - l = fargss(f).foldRight(fbody(f)) { case (args, body) => wrapConstruct(lambda(args, body)) } - e2 = repExtract(name -> l) - m <- mergeAll(e1, e2, es.ex) - } yield m) + es.updateExtractWith( + typ extract (xtee.typ, Covariant), + for { + f <- argss.foldRight(Option(emptyFunc(xtee)))(extendFunc) + _ = println(s"F: $f") + l = fargss(f).foldRight(fbody(f)) { case (args, body) => wrapConstruct(lambda(args, body)) } + } yield repExtract(name -> l) + ) } def extractLBs(lb1: LetBinding, lb2: LetBinding)(implicit es: State): ExtractState = (effect(lb1.value), effect(lb2.value)) match { case (Impure, Impure) => - extractDefs(lb1.value, lb2.value).right flatMap { es => + extractDefs(lb1.value, lb2.value) match { + case Right(es) => if (es.makeUnreachable) lb2.value = Unreachable - extractWithState(lb1.body, lb2.body)(es withCtx (lb1.bound, lb2.bound) withMatched lb2.bound) + extractWithState(lb1.body, lb2.body)(es withCtx lb1.bound -> lb2.bound withMatched lb2.bound) + case Left(es) => Left(es withFailed lb1.bound -> lb2.bound) } - //for { - // es1 <- extractDefs(lb1.value, lb2.value).right - // _ = if (es1.makeUnreachable) lb2.value = Unreachable - // es2 <- extractWithState(lb1.body, lb2.body)(es1 withCtx (lb1.bound, lb2.bound) withMatched lb2.bound).right - //} yield es2 - case (Pure, Pure) => (es.mks.xtorMarker(lb1.bound), es.mks.xteeMarker(lb2.bound)) match { case (EndPoint, EndPoint) => extractWithState(lb1.bound, lb2.bound) match { - case Right(es0) => + case Right(es) => if (es.makeUnreachable) lb2.value = Unreachable - extractWithState(lb1.body, lb2.body)(es0 withoutMarkers(lb1.bound, lb2.bound) withMatched lb2.bound) - case Left(es0) => extractWithState(lb1, lb2.body)(es0) + extractWithState(lb1.body, lb2.body)(es withoutMarkers(lb1.bound, lb2.bound) withMatched lb2.bound) + case Left(es) => extractWithState(lb1, lb2.body)(es) } case (EndPoint, NonEndPoint) => extractWithState(lb1, lb2.body) - case (NonEndPoint, EndPoint) => extractWithState(lb1.body, lb2) - case (NonEndPoint, NonEndPoint) => extractWithState(lb1.body, lb2) + case (NonEndPoint, _) => extractWithState(lb1.body, lb2) + //case (NonEndPoint, EndPoint) => extractWithState(lb1.body, lb2) + //case (NonEndPoint, NonEndPoint) => extractWithState(lb1.body, lb2) } case (Impure, Pure) => extractWithState(lb1, lb2.body) @@ -531,19 +540,19 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi // case lb: LetBinding => new LetBinding(lb.name, lb.bound, lb.value, lb.bound) // flag // case _ => bv //} - es.updateExtract( + es.updateExtractWith( t extract (xtee.typ, Covariant), Some(repExtract(n -> bv)) - ).right.map(_ withMatched bv) + ).map(_ withMatched bv) case (Hole(n, t), lb: LetBinding) => - es.updateExtract( + es.updateExtractWith( t extract(lb.typ, Covariant), Some(repExtract(n -> wrapConstruct(letbind(lb.value)))) - ).right.map(_ withMatched lb.bound) + ).map(_ withMatched lb.bound) case (Hole(n, t), _) => - es.updateExtract( + es.updateExtractWith( t extract(xtee.typ, Covariant), Some(repExtract(n -> xtee)) ) @@ -559,11 +568,12 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => Nil } - bvs(d).foldLeft[ExtractState](Left(es)) { case (acc, bv2) => for { - _ <- acc.left - es1 <- extractWithState(bv, bv2)(es).left - } yield es1 - } alsoApply(s => println(s"FOO: $s")) + bvs(d).foldLeft[ExtractState](Left(es)) { case (acc, bv2) => + for { + es1 <- acc.left + es2 <- extractWithState(bv, bv2)(es1).left + } yield es2 + } alsoApply (s => println(s"FOO: $s")) } def contentsOf(h: Hole)(implicit es: State): Option[Rep] = es.ex._1 get h.name // TODO check in ex._3, return Option[List[Rep]] @@ -582,9 +592,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case None => extractHole(h, xtee) } - case (h@HOPHole2(name, typ, argss, visible), _) => - //println(s"Hole $h -> $xtee\n\n") - extractHOPHole(name, typ, argss, visible) + case (HOPHole2(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) case (lb1: LetBinding, lb2: LetBinding) => extractLBs(lb1, lb2) @@ -602,8 +610,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (_, Ascribe(s, _)) => extractWithState(xtor, s) case (Ascribe(s, t), _) => for { - es1 <- es.updateExtract(t extract(xtee.typ, Covariant)).right - es2 <- extractWithState(s, xtee)(es1).right + es1 <- es.updateExtractWith(t extract(xtee.typ, Covariant)) + es2 <- extractWithState(s, xtee)(es1) } yield es2 case (HOPHole(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) @@ -615,17 +623,16 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi else if (es.failedMatches(bv1) contains bv2) Left(es) else (bv1.owner, bv2.owner) match { case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value) match { - case Right(es0) => - if (es0.makeUnreachable) lb2.value = Unreachable - Right(es0 withCtx (lb1.bound, lb2.bound) withMatched lb2.bound) - case Left(es0) => Left(es0 withFailed lb1.bound -> lb2.bound) + case Right(es) => + if (es.makeUnreachable) lb2.value = Unreachable + Right(es withCtx (lb1.bound, lb2.bound) withMatched lb2.bound) + case Left(es) => Left(es withFailed lb1.bound -> lb2.bound) } - case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2).right map (_ withCtx l1.bound -> l2.bound) - case _ => Left(es) + case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2) map (_ withCtx l1.bound -> l2.bound) // TODO handle failed extract? + case _ => Left(es withFailed bv1 -> bv2) } - case (Constant(v1), Constant(v2)) if v1 == v2 => - es.updateExtract(xtor.typ extract (xtee.typ, Covariant)) + case (Constant(v1), Constant(v2)) if v1 == v2 => es updateExtractWith (xtor.typ extract(xtee.typ, Covariant)) // Assuming if they have the same name the type is the same case (StaticModule(fn1), StaticModule(fn2)) if fn1 == fn2 => Right(es) @@ -633,7 +640,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi // Assuming if they have the same name and prefix the type is the same case (Module(p1, n1, _), Module(p2, n2, _)) if n1 == n2 => extractWithState(p1, p2) - case (NewObject(t1), NewObject(t2)) => es.updateExtract(t1 extract (t2, Covariant)) + case (NewObject(t1), NewObject(t2)) => es updateExtractWith (t1 extract(t2, Covariant)) case _ => Left(es) } @@ -714,34 +721,34 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (l1: Lambda, l2: Lambda) => for { - es1 <- es.updateExtract(l1.boundType extract (l2.boundType, Covariant)).right - es2 <- extractWithState(l1.body, l2.body)(es1 withCtx l1.bound -> l2.bound).right + es1 <- es updateExtractWith (l1.boundType extract(l2.boundType, Covariant)) + es2 <- extractWithState(l1.body, l2.body)(es1 withCtx l1.bound -> l2.bound) } yield es2 case (ma1: MethodApp, ma2: MethodApp) if ma1.mtd == ma2.mtd => - def targExtract(es0: State): ExtractState = es.updateExtract(mergeAll(for { - (e1, e2) <- ma1.targs zip ma2.targs - } yield e1 extract(e2, Invariant)))(es) + def targExtract(es0: State): ExtractState = + es0.updateExtractWith( + (for { + (e1, e2) <- ma1.targs zip ma2.targs + } yield e1 extract(e2, Invariant)): _* + ) def extractArgss(argss1: ArgumentLists, argss2: ArgumentLists)(implicit es: State): ExtractState = { def extractArgss0(argss1: ArgumentLists, argss2: ArgumentLists, acc: List[Rep])(implicit es: State): ExtractState = (argss1, argss2) match { case (ArgumentListCons(h1, t1), ArgumentListCons(h2, t2)) => for { - es0 <- extractArgss0(h1, h2, acc).right - es1 <- extractArgss0(t1, t2, acc)(es0).right + es0 <- extractArgss0(h1, h2, acc) + es1 <- extractArgss0(t1, t2, acc)(es0) } yield es1 case (ArgumentCons(h1, t1), ArgumentCons(h2, t2)) => for { - es0 <- extractArgss0(h1, h2, acc).right - es1 <- extractArgss0(t1, t2, acc)(es0).right + es0 <- extractArgss0(h1, h2, acc) + es1 <- extractArgss0(t1, t2, acc)(es0) } yield es1 case (SplicedArgument(arg1), SplicedArgument(arg2)) => extractWithState(arg1, arg2) case (sa: SplicedArgument, ArgumentCons(h, t)) => extractArgss0(sa, t, h :: acc) case (sa: SplicedArgument, r: Rep) => extractArgss0(sa, NoArguments, r :: acc) - case (SplicedArgument(arg), NoArguments) => spliceExtract(arg, Args(acc.reverse: _*)) match { - case Some(ex) => merge(ex, es.ex).fold[ExtractState](Left(es))(ex => Right(es withNewExtract ex)) - case None => Left(es) - } + case (SplicedArgument(arg), NoArguments) => es updateExtractWith spliceExtract(arg, Args(acc.reverse: _*)) case (r1: Rep, r2: Rep) => extractWithState(r1, r2) case (NoArguments, NoArguments) => Right(es) case (NoArgumentLists, NoArgumentLists) => Right(es) @@ -752,10 +759,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } for { - es1 <- extractWithState(ma1.self, ma2.self)(es).right - es2 <- targExtract(es1).right - es3 <- extractArgss(ma1.argss, ma2.argss)(es2).right - es4 <- es3.updateExtract(ma1.typ extract (ma2.typ, Covariant)).right + es1 <- extractWithState(ma1.self, ma2.self) + es2 <- targExtract(es1) + es3 <- extractArgss(ma1.argss, ma2.argss)(es2) + es4 <- es3.updateExtractWith(ma1.typ extract (ma2.typ, Covariant)) } yield es4 case (DefHole(h), _) => extractWithState(h, wrapConstruct(letbind(v2))) diff --git a/src/main/scala/squid/ir/fastanf/Rep.scala b/src/main/scala/squid/ir/fastanf/Rep.scala index 160957fd..f2ba0ed7 100644 --- a/src/main/scala/squid/ir/fastanf/Rep.scala +++ b/src/main/scala/squid/ir/fastanf/Rep.scala @@ -205,6 +205,7 @@ trait RebindableBinding extends Binding { def name_= (newName: String): Unit } class LetBinding(var name: String, var bound: Symbol, var value: Def, private var _body: Rep) extends Rep with RebindableBinding { + var userDefined = false def body = _body def body_= (newBody: Rep) = _body = newBody def boundType = value.typ From 7ffe328be00fa3340411888c6337de9a6430788a Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Thu, 23 Nov 2017 17:06:40 +0100 Subject: [PATCH 15/66] Failed rewriting will not affect input --- src/main/scala/squid/ir/fastanf/Effects.scala | 1 - src/main/scala/squid/ir/fastanf/FastANF.scala | 244 ++++++++++-------- src/main/scala/squid/ir/fastanf/Rep.scala | 7 +- 3 files changed, 144 insertions(+), 108 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/Effects.scala b/src/main/scala/squid/ir/fastanf/Effects.scala index 4256f858..2e915f78 100644 --- a/src/main/scala/squid/ir/fastanf/Effects.scala +++ b/src/main/scala/squid/ir/fastanf/Effects.scala @@ -48,7 +48,6 @@ trait Effects { val mtdEff = mtdEffect(ma.mtd) selfEff |+| argssEff |+| mtdEff case DefHole(_) => Impure - case Unreachable => Pure } sealed trait Effect { diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 1187e02a..4ae7cb39 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -140,7 +140,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } s.owner |>? { case lb: LetBinding => - lb.userDefined = true + lb.isUserDefined = true } withSubs(bound, value)(body) @@ -249,7 +249,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi reinterpretType(ma.typ)) case l: Lambda => newBase.lambda(List(reinterpretBV(l.bound)), reinterpret0(l.body)) case DefHole(h) => newBase.hole(h.name, reinterpretType(h.typ)) - case Unreachable => unsupported } r match { @@ -299,8 +298,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi //new Lambda(l.name, l.bound, l.boundType, transformRepAndDef0(l.body)) l.body = l.body |> transformRepAndDef0 l - case dh: DefHole => dh - case Unreachable => Unreachable + case _ => d }) post(pre(r) match { @@ -333,7 +331,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } yield es.ex } - type Ctx = Map[BoundVal, Set[BoundVal]] + type Ctx = Map[BoundVal, BoundVal] + def reverse[A, B](m: Map[A, B]): Map[B, Set[A]] = m.groupBy(_._2).mapValues(_.keys.toSet) def updateWith(m: Map[BoundVal, Set[BoundVal]])(u: (BoundVal, BoundVal)): Map[BoundVal, Set[BoundVal]] = u match { case (k, v) => m + (k -> (m(k) + v)) } @@ -344,10 +343,14 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case class State(ex: Extract, ctx: Ctx, mks: Markers, matchedBVs: Set[BoundVal], failedMatches: Map[BoundVal, Set[BoundVal]], makeUnreachable: Boolean) { def withNewExtract(newEx: Extract): State = copy(ex = newEx) def withCtx(newCtx: Ctx): State = copy(ctx = newCtx) - def withCtx(p: (BoundVal, BoundVal)): State = copy(ctx = updateWith(ctx)(p)) + def withCtx(p: (BoundVal, BoundVal)): State = copy(ctx = ctx + p) def updateMarkers(newMks: Markers): State = copy(mks = newMks) def withoutMarkers(xtorMk: BoundVal, xteeMk: BoundVal): State = copy(mks = Markers(mks.xtor - xtorMk, mks.xtee - xteeMk)) - def withMatched(bv: BoundVal): State = copy(matchedBVs = matchedBVs + bv) + def withMatched(r: Rep): State = r match { + case lb: LetBinding => copy(matchedBVs = matchedBVs + lb.bound) withMatched lb.body + case bv: BoundVal => copy(matchedBVs = matchedBVs + bv) + case _ => this + } def withFailed(p: (BoundVal, BoundVal)): State = copy(failedMatches = updateWith(failedMatches)(p)) def updateExtractWith(e: Option[Extract]*)(implicit default: State): ExtractState = { @@ -359,7 +362,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def forExtraction(xtor: Rep, xtee: Rep): State = State(xtor, xtee, false) private def apply(xtor: Rep, xtee: Rep, makeUnreachable: Bool): State = - State(EmptyExtract, ListMap.empty.withDefaultValue(Set.empty), Markers(xtor, xtee), Set.empty, Map.empty.withDefaultValue(Set.empty), makeUnreachable) + State(EmptyExtract, ListMap.empty, Markers(xtor, xtee), Set.empty, Map.empty.withDefaultValue(Set.empty), makeUnreachable) } sealed trait Marker @@ -381,8 +384,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (acc, bv: BoundVal) => acc - bv case (acc, _) => acc // Assuming no LBs in self or argument positions } - case _: DefHole => ListSet.empty - case Unreachable => ListSet.empty + case _ => ListSet.empty } def extractionStarts0(r: Rep, acc: ListSet[BoundVal]): ListSet[BoundVal] = r match { @@ -399,13 +401,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } def extractWithState(xtor: Rep, xtee: Rep)(implicit es: State): ExtractState = { - - //println(s"$xtor\n$xtee with $ctx\n\n") - def reverse[A, B](m: Map[A, Set[B]]): Map[B, A] = for { - (a, bs) <- m - b <- bs - } yield b -> a - def extractHOPHole(name: String, typ: TypeRep, argss: List[List[Rep]], visible: List[BoundVal])(implicit es: State): ExtractState = { println("EXTRACTINGHOPHOLE") type Func = List[List[BoundVal]] -> Rep @@ -413,13 +408,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def fargss(f: Func) = f._1 def fbody(f: Func) = f._2 - val ctx0 = es.ctx.mapValues(_.head) - val invCtx = reverse(es.ctx) - - println(s"INVCTX: $invCtx") - println(s"VISIBLE: $visible") - bottomUpPartial(xtee) { case bv: BoundVal if visible contains invCtx.getOrElse(bv, bv) => return Left(es) } - def changeRepBVs(r: Rep)(f: BoundVal => Option[BoundVal]): Rep = r match { case bv: BoundVal => f(bv) match { @@ -460,15 +448,14 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } def extendFunc(args: List[Rep], f: Option[Func]): Option[Func] = { - println(s"ARGS: $args") val args0 = args.map { case bv: BoundVal => bv.owner match { case lb: LetBinding => lb.body = lb.bound // Cut out the HOPHole - changeRepBVs(bv)(ctx0.get) - case _ => ctx0.getOrElse(bv, bv) + changeRepBVs(bv)(es.ctx.get) + case _ => es.ctx.getOrElse(bv, bv) } - case arg => changeRepBVs(arg)(ctx0.get) + case arg => changeRepBVs(arg)(es.ctx.get) } // Traverse bottom up (parent, etc) println(s"ARGS0: $args0") @@ -484,9 +471,16 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi _ = println(s"BEFORE $f") body0 <- transformations.foldLeft(Option(fbody(f))) { case (Some(body), (bv: BoundVal, res)) => Some(changeRepBVs(body)(bv0 => Some(if (bv == bv0) res else bv0))) - case (Some(body), (xtor, res)) => rewriteRep0(xtor, body, _ => Some(res))(State.forRewriting(xtor, body), true) + case (Some(body), (xtor, res)) => rewriteRep0(xtor, body, _ => Some(res))(true)(State.forRewriting(xtor, body)) case _ => None } + invCtx = reverse(es.ctx) + _ = bottomUpPartial(body0) { case bv: BoundVal => + (for { + bvs <- invCtx.get(bv) + if bvs exists (visible contains _) + } yield return None) getOrElse bv + } } yield (xs :: fargss(f)) -> body0 println(s"AFTER: $after") @@ -508,7 +502,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (Impure, Impure) => extractDefs(lb1.value, lb2.value) match { case Right(es) => - if (es.makeUnreachable) lb2.value = Unreachable + if (es.makeUnreachable) lb2.isUnreachable = true extractWithState(lb1.body, lb2.body)(es withCtx lb1.bound -> lb2.bound withMatched lb2.bound) case Left(es) => Left(es withFailed lb1.bound -> lb2.bound) } @@ -517,7 +511,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (EndPoint, EndPoint) => extractWithState(lb1.bound, lb2.bound) match { case Right(es) => - if (es.makeUnreachable) lb2.value = Unreachable extractWithState(lb1.body, lb2.body)(es withoutMarkers(lb1.bound, lb2.bound) withMatched lb2.bound) case Left(es) => extractWithState(lb1, lb2.body)(es) } @@ -534,12 +527,15 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def extractHole(h: Hole, r: Rep)(implicit es: State): ExtractState = { println(s"ExtractHole: $h --> $r") + def makeUnreachable(r: Rep): Rep = r match { + case lb: LetBinding => + lb.isUnreachable = true + makeUnreachable(lb.body) + case _ => r + } + (h, r) match { case (Hole(n, t), bv: BoundVal) => - //val r = bv.owner match { - // case lb: LetBinding => new LetBinding(lb.name, lb.bound, lb.value, lb.bound) // flag - // case _ => bv - //} es.updateExtractWith( t extract (xtee.typ, Covariant), Some(repExtract(n -> bv)) @@ -547,9 +543,12 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (Hole(n, t), lb: LetBinding) => es.updateExtractWith( - t extract(lb.typ, Covariant), + t extract (lb.typ, Covariant), Some(repExtract(n -> wrapConstruct(letbind(lb.value)))) - ).map(_ withMatched lb.bound) + ).map { + makeUnreachable(lb) + _ withMatched lb + } case (Hole(n, t), _) => es.updateExtractWith( @@ -597,15 +596,20 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (lb1: LetBinding, lb2: LetBinding) => extractLBs(lb1, lb2) // Stop at markers? - case (lb: LetBinding, _: Rep) => extractWithState(lb.body, xtee) - + case (lb: LetBinding, _: Rep) => (effect(lb), effect(xtee)) match { + case (Pure, Pure) => extractWithState(lb.body, xtee) + case (Impure, Pure) => Left(es) + case (_, Impure) => Left(es) // Assuming the return value cannot be impure + } + case (bv: BoundVal, lb: LetBinding) => for { es1 <- extractWithState(bv, lb.bound).left es2 <- extractInside(bv, lb.value)(es1).left es3 <- extractWithState(bv, lb.body)(es2).left } yield es3 - case (_: Rep, lb: LetBinding) if lb.value == Unreachable => extractWithState(xtor, lb.body) + case (_: Rep, lb: LetBinding) if lb.isUnreachable => extractWithState(xtor, lb.body) + //case (_: Rep, lb: LetBinding) if es.matchedBVs contains lb.bound => extractWithState(xtor, lb.body) case (_, Ascribe(s, _)) => extractWithState(xtor, s) @@ -619,13 +623,13 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (bv1: BoundVal, bv2: BoundVal) => println(s"EXTRACTIONSTATE IN BV: $es") println(s"OWNERS: ${bv1.owner} -- ${bv2.owner}") - if (bv1 == bv2 || (es.ctx(bv1) contains bv2)) Right(es) + if (es.ctx.getOrElse(bv1, bv1) == bv2) Right(es) else if (es.failedMatches(bv1) contains bv2) Left(es) else (bv1.owner, bv2.owner) match { case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value) match { case Right(es) => - if (es.makeUnreachable) lb2.value = Unreachable - Right(es withCtx (lb1.bound, lb2.bound) withMatched lb2.bound) + if (es.makeUnreachable) lb2.isUnreachable = true + Right(es withCtx lb1.bound -> lb2.bound withMatched lb2.bound) case Left(es) => Left(es withFailed lb1.bound -> lb2.bound) } case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2) map (_ withCtx l1.bound -> l2.bound) // TODO handle failed extract? @@ -665,60 +669,9 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => throw IRException(s"Trying to splice-extract with invalid extractor $xtor") } - override def rewriteRep(xtor: Rep, xtee: Rep, code: Extract => Option[Rep]): Option[Rep] = - rewriteRep0(xtor, xtee, code)(State.forRewriting(xtor, xtee), false) - - def rewriteRep0(xtor: Rep, xtee: Rep, code: Extract => Option[Rep])(es: State, internalRec: Boolean): Option[Rep] = { - def rewriteRepWithState(xtor: Rep, xtee: Rep)(implicit es: State): ExtractState = { - println(s"rewriteRepWithState(\n\t$xtor\n\t$xtee)($es)") - - (xtor, xtee) match { - case (lb1: LetBinding, lb2: LetBinding) if !internalRec => ((effect(lb1), es.mks.xtorMarker(lb1.bound)), (effect(lb2), es.mks.xteeMarker(lb2.bound))) match { - case ((Pure, NonEndPoint), (Pure, NonEndPoint)) => Left(es) - case _ => extractWithState(lb1, lb2) // TODO With unreachable handling TODO why did I mean? - } - case _ => extractWithState(xtor, xtee) - } - } - - def genCode(es: State): Option[Rep] = { - def check(matchedBVs: Set[BoundVal])(r: Rep): Boolean = r match { - case lb: LetBinding => checkDef(matchedBVs)(lb.value) - case bv: BoundVal => !(matchedBVs contains bv) - case _ => true - } - - def checkDef(matchedBVs: Set[BoundVal])(d: Def): Boolean = d match { - case ma: MethodApp => (ma.self :: ma.argss.argssList).foldLeft(true) { - case (checks, bv: BoundVal) => checks && !(matchedBVs contains bv) - case (checks, lb: LetBinding) => checks && check(matchedBVs)(lb) - case (checks, _) => true - } - case l: Lambda => !(matchedBVs contains l.bound) && check(matchedBVs)(l.body) - case _ => true - } - - for { - code <- code(es.ex) - if check(es.matchedBVs)(cleanup(code)) - } yield code - } - - rewriteRepWithState(xtor, xtee)(es) match { - case Right(es) => genCode(es) - case Left(_) => None - } - - // flatMap genCode alsoApply (c => println(s"Code: $c")) - } - def extractDefs(v1: Def, v2: Def)(implicit es: State): ExtractState = { println(s"VALUES: \n\t$v1\n\t$v2 with $es \n\n") (v1, v2) match { - // Has already been matched... - case (_, Unreachable) => Left(es) - //case (Unreachable, _) => Some(es) - case (l1: Lambda, l2: Lambda) => for { es1 <- es updateExtractWith (l1.boundType extract(l2.boundType, Covariant)) @@ -771,14 +724,101 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } } - def cleanup(r: Rep): Rep = { - println("CLEANING!!") - r match { - case lb: LetBinding if lb.value == Unreachable => cleanup(lb.body) - case lb: LetBinding => - lb.body = cleanup(lb.body) - lb - case _ => r + override def rewriteRep(xtor: Rep, xtee: Rep, code: Extract => Option[Rep]): Option[Rep] = { + println(s"Again") + rewriteRep0(xtor, xtee, code)(false)(State.forRewriting(xtor, xtee)) + } + + def rewriteRep0(xtor: Rep, xtee: Rep, code: Extract => Option[Rep])(internalRec: Boolean)(implicit es: State): Option[Rep] = { + def rewriteRepWithState(xtor: Rep, xtee: Rep)(implicit es: State): ExtractState = { + println(s"rewriteRepWithState(\n\t$xtor\n\t$xtee)($es)") + + (xtor, xtee) match { + case (lb1: LetBinding, lb2: LetBinding) if !internalRec => ((effect(lb1), es.mks.xtorMarker(lb1.bound)), (effect(lb2), es.mks.xteeMarker(lb2.bound))) match { + case ((Pure, NonEndPoint), (Pure, NonEndPoint)) => Left(es) + case _ => extractWithState(lb1, lb2) // TODO With unreachable handling TODO why did I mean? + } + case _ => extractWithState(xtor, xtee) + } + } + + def genCode(implicit es: State): Option[Rep] = { + def preCheck(ex: Extract): Boolean = { + def preCheckRep(declaredBVs: Set[BoundVal], invCtx: Map[BoundVal, Set[BoundVal]], r: Rep): Boolean = { + def preCheckDef(declaredBVs: Set[BoundVal], invCtx: Map[BoundVal, Set[BoundVal]], d: Def): Boolean = { + d match { + case l: Lambda => preCheckRep(declaredBVs, invCtx, l.body) + case ma: MethodApp => (ma.self :: ma.argss.argssList) forall { + case bv: BoundVal => + (declaredBVs contains bv) || + ((for { + bvs <- invCtx.get(bv) + isUserDefined = bvs map (_.owner) forall { + case lb: LetBinding => lb.isUserDefined + case _ => true + } + } yield isUserDefined) getOrElse false) + case lb: LetBinding => preCheckRep(declaredBVs, invCtx, lb) + case _ => true + } + case _ => true + } + } + + r match { + case lb: LetBinding => + val acc0 = declaredBVs + lb.bound + preCheckDef(acc0, invCtx, lb.value) && preCheckRep(acc0, invCtx, lb.body) + case _ => true + } + } + + val invCtx = reverse(es.ctx) + (ex._1.values ++ ex._3.values.flatten).forall(preCheckRep(Set.empty, invCtx, _)) + } + + def check(declaredBVs: Set[BoundVal], matchedBVs: Set[BoundVal])(r: Rep): Boolean = { + def checkDef(declaredBVs: Set[BoundVal], matchedBVs: Set[BoundVal])(d: Def): Boolean = d match { + case ma: MethodApp => (ma.self :: ma.argss.argssList) forall { + case bv: BoundVal => (declaredBVs contains bv) || !(matchedBVs contains bv) + case lb: LetBinding => check(declaredBVs + lb.bound, matchedBVs)(lb) + case _ => true + } + case l: Lambda => + ((declaredBVs contains l.bound) || + !(matchedBVs contains l.bound)) && + check(declaredBVs, matchedBVs)(l.body) + case _ => true + } + + r match { + case lb: LetBinding => checkDef(declaredBVs + lb.bound, matchedBVs)(lb.value) + case bv: BoundVal => (declaredBVs contains bv) || !(matchedBVs contains bv) + case _ => true + } + } + + + def cleanup(r: Rep)(implicit es: State): Rep = r match { + case lb: LetBinding if lb.isUnreachable => cleanup(lb.body) + case lb: LetBinding => + lb.body = cleanup(lb.body) + lb + case _ => r + } + + if (preCheck(es.ex) alsoApply println) + for { + code <- code(es.ex) + _ = println(code) + if check(Set.empty, es.matchedBVs)(cleanup(code) alsoApply println) alsoApply println + } yield code + else None + } + + rewriteRepWithState(xtor, xtee) match { + case Right(es) => genCode(es) alsoApply println + case Left(_) => None } } diff --git a/src/main/scala/squid/ir/fastanf/Rep.scala b/src/main/scala/squid/ir/fastanf/Rep.scala index f2ba0ed7..f1dd4c10 100644 --- a/src/main/scala/squid/ir/fastanf/Rep.scala +++ b/src/main/scala/squid/ir/fastanf/Rep.scala @@ -40,10 +40,6 @@ sealed abstract class Def extends DefOption with DefOrTypeRep with FlatSom[Def] def map(f: Def => Def): Def = f(this) } -case object Unreachable extends Def { - override def typ = ??? -} - /** Expression that can be used as an argument or result; this includes let bindings. */ sealed abstract class Rep extends RepOption with ArgumentList with FlatSom[Rep] { def typ: TypeRep @@ -205,7 +201,8 @@ trait RebindableBinding extends Binding { def name_= (newName: String): Unit } class LetBinding(var name: String, var bound: Symbol, var value: Def, private var _body: Rep) extends Rep with RebindableBinding { - var userDefined = false + var isUserDefined = false + var isUnreachable = false def body = _body def body_= (newBody: Rep) = _body = newBody def boundType = value.typ From 9fcb174cf089a98b08582bc5d8dbd793f702d857 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Fri, 24 Nov 2017 15:56:34 +0100 Subject: [PATCH 16/66] Make HOPHole args by-name args --- core/src/main/scala/squid/quasi/QuasiEmbedder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/squid/quasi/QuasiEmbedder.scala b/core/src/main/scala/squid/quasi/QuasiEmbedder.scala index c9cc12a6..85678543 100644 --- a/core/src/main/scala/squid/quasi/QuasiEmbedder.scala +++ b/core/src/main/scala/squid/quasi/QuasiEmbedder.scala @@ -496,7 +496,7 @@ class QuasiEmbedder[C <: whitebox.Context](val c: C) { val hopvType = FunctionType(args map (_.tpe) : _*)(holeType) termHoleInfo(termName) = Map() -> hopvType - val largs = args map (liftTerm(_,x,None,false)) + val largs = args map (arg => b.byName(liftTerm(arg,x,None,false))) b.hopHole2(name, liftType(holeType), (largs)::Nil, ctx.values.toList) From e7372d5a164c61404dc5defd7f37384aa5f34897 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Fri, 24 Nov 2017 18:15:55 +0100 Subject: [PATCH 17/66] Corrected HOPHole and code removal semantics --- src/main/scala/squid/ir/fastanf/FastANF.scala | 186 +++++++++--------- src/main/scala/squid/ir/fastanf/Rep.scala | 1 - 2 files changed, 97 insertions(+), 90 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 4ae7cb39..51dcd06b 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -340,15 +340,16 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi type ExtractState = Either[State, State] implicit def rightBias[A, B](e: Either[A, B]): Either.RightProjection[A,B] = e.right - case class State(ex: Extract, ctx: Ctx, mks: Markers, matchedBVs: Set[BoundVal], failedMatches: Map[BoundVal, Set[BoundVal]], makeUnreachable: Boolean) { + case class State(ex: Extract, ctx: Ctx, mks: Markers, matchedImpureBVs: Set[BoundVal], failedMatches: Map[BoundVal, Set[BoundVal]], makeUnreachable: Boolean) { def withNewExtract(newEx: Extract): State = copy(ex = newEx) def withCtx(newCtx: Ctx): State = copy(ctx = newCtx) def withCtx(p: (BoundVal, BoundVal)): State = copy(ctx = ctx + p) def updateMarkers(newMks: Markers): State = copy(mks = newMks) def withoutMarkers(xtorMk: BoundVal, xteeMk: BoundVal): State = copy(mks = Markers(mks.xtor - xtorMk, mks.xtee - xteeMk)) - def withMatched(r: Rep): State = r match { - case lb: LetBinding => copy(matchedBVs = matchedBVs + lb.bound) withMatched lb.body - case bv: BoundVal => copy(matchedBVs = matchedBVs + bv) + def withMatchedImpures(r: Rep): State = r match { + case lb: LetBinding if isPure(lb.body) => copy(matchedImpureBVs = matchedImpureBVs + lb.bound) withMatchedImpures lb.body + case lb: LetBinding => this withMatchedImpures lb.body + case bv: BoundVal => copy(matchedImpureBVs = matchedImpureBVs + bv) case _ => this } def withFailed(p: (BoundVal, BoundVal)): State = copy(failedMatches = updateWith(failedMatches)(p)) @@ -408,80 +409,84 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def fargss(f: Func) = f._1 def fbody(f: Func) = f._2 - def changeRepBVs(r: Rep)(f: BoundVal => Option[BoundVal]): Rep = r match { - case bv: BoundVal => - f(bv) match { - case Some(bv0) => bv0 - case None => bv.owner match { - case lb: LetBinding => - lb.value = changeDefBVs(lb.value)(f) - lb - case _ => r - } - } + def replaceBVs(r: Rep)(f: BoundVal => BoundVal): Rep = r match { + case bv: BoundVal => f(bv) case lb: LetBinding => - lb.value = changeDefBVs(lb.value)(f) + lb.value = replaceBVsInDef(lb.value)(f) lb - case _ => r } - - def changeDefBVs(d: Def)(f: BoundVal => Option[BoundVal]): Def = d match { - case ma: MethodApp => - MethodApp( - changeRepBVs(ma.self)(f), - ma.mtd, - ma.targs, - ma.argss.argssMap(r => changeRepBVs(r)(f)), - ma.typ - ) - case l: Lambda => - new Lambda( - l.name, - l.bound, - l.boundType, - changeRepBVs(l.body)(f) - ) + def replaceBVsInDef(d: Def)(f: BoundVal => BoundVal): Def = { + println(s"CHANGEDDEF: $d") + d match { + case ma: MethodApp => + MethodApp( + replaceBVs(ma.self)(f), + ma.mtd, + ma.targs, + ma.argss.argssMap(r => replaceBVs(r)(f)), + ma.typ + ) + + case l: Lambda => + new Lambda( + l.name, + l.bound, + l.boundType, + replaceBVs(l.body)(f) + ) + + case _ => d + } + } + + def hasUndeclaredBVs(r: Rep): Boolean = { + println(s"Checking $r") + def hasUndeclaredBVs0(r: Rep, declared: Set[BoundVal]): Boolean = r match { + case bv: BoundVal => !(declared contains bv) + case lb: LetBinding => + val declared0 = declared + lb.bound + hasUndeclaredBVsinDef(lb.value, declared0) || hasUndeclaredBVs0(lb.body, declared0) + case _ => false + } + + def hasUndeclaredBVsinDef(d: Def, declared: Set[BoundVal]): Boolean = d match { + case l: Lambda => hasUndeclaredBVs0(l.body, declared + l.bound) + case ma: MethodApp => (ma.self +: ma.argss.argssList) exists (hasUndeclaredBVs0(_, declared)) + case _ => false + } - case _ => d + hasUndeclaredBVs0(r, Set.empty) } - def extendFunc(args: List[Rep], f: Option[Func]): Option[Func] = { - val args0 = args.map { - case bv: BoundVal => bv.owner match { - case lb: LetBinding => - lb.body = lb.bound // Cut out the HOPHole - changeRepBVs(bv)(es.ctx.get) - case _ => es.ctx.getOrElse(bv, bv) - } - case arg => changeRepBVs(arg)(es.ctx.get) - } // Traverse bottom up (parent, etc) + def extendFunc(args: List[Rep], maybeFunc: Option[Func]): Option[Func] = { + val xteeifiedArgs = args.map(arg => replaceBVs(arg)(bv => es.ctx.getOrElse(bv, bv))) - println(s"ARGS0: $args0") + println(s"ARGS0: $xteeifiedArgs") - val xs = args.map(arg => bindVal("hopArg", arg.typ, Nil)) + val hopArgs = args.map(arg => bindVal("hopArg", arg.typ, Nil)) - val transformations = (args0 zip xs).toMap + val transformations = xteeifiedArgs zip hopArgs println(s"Transformation: $transformations") val after = for { - f <- f + f <- maybeFunc _ = println(s"BEFORE $f") - body0 <- transformations.foldLeft(Option(fbody(f))) { - case (Some(body), (bv: BoundVal, res)) => Some(changeRepBVs(body)(bv0 => Some(if (bv == bv0) res else bv0))) - case (Some(body), (xtor, res)) => rewriteRep0(xtor, body, _ => Some(res))(true)(State.forRewriting(xtor, body)) - case _ => None + body0 = transformations.foldLeft(fbody(f)) { + case (body, (bv: BoundVal, hopArg)) => + println("bv") + replaceBVs(body){ bv0 => if (bv0 == bv) hopArg else bv0 } alsoApply (res => println(s"BLA $res")) + case (body, (xtor, hopArg)) => rewriteRep0(xtor, body, _ => Some(hopArg))(true)(State.forRewriting(xtor, body)) getOrElse body } + _ = println(s"PASSED $body0") invCtx = reverse(es.ctx) - _ = bottomUpPartial(body0) { case bv: BoundVal => - (for { - bvs <- invCtx.get(bv) - if bvs exists (visible contains _) - } yield return None) getOrElse bv - } - } yield (xs :: fargss(f)) -> body0 + _ = println(s"INVCTX: $invCtx") + _ = println(s"VISIBLE: $visible") + _ = bottomUpPartial(body0) { case bv: BoundVal if visible contains bv => return None } + _ = println("OOPS") + } yield (hopArgs :: fargss(f)) -> body0 println(s"AFTER: $after") @@ -494,6 +499,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi f <- argss.foldRight(Option(emptyFunc(xtee)))(extendFunc) _ = println(s"F: $f") l = fargss(f).foldRight(fbody(f)) { case (args, body) => wrapConstruct(lambda(args, body)) } + if !hasUndeclaredBVs(l) + _ = println(s"GOT THORUGHT") } yield repExtract(name -> l) ) } @@ -502,8 +509,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (Impure, Impure) => extractDefs(lb1.value, lb2.value) match { case Right(es) => - if (es.makeUnreachable) lb2.isUnreachable = true - extractWithState(lb1.body, lb2.body)(es withCtx lb1.bound -> lb2.bound withMatched lb2.bound) + //if (es.makeUnreachable) lb2.isUnreachable = true + extractWithState(lb1.body, lb2.body)(es withCtx lb1.bound -> lb2.bound withMatchedImpures lb2.bound) case Left(es) => Left(es withFailed lb1.bound -> lb2.bound) } @@ -511,13 +518,12 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (EndPoint, EndPoint) => extractWithState(lb1.bound, lb2.bound) match { case Right(es) => - extractWithState(lb1.body, lb2.body)(es withoutMarkers(lb1.bound, lb2.bound) withMatched lb2.bound) + extractWithState(lb1.body, lb2.body)(es withoutMarkers(lb1.bound, lb2.bound)) // withMatched lb2.bound) case Left(es) => extractWithState(lb1, lb2.body)(es) } case (EndPoint, NonEndPoint) => extractWithState(lb1, lb2.body) - case (NonEndPoint, _) => extractWithState(lb1.body, lb2) - //case (NonEndPoint, EndPoint) => extractWithState(lb1.body, lb2) - //case (NonEndPoint, NonEndPoint) => extractWithState(lb1.body, lb2) + case (NonEndPoint, EndPoint) => extractWithState(lb1.body, lb2) + case (NonEndPoint, NonEndPoint) => extractWithState(lb1.body, lb2.body) } case (Impure, Pure) => extractWithState(lb1, lb2.body) @@ -527,27 +533,27 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def extractHole(h: Hole, r: Rep)(implicit es: State): ExtractState = { println(s"ExtractHole: $h --> $r") - def makeUnreachable(r: Rep): Rep = r match { - case lb: LetBinding => - lb.isUnreachable = true - makeUnreachable(lb.body) - case _ => r - } + //def makeUnreachable(r: Rep): Rep = r match { + // case lb: LetBinding => + // //lb.isUnreachable = true + // makeUnreachable(lb.body) + // case _ => r + //} (h, r) match { case (Hole(n, t), bv: BoundVal) => es.updateExtractWith( t extract (xtee.typ, Covariant), Some(repExtract(n -> bv)) - ).map(_ withMatched bv) + ).map(_ withMatchedImpures bv) case (Hole(n, t), lb: LetBinding) => es.updateExtractWith( t extract (lb.typ, Covariant), Some(repExtract(n -> wrapConstruct(letbind(lb.value)))) ).map { - makeUnreachable(lb) - _ withMatched lb + //makeUnreachable(lb) + _ withMatchedImpures lb } case (Hole(n, t), _) => @@ -602,13 +608,15 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (_, Impure) => Left(es) // Assuming the return value cannot be impure } - case (bv: BoundVal, lb: LetBinding) => for { - es1 <- extractWithState(bv, lb.bound).left - es2 <- extractInside(bv, lb.value)(es1).left - es3 <- extractWithState(bv, lb.body)(es2).left - } yield es3 + case (bv: BoundVal, lb: LetBinding) => + println("TRING") + for { + es1 <- extractWithState(bv, lb.bound).left + es2 <- extractInside(bv, lb.value)(es1).left + es3 <- extractWithState(bv, lb.body)(es2).left + } yield es3 - case (_: Rep, lb: LetBinding) if lb.isUnreachable => extractWithState(xtor, lb.body) + case (_: Rep, lb: LetBinding) if es.matchedImpureBVs contains lb.bound => extractWithState(xtor, lb.body) //case (_: Rep, lb: LetBinding) if es.matchedBVs contains lb.bound => extractWithState(xtor, lb.body) case (_, Ascribe(s, _)) => extractWithState(xtor, s) @@ -628,8 +636,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi else (bv1.owner, bv2.owner) match { case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value) match { case Right(es) => - if (es.makeUnreachable) lb2.isUnreachable = true - Right(es withCtx lb1.bound -> lb2.bound withMatched lb2.bound) + //if (es.makeUnreachable) lb2.isUnreachable = true + Right(es withCtx lb1.bound -> lb2.bound withMatchedImpures lb2.bound) case Left(es) => Left(es withFailed lb1.bound -> lb2.bound) } case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2) map (_ withCtx l1.bound -> l2.bound) // TODO handle failed extract? @@ -734,7 +742,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi println(s"rewriteRepWithState(\n\t$xtor\n\t$xtee)($es)") (xtor, xtee) match { - case (lb1: LetBinding, lb2: LetBinding) if !internalRec => ((effect(lb1), es.mks.xtorMarker(lb1.bound)), (effect(lb2), es.mks.xteeMarker(lb2.bound))) match { + case (lb1: LetBinding, lb2: LetBinding) if !internalRec => ((effect(lb1.value), es.mks.xtorMarker(lb1.bound)), (effect(lb2.value), es.mks.xteeMarker(lb2.bound))) match { case ((Pure, NonEndPoint), (Pure, NonEndPoint)) => Left(es) case _ => extractWithState(lb1, lb2) // TODO With unreachable handling TODO why did I mean? } @@ -777,7 +785,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi (ex._1.values ++ ex._3.values.flatten).forall(preCheckRep(Set.empty, invCtx, _)) } - def check(declaredBVs: Set[BoundVal], matchedBVs: Set[BoundVal])(r: Rep): Boolean = { + def check(declaredBVs: Set[BoundVal], matchedImpureBVs: Set[BoundVal])(r: Rep): Boolean = { def checkDef(declaredBVs: Set[BoundVal], matchedBVs: Set[BoundVal])(d: Def): Boolean = d match { case ma: MethodApp => (ma.self :: ma.argss.argssList) forall { case bv: BoundVal => (declaredBVs contains bv) || !(matchedBVs contains bv) @@ -792,15 +800,15 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } r match { - case lb: LetBinding => checkDef(declaredBVs + lb.bound, matchedBVs)(lb.value) - case bv: BoundVal => (declaredBVs contains bv) || !(matchedBVs contains bv) + case lb: LetBinding => checkDef(declaredBVs + lb.bound, matchedImpureBVs)(lb.value) + case bv: BoundVal => (declaredBVs contains bv) || !(matchedImpureBVs contains bv) case _ => true } } def cleanup(r: Rep)(implicit es: State): Rep = r match { - case lb: LetBinding if lb.isUnreachable => cleanup(lb.body) + case lb: LetBinding if es.matchedImpureBVs contains lb.bound => cleanup(lb.body) case lb: LetBinding => lb.body = cleanup(lb.body) lb @@ -811,7 +819,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi for { code <- code(es.ex) _ = println(code) - if check(Set.empty, es.matchedBVs)(cleanup(code) alsoApply println) alsoApply println + if check(Set.empty, es.matchedImpureBVs)(cleanup(code) alsoApply println) alsoApply println } yield code else None } diff --git a/src/main/scala/squid/ir/fastanf/Rep.scala b/src/main/scala/squid/ir/fastanf/Rep.scala index f1dd4c10..076b3958 100644 --- a/src/main/scala/squid/ir/fastanf/Rep.scala +++ b/src/main/scala/squid/ir/fastanf/Rep.scala @@ -202,7 +202,6 @@ trait RebindableBinding extends Binding { } class LetBinding(var name: String, var bound: Symbol, var value: Def, private var _body: Rep) extends Rep with RebindableBinding { var isUserDefined = false - var isUnreachable = false def body = _body def body_= (newBody: Rep) = _body = newBody def boundType = value.typ From 685599780a58686e9f50873578b04a3bc4abf50d Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Sat, 25 Nov 2017 21:50:32 +0100 Subject: [PATCH 18/66] Fix HOPHoles --- src/main/scala/squid/ir/fastanf/FastANF.scala | 92 ++++++------------- .../fastir/HigherOrderPatternVariables.scala | 75 +++++++++------ 2 files changed, 75 insertions(+), 92 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 51dcd06b..5eeb7a7d 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -408,39 +408,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def emptyFunc(r: Rep) = List.empty[List[BoundVal]] -> r def fargss(f: Func) = f._1 def fbody(f: Func) = f._2 - - def replaceBVs(r: Rep)(f: BoundVal => BoundVal): Rep = r match { - case bv: BoundVal => f(bv) - case lb: LetBinding => - lb.value = replaceBVsInDef(lb.value)(f) - lb - case _ => r - } - - def replaceBVsInDef(d: Def)(f: BoundVal => BoundVal): Def = { - println(s"CHANGEDDEF: $d") - d match { - case ma: MethodApp => - MethodApp( - replaceBVs(ma.self)(f), - ma.mtd, - ma.targs, - ma.argss.argssMap(r => replaceBVs(r)(f)), - ma.typ - ) - - case l: Lambda => - new Lambda( - l.name, - l.bound, - l.boundType, - replaceBVs(l.body)(f) - ) - - case _ => d - } - } - + def hasUndeclaredBVs(r: Rep): Boolean = { println(s"Checking $r") def hasUndeclaredBVs0(r: Rep, declared: Set[BoundVal]): Boolean = r match { @@ -461,13 +429,11 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } def extendFunc(args: List[Rep], maybeFunc: Option[Func]): Option[Func] = { - val xteeifiedArgs = args.map(arg => replaceBVs(arg)(bv => es.ctx.getOrElse(bv, bv))) - - println(s"ARGS0: $xteeifiedArgs") + println(s"ARGS0: $args") val hopArgs = args.map(arg => bindVal("hopArg", arg.typ, Nil)) - val transformations = xteeifiedArgs zip hopArgs + val transformations = args zip hopArgs println(s"Transformation: $transformations") @@ -475,17 +441,18 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi f <- maybeFunc _ = println(s"BEFORE $f") body0 = transformations.foldLeft(fbody(f)) { - case (body, (bv: BoundVal, hopArg)) => - println("bv") - replaceBVs(body){ bv0 => if (bv0 == bv) hopArg else bv0 } alsoApply (res => println(s"BLA $res")) - case (body, (xtor, hopArg)) => rewriteRep0(xtor, body, _ => Some(hopArg))(true)(State.forRewriting(xtor, body)) getOrElse body + case (body, (bv: BoundVal, hopArg)) => + val replace = es.ctx(bv) + bottomUpPartial(body){ case `replace` => hopArg } + + case (body, (lb: LetBinding, hopArg)) => extractWithState(lb, body) map { es => + val replace = es.ctx(lb.last.bound) + bottomUpPartial(filterLBs(body)(es.ctx.values.toSet contains _.bound)) { case `replace` => hopArg } + } getOrElse body + + case (body, (r, hopArg)) => bottomUpPartial(body) { case `r` => hopArg } } - _ = println(s"PASSED $body0") - invCtx = reverse(es.ctx) - _ = println(s"INVCTX: $invCtx") - _ = println(s"VISIBLE: $visible") _ = bottomUpPartial(body0) { case bv: BoundVal if visible contains bv => return None } - _ = println("OOPS") } yield (hopArgs :: fargss(f)) -> body0 println(s"AFTER: $after") @@ -500,7 +467,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi _ = println(s"F: $f") l = fargss(f).foldRight(fbody(f)) { case (args, body) => wrapConstruct(lambda(args, body)) } if !hasUndeclaredBVs(l) - _ = println(s"GOT THORUGHT") } yield repExtract(name -> l) ) } @@ -608,13 +574,11 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (_, Impure) => Left(es) // Assuming the return value cannot be impure } - case (bv: BoundVal, lb: LetBinding) => - println("TRING") - for { - es1 <- extractWithState(bv, lb.bound).left - es2 <- extractInside(bv, lb.value)(es1).left - es3 <- extractWithState(bv, lb.body)(es2).left - } yield es3 + case (bv: BoundVal, lb: LetBinding) => for { + es1 <- extractWithState(bv, lb.bound).left + es2 <- extractInside(bv, lb.value)(es1).left + es3 <- extractWithState(bv, lb.body)(es2).left + } yield es3 case (_: Rep, lb: LetBinding) if es.matchedImpureBVs contains lb.bound => extractWithState(xtor, lb.body) //case (_: Rep, lb: LetBinding) if es.matchedBVs contains lb.bound => extractWithState(xtor, lb.body) @@ -805,21 +769,12 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => true } } - - - def cleanup(r: Rep)(implicit es: State): Rep = r match { - case lb: LetBinding if es.matchedImpureBVs contains lb.bound => cleanup(lb.body) - case lb: LetBinding => - lb.body = cleanup(lb.body) - lb - case _ => r - } if (preCheck(es.ex) alsoApply println) for { code <- code(es.ex) _ = println(code) - if check(Set.empty, es.matchedImpureBVs)(cleanup(code) alsoApply println) alsoApply println + if check(Set.empty, es.matchedImpureBVs)(filterLBs(code)(es.matchedImpureBVs contains _.bound) alsoApply println) alsoApply println } yield code else None } @@ -829,6 +784,15 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case Left(_) => None } } + + def filterLBs(r: Rep)(p: LetBinding => Boolean): Rep = r match { + case lb: LetBinding if p(lb) => + filterLBs(lb.body)(p) + case lb: LetBinding => + lb.body = filterLBs(lb.body)(p) + lb + case _ => r + } // * --- * --- * --- * Implementations of `QuasiBase` methods * --- * --- * --- * diff --git a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala index f9298dbc..3a7d9a93 100644 --- a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala +++ b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala @@ -2,37 +2,62 @@ package squid package ir package fastir +import scala.util.Try + class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVariables.Embedding) { import HigherOrderPatternVariables.Embedding.Predef._ test("Matching lambda bodies") { - val id = ir"(z:Int) => z" - ir"(a: Int) => a + 1" matches { + ir"(a: Int) => a + 1" match { case ir"(x: Int) => $body(x): Int" => assert(body == ir"(_:Int) + 1") + } + + ir"(a: Int) => a + 1" match { case ir"(x: Int) => $body(x):$t" => assert(body == ir"(_:Int)+1") + } + + ir"(a: Int) => a + 1" match { case ir"(x: Int) => ($exp(x):Int)+1" => assert(exp == id) } - ir"(a: Int, b: Int) => a + 1" matches { - case ir"(x: Int, y: Int) => $body(y):Int" => fail + assert(Try { + ir"(a: Int, b: Int) => a + 1" match { + case ir"(x: Int, y: Int) => $body(y):Int" => fail + } + }.isFailure) + + ir"(a: Int, b: Int) => a + 1" match { case ir"(x: Int, y: Int) => $body(x):Int" => - } and { + } + + ir"(a: Int, b: Int) => a + 1" match { case ir"(x: Int, y: Int) => $body(x,y):Int" => } - ir"(a: Int, b: Int) => a + b" matches { - case ir"(x: Int, y: Int) => $body(x):Int" => fail + + ir"(a: Int, b: Int) => a + b" match { case ir"(x: Int, y: Int) => $body(x,y):Int" => assert(body == ir"(_:Int)+(_:Int)") } - ir"(a: Int, b: Int) => a + b" matches { - case ir"(x: Int, y: Int) => ($lhs(y):Int)+($rhs(y):Int)" => fail + assert(Try { + ir"(a: Int, b: Int) => a + b" match { + case ir"(x: Int, y: Int) => $body(x):Int" => fail + } + }.isFailure) + + ir"(a: Int, b: Int) => a + b" match { case ir"(x: Int, y: Int) => ($lhs(x):Int)+($rhs(y):Int)" => assert(lhs == id) assert(rhs == id) } + + assert(Try { + ir"(a: Int, b: Int) => a + b" match { + case ir"(x: Int, y: Int) => ($lhs(y):Int)+($rhs(y):Int)" => + } + }.isFailure) } test("Matching let-binding bodies") { @@ -50,16 +75,10 @@ class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVaria test("Non-trivial arguments") { val id = ir"(z: Int) => z" - - ir"(a: Int, b: Int) => a + b" matches { - case ir"(x: Int, y: Int) => $body(x + y): Int" => assert(body == id) - case ir"(x: Int, y: Int) => $body(x): Int" => fail - case ir"(x: Int, y: Int) => $body(y): Int" => fail - } - - ir"(a: Int, b: Int, c: Int) => a + b + c" matches { - case ir"(x: Int, y: Int, z: Int) => $body(x + y, z): Int" => assert(body == ir"(r: Int, s: Int) => r + s") - } + + //ir"(a: Int, b: Int, c: Int) => a + b + c" matches { + // case ir"(x: Int, y: Int, z: Int) => $body(x + y, z): Int" => assert(body == ir"(r: Int, s: Int) => r + s") + //} ir"(a: Int, b: Int, c: Int) => a + b + c" matches { case ir"(x: Int, y: Int, z: Int) => $body(x + y + z): Int" => assert(body == id) @@ -71,17 +90,17 @@ class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVaria // TODO doesn't "align", `extract` is too structural // case ir"(x: Int, y: Int, z: Int) => $body(x, y + z)" => assert(body == ir"(r: Int, s: Int) => r + s") - ir"(a: Int) => readInt + a" matches { - case ir"(x: Int) => $body(readInt, x): Int" => assert(body == ir"(r: Int, s: Int) => r + s") - } + //ir"(a: Int) => readInt + a" matches { + // case ir"(x: Int) => $body(readInt, x): Int" => assert(body == ir"(r: Int, s: Int) => r + s") + //} - ir"(a: Int) => readInt + a" matches { - case ir"(x: Int) => $body(x, readInt): Int" => assert(body == ir"(r: Int, s: Int) => s + r") - } + //ir"(a: Int) => readInt + a" matches { + // case ir"(x: Int) => $body(x, readInt): Int" => assert(body == ir"(r: Int, s: Int) => s + r") + //} - ir"(a: Int, b: Int) => readInt + (a + b)" matches { - case ir"(x: Int, y: Int) => $body(readInt, x + y): Int" => assert(body == ir"(r: Int, s: Int) => r + s") - } + //ir"(a: Int, b: Int) => readInt + (a + b)" matches { + // case ir"(x: Int, y: Int) => $body(readInt, x + y): Int" => assert(body == ir"(r: Int, s: Int) => r + s") + //} // TODO doesn't "align", `extract` is too structural //ir"(a: Int, b: Int) => readInt + (a + b)" matches { From fedff936b176fc02131ef0cd3bfb88fba3b5c778 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Sun, 26 Nov 2017 13:31:18 +0100 Subject: [PATCH 19/66] Fix issue where a HOPHole extraction would happen too early in the extraction process --- src/main/scala/squid/ir/fastanf/Effects.scala | 1 + src/main/scala/squid/ir/fastanf/FastANF.scala | 145 ++++++++---------- .../fastir/HigherOrderPatternVariables.scala | 32 ++-- 3 files changed, 84 insertions(+), 94 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/Effects.scala b/src/main/scala/squid/ir/fastanf/Effects.scala index 2e915f78..04a70b79 100644 --- a/src/main/scala/squid/ir/fastanf/Effects.scala +++ b/src/main/scala/squid/ir/fastanf/Effects.scala @@ -65,6 +65,7 @@ trait Effects { trait StandardEffects extends Effects { addPureMtd(MethodSymbol(TypeSymbol("scala.Int"),"$plus")) + addPureMtd(MethodSymbol(TypeSymbol("scala.Double"),"$plus")) addPureMtd(MethodSymbol(TypeSymbol("scala.Int"),"$times")) addPureMtd(MethodSymbol(TypeSymbol("scala.Int"), "toDouble")) addPureMtd(MethodSymbol(TypeSymbol("scala.Int"), "toFloat")) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 5eeb7a7d..33c1bca6 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -303,12 +303,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi post(pre(r) match { case lb: LetBinding => // Note: destructive modification of the let-binding! - //new LetBinding( - // lb.name, - // lb.bound, - // transformDef(lb.value), - // transformRepAndDef0(lb.body) - //) lb.value = lb.value |> transformDef lb.body = lb.body |> transformRepAndDef0 lb @@ -327,7 +321,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi println(s"Extract($xtor, $xtee)") for { es <- extractWithState(xtor, xtee)(State.forExtraction(xtor, xtee)).fold(_ => None, Some(_)) - if es.mks.xtor.isEmpty && es.mks.xtee.isEmpty + if es.flags.xtor.isEmpty && es.flags.xtee.isEmpty } yield es.ex } @@ -340,12 +334,12 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi type ExtractState = Either[State, State] implicit def rightBias[A, B](e: Either[A, B]): Either.RightProjection[A,B] = e.right - case class State(ex: Extract, ctx: Ctx, mks: Markers, matchedImpureBVs: Set[BoundVal], failedMatches: Map[BoundVal, Set[BoundVal]], makeUnreachable: Boolean) { + case class State(ex: Extract, ctx: Ctx, flags: Flags, matchedImpureBVs: Set[BoundVal], failedMatches: Map[BoundVal, Set[BoundVal]], makeUnreachable: Boolean) { def withNewExtract(newEx: Extract): State = copy(ex = newEx) def withCtx(newCtx: Ctx): State = copy(ctx = newCtx) def withCtx(p: (BoundVal, BoundVal)): State = copy(ctx = ctx + p) - def updateMarkers(newMks: Markers): State = copy(mks = newMks) - def withoutMarkers(xtorMk: BoundVal, xteeMk: BoundVal): State = copy(mks = Markers(mks.xtor - xtorMk, mks.xtee - xteeMk)) + def updateFlags(newFlags: Flags): State = copy(flags = newFlags) + def withoutFlags(xtorFlag: BoundVal, xteeFlag: BoundVal): State = copy(flags = Flags(flags.xtor - xtorFlag, flags.xtee - xteeFlag)) def withMatchedImpures(r: Rep): State = r match { case lb: LetBinding if isPure(lb.body) => copy(matchedImpureBVs = matchedImpureBVs + lb.bound) withMatchedImpures lb.body case lb: LetBinding => this withMatchedImpures lb.body @@ -363,41 +357,49 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def forExtraction(xtor: Rep, xtee: Rep): State = State(xtor, xtee, false) private def apply(xtor: Rep, xtee: Rep, makeUnreachable: Bool): State = - State(EmptyExtract, ListMap.empty, Markers(xtor, xtee), Set.empty, Map.empty.withDefaultValue(Set.empty), makeUnreachable) + State(EmptyExtract, ListMap.empty, Flags(xtor, xtee), Set.empty, Map.empty.withDefaultValue(Set.empty), makeUnreachable) } - sealed trait Marker - case object EndPoint extends Marker - case object NonEndPoint extends Marker + sealed trait Flag + case object Start extends Flag + case object Skip extends Flag - case class Markers(xtor: ListSet[BoundVal], xtee: ListSet[BoundVal]) { - private def marker(ls: ListSet[BoundVal])(bv: BoundVal) = if (ls contains bv) EndPoint else NonEndPoint - def xtorMarker(bv: BoundVal): Marker = marker(xtor)(bv) - def xteeMarker(bv: BoundVal): Marker = marker(xtee)(bv) + case class Flags(xtor: Set[BoundVal], xtee: Set[BoundVal]) { + private def flags(ls: Set[BoundVal])(bv: BoundVal) = if (ls contains bv) Start else Skip + def xtorFlag(bv: BoundVal): Flag = flags(xtor)(bv) + def xteeFlag(bv: BoundVal): Flag = flags(xtee)(bv) } - object Markers { - def apply(xtor: Rep, xtee: Rep): Markers = Markers(extractionStarts(xtor), extractionStarts(xtee)) + + object Flags { + def apply(xtor: Rep, xtee: Rep): Flags = Flags(genFlags(xtor), genFlags(xtee)) - private def extractionStarts(r: Rep): ListSet[BoundVal] = { - def bvs(d: Def, acc: ListSet[BoundVal]): ListSet[BoundVal] = d match { - case _: Lambda => ListSet.empty // TODO The lambda may never be applied. - case ma: MethodApp => (ma.self :: ma.argss.argssList).foldLeft(acc) { + private def genFlags(r: Rep): Set[BoundVal] = { + def update(d: Def, unusedBVs: Set[BoundVal], impures: Set[BoundVal]): (Set[BoundVal], Set[BoundVal]) = d match { + case l: Lambda => genFlags0(l.body, unusedBVs, impures) + + case ma: MethodApp => ((ma.self :: ma.argss.argssList).foldLeft(unusedBVs) { case (acc, bv: BoundVal) => acc - bv - case (acc, _) => acc // Assuming no LBs in self or argument positions - } - case _ => ListSet.empty + case (acc, _) => acc + }, impures) + + case _ => (unusedBVs, impures) } - def extractionStarts0(r: Rep, acc: ListSet[BoundVal]): ListSet[BoundVal] = r match { - case lb: LetBinding => effect(lb) match { - case Pure => extractionStarts0(lb.body, bvs(lb.value, acc + lb.bound)) - case Impure => extractionStarts0(lb.body, bvs(lb.value, acc)) + def genFlags0(r: Rep, unusedBVs: Set[BoundVal], impures: Set[BoundVal]): (Set[BoundVal], Set[BoundVal]) = r match { + case lb: LetBinding => effect(lb.value) match { + case Pure => + val updated = update(lb.value, unusedBVs + lb.bound, impures) + genFlags0(lb.body, updated._1, updated._2) + case Impure => + val updated = update(lb.value, unusedBVs + lb.bound, impures + lb.bound) + genFlags0(lb.body, updated._1, updated._2) } - case bv: BoundVal => acc - bv - case _ => acc + case bv: BoundVal => (unusedBVs - bv, impures) + case _ => (unusedBVs, impures) } - extractionStarts0(r, ListSet.empty) + val flags = genFlags0(r, Set.empty, Set.empty) + flags._1 ++ flags._2 } } @@ -443,7 +445,9 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi body0 = transformations.foldLeft(fbody(f)) { case (body, (bv: BoundVal, hopArg)) => val replace = es.ctx(bv) - bottomUpPartial(body){ case `replace` => hopArg } + bottomUpPartial(body){ + case `replace` => hopArg + } case (body, (lb: LetBinding, hopArg)) => extractWithState(lb, body) map { es => val replace = es.ctx(lb.last.bound) @@ -471,41 +475,32 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi ) } - def extractLBs(lb1: LetBinding, lb2: LetBinding)(implicit es: State): ExtractState = (effect(lb1.value), effect(lb2.value)) match { - case (Impure, Impure) => - extractDefs(lb1.value, lb2.value) match { - case Right(es) => - //if (es.makeUnreachable) lb2.isUnreachable = true - extractWithState(lb1.body, lb2.body)(es withCtx lb1.bound -> lb2.bound withMatchedImpures lb2.bound) - case Left(es) => Left(es withFailed lb1.bound -> lb2.bound) - } - - case (Pure, Pure) => (es.mks.xtorMarker(lb1.bound), es.mks.xteeMarker(lb2.bound)) match { - case (EndPoint, EndPoint) => - extractWithState(lb1.bound, lb2.bound) match { - case Right(es) => - extractWithState(lb1.body, lb2.body)(es withoutMarkers(lb1.bound, lb2.bound)) // withMatched lb2.bound) - case Left(es) => extractWithState(lb1, lb2.body)(es) - } - case (EndPoint, NonEndPoint) => extractWithState(lb1, lb2.body) - case (NonEndPoint, EndPoint) => extractWithState(lb1.body, lb2) - case (NonEndPoint, NonEndPoint) => extractWithState(lb1.body, lb2.body) - } + def extractLBs(lb1: LetBinding, lb2: LetBinding)(implicit es: State): ExtractState = { + def extractAndContinue(lb1: LetBinding, lb2: LetBinding)(implicit es: State): ExtractState = for { + es1 <- extractWithState(lb1.bound, lb2.bound) + es2 <- extractWithState(lb1.body, lb2.body)(es1) + } yield es2 + + (es.flags.xtorFlag(lb1.bound), es.flags.xteeFlag(lb2.bound)) match { + case (Start, Start) => extractAndContinue(lb1, lb2) - case (Impure, Pure) => extractWithState(lb1, lb2.body) - case (Pure, Impure) => extractWithState(lb1.body, lb2) + case (Start, Skip) => for { + es1 <- extractAndContinue(lb1, lb2).left + es2 <- extractWithState(lb1, lb2.body)(es1).left + } yield es2 + + case (Skip, Start) => for { + es1 <- extractAndContinue(lb1, lb2).left + es2 <- extractWithState(lb1.body, lb2)(es1).left + } yield es2 + + case (Skip, Skip) => extractWithState(lb1.body, lb2.body) + } } def extractHole(h: Hole, r: Rep)(implicit es: State): ExtractState = { println(s"ExtractHole: $h --> $r") - //def makeUnreachable(r: Rep): Rep = r match { - // case lb: LetBinding => - // //lb.isUnreachable = true - // makeUnreachable(lb.body) - // case _ => r - //} - (h, r) match { case (Hole(n, t), bv: BoundVal) => es.updateExtractWith( @@ -515,12 +510,9 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (Hole(n, t), lb: LetBinding) => es.updateExtractWith( - t extract (lb.typ, Covariant), + t extract(lb.typ, Covariant), Some(repExtract(n -> wrapConstruct(letbind(lb.value)))) - ).map { - //makeUnreachable(lb) - _ withMatchedImpures lb - } + ) map (_ withMatchedImpures lb) case (Hole(n, t), _) => es.updateExtractWith( @@ -544,7 +536,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi es1 <- acc.left es2 <- extractWithState(bv, bv2)(es1).left } yield es2 - } alsoApply (s => println(s"FOO: $s")) + } } def contentsOf(h: Hole)(implicit es: State): Option[Rep] = es.ex._1 get h.name // TODO check in ex._3, return Option[List[Rep]] @@ -581,7 +573,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } yield es3 case (_: Rep, lb: LetBinding) if es.matchedImpureBVs contains lb.bound => extractWithState(xtor, lb.body) - //case (_: Rep, lb: LetBinding) if es.matchedBVs contains lb.bound => extractWithState(xtor, lb.body) case (_, Ascribe(s, _)) => extractWithState(xtor, s) @@ -593,15 +584,15 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (HOPHole(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) case (bv1: BoundVal, bv2: BoundVal) => - println(s"EXTRACTIONSTATE IN BV: $es") println(s"OWNERS: ${bv1.owner} -- ${bv2.owner}") if (es.ctx.getOrElse(bv1, bv1) == bv2) Right(es) else if (es.failedMatches(bv1) contains bv2) Left(es) else (bv1.owner, bv2.owner) match { case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value) match { - case Right(es) => - //if (es.makeUnreachable) lb2.isUnreachable = true - Right(es withCtx lb1.bound -> lb2.bound withMatchedImpures lb2.bound) + case Right(es) => effect(lb2.value) match { + case Pure => Right(es withCtx lb1.bound -> lb2.bound withoutFlags(bv1, bv2)) + case Impure => Right(es withCtx lb1.bound -> lb2.bound withMatchedImpures lb2.bound withoutFlags(bv1, bv2)) + } case Left(es) => Left(es withFailed lb1.bound -> lb2.bound) } case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2) map (_ withCtx l1.bound -> l2.bound) // TODO handle failed extract? @@ -706,8 +697,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi println(s"rewriteRepWithState(\n\t$xtor\n\t$xtee)($es)") (xtor, xtee) match { - case (lb1: LetBinding, lb2: LetBinding) if !internalRec => ((effect(lb1.value), es.mks.xtorMarker(lb1.bound)), (effect(lb2.value), es.mks.xteeMarker(lb2.bound))) match { - case ((Pure, NonEndPoint), (Pure, NonEndPoint)) => Left(es) + case (lb1: LetBinding, lb2: LetBinding) if !internalRec => ((effect(lb1.value), es.flags.xtorFlag(lb1.bound)), (effect(lb2.value), es.flags.xteeFlag(lb2.bound))) match { + case ((Pure, Skip), (Pure, Skip)) => Left(es) case _ => extractWithState(lb1, lb2) // TODO With unreachable handling TODO why did I mean? } case _ => extractWithState(xtor, xtee) diff --git a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala index 3a7d9a93..7e7d3bff 100644 --- a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala +++ b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala @@ -80,7 +80,7 @@ class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVaria // case ir"(x: Int, y: Int, z: Int) => $body(x + y, z): Int" => assert(body == ir"(r: Int, s: Int) => r + s") //} - ir"(a: Int, b: Int, c: Int) => a + b + c" matches { + ir"(a: Int, b: Int, c: Int) => a + b + c" match { case ir"(x: Int, y: Int, z: Int) => $body(x + y + z): Int" => assert(body == id) } @@ -90,17 +90,17 @@ class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVaria // TODO doesn't "align", `extract` is too structural // case ir"(x: Int, y: Int, z: Int) => $body(x, y + z)" => assert(body == ir"(r: Int, s: Int) => r + s") - //ir"(a: Int) => readInt + a" matches { - // case ir"(x: Int) => $body(readInt, x): Int" => assert(body == ir"(r: Int, s: Int) => r + s") - //} + ir"(a: Int) => readInt + a" matches { + case ir"(x: Int) => $body(readInt, x): Int" => assert(body == ir"(r: Int, s: Int) => r + s") + } - //ir"(a: Int) => readInt + a" matches { - // case ir"(x: Int) => $body(x, readInt): Int" => assert(body == ir"(r: Int, s: Int) => s + r") - //} + ir"(a: Int) => readInt + a" matches { + case ir"(x: Int) => $body(x, readInt): Int" => assert(body == ir"(r: Int, s: Int) => s + r") + } - //ir"(a: Int, b: Int) => readInt + (a + b)" matches { - // case ir"(x: Int, y: Int) => $body(readInt, x + y): Int" => assert(body == ir"(r: Int, s: Int) => r + s") - //} + ir"(a: Int, b: Int) => readInt + (a + b)" matches { + case ir"(x: Int, y: Int) => $body(readInt, x + y): Int" => assert(body == ir"(r: Int, s: Int) => r + s") + } // TODO doesn't "align", `extract` is too structural //ir"(a: Int, b: Int) => readInt + (a + b)" matches { @@ -108,14 +108,12 @@ class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVaria //} } - //test("Currying") { - // ir"(a: Int, b: Int) => a + b" match { - // case ir"(x: Int, y: Int) => $body(x)(y)" => - // } - //} - test("Match letbindinds") { - // TODO `apply` should call inline + ir"val a = 10.toDouble; a + 1" match { + case ir"val x = 10.toDouble; $body(x)" => assert(body == ir"(_: Double) + 1") + } + + //val a = ir"val a = 10.toDouble; val b = a + 1; val c = b + 2; c" matches { // case ir"val x = 10.toDouble; $body(x):Double" => // assert(ir"$body(42)" == ir"(val a = (x: Int) => (val b = x + 1; val c = b + 2; c)); val tmp = a.apply(42.0); tmp") From 0bb0eccff1466945c12af3f335905bff00ec2eea Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Mon, 27 Nov 2017 09:59:06 +0100 Subject: [PATCH 20/66] Small HOPHole optimization and updated tests --- src/main/scala/squid/ir/fastanf/FastANF.scala | 28 +++++----- .../squid/ir/fastir/ExtractingTests.scala | 54 +++++++++++++++++++ .../fastir/HigherOrderPatternVariables.scala | 33 ++++-------- .../squid/ir/fastir/RewritingTests.scala | 6 +-- 4 files changed, 83 insertions(+), 38 deletions(-) create mode 100644 src/test/scala/squid/ir/fastir/ExtractingTests.scala diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 33c1bca6..e7309eff 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -318,7 +318,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi transformRepAndDef(r)(pre, post)(identity, identity) protected def extract(xtor: Rep, xtee: Rep): Option[Extract] = { - println(s"Extract($xtor, $xtee)") + println(s"Extract(\n$xtor, \n$xtee)") for { es <- extractWithState(xtor, xtee)(State.forExtraction(xtor, xtee)).fold(_ => None, Some(_)) if es.flags.xtor.isEmpty && es.flags.xtee.isEmpty @@ -340,11 +340,12 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def withCtx(p: (BoundVal, BoundVal)): State = copy(ctx = ctx + p) def updateFlags(newFlags: Flags): State = copy(flags = newFlags) def withoutFlags(xtorFlag: BoundVal, xteeFlag: BoundVal): State = copy(flags = Flags(flags.xtor - xtorFlag, flags.xtee - xteeFlag)) + //def withMatched def withMatchedImpures(r: Rep): State = r match { - case lb: LetBinding if isPure(lb.body) => copy(matchedImpureBVs = matchedImpureBVs + lb.bound) withMatchedImpures lb.body + case lb: LetBinding if !isPure(lb.value) => copy(matchedImpureBVs = matchedImpureBVs + lb.bound) withMatchedImpures lb.body case lb: LetBinding => this withMatchedImpures lb.body - case bv: BoundVal => copy(matchedImpureBVs = matchedImpureBVs + bv) - case _ => this + //case bv: BoundVal => copy(matchedImpureBVs = matchedImpureBVs + bv) + case _ => this // Everything else is pure so we ignore it } def withFailed(p: (BoundVal, BoundVal)): State = copy(failedMatches = updateWith(failedMatches)(p)) @@ -445,9 +446,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi body0 = transformations.foldLeft(fbody(f)) { case (body, (bv: BoundVal, hopArg)) => val replace = es.ctx(bv) - bottomUpPartial(body){ - case `replace` => hopArg - } + replace rebind hopArg + body case (body, (lb: LetBinding, hopArg)) => extractWithState(lb, body) map { es => val replace = es.ctx(lb.last.bound) @@ -501,18 +501,18 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def extractHole(h: Hole, r: Rep)(implicit es: State): ExtractState = { println(s"ExtractHole: $h --> $r") - (h, r) match { + val newEs = (h, r) match { case (Hole(n, t), bv: BoundVal) => es.updateExtractWith( - t extract (xtee.typ, Covariant), + t extract(xtee.typ, Covariant), Some(repExtract(n -> bv)) - ).map(_ withMatchedImpures bv) + ) case (Hole(n, t), lb: LetBinding) => es.updateExtractWith( t extract(lb.typ, Covariant), Some(repExtract(n -> wrapConstruct(letbind(lb.value)))) - ) map (_ withMatchedImpures lb) + ) case (Hole(n, t), _) => es.updateExtractWith( @@ -520,6 +520,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi Some(repExtract(n -> xtee)) ) } + + newEs map (_ withMatchedImpures r) } def extractInside(bv: BoundVal, d: Def)(implicit es: State): ExtractState = { @@ -761,11 +763,11 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } } - if (preCheck(es.ex) alsoApply println) + if (preCheck(es.ex)) for { code <- code(es.ex) _ = println(code) - if check(Set.empty, es.matchedImpureBVs)(filterLBs(code)(es.matchedImpureBVs contains _.bound) alsoApply println) alsoApply println + if check(Set.empty, es.matchedImpureBVs)(filterLBs(code)(es.matchedImpureBVs contains _.bound)) } yield code else None } diff --git a/src/test/scala/squid/ir/fastir/ExtractingTests.scala b/src/test/scala/squid/ir/fastir/ExtractingTests.scala new file mode 100644 index 00000000..6ac10c70 --- /dev/null +++ b/src/test/scala/squid/ir/fastir/ExtractingTests.scala @@ -0,0 +1,54 @@ +package squid +package ir +package fastir + +import scala.util.Try + +class ExtractingTests extends MyFunSuiteBase(ExtractingTests.Embedding) { + import ExtractingTests.Embedding.Predef._ + + test("Matching with pure statements") { + ir"42" match { + case ir"$h" => assert(h =~= ir"42") + } + + ir"42.toDouble" match { + case ir"($h: Int).toDouble" => assert(h =~= ir"42") + } + + ir"println(42.toDouble)" match { + case ir"println($h)" => assert(h =~= ir"42.toDouble") + } + + ir"(42, 1337)" match { + case ir"(($l: Int), ($r: Int))" => + assert(l =~= ir"42") + assert(r =~= ir"1337") + } + } + + test("Matching with impure statements") { + ir"val r = readInt; r + 1" match { + case ir"val r = ($h: Int); r + 1" => assert(h =~= ir"readInt") + } + + assert(Try { + ir"val r = 10.toDouble; r + 1" match { + case ir"val rX = 42.toDouble; $body" => fail + } + }.isFailure) + } + + test("Matching with dead-ends") { + ir"val a = 42.toDouble; val b = a + 1; 1337" match { + case ir"val aX = ($h1: Int).toDouble; val bX = aX + ($h2: Int); ($h3: Int)" => + assert(h1 =~= ir"42") + assert(h2 =~= ir"1") + assert(h3 =~= ir"1337") + } + } +} + +object ExtractingTests { + object Embedding extends FastANF +} diff --git a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala index 7e7d3bff..449f7618 100644 --- a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala +++ b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala @@ -11,15 +11,15 @@ class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVaria val id = ir"(z:Int) => z" ir"(a: Int) => a + 1" match { - case ir"(x: Int) => $body(x): Int" => assert(body == ir"(_:Int) + 1") + case ir"(x: Int) => $body(x): Int" => assert(body =~= ir"(_:Int) + 1") } ir"(a: Int) => a + 1" match { - case ir"(x: Int) => $body(x):$t" => assert(body == ir"(_:Int)+1") + case ir"(x: Int) => $body(x):$t" => assert(body =~= ir"(_:Int)+1") } ir"(a: Int) => a + 1" match { - case ir"(x: Int) => ($exp(x):Int)+1" => assert(exp == id) + case ir"(x: Int) => ($exp(x):Int)+1" => assert(exp =~= id) } assert(Try { @@ -38,7 +38,7 @@ class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVaria ir"(a: Int, b: Int) => a + b" match { - case ir"(x: Int, y: Int) => $body(x,y):Int" => assert(body == ir"(_:Int)+(_:Int)") + case ir"(x: Int, y: Int) => $body(x,y):Int" => assert(body =~= ir"(_:Int)+(_:Int)") } assert(Try { @@ -49,8 +49,8 @@ class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVaria ir"(a: Int, b: Int) => a + b" match { case ir"(x: Int, y: Int) => ($lhs(x):Int)+($rhs(y):Int)" => - assert(lhs == id) - assert(rhs == id) + assert(lhs =~= id) + assert(rhs =~= id) } assert(Try { @@ -81,36 +81,25 @@ class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVaria //} ir"(a: Int, b: Int, c: Int) => a + b + c" match { - case ir"(x: Int, y: Int, z: Int) => $body(x + y + z): Int" => assert(body == id) + case ir"(x: Int, y: Int, z: Int) => $body(x + y + z): Int" => assert(body =~= id) } - // TODO `extract` should see the different combinations of `x + y + z` - // case ir"(x: Int, y: Int, z: Int) => $body(x + z, y)" => println(body); assert(body == ir"(r: Int, s: Int) => r + s") - - // TODO doesn't "align", `extract` is too structural - // case ir"(x: Int, y: Int, z: Int) => $body(x, y + z)" => assert(body == ir"(r: Int, s: Int) => r + s") - ir"(a: Int) => readInt + a" matches { - case ir"(x: Int) => $body(readInt, x): Int" => assert(body == ir"(r: Int, s: Int) => r + s") + case ir"(x: Int) => $body(readInt, x): Int" => assert(body =~= ir"(r: Int, s: Int) => r + s") } ir"(a: Int) => readInt + a" matches { - case ir"(x: Int) => $body(x, readInt): Int" => assert(body == ir"(r: Int, s: Int) => s + r") + case ir"(x: Int) => $body(x, readInt): Int" => assert(body =~= ir"(r: Int, s: Int) => s + r") } ir"(a: Int, b: Int) => readInt + (a + b)" matches { - case ir"(x: Int, y: Int) => $body(readInt, x + y): Int" => assert(body == ir"(r: Int, s: Int) => r + s") + case ir"(x: Int, y: Int) => $body(readInt, x + y): Int" => assert(body =~= ir"(r: Int, s: Int) => r + s") } - - // TODO doesn't "align", `extract` is too structural - //ir"(a: Int, b: Int) => readInt + (a + b)" matches { - // case ir"(x: Int, y: Int) => $body(x + y, readInt): Int" => assert(body == ir"(r: Int, s: Int) => s + r") - //} } test("Match letbindinds") { ir"val a = 10.toDouble; a + 1" match { - case ir"val x = 10.toDouble; $body(x)" => assert(body == ir"(_: Double) + 1") + case ir"val x = 10.toDouble; $body(x)" => assert(body =~= ir"(_: Double) + 1") } diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index ab7123d7..b032d1da 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -18,12 +18,12 @@ class RewritingTests extends MyFunSuiteBase(BasicTests.Embedding) { val b = ir"42.toFloat" rewrite { case ir"42.toFloat" => ir"42f" } - assert(b =~= ir"42f") + assert(b =~= ir"val t = ${ir"42"}.toFloat; 42f") val c = ir"42.toDouble" rewrite { case ir"(${Const(n)}: Int).toDouble" => ir"${Const(n.toDouble)}" } - assert(c =~= ir"42.0") + assert(c =~= ir"val t = ${ir"42"}.toDouble; 42.0") //assertDoesNotCompile(""" // T.rewrite { case ir"0.5" => ir"42" } @@ -69,7 +69,7 @@ class RewritingTests extends MyFunSuiteBase(BasicTests.Embedding) { //test("Rewriting simple expressions only once") { // val a = ir"println((50, 60))" rewrite { // case ir"($x:Int,$y:Int)" => ir"($y:Int,$x:Int)" - // case ir"(${Const(n)}:Int)" => Const(n+1) + // //case ir"(${Const(n)}:Int)" => Const(n+1) // } // assert(a =~= ir"println((61,51))") //} From 80f46c5b5acad6694b8cde896f9a49a4ec498e04 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Mon, 27 Nov 2017 14:45:56 +0100 Subject: [PATCH 21/66] Add way to shortcut extraction, hophole extraction branches now update the state --- src/main/scala/squid/ir/fastanf/FastANF.scala | 160 ++++++++++-------- .../squid/ir/fastir/RewritingTests.scala | 16 +- 2 files changed, 98 insertions(+), 78 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index e7309eff..d64bd4fe 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -320,7 +320,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi protected def extract(xtor: Rep, xtee: Rep): Option[Extract] = { println(s"Extract(\n$xtor, \n$xtee)") for { - es <- extractWithState(xtor, xtee)(State.forExtraction(xtor, xtee)).fold(_ => None, Some(_)) + es <- extractWithState(xtor, xtee)(_ => false)(State.forExtraction(xtor, xtee)).fold(_ => None, Some(_)) if es.flags.xtor.isEmpty && es.flags.xtee.isEmpty } yield es.ex } @@ -340,7 +340,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def withCtx(p: (BoundVal, BoundVal)): State = copy(ctx = ctx + p) def updateFlags(newFlags: Flags): State = copy(flags = newFlags) def withoutFlags(xtorFlag: BoundVal, xteeFlag: BoundVal): State = copy(flags = Flags(flags.xtor - xtorFlag, flags.xtee - xteeFlag)) - //def withMatched + def withoutXteeFlag(flag: BoundVal): State = copy(flags = flags.copy(xtee = flags.xtee - flag)) def withMatchedImpures(r: Rep): State = r match { case lb: LetBinding if !isPure(lb.value) => copy(matchedImpureBVs = matchedImpureBVs + lb.bound) withMatchedImpures lb.body case lb: LetBinding => this withMatchedImpures lb.body @@ -403,8 +403,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi flags._1 ++ flags._2 } } - - def extractWithState(xtor: Rep, xtee: Rep)(implicit es: State): ExtractState = { + + def extractWithState(xtor: Rep, xtee: Rep)(done: State => Boolean)(implicit es: State): ExtractState = { def extractHOPHole(name: String, typ: TypeRep, argss: List[List[Rep]], visible: List[BoundVal])(implicit es: State): ExtractState = { println("EXTRACTINGHOPHOLE") type Func = List[List[BoundVal]] -> Rep @@ -431,7 +431,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi hasUndeclaredBVs0(r, Set.empty) } - def extendFunc(args: List[Rep], maybeFunc: Option[Func]): Option[Func] = { + def extendFunc(args: List[Rep], maybeFuncAndState: Option[(Func, State)]): Option[(Func, State)] = { println(s"ARGS0: $args") val hopArgs = args.map(arg => bindVal("hopArg", arg.typ, Nil)) @@ -440,79 +440,97 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi println(s"Transformation: $transformations") + def bvs(r: Rep): List[BoundVal] = { + def bvs0(r: Rep, acc: List[BoundVal]): List[BoundVal] = r match { + case lb: LetBinding => lb.bound :: acc + case _ => acc + } + + bvs0(r, List.empty) + } + val after = for { - f <- maybeFunc + (f, es) <- maybeFuncAndState _ = println(s"BEFORE $f") - body0 = transformations.foldLeft(fbody(f)) { - case (body, (bv: BoundVal, hopArg)) => + newBodyAndState = transformations.foldLeft(fbody(f) -> es) { + case ((body, es), (bv: BoundVal, hopArg)) => val replace = es.ctx(bv) replace rebind hopArg - body - - case (body, (lb: LetBinding, hopArg)) => extractWithState(lb, body) map { es => - val replace = es.ctx(lb.last.bound) - bottomUpPartial(filterLBs(body)(es.ctx.values.toSet contains _.bound)) { case `replace` => hopArg } - } getOrElse body - - case (body, (r, hopArg)) => bottomUpPartial(body) { case `r` => hopArg } + body -> es + + case ((body, es), (lb: LetBinding, hopArg)) => + val done = (s: State) => bvs(lb) forall (s.ctx.keySet contains _) + extractWithState(lb, body)(done)(es) map { es => + val replace = es.ctx(lb.last.bound) + bottomUpPartial(filterLBs(body)(es.ctx.values.toSet contains _.bound)) { case `replace` => hopArg } -> es + } getOrElse body -> es + + case ((body, es), (r, hopArg)) => bottomUpPartial(body) { case `r` => hopArg } -> es } - _ = bottomUpPartial(body0) { case bv: BoundVal if visible contains bv => return None } - } yield (hopArgs :: fargss(f)) -> body0 + _ = bottomUpPartial(newBodyAndState._1) { case bv: BoundVal if visible contains bv => return None } + } yield ((hopArgs :: fargss(f)) -> newBodyAndState._1) -> newBodyAndState._2 println(s"AFTER: $after") after } - es.updateExtractWith( - typ extract (xtee.typ, Covariant), - for { - f <- argss.foldRight(Option(emptyFunc(xtee)))(extendFunc) - _ = println(s"F: $f") - l = fargss(f).foldRight(fbody(f)) { case (args, body) => wrapConstruct(lambda(args, body)) } - if !hasUndeclaredBVs(l) - } yield repExtract(name -> l) - ) + val oe = for { + es1 <- typ extract (xtee.typ, Covariant) + m <- merge(es.ex, es1) + (f, es2) <- argss.foldRight(Option(emptyFunc(xtee) -> (es withNewExtract m)))(extendFunc) + _ = println(s"F: $f") + l = fargss(f).foldRight(fbody(f)) { case (args, body) => wrapConstruct(lambda(args, body)) } + if !hasUndeclaredBVs(l) + //m2 <- merge(es2.ex, repExtract(name -> l)) + } yield es2 updateExtractWith Some(repExtract(name -> l)) + + oe getOrElse Left(es) } - def extractLBs(lb1: LetBinding, lb2: LetBinding)(implicit es: State): ExtractState = { - def extractAndContinue(lb1: LetBinding, lb2: LetBinding)(implicit es: State): ExtractState = for { - es1 <- extractWithState(lb1.bound, lb2.bound) - es2 <- extractWithState(lb1.body, lb2.body)(es1) + def extractLBs(lb1: LetBinding, lb2: LetBinding)(done: State => Boolean)(implicit es: State): ExtractState = { + def extractAndContinue(lb1: LetBinding, lb2: LetBinding)(done: State => Boolean)(implicit es: State): ExtractState = for { + es1 <- extractWithState(lb1.bound, lb2.bound)(done) + es2 <- extractWithState(lb1.body, lb2.body)(done)(es1) } yield es2 (es.flags.xtorFlag(lb1.bound), es.flags.xteeFlag(lb2.bound)) match { - case (Start, Start) => extractAndContinue(lb1, lb2) + case (Start, Start) => extractAndContinue(lb1, lb2)(done) case (Start, Skip) => for { - es1 <- extractAndContinue(lb1, lb2).left - es2 <- extractWithState(lb1, lb2.body)(es1).left + es1 <- extractAndContinue(lb1, lb2)(done).left + es2 <- extractWithState(lb1, lb2.body)(done)(es1).left } yield es2 case (Skip, Start) => for { - es1 <- extractAndContinue(lb1, lb2).left - es2 <- extractWithState(lb1.body, lb2)(es1).left + es1 <- extractAndContinue(lb1, lb2)(done).left + es2 <- extractWithState(lb1.body, lb2)(done)(es1).left } yield es2 - case (Skip, Skip) => extractWithState(lb1.body, lb2.body) + case (Skip, Skip) => extractWithState(lb1.body, lb2.body)(done) } } def extractHole(h: Hole, r: Rep)(implicit es: State): ExtractState = { println(s"ExtractHole: $h --> $r") + def updateFlags(r: Rep, es: State): State = r match { + case lb: LetBinding => updateFlags(lb.body, es withoutXteeFlag lb.bound) + case _ => es + } + val newEs = (h, r) match { case (Hole(n, t), bv: BoundVal) => es.updateExtractWith( t extract(xtee.typ, Covariant), Some(repExtract(n -> bv)) - ) + ) map (_ withoutXteeFlag bv) case (Hole(n, t), lb: LetBinding) => es.updateExtractWith( t extract(lb.typ, Covariant), Some(repExtract(n -> wrapConstruct(letbind(lb.value)))) - ) + ).map(updateFlags(lb, _)) case (Hole(n, t), _) => es.updateExtractWith( @@ -536,7 +554,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi bvs(d).foldLeft[ExtractState](Left(es)) { case (acc, bv2) => for { es1 <- acc.left - es2 <- extractWithState(bv, bv2)(es1).left + es2 <- extractWithState(bv, bv2)(done)(es1).left } yield es2 } } @@ -544,7 +562,11 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def contentsOf(h: Hole)(implicit es: State): Option[Rep] = es.ex._1 get h.name // TODO check in ex._3, return Option[List[Rep]] println(s"extractWithState: $xtor\n$xtee\n") - xtor -> xtee match { + if (done(es)) { + println(s"FINISHED") + Right(es) + } + else xtor -> xtee match { case (h: Hole, lb: LetBinding) => contentsOf(h) match { case Some(lb1: LetBinding) if lb1.value == lb.value => Right(es) case Some(_) => Left(es) @@ -558,46 +580,49 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } case (HOPHole2(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) - - case (lb1: LetBinding, lb2: LetBinding) => extractLBs(lb1, lb2) - + + case (lb1: LetBinding, lb2: LetBinding) => extractLBs(lb1, lb2)(done) + // Stop at markers? case (lb: LetBinding, _: Rep) => (effect(lb), effect(xtee)) match { - case (Pure, Pure) => extractWithState(lb.body, xtee) + case (Pure, Pure) => extractWithState(lb.body, xtee)(done) case (Impure, Pure) => Left(es) case (_, Impure) => Left(es) // Assuming the return value cannot be impure } case (bv: BoundVal, lb: LetBinding) => for { - es1 <- extractWithState(bv, lb.bound).left + es1 <- extractWithState(bv, lb.bound)(done).left es2 <- extractInside(bv, lb.value)(es1).left - es3 <- extractWithState(bv, lb.body)(es2).left + es3 <- extractWithState(bv, lb.body)(done)(es2).left } yield es3 - - case (_: Rep, lb: LetBinding) if es.matchedImpureBVs contains lb.bound => extractWithState(xtor, lb.body) - - case (_, Ascribe(s, _)) => extractWithState(xtor, s) - + + case (_: Rep, lb: LetBinding) if es.matchedImpureBVs contains lb.bound => extractWithState(xtor, lb.body)(done) + + case (_, Ascribe(s, _)) => extractWithState(xtor, s)(done) + case (Ascribe(s, t), _) => for { es1 <- es.updateExtractWith(t extract(xtee.typ, Covariant)) - es2 <- extractWithState(s, xtee)(es1) + es2 <- extractWithState(s, xtee)(done)(es1) } yield es2 case (HOPHole(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) case (bv1: BoundVal, bv2: BoundVal) => println(s"OWNERS: ${bv1.owner} -- ${bv2.owner}") - if (es.ctx.getOrElse(bv1, bv1) == bv2) Right(es) + if (es.ctx.getOrElse(bv1, bv1) == bv2) { + println(s"SEEYA") + Right(es) + } else if (es.failedMatches(bv1) contains bv2) Left(es) else (bv1.owner, bv2.owner) match { - case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value) match { + case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value)(done) match { case Right(es) => effect(lb2.value) match { case Pure => Right(es withCtx lb1.bound -> lb2.bound withoutFlags(bv1, bv2)) case Impure => Right(es withCtx lb1.bound -> lb2.bound withMatchedImpures lb2.bound withoutFlags(bv1, bv2)) } case Left(es) => Left(es withFailed lb1.bound -> lb2.bound) } - case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2) map (_ withCtx l1.bound -> l2.bound) // TODO handle failed extract? + case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2)(done) map (_ withCtx l1.bound -> l2.bound) // TODO handle failed extract? case _ => Left(es withFailed bv1 -> bv2) } @@ -607,12 +632,12 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (StaticModule(fn1), StaticModule(fn2)) if fn1 == fn2 => Right(es) // Assuming if they have the same name and prefix the type is the same - case (Module(p1, n1, _), Module(p2, n2, _)) if n1 == n2 => extractWithState(p1, p2) - + case (Module(p1, n1, _), Module(p2, n2, _)) if n1 == n2 => extractWithState(p1, p2)(done) + case (NewObject(t1), NewObject(t2)) => es updateExtractWith (t1 extract(t2, Covariant)) case _ => Left(es) - } + } } alsoApply (res => println(s"Extract: $res")) protected def spliceExtract(xtor: Rep, args: Args): Option[Extract] = xtor match { @@ -634,13 +659,13 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => throw IRException(s"Trying to splice-extract with invalid extractor $xtor") } - def extractDefs(v1: Def, v2: Def)(implicit es: State): ExtractState = { + def extractDefs(v1: Def, v2: Def)(done: State => Boolean)(implicit es: State): ExtractState = { println(s"VALUES: \n\t$v1\n\t$v2 with $es \n\n") (v1, v2) match { case (l1: Lambda, l2: Lambda) => for { es1 <- es updateExtractWith (l1.boundType extract(l2.boundType, Covariant)) - es2 <- extractWithState(l1.body, l2.body)(es1 withCtx l1.bound -> l2.bound) + es2 <- extractWithState(l1.body, l2.body)(done)(es1 withCtx l1.bound -> l2.bound) } yield es2 case (ma1: MethodApp, ma2: MethodApp) if ma1.mtd == ma2.mtd => @@ -663,11 +688,11 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi es1 <- extractArgss0(t1, t2, acc)(es0) } yield es1 - case (SplicedArgument(arg1), SplicedArgument(arg2)) => extractWithState(arg1, arg2) + case (SplicedArgument(arg1), SplicedArgument(arg2)) => extractWithState(arg1, arg2)(done) case (sa: SplicedArgument, ArgumentCons(h, t)) => extractArgss0(sa, t, h :: acc) case (sa: SplicedArgument, r: Rep) => extractArgss0(sa, NoArguments, r :: acc) case (SplicedArgument(arg), NoArguments) => es updateExtractWith spliceExtract(arg, Args(acc.reverse: _*)) - case (r1: Rep, r2: Rep) => extractWithState(r1, r2) + case (r1: Rep, r2: Rep) => extractWithState(r1, r2)(done) case (NoArguments, NoArguments) => Right(es) case (NoArgumentLists, NoArgumentLists) => Right(es) case _ => Left(es) @@ -677,13 +702,13 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } for { - es1 <- extractWithState(ma1.self, ma2.self) + es1 <- extractWithState(ma1.self, ma2.self)(done) es2 <- targExtract(es1) es3 <- extractArgss(ma1.argss, ma2.argss)(es2) es4 <- es3.updateExtractWith(ma1.typ extract (ma2.typ, Covariant)) } yield es4 - case (DefHole(h), _) => extractWithState(h, wrapConstruct(letbind(v2))) + case (DefHole(h), _) => extractWithState(h, wrapConstruct(letbind(v2)))(done) case _ => Left(es) } @@ -701,9 +726,9 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi (xtor, xtee) match { case (lb1: LetBinding, lb2: LetBinding) if !internalRec => ((effect(lb1.value), es.flags.xtorFlag(lb1.bound)), (effect(lb2.value), es.flags.xteeFlag(lb2.bound))) match { case ((Pure, Skip), (Pure, Skip)) => Left(es) - case _ => extractWithState(lb1, lb2) // TODO With unreachable handling TODO why did I mean? + case _ => extractWithState(lb1, lb2)(_ => false) // TODO With unreachable handling TODO why did I mean? } - case _ => extractWithState(xtor, xtee) + case _ => extractWithState(xtor, xtee)(_ => false) } } @@ -766,7 +791,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi if (preCheck(es.ex)) for { code <- code(es.ex) - _ = println(code) if check(Set.empty, es.matchedImpureBVs)(filterLBs(code)(es.matchedImpureBVs contains _.bound)) } yield code else None diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index b032d1da..d0e740b4 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -5,10 +5,6 @@ package fastir class RewritingTests extends MyFunSuiteBase(BasicTests.Embedding) { import RewritingTests.Embedding.Predef._ - //object T extends SimpleRuleBasedTransformer with TopDownTransformer { - // val base: DSL.type = DSL - //} - test("Simple rewrites") { val a = ir"123" rewrite { case ir"123" => ir"666" @@ -35,28 +31,28 @@ class RewritingTests extends MyFunSuiteBase(BasicTests.Embedding) { } test("Rewriting subpatterns") { - val a = ir"(readInt + 111) * .5" rewrite { - case ir"(($n: Int) + 111) * .5" => ir"$n * .25" + val a = ir"(10.toDouble + 111) * .5" match { + case ir"(($n: Double) + 111) * .5" => ir"$n * .25" } - assert(a =~= ir"readInt * .25") + assert(a =~= ir"10.toDouble * .25") val b = ir"Option(42).get" rewrite { case ir"Option(($n: Int)).get" => n } - assert(b =~= ir"42") + assert(b =~= ir"val t = Option(42).get; 42") val c = ir"val a = Option(42).get; a * 5" rewrite { case ir"Option(($n: Int)).get" => n case ir"($a: Int) * 5" => ir"$a * 2" } - assert(c =~= ir"val a = 42; a * 2") + assert(c =~= ir"val t1 = Option(42); val t2 = t1.get; val t3 = ${ir"42"} * 5; ${ir"42"} * 2") } test("Rewriting with dead-ends") { val b = ir"Option(42).get; 20" rewrite { case ir"Option(($n: Int)).get; 20" => n } - assert(b =~= ir"42") + assert(b =~= ir"val t = ${ir"Option(42)"}; 42") } //test("Rewriting with impures") { From bf125b1c0a83bb4719c7a008fcd51b13adad2199 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Tue, 28 Nov 2017 10:59:17 +0100 Subject: [PATCH 22/66] Cleanup --- src/main/scala/squid/ir/fastanf/FastANF.scala | 135 +++++++++--------- .../squid/ir/fastir/RewritingTests.scala | 2 +- 2 files changed, 67 insertions(+), 70 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index d64bd4fe..29388547 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -344,7 +344,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def withMatchedImpures(r: Rep): State = r match { case lb: LetBinding if !isPure(lb.value) => copy(matchedImpureBVs = matchedImpureBVs + lb.bound) withMatchedImpures lb.body case lb: LetBinding => this withMatchedImpures lb.body - //case bv: BoundVal => copy(matchedImpureBVs = matchedImpureBVs + bv) case _ => this // Everything else is pure so we ignore it } def withFailed(p: (BoundVal, BoundVal)): State = copy(failedMatches = updateWith(failedMatches)(p)) @@ -385,16 +384,19 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => (unusedBVs, impures) } - + def genFlags0(r: Rep, unusedBVs: Set[BoundVal], impures: Set[BoundVal]): (Set[BoundVal], Set[BoundVal]) = r match { - case lb: LetBinding => effect(lb.value) match { - case Pure => - val updated = update(lb.value, unusedBVs + lb.bound, impures) - genFlags0(lb.body, updated._1, updated._2) - case Impure => - val updated = update(lb.value, unusedBVs + lb.bound, impures + lb.bound) - genFlags0(lb.body, updated._1, updated._2) - } + case lb: LetBinding => + val updated = update( + lb.value, + unusedBVs + lb.bound, + effect(lb.value) match { + case Pure => impures + case Impure => impures + lb.bound + } + ) + genFlags0(lb.body, updated._1, updated._2) + case bv: BoundVal => (unusedBVs - bv, impures) case _ => (unusedBVs, impures) } @@ -405,6 +407,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } def extractWithState(xtor: Rep, xtee: Rep)(done: State => Boolean)(implicit es: State): ExtractState = { + println(es) + def extractHOPHole(name: String, typ: TypeRep, argss: List[List[Rep]], visible: List[BoundVal])(implicit es: State): ExtractState = { println("EXTRACTINGHOPHOLE") type Func = List[List[BoundVal]] -> Rep @@ -432,57 +436,40 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } def extendFunc(args: List[Rep], maybeFuncAndState: Option[(Func, State)]): Option[(Func, State)] = { - println(s"ARGS0: $args") - val hopArgs = args.map(arg => bindVal("hopArg", arg.typ, Nil)) - val transformations = args zip hopArgs - println(s"Transformation: $transformations") - - def bvs(r: Rep): List[BoundVal] = { - def bvs0(r: Rep, acc: List[BoundVal]): List[BoundVal] = r match { - case lb: LetBinding => lb.bound :: acc - case _ => acc - } - - bvs0(r, List.empty) - } - - val after = for { + for { (f, es) <- maybeFuncAndState - _ = println(s"BEFORE $f") + newBodyAndState = transformations.foldLeft(fbody(f) -> es) { case ((body, es), (bv: BoundVal, hopArg)) => val replace = es.ctx(bv) replace rebind hopArg - body -> es + (body, es) case ((body, es), (lb: LetBinding, hopArg)) => - val done = (s: State) => bvs(lb) forall (s.ctx.keySet contains _) + val lbBVs = bvs(lb) + val done = (s: State) => lbBVs forall (s.ctx.keySet contains _) // TODO Keep set of matched BVs in state? + extractWithState(lb, body)(done)(es) map { es => val replace = es.ctx(lb.last.bound) bottomUpPartial(filterLBs(body)(es.ctx.values.toSet contains _.bound)) { case `replace` => hopArg } -> es } getOrElse body -> es - case ((body, es), (r, hopArg)) => bottomUpPartial(body) { case `r` => hopArg } -> es + case ((body, es), (r, hopArg)) => (bottomUpPartial(body) { case `r` => hopArg }, es) } - _ = bottomUpPartial(newBodyAndState._1) { case bv: BoundVal if visible contains bv => return None } - } yield ((hopArgs :: fargss(f)) -> newBodyAndState._1) -> newBodyAndState._2 - - println(s"AFTER: $after") - - after + + _ = bottomUpPartial(newBodyAndState._1) { case bv: BoundVal if visible contains bv => return None } // TODO is too early to check? If there are more args left + } yield ((hopArgs :: fargss(f)) -> newBodyAndState._1, newBodyAndState._2) } val oe = for { es1 <- typ extract (xtee.typ, Covariant) m <- merge(es.ex, es1) (f, es2) <- argss.foldRight(Option(emptyFunc(xtee) -> (es withNewExtract m)))(extendFunc) - _ = println(s"F: $f") l = fargss(f).foldRight(fbody(f)) { case (args, body) => wrapConstruct(lambda(args, body)) } if !hasUndeclaredBVs(l) - //m2 <- merge(es2.ex, repExtract(name -> l)) } yield es2 updateExtractWith Some(repExtract(name -> l)) oe getOrElse Left(es) @@ -542,30 +529,18 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi newEs map (_ withMatchedImpures r) } - def extractInside(bv: BoundVal, d: Def)(implicit es: State): ExtractState = { - def bvs(d: Def): List[BoundVal] = d match { - case ma: MethodApp => (ma.self :: ma.argss.argssList).foldRight(List.empty[BoundVal]) { - case (bv: BoundVal, acc) => bv :: acc - case (_, acc) => acc - } - case _ => Nil - } - + def extractInside(bv: BoundVal, d: Def)(implicit es: State): ExtractState = bvs(d).foldLeft[ExtractState](Left(es)) { case (acc, bv2) => for { es1 <- acc.left es2 <- extractWithState(bv, bv2)(done)(es1).left } yield es2 } - } def contentsOf(h: Hole)(implicit es: State): Option[Rep] = es.ex._1 get h.name // TODO check in ex._3, return Option[List[Rep]] println(s"extractWithState: $xtor\n$xtee\n") - if (done(es)) { - println(s"FINISHED") - Right(es) - } + if (done(es)) Right(es) else xtor -> xtee match { case (h: Hole, lb: LetBinding) => contentsOf(h) match { case Some(lb1: LetBinding) if lb1.value == lb.value => Right(es) @@ -583,7 +558,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (lb1: LetBinding, lb2: LetBinding) => extractLBs(lb1, lb2)(done) - // Stop at markers? + // TODO Stop at markers? case (lb: LetBinding, _: Rep) => (effect(lb), effect(xtee)) match { case (Pure, Pure) => extractWithState(lb.body, xtee)(done) case (Impure, Pure) => Left(es) @@ -609,10 +584,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (bv1: BoundVal, bv2: BoundVal) => println(s"OWNERS: ${bv1.owner} -- ${bv2.owner}") - if (es.ctx.getOrElse(bv1, bv1) == bv2) { - println(s"SEEYA") - Right(es) - } + if (es.ctx.getOrElse(bv1, bv1) == bv2) Right(es) else if (es.failedMatches(bv1) contains bv2) Left(es) else (bv1.owner, bv2.owner) match { case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value)(done) match { @@ -714,25 +686,29 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } } - override def rewriteRep(xtor: Rep, xtee: Rep, code: Extract => Option[Rep]): Option[Rep] = { - println(s"Again") + override def rewriteRep(xtor: Rep, xtee: Rep, code: Extract => Option[Rep]): Option[Rep] = rewriteRep0(xtor, xtee, code)(false)(State.forRewriting(xtor, xtee)) - } def rewriteRep0(xtor: Rep, xtee: Rep, code: Extract => Option[Rep])(internalRec: Boolean)(implicit es: State): Option[Rep] = { def rewriteRepWithState(xtor: Rep, xtee: Rep)(implicit es: State): ExtractState = { println(s"rewriteRepWithState(\n\t$xtor\n\t$xtee)($es)") (xtor, xtee) match { - case (lb1: LetBinding, lb2: LetBinding) if !internalRec => ((effect(lb1.value), es.flags.xtorFlag(lb1.bound)), (effect(lb2.value), es.flags.xteeFlag(lb2.bound))) match { - case ((Pure, Skip), (Pure, Skip)) => Left(es) - case _ => extractWithState(lb1, lb2)(_ => false) // TODO With unreachable handling TODO why did I mean? - } + case (lb1: LetBinding, lb2: LetBinding) if !internalRec => + ((effect(lb1.value), es.flags.xtorFlag(lb1.bound)), (effect(lb2.value), es.flags.xteeFlag(lb2.bound))) match { + case ((Pure, Skip), (Pure, Skip)) => Left(es) + case _ => extractWithState(lb1, lb2)(_ => false) + } case _ => extractWithState(xtor, xtee)(_ => false) } } def genCode(implicit es: State): Option[Rep] = { + + /** + * First sanity check on the extraction. + * Checks if the BVs in the extract are declared orwere defined by the user. + */ def preCheck(ex: Extract): Boolean = { def preCheckRep(declaredBVs: Set[BoundVal], invCtx: Map[BoundVal, Set[BoundVal]], r: Rep): Boolean = { def preCheckDef(declaredBVs: Set[BoundVal], invCtx: Map[BoundVal, Set[BoundVal]], d: Def): Boolean = { @@ -767,17 +743,22 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi (ex._1.values ++ ex._3.values.flatten).forall(preCheckRep(Set.empty, invCtx, _)) } + /** + * Final check after rewriting the program. + * Checks if all the BVs are declared and that the removed + * let-binding are not referenced anymore in the code. + */ def check(declaredBVs: Set[BoundVal], matchedImpureBVs: Set[BoundVal])(r: Rep): Boolean = { - def checkDef(declaredBVs: Set[BoundVal], matchedBVs: Set[BoundVal])(d: Def): Boolean = d match { + def checkDef(declaredBVs: Set[BoundVal], matchedImpureBVs: Set[BoundVal])(d: Def): Boolean = d match { case ma: MethodApp => (ma.self :: ma.argss.argssList) forall { - case bv: BoundVal => (declaredBVs contains bv) || !(matchedBVs contains bv) - case lb: LetBinding => check(declaredBVs + lb.bound, matchedBVs)(lb) + case bv: BoundVal => (declaredBVs contains bv) || !(matchedImpureBVs contains bv) + case lb: LetBinding => check(declaredBVs + lb.bound, matchedImpureBVs)(lb) case _ => true } case l: Lambda => ((declaredBVs contains l.bound) || - !(matchedBVs contains l.bound)) && - check(declaredBVs, matchedBVs)(l.body) + !(matchedImpureBVs contains l.bound)) && + check(declaredBVs, matchedImpureBVs)(l.body) case _ => true } @@ -810,7 +791,23 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi lb case _ => r } - + + def bvs(r: Rep): List[BoundVal] = { + def bvs0(r: Rep, acc: List[BoundVal]): List[BoundVal] = r match { + case lb: LetBinding => lb.bound :: acc + case _ => acc + } + + bvs0(r, List.empty) + } + def bvs(d: Def): List[BoundVal] = d match { + case ma: MethodApp => (ma.self :: ma.argss.argssList).foldRight(List.empty[BoundVal]) { + case (bv: BoundVal, acc) => bv :: acc + case (_, acc) => acc + } + case _ => Nil + } + // * --- * --- * --- * Implementations of `QuasiBase` methods * --- * --- * --- * def hole(name: String, typ: TypeRep) = Hole(name, typ) diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index d0e740b4..eb0e6243 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -2,7 +2,7 @@ package squid package ir package fastir -class RewritingTests extends MyFunSuiteBase(BasicTests.Embedding) { +class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { import RewritingTests.Embedding.Predef._ test("Simple rewrites") { From a6f6170e7608ebf8fc68673d8d7f8078da61ae94 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Tue, 28 Nov 2017 15:45:43 +0100 Subject: [PATCH 23/66] Fix issue where impure statements could be matched wrongly --- src/main/scala/squid/ir/fastanf/FastANF.scala | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 29388547..22331e93 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -584,18 +584,25 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (bv1: BoundVal, bv2: BoundVal) => println(s"OWNERS: ${bv1.owner} -- ${bv2.owner}") - if (es.ctx.getOrElse(bv1, bv1) == bv2) Right(es) - else if (es.failedMatches(bv1) contains bv2) Left(es) - else (bv1.owner, bv2.owner) match { - case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value)(done) match { - case Right(es) => effect(lb2.value) match { - case Pure => Right(es withCtx lb1.bound -> lb2.bound withoutFlags(bv1, bv2)) - case Impure => Right(es withCtx lb1.bound -> lb2.bound withMatchedImpures lb2.bound withoutFlags(bv1, bv2)) + println(s"STATE: $es") + + es.ctx.get(bv1) map { bv => + if (bv != bv2) Left(es) + else Right(es) + } getOrElse { + if (bv1 == bv2) Right(es) + else if (es.failedMatches(bv1) contains bv2) Left(es) + else (bv1.owner, bv2.owner) match { + case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value)(done) match { + case Right(es) => effect(lb2.value) match { + case Pure => Right(es withCtx lb1.bound -> lb2.bound withoutFlags(bv1, bv2)) + case Impure => Right(es withCtx lb1.bound -> lb2.bound withMatchedImpures lb2.bound withoutFlags(bv1, bv2)) + } + case Left(es) => Left(es withFailed lb1.bound -> lb2.bound) } - case Left(es) => Left(es withFailed lb1.bound -> lb2.bound) + case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2)(done) map (_ withCtx l1.bound -> l2.bound) // TODO handle failed extract? + case _ => Left(es withFailed bv1 -> bv2) } - case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2)(done) map (_ withCtx l1.bound -> l2.bound) // TODO handle failed extract? - case _ => Left(es withFailed bv1 -> bv2) } case (Constant(v1), Constant(v2)) if v1 == v2 => es updateExtractWith (xtor.typ extract(xtee.typ, Covariant)) From 9c8523113f398f083930d65284a49eed52e1df74 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Tue, 28 Nov 2017 16:06:46 +0100 Subject: [PATCH 24/66] Some more extraction tests --- .../squid/ir/fastir/ExtractingTests.scala | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/test/scala/squid/ir/fastir/ExtractingTests.scala b/src/test/scala/squid/ir/fastir/ExtractingTests.scala index 6ac10c70..ec55a27f 100644 --- a/src/test/scala/squid/ir/fastir/ExtractingTests.scala +++ b/src/test/scala/squid/ir/fastir/ExtractingTests.scala @@ -46,6 +46,43 @@ class ExtractingTests extends MyFunSuiteBase(ExtractingTests.Embedding) { assert(h2 =~= ir"1") assert(h3 =~= ir"1337") } + + ir"val a = 10.toDouble; val b = 20.toDouble; val c = 30.toDouble; a + b" match { + case ir"val aX = ($a: Int).toDouble; val bX= 20.toDouble; val cX = ($c: Int).toDouble; aX + bX" => + assert(a =~= ir"10") + assert(c =~= ir"30") + } + + ir"val a = 10.toDouble; val b = 20.toDouble; val c = 30.toDouble; a + b" match { + case ir"val bX= 20.toDouble; val aX = ($a: Int).toDouble; val cX = ($c: Int).toDouble; aX + bX" => + assert(a =~= ir"10") + assert(c =~= ir"30") + } + + // TODO better dead-end handling, should try to match dead-ends only with dead-ends? + //ir"val a = 10.toDouble; val b = 20.toDouble; val c = 30.toDouble; a + b" match { + // case ir"val cX = ($c: Int).toDouble; val bX= 20.toDouble; val aX = ($a: Int).toDouble; aX + bX" => + // assert(a =~= ir"10") + // assert(c =~= ir"30") + //} + } + + test("Extracting impure statements") { + ir"val a = readInt; val b = readInt; a + b" match { + case ir"val aX = readInt; val bX = readInt; aX + bX" => + } + + assert(Try { + ir"val a = readInt; val b = readInt; a + b" match { + case ir"val aX = readInt; val bX = readInt; bX + aX" => fail + } + }.isFailure) + + // TODO Holes in impure position and first position get removed :s + //ir"val a = readInt; val b = readInt; b" match { + // //case ir"val t = (${ir"($h: Int)"}: Int); val bX = readInt; bX" => assert(h =~= ir"readInt") + // //case ir"${ir"($h: Int)"}; val bX = readInt; bX" => assert(h =~= ir"readInt") + //} } } From 4497d2680be2c9b60d37a807850887172eab7359 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Tue, 28 Nov 2017 18:15:18 +0100 Subject: [PATCH 25/66] Fix issue where a hole would disappear when not referenced --- src/main/scala/squid/ir/fastanf/FastANF.scala | 2 +- src/test/scala/squid/ir/fastir/ExtractingTests.scala | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 22331e93..f1e09fc7 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -164,7 +164,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi // letin(x, Hole, Constant(20)) => `val tmp = defHole; 20;` - val dh = wrapConstruct(new LetBinding(bound.name, bound, DefHole(h), bound) alsoApply bound.rebind) // flag wrapConstruct? + val dh = DefHole(h) |> letbind //(dh |>? { // case bv: BoundVal => bv.owner |>? { diff --git a/src/test/scala/squid/ir/fastir/ExtractingTests.scala b/src/test/scala/squid/ir/fastir/ExtractingTests.scala index ec55a27f..b396dcf5 100644 --- a/src/test/scala/squid/ir/fastir/ExtractingTests.scala +++ b/src/test/scala/squid/ir/fastir/ExtractingTests.scala @@ -78,11 +78,9 @@ class ExtractingTests extends MyFunSuiteBase(ExtractingTests.Embedding) { } }.isFailure) - // TODO Holes in impure position and first position get removed :s - //ir"val a = readInt; val b = readInt; b" match { - // //case ir"val t = (${ir"($h: Int)"}: Int); val bX = readInt; bX" => assert(h =~= ir"readInt") - // //case ir"${ir"($h: Int)"}; val bX = readInt; bX" => assert(h =~= ir"readInt") - //} + ir"val a = readInt; val b = readInt; b" match { + case ir"val t = ($h: Int); val bX = readInt; bX" => assert(h =~= ir"readInt") + } } } From 3431b25e4a2fc0762b27380932a6a97a00d76444 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Tue, 28 Nov 2017 19:22:32 +0100 Subject: [PATCH 26/66] Fix issue where rewriting would remove the rest of xtee --- src/main/scala/squid/ir/fastanf/Effects.scala | 1 + src/main/scala/squid/ir/fastanf/FastANF.scala | 20 ++++++++++++++++++- .../squid/ir/fastir/RewritingTests.scala | 7 +++++++ 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/main/scala/squid/ir/fastanf/Effects.scala b/src/main/scala/squid/ir/fastanf/Effects.scala index 04a70b79..cd6a0b17 100644 --- a/src/main/scala/squid/ir/fastanf/Effects.scala +++ b/src/main/scala/squid/ir/fastanf/Effects.scala @@ -68,6 +68,7 @@ trait StandardEffects extends Effects { addPureMtd(MethodSymbol(TypeSymbol("scala.Double"),"$plus")) addPureMtd(MethodSymbol(TypeSymbol("scala.Int"),"$times")) addPureMtd(MethodSymbol(TypeSymbol("scala.Int"), "toDouble")) + addPureMtd(MethodSymbol(TypeSymbol("scala.Double"), "toInt")) addPureMtd(MethodSymbol(TypeSymbol("scala.Int"), "toFloat")) addPureMtd(MethodSymbol(TypeSymbol("scala.Option$"), "apply")) addPureMtd(MethodSymbol(TypeSymbol("scala.Option"), "get")) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index f1e09fc7..ade91506 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -776,10 +776,28 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } } + def appendRestOfXtee(code: Rep, xtor: Rep, ctx: Ctx): Rep = xtor match { + case lb: LetBinding => + val xtorLast = lb.last + val lastXteeMatched = ctx(xtorLast.bound) + lastXteeMatched.owner |>? { + case innerLB: LetBinding => code |>? { + case codeLB: LetBinding => + val codeLast = codeLB.last + codeLast.body = innerLB.body + bottomUpPartial(code) { case `lastXteeMatched` => codeLast.bound } + } + } + code + + case _ => code + } + if (preCheck(es.ex)) for { code <- code(es.ex) - if check(Set.empty, es.matchedImpureBVs)(filterLBs(code)(es.matchedImpureBVs contains _.bound)) + codeWithRest = appendRestOfXtee(code, xtor, es.ctx) + if check(Set.empty, es.matchedImpureBVs)(filterLBs(codeWithRest)(es.matchedImpureBVs contains _.bound)) } yield code else None } diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index eb0e6243..1160cfd7 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -55,6 +55,13 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { assert(b =~= ir"val t = ${ir"Option(42)"}; 42") } + test("Rewriting sequences of bindings") { + val a = ir"val a = readInt; val b = readDouble.toInt; a + b" rewrite { + case ir"readDouble.toInt" => ir"readInt" + } + assert(a =~= ir"val a = readInt; val b = readInt; a + b") + } + //test("Rewriting with impures") { // val a = ir"val a = readInt; val b = readInt; (a + b) * 0.5" rewrite { // case ir"(($h1: Int) + ($h2: Int)) * 0.5" => dbg_ir"($h1 * $h2) + 42.0" From 0ff74a5d4b709bc859fca3e94564096f8228b021 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Tue, 28 Nov 2017 19:39:49 +0100 Subject: [PATCH 27/66] Holes now extract the entire let-binding --- src/main/scala/squid/ir/fastanf/FastANF.scala | 2 +- src/test/scala/squid/ir/fastir/ExtractingTests.scala | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index ade91506..e2cdfe82 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -516,7 +516,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (Hole(n, t), lb: LetBinding) => es.updateExtractWith( t extract(lb.typ, Covariant), - Some(repExtract(n -> wrapConstruct(letbind(lb.value)))) + Some(repExtract(n -> lb)) ).map(updateFlags(lb, _)) case (Hole(n, t), _) => diff --git a/src/test/scala/squid/ir/fastir/ExtractingTests.scala b/src/test/scala/squid/ir/fastir/ExtractingTests.scala index b396dcf5..04d5edab 100644 --- a/src/test/scala/squid/ir/fastir/ExtractingTests.scala +++ b/src/test/scala/squid/ir/fastir/ExtractingTests.scala @@ -25,6 +25,10 @@ class ExtractingTests extends MyFunSuiteBase(ExtractingTests.Embedding) { assert(l =~= ir"42") assert(r =~= ir"1337") } + + ir"val a = 10.toDouble; val b = 42.toDouble; a + b" match { + case ir"($body: Double)" => assert(body =~= ir"val a = 10.toDouble; val b = 42.toDouble; a + b") + } } test("Matching with impure statements") { From 7f2e64b790e6ca7868936154eaaf5320c896ce5c Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Tue, 28 Nov 2017 20:35:51 +0100 Subject: [PATCH 28/66] Remove redundant no flags left check --- src/main/scala/squid/ir/fastanf/FastANF.scala | 23 +++++-------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index e2cdfe82..38bd135d 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -319,10 +319,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi protected def extract(xtor: Rep, xtee: Rep): Option[Extract] = { println(s"Extract(\n$xtor, \n$xtee)") - for { - es <- extractWithState(xtor, xtee)(_ => false)(State.forExtraction(xtor, xtee)).fold(_ => None, Some(_)) - if es.flags.xtor.isEmpty && es.flags.xtee.isEmpty - } yield es.ex + extractWithState(xtor, xtee)(_ => false)(State.forExtraction(xtor, xtee)).fold(_ => None, Some(_)) map (_.ex) } type Ctx = Map[BoundVal, BoundVal] @@ -338,9 +335,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def withNewExtract(newEx: Extract): State = copy(ex = newEx) def withCtx(newCtx: Ctx): State = copy(ctx = newCtx) def withCtx(p: (BoundVal, BoundVal)): State = copy(ctx = ctx + p) - def updateFlags(newFlags: Flags): State = copy(flags = newFlags) - def withoutFlags(xtorFlag: BoundVal, xteeFlag: BoundVal): State = copy(flags = Flags(flags.xtor - xtorFlag, flags.xtee - xteeFlag)) - def withoutXteeFlag(flag: BoundVal): State = copy(flags = flags.copy(xtee = flags.xtee - flag)) def withMatchedImpures(r: Rep): State = r match { case lb: LetBinding if !isPure(lb.value) => copy(matchedImpureBVs = matchedImpureBVs + lb.bound) withMatchedImpures lb.body case lb: LetBinding => this withMatchedImpures lb.body @@ -440,7 +434,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi val transformations = args zip hopArgs for { - (f, es) <- maybeFuncAndState + (f, es) <- maybeFuncAndState alsoApply println newBodyAndState = transformations.foldLeft(fbody(f) -> es) { case ((body, es), (bv: BoundVal, hopArg)) => @@ -501,23 +495,18 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def extractHole(h: Hole, r: Rep)(implicit es: State): ExtractState = { println(s"ExtractHole: $h --> $r") - def updateFlags(r: Rep, es: State): State = r match { - case lb: LetBinding => updateFlags(lb.body, es withoutXteeFlag lb.bound) - case _ => es - } - val newEs = (h, r) match { case (Hole(n, t), bv: BoundVal) => es.updateExtractWith( t extract(xtee.typ, Covariant), Some(repExtract(n -> bv)) - ) map (_ withoutXteeFlag bv) + ) case (Hole(n, t), lb: LetBinding) => es.updateExtractWith( t extract(lb.typ, Covariant), Some(repExtract(n -> lb)) - ).map(updateFlags(lb, _)) + ) case (Hole(n, t), _) => es.updateExtractWith( @@ -595,8 +584,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi else (bv1.owner, bv2.owner) match { case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value)(done) match { case Right(es) => effect(lb2.value) match { - case Pure => Right(es withCtx lb1.bound -> lb2.bound withoutFlags(bv1, bv2)) - case Impure => Right(es withCtx lb1.bound -> lb2.bound withMatchedImpures lb2.bound withoutFlags(bv1, bv2)) + case Pure => Right(es withCtx lb1.bound -> lb2.bound) + case Impure => Right(es withCtx lb1.bound -> lb2.bound withMatchedImpures lb2.bound) } case Left(es) => Left(es withFailed lb1.bound -> lb2.bound) } From 3aef763ff2d349bece02fbecf4604c7590a33a58 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Thu, 30 Nov 2017 12:50:55 +0100 Subject: [PATCH 29/66] Check flags when looking for the return value --- src/main/scala/squid/ir/fastanf/FastANF.scala | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 38bd135d..97254c02 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -547,11 +547,16 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (lb1: LetBinding, lb2: LetBinding) => extractLBs(lb1, lb2)(done) - // TODO Stop at markers? - case (lb: LetBinding, _: Rep) => (effect(lb), effect(xtee)) match { - case (Pure, Pure) => extractWithState(lb.body, xtee)(done) - case (Impure, Pure) => Left(es) - case (_, Impure) => Left(es) // Assuming the return value cannot be impure + //// TODO Stop at markers? + //case (lb: LetBinding, _: Rep) => (effect(lb), effect(xtee)) match { + // case (Pure, Pure) => extractWithState(lb.body, xtee)(done) + // case (Impure, Pure) => Left(es) + // case (_, Impure) => Left(es) // Assuming the return value cannot be impure + //} + + case (lb: LetBinding, _: Rep) => es.flags.xtorFlag(lb.bound) match { + case Start => Left(es) + case Skip => extractWithState(lb.body, xtee)(done) } case (bv: BoundVal, lb: LetBinding) => for { From 0c20ec6ef2b69888b11c08caac833b9a1160e5fb Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Thu, 30 Nov 2017 13:45:23 +0100 Subject: [PATCH 30/66] Fix issue where substitute gets called outside a reification context --- .../src/main/scala/squid/lang/Optimizer.scala | 2 +- .../main/scala/squid/quasi/QuasiBase.scala | 4 +++ src/main/scala/squid/ir/fastanf/FastANF.scala | 29 ++++++++++++++----- .../squid/ir/fastir/RewritingTests.scala | 7 +++++ 4 files changed, 34 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/squid/lang/Optimizer.scala b/core/src/main/scala/squid/lang/Optimizer.scala index f9ded26f..913e7ba7 100644 --- a/core/src/main/scala/squid/lang/Optimizer.scala +++ b/core/src/main/scala/squid/lang/Optimizer.scala @@ -7,7 +7,7 @@ trait Optimizer { //final def optimizeRep(pgrm: Rep): Rep = pipeline(pgrm) final def optimizeRep(pgrm: Rep): Rep = { // TODO do this Transformer's `pipeline`......? val r = pipeline(pgrm) - if (!(r eq pgrm)) substitute(r) else r // only calls substitute if a transformation actually happened + if (!(r eq pgrm)) insertAfterTransformation(r) else r // only calls insertAfterTransformation if a transformation actually happened } final def optimize[T,C](pgrm: IR[T,C]): IR[T,C] = `internal IR`[T,C](optimizeRep(pgrm.rep)) diff --git a/core/src/main/scala/squid/quasi/QuasiBase.scala b/core/src/main/scala/squid/quasi/QuasiBase.scala index d14af403..ffac516d 100644 --- a/core/src/main/scala/squid/quasi/QuasiBase.scala +++ b/core/src/main/scala/squid/quasi/QuasiBase.scala @@ -33,6 +33,8 @@ self: Base => def substitute(r: => Rep, defs: Map[String, Rep]): Rep def substituteLazy(r: => Rep, defs: Map[String, () => Rep]): Rep = substitute(r, defs map (kv => kv._1 -> kv._2())) + def insertAfterTransformation(r: => Rep, defs: Map[String, Rep]): Rep = substitute(r, defs) + /** Not yet used: should eventually be defined by all IRs as something different than hole */ def freeVar(name: String, typ: TypeRep) = hole(name, typ) @@ -70,6 +72,8 @@ self: Base => /* if (defs isEmpty) r else */ // <- This "optimization" is not welcome, as some IRs (ANF) may relie on `substitute` being called for all insertions substitute(r, defs.toMap) + final def insertAfterTransformation(r: => Rep, defs: (String, Rep)*): Rep = insertAfterTransformation(r, defs.toMap) + protected def mkIR[T,C](r: Rep): IR[T,C] = new IR[T,C] { val rep = r override def equals(that: Any): Boolean = that match { diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 97254c02..b3281bcb 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -837,13 +837,28 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi HOPHole2(name, typ, args, visible filterNot (args.flatten contains _)) def substitute(r: => Rep, defs: Map[String, Rep]): Rep = { println(s"Subs: $r with $defs") - if (defs isEmpty) r |> inlineBlock // TODO works if I remove this... - else bottomUp(r) { - case h@Hole(n, _) => defs getOrElse(n, h) - case h@SplicedHole(n, _) => defs getOrElse(n, h) - //case h: BoundVal => defs getOrElse(h.name, h) // TODO FVs in lambda become BVs too early, this should be changed!! - case h => h - } |> inlineBlock + val r0 = + if (defs isEmpty) r + else bottomUp(r) { + case h@Hole(n, _) => defs getOrElse(n, h) + case h@SplicedHole(n, _) => defs getOrElse(n, h) + //case h: BoundVal => defs getOrElse(h.name, h) // TODO FVs in lambda become BVs too early, this should be changed!! + case h => h + } + + r0 |> inlineBlock + } + override def insertAfterTransformation(r: => Rep, defs: Map[String, Rep]): Rep = { + // TODO for now we do nothing to r. Later make sure that after applying the defs it is still valid in ANF! + require(defs.isEmpty) + r + + //if (defs isEmpty) r + //else bottomUp(r) { + // case h@Hole(n, _) => defs getOrElse(n, h) + // case h@SplicedHole(n, _) => defs getOrElse(n, h) + // case h => h + //} } diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index 1160cfd7..a08130a0 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -55,6 +55,13 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { assert(b =~= ir"val t = ${ir"Option(42)"}; 42") } + test("Substitution should be called from inside a reification context") { + val a = ir"readDouble" rewrite { + case ir"readDouble" => ir"42.toDouble" + } + assert(a =~= ir"42.toDouble") + } + test("Rewriting sequences of bindings") { val a = ir"val a = readInt; val b = readDouble.toInt; a + b" rewrite { case ir"readDouble.toInt" => ir"readInt" From 512732969a143adbc0bee8fa1c6535547d7a4a89 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Thu, 30 Nov 2017 14:07:04 +0100 Subject: [PATCH 31/66] Fix issue where the generated code is wrong when there's a hole at the return position --- src/main/scala/squid/ir/fastanf/FastANF.scala | 37 ++++++++++--------- .../squid/ir/fastir/RewritingTests.scala | 7 ++++ 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index b3281bcb..fc7b2aaf 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -772,32 +772,35 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def appendRestOfXtee(code: Rep, xtor: Rep, ctx: Ctx): Rep = xtor match { case lb: LetBinding => - val xtorLast = lb.last - val lastXteeMatched = ctx(xtorLast.bound) - lastXteeMatched.owner |>? { - case innerLB: LetBinding => code |>? { - case codeLB: LetBinding => - val codeLast = codeLB.last - codeLast.body = innerLB.body - bottomUpPartial(code) { case `lastXteeMatched` => codeLast.bound } - } + val xtorLast = lb.last + xtorLast.body match { + case _: Hole | _: HOPHole2 => code + case _ => + val lastXteeMatched = ctx(xtorLast.bound) + lastXteeMatched.owner |>? { + case innerLB: LetBinding => code |>? { + case codeLB: LetBinding => + val codeLast = codeLB.last + codeLast.body = innerLB.body + bottomUpPartial(code) { case `lastXteeMatched` => codeLast.bound } + } + } + code } - code case _ => code } - if (preCheck(es.ex)) - for { - code <- code(es.ex) - codeWithRest = appendRestOfXtee(code, xtor, es.ctx) - if check(Set.empty, es.matchedImpureBVs)(filterLBs(codeWithRest)(es.matchedImpureBVs contains _.bound)) - } yield code + if (preCheck(es.ex)) for { + code <- code(es.ex) + codeWithRest = appendRestOfXtee(code, xtor, es.ctx) + if check(Set.empty, es.matchedImpureBVs)(filterLBs(codeWithRest)(es.matchedImpureBVs contains _.bound)) + } yield code else None } rewriteRepWithState(xtor, xtee) match { - case Right(es) => genCode(es) alsoApply println + case Right(es) => genCode(es) alsoApply(c => println(s"GEN: $c")) case Left(_) => None } } diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index a08130a0..0fdfaff3 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -62,6 +62,13 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { assert(a =~= ir"42.toDouble") } + test("Code generation should handle a hole in return position") { + val a = ir"val a = 10.toDouble; val b = 22.toDouble; a + b" rewrite { + case ir"val a = 10.toDouble; ($body:Double)" => ir"readDouble" + } + assert(a =~= ir"readDouble") + } + test("Rewriting sequences of bindings") { val a = ir"val a = readInt; val b = readDouble.toInt; a + b" rewrite { case ir"readDouble.toInt" => ir"readInt" From 13ae0a01df91974e9df4259aaa63b1989d673946 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Thu, 30 Nov 2017 17:09:49 +0100 Subject: [PATCH 32/66] Remove extremely slow intermediate step in the HOPHole extraction --- src/main/scala/squid/ir/fastanf/FastANF.scala | 71 +++++++++++-------- 1 file changed, 43 insertions(+), 28 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index fc7b2aaf..ce9bbcf4 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -405,10 +405,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def extractHOPHole(name: String, typ: TypeRep, argss: List[List[Rep]], visible: List[BoundVal])(implicit es: State): ExtractState = { println("EXTRACTINGHOPHOLE") - type Func = List[List[BoundVal]] -> Rep - def emptyFunc(r: Rep) = List.empty[List[BoundVal]] -> r - def fargss(f: Func) = f._1 - def fbody(f: Func) = f._2 def hasUndeclaredBVs(r: Rep): Boolean = { println(s"Checking $r") @@ -429,42 +425,61 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi hasUndeclaredBVs0(r, Set.empty) } - def extendFunc(args: List[Rep], maybeFuncAndState: Option[(Func, State)]): Option[(Func, State)] = { + def extendFunc(args: List[Rep], maybeCurrFuncAndState: Option[(Rep, State)]): Option[(Rep, State)] = { val hopArgs = args.map(arg => bindVal("hopArg", arg.typ, Nil)) val transformations = args zip hopArgs + def body(r: Rep): Rep = { + def body0(d: Def): Option[Rep] = d match { + case l: Lambda => l.body match { + case lb: LetBinding => Some(body(lb)) + case body => Some(body) + } + case _ => None + } + + + r match { + case lb: LetBinding => body0(lb.value) getOrElse lb + case _ => r + } + } + for { - (f, es) <- maybeFuncAndState alsoApply println - - newBodyAndState = transformations.foldLeft(fbody(f) -> es) { - case ((body, es), (bv: BoundVal, hopArg)) => - val replace = es.ctx(bv) - replace rebind hopArg - (body, es) - - case ((body, es), (lb: LetBinding, hopArg)) => - val lbBVs = bvs(lb) - val done = (s: State) => lbBVs forall (s.ctx.keySet contains _) // TODO Keep set of matched BVs in state? - - extractWithState(lb, body)(done)(es) map { es => - val replace = es.ctx(lb.last.bound) - bottomUpPartial(filterLBs(body)(es.ctx.values.toSet contains _.bound)) { case `replace` => hopArg } -> es - } getOrElse body -> es - - case ((body, es), (r, hopArg)) => (bottomUpPartial(body) { case `r` => hopArg }, es) + (f, es) <- maybeCurrFuncAndState alsoApply println + + newBodyAndState = transformations.foldLeft(body(f) -> es) { + case ((body, es), (arg, hopArg)) => arg match { + case bv: BoundVal => + val replace = es.ctx(bv) + replace rebind hopArg + body -> es + + case lb: LetBinding => + val lbBVs = bvs(lb) + val done: State => Boolean = s => lbBVs forall (s.ctx.keySet contains _) // TODO Keep set of matched BVs in state? + + extractWithState(lb, body)(done)(es) map { es => + val replace = es.ctx(lb.last.bound) + bottomUpPartial(filterLBs(body)(es.ctx.values.toSet contains _.bound)) { case `replace` => hopArg } -> es + } getOrElse body -> es + + case _ => bottomUpPartial(body) { case `arg` => hopArg } -> es + } } _ = bottomUpPartial(newBodyAndState._1) { case bv: BoundVal if visible contains bv => return None } // TODO is too early to check? If there are more args left - } yield ((hopArgs :: fargss(f)) -> newBodyAndState._1, newBodyAndState._2) + } yield newBodyAndState match { + case (func0, es0) => wrapConstruct(lambda(hopArgs, func0)) -> es0 + } } val oe = for { es1 <- typ extract (xtee.typ, Covariant) m <- merge(es.ex, es1) - (f, es2) <- argss.foldRight(Option(emptyFunc(xtee) -> (es withNewExtract m)))(extendFunc) - l = fargss(f).foldRight(fbody(f)) { case (args, body) => wrapConstruct(lambda(args, body)) } - if !hasUndeclaredBVs(l) - } yield es2 updateExtractWith Some(repExtract(name -> l)) + (f, es2) <- argss.foldRight(Option(xtee -> (es withNewExtract m)))(extendFunc) + if !hasUndeclaredBVs(f) + } yield es2 updateExtractWith Some(repExtract(name -> f)) oe getOrElse Left(es) } From 44145b9dd8dad0284ad02ac9776397b69d6e6376 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Thu, 30 Nov 2017 17:35:16 +0100 Subject: [PATCH 33/66] Remove remnants of the unreachable mechanism --- src/main/scala/squid/ir/fastanf/FastANF.scala | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index ce9bbcf4..4fbdd6df 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -319,7 +319,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi protected def extract(xtor: Rep, xtee: Rep): Option[Extract] = { println(s"Extract(\n$xtor, \n$xtee)") - extractWithState(xtor, xtee)(_ => false)(State.forExtraction(xtor, xtee)).fold(_ => None, Some(_)) map (_.ex) + extractWithState(xtor, xtee)(_ => false)(State.init(xtor, xtee)).fold(_ => None, Some(_)) map (_.ex) } type Ctx = Map[BoundVal, BoundVal] @@ -331,7 +331,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi type ExtractState = Either[State, State] implicit def rightBias[A, B](e: Either[A, B]): Either.RightProjection[A,B] = e.right - case class State(ex: Extract, ctx: Ctx, flags: Flags, matchedImpureBVs: Set[BoundVal], failedMatches: Map[BoundVal, Set[BoundVal]], makeUnreachable: Boolean) { + case class State(ex: Extract, ctx: Ctx, flags: Flags, matchedImpureBVs: Set[BoundVal], failedMatches: Map[BoundVal, Set[BoundVal]]) { def withNewExtract(newEx: Extract): State = copy(ex = newEx) def withCtx(newCtx: Ctx): State = copy(ctx = newCtx) def withCtx(p: (BoundVal, BoundVal)): State = copy(ctx = ctx + p) @@ -347,11 +347,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } } object State { - def forRewriting(xtor: Rep, xtee: Rep): State = State(xtor, xtee, true) - def forExtraction(xtor: Rep, xtee: Rep): State = State(xtor, xtee, false) + def init(xtor: Rep, xtee: Rep): State = apply(xtor, xtee) - private def apply(xtor: Rep, xtee: Rep, makeUnreachable: Bool): State = - State(EmptyExtract, ListMap.empty, Flags(xtor, xtee), Set.empty, Map.empty.withDefaultValue(Set.empty), makeUnreachable) + private def apply(xtor: Rep, xtee: Rep): State = + State(EmptyExtract, ListMap.empty, Flags(xtor, xtee), Set.empty, Map.empty.withDefaultValue(Set.empty)) } sealed trait Flag @@ -703,7 +702,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } override def rewriteRep(xtor: Rep, xtee: Rep, code: Extract => Option[Rep]): Option[Rep] = - rewriteRep0(xtor, xtee, code)(false)(State.forRewriting(xtor, xtee)) + rewriteRep0(xtor, xtee, code)(false)(State.init(xtor, xtee)) def rewriteRep0(xtor: Rep, xtee: Rep, code: Extract => Option[Rep])(internalRec: Boolean)(implicit es: State): Option[Rep] = { def rewriteRepWithState(xtor: Rep, xtee: Rep)(implicit es: State): ExtractState = { From 0764fa9f67b5822d352aa96c7225d4ad6e9a7ebd Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Fri, 1 Dec 2017 15:59:57 +0100 Subject: [PATCH 34/66] Test rewriting with HOPHoles, lambdas, and with mutliple occurences --- .../squid/ir/fastir/RewritingTests.scala | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index 0fdfaff3..78c8d800 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -76,6 +76,37 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { assert(a =~= ir"val a = readInt; val b = readInt; a + b") } + test("Rewriting with HOPHoles") { + val a = ir"val a = 11.toDouble; val b = 22.toDouble; val c = 33.toDouble; (a,b,c)" rewrite { + case ir"val a = ($x:Int).toDouble; val b = ($y:Int).toDouble; ($body(a, b): Tuple3[Double, Double, Double])" => + ir"val a = ($x+$y).toDouble/2; $body(a, a)" + } + val f = ir"(x: Double, y: Double) => (x, y, 33.toDouble)" + assert(a =~= ir"val a = (11+${ir"22"}).toDouble/2; $f(a, a)") + } + + test("Rewriting lambdas") { + val l = ir"(x: Double) => x * 2" + + val a = ir"val a = 11.toDouble; val f = $l; f(a)" rewrite { + case ir"(x: Double) => x * 2" => ir"(x: Double) => x * 4" + } + val l0 = ir"(x: Double) => x * 4" + assert(a =~= ir"val a = 11.toDouble; val f = $l0; f(a)") + } + + test("Rewriting should happen at all occurences") { + val a = ir"val a = Option(11).get; val b = 22; val c = Option(33).get; a + b + c" rewrite { + case ir"Option(($n: Int)).get" => n + } + assert(a =~= ir"val a = Option(11).get; val b = 22; val c = Option(33).get; ${ir"11"} + 22 + 33") + + val b = ir"val a = Option(Option(11).get).get; val b = 22; a + b" rewrite { + case ir"Option(($n: Int)).get" => n + } + assert(b =~= ir"val t1 = Option(11).get; val t2 = Option(11).get; ${ir"11"} + 22") + } + //test("Rewriting with impures") { // val a = ir"val a = readInt; val b = readInt; (a + b) * 0.5" rewrite { // case ir"(($h1: Int) + ($h2: Int)) * 0.5" => dbg_ir"($h1 * $h2) + 42.0" From 59eb6588bd12b4ad2577124348623d1d72c15561 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Fri, 1 Dec 2017 22:51:17 +0100 Subject: [PATCH 35/66] Remove pure statements that depend on removed statements --- src/main/scala/squid/ir/fastanf/FastANF.scala | 82 +++++++++++++------ .../squid/ir/fastir/RewritingTests.scala | 11 +++ 2 files changed, 70 insertions(+), 23 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 4fbdd6df..2828e709 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -338,6 +338,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def withMatchedImpures(r: Rep): State = r match { case lb: LetBinding if !isPure(lb.value) => copy(matchedImpureBVs = matchedImpureBVs + lb.bound) withMatchedImpures lb.body case lb: LetBinding => this withMatchedImpures lb.body + case bv: BoundVal => copy(matchedImpureBVs = matchedImpureBVs + bv) case _ => this // Everything else is pure so we ignore it } def withFailed(p: (BoundVal, BoundVal)): State = copy(failedMatches = updateWith(failedMatches)(p)) @@ -509,7 +510,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def extractHole(h: Hole, r: Rep)(implicit es: State): ExtractState = { println(s"ExtractHole: $h --> $r") - val newEs = (h, r) match { + (h, r) match { case (Hole(n, t), bv: BoundVal) => es.updateExtractWith( t extract(xtee.typ, Covariant), @@ -520,7 +521,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi es.updateExtractWith( t extract(lb.typ, Covariant), Some(repExtract(n -> lb)) - ) + ) map (_ withMatchedImpures lb) case (Hole(n, t), _) => es.updateExtractWith( @@ -528,8 +529,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi Some(repExtract(n -> xtee)) ) } - - newEs map (_ withMatchedImpures r) } def extractInside(bv: BoundVal, d: Def)(implicit es: State): ExtractState = @@ -784,32 +783,69 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } } - def appendRestOfXtee(code: Rep, xtor: Rep, ctx: Ctx): Rep = xtor match { - case lb: LetBinding => - val xtorLast = lb.last - xtorLast.body match { - case _: Hole | _: HOPHole2 => code - case _ => - val lastXteeMatched = ctx(xtorLast.bound) - lastXteeMatched.owner |>? { - case innerLB: LetBinding => code |>? { - case codeLB: LetBinding => - val codeLast = codeLB.last - codeLast.body = innerLB.body - bottomUpPartial(code) { case `lastXteeMatched` => codeLast.bound } - } + def cleanup(r: Rep, remove: Set[BoundVal]): Rep = r match { + case lb: LetBinding if remove contains lb.bound => cleanup(lb.body, remove) + case lb: LetBinding if isPure(lb.value) => + val bvsInValue = bvs(lb.value) + + if (bvsInValue exists (remove contains)) { + cleanup(lb.body, remove ++ bvsInValue.toSet) + } else { + lb.body = cleanup(lb.body, remove) + lb + } + case lb: LetBinding => + lb.body = cleanup(lb.body, remove) + lb + case _ => r + } + + def finalize(code: Rep, xtor: Rep, filteredXtee: Rep)(ctx: Ctx): Rep = { + println(s"FOO: $code, \n$xtor, \n$filteredXtee") + + xtor match { + case xtorLB: LetBinding => + val xtorLast = xtorLB.last + xtorLast.body match { + case xtorRet: BoundVal => code match { + case codeLB: LetBinding => + val codeLast = codeLB.last + codeLast.body |>? { + case codeRet: BoundVal => + val bv = ctx(xtorRet) + codeLast.body = bottomUpPartial(filteredXtee) { case `bv` => codeRet } + } + code + + case _ => + val bv = ctx(xtorRet) + bottomUpPartial(filteredXtee) { case `bv` => code } + } + + // Hole? + case _ => code + } + + case _ => code match { + case codeLB: LetBinding => + val codeLast = codeLB.last + codeLast.body |>? { + case codeRet: BoundVal => + codeLast.body = bottomUpPartial(filteredXtee) { case `xtor` => codeRet } } code + + case _ => + bottomUpPartial(filteredXtee) { case `xtor` => code } } - - case _ => code + } } if (preCheck(es.ex)) for { code <- code(es.ex) - codeWithRest = appendRestOfXtee(code, xtor, es.ctx) - if check(Set.empty, es.matchedImpureBVs)(filterLBs(codeWithRest)(es.matchedImpureBVs contains _.bound)) - } yield code + code0 = finalize(code, xtor, cleanup(xtee, es.matchedImpureBVs))(es.ctx) + if check(Set.empty, es.matchedImpureBVs)(code0) + } yield code0 else None } diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index 78c8d800..d47a13ae 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -107,6 +107,17 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { assert(b =~= ir"val t1 = Option(11).get; val t2 = Option(11).get; ${ir"11"} + 22") } + test("Rewriting should remove the dependent pure statements") { + val a = ir"val a = readInt; val b = readDouble.toInt; a + b" rewrite { + case ir"readDouble.toInt" => ir"readInt" + } + assert(a =~= ir"val a = readInt; val b = readInt; a + b") + } + + test("Rewriting should not remove statements that weren't matched") { + + } + //test("Rewriting with impures") { // val a = ir"val a = readInt; val b = readInt; (a + b) * 0.5" rewrite { // case ir"(($h1: Int) + ($h2: Int)) * 0.5" => dbg_ir"($h1 * $h2) + 42.0" From 8c34177c8c5873c899f20ce9b5c0a3a455971239 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Fri, 1 Dec 2017 23:10:54 +0100 Subject: [PATCH 36/66] Fix issue where rewriting would remove statements that are part of the result --- src/main/scala/squid/ir/fastanf/FastANF.scala | 4 ++-- src/test/scala/squid/ir/fastir/RewritingTests.scala | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 2828e709..3c0c2436 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -843,8 +843,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi if (preCheck(es.ex)) for { code <- code(es.ex) - code0 = finalize(code, xtor, cleanup(xtee, es.matchedImpureBVs))(es.ctx) - if check(Set.empty, es.matchedImpureBVs)(code0) + code0 = finalize(code, xtor, xtee)(es.ctx) + if check(Set.empty, es.matchedImpureBVs)(cleanup(code0, es.matchedImpureBVs)) } yield code0 else None } diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index d47a13ae..36bb7c80 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -114,8 +114,11 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { assert(a =~= ir"val a = readInt; val b = readInt; a + b") } - test("Rewriting should not remove statements that weren't matched") { - + test("Rewriting should not remove statements that are part of the result") { + val a = ir"val a = readInt; val b = readDouble; a + b" rewrite { + case ir"readInt" => ir"readDouble.toInt" + } + assert(a =~= ir"val a = readDouble.toInt; val b = readDouble; a + b") } //test("Rewriting with impures") { From 21facd92c16fce01470b8f9313738bc317527109 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Sun, 3 Dec 2017 19:42:42 +0100 Subject: [PATCH 37/66] Fix some rewriting and extraction issues --- src/main/scala/squid/ir/fastanf/Effects.scala | 3 + src/main/scala/squid/ir/fastanf/FastANF.scala | 186 ++++++++++-------- .../squid/ir/fastir/ExtractingTests.scala | 10 +- .../fastir/HigherOrderPatternVariables.scala | 16 +- .../squid/ir/fastir/RewritingTests.scala | 33 +++- 5 files changed, 160 insertions(+), 88 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/Effects.scala b/src/main/scala/squid/ir/fastanf/Effects.scala index cd6a0b17..f70bc24a 100644 --- a/src/main/scala/squid/ir/fastanf/Effects.scala +++ b/src/main/scala/squid/ir/fastanf/Effects.scala @@ -67,11 +67,14 @@ trait StandardEffects extends Effects { addPureMtd(MethodSymbol(TypeSymbol("scala.Int"),"$plus")) addPureMtd(MethodSymbol(TypeSymbol("scala.Double"),"$plus")) addPureMtd(MethodSymbol(TypeSymbol("scala.Int"),"$times")) + addPureMtd(MethodSymbol(TypeSymbol("scala.Double"),"$times")) + addPureMtd(MethodSymbol(TypeSymbol("scala.Double"),"div")) addPureMtd(MethodSymbol(TypeSymbol("scala.Int"), "toDouble")) addPureMtd(MethodSymbol(TypeSymbol("scala.Double"), "toInt")) addPureMtd(MethodSymbol(TypeSymbol("scala.Int"), "toFloat")) addPureMtd(MethodSymbol(TypeSymbol("scala.Option$"), "apply")) addPureMtd(MethodSymbol(TypeSymbol("scala.Option"), "get")) addPureMtd(MethodSymbol(TypeSymbol("scala.Tuple2$"), "apply")) + addPureMtd(MethodSymbol(TypeSymbol("scala.Tuple3$"), "apply")) addPureMtd(MethodSymbol(TypeSymbol("squid.lib.package$"),"uncurried2")) } diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 3c0c2436..7da7fa2c 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -319,7 +319,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi protected def extract(xtor: Rep, xtee: Rep): Option[Extract] = { println(s"Extract(\n$xtor, \n$xtee)") - extractWithState(xtor, xtee)(_ => false)(State.init(xtor, xtee)).fold(_ => None, Some(_)) map (_.ex) + extractWithState(xtor, xtee)(_ => false)(State.forExtraction(xtor, xtee)).fold(_ => None, Some(_)) map (_.ex) } type Ctx = Map[BoundVal, BoundVal] @@ -331,7 +331,9 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi type ExtractState = Either[State, State] implicit def rightBias[A, B](e: Either[A, B]): Either.RightProjection[A,B] = e.right - case class State(ex: Extract, ctx: Ctx, flags: Flags, matchedImpureBVs: Set[BoundVal], failedMatches: Map[BoundVal, Set[BoundVal]]) { + case class State(ex: Extract, ctx: Ctx, flags: Flags, matchedImpureBVs: Set[BoundVal], failedMatches: Map[BoundVal, Set[BoundVal]], partialMatching: Boolean) { + private val _partialMatching = partialMatching + def withNewExtract(newEx: Extract): State = copy(ex = newEx) def withCtx(newCtx: Ctx): State = copy(ctx = newCtx) def withCtx(p: (BoundVal, BoundVal)): State = copy(ctx = ctx + p) @@ -342,16 +344,20 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => this // Everything else is pure so we ignore it } def withFailed(p: (BoundVal, BoundVal)): State = copy(failedMatches = updateWith(failedMatches)(p)) - + def withPartialMatching: State = copy(partialMatching = true) + def withCompleteMatching: State = copy(partialMatching = false) + def withDefaultEarlyReturn: State = copy(partialMatching = _partialMatching) + def updateExtractWith(e: Option[Extract]*)(implicit default: State): ExtractState = { mergeAll(Some(ex) +: e).fold[ExtractState](Left(default))(ex => Right(this withNewExtract ex)) } } object State { - def init(xtor: Rep, xtee: Rep): State = apply(xtor, xtee) + def forExtraction(xtor: Rep, xtee: Rep): State = apply(xtor, xtee, false) + def forRewriting(xtor: Rep, xtee: Rep): State = apply(xtor, xtee, true) - private def apply(xtor: Rep, xtee: Rep): State = - State(EmptyExtract, ListMap.empty, Flags(xtor, xtee), Set.empty, Map.empty.withDefaultValue(Set.empty)) + private def apply(xtor: Rep, xtee: Rep, inRewriting: Boolean): State = + State(EmptyExtract, ListMap.empty, Flags(xtor, xtee), Set.empty, Map.empty.withDefaultValue(Set.empty), inRewriting) } sealed trait Flag @@ -380,18 +386,31 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } def genFlags0(r: Rep, unusedBVs: Set[BoundVal], impures: Set[BoundVal]): (Set[BoundVal], Set[BoundVal]) = r match { - case lb: LetBinding => - val updated = update( - lb.value, - unusedBVs + lb.bound, - effect(lb.value) match { - case Pure => impures - case Impure => impures + lb.bound - } - ) - genFlags0(lb.body, updated._1, updated._2) + case lb: LetBinding => lb.value match { + // TODO should just inline the function in App during the transformation. Else you get an "Illegal ANF self argument". + case _: Lambda => + val updated = update( + lb.value, + unusedBVs + lb.bound, + effect(lb.value) match { + case Pure => impures + lb.bound + case Impure => impures + lb.bound + } + ) + genFlags0(lb.body, updated._1, updated._2) + case _ => + val updated = update( + lb.value, + unusedBVs + lb.bound, + effect(lb.value) match { + case Pure => impures + case Impure => impures + lb.bound + } + ) + genFlags0(lb.body, updated._1, updated._2) + } - case bv: BoundVal => (unusedBVs - bv, impures) + case bv: BoundVal => (unusedBVs - bv, impures) alsoApply println case _ => (unusedBVs, impures) } @@ -459,10 +478,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi val lbBVs = bvs(lb) val done: State => Boolean = s => lbBVs forall (s.ctx.keySet contains _) // TODO Keep set of matched BVs in state? - extractWithState(lb, body)(done)(es) map { es => + extractWithState(lb, body)(done)(es withPartialMatching) map { es => val replace = es.ctx(lb.last.bound) bottomUpPartial(filterLBs(body)(es.ctx.values.toSet contains _.bound)) { case `replace` => hopArg } -> es - } getOrElse body -> es + } getOrElse body -> (es withDefaultEarlyReturn) case _ => bottomUpPartial(body) { case `arg` => hopArg } -> es } @@ -572,13 +591,20 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case Skip => extractWithState(lb.body, xtee)(done) } - case (bv: BoundVal, lb: LetBinding) => for { - es1 <- extractWithState(bv, lb.bound)(done).left - es2 <- extractInside(bv, lb.value)(es1).left - es3 <- extractWithState(bv, lb.body)(done)(es2).left - } yield es3 - - case (_: Rep, lb: LetBinding) if es.matchedImpureBVs contains lb.bound => extractWithState(xtor, lb.body)(done) + case (bv: BoundVal, lb: LetBinding) => + if (es.ctx.keySet contains bv) Right(es) + else { + if (!isPure(lb) || !es.partialMatching) Left(es) + else { + for { + es1 <- extractWithState(bv, lb.bound)(done).left + es2 <- extractInside(bv, lb.value)(es1).left + es3 <- extractWithState(bv, lb.body)(done)(es2).left + } yield es3 + } + } + + case (_: Rep, lb: LetBinding) if (es.matchedImpureBVs contains lb.bound) || es.partialMatching => extractWithState(xtor, lb.body)(done) case (_, Ascribe(s, _)) => extractWithState(xtor, s)(done) @@ -651,8 +677,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (l1: Lambda, l2: Lambda) => for { es1 <- es updateExtractWith (l1.boundType extract(l2.boundType, Covariant)) - es2 <- extractWithState(l1.body, l2.body)(done)(es1 withCtx l1.bound -> l2.bound) - } yield es2 + es2 <- extractWithState(l1.body, l2.body)(done)(es1 withCtx l1.bound -> l2.bound withCompleteMatching) + } yield es2 withDefaultEarlyReturn case (ma1: MethodApp, ma2: MethodApp) if ma1.mtd == ma2.mtd => def targExtract(es0: State): ExtractState = @@ -701,7 +727,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } override def rewriteRep(xtor: Rep, xtee: Rep, code: Extract => Option[Rep]): Option[Rep] = - rewriteRep0(xtor, xtee, code)(false)(State.init(xtor, xtee)) + rewriteRep0(xtor, xtee, code)(false)(State.forRewriting(xtor, xtee)) def rewriteRep0(xtor: Rep, xtee: Rep, code: Extract => Option[Rep])(internalRec: Boolean)(implicit es: State): Option[Rep] = { def rewriteRepWithState(xtor: Rep, xtee: Rep)(implicit es: State): ExtractState = { @@ -783,68 +809,68 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } } - def cleanup(r: Rep, remove: Set[BoundVal]): Rep = r match { - case lb: LetBinding if remove contains lb.bound => cleanup(lb.body, remove) - case lb: LetBinding if isPure(lb.value) => - val bvsInValue = bvs(lb.value) - - if (bvsInValue exists (remove contains)) { - cleanup(lb.body, remove ++ bvsInValue.toSet) - } else { - lb.body = cleanup(lb.body, remove) + def cleanup(r: Rep, remove: Set[BoundVal], removeLambdas: Boolean)(ctx: Ctx): Rep = r match { + case lb: LetBinding if remove contains lb.bound => cleanup(lb.body, remove, removeLambdas)(ctx) + case lb: LetBinding => lb.value match { + case _: Lambda if (ctx.values.toSet contains lb.bound) && removeLambdas => + cleanup(lb.body, remove, removeLambdas)(ctx) + case v if isPure(v) => + val bvsInValue = bvs(lb.value) + + if (bvsInValue exists (remove contains)) { + cleanup(lb.body, remove ++ bvsInValue.toSet, removeLambdas)(ctx) + } else { + lb.body = cleanup(lb.body, remove, removeLambdas)(ctx) + lb + } + case _ => + lb.body = cleanup(lb.body, remove, removeLambdas)(ctx) lb - } - case lb: LetBinding => - lb.body = cleanup(lb.body, remove) - lb + } + case _ => r } - def finalize(code: Rep, xtor: Rep, filteredXtee: Rep)(ctx: Ctx): Rep = { - println(s"FOO: $code, \n$xtor, \n$filteredXtee") - - xtor match { - case xtorLB: LetBinding => - val xtorLast = xtorLB.last - xtorLast.body match { - case xtorRet: BoundVal => code match { - case codeLB: LetBinding => - val codeLast = codeLB.last - codeLast.body |>? { - case codeRet: BoundVal => - val bv = ctx(xtorRet) - codeLast.body = bottomUpPartial(filteredXtee) { case `bv` => codeRet } - } - code - - case _ => - val bv = ctx(xtorRet) - bottomUpPartial(filteredXtee) { case `bv` => code } - } - - // Hole? - case _ => code + def finalize(code: Rep, xtor: Rep, filteredXtee: Rep)(ctx: Ctx): Rep = xtor match { + case xtorLB: LetBinding => + val xtorLast = xtorLB.last + xtorLast.body match { + case xtorRet: BoundVal => code match { + case codeLB: LetBinding => + val codeLast = codeLB.last + codeLast.body |>? { + case codeRet: BoundVal => + val bv = ctx(xtorRet) + codeLast.body = bottomUpPartial(filteredXtee) { case `bv` => codeRet } + } + code + + case _ => + val bv = ctx(xtorRet) + bottomUpPartial(filteredXtee) { case `bv` => code } } - case _ => code match { - case codeLB: LetBinding => - val codeLast = codeLB.last - codeLast.body |>? { - case codeRet: BoundVal => - codeLast.body = bottomUpPartial(filteredXtee) { case `xtor` => codeRet } - } - code - - case _ => - bottomUpPartial(filteredXtee) { case `xtor` => code } + // Hole? + case _ => code } + + case _ => code match { + case codeLB: LetBinding => + val codeLast = codeLB.last + codeLast.body |>? { + case codeRet: BoundVal => + codeLast.body = bottomUpPartial(filteredXtee) { case `xtor` => codeRet } + } + code + + case _ => bottomUpPartial(filteredXtee) { case `xtor` => code } } } if (preCheck(es.ex)) for { - code <- code(es.ex) - code0 = finalize(code, xtor, xtee)(es.ctx) - if check(Set.empty, es.matchedImpureBVs)(cleanup(code0, es.matchedImpureBVs)) + code <- code(es.ex) alsoApply(c => println(s"CODE: $c")) + code0 = finalize(code, xtor, xtee)(es.ctx) alsoApply (c => println(s"CODE0: $c")) + if check(Set.empty, es.matchedImpureBVs)(cleanup(code0, es.matchedImpureBVs, true)(es.ctx)) } yield code0 else None } diff --git a/src/test/scala/squid/ir/fastir/ExtractingTests.scala b/src/test/scala/squid/ir/fastir/ExtractingTests.scala index 04d5edab..b8138a10 100644 --- a/src/test/scala/squid/ir/fastir/ExtractingTests.scala +++ b/src/test/scala/squid/ir/fastir/ExtractingTests.scala @@ -17,7 +17,7 @@ class ExtractingTests extends MyFunSuiteBase(ExtractingTests.Embedding) { } ir"println(42.toDouble)" match { - case ir"println($h)" => assert(h =~= ir"42.toDouble") + case ir"val t = ($h: Double); println(t)" => assert(h =~= ir"42.toDouble") } ir"(42, 1337)" match { @@ -86,6 +86,14 @@ class ExtractingTests extends MyFunSuiteBase(ExtractingTests.Embedding) { case ir"val t = ($h: Int); val bX = readInt; bX" => assert(h =~= ir"readInt") } } + + test("Extracting should not match a return in the middle of block") { + assert(Try { + ir"val a = readInt; a + 11 + 22" match { + case ir"val a = readInt; a + 11" => fail + } + }.isFailure) + } } object ExtractingTests { diff --git a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala index 449f7618..368709fe 100644 --- a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala +++ b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala @@ -84,24 +84,30 @@ class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVaria case ir"(x: Int, y: Int, z: Int) => $body(x + y + z): Int" => assert(body =~= id) } - ir"(a: Int) => readInt + a" matches { + ir"(a: Int) => readInt + a" match { case ir"(x: Int) => $body(readInt, x): Int" => assert(body =~= ir"(r: Int, s: Int) => r + s") } - ir"(a: Int) => readInt + a" matches { + ir"(a: Int) => readInt + a" match { case ir"(x: Int) => $body(x, readInt): Int" => assert(body =~= ir"(r: Int, s: Int) => s + r") } - ir"(a: Int, b: Int) => readInt + (a + b)" matches { + ir"(a: Int, b: Int) => readInt + (a + b)" match { case ir"(x: Int, y: Int) => $body(readInt, x + y): Int" => assert(body =~= ir"(r: Int, s: Int) => r + s") } } - test("Match letbindinds") { + test("Match letbindings") { ir"val a = 10.toDouble; a + 1" match { case ir"val x = 10.toDouble; $body(x)" => assert(body =~= ir"(_: Double) + 1") } - + + ir"val a = 11.toDouble; val b = 22.toDouble; val c = 33.toDouble; (a,b,c)" match { + case ir"val a = ($x:Int).toDouble; val b = ($y:Int).toDouble; $body(a, b)" => + assert(x =~= ir"11") + assert(y =~= ir"22") + assert(body =~= ir"(a: Double, b: Double) => (a, b, 33.toDouble)") + } //val a = ir"val a = 10.toDouble; val b = a + 1; val c = b + 2; c" matches { // case ir"val x = 10.toDouble; $body(x):Double" => diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index 36bb7c80..37b61ecf 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -2,8 +2,11 @@ package squid package ir package fastir +import squid.test.Test.Embedding + class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { import RewritingTests.Embedding.Predef._ + import RewritingTests.Embedding.Quasicodes._ test("Simple rewrites") { val a = ir"123" rewrite { @@ -31,10 +34,10 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { } test("Rewriting subpatterns") { - val a = ir"(10.toDouble + 111) * .5" match { + val a = ir"(10.toDouble + 111) * .5" rewrite { case ir"(($n: Double) + 111) * .5" => ir"$n * .25" } - assert(a =~= ir"10.toDouble * .25") + assert(a =~= ir"val t = 10.toDouble; val t2 = (t + 111) * .5; t * .25") val b = ir"Option(42).get" rewrite { case ir"Option(($n: Int)).get" => n @@ -121,6 +124,32 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { assert(a =~= ir"val a = readDouble.toInt; val b = readDouble; a + b") } + test("Rewriting should not partially match inside a lambda") { + val a = ir"(x: Double) => x * 2 + 33" rewrite { + case ir"(x: Double) => x * 2" => fail + } + } + + /* --- Hacky solutions --- */ + test("Rewriting should flag lambdas as starting points") { + // Else we get an "Illegal ANF self argument" because the rewriting + // puts a LB in the function position of tha App f(a) + val l = ir"(x: Double) => x * 2" + val a = ir"val a = 11.toDouble; val f = $l; f(a)" rewrite { + case ir"(x: Double) => x * 2" => ir"(x: Double) => x * 4" + } + } + + test("Rewriting should remove matched lambdas") { + // Else since genCode puts code at the top it will rewrite + // again and again the same lambda + val l = ir"(x: Double) => x * 2" + val a = ir"val a = 11.toDouble; val f = $l; f(a)" rewrite { + case ir"(x: Double) => x * 2" => ir"(x: Double) => x * 4" + } + } + /* ----------------------- */ + //test("Rewriting with impures") { // val a = ir"val a = readInt; val b = readInt; (a + b) * 0.5" rewrite { // case ir"(($h1: Int) + ($h2: Int)) * 0.5" => dbg_ir"($h1 * $h2) + 42.0" From c114a94e60fc91f42f00db2b9bb7fa4dbb60c15d Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Sun, 3 Dec 2017 22:51:30 +0100 Subject: [PATCH 38/66] Remove explicit of passing done predicate --- src/main/scala/squid/ir/fastanf/FastANF.scala | 84 ++++++++++--------- 1 file changed, 45 insertions(+), 39 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 7da7fa2c..ecfa2070 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -319,7 +319,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi protected def extract(xtor: Rep, xtee: Rep): Option[Extract] = { println(s"Extract(\n$xtor, \n$xtee)") - extractWithState(xtor, xtee)(_ => false)(State.forExtraction(xtor, xtee)).fold(_ => None, Some(_)) map (_.ex) + extractWithState(xtor, xtee)(State.forExtraction(xtor, xtee)).fold(_ => None, Some(_)) map (_.ex) } type Ctx = Map[BoundVal, BoundVal] @@ -331,7 +331,11 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi type ExtractState = Either[State, State] implicit def rightBias[A, B](e: Either[A, B]): Either.RightProjection[A,B] = e.right - case class State(ex: Extract, ctx: Ctx, flags: Flags, matchedImpureBVs: Set[BoundVal], failedMatches: Map[BoundVal, Set[BoundVal]], partialMatching: Boolean) { + case class State(ex: Extract, ctx: Ctx, flags: Flags, matchedImpureBVs: Set[BoundVal], + failedMatches: Map[BoundVal, Set[BoundVal]], done: State => Boolean, partialMatching: Boolean) { + def isDone: Boolean = done(this) + private val _done = done + private val _partialMatching = partialMatching def withNewExtract(newEx: Extract): State = copy(ex = newEx) @@ -347,17 +351,19 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def withPartialMatching: State = copy(partialMatching = true) def withCompleteMatching: State = copy(partialMatching = false) def withDefaultEarlyReturn: State = copy(partialMatching = _partialMatching) + def withTerminationPredicate(done: State => Boolean): State = copy(done = done) + def withDefaultTerminationPredicate: State = copy(done = _done) def updateExtractWith(e: Option[Extract]*)(implicit default: State): ExtractState = { mergeAll(Some(ex) +: e).fold[ExtractState](Left(default))(ex => Right(this withNewExtract ex)) } } object State { - def forExtraction(xtor: Rep, xtee: Rep): State = apply(xtor, xtee, false) - def forRewriting(xtor: Rep, xtee: Rep): State = apply(xtor, xtee, true) + def forExtraction(xtor: Rep, xtee: Rep, done: State => Boolean = _ => false): State = apply(xtor, xtee, done, false) + def forRewriting(xtor: Rep, xtee: Rep, done: State => Boolean = _ => false): State = apply(xtor, xtee, done, true) - private def apply(xtor: Rep, xtee: Rep, inRewriting: Boolean): State = - State(EmptyExtract, ListMap.empty, Flags(xtor, xtee), Set.empty, Map.empty.withDefaultValue(Set.empty), inRewriting) + private def apply(xtor: Rep, xtee: Rep, done: State => Boolean, partialMatching: Boolean): State = + State(EmptyExtract, ListMap.empty, Flags(xtor, xtee), Set.empty, Map.empty.withDefaultValue(Set.empty), done, partialMatching) } sealed trait Flag @@ -419,7 +425,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } } - def extractWithState(xtor: Rep, xtee: Rep)(done: State => Boolean)(implicit es: State): ExtractState = { + def extractWithState(xtor: Rep, xtee: Rep)(implicit es: State): ExtractState = { println(es) def extractHOPHole(name: String, typ: TypeRep, argss: List[List[Rep]], visible: List[BoundVal])(implicit es: State): ExtractState = { @@ -478,10 +484,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi val lbBVs = bvs(lb) val done: State => Boolean = s => lbBVs forall (s.ctx.keySet contains _) // TODO Keep set of matched BVs in state? - extractWithState(lb, body)(done)(es withPartialMatching) map { es => + extractWithState(lb, body)(es withTerminationPredicate done withPartialMatching) map { es => val replace = es.ctx(lb.last.bound) bottomUpPartial(filterLBs(body)(es.ctx.values.toSet contains _.bound)) { case `replace` => hopArg } -> es - } getOrElse body -> (es withDefaultEarlyReturn) + } getOrElse body -> es.withDefaultTerminationPredicate.withDefaultEarlyReturn case _ => bottomUpPartial(body) { case `arg` => hopArg } -> es } @@ -503,26 +509,26 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi oe getOrElse Left(es) } - def extractLBs(lb1: LetBinding, lb2: LetBinding)(done: State => Boolean)(implicit es: State): ExtractState = { - def extractAndContinue(lb1: LetBinding, lb2: LetBinding)(done: State => Boolean)(implicit es: State): ExtractState = for { - es1 <- extractWithState(lb1.bound, lb2.bound)(done) - es2 <- extractWithState(lb1.body, lb2.body)(done)(es1) + def extractLBs(lb1: LetBinding, lb2: LetBinding)(implicit es: State): ExtractState = { + def extractAndContinue(lb1: LetBinding, lb2: LetBinding)(implicit es: State): ExtractState = for { + es1 <- extractWithState(lb1.bound, lb2.bound) + es2 <- extractWithState(lb1.body, lb2.body)(es1) } yield es2 (es.flags.xtorFlag(lb1.bound), es.flags.xteeFlag(lb2.bound)) match { - case (Start, Start) => extractAndContinue(lb1, lb2)(done) + case (Start, Start) => extractAndContinue(lb1, lb2) case (Start, Skip) => for { - es1 <- extractAndContinue(lb1, lb2)(done).left - es2 <- extractWithState(lb1, lb2.body)(done)(es1).left + es1 <- extractAndContinue(lb1, lb2).left + es2 <- extractWithState(lb1, lb2.body)(es1).left } yield es2 case (Skip, Start) => for { - es1 <- extractAndContinue(lb1, lb2)(done).left - es2 <- extractWithState(lb1.body, lb2)(done)(es1).left + es1 <- extractAndContinue(lb1, lb2).left + es2 <- extractWithState(lb1.body, lb2)(es1).left } yield es2 - case (Skip, Skip) => extractWithState(lb1.body, lb2.body)(done) + case (Skip, Skip) => extractWithState(lb1.body, lb2.body) } } @@ -554,14 +560,14 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi bvs(d).foldLeft[ExtractState](Left(es)) { case (acc, bv2) => for { es1 <- acc.left - es2 <- extractWithState(bv, bv2)(done)(es1).left + es2 <- extractWithState(bv, bv2)(es1).left } yield es2 } def contentsOf(h: Hole)(implicit es: State): Option[Rep] = es.ex._1 get h.name // TODO check in ex._3, return Option[List[Rep]] println(s"extractWithState: $xtor\n$xtee\n") - if (done(es)) Right(es) + if (es.isDone) Right(es) else xtor -> xtee match { case (h: Hole, lb: LetBinding) => contentsOf(h) match { case Some(lb1: LetBinding) if lb1.value == lb.value => Right(es) @@ -577,7 +583,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (HOPHole2(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) - case (lb1: LetBinding, lb2: LetBinding) => extractLBs(lb1, lb2)(done) + case (lb1: LetBinding, lb2: LetBinding) => extractLBs(lb1, lb2) //// TODO Stop at markers? //case (lb: LetBinding, _: Rep) => (effect(lb), effect(xtee)) match { @@ -588,7 +594,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (lb: LetBinding, _: Rep) => es.flags.xtorFlag(lb.bound) match { case Start => Left(es) - case Skip => extractWithState(lb.body, xtee)(done) + case Skip => extractWithState(lb.body, xtee) } case (bv: BoundVal, lb: LetBinding) => @@ -597,20 +603,20 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi if (!isPure(lb) || !es.partialMatching) Left(es) else { for { - es1 <- extractWithState(bv, lb.bound)(done).left + es1 <- extractWithState(bv, lb.bound).left es2 <- extractInside(bv, lb.value)(es1).left - es3 <- extractWithState(bv, lb.body)(done)(es2).left + es3 <- extractWithState(bv, lb.body)(es2).left } yield es3 } } - case (_: Rep, lb: LetBinding) if (es.matchedImpureBVs contains lb.bound) || es.partialMatching => extractWithState(xtor, lb.body)(done) + case (_: Rep, lb: LetBinding) if (es.matchedImpureBVs contains lb.bound) || es.partialMatching => extractWithState(xtor, lb.body) - case (_, Ascribe(s, _)) => extractWithState(xtor, s)(done) + case (_, Ascribe(s, _)) => extractWithState(xtor, s) case (Ascribe(s, t), _) => for { es1 <- es.updateExtractWith(t extract(xtee.typ, Covariant)) - es2 <- extractWithState(s, xtee)(done)(es1) + es2 <- extractWithState(s, xtee)(es1) } yield es2 case (HOPHole(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) @@ -626,14 +632,14 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi if (bv1 == bv2) Right(es) else if (es.failedMatches(bv1) contains bv2) Left(es) else (bv1.owner, bv2.owner) match { - case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value)(done) match { + case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value) match { case Right(es) => effect(lb2.value) match { case Pure => Right(es withCtx lb1.bound -> lb2.bound) case Impure => Right(es withCtx lb1.bound -> lb2.bound withMatchedImpures lb2.bound) } case Left(es) => Left(es withFailed lb1.bound -> lb2.bound) } - case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2)(done) map (_ withCtx l1.bound -> l2.bound) // TODO handle failed extract? + case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2) map (_ withCtx l1.bound -> l2.bound) // TODO handle failed extract? case _ => Left(es withFailed bv1 -> bv2) } } @@ -644,7 +650,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (StaticModule(fn1), StaticModule(fn2)) if fn1 == fn2 => Right(es) // Assuming if they have the same name and prefix the type is the same - case (Module(p1, n1, _), Module(p2, n2, _)) if n1 == n2 => extractWithState(p1, p2)(done) + case (Module(p1, n1, _), Module(p2, n2, _)) if n1 == n2 => extractWithState(p1, p2) case (NewObject(t1), NewObject(t2)) => es updateExtractWith (t1 extract(t2, Covariant)) @@ -671,13 +677,13 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => throw IRException(s"Trying to splice-extract with invalid extractor $xtor") } - def extractDefs(v1: Def, v2: Def)(done: State => Boolean)(implicit es: State): ExtractState = { + def extractDefs(v1: Def, v2: Def)(implicit es: State): ExtractState = { println(s"VALUES: \n\t$v1\n\t$v2 with $es \n\n") (v1, v2) match { case (l1: Lambda, l2: Lambda) => for { es1 <- es updateExtractWith (l1.boundType extract(l2.boundType, Covariant)) - es2 <- extractWithState(l1.body, l2.body)(done)(es1 withCtx l1.bound -> l2.bound withCompleteMatching) + es2 <- extractWithState(l1.body, l2.body)(es1 withCtx l1.bound -> l2.bound withCompleteMatching) } yield es2 withDefaultEarlyReturn case (ma1: MethodApp, ma2: MethodApp) if ma1.mtd == ma2.mtd => @@ -700,11 +706,11 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi es1 <- extractArgss0(t1, t2, acc)(es0) } yield es1 - case (SplicedArgument(arg1), SplicedArgument(arg2)) => extractWithState(arg1, arg2)(done) + case (SplicedArgument(arg1), SplicedArgument(arg2)) => extractWithState(arg1, arg2) case (sa: SplicedArgument, ArgumentCons(h, t)) => extractArgss0(sa, t, h :: acc) case (sa: SplicedArgument, r: Rep) => extractArgss0(sa, NoArguments, r :: acc) case (SplicedArgument(arg), NoArguments) => es updateExtractWith spliceExtract(arg, Args(acc.reverse: _*)) - case (r1: Rep, r2: Rep) => extractWithState(r1, r2)(done) + case (r1: Rep, r2: Rep) => extractWithState(r1, r2) case (NoArguments, NoArguments) => Right(es) case (NoArgumentLists, NoArgumentLists) => Right(es) case _ => Left(es) @@ -714,13 +720,13 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } for { - es1 <- extractWithState(ma1.self, ma2.self)(done) + es1 <- extractWithState(ma1.self, ma2.self) es2 <- targExtract(es1) es3 <- extractArgss(ma1.argss, ma2.argss)(es2) es4 <- es3.updateExtractWith(ma1.typ extract (ma2.typ, Covariant)) } yield es4 - case (DefHole(h), _) => extractWithState(h, wrapConstruct(letbind(v2)))(done) + case (DefHole(h), _) => extractWithState(h, wrapConstruct(letbind(v2))) case _ => Left(es) } @@ -737,9 +743,9 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (lb1: LetBinding, lb2: LetBinding) if !internalRec => ((effect(lb1.value), es.flags.xtorFlag(lb1.bound)), (effect(lb2.value), es.flags.xteeFlag(lb2.bound))) match { case ((Pure, Skip), (Pure, Skip)) => Left(es) - case _ => extractWithState(lb1, lb2)(_ => false) + case _ => extractWithState(lb1, lb2) } - case _ => extractWithState(xtor, xtee)(_ => false) + case _ => extractWithState(xtor, xtee) } } From f0616ec031952324be4cc785e54a7a311f4733e8 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Mon, 4 Dec 2017 10:07:14 +0100 Subject: [PATCH 39/66] Fix issue where order of pure statements sometimes mattered, defholes now only extracts impure statements --- src/main/scala/squid/ir/fastanf/FastANF.scala | 21 ++++++++++++------- .../squid/ir/fastir/ExtractingTests.scala | 20 +++++++++++++++++- .../squid/ir/fastir/RewritingTests.scala | 20 +++++++++++------- 3 files changed, 46 insertions(+), 15 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index ecfa2070..a6a6cb61 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -416,7 +416,17 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi genFlags0(lb.body, updated._1, updated._2) } - case bv: BoundVal => (unusedBVs - bv, impures) alsoApply println + case bv: BoundVal => (unusedBVs - bv, impures) //alsoApply println + + case HOPHole2(_, _, argss, _) => + val updatedImpures = argss.flatten.foldLeft(impures) { + case (impures, arg) => arg match { + case lb: LetBinding if !isPure(lb.value) => impures + lb.bound + case _ => impures + } + } + (unusedBVs, updatedImpures) + case _ => (unusedBVs, impures) } @@ -523,11 +533,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi es2 <- extractWithState(lb1, lb2.body)(es1).left } yield es2 - case (Skip, Start) => for { - es1 <- extractAndContinue(lb1, lb2).left - es2 <- extractWithState(lb1.body, lb2)(es1).left - } yield es2 - + case (Skip, Start) => extractWithState(lb1.body, lb2) + case (Skip, Skip) => extractWithState(lb1.body, lb2.body) } } @@ -726,7 +733,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi es4 <- es3.updateExtractWith(ma1.typ extract (ma2.typ, Covariant)) } yield es4 - case (DefHole(h), _) => extractWithState(h, wrapConstruct(letbind(v2))) + case (DefHole(h), _) if !isPure(v2) => extractWithState(h, wrapConstruct(letbind(v2))) case _ => Left(es) } diff --git a/src/test/scala/squid/ir/fastir/ExtractingTests.scala b/src/test/scala/squid/ir/fastir/ExtractingTests.scala index b8138a10..3910655a 100644 --- a/src/test/scala/squid/ir/fastir/ExtractingTests.scala +++ b/src/test/scala/squid/ir/fastir/ExtractingTests.scala @@ -17,7 +17,7 @@ class ExtractingTests extends MyFunSuiteBase(ExtractingTests.Embedding) { } ir"println(42.toDouble)" match { - case ir"val t = ($h: Double); println(t)" => assert(h =~= ir"42.toDouble") + case ir"println(($h: Int).toDouble)" => assert(h =~= ir"42") } ir"(42, 1337)" match { @@ -81,6 +81,24 @@ class ExtractingTests extends MyFunSuiteBase(ExtractingTests.Embedding) { case ir"val aX = readInt; val bX = readInt; bX + aX" => fail } }.isFailure) + + assert(Try { + ir"(readInt, readDouble)" match { + case ir"(readDouble, readInt)" => fail + } + }.isFailure) + + ir"readInt; readDouble" match { + case ir"$a; $b" => + assert(a =~= ir"readInt") + assert(b =~= ir"readDouble") + } + + ir"val r1 = readInt; val a = 20 + r1; val r2 = readDouble; a + r2" match { + case ir"val r1 = ($r1: Int); val r2 = ($r2: Double); val a = 20 + r1; a + r2" => + assert(r1 =~= ir"readInt") + assert(r2 =~= ir"readDouble") + } ir"val a = readInt; val b = readInt; b" match { case ir"val t = ($h: Int); val bX = readInt; bX" => assert(h =~= ir"readInt") diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index 37b61ecf..b4d4dc11 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -157,13 +157,19 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { // assert(a =~= ir"(readInt + readInt) + 42.0") //} - //test("Rewriting simple expressions only once") { - // val a = ir"println((50, 60))" rewrite { - // case ir"($x:Int,$y:Int)" => ir"($y:Int,$x:Int)" - // //case ir"(${Const(n)}:Int)" => Const(n+1) - // } - // assert(a =~= ir"println((61,51))") - //} + test("Rewriting simple expressions only once") { + /* + * TODO + * goes into an infinite loop since it rewrites + * at the "current" point. So the topdown transformer + * will apply the transformation again on what was rewritten. + */ + //val a = ir"println((50, 60))" rewrite { + // case ir"($x:Int,$y:Int)" => ir"($y:Int,$x:Int)" + // //case ir"(${Const(n)}:Int)" => Const(n+1) + //} + //assert(a =~= ir"println((61,51))") + } // //test("Function Rewritings") { // val a = ir"(x: Int) => (x-5) * 32" rewrite { From 95c2c4b7e163a3ebeedcc65c50ca309a0032a3fb Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Mon, 4 Dec 2017 19:27:40 +0100 Subject: [PATCH 40/66] Fix issue where you couldn't extract an impure statement --- src/main/scala/squid/ir/fastanf/FastANF.scala | 17 +++++++---------- .../scala/squid/ir/fastir/ExtractingTests.scala | 4 ++++ .../ir/fastir/HigherOrderPatternVariables.scala | 7 +++++++ 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index a6a6cb61..c5544538 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -606,16 +606,13 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (bv: BoundVal, lb: LetBinding) => if (es.ctx.keySet contains bv) Right(es) - else { - if (!isPure(lb) || !es.partialMatching) Left(es) - else { - for { - es1 <- extractWithState(bv, lb.bound).left - es2 <- extractInside(bv, lb.value)(es1).left - es3 <- extractWithState(bv, lb.body)(es2).left - } yield es3 - } - } + else if (es.partialMatching) { + for { + es1 <- extractWithState(bv, lb.bound).left + es2 <- extractInside(bv, lb.value)(es1).left + es3 <- extractWithState(bv, lb.body)(es2).left + } yield es3 + } else Left(es) case (_: Rep, lb: LetBinding) if (es.matchedImpureBVs contains lb.bound) || es.partialMatching => extractWithState(xtor, lb.body) diff --git a/src/test/scala/squid/ir/fastir/ExtractingTests.scala b/src/test/scala/squid/ir/fastir/ExtractingTests.scala index 3910655a..5c7bd1ae 100644 --- a/src/test/scala/squid/ir/fastir/ExtractingTests.scala +++ b/src/test/scala/squid/ir/fastir/ExtractingTests.scala @@ -99,6 +99,10 @@ class ExtractingTests extends MyFunSuiteBase(ExtractingTests.Embedding) { assert(r1 =~= ir"readInt") assert(r2 =~= ir"readDouble") } + + ir"val r1 = readInt; val a = 20 + r1; val b = r1 * 2; val r2 = a + readDouble; r2 + b" match { + case ir"val r1 = readInt; val b = r1 * 2; val a = 20 + r1; val r2 = a + readDouble; r2 + b" => + } ir"val a = readInt; val b = readInt; b" match { case ir"val t = ($h: Int); val bX = readInt; bX" => assert(h =~= ir"readInt") diff --git a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala index 368709fe..3b5c3b25 100644 --- a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala +++ b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala @@ -121,6 +121,13 @@ class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVaria // assert(b == ir"readDouble.toInt + 42") // assert(b == ir"val r = readDouble + 1; r") } + + test("HOPHoles should be able to extract an impure statement") { + val f = ir"(x: Int) => println(x + 1)" + ir"val a = 20; $f(a)" match { + case ir"(x: Int) => ($prt(x + 1): Int)" => assert(prt =~= ir"(x: Int) => println(x)") + } + } } object HigherOrderPatternVariables { From 8b145b28c5ec79ef71bd9e9e721624ea26d7f166 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Mon, 4 Dec 2017 22:30:08 +0100 Subject: [PATCH 41/66] HOPHoles now replace all the occurences of a pattern --- src/main/scala/squid/ir/fastanf/FastANF.scala | 19 +++++++++++++++---- .../fastir/HigherOrderPatternVariables.scala | 12 +++++++++--- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index c5544538..e8a9cbe6 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -493,11 +493,22 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case lb: LetBinding => val lbBVs = bvs(lb) val done: State => Boolean = s => lbBVs forall (s.ctx.keySet contains _) // TODO Keep set of matched BVs in state? + + def replaceAllOccurences(r: Rep, body: Rep)(es: State): Rep -> State = { + def replaceAllOccurences0(r: Rep, body: Rep)(implicit es: State): Rep -> State = { + extractWithState(lb, body) map { es0 => + println(s"------------------") + val replace = es0.ctx(lb.last.bound) + val body0 = bottomUpPartial(filterLBs(body)(es0.ctx.values.toSet contains _.bound)) { case `replace` => hopArg } + replaceAllOccurences0(r, body0)._1 -> es0 + } getOrElse body -> es + } + + replaceAllOccurences0(r, body)(es withTerminationPredicate done withPartialMatching) + .mapSecond(_.withDefaultTerminationPredicate.withDefaultEarlyReturn) + } - extractWithState(lb, body)(es withTerminationPredicate done withPartialMatching) map { es => - val replace = es.ctx(lb.last.bound) - bottomUpPartial(filterLBs(body)(es.ctx.values.toSet contains _.bound)) { case `replace` => hopArg } -> es - } getOrElse body -> es.withDefaultTerminationPredicate.withDefaultEarlyReturn + replaceAllOccurences(lb, body)(es) case _ => bottomUpPartial(body) { case `arg` => hopArg } -> es } diff --git a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala index 3b5c3b25..ad0943d1 100644 --- a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala +++ b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala @@ -76,9 +76,9 @@ class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVaria test("Non-trivial arguments") { val id = ir"(z: Int) => z" - //ir"(a: Int, b: Int, c: Int) => a + b + c" matches { - // case ir"(x: Int, y: Int, z: Int) => $body(x + y, z): Int" => assert(body == ir"(r: Int, s: Int) => r + s") - //} + ir"(a: Int, b: Int, c: Int) => a + b + c" matches { + case ir"(x: Int, y: Int, z: Int) => $body(x + y, z): Int" => assert(body == ir"(r: Int, s: Int) => r + s") + } ir"(a: Int, b: Int, c: Int) => a + b + c" match { case ir"(x: Int, y: Int, z: Int) => $body(x + y + z): Int" => assert(body =~= id) @@ -128,6 +128,12 @@ class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVaria case ir"(x: Int) => ($prt(x + 1): Int)" => assert(prt =~= ir"(x: Int) => println(x)") } } + + test("HOPHoles should extract all the occurences of the pattern") { + ir"val f = (x: Int, y: Int) => { println(x + y); x + y }; f(11, 22)" match { + case ir"(x: Int, y: Int) => ($h(x + y): Int)" => assert(h =~= ir"(x: Int) => { println(x); x }") + } + } } object HigherOrderPatternVariables { From d928e38807fa837dd5f277b4d85e2bde5e506d2f Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Mon, 4 Dec 2017 22:42:38 +0100 Subject: [PATCH 42/66] Fix issue where HOPHoles couldn't extract nestet letbindings --- src/main/scala/squid/ir/fastanf/FastANF.scala | 2 +- .../scala/squid/ir/fastir/HigherOrderPatternVariables.scala | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index e8a9cbe6..ff8f3d8c 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -913,7 +913,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def bvs(r: Rep): List[BoundVal] = { def bvs0(r: Rep, acc: List[BoundVal]): List[BoundVal] = r match { - case lb: LetBinding => lb.bound :: acc + case lb: LetBinding => bvs0(lb.body, lb.bound :: acc) case _ => acc } diff --git a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala index ad0943d1..25c24809 100644 --- a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala +++ b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala @@ -134,6 +134,12 @@ class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVaria case ir"(x: Int, y: Int) => ($h(x + y): Int)" => assert(h =~= ir"(x: Int) => { println(x); x }") } } + + test("HOPHoles should be able to extract nested letbindings") { + ir"val a = readInt; val b = readInt; val a1 = a + 1; val b1 = b + 1; a1 + b1" match { + case ir"($h(readInt + 1): Int)" => assert(h =~= ir"(x: Int) => x + x") + } + } } object HigherOrderPatternVariables { From 5bafcefdc9715980361bb3c63ac07eb5529d6c25 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Tue, 5 Dec 2017 13:03:55 +0100 Subject: [PATCH 43/66] Add HOPHole tests --- .../ir/fastir/HigherOrderPatternVariables.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala index 25c24809..e4a643b3 100644 --- a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala +++ b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala @@ -133,6 +133,18 @@ class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVaria ir"val f = (x: Int, y: Int) => { println(x + y); x + y }; f(11, 22)" match { case ir"(x: Int, y: Int) => ($h(x + y): Int)" => assert(h =~= ir"(x: Int) => { println(x); x }") } + + ir"val f = (x: Int, y: Int) => { println(x + y); val a = x + y; println(x * y); val b = x * y; a + b}; f(11, 22)" match { + case ir"(x: Int, y: Int) => ($h(x + y, x * y): Int)" => assert(h =~= ir"(x: Int, y: Int) => { println(x); println(y); x + y }") + } + + ir"val f = (x: Int, y: Int) => { println(x * y); val a = x + y; println(x + y); val b = x * y; a + b}; f(11, 22)" match { + case ir"(x: Int, y: Int) => ($h(x + y, x * y): Int)" => assert(h =~= ir"(x: Int, y: Int) => { println(y); println(x); x + y }") + } + + ir"val f = (x: Int, y: Int) => { println(x * y); val a = x + y; println(x + y); val b = x * y; println(a + b); a * b}; f(11, 22)" match { + case ir"(x: Int, y: Int) => ($h(x + y, x * y): Int)" => assert(h =~= ir"(x: Int, y: Int) => { println(y); println(x); println(x + y); x * y}") + } } test("HOPHoles should be able to extract nested letbindings") { From 201706bed9129279826ef9fc13015c42842daa21 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Tue, 5 Dec 2017 13:09:38 +0100 Subject: [PATCH 44/66] Cleanup tests --- .../scala/squid/ir/fastir/ArgumentTests.scala | 36 +++++ .../scala/squid/ir/fastir/BasicTests.scala | 136 ------------------ 2 files changed, 36 insertions(+), 136 deletions(-) create mode 100644 src/test/scala/squid/ir/fastir/ArgumentTests.scala delete mode 100644 src/test/scala/squid/ir/fastir/BasicTests.scala diff --git a/src/test/scala/squid/ir/fastir/ArgumentTests.scala b/src/test/scala/squid/ir/fastir/ArgumentTests.scala new file mode 100644 index 00000000..d16b9e12 --- /dev/null +++ b/src/test/scala/squid/ir/fastir/ArgumentTests.scala @@ -0,0 +1,36 @@ +package squid +package ir.fastir + +import squid.ir.FastANF + +class ArgumentTests extends MyFunSuiteBase(ArgumentTests.Embedding) { + test("Arguments") { + import squid.ir.fastanf._ + + val c0 = Constant(0) + val c1 = Constant(1) + val c2 = Constant(2) + + assert(c1 ~: c2 ~: NoArguments == ArgumentCons(c1, c2)) + assert(c0 ~~: (c1 ~: c2) ~~: NoArguments == ArgumentListCons(c0, ArgumentListCons(ArgumentCons(c1, c2), NoArguments))) + assert(c0 ~~: (c1 ~: c2 ~: NoArguments) ~~: NoArgumentLists == ArgumentListCons(c0, ArgumentCons(c1, c2))) + } + + test("ArgumentLists Pretty Print") { + import squid.ir.fastanf._ + + val c0 = Constant(0) + val c1 = Constant(1) + val c2 = Constant(2) + val c3 = Constant(3) + + assert((c0 ~: c1).toArgssString == s"($c0, $c1)") + assert((c0 ~~: (c1 ~: c2)).toArgssString == s"($c0)($c1, $c2)") + assert((c0 ~~: (c1 ~: c2 ~: NoArguments)).toArgssString == s"($c0)($c1, $c2)") + assert((c0 ~~: (c1 ~: c2 ~: SplicedArgument(c3))).toArgssString == s"($c0)($c1, $c2, $c3: _*)") + } +} + +object ArgumentTests { + object Embedding extends FastANF +} diff --git a/src/test/scala/squid/ir/fastir/BasicTests.scala b/src/test/scala/squid/ir/fastir/BasicTests.scala deleted file mode 100644 index 4f98c60f..00000000 --- a/src/test/scala/squid/ir/fastir/BasicTests.scala +++ /dev/null @@ -1,136 +0,0 @@ -package squid -package ir.fastir - -import utils._ -import squid.ir.FastANF -import squid.ir.{SimpleRuleBasedTransformer,TopDownTransformer} - -class BasicTests extends MyFunSuiteBase(BasicTests.Embedding) { - import BasicTests.Embedding.Predef._ - - // TODO make proper tests... - - test("Basics") { - - println(ir"(x:Int) => x.toDouble.toInt + 42.0.toInt") - - } - - test("Arguments") { - import squid.ir.fastanf._ - - val c0 = Constant(0) - val c1 = Constant(1) - val c2 = Constant(2) - - assert(c1 ~: c2 ~: NoArguments == ArgumentCons(c1, c2)) - assert(c0 ~~: (c1 ~: c2) ~~: NoArguments == ArgumentListCons(c0, ArgumentListCons(ArgumentCons(c1, c2), NoArguments))) - assert(c0 ~~: (c1 ~: c2 ~: NoArguments) ~~: NoArgumentLists == ArgumentListCons(c0, ArgumentCons(c1, c2))) - - } - - test("ArgumentLists Pretty Print") { - import squid.ir.fastanf._ - - val c0 = Constant(0) - val c1 = Constant(1) - val c2 = Constant(2) - val c3 = Constant(3) - - assert((c0 ~: c1).toArgssString == s"($c0, $c1)") - assert((c0 ~~: (c1 ~: c2)).toArgssString == s"($c0)($c1, $c2)") - assert((c0 ~~: (c1 ~: c2 ~: NoArguments)).toArgssString == s"($c0)($c1, $c2)") - assert((c0 ~~: (c1 ~: c2 ~: SplicedArgument(c3))).toArgssString == s"($c0)($c1, $c2, $c3: _*)") - } - - test("Transformations") { - import squid.ir.fastanf._ - object Embedding extends FastANF - import Embedding.Predef._ - import Embedding.Quasicodes._ - - assert(Embedding.bottomUpPartial(code"42".rep){case Constant(n: Int) => Constant(n * 2)} == code"84".rep) - //assert(Embedding.bottomUpPartial(code"(x: Int) => x + 5".rep){case Constant(5) => Constant(42)} == code"(x: Int) => x + 42".rep) - //assert(Embedding.bottomUpPartial(code{ Some(5) }.rep){ case Constant(5) => Constant(42) } == code{ Some(42) }.rep) - //assert(Embedding.bottomUpPartial(code"foo(4, 5)".rep){ case Constant(4) => Constant(42) } == code"foo(42, 5)".rep) - - //case class Point(x: Int, y: Int) - //val p = Point(0, 1) - // - //assert(Embedding.bottomUpPartial(code{p.x + p.y}.rep){ case Constant(0) => Constant(42) } == code"p.y + p.y") - // - //code{p.x} match { - // case code"p.x" => assert(true) - // case _ => assert(false) - //} - // - //code{p.x + p.y} match { - // case code"p.x" => assert(true) - // case _ => assert(false) - //} - - //assert(Embedding.bottomUpPartial(code"println((1, 23))".rep) { case Constant(1) => Constant(42) } == code"println((42,23))".rep) - } - - test("Transformers") { - //object Tr extends BasicTests.Embedding.SelfTransformer with SimpleRuleBasedTransformer with TopDownTransformer { - // rewrite { - // case ir"123" => ir"readInt + 5" - // case ir"readDouble" => ir"10.0" - // } - //} - // - //val p = - // BasicTests.Embedding.debugFor { - // ir"123+readDouble+123" alsoApply println transformWith Tr alsoApply println - // } - } - - test("Rewriting simple expressions only once") { - val a = ir"println((50,60))" - val b = a rewrite { - case ir"($x:Int,$y:Int)" => - ir"($y:Int,$x:Int)" - case ir"(${Const(n)}:Int)" => Const(n+1) - } - - assert(b =~= ir"println((61,51))") - } - - test("Extract") { - import squid.ir.fastanf._ - object Embedding extends FastANF - import Embedding.Predef._ - import Embedding.Quasicodes._ - - code"42: Int" match { - case code"${Const(x: Int)}" => assert(x == 42) - } - - code"(x: Int) => x" match { - case code"(x: Int) => x" => - } - - code"(x: Int, y: Int) => x + 1 + y" match { - case code"(y: Int, z: Int) => y + 1 + z" => assert(true) - case _ => assert(false) - } - - //case class Point(x: Int, y: Int) - //val p = Point(0, 1) - // - //code{p.x} match { - // case code"p.x" => assert(true) - // case _ => assert(false) - //} - // - //code{p.x + p.y} match { - // case code"p.x" => assert(true) - // case _ => assert(false) - //} - } - -} -object BasicTests { - object Embedding extends FastANF -} From f9d8205a0764546d2f34c2f98f683808546c7522 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Wed, 6 Dec 2017 22:50:08 +0100 Subject: [PATCH 45/66] Automatically convert to ANF valid reps, correct handling of lambdas, and test some examples from the paper --- src/main/scala/squid/ir/fastanf/Effects.scala | 6 +- src/main/scala/squid/ir/fastanf/FastANF.scala | 148 ++++++++---------- src/main/scala/squid/ir/fastanf/Rep.scala | 59 +++++++ .../squid/ir/fastir/ExtractingTests.scala | 22 +++ .../squid/ir/fastir/RewritingTests.scala | 34 ++-- 5 files changed, 158 insertions(+), 111 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/Effects.scala b/src/main/scala/squid/ir/fastanf/Effects.scala index f70bc24a..8240854a 100644 --- a/src/main/scala/squid/ir/fastanf/Effects.scala +++ b/src/main/scala/squid/ir/fastanf/Effects.scala @@ -29,10 +29,12 @@ trait Effects { case Module(r, _, _) => effect(r) + case HOPHole2(_, _, args, _) => Impure//args.flatten.foldLeft(Pure: Effect) { _ |+| effect(_) } + case Constant(_) | _: Symbol | StaticModule(_) | NewObject(_) | Hole(_, _) | SplicedHole(_, _) | - HOPHole(_, _, _, _) | HOPHole2(_, _, _, _) => Pure + HOPHole(_, _, _, _) => Pure } def mtdEffect(m: MethodSymbol): Effect = { @@ -77,4 +79,6 @@ trait StandardEffects extends Effects { addPureMtd(MethodSymbol(TypeSymbol("scala.Tuple2$"), "apply")) addPureMtd(MethodSymbol(TypeSymbol("scala.Tuple3$"), "apply")) addPureMtd(MethodSymbol(TypeSymbol("squid.lib.package$"),"uncurried2")) + addPureMtd(MethodSymbol(TypeSymbol("scala.collection.LinearSeqOptimized"),"foldLeft")) + addPureMtd(MethodSymbol(TypeSymbol("scala.collection.immutable.List$"),"apply")) } diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index ff8f3d8c..2ce3e731 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -286,36 +286,35 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi // * --- * --- * --- * Implementations of `InspectableBase` methods * --- * --- * --- * def extractType(xtor: TypeRep, xtee: TypeRep, va: squid.ir.Variance): Option[Extract] = Some(EmptyExtract) //unsupported - def bottomUp(r: Rep)(f: Rep => Rep): Rep = transformRepAndDef(r)(identity, f)(identity) - def topDown(r: Rep)(f: Rep => Rep): Rep = transformRepAndDef(r)(f)(identity) - def transformRepAndDef(r: Rep)(pre: Rep => Rep, post: Rep => Rep = identity)(preDef: Def => Def, postDef: Def => Def = identity): Rep = { - def transformRepAndDef0(r: Rep) = transformRepAndDef(r)(pre, post)(preDef, postDef) - - def transformDef(d: Def): Def = postDef(preDef(d) match { - case App(f, a) => App(transformRepAndDef0(f), transformRepAndDef0(a)) // Note: App is a MethodApp, but we can transform it more efficiently this way - case ma: MethodApp => MethodApp(transformRepAndDef0(ma.self), ma.mtd, ma.targs, ma.argss argssMap (transformRepAndDef0(_)), ma.typ) - case l: Lambda => // Note: destructive modification of the lambda binding! - //new Lambda(l.name, l.bound, l.boundType, transformRepAndDef0(l.body)) - l.body = l.body |> transformRepAndDef0 - l - case _ => d - }) - + def bottomUp(r: Rep)(f: Rep => Rep): Rep = transformRep(r)(identity, f) + def topDown(r: Rep)(f: Rep => Rep): Rep = transformRep(r)(f) + + def transformRep(r: Rep)(pre: Rep => Rep, post: Rep => Rep = identity): Rep = { + def transformRep0(r: Rep) = transformRep(r)(pre, post) + + def transformDef(d: Def): Either[Rep, Def] = d match { + case ma: MethodApp => + Left(MethodApp.toANF(transformRep0(ma.self), ma.mtd, ma.targs, ma.argss argssMap transformRep0, ma.typ)) + case l: Lambda => + l.body = l.body |> transformRep0 + Right(l) + case _ => Right(d) + } + post(pre(r) match { - case lb: LetBinding => // Note: destructive modification of the let-binding! - lb.value = lb.value |> transformDef - lb.body = lb.body |> transformRepAndDef0 - lb - case Ascribe(s, t) => - Ascribe(transformRepAndDef0(s), t) - case Module(p, n, t) => - Module(transformRepAndDef0(p), n, t) + case lb: LetBinding => + lb.value |> transformDef match { + case Right(d) => + lb.value = d + lb.body = lb.body |> transformRep0 + lb + case Left(r) => LetBinding.withRepValue(lb.name, lb.bound, r, lb.body |> transformRep0) + } + case Ascribe(s, t) => Ascribe(transformRep0(s), t) + case Module(p, n, t) => Module(transformRep0(p), n, t) case r @ ((_:Constant) | (_:Hole) | (_:Symbol) | (_:SplicedHole) | (_:HOPHole) | (_:HOPHole2) | (_:NewObject) | (_:StaticModule)) => r }) } - - def transformRep(r: Rep)(pre: Rep => Rep, post: Rep => Rep = identity): Rep = - transformRepAndDef(r)(pre, post)(identity, identity) protected def extract(xtor: Rep, xtee: Rep): Option[Extract] = { println(s"Extract(\n$xtor, \n$xtee)") @@ -392,31 +391,18 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } def genFlags0(r: Rep, unusedBVs: Set[BoundVal], impures: Set[BoundVal]): (Set[BoundVal], Set[BoundVal]) = r match { - case lb: LetBinding => lb.value match { - // TODO should just inline the function in App during the transformation. Else you get an "Illegal ANF self argument". - case _: Lambda => - val updated = update( - lb.value, - unusedBVs + lb.bound, - effect(lb.value) match { - case Pure => impures + lb.bound - case Impure => impures + lb.bound - } - ) - genFlags0(lb.body, updated._1, updated._2) - case _ => - val updated = update( - lb.value, - unusedBVs + lb.bound, - effect(lb.value) match { - case Pure => impures - case Impure => impures + lb.bound - } - ) - genFlags0(lb.body, updated._1, updated._2) - } + case lb: LetBinding => + val updated = update( + lb.value, + unusedBVs + lb.bound, + effect(lb.value) match { + case Pure => impures + case Impure => impures + lb.bound + } + ) + genFlags0(lb.body, updated._1, updated._2) - case bv: BoundVal => (unusedBVs - bv, impures) //alsoApply println + case bv: BoundVal => (unusedBVs - bv, impures) case HOPHole2(_, _, argss, _) => val updatedImpures = argss.flatten.foldLeft(impures) { @@ -584,7 +570,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def contentsOf(h: Hole)(implicit es: State): Option[Rep] = es.ex._1 get h.name // TODO check in ex._3, return Option[List[Rep]] - println(s"extractWithState: $xtor\n$xtee\n") + println(s"-----\nextractWithState: \n$xtor\n\n$xtee\n-----\n\n") if (es.isDone) Right(es) else xtor -> xtee match { case (h: Hole, lb: LetBinding) => contentsOf(h) match { @@ -602,29 +588,24 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (HOPHole2(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) case (lb1: LetBinding, lb2: LetBinding) => extractLBs(lb1, lb2) - - //// TODO Stop at markers? - //case (lb: LetBinding, _: Rep) => (effect(lb), effect(xtee)) match { - // case (Pure, Pure) => extractWithState(lb.body, xtee)(done) - // case (Impure, Pure) => Left(es) - // case (_, Impure) => Left(es) // Assuming the return value cannot be impure - //} case (lb: LetBinding, _: Rep) => es.flags.xtorFlag(lb.bound) match { case Start => Left(es) case Skip => extractWithState(lb.body, xtee) } - case (bv: BoundVal, lb: LetBinding) => + case (bv: BoundVal, lb: LetBinding) => if (es.ctx.keySet contains bv) Right(es) - else if (es.partialMatching) { - for { - es1 <- extractWithState(bv, lb.bound).left - es2 <- extractInside(bv, lb.value)(es1).left - es3 <- extractWithState(bv, lb.body)(es2).left - } yield es3 - } else Left(es) - + else if (es.partialMatching) for { + es1 <- extractWithState(bv, lb.bound).left + es2 <- extractInside(bv, lb.value)(es1).left + es3 <- extractWithState(bv, lb.body)(es2).left + } yield es3 + else es.flags.xteeFlag(lb.bound) match { + case Skip => extractWithState(bv, lb.body) + case Start => Left(es) + } + case (_: Rep, lb: LetBinding) if (es.matchedImpureBVs contains lb.bound) || es.partialMatching => extractWithState(xtor, lb.body) case (_, Ascribe(s, _)) => extractWithState(xtor, s) @@ -830,24 +811,21 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } } - def cleanup(r: Rep, remove: Set[BoundVal], removeLambdas: Boolean)(ctx: Ctx): Rep = r match { - case lb: LetBinding if remove contains lb.bound => cleanup(lb.body, remove, removeLambdas)(ctx) - case lb: LetBinding => lb.value match { - case _: Lambda if (ctx.values.toSet contains lb.bound) && removeLambdas => - cleanup(lb.body, remove, removeLambdas)(ctx) - case v if isPure(v) => - val bvsInValue = bvs(lb.value) - - if (bvsInValue exists (remove contains)) { - cleanup(lb.body, remove ++ bvsInValue.toSet, removeLambdas)(ctx) - } else { - lb.body = cleanup(lb.body, remove, removeLambdas)(ctx) - lb - } - case _ => - lb.body = cleanup(lb.body, remove, removeLambdas)(ctx) + def cleanup(r: Rep, remove: Set[BoundVal])(ctx: Ctx): Rep = r match { + case lb: LetBinding if remove contains lb.bound => cleanup(lb.body, remove)(ctx) + + case lb: LetBinding if isPure(lb.value) => + val bvsInValue = bvs(lb.value) + + if (bvsInValue exists (remove contains)) { + cleanup(lb.body, remove ++ bvsInValue.toSet)(ctx) + } else { + lb.body = cleanup(lb.body, remove)(ctx) lb - } + } + + case lb: LetBinding => lb.body = cleanup(lb.body, remove)(ctx) + lb case _ => r } @@ -891,7 +869,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi if (preCheck(es.ex)) for { code <- code(es.ex) alsoApply(c => println(s"CODE: $c")) code0 = finalize(code, xtor, xtee)(es.ctx) alsoApply (c => println(s"CODE0: $c")) - if check(Set.empty, es.matchedImpureBVs)(cleanup(code0, es.matchedImpureBVs, true)(es.ctx)) + if check(Set.empty, es.matchedImpureBVs)(cleanup(code0, es.matchedImpureBVs)(es.ctx)) } yield code0 else None } diff --git a/src/main/scala/squid/ir/fastanf/Rep.scala b/src/main/scala/squid/ir/fastanf/Rep.scala index 076b3958..8c55ad94 100644 --- a/src/main/scala/squid/ir/fastanf/Rep.scala +++ b/src/main/scala/squid/ir/fastanf/Rep.scala @@ -101,6 +101,51 @@ object MethodApp { case _ => SimpleMethodApp(self: Rep, mtd: MethodSymbol, targs: List[TypeRep], argss)(typ) } } + + def toANF(self: Rep, mtd: MethodSymbol, targs: List[TypeRep], argss: ArgumentLists, typ0: TypeRep)(implicit base: FastANF): Rep = { + def split(r: Rep): Option[LetBinding] -> Rep = r match { + case lb: LetBinding => Some(lb) -> lb.last.body + case _ => None -> r + } + + val self0 = self |> split + //val argss0 = argss.argssList map split // TODO loosing structure of ArgumentLists... + + def withANFSelf(argss0: ArgumentLists): Rep = self0 match { + case (Some(lb), ret) => + val innerLB = new Symbol { + protected var _parent: SymbolParent = new LetBinding("tmp", this, MethodApp(ret, mtd, targs, argss0, typ0), this) + }.owner.asInstanceOf[LetBinding] + + lb.last.body = innerLB + lb + + case (None, ret) => new Symbol { + protected var _parent: SymbolParent = new LetBinding("tmp", this, MethodApp(ret, mtd, targs, argss0, typ0), this) + }.owner.asInstanceOf[LetBinding] + } + + //val mergedArgss0 = argss0.foldRight((None: Option[LetBinding], List.empty[Rep])) { + // case ((Some(lb), ret), (lbAcc, argssAcc)) => + // val lbAcc0 = lbAcc match { + // case Some(lbAcc) => + // lb.last.body = lbAcc + // lb + // case None => lb + // } + // (Some(lbAcc0), ret :: argssAcc) + // + // case ((None, ret), (lbAcc, argssAcc)) => (lbAcc, ret :: argssAcc) + //} + + //val t = (None, argss) match { + // case (Some(lb), argss0) => + // lb.last.body = withANFSelf(argss)//(argss0.foldRight(NoArguments: ArgumentList)(_ ~: _)) + // lb + // case (None, argss0) => withANFSelf(argss)//(argss0.foldRight(NoArguments: ArgumentList)(_ ~: _)) + //} + withANFSelf(argss) + } } final case class SimpleMethodApp protected(self: Rep, mtd: MethodSymbol, targs: List[TypeRep], argss: ArgumentLists)(val typ: TypeRep) extends MethodApp with CachedHashCode { doChecks } @@ -214,6 +259,20 @@ class LetBinding(var name: String, var bound: Symbol, var value: Def, private va } override def toString: String = s"val $bound = $value; $body" } +object LetBinding { + def withRepValue(name: String, bound: Symbol, value: Rep, mkBody: => Rep): Rep = value match { + case lb: LetBinding => + val last = lb.last + last.name = name + bound rebind last.bound + last.bound = bound + last.body = mkBody + lb + + case _ => throw new IllegalArgumentException + } +} + class Lambda(var name: String, var bound: Symbol, val boundType: TypeRep, var body: Rep)(implicit base: FastANF) extends Def with RebindableBinding { val typ: TypeRep = base.funType(boundType, body.typ) override def toString: String = s"($bound: $boundType) => $body --" diff --git a/src/test/scala/squid/ir/fastir/ExtractingTests.scala b/src/test/scala/squid/ir/fastir/ExtractingTests.scala index 5c7bd1ae..f600f0b2 100644 --- a/src/test/scala/squid/ir/fastir/ExtractingTests.scala +++ b/src/test/scala/squid/ir/fastir/ExtractingTests.scala @@ -116,6 +116,28 @@ class ExtractingTests extends MyFunSuiteBase(ExtractingTests.Embedding) { } }.isFailure) } + + test("Squid paper") { + + // 3.2 + ir"${ir"2"} + 1" match { + case ir"($n: Int) + 1" => + val res = ir"$n - 1" + assert(res =~= ir"${ir"2"} - 1") + } + + // 3.6 + assert(Try { + ir"(x: Int) => x + 1" match { + case ir"(x: Int) => $body: Int" => fail + }}.isFailure + ) + + ir"(x: Int) => x + 1" match { + case ir"(x: Int) => $f(x): Int" => assert(f =~= ir"(x: Int) => x + 1") + } + // --- + } } object ExtractingTests { diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index b4d4dc11..4a2f7503 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -95,7 +95,7 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { case ir"(x: Double) => x * 2" => ir"(x: Double) => x * 4" } val l0 = ir"(x: Double) => x * 4" - assert(a =~= ir"val a = 11.toDouble; val f = $l0; f(a)") + assert(a =~= ir"val a = 11.toDouble; val fOld = $l; val f = $l0; f(a)") } test("Rewriting should happen at all occurences") { @@ -129,33 +129,17 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { case ir"(x: Double) => x * 2" => fail } } - - /* --- Hacky solutions --- */ - test("Rewriting should flag lambdas as starting points") { - // Else we get an "Illegal ANF self argument" because the rewriting - // puts a LB in the function position of tha App f(a) - val l = ir"(x: Double) => x * 2" - val a = ir"val a = 11.toDouble; val f = $l; f(a)" rewrite { - case ir"(x: Double) => x * 2" => ir"(x: Double) => x * 4" - } - } - test("Rewriting should remove matched lambdas") { - // Else since genCode puts code at the top it will rewrite - // again and again the same lambda - val l = ir"(x: Double) => x * 2" - val a = ir"val a = 11.toDouble; val f = $l; f(a)" rewrite { - case ir"(x: Double) => x * 2" => ir"(x: Double) => x * 4" + test("Squid paper") { + + // 3.4 + val a = ir"List(1,2,3).foldLeft(0)((acc,x) => acc+x) + 4" rewrite { + case ir"($ls: List[Int]).foldLeft[Int]($init)($f)" => + ir"var cur = $init; $ls.foreach(x => cur = $f(cur, x)); cur" } + assert(a =~= + ir"val t = List(1, 2, 3); val f = ((acc: Int, x: Int) => acc+x); t.foldLeft(0)(f); var cur = 0; t.foreach(x => cur = f(cur, x)); cur + 4") } - /* ----------------------- */ - - //test("Rewriting with impures") { - // val a = ir"val a = readInt; val b = readInt; (a + b) * 0.5" rewrite { - // case ir"(($h1: Int) + ($h2: Int)) * 0.5" => dbg_ir"($h1 * $h2) + 42.0" - // } - // assert(a =~= ir"(readInt + readInt) + 42.0") - //} test("Rewriting simple expressions only once") { /* From c0386008a76bc637a77e0a137bdd4a3233a0d89d Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Thu, 7 Dec 2017 15:17:07 +0100 Subject: [PATCH 46/66] Add lambda rewriting tests --- .../squid/ir/fastir/RewritingTests.scala | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index 4a2f7503..d383b040 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -89,13 +89,27 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { } test("Rewriting lambdas") { - val l = ir"(x: Double) => x * 2" - - val a = ir"val a = 11.toDouble; val f = $l; f(a)" rewrite { + val a = ir"val a = 11.toDouble; val f = (x: Double) => x * 2; f(a)" rewrite { case ir"(x: Double) => x * 2" => ir"(x: Double) => x * 4" } - val l0 = ir"(x: Double) => x * 4" - assert(a =~= ir"val a = 11.toDouble; val fOld = $l; val f = $l0; f(a)") + assert(a =~= ir"val a = 11.toDouble; val fOld = (x: Double) => x * 2; val f = (x: Double) => x * 4; f(a)") + + // Rewrite multiple lambdas + val b = ir"val a = 11.toDouble; val f1 = (x: Double) => { println(x); x * 2 }; val f2 = (x: Double) => { val t = x * 2; val p = println(x); t }; f1(a) + f2(a)" rewrite { + case ir"(x: Double) => { println(x); x * 2 }" => ir"(x: Double) => x * 4" + } + assert(b =~= ir"val a = 11.toDouble; val f1 = (x: Double) => x * 4; val f2 = (x: Double) => x * 4; f1(a) + f2(a)") + + // Rewrite nested lambda + val c = ir"val a = 11.toDouble; val f1 = (x: Double) => { val f2 = (x: Double) => { val t = x * 2; val p = println(x); t }; println(x); f2(x) * 2 }; f1(a)" rewrite { + case ir"(x: Double) => { println(x); x * 2 } " => ir"(x: Double) => x * 4" + } + assert(c =~= ir"val a = 11.toDouble; val f1 = (x: Double) => { val f2 = (x: Double) => x * 4; println(x); f2(x) * 2 }; f1(a)") + + val d = c rewrite { + case ir"(x: Double) => { println(x); val f = (x: Double) => x * 4; f(x) * 2 }" => ir"(x: Double) => x * 4" + } + assert(d =~= ir"val a = 11.toDouble; val f = (x: Double) => x * 4; f(a)") } test("Rewriting should happen at all occurences") { From e137ed10556d68f8bba026758c852c723e219314 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Thu, 7 Dec 2017 19:21:02 +0100 Subject: [PATCH 47/66] Fix issue where rewriting in the arg position of a methodApp would generate a letbinding --- src/main/scala/squid/ir/fastanf/Rep.scala | 97 ++++++++++++------- .../squid/ir/fastir/RewritingTests.scala | 16 +++ 2 files changed, 80 insertions(+), 33 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/Rep.scala b/src/main/scala/squid/ir/fastanf/Rep.scala index 8c55ad94..59ba124e 100644 --- a/src/main/scala/squid/ir/fastanf/Rep.scala +++ b/src/main/scala/squid/ir/fastanf/Rep.scala @@ -103,48 +103,79 @@ object MethodApp { } def toANF(self: Rep, mtd: MethodSymbol, targs: List[TypeRep], argss: ArgumentLists, typ0: TypeRep)(implicit base: FastANF): Rep = { + def processArgss(argss: ArgumentLists)(f: Rep => (Option[LetBinding], Rep)): (Option[LetBinding], ArgumentLists) = { + def processArgs(args: ArgumentList)(f: Rep => (Option[LetBinding], Rep)): (Option[LetBinding], ArgumentList) = { + def go(args: ArgumentList)(k: (Option[LetBinding], ArgumentList) => (Option[LetBinding], ArgumentList)): (Option[LetBinding], ArgumentList) = args match { + case NoArguments => k(None, NoArguments) + case ArgumentCons(h, t) => + val (lb, ret) = f(h) + go(t) { case (lb0, args0) => + val newLB = (lb0, lb) match { + case (Some(lb0), Some(lb)) => lb.last.body = lb0; Some(lb) + case (_, Some(lb)) => Some(lb) + case (Some(lb0), _) => Some(lb0) + case _ => None + } + + k(newLB, ArgumentCons(ret, args0)) + } + case SplicedArgument(arg) => k.tupled(f(arg)) + case r: Rep => k.tupled(f(r)) + } + + go(args)((lb, args) => (lb, args)) + } + + def go(argss: ArgumentLists)(k: (Option[LetBinding], ArgumentLists) => (Option[LetBinding], ArgumentLists)): (Option[LetBinding], ArgumentLists) = argss match { + case NoArgumentLists => k(None, NoArgumentLists) + case NoArguments => k(None, NoArguments) + case a: ArgumentCons => processArgs(a)(f) + case s: SplicedArgument => processArgs(s)(f) + case ArgumentListCons(h, t) => + val (lb, args) = processArgs(h)(f) + go(t) { case (lb0, argss0) => + val newLB = (lb0, lb) match { + case (Some(lb0), Some(lb)) => lb.last.body = lb0; Some(lb) + case (_, Some(lb)) => Some(lb) + case (Some(lb0), _) => Some(lb0) + case _ => None + } + k(newLB, ArgumentListCons(args, argss0)) + } + case r: Rep => k.tupled(f(r)) + } + + go(argss)((lb, argss) => (lb, argss)) + } + def split(r: Rep): Option[LetBinding] -> Rep = r match { case lb: LetBinding => Some(lb) -> lb.last.body case _ => None -> r } - val self0 = self |> split - //val argss0 = argss.argssList map split // TODO loosing structure of ArgumentLists... - - def withANFSelf(argss0: ArgumentLists): Rep = self0 match { - case (Some(lb), ret) => - val innerLB = new Symbol { - protected var _parent: SymbolParent = new LetBinding("tmp", this, MethodApp(ret, mtd, targs, argss0, typ0), this) + val (lb0, self0) = self |> split + val (lbFromArgss, argss0) = processArgss(argss)(split) alsoApply println + + val newLB = (lb0, lbFromArgss) match { + case (Some(lb0), Some(lbFromArgss)) => lbFromArgss.last.body = lb0; Some(lbFromArgss) + case (_, Some(lbFromArgss)) => Some(lbFromArgss) + case (Some(lb0), _) => Some(lb0) + case _ => None + } + + val ma = MethodApp(self0, mtd, targs, argss0, typ0) + + newLB match { + case Some(lb) => + lb.last.body = new Symbol { + protected var _parent: SymbolParent = new LetBinding("tmp", this, ma, this) }.owner.asInstanceOf[LetBinding] - - lb.last.body = innerLB lb - - case (None, ret) => new Symbol { - protected var _parent: SymbolParent = new LetBinding("tmp", this, MethodApp(ret, mtd, targs, argss0, typ0), this) + + case None => new Symbol { + protected var _parent: SymbolParent = new LetBinding("tmp", this, ma, this) }.owner.asInstanceOf[LetBinding] } - - //val mergedArgss0 = argss0.foldRight((None: Option[LetBinding], List.empty[Rep])) { - // case ((Some(lb), ret), (lbAcc, argssAcc)) => - // val lbAcc0 = lbAcc match { - // case Some(lbAcc) => - // lb.last.body = lbAcc - // lb - // case None => lb - // } - // (Some(lbAcc0), ret :: argssAcc) - // - // case ((None, ret), (lbAcc, argssAcc)) => (lbAcc, ret :: argssAcc) - //} - - //val t = (None, argss) match { - // case (Some(lb), argss0) => - // lb.last.body = withANFSelf(argss)//(argss0.foldRight(NoArguments: ArgumentList)(_ ~: _)) - // lb - // case (None, argss0) => withANFSelf(argss)//(argss0.foldRight(NoArguments: ArgumentList)(_ ~: _)) - //} - withANFSelf(argss) } } final case class SimpleMethodApp protected(self: Rep, mtd: MethodSymbol, targs: List[TypeRep], argss: ArgumentLists)(val typ: TypeRep) diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index d383b040..080c7ac5 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -144,6 +144,22 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { } } + test("Rewriting should generate ANF valid code") { + val a = ir"List(1,2,3).foldLeft(0)((acc,x) => acc+x) + 4" rewrite { + case ir"($ls: List[Int]).foldLeft[Int]($init)($f)" => + ir"var cur = $init; $ls.foreach(x => cur = $f(cur, x)); cur" + } + assert(a =~= + ir"val t = List(1, 2, 3); val f = ((acc: Int, x: Int) => acc+x); t.foldLeft(0)(f); var cur = 0; t.foreach(x => cur = f(cur, x)); cur + 4") + + val b = ir"4 + List(1,2,3).foldLeft(0)((acc,x) => acc+x)" rewrite { + case ir"($ls: List[Int]).foldLeft[Int]($init)($f)" => + ir"var cur = $init; $ls.foreach(x => cur = $f(cur, x)); cur" + } + assert(b =~= + ir"val t = List(1, 2, 3); val f = ((acc: Int, x: Int) => acc+x); t.foldLeft(0)(f); var cur = 0; t.foreach(x => cur = f(cur, x)); 4 + cur") + } + test("Squid paper") { // 3.4 From 7b4e0c51a8b1b8876fa149b507d1d5f6f995b8a7 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Sun, 10 Dec 2017 16:56:44 +0100 Subject: [PATCH 48/66] Fix issue where HOPHoles would disregard pureness and add ByName wrapper --- src/main/scala/squid/ir/fastanf/Effects.scala | 2 + src/main/scala/squid/ir/fastanf/FastANF.scala | 203 ++++++++++-------- src/main/scala/squid/ir/fastanf/Rep.scala | 4 + .../fastir/HigherOrderPatternVariables.scala | 10 + .../squid/ir/fastir/RewritingTests.scala | 16 ++ 5 files changed, 148 insertions(+), 87 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/Effects.scala b/src/main/scala/squid/ir/fastanf/Effects.scala index 8240854a..53bd42ff 100644 --- a/src/main/scala/squid/ir/fastanf/Effects.scala +++ b/src/main/scala/squid/ir/fastanf/Effects.scala @@ -25,6 +25,8 @@ trait Effects { //} case _: Symbol => Pure + case ByName(r) => effect(r) + case Ascribe(r, _) => effect(r) case Module(r, _, _) => effect(r) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 2ce3e731..1325155c 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -118,7 +118,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => MethodApp(self |> inlineBlock, mtd, targs, argss |> toArgumentLists, tp) |> letbind } - def byName(mkArg: => Rep): Rep = wrapNest(mkArg) + def byName(mkArg: => Rep): Rep = ByName(wrapNest(mkArg)) def letbind(d: Def): Rep = currentScope += d def inlineBlock(r: Rep): Rep = r |>=? { @@ -276,6 +276,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi reinterpret0(lb.body), reinterpretType(lb.typ)) case s: Symbol => extrudedHandle(s) + case ByName(r) => newBase.byName(reinterpret0(r)) } } @@ -310,6 +311,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi lb case Left(r) => LetBinding.withRepValue(lb.name, lb.bound, r, lb.body |> transformRep0) } + case ByName(r) => ByName(transformRep0(r)) case Ascribe(s, t) => Ascribe(transformRep0(s), t) case Module(p, n, t) => Module(transformRep0(p), n, t) case r @ ((_:Constant) | (_:Hole) | (_:Symbol) | (_:SplicedHole) | (_:HOPHole) | (_:HOPHole2) | (_:NewObject) | (_:StaticModule)) => r @@ -330,12 +332,14 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi type ExtractState = Either[State, State] implicit def rightBias[A, B](e: Either[A, B]): Either.RightProjection[A,B] = e.right - case class State(ex: Extract, ctx: Ctx, flags: Flags, matchedImpureBVs: Set[BoundVal], - failedMatches: Map[BoundVal, Set[BoundVal]], done: State => Boolean, partialMatching: Boolean) { - def isDone: Boolean = done(this) - private val _done = done - - private val _partialMatching = partialMatching + sealed trait Objective + case object PartialMatching extends Objective + case object CompleteMatching extends Objective + case class Lookup(done: State => Boolean) extends Objective + + case class State(ex: Extract, ctx: Ctx, flags: Flags, matchedImpureBVs: Set[BoundVal], + failedMatches: Map[BoundVal, Set[BoundVal]], objective: Objective) { + private val _objective = objective def withNewExtract(newEx: Extract): State = copy(ex = newEx) def withCtx(newCtx: Ctx): State = copy(ctx = newCtx) @@ -347,22 +351,20 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => this // Everything else is pure so we ignore it } def withFailed(p: (BoundVal, BoundVal)): State = copy(failedMatches = updateWith(failedMatches)(p)) - def withPartialMatching: State = copy(partialMatching = true) - def withCompleteMatching: State = copy(partialMatching = false) - def withDefaultEarlyReturn: State = copy(partialMatching = _partialMatching) - def withTerminationPredicate(done: State => Boolean): State = copy(done = done) - def withDefaultTerminationPredicate: State = copy(done = _done) + def withObjective(b: Objective): State = copy(objective = b) + def withDefaultObjective: State = copy(objective = _objective) + def withXtorFlagsOf(xtor: Rep): State = copy(flags = flags.copy(xtor = Flags.genFlags(xtor))) def updateExtractWith(e: Option[Extract]*)(implicit default: State): ExtractState = { mergeAll(Some(ex) +: e).fold[ExtractState](Left(default))(ex => Right(this withNewExtract ex)) } } object State { - def forExtraction(xtor: Rep, xtee: Rep, done: State => Boolean = _ => false): State = apply(xtor, xtee, done, false) - def forRewriting(xtor: Rep, xtee: Rep, done: State => Boolean = _ => false): State = apply(xtor, xtee, done, true) + def forExtraction(xtor: Rep, xtee: Rep): State = apply(xtor, xtee, CompleteMatching) + def forRewriting(xtor: Rep, xtee: Rep): State = apply(xtor, xtee, PartialMatching) - private def apply(xtor: Rep, xtee: Rep, done: State => Boolean, partialMatching: Boolean): State = - State(EmptyExtract, ListMap.empty, Flags(xtor, xtee), Set.empty, Map.empty.withDefaultValue(Set.empty), done, partialMatching) + private def apply(xtor: Rep, xtee: Rep, objective: Objective): State = + State(EmptyExtract, ListMap.empty, Flags(xtor, xtee), Set.empty, Map.empty.withDefaultValue(Set.empty), objective) } sealed trait Flag @@ -378,7 +380,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi object Flags { def apply(xtor: Rep, xtee: Rep): Flags = Flags(genFlags(xtor), genFlags(xtee)) - private def genFlags(r: Rep): Set[BoundVal] = { + def genFlags(r: Rep): Set[BoundVal] = { def update(d: Def, unusedBVs: Set[BoundVal], impures: Set[BoundVal]): (Set[BoundVal], Set[BoundVal]) = d match { case l: Lambda => genFlags0(l.body, unusedBVs, impures) @@ -404,6 +406,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case bv: BoundVal => (unusedBVs - bv, impures) + case ByName(r) => genFlags0(r, unusedBVs, impures) + case HOPHole2(_, _, argss, _) => val updatedImpures = argss.flatten.foldLeft(impures) { case (impures, arg) => arg match { @@ -490,8 +494,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } getOrElse body -> es } - replaceAllOccurences0(r, body)(es withTerminationPredicate done withPartialMatching) - .mapSecond(_.withDefaultTerminationPredicate.withDefaultEarlyReturn) + replaceAllOccurences0(r, body)(es withXtorFlagsOf lb withObjective Lookup(done)) + .mapSecond(_.withXtorFlagsOf(xtor).withDefaultObjective) } replaceAllOccurences(lb, body)(es) @@ -509,7 +513,11 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi val oe = for { es1 <- typ extract (xtee.typ, Covariant) m <- merge(es.ex, es1) - (f, es2) <- argss.foldRight(Option(xtee -> (es withNewExtract m)))(extendFunc) + argss0 = argss.map(_.map { + case ByName(r) => r + case r => r + }) + (f, es2) <- argss0.foldRight(Option(xtee -> (es withNewExtract m)))(extendFunc) if !hasUndeclaredBVs(f) } yield es2 updateExtractWith Some(repExtract(name -> f)) @@ -571,87 +579,108 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def contentsOf(h: Hole)(implicit es: State): Option[Rep] = es.ex._1 get h.name // TODO check in ex._3, return Option[List[Rep]] println(s"-----\nextractWithState: \n$xtor\n\n$xtee\n-----\n\n") - if (es.isDone) Right(es) - else xtor -> xtee match { - case (h: Hole, lb: LetBinding) => contentsOf(h) match { - case Some(lb1: LetBinding) if lb1.value == lb.value => Right(es) - case Some(_) => Left(es) - case None => extractHole(h, xtee) - } + + es.objective match { + case Lookup(done) if done(es) => Right(es) + + case _ => xtor -> xtee match { + case (h: Hole, lb: LetBinding) => contentsOf(h) match { + case Some(lb1: LetBinding) if lb1.value == lb.value => Right(es) + case Some(_) => Left(es) + case None => extractHole(h, xtee) + } - case (h: Hole, _) => contentsOf(h) match { - case Some(`xtee`) => Right(es) - case Some(_) => Left(es) - case None => extractHole(h, xtee) - } + case (h: Hole, _) => contentsOf(h) match { + case Some(`xtee`) => Right(es) + case Some(_) => Left(es) + case None => extractHole(h, xtee) + } - case (HOPHole2(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) - - case (lb1: LetBinding, lb2: LetBinding) => extractLBs(lb1, lb2) + case (HOPHole2(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) - case (lb: LetBinding, _: Rep) => es.flags.xtorFlag(lb.bound) match { - case Start => Left(es) - case Skip => extractWithState(lb.body, xtee) - } - - case (bv: BoundVal, lb: LetBinding) => - if (es.ctx.keySet contains bv) Right(es) - else if (es.partialMatching) for { - es1 <- extractWithState(bv, lb.bound).left - es2 <- extractInside(bv, lb.value)(es1).left - es3 <- extractWithState(bv, lb.body)(es2).left - } yield es3 - else es.flags.xteeFlag(lb.bound) match { - case Skip => extractWithState(bv, lb.body) + case (lb1: LetBinding, lb2: LetBinding) => extractLBs(lb1, lb2) + + case (lb: LetBinding, _: Rep) => es.flags.xtorFlag(lb.bound) match { case Start => Left(es) + case Skip => extractWithState(lb.body, xtee) } - case (_: Rep, lb: LetBinding) if (es.matchedImpureBVs contains lb.bound) || es.partialMatching => extractWithState(xtor, lb.body) - - case (_, Ascribe(s, _)) => extractWithState(xtor, s) - - case (Ascribe(s, t), _) => for { - es1 <- es.updateExtractWith(t extract(xtee.typ, Covariant)) - es2 <- extractWithState(s, xtee)(es1) - } yield es2 + case (bv: BoundVal, lb: LetBinding) => + if (es.ctx.keySet contains bv) Right(es) + else es.objective match { + case CompleteMatching => es.flags.xteeFlag(lb.bound) match { + case Skip => extractWithState(bv, lb.body) + case Start => Left(es) + } - case (HOPHole(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) + case PartialMatching => es.flags.xteeFlag(lb.bound) match { + case Skip => for { + es1 <- extractWithState(bv, lb.bound).left + es2 <- extractInside(bv, lb.value)(es1).left + es3 <- extractWithState(bv, lb.body)(es2).left + } yield es3 + case Start => Left(es) + } - case (bv1: BoundVal, bv2: BoundVal) => - println(s"OWNERS: ${bv1.owner} -- ${bv2.owner}") - println(s"STATE: $es") - - es.ctx.get(bv1) map { bv => - if (bv != bv2) Left(es) - else Right(es) - } getOrElse { - if (bv1 == bv2) Right(es) - else if (es.failedMatches(bv1) contains bv2) Left(es) - else (bv1.owner, bv2.owner) match { - case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value) match { - case Right(es) => effect(lb2.value) match { - case Pure => Right(es withCtx lb1.bound -> lb2.bound) - case Impure => Right(es withCtx lb1.bound -> lb2.bound withMatchedImpures lb2.bound) + case Lookup(_) => for { + es1 <- extractWithState(bv, lb.bound).left + es2 <- extractInside(bv, lb.value)(es1).left + es3 <- extractWithState(bv, lb.body)(es2).left + } yield es3 + } + + case (_: Rep, lb: LetBinding) if es.matchedImpureBVs contains lb.bound => es.objective match { + case PartialMatching | Lookup(_) => extractWithState(xtor, lb.body) + case CompleteMatching => Left(es) + } + + case (ByName(r1), ByName(r2)) => extractWithState(r1, r2) + + case (_, Ascribe(s, _)) => extractWithState(xtor, s) + + case (Ascribe(s, t), _) => for { + es1 <- es.updateExtractWith(t extract(xtee.typ, Covariant)) + es2 <- extractWithState(s, xtee)(es1) + } yield es2 + + case (HOPHole(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) + + case (bv1: BoundVal, bv2: BoundVal) => + println(s"OWNERS: ${bv1.owner} -- ${bv2.owner}") + println(s"STATE: $es") + + es.ctx.get(bv1) map { bv => + if (bv != bv2) Left(es) + else Right(es) + } getOrElse { + if (bv1 == bv2) Right(es) + else if (es.failedMatches(bv1) contains bv2) Left(es) + else (bv1.owner, bv2.owner) match { + case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value) match { + case Right(es) => effect(lb2.value) match { + case Pure => Right(es withCtx lb1.bound -> lb2.bound) + case Impure => Right(es withCtx lb1.bound -> lb2.bound withMatchedImpures lb2.bound) + } + case Left(es) => Left(es withFailed lb1.bound -> lb2.bound) } - case Left(es) => Left(es withFailed lb1.bound -> lb2.bound) + case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2) map (_ withCtx l1.bound -> l2.bound) // TODO handle failed extract? + case _ => Left(es withFailed bv1 -> bv2) } - case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2) map (_ withCtx l1.bound -> l2.bound) // TODO handle failed extract? - case _ => Left(es withFailed bv1 -> bv2) } - } - case (Constant(v1), Constant(v2)) if v1 == v2 => es updateExtractWith (xtor.typ extract(xtee.typ, Covariant)) + case (Constant(v1), Constant(v2)) if v1 == v2 => es updateExtractWith (xtor.typ extract(xtee.typ, Covariant)) - // Assuming if they have the same name the type is the same - case (StaticModule(fn1), StaticModule(fn2)) if fn1 == fn2 => Right(es) + // Assuming if they have the same name the type is the same + case (StaticModule(fn1), StaticModule(fn2)) if fn1 == fn2 => Right(es) - // Assuming if they have the same name and prefix the type is the same - case (Module(p1, n1, _), Module(p2, n2, _)) if n1 == n2 => extractWithState(p1, p2) - - case (NewObject(t1), NewObject(t2)) => es updateExtractWith (t1 extract(t2, Covariant)) + // Assuming if they have the same name and prefix the type is the same + case (Module(p1, n1, _), Module(p2, n2, _)) if n1 == n2 => extractWithState(p1, p2) - case _ => Left(es) + case (NewObject(t1), NewObject(t2)) => es updateExtractWith (t1 extract(t2, Covariant)) + + case _ => Left(es) } + } } alsoApply (res => println(s"Extract: $res")) protected def spliceExtract(xtor: Rep, args: Args): Option[Extract] = xtor match { @@ -679,8 +708,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (l1: Lambda, l2: Lambda) => for { es1 <- es updateExtractWith (l1.boundType extract(l2.boundType, Covariant)) - es2 <- extractWithState(l1.body, l2.body)(es1 withCtx l1.bound -> l2.bound withCompleteMatching) - } yield es2 withDefaultEarlyReturn + es2 <- extractWithState(l1.body, l2.body)(es1 withCtx l1.bound -> l2.bound withObjective CompleteMatching) + } yield es2 withDefaultObjective case (ma1: MethodApp, ma2: MethodApp) if ma1.mtd == ma2.mtd => def targExtract(es0: State): ExtractState = diff --git a/src/main/scala/squid/ir/fastanf/Rep.scala b/src/main/scala/squid/ir/fastanf/Rep.scala index 59ba124e..6adf87dc 100644 --- a/src/main/scala/squid/ir/fastanf/Rep.scala +++ b/src/main/scala/squid/ir/fastanf/Rep.scala @@ -48,6 +48,10 @@ sealed abstract class Rep extends RepOption with ArgumentList with FlatSom[Rep] def argssList: List[Rep] = this :: Nil } +final case class ByName(r: Rep) extends Rep { + val typ: TypeRep = r.typ +} + final case class Constant(value: Any) extends Rep with CachedHashCode { // TODO impl and rm lazy lazy val typ = value match { diff --git a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala index e4a643b3..7ba5d366 100644 --- a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala +++ b/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala @@ -152,6 +152,16 @@ class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVaria case ir"($h(readInt + 1): Int)" => assert(h =~= ir"(x: Int) => x + x") } } + + test("HOPHoles should correctly match impure statements") { + ir"val a = readInt; val b = readDouble; a + b" match { + case ir"($h(readInt + readDouble): Double)" => assert(h =~= ir"(x: Int) => x") + } + + ir"val a = readInt; val b = readDouble; b + a" match { + case ir"($h(readInt + readDouble): Double)" => assert(h =~= ir"(x: Int) => { val a = readInt; val b = readDouble; b + a }") + } + } } object HigherOrderPatternVariables { diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index 080c7ac5..d25049fb 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -160,6 +160,20 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { ir"val t = List(1, 2, 3); val f = ((acc: Int, x: Int) => acc+x); t.foldLeft(0)(f); var cur = 0; t.foreach(x => cur = f(cur, x)); 4 + cur") } + test("Rewriting a by-name argument rewrites inside it") { + import RewritingTests.f + + val a = ir"f({println(10.toDouble); 10.toDouble})" rewrite { + case ir"10.toDouble" => ir"42.toDouble" + } + assert(a =~= ir"f({val a = 10.toDouble; println(42.toDouble); val b = 10.toDouble; 42.toDouble})") + + val b = ir"val a = 10.toDouble; f(a)" rewrite { + case ir"10.toDouble" => ir"42.toDouble" + } + assert(b =~= ir"val a = 10.toDouble; f(42.toDouble)") + } + test("Squid paper") { // 3.4 @@ -203,4 +217,6 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { object RewritingTests { object Embedding extends FastANF + + def f(x: => Double) = x } From 21886100e303d7ee96dc1464116ef5e85a9e4fd4 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Mon, 11 Dec 2017 23:05:37 +0100 Subject: [PATCH 49/66] Documentation and cleanup --- src/main/scala/squid/ir/fastanf/FastANF.scala | 729 ++++++++++-------- .../squid/ir/fastir/RewritingTests.scala | 7 +- 2 files changed, 401 insertions(+), 335 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 1325155c..88e551c5 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -1,9 +1,9 @@ package squid package ir.fastanf -import utils._ -import lang.{Base, InspectableBase, ScalaCore} import squid.ir._ +import squid.lang.{Base, InspectableBase, ScalaCore} +import squid.utils._ import scala.collection.immutable.{ListMap, ListSet} @@ -123,68 +123,62 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def letbind(d: Def): Rep = currentScope += d def inlineBlock(r: Rep): Rep = r |>=? { case lb: LetBinding => - println(s"INLINE: $lb --> $scopes") currentScope += lb - println(s"$scopes") inlineBlock(lb.body) } - override def letin(bound: BoundVal, value: Rep, body: => Rep, bodyType: TypeRep): Rep = { - println(s"letin: $value --> $bound") - value match { - case s: Symbol => - s.owner |>? { - case lb: RebindableBinding => - //println(s"LETIN $lb ") - lb.name = bound.name - } - s.owner |>? { - case lb: LetBinding => - lb.isUserDefined = true - } - withSubs(bound, value)(body) - - //s.owner |>? { - // case lb: RebindableBinding => - // lb.name = bound.name - //} - //bound rebind s - //body - case lb: LetBinding => - // conceptually, does like `inlineBlock`, but additionally rewrites `bound` and renames `lb`'s last binding - val last = lb.last - val boundName = bound.name - bound rebind last.bound - last.body = body - last.name = boundName // TODO make sure we're only renaming an automatically-named binding? - lb - // case c: Constant => bottomUpPartial(body) { case `bound` => c } - case h: Hole => - //Wrap construct? How? - - // letin(x, Hole, Constant(20)) => `val tmp = defHole; 20;` - - val dh = DefHole(h) |> letbind - - //(dh |>? { - // case bv: BoundVal => bv.owner |>? { - // case lb: LetBinding => - // lb.body = body - // lb - // } - //}).flatten.getOrElse(body) - - //new LetBinding(bound.name, bound, dh, body) alsoApply (currentScope += _) alsoApply (bound.rebind) - withSubs(bound -> dh)(body) - - - case (_:HOPHole) | (_:HOPHole2) | (_:SplicedHole) => - ??? // TODO holes should probably be Def's; note that it's not safe to do a substitution for holes - case _ => - withSubs(bound -> value)(body) - // ^ executing `body` will reify some statements into the reification scope, and likely return a symbol - // during this reification, we need all references to `bound` to be replaced by the actual `value` - } + override def letin(bound: BoundVal, value: Rep, body: => Rep, bodyType: TypeRep): Rep = value match { + case s: Symbol => + s.owner |>? { + case lb: RebindableBinding => + lb.name = bound.name + } + s.owner |>? { + case lb: LetBinding => + lb.isUserDefined = true + } + withSubs(bound, value)(body) + + //s.owner |>? { + // case lb: RebindableBinding => + // lb.name = bound.name + //} + //bound rebind s + //body + case lb: LetBinding => + // conceptually, does like `inlineBlock`, but additionally rewrites `bound` and renames `lb`'s last binding + val last = lb.last + val boundName = bound.name + bound rebind last.bound + last.body = body + last.name = boundName // TODO make sure we're only renaming an automatically-named binding? + lb + // case c: Constant => bottomUpPartial(body) { case `bound` => c } + case h: Hole => + //Wrap construct? How? + + // letin(x, Hole, Constant(20)) => `val tmp = defHole; 20;` + + val dh = DefHole(h) |> letbind + + //(dh |>? { + // case bv: BoundVal => bv.owner |>? { + // case lb: LetBinding => + // lb.body = body + // lb + // } + //}).flatten.getOrElse(body) + + //new LetBinding(bound.name, bound, dh, body) alsoApply (currentScope += _) alsoApply (bound.rebind) + withSubs(bound -> dh)(body) + + + case (_:HOPHole) | (_:HOPHole2) | (_:SplicedHole) => + ??? // TODO holes should probably be Def's; note that it's not safe to do a substitution for holes + case _ => + withSubs(bound -> value)(body) + // ^ executing `body` will reify some statements into the reification scope, and likely return a symbol + // during this reification, we need all references to `bound` to be replaced by the actual `value` } var curSub: Map[Symbol,Rep] = Map.empty @@ -194,15 +188,12 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi try k finally curSub = oldSub } - override def tryInline(fun: Rep, arg: Rep)(retTp: TypeRep): Rep = { - println(s"tryInline $fun -- $arg") - fun match { - case lb: LetBinding => lb.value match { - case l: Lambda => letin(l.bound, arg, l.body, l.body.typ) - case _ => super.tryInline(fun, arg)(retTp) - } + override def tryInline(fun: Rep, arg: Rep)(retTp: TypeRep): Rep = fun match { + case lb: LetBinding => lb.value match { + case l: Lambda => letin(l.bound, arg, l.body, l.body.typ) case _ => super.tryInline(fun, arg)(retTp) } + case _ => super.tryInline(fun, arg)(retTp) } override def ascribe(self: Rep, typ: TypeRep): Rep = if (self.typ =:= typ) self else self match { @@ -329,60 +320,122 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (k, v) => m + (k -> (m(k) + v)) } + // * --- * --- * --- * Extraction State * --- * --- * --- * + + /** + * Signals if the current extraction attempt has failed. `Left` it has failed, `Right` it succeeded. + */ type ExtractState = Either[State, State] implicit def rightBias[A, B](e: Either[A, B]): Either.RightProjection[A,B] = e.right - - sealed trait Objective - case object PartialMatching extends Objective - case object CompleteMatching extends Objective - case class Lookup(done: State => Boolean) extends Objective - - case class State(ex: Extract, ctx: Ctx, flags: Flags, matchedImpureBVs: Set[BoundVal], - failedMatches: Map[BoundVal, Set[BoundVal]], objective: Objective) { - private val _objective = objective + + /** + * Represents the current of the state of the extraction. + * @param ex what has been extracted by holes. + * @param ctx discovered matchings between bound values in the `xtor` and the `xtee`. + * @param instructions + * @param matchedImpureBVs + * @param failedMatches + * @param strategy + */ + case class State(ex: Extract, ctx: Ctx, instructions: Instructions, matchedImpureBVs: Set[BoundVal], + failedMatches: Map[BoundVal, Set[BoundVal]], strategy: Strategy) { + private val _strategy = strategy def withNewExtract(newEx: Extract): State = copy(ex = newEx) def withCtx(newCtx: Ctx): State = copy(ctx = newCtx) def withCtx(p: (BoundVal, BoundVal)): State = copy(ctx = ctx + p) + + /** + * Adds all the BVs referencing impure statements in `r` to `matchedImpureBVs`. + */ def withMatchedImpures(r: Rep): State = r match { case lb: LetBinding if !isPure(lb.value) => copy(matchedImpureBVs = matchedImpureBVs + lb.bound) withMatchedImpures lb.body case lb: LetBinding => this withMatchedImpures lb.body case bv: BoundVal => copy(matchedImpureBVs = matchedImpureBVs + bv) case _ => this // Everything else is pure so we ignore it } - def withFailed(p: (BoundVal, BoundVal)): State = copy(failedMatches = updateWith(failedMatches)(p)) - def withObjective(b: Objective): State = copy(objective = b) - def withDefaultObjective: State = copy(objective = _objective) - def withXtorFlagsOf(xtor: Rep): State = copy(flags = flags.copy(xtor = Flags.genFlags(xtor))) + + def withFailedMatch(p: (BoundVal, BoundVal)): State = copy(failedMatches = updateWith(failedMatches)(p)) + def withStrategy(s: Strategy): State = copy(strategy = s) + def withDefaultStrategy: State = copy(strategy = _strategy) + def withInstructionsFor(r: Rep): State = copy(instructions = instructions.copy(flags = instructions.flags ++ Instructions.gen(r))) def updateExtractWith(e: Option[Extract]*)(implicit default: State): ExtractState = { mergeAll(Some(ex) +: e).fold[ExtractState](Left(default))(ex => Right(this withNewExtract ex)) } } + object State { def forExtraction(xtor: Rep, xtee: Rep): State = apply(xtor, xtee, CompleteMatching) def forRewriting(xtor: Rep, xtee: Rep): State = apply(xtor, xtee, PartialMatching) - private def apply(xtor: Rep, xtee: Rep, objective: Objective): State = - State(EmptyExtract, ListMap.empty, Flags(xtor, xtee), Set.empty, Map.empty.withDefaultValue(Set.empty), objective) + private def apply(xtor: Rep, xtee: Rep, strategy: Strategy): State = + State(EmptyExtract, ListMap.empty, Instructions(xtor, xtee), Set.empty, Map.empty.withDefaultValue(Set.empty), strategy) } - sealed trait Flag - case object Start extends Flag - case object Skip extends Flag + + + + // * --- * --- * --- * Strategy * --- * --- * --- * + + /** + * Specifies the semantics of the extraction. + */ + sealed trait Strategy - case class Flags(xtor: Set[BoundVal], xtee: Set[BoundVal]) { - private def flags(ls: Set[BoundVal])(bv: BoundVal) = if (ls contains bv) Start else Skip - def xtorFlag(bv: BoundVal): Flag = flags(xtor)(bv) - def xteeFlag(bv: BoundVal): Flag = flags(xtee)(bv) + /** + * Allows the return value of the `xtor` to match anywhere in `xtee`. + * This will potentially leave some of the statement of the `xtee` unmatched. + * For instance, this [[Strategy]] is necessary for rewriting as we may only be rewriting parts of the `xtee`. + */ + case object PartialMatching extends Strategy + + // TODO enforces? + /** + * Enforces the fact that `xtor` has to fully match the `xtee`. + * For instance, this [[Strategy]] is necessary for extraction as + * the pattern `xtor` has to match the entire `xtee`. + */ + case object CompleteMatching extends Strategy + + + + + // * --- * --- * --- * Instructions * --- * --- * --- * + + /** + * Specifies what the extraction should do. + */ + sealed trait Instruction + + /** + * Instructs the extraction has to look for a matching statements. + * This instruction is attached to impure statements as well as pure statements + * that are not used by other statements. + */ + case object Start extends Instruction + + /** + * Instructs the extraction can ignore this statement. + * This instruction is attached to all pure statements that are used by the return value. + */ + case object Skip extends Instruction + + + case class Instructions(flags: Set[BoundVal]) { + def get(bv: BoundVal): Instruction = if (flags contains bv) Start else Skip } - object Flags { - def apply(xtor: Rep, xtee: Rep): Flags = Flags(genFlags(xtor), genFlags(xtee)) - - def genFlags(r: Rep): Set[BoundVal] = { + object Instructions { + def apply(xtor: Rep, xtee: Rep): Instructions = Instructions(gen(xtor) ++ gen(xtee)) + + /** + * Generates the instructions that will be flagged with [[Start]]. + * These will include the impure statements and the unused statements. + */ + def gen(r: Rep): Set[BoundVal] = { def update(d: Def, unusedBVs: Set[BoundVal], impures: Set[BoundVal]): (Set[BoundVal], Set[BoundVal]) = d match { - case l: Lambda => genFlags0(l.body, unusedBVs, impures) + case l: Lambda => genInstructions0(l.body, unusedBVs, impures) case ma: MethodApp => ((ma.self :: ma.argss.argssList).foldLeft(unusedBVs) { case (acc, bv: BoundVal) => acc - bv @@ -392,7 +445,11 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => (unusedBVs, impures) } - def genFlags0(r: Rep, unusedBVs: Set[BoundVal], impures: Set[BoundVal]): (Set[BoundVal], Set[BoundVal]) = r match { + /* + * The unused BVs are kept separate from the impures even thought both will be merged at the end in order to be + * able to keep track of the unused statements when recursing in `r`. + */ + def genInstructions0(r: Rep, unusedBVs: Set[BoundVal], impures: Set[BoundVal]): (Set[BoundVal], Set[BoundVal]) = r match { case lb: LetBinding => val updated = update( lb.value, @@ -402,26 +459,17 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case Impure => impures + lb.bound } ) - genFlags0(lb.body, updated._1, updated._2) + genInstructions0(lb.body, updated._1, updated._2) case bv: BoundVal => (unusedBVs - bv, impures) - case ByName(r) => genFlags0(r, unusedBVs, impures) - - case HOPHole2(_, _, argss, _) => - val updatedImpures = argss.flatten.foldLeft(impures) { - case (impures, arg) => arg match { - case lb: LetBinding if !isPure(lb.value) => impures + lb.bound - case _ => impures - } - } - (unusedBVs, updatedImpures) - + case ByName(r) => genInstructions0(r, unusedBVs, impures) + case _ => (unusedBVs, impures) } - val flags = genFlags0(r, Set.empty, Set.empty) - flags._1 ++ flags._2 + val instructions = genInstructions0(r, Set.empty, Set.empty) + instructions._1 ++ instructions._2 } } @@ -431,30 +479,42 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def extractHOPHole(name: String, typ: TypeRep, argss: List[List[Rep]], visible: List[BoundVal])(implicit es: State): ExtractState = { println("EXTRACTINGHOPHOLE") - def hasUndeclaredBVs(r: Rep): Boolean = { - println(s"Checking $r") - def hasUndeclaredBVs0(r: Rep, declared: Set[BoundVal]): Boolean = r match { + def usesUndeclaredBVs(r: Rep): Boolean = { + def usesUndeclaredBVs0(r: Rep, declared: Set[BoundVal]): Boolean = r match { case bv: BoundVal => !(declared contains bv) case lb: LetBinding => val declared0 = declared + lb.bound - hasUndeclaredBVsinDef(lb.value, declared0) || hasUndeclaredBVs0(lb.body, declared0) + defUsesUndeclaredBVs(lb.value, declared0) || usesUndeclaredBVs0(lb.body, declared0) case _ => false } - def hasUndeclaredBVsinDef(d: Def, declared: Set[BoundVal]): Boolean = d match { - case l: Lambda => hasUndeclaredBVs0(l.body, declared + l.bound) - case ma: MethodApp => (ma.self +: ma.argss.argssList) exists (hasUndeclaredBVs0(_, declared)) + def defUsesUndeclaredBVs(d: Def, declared: Set[BoundVal]): Boolean = d match { + case l: Lambda => usesUndeclaredBVs0(l.body, declared + l.bound) + case ma: MethodApp => (ma.self +: ma.argss.argssList) exists (usesUndeclaredBVs0(_, declared)) case _ => false } - hasUndeclaredBVs0(r, Set.empty) + usesUndeclaredBVs0(r, Set.empty) } - def extendFunc(args: List[Rep], maybeCurrFuncAndState: Option[(Rep, State)]): Option[(Rep, State)] = { - val hopArgs = args.map(arg => bindVal("hopArg", arg.typ, Nil)) - val transformations = args zip hopArgs - - def body(r: Rep): Rep = { + /** + * Attemps to find the `xtors` in the body of the function and replaces them with newly generated arguments, + * adding the new arguments to the function. Even if the `xtors` are not found in the body, + * arguments representing them will be generated and added to it. They will simply be ignored when it is applied. + * + * @param xtors + * @param maybeFuncAndState the current function and the extraction state after its creation. + * @return + */ + def extendFunc(xtors: List[Rep], maybeFuncAndState: Option[(Rep, State)]): Option[(Rep, State)] = { + val args = xtors.map(arg => bindVal("hopArg", arg.typ, Nil)) + val transformations = xtors zip args + + /** + * Returns the body of function. + * This is the body of the most deeply nested [[Lambda]] as the function is curried. + */ + def body(func: Rep): Rep = { def body0(d: Def): Option[Rep] = d match { case l: Lambda => l.body match { case lb: LetBinding => Some(body(lb)) @@ -463,54 +523,62 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => None } - - r match { + func match { case lb: LetBinding => body0(lb.value) getOrElse lb - case _ => r + case _ => func } } for { - (f, es) <- maybeCurrFuncAndState alsoApply println + (f, es) <- maybeFuncAndState newBodyAndState = transformations.foldLeft(body(f) -> es) { - case ((body, es), (arg, hopArg)) => arg match { - case bv: BoundVal => + case ((body, es), (xtor, arg)) => xtor match { + case bv: BoundVal => + /* + * Assumes the bv is already in the context. + * This is enforced by how the instructions are chosen. + */ val replace = es.ctx(bv) - replace rebind hopArg + replace rebind arg body -> es - case lb: LetBinding => - val lbBVs = bvs(lb) - val done: State => Boolean = s => lbBVs forall (s.ctx.keySet contains _) // TODO Keep set of matched BVs in state? - - def replaceAllOccurences(r: Rep, body: Rep)(es: State): Rep -> State = { - def replaceAllOccurences0(r: Rep, body: Rep)(implicit es: State): Rep -> State = { + case lb: LetBinding => + + /** + * Replaces all occurrences of `lb` in the body by `arg`. + */ + def replaceAllOccurrences(body: Rep)(es: State): Rep -> State = { + def replaceAllOccurrences0(body: Rep)(implicit es: State): Rep -> State = { + + /* + * Extracts the function body with the xtor in order to be able to use the context `ctx` to know + * what to replace with the new argument. + */ extractWithState(lb, body) map { es0 => - println(s"------------------") val replace = es0.ctx(lb.last.bound) - val body0 = bottomUpPartial(filterLBs(body)(es0.ctx.values.toSet contains _.bound)) { case `replace` => hopArg } - replaceAllOccurences0(r, body0)._1 -> es0 + val body0 = bottomUpPartial(filterLBs(body)(es0.ctx.values.toSet contains _.bound)) { case `replace` => arg } + replaceAllOccurrences0(body0) } getOrElse body -> es } - replaceAllOccurences0(r, body)(es withXtorFlagsOf lb withObjective Lookup(done)) - .mapSecond(_.withXtorFlagsOf(xtor).withDefaultObjective) + replaceAllOccurrences0(body)(es withInstructionsFor lb withStrategy PartialMatching) + .mapSecond(_ withInstructionsFor xtor withDefaultStrategy) } - replaceAllOccurences(lb, body)(es) + replaceAllOccurrences(body)(es) - case _ => bottomUpPartial(body) { case `arg` => hopArg } -> es + case _ => bottomUpPartial(body) { case `xtor` => arg } -> es } } - _ = bottomUpPartial(newBodyAndState._1) { case bv: BoundVal if visible contains bv => return None } // TODO is too early to check? If there are more args left + _ = bottomUpPartial(newBodyAndState._1) { case bv: BoundVal if visible contains bv => return None } // TODO is too early to check? If there are more xtors left } yield newBodyAndState match { - case (func0, es0) => wrapConstruct(lambda(hopArgs, func0)) -> es0 + case (func0, es0) => wrapConstruct(lambda(args, func0)) -> es0 } } - val oe = for { + val maybeES = for { es1 <- typ extract (xtee.typ, Covariant) m <- merge(es.ex, es1) argss0 = argss.map(_.map { @@ -518,10 +586,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case r => r }) (f, es2) <- argss0.foldRight(Option(xtee -> (es withNewExtract m)))(extendFunc) - if !hasUndeclaredBVs(f) + if !usesUndeclaredBVs(f) } yield es2 updateExtractWith Some(repExtract(name -> f)) - oe getOrElse Left(es) + maybeES getOrElse Left(es) } def extractLBs(lb1: LetBinding, lb2: LetBinding)(implicit es: State): ExtractState = { @@ -530,9 +598,18 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi es2 <- extractWithState(lb1.body, lb2.body)(es1) } yield es2 - (es.flags.xtorFlag(lb1.bound), es.flags.xteeFlag(lb2.bound)) match { + (es.instructions.get(lb1.bound), es.instructions.get(lb2.bound)) match { case (Start, Start) => extractAndContinue(lb1, lb2) + /** + * In the following case `aX` has to be extracted since when extracting the + * HOPHole `h` the BV has to be matched before it (see reason in [[extractHOPHole]]. + * + * XTOR: ir"val aX = 10.toDouble; $h(aX)" + * \-------Start------/ + * XTEE: ir"val a = 10.toDouble; a + 1" + * \------Skip------/ + */ case (Start, Skip) => for { es1 <- extractAndContinue(lb1, lb2).left es2 <- extractWithState(lb1, lb2.body)(es1).left @@ -544,143 +621,147 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } } - def extractHole(h: Hole, r: Rep)(implicit es: State): ExtractState = { - println(s"ExtractHole: $h --> $r") - - (h, r) match { - case (Hole(n, t), bv: BoundVal) => - es.updateExtractWith( - t extract(xtee.typ, Covariant), - Some(repExtract(n -> bv)) - ) - - case (Hole(n, t), lb: LetBinding) => - es.updateExtractWith( - t extract(lb.typ, Covariant), - Some(repExtract(n -> lb)) - ) map (_ withMatchedImpures lb) - - case (Hole(n, t), _) => - es.updateExtractWith( - t extract(xtee.typ, Covariant), - Some(repExtract(n -> xtee)) - ) - } + def extractHole(h: Hole, r: Rep)(implicit es: State): ExtractState = (h, r) match { + case (Hole(n, t), bv: BoundVal) => + es.updateExtractWith( + t extract(xtee.typ, Covariant), + Some(repExtract(n -> bv)) + ) + + case (Hole(n, t), lb: LetBinding) => + es.updateExtractWith( + t extract(lb.typ, Covariant), + Some(repExtract(n -> lb)) + ) map (_ withMatchedImpures lb) + + case (Hole(n, t), _) => + es.updateExtractWith( + t extract(xtee.typ, Covariant), + Some(repExtract(n -> xtee)) + ) } - def extractInside(bv: BoundVal, d: Def)(implicit es: State): ExtractState = - bvs(d).foldLeft[ExtractState](Left(es)) { case (acc, bv2) => + /** + * Attemps to extract `bv` by trying to match the BVs used in the [[MethodApp]] from left to right. + * It won't extract inside the [[ByName]] arguments. + */ + def extractInside(bv: BoundVal, ma: MethodApp)(implicit es: State): ExtractState = + collectBVs(ma).foldLeft[ExtractState](Left(es)) { case (acc, bv2) => for { es1 <- acc.left es2 <- extractWithState(bv, bv2)(es1).left } yield es2 } - def contentsOf(h: Hole)(implicit es: State): Option[Rep] = es.ex._1 get h.name // TODO check in ex._3, return Option[List[Rep]] + def extractedBy(h: Hole)(implicit es: State): Option[Rep] = es.ex._1 get h.name println(s"-----\nextractWithState: \n$xtor\n\n$xtee\n-----\n\n") - es.objective match { - case Lookup(done) if done(es) => Right(es) - - case _ => xtor -> xtee match { - case (h: Hole, lb: LetBinding) => contentsOf(h) match { - case Some(lb1: LetBinding) if lb1.value == lb.value => Right(es) - case Some(_) => Left(es) - case None => extractHole(h, xtee) - } - - case (h: Hole, _) => contentsOf(h) match { - case Some(`xtee`) => Right(es) - case Some(_) => Left(es) - case None => extractHole(h, xtee) - } - - case (HOPHole2(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) - - case (lb1: LetBinding, lb2: LetBinding) => extractLBs(lb1, lb2) + xtor -> xtee match { + case (h: Hole, _) => extractedBy(h) match { + case Some(`xtee`) => Right(es) + case Some(_) => Left(es) // Something has gone wrong + case None => extractHole(h, xtee) + } - case (lb: LetBinding, _: Rep) => es.flags.xtorFlag(lb.bound) match { - case Start => Left(es) - case Skip => extractWithState(lb.body, xtee) - } + case (HOPHole2(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) - case (bv: BoundVal, lb: LetBinding) => - if (es.ctx.keySet contains bv) Right(es) - else es.objective match { - case CompleteMatching => es.flags.xteeFlag(lb.bound) match { - case Skip => extractWithState(bv, lb.body) - case Start => Left(es) - } + case (lb1: LetBinding, lb2: LetBinding) => extractLBs(lb1, lb2) - case PartialMatching => es.flags.xteeFlag(lb.bound) match { - case Skip => for { - es1 <- extractWithState(bv, lb.bound).left - es2 <- extractInside(bv, lb.value)(es1).left - es3 <- extractWithState(bv, lb.body)(es2).left - } yield es3 - case Start => Left(es) - } + case (lb: LetBinding, _) => es.instructions.get(lb.bound) match { + case Start => Left(es) // The `xtor` has more impure statements than the `xtee`. + case Skip => extractWithState(lb.body, xtee) + } - case Lookup(_) => for { - es1 <- extractWithState(bv, lb.bound).left - es2 <- extractInside(bv, lb.value)(es1).left - es3 <- extractWithState(bv, lb.body)(es2).left - } yield es3 + case (bv: BoundVal, lb: LetBinding) => + if (es.ctx.keySet contains bv) Right(es) // `bv` has already been extracted + else es.strategy match { + case CompleteMatching => es.instructions.get(lb.bound) match { + + // The `xtee` has more impure statements or dead-ends than the `xtor` + case Start => Left(es) + + // The skipped statements of the `xtee` will be matched later + // when extracting its return value or the `Start` ones. + case Skip => extractWithState(bv, lb.body) } - case (_: Rep, lb: LetBinding) if es.matchedImpureBVs contains lb.bound => es.objective match { - case PartialMatching | Lookup(_) => extractWithState(xtor, lb.body) - case CompleteMatching => Left(es) + // Attempts to extract the return value `bv` by trying + // 1. The current let-binding + // 2. The BVs of the value of the let-binding if it's a [[MethodApp]] (see `extractInside`) + // 3. (1. and then 2.) on the next statement + case PartialMatching => for { + es1 <- extractWithState(bv, lb.bound).left + es2 <- (lb.value match { + case ma: MethodApp => extractInside(bv, ma)(es1) + case _ => Left(es) + }).left + es3 <- extractWithState(bv, lb.body)(es2).left + } yield es3 } - case (ByName(r1), ByName(r2)) => extractWithState(r1, r2) + case (Constant(()), lb: LetBinding) => es.strategy match { + // Unit return doesn't have to matched + case PartialMatching => Right(es) + case CompleteMatching => Left(es) + } - case (_, Ascribe(s, _)) => extractWithState(xtor, s) + + case (_: Rep, lb: LetBinding) => es.strategy match { + // `xtor` cannot be a let-binding or BV in this case so it only makes sense to extract the body. + case PartialMatching => extractWithState(xtor, lb.body) + case CompleteMatching => Left(es) + } + + // Only a [[ByName]] can extract a [[ByName]] so that + // the rewriting will rewrite inside the [[ByName]]. + case (ByName(r1), ByName(r2)) => extractWithState(r1, r2) + + case (_, Ascribe(s, _)) => extractWithState(xtor, s) - case (Ascribe(s, t), _) => for { - es1 <- es.updateExtractWith(t extract(xtee.typ, Covariant)) - es2 <- extractWithState(s, xtee)(es1) - } yield es2 + case (Ascribe(s, t), _) => for { + es1 <- es.updateExtractWith(t extract(xtee.typ, Covariant)) + es2 <- extractWithState(s, xtee)(es1) + } yield es2 - case (HOPHole(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) - - case (bv1: BoundVal, bv2: BoundVal) => - println(s"OWNERS: ${bv1.owner} -- ${bv2.owner}") - println(s"STATE: $es") - - es.ctx.get(bv1) map { bv => - if (bv != bv2) Left(es) - else Right(es) - } getOrElse { - if (bv1 == bv2) Right(es) - else if (es.failedMatches(bv1) contains bv2) Left(es) - else (bv1.owner, bv2.owner) match { - case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value) match { - case Right(es) => effect(lb2.value) match { - case Pure => Right(es withCtx lb1.bound -> lb2.bound) - case Impure => Right(es withCtx lb1.bound -> lb2.bound withMatchedImpures lb2.bound) - } - case Left(es) => Left(es withFailed lb1.bound -> lb2.bound) + case (HOPHole(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) + + // The actual extraction happens here. + case (bv1: BoundVal, bv2: BoundVal) => + println(s"OWNERS: ${bv1.owner} -- ${bv2.owner}") + println(s"STATE: $es") + + es.ctx.get(bv1) map { bv => + if (bv != bv2) Left(es) // `bv1` has already extracted something else + else Right(es) + } getOrElse { + if (bv1 == bv2) Right(es) + else if (es.failedMatches(bv1) contains bv2) Left(es) // Previously failed to extract `bv1` with `bv2` + else (bv1.owner, bv2.owner) match { + case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value) match { + case Right(es) => effect(lb2.value) match { + case Pure => Right(es withCtx lb1.bound -> lb2.bound) + case Impure => Right(es withCtx lb1.bound -> lb2.bound withMatchedImpures lb2.bound) } - case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2) map (_ withCtx l1.bound -> l2.bound) // TODO handle failed extract? - case _ => Left(es withFailed bv1 -> bv2) + case Left(es) => Left(es withFailedMatch lb1.bound -> lb2.bound) } + case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2) map (_ withCtx l1.bound -> l2.bound) + case _ => Left(es withFailedMatch bv1 -> bv2) } + } - case (Constant(v1), Constant(v2)) if v1 == v2 => es updateExtractWith (xtor.typ extract(xtee.typ, Covariant)) + case (Constant(v1), Constant(v2)) if v1 == v2 => es updateExtractWith (xtor.typ extract(xtee.typ, Covariant)) - // Assuming if they have the same name the type is the same - case (StaticModule(fn1), StaticModule(fn2)) if fn1 == fn2 => Right(es) + // Assuming if they have the same name the type is the same + case (StaticModule(fn1), StaticModule(fn2)) if fn1 == fn2 => Right(es) - // Assuming if they have the same name and prefix the type is the same - case (Module(p1, n1, _), Module(p2, n2, _)) if n1 == n2 => extractWithState(p1, p2) + // Assuming if they have the same name and prefix the type is the same + case (Module(p1, n1, _), Module(p2, n2, _)) if n1 == n2 => extractWithState(p1, p2) - case (NewObject(t1), NewObject(t2)) => es updateExtractWith (t1 extract(t2, Covariant)) + case (NewObject(t1), NewObject(t2)) => es updateExtractWith (t1 extract(t2, Covariant)) - case _ => Left(es) + case _ => Left(es) } - } } alsoApply (res => println(s"Extract: $res")) protected def spliceExtract(xtor: Rep, args: Args): Option[Extract] = xtor match { @@ -702,72 +783,70 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => throw IRException(s"Trying to splice-extract with invalid extractor $xtor") } - def extractDefs(v1: Def, v2: Def)(implicit es: State): ExtractState = { - println(s"VALUES: \n\t$v1\n\t$v2 with $es \n\n") - (v1, v2) match { - case (l1: Lambda, l2: Lambda) => - for { - es1 <- es updateExtractWith (l1.boundType extract(l2.boundType, Covariant)) - es2 <- extractWithState(l1.body, l2.body)(es1 withCtx l1.bound -> l2.bound withObjective CompleteMatching) - } yield es2 withDefaultObjective - - case (ma1: MethodApp, ma2: MethodApp) if ma1.mtd == ma2.mtd => - def targExtract(es0: State): ExtractState = - es0.updateExtractWith( - (for { - (e1, e2) <- ma1.targs zip ma2.targs - } yield e1 extract(e2, Invariant)): _* - ) - - def extractArgss(argss1: ArgumentLists, argss2: ArgumentLists)(implicit es: State): ExtractState = { - def extractArgss0(argss1: ArgumentLists, argss2: ArgumentLists, acc: List[Rep])(implicit es: State): ExtractState = (argss1, argss2) match { - case (ArgumentListCons(h1, t1), ArgumentListCons(h2, t2)) => for { - es0 <- extractArgss0(h1, h2, acc) - es1 <- extractArgss0(t1, t2, acc)(es0) - } yield es1 - - case (ArgumentCons(h1, t1), ArgumentCons(h2, t2)) => for { - es0 <- extractArgss0(h1, h2, acc) - es1 <- extractArgss0(t1, t2, acc)(es0) - } yield es1 - - case (SplicedArgument(arg1), SplicedArgument(arg2)) => extractWithState(arg1, arg2) - case (sa: SplicedArgument, ArgumentCons(h, t)) => extractArgss0(sa, t, h :: acc) - case (sa: SplicedArgument, r: Rep) => extractArgss0(sa, NoArguments, r :: acc) - case (SplicedArgument(arg), NoArguments) => es updateExtractWith spliceExtract(arg, Args(acc.reverse: _*)) - case (r1: Rep, r2: Rep) => extractWithState(r1, r2) - case (NoArguments, NoArguments) => Right(es) - case (NoArgumentLists, NoArgumentLists) => Right(es) - case _ => Left(es) - } - - extractArgss0(argss1, argss2, Nil) + def extractDefs(v1: Def, v2: Def)(implicit es: State): ExtractState = (v1, v2) match { + case (l1: Lambda, l2: Lambda) => + for { + es1 <- es updateExtractWith (l1.boundType extract(l2.boundType, Covariant)) + es2 <- extractWithState(l1.body, l2.body)(es1 withCtx l1.bound -> l2.bound withStrategy CompleteMatching) + } yield es2 withDefaultStrategy + + case (ma1: MethodApp, ma2: MethodApp) if ma1.mtd == ma2.mtd => + def targExtract(es0: State): ExtractState = + es0.updateExtractWith( + (for { + (e1, e2) <- ma1.targs zip ma2.targs + } yield e1 extract(e2, Invariant)): _* + ) + + def extractArgss(argss1: ArgumentLists, argss2: ArgumentLists)(implicit es: State): ExtractState = { + def extractArgss0(argss1: ArgumentLists, argss2: ArgumentLists, acc: List[Rep])(implicit es: State): ExtractState = (argss1, argss2) match { + case (ArgumentListCons(h1, t1), ArgumentListCons(h2, t2)) => for { + es0 <- extractArgss0(h1, h2, acc) + es1 <- extractArgss0(t1, t2, acc)(es0) + } yield es1 + + case (ArgumentCons(h1, t1), ArgumentCons(h2, t2)) => for { + es0 <- extractArgss0(h1, h2, acc) + es1 <- extractArgss0(t1, t2, acc)(es0) + } yield es1 + + case (SplicedArgument(arg1), SplicedArgument(arg2)) => extractWithState(arg1, arg2) + case (sa: SplicedArgument, ArgumentCons(h, t)) => extractArgss0(sa, t, h :: acc) + case (sa: SplicedArgument, r: Rep) => extractArgss0(sa, NoArguments, r :: acc) + case (SplicedArgument(arg), NoArguments) => es updateExtractWith spliceExtract(arg, Args(acc.reverse: _*)) + case (r1: Rep, r2: Rep) => extractWithState(r1, r2) + case (NoArguments, NoArguments) => Right(es) + case (NoArgumentLists, NoArgumentLists) => Right(es) + case _ => Left(es) } - for { - es1 <- extractWithState(ma1.self, ma2.self) - es2 <- targExtract(es1) - es3 <- extractArgss(ma1.argss, ma2.argss)(es2) - es4 <- es3.updateExtractWith(ma1.typ extract (ma2.typ, Covariant)) - } yield es4 + extractArgss0(argss1, argss2, Nil) + } - case (DefHole(h), _) if !isPure(v2) => extractWithState(h, wrapConstruct(letbind(v2))) + for { + es1 <- extractWithState(ma1.self, ma2.self) + es2 <- targExtract(es1) + es3 <- extractArgss(ma1.argss, ma2.argss)(es2) + es4 <- es3.updateExtractWith(ma1.typ extract (ma2.typ, Covariant)) + } yield es4 - case _ => Left(es) - } + // Assuming a [[DefHole]] only extracts impure statements + case (DefHole(h), _) if !isPure(v2) => extractWithState(h, wrapConstruct(letbind(v2))) + + case _ => Left(es) } override def rewriteRep(xtor: Rep, xtee: Rep, code: Extract => Option[Rep]): Option[Rep] = - rewriteRep0(xtor, xtee, code)(false)(State.forRewriting(xtor, xtee)) + rewriteRep0(xtor, xtee, code)(State.forRewriting(xtor, xtee)) - def rewriteRep0(xtor: Rep, xtee: Rep, code: Extract => Option[Rep])(internalRec: Boolean)(implicit es: State): Option[Rep] = { + def rewriteRep0(xtor: Rep, xtee: Rep, code: Extract => Option[Rep])(implicit es: State): Option[Rep] = { def rewriteRepWithState(xtor: Rep, xtee: Rep)(implicit es: State): ExtractState = { println(s"rewriteRepWithState(\n\t$xtor\n\t$xtee)($es)") (xtor, xtee) match { - case (lb1: LetBinding, lb2: LetBinding) if !internalRec => - ((effect(lb1.value), es.flags.xtorFlag(lb1.bound)), (effect(lb2.value), es.flags.xteeFlag(lb2.bound))) match { - case ((Pure, Skip), (Pure, Skip)) => Left(es) + case (lb1: LetBinding, lb2: LetBinding) => + (es.instructions.get(lb1.bound), es.instructions.get(lb2.bound)) match { + case (Skip, Skip) => Left(es) case _ => extractWithState(lb1, lb2) } case _ => extractWithState(xtor, xtee) @@ -844,8 +923,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case lb: LetBinding if remove contains lb.bound => cleanup(lb.body, remove)(ctx) case lb: LetBinding if isPure(lb.value) => - val bvsInValue = bvs(lb.value) - + val bvsInValue = lb.value |>? { + case ma: MethodApp => collectBVs(ma) + } getOrElse ListSet.empty + if (bvsInValue exists (remove contains)) { cleanup(lb.body, remove ++ bvsInValue.toSet)(ctx) } else { @@ -918,21 +999,20 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => r } - def bvs(r: Rep): List[BoundVal] = { - def bvs0(r: Rep, acc: List[BoundVal]): List[BoundVal] = r match { - case lb: LetBinding => bvs0(lb.body, lb.bound :: acc) + def collectLBBounds(r: Rep): ListSet[BoundVal] = { + def collectLBBounds0(r: Rep, acc: ListSet[BoundVal]): ListSet[BoundVal] = r match { + case lb: LetBinding => collectLBBounds0(lb.body, acc + lb.bound) case _ => acc } - bvs0(r, List.empty) + collectLBBounds0(r, ListSet.empty) } - def bvs(d: Def): List[BoundVal] = d match { - case ma: MethodApp => (ma.self :: ma.argss.argssList).foldRight(List.empty[BoundVal]) { - case (bv: BoundVal, acc) => bv :: acc - case (_, acc) => acc + + def collectBVs(ma: MethodApp): ListSet[BoundVal] = + (ma.self :: ma.argss.argssList).foldLeft(ListSet.empty[BoundVal]) { + case (acc, bv: BoundVal) => acc + bv + case (acc, _) => acc } - case _ => Nil - } // * --- * --- * --- * Implementations of `QuasiBase` methods * --- * --- * --- * @@ -943,13 +1023,11 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi override def hopHole2(name: String, typ: TypeRep, args: List[List[Rep]], visible: List[BoundVal]) = HOPHole2(name, typ, args, visible filterNot (args.flatten contains _)) def substitute(r: => Rep, defs: Map[String, Rep]): Rep = { - println(s"Subs: $r with $defs") val r0 = if (defs isEmpty) r else bottomUp(r) { case h@Hole(n, _) => defs getOrElse(n, h) case h@SplicedHole(n, _) => defs getOrElse(n, h) - //case h: BoundVal => defs getOrElse(h.name, h) // TODO FVs in lambda become BVs too early, this should be changed!! case h => h } @@ -959,13 +1037,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi // TODO for now we do nothing to r. Later make sure that after applying the defs it is still valid in ANF! require(defs.isEmpty) r - - //if (defs isEmpty) r - //else bottomUp(r) { - // case h@Hole(n, _) => defs getOrElse(n, h) - // case h@SplicedHole(n, _) => defs getOrElse(n, h) - // case h => h - //} } diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index d25049fb..eb5625b7 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -163,15 +163,10 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { test("Rewriting a by-name argument rewrites inside it") { import RewritingTests.f - val a = ir"f({println(10.toDouble); 10.toDouble})" rewrite { + val a = ir"f({val a = 10.toDouble; println(10.toDouble); a})" rewrite { case ir"10.toDouble" => ir"42.toDouble" } assert(a =~= ir"f({val a = 10.toDouble; println(42.toDouble); val b = 10.toDouble; 42.toDouble})") - - val b = ir"val a = 10.toDouble; f(a)" rewrite { - case ir"10.toDouble" => ir"42.toDouble" - } - assert(b =~= ir"val a = 10.toDouble; f(42.toDouble)") } test("Squid paper") { From 57926a3f5dea5a8641cbc7252341c46160eca8be Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Wed, 13 Dec 2017 14:05:42 +0100 Subject: [PATCH 50/66] Simplify hole extraction and add mechanism to revert to the default instructions --- src/main/scala/squid/ir/fastanf/FastANF.scala | 43 ++++++++----------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 88e551c5..ca807d54 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -340,6 +340,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case class State(ex: Extract, ctx: Ctx, instructions: Instructions, matchedImpureBVs: Set[BoundVal], failedMatches: Map[BoundVal, Set[BoundVal]], strategy: Strategy) { private val _strategy = strategy + private val _instructions = instructions def withNewExtract(newEx: Extract): State = copy(ex = newEx) def withCtx(newCtx: Ctx): State = copy(ctx = newCtx) @@ -351,14 +352,17 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def withMatchedImpures(r: Rep): State = r match { case lb: LetBinding if !isPure(lb.value) => copy(matchedImpureBVs = matchedImpureBVs + lb.bound) withMatchedImpures lb.body case lb: LetBinding => this withMatchedImpures lb.body - case bv: BoundVal => copy(matchedImpureBVs = matchedImpureBVs + bv) case _ => this // Everything else is pure so we ignore it } + def withMatchedImpure(bv: BoundVal): State = copy(matchedImpureBVs = matchedImpureBVs + bv) def withFailedMatch(p: (BoundVal, BoundVal)): State = copy(failedMatches = updateWith(failedMatches)(p)) + def withStrategy(s: Strategy): State = copy(strategy = s) def withDefaultStrategy: State = copy(strategy = _strategy) + def withInstructionsFor(r: Rep): State = copy(instructions = instructions.copy(flags = instructions.flags ++ Instructions.gen(r))) + def withDefaultInstructions: State = copy(instructions = _instructions) def updateExtractWith(e: Option[Extract]*)(implicit default: State): ExtractState = { mergeAll(Some(ex) +: e).fold[ExtractState](Left(default))(ex => Right(this withNewExtract ex)) @@ -389,10 +393,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi * For instance, this [[Strategy]] is necessary for rewriting as we may only be rewriting parts of the `xtee`. */ case object PartialMatching extends Strategy - - // TODO enforces? + /** - * Enforces the fact that `xtor` has to fully match the `xtee`. + * Enforces the fact that `xtor` has to fully match the `xtee`. + * (Enforced by the extraction algorithm but having a final check might be a good idea) * For instance, this [[Strategy]] is necessary for extraction as * the pattern `xtor` has to match the entire `xtee`. */ @@ -506,7 +510,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi * @param maybeFuncAndState the current function and the extraction state after its creation. * @return */ - def extendFunc(xtors: List[Rep], maybeFuncAndState: Option[(Rep, State)]): Option[(Rep, State)] = { + def buildFunc(xtors: List[Rep], maybeFuncAndState: Option[(Rep, State)]): Option[(Rep, State)] = { val args = xtors.map(arg => bindVal("hopArg", arg.typ, Nil)) val transformations = xtors zip args @@ -563,7 +567,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } replaceAllOccurrences0(body)(es withInstructionsFor lb withStrategy PartialMatching) - .mapSecond(_ withInstructionsFor xtor withDefaultStrategy) + .mapSecond(_.withDefaultInstructions.withDefaultStrategy) } replaceAllOccurrences(body)(es) @@ -581,11 +585,14 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi val maybeES = for { es1 <- typ extract (xtee.typ, Covariant) m <- merge(es.ex, es1) + argss0 = argss.map(_.map { case ByName(r) => r case r => r }) - (f, es2) <- argss0.foldRight(Option(xtee -> (es withNewExtract m)))(extendFunc) + + (f, es2) <- argss0.foldRight(Option(xtee -> (es withNewExtract m)))(buildFunc) + if !usesUndeclaredBVs(f) } yield es2 updateExtractWith Some(repExtract(name -> f)) @@ -621,24 +628,12 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } } - def extractHole(h: Hole, r: Rep)(implicit es: State): ExtractState = (h, r) match { - case (Hole(n, t), bv: BoundVal) => + def extractHole(h: Hole, r: Rep)(implicit es: State): ExtractState = h match { + case Hole(n, t) => es.updateExtractWith( t extract(xtee.typ, Covariant), - Some(repExtract(n -> bv)) - ) - - case (Hole(n, t), lb: LetBinding) => - es.updateExtractWith( - t extract(lb.typ, Covariant), - Some(repExtract(n -> lb)) - ) map (_ withMatchedImpures lb) - - case (Hole(n, t), _) => - es.updateExtractWith( - t extract(xtee.typ, Covariant), - Some(repExtract(n -> xtee)) - ) + Some(repExtract(n -> r)) + ) map (_ withMatchedImpures r) } /** @@ -741,7 +736,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value) match { case Right(es) => effect(lb2.value) match { case Pure => Right(es withCtx lb1.bound -> lb2.bound) - case Impure => Right(es withCtx lb1.bound -> lb2.bound withMatchedImpures lb2.bound) + case Impure => Right(es withCtx lb1.bound -> lb2.bound withMatchedImpure lb2.bound) } case Left(es) => Left(es withFailedMatch lb1.bound -> lb2.bound) } From 9afe79010697a28d34a63ef10d75e39bf46ae6ce Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Wed, 13 Dec 2017 14:07:14 +0100 Subject: [PATCH 51/66] Rename HOPV tests file to be more uniform --- ...{HigherOrderPatternVariables.scala => HOPVTests.scala} | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) rename src/test/scala/squid/ir/fastir/{HigherOrderPatternVariables.scala => HOPVTests.scala} (96%) diff --git a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala b/src/test/scala/squid/ir/fastir/HOPVTests.scala similarity index 96% rename from src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala rename to src/test/scala/squid/ir/fastir/HOPVTests.scala index 7ba5d366..426690cb 100644 --- a/src/test/scala/squid/ir/fastir/HigherOrderPatternVariables.scala +++ b/src/test/scala/squid/ir/fastir/HOPVTests.scala @@ -4,8 +4,8 @@ package fastir import scala.util.Try -class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVariables.Embedding) { - import HigherOrderPatternVariables.Embedding.Predef._ +class HOPVTests extends MyFunSuiteBase(HOPVTests.Embedding) { + import HOPVTests.Embedding.Predef._ test("Matching lambda bodies") { val id = ir"(z:Int) => z" @@ -62,7 +62,7 @@ class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVaria test("Matching let-binding bodies") { // Not implemented error in `letin` - //ir"val a = 0; val b = 1; a + b" matches { + //ir"val a = ${ir"0"}; val b = 1; a + b" match { // case ir"val x: Int = $v; $body(x):Int" => // assert(v == ir"0") // body matches { @@ -164,6 +164,6 @@ class HigherOrderPatternVariables extends MyFunSuiteBase(HigherOrderPatternVaria } } -object HigherOrderPatternVariables { +object HOPVTests { object Embedding extends FastANF } From 8d1093c529e1d5b2444cef818545348c12e9935f Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Wed, 13 Dec 2017 15:37:36 +0100 Subject: [PATCH 52/66] Simplify extraction with a SplicedArgument --- src/main/scala/squid/ir/fastanf/FastANF.scala | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index ca807d54..ad1e0926 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -793,29 +793,29 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } yield e1 extract(e2, Invariant)): _* ) - def extractArgss(argss1: ArgumentLists, argss2: ArgumentLists)(implicit es: State): ExtractState = { - def extractArgss0(argss1: ArgumentLists, argss2: ArgumentLists, acc: List[Rep])(implicit es: State): ExtractState = (argss1, argss2) match { - case (ArgumentListCons(h1, t1), ArgumentListCons(h2, t2)) => for { - es0 <- extractArgss0(h1, h2, acc) - es1 <- extractArgss0(t1, t2, acc)(es0) - } yield es1 - - case (ArgumentCons(h1, t1), ArgumentCons(h2, t2)) => for { - es0 <- extractArgss0(h1, h2, acc) - es1 <- extractArgss0(t1, t2, acc)(es0) - } yield es1 - - case (SplicedArgument(arg1), SplicedArgument(arg2)) => extractWithState(arg1, arg2) - case (sa: SplicedArgument, ArgumentCons(h, t)) => extractArgss0(sa, t, h :: acc) - case (sa: SplicedArgument, r: Rep) => extractArgss0(sa, NoArguments, r :: acc) - case (SplicedArgument(arg), NoArguments) => es updateExtractWith spliceExtract(arg, Args(acc.reverse: _*)) - case (r1: Rep, r2: Rep) => extractWithState(r1, r2) - case (NoArguments, NoArguments) => Right(es) - case (NoArgumentLists, NoArgumentLists) => Right(es) - case _ => Left(es) - } - - extractArgss0(argss1, argss2, Nil) + def extractArgss(argss1: ArgumentLists, argss2: ArgumentLists)(implicit es: State): ExtractState = (argss1, argss2) match { + case (ArgumentListCons(h1, t1), ArgumentListCons(h2, t2)) => for { + es0 <- extractArgss(h1, h2) + es1 <- extractArgss(t1, t2)(es0) + } yield es1 + + case (ArgumentCons(h1, t1), ArgumentCons(h2, t2)) => for { + es0 <- extractArgss(h1, h2) + es1 <- extractArgss(t1, t2)(es0) + } yield es1 + + case (SplicedArgument(arg1), SplicedArgument(arg2)) => extractWithState(arg1, arg2) + case (SplicedArgument(arg), ac: ArgumentCons) => es updateExtractWith spliceExtract(arg, Args(ac.argssList: _*)) + case (SplicedArgument(arg), r: Rep) => es updateExtractWith spliceExtract(arg, Args(r)) + case (SplicedArgument(_), NoArguments) => Right(es) + + case (r1: Rep, r2: Rep) => extractWithState(r1, r2) + + case (NoArguments, NoArguments) => Right(es) + + case (NoArgumentLists, NoArgumentLists) => Right(es) + + case _ => Left(es) } for { From a9f83e59a4c9c49fce7e89dfd0e6faf34ea64f07 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Thu, 14 Dec 2017 00:00:14 +0100 Subject: [PATCH 53/66] Remove wrong and buggy check --- src/main/scala/squid/ir/fastanf/FastANF.scala | 32 ++----------------- 1 file changed, 3 insertions(+), 29 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index ad1e0926..28c224c8 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -887,32 +887,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi val invCtx = reverse(es.ctx) (ex._1.values ++ ex._3.values.flatten).forall(preCheckRep(Set.empty, invCtx, _)) } - - /** - * Final check after rewriting the program. - * Checks if all the BVs are declared and that the removed - * let-binding are not referenced anymore in the code. - */ - def check(declaredBVs: Set[BoundVal], matchedImpureBVs: Set[BoundVal])(r: Rep): Boolean = { - def checkDef(declaredBVs: Set[BoundVal], matchedImpureBVs: Set[BoundVal])(d: Def): Boolean = d match { - case ma: MethodApp => (ma.self :: ma.argss.argssList) forall { - case bv: BoundVal => (declaredBVs contains bv) || !(matchedImpureBVs contains bv) - case lb: LetBinding => check(declaredBVs + lb.bound, matchedImpureBVs)(lb) - case _ => true - } - case l: Lambda => - ((declaredBVs contains l.bound) || - !(matchedImpureBVs contains l.bound)) && - check(declaredBVs, matchedImpureBVs)(l.body) - case _ => true - } - - r match { - case lb: LetBinding => checkDef(declaredBVs + lb.bound, matchedImpureBVs)(lb.value) - case bv: BoundVal => (declaredBVs contains bv) || !(matchedImpureBVs contains bv) - case _ => true - } - } def cleanup(r: Rep, remove: Set[BoundVal])(ctx: Ctx): Rep = r match { case lb: LetBinding if remove contains lb.bound => cleanup(lb.body, remove)(ctx) @@ -973,9 +947,9 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi if (preCheck(es.ex)) for { code <- code(es.ex) alsoApply(c => println(s"CODE: $c")) - code0 = finalize(code, xtor, xtee)(es.ctx) alsoApply (c => println(s"CODE0: $c")) - if check(Set.empty, es.matchedImpureBVs)(cleanup(code0, es.matchedImpureBVs)(es.ctx)) - } yield code0 + code0 = finalize(code, xtor, xtee)(es.ctx) alsoApply (c => println(s"CODE0: $c")) + cleanCode = cleanup(code0, es.matchedImpureBVs)(es.ctx) + } yield cleanCode else None } From cdb21e30fa6b689e4c0d148a3e0f1429d25e3267 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Thu, 14 Dec 2017 00:01:34 +0100 Subject: [PATCH 54/66] Better function name --- src/main/scala/squid/ir/fastanf/FastANF.scala | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 28c224c8..48333582 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -482,23 +482,23 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def extractHOPHole(name: String, typ: TypeRep, argss: List[List[Rep]], visible: List[BoundVal])(implicit es: State): ExtractState = { println("EXTRACTINGHOPHOLE") - - def usesUndeclaredBVs(r: Rep): Boolean = { - def usesUndeclaredBVs0(r: Rep, declared: Set[BoundVal]): Boolean = r match { - case bv: BoundVal => !(declared contains bv) + + def hasNoUndeclaredUsages(r: Rep): Boolean = { + def hasNoUndeclaredUsages0(r: Rep, declared: Set[BoundVal]): Boolean = r match { + case bv: BoundVal => declared contains bv case lb: LetBinding => val declared0 = declared + lb.bound - defUsesUndeclaredBVs(lb.value, declared0) || usesUndeclaredBVs0(lb.body, declared0) - case _ => false + defHasNoUndeclaredUsages(lb.value, declared0) && hasNoUndeclaredUsages0(lb.body, declared0) + case _ => true } - def defUsesUndeclaredBVs(d: Def, declared: Set[BoundVal]): Boolean = d match { - case l: Lambda => usesUndeclaredBVs0(l.body, declared + l.bound) - case ma: MethodApp => (ma.self +: ma.argss.argssList) exists (usesUndeclaredBVs0(_, declared)) - case _ => false + def defHasNoUndeclaredUsages(d: Def, declared: Set[BoundVal]): Boolean = d match { + case l: Lambda => hasNoUndeclaredUsages0(l.body, declared + l.bound) + case ma: MethodApp => (ma.self :: ma.argss.argssList) forall (hasNoUndeclaredUsages0(_, declared)) + case _ => true } - usesUndeclaredBVs0(r, Set.empty) + hasNoUndeclaredUsages0(r, Set.empty) } /** @@ -593,7 +593,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi (f, es2) <- argss0.foldRight(Option(xtee -> (es withNewExtract m)))(buildFunc) - if !usesUndeclaredBVs(f) + if hasNoUndeclaredUsages(f) } yield es2 updateExtractWith Some(repExtract(name -> f)) maybeES getOrElse Left(es) From 3473bc561756959ea09a2854a3c4c2910a30f94c Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Fri, 15 Dec 2017 16:04:29 +0100 Subject: [PATCH 55/66] Fix issue where rewriting with code having unused statements would not be handled correctly --- src/main/scala/squid/ir/fastanf/FastANF.scala | 170 +++++++++++++----- .../squid/ir/fastir/RewritingTests.scala | 14 ++ 2 files changed, 138 insertions(+), 46 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 48333582..392d766b 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -637,16 +637,20 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } /** - * Attemps to extract `bv` by trying to match the BVs used in the [[MethodApp]] from left to right. + * Attemps to extract `r` by trying to match it with the component of `d`. + * Only extracts inside a [[MethodApp]], fails for all other cases. * It won't extract inside the [[ByName]] arguments. */ - def extractInside(bv: BoundVal, ma: MethodApp)(implicit es: State): ExtractState = - collectBVs(ma).foldLeft[ExtractState](Left(es)) { case (acc, bv2) => - for { - es1 <- acc.left - es2 <- extractWithState(bv, bv2)(es1).left - } yield es2 - } + def extractInside(r: Rep, d: Def)(implicit es: State): ExtractState = d match { + case ma: MethodApp => (ma.self :: ma.argss.argssList).foldLeft[ExtractState](Left(es)) { case (acc, arg) => + for { + es1 <- acc.left + es2 <- extractWithState(r, arg)(es1).left + } yield es2 + } + case _ => Left(es) + } + def extractedBy(h: Hole)(implicit es: State): Option[Rep] = es.ex._1 get h.name @@ -683,14 +687,11 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi // Attempts to extract the return value `bv` by trying // 1. The current let-binding - // 2. The BVs of the value of the let-binding if it's a [[MethodApp]] (see `extractInside`) + // 2. The components of the let-bindings value // 3. (1. and then 2.) on the next statement case PartialMatching => for { es1 <- extractWithState(bv, lb.bound).left - es2 <- (lb.value match { - case ma: MethodApp => extractInside(bv, ma)(es1) - case _ => Left(es) - }).left + es2 <- extractInside(bv, lb.value)(es1).left es3 <- extractWithState(bv, lb.body)(es2).left } yield es3 } @@ -703,8 +704,11 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (_: Rep, lb: LetBinding) => es.strategy match { - // `xtor` cannot be a let-binding or BV in this case so it only makes sense to extract the body. - case PartialMatching => extractWithState(xtor, lb.body) + case PartialMatching => for { + es1 <- extractInside(xtor, lb.value).left + es2 <- extractWithState(xtor, lb.body)(es1).left + } yield es2 + case CompleteMatching => Left(es) } @@ -740,7 +744,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } case Left(es) => Left(es withFailedMatch lb1.bound -> lb2.bound) } - case (l1: Lambda, l2: Lambda) => extractDefs(l1, l2) map (_ withCtx l1.bound -> l2.bound) + case (l1: Lambda, l2: Lambda) => + // We cannot know the owner of the [[Lambda]] + // so in case of a failure to extract there's nothing to do. + extractDefs(l1, l2) map (_ withCtx l1.bound -> l2.bound) case _ => Left(es withFailedMatch bv1 -> bv2) } } @@ -839,9 +846,27 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi println(s"rewriteRepWithState(\n\t$xtor\n\t$xtee)($es)") (xtor, xtee) match { - case (lb1: LetBinding, lb2: LetBinding) => + case (lb1: LetBinding, lb2: LetBinding) => + + /** + * Pure statements (annotated with the instruction [[Skip]] only have to extracted starting from their + * return value and extract each sub-part recursively. Through this mechanism the order of the pure statements + * does not matter. + * For instance, this will successfully match : + * {{{ + * ir"val b = 22.toDouble; val a = 11.toDouble; a + b" match { + * case ir"val aX = 11.toDouble; val bX = 22.toDouble; aX + bX" => ??? + * } + * }}} + * + */ (es.instructions.get(lb1.bound), es.instructions.get(lb2.bound)) match { - case (Skip, Skip) => Left(es) + /** + * The traversal of the code is done externally by [[transformRep()]]. + * Hence, if the current statements don't have to be extracted at this point (both are pure and not return values) + * we simply skip the extraction of the current `xtee`. + */ + case (Skip, Skip) => Left(es) case _ => extractWithState(lb1, lb2) } case _ => extractWithState(xtor, xtee) @@ -851,8 +876,20 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def genCode(implicit es: State): Option[Rep] = { /** - * First sanity check on the extraction. - * Checks if the BVs in the extract are declared orwere defined by the user. + * Returns true if all the BV usages appearing in the extraction are declared inside the extraction and, if not, + * that it has been declared by the user. + * For instance: + * {{{ + * ir"val r = readInt; r + 1" rewrite { + * case ir"val rX = readInt; $body" => ??? + * } + * }}} + * `$body` will extract `ir"r + 1"` where `r` <-> `rX`. Here the let-binding `val rX = readInt; ...` + * is user-defined. Therefore, on the rhs of the rewriting rule the user can still write valid code + * (e.g. `ir"val r ` + * + * TODO ask about this + * */ def preCheck(ex: Extract): Boolean = { def preCheckRep(declaredBVs: Set[BoundVal], invCtx: Map[BoundVal, Set[BoundVal]], r: Rep): Boolean = { @@ -887,45 +924,86 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi val invCtx = reverse(es.ctx) (ex._1.values ++ ex._3.values.flatten).forall(preCheckRep(Set.empty, invCtx, _)) } - - def cleanup(r: Rep, remove: Set[BoundVal])(ctx: Ctx): Rep = r match { - case lb: LetBinding if remove contains lb.bound => cleanup(lb.body, remove)(ctx) - - case lb: LetBinding if isPure(lb.value) => - val bvsInValue = lb.value |>? { - case ma: MethodApp => collectBVs(ma) - } getOrElse ListSet.empty - - if (bvsInValue exists (remove contains)) { - cleanup(lb.body, remove ++ bvsInValue.toSet)(ctx) - } else { - lb.body = cleanup(lb.body, remove)(ctx) + + /** + * Removes all the statements referencing elements from `remove`. + * Also takes care of statements referencing removed statements. + * + * Returns None, if `r`'s return value is also removed. + * For instance in {{{filterNot(ir"val a = readInt; a", Set(a)}}} + */ + def filterNot(remove: Set[BoundVal])(r: Rep): Option[Rep] = { + /** + * Returns the statements to remove based on the initial `remove` set. + * + * Adds all the pure statements referencing removed ones. + * TODO add mathematical term + */ + def buildToRemove(r: Rep, remove: Set[BoundVal]): Set[BoundVal] = r match { + case lb: LetBinding if remove contains lb.bound => buildToRemove(lb.body, remove) + + case lb: LetBinding if isPure(lb.value) => + val bvsInValue = lb.value |>? { + case ma: MethodApp => collectBVs(ma) + } getOrElse ListSet.empty + + if (bvsInValue exists (remove contains)) { + buildToRemove(lb.body, remove + lb.bound) + } else { + buildToRemove(lb.body, remove) + } + + case lb: LetBinding => buildToRemove(lb.body, remove) + + case _ => remove + } + + def filterNot0(remove: Set[BoundVal])(r: Rep): Rep = r match { + case lb: LetBinding if remove contains lb.bound => filterNot0(remove)(lb.body) + case lb: LetBinding => + lb.body = filterNot0(remove)(lb.body) lb + case _ => r + } + + val remove0 = buildToRemove(r, remove) + r match { + case lb: LetBinding => lb.last.body match { + case bv: BoundVal if remove0 contains bv => None + case _ => Some(filterNot0(remove0)(r)) } - - case lb: LetBinding => lb.body = cleanup(lb.body, remove)(ctx) - lb - - case _ => r + case _ => Some(r) // Nothing to do + } } - def finalize(code: Rep, xtor: Rep, filteredXtee: Rep)(ctx: Ctx): Rep = xtor match { + + /** + * Merges the generated `code` with the `xtee` + */ + def merge(code: Rep, xtee: Rep)(xtor: Rep, ctx: Ctx): Rep = xtor match { + // Find what the return value of `xtor` matched in `xtee` and + // replace that by the return value of `code`. case xtorLB: LetBinding => val xtorLast = xtorLB.last xtorLast.body match { case xtorRet: BoundVal => code match { case codeLB: LetBinding => val codeLast = codeLB.last - codeLast.body |>? { + codeLast.body match { case codeRet: BoundVal => val bv = ctx(xtorRet) - codeLast.body = bottomUpPartial(filteredXtee) { case `bv` => codeRet } + codeLast.body = bottomUpPartial(xtee) { case `bv` => codeRet } + + // When code: `ir"val a = readInt; 20"`, + case _ => + val bv = ctx(xtorRet) + codeLast.body = bottomUpPartial(xtee) { case `bv` => codeLast.body } } code case _ => val bv = ctx(xtorRet) - bottomUpPartial(filteredXtee) { case `bv` => code } + bottomUpPartial(xtee) { case `bv` => code } } // Hole? @@ -937,18 +1015,18 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi val codeLast = codeLB.last codeLast.body |>? { case codeRet: BoundVal => - codeLast.body = bottomUpPartial(filteredXtee) { case `xtor` => codeRet } + codeLast.body = bottomUpPartial(xtee) { case `xtor` => codeRet } } code - case _ => bottomUpPartial(filteredXtee) { case `xtor` => code } + case _ => bottomUpPartial(xtee) { case `xtor` => code } } } if (preCheck(es.ex)) for { code <- code(es.ex) alsoApply(c => println(s"CODE: $c")) - code0 = finalize(code, xtor, xtee)(es.ctx) alsoApply (c => println(s"CODE0: $c")) - cleanCode = cleanup(code0, es.matchedImpureBVs)(es.ctx) + code0 = merge(code, xtee)(xtor, es.ctx) alsoApply (c => println(s"CODE0: $c")) + cleanCode <- filterNot(es.matchedImpureBVs)(code0) } yield cleanCode else None } diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index eb5625b7..5499ce5c 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -169,6 +169,20 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { assert(a =~= ir"f({val a = 10.toDouble; println(42.toDouble); val b = 10.toDouble; 42.toDouble})") } + test("Rewriting with unused statements should inline them") { + val a = ir"val r = readDouble; val a = readInt; val b = r.toInt; b" rewrite { + case ir"readDouble" => ir"val r = readInt; val a = r + 1; 20d" + } + assert(a =~= ir"val r = readInt; val a2 = r + 1; val a1 = readInt; val b = 20d.toInt; b") + } + + test("Rewriting should not duplicate effects") { + val a = ir"val r = readDouble; val b = r.toInt; val c = r.toDouble; b + c" rewrite { + case ir"readDouble" => ir"val r = readInt; val a = r + 1; 20d" + } + assert(a =~= ir"val r = readInt; val a = r + 1; 20d.toInt + 20d.toDouble") + } + test("Squid paper") { // 3.4 From 06c3f8a3a4b6f2c741fb9c57422bb7ad23c0e129 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Sat, 16 Dec 2017 23:00:13 +0100 Subject: [PATCH 56/66] Simplify merging of xtee and generated code --- src/main/scala/squid/ir/fastanf/FastANF.scala | 20 +++++-------------- .../squid/ir/fastir/RewritingTests.scala | 12 +++++++++++ 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 392d766b..89d847e2 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -989,16 +989,9 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case xtorRet: BoundVal => code match { case codeLB: LetBinding => val codeLast = codeLB.last - codeLast.body match { - case codeRet: BoundVal => - val bv = ctx(xtorRet) - codeLast.body = bottomUpPartial(xtee) { case `bv` => codeRet } - - // When code: `ir"val a = readInt; 20"`, - case _ => - val bv = ctx(xtorRet) - codeLast.body = bottomUpPartial(xtee) { case `bv` => codeLast.body } - } + val codeRet = codeLast.body + val bv = ctx(xtorRet) + codeLast.body = bottomUpPartial(xtee) { case `bv` => codeRet } code case _ => @@ -1006,17 +999,14 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi bottomUpPartial(xtee) { case `bv` => code } } - // Hole? case _ => code } case _ => code match { case codeLB: LetBinding => val codeLast = codeLB.last - codeLast.body |>? { - case codeRet: BoundVal => - codeLast.body = bottomUpPartial(xtee) { case `xtor` => codeRet } - } + val codeRet = codeLast.body + codeLast.body = bottomUpPartial(xtee) { case `xtor` => codeRet } code case _ => bottomUpPartial(xtee) { case `xtor` => code } diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index 5499ce5c..11fe015d 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -183,6 +183,18 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { assert(a =~= ir"val r = readInt; val a = r + 1; 20d.toInt + 20d.toDouble") } + test("Rewriting a non-letbinding with a let-binding") { + val a = ir"val a = 20.toDouble; a + 1" rewrite { + case ir"20" => ir"readInt" + } + assert(a =~= ir"val r = readInt; val a = r.toDouble; a + 1") + + val b = ir"val a = 20.toDouble; a + 1" rewrite { + case ir"20" => ir"val t = readInt; 40" + } + assert(b =~= ir"val t = readInt; val a = 40.toDouble; a + 1") + } + test("Squid paper") { // 3.4 From 6f370951e2d0b7abb634c027d7e70233803835ec Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Sun, 17 Dec 2017 10:42:45 +0100 Subject: [PATCH 57/66] Fix documentation and cleanup --- src/main/scala/squid/ir/fastanf/FastANF.scala | 90 +++++++++---------- src/main/scala/squid/ir/fastanf/Rep.scala | 8 +- 2 files changed, 51 insertions(+), 47 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 89d847e2..c663417c 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -7,6 +7,10 @@ import squid.utils._ import scala.collection.immutable.{ListMap, ListSet} +/** + * ANF representation of the code. + * The IR is mutable in order to be able to do O(1) substitutions. + */ class FastANF extends InspectableBase with CurryEncoding with StandardEffects with ScalaCore { private[this] implicit val base = this @@ -100,9 +104,9 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def module(prefix: Rep, name: String, typ: TypeRep): Rep = Module(prefix, name, typ) def newObject(typ: TypeRep): Rep = NewObject(typ) def methodApp(self: Rep, mtd: MtdSymbol, targs: List[TypeRep], argss: List[ArgList], tp: TypeRep): Rep = mtd match { + // Converts the call to `Imperative` to let-bindings case MethodSymbol(TypeSymbol("squid.lib.package$"), "Imperative") => argss match { case List(h, t) => - val firstArgss = h.reps val holes = h.reps.filter { case Hole(_, _) => true case _ => false @@ -115,18 +119,25 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } } - case _ => MethodApp(self |> inlineBlock, mtd, targs, argss |> toArgumentLists, tp) |> letbind } def byName(mkArg: => Rep): Rep = ByName(wrapNest(mkArg)) def letbind(d: Def): Rep = currentScope += d + + /** + * Adds all the statements of `r` to the current reification context. + * Returns the final non-statement unchanged. + */ def inlineBlock(r: Rep): Rep = r |>=? { case lb: LetBinding => currentScope += lb inlineBlock(lb.body) } + /** + * Let binds `value` to `bound` in the `body`. + */ override def letin(bound: BoundVal, value: Rep, body: => Rep, bodyType: TypeRep): Rep = value match { case s: Symbol => s.owner |>? { @@ -311,7 +322,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi protected def extract(xtor: Rep, xtee: Rep): Option[Extract] = { println(s"Extract(\n$xtor, \n$xtee)") - extractWithState(xtor, xtee)(State.forExtraction(xtor, xtee)).fold(_ => None, Some(_)) map (_.ex) + extractWithState(xtor, xtee)(State.forExtraction(xtor, xtee)).toOption map (_.ex) } type Ctx = Map[BoundVal, BoundVal] @@ -332,10 +343,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi * Represents the current of the state of the extraction. * @param ex what has been extracted by holes. * @param ctx discovered matchings between bound values in the `xtor` and the `xtee`. - * @param instructions - * @param matchedImpureBVs - * @param failedMatches - * @param strategy + * @param instructions what has to be done for each let-binding + * @param matchedImpureBVs impure statements that have been matched + * @param failedMatches statements that do not match + * @param strategy see [[Strategy]] */ case class State(ex: Extract, ctx: Ctx, instructions: Instructions, matchedImpureBVs: Set[BoundVal], failedMatches: Map[BoundVal, Set[BoundVal]], strategy: Strategy) { @@ -481,8 +492,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi println(es) def extractHOPHole(name: String, typ: TypeRep, argss: List[List[Rep]], visible: List[BoundVal])(implicit es: State): ExtractState = { - println("EXTRACTINGHOPHOLE") - def hasNoUndeclaredUsages(r: Rep): Boolean = { def hasNoUndeclaredUsages0(r: Rep, declared: Set[BoundVal]): Boolean = r match { case bv: BoundVal => declared contains bv @@ -696,7 +705,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } yield es3 } - case (Constant(()), lb: LetBinding) => es.strategy match { + case (Constant(()), _: LetBinding) => es.strategy match { // Unit return doesn't have to matched case PartialMatching => Right(es) case CompleteMatching => Left(es) @@ -726,13 +735,9 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (HOPHole(name, typ, argss, visible), _) => extractHOPHole(name, typ, argss, visible) // The actual extraction happens here. - case (bv1: BoundVal, bv2: BoundVal) => - println(s"OWNERS: ${bv1.owner} -- ${bv2.owner}") - println(s"STATE: $es") - - es.ctx.get(bv1) map { bv => - if (bv != bv2) Left(es) // `bv1` has already extracted something else - else Right(es) + case (bv1: BoundVal, bv2: BoundVal) => es.ctx.get(bv1) map { extractedByBV1 => + if (extractedByBV1 == bv2) Right(es) + else Left(es) } getOrElse { if (bv1 == bv2) Right(es) else if (es.failedMatches(bv1) contains bv2) Left(es) // Previously failed to extract `bv1` with `bv2` @@ -939,23 +944,31 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi * Adds all the pure statements referencing removed ones. * TODO add mathematical term */ - def buildToRemove(r: Rep, remove: Set[BoundVal]): Set[BoundVal] = r match { - case lb: LetBinding if remove contains lb.bound => buildToRemove(lb.body, remove) - - case lb: LetBinding if isPure(lb.value) => - val bvsInValue = lb.value |>? { - case ma: MethodApp => collectBVs(ma) - } getOrElse ListSet.empty - - if (bvsInValue exists (remove contains)) { - buildToRemove(lb.body, remove + lb.bound) - } else { - buildToRemove(lb.body, remove) + def buildToRemove(r: Rep, remove: Set[BoundVal]): Set[BoundVal] = { + def collectBVs(ma: MethodApp): ListSet[BoundVal] = + (ma.self :: ma.argss.argssList).foldLeft(ListSet.empty[BoundVal]) { + case (acc, bv: BoundVal) => acc + bv + case (acc, _) => acc } + + r match { + case lb: LetBinding if remove contains lb.bound => buildToRemove(lb.body, remove) - case lb: LetBinding => buildToRemove(lb.body, remove) + case lb: LetBinding if isPure(lb.value) => + val bvsInValue = lb.value |>? { + case ma: MethodApp => collectBVs(ma) + } getOrElse ListSet.empty - case _ => remove + if (bvsInValue exists (remove contains)) { + buildToRemove(lb.body, remove + lb.bound) + } else { + buildToRemove(lb.body, remove) + } + + case lb: LetBinding => buildToRemove(lb.body, remove) + + case _ => remove + } } def filterNot0(remove: Set[BoundVal])(r: Rep): Rep = r match { @@ -1035,22 +1048,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi lb case _ => r } - - def collectLBBounds(r: Rep): ListSet[BoundVal] = { - def collectLBBounds0(r: Rep, acc: ListSet[BoundVal]): ListSet[BoundVal] = r match { - case lb: LetBinding => collectLBBounds0(lb.body, acc + lb.bound) - case _ => acc - } - - collectLBBounds0(r, ListSet.empty) - } - def collectBVs(ma: MethodApp): ListSet[BoundVal] = - (ma.self :: ma.argss.argssList).foldLeft(ListSet.empty[BoundVal]) { - case (acc, bv: BoundVal) => acc + bv - case (acc, _) => acc - } - // * --- * --- * --- * Implementations of `QuasiBase` methods * --- * --- * --- * def hole(name: String, typ: TypeRep) = Hole(name, typ) diff --git a/src/main/scala/squid/ir/fastanf/Rep.scala b/src/main/scala/squid/ir/fastanf/Rep.scala index 6adf87dc..ab7eb199 100644 --- a/src/main/scala/squid/ir/fastanf/Rep.scala +++ b/src/main/scala/squid/ir/fastanf/Rep.scala @@ -105,7 +105,10 @@ object MethodApp { case _ => SimpleMethodApp(self: Rep, mtd: MethodSymbol, targs: List[TypeRep], argss)(typ) } } - + + /** + * Inlines `self` and the arguments `argss` if they are let-bindings in order to transform it into valid ANF. + */ def toANF(self: Rep, mtd: MethodSymbol, targs: List[TypeRep], argss: ArgumentLists, typ0: TypeRep)(implicit base: FastANF): Rep = { def processArgss(argss: ArgumentLists)(f: Rep => (Option[LetBinding], Rep)): (Option[LetBinding], ArgumentLists) = { def processArgs(args: ArgumentList)(f: Rep => (Option[LetBinding], Rep)): (Option[LetBinding], ArgumentList) = { @@ -295,6 +298,9 @@ class LetBinding(var name: String, var bound: Symbol, var value: Def, private va override def toString: String = s"val $bound = $value; $body" } object LetBinding { + /** + * Inlines the `value` to transform it into valid ANF. + */ def withRepValue(name: String, bound: Symbol, value: Rep, mkBody: => Rep): Rep = value match { case lb: LetBinding => val last = lb.last From 0ada46482afb7cccb4df0aaf8672f04321aa4c4e Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Sun, 17 Dec 2017 10:56:15 +0100 Subject: [PATCH 58/66] Move filterLBs to call-site --- src/main/scala/squid/ir/fastanf/FastANF.scala | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index c663417c..9f9a167c 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -563,11 +563,19 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi */ def replaceAllOccurrences(body: Rep)(es: State): Rep -> State = { def replaceAllOccurrences0(body: Rep)(implicit es: State): Rep -> State = { - - /* - * Extracts the function body with the xtor in order to be able to use the context `ctx` to know - * what to replace with the new argument. - */ + def filterLBs(r: Rep)(p: LetBinding => Boolean): Rep = r match { + case lb: LetBinding if p(lb) => + filterLBs(lb.body)(p) + case lb: LetBinding => + lb.body = filterLBs(lb.body)(p) + lb + case _ => r + } + + /* + * Extracts the function body with the xtor in order to be able to use the context `ctx` to know + * what to replace with the new argument. + */ extractWithState(lb, body) map { es0 => val replace = es0.ctx(lb.last.bound) val body0 = bottomUpPartial(filterLBs(body)(es0.ctx.values.toSet contains _.bound)) { case `replace` => arg } @@ -1039,15 +1047,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case Left(_) => None } } - - def filterLBs(r: Rep)(p: LetBinding => Boolean): Rep = r match { - case lb: LetBinding if p(lb) => - filterLBs(lb.body)(p) - case lb: LetBinding => - lb.body = filterLBs(lb.body)(p) - lb - case _ => r - } // * --- * --- * --- * Implementations of `QuasiBase` methods * --- * --- * --- * From 7f4388ba56f0f3c62a39d96dcc5c874c1f1e61bd Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Sun, 17 Dec 2017 14:36:23 +0100 Subject: [PATCH 59/66] HOPV foldLeft extraction test --- src/test/scala/squid/ir/fastir/HOPVTests.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/test/scala/squid/ir/fastir/HOPVTests.scala b/src/test/scala/squid/ir/fastir/HOPVTests.scala index 426690cb..dfd818ba 100644 --- a/src/test/scala/squid/ir/fastir/HOPVTests.scala +++ b/src/test/scala/squid/ir/fastir/HOPVTests.scala @@ -95,6 +95,11 @@ class HOPVTests extends MyFunSuiteBase(HOPVTests.Embedding) { ir"(a: Int, b: Int) => readInt + (a + b)" match { case ir"(x: Int, y: Int) => $body(readInt, x + y): Int" => assert(body =~= ir"(r: Int, s: Int) => r + s") } + + ir"List(1,2,3).foldLeft(0)((acc,x) => acc+x)" match { + case ir"$foldLeft(List(1, 2, 3), 0, (acc: Int, x: Int) => acc + x)" => + assert(foldLeft =~= ir"(l: List[Int], z: Int, f: (Int, Int) => Int) => l.foldLeft(z)(f)") + } } test("Match letbindings") { From 6b50f95f996e116c7de3d6a0d062b6c2444bbf8e Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Sun, 17 Dec 2017 14:37:26 +0100 Subject: [PATCH 60/66] Fix issue where generated code would not appear in the right position in xtee --- src/main/scala/squid/ir/fastanf/Effects.scala | 1 + src/main/scala/squid/ir/fastanf/FastANF.scala | 140 ++++++++++++------ .../squid/ir/fastir/RewritingTests.scala | 12 ++ 3 files changed, 107 insertions(+), 46 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/Effects.scala b/src/main/scala/squid/ir/fastanf/Effects.scala index 53bd42ff..c9941588 100644 --- a/src/main/scala/squid/ir/fastanf/Effects.scala +++ b/src/main/scala/squid/ir/fastanf/Effects.scala @@ -83,4 +83,5 @@ trait StandardEffects extends Effects { addPureMtd(MethodSymbol(TypeSymbol("squid.lib.package$"),"uncurried2")) addPureMtd(MethodSymbol(TypeSymbol("scala.collection.LinearSeqOptimized"),"foldLeft")) addPureMtd(MethodSymbol(TypeSymbol("scala.collection.immutable.List$"),"apply")) + addPureMtd(MethodSymbol(TypeSymbol("scala.collection.immutable.List$"),"canBuildFrom")) } diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 9f9a167c..d056e7be 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -938,6 +938,14 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi (ex._1.values ++ ex._3.values.flatten).forall(preCheckRep(Set.empty, invCtx, _)) } + def collectBVs(d: Def): Set[BoundVal] = d match { + case ma: MethodApp => (ma.self :: ma.argss.argssList).foldLeft(ListSet.empty[BoundVal]) { + case (acc, bv: BoundVal) => acc + bv + case (acc, _) => acc + } + case _ => Set.empty + } + /** * Removes all the statements referencing elements from `remove`. * Also takes care of statements referencing removed statements. @@ -952,31 +960,21 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi * Adds all the pure statements referencing removed ones. * TODO add mathematical term */ - def buildToRemove(r: Rep, remove: Set[BoundVal]): Set[BoundVal] = { - def collectBVs(ma: MethodApp): ListSet[BoundVal] = - (ma.self :: ma.argss.argssList).foldLeft(ListSet.empty[BoundVal]) { - case (acc, bv: BoundVal) => acc + bv - case (acc, _) => acc - } - - r match { - case lb: LetBinding if remove contains lb.bound => buildToRemove(lb.body, remove) + def buildToRemove(r: Rep, remove: Set[BoundVal]): Set[BoundVal] = r match { + case lb: LetBinding if remove contains lb.bound => buildToRemove(lb.body, remove) - case lb: LetBinding if isPure(lb.value) => - val bvsInValue = lb.value |>? { - case ma: MethodApp => collectBVs(ma) - } getOrElse ListSet.empty + case lb: LetBinding if isPure(lb.value) => + val bvsInValue = collectBVs(lb.value) - if (bvsInValue exists (remove contains)) { - buildToRemove(lb.body, remove + lb.bound) - } else { - buildToRemove(lb.body, remove) - } + if (bvsInValue exists (remove contains)) { + buildToRemove(lb.body, remove + lb.bound) + } else { + buildToRemove(lb.body, remove) + } - case lb: LetBinding => buildToRemove(lb.body, remove) + case lb: LetBinding => buildToRemove(lb.body, remove) - case _ => remove - } + case _ => remove } def filterNot0(remove: Set[BoundVal])(r: Rep): Rep = r match { @@ -1001,36 +999,86 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi /** * Merges the generated `code` with the `xtee` */ - def merge(code: Rep, xtee: Rep)(xtor: Rep, ctx: Ctx): Rep = xtor match { - // Find what the return value of `xtor` matched in `xtee` and - // replace that by the return value of `code`. - case xtorLB: LetBinding => - val xtorLast = xtorLB.last - xtorLast.body match { - case xtorRet: BoundVal => code match { - case codeLB: LetBinding => - val codeLast = codeLB.last - val codeRet = codeLast.body - val bv = ctx(xtorRet) - codeLast.body = bottomUpPartial(xtee) { case `bv` => codeRet } - code - - case _ => - val bv = ctx(xtorRet) - bottomUpPartial(xtee) { case `bv` => code } + def merge(code: Rep, xtee: Rep)(xtor: Rep, ctx: Ctx): Rep = { + /** + * Puts `code` in the right position in `xtee`. + */ + def mergeLBs(code: LetBinding, xtee: LetBinding)(es: State): Rep = { + def collectAllBVs(r: Rep): Set[BoundVal] = { + def collectAllBVs0(r: Rep, acc: Set[BoundVal]): Set[BoundVal] = r match { + case lb: LetBinding => collectAllBVs0(lb.body, acc ++ collectBVs(lb.value)) + case _ => acc } + collectAllBVs0(r, Set.empty) + } - case _ => code + def findPos(r: LetBinding, lookFor: Set[BoundVal]): Option[LetBinding] = { + if (lookFor.isEmpty) None + else r match { + case lb: LetBinding => + val lookFor0 = lookFor -- collectBVs(lb.value) + if (lookFor0.isEmpty) Some(lb) + else lb.body match { + case innerLB: LetBinding => findPos(innerLB, lookFor0) + case _ => None + } + case _ => None + } } - case _ => code match { - case codeLB: LetBinding => - val codeLast = codeLB.last - val codeRet = codeLast.body - codeLast.body = bottomUpPartial(xtee) { case `xtor` => codeRet } - code + // All usages of BVs that come from the xtee + val lookFor = collectAllBVs(code).filter((es.ex._1.values ++ es.ex._3.values).toSet contains) + + findPos(xtee, lookFor) match { + case Some(pos) => + code.last.body = pos.body + pos.body = code + xtee + case None => + code.last.body = xtee + code + } + } - case _ => bottomUpPartial(xtee) { case `xtor` => code } + xtor match { + // Find what the return value of `xtor` matched in `xtee` and + // replace that by the return value of `code`. + case xtorLB: LetBinding => + val xtorLast = xtorLB.last + xtorLast.body match { + case xtorRet: BoundVal => code match { + case codeLB: LetBinding => + val codeLast = codeLB.last + val codeRet = codeLast.body + val bv = ctx(xtorRet) + bottomUpPartial(xtee) { case `bv` => codeRet } match { + case xteeLB: LetBinding => mergeLBs(codeLB, xteeLB)(es) + case r => + codeLast.body = r + code + } + + case _ => + val bv = ctx(xtorRet) + bottomUpPartial(xtee) { case `bv` => code } + } + + case _ => code + } + + case _ => code match { + case codeLB: LetBinding => + val codeLast = codeLB.last + val codeRet = codeLast.body + bottomUpPartial(xtee) { case `xtor` => codeRet } match { + case xteeLB: LetBinding => mergeLBs(codeLB, xteeLB)(es) + case r => + codeLast.body = r + code + } + + case _ => bottomUpPartial(xtee) { case `xtor` => code } + } } } diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index 11fe015d..d6cc3875 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -195,6 +195,18 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { assert(b =~= ir"val t = readInt; val a = 40.toDouble; a + 1") } + test("Rewriting should insert the generated code at the right position") { + val a = ir"val f1 = (x: Int) => x + 1; val f2 = (x: Int) => x * 2; f1(f2(5))" rewrite { + case ir"($f1: Int => Int)(($f2: Int => Int)(5))" => ir"($f1 compose $f2)(5)" + } + assert(a =~= ir"val f1 = (x: Int) => x + 1; val f2 = (x: Int) => x * 2; val f3 = f1 compose f2; f3(5)") + + val b = ir"List(1,2,3).map(_ * 2).map(_ * 4)" rewrite { + case ir"($l: List[Int]).map($f1: Int => Int).map($f2: Int => Int)" => ir"$l.map($f1 compose $f2)" + } + assert(b =~= ir"val l = List(1, 2, 3); val f1 = (x: Int) => x * 2; val cbf1 = List.canBuildFrom; val f2 = (x: Int) => x * 4; val cbf2 = List.canBuildFrom; val f3 = f1 compose f2; l.map(f3)") + } + test("Squid paper") { // 3.4 From 0b58dabc9a406b4bf07abc61bf483658921a2135 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Mon, 18 Dec 2017 15:38:43 +0100 Subject: [PATCH 61/66] Add HOPV test, documentation, and cleaning --- src/main/scala/squid/ir/fastanf/FastANF.scala | 140 ++++++++++-------- .../scala/squid/ir/fastir/HOPVTests.scala | 5 + 2 files changed, 85 insertions(+), 60 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index d056e7be..9bf1506d 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -27,7 +27,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi // * --- * --- * --- * Reification * --- * --- * --- * var scopes: List[ReificationContext] = Nil - + + // TODO what happens here? It looks like `finalize` doesn't anything but run the thunk `r`. But if I remove `scp` it doesn't work @inline final def wrap(r: => Rep, inXtor: Bool): Rep = { val scp = new ReificationContext(inXtor) scopes ::= scp @@ -122,7 +123,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => MethodApp(self |> inlineBlock, mtd, targs, argss |> toArgumentLists, tp) |> letbind } def byName(mkArg: => Rep): Rep = ByName(wrapNest(mkArg)) - + + /** + * Let-bind `d` and add it the current reification scope. + */ def letbind(d: Def): Rep = currentScope += d /** @@ -156,6 +160,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi //} //bound rebind s //body + case lb: LetBinding => // conceptually, does like `inlineBlock`, but additionally rewrites `bound` and renames `lb`'s last binding val last = lb.last @@ -164,25 +169,20 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi last.body = body last.name = boundName // TODO make sure we're only renaming an automatically-named binding? lb - // case c: Constant => bottomUpPartial(body) { case `bound` => c } - case h: Hole => - //Wrap construct? How? - - // letin(x, Hole, Constant(20)) => `val tmp = defHole; 20;` + case h: Hole => val dh = DefHole(h) |> letbind - - //(dh |>? { - // case bv: BoundVal => bv.owner |>? { - // case lb: LetBinding => - // lb.body = body - // lb - // } - //}).flatten.getOrElse(body) - - //new LetBinding(bound.name, bound, dh, body) alsoApply (currentScope += _) alsoApply (bound.rebind) withSubs(bound -> dh)(body) + //(dh |>? { + // case bv: BoundVal => bv.owner |>? { + // case lb: LetBinding => + // lb.body = body + // lb + // } + //}).flatten.getOrElse(body) + + //new LetBinding(bound.name, bound, dh, body) alsoApply (currentScope += _) alsoApply (bound.rebind) case (_:HOPHole) | (_:HOPHole2) | (_:SplicedHole) => ??? // TODO holes should probably be Def's; note that it's not safe to do a substitution for holes @@ -193,6 +193,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } var curSub: Map[Symbol,Rep] = Map.empty + + /** + * Substitutes the [[Symbol]]s in `k` with the based on the mappings in [[curSub]] and `subs`. + */ def withSubs[R](subs: Symbol -> Rep)(k: => R): R = { val oldSub = curSub curSub += subs @@ -220,9 +224,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => None } } - - // /** Artifact of a term extraction: map from hole name to terms, types and spliced term lists */ - + def repEq(a: Rep, b: Rep): Boolean = (a extractRep b) === Some(EmptyExtract) && (b extractRep a) === Some(EmptyExtract) @@ -316,7 +318,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case ByName(r) => ByName(transformRep0(r)) case Ascribe(s, t) => Ascribe(transformRep0(s), t) case Module(p, n, t) => Module(transformRep0(p), n, t) - case r @ ((_:Constant) | (_:Hole) | (_:Symbol) | (_:SplicedHole) | (_:HOPHole) | (_:HOPHole2) | (_:NewObject) | (_:StaticModule)) => r + case r => r }) } @@ -325,11 +327,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi extractWithState(xtor, xtee)(State.forExtraction(xtor, xtee)).toOption map (_.ex) } + // Context is mapping from xtor BVs to xtee BVs type Ctx = Map[BoundVal, BoundVal] - def reverse[A, B](m: Map[A, B]): Map[B, Set[A]] = m.groupBy(_._2).mapValues(_.keys.toSet) - def updateWith(m: Map[BoundVal, Set[BoundVal]])(u: (BoundVal, BoundVal)): Map[BoundVal, Set[BoundVal]] = u match { - case (k, v) => m + (k -> (m(k) + v)) - } // * --- * --- * --- * Extraction State * --- * --- * --- * @@ -378,6 +377,11 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def updateExtractWith(e: Option[Extract]*)(implicit default: State): ExtractState = { mergeAll(Some(ex) +: e).fold[ExtractState](Left(default))(ex => Right(this withNewExtract ex)) } + + // Add the entry `u` to the immutable multimap `m` + private def updateWith(m: Map[BoundVal, Set[BoundVal]])(u: (BoundVal, BoundVal)): Map[BoundVal, Set[BoundVal]] = u match { + case (k, v) => m + (k -> (m.getOrElse(k, Set.empty) + v)) + } } object State { @@ -386,6 +390,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi private def apply(xtor: Rep, xtee: Rep, strategy: Strategy): State = State(EmptyExtract, ListMap.empty, Instructions(xtor, xtee), Set.empty, Map.empty.withDefaultValue(Set.empty), strategy) + // \_ Assumed by `extractWithState` } @@ -426,7 +431,9 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi /** * Instructs the extraction has to look for a matching statements. * This instruction is attached to impure statements as well as pure statements - * that are not used by other statements. + * that are not used by other statements. Giving this instructions to those two cases + * means that the unused statements are handled the same way. So, even though they are pure, + * unused statements will be matched in order. */ case object Start extends Instruction @@ -552,8 +559,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi * Assumes the bv is already in the context. * This is enforced by how the instructions are chosen. */ - val replace = es.ctx(bv) - replace rebind arg + es.ctx(bv) rebind arg body -> es case lb: LetBinding => @@ -564,8 +570,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def replaceAllOccurrences(body: Rep)(es: State): Rep -> State = { def replaceAllOccurrences0(body: Rep)(implicit es: State): Rep -> State = { def filterLBs(r: Rep)(p: LetBinding => Boolean): Rep = r match { - case lb: LetBinding if p(lb) => - filterLBs(lb.body)(p) + case lb: LetBinding if p(lb) => filterLBs(lb.body)(p) case lb: LetBinding => lb.body = filterLBs(lb.body)(p) lb @@ -573,8 +578,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } /* - * Extracts the function body with the xtor in order to be able to use the context `ctx` to know - * what to replace with the new argument. + * Replaces every occurrence of `lb` in the body with `arg`. */ extractWithState(lb, body) map { es0 => val replace = es0.ctx(lb.last.bound) @@ -583,17 +587,22 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } getOrElse body -> es } + // We only want to extract `lb` and not rewrite it so the strategy is changed to `PartialMatching`. replaceAllOccurrences0(body)(es withInstructionsFor lb withStrategy PartialMatching) - .mapSecond(_.withDefaultInstructions.withDefaultStrategy) + .mapSecond(_.withDefaultInstructions.withDefaultStrategy) // Reset the strategy so it resume doing a `CompleteMatching`. } replaceAllOccurrences(body)(es) + + // TODO implement this when we allow holes in HOPHoles + // case Hole(n, t) => ??? case _ => bottomUpPartial(body) { case `xtor` => arg } -> es } } - _ = bottomUpPartial(newBodyAndState._1) { case bv: BoundVal if visible contains bv => return None } // TODO is too early to check? If there are more xtors left + // The body of the extracted function should not contains any references to elements of `visible`. + _ = bottomUpPartial(newBodyAndState._1) { case bv: BoundVal if visible contains bv => return None } } yield newBodyAndState match { case (func0, es0) => wrapConstruct(lambda(args, func0)) -> es0 } @@ -604,7 +613,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi m <- merge(es.ex, es1) argss0 = argss.map(_.map { - case ByName(r) => r + case ByName(r) => r // TODO is it ok to unwrap by-name arguments? case r => r }) @@ -799,19 +808,16 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } def extractDefs(v1: Def, v2: Def)(implicit es: State): ExtractState = (v1, v2) match { - case (l1: Lambda, l2: Lambda) => - for { - es1 <- es updateExtractWith (l1.boundType extract(l2.boundType, Covariant)) - es2 <- extractWithState(l1.body, l2.body)(es1 withCtx l1.bound -> l2.bound withStrategy CompleteMatching) - } yield es2 withDefaultStrategy + case (l1: Lambda, l2: Lambda) => for { + es1 <- es updateExtractWith (l1.boundType extract(l2.boundType, Covariant)) + es2 <- extractWithState(l1.body, l2.body)(es1 withCtx l1.bound -> l2.bound withStrategy CompleteMatching) + } yield es2 withDefaultStrategy case (ma1: MethodApp, ma2: MethodApp) if ma1.mtd == ma2.mtd => def targExtract(es0: State): ExtractState = - es0.updateExtractWith( - (for { - (e1, e2) <- ma1.targs zip ma2.targs - } yield e1 extract(e2, Invariant)): _* - ) + es0.updateExtractWith((for { + (e1, e2) <- ma1.targs zip ma2.targs + } yield e1 extract(e2, Invariant)): _*) def extractArgss(argss1: ArgumentLists, argss2: ArgumentLists)(implicit es: State): ExtractState = (argss1, argss2) match { case (ArgumentListCons(h1, t1), ArgumentListCons(h2, t2)) => for { @@ -828,13 +834,12 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (SplicedArgument(arg), ac: ArgumentCons) => es updateExtractWith spliceExtract(arg, Args(ac.argssList: _*)) case (SplicedArgument(arg), r: Rep) => es updateExtractWith spliceExtract(arg, Args(r)) case (SplicedArgument(_), NoArguments) => Right(es) - - case (r1: Rep, r2: Rep) => extractWithState(r1, r2) - + case (NoArguments, NoArguments) => Right(es) - case (NoArgumentLists, NoArgumentLists) => Right(es) + case (r1: Rep, r2: Rep) => extractWithState(r1, r2) + case _ => Left(es) } @@ -874,14 +879,17 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi * */ (es.instructions.get(lb1.bound), es.instructions.get(lb2.bound)) match { + /** * The traversal of the code is done externally by [[transformRep()]]. * Hence, if the current statements don't have to be extracted at this point (both are pure and not return values) * we simply skip the extraction of the current `xtee`. */ case (Skip, Skip) => Left(es) + case _ => extractWithState(lb1, lb2) } + case _ => extractWithState(xtor, xtee) } } @@ -934,15 +942,14 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } } + def reverse[A, B](m: Map[A, B]): Map[B, Set[A]] = m.groupBy(_._2).mapValues(_.keys.toSet) + val invCtx = reverse(es.ctx) - (ex._1.values ++ ex._3.values.flatten).forall(preCheckRep(Set.empty, invCtx, _)) + ex._1.values ++ ex._3.values.flatten forall (preCheckRep(Set.empty, invCtx, _)) } def collectBVs(d: Def): Set[BoundVal] = d match { - case ma: MethodApp => (ma.self :: ma.argss.argssList).foldLeft(ListSet.empty[BoundVal]) { - case (acc, bv: BoundVal) => acc + bv - case (acc, _) => acc - } + case ma: MethodApp => ma.self :: ma.argss.argssList collect { case bv: BoundVal => bv } toSet case _ => Set.empty } @@ -1000,6 +1007,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi * Merges the generated `code` with the `xtee` */ def merge(code: Rep, xtee: Rep)(xtor: Rep, ctx: Ctx): Rep = { + /** * Puts `code` in the right position in `xtee`. */ @@ -1012,6 +1020,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi collectAllBVs0(r, Set.empty) } + /** + * Returns the statement in `r` at which point all the BVs from `lookFor` have been declared, + * if it exists. + */ def findPos(r: LetBinding, lookFor: Set[BoundVal]): Option[LetBinding] = { if (lookFor.isEmpty) None else r match { @@ -1027,7 +1039,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } // All usages of BVs that come from the xtee - val lookFor = collectAllBVs(code).filter((es.ex._1.values ++ es.ex._3.values).toSet contains) + val lookFor = collectAllBVs(code).filter((es.ex._1.values ++ es.ex._3.values.flatten).toSet contains) findPos(xtee, lookFor) match { case Some(pos) => @@ -1062,7 +1074,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi val bv = ctx(xtorRet) bottomUpPartial(xtee) { case `bv` => code } } - case _ => code } @@ -1084,9 +1095,9 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi if (preCheck(es.ex)) for { code <- code(es.ex) alsoApply(c => println(s"CODE: $c")) - code0 = merge(code, xtee)(xtor, es.ctx) alsoApply (c => println(s"CODE0: $c")) - cleanCode <- filterNot(es.matchedImpureBVs)(code0) - } yield cleanCode + mergedCode = merge(code, xtee)(xtor, es.ctx) alsoApply (c => println(s"CODE0: $c")) + finalCode <- filterNot(es.matchedImpureBVs)(mergedCode) + } yield finalCode else None } @@ -1150,6 +1161,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi class ReificationContext(val inExtractor: Bool) { reif => var firstLet: FlatOpt[LetBinding] = Non var curLet: FlatOpt[LetBinding] = Non + + /** + * Updates the current let-binding with `lb`. + */ def += (lb: LetBinding): Unit = { curLet match { case Non => firstLet = lb.som @@ -1157,10 +1172,15 @@ class ReificationContext(val inExtractor: Bool) { reif => } curLet = lb.som } + + /** + * Let-binds `d` and updates the current let-binding with it. + */ def += (d: Def): Symbol = new Symbol { protected var _parent: SymbolParent = new LetBinding("tmp", this, d, this) alsoApply (reif += _) } - def finalize(r: Rep) = { + + def finalize(r: Rep): Rep = { firstLet match { case Non => assert(curLet.isEmpty) diff --git a/src/test/scala/squid/ir/fastir/HOPVTests.scala b/src/test/scala/squid/ir/fastir/HOPVTests.scala index dfd818ba..ca53763d 100644 --- a/src/test/scala/squid/ir/fastir/HOPVTests.scala +++ b/src/test/scala/squid/ir/fastir/HOPVTests.scala @@ -100,6 +100,11 @@ class HOPVTests extends MyFunSuiteBase(HOPVTests.Embedding) { case ir"$foldLeft(List(1, 2, 3), 0, (acc: Int, x: Int) => acc + x)" => assert(foldLeft =~= ir"(l: List[Int], z: Int, f: (Int, Int) => Int) => l.foldLeft(z)(f)") } + + ir"List(0,1,2,3).foldLeft(0)((acc,x) => acc+x)" match { + case ir"$foldLeft(List(0, 1, 2, 3), 0, (acc: Int, x: Int) => acc + x)" => + assert(foldLeft =~= ir"(l: List[Int], z: Int, f: (Int, Int) => Int) => l.foldLeft(z)(f)") + } } test("Match letbindings") { From f8ef360089ef8a3af529c65cc50c6781848b0bb1 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Mon, 18 Dec 2017 22:13:24 +0100 Subject: [PATCH 62/66] Remove catch-all cases --- src/main/scala/squid/ir/fastanf/FastANF.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 9bf1506d..c92854fc 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -318,7 +318,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case ByName(r) => ByName(transformRep0(r)) case Ascribe(s, t) => Ascribe(transformRep0(s), t) case Module(p, n, t) => Module(transformRep0(p), n, t) - case r => r + case r @ (_: Constant | _: Hole | _: Symbol | _: SplicedHole | _: HOPHole | _: HOPHole2 | _: NewObject | _: StaticModule) => r }) } @@ -611,10 +611,13 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi val maybeES = for { es1 <- typ extract (xtee.typ, Covariant) m <- merge(es.ex, es1) - + + /** + * The arguments of HOPHoles are _always_ passed by-name + */ argss0 = argss.map(_.map { - case ByName(r) => r // TODO is it ok to unwrap by-name arguments? - case r => r + case ByName(r) => r + case _ => die // All HOPHole args have to be by-name! }) (f, es2) <- argss0.foldRight(Option(xtee -> (es withNewExtract m)))(buildFunc) From 02f8c62b692a035019c5c38f9494a815b482effb Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Tue, 19 Dec 2017 13:23:40 +0100 Subject: [PATCH 63/66] Simplify handling of failed matches --- src/main/scala/squid/ir/fastanf/FastANF.scala | 98 +++++++++++-------- 1 file changed, 56 insertions(+), 42 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index c92854fc..3748c8bd 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -333,9 +333,29 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi // * --- * --- * --- * Extraction State * --- * --- * --- * /** - * Signals if the current extraction attempt has failed. `Left` it has failed, `Right` it succeeded. + * Mapping of failed matches between xtor BVs and xtee BVs. */ - type ExtractState = Either[State, State] + type Failed = Map[BoundVal, Set[BoundVal]] + + /** + * Signals if the current extraction attempt has failed. `Left` returns all the failed matchings at this point, + * `Right` the update state after a successful extraction. + */ + type ExtractState = Either[Failed, State] + + val fail = Left(Map.empty[BoundVal, Set[BoundVal]]) + + def failWith(currentFailed: Failed)(failed: (BoundVal, BoundVal)): ExtractState = { + // Add the entry `u` to the immutable multimap `m` + def updateWith(m: Failed)(u: (BoundVal, BoundVal)): Failed = u match { + case (k, v) => m + (k -> (m.getOrElse(k, Set.empty) + v)) + } + + Left(updateWith(currentFailed)(failed)) + } + + + implicit def rightBias[A, B](e: Either[A, B]): Either.RightProjection[A,B] = e.right /** @@ -344,11 +364,11 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi * @param ctx discovered matchings between bound values in the `xtor` and the `xtee`. * @param instructions what has to be done for each let-binding * @param matchedImpureBVs impure statements that have been matched - * @param failedMatches statements that do not match + * @param failed statements that do not match * @param strategy see [[Strategy]] */ case class State(ex: Extract, ctx: Ctx, instructions: Instructions, matchedImpureBVs: Set[BoundVal], - failedMatches: Map[BoundVal, Set[BoundVal]], strategy: Strategy) { + failed: Failed, strategy: Strategy) { private val _strategy = strategy private val _instructions = instructions @@ -366,7 +386,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } def withMatchedImpure(bv: BoundVal): State = copy(matchedImpureBVs = matchedImpureBVs + bv) - def withFailedMatch(p: (BoundVal, BoundVal)): State = copy(failedMatches = updateWith(failedMatches)(p)) + def withFailed(failed: Failed): State = copy(failed = failed) def withStrategy(s: Strategy): State = copy(strategy = s) def withDefaultStrategy: State = copy(strategy = _strategy) @@ -375,12 +395,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def withDefaultInstructions: State = copy(instructions = _instructions) def updateExtractWith(e: Option[Extract]*)(implicit default: State): ExtractState = { - mergeAll(Some(ex) +: e).fold[ExtractState](Left(default))(ex => Right(this withNewExtract ex)) - } - - // Add the entry `u` to the immutable multimap `m` - private def updateWith(m: Map[BoundVal, Set[BoundVal]])(u: (BoundVal, BoundVal)): Map[BoundVal, Set[BoundVal]] = u match { - case (k, v) => m + (k -> (m.getOrElse(k, Set.empty) + v)) + mergeAll(Some(ex) +: e).fold[ExtractState](fail)(ex => Right(this withNewExtract ex)) } } @@ -389,8 +404,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def forRewriting(xtor: Rep, xtee: Rep): State = apply(xtor, xtee, PartialMatching) private def apply(xtor: Rep, xtee: Rep, strategy: Strategy): State = - State(EmptyExtract, ListMap.empty, Instructions(xtor, xtee), Set.empty, Map.empty.withDefaultValue(Set.empty), strategy) - // \_ Assumed by `extractWithState` + State(EmptyExtract, ListMap.empty, Instructions(xtor, xtee), Set.empty, Map.empty, strategy) } @@ -625,7 +639,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi if hasNoUndeclaredUsages(f) } yield es2 updateExtractWith Some(repExtract(name -> f)) - maybeES getOrElse Left(es) + maybeES getOrElse fail } def extractLBs(lb1: LetBinding, lb2: LetBinding)(implicit es: State): ExtractState = { @@ -647,9 +661,9 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi * \------Skip------/ */ case (Start, Skip) => for { - es1 <- extractAndContinue(lb1, lb2).left - es2 <- extractWithState(lb1, lb2.body)(es1).left - } yield es2 + failed1 <- extractAndContinue(lb1, lb2).left + failed2 <- extractWithState(lb1, lb2.body)(es withFailed failed1).left + } yield failed2 case (Skip, Start) => extractWithState(lb1.body, lb2) @@ -671,13 +685,13 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi * It won't extract inside the [[ByName]] arguments. */ def extractInside(r: Rep, d: Def)(implicit es: State): ExtractState = d match { - case ma: MethodApp => (ma.self :: ma.argss.argssList).foldLeft[ExtractState](Left(es)) { case (acc, arg) => + case ma: MethodApp => (ma.self :: ma.argss.argssList).foldLeft[ExtractState](Left(es.failed)) { case (acc, arg) => for { - es1 <- acc.left - es2 <- extractWithState(r, arg)(es1).left - } yield es2 + failed1 <- acc.left + failed2 <- extractWithState(r, arg)(es withFailed failed1).left + } yield failed2 } - case _ => Left(es) + case _ => fail } @@ -688,7 +702,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi xtor -> xtee match { case (h: Hole, _) => extractedBy(h) match { case Some(`xtee`) => Right(es) - case Some(_) => Left(es) // Something has gone wrong + case Some(_) => fail // Something has gone wrong case None => extractHole(h, xtee) } @@ -697,7 +711,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (lb1: LetBinding, lb2: LetBinding) => extractLBs(lb1, lb2) case (lb: LetBinding, _) => es.instructions.get(lb.bound) match { - case Start => Left(es) // The `xtor` has more impure statements than the `xtee`. + case Start => fail // The `xtor` has more impure statements than the `xtee`. case Skip => extractWithState(lb.body, xtee) } @@ -707,7 +721,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case CompleteMatching => es.instructions.get(lb.bound) match { // The `xtee` has more impure statements or dead-ends than the `xtor` - case Start => Left(es) + case Start => fail // The skipped statements of the `xtee` will be matched later // when extracting its return value or the `Start` ones. @@ -719,26 +733,26 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi // 2. The components of the let-bindings value // 3. (1. and then 2.) on the next statement case PartialMatching => for { - es1 <- extractWithState(bv, lb.bound).left - es2 <- extractInside(bv, lb.value)(es1).left - es3 <- extractWithState(bv, lb.body)(es2).left - } yield es3 + failed1 <- extractWithState(bv, lb.bound).left + failed2 <- extractInside(bv, lb.value)(es withFailed failed1).left + failed3 <- extractWithState(bv, lb.body)(es withFailed failed2).left + } yield failed3 } case (Constant(()), _: LetBinding) => es.strategy match { // Unit return doesn't have to matched case PartialMatching => Right(es) - case CompleteMatching => Left(es) + case CompleteMatching => fail } case (_: Rep, lb: LetBinding) => es.strategy match { case PartialMatching => for { - es1 <- extractInside(xtor, lb.value).left - es2 <- extractWithState(xtor, lb.body)(es1).left - } yield es2 + failed1 <- extractInside(xtor, lb.value).left + failed2 <- extractWithState(xtor, lb.body)(es withFailed failed1).left + } yield failed2 - case CompleteMatching => Left(es) + case CompleteMatching => fail } // Only a [[ByName]] can extract a [[ByName]] so that @@ -757,23 +771,23 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi // The actual extraction happens here. case (bv1: BoundVal, bv2: BoundVal) => es.ctx.get(bv1) map { extractedByBV1 => if (extractedByBV1 == bv2) Right(es) - else Left(es) + else fail } getOrElse { if (bv1 == bv2) Right(es) - else if (es.failedMatches(bv1) contains bv2) Left(es) // Previously failed to extract `bv1` with `bv2` + else if (es.failed.getOrElse(bv1, Set.empty) contains bv2) fail // Previously failed to extract `bv1` with `bv2` else (bv1.owner, bv2.owner) match { case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value) match { case Right(es) => effect(lb2.value) match { case Pure => Right(es withCtx lb1.bound -> lb2.bound) case Impure => Right(es withCtx lb1.bound -> lb2.bound withMatchedImpure lb2.bound) } - case Left(es) => Left(es withFailedMatch lb1.bound -> lb2.bound) + case Left(failed) => failWith(failed)(lb1.bound -> lb2.bound) } case (l1: Lambda, l2: Lambda) => // We cannot know the owner of the [[Lambda]] // so in case of a failure to extract there's nothing to do. extractDefs(l1, l2) map (_ withCtx l1.bound -> l2.bound) - case _ => Left(es withFailedMatch bv1 -> bv2) + case _ => failWith(es.failed)(bv1 -> bv2) } } @@ -787,7 +801,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (NewObject(t1), NewObject(t2)) => es updateExtractWith (t1 extract(t2, Covariant)) - case _ => Left(es) + case _ => fail } } alsoApply (res => println(s"Extract: $res")) @@ -843,7 +857,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (r1: Rep, r2: Rep) => extractWithState(r1, r2) - case _ => Left(es) + case _ => fail } for { @@ -856,7 +870,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi // Assuming a [[DefHole]] only extracts impure statements case (DefHole(h), _) if !isPure(v2) => extractWithState(h, wrapConstruct(letbind(v2))) - case _ => Left(es) + case _ => fail } override def rewriteRep(xtor: Rep, xtee: Rep, code: Extract => Option[Rep]): Option[Rep] = @@ -888,7 +902,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi * Hence, if the current statements don't have to be extracted at this point (both are pure and not return values) * we simply skip the extraction of the current `xtee`. */ - case (Skip, Skip) => Left(es) + case (Skip, Skip) => fail case _ => extractWithState(lb1, lb2) } From 6c5360ca5808412c591fbff88efca22c97a2e508 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Tue, 19 Dec 2017 14:06:49 +0100 Subject: [PATCH 64/66] Add HOP test and cleanup --- src/main/scala/squid/ir/fastanf/FastANF.scala | 11 +++-- src/main/scala/squid/ir/fastanf/Rep.scala | 41 +++++++++++-------- .../scala/squid/ir/fastir/HOPVTests.scala | 5 +++ .../squid/ir/fastir/RewritingTests.scala | 4 +- 4 files changed, 39 insertions(+), 22 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 3748c8bd..275fda06 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -28,7 +28,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi var scopes: List[ReificationContext] = Nil - // TODO what happens here? It looks like `finalize` doesn't anything but run the thunk `r`. But if I remove `scp` it doesn't work + /** + * Runs the thunk `r`, appends it to the the last let-binding in the reification context, + * and returns the first let-binding. + */ @inline final def wrap(r: => Rep, inXtor: Bool): Rep = { val scp = new ReificationContext(inXtor) scopes ::= scp @@ -354,8 +357,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi Left(updateWith(currentFailed)(failed)) } - - implicit def rightBias[A, B](e: Either[A, B]): Either.RightProjection[A,B] = e.right /** @@ -978,11 +979,13 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi * For instance in {{{filterNot(ir"val a = readInt; a", Set(a)}}} */ def filterNot(remove: Set[BoundVal])(r: Rep): Option[Rep] = { + /** * Returns the statements to remove based on the initial `remove` set. * * Adds all the pure statements referencing removed ones. - * TODO add mathematical term + * In essence, computes the the transitive closures of pure statements that depend on the BVs in `remove` + * and adds it to `remove`. */ def buildToRemove(r: Rep, remove: Set[BoundVal]): Set[BoundVal] = r match { case lb: LetBinding if remove contains lb.bound => buildToRemove(lb.body, remove) diff --git a/src/main/scala/squid/ir/fastanf/Rep.scala b/src/main/scala/squid/ir/fastanf/Rep.scala index ab7eb199..64ce9f23 100644 --- a/src/main/scala/squid/ir/fastanf/Rep.scala +++ b/src/main/scala/squid/ir/fastanf/Rep.scala @@ -48,6 +48,10 @@ sealed abstract class Rep extends RepOption with ArgumentList with FlatSom[Rep] def argssList: List[Rep] = this :: Nil } +// TODO check that it only contains args +/** + * Wraps _arguments_ to make them by-name arguments. + */ final case class ByName(r: Rep) extends Rep { val typ: TypeRep = r.typ } @@ -115,16 +119,16 @@ object MethodApp { def go(args: ArgumentList)(k: (Option[LetBinding], ArgumentList) => (Option[LetBinding], ArgumentList)): (Option[LetBinding], ArgumentList) = args match { case NoArguments => k(None, NoArguments) case ArgumentCons(h, t) => - val (lb, ret) = f(h) + val (maybeLB, r) = f(h) go(t) { case (lb0, args0) => - val newLB = (lb0, lb) match { + val maybeNewLB = (lb0, maybeLB) match { case (Some(lb0), Some(lb)) => lb.last.body = lb0; Some(lb) case (_, Some(lb)) => Some(lb) case (Some(lb0), _) => Some(lb0) case _ => None } - k(newLB, ArgumentCons(ret, args0)) + k(maybeNewLB, ArgumentCons(r, args0)) } case SplicedArgument(arg) => k.tupled(f(arg)) case r: Rep => k.tupled(f(r)) @@ -154,33 +158,38 @@ object MethodApp { go(argss)((lb, argss) => (lb, argss)) } - + + /** + * Splits `r` at before the point of its return value + */ def split(r: Rep): Option[LetBinding] -> Rep = r match { case lb: LetBinding => Some(lb) -> lb.last.body case _ => None -> r } - val (lb0, self0) = self |> split - val (lbFromArgss, argss0) = processArgss(argss)(split) alsoApply println + val (maybeSelfLB, selfBV) = self |> split + val (maybeArgssLB, argssBVs) = processArgss(argss)(split) - val newLB = (lb0, lbFromArgss) match { - case (Some(lb0), Some(lbFromArgss)) => lbFromArgss.last.body = lb0; Some(lbFromArgss) + // Merge `maybeSelfLB` and `maybeArgssLB` + val maybeNewLB = (maybeSelfLB, maybeArgssLB) match { + case (Some(selfLB), Some(lbFromArgss)) => lbFromArgss.last.body = selfLB; Some(lbFromArgss) case (_, Some(lbFromArgss)) => Some(lbFromArgss) - case (Some(lb0), _) => Some(lb0) + case (Some(selfLB), _) => Some(selfLB) case _ => None } - val ma = MethodApp(self0, mtd, targs, argss0, typ0) + val maWithBVs = MethodApp(selfBV, mtd, targs, argssBVs, typ0) - newLB match { - case Some(lb) => - lb.last.body = new Symbol { - protected var _parent: SymbolParent = new LetBinding("tmp", this, ma, this) + // Set `maWithBVs` as the body of `maybeNewLB` + maybeNewLB match { + case Some(newLB) => + newLB.last.body = new Symbol { + protected var _parent: SymbolParent = new LetBinding("tmp", this, maWithBVs, this) }.owner.asInstanceOf[LetBinding] - lb + newLB case None => new Symbol { - protected var _parent: SymbolParent = new LetBinding("tmp", this, ma, this) + protected var _parent: SymbolParent = new LetBinding("tmp", this, maWithBVs, this) }.owner.asInstanceOf[LetBinding] } } diff --git a/src/test/scala/squid/ir/fastir/HOPVTests.scala b/src/test/scala/squid/ir/fastir/HOPVTests.scala index ca53763d..dbd01ddf 100644 --- a/src/test/scala/squid/ir/fastir/HOPVTests.scala +++ b/src/test/scala/squid/ir/fastir/HOPVTests.scala @@ -105,6 +105,11 @@ class HOPVTests extends MyFunSuiteBase(HOPVTests.Embedding) { case ir"$foldLeft(List(0, 1, 2, 3), 0, (acc: Int, x: Int) => acc + x)" => assert(foldLeft =~= ir"(l: List[Int], z: Int, f: (Int, Int) => Int) => l.foldLeft(z)(f)") } + + ir"List(0,1,2,3).foldLeft(0)((acc,x) => acc+x)" match { + case ir"$foldLeft(0, List(0, 1, 2, 3), (acc: Int, x: Int) => acc + x)" => + assert(foldLeft =~= ir"(z: Int, _: List[Int], f: (Int, Int) => Int) => List(z, 1, 2, 3).foldLeft(z)(f)") + } } test("Match letbindings") { diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index d6cc3875..c95533b3 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -202,9 +202,9 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { assert(a =~= ir"val f1 = (x: Int) => x + 1; val f2 = (x: Int) => x * 2; val f3 = f1 compose f2; f3(5)") val b = ir"List(1,2,3).map(_ * 2).map(_ * 4)" rewrite { - case ir"($l: List[Int]).map($f1: Int => Int).map($f2: Int => Int)" => ir"$l.map($f1 compose $f2)" + case ir"($l: List[Int]).map($f1: Int => Int).map($f2: Int => Int)" => ir"$l.map($f2 compose $f1)" } - assert(b =~= ir"val l = List(1, 2, 3); val f1 = (x: Int) => x * 2; val cbf1 = List.canBuildFrom; val f2 = (x: Int) => x * 4; val cbf2 = List.canBuildFrom; val f3 = f1 compose f2; l.map(f3)") + assert(b =~= ir"val l = List(1, 2, 3); val f1 = (x: Int) => x * 2; val cbf1 = List.canBuildFrom; val f2 = (x: Int) => x * 4; val cbf2 = List.canBuildFrom; val f3 = f2 compose f1; l.map(f3)") } test("Squid paper") { From 4f8b85db96f830ec0ea80a352a72df55c2375264 Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Wed, 20 Dec 2017 19:03:59 +0100 Subject: [PATCH 65/66] Add helper functions --- src/main/scala/squid/ir/fastanf/FastANF.scala | 76 ++++++++++--------- 1 file changed, 40 insertions(+), 36 deletions(-) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 275fda06..41437f05 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -341,21 +341,16 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi type Failed = Map[BoundVal, Set[BoundVal]] /** - * Signals if the current extraction attempt has failed. `Left` returns all the failed matchings at this point, + * Signals if the current extraction attempt has failed. + * `Left` returns the matching that failed, * `Right` the update state after a successful extraction. */ - type ExtractState = Either[Failed, State] - - val fail = Left(Map.empty[BoundVal, Set[BoundVal]]) - - def failWith(currentFailed: Failed)(failed: (BoundVal, BoundVal)): ExtractState = { - // Add the entry `u` to the immutable multimap `m` - def updateWith(m: Failed)(u: (BoundVal, BoundVal)): Failed = u match { - case (k, v) => m + (k -> (m.getOrElse(k, Set.empty) + v)) - } - - Left(updateWith(currentFailed)(failed)) - } + type ExtractState = Either[Set[(BoundVal, BoundVal)], State] + + /* Helper functions for `ExtractState` */ + def succeed(implicit es: State): ExtractState = Right(es) + def fail = Left(Set.empty[(BoundVal, BoundVal)]) + def failWith(failed: Set[(BoundVal, BoundVal)]): ExtractState = Left(failed) implicit def rightBias[A, B](e: Either[A, B]): Either.RightProjection[A,B] = e.right @@ -387,7 +382,13 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } def withMatchedImpure(bv: BoundVal): State = copy(matchedImpureBVs = matchedImpureBVs + bv) - def withFailed(failed: Failed): State = copy(failed = failed) + def withFailed(newFailed: Set[(BoundVal, BoundVal)]): State = { + val updatedFailed = newFailed.foldLeft(failed) { case (mergedF, (k, v)) => + println(s"($mergedF, ($k -> $v)") + mergedF + (k -> mergedF.get(k).map(_ + v).getOrElse(Set(v))) + } + copy(failed = updatedFailed) + } def withStrategy(s: Strategy): State = copy(strategy = s) def withDefaultStrategy: State = copy(strategy = _strategy) @@ -396,7 +397,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def withDefaultInstructions: State = copy(instructions = _instructions) def updateExtractWith(e: Option[Extract]*)(implicit default: State): ExtractState = { - mergeAll(Some(ex) +: e).fold[ExtractState](fail)(ex => Right(this withNewExtract ex)) + mergeAll(Some(ex) +: e).fold[ExtractState](fail)(ex => succeed(this withNewExtract ex)) } } @@ -664,7 +665,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (Start, Skip) => for { failed1 <- extractAndContinue(lb1, lb2).left failed2 <- extractWithState(lb1, lb2.body)(es withFailed failed1).left - } yield failed2 + } yield failed1 ++ failed2 case (Skip, Start) => extractWithState(lb1.body, lb2) @@ -685,14 +686,17 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi * Only extracts inside a [[MethodApp]], fails for all other cases. * It won't extract inside the [[ByName]] arguments. */ - def extractInside(r: Rep, d: Def)(implicit es: State): ExtractState = d match { - case ma: MethodApp => (ma.self :: ma.argss.argssList).foldLeft[ExtractState](Left(es.failed)) { case (acc, arg) => + def extractInside(r: Rep, d: Def)(implicit es: State): ExtractState = { + println(s"INSIDE: $r -> $d") + d match { + case ma: MethodApp => (ma.self :: ma.argss.argssList).foldLeft[ExtractState](fail) { case (acc, arg) => for { failed1 <- acc.left failed2 <- extractWithState(r, arg)(es withFailed failed1).left - } yield failed2 + } yield failed1 ++ failed2 } - case _ => fail + case _ => fail + } } @@ -702,7 +706,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi xtor -> xtee match { case (h: Hole, _) => extractedBy(h) match { - case Some(`xtee`) => Right(es) + case Some(`xtee`) => succeed case Some(_) => fail // Something has gone wrong case None => extractHole(h, xtee) } @@ -717,7 +721,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } case (bv: BoundVal, lb: LetBinding) => - if (es.ctx.keySet contains bv) Right(es) // `bv` has already been extracted + if (es.ctx.keySet contains bv) succeed // `bv` has already been extracted else es.strategy match { case CompleteMatching => es.instructions.get(lb.bound) match { @@ -736,13 +740,13 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case PartialMatching => for { failed1 <- extractWithState(bv, lb.bound).left failed2 <- extractInside(bv, lb.value)(es withFailed failed1).left - failed3 <- extractWithState(bv, lb.body)(es withFailed failed2).left - } yield failed3 + failed3 <- extractWithState(bv, lb.body)(es withFailed (failed1 ++ failed2)).left + } yield failed1 ++ failed2 ++ failed3 } case (Constant(()), _: LetBinding) => es.strategy match { // Unit return doesn't have to matched - case PartialMatching => Right(es) + case PartialMatching => succeed case CompleteMatching => fail } @@ -751,7 +755,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case PartialMatching => for { failed1 <- extractInside(xtor, lb.value).left failed2 <- extractWithState(xtor, lb.body)(es withFailed failed1).left - } yield failed2 + } yield failed1 ++ failed2 case CompleteMatching => fail } @@ -771,31 +775,31 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi // The actual extraction happens here. case (bv1: BoundVal, bv2: BoundVal) => es.ctx.get(bv1) map { extractedByBV1 => - if (extractedByBV1 == bv2) Right(es) + if (extractedByBV1 == bv2) succeed else fail } getOrElse { - if (bv1 == bv2) Right(es) + if (bv1 == bv2) succeed else if (es.failed.getOrElse(bv1, Set.empty) contains bv2) fail // Previously failed to extract `bv1` with `bv2` else (bv1.owner, bv2.owner) match { case (lb1: LetBinding, lb2: LetBinding) => extractDefs(lb1.value, lb2.value) match { case Right(es) => effect(lb2.value) match { - case Pure => Right(es withCtx lb1.bound -> lb2.bound) - case Impure => Right(es withCtx lb1.bound -> lb2.bound withMatchedImpure lb2.bound) + case Pure => succeed(es withCtx lb1.bound -> lb2.bound) + case Impure => succeed(es withCtx lb1.bound -> lb2.bound withMatchedImpure lb2.bound) } - case Left(failed) => failWith(failed)(lb1.bound -> lb2.bound) + case Left(failed) => failWith(failed + (lb1.bound -> lb2.bound)) } case (l1: Lambda, l2: Lambda) => // We cannot know the owner of the [[Lambda]] // so in case of a failure to extract there's nothing to do. extractDefs(l1, l2) map (_ withCtx l1.bound -> l2.bound) - case _ => failWith(es.failed)(bv1 -> bv2) + case _ => failWith(Set(bv1 -> bv2)) } } case (Constant(v1), Constant(v2)) if v1 == v2 => es updateExtractWith (xtor.typ extract(xtee.typ, Covariant)) // Assuming if they have the same name the type is the same - case (StaticModule(fn1), StaticModule(fn2)) if fn1 == fn2 => Right(es) + case (StaticModule(fn1), StaticModule(fn2)) if fn1 == fn2 => succeed // Assuming if they have the same name and prefix the type is the same case (Module(p1, n1, _), Module(p2, n2, _)) if n1 == n2 => extractWithState(p1, p2) @@ -851,10 +855,10 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case (SplicedArgument(arg1), SplicedArgument(arg2)) => extractWithState(arg1, arg2) case (SplicedArgument(arg), ac: ArgumentCons) => es updateExtractWith spliceExtract(arg, Args(ac.argssList: _*)) case (SplicedArgument(arg), r: Rep) => es updateExtractWith spliceExtract(arg, Args(r)) - case (SplicedArgument(_), NoArguments) => Right(es) + case (SplicedArgument(_), NoArguments) => succeed - case (NoArguments, NoArguments) => Right(es) - case (NoArgumentLists, NoArgumentLists) => Right(es) + case (NoArguments, NoArguments) => succeed + case (NoArgumentLists, NoArgumentLists) => succeed case (r1: Rep, r2: Rep) => extractWithState(r1, r2) From 6593318d25376f9fbc9141f8523fc39e4169d58e Mon Sep 17 00:00:00 2001 From: Dennis van der Bij Date: Sun, 7 Jan 2018 11:25:41 +0100 Subject: [PATCH 66/66] Add documentation and cleanup --- src/main/scala/squid/ir/fastanf/Effects.scala | 18 ++- src/main/scala/squid/ir/fastanf/FastANF.scala | 109 +++++++----------- src/main/scala/squid/test/Test.scala | 45 -------- .../{HOPVTests.scala => HOPTests.scala} | 6 +- .../squid/ir/fastir/RewritingTests.scala | 31 +++-- 5 files changed, 79 insertions(+), 130 deletions(-) delete mode 100644 src/main/scala/squid/test/Test.scala rename src/test/scala/squid/ir/fastir/{HOPVTests.scala => HOPTests.scala} (98%) diff --git a/src/main/scala/squid/ir/fastanf/Effects.scala b/src/main/scala/squid/ir/fastanf/Effects.scala index c9941588..00e69845 100644 --- a/src/main/scala/squid/ir/fastanf/Effects.scala +++ b/src/main/scala/squid/ir/fastanf/Effects.scala @@ -4,8 +4,12 @@ import squid.utils.Bool import scala.collection.mutable +/** + * Basic effect system where statements can be `Pure` or `Impure`. + * By default everything is impure. + */ trait Effects { - protected val pureMtds = mutable.Set[MethodSymbol]() + private val pureMtds = mutable.Set[MethodSymbol]() //protected val pureTyps = mutable.Set[TypeSymbol]() def addPureMtd(m: MethodSymbol): Unit = pureMtds += m @@ -39,10 +43,7 @@ trait Effects { HOPHole(_, _, _, _) => Pure } - def mtdEffect(m: MethodSymbol): Effect = { - //println(m) - if (pureMtds contains m) Pure else Impure - } + def mtdEffect(m: MethodSymbol): Effect = if (pureMtds contains m) Pure else Impure def defEffect(d: Def): Effect = d match { case l: Lambda => effect(l.body) @@ -55,6 +56,9 @@ trait Effects { } sealed trait Effect { + /** + * Combine effects. + */ def |+|(e: Effect): Effect } @@ -67,6 +71,10 @@ trait Effects { } } +/** + * Standard effect system. + * Defines the set of pure methods used in tests. + */ trait StandardEffects extends Effects { addPureMtd(MethodSymbol(TypeSymbol("scala.Int"),"$plus")) addPureMtd(MethodSymbol(TypeSymbol("scala.Double"),"$plus")) diff --git a/src/main/scala/squid/ir/fastanf/FastANF.scala b/src/main/scala/squid/ir/fastanf/FastANF.scala index 41437f05..74b9c1c6 100644 --- a/src/main/scala/squid/ir/fastanf/FastANF.scala +++ b/src/main/scala/squid/ir/fastanf/FastANF.scala @@ -325,10 +325,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi }) } - protected def extract(xtor: Rep, xtee: Rep): Option[Extract] = { - println(s"Extract(\n$xtor, \n$xtee)") - extractWithState(xtor, xtee)(State.forExtraction(xtor, xtee)).toOption map (_.ex) - } + protected def extract(xtor: Rep, xtee: Rep): Option[Extract] = extractWithState(xtor, xtee)(State.forExtraction(xtor, xtee)).toOption map (_.ex) // Context is mapping from xtor BVs to xtee BVs type Ctx = Map[BoundVal, BoundVal] @@ -342,7 +339,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi /** * Signals if the current extraction attempt has failed. - * `Left` returns the matching that failed, + * `Left` returns the matchings that failed (xtor -> xtee), * `Right` the update state after a successful extraction. */ type ExtractState = Either[Set[(BoundVal, BoundVal)], State] @@ -384,7 +381,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi def withFailed(newFailed: Set[(BoundVal, BoundVal)]): State = { val updatedFailed = newFailed.foldLeft(failed) { case (mergedF, (k, v)) => - println(s"($mergedF, ($k -> $v)") mergedF + (k -> mergedF.get(k).map(_ + v).getOrElse(Set(v))) } copy(failed = updatedFailed) @@ -512,8 +508,6 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } def extractWithState(xtor: Rep, xtee: Rep)(implicit es: State): ExtractState = { - println(es) - def extractHOPHole(name: String, typ: TypeRep, argss: List[List[Rep]], visible: List[BoundVal])(implicit es: State): ExtractState = { def hasNoUndeclaredUsages(r: Rep): Boolean = { def hasNoUndeclaredUsages0(r: Rep, declared: Set[BoundVal]): Boolean = r match { @@ -536,11 +530,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi /** * Attemps to find the `xtors` in the body of the function and replaces them with newly generated arguments, * adding the new arguments to the function. Even if the `xtors` are not found in the body, - * arguments representing them will be generated and added to it. They will simply be ignored when it is applied. - * - * @param xtors - * @param maybeFuncAndState the current function and the extraction state after its creation. - * @return + * arguments representing them will be generated and added to it. They will simply not appear in the function's body. */ def buildFunc(xtors: List[Rep], maybeFuncAndState: Option[(Rep, State)]): Option[(Rep, State)] = { val args = xtors.map(arg => bindVal("hopArg", arg.typ, Nil)) @@ -548,7 +538,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi /** * Returns the body of function. - * This is the body of the most deeply nested [[Lambda]] as the function is curried. + * This is the body of the most deeply nested [[Lambda]] since the function is curried. */ def body(func: Rep): Rep = { def body0(d: Def): Option[Rep] = d match { @@ -596,6 +586,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi /* * Replaces every occurrence of `lb` in the body with `arg`. */ + // TODO single pass extractWithState(lb, body) map { es0 => val replace = es0.ctx(lb.last.bound) val body0 = bottomUpPartial(filterLBs(body)(es0.ctx.values.toSet contains _.bound)) { case `replace` => arg } @@ -637,6 +628,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi }) (f, es2) <- argss0.foldRight(Option(xtee -> (es withNewExtract m)))(buildFunc) + // ^ so arguments are in the right order if hasNoUndeclaredUsages(f) } yield es2 updateExtractWith Some(repExtract(name -> f)) @@ -686,28 +678,23 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi * Only extracts inside a [[MethodApp]], fails for all other cases. * It won't extract inside the [[ByName]] arguments. */ - def extractInside(r: Rep, d: Def)(implicit es: State): ExtractState = { - println(s"INSIDE: $r -> $d") - d match { - case ma: MethodApp => (ma.self :: ma.argss.argssList).foldLeft[ExtractState](fail) { case (acc, arg) => - for { - failed1 <- acc.left - failed2 <- extractWithState(r, arg)(es withFailed failed1).left - } yield failed1 ++ failed2 - } - case _ => fail + def extractInside(r: Rep, d: Def)(implicit es: State): ExtractState = d match { + case ma: MethodApp => (ma.self :: ma.argss.argssList).foldLeft[ExtractState](fail) { case (acc, arg) => + for { + failed1 <- acc.left + failed2 <- extractWithState(r, arg)(es withFailed failed1).left + } yield failed1 ++ failed2 } + case _ => fail } def extractedBy(h: Hole)(implicit es: State): Option[Rep] = es.ex._1 get h.name - - println(s"-----\nextractWithState: \n$xtor\n\n$xtee\n-----\n\n") xtor -> xtee match { case (h: Hole, _) => extractedBy(h) match { case Some(`xtee`) => succeed - case Some(_) => fail // Something has gone wrong + case Some(_) => die // Something has gone wrong case None => extractHole(h, xtee) } @@ -808,7 +795,7 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi case _ => fail } - } alsoApply (res => println(s"Extract: $res")) + } protected def spliceExtract(xtor: Rep, args: Args): Option[Extract] = xtor match { // Should check that type matches, but don't see how to access it for Args @@ -882,38 +869,34 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi rewriteRep0(xtor, xtee, code)(State.forRewriting(xtor, xtee)) def rewriteRep0(xtor: Rep, xtee: Rep, code: Extract => Option[Rep])(implicit es: State): Option[Rep] = { - def rewriteRepWithState(xtor: Rep, xtee: Rep)(implicit es: State): ExtractState = { - println(s"rewriteRepWithState(\n\t$xtor\n\t$xtee)($es)") - - (xtor, xtee) match { - case (lb1: LetBinding, lb2: LetBinding) => + def rewriteRepWithState(xtor: Rep, xtee: Rep)(implicit es: State): ExtractState = (xtor, xtee) match { + case (lb1: LetBinding, lb2: LetBinding) => + /** + * Pure statements (annotated with the instruction [[Skip]] only have to extracted starting from their + * return value and extract each sub-part recursively. Through this mechanism the order of the pure statements + * does not matter. + * For instance, this will successfully match : + * {{{ + * ir"val b = 22.toDouble; val a = 11.toDouble; a + b" match { + * case ir"val aX = 11.toDouble; val bX = 22.toDouble; aX + bX" => ??? + * } + * }}} + * + */ + (es.instructions.get(lb1.bound), es.instructions.get(lb2.bound)) match { + /** - * Pure statements (annotated with the instruction [[Skip]] only have to extracted starting from their - * return value and extract each sub-part recursively. Through this mechanism the order of the pure statements - * does not matter. - * For instance, this will successfully match : - * {{{ - * ir"val b = 22.toDouble; val a = 11.toDouble; a + b" match { - * case ir"val aX = 11.toDouble; val bX = 22.toDouble; aX + bX" => ??? - * } - * }}} - * + * The traversal of the code is done externally by [[transformRep()]]. + * Hence, if the current statements don't have to be extracted at this point (both are pure and not return values) + * we simply skip the extraction of the current `xtee`. */ - (es.instructions.get(lb1.bound), es.instructions.get(lb2.bound)) match { + case (Skip, Skip) => fail - /** - * The traversal of the code is done externally by [[transformRep()]]. - * Hence, if the current statements don't have to be extracted at this point (both are pure and not return values) - * we simply skip the extraction of the current `xtee`. - */ - case (Skip, Skip) => fail - - case _ => extractWithState(lb1, lb2) - } - - case _ => extractWithState(xtor, xtee) - } + case _ => extractWithState(lb1, lb2) + } + + case _ => extractWithState(xtor, xtee) } def genCode(implicit es: State): Option[Rep] = { @@ -927,12 +910,8 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi * case ir"val rX = readInt; $body" => ??? * } * }}} - * `$body` will extract `ir"r + 1"` where `r` <-> `rX`. Here the let-binding `val rX = readInt; ...` - * is user-defined. Therefore, on the rhs of the rewriting rule the user can still write valid code - * (e.g. `ir"val r ` - * - * TODO ask about this - * + * `$body` will extract `ir"r + 1"` where `r` <-> `rX`. Since the let-binding `val rX = readInt; ...` + * is user-defined. */ def preCheck(ex: Extract): Boolean = { def preCheckRep(declaredBVs: Set[BoundVal], invCtx: Map[BoundVal, Set[BoundVal]], r: Rep): Boolean = { @@ -1118,15 +1097,15 @@ class FastANF extends InspectableBase with CurryEncoding with StandardEffects wi } if (preCheck(es.ex)) for { - code <- code(es.ex) alsoApply(c => println(s"CODE: $c")) - mergedCode = merge(code, xtee)(xtor, es.ctx) alsoApply (c => println(s"CODE0: $c")) + code <- code(es.ex) + mergedCode = merge(code, xtee)(xtor, es.ctx) finalCode <- filterNot(es.matchedImpureBVs)(mergedCode) } yield finalCode else None } rewriteRepWithState(xtor, xtee) match { - case Right(es) => genCode(es) alsoApply(c => println(s"GEN: $c")) + case Right(es) => genCode(es) case Left(_) => None } } diff --git a/src/main/scala/squid/test/Test.scala b/src/main/scala/squid/test/Test.scala deleted file mode 100644 index 344e8744..00000000 --- a/src/main/scala/squid/test/Test.scala +++ /dev/null @@ -1,45 +0,0 @@ -package squid -package test - -import squid.ir.FastANF - -object Test extends App { - object Embedding extends FastANF - import Embedding.Predef._ - import Embedding.Quasicodes._ - - //def odd(x: Int, y: Int)(z: Int)(r: Double*): Int = x + y + z - //def foo = 42 - // - //val bla = ir"foo" - // - //val program = ir"(a: Int, b: Int, c: Int, d: Double, e: Double) => odd(a, b)($bla)(d, e)" - - //val program = dbg_ir"(y: Int) => y + 5" - // - //println(s"$program") - - //val p1 = code"(x: Int) => x + 4" - - //case class Point(x: Int, y: Int) - //val p = Point(0, 1) - // - //code{p.x} match { - // case code"p.x" => println(code"p.y") - //} - // val a = ir"val a = 10.toDouble; val b = readDouble; (a + b) * 5" - // val b = a rewrite { - // case dbg_ir"val x = 10.toDouble; val y = readDouble; $body(x + y)" => ir"$body(222)" - // } alsoApply println - - // ir"(a: Int, b: Int) => a + 1" match { - //// case ir"(x: Int, y: Int) => $body(y):Int" => - // case db_ir"(x: Int, y: Int) => $body(x):Int" => - // } - - val a = ir"(readInt + 111) * .5" rewrite { - case ir"(($n: Int) + 111) * .5" => ir"$n * .25" - } alsoApply println - println(ir"readInt * .25") - assert(a =~= ir"readInt * .25") -} diff --git a/src/test/scala/squid/ir/fastir/HOPVTests.scala b/src/test/scala/squid/ir/fastir/HOPTests.scala similarity index 98% rename from src/test/scala/squid/ir/fastir/HOPVTests.scala rename to src/test/scala/squid/ir/fastir/HOPTests.scala index dbd01ddf..fc784aaa 100644 --- a/src/test/scala/squid/ir/fastir/HOPVTests.scala +++ b/src/test/scala/squid/ir/fastir/HOPTests.scala @@ -4,8 +4,8 @@ package fastir import scala.util.Try -class HOPVTests extends MyFunSuiteBase(HOPVTests.Embedding) { - import HOPVTests.Embedding.Predef._ +class HOPTests extends MyFunSuiteBase(HOPTests.Embedding) { + import HOPTests.Embedding.Predef._ test("Matching lambda bodies") { val id = ir"(z:Int) => z" @@ -179,6 +179,6 @@ class HOPVTests extends MyFunSuiteBase(HOPVTests.Embedding) { } } -object HOPVTests { +object HOPTests { object Embedding extends FastANF } diff --git a/src/test/scala/squid/ir/fastir/RewritingTests.scala b/src/test/scala/squid/ir/fastir/RewritingTests.scala index c95533b3..a3e08c44 100644 --- a/src/test/scala/squid/ir/fastir/RewritingTests.scala +++ b/src/test/scala/squid/ir/fastir/RewritingTests.scala @@ -2,8 +2,6 @@ package squid package ir package fastir -import squid.test.Test.Embedding - class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { import RewritingTests.Embedding.Predef._ import RewritingTests.Embedding.Quasicodes._ @@ -23,14 +21,6 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { case ir"(${Const(n)}: Int).toDouble" => ir"${Const(n.toDouble)}" } assert(c =~= ir"val t = ${ir"42"}.toDouble; 42.0") - - //assertDoesNotCompile(""" - // T.rewrite { case ir"0.5" => ir"42" } - //""") - - //assertDoesNotCompile(""" - // T.rewrite { case ir"123" => ir"($$n:Int)" } - //""") } test("Rewriting subpatterns") { @@ -49,13 +39,30 @@ class RewritingTests extends MyFunSuiteBase(RewritingTests.Embedding) { case ir"($a: Int) * 5" => ir"$a * 2" } assert(c =~= ir"val t1 = Option(42); val t2 = t1.get; val t3 = ${ir"42"} * 5; ${ir"42"} * 2") + + val d = ir"val r = (() => readInt.toDouble)(); r * r" rewrite { + case ir"readInt.toDouble" => ir"readDouble" + } + assert(d =~= ir"val r = (() => readDouble)(); r * r") } test("Rewriting with dead-ends") { - val b = ir"Option(42).get; 20" rewrite { + val a = ir"Option(42).get; 20" rewrite { case ir"Option(($n: Int)).get; 20" => n } - assert(b =~= ir"val t = ${ir"Option(42)"}; 42") + assert(a =~= ir"val t = ${ir"Option(42)"}; 42") + + val b = ir"val r1 = readDouble.toInt; val a = r1 + 2; val b = r1 + 4; r1" rewrite { + case ir"readDouble.toInt" => ir"readInt" + case ir"2" => ir"4" + case ir"4" => ir"8" + } + assert(b =~= ir"val r1 = readInt; val a = r1 + 8; val b = r1 + 8; r1") + + val c = ir"val r1 = readInt; val f = (x: Int) => { val r2 = readInt; val a = x + 1; r2 }; f(r1)" rewrite { + case ir"(x: Int) => { val rX = readInt; val a = x + 1; rX }" => ir"(x: Int) => { val a = x + 22; x }" + } + assert(c =~= ir"val r1 = readInt; val f = (x: Int) => { val a = x + 22; x }; f(r1)") } test("Substitution should be called from inside a reification context") {