Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions rust/ql/lib/codeql/rust/internal/Type.qll
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ newtype TType =
TSliceType() or
TNeverType() or
TPtrType() or
TContextType() or
TTupleTypeParameter(int arity, int i) { exists(TTuple(arity)) and i in [0 .. arity - 1] } or
TTypeParamTypeParameter(TypeParam t) or
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
Expand Down Expand Up @@ -371,6 +372,26 @@ class PtrType extends Type, TPtrType {
override Location getLocation() { result instanceof EmptyLocation }
}

/**
* A special pseudo type used to indicate that the actual type is to be inferred
* from a context.
*
* For example, a call like `Default::default()` is assigned this type, which
* means that the actual type is to be inferred from the context in which the call
* occurs.
*
* Context types are not restricted to root types, for example in a call like
* `Vec::new()` we assign this type at the type path corresponding to the type
* parameter of `Vec`.
*/
class ContextType extends Type, TContextType {
override TypeParameter getPositionalTypeParameter(int i) { none() }

override string toString() { result = "(context typed)" }

override Location getLocation() { result instanceof EmptyLocation }
}

/** A type parameter. */
abstract class TypeParameter extends Type {
override TypeParameter getPositionalTypeParameter(int i) { none() }
Expand Down
170 changes: 151 additions & 19 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,113 @@
)
}

/**
* Provides functionality related to context-based typing of calls.
*/
private module ContextTyping {
/**
* Holds if the return type of the function `f` at path `path` is `tp`,
* and `tp` does not appear in the type of any parameter of `f`.
*
* In this case, the context in which `f` is called may be needed to infer
* the instantiation of `tp`.
*/
pragma[nomagic]
private predicate assocFunctionReturnContextTypedAt(
Function f, FunctionPosition pos, TypePath path, TypeParameter tp
) {
exists(ImplOrTraitItemNode i |
pos.isReturn() and
assocFunctionTypeAt(f, i, pos, path, tp) and
not exists(FunctionPosition nonResPos |
not nonResPos.isReturn() and
assocFunctionTypeAt(f, i, nonResPos, _, tp)
)
)
}

/**
* A call where the type of the result may have to be inferred from the
* context in which the call appears, for example a call like
* `Default::default()`.
*/
abstract class ContextTypedCallCand extends AstNode {
abstract Type getTypeArgument(TypeArgumentPosition apos, TypePath path);

private predicate hasTypeArgument(TypeArgumentPosition apos) {
exists(this.getTypeArgument(apos, _))
}

/**
* Holds if `this` call resolves to `target` and the type at `pos` and `path`
* may have to be inferred from the context.
*/
bindingset[this, target]
predicate isContextTypedAt(Function target, TypePath path, FunctionPosition pos) {
exists(TypeParameter tp |
assocFunctionReturnContextTypedAt(target, pos, path, tp) and
// check that no explicit type arguments have been supplied for `tp`
not exists(TypeArgumentPosition tapos | this.hasTypeArgument(tapos) |
exists(int i |
i = tapos.asMethodTypeArgumentPosition() and
tp = TTypeParamTypeParameter(target.getGenericParamList().getTypeParam(i))
)
or
TTypeParamTypeParameter(tapos.asTypeParam()) = tp
) and
not (
tp instanceof TSelfTypeParameter and
exists(getCallExprTypeQualifier(this, _))
)
)
}
}

pragma[nomagic]
private predicate isContextTyped(AstNode n, TypePath path) { inferType(n, path) = TContextType() }

pragma[nomagic]
private predicate isContextTyped(AstNode n) { isContextTyped(n, _) }

signature Type inferCallTypeSig(AstNode n, FunctionPosition pos, TypePath path);

/**
* Given a predicate `inferCallType` for inferring the type of a call at a given
* position, this module exposes the predicate `check`, which wraps the input
* predicate and checks that types are only propagated into arguments when they
* are context-typed.
*/
module CheckContextTyping<inferCallTypeSig/3 inferCallType> {
pragma[nomagic]
private Type inferCallTypeFromContextCand(
AstNode n, FunctionPosition pos, TypePath path, TypePath prefix
) {
result = inferCallType(n, pos, path) and
not pos.isReturn() and
isContextTyped(n) and
prefix = path
or
exists(TypePath mid |
result = inferCallTypeFromContextCand(n, pos, path, mid) and
mid.isSnoc(prefix, _)
)
}

pragma[nomagic]
Type check(AstNode n, TypePath path) {
exists(FunctionPosition pos |
result = inferCallType(n, pos, path) and
pos.isReturn()
or
exists(TypePath prefix |
result = inferCallTypeFromContextCand(n, pos, path, prefix) and
isContextTyped(n, prefix)
)
)
}
}
}

