@@ -524,6 +524,14 @@ private Struct getRangeType(RangeExpr re) {
524524 result instanceof RangeToInclusiveStruct
525525}
526526
527+ private predicate bodyReturns ( Expr body , Expr e ) {
528+ exists ( ReturnExpr re , Callable c |
529+ e = re .getExpr ( ) and
530+ c = re .getEnclosingCallable ( ) and
531+ body = c .getBody ( )
532+ )
533+ }
534+
527535/**
528536 * Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
529537 * of `n2` at `prefix2` and type information should propagate in both directions
@@ -540,9 +548,11 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
540548 let .getInitializer ( ) = n2
541549 )
542550 or
543- n1 = n2 .( IfExpr ) .getABranch ( )
544- or
545- n1 = n2 .( MatchExpr ) .getAnArm ( ) .getExpr ( )
551+ n2 =
552+ any ( MatchExpr me |
553+ n1 = me .getAnArm ( ) .getExpr ( ) and
554+ me .getNumberOfArms ( ) = 1
555+ )
546556 or
547557 exists ( LetExpr let |
548558 n1 = let .getScrutinee ( ) and
@@ -573,6 +583,9 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
573583 n1 = n2 .( MacroExpr ) .getMacroCall ( ) .getMacroCallExpansion ( )
574584 or
575585 n1 = n2 .( MacroPat ) .getMacroCall ( ) .getMacroCallExpansion ( )
586+ or
587+ bodyReturns ( n1 , n2 ) and
588+ strictcount ( Expr e | bodyReturns ( n1 , e ) ) = 1
576589 )
577590 or
578591 (
@@ -606,8 +619,12 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
606619 )
607620 )
608621 or
609- // an array list expression (`[1, 2, 3]`) has the type of the first (any) element
610- n1 .( ArrayListExpr ) .getExpr ( _) = n2 and
622+ // an array list expression (`[1, 2, 3]`) has the type of the element
623+ n1 =
624+ any ( ArrayListExpr ale |
625+ ale .getAnExpr ( ) = n2 and
626+ ale .getNumberOfExprs ( ) = 1
627+ ) and
611628 prefix1 = TypePath:: singleton ( TArrayTypeParameter ( ) ) and
612629 prefix2 .isEmpty ( )
613630 or
@@ -635,6 +652,61 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
635652 prefix2 .isEmpty ( )
636653}
637654
655+ /**
656+ * Holds if `child` is a child of `parent`, and the Rust compiler applies [least
657+ * upper bound (LUB) coercion](1) to infer the type of `parent` from the type of
658+ * `child`.
659+ *
660+ * In this case, we want type information to only flow from `child` to `parent`,
661+ * to avoid (a) either having to model LUB coercions, or (b) risk combinatorial
662+ * explosion in inferred types.
663+ *
664+ * [1]: https://doc.rust-lang.org/reference/type-coercions.html#r-coerce.least-upper-bound
665+ */
666+ private predicate lubCoercion ( AstNode parent , AstNode child , TypePath prefix ) {
667+ child = parent .( IfExpr ) .getABranch ( ) and
668+ prefix .isEmpty ( )
669+ or
670+ parent =
671+ any ( MatchExpr me |
672+ child = me .getAnArm ( ) .getExpr ( ) and
673+ me .getNumberOfArms ( ) > 1
674+ ) and
675+ prefix .isEmpty ( )
676+ or
677+ parent =
678+ any ( ArrayListExpr ale |
679+ child = ale .getAnExpr ( ) and
680+ ale .getNumberOfExprs ( ) > 1
681+ ) and
682+ prefix = TypePath:: singleton ( TArrayTypeParameter ( ) )
683+ or
684+ bodyReturns ( parent , child ) and
685+ strictcount ( Expr e | bodyReturns ( parent , e ) ) > 1 and
686+ prefix .isEmpty ( )
687+ }
688+
689+ /**
690+ * Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
691+ * of `n2` at `prefix2`, but type information should only propagate from `n1` to
692+ * `n2`.
693+ */
694+ private predicate typeEqualityNonSymmetric (
695+ AstNode n1 , TypePath prefix1 , AstNode n2 , TypePath prefix2
696+ ) {
697+ lubCoercion ( n2 , n1 , prefix2 ) and
698+ prefix1 .isEmpty ( )
699+ or
700+ exists ( AstNode mid , TypePath prefixMid , TypePath suffix |
701+ typeEquality ( n1 , prefixMid , mid , prefix2 ) or
702+ typeEquality ( mid , prefix2 , n1 , prefixMid )
703+ |
704+ lubCoercion ( mid , n2 , suffix ) and
705+ not lubCoercion ( mid , n1 , _) and
706+ prefix1 = prefixMid .append ( suffix )
707+ )
708+ }
709+
638710pragma [ nomagic]
639711private Type inferTypeEquality ( AstNode n , TypePath path ) {
640712 exists ( TypePath prefix1 , AstNode n2 , TypePath prefix2 , TypePath suffix |
@@ -644,6 +716,8 @@ private Type inferTypeEquality(AstNode n, TypePath path) {
644716 typeEquality ( n , prefix1 , n2 , prefix2 )
645717 or
646718 typeEquality ( n2 , prefix2 , n , prefix1 )
719+ or
720+ typeEqualityNonSymmetric ( n2 , prefix2 , n , prefix1 )
647721 )
648722}
649723
0 commit comments