Skip to content

Commit e05b287

Browse files
committed
Rust: Restrict type propagation into arguments
1 parent 90b7a1d commit e05b287

File tree

11 files changed

+201
-582
lines changed

11 files changed

+201
-582
lines changed

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ newtype TType =
5151
TSliceType() or
5252
TNeverType() or
5353
TPtrType() or
54+
TContextType() or
5455
TTupleTypeParameter(int arity, int i) { exists(TTuple(arity)) and i in [0 .. arity - 1] } or
5556
TTypeParamTypeParameter(TypeParam t) or
5657
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
@@ -371,6 +372,30 @@ class PtrType extends Type, TPtrType {
371372
override Location getLocation() { result instanceof EmptyLocation }
372373
}
373374

375+
/**
376+
* A special pseudo type used to indicate that the actual type may have to be
377+
* inferred from a context.
378+
*
379+
* For example, a call like `Default::default()` is assigned this type, which
380+
* means that the actual type is to be inferred from the context in which the call
381+
* occurs.
382+
*
383+
* Context types are not restricted to root types, for example in a call like
384+
* `Vec::new()` we assign this type at the type path corresponding to the type
385+
* parameter of `Vec`.
386+
*
387+
* Context types are used to restrict when type information is allowed to flow
388+
* into call arguments (including method call receivers), in order to avoid
389+
* combinatorial explosions.
390+
*/
391+
class ContextType extends Type, TContextType {
392+
override TypeParameter getPositionalTypeParameter(int i) { none() }
393+
394+
override string toString() { result = "(context typed)" }
395+
396+
override Location getLocation() { result instanceof EmptyLocation }
397+
}
398+
374399
/** A type parameter. */
375400
abstract class TypeParameter extends Type {
376401
override TypeParameter getPositionalTypeParameter(int i) { none() }

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 153 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,113 @@ private Type getCallExprTypeQualifier(CallExpr ce, TypePath path) {
909909
)
910910
}
911911

