Skip to content

Commit b7387d1

Browse files
committed
Refine treatment of fields
- Charge the use set of the initializer to the class constructor - Charge the declared capture set to the class
1 parent 137d3b9 commit b7387d1

File tree

9 files changed

+67
-30
lines changed

9 files changed

+67
-30
lines changed

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ class CheckCaptures extends Recheck, SymTransformer:
464464
* environments. At each stage, only include references from `cs` that are outside
465465
* the environment's owner
466466
*/
467-
def markFree(cs: CaptureSet, tree: Tree)(using Context): Unit =
467+
def markFree(cs: CaptureSet, tree: Tree, addUseInfo: Boolean = true)(using Context): Unit =
468468
// A captured reference with the symbol `sym` is visible from the environment
469469
// if `sym` is not defined inside the owner of the environment.
470470
inline def isVisibleFromEnv(sym: Symbol, env: Env) =
@@ -546,7 +546,7 @@ class CheckCaptures extends Recheck, SymTransformer:
546546

547547
if !cs.isAlwaysEmpty then
548548
recur(cs, curEnv, null)
549-
useInfos += ((tree, cs, curEnv))
549+
if addUseInfo then useInfos += ((tree, cs, curEnv))
550550
end markFree
551551

552552
/** If capability `c` refers to a parameter that is not implicitly or explicitly
@@ -988,6 +988,8 @@ class CheckCaptures extends Recheck, SymTransformer:
988988
* - Interpolate contravariant capture set variables in result type.
989989
*/
990990
override def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Type =
991+
val savedEnv = curEnv
992+
val runInConstructor = !sym.isOneOf(Param | ParamAccessor | Lazy | NonMember)
991993
try
992994
if sym.is(Module) then sym.info // Modules are checked by checking the module class
993995
else
@@ -1006,6 +1008,8 @@ class CheckCaptures extends Recheck, SymTransformer:
10061008
""
10071009
disallowBadRootsIn(
10081010
tree.tpt.nuType, NoSymbol, i"Mutable $sym", "have type", addendum, sym.srcPos)
1011+
if runInConstructor then
1012+
pushConstructorEnv()
10091013
checkInferredResult(super.recheckValDef(tree, sym), tree)
10101014
finally
10111015
if !sym.is(Param) then
@@ -1015,6 +1019,11 @@ class CheckCaptures extends Recheck, SymTransformer:
10151019
// function is compiled since we do not propagate expected types into blocks.
10161020
interpolateIfInferred(tree.tpt, sym)
10171021

1022+
if runInConstructor && savedEnv.owner.isClass then
1023+
curEnv = savedEnv
1024+
markFree(tree.tpt.nuType.captureSet, tree, addUseInfo = false)
1025+
end recheckValDef
1026+
10181027
/** Recheck method definitions:
10191028
* - check body in a nested environment that tracks uses, in a nested level,
10201029
* and in a nested context that knows abaout Contains parameters so that we
@@ -1241,6 +1250,24 @@ class CheckCaptures extends Recheck, SymTransformer:
12411250
recheckFinish(result, arg, pt)
12421251
*/
12431252

