Skip to content

Commit 401495b

Browse files
committed
Rust: Associated types are inherited as type parameters by traits and dyn traits
1 parent 9671547 commit 401495b

File tree

6 files changed

+148
-77
lines changed

6 files changed

+148
-77
lines changed

rust/ql/lib/codeql/rust/elements/internal/TraitImpl.qll

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
*/
66

77
private import codeql.rust.elements.internal.generated.Trait
8+
private import codeql.rust.internal.PathResolution as PathResolution
89

910
/**
1011
* INTERNAL: This module contains the customizable definition of `Trait` and should not
@@ -67,5 +68,11 @@ module Impl {
6768
* `where` clauses for `Self`.
6869
*/
6970
TypeBound getATypeBound() { result = this.getTypeBound(_) }
71+
72+
/** Gets a direct supertrait of this trait, if any. */
73+
Trait getSupertrait() {
74+
result =
75+
PathResolution::resolvePath(this.getATypeBound().getTypeRepr().(PathTypeRepr).getPath())
76+
}
7077
}
7178
}

rust/ql/lib/codeql/rust/internal/typeinference/Type.qll

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,17 @@ private import codeql.rust.elements.internal.generated.Synth
99
private import codeql.rust.frameworks.stdlib.Stdlib
1010
private import codeql.rust.frameworks.stdlib.Builtins as Builtins
1111

12+
/** Gets a type alias of `trait` or of a supertrait of `trait`. */
13+
private TypeAlias getTraitTypeAlias(Trait trait) {
14+
result = trait.getSupertrait*().getAssocItemList().getAnAssocItem()
15+
}
16+
1217
/**
13-
* Holds if a dyn trait type should have a type parameter associated with `n`. A
14-
* dyn trait type inherits the type parameters of the trait it implements. That
15-
* includes the type parameters corresponding to associated types.
18+
* Holds if a dyn trait type for the trait `trait` should have a type parameter
19+
* associated with `n`.
20+
*
21+
* A dyn trait type inherits the type parameters of the trait it implements.
22+
* That includes the type parameters corresponding to associated types.
1623
*
1724
* For instance in
1825
* ```rust
@@ -24,10 +31,7 @@ private import codeql.rust.frameworks.stdlib.Builtins as Builtins
2431
*/
2532
private predicate dynTraitTypeParameter(Trait trait, AstNode n) {
2633
trait = any(DynTraitTypeRepr dt).getTrait() and
27-
(
28-
n = trait.getGenericParamList().getATypeParam() or
29-
n = trait.(TraitItemNode).getAnAssocItem().(TypeAlias)
30-
)
34+
n = [trait.getGenericParamList().getATypeParam().(AstNode), getTraitTypeAlias(trait)]
3135
}
3236