912+
/**
913+
* Provides functionality related to context-based typing of calls.
914+
*/
915+
private module ContextTyping {
916+
/**
917+
* Holds if the return type of the function `f` at path `path` is `tp`,
918+
* and `tp` does not appear in the type of any parameter of `f`.
919+
*
920+
* In this case, the context in which `f` is called may be needed to infer
921+
* the instantiation of `tp`.
922+
*/
923+
pragma[nomagic]
924+
private predicate assocFunctionReturnContextTypedAt(
925+
Function f, FunctionPosition pos, TypePath path, TypeParameter tp
926+
) {
927+
exists(ImplOrTraitItemNode i |
928+
pos.isReturn() and
929+
tp = getAssocFunctionTypeAt(f, i, pos, path) and
930+
not exists(FunctionPosition nonResPos |
931+
not nonResPos.isReturn() and
932+
tp = getAssocFunctionTypeAt(f, i, nonResPos, _)
933+
)
934+
)
935+
}
936+
937+
/**
938+
* A call where the type of the result may have to be inferred from the
939+
* context in which the call appears, for example a call like
940+
* `Default::default()`.
941+
*/
942+
abstract class ContextTypedCallCand extends AstNode {
943+
abstract Type getTypeArgument(TypeArgumentPosition apos, TypePath path);
944+
945+
private predicate hasTypeArgument(TypeArgumentPosition apos) {
946+
exists(this.getTypeArgument(apos, _))
947+
}
948+
949+
/**
950+
* Holds if `this` call resolves to `target` and the type at `pos` and `path`
951+
* may have to be inferred from the context.
952+
*/
953+
bindingset[this, target]
954+
predicate isContextTypedAt(Function target, TypePath path, FunctionPosition pos) {
955+
exists(TypeParameter tp |
956+
assocFunctionReturnContextTypedAt(target, pos, path, tp) and
957+
// check that no explicit type arguments have been supplied for `tp`
958+
not exists(TypeArgumentPosition tapos | this.hasTypeArgument(tapos) |
959+
exists(int i |
960+
i = tapos.asMethodTypeArgumentPosition() and
961+
tp = TTypeParamTypeParameter(target.getGenericParamList().getTypeParam(i))
962+
)
963+
or
964+
TTypeParamTypeParameter(tapos.asTypeParam()) = tp
965+
) and
966+
not (
967+
tp instanceof TSelfTypeParameter and
968+
exists(getCallExprTypeQualifier(this, _))
969+
)
970+
)
971+
}
972+
}
973+
974+
pragma[nomagic]
975+
private predicate isContextTyped(AstNode n, TypePath path) { inferType(n, path) = TContextType() }
976+
977+
pragma[nomagic]
978+
private predicate isContextTyped(AstNode n) { isContextTyped(n, _) }
979+
980+
signature Type inferCallTypeSig(AstNode n, FunctionPosition pos, TypePath path);
981+
982+
/**
983+
* Given a predicate `inferCallType` for inferring the type of a call at a given
984+
* position, this module exposes the predicate `check`, which wraps the input
985+
* predicate and checks that types are only propagated into arguments when they
986+
* are context-typed.
987+
*/
988+
module CheckContextTyping<inferCallTypeSig/3 inferCallType> {
989+
pragma[nomagic]
990+
private Type inferCallTypeFromContextCand(
991+
AstNode n, FunctionPosition pos, TypePath path, TypePath prefix
992+
) {
993+
result = inferCallType(n, pos, path) and
994+
not pos.isReturn() and
995+
isContextTyped(n) and
996+
prefix = path
997+
or
998+
exists(TypePath mid |
999+
result = inferCallTypeFromContextCand(n, pos, path, mid) and
1000+
mid.isSnoc(prefix, _)
1001+
)
1002+
}
1003+
1004+
pragma[nomagic]
1005+
Type check(AstNode n, TypePath path) {
1006+
exists(FunctionPosition pos |
1007+
result = inferCallType(n, pos, path) and
1008+
pos.isReturn()
1009+
or
1010+
exists(TypePath prefix |
1011+
result = inferCallTypeFromContextCand(n, pos, path, prefix) and
1012+
isContextTyped(n, prefix)
1013+
)
1014+
)
1015+
}
1016+
}
1017+
}
1018+
9121019
/**
9131020
* Holds if function `f` with the name `name` and the arity `arity` exists in
9141021
* `i`, and the type at position `pos` is `t`.
@@ -1569,7 +1676,8 @@ private module MethodResolution {
15691676

15701677
Type getTypeAt(TypePath path) {
15711678
result = mc_.getACandidateReceiverTypeAtSubstituteLookupTraits(derefChain, borrow, path) and
1572-
not result = TNeverType()
1679+
not result = TNeverType() and
1680+
not result = TContextType()
15731681
}
15741682

15751683
pragma[nomagic]
@@ -1918,14 +2026,14 @@ private module MethodCallMatchingInput implements MatchingWithEnvironmentInputSi
19182026

19192027
final private class MethodCallFinal = MethodResolution::MethodCall;
19202028

1921-
class Access extends MethodCallFinal {
2029+
class Access extends MethodCallFinal, ContextTyping::ContextTypedCallCand {
19222030
Access() {
19232031
// handled in the `OperationMatchingInput` module
19242032
not this instanceof Operation
19252033
}
19262034

19272035
pragma[nomagic]
1928-
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
2036+
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
19292037
exists(TypeMention arg | result = arg.resolveTypeAt(path) |
19302038
arg =
19312039
this.(MethodCallExpr).getGenericArgList().getTypeArg(apos.asMethodTypeArgumentPosition())
@@ -1989,7 +2097,12 @@ private Type inferMethodCallType0(
19892097
) {
19902098
exists(TypePath path0 |
19912099
n = a.getNodeAt(apos) and
1992-
result = MethodCallMatching::inferAccessType(a, derefChainBorrow, apos, path0)
2100+
(
2101+
result = MethodCallMatching::inferAccessType(a, derefChainBorrow, apos, path0)
2102+
or
2103+
a.isContextTypedAt(a.getTarget(derefChainBorrow), path0, apos) and
2104+
result = TContextType()
2105+
)
19932106
|
19942107
if
19952108
// index expression `x[i]` desugars to `*x.index(i)`, so we must account for
@@ -2001,16 +2114,11 @@ private Type inferMethodCallType0(
20012114
)
20022115
}
20032116

2004-
/**
2005-
* Gets the type of `n` at `path`, where `n` is either a method call or an
2006-
* argument/receiver of a method call.
2007-
*/
20082117
pragma[nomagic]
2009-
private Type inferMethodCallType(AstNode n, TypePath path) {
2010-
exists(
2011-
MethodCallMatchingInput::Access a, MethodCallMatchingInput::AccessPosition apos,
2012-
string derefChainBorrow, TypePath path0
2013-
|
2118+
private Type inferMethodCallType1(
2119+
AstNode n, MethodCallMatchingInput::AccessPosition apos, TypePath path
2120+
) {
2121+
exists(MethodCallMatchingInput::Access a, string derefChainBorrow, TypePath path0 |
20142122
result = inferMethodCallType0(a, apos, n, derefChainBorrow, path0)
20152123
|
20162124
(
@@ -2032,6 +2140,13 @@ private Type inferMethodCallType(AstNode n, TypePath path) {
20322140
)
20332141
}
20342142

2143+
/**
2144+
* Gets the type of `n` at `path`, where `n` is either a method call or an
2145+
* argument/receiver of a method call.
2146+
*/
2147+
private predicate inferMethodCallType =
2148+
ContextTyping::CheckContextTyping<inferMethodCallType1/3>::check/2;
2149+
20352150
/**
20362151
* Provides logic for resolving calls to non-method items. This includes
20372152
* "calls" to tuple variants and tuple structs.
@@ -2199,6 +2314,12 @@ private module NonMethodResolution {
21992314
or
22002315
result = this.resolveCallTargetRec()
22012316
}
2317+
2318+
pragma[nomagic]
2319+
Function resolveTraitFunction() {
2320+
this.(Call).hasTrait() and
2321+
result = this.getPathResolutionResolved()
2322+
}
22022323
}
22032324

22042325
private newtype TCallAndBlanketPos =
@@ -2433,9 +2554,9 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
24332554
}
24342555
}
24352556

2436-
class Access extends NonMethodResolution::NonMethodCall {
2557+
class Access extends NonMethodResolution::NonMethodCall, ContextTyping::ContextTypedCallCand {
24372558
pragma[nomagic]
2438-
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
2559+
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
24392560
result = getCallExprTypeArgument(this, apos).resolveTypeAt(path)
24402561
}
24412562

@@ -2456,13 +2577,20 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
24562577
private module NonMethodCallMatching = Matching<NonMethodCallMatchingInput>;
24572578

24582579
pragma[nomagic]
2459-
private Type inferNonMethodCallType(AstNode n, TypePath path) {
2460-
exists(NonMethodCallMatchingInput::Access a, NonMethodCallMatchingInput::AccessPosition apos |
2461-
n = a.getNodeAt(apos) and
2580+
private Type inferNonMethodCallType0(
2581+
AstNode n, NonMethodCallMatchingInput::AccessPosition apos, TypePath path
2582+
) {
2583+
exists(NonMethodCallMatchingInput::Access a | n = a.getNodeAt(apos) |
24622584
result = NonMethodCallMatching::inferAccessType(a, apos, path)
2585+
or
2586+
a.isContextTypedAt([a.resolveCallTarget().(Function), a.resolveTraitFunction()], path, apos) and
2587+
result = TContextType()
24632588
)
24642589
}
24652590

2591+
private predicate inferNonMethodCallType =
2592+
ContextTyping::CheckContextTyping<inferNonMethodCallType0/3>::check/2;
2593+
24662594
/**
24672595
* A matching configuration for resolving types of operations like `a + b`.
24682596
*/
@@ -2535,13 +2663,18 @@ private module OperationMatchingInput implements MatchingInputSig {
25352663
private module OperationMatching = Matching<OperationMatchingInput>;
25362664

25372665
pragma[nomagic]
2538-
private Type inferOperationType(AstNode n, TypePath path) {
2539-
exists(OperationMatchingInput::Access a, OperationMatchingInput::AccessPosition apos |
2666+
private Type inferOperationType0(
2667+
AstNode n, OperationMatchingInput::AccessPosition apos, TypePath path
2668+
) {
2669+
exists(OperationMatchingInput::Access a |
25402670
n = a.getNodeAt(apos) and
25412671
result = OperationMatching::inferAccessType(a, apos, path)
25422672
)
25432673
}
25442674

2675+
private predicate inferOperationType =
2676+
ContextTyping::CheckContextTyping<inferOperationType0/3>::check/2;
2677+
25452678
pragma[nomagic]
25462679
private Type getFieldExprLookupType(FieldExpr fe, string name) {
25472680
exists(TypePath path |

rust/ql/lib/codeql/rust/internal/typeinference/BlanketImplementation.qll

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ module SatisfiesBlanketConstraint<
9292

9393
Type getTypeAt(TypePath path) {
9494
result = at.getTypeAt(blanketPath.appendInverse(path)) and
95-
not result = TNeverType()
95+
not result = TNeverType() and
96+
not result = TContextType()
9697
}
9798

9899
string toString() { result = at.toString() + " [blanket at " + blanketPath.toString() + "]" }

rust/ql/lib/codeql/rust/internal/typeinference/FunctionType.qll

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,8 @@ module ArgIsInstantiationOf<
229229
private class ArgSubst extends ArgFinal {
230230
Type getTypeAt(TypePath path) {
231231
result = substituteLookupTraits(super.getTypeAt(path)) and
232-
not result = TNeverType()
232+
not result = TNeverType() and
233+
not result = TContextType()
233234
}
234235
}
235236

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
multipleCallTargets
2+
| test.rs:288:7:288:36 | ... .as_str() |

0 commit comments

Comments
 (0)