Skip to content

Commit 411aa6d

Browse files
authored
Merge pull request #19971 from hvitved/rust/type-inference-for-range
Rust: Improve type inference for `for` loops and range expressions
2 parents 52abf3b + 1518cad commit 411aa6d

File tree

10 files changed

+341
-38
lines changed

10 files changed

+341
-38
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
category: minorAnalysis
3+
---
4+
* Type inference has been improved for `for` loops and range expressions, which improves call resolution and may ultimately lead to more query results.
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/**
2+
* This module provides sub classes of the `RangeExpr` class.
3+
*/
4+
5+
private import rust
6+
7+
/**
8+
* A range-from expression. For example:
9+
* ```rust
10+
* let x = 10..;
11+
* ```
12+
*/
13+
final class RangeFromExpr extends RangeExpr {
14+
RangeFromExpr() {
15+
this.getOperatorName() = ".." and
16+
this.hasStart() and
17+
not this.hasEnd()
18+
}
19+
}
20+
21+
/**
22+
* A range-to expression. For example:
23+
* ```rust
24+
* let x = ..10;
25+
* ```
26+
*/
27+
final class RangeToExpr extends RangeExpr {
28+
RangeToExpr() {
29+
this.getOperatorName() = ".." and
30+
not this.hasStart() and
31+
this.hasEnd()
32+
}
33+
}
34+
35+
/**
36+
* A range-from-to expression. For example:
37+
* ```rust
38+
* let x = 10..20;
39+
* ```
40+
*/
41+
final class RangeFromToExpr extends RangeExpr {
42+
RangeFromToExpr() {
43+
this.getOperatorName() = ".." and
44+
this.hasStart() and
45+
this.hasEnd()
46+
}
47+
}
48+
49+
/**
50+
* A range-inclusive expression. For example:
51+
* ```rust
52+
* let x = 1..=10;
53+
* ```
54+
*/
55+
final class RangeInclusiveExpr extends RangeExpr {
56+
RangeInclusiveExpr() {
57+
this.getOperatorName() = "..=" and
58+
this.hasStart() and
59+
this.hasEnd()
60+
}
61+
}
62+
63+
/**
64+
* A range-to-inclusive expression. For example:
65+
* ```rust
66+
* let x = ..=10;
67+
* ```
68+
*/
69+
final class RangeToInclusiveExpr extends RangeExpr {
70+
RangeToInclusiveExpr() {
71+
this.getOperatorName() = "..=" and
72+
not this.hasStart() and
73+
this.hasEnd()
74+
}
75+
}

rust/ql/lib/codeql/rust/frameworks/stdlib/Stdlib.qll

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,72 @@ class ResultEnum extends Enum {
5050
Variant getErr() { result = this.getVariant("Err") }
5151
}
5252

53+
/**
54+
* The [`Range` struct][1].
55+
*
56+
* [1]: https://doc.rust-lang.org/core/ops/struct.Range.html
57+
*/
58+
class RangeStruct extends Struct {
59+
RangeStruct() { this.getCanonicalPath() = "core::ops::range::Range" }
60+
61+
/** Gets the `start` field. */
62+
StructField getStart() { result = this.getStructField("start") }
63+
64+
/** Gets the `end` field. */
65+
StructField getEnd() { result = this.getStructField("end") }
66+
}
67+
68+
/**
69+
* The [`RangeFrom` struct][1].
70+
*
71+
* [1]: https://doc.rust-lang.org/core/ops/struct.RangeFrom.html
72+
*/
73+
class RangeFromStruct extends Struct {
74+
RangeFromStruct() { this.getCanonicalPath() = "core::ops::range::RangeFrom" }
75+
76+
/** Gets the `start` field. */
77+
StructField getStart() { result = this.getStructField("start") }
78+
}
79+
80+
/**
81+
* The [`RangeTo` struct][1].
82+
*
83+
* [1]: https://doc.rust-lang.org/core/ops/struct.RangeTo.html
84+
*/
85+
class RangeToStruct extends Struct {
86+
RangeToStruct() { this.getCanonicalPath() = "core::ops::range::RangeTo" }
87+
88+
/** Gets the `end` field. */
89+
StructField getEnd() { result = this.getStructField("end") }
90+
}
91+
92+
/**
93+
* The [`RangeInclusive` struct][1].
94+
*
95+
* [1]: https://doc.rust-lang.org/core/ops/struct.RangeInclusive.html
96+
*/
97+
class RangeInclusiveStruct extends Struct {
98+
RangeInclusiveStruct() { this.getCanonicalPath() = "core::ops::range::RangeInclusive" }
99+
100+
/** Gets the `start` field. */
101+
StructField getStart() { result = this.getStructField("start") }
102+
103+
/** Gets the `end` field. */
104+
StructField getEnd() { result = this.getStructField("end") }
105+
}
106+
107+
/**
108+
* The [`RangeToInclusive` struct][1].
109+
*
110+
* [1]: https://doc.rust-lang.org/core/ops/struct.RangeToInclusive.html
111+
*/
112+
class RangeToInclusiveStruct extends Struct {
113+
RangeToInclusiveStruct() { this.getCanonicalPath() = "core::ops::range::RangeToInclusive" }
114+
115+
/** Gets the `end` field. */
116+
StructField getEnd() { result = this.getStructField("end") }
117+
}
118+
53119
/**
54120
* The [`Future` trait][1].
55121
*
@@ -66,6 +132,38 @@ class FutureTrait extends Trait {
66132
}
67133
}
68134

135+
/**
136+
* The [`Iterator` trait][1].
137+
*
138+
* [1]: https://doc.rust-lang.org/std/iter/trait.Iterator.html
139+
*/
140+
class IteratorTrait extends Trait {
141+
IteratorTrait() { this.getCanonicalPath() = "core::iter::traits::iterator::Iterator" }
142+
143+
/** Gets the `Item` associated type. */
144+
pragma[nomagic]
145+
TypeAlias getItemType() {
146+
result = this.getAssocItemList().getAnAssocItem() and
147+
result.getName().getText() = "Item"
148+
}
149+
}
150+
151+
/**
152+
* The [`IntoIterator` trait][1].
153+
*
154+
* [1]: https://doc.rust-lang.org/std/iter/trait.IntoIterator.html
155+
*/
156+
class IntoIteratorTrait extends Trait {
157+
IntoIteratorTrait() { this.getCanonicalPath() = "core::iter::traits::collect::IntoIterator" }
158+
159+
/** Gets the `Item` associated type. */
160+
pragma[nomagic]
161+
TypeAlias getItemType() {
162+
result = this.getAssocItemList().getAnAssocItem() and
163+
result.getName().getText() = "Item"
164+
}
165+
}
166+
69167
/**
70168
* The [`String` struct][1].
71169
*

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -421,21 +421,25 @@ final class ImplTypeAbstraction extends TypeAbstraction, Impl {
421421
}
422422

423423
final class TraitTypeAbstraction extends TypeAbstraction, Trait {
424-
override TypeParamTypeParameter getATypeParameter() {
425-
result.getTypeParam() = this.getGenericParamList().getATypeParam()
424+
override TypeParameter getATypeParameter() {
425+
result.(TypeParamTypeParameter).getTypeParam() = this.getGenericParamList().getATypeParam()
426+
or
427+
result.(AssociatedTypeTypeParameter).getTrait() = this
426428
}
427429
}
428430

429431
final class TypeBoundTypeAbstraction extends TypeAbstraction, TypeBound {
430-
override TypeParamTypeParameter getATypeParameter() { none() }
432+
override TypeParameter getATypeParameter() { none() }
431433
}
432434

433435
final class SelfTypeBoundTypeAbstraction extends TypeAbstraction, Name {
434-
SelfTypeBoundTypeAbstraction() { any(Trait trait).getName() = this }
436+
private TraitTypeAbstraction trait;
437+
438+
SelfTypeBoundTypeAbstraction() { trait.getName() = this }
435439

436-
override TypeParamTypeParameter getATypeParameter() { none() }
440+
override TypeParameter getATypeParameter() { none() }
437441
}
438442

439443
final class ImplTraitTypeReprAbstraction extends TypeAbstraction, ImplTraitTypeRepr {
440-
override TypeParamTypeParameter getATypeParameter() { none() }
444+
override TypeParameter getATypeParameter() { none() }
441445
}

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 74 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,24 @@ private Type inferAssignmentOperationType(AstNode n, TypePath path) {
231231
result = TUnit()
232232
}
233233

234+
pragma[nomagic]
235+
private Struct getRangeType(RangeExpr re) {
236+
re instanceof RangeFromExpr and
237+
result instanceof RangeFromStruct
238+
or
239+
re instanceof RangeToExpr and
240+
result instanceof RangeToStruct
241+
or
242+
re instanceof RangeFromToExpr and
243+
result instanceof RangeStruct
244+
or
245+
re instanceof RangeInclusiveExpr and
246+
result instanceof RangeInclusiveStruct
247+
or
248+
re instanceof RangeToInclusiveExpr and
249+
result instanceof RangeToInclusiveStruct
250+
}
251+
234252
/**
235253
* Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
236254
* of `n2` at `prefix2` and type information should propagate in both directions
@@ -296,6 +314,13 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
296314
n1.(ArrayRepeatExpr).getRepeatOperand() = n2 and
297315
prefix1 = TypePath::singleton(TArrayTypeParameter()) and
298316
prefix2.isEmpty()
317+
or
318+
exists(Struct s |
319+
n2 = [n1.(RangeExpr).getStart(), n1.(RangeExpr).getEnd()] and
320+
prefix1 = TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and
321+
prefix2.isEmpty() and
322+
s = getRangeType(n1)
323+
)
299324
}
300325

301326
pragma[nomagic]
@@ -1062,7 +1087,7 @@ private TraitType inferAsyncBlockExprRootType(AsyncBlockExpr abe) {
10621087
result = getFutureTraitType()
10631088
}
10641089

1065-
final class AwaitTarget extends Expr {
1090+
final private class AwaitTarget extends Expr {
10661091
AwaitTarget() { this = any(AwaitExpr ae).getExpr() }
10671092

10681093
Type getTypeAt(TypePath path) { result = inferType(this, path) }
@@ -1098,6 +1123,12 @@ private class Vec extends Struct {
10981123
pragma[nomagic]
10991124
private Type inferArrayExprType(ArrayExpr ae) { exists(ae) and result = TArrayType() }
11001125

1126+
/**
1127+
* Gets the root type of the range expression `re`.
1128+
*/
1129+
pragma[nomagic]
1130+
private Type inferRangeExprType(RangeExpr re) { result = TStruct(getRangeType(re)) }
1131+
11011132
/**
11021133
* According to [the Rust reference][1]: _"array and slice-typed expressions
11031134
* can be indexed with a `usize` index ... For other types an index expression
@@ -1134,23 +1165,49 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) {
11341165
)
11351166
}
11361167

1168+
final private class ForIterableExpr extends Expr {
1169+
ForIterableExpr() { this = any(ForExpr fe).getIterable() }
1170+
1171+
Type getTypeAt(TypePath path) { result = inferType(this, path) }
1172+
}
1173+
1174+
private module ForIterableSatisfiesConstraintInput implements
1175+
SatisfiesConstraintInputSig<ForIterableExpr>
1176+
{
1177+
predicate relevantConstraint(ForIterableExpr term, Type constraint) {
1178+
exists(term) and
1179+
exists(Trait t | t = constraint.(TraitType).getTrait() |
1180+
// TODO: Remove the line below once we can handle the `impl<I: Iterator> IntoIterator for I` implementation
1181+
t instanceof IteratorTrait or
1182+
t instanceof IntoIteratorTrait
1183+
)
1184+
}
1185+
}
1186+
1187+
pragma[nomagic]
1188+
private AssociatedTypeTypeParameter getIteratorItemTypeParameter() {
1189+
result.getTypeAlias() = any(IteratorTrait t).getItemType()
1190+
}
1191+
1192+
pragma[nomagic]
1193+
private AssociatedTypeTypeParameter getIntoIteratorItemTypeParameter() {
1194+
result.getTypeAlias() = any(IntoIteratorTrait t).getItemType()
1195+
}
1196+
11371197
pragma[nomagic]
11381198
private Type inferForLoopExprType(AstNode n, TypePath path) {
11391199
// type of iterable -> type of pattern (loop variable)
1140-
exists(ForExpr fe, Type iterableType, TypePath iterablePath |
1200+
exists(ForExpr fe, TypePath exprPath, AssociatedTypeTypeParameter tp |
11411201
n = fe.getPat() and
1142-
iterableType = inferType(fe.getIterable(), iterablePath) and
1143-
result = iterableType and
1144-
(
1145-
iterablePath.isCons(any(Vec v).getElementTypeParameter(), path)
1146-
or
1147-
iterablePath.isCons(any(ArrayTypeParameter tp), path)
1148-
or
1149-
iterablePath
1150-
.stripPrefix(TypePath::cons(TRefTypeParameter(),
1151-
TypePath::singleton(any(SliceTypeParameter tp)))) = path
1152-
// TODO: iterables (general case for containers, ranges etc)
1153-
)
1202+
SatisfiesConstraint<ForIterableExpr, ForIterableSatisfiesConstraintInput>::satisfiesConstraintType(fe.getIterable(),
1203+
_, exprPath, result) and
1204+
exprPath.isCons(tp, path)
1205+
|
1206+
tp = getIntoIteratorItemTypeParameter()
1207+
or
1208+
// TODO: Remove once we can handle the `impl<I: Iterator> IntoIterator for I` implementation
1209+
tp = getIteratorItemTypeParameter() and
1210+
inferType(fe.getIterable()) != TArrayType()
11541211
)
11551212
}
11561213

@@ -1589,6 +1646,9 @@ private module Cached {
15891646
result = inferArrayExprType(n) and
15901647
path.isEmpty()
15911648
or
1649+
result = inferRangeExprType(n) and
1650+
path.isEmpty()
1651+
or
15921652
result = inferIndexExprType(n, path)
15931653
or
15941654
result = inferForLoopExprType(n, path)

rust/ql/lib/rust.qll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import codeql.rust.elements.AsyncBlockExpr
1515
import codeql.rust.elements.Variable
1616
import codeql.rust.elements.NamedFormatArgument
1717
import codeql.rust.elements.PositionalFormatArgument
18+
import codeql.rust.elements.RangeExprExt
1819
private import codeql.rust.elements.Call as Call
1920

2021
class Call = Call::Call;

rust/ql/test/library-tests/dataflow/sources/TaintSources.expected

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,16 @@
5555
| test.rs:412:31:412:38 | ...::read | Flow source 'FileSource' of type file (DEFAULT). |
5656
| test.rs:417:22:417:39 | ...::read_to_string | Flow source 'FileSource' of type file (DEFAULT). |
5757
| test.rs:417:22:417:39 | ...::read_to_string | Flow source 'FileSource' of type file (DEFAULT). |
58+
| test.rs:423:22:423:25 | path | Flow source 'FileSource' of type file (DEFAULT). |
59+
| test.rs:424:27:424:35 | file_name | Flow source 'FileSource' of type file (DEFAULT). |
5860
| test.rs:430:22:430:34 | ...::read_link | Flow source 'FileSource' of type file (DEFAULT). |
5961
| test.rs:439:31:439:45 | ...::read | Flow source 'FileSource' of type file (DEFAULT). |
6062
| test.rs:444:31:444:45 | ...::read | Flow source 'FileSource' of type file (DEFAULT). |
6163
| test.rs:449:22:449:46 | ...::read_to_string | Flow source 'FileSource' of type file (DEFAULT). |
64+
| test.rs:455:26:455:29 | path | Flow source 'FileSource' of type file (DEFAULT). |
65+
| test.rs:455:26:455:29 | path | Flow source 'FileSource' of type file (DEFAULT). |
66+
| test.rs:456:31:456:39 | file_name | Flow source 'FileSource' of type file (DEFAULT). |
67+
| test.rs:456:31:456:39 | file_name | Flow source 'FileSource' of type file (DEFAULT). |
6268
| test.rs:462:22:462:41 | ...::read_link | Flow source 'FileSource' of type file (DEFAULT). |
6369
| test.rs:472:20:472:38 | ...::open | Flow source 'FileSource' of type file (DEFAULT). |
6470
| test.rs:506:21:506:39 | ...::open | Flow source 'FileSource' of type file (DEFAULT). |

0 commit comments

Comments
 (0)