55
66use crate :: sync:: thread;
77use crate :: { Knobs , KnobsDatabase } ;
8+ use std:: fmt;
89use std:: panic:: catch_unwind;
910
1011use salsa:: CycleRecoveryAction ;
@@ -17,7 +18,6 @@ const MAX: CycleValue = CycleValue(3);
1718
1819#[ salsa:: tracked( cycle_fn=cycle_fn, cycle_initial=initial) ]
1920fn query_a ( db : & dyn KnobsDatabase ) -> CycleValue {
20- db. signal ( 1 ) ;
2121 query_b ( db)
2222}
2323
@@ -27,16 +27,14 @@ fn query_b(db: &dyn KnobsDatabase) -> CycleValue {
2727 CycleValue ( c_value. 0 + 1 ) . min ( MAX )
2828}
2929
30- #[ salsa:: tracked( cycle_fn=cycle_fn , cycle_initial=initial ) ]
30+ #[ salsa:: tracked]
3131fn query_c ( db : & dyn KnobsDatabase ) -> CycleValue {
3232 let d_value = query_d ( db) ;
3333
3434 if d_value > CycleValue ( 0 ) {
35- let _e_value = query_e ( db) ;
36- let _b = query_b ( db) ;
37- db. wait_for ( 2 ) ;
38- db. signal ( 3 ) ;
39- panic ! ( "Dragons are real" ) ;
35+ let e_value = query_e ( db) ;
36+ let b_value = query_b ( db) ;
37+ CycleValue ( d_value. 0 . max ( e_value. 0 ) . max ( b_value. 0 ) )
4038 } else {
4139 let a_value = query_a ( db) ;
4240 CycleValue ( d_value. 0 . max ( a_value. 0 ) )
@@ -45,7 +43,7 @@ fn query_c(db: &dyn KnobsDatabase) -> CycleValue {
4543
4644#[ salsa:: tracked( cycle_fn=cycle_fn, cycle_initial=initial) ]
4745fn query_d ( db : & dyn KnobsDatabase ) -> CycleValue {
48- query_c ( db)
46+ query_b ( db)
4947}
5048
5149#[ salsa:: tracked( cycle_fn=cycle_fn, cycle_initial=initial) ]
@@ -75,29 +73,62 @@ fn the_test() {
7573
7674 let t1 = thread:: spawn ( move || {
7775 let _span = tracing:: debug_span!( "t1" , thread_id = ?thread:: current( ) . id( ) ) . entered ( ) ;
78-
79- query_a ( & db_t1)
76+ catch_unwind ( || {
77+ db_t1. wait_for ( 1 ) ;
78+ query_a ( & db_t1)
79+ } )
8080 } ) ;
8181 let t2 = thread:: spawn ( move || {
82- let _span = tracing:: debug_span!( "t4" , thread_id = ?thread:: current( ) . id( ) ) . entered ( ) ;
83- db_t4. wait_for ( 1 ) ;
84- db_t4. signal ( 2 ) ;
85- query_b ( & db_t4)
82+ let _span = tracing:: debug_span!( "t2" , thread_id = ?thread:: current( ) . id( ) ) . entered ( ) ;
83+ catch_unwind ( || {
84+ db_t2. wait_for ( 1 ) ;
85+
86+ query_b ( & db_t2)
87+ } )
8688 } ) ;
8789 let t3 = thread:: spawn ( move || {
88- let _span = tracing:: debug_span!( "t2" , thread_id = ?thread:: current( ) . id( ) ) . entered ( ) ;
89- db_t2. wait_for ( 1 ) ;
90- query_d ( & db_t2)
90+ let _span = tracing:: debug_span!( "t3" , thread_id = ?thread:: current( ) . id( ) ) . entered ( ) ;
91+ catch_unwind ( || {
92+ db_t3. signal ( 2 ) ;
93+ query_d ( & db_t3)
94+ } )
9195 } ) ;
9296
93- let r_t1 = t1. join ( ) ;
94- let r_t2 = t2. join ( ) ;
95- let r_t3 = t3. join ( ) ;
97+ let r_t1 = t1. join ( ) . unwrap ( ) ;
98+ let r_t2 = t2. join ( ) . unwrap ( ) ;
99+ let r_t3 = t3. join ( ) . unwrap ( ) ;
96100
97- assert ! ( r_t1. is_err ( ) ) ;
98- assert ! ( r_t2. is_err ( ) ) ;
99- assert ! ( r_t3. is_err ( ) ) ;
101+ assert_is_set_cycle_error ( r_t1) ;
102+ assert_is_set_cycle_error ( r_t2) ;
103+ assert_is_set_cycle_error ( r_t3) ;
100104
101105 // Pulling the cycle again at a later point should still result in a panic.
102- assert ! ( catch_unwind( || query_d( & db_t3) ) . is_err( ) ) ;
106+ assert_is_set_cycle_error ( catch_unwind ( || query_d ( & db_t4) ) ) ;
107+ }
108+
109+ #[ track_caller]
110+ fn assert_is_set_cycle_error < T > ( result : Result < T , Box < dyn std:: any:: Any + Send > > )
111+ where
112+ T : fmt:: Debug ,
113+ {
114+ let err = result. expect_err ( "expected an error" ) ;
115+
116+ if let Some ( message) = err. downcast_ref :: < & str > ( ) {
117+ assert ! (
118+ message. contains( "set cycle_fn/cycle_initial to fixpoint iterate" ) ,
119+ "Expected error message to contain 'set cycle_fn/cycle_initial to fixpoint iterate', but got: {}" ,
120+ message
121+ ) ;
122+ } else if let Some ( message) = err. downcast_ref :: < String > ( ) {
123+ assert ! (
124+ message. contains( "set cycle_fn/cycle_initial to fixpoint iterate" ) ,
125+ "Expected error message to contain 'set cycle_fn/cycle_initial to fixpoint iterate', but got: {}" ,
126+ message
127+ ) ;
128+ } else if let Some ( _) = err. downcast_ref :: < salsa:: Cancelled > ( ) {
129+ // This is okay, because Salsa throws a Cancelled::PropagatedPanic when a panic occurs in a query
130+ // that it blocks on.
131+ } else {
132+ std:: panic:: resume_unwind ( err) ;
133+ }
103134}
0 commit comments