Skip to content

Commit 8840779

Browse files
committed
Try to only instantiate one type variable bounded function types
1 parent 19cec83 commit 8840779

File tree

3 files changed

+41
-13
lines changed

3 files changed

+41
-13
lines changed

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,21 +1921,39 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19211921
NoType
19221922
}
19231923

1924-
def tryToInstantiateInUnion(tp: Type): Unit = tp match
1925-
case tp: OrType =>
1926-
tryToInstantiateInUnion(tp.tp1)
1927-
tryToInstantiateInUnion(tp.tp2)
1924+
/** Try to instantiate one type variable bounded by function types that appear
1925+
* deeply inside `tp`, including union or intersection types.
1926+
*/
1927+
def tryToInstantiateDeeply(tp: Type): Boolean = tp match
1928+
case tp: AndOrType =>
1929+
tryToInstantiateDeeply(tp.tp1)
1930+
|| tryToInstantiateDeeply(tp.tp2)
19281931
case tp: FlexibleType =>
1929-
tryToInstantiateInUnion(tp.hi)
1930-
case tp: TypeVar =>
1932+
tryToInstantiateDeeply(tp.hi)
1933+
case tp: TypeVar if isConstrainedByFunctionType(tp) =>
1934+
// Only instantiate if the type variable is constrained by function types
19311935
isFullyDefined(tp, ForceDegree.flipBottom)
1932-
case _ =>
1936+
case _ => false
1937+
1938+
def isConstrainedByFunctionType(tvar: TypeVar): Boolean =
1939+
val origin = tvar.origin
1940+
val bounds = ctx.typerState.constraint.bounds(origin)
1941+
def containsFunctionType(tp: Type): Boolean = tp.dealias match
1942+
case tp if defn.isFunctionType(tp) => true
1943+
case SAMType(_, _) => true
1944+
case tp: AndOrType =>
1945+
containsFunctionType(tp.tp1) || containsFunctionType(tp.tp2)
1946+
case tp: FlexibleType =>
1947+
containsFunctionType(tp.hi)
1948+
case _ => false
1949+
containsFunctionType(bounds.lo) || containsFunctionType(bounds.hi)
19331950

19341951
if untpd.isFunctionWithUnknownParamType(tree) && !calleeType.exists then
1935-
// Try to instantiate `pt` when possible, including type variables in union types
1936-
// to help finding function types. If it does not work the error will be reported
1937-
// later in `inferredParam`, when we try to infer the parameter type.
1938-
tryToInstantiateInUnion(pt)
1952+
// Try to instantiate `pt` when possible, by searching a nested type variable
1953+
// bounded by function types to help infer parameter types.
1954+
// If it does not work the error will be reported later in `inferredParam`,
1955+
// when we try to infer the parameter type.
1956+
tryToInstantiateDeeply(pt)
19391957

19401958
val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree.srcPos)
19411959

tests/neg-custom-args/captures/i15923.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
|
77
|Note that capability lcap cannot be included in outer capture set 's1 of parameter cap.
88
|
9-
|where: => refers to a fresh root capability created in anonymous function of type (using lcap: scala.caps.Capability): test2.Cap^{lcap} -> [T] => (op: test2.Cap^{lcap} => T) -> T when instantiating expected result type test2.Cap^{lcap} ->{cap²} [T] => (op: test2.Cap^'s6 ->'s7 T) ->'s8 T of function literal
9+
|where: => refers to a fresh root capability created in anonymous function of type (using lcap: scala.caps.Capability): test2.Cap^{lcap} -> [T] => (op: test2.Cap => T) -> T when instantiating expected result type test2.Cap^{lcap} ->{cap²} [T] => (op: test2.Cap^'s6 ->'s7 T) ->'s8 T of function literal
1010
|
1111
| longer explanation available when compiling with `-explain`
1212
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i15923.scala:12:21 ---------------------------------------

tests/pos/infer-function-type-in-union.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,14 @@ def test =
2929
}
3030
val o3: MyOption[(String, String) => Boolean] = MyOption {
3131
(x, y) => x.length > y.length
32-
}
32+
}
33+
34+
35+
class Box[T]
36+
val box: Box[Unit] = ???
37+
def ff1[T, U](x: T | U, y: Box[U]): T = ???
38+
def ff2[T, U](x: T & U): T = ???
39+
40+
def test2 =
41+
val a1: Any => Any = ff1(x => x, box)
42+
val a2: Any => Any = ff2(x => x)

0 commit comments

Comments
 (0)