/**
* Holds if function `f` with the name `name` and the arity `arity` exists in
* `i`, and the type at position `pos` is `t`.
Expand Down Expand Up @@ -1890,14 +1997,14 @@

final private class MethodCallFinal = MethodResolution::MethodCall;

class Access extends MethodCallFinal {
class Access extends MethodCallFinal, ContextTyping::ContextTypedCallCand {
Access() {
// handled in the `OperationMatchingInput` module
not this instanceof Operation
}

pragma[nomagic]
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
exists(TypeMention arg | result = arg.resolveTypeAt(path) |
arg =
this.(MethodCallExpr).getGenericArgList().getTypeArg(apos.asMethodTypeArgumentPosition())
Expand Down Expand Up @@ -1961,7 +2068,12 @@
) {
exists(TypePath path0 |
n = a.getNodeAt(apos) and
result = MethodCallMatching::inferAccessType(a, derefChainBorrow, apos, path0)
(
result = MethodCallMatching::inferAccessType(a, derefChainBorrow, apos, path0)
or
a.isContextTypedAt(a.getTarget(derefChainBorrow), path0, apos) and
result = TContextType()
)
|
if
// index expression `x[i]` desugars to `*x.index(i)`, so we must account for
Expand All @@ -1973,16 +2085,11 @@
)
}

/**
* Gets the type of `n` at `path`, where `n` is either a method call or an
* argument/receiver of a method call.
*/
pragma[nomagic]
private Type inferMethodCallType(AstNode n, TypePath path) {
exists(
MethodCallMatchingInput::Access a, MethodCallMatchingInput::AccessPosition apos,
string derefChainBorrow, TypePath path0
|
private Type inferMethodCallType1(
AstNode n, MethodCallMatchingInput::AccessPosition apos, TypePath path
) {
exists(MethodCallMatchingInput::Access a, string derefChainBorrow, TypePath path0 |
result = inferMethodCallType0(a, apos, n, derefChainBorrow, path0)
|
(
Expand All @@ -2004,6 +2111,13 @@
)
}

/**
* Gets the type of `n` at `path`, where `n` is either a method call or an
* argument/receiver of a method call.
*/
private predicate inferMethodCallType =
ContextTyping::CheckContextTyping<inferMethodCallType1/3>::check/2;

/**
* Provides logic for resolving calls to non-method items. This includes
* "calls" to tuple variants and tuple structs.
Expand Down Expand Up @@ -2171,6 +2285,12 @@
or
result = this.resolveCallTargetRec()
}

pragma[nomagic]
Function resolveTraitFunction() {
this.(Call).hasTrait() and
result = this.getPathResolutionResolved()
}
}

private newtype TCallAndBlanketPos =
Expand Down Expand Up @@ -2405,9 +2525,9 @@
}
}

class Access extends NonMethodResolution::NonMethodCall {
class Access extends NonMethodResolution::NonMethodCall, ContextTyping::ContextTypedCallCand {
pragma[nomagic]
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
result = getCallExprTypeArgument(this, apos).resolveTypeAt(path)
}

Expand All @@ -2428,13 +2548,20 @@
private module NonMethodCallMatching = Matching<NonMethodCallMatchingInput>;

pragma[nomagic]
private Type inferNonMethodCallType(AstNode n, TypePath path) {
exists(NonMethodCallMatchingInput::Access a, NonMethodCallMatchingInput::AccessPosition apos |
n = a.getNodeAt(apos) and
private Type inferNonMethodCallType0(
AstNode n, NonMethodCallMatchingInput::AccessPosition apos, TypePath path
) {
exists(NonMethodCallMatchingInput::Access a | n = a.getNodeAt(apos) |
result = NonMethodCallMatching::inferAccessType(a, apos, path)
or
a.isContextTypedAt([a.resolveCallTarget().(Function), a.resolveTraitFunction()], path, apos) and
result = TContextType()
)
}

private predicate inferNonMethodCallType =
ContextTyping::CheckContextTyping<inferNonMethodCallType0/3>::check/2;

/**
* A matching configuration for resolving types of operations like `a + b`.
*/
Expand Down Expand Up @@ -2507,13 +2634,18 @@
private module OperationMatching = Matching<OperationMatchingInput>;

pragma[nomagic]
private Type inferOperationType(AstNode n, TypePath path) {
exists(OperationMatchingInput::Access a, OperationMatchingInput::AccessPosition apos |
private Type inferOperationType0(
AstNode n, OperationMatchingInput::AccessPosition apos, TypePath path
) {
exists(OperationMatchingInput::Access a |
n = a.getNodeAt(apos) and
result = OperationMatching::inferAccessType(a, apos, path)
)
}

private predicate inferOperationType =
ContextTyping::CheckContextTyping<inferOperationType0/3>::check/2;

pragma[nomagic]
private Type getFieldExprLookupType(FieldExpr fe, string name) {
exists(TypePath path |
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
multipleCallTargets
| test.rs:24:24:24:34 | row.take(...) |
| test.rs:111:24:111:34 | row.take(...) |
multiplePathResolutions
| test.rs:10:28:10:65 | Result::<...> |
| test.rs:97:40:97:49 | Result::<...> |
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
multipleCallTargets
| test.rs:288:7:288:36 | ... .as_str() |
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ multipleCallTargets
| dereference.rs:184:17:184:30 | ... .foo() |
| dereference.rs:186:17:186:25 | S.bar(...) |
| dereference.rs:187:17:187:29 | S.bar(...) |
| main.rs:2437:13:2437:31 | ...::from(...) |
| main.rs:2438:13:2438:31 | ...::from(...) |
| main.rs:2439:13:2439:31 | ...::from(...) |
| main.rs:2445:13:2445:31 | ...::from(...) |
| main.rs:2446:13:2446:31 | ...::from(...) |
| main.rs:2447:13:2447:31 | ...::from(...) |
| main.rs:2459:13:2459:31 | ...::from(...) |
| main.rs:2460:13:2460:31 | ...::from(...) |
| main.rs:2461:13:2461:31 | ...::from(...) |
| main.rs:2467:13:2467:31 | ...::from(...) |
| main.rs:2468:13:2468:31 | ...::from(...) |
| main.rs:2469:13:2469:31 | ...::from(...) |
22 changes: 22 additions & 0 deletions rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,26 @@ mod function_trait_bounds {
fn assoc(x: Self) -> A;
}

impl<T: Default> MyTrait<T> for S2 {
fn m1(self) -> T {
Default::default() // $ target=default
}

fn assoc(x: Self) -> T {
Default::default() // $ target=default
}
}

impl MyTrait<i32> for S1 {
fn m1(self) -> i32 {
0
}

fn assoc(x: Self) -> i32 {
0
}
}

// Type parameter with bound occurs in the root of a parameter type.

fn call_trait_m1<T1, T2: MyTrait<T1> + Copy>(x: T2) -> T1 {
Expand Down Expand Up @@ -781,6 +801,8 @@ mod function_trait_bounds {
println!("{:?}", b);
let b = call_trait_thing_m1_3(y3); // $ type=b:S2 target=call_trait_thing_m1_3
println!("{:?}", b);
let x = S1::m2(S1); // $ target=m2 $ type=x:i32
let y: i32 = S2::m2(S2); // $ target=m2
}
}

Expand Down
Loading