Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions effekt/shared/src/main/scala/effekt/context/Context.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package context

import effekt.namer.NamerOps
import effekt.typer.{TyperOps, Unification}
import effekt.core.TransformerOps
import effekt.core.{BindingDB, TransformerOps}
import effekt.source.Tree
import effekt.util.messages.{EffektMessages, ErrorReporter}
import effekt.util.Timers
Expand Down Expand Up @@ -42,12 +42,14 @@ abstract class Context
extends NamerOps
with TyperOps
with ModuleDB
with TransformerOps
with Timers {

// bring the context itself in scope
implicit val context: Context = this

// Storage for bindings
var bindingDB: BindingDB = new BindingDB

// the currently processed module
var module: Module = _

Expand Down
15 changes: 15 additions & 0 deletions effekt/shared/src/main/scala/effekt/core/BindingDB.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package effekt.core


import scala.collection.mutable.ListBuffer

/**
* Storage for bindings
*/
final class BindingDB {

/**
* A _mutable_ ListBuffer that stores all bindings to be inserted at the current scope
*/
var bindings: ListBuffer[Binding] = ListBuffer()
}
92 changes: 46 additions & 46 deletions effekt/shared/src/main/scala/effekt/core/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {

def run(input: Typechecked)(using Context) =
val Typechecked(source, tree, mod) = input
Context.initTransformerState()
TransformerOps.initTransformerState()

if (Context.messaging.hasErrors) {
None
Expand Down Expand Up @@ -90,7 +90,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {

case v @ source.DefDef(id, captures, annot, binding, doc, span) =>
val sym = v.symbol
val (definition, bindings) = Context.withBindings {
val (definition, bindings) = TransformerOps.withBindings {
Toplevel.Def(sym, transformAsBlock(binding))
}

Expand Down Expand Up @@ -186,13 +186,13 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {
case v @ source.RegDef(id, _, reg, binding, doc, span) =>
val sym = v.symbol
insertBindings {
Alloc(sym, Context.bind(transform(binding)), sym.region, transform(rest))
Alloc(sym, TransformerOps.bind(transform(binding)), sym.region, transform(rest))
}

case v @ source.VarDef(id, _, binding, doc, span) =>
val sym = v.symbol
insertBindings {
Var(sym, Context.bind(transform(binding)), sym.capture, transform(rest))
Var(sym, TransformerOps.bind(transform(binding)), sym.capture, transform(rest))
}

case d: source.Def.Extern => Context.panic("Only allowed on the toplevel")
Expand Down Expand Up @@ -321,7 +321,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {
val tpe = TState.extractType(stateType)
val stateId = Id("s")
// emits `let s = !ref; return s`
Context.bind(Get(stateId, transform(tpe), sym, transform(Context.captureOf(sym)), Return(core.ValueVar(stateId, transform(tpe)))))
TransformerOps.bind(Get(stateId, transform(tpe), sym, transform(Context.captureOf(sym)), Return(core.ValueVar(stateId, transform(tpe)))))
case sym: ValueSymbol => ValueVar(sym)
case sym: BlockSymbol => transformBox(tree)
}
Expand Down Expand Up @@ -356,7 +356,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {
val tpe = transform(substitution.substitute(f.returnType))
core.ValueParam(if f == field then selected else Id("_"), tpe)
}
Context.bind(Stmt.Match(transformAsExpr(receiver),
TransformerOps.bind(Stmt.Match(transformAsExpr(receiver),
List((constructor, BlockLit(Nil, Nil, params, Nil, Stmt.Return(Expr.ValueVar(selected, tpe))))), None))

case source.Box(capt, block, _) =>
Expand All @@ -373,12 +373,12 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {

case source.If(List(MatchGuard.BooleanGuard(cond, _)), thn, els, _) =>
val c = transformAsExpr(cond)
Context.bind(If(c, transform(thn), transform(els)))
TransformerOps.bind(If(c, transform(thn), transform(els)))

case source.If(guards, thn, els, _) =>
val thnClause = preprocess("thn", Nil, guards, transform(thn))
val elsClause = preprocess("els", Nil, Nil, transform(els))
Context.bind(PatternMatchingCompiler.compile(List(thnClause, elsClause)))
TransformerOps.bind(PatternMatchingCompiler.compile(List(thnClause, elsClause)))

// case i @ source.If(guards, thn, els) =>
// val compiled = collectClauses(i)
Expand Down Expand Up @@ -412,22 +412,22 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {
}
}

Context.bind(loopName, Block.BlockLit(Nil, Nil, Nil, Nil, loopBody))
TransformerOps.bind(loopName, Block.BlockLit(Nil, Nil, Nil, Nil, loopBody))

Context.bind(loopCall)
TransformerOps.bind(loopCall)

// Empty match (matching on Nothing)
case source.Match(List(sc), Nil, None, _) =>
val scrutinee: ValueVar = Context.bind(transformAsExpr(sc))
Context.bind(core.Match(scrutinee, Nil, None))
val scrutinee: ValueVar = TransformerOps.bind(transformAsExpr(sc))
TransformerOps.bind(core.Match(scrutinee, Nil, None))

case source.Match(scs, cs, default, _) =>
// (1) Bind scrutinee and all clauses so we do not have to deal with sharing on demand.
val scrutinees: List[ValueVar] = scs.map{ sc => Context.bind(transformAsExpr(sc)) }
val scrutinees: List[ValueVar] = scs.map{ sc => TransformerOps.bind(transformAsExpr(sc)) }
val clauses = cs.zipWithIndex.map((c, i) => preprocess(s"k${i}", scrutinees, c))
val defaultClause = default.map(stmt => preprocess("k_els", Nil, Nil, transform(stmt))).toList
val compiledMatch = PatternMatchingCompiler.compile(clauses ++ defaultClause)
Context.bind(compiledMatch)
TransformerOps.bind(compiledMatch)

case source.TryHandle(prog, handlers, _) =>

Expand All @@ -451,21 +451,21 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {
val body: BlockLit = BlockLit(Nil, List(promptCapt), Nil, List(promptParam),
Binding(transformedHandlers, transform(prog)))

Context.bind(Reset(body))
TransformerOps.bind(Reset(body))

case r @ source.Region(name, body, _) =>
val region = r.symbol
val tpe = Context.blockTypeOf(region)
val cap: core.BlockParam = core.BlockParam(region, transform(tpe), Set(region.capture))
Context.bind(Region(BlockLit(Nil, List(region.capture), Nil, List(cap), transform(body))))
TransformerOps.bind(Region(BlockLit(Nil, List(region.capture), Nil, List(cap), transform(body))))

case source.Hole(id, stmts, span) =>
Context.bind(core.Hole(span))
TransformerOps.bind(core.Hole(span))

case a @ source.Assign(id, expr, _) =>
val sym = a.definition
// emits `ref := value; return ()`
Context.bind(Put(sym, transform(Context.captureOf(sym)), transformAsExpr(expr), Return(Literal((), core.Type.TUnit))))
TransformerOps.bind(Put(sym, transform(Context.captureOf(sym)), transformAsExpr(expr), Return(Literal((), core.Type.TUnit))))
Literal((), core.Type.TUnit)

// methods are dynamically dispatched, so we have to assume they are `control`, hence no PureApp.
Expand All @@ -485,7 +485,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {
// Do not pass type arguments for the type constructor of the receiver.
val remainingTypeArgs = typeArgs.drop(operation.interface.tparams.size)

Context.bind(Invoke(rec, operation, opType, remainingTypeArgs, valueArgs, blockArgs))
TransformerOps.bind(Invoke(rec, operation, opType, remainingTypeArgs, valueArgs, blockArgs))

case c @ source.Call(source.ExprTarget(source.Unbox(expr, _)), targs, vargs, bargs, _) =>

Expand All @@ -499,7 +499,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {
val blockArgs = bargs.map(transformAsBlock)
// val captArgs = blockArgs.map(b => b.capt) //transform(Context.inferredCapture(b)))

Context.bind(App(Unbox(e), typeArgs, valueArgs, blockArgs))
TransformerOps.bind(App(Unbox(e), typeArgs, valueArgs, blockArgs))

case c @ source.Call(fun: source.IdTarget, _, vargs, bargs, _) =>
// assumption: typer removed all ambiguous references, so there is exactly one
Expand Down Expand Up @@ -694,7 +694,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {
// create joinpoint
val tparams = patterns.flatMap { case (sc, p) => boundTypesInPattern(p) } ++ guards.flatMap(boundTypesInGuard)
val params = patterns.flatMap { case (sc, p) => boundInPattern(p) } ++ guards.flatMap(boundInGuard)
val joinpoint = Context.bind(TmpBlock(label), BlockLit(tparams, Nil, params, Nil, body))
val joinpoint = TransformerOps.bind(TmpBlock(label), BlockLit(tparams, Nil, params, Nil, body))

def transformPattern(p: source.MatchPattern): Pattern = p match {
case source.AnyPattern(id, _) =>
Expand All @@ -721,12 +721,12 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {
}

def transformGuard(p: source.MatchGuard): List[Condition] =
val (cond, bindings) = Context.withBindings {
val (cond, bindings) = TransformerOps.withBindings {
p match {
case MatchGuard.BooleanGuard(condition, _) =>
Condition.Predicate(transformAsExpr(condition))
case MatchGuard.PatternGuard(scrutinee, pattern, _) =>
val x = Context.bind(transformAsExpr(scrutinee))
val x = TransformerOps.bind(transformAsExpr(scrutinee))
Condition.Patterns(Map(x -> transformPattern(pattern)))
}
}
Expand Down Expand Up @@ -807,7 +807,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {
case f: Callable if callingConvention(f) == CallingConvention.Pure =>
PureApp(BlockVar(f), targs, vargsT)
case f: Callable if callingConvention(f) == CallingConvention.Direct =>
Context.bind(BlockVar(f), targs, vargsT, bargsT)
TransformerOps.bind(BlockVar(f), targs, vargsT, bargsT)
case r: Constructor =>
if (bargs.nonEmpty) Context.abort("Constructors cannot take block arguments.")
val universals = targs.take(r.tpe.tparams.length)
Expand All @@ -818,9 +818,9 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {
case f: Field =>
Context.panic("Should have been translated to a select!")
case f: BlockSymbol =>
Context.bind(App(BlockVar(f), targs, vargsT, bargsT))
TransformerOps.bind(App(BlockVar(f), targs, vargsT, bargsT))
case f: ValueSymbol =>
Context.bind(App(Unbox(ValueVar(f)), targs, vargsT, bargsT))
TransformerOps.bind(App(Unbox(ValueVar(f)), targs, vargsT, bargsT))
}
}

Expand All @@ -829,7 +829,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {
def transform(p: source.ValueParam)(using Context): core.ValueParam = ValueParam(p.symbol)

def insertBindings(stmt: => Stmt)(using Context): Stmt = {
val (body, bindings) = Context.withBindings { stmt }
val (body, bindings) = TransformerOps.withBindings { stmt }
Binding(bindings, body)
}

Expand Down Expand Up @@ -901,15 +901,15 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {

}

trait TransformerOps extends ContextOps { Context: Context =>
/**
* Helper for dealing with bindings in the [[BindingDB]]
* As [[Context]] stores the [[BindingDB]], it is implicitly passed to all the functions
*/
object TransformerOps {

/**
* A _mutable_ ListBuffer that stores all bindings to be inserted at the current scope
*/
private var bindings: ListBuffer[Binding] = ListBuffer()

private[core] def initTransformerState() = {
bindings = ListBuffer()
private[core] def initTransformerState()(using Context) = {
Context.bindingDB.bindings = ListBuffer()
}

/**
Expand All @@ -918,50 +918,50 @@ trait TransformerOps extends ContextOps { Context: Context =>
* @param tpe the type of the bound statement
* @param s the statement to be bound
*/
private[core] def bind(s: Stmt): ValueVar = {
private[core] def bind(s: Stmt)(using Context): ValueVar = {

// create a fresh symbol and assign the type
val x = TmpValue("r")

val binding = Binding.Val(x, s.tpe, s)
bindings += binding
Context.bindingDB.bindings += binding

ValueVar(x, s.tpe)
}

private[core] def bind(e: Expr): ValueVar = e match {
private[core] def bind(e: Expr)(using Context): ValueVar = e match {
case x: ValueVar => x
case e =>
// create a fresh symbol and assign the type
val x = TmpValue("r")

val binding = Binding.Let(x, e.tpe, e)
bindings += binding
Context.bindingDB.bindings += binding

ValueVar(x, e.tpe)
}

private[core] def bind(callee: Block.BlockVar, targs: List[core.ValueType], vargs: List[Expr], bargs: List[Block]): ValueVar = {
private[core] def bind(callee: Block.BlockVar, targs: List[core.ValueType], vargs: List[Expr], bargs: List[Block])(using Context): ValueVar = {
// create a fresh symbol and assign the type
val x = TmpValue("r")
val binding: Binding.ImpureApp = Binding.ImpureApp(x, callee, targs, vargs, bargs)
bindings += binding
Context.bindingDB.bindings += binding

ValueVar(x, Type.bindingType(binding))
}

private[core] def bind(name: BlockSymbol, b: Block): BlockVar = {
private[core] def bind(name: BlockSymbol, b: Block)(using Context): BlockVar = {
val binding = Binding.Def(name, b)
bindings += binding
Context.bindingDB.bindings += binding
BlockVar(name, b.tpe, b.capt)
}

private[core] def withBindings[R](block: => R): (R, List[Binding]) = Context in {
val before = bindings
private[core] def withBindings[R](block: => R)(using Context): (R, List[Binding]) = Context in {
val before = Context.bindingDB.bindings
val b = ListBuffer.empty[Binding]
bindings = b
Context.bindingDB.bindings = b
val result = block
bindings = before
Context.bindingDB.bindings = before
(result, b.toList)
}
}