Skip to content

Commit 0bf6b2d

Browse files
committed
Rust: Disambiguate calls to associated functions
1 parent 526b784 commit 0bf6b2d

File tree

5 files changed

+79
-23
lines changed

5 files changed

+79
-23
lines changed

rust/ql/lib/codeql/rust/elements/internal/CallExprBaseImpl.qll

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ private import codeql.rust.elements.Resolvable
1313
*/
1414
module Impl {
1515
private import rust
16+
private import codeql.rust.internal.TypeInference as TypeInference
1617

1718
pragma[nomagic]
1819
Resolvable getCallResolvable(CallExprBase call) {
@@ -27,7 +28,11 @@ module Impl {
2728
*/
2829
class CallExprBase extends Generated::CallExprBase {
2930
/** Gets the static target of this call, if any. */
30-
Callable getStaticTarget() { none() } // overridden by subclasses, but cannot be made abstract
31+
final Callable getStaticTarget() {
32+
result = TypeInference::resolveMethodCallTarget(this)
33+
or
34+
result = TypeInference::resolveCallTarget(this)
35+
}
3136

3237
override Expr getArg(int index) { result = this.getArgList().getArg(index) }
3338
}

rust/ql/lib/codeql/rust/elements/internal/CallExprImpl.qll

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ private import codeql.rust.elements.PathExpr
1414
module Impl {
1515
private import rust
1616
private import codeql.rust.internal.PathResolution as PathResolution
17-
private import codeql.rust.internal.TypeInference as TypeInference
1817

1918
pragma[nomagic]
2019
Path getFunctionPath(CallExpr ce) { result = ce.getFunction().(PathExpr).getPath() }
@@ -37,15 +36,6 @@ module Impl {
3736
class CallExpr extends Generated::CallExpr {
3837
override string toStringImpl() { result = this.getFunction().toAbbreviatedString() + "(...)" }
3938

40-
override Callable getStaticTarget() {
41-
// If this call is to a trait method, e.g., `Trait::foo(bar)`, then check
42-
// if type inference can resolve it to the correct trait implementation.
43-
result = TypeInference::resolveMethodCallTarget(this)
44-
or
45-
not exists(TypeInference::resolveMethodCallTarget(this)) and
46-
result = getResolvedFunction(this)
47-
}
48-
4939
/** Gets the struct that this call resolves to, if any. */
5040
Struct getStruct() { result = getResolvedFunction(this) }
5141

rust/ql/lib/codeql/rust/elements/internal/CallImpl.qll

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ module Impl {
6262
Function getStaticTarget() {
6363
result = TypeInference::resolveMethodCallTarget(this)
6464
or
65-
not exists(TypeInference::resolveMethodCallTarget(this)) and
6665
result = this.(CallExpr).getStaticTarget()
6766
}
6867

rust/ql/lib/codeql/rust/elements/internal/MethodCallExprImpl.qll

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
private import rust
88
private import codeql.rust.elements.internal.generated.MethodCallExpr
9-
private import codeql.rust.internal.PathResolution
10-
private import codeql.rust.internal.TypeInference
119

1210
/**
1311
* INTERNAL: This module contains the customizable definition of `MethodCallExpr` and should not
@@ -23,8 +21,6 @@ module Impl {
2321
* ```
2422
*/
2523
class MethodCallExpr extends Generated::MethodCallExpr {
26-
override Function getStaticTarget() { result = resolveMethodCallTarget(this) }
27-
2824
private string toStringPart(int index) {
2925
index = 0 and
3026
result = this.getReceiver().toAbbreviatedString()

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ private import codeql.rust.frameworks.stdlib.Stdlib
1111
private import codeql.rust.frameworks.stdlib.Builtins as Builtins
1212
private import codeql.rust.elements.Call
1313
private import codeql.rust.elements.internal.CallImpl::Impl as CallImpl
14+
private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl
1415

1516
class Type = T::Type;
1617

@@ -724,8 +725,6 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
724725
}
725726
}
726727

727-
private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl
728-
729728
final class Access extends Call {
730729
pragma[nomagic]
731730
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
@@ -771,7 +770,9 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
771770
Declaration getTarget() {
772771
result = resolveMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
773772
or
774-
result = CallExprImpl::getResolvedFunction(this)
773+
result = resolveCallTargetSimple(this)
774+
or
775+
result = resolveCallTargetComplex(this) // mutual recursion
775776
}
776777
}
777778

@@ -1347,7 +1348,7 @@ private predicate implSiblingCandidate(
13471348
// siblings).
13481349
not exists(impl.getAttributeMacroExpansion()) and
13491350
// We use this for resolving methods, so exclude traits that do not have methods.
1350-
exists(Function f | f = trait.getASuccessor(_) and f.getParamList().hasSelfParam()) and
1351+
exists(Function f | f = trait.getASuccessor(_)) and
13511352
selfTy = impl.getSelfTy() and
13521353
rootType = selfTy.resolveType()
13531354
}
@@ -1496,6 +1497,58 @@ private Function getTraitMethod(ImplTraitReturnType trait, string name) {
14961497
result = getMethodSuccessor(trait.getImplTraitTypeRepr(), name)
14971498
}
14981499

1500+
pragma[nomagic]
1501+
private predicate assocFuncResolutionDependsOnArgument(Function f, Impl impl, int pos) {
1502+
methodResolutionDependsOnArgument(impl, _, f, pos, _, _)
1503+
}
1504+
1505+
private class AssocFuncCallExpr extends CallExpr {
1506+
private int pos;
1507+
1508+
AssocFuncCallExpr() {
1509+
assocFuncResolutionDependsOnArgument(CallExprImpl::getResolvedFunction(this), _, pos)
1510+
}
1511+
1512+
Function getACandidate(Impl impl) {
1513+
result = CallExprImpl::getResolvedFunction(this) and
1514+
assocFuncResolutionDependsOnArgument(result, impl, pos)
1515+
}
1516+
1517+
int getPosition() { result = pos }
1518+
1519+
/** Gets the type of the receiver of the associated function call at `path`. */
1520+
Type getTypeAt(TypePath path) { result = inferType(this.getArg(pos), path) }
1521+
}
1522+
1523+
private module AssocFuncIsInstantiationOfInput implements
1524+
IsInstantiationOfInputSig<AssocFuncCallExpr>
1525+
{
1526+
pragma[nomagic]
1527+
predicate potentialInstantiationOf(
1528+
AssocFuncCallExpr ce, TypeAbstraction impl, TypeMention constraint
1529+
) {
1530+
exists(Function cand |
1531+
cand = ce.getACandidate(impl) and
1532+
constraint = cand.getParam(ce.getPosition()).getTypeRepr()
1533+
)
1534+
}
1535+
}
1536+
1537+
pragma[nomagic]
1538+
ItemNode resolveCallTargetSimple(CallExpr ce) {
1539+
result = CallExprImpl::getResolvedFunction(ce) and
1540+
not assocFuncResolutionDependsOnArgument(result, _, _)
1541+
}
1542+
1543+
pragma[nomagic]
1544+
Function resolveCallTargetComplex(AssocFuncCallExpr ce) {
1545+
exists(Impl impl |
1546+
IsInstantiationOf<AssocFuncCallExpr, AssocFuncIsInstantiationOfInput>::isInstantiationOf(ce,
1547+
impl, _) and
1548+
result = getMethodSuccessor(impl, ce.getACandidate(_).getName().getText())
1549+
)
1550+
}
1551+
14991552
cached
15001553
private module Cached {
15011554
private import codeql.rust.internal.CachedStages
@@ -1538,6 +1591,14 @@ private module Cached {
15381591
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
15391592
}
15401593

1594+
/** Gets a method that the method call `mc` resolves to, if any. */
1595+
cached
1596+
Function resolveCallTarget(CallExpr ce) {
1597+
result = resolveCallTargetSimple(ce)
1598+
or
1599+
result = resolveCallTargetComplex(ce)
1600+
}
1601+
15411602
pragma[inline]
15421603
private Type inferRootTypeDeref(AstNode n) {
15431604
result = inferType(n) and
@@ -1682,9 +1743,9 @@ private module Debug {
16821743
result = inferType(n, path)
16831744
}
16841745

1685-
Function debugResolveMethodCallTarget(Call mce) {
1686-
mce = getRelevantLocatable() and
1687-
result = resolveMethodCallTarget(mce)
1746+
Function debugResolveCallTarget(Call c) {
1747+
c = getRelevantLocatable() and
1748+
result = [resolveMethodCallTarget(c), resolveCallTarget(c)]
16881749
}
16891750

16901751
predicate debugInferImplicitSelfType(SelfParam self, TypePath path, Type t) {
@@ -1702,6 +1763,11 @@ private module Debug {
17021763
tm.resolveTypeAt(path) = type
17031764
}
17041765

1766+
Type debugInferAnnotatedType(AstNode n, TypePath path) {
1767+
n = getRelevantLocatable() and
1768+
result = inferAnnotatedType(n, path)
1769+
}
1770+
17051771
pragma[nomagic]
17061772
private int countTypesAtPath(AstNode n, TypePath path, Type t) {
17071773
t = inferType(n, path) and

0 commit comments

Comments
 (0)