@@ -10,6 +10,7 @@ private import codeql.typeinference.internal.TypeInference
1010private import codeql.rust.frameworks.stdlib.Stdlib
1111private import codeql.rust.frameworks.stdlib.Builtins as Builtins
1212private import codeql.rust.elements.Call
13+ private import codeql.rust.elements.internal.CallImpl:: Impl as CallImpl
1314
1415class Type = T:: Type ;
1516
@@ -353,19 +354,6 @@ private Type inferImplicitSelfType(SelfParam self, TypePath path) {
353354 )
354355}
355356
356- /**
357- * Gets any of the types mentioned in `path` that corresponds to the type
358- * parameter `tp`.
359- */
360- private TypeMention getExplicitTypeArgMention ( Path path , TypeParam tp ) {
361- exists ( int i |
362- result = path .getSegment ( ) .getGenericArgList ( ) .getTypeArg ( pragma [ only_bind_into ] ( i ) ) and
363- tp = resolvePath ( path ) .getTypeParam ( pragma [ only_bind_into ] ( i ) )
364- )
365- or
366- result = getExplicitTypeArgMention ( path .getQualifier ( ) , tp )
367- }
368-
369357/**
370358 * A matching configuration for resolving types of struct expressions
371359 * like `Foo { bar = baz }`.
@@ -452,9 +440,7 @@ private module StructExprMatchingInput implements MatchingInputSig {
452440 class AccessPosition = DeclarationPosition ;
453441
454442 class Access extends StructExpr {
455- Type getTypeArgument ( TypeArgumentPosition apos , TypePath path ) {
456- result = getExplicitTypeArgMention ( this .getPath ( ) , apos .asTypeParam ( ) ) .resolveTypeAt ( path )
457- }
443+ Type getTypeArgument ( TypeArgumentPosition apos , TypePath path ) { none ( ) }
458444
459445 AstNode getNodeAt ( AccessPosition apos ) {
460446 result = this .getFieldExpr ( apos .asFieldPos ( ) ) .getExpr ( )
@@ -465,6 +451,16 @@ private module StructExprMatchingInput implements MatchingInputSig {
465451
466452 Type getInferredType ( AccessPosition apos , TypePath path ) {
467453 result = inferType ( this .getNodeAt ( apos ) , path )
454+ or
455+ // The struct type is supplied explicitly as a type qualifier, e.g.
456+ // `Foo<Bar>::Variant { ... }`.
457+ apos .isStructPos ( ) and
458+ exists ( Path p , TypeMention tm |
459+ p = this .getPath ( ) and
460+ if resolvePath ( p ) instanceof Variant then tm = p .getQualifier ( ) else tm = p
461+ |
462+ result = tm .resolveTypeAt ( path )
463+ )
468464 }
469465
470466 Declaration getTarget ( ) { result = resolvePath ( this .getPath ( ) ) }
@@ -537,15 +533,24 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
537533
538534 abstract Type getReturnType ( TypePath path ) ;
539535
540- final Type getDeclaredType ( DeclarationPosition dpos , TypePath path ) {
536+ Type getDeclaredType ( DeclarationPosition dpos , TypePath path ) {
541537 result = this .getParameterType ( dpos , path )
542538 or
543539 dpos .isReturn ( ) and
544540 result = this .getReturnType ( path )
545541 }
546542 }
547543
548- private class TupleStructDecl extends Declaration , Struct {
544+ abstract private class TupleDeclaration extends Declaration {
545+ override Type getDeclaredType ( DeclarationPosition dpos , TypePath path ) {
546+ result = super .getDeclaredType ( dpos , path )
547+ or
548+ dpos .isSelf ( ) and
549+ result = this .getReturnType ( path )
550+ }
551+ }
552+
553+ private class TupleStructDecl extends TupleDeclaration , Struct {
549554 TupleStructDecl ( ) { this .isTuple ( ) }
550555
551556 override TypeParameter getTypeParameter ( TypeParameterPosition ppos ) {
@@ -568,7 +573,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
568573 }
569574 }
570575
571- private class TupleVariantDecl extends Declaration , Variant {
576+ private class TupleVariantDecl extends TupleDeclaration , Variant {
572577 TupleVariantDecl ( ) { this .isTuple ( ) }
573578
574579 override TypeParameter getTypeParameter ( TypeParameterPosition ppos ) {
@@ -597,13 +602,13 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
597602 override TypeParameter getTypeParameter ( TypeParameterPosition ppos ) {
598603 typeParamMatchPosition ( this .getGenericParamList ( ) .getATypeParam ( ) , result , ppos )
599604 or
600- exists ( TraitItemNode trait | this = trait .getAnAssocItem ( ) |
601- typeParamMatchPosition ( trait .getTypeParam ( _) , result , ppos )
605+ exists ( ImplOrTraitItemNode i | this = i .getAnAssocItem ( ) |
606+ typeParamMatchPosition ( i .getTypeParam ( _) , result , ppos )
602607 or
603- ppos .isImplicit ( ) and result = TSelfTypeParameter ( trait )
608+ ppos .isImplicit ( ) and result = TSelfTypeParameter ( i )
604609 or
605610 ppos .isImplicit ( ) and
606- result .( AssociatedTypeTypeParameter ) .getTrait ( ) = trait
611+ result .( AssociatedTypeTypeParameter ) .getTrait ( ) = i
607612 )
608613 or
609614 ppos .isImplicit ( ) and
@@ -625,6 +630,33 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
625630 or
626631 result = inferImplicitSelfType ( self , path ) // `self` parameter without type annotation
627632 )
633+ or
634+ // For associated functions, we may also need to match type arguments against
635+ // the `Self` type. For example, in
636+ //
637+ // ```rust
638+ // struct Foo<T>(T);
639+ //
640+ // impl<T : Default> Foo<T> {
641+ // fn default() -> Self {
642+ // Foo(Default::default())
643+ // }
644+ // }
645+ //
646+ // Foo::<i32>::default();
647+ // ```
648+ //
649+ // we need to match `i32` against the type parameter `T` of the `impl` block.
650+ exists ( ImplOrTraitItemNode i |
651+ this = i .getAnAssocItem ( ) and
652+ dpos .isSelf ( ) and
653+ not this .getParamList ( ) .hasSelfParam ( )
654+ |
655+ result = TSelfTypeParameter ( i ) and
656+ path .isEmpty ( )
657+ or
658+ result = resolveImplSelfType ( i , path )
659+ )
628660 }
629661
630662 private Type resolveRetType ( TypePath path ) {
@@ -670,9 +702,14 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
670702 private import codeql.rust.elements.internal.CallExprImpl:: Impl as CallExprImpl
671703
672704 final class Access extends Call {
705+ pragma [ nomagic]
673706 Type getTypeArgument ( TypeArgumentPosition apos , TypePath path ) {
674707 exists ( TypeMention arg | result = arg .resolveTypeAt ( path ) |
675- arg = getExplicitTypeArgMention ( CallExprImpl:: getFunctionPath ( this ) , apos .asTypeParam ( ) )
708+ exists ( Path p , int i |
709+ p = CallExprImpl:: getFunctionPath ( this ) and
710+ arg = p .getSegment ( ) .getGenericArgList ( ) .getTypeArg ( pragma [ only_bind_into ] ( i ) ) and
711+ apos .asTypeParam ( ) = resolvePath ( p ) .getTypeParam ( pragma [ only_bind_into ] ( i ) )
712+ )
676713 or
677714 arg =
678715 this .( MethodCallExpr ) .getGenericArgList ( ) .getTypeArg ( apos .asMethodTypeArgumentPosition ( ) )
@@ -696,6 +733,14 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
696733
697734 Type getInferredType ( AccessPosition apos , TypePath path ) {
698735 result = inferType ( this .getNodeAt ( apos ) , path )
736+ or
737+ // The `Self` type is supplied explicitly as a type qualifier, e.g. `Foo::<Bar>::baz()`
738+ apos = TArgumentAccessPosition ( CallImpl:: TSelfArgumentPosition ( ) , false , false ) and
739+ exists ( PathExpr pe , TypeMention tm |
740+ pe = this .( CallExpr ) .getFunction ( ) and
741+ tm = pe .getPath ( ) .getQualifier ( ) and
742+ result = tm .resolveTypeAt ( path )
743+ )
699744 }
700745
701746 Declaration getTarget ( ) {
@@ -1110,12 +1155,7 @@ private Type inferForLoopExprType(AstNode n, TypePath path) {
11101155}
11111156
11121157final class MethodCall extends Call {
1113- MethodCall ( ) {
1114- exists ( this .getReceiver ( ) ) and
1115- // We want the method calls that don't have a path to a concrete method in
1116- // an impl block. We need to exclude calls like `MyType::my_method(..)`.
1117- ( this instanceof CallExpr implies exists ( this .getTrait ( ) ) )
1118- }
1158+ MethodCall ( ) { exists ( this .getReceiver ( ) ) }
11191159
11201160 /** Gets the type of the receiver of the method call at `path`. */
11211161 Type getTypeAt ( TypePath path ) {
@@ -1582,19 +1622,51 @@ private module Debug {
15821622 result = resolveMethodCallTarget ( mce )
15831623 }
15841624
1625+ predicate debugInferImplicitSelfType ( SelfParam self , TypePath path , Type t ) {
1626+ self = getRelevantLocatable ( ) and
1627+ t = inferImplicitSelfType ( self , path )
1628+ }
1629+
1630+ predicate debugInferCallExprBaseType ( AstNode n , TypePath path , Type t ) {
1631+ n = getRelevantLocatable ( ) and
1632+ t = inferCallExprBaseType ( n , path )
1633+ }
1634+
15851635 predicate debugTypeMention ( TypeMention tm , TypePath path , Type type ) {
15861636 tm = getRelevantLocatable ( ) and
15871637 tm .resolveTypeAt ( path ) = type
15881638 }
15891639
15901640 pragma [ nomagic]
1591- private int countTypes ( AstNode n , TypePath path , Type t ) {
1641+ private int countTypesAtPath ( AstNode n , TypePath path , Type t ) {
15921642 t = inferType ( n , path ) and
15931643 result = strictcount ( Type t0 | t0 = inferType ( n , path ) )
15941644 }
15951645
15961646 predicate maxTypes ( AstNode n , TypePath path , Type t , int c ) {
1597- c = countTypes ( n , path , t ) and
1598- c = max ( countTypes ( _, _, _) )
1647+ c = countTypesAtPath ( n , path , t ) and
1648+ c = max ( countTypesAtPath ( _, _, _) )
1649+ }
1650+
1651+ pragma [ nomagic]
1652+ private predicate typePathLength ( AstNode n , TypePath path , Type t , int len ) {
1653+ t = inferType ( n , path ) and
1654+ len = path .length ( )
1655+ }
1656+
1657+ predicate maxTypePath ( AstNode n , TypePath path , Type t , int len ) {
1658+ typePathLength ( n , path , t , len ) and
1659+ len = max ( int i | typePathLength ( _, _, _, i ) )
1660+ }
1661+
1662+ pragma [ nomagic]
1663+ private int countTypePaths ( AstNode n , TypePath path , Type t ) {
1664+ t = inferType ( n , path ) and
1665+ result = strictcount ( TypePath path0 , Type t0 | t0 = inferType ( n , path0 ) )
1666+ }
1667+
1668+ predicate maxTypePaths ( AstNode n , TypePath path , Type t , int c ) {
1669+ c = countTypePaths ( n , path , t ) and
1670+ c = max ( countTypePaths ( _, _, _) )
15991671 }
16001672}
0 commit comments