diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala b/compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala index 60183126bca2..c4e5d1eb1162 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala @@ -62,14 +62,19 @@ case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean)(cls: Symbol) exte case _ => false override def mapWith(tm: TypeMap)(using Context) = - val elems = refs.elems.toList - val elems1 = elems.mapConserve(tm.mapCapability(_)) - if elems1 eq elems then this - else if elems1.forall: - case elem1: Capability => elem1.isWellformed - case _ => false - then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[Capability]]*), boxed) - else EmptyAnnotation + if ctx.phase.id > Phases.checkCapturesPhase.id then + // Annotation is no longer relevant, can be dropped. + // This avoids running into illegal states in mapCapability. + EmptyAnnotation + else + val elems = refs.elems.toList + val elems1 = elems.mapConserve(tm.mapCapability(_)) + if elems1 eq elems then this + else if elems1.forall: + case elem1: Capability => elem1.isWellformed + case _ => false + then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[Capability]]*), boxed) + else EmptyAnnotation override def refersToParamOf(tl: TermLambda)(using Context): Boolean = refs.elems.exists { diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index 3e00907332f2..34fe3c63a747 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -246,6 +246,28 @@ class CheckCaptures extends Recheck, SymTransformer: override def isRunnable(using Context) = super.isRunnable && Feature.ccEnabledSomewhere + /** We normally need a recompute if the prefix is a SingletonType and the + * last denotation is not a SymDenotation. The SingletonType requirement is + * so that we don't widen TermRefs with non-path prefixes to their underlying + * type when recomputing their denotations with asSeenFrom. Such widened types + * would become illegal members of capture sets. + * + * The SymDenotation requirement is so that we don't recompute termRefs of Symbols + * which should be handled by SymTransformers alone. However, if the underlying type + * of the prefix is a capturing type, we do need to recompute since in that case + * the prefix might carry a parameter refinement created in Setup, and we need to + * take these refinements into account. + */ + override def needsRecompute(tp: NamedType, lastDenotation: SingleDenotation)(using Context): Boolean = + tp.prefix match + case prefix: TermRef => + !lastDenotation.isInstanceOf[SymDenotation] + || !prefix.info.captureSet.isAlwaysEmpty + case prefix: SingletonType => + !lastDenotation.isInstanceOf[SymDenotation] + case _ => + false + def newRechecker()(using Context) = CaptureChecker(ctx) override def run(using Context): Unit = @@ -682,12 +704,6 @@ class CheckCaptures extends Recheck, SymTransformer: markFree(ref.readOnly, tree) else val sel = ref.select(pt.select.symbol).asInstanceOf[TermRef] - sel.recomputeDenot() - // We need to do a recomputeDenot here since we have not yet properly - // computed the type of the full path. This means that we erroneously - // think the denotation is the same as in the previous phase so no - // member computation is performed. A test case where this matters is - // read-only-use.scala, where the error on r3 goes unreported. markPathFree(sel, pt.pt, pt.select) case _ => markFree(ref.adjustReadOnly(pt), tree) @@ -1087,11 +1103,11 @@ class CheckCaptures extends Recheck, SymTransformer: if sym.is(Module) then sym.info // Modules are checked by checking the module class else if sym.is(Mutable) && !sym.hasAnnotation(defn.UncheckedCapturesAnnot) then - val addendum = capturedBy.get(sym) match + val addendum = setup.capturedBy.get(sym) match case Some(encl) => val enclStr = if encl.isAnonymousFunction then - val location = anonFunCallee.get(encl) match + val location = setup.anonFunCallee.get(encl) match case Some(meth) if meth.exists => i" argument in a call to $meth" case _ => "" s"an anonymous function$location" @@ -1907,49 +1923,12 @@ class CheckCaptures extends Recheck, SymTransformer: traverseChildren(t) end checkOverrides - /** Used for error reporting: - * Maps mutable variables to the symbols that capture them (in the - * CheckCaptures sense, i.e. symbol is referred to from a different method - * than the one it is defined in). - */ - private val capturedBy = util.HashMap[Symbol, Symbol]() - - /** Used for error reporting: - * Maps anonymous functions appearing as function arguments to - * the function that is called. - */ - private val anonFunCallee = util.HashMap[Symbol, Symbol]() - - /** Used for error reporting: - * Populates `capturedBy` and `anonFunCallee`. Called by `checkUnit`. - */ - private def collectCapturedMutVars(using Context) = new TreeTraverser: - def traverse(tree: Tree)(using Context) = tree match - case id: Ident => - val sym = id.symbol - if sym.isMutableVar && sym.owner.isTerm then - val enclMeth = ctx.owner.enclosingMethod - if sym.enclosingMethod != enclMeth then - capturedBy(sym) = enclMeth - case Apply(fn, args) => - for case closureDef(mdef) <- args do - anonFunCallee(mdef.symbol) = fn.symbol - traverseChildren(tree) - case Inlined(_, bindings, expansion) => - traverse(bindings) - traverse(expansion) - case mdef: DefDef => - if !mdef.symbol.isInlineMethod then traverseChildren(tree) - case _ => - traverseChildren(tree) - private val setup: SetupAPI = thisPhase.prev.asInstanceOf[Setup] override def checkUnit(unit: CompilationUnit)(using Context): Unit = capt.println(i"cc check ${unit.source}") ccState.start() setup.setupUnit(unit.tpdTree, this) - collectCapturedMutVars.traverse(unit.tpdTree) if ctx.settings.YccPrintSetup.value then val echoHeader = "[[syntax tree at end of cc setup]]" diff --git a/compiler/src/dotty/tools/dotc/cc/Setup.scala b/compiler/src/dotty/tools/dotc/cc/Setup.scala index fef4adc41b7d..38b60744c6f7 100644 --- a/compiler/src/dotty/tools/dotc/cc/Setup.scala +++ b/compiler/src/dotty/tools/dotc/cc/Setup.scala @@ -40,6 +40,20 @@ trait SetupAPI: /** Check to do after the capture checking traversal */ def postCheck()(using Context): Unit + /** Used for error reporting: + * Maps mutable variables to the symbols that capture them (in the + * CheckCaptures sense, i.e. symbol is referred to from a different method + * than the one it is defined in). + */ + def capturedBy: collection.Map[Symbol, Symbol] + + /** Used for error reporting: + * Maps anonymous functions appearing as function arguments to + * the function that is called. + */ + def anonFunCallee: collection.Map[Symbol, Symbol] +end SetupAPI + object Setup: val name: String = "setupCC" @@ -518,6 +532,18 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI: def traverse(tree: Tree)(using Context): Unit = tree match + case tree: Ident => + val sym = tree.symbol + if sym.isMutableVar && sym.owner.isTerm then + val enclMeth = ctx.owner.enclosingMethod + if sym.enclosingMethod != enclMeth then + capturedBy(sym) = enclMeth + + case Apply(fn, args) => + for case closureDef(mdef) <- args do + anonFunCallee(mdef.symbol) = fn.symbol + traverseChildren(tree) + case tree @ DefDef(_, paramss, tpt: TypeTree, _) => val meth = tree.symbol if isExcluded(meth) then @@ -567,9 +593,12 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI: traverse(body) catches.foreach(traverse) traverse(finalizer) + case tree: New => + case _ => traverseChildren(tree) + postProcess(tree) checkProperUseOrConsume(tree) end traverse @@ -889,11 +918,16 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI: else t case _ => mapFollowingAliases(t) + val capturedBy: mutable.HashMap[Symbol, Symbol] = mutable.HashMap[Symbol, Symbol]() + + val anonFunCallee: mutable.HashMap[Symbol, Symbol] = mutable.HashMap[Symbol, Symbol]() + /** Run setup on a compilation unit with given `tree`. * @param recheckDef the function to run for completing a val or def */ def setupUnit(tree: Tree, checker: CheckerAPI)(using Context): Unit = - setupTraverser(checker).traverse(tree)(using ctx.withPhase(thisPhase)) + inContext(ctx.withPhase(thisPhase)): + setupTraverser(checker).traverse(tree) // ------ Checks to run at Setup ---------------------------------------- diff --git a/compiler/src/dotty/tools/dotc/core/ContextOps.scala b/compiler/src/dotty/tools/dotc/core/ContextOps.scala index d4890cc02a1f..8e1cafd56628 100644 --- a/compiler/src/dotty/tools/dotc/core/ContextOps.scala +++ b/compiler/src/dotty/tools/dotc/core/ContextOps.scala @@ -135,4 +135,8 @@ object ContextOps: if (pkg.is(Package)) ctx.fresh.setOwner(pkg.moduleClass).setTree(tree).setNewScope else ctx } + + def isRechecking: Boolean = + (ctx.base.recheckPhaseIds & (1L << ctx.phaseId)) != 0 + end ContextOps diff --git a/compiler/src/dotty/tools/dotc/core/Phases.scala b/compiler/src/dotty/tools/dotc/core/Phases.scala index b24dbf2e8d7d..df102c514fdf 100644 --- a/compiler/src/dotty/tools/dotc/core/Phases.scala +++ b/compiler/src/dotty/tools/dotc/core/Phases.scala @@ -41,6 +41,18 @@ object Phases { // drop NoPhase at beginning def allPhases: Array[Phase] = (if (fusedPhases.nonEmpty) fusedPhases else phases).tail + private var myRecheckPhaseIds: Long = 0 + + /** A bitset of the ids of the phases extending `transform.Recheck`. + * Recheck phases must have id 63 or less. + */ + def recheckPhaseIds: Long = myRecheckPhaseIds + + def recordRecheckPhase(phase: Recheck): Unit = + val id = phase.id + assert(id < 64, s"Recheck phase with id $id outside permissible range 0..63") + myRecheckPhaseIds |= (1L << id) + object SomePhase extends Phase { def phaseName: String = "" def run(using Context): Unit = unsupported("run") diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 71699c992ab6..ca9eda247cc0 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -38,9 +38,11 @@ import config.Printers.{core, typr, matchTypes} import reporting.{trace, Message} import java.lang.ref.WeakReference import compiletime.uninitialized +import ContextOps.isRechecking import cc.* import CaptureSet.IdentityCaptRefMap import Capabilities.* +import transform.Recheck.currentRechecker import scala.annotation.internal.sharable import scala.annotation.threadUnsafe @@ -2509,15 +2511,31 @@ object Types extends TypeUtils { lastDenotation match { case lastd0: SingleDenotation => val lastd = lastd0.skipRemoved - if lastd.validFor.runId == ctx.runId && checkedPeriod.code != NowhereCode then + var needsRecompute = false + if lastd.validFor.runId == ctx.runId + && checkedPeriod.code != NowhereCode + && !(ctx.isRechecking + && { + needsRecompute = currentRechecker.needsRecompute(this, lastd) + needsRecompute + } + ) + then finish(lastd.current) - else lastd match { - case lastd: SymDenotation => - if stillValid(lastd) && checkedPeriod.code != NowhereCode then finish(lastd.current) - else finish(memberDenot(lastd.initial.name, allowPrivate = lastd.is(Private))) - case _ => - fromDesignator - } + else + val newd = lastd match + case lastd: SymDenotation => + if stillValid(lastd) && checkedPeriod.code != NowhereCode && !needsRecompute + then finish(lastd.current) + else finish(memberDenot(lastd.initial.name, allowPrivate = lastd.is(Private))) + case _ => + fromDesignator + if needsRecompute && (newd.info ne lastd.info) then + // Record the previous denotation, so that it can be reset at the end + // of the rechecker phase + currentRechecker.prevSelDenots(this) = lastd + //println(i"NEW PATH $this: ${newd.info} at ${ctx.phase}, prefix = $prefix") + newd case _ => fromDesignator } } diff --git a/compiler/src/dotty/tools/dotc/transform/Recheck.scala b/compiler/src/dotty/tools/dotc/transform/Recheck.scala index 34e3773ba147..6d9f8e4e90ed 100644 --- a/compiler/src/dotty/tools/dotc/transform/Recheck.scala +++ b/compiler/src/dotty/tools/dotc/transform/Recheck.scala @@ -50,6 +50,12 @@ object Recheck: case None => tree + /** The currently running rechecker + * @pre ctx.isRechecking + */ + def currentRechecker(using Context): Recheck = + ctx.phase.asInstanceOf[Recheck] + extension (sym: Symbol)(using Context) /** Update symbol's info to newInfo after `prevPhase`. @@ -143,6 +149,7 @@ abstract class Recheck extends Phase, SymTransformer: else symd def run(using Context): Unit = + ctx.base.recordRecheckPhase(this) val rechecker = newRechecker() rechecker.checkUnit(ctx.compilationUnit) rechecker.reset() @@ -151,6 +158,19 @@ abstract class Recheck extends Phase, SymTransformer: try super.runOn(units) finally preRecheckPhase.pastRecheck = true + /** A hook to determine whether the denotation of a NamedType should be recomputed + * from its symbol and prefix, instead of just evolving the previous denotation with + * `current`. This should return true if there are complex changes to types that + * are not reflected in `current`. + */ + def needsRecompute(tp: NamedType, lastDenotation: SingleDenotation)(using Context): Boolean = + false + + /** A map from NamedTypes to the denotations they had before this phase. + * Needed so that we can `reset` them after this phase. + */ + val prevSelDenots = util.HashMap[NamedType, Denotation]() + def newRechecker()(using Context): Rechecker /** The typechecker pass */ @@ -192,17 +212,13 @@ abstract class Recheck extends Phase, SymTransformer: def resetNuTypes()(using Context): Unit = nuTypes.clear(resetToInitial = false) - /** A map from NamedTypes to the denotations they had before this phase. - * Needed so that we can `reset` them after this phase. - */ - private val prevSelDenots = util.HashMap[NamedType, Denotation]() - /** Reset all references in `prevSelDenots` to the denotations they had * before this phase. */ def reset()(using Context): Unit = for (ref, mbr) <- prevSelDenots.iterator do ref.withDenot(mbr) + prevSelDenots.clear() /** Constant-folded rechecked type `tp` of tree `tree` */ protected def constFold(tree: Tree, tp: Type)(using Context): Type = diff --git a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala index 9f3629866a52..d9c6da8d97eb 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala @@ -412,16 +412,16 @@ object TreeChecker { assert(false, s"The type of a non-Super tree must not be a SuperType, but $tree has type $tp") case _ => - override def typed(tree: untpd.Tree, pt: Type = WildcardType)(using Context): Tree = { - val tpdTree = super.typed(tree, pt) - Typer.assertPositioned(tree) - checkSuper(tpdTree) - if (ctx.erasedTypes) - // Can't be checked in earlier phases since `checkValue` is only run in - // Erasure (because running it in Typer would force too much) - checkIdentNotJavaClass(tpdTree) - tpdTree - } + override def typed(tree: untpd.Tree, pt: Type = WildcardType)(using Context): Tree = + trace(i"checking $tree against $pt"): + val tpdTree = super.typed(tree, pt) + Typer.assertPositioned(tree) + checkSuper(tpdTree) + if (ctx.erasedTypes) + // Can't be checked in earlier phases since `checkValue` is only run in + // Erasure (because running it in Typer would force too much) + checkIdentNotJavaClass(tpdTree) + tpdTree override def typedUnadapted(tree: untpd.Tree, pt: Type, locked: TypeVars)(using Context): Tree = { try diff --git a/tests/neg-custom-args/captures/i23582.check b/tests/neg-custom-args/captures/i23582.check new file mode 100644 index 000000000000..7e78d9e8e93b --- /dev/null +++ b/tests/neg-custom-args/captures/i23582.check @@ -0,0 +1,13 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i23582.scala:27:26 --------------------------------------- +27 | parReduce(1 to 1000): (x, y) => // error + | ^ + |Found: (x: Int, y: Int) ->{write, read} Int + |Required: (Int, Int) ->{cap.only[Read]} Int + | + |Note that capability write is not included in capture set {cap.only[Read]}. + | + |where: cap is a fresh root capability created in method test when checking argument to parameter op of method parReduce +28 | write(x) +29 | x + y + read() + | + | longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/i23582.scala b/tests/neg-custom-args/captures/i23582.scala new file mode 100644 index 000000000000..1c3dd84e8458 --- /dev/null +++ b/tests/neg-custom-args/captures/i23582.scala @@ -0,0 +1,30 @@ +import caps.* +object Levels: + + trait Read extends Classifier, SharedCapability + trait ReadWrite extends Classifier, SharedCapability + + class Box[T](acc: T): + val access: T = acc + + def parReduce(xs: Seq[Int])(op: (Int, Int) ->{cap.only[Read]} Int): Int = xs.reduce(op) + + @main def test = + val r: Box[Read^] = ??? + val rw: Box[ReadWrite^] = ??? + val read: () ->{r.access*} Int = ??? + val write: Int ->{rw.access*} Unit = ??? + val checkRead: () ->{cap.only[Read]} Int = read + + //read() // causes error with and without the println below + parReduce(1 to 1000): (x, y) => + //println(r.access) // ok only if this is uncommented + read() + read() // should be ok + + parReduce(1 to 1000): (x, y) => + x + y + read() // should be ok + + parReduce(1 to 1000): (x, y) => // error + write(x) + x + y + read() +