Skip to content

Rust: Rework type inference for impl Trait in return position #19954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 15 additions & 51 deletions rust/ql/lib/codeql/rust/internal/Type.qll
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ newtype TType =
TTrait(Trait t) or
TArrayType() or // todo: add size?
TRefType() or // todo: add mut?
TImplTraitType(ImplTraitTypeRepr impl) or
TImplTraitArgumentType(Function function, ImplTraitTypeRepr impl) {
impl = function.getAParam().getTypeRepr()
} or
TSliceType() or
TTypeParamTypeParameter(TypeParam t) or
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
Expand Down Expand Up @@ -196,53 +198,6 @@ class RefType extends Type, TRefType {
override Location getLocation() { result instanceof EmptyLocation }
}

/**
* An [impl Trait][1] type.
*
* Each syntactic `impl Trait` type gives rise to its own type, even if
* two `impl Trait` types have the same bounds.
*
* [1]: https://doc.rust-lang.org/reference/types/impl-trait.html
*/
class ImplTraitType extends Type, TImplTraitType {
ImplTraitTypeRepr impl;

ImplTraitType() { this = TImplTraitType(impl) }

/** Gets the underlying AST node. */
ImplTraitTypeRepr getImplTraitTypeRepr() { result = impl }

/** Gets the function that this `impl Trait` belongs to. */
abstract Function getFunction();

override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override TypeParameter getTypeParameter(int i) { none() }

override string toString() { result = impl.toString() }

override Location getLocation() { result = impl.getLocation() }
}

/**
* An [impl Trait in return position][1] type, for example:
*
* ```rust
* fn foo() -> impl Trait
* ```
*
* [1]: https://doc.rust-lang.org/reference/types/impl-trait.html#r-type.impl-trait.return
*/
class ImplTraitReturnType extends ImplTraitType {
private Function function;

ImplTraitReturnType() { impl = function.getRetType().getTypeRepr() }

override Function getFunction() { result = function }
}

