Skip to content

Commit d22f70e

Browse files
committed
Rust: Non-symmetric type propagation for lub coercions
1 parent 96fc13b commit d22f70e

File tree

2 files changed

+79
-235
lines changed

2 files changed

+79
-235
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,14 @@ private Struct getRangeType(RangeExpr re) {
524524
result instanceof RangeToInclusiveStruct
525525
}
526526

527+
private predicate bodyReturns(Expr body, Expr e) {
528+
exists(ReturnExpr re, Callable c |
529+
e = re.getExpr() and
530+
c = re.getEnclosingCallable() and
531+
body = c.getBody()
532+
)
533+
}
534+
527535
/**
528536
* Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
529537
* of `n2` at `prefix2` and type information should propagate in both directions
@@ -540,9 +548,11 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
540548
let.getInitializer() = n2
541549
)
542550
or
543-
n1 = n2.(IfExpr).getABranch()
544-
or
545-
n1 = n2.(MatchExpr).getAnArm().getExpr()
551+
n2 =
552+
any(MatchExpr me |
553+
n1 = me.getAnArm().getExpr() and
554+
me.getNumberOfArms() = 1
555+
)
546556
or
547557
exists(LetExpr let |
548558
n1 = let.getScrutinee() and
@@ -573,6 +583,9 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
573583
n1 = n2.(MacroExpr).getMacroCall().getMacroCallExpansion()
574584
or
575585
n1 = n2.(MacroPat).getMacroCall().getMacroCallExpansion()
586+
or
587+
bodyReturns(n1, n2) and
588+
strictcount(Expr e | bodyReturns(n1, e)) = 1
576589
)
577590
or
578591
(
@@ -606,8 +619,12 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
606619
)
607620
)
608621
or
609-
// an array list expression (`[1, 2, 3]`) has the type of the first (any) element
610-
n1.(ArrayListExpr).getExpr(_) = n2 and
622+
// an array list expression (`[1, 2, 3]`) has the type of the element
623+
n1 =
624+
any(ArrayListExpr ale |
625+
ale.getAnExpr() = n2 and
626+
ale.getNumberOfExprs() = 1
627+
) and
611628
prefix1 = TypePath::singleton(TArrayTypeParameter()) and
612629
prefix2.isEmpty()
613630
or
@@ -635,6 +652,61 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
635652
prefix2.isEmpty()
636653
}
637654

655+
/**
656+
* Holds if `child` is a child of `parent`, and the Rust compiler applies [least
657+
* upper bound (LUB) coercion](1) to infer the type of `parent` from the type of
658+
* `child`.
659+
*
660+
* In this case, we want type information to only flow from `child` to `parent`,
661+
* to avoid (a) either having to model LUB coercions, or (b) risk combinatorial
662+
* explosion in inferred types.
663+
*
664+
* [1]: https://doc.rust-lang.org/reference/type-coercions.html#r-coerce.least-upper-bound
665+
*/
666+
private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) {
667+
child = parent.(IfExpr).getABranch() and
668+
prefix.isEmpty()
669+
or
670+
parent =
671+
any(MatchExpr me |
672+
child = me.getAnArm().getExpr() and
673+
me.getNumberOfArms() > 1
674+
) and
675+
prefix.isEmpty()
676+
or
677+
parent =
678+
any(ArrayListExpr ale |
679+
child = ale.getAnExpr() and
680+
ale.getNumberOfExprs() > 1
681+
) and
682+
prefix = TypePath::singleton(TArrayTypeParameter())
683+
or
684+
bodyReturns(parent, child) and
685+
strictcount(Expr e | bodyReturns(parent, e)) > 1 and
686+
prefix.isEmpty()
687+
}
688+
689+
/**
690+
* Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
691+
* of `n2` at `prefix2`, but type information should only propagate from `n1` to
692+
* `n2`.
693+
*/
694+
private predicate typeEqualityNonSymmetric(
695+
AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2
696+
) {
697+
lubCoercion(n2, n1, prefix2) and
698+
prefix1.isEmpty()
699+
or
700+
exists(AstNode mid, TypePath prefixMid, TypePath suffix |
701+
typeEquality(n1, prefixMid, mid, prefix2) or
702+
typeEquality(mid, prefix2, n1, prefixMid)
703+
|
704+
lubCoercion(mid, n2, suffix) and
705+
not lubCoercion(mid, n1, _) and
706+
prefix1 = prefixMid.append(suffix)
707+
)
708+
}
709+
638710
pragma[nomagic]
639711
private Type inferTypeEquality(AstNode n, TypePath path) {
640712
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
@@ -644,6 +716,8 @@ private Type inferTypeEquality(AstNode n, TypePath path) {
644716
typeEquality(n, prefix1, n2, prefix2)
645717
or
646718
typeEquality(n2, prefix2, n, prefix1)
719+
or
720+
typeEqualityNonSymmetric(n2, prefix2, n, prefix1)
647721
)
648722
}
649723

0 commit comments

Comments
 (0)