3337
cached
@@ -39,8 +43,10 @@ newtype TType =
3943
TNeverType() or
4044
TUnknownType() or
4145
TTypeParamTypeParameter(TypeParam t) or
42-
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
43-
TDynTraitTypeParameter(AstNode n) { dynTraitTypeParameter(_, n) } or
46+
TAssociatedTypeTypeParameter(Trait trait, TypeAlias typeAlias) {
47+
getTraitTypeAlias(trait) = typeAlias
48+
} or
49+
TDynTraitTypeParameter(Trait trait, AstNode n) { dynTraitTypeParameter(trait, n) } or
4450
TImplTraitTypeParameter(ImplTraitTypeRepr implTrait, TypeParam tp) {
4551
implTraitTypeParam(implTrait, _, tp)
4652
} or
@@ -270,17 +276,10 @@ class DynTraitType extends Type, TDynTraitType {
270276
DynTraitType() { this = TDynTraitType(trait) }
271277

272278
override DynTraitTypeParameter getPositionalTypeParameter(int i) {
273-
result = TDynTraitTypeParameter(trait.getGenericParamList().getTypeParam(i))
279+
result.getTypeParam() = trait.getGenericParamList().getTypeParam(i)
274280
}
275281

276-
override TypeParameter getATypeParameter() {
277-
result = super.getATypeParameter()
278-
or
279-
exists(AstNode n |
280-
dynTraitTypeParameter(trait, n) and
281-
result = TDynTraitTypeParameter(n)
282-
)
283-
}
282+
override DynTraitTypeParameter getATypeParameter() { result.getTrait() = trait }
284283

285284
Trait getTrait() { result = trait }
286285

@@ -427,30 +426,54 @@ class TypeParamTypeParameter extends TypeParameter, TTypeParamTypeParameter {
427426
* // ...
428427
* }
429428
* ```
429+
* Furthermore, associated types of a supertrait induce a corresponding type
430+
* parameter in any subtraits. E.g., if we have a trait `SubTrait: ATrait` then
431+
* `SubTrait` also has a type parameter for the associated type
432+
* `AssociatedType`.
430433
*/
431434
class AssociatedTypeTypeParameter extends TypeParameter, TAssociatedTypeTypeParameter {
435+
private Trait trait;
432436
private TypeAlias typeAlias;
433437

434-
AssociatedTypeTypeParameter() { this = TAssociatedTypeTypeParameter(typeAlias) }
438+
AssociatedTypeTypeParameter() { this = TAssociatedTypeTypeParameter(trait, typeAlias) }
435439

436440
TypeAlias getTypeAlias() { result = typeAlias }
437441

438442
/** Gets the trait that contains this associated type declaration. */
439-
TraitItemNode getTrait() { result.getAnAssocItem() = typeAlias }
443+
TraitItemNode getTrait() { result = trait }
440444

441-
override ItemNode getDeclaringItem() { result = this.getTrait() }
445+
/**
446+
* Holds if this associated type type parameter corresponds directly its
447+
* trait, that is, it is not inherited from a supertrait.
448+
*/
449+
predicate isDirect() { trait.(TraitItemNode).getAnAssocItem() = typeAlias }
450+
451+
override ItemNode getDeclaringItem() { result = trait }
442452

443-
override string toString() { result = typeAlias.getName().getText() }
453+
override string toString() {
454+
result = typeAlias.getName().getText() + "[" + trait.getName().toString() + "]"
455+
}
444456

445457
override Location getLocation() { result = typeAlias.getLocation() }
446458
}
447459

460+
/** Gets the associated type type parameter corresponding directly to `typeAlias`. */
461+
AssociatedTypeTypeParameter getAssociatedTypeTypeParameter(TypeAlias typeAlias) {
462+
result.isDirect() and result.getTypeAlias() = typeAlias
463+
}
464+
465+
/** Gets the dyn type type parameter corresponding directly to `typeAlias`. */
466+
DynTraitTypeParameter getDynTraitTypeParameter(TypeAlias typeAlias) {
467+
result.getTraitTypeParameter() = getAssociatedTypeTypeParameter(typeAlias)
468+
}
469+
448470
class DynTraitTypeParameter extends TypeParameter, TDynTraitTypeParameter {
471+
private Trait trait;
449472
private AstNode n;
450473

451-
DynTraitTypeParameter() { this = TDynTraitTypeParameter(n) }
474+
DynTraitTypeParameter() { this = TDynTraitTypeParameter(trait, n) }
452475

453-
Trait getTrait() { dynTraitTypeParameter(result, n) }
476+
Trait getTrait() { result = trait }
454477

455478
/** Gets the dyn trait type that this type parameter belongs to. */
456479
DynTraitType getDynTraitType() { result.getTrait() = this.getTrait() }
@@ -465,7 +488,7 @@ class DynTraitTypeParameter extends TypeParameter, TDynTraitTypeParameter {
465488
TypeParameter getTraitTypeParameter() {
466489
result.(TypeParamTypeParameter).getTypeParam() = n
467490
or
468-
result.(AssociatedTypeTypeParameter).getTypeAlias() = n
491+
result = TAssociatedTypeTypeParameter(trait, n)
469492
}
470493

471494
private string toStringInner() {

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ private module Input1 implements InputSig1<Location> {
9090
tp =
9191
rank[result](TypeParameter tp0, int kind, int id1, int id2 |
9292
kind = 1 and
93-
id1 = 0 and
93+
id1 = idOfTypeParameterAstNode(tp0.(DynTraitTypeParameter).getTrait()) and
9494
id2 =
9595
idOfTypeParameterAstNode([
9696
tp0.(DynTraitTypeParameter).getTypeParam().(AstNode),
@@ -102,10 +102,13 @@ private module Input1 implements InputSig1<Location> {
102102
id2 = idOfTypeParameterAstNode(tp0.(ImplTraitTypeParameter).getTypeParam())
103103
or
104104
kind = 3 and
105+
id1 = idOfTypeParameterAstNode(tp0.(AssociatedTypeTypeParameter).getTrait()) and
106+
id2 = idOfTypeParameterAstNode(tp0.(AssociatedTypeTypeParameter).getTypeAlias())
107+
or
108+
kind = 4 and
105109
id1 = 0 and
106110
exists(AstNode node | id2 = idOfTypeParameterAstNode(node) |
107111
node = tp0.(TypeParamTypeParameter).getTypeParam() or
108-
node = tp0.(AssociatedTypeTypeParameter).getTypeAlias() or
109112
node = tp0.(SelfTypeParameter).getTrait() or
110113
node = tp0.(ImplTraitTypeTypeParameter).getImplTraitTypeRepr()
111114
)
@@ -3507,12 +3510,12 @@ private DynTraitType getFutureTraitType() { result.getTrait() instanceof FutureT
35073510

35083511
pragma[nomagic]
35093512
private AssociatedTypeTypeParameter getFutureOutputTypeParameter() {
3510-
result.getTypeAlias() = any(FutureTrait ft).getOutputType()
3513+
result = getAssociatedTypeTypeParameter(any(FutureTrait ft).getOutputType())
35113514
}
35123515

35133516
pragma[nomagic]
35143517
private DynTraitTypeParameter getDynFutureOutputTypeParameter() {
3515-
result = TDynTraitTypeParameter(any(FutureTrait ft).getOutputType())
3518+
result.getTraitTypeParameter() = getFutureOutputTypeParameter()
35163519
}
35173520

35183521
pragma[nomagic]
@@ -3824,20 +3827,20 @@ private Type invokedClosureFnTypeAt(InvokedClosureExpr ce, TypePath path) {
38243827

38253828
/** Gets the path to a closure's return type. */
38263829
private TypePath closureReturnPath() {
3827-
result = TypePath::singleton(TDynTraitTypeParameter(any(FnOnceTrait t).getOutputType()))
3830+
result = TypePath::singleton(getDynTraitTypeParameter(any(FnOnceTrait t).getOutputType()))
38283831
}
38293832

38303833
/** Gets the path to a closure with arity `arity`s `index`th parameter type. */
38313834
pragma[nomagic]
38323835
private TypePath closureParameterPath(int arity, int index) {
38333836
result =
3834-
TypePath::cons(TDynTraitTypeParameter(any(FnOnceTrait t).getTypeParam()),
3837+
TypePath::cons(TDynTraitTypeParameter(_, any(FnOnceTrait t).getTypeParam()),
38353838
TypePath::singleton(getTupleTypeParameter(arity, index)))
38363839
}
38373840

38383841
/** Gets the path to the return type of the `FnOnce` trait. */
38393842
private TypePath fnReturnPath() {
3840-
result = TypePath::singleton(TAssociatedTypeTypeParameter(any(FnOnceTrait t).getOutputType()))
3843+
result = TypePath::singleton(getAssociatedTypeTypeParameter(any(FnOnceTrait t).getOutputType()))
38413844
}
38423845

38433846
/**
@@ -3898,7 +3901,7 @@ private Type inferClosureExprType(AstNode n, TypePath path) {
38983901
result = TDynTraitType(any(FnOnceTrait t)) // always exists because of the mention in `builtins/mentions.rs`
38993902
or
39003903
n = ce and
3901-
path = TypePath::singleton(TDynTraitTypeParameter(any(FnOnceTrait t).getTypeParam())) and
3904+
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnOnceTrait t).getTypeParam())) and
39023905
result.(TupleType).getArity() = ce.getNumberOfParams()
39033906
or
39043907
// Propagate return type annotation to body

rust/ql/lib/codeql/rust/internal/typeinference/TypeMention.qll

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -148,30 +148,11 @@ class NonAliasPathTypeMention extends PathTypeMention {
148148

149149
TypeItemNode getResolved() { result = resolved }
150150

151-
/**
152-
* Gets a type alias with the name `name` of the trait that this path resolves
153-
* to, if any.
154-
*/
155-
pragma[nomagic]
156-
private TypeAlias getResolvedTraitAlias(string name) {
157-
result = resolved.(TraitItemNode).getAnAssocItem() and
158-
name = result.getName().getText()
159-
}
160-
161151
pragma[nomagic]
162152
private TypeRepr getAssocTypeArg(string name) {
163153
result = this.getSegment().getGenericArgList().getAssocTypeArg(name)
164154
}
165155

166-
/** Gets the type argument for the associated type `alias`, if any. */
167-
pragma[nomagic]
168-
private TypeRepr getAnAssocTypeArgument(TypeAlias alias) {
169-
exists(string name |
170-
alias = this.getResolvedTraitAlias(name) and
171-
result = this.getAssocTypeArg(name)
172-
)
173-
}
174-
175156
/**
176157
* Gets the type mention that instantiates the implicit `Self` type parameter
177158
* for this path, if it occurs in the position of a trait bound.
@@ -239,7 +220,7 @@ class NonAliasPathTypeMention extends PathTypeMention {
239220
tp = TTypeParamTypeParameter(t.getTypeParam()) and
240221
result = s.getParenthesizedArgList().(TypeMention).resolveTypeAt(path)
241222
or
242-
tp = TAssociatedTypeTypeParameter(t.getOutputType()) and
223+
tp = TAssociatedTypeTypeParameter(t, t.getOutputType()) and
243224
(
244225
result = s.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(path)
245226
or
@@ -249,6 +230,28 @@ class NonAliasPathTypeMention extends PathTypeMention {
249230
path.isEmpty()
250231
)
251232
)
233+
or
234+
// If `path` is the supertrait of a trait block then any associated types
235+
// of the supertrait should be instantiated with the subtrait's
236+
// corresponding copies.
237+
//
238+
// As an example, for
239+
// ```rust
240+
// trait Sub: Super {
241+
// // ^^^^^ this
242+
// ```
243+
// we do something to the effect of:
244+
// ```rust
245+
// trait Sub: Super<Assoc=Assoc[Sub]>
246+
// ```
247+
// Where `Assoc` is an associated type of `Super` and `Assoc[Sub]` denotes
248+
// the copy of the type parameter inherited into `Sub`.
249+
exists(Trait subtrait, TypeAlias alias |
250+
subtrait.getATypeBound().getTypeRepr().(PathTypeRepr).getPath() = this and
251+
result = TAssociatedTypeTypeParameter(subtrait, alias) and
252+
tp = TAssociatedTypeTypeParameter(resolved, alias) and
253+
path.isEmpty()
254+
)
252255
}
253256

254257
pragma[nomagic]
@@ -259,9 +262,10 @@ class NonAliasPathTypeMention extends PathTypeMention {
259262
/** Gets the type mention in this path for the type parameter `tp`, if any. */
260263
pragma[nomagic]
261264
private TypeMention getTypeMentionForTypeParameter(TypeParameter tp) {
262-
exists(TypeAlias alias |
263-
result = this.getAnAssocTypeArgument(alias) and
264-
tp = TAssociatedTypeTypeParameter(alias)
265+
exists(TypeAlias alias, string name |
266+
result = this.getAssocTypeArg(name) and
267+
tp = TAssociatedTypeTypeParameter(resolved, alias) and
268+
alias = resolved.(TraitItemNode).getASuccessor(name)
265269
)
266270
or
267271
// If `path` is the trait of an `impl` block then any associated types
@@ -281,7 +285,8 @@ class NonAliasPathTypeMention extends PathTypeMention {
281285
this = impl.getTraitPath() and
282286
alias = impl.getASuccessor(pragma[only_bind_into](name)) and
283287
result = alias.getTypeRepr() and
284-
tp = TAssociatedTypeTypeParameter(this.getResolvedAlias(pragma[only_bind_into](name)))
288+
tp =
289+
TAssociatedTypeTypeParameter(resolved, this.getResolvedAlias(pragma[only_bind_into](name)))
285290
)
286291
}
287292

@@ -299,7 +304,7 @@ class NonAliasPathTypeMention extends PathTypeMention {
299304
or
300305
result = TTypeParamTypeParameter(resolved)
301306
or
302-
result = TAssociatedTypeTypeParameter(resolved)
307+
result = TAssociatedTypeTypeParameter(resolvePath(this.getQualifier()), resolved)
303308
}
304309

305310
override Type resolvePathTypeAt(TypePath typePath) {
@@ -384,9 +389,8 @@ class TraitMention extends TypeMention instanceof TraitItemNode {
384389
result = TSelfTypeParameter(this)
385390
or
386391
exists(TypeAlias alias |
387-
alias = super.getAnAssocItem() and
388392
typePath = TypePath::singleton(result) and
389-
result = TAssociatedTypeTypeParameter(alias)
393+
result = TAssociatedTypeTypeParameter(this, alias)
390394
)
391395
or
392396
exists(TypeParam tp |
@@ -540,7 +544,7 @@ class DynTraitTypeReprMention extends TypeMention instanceof DynTraitTypeRepr {
540544
// impl<A, B, ..> Trait<A, B, ..> for (dyn Trait)<A, B, ..>
541545
// ```
542546
// To achieve this:
543-
// - `DynTypeAbstraction` is an abstraction over type parameters of the trait.
547+
// - `DynTypeAbstraction` is an abstraction over the type parameters of the trait.
544548
// - `DynTypeBoundListMention` (this class) is a type mention which has `dyn
545549
// Trait` at the root and which for every type parameter of `dyn Trait` has the
546550
// corresponding type parameter of the trait.

rust/ql/test/library-tests/type-inference/associated_types.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ mod default_method_using_associated_type {
6969
Self::AssociatedType: Default,
7070
Self: Sized,
7171
{
72-
self.m1(); // $ target=MyTrait::m1 type=self.m1():AssociatedType
72+
self.m1(); // $ target=MyTrait::m1 type=self.m1():AssociatedType[MyTrait]
7373
let _default = Self::AssociatedType::default(); // $ MISSING: target=default _default:AssociatedType
7474
Self::AssociatedType::default() // $ MISSING: target=default
7575
}
@@ -144,8 +144,8 @@ mod equality_on_associated_type {
144144
where
145145
T: AnotherGet<Output = i32, AnotherOutput = bool>,
146146
{
147-
let _a1 = x.get(); // $ target=GetSet::get MISSING: type=_a1:i32
148-
let _a2 = get(&x); // $ target=get MISSING: type=_a2:i32
147+
let _a1 = x.get(); // $ target=GetSet::get type=_a1:i32
148+
let _a2 = get(&x); // $ target=get type=_a2:i32
149149
let _b = x.get_another(); // $ type=_b:bool target=AnotherGet::get_another
150150
}
151151

@@ -346,9 +346,9 @@ mod dyn_trait {
346346
}
347347

348348
fn _assoc_type_from_supertrait(t: &dyn AnotherGet<Output = i32, AnotherOutput = bool>) {
349-
let _a1 = (*t).get(); // $ target=deref target=GetSet::get MISSING: type=_a1:i32
350-
let _a2 = t.get(); // $ target=GetSet::get MISSING: type=_a2:i32
351-
let _a3 = get(t); // $ target=get MISSING: type=_a3:i32
349+
let _a1 = (*t).get(); // $ target=deref target=GetSet::get type=_a1:i32
350+
let _a2 = t.get(); // $ target=GetSet::get type=_a2:i32
351+
let _a3 = get(t); // $ target=get type=_a3:i32
352352
let _b1 = (*t).get_another(); // $ target=deref target=AnotherGet::get_another type=_b1:bool
353353
let _b2 = t.get_another(); // $ target=AnotherGet::get_another type=_b2:bool
354354
}

0 commit comments

Comments
 (0)