/**
* A slice type.
*
Expand Down Expand Up @@ -386,18 +341,27 @@ class SelfTypeParameter extends TypeParameter, TSelfTypeParameter {
*
* [1]: https://doc.rust-lang.org/reference/types/impl-trait.html#r-type.impl-trait.param
*/
class ImplTraitTypeTypeParameter extends ImplTraitType, TypeParameter {
class ImplTraitArgumentType extends TypeParameter, TImplTraitArgumentType {
private Function function;
private ImplTraitTypeRepr impl;

ImplTraitTypeTypeParameter() { impl = function.getAParam().getTypeRepr() }
ImplTraitArgumentType() { this = TImplTraitArgumentType(function, impl) }

override Function getFunction() { result = function }
/** Gets the function that this `impl Trait` belongs to. */
Function getFunction() { result = function }

/** Gets the underlying AST node. */
ImplTraitTypeRepr getImplTraitTypeRepr() { result = impl }

override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override TypeParameter getTypeParameter(int i) { none() }

override string toString() { result = impl.toString() }

override Location getLocation() { result = impl.getLocation() }
}

/**
Expand Down
19 changes: 8 additions & 11 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ private module Input1 implements InputSig1<Location> {
node = tp0.(TypeParamTypeParameter).getTypeParam() or
node = tp0.(AssociatedTypeTypeParameter).getTypeAlias() or
node = tp0.(SelfTypeParameter).getTrait() or
node = tp0.(ImplTraitTypeTypeParameter).getImplTraitTypeRepr()
node = tp0.(ImplTraitArgumentType).getImplTraitTypeRepr()
)
|
tp0 order by kind, id
Expand Down Expand Up @@ -132,11 +132,7 @@ private module Input2 implements InputSig2 {
result = tp.(SelfTypeParameter).getTrait()
or
result =
tp.(ImplTraitTypeTypeParameter)
.getImplTraitTypeRepr()
.getTypeBoundList()
.getABound()
.getTypeRepr()
tp.(ImplTraitArgumentType).getImplTraitTypeRepr().getTypeBoundList().getABound().getTypeRepr()
}

/**
Expand Down Expand Up @@ -670,7 +666,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
)
or
ppos.isImplicit() and
this = result.(ImplTraitTypeTypeParameter).getFunction()
this = result.(ImplTraitArgumentType).getFunction()
}

override Type getParameterType(DeclarationPosition dpos, TypePath path) {
Expand Down Expand Up @@ -1476,7 +1472,7 @@ private Function getTypeParameterMethod(TypeParameter tp, string name) {
or
result = getMethodSuccessor(tp.(SelfTypeParameter).getTrait(), name)
or
result = getMethodSuccessor(tp.(ImplTraitTypeTypeParameter).getImplTraitTypeRepr(), name)
result = getMethodSuccessor(tp.(ImplTraitArgumentType).getImplTraitTypeRepr(), name)
}

pragma[nomagic]
Expand Down Expand Up @@ -1655,8 +1651,8 @@ private Function getMethodFromImpl(MethodCall mc) {

bindingset[trait, name]
pragma[inline_late]
private Function getTraitMethod(ImplTraitReturnType trait, string name) {
result = getMethodSuccessor(trait.getImplTraitTypeRepr(), name)
private Function getTraitMethod(TraitType trait, string name) {
result = getMethodSuccessor(trait.getTrait(), name)
}

pragma[nomagic]
Expand All @@ -1669,7 +1665,8 @@ private Function resolveMethodCallTarget(MethodCall mc) {
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
or
// The type of the receiver is an `impl Trait` type.
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName()) and
not exists(mc.getTrait())
}

pragma[nomagic]
Expand Down
8 changes: 7 additions & 1 deletion rust/ql/lib/codeql/rust/internal/TypeMention.qll
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,13 @@ class PathTypeReprMention extends TypeMention, PathTypeRepr {
class ImplTraitTypeReprMention extends TypeMention instanceof ImplTraitTypeRepr {
override Type resolveTypeAt(TypePath typePath) {
typePath.isEmpty() and
result.(ImplTraitType).getImplTraitTypeRepr() = this
result.(ImplTraitArgumentType).getImplTraitTypeRepr() = this
or
exists(Function f |
this = f.getRetType().getTypeRepr() and
result =
super.getTypeBoundList().getABound().getTypeRepr().(TypeMention).resolveTypeAt(typePath)
)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
multipleCallTargets
| dereference.rs:61:15:61:24 | e1.deref() |
| main.rs:2186:13:2186:31 | ...::from(...) |
| main.rs:2187:13:2187:31 | ...::from(...) |
| main.rs:2188:13:2188:31 | ...::from(...) |
| main.rs:2194:13:2194:31 | ...::from(...) |
| main.rs:2195:13:2195:31 | ...::from(...) |
| main.rs:2196:13:2196:31 | ...::from(...) |
| main.rs:2200:13:2200:31 | ...::from(...) |
| main.rs:2201:13:2201:31 | ...::from(...) |
| main.rs:2202:13:2202:31 | ...::from(...) |
| main.rs:2208:13:2208:31 | ...::from(...) |
| main.rs:2209:13:2209:31 | ...::from(...) |
| main.rs:2210:13:2210:31 | ...::from(...) |
14 changes: 14 additions & 0 deletions rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1854,8 +1854,10 @@ mod async_ {
}

mod impl_trait {
#[derive(Clone)]
struct S1;
struct S2;
struct S3<T3>(T3);

trait Trait1 {
fn f1(&self) {} // Trait1f1
Expand Down Expand Up @@ -1887,10 +1889,21 @@ mod impl_trait {
}
}

impl<T: Clone> MyTrait<T> for S3<T> {
fn get_a(&self) -> T {
let S3(t) = self;
t.clone()
}
}

fn get_a_my_trait() -> impl MyTrait<S2> {
S1
}

fn get_a_my_trait2<T: Clone>(x: T) -> impl MyTrait<T> {
S3(x)
}

fn uses_my_trait1<A, B: MyTrait<A>>(t: B) -> A {
t.get_a() // $ method=MyTrait::get_a
}
Expand All @@ -1908,6 +1921,7 @@ mod impl_trait {
let a = get_a_my_trait(); // $ method=get_a_my_trait
let c = uses_my_trait2(a); // $ type=c:S2 method=uses_my_trait2
let d = uses_my_trait2(S1); // $ type=d:S2 method=uses_my_trait2
let e = get_a_my_trait2(S1).get_a(); // $ method=get_a_my_trait2 method=MyTrait::get_a $ type=e:S1
}
}

Expand Down
Loading