1253+
/** If environment is owned by a class, run in a new environment owned by
1254+
* its primary constructor instead.
1255+
*/
1256+
def pushConstructorEnv()(using Context): Unit =
1257+
if curEnv.owner.isClass then
1258+
val constr = curEnv.owner.primaryConstructor
1259+
if constr.exists then
1260+
val constrSet = capturedVars(constr)
1261+
if capturedVars(constr) ne CaptureSet.empty then
1262+
curEnv = Env(constr, EnvKind.Regular, constrSet, curEnv)
1263+
1264+
override def recheckStat(stat: Tree)(using Context): Unit =
1265+
val saved = curEnv
1266+
if !stat.isInstanceOf[MemberDef] then
1267+
pushConstructorEnv()
1268+
try recheck(stat)
1269+
finally curEnv = saved
1270+
12441271
/** The main recheck method does some box adapation for all nodes:
12451272
* - If expected type `pt` is boxed and the tree is a lambda or a reference,
12461273
* don't propagate free variables.

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,12 +493,15 @@ abstract class Recheck extends Phase, SymTransformer:
493493
recheckStats(tree.stats)
494494
NoType
495495

496+
def recheckStat(stat: Tree)(using Context): Unit =
497+
recheck(stat)
498+
496499
def recheckStats(stats: List[Tree])(using Context): Unit =
497500
@tailrec def traverse(stats: List[Tree])(using Context): Unit = stats match
498501
case (imp: Import) :: rest =>
499502
traverse(rest)(using ctx.importContext(imp, imp.symbol))
500503
case stat :: rest =>
501-
recheck(stat)
504+
recheckStat(stat)
502505
traverse(rest)
503506
case _ =>
504507
traverse(stats)

tests/neg-custom-args/captures/class-constr.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,6 @@ def test(a: Cap, b: Cap) =
1919
println(b)
2020
2
2121
val d = () => new D()
22-
val d_ok1: () ->{a, b} D^{a, b} = d
23-
val d_ok2: () -> D^{a, b} = d // because of function shorthand
24-
val d_ok3: () ->{a, b} D^{b} = d // error, but should work
22+
val d_ok1: () ->{a, b} D^{a, b} = d // ok
23+
val d_ok2: () ->{a} D^{b} = d // ok
24+
val d_ok3: () -> D^{a, b} = d // error

tests/neg-custom-args/captures/exception-definitions.check

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
3 | self: Err^ => // error
33
| ^^^^
44
| Err is a pure type, it makes no sense to add a capture set to it
5-
-- Error: tests/neg-custom-args/captures/exception-definitions.scala:7:12 ----------------------------------------------
5+
-- Error: tests/neg-custom-args/captures/exception-definitions.scala:7:8 -----------------------------------------------
66
7 | val x = c // error
7-
| ^
8-
| reference (c : Any^) is not included in the allowed capture set {} of the self type of class Err2
7+
| ^^^^^^^^^
8+
| reference (c : Object^) is not included in the allowed capture set {} of the self type of class Err2
99
-- Error: tests/neg-custom-args/captures/exception-definitions.scala:8:13 ----------------------------------------------
10-
8 | class Err3(c: Any^) extends Exception // error
10+
8 | class Err3(c: Object^) extends Exception // error
1111
| ^
12-
| reference (Err3.this.c : Any^) is not included in the allowed capture set {} of the self type of class Err3
12+
| reference (Err3.this.c : Object^) is not included in the allowed capture set {} of the self type of class Err3

tests/neg-custom-args/captures/exception-definitions.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
class Err extends Exception:
33
self: Err^ => // error
44

5-
def test(c: Any^) =
5+
def test(c: Object^) =
66
class Err2 extends Exception:
77
val x = c // error
8-
class Err3(c: Any^) extends Exception // error
8+
class Err3(c: Object^) extends Exception // error
99

10-
class Err4(c: Any^) extends AnyVal // was error, now ok
10+
class Err4(c: Object^) extends AnyVal // was error, now ok
1111

1212

tests/neg-custom-args/captures/method-uses.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def test(xs: List[() => Unit]) =
1111

1212
foo // error
1313
bar() // error
14-
Foo() // OK, but could be error
14+
Foo() // error
1515

1616
def test2(xs: List[() => Unit]) =
1717
def foo = xs.head // error, ok under deferredReaches
@@ -23,8 +23,8 @@ def test3(xs: List[() => Unit]): () ->{xs*} Unit = () =>
2323
def test4(xs: List[() => Unit]) = () => xs.head // error, ok under deferredReaches
2424

2525
def test5(xs: List[() => Unit]) = new:
26-
println(xs.head) // error, ok under deferredReaches
26+
println(xs.head) // error, ok under deferredReaches // error
2727

2828
def test6(xs: List[() => Unit]) =
29-
val x= new { println(xs.head) } // error
29+
val x= new { println(xs.head) } // error // error
3030
x

tests/neg-custom-args/captures/uses.check

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
1-
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/uses.scala:8:17 ------------------------------------------
2-
8 | val _: D^{y} = d // error, should be ok
3-
| ^
4-
| Found: (d : D^{x, y})
5-
| Required: D^{y}
6-
|
7-
| Note that capability x is not included in capture set {y}.
8-
|
9-
| longer explanation available when compiling with `-explain`
101
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/uses.scala:9:13 ------------------------------------------
112
9 | val _: D = d // error
123
| ^
13-
| Found: (d : D^{x, y})
4+
| Found: (d : D^{y})
145
| Required: D
156
|
16-
| Note that capability x is not included in capture set {}.
7+
| Note that capability y is not included in capture set {}.
178
|
189
| longer explanation available when compiling with `-explain`
10+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/uses.scala:13:22 -----------------------------------------
11+
13 | val _: () -> Unit = f // error
12+
| ^
13+
| Found: (f : () ->{x} Unit)
14+
| Required: () -> Unit
15+
|
16+
| Note that capability x is not included in capture set {}.
17+
|
18+
| longer explanation available when compiling with `-explain`
1919
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/uses.scala:18:34 -----------------------------------------
2020
18 | val _: () ->{x} () ->{y} Unit = g // error, should be ok
2121
| ^

tests/neg-custom-args/captures/uses.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ def test(x: C^, y: C^) =
55
def foo() = println(y)
66
}
77
val d = D()
8-
val _: D^{y} = d // error, should be ok
8+
val _: D^{y} = d
99
val _: D = d // error
1010

1111
val f = () => println(D())
1212
val _: () ->{x} Unit = f // ok
13-
val _: () -> Unit = f // should be error
13+
val _: () -> Unit = f // error
1414

1515
def g = () =>
1616
println(x)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
trait MapOps[K]:
2+
this: MapOps[K]^ =>
3+
def keysIterator: Iterator[K]
4+
5+
trait GenKeySet[K]:
6+
this: collection.Set[K] =>
7+
private[MapOps] val allKeys = MapOps.this.keysIterator.toSet

0 commit comments

Comments
 (0)