diff --git a/azure-servicebus/src/main/java/com/microsoft/azure/servicebus/primitives/ClientEntity.java b/azure-servicebus/src/main/java/com/microsoft/azure/servicebus/primitives/ClientEntity.java index 86156dc5..5a938060 100644 --- a/azure-servicebus/src/main/java/com/microsoft/azure/servicebus/primitives/ClientEntity.java +++ b/azure-servicebus/src/main/java/com/microsoft/azure/servicebus/primitives/ClientEntity.java @@ -4,6 +4,8 @@ */ package com.microsoft.azure.servicebus.primitives; +import java.util.ArrayList; +import java.util.List; import java.util.Locale; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -21,6 +23,7 @@ public abstract class ClientEntity private boolean isClosing; private boolean isClosed; + private List> closeFutures; protected ClientEntity(final String clientId) { @@ -44,8 +47,7 @@ protected boolean getIsClosed() } protected boolean getIsClosingOrClosed() - { - + { synchronized (this.syncClose) { return this.isClosing || this.isClosed; @@ -63,27 +65,49 @@ protected final void setClosed() public final CompletableFuture closeAsync() { - if(this.getIsClosingOrClosed()) - { - return CompletableFuture.completedFuture(null); - } - synchronized (this.syncClose) { - this.isClosing = true; + if(this.isClosed) { + return CompletableFuture.completedFuture(null); + }else if (this.isClosing) { + if(this.closeFutures == null) { + this.closeFutures = new ArrayList<>(); + } + + CompletableFuture closeFuture = new CompletableFuture<>(); + this.closeFutures.add(closeFuture); + return closeFuture; + }else { + this.isClosing = true; + } } - - return this.onClose().thenRunAsync(new Runnable() - { - @Override - public void run() + + return this.onClose().whenCompleteAsync((v, t) -> { + if(t == null) { - synchronized (ClientEntity.this.syncClose) - { + synchronized (ClientEntity.this.syncClose) { ClientEntity.this.isClosing = false; ClientEntity.this.isClosed = true; + if(this.closeFutures != null) { + for(CompletableFuture future : this.closeFutures) { + AsyncUtil.completeFuture(future, null); + } + } + } + } + else + { + // onClose failed with some exception. set isClosing to false, so client can call close again. + synchronized (ClientEntity.this.syncClose) { + ClientEntity.this.isClosing = false; + if(this.closeFutures != null) { + for(CompletableFuture future : this.closeFutures) { + AsyncUtil.completeFutureExceptionally(future, t); + } + } } - }}, MessagingFactory.INTERNAL_THREAD_POOL); + } + }, MessagingFactory.INTERNAL_THREAD_POOL); } public final void close() throws ServiceBusException diff --git a/azure-servicebus/src/test/java/com/microsoft/azure/servicebus/primitives/ClientEntityTests.java b/azure-servicebus/src/test/java/com/microsoft/azure/servicebus/primitives/ClientEntityTests.java new file mode 100644 index 00000000..d344a165 --- /dev/null +++ b/azure-servicebus/src/test/java/com/microsoft/azure/servicebus/primitives/ClientEntityTests.java @@ -0,0 +1,199 @@ +package com.microsoft.azure.servicebus.primitives; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ForkJoinTask; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Assert; +import org.junit.Test; + +public class ClientEntityTests { + + @Test + public void closeMultipleTimesTest() { + int expectedNumberOfCloseFailures = 5; + TestClientEntity clientEntity = new TestClientEntity(expectedNumberOfCloseFailures); + int actualCloseFailures = 0; + for (int i=0; i>> tasks = new ArrayList<>(); + + for (int i=0; i clientEntity.closeAsync())); + } + + for (ForkJoinTask> task : tasks) { + Assert.assertFalse("Entity closed early.", task.get().isDone()); + } + + Assert.assertTrue("Entity not closing even after calling close.", clientEntity.getIsClosingOrClosed()); + Assert.assertFalse("Entity closed without sleeping.", clientEntity.getIsClosed()); + + Thread.sleep(sleepInCloseDuration.toMillis() + 500); // 500 millis buffer + + for (ForkJoinTask> task : tasks) { + Assert.assertTrue("Entity not closed even after delay.", task.get().isDone()); + } + + Assert.assertTrue("Entity not closed even after calling close.", clientEntity.getIsClosed()); + Assert.assertEquals("OnClose called more than expected number of times", 1, clientEntity.getTotalNumberOfCloseCalls()); + } + + @Test + public void concurrentEntityCloseFailureTest() throws InterruptedException, ExecutionException { + int numConcurrentCalls = 10; + Duration sleepInCloseDuration = Duration.ofSeconds(5); + TestClientEntity clientEntity = new TestClientEntity(1, sleepInCloseDuration, true); + ArrayList>> tasks = new ArrayList<>(); + + for (int i=0; i clientEntity.closeAsync())); + } + + for (ForkJoinTask> task : tasks) { + Assert.assertFalse("Entity close failed too early.", task.get().isDone()); + } + + Assert.assertTrue("Entity not closing even after calling close.", clientEntity.getIsClosingOrClosed()); + Assert.assertFalse("Entity closed without sleeping.", clientEntity.getIsClosed()); + + Thread.sleep(sleepInCloseDuration.toMillis() + 500); // 500 millis buffer + + Throwable failureException = null; + for (ForkJoinTask> task : tasks) { + CompletableFuture closeFuture = task.get(); + Assert.assertTrue("Entity close didn't fail even after delay.", closeFuture.isDone()); + try { + closeFuture.get(); + Assert.fail("Entity close didn't fail."); + }catch (ExecutionException ex) { + Throwable cause = ex.getCause(); + if (failureException == null) { + failureException = cause; + } + else + { + Assert.assertEquals("All concurrent close failures didn't fail with the same exception.", failureException, cause); + } + } + } + } + + class TestClientEntity extends ClientEntity { + + private final Duration sleepDurationInClose; + private final int targetNumberOfCloseFailures; + private final boolean shouldSleepInFailure; + private AtomicInteger numberOfCloseFailures = new AtomicInteger(0); + private AtomicInteger numberOfCloseCalls = new AtomicInteger(0); + + TestClientEntity(int targetNumberOfCloseFailures) { + this(targetNumberOfCloseFailures, Duration.ZERO); + } + + TestClientEntity(int targetNumberOfCloseFailures, Duration sleepDurationInClose) { + this(targetNumberOfCloseFailures, sleepDurationInClose, false); + } + + TestClientEntity(int targetNumberOfCloseFailures, Duration sleepDurationInClose, boolean shouldSleepInFailure) { + super(UUID.randomUUID().toString()); + this.sleepDurationInClose = sleepDurationInClose; + this.targetNumberOfCloseFailures = targetNumberOfCloseFailures; + this.shouldSleepInFailure = shouldSleepInFailure; + } + + @Override + protected CompletableFuture onClose() { + numberOfCloseCalls.incrementAndGet(); + if(numberOfCloseFailures.get() < targetNumberOfCloseFailures) + { + numberOfCloseFailures.incrementAndGet(); + CompletableFuture failFuture = new CompletableFuture<>(); + if (this.shouldSleepInFailure) { + if (this.sleepDurationInClose.isZero()) { + failFuture.completeExceptionally(new Exception("Close failed.")); + } else { + Thread completionThread = new Thread(() -> { + try { + Thread.sleep(this.sleepDurationInClose.toMillis()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + failFuture.completeExceptionally(new Exception("Close failed.")); + }); + + completionThread.start(); + } + } else { + failFuture.completeExceptionally(new Exception("Close failed.")); + } + + return failFuture; + } + else + { + if (this.sleepDurationInClose.isZero()) { + return CompletableFuture.completedFuture(null); + } else { + CompletableFuture closeFuture = new CompletableFuture<>(); + + Thread completionThread = new Thread(() -> { + try { + Thread.sleep(this.sleepDurationInClose.toMillis()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + closeFuture.complete(null); + }); + + completionThread.start(); + return closeFuture; + } + } + } + + int getTotalNumberOfCloseCalls() { + return this.numberOfCloseCalls.get(); + } + } +}