Skip to content

Commit 9d72fab

Browse files
authored
Merge pull request #20119 from paldepind/rust/type-inference-assoc-type-tp
Rust: Type inference for impl trait types with type parameters
2 parents 37b508b + 92bce4e commit 9d72fab

File tree

9 files changed

+1320
-1148
lines changed

9 files changed

+1320
-1148
lines changed

rust/ql/.generated.list

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/ql/.gitattributes

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1-
// generated by codegen, remove this comment if you wish to edit this file
21
/**
32
* This module provides a hand-modifiable wrapper around the generated class `ImplTraitTypeRepr`.
43
*
54
* INTERNAL: Do not use.
65
*/
76

87
private import codeql.rust.elements.internal.generated.ImplTraitTypeRepr
8+
private import rust
99

1010
/**
1111
* INTERNAL: This module contains the customizable definition of `ImplTraitTypeRepr` and should not
1212
* be referenced directly.
1313
*/
1414
module Impl {
15+
// the following QLdoc is generated: if you need to edit it, do it in the schema file
1516
/**
1617
* An `impl Trait` type.
1718
*
@@ -21,5 +22,15 @@ module Impl {
2122
* // ^^^^^^^^^^^^^^^^^^^^^^^^^^
2223
* ```
2324
*/
24-
class ImplTraitTypeRepr extends Generated::ImplTraitTypeRepr { }
25+
class ImplTraitTypeRepr extends Generated::ImplTraitTypeRepr {
26+
/** Gets the function for which this impl trait type occurs, if any. */
27+
Function getFunction() {
28+
this.getParentNode*() = [result.getRetType().getTypeRepr(), result.getAParam().getTypeRepr()]
29+
}
30+
31+
/** Holds if this impl trait type occurs in the return type of a function. */
32+
predicate isInReturnPos() {
33+
this.getParentNode*() = this.getFunction().getRetType().getTypeRepr()
34+
}
35+
}
2536
}

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,21 @@ newtype TType =
5252
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
5353
TArrayTypeParameter() or
5454
TDynTraitTypeParameter(AstNode n) { dynTraitTypeParameter(_, n) } or
55+
TImplTraitTypeParameter(ImplTraitTypeRepr implTrait, TypeParam tp) {
56+
implTraitTypeParam(implTrait, _, tp)
57+
} or
5558
TRefTypeParameter() or
5659
TSelfTypeParameter(Trait t) or
5760
TSliceTypeParameter()
5861

62+
predicate implTraitTypeParam(ImplTraitTypeRepr implTrait, int i, TypeParam tp) {
63+
implTrait.isInReturnPos() and
64+
tp = implTrait.getFunction().getGenericParamList().getTypeParam(i) and
65+
// Only include type parameters of the function that occur inside the impl
66+
// trait type.
67+
exists(Path path | path.getParentNode*() = implTrait and resolvePath(path) = tp)
68+
}
69+
5970
/**
6071
* A type without type arguments.
6172
*
@@ -263,7 +274,12 @@ class ImplTraitType extends Type, TImplTraitType {
263274

264275
override TupleField getTupleField(int i) { none() }
265276

266-
override TypeParameter getTypeParameter(int i) { none() }
277+
override TypeParameter getTypeParameter(int i) {
278+
exists(TypeParam tp |
279+
implTraitTypeParam(impl, i, tp) and
280+
result = TImplTraitTypeParameter(impl, tp)
281+
)
282+
}
267283

268284
override string toString() { result = impl.toString() }
269285

@@ -302,7 +318,7 @@ class DynTraitType extends Type, TDynTraitType {
302318
class ImplTraitReturnType extends ImplTraitType {
303319
private Function function;
304320

305-
ImplTraitReturnType() { impl = function.getRetType().getTypeRepr() }
321+
ImplTraitReturnType() { impl.isInReturnPos() and function = impl.getFunction() }
306322

307323
override Function getFunction() { result = function }
308324
}
@@ -456,6 +472,21 @@ class DynTraitTypeParameter extends TypeParameter, TDynTraitTypeParameter {
456472
override Location getLocation() { result = n.getLocation() }
457473
}
458474

475+
class ImplTraitTypeParameter extends TypeParameter, TImplTraitTypeParameter {
476+
private TypeParam typeParam;
477+
private ImplTraitTypeRepr implTrait;
478+
479+
ImplTraitTypeParameter() { this = TImplTraitTypeParameter(implTrait, typeParam) }
480+
481+
TypeParam getTypeParam() { result = typeParam }
482+
483+
ImplTraitTypeRepr getImplTraitTypeRepr() { result = implTrait }
484+
485+
override string toString() { result = "impl(" + typeParam.toString() + ")" }
486+
487+
override Location getLocation() { result = typeParam.getLocation() }
488+
}
489+
459490
/** An implicit reference type parameter. */
460491
class RefTypeParameter extends TypeParameter, TRefTypeParameter {
461492
override string toString() { result = "&T" }
@@ -569,5 +600,7 @@ final class SelfTypeBoundTypeAbstraction extends TypeAbstraction, Name {
569600
}
570601

571602
final class ImplTraitTypeReprAbstraction extends TypeAbstraction, ImplTraitTypeRepr {
572-
override TypeParameter getATypeParameter() { none() }
603+
override TypeParameter getATypeParameter() {
604+
implTraitTypeParam(this, _, result.(TypeParamTypeParameter).getTypeParam())
605+
}
573606
}

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -83,42 +83,48 @@ private module Input1 implements InputSig1<Location> {
8383

8484
int getTypeParameterId(TypeParameter tp) {
8585
tp =
86-
rank[result](TypeParameter tp0, int kind, int id |
86+
rank[result](TypeParameter tp0, int kind, int id1, int id2 |
8787
tp0 instanceof ArrayTypeParameter and
8888
kind = 0 and
89-
id = 0
89+
id1 = 0 and
90+
id2 = 0
9091
or
9192
tp0 instanceof RefTypeParameter and
9293
kind = 0 and
93-
id = 1
94+
id1 = 0 and
95+
id2 = 1
9496
or
9597
tp0 instanceof SliceTypeParameter and
9698
kind = 0 and
97-
id = 2
99+
id1 = 0 and
100+
id2 = 2
98101
or
99102
kind = 1 and
100-
id =
103+
id1 = 0 and
104+
id2 =
101105
idOfTypeParameterAstNode([
102106
tp0.(DynTraitTypeParameter).getTypeParam().(AstNode),
103107
tp0.(DynTraitTypeParameter).getTypeAlias()
104108
])
105109
or
106110
kind = 2 and
107-
exists(AstNode node | id = idOfTypeParameterAstNode(node) |
111+
id1 = idOfTypeParameterAstNode(tp0.(ImplTraitTypeParameter).getImplTraitTypeRepr()) and
112+
id2 = idOfTypeParameterAstNode(tp0.(ImplTraitTypeParameter).getTypeParam())
113+
or
114+
kind = 3 and
115+
id1 = 0 and
116+
exists(AstNode node | id2 = idOfTypeParameterAstNode(node) |
108117
node = tp0.(TypeParamTypeParameter).getTypeParam() or
109118
node = tp0.(AssociatedTypeTypeParameter).getTypeAlias() or
110119
node = tp0.(SelfTypeParameter).getTrait() or
111120
node = tp0.(ImplTraitTypeTypeParameter).getImplTraitTypeRepr()
112121
)
113122
or
114-
exists(TupleTypeParameter ttp, int maxArity |
115-
maxArity = max(int i | i = any(TupleType tt).getArity()) and
116-
tp0 = ttp and
117-
kind = 3 and
118-
id = ttp.getTupleType().getArity() * maxArity + ttp.getIndex()
119-
)
123+
kind = 4 and
124+
id1 = tp0.(TupleTypeParameter).getTupleType().getArity() and
125+
id2 = tp0.(TupleTypeParameter).getIndex()
120126
|
121-
tp0 order by kind, id
127+
tp0 order by kind, id1, id2
122128
)
123129
}
124130
}

rust/ql/lib/codeql/rust/internal/TypeMention.qll

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,12 @@ class ImplTraitTypeReprMention extends TypeMention instanceof ImplTraitTypeRepr
258258
override Type resolveTypeAt(TypePath typePath) {
259259
typePath.isEmpty() and
260260
result.(ImplTraitType).getImplTraitTypeRepr() = this
261+
or
262+
exists(ImplTraitTypeParameter tp |
263+
this = tp.getImplTraitTypeRepr() and
264+
typePath = TypePath::singleton(tp) and
265+
result = TTypeParamTypeParameter(tp.getTypeParam())
266+
)
261267
}
262268
}
263269

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
multipleCallTargets
22
| dereference.rs:61:15:61:24 | e1.deref() |
3-
| main.rs:2253:13:2253:31 | ...::from(...) |
4-
| main.rs:2254:13:2254:31 | ...::from(...) |
5-
| main.rs:2255:13:2255:31 | ...::from(...) |
6-
| main.rs:2261:13:2261:31 | ...::from(...) |
7-
| main.rs:2262:13:2262:31 | ...::from(...) |
8-
| main.rs:2263:13:2263:31 | ...::from(...) |
3+
| main.rs:2278:13:2278:31 | ...::from(...) |
4+
| main.rs:2279:13:2279:31 | ...::from(...) |
5+
| main.rs:2280:13:2280:31 | ...::from(...) |
6+
| main.rs:2286:13:2286:31 | ...::from(...) |
7+
| main.rs:2287:13:2287:31 | ...::from(...) |
8+
| main.rs:2288:13:2288:31 | ...::from(...) |

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1913,8 +1913,10 @@ mod async_ {
19131913
}
19141914

19151915
mod impl_trait {
1916+
#[derive(Copy, Clone)]
19161917
struct S1;
19171918
struct S2;
1919+
struct S3<T3>(T3);
19181920

19191921
trait Trait1 {
19201922
fn f1(&self) {} // Trait1f1
@@ -1946,6 +1948,13 @@ mod impl_trait {
19461948
}
19471949
}
19481950

1951+
impl<T: Clone> MyTrait<T> for S3<T> {
1952+
fn get_a(&self) -> T {
1953+
let S3(t) = self;
1954+
t.clone()
1955+
}
1956+
}
1957+
19491958
fn get_a_my_trait() -> impl MyTrait<S2> {
19501959
S1
19511960
}
@@ -1954,6 +1963,18 @@ mod impl_trait {
19541963
t.get_a() // $ target=MyTrait::get_a
19551964
}
19561965

1966+
fn get_a_my_trait2<T: Clone>(x: T) -> impl MyTrait<T> {
1967+
S3(x)
1968+
}
1969+
1970+
fn get_a_my_trait3<T: Clone>(x: T) -> Option<impl MyTrait<T>> {
1971+
Some(S3(x))
1972+
}
1973+
1974+
fn get_a_my_trait4<T: Clone>(x: T) -> (impl MyTrait<T>, impl MyTrait<T>) {
1975+
(S3(x.clone()), S3(x)) // $ target=clone
1976+
}
1977+
19571978
fn uses_my_trait2<A>(t: impl MyTrait<A>) -> A {
19581979
t.get_a() // $ target=MyTrait::get_a
19591980
}
@@ -1967,6 +1988,10 @@ mod impl_trait {
19671988
let a = get_a_my_trait(); // $ target=get_a_my_trait
19681989
let c = uses_my_trait2(a); // $ type=c:S2 target=uses_my_trait2
19691990
let d = uses_my_trait2(S1); // $ type=d:S2 target=uses_my_trait2
1991+
let e = get_a_my_trait2(S1).get_a(); // $ target=get_a_my_trait2 target=MyTrait::get_a type=e:S1
1992+
// For this function the `impl` type does not appear in the root of the return type
1993+
let f = get_a_my_trait3(S1).unwrap().get_a(); // $ target=get_a_my_trait3 target=unwrap target=MyTrait::get_a type=f:S1
1994+
let g = get_a_my_trait4(S1).0.get_a(); // $ target=get_a_my_trait4 target=MyTrait::get_a type=g:S1
19701995
}
19711996
}
19721997

@@ -2425,7 +2450,7 @@ mod tuples {
24252450

24262451
let pair = [1, 1].into(); // $ type=pair:(T_2) type=pair:0(2).i32 type=pair:1(2).i32 MISSING: target=into
24272452
match pair {
2428-
(0,0) => print!("unexpected"),
2453+
(0, 0) => print!("unexpected"),
24292454
_ => print!("expected"),
24302455
}
24312456
let x = pair.0; // $ type=x:i32

0 commit comments

Comments
 (0)