|
1 | 1 | package com.microsoft.durabletask; |
2 | 2 |
|
3 | 3 | import org.junit.jupiter.api.extension.ExtensionContext; |
4 | | -import org.junit.jupiter.api.extension.TestExecutionExceptionHandler; |
5 | | -import org.junit.jupiter.api.extension.BeforeEachCallback; |
6 | | -import org.junit.jupiter.api.extension.AfterEachCallback; |
| 4 | +import org.junit.jupiter.api.extension.InvocationInterceptor; |
| 5 | +import org.junit.jupiter.api.extension.ReflectiveInvocationContext; |
| 6 | +import java.lang.reflect.Method; |
7 | 7 |
|
8 | | -public class TestRetryExtension implements TestExecutionExceptionHandler, BeforeEachCallback, AfterEachCallback { |
| 8 | +public class TestRetryExtension implements InvocationInterceptor { |
9 | 9 | private static final int MAX_RETRIES = 3; |
10 | | - private int currentRetries = 0; |
11 | 10 |
|
12 | 11 | @Override |
13 | | - public void beforeEach(ExtensionContext context) { |
14 | | - currentRetries = 0; |
15 | | - } |
16 | | - |
17 | | - @Override |
18 | | - public void handleTestExecutionException(ExtensionContext context, Throwable throwable) throws Throwable { |
19 | | - if (currentRetries < MAX_RETRIES - 1) { |
20 | | - currentRetries++; |
21 | | - System.err.println(String.format("Test '%s' failed on attempt %d/%d", context.getDisplayName(), currentRetries + 1, MAX_RETRIES)); |
22 | | - context.getRequiredTestMethod().invoke(context.getRequiredTestInstance()); |
23 | | - } else { |
24 | | - System.err.println(String.format("Test '%s' failed after %d attempts", context.getDisplayName(), MAX_RETRIES)); |
25 | | - throw throwable; |
| 12 | + public void interceptTestMethod(Invocation<Void> invocation, |
| 13 | + ReflectiveInvocationContext<Method> invocationContext, |
| 14 | + ExtensionContext extensionContext) throws Throwable { |
| 15 | + |
| 16 | + Throwable lastException = null; |
| 17 | + |
| 18 | + for (int attempt = 1; attempt <= MAX_RETRIES; attempt++) { |
| 19 | + try { |
| 20 | + invocation.proceed(); |
| 21 | + return; // Success, exit the retry loop |
| 22 | + } catch (Throwable throwable) { |
| 23 | + lastException = throwable; |
| 24 | + if (attempt < MAX_RETRIES) { |
| 25 | + System.err.println(String.format("Test '%s' failed on attempt %d/%d: %s", |
| 26 | + extensionContext.getDisplayName(), attempt, MAX_RETRIES, throwable.getMessage())); |
| 27 | + } else { |
| 28 | + System.err.println(String.format("Test '%s' failed after %d attempts", |
| 29 | + extensionContext.getDisplayName(), MAX_RETRIES)); |
| 30 | + } |
| 31 | + } |
26 | 32 | } |
27 | | - } |
28 | | - |
29 | | - @Override |
30 | | - public void afterEach(ExtensionContext context) { |
31 | | - currentRetries = 0; |
| 33 | + |
| 34 | + // If we get here, all retries failed |
| 35 | + throw lastException; |
32 | 36 | } |
33 | 37 | } |
0 commit comments