Skip to content

Rust: Disambiguate associated function calls #19995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ private import codeql.rust.elements.Resolvable
*/
module Impl {
private import rust
private import codeql.rust.internal.TypeInference as TypeInference

pragma[nomagic]
Resolvable getCallResolvable(CallExprBase call) {
Expand All @@ -27,7 +28,7 @@ module Impl {
*/
class CallExprBase extends Generated::CallExprBase {
/** Gets the static target of this call, if any. */
Callable getStaticTarget() { none() } // overridden by subclasses, but cannot be made abstract
final Function getStaticTarget() { result = TypeInference::resolveCallTarget(this) }

override Expr getArg(int index) { result = this.getArgList().getArg(index) }
}
Expand Down
10 changes: 0 additions & 10 deletions rust/ql/lib/codeql/rust/elements/internal/CallExprImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ private import codeql.rust.elements.PathExpr
module Impl {
private import rust
private import codeql.rust.internal.PathResolution as PathResolution
private import codeql.rust.internal.TypeInference as TypeInference

pragma[nomagic]
Path getFunctionPath(CallExpr ce) { result = ce.getFunction().(PathExpr).getPath() }
Expand All @@ -37,15 +36,6 @@ module Impl {
class CallExpr extends Generated::CallExpr {
override string toStringImpl() { result = this.getFunction().toAbbreviatedString() + "(...)" }

override Callable getStaticTarget() {
// If this call is to a trait method, e.g., `Trait::foo(bar)`, then check
// if type inference can resolve it to the correct trait implementation.
result = TypeInference::resolveMethodCallTarget(this)
or
not exists(TypeInference::resolveMethodCallTarget(this)) and
result = getResolvedFunction(this)
}

/** Gets the struct that this call resolves to, if any. */
Struct getStruct() { result = getResolvedFunction(this) }

Expand Down
7 changes: 1 addition & 6 deletions rust/ql/lib/codeql/rust/elements/internal/CallImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,7 @@ module Impl {
Expr getReceiver() { result = this.getArgument(TSelfArgumentPosition()) }

/** Gets the static target of this call, if any. */
Function getStaticTarget() {
result = TypeInference::resolveMethodCallTarget(this)
or
not exists(TypeInference::resolveMethodCallTarget(this)) and
result = this.(CallExpr).getStaticTarget()
}
Function getStaticTarget() { result = TypeInference::resolveCallTarget(this) }

/** Gets a runtime target of this call, if any. */
pragma[nomagic]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

private import rust
private import codeql.rust.elements.internal.generated.MethodCallExpr
private import codeql.rust.internal.PathResolution
private import codeql.rust.internal.TypeInference

/**
* INTERNAL: This module contains the customizable definition of `MethodCallExpr` and should not
Expand All @@ -23,8 +21,6 @@ module Impl {
* ```
*/
class MethodCallExpr extends Generated::MethodCallExpr {
override Function getStaticTarget() { result = resolveMethodCallTarget(this) }

private string toStringPart(int index) {
index = 0 and
result = this.getReceiver().toAbbreviatedString()
Expand Down
149 changes: 122 additions & 27 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
private import codeql.rust.frameworks.stdlib.Builtins as Builtins
private import codeql.rust.elements.Call
private import codeql.rust.elements.internal.CallImpl::Impl as CallImpl
private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl

class Type = T::Type;

Expand Down Expand Up @@ -724,8 +725,6 @@
}
}

private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl

final class Access extends Call {
pragma[nomagic]
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
Expand Down Expand Up @@ -771,7 +770,7 @@
Declaration getTarget() {
result = resolveMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
or
result = CallExprImpl::getResolvedFunction(this)
result = resolveFunctionCallTarget(this) // potential mutual recursion; resolving some associated function calls requires resolving types
}
}

Expand Down Expand Up @@ -1045,14 +1044,12 @@
}

pragma[nomagic]
private StructType inferLiteralType(LiteralExpr le) {
private Type inferLiteralType(LiteralExpr le, TypePath path) {
path.isEmpty() and
exists(Builtins::BuiltinType t | result = TStruct(t) |
le instanceof CharLiteralExpr and
t instanceof Builtins::Char
or
le instanceof StringLiteralExpr and
t instanceof Builtins::Str
or
le =
any(NumberLiteralExpr ne |
t.getName() = ne.getSuffix()
Expand All @@ -1070,6 +1067,14 @@
le instanceof BooleanLiteralExpr and
t instanceof Builtins::Bool
)
or
le instanceof StringLiteralExpr and
(
path.isEmpty() and result = TRefType()
or
path = TypePath::singleton(TRefTypeParameter()) and
result = TStruct(any(Builtins::Str s))
)
}

pragma[nomagic]
Expand Down Expand Up @@ -1214,12 +1219,22 @@
final class MethodCall extends Call {
MethodCall() { exists(this.getReceiver()) }

private Type getReceiverTypeAt(TypePath path) {
result = inferType(super.getReceiver(), path)
or
exists(PathExpr pe, TypeMention tm |
pe = this.(CallExpr).getFunction() and
tm = pe.getPath().getQualifier() and
result = tm.resolveTypeAt(path)
)
}

/** Gets the type of the receiver of the method call at `path`. */
Type getTypeAt(TypePath path) {
if this.receiverImplicitlyBorrowed()
then
exists(TypePath path0, Type t0 |
t0 = inferType(super.getReceiver(), path0) and
t0 = this.getReceiverTypeAt(path0) and
(
path0.isCons(TRefTypeParameter(), path)
or
Expand Down Expand Up @@ -1247,10 +1262,14 @@
t0.(StructType).asItemNode() instanceof StringStruct and
result.(StructType).asItemNode() instanceof Builtins::Str
)
else result = inferType(super.getReceiver(), path)
else result = this.getReceiverTypeAt(path)
}
}

final private class FunctionCallExpr extends CallExpr {
FunctionCallExpr() { not this instanceof MethodCall }
}

/**
* Holds if a method for `type` with the name `name` and the arity `arity`
* exists in `impl`.
Expand All @@ -1266,7 +1285,7 @@
}

/**
* Holds if a method for `type` for `trait` with the name `name` and the arity
* `arity` exists in `impl`.
*/
pragma[nomagic]
Expand Down Expand Up @@ -1341,7 +1360,7 @@
// siblings).
not exists(impl.getAttributeMacroExpansion()) and
// We use this for resolving methods, so exclude traits that do not have methods.
exists(Function f | f = trait.getASuccessor(_) and f.getParamList().hasSelfParam()) and
exists(Function f | f = trait.getASuccessor(_)) and

Check warning

Code scanning / CodeQL

Expression can be replaced with a cast Warning

The assignment to
f
in the exists(..) can replaced with an instanceof expression.
selfTy = impl.getSelfTy() and
rootType = selfTy.resolveType()
}
Expand Down Expand Up @@ -1490,6 +1509,84 @@
result = getMethodSuccessor(trait.getImplTraitTypeRepr(), name)
}

pragma[nomagic]
private Function resolveMethodCallTarget(MethodCall mc) {
// The method comes from an `impl` block targeting the type of the receiver.
result = getMethodFromImpl(mc)
or
// The type of the receiver is a type parameter and the method comes from a
// trait bound on the type parameter.
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
or
// The type of the receiver is an `impl Trait` type.
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
}

pragma[nomagic]
private predicate assocFuncResolutionDependsOnArgument(Function f, Impl impl, int pos) {
methodResolutionDependsOnArgument(impl, _, f, pos, _, _)
}

private class AssocFunctionCallExpr extends FunctionCallExpr {
private int pos;

AssocFunctionCallExpr() {
assocFuncResolutionDependsOnArgument(CallExprImpl::getResolvedFunction(this), _, pos)
}

Function getACandidate(Impl impl) {
result = CallExprImpl::getResolvedFunction(this) and
assocFuncResolutionDependsOnArgument(result, impl, pos)
}

int getPosition() { result = pos }

/** Gets the type of the receiver of the associated function call at `path`. */
Type getTypeAt(TypePath path) { result = inferType(this.getArg(pos), path) }
}

private module AssocFuncIsInstantiationOfInput implements
IsInstantiationOfInputSig<AssocFunctionCallExpr>
{
pragma[nomagic]
predicate potentialInstantiationOf(
AssocFunctionCallExpr ce, TypeAbstraction impl, TypeMention constraint
) {
exists(Function cand |
cand = ce.getACandidate(impl) and
constraint = cand.getParam(ce.getPosition()).getTypeRepr()
)
}
}

/**
* Gets the target of `call`, where resolution does not rely on type inference.
*/
pragma[nomagic]
private ItemNode resolveFunctionCallTargetSimple(FunctionCallExpr call) {
result = CallExprImpl::getResolvedFunction(call) and
not assocFuncResolutionDependsOnArgument(result, _, _)
}

/**
* Gets the target of `call`, where resolution relies on type inference.
*/
pragma[nomagic]
private Function resolveFunctionCallTargetComplex(AssocFunctionCallExpr call) {
exists(Impl impl |
IsInstantiationOf<AssocFunctionCallExpr, AssocFuncIsInstantiationOfInput>::isInstantiationOf(call,
impl, _) and
result = getMethodSuccessor(impl, call.getACandidate(_).getName().getText())
)
}

pragma[inline]
private ItemNode resolveFunctionCallTarget(FunctionCallExpr call) {
result = resolveFunctionCallTargetSimple(call)
or
result = resolveFunctionCallTargetComplex(call)
}

cached
private module Cached {
private import codeql.rust.internal.CachedStages
Expand Down Expand Up @@ -1518,18 +1615,12 @@
)
}

/** Gets a method that the method call `mc` resolves to, if any. */
/** Gets a function that `call` resolves to, if any. */
cached
Function resolveMethodCallTarget(MethodCall mc) {
// The method comes from an `impl` block targeting the type of the receiver.
result = getMethodFromImpl(mc)
or
// The type of the receiver is a type parameter and the method comes from a
// trait bound on the type parameter.
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
Function resolveCallTarget(Call call) {
result = resolveMethodCallTarget(call)
or
// The type of the receiver is an `impl Trait` type.
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
result = resolveFunctionCallTarget(call)
}

pragma[inline]
Expand Down Expand Up @@ -1635,8 +1726,7 @@
or
result = inferTryExprType(n, path)
or
result = inferLiteralType(n) and
path.isEmpty()
result = inferLiteralType(n, path)
or
result = inferAsyncBlockExprRootType(n) and
path.isEmpty()
Expand Down Expand Up @@ -1667,8 +1757,8 @@
private Locatable getRelevantLocatable() {
exists(string filepath, int startline, int startcolumn, int endline, int endcolumn |
result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and
filepath.matches("%/sqlx.rs") and
startline = [56 .. 60]
filepath.matches("%/main.rs") and
startline = 120
)
}

Expand All @@ -1677,9 +1767,9 @@
result = inferType(n, path)
}

Function debugResolveMethodCallTarget(Call mce) {
mce = getRelevantLocatable() and
result = resolveMethodCallTarget(mce)
Function debugResolveCallTarget(Call c) {
c = getRelevantLocatable() and
result = [resolveMethodCallTarget(c), resolveCallTarget(c)]
}

predicate debugInferImplicitSelfType(SelfParam self, TypePath path, Type t) {
Expand All @@ -1697,6 +1787,11 @@
tm.resolveTypeAt(path) = type
}

Type debugInferAnnotatedType(AstNode n, TypePath path) {
n = getRelevantLocatable() and
result = inferAnnotatedType(n, path)
}

pragma[nomagic]
private int countTypesAtPath(AstNode n, TypePath path, Type t) {
t = inferType(n, path) and
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
multipleCallTargets
| proc_macro.rs:6:18:6:61 | ...::from(...) |
| proc_macro.rs:7:15:7:58 | ...::from(...) |
| proc_macro.rs:15:5:17:5 | ...::new(...) |
| proc_macro.rs:16:12:16:16 | ...::to_tokens(...) |
| proc_macro.rs:22:15:22:58 | ...::from(...) |
| proc_macro.rs:25:5:28:5 | ...::new(...) |
| proc_macro.rs:26:10:26:12 | ...::to_tokens(...) |
| proc_macro.rs:27:10:27:16 | ...::to_tokens(...) |
| proc_macro.rs:38:15:38:64 | ...::from(...) |
| proc_macro.rs:41:5:49:5 | ...::new(...) |
| proc_macro.rs:41:5:49:5 | ...::new(...) |
| proc_macro.rs:41:5:49:5 | ...::new(...) |
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
multipleCallTargets
| main.rs:272:14:272:29 | ...::deref(...) |
| main.rs:301:30:301:54 | ...::take_self(...) |
| main.rs:306:30:306:56 | ...::take_second(...) |
| main.rs:311:30:311:54 | ...::take_self(...) |
Loading