Skip to content

Commit 3e7d84a

Browse files
committed
Prevent leaking FreshCaps created for parameters
1 parent b160322 commit 3e7d84a

File tree

5 files changed

+40
-22
lines changed

5 files changed

+40
-22
lines changed

compiler/src/dotty/tools/dotc/cc/Capability.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ object Capabilities:
152152
val hiddenSet = CaptureSet.HiddenSet(owner, this: @unchecked)
153153
// fails initialization check without the @unchecked
154154

155+
//assert(rootId != 6, origin)
156+
155157
override def equals(that: Any) = that match
156158
case that: FreshCap => this eq that
157159
case _ => false
@@ -808,6 +810,7 @@ object Capabilities:
808810
case LambdaActual(restp: Type)
809811
case OverriddenType(member: Symbol)
810812
case DeepCS(ref: TypeRef)
813+
case Parameter(param: Symbol)
811814
case Unknown
812815

813816
def explanation(using Context): String = this match
@@ -841,6 +844,8 @@ object Capabilities:
841844
i" when instantiating upper bound of member overridden by $member"
842845
case DeepCS(ref: TypeRef) =>
843846
i" when computing deep capture set of $ref"
847+
case Parameter(param) =>
848+
i" of parameter $param of ${param.owner}"
844849
case Unknown =>
845850
""
846851
end Origin
@@ -907,8 +912,8 @@ object Capabilities:
907912
CapToFresh(origin)(tp)
908913

909914
/** Maps fresh to cap */
910-
def freshToCap(tp: Type)(using Context): Type =
911-
CapToFresh(Origin.Unknown).inverse(tp)
915+
def freshToCap(param: Symbol, tp: Type)(using Context): Type =
916+
CapToFresh(Origin.Parameter(param)).inverse(tp)
912917

913918
/** Map top-level free existential variables one-to-one to Fresh instances */
914919
def resultToFresh(tp: Type, origin: Origin)(using Context): Type =

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ object CaptureSet:
698698
private def narrowClassifier(cls: ClassSymbol)(using Context): Unit =
699699
val newClassifier = leastClassifier(classifier, cls)
700700
if newClassifier == defn.NothingClass then
701-
println(i"conflicting classifications for $this, was $classifier, now $cls")
701+
capt.println(i"conflicting classifications for $this, was $classifier, now $cls")
702702
myClassifier = newClassifier
703703

704704
override def adoptClassifier(cls: ClassSymbol)(using Context): Unit =

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@ object CheckCaptures:
6868
cur = cur.outer
6969
res
7070

71-
def ownerString(using Context): String =
72-
if owner.isAnonymousFunction then "enclosing function" else owner.show
71+
def ownerString(using Context): String = ownerStr(owner)
7372
end Env
7473

7574
/** Similar normal substParams, but this is an approximating type map that
@@ -184,6 +183,9 @@ object CheckCaptures:
184183
check.traverse(tp)
185184
end disallowBadRootsIn
186185

186+
private def ownerStr(owner: Symbol)(using Context): String =
187+
if owner.isAnonymousFunction then "enclosing function" else owner.show
188+
187189
trait CheckerAPI:
188190
/** Complete symbol info of a val or a def */
189191
def completeDef(tree: ValOrDefDef, sym: Symbol, completer: LazyType)(using Context): Type
@@ -548,10 +550,8 @@ class CheckCaptures extends Recheck, SymTransformer:
548550
c.paramPathRoot match
549551
case ref: NamedType if !ref.symbol.isUseParam =>
550552
val what = if ref.isType then "Capture set parameter" else "Local reach capability"
551-
val owner = ref.symbol.owner
552-
val ownerStr = if owner.isAnonymousFunction then "enclosing function" else owner.show
553553
report.error(
554-
em"""$what $c leaks into capture scope of $ownerStr.
554+
em"""$what $c leaks into capture scope of ${ownerStr(ref.symbol.owner)}.
555555
|To allow this, the ${ref.symbol} should be declared with a @use annotation""", pos)
556556
case _ =>
557557

@@ -925,7 +925,7 @@ class CheckCaptures extends Recheck, SymTransformer:
925925
assert(params.hasSameLengthAs(argTypes), i"$mdef vs $pt, ${params}")
926926
for (argType, param) <- argTypes.lazyZip(params) do
927927
val paramTpt = param.asInstanceOf[ValDef].tpt
928-
val paramType = freshToCap(paramTpt.nuType)
928+
val paramType = freshToCap(param.symbol, paramTpt.nuType)
929929
checkConformsExpr(argType, paramType, param)
930930
.showing(i"compared expected closure formal $argType against $param with ${paramTpt.nuType}", capt)
931931
if ccConfig.preTypeClosureResults && !(isEtaExpansion(mdef) && ccConfig.handleEtaExpansionsSpecially) then
@@ -2030,7 +2030,13 @@ class CheckCaptures extends Recheck, SymTransformer:
20302030
case c: DerivedCapability =>
20312031
checkElem(c.underlying)
20322032
case c: FreshCap =>
2033-
check(c.hiddenSet)
2033+
c.origin match
2034+
case Origin.Parameter(param) =>
2035+
report.error(
2036+
em"Local $c created in type of $param leaks into capture scope of ${ownerStr(param.owner)}",
2037+
tree.srcPos)
2038+
case _ =>
2039+
check(c.hiddenSet)
20342040
case _ =>
20352041

20362042
check(uses)

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -646,17 +646,18 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
646646
def ownerChanges =
647647
ctx.owner.name.is(TryOwnerName)
648648

649-
def paramsToCap(mt: Type)(using Context): Type = mt match
649+
def paramsToCap(psymss: List[List[Symbol]], mt: Type)(using Context): Type = mt match
650650
case mt: MethodType =>
651651
try
652652
mt.derivedLambdaType(
653-
paramInfos = mt.paramInfos.map(freshToCap),
654-
resType = paramsToCap(mt.resType))
653+
paramInfos =
654+
psymss.head.lazyZip(mt.paramInfos).map(freshToCap),
655+
resType = paramsToCap(psymss.tail, mt.resType))
655656
catch case ex: AssertionError =>
656657
println(i"error while mapping params ${mt.paramInfos} of $sym")
657658
throw ex
658659
case mt: PolyType =>
659-
mt.derivedLambdaType(resType = paramsToCap(mt.resType))
660+
mt.derivedLambdaType(resType = paramsToCap(psymss.tail, mt.resType))
660661
case _ => mt
661662

662663
// If there's a change in the signature or owner, update the info of `sym`
@@ -668,7 +669,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
668669
toResultInResults(sym, report.error(_, tree.srcPos)):
669670
if sym.is(Method) then
670671
inContext(ctx.withOwner(sym)):
671-
paramsToCap(methodType(paramSymss, localReturnType))
672+
paramsToCap(paramSymss, methodType(paramSymss, localReturnType))
672673
else tree.tpt.nuType
673674
if tree.tpt.isInstanceOf[InferredTypeTree]
674675
&& !sym.is(Param) && !sym.is(ParamAccessor)

tests/neg-custom-args/captures/leak-problem-unboxed.scala

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import language.experimental.captureChecking
2-
import caps.use
32

43
// Some capabilities that should be used locally
54
trait Async:
@@ -19,16 +18,23 @@ def useBoxedAsync1[C^](x: Box[Async^{C}]): Unit = x.get.read() // ok
1918
def test(): Unit =
2019

2120
val f: Box[Async^] => Unit = (x: Box[Async^]) => useBoxedAsync(x) // error
22-
val t0: Box[Async^] => Unit = x => useBoxedAsync(x) // TODO hould be error!
21+
val f0: Box[Async^] => Unit = x => useBoxedAsync(x) // // error
2322

24-
val t1: Box[Async^] => Unit = useBoxedAsync(_) // TODO should be error!
25-
val t2: Box[Async^] => Unit = useBoxedAsync // TODO should be error!
26-
val t3 = useBoxedAsync(_) // was error, now ok
27-
val t4 = useBoxedAsync // was error, now ok
23+
val f1: Box[Async^] => Unit = useBoxedAsync(_) // error
24+
val f2: Box[Async^] => Unit = useBoxedAsync // error
25+
val f3 = useBoxedAsync(_) // was error, now ok, but bang below fails
26+
val f4 = useBoxedAsync // was error, now ok, but bang2 below fails
2827

2928
def boom(x: Async^): () ->{f} Unit =
3029
() => f(Box(x))
3130

3231
val leaked = usingAsync[() ->{f} Unit](boom)
3332

34-
leaked() // scope violation
33+
leaked() // was scope violation
34+
35+
def bang(x: Async^) =
36+
() => f3(Box(x)) // error
37+
38+
def bang2(x: Async^) =
39+
() => f3(Box(x)) // error
40+

0 commit comments

Comments
 (0)