Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 50 additions & 49 deletions compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ class QuoteMatcher(debug: Boolean) {
private type MatchingExprs = Seq[MatchResult]

/** TODO-18271: update
* A map relating equivalent symbols from the scrutinee and the pattern
* For example in
* ```
* '{val a = 4; a * a} match case '{ val x = 4; x * x }
* ```
* when matching `a * a` with `x * x` the environment will contain `Map(a -> x)`.
*/
* A map relating equivalent symbols from the scrutinee and the pattern
* For example in
* ```
* '{val a = 4; a * a} match case '{ val x = 4; x * x }
* ```
* when matching `a * a` with `x * x` the environment will contain `Map(a -> x)`.
*/
private case class Env(val termEnv: Map[Symbol, Symbol], val typeEnv: Map[Symbol, Symbol])

private def withEnv[T](env: Env)(body: Env ?=> T): T = body(using env)
Expand Down Expand Up @@ -198,14 +198,14 @@ class QuoteMatcher(debug: Boolean) {
extension (scrutinee0: Tree)

/** Check that the trees match and return the contents from the pattern holes.
* Return a sequence containing all the contents in the holes.
* If it does not match, continues to the `optional` with `None`.
*
* @param scrutinee The tree being matched
* @param pattern The pattern tree that the scrutinee should match. Contains `patternHole` holes.
* @param `summon[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
* @return The sequence with the contents of the holes of the matched expression.
*/
* Return a sequence containing all the contents in the holes.
* If it does not match, continues to the `optional` with `None`.
*
* @param scrutinee The tree being matched
* @param pattern The pattern tree that the scrutinee should match. Contains `patternHole` holes.
* @param `summon[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
* @return The sequence with the contents of the holes of the matched expression.
*/
private def =?= (pattern0: Tree)(using Env, Context): optional[MatchingExprs] =

/* Match block flattening */ // TODO move to cases
Expand Down Expand Up @@ -238,19 +238,19 @@ class QuoteMatcher(debug: Boolean) {
case _ => None
end TypeTreeTypeTest

/* Some of method symbols in arguments of higher-order term hole are eta-expanded.
* e.g.
* g: (Int) => Int
* => {
* def $anonfun(y: Int): Int = g(y)
* closure($anonfun)
* }
*
* f: (using Int) => Int
* => f(using x)
* This function restores the symbol of the original method from
* the eta-expanded function.
*/
/** Some of method symbols in arguments of higher-order term hole are eta-expanded.
* e.g.
* g: (Int) => Int
* => {
* def $anonfun(y: Int): Int = g(y)
* closure($anonfun)
* }
*
* f: (using Int) => Int
* => f(using x)
* This function restores the symbol of the original method from
* the eta-expanded function.
*/
def getCapturedIdent(arg: Tree)(using Context): Ident =
arg match
case id: Ident => id
Expand Down Expand Up @@ -454,19 +454,21 @@ class QuoteMatcher(debug: Boolean) {
notMatched
case _ => matched

/**
* Implementation restriction: The current implementation matches type parameters
* only when they have empty bounds (>: Nothing <: Any)
*/
def matchTypeDef(sctypedef: TypeDef, pttypedef: TypeDef): MatchingExprs = sctypedef match
case TypeDef(_, TypeBoundsTree(sclo, schi, EmptyTree))
if sclo.tpe.isNothingType && schi.tpe.isAny =>
pttypedef match
case TypeDef(_, TypeBoundsTree(ptlo, pthi, EmptyTree))
if sclo.tpe.isNothingType && schi.tpe.isAny =>
matched
/** Implementation restriction: The current implementation matches type parameters
* only when they have empty bounds (>: Nothing <: Any)
*/
def matchTypeDef(sctypedef: TypeDef, pttypedef: TypeDef): MatchingExprs =
inline def recur(i: Int): MatchingExprs =
if i == 2 then
matched
else
val td = if i == 0 then sctypedef else pttypedef
td.rhs match
case tbt: TypeBoundsTree
if tbt.lo.tpe.isNothingType && tbt.hi.tpe.isAny && tbt.alias.isEmpty
=> recur(i + 1)
case _ => notMatched
case _ => notMatched
recur(i = 0)

def matchParamss(scparamss: List[ParamClause], ptparamss: List[ParamClause])(using Env): optional[(Env, MatchingExprs)] =
(scparamss, ptparamss) match {
Expand Down Expand Up @@ -550,10 +552,10 @@ class QuoteMatcher(debug: Boolean) {
end extension

/** Does the scrutinee symbol match the pattern symbol? It matches if:
* - They are the same symbol
* - The scrutinee has is in the environment and they are equivalent
* - The scrutinee overrides the symbol of the pattern
*/
* - They are the same symbol
* - The scrutinee is in the environment and they are equivalent
* - The scrutinee overrides the symbol of the pattern
*/
private def symbolMatch(scrutineeTree: Tree, patternTree: Tree)(using Env, Context): Boolean =
val scrutinee = scrutineeTree.symbol

Expand Down Expand Up @@ -673,11 +675,10 @@ class QuoteMatcher(debug: Boolean) {
treeMap = new TreeMap {
override def transform(tree: Tree)(using Context): Tree =
tree match
/*
* When matching a method call `f(0)` against a HOAS pattern `p(g)` where
* f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold
* `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion.
*/
/* When matching a method call `f(0)` against a HOAS pattern `p(g)` where
* f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold
* `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion.
*/
case Apply(fun, args) if termEnv.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform))
case tree: Ident => termEnv.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
case tree => super.transform(tree)
Expand Down
Loading