Skip to content

Commit 7647bc3

Browse files
committed
Rust: Improve type inference for for loops and range expressions
1 parent e33ddce commit 7647bc3

File tree

9 files changed

+348
-38
lines changed

9 files changed

+348
-38
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/**
2+
* This module provides sub classes of the `RangeExpr` class.
3+
*/
4+
5+
private import rust
6+
7+
/**
8+
* A range-from expression. For example:
9+
* ```rust
10+
* let x = 10..;
11+
* ```
12+
*/
13+
final class RangeFromExpr extends RangeExpr {
14+
RangeFromExpr() {
15+
this.getOperatorName() = ".." and
16+
not this.hasEnd()
17+
}
18+
}
19+
20+
/**
21+
* A range-to expression. For example:
22+
* ```rust
23+
* let x = ..10;
24+
* ```
25+
*/
26+
final class RangeToExpr extends RangeExpr {
27+
RangeToExpr() {
28+
this.getOperatorName() = ".." and
29+
not this.hasStart()
30+
}
31+
}
32+
33+
/**
34+
* A range-from-to expression. For example:
35+
* ```rust
36+
* let x = 10..20;
37+
* ```
38+
*/
39+
final class RangeFromToExpr extends RangeExpr {
40+
RangeFromToExpr() {
41+
this.getOperatorName() = ".." and
42+
this.hasStart() and
43+
this.hasEnd()
44+
}
45+
}
46+
47+
/**
48+
* A range-inclusive expression. For example:
49+
* ```rust
50+
* let x = 1..=10;
51+
* ```
52+
*/
53+
final class RangeInclusiveExpr extends RangeExpr {
54+
RangeInclusiveExpr() {
55+
this.getOperatorName() = "..=" and
56+
this.hasStart() and
57+
this.hasEnd()
58+
}
59+
}
60+
61+
/**
62+
* A range-to-inclusive expression. For example:
63+
* ```rust
64+
* let x = ..=10;
65+
* ```
66+
*/
67+
final class RangeToInclusiveExpr extends RangeExpr {
68+
RangeToInclusiveExpr() {
69+
this.getOperatorName() = "..=" and
70+
not this.hasStart() and
71+
this.hasEnd()
72+
}
73+
}

rust/ql/lib/codeql/rust/frameworks/stdlib/Stdlib.qll

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,72 @@ class ResultEnum extends Enum {
5050
Variant getErr() { result = this.getVariant("Err") }
5151
}
5252

53+
/**
54+
* The [`Range` struct][1].
55+
*
56+
* [1]: https://doc.rust-lang.org/core/ops/struct.Range.html
57+
*/
58+
class RangeStruct extends Struct {
59+
RangeStruct() { this.getCanonicalPath() = "core::ops::range::Range" }
60+
61+
/** Gets the `start` field. */
62+
StructField getStart() { result = this.getStructField("start") }
63+
64+
/** Gets the `end` field. */
65+
StructField getEnd() { result = this.getStructField("end") }
66+
}
67+
68+
/**
69+
* The [`RangeFrom` struct][1].
70+
*
71+
* [1]: https://doc.rust-lang.org/core/ops/struct.RangeFrom.html
72+
*/
73+
class RangeFromStruct extends Struct {
74+
RangeFromStruct() { this.getCanonicalPath() = "core::ops::range::RangeFrom" }
75+
76+
/** Gets the `start` field. */
77+
StructField getStart() { result = this.getStructField("start") }
78+
}
79+
80+
/**
81+
* The [`RangeTo` struct][1].
82+
*
83+
* [1]: https://doc.rust-lang.org/core/ops/struct.RangeTo.html
84+
*/
85+
class RangeToStruct extends Struct {
86+
RangeToStruct() { this.getCanonicalPath() = "core::ops::range::RangeTo" }
87+
88+
/** Gets the `end` field. */
89+
StructField getEnd() { result = this.getStructField("end") }
90+
}
91+
92+
/**
93+
* The [`RangeInclusive` struct][1].
94+
*
95+
* [1]: https://doc.rust-lang.org/core/ops/struct.RangeInclusive.html
96+
*/
97+
class RangeInclusiveStruct extends Struct {
98+
RangeInclusiveStruct() { this.getCanonicalPath() = "core::ops::range::RangeInclusive" }
99+
100+
/** Gets the `start` field. */
101+
StructField getStart() { result = this.getStructField("start") }
102+
103+
/** Gets the `end` field. */
104+
StructField getEnd() { result = this.getStructField("end") }
105+
}
106+
107+
/**
108+
* The [`RangeToInclusive` struct][1].
109+
*
110+
* [1]: https://doc.rust-lang.org/core/ops/struct.RangeToInclusive.html
111+
*/
112+
class RangeToInclusiveStruct extends Struct {
113+
RangeToInclusiveStruct() { this.getCanonicalPath() = "core::ops::range::RangeToInclusive" }
114+
115+
/** Gets the `end` field. */
116+
StructField getEnd() { result = this.getStructField("end") }
117+
}
118+
53119
/**
54120
* The [`Future` trait][1].
55121
*
@@ -66,6 +132,38 @@ class FutureTrait extends Trait {
66132
}
67133
}
68134

135+
/**
136+
* The [`Iterator` trait][1].
137+
*
138+
* [1]: https://doc.rust-lang.org/std/iter/trait.Iterator.html
139+
*/
140+
class IteratorTrait extends Trait {
141+
IteratorTrait() { this.getCanonicalPath() = "core::iter::traits::iterator::Iterator" }
142+
143+
/** Gets the `Item` associated type. */
144+
pragma[nomagic]
145+
TypeAlias getItemType() {
146+
result = this.getAssocItemList().getAnAssocItem() and
147+
result.getName().getText() = "Item"
148+
}
149+
}
150+
151+
/**
152+
* The [`IntoIterator` trait][1].
153+
*
154+
* [1]: https://doc.rust-lang.org/std/iter/trait.IntoIterator.html
155+
*/
156+
class IntoIteratorTrait extends Trait {
157+
IntoIteratorTrait() { this.getCanonicalPath() = "core::iter::traits::collect::IntoIterator" }
158+
159+
/** Gets the `Item` associated type. */
160+
pragma[nomagic]
161+
TypeAlias getItemType() {
162+
result = this.getAssocItemList().getAnAssocItem() and
163+
result.getName().getText() = "Item"
164+
}
165+
}
166+
69167
/**
70168
* The [`String` struct][1].
71169
*

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -421,21 +421,25 @@ final class ImplTypeAbstraction extends TypeAbstraction, Impl {
421421
}
422422

423423
final class TraitTypeAbstraction extends TypeAbstraction, Trait {
424-
override TypeParamTypeParameter getATypeParameter() {
425-
result.getTypeParam() = this.getGenericParamList().getATypeParam()
424+
override TypeParameter getATypeParameter() {
425+
result.(TypeParamTypeParameter).getTypeParam() = this.getGenericParamList().getATypeParam()
426+
or
427+
result.(AssociatedTypeTypeParameter).getTrait() = this
426428
}
427429
}
428430

429431
final class TypeBoundTypeAbstraction extends TypeAbstraction, TypeBound {
430-
override TypeParamTypeParameter getATypeParameter() { none() }
432+
override TypeParameter getATypeParameter() { none() }
431433
}
432434

433435
final class SelfTypeBoundTypeAbstraction extends TypeAbstraction, Name {
434-
SelfTypeBoundTypeAbstraction() { any(Trait trait).getName() = this }
436+
private TraitTypeAbstraction trait;
437+
438+
SelfTypeBoundTypeAbstraction() { trait.getName() = this }
435439

436-
override TypeParamTypeParameter getATypeParameter() { none() }
440+
override TypeParameter getATypeParameter() { none() }
437441
}
438442

439443
final class ImplTraitTypeReprAbstraction extends TypeAbstraction, ImplTraitTypeRepr {
440-
override TypeParamTypeParameter getATypeParameter() { none() }
444+
override TypeParameter getATypeParameter() { none() }
441445
}

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 87 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,27 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
296296
n1.(ArrayRepeatExpr).getRepeatOperand() = n2 and
297297
prefix1 = TypePath::singleton(TArrayTypeParameter()) and
298298
prefix2.isEmpty()
299+
or
300+
exists(Struct s |
301+
n2 = [n1.(RangeExpr).getStart(), n1.(RangeExpr).getEnd()] and
302+
prefix1 = TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and
303+
prefix2.isEmpty()
304+
|
305+
n1 instanceof RangeFromExpr and
306+
s instanceof RangeFromStruct
307+
or
308+
n1 instanceof RangeToExpr and
309+
s instanceof RangeToStruct
310+
or
311+
n1 instanceof RangeFromToExpr and
312+
s instanceof RangeStruct
313+
or
314+
n1 instanceof RangeInclusiveExpr and
315+
s instanceof RangeInclusiveStruct
316+
or
317+
n1 instanceof RangeToInclusiveExpr and
318+
s instanceof RangeToInclusiveStruct
319+
)
299320
}
300321

301322
pragma[nomagic]
@@ -1062,7 +1083,7 @@ private TraitType inferAsyncBlockExprRootType(AsyncBlockExpr abe) {
10621083
result = getFutureTraitType()
10631084
}
10641085

1065-
final class AwaitTarget extends Expr {
1086+
final private class AwaitTarget extends Expr {
10661087
AwaitTarget() { this = any(AwaitExpr ae).getExpr() }
10671088

10681089
Type getTypeAt(TypePath path) { result = inferType(this, path) }
@@ -1098,6 +1119,29 @@ private class Vec extends Struct {
10981119
pragma[nomagic]
10991120
private Type inferArrayExprType(ArrayExpr ae) { exists(ae) and result = TArrayType() }
11001121

1122+
/**
1123+
* Gets the root type of the range expression `re`.
1124+
*/
1125+
pragma[nomagic]
1126+
private Type inferRangeExprType(RangeExpr re) {
1127+
exists(Struct s | result = TStruct(s) |
1128+
re instanceof RangeFromExpr and
1129+
s instanceof RangeFromStruct
1130+
or
1131+
re instanceof RangeToExpr and
1132+
s instanceof RangeToStruct
1133+
or
1134+
re instanceof RangeFromToExpr and
1135+
s instanceof RangeStruct
1136+
or
1137+
re instanceof RangeInclusiveExpr and
1138+
s instanceof RangeInclusiveStruct
1139+
or
1140+
re instanceof RangeToInclusiveExpr and
1141+
s instanceof RangeToInclusiveStruct
1142+
)
1143+
}
1144+
11011145
/**
11021146
* According to [the Rust reference][1]: _"array and slice-typed expressions
11031147
* can be indexed with a `usize` index ... For other types an index expression
@@ -1134,23 +1178,49 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) {
11341178
)
11351179
}
11361180

1181+
final private class ForIterableExpr extends Expr {
1182+
ForIterableExpr() { this = any(ForExpr fe).getIterable() }
1183+
1184+
Type getTypeAt(TypePath path) { result = inferType(this, path) }
1185+
}
1186+
1187+
private module ForIterableSatisfiesConstraintInput implements
1188+
SatisfiesConstraintInputSig<ForIterableExpr>
1189+
{
1190+
predicate relevantConstraint(ForIterableExpr term, Type constraint) {
1191+
exists(term) and
1192+
exists(Trait t | t = constraint.(TraitType).getTrait() |
1193+
// TODO: Remove the line below once we can handle the `impl<I: Iterator> IntoIterator for I` implementation
1194+
t instanceof IteratorTrait or
1195+
t instanceof IntoIteratorTrait
1196+
)
1197+
}
1198+
}
1199+
1200+
pragma[nomagic]
1201+
private AssociatedTypeTypeParameter getIteratorItemTypeParameter() {
1202+
result.getTypeAlias() = any(IteratorTrait t).getItemType()
1203+
}
1204+
1205+
pragma[nomagic]
1206+
private AssociatedTypeTypeParameter getIntoIteratorItemTypeParameter() {
1207+
result.getTypeAlias() = any(IntoIteratorTrait t).getItemType()
1208+
}
1209+
11371210
pragma[nomagic]
11381211
private Type inferForLoopExprType(AstNode n, TypePath path) {
11391212
// type of iterable -> type of pattern (loop variable)
1140-
exists(ForExpr fe, Type iterableType, TypePath iterablePath |
1213+
exists(ForExpr fe, TypePath exprPath, AssociatedTypeTypeParameter tp |
11411214
n = fe.getPat() and
1142-
iterableType = inferType(fe.getIterable(), iterablePath) and
1143-
result = iterableType and
1144-
(
1145-
iterablePath.isCons(any(Vec v).getElementTypeParameter(), path)
1146-
or
1147-
iterablePath.isCons(any(ArrayTypeParameter tp), path)
1148-
or
1149-
iterablePath
1150-
.stripPrefix(TypePath::cons(TRefTypeParameter(),
1151-
TypePath::singleton(any(SliceTypeParameter tp)))) = path
1152-
// TODO: iterables (general case for containers, ranges etc)
1153-
)
1215+
SatisfiesConstraint<ForIterableExpr, ForIterableSatisfiesConstraintInput>::satisfiesConstraintType(fe.getIterable(),
1216+
_, exprPath, result) and
1217+
exprPath.isCons(tp, path)
1218+
|
1219+
tp = getIntoIteratorItemTypeParameter()
1220+
or
1221+
// TODO: Remove once we can handle the `impl<I: Iterator> IntoIterator for I` implementation
1222+
tp = getIteratorItemTypeParameter() and
1223+
inferType(fe.getIterable()) != TArrayType()
11541224
)
11551225
}
11561226

@@ -1589,6 +1659,9 @@ private module Cached {
15891659
result = inferArrayExprType(n) and
15901660
path.isEmpty()
15911661
or
1662+
result = inferRangeExprType(n) and
1663+
path.isEmpty()
1664+
or
15921665
result = inferIndexExprType(n, path)
15931666
or
15941667
result = inferForLoopExprType(n, path)

rust/ql/lib/rust.qll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ import codeql.rust.elements.AsyncBlockExpr
1515
import codeql.rust.elements.Variable
1616
import codeql.rust.elements.NamedFormatArgument
1717
import codeql.rust.elements.PositionalFormatArgument
18+
import codeql.rust.elements.RangeExprExt

rust/ql/test/library-tests/dataflow/sources/TaintSources.expected

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,16 @@
5555
| test.rs:412:31:412:38 | ...::read | Flow source 'FileSource' of type file (DEFAULT). |
5656
| test.rs:417:22:417:39 | ...::read_to_string | Flow source 'FileSource' of type file (DEFAULT). |
5757
| test.rs:417:22:417:39 | ...::read_to_string | Flow source 'FileSource' of type file (DEFAULT). |
58+
| test.rs:423:22:423:25 | path | Flow source 'FileSource' of type file (DEFAULT). |
59+
| test.rs:424:27:424:35 | file_name | Flow source 'FileSource' of type file (DEFAULT). |
5860
| test.rs:430:22:430:34 | ...::read_link | Flow source 'FileSource' of type file (DEFAULT). |
5961
| test.rs:439:31:439:45 | ...::read | Flow source 'FileSource' of type file (DEFAULT). |
6062
| test.rs:444:31:444:45 | ...::read | Flow source 'FileSource' of type file (DEFAULT). |
6163
| test.rs:449:22:449:46 | ...::read_to_string | Flow source 'FileSource' of type file (DEFAULT). |
64+
| test.rs:455:26:455:29 | path | Flow source 'FileSource' of type file (DEFAULT). |
65+
| test.rs:455:26:455:29 | path | Flow source 'FileSource' of type file (DEFAULT). |
66+
| test.rs:456:31:456:39 | file_name | Flow source 'FileSource' of type file (DEFAULT). |
67+
| test.rs:456:31:456:39 | file_name | Flow source 'FileSource' of type file (DEFAULT). |
6268
| test.rs:462:22:462:41 | ...::read_link | Flow source 'FileSource' of type file (DEFAULT). |
6369
| test.rs:472:20:472:38 | ...::open | Flow source 'FileSource' of type file (DEFAULT). |
6470
| test.rs:506:21:506:39 | ...::open | Flow source 'FileSource' of type file (DEFAULT). |

0 commit comments

Comments
 (0)