Skip to content

Commit 224e5de

Browse files
authored
Merge pull request #21170 from paldepind/rust/type-inference-fns
Rust: Improve type inference for closures and function traits
2 parents 6fe76b3 + dd73399 commit 224e5de

File tree

7 files changed

+424
-110
lines changed

7 files changed

+424
-110
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
category: minorAnalysis
3+
---
4+
* Added type inference support for the `FnMut(..) -> ..` and `Fn(..) -> ..` traits. They now work in type parameter bounds and are implemented by closures.

rust/ql/lib/codeql/rust/frameworks/stdlib/Stdlib.qll

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,23 +143,48 @@ class FutureTrait extends Trait {
143143
TypeAlias getOutputType() { result = this.(TraitItemNode).getAssocItem("Output") }
144144
}
145145

146+
/** A function trait `FnOnce`, `FnMut`, or `Fn`. */
147+
abstract private class AnyFnTraitImpl extends Trait {
148+
/** Gets the `Args` type parameter of this trait. */
149+
TypeParam getTypeParam() { result = this.getGenericParamList().getGenericParam(0) }
150+
}
151+
152+
final class AnyFnTrait = AnyFnTraitImpl;
153+
146154
/**
147155
* The [`FnOnce` trait][1].
148156
*
149157
* [1]: https://doc.rust-lang.org/std/ops/trait.FnOnce.html
150158
*/
151-
class FnOnceTrait extends Trait {
159+
class FnOnceTrait extends AnyFnTraitImpl {
152160
pragma[nomagic]
153161
FnOnceTrait() { this.getCanonicalPath() = "core::ops::function::FnOnce" }
154162

155-
/** Gets the type parameter of this trait. */
156-
TypeParam getTypeParam() { result = this.getGenericParamList().getGenericParam(0) }
157-
158163
/** Gets the `Output` associated type. */
159164
pragma[nomagic]
160165
TypeAlias getOutputType() { result = this.(TraitItemNode).getAssocItem("Output") }
161166
}
162167

168+
/**
169+
* The [`FnMut` trait][1].
170+
*
171+
* [1]: https://doc.rust-lang.org/std/ops/trait.FnMut.html
172+
*/
173+
class FnMutTrait extends AnyFnTraitImpl {
174+
pragma[nomagic]
175+
FnMutTrait() { this.getCanonicalPath() = "core::ops::function::FnMut" }
176+
}
177+
178+
/**
179+
* The [`Fn` trait][1].
180+
*
181+
* [1]: https://doc.rust-lang.org/std/ops/trait.Fn.html
182+
*/
183+
class FnTrait extends AnyFnTraitImpl {
184+
pragma[nomagic]
185+
FnTrait() { this.getCanonicalPath() = "core::ops::function::Fn" }
186+
}
187+
163188
/**
164189
* The [`Iterator` trait][1].
165190
*

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3827,16 +3827,29 @@ private Type invokedClosureFnTypeAt(InvokedClosureExpr ce, TypePath path) {
38273827
_, path, result)
38283828
}
38293829

3830+
/**
3831+
* Gets the root type of a closure.
3832+
*
3833+
* We model closures as `dyn Fn` trait object types. A closure might implement
3834+
* only `Fn`, `FnMut`, or `FnOnce`. But since `Fn` is a subtrait of the others,
3835+
* giving closures the type `dyn Fn` works well in practice -- even if not
3836+
* entirely accurate.
3837+
*/
3838+
private DynTraitType closureRootType() {
3839+
result = TDynTraitType(any(FnTrait t)) // always exists because of the mention in `builtins/mentions.rs`
3840+
}
3841+
38303842
/** Gets the path to a closure's return type. */
38313843
private TypePath closureReturnPath() {
3832-
result = TypePath::singleton(getDynTraitTypeParameter(any(FnOnceTrait t).getOutputType()))
3844+
result =
3845+
TypePath::singleton(TDynTraitTypeParameter(any(FnTrait t), any(FnOnceTrait t).getOutputType()))
38333846
}
38343847

3835-
/** Gets the path to a closure with arity `arity`s `index`th parameter type. */
3848+
/** Gets the path to a closure with arity `arity`'s `index`th parameter type. */
38363849
pragma[nomagic]
38373850
private TypePath closureParameterPath(int arity, int index) {
38383851
result =
3839-
TypePath::cons(TDynTraitTypeParameter(_, any(FnOnceTrait t).getTypeParam()),
3852+
TypePath::cons(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam()),
38403853
TypePath::singleton(getTupleTypeParameter(arity, index)))
38413854
}
38423855

@@ -3874,9 +3887,7 @@ private Type inferDynamicCallExprType(Expr n, TypePath path) {
38743887
or
38753888
// _If_ the invoked expression has the type of a closure, then we propagate
38763889
// the surrounding types into the closure.
3877-
exists(int arity, TypePath path0 |
3878-
ce.getTypeAt(TypePath::nil()).(DynTraitType).getTrait() instanceof FnOnceTrait
3879-
|
3890+
exists(int arity, TypePath path0 | ce.getTypeAt(TypePath::nil()) = closureRootType() |
38803891
// Propagate the type of arguments to the parameter types of closure
38813892
exists(int index, ArgList args |
38823893
n = ce and
@@ -3900,10 +3911,10 @@ private Type inferClosureExprType(AstNode n, TypePath path) {
39003911
exists(ClosureExpr ce |
39013912
n = ce and
39023913
path.isEmpty() and
3903-
result = TDynTraitType(any(FnOnceTrait t)) // always exists because of the mention in `builtins/mentions.rs`
3914+
result = closureRootType()
39043915
or
39053916
n = ce and
3906-
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnOnceTrait t).getTypeParam())) and
3917+
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and
39073918
result.(TupleType).getArity() = ce.getNumberOfParams()
39083919
or
39093920
// Propagate return type annotation to body

rust/ql/lib/codeql/rust/internal/typeinference/TypeMention.qll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,15 +213,15 @@ class NonAliasPathTypeMention extends PathTypeMention {
213213
// associated types of `Fn` and `FnMut` yet.
214214
//
215215
// [1]: https://doc.rust-lang.org/reference/paths.html#grammar-TypePathFn
216-
exists(FnOnceTrait t, PathSegment s |
216+
exists(AnyFnTrait t, PathSegment s |
217217
t = resolved and
218218
s = this.getSegment() and
219219
s.hasParenthesizedArgList()
220220
|
221221
tp = TTypeParamTypeParameter(t.getTypeParam()) and
222222
result = s.getParenthesizedArgList().(TypeMention).resolveTypeAt(path)
223223
or
224-
tp = TAssociatedTypeTypeParameter(t, t.getOutputType()) and
224+
tp = TAssociatedTypeTypeParameter(t, any(FnOnceTrait tr).getOutputType()) and
225225
(
226226
result = s.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(path)
227227
or

rust/ql/test/library-tests/type-inference/closure.rs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,80 @@ mod fn_once_trait {
6868
}
6969
}
7070

71+
mod fn_mut_trait {
72+
fn return_type<F: FnMut(bool) -> i64>(mut f: F) {
73+
let _return = f(true); // $ type=_return:i64
74+
}
75+
76+
fn return_type_omitted<F: FnMut(bool)>(mut f: F) {
77+
let _return = f(true); // $ type=_return:()
78+
}
79+
80+
fn argument_type<F: FnMut(bool) -> i64>(mut f: F) {
81+
let arg = Default::default(); // $ target=default type=arg:bool
82+
f(arg);
83+
}
84+
85+
fn apply<A, B, F: FnMut(A) -> B>(mut f: F, a: A) -> B {
86+
f(a)
87+
}
88+
89+
fn apply_two(mut f: impl FnMut(i64) -> i64) -> i64 {
90+
f(2)
91+
}
92+
93+
fn test() {
94+
let f = |x: bool| -> i64 {
95+
if x {
96+
1
97+
} else {
98+
0
99+
}
100+
};
101+
let _r = apply(f, true); // $ target=apply type=_r:i64
102+
103+
let f = |x| x + 1; // $ MISSING: type=x:i64 target=add
104+
let _r2 = apply_two(f); // $ target=apply_two certainType=_r2:i64
105+
}
106+
}
107+
108+
mod fn_trait {
109+
fn return_type<F: Fn(bool) -> i64>(f: F) {
110+
let _return = f(true); // $ type=_return:i64
111+
}
112+
113+
fn return_type_omitted<F: Fn(bool)>(f: F) {
114+
let _return = f(true); // $ type=_return:()
115+
}
116+
117+
fn argument_type<F: Fn(bool) -> i64>(f: F) {
118+
let arg = Default::default(); // $ target=default type=arg:bool
119+
f(arg);
120+
}
121+
122+
fn apply<A, B, F: Fn(A) -> B>(f: F, a: A) -> B {
123+
f(a)
124+
}
125+
126+
fn apply_two(f: impl Fn(i64) -> i64) -> i64 {
127+
f(2)
128+
}
129+
130+
fn test() {
131+
let f = |x: bool| -> i64 {
132+
if x {
133+
1
134+
} else {
135+
0
136+
}
137+
};
138+
let _r = apply(f, true); // $ target=apply type=_r:i64
139+
140+
let f = |x| x + 1; // $ MISSING: type=x:i64 target=add
141+
let _r2 = apply_two(f); // $ target=apply_two certainType=_r2:i64
142+
}
143+
}
144+
71145
mod dyn_fn_once {
72146
fn apply_boxed<A, B, F: FnOnce(A) -> B + ?Sized>(f: Box<F>, arg: A) -> B {
73147
f(arg)

0 commit comments

Comments
 (0)