Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,19 @@ case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean)(cls: Symbol) exte
case _ => false

override def mapWith(tm: TypeMap)(using Context) =
val elems = refs.elems.toList
val elems1 = elems.mapConserve(tm.mapCapability(_))
if elems1 eq elems then this
else if elems1.forall:
case elem1: Capability => elem1.isWellformed
case _ => false
then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[Capability]]*), boxed)
else EmptyAnnotation
if ctx.phase.id > Phases.checkCapturesPhase.id then
// Annotation is no longer relevant, can be dropped.
// This avoids running into illegal states in mapCapability.
EmptyAnnotation
else
val elems = refs.elems.toList
val elems1 = elems.mapConserve(tm.mapCapability(_))
if elems1 eq elems then this
else if elems1.forall:
case elem1: Capability => elem1.isWellformed
case _ => false
then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[Capability]]*), boxed)
else EmptyAnnotation

override def refersToParamOf(tl: TermLambda)(using Context): Boolean =
refs.elems.exists {
Expand Down
69 changes: 24 additions & 45 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,28 @@ class CheckCaptures extends Recheck, SymTransformer:

override def isRunnable(using Context) = super.isRunnable && Feature.ccEnabledSomewhere

/** We normally need a recompute if the prefix is a SingletonType and the
* last denotation is not a SymDenotation. The SingletonType requirement is
* so that we don't widen TermRefs with non-path prefixes to their underlying
* type when recomputing their denotations with asSeenFrom. Such widened types
* would become illegal members of capture sets.
*
* The SymDenotation requirement is so that we don't recompute termRefs of Symbols
* which should be handled by SymTransformers alone. However, if the underlying type
* of the prefix is a capturing type, we do need to recompute since in that case
* the prefix might carry a parameter refinement created in Setup, and we need to
* take these refinements into account.
*/
override def needsRecompute(tp: NamedType, lastDenotation: SingleDenotation)(using Context): Boolean =
tp.prefix match
case prefix: TermRef =>
!lastDenotation.isInstanceOf[SymDenotation]
|| !prefix.info.captureSet.isAlwaysEmpty
case prefix: SingletonType =>
!lastDenotation.isInstanceOf[SymDenotation]
case _ =>
false

def newRechecker()(using Context) = CaptureChecker(ctx)

override def run(using Context): Unit =
Expand Down Expand Up @@ -682,12 +704,6 @@ class CheckCaptures extends Recheck, SymTransformer:
markFree(ref.readOnly, tree)
else
val sel = ref.select(pt.select.symbol).asInstanceOf[TermRef]
sel.recomputeDenot()
// We need to do a recomputeDenot here since we have not yet properly
// computed the type of the full path. This means that we erroneously
// think the denotation is the same as in the previous phase so no
// member computation is performed. A test case where this matters is
// read-only-use.scala, where the error on r3 goes unreported.
markPathFree(sel, pt.pt, pt.select)
case _ =>
markFree(ref.adjustReadOnly(pt), tree)
Expand Down Expand Up @@ -1087,11 +1103,11 @@ class CheckCaptures extends Recheck, SymTransformer:
if sym.is(Module) then sym.info // Modules are checked by checking the module class
else
if sym.is(Mutable) && !sym.hasAnnotation(defn.UncheckedCapturesAnnot) then
val addendum = capturedBy.get(sym) match
val addendum = setup.capturedBy.get(sym) match
case Some(encl) =>
val enclStr =
if encl.isAnonymousFunction then
val location = anonFunCallee.get(encl) match
val location = setup.anonFunCallee.get(encl) match
case Some(meth) if meth.exists => i" argument in a call to $meth"
case _ => ""
s"an anonymous function$location"
Expand Down Expand Up @@ -1907,49 +1923,12 @@ class CheckCaptures extends Recheck, SymTransformer:
traverseChildren(t)
end checkOverrides

/** Used for error reporting:
* Maps mutable variables to the symbols that capture them (in the
* CheckCaptures sense, i.e. symbol is referred to from a different method
* than the one it is defined in).
*/
private val capturedBy = util.HashMap[Symbol, Symbol]()

/** Used for error reporting:
* Maps anonymous functions appearing as function arguments to
* the function that is called.
*/
private val anonFunCallee = util.HashMap[Symbol, Symbol]()

/** Used for error reporting:
* Populates `capturedBy` and `anonFunCallee`. Called by `checkUnit`.
*/
private def collectCapturedMutVars(using Context) = new TreeTraverser:
def traverse(tree: Tree)(using Context) = tree match
case id: Ident =>
val sym = id.symbol
if sym.isMutableVar && sym.owner.isTerm then
val enclMeth = ctx.owner.enclosingMethod
if sym.enclosingMethod != enclMeth then
capturedBy(sym) = enclMeth
case Apply(fn, args) =>
for case closureDef(mdef) <- args do
anonFunCallee(mdef.symbol) = fn.symbol
traverseChildren(tree)
case Inlined(_, bindings, expansion) =>
traverse(bindings)
traverse(expansion)
case mdef: DefDef =>
if !mdef.symbol.isInlineMethod then traverseChildren(tree)
case _ =>
traverseChildren(tree)

private val setup: SetupAPI = thisPhase.prev.asInstanceOf[Setup]

override def checkUnit(unit: CompilationUnit)(using Context): Unit =
capt.println(i"cc check ${unit.source}")
ccState.start()
setup.setupUnit(unit.tpdTree, this)
collectCapturedMutVars.traverse(unit.tpdTree)

if ctx.settings.YccPrintSetup.value then
val echoHeader = "[[syntax tree at end of cc setup]]"
Expand Down
36 changes: 35 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,20 @@ trait SetupAPI:
/** Check to do after the capture checking traversal */
def postCheck()(using Context): Unit

/** Used for error reporting:
* Maps mutable variables to the symbols that capture them (in the
* CheckCaptures sense, i.e. symbol is referred to from a different method
* than the one it is defined in).
*/
def capturedBy: collection.Map[Symbol, Symbol]

/** Used for error reporting:
* Maps anonymous functions appearing as function arguments to
* the function that is called.
*/
def anonFunCallee: collection.Map[Symbol, Symbol]
end SetupAPI

object Setup:

val name: String = "setupCC"
Expand Down Expand Up @@ -518,6 +532,18 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:

def traverse(tree: Tree)(using Context): Unit =
tree match
case tree: Ident =>
val sym = tree.symbol
if sym.isMutableVar && sym.owner.isTerm then
val enclMeth = ctx.owner.enclosingMethod
if sym.enclosingMethod != enclMeth then
capturedBy(sym) = enclMeth

case Apply(fn, args) =>
for case closureDef(mdef) <- args do
anonFunCallee(mdef.symbol) = fn.symbol
traverseChildren(tree)

case tree @ DefDef(_, paramss, tpt: TypeTree, _) =>
val meth = tree.symbol
if isExcluded(meth) then
Expand Down Expand Up @@ -567,9 +593,12 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
traverse(body)
catches.foreach(traverse)
traverse(finalizer)

case tree: New =>

case _ =>
traverseChildren(tree)

postProcess(tree)
checkProperUseOrConsume(tree)
end traverse
Expand Down Expand Up @@ -889,11 +918,16 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
else t
case _ => mapFollowingAliases(t)

val capturedBy: mutable.HashMap[Symbol, Symbol] = mutable.HashMap[Symbol, Symbol]()

val anonFunCallee: mutable.HashMap[Symbol, Symbol] = mutable.HashMap[Symbol, Symbol]()

/** Run setup on a compilation unit with given `tree`.
* @param recheckDef the function to run for completing a val or def
*/
def setupUnit(tree: Tree, checker: CheckerAPI)(using Context): Unit =
setupTraverser(checker).traverse(tree)(using ctx.withPhase(thisPhase))
inContext(ctx.withPhase(thisPhase)):
setupTraverser(checker).traverse(tree)

// ------ Checks to run at Setup ----------------------------------------

Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/ContextOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,8 @@ object ContextOps:
if (pkg.is(Package)) ctx.fresh.setOwner(pkg.moduleClass).setTree(tree).setNewScope
else ctx
}

def isRechecking: Boolean =
(ctx.base.recheckPhaseIds & (1L << ctx.phaseId)) != 0

end ContextOps
12 changes: 12 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Phases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ object Phases {
// drop NoPhase at beginning
def allPhases: Array[Phase] = (if (fusedPhases.nonEmpty) fusedPhases else phases).tail

private var myRecheckPhaseIds: Long = 0

/** A bitset of the ids of the phases extending `transform.Recheck`.
* Recheck phases must have id 63 or less.
*/
def recheckPhaseIds: Long = myRecheckPhaseIds

def recordRecheckPhase(phase: Recheck): Unit =
val id = phase.id
assert(id < 64, s"Recheck phase with id $id outside permissible range 0..63")
myRecheckPhaseIds |= (1L << id)

object SomePhase extends Phase {
def phaseName: String = "<some phase>"
def run(using Context): Unit = unsupported("run")
Expand Down
34 changes: 26 additions & 8 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ import config.Printers.{core, typr, matchTypes}
import reporting.{trace, Message}
import java.lang.ref.WeakReference
import compiletime.uninitialized
import ContextOps.isRechecking
import cc.*
import CaptureSet.IdentityCaptRefMap
import Capabilities.*
import transform.Recheck.currentRechecker

import scala.annotation.internal.sharable
import scala.annotation.threadUnsafe
Expand Down Expand Up @@ -2509,15 +2511,31 @@ object Types extends TypeUtils {
lastDenotation match {
case lastd0: SingleDenotation =>
val lastd = lastd0.skipRemoved
if lastd.validFor.runId == ctx.runId && checkedPeriod.code != NowhereCode then
var needsRecompute = false
if lastd.validFor.runId == ctx.runId
&& checkedPeriod.code != NowhereCode
&& !(ctx.isRechecking
&& {
needsRecompute = currentRechecker.needsRecompute(this, lastd)
needsRecompute
}
)
then
finish(lastd.current)
else lastd match {
case lastd: SymDenotation =>
if stillValid(lastd) && checkedPeriod.code != NowhereCode then finish(lastd.current)
else finish(memberDenot(lastd.initial.name, allowPrivate = lastd.is(Private)))
case _ =>
fromDesignator
}
else
val newd = lastd match
case lastd: SymDenotation =>
if stillValid(lastd) && checkedPeriod.code != NowhereCode && !needsRecompute
then finish(lastd.current)
else finish(memberDenot(lastd.initial.name, allowPrivate = lastd.is(Private)))
case _ =>
fromDesignator
if needsRecompute && (newd.info ne lastd.info) then
// Record the previous denotation, so that it can be reset at the end
// of the rechecker phase
currentRechecker.prevSelDenots(this) = lastd
//println(i"NEW PATH $this: ${newd.info} at ${ctx.phase}, prefix = $prefix")
newd
case _ => fromDesignator
}
}
Expand Down
26 changes: 21 additions & 5 deletions compiler/src/dotty/tools/dotc/transform/Recheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ object Recheck:
case None =>
tree

/** The currently running rechecker
* @pre ctx.isRechecking
*/
def currentRechecker(using Context): Recheck =
ctx.phase.asInstanceOf[Recheck]

extension (sym: Symbol)(using Context)

/** Update symbol's info to newInfo after `prevPhase`.
Expand Down Expand Up @@ -143,6 +149,7 @@ abstract class Recheck extends Phase, SymTransformer:
else symd

def run(using Context): Unit =
ctx.base.recordRecheckPhase(this)
val rechecker = newRechecker()
rechecker.checkUnit(ctx.compilationUnit)
rechecker.reset()
Expand All @@ -151,6 +158,19 @@ abstract class Recheck extends Phase, SymTransformer:
try super.runOn(units)
finally preRecheckPhase.pastRecheck = true

/** A hook to determine whether the denotation of a NamedType should be recomputed
* from its symbol and prefix, instead of just evolving the previous denotation with
* `current`. This should return true if there are complex changes to types that
* are not reflected in `current`.
*/
def needsRecompute(tp: NamedType, lastDenotation: SingleDenotation)(using Context): Boolean =
false

/** A map from NamedTypes to the denotations they had before this phase.
* Needed so that we can `reset` them after this phase.
*/
val prevSelDenots = util.HashMap[NamedType, Denotation]()

def newRechecker()(using Context): Rechecker

/** The typechecker pass */
Expand Down Expand Up @@ -192,17 +212,13 @@ abstract class Recheck extends Phase, SymTransformer:
def resetNuTypes()(using Context): Unit =
nuTypes.clear(resetToInitial = false)

/** A map from NamedTypes to the denotations they had before this phase.
* Needed so that we can `reset` them after this phase.
*/
private val prevSelDenots = util.HashMap[NamedType, Denotation]()

/** Reset all references in `prevSelDenots` to the denotations they had
* before this phase.
*/
def reset()(using Context): Unit =
for (ref, mbr) <- prevSelDenots.iterator do
ref.withDenot(mbr)
prevSelDenots.clear()

/** Constant-folded rechecked type `tp` of tree `tree` */
protected def constFold(tree: Tree, tp: Type)(using Context): Type =
Expand Down
20 changes: 10 additions & 10 deletions compiler/src/dotty/tools/dotc/transform/TreeChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -412,16 +412,16 @@ object TreeChecker {
assert(false, s"The type of a non-Super tree must not be a SuperType, but $tree has type $tp")
case _ =>

override def typed(tree: untpd.Tree, pt: Type = WildcardType)(using Context): Tree = {
val tpdTree = super.typed(tree, pt)
Typer.assertPositioned(tree)
checkSuper(tpdTree)
if (ctx.erasedTypes)
// Can't be checked in earlier phases since `checkValue` is only run in
// Erasure (because running it in Typer would force too much)
checkIdentNotJavaClass(tpdTree)
tpdTree
}
override def typed(tree: untpd.Tree, pt: Type = WildcardType)(using Context): Tree =
trace(i"checking $tree against $pt"):
val tpdTree = super.typed(tree, pt)
Typer.assertPositioned(tree)
checkSuper(tpdTree)
if (ctx.erasedTypes)
// Can't be checked in earlier phases since `checkValue` is only run in
// Erasure (because running it in Typer would force too much)
checkIdentNotJavaClass(tpdTree)
tpdTree

override def typedUnadapted(tree: untpd.Tree, pt: Type, locked: TypeVars)(using Context): Tree = {
try
Expand Down
13 changes: 13 additions & 0 deletions tests/neg-custom-args/captures/i23582.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i23582.scala:27:26 ---------------------------------------
27 | parReduce(1 to 1000): (x, y) => // error
| ^
|Found: (x: Int, y: Int) ->{write, read} Int
|Required: (Int, Int) ->{cap.only[Read]} Int
|
|Note that capability write is not included in capture set {cap.only[Read]}.
|
|where: cap is a fresh root capability created in method test when checking argument to parameter op of method parReduce
28 | write(x)
29 | x + y + read()
|
| longer explanation available when compiling with `-explain`
Loading
Loading