Skip to content

Commit 5ff92c2

Browse files
committed
Ensure that when we invalidate a WorkExecutor that we close the MapTaskExecutor which ends up calling teardown on the Dofns.
Ensure that when an exception is encountered during Setup of DoFns, that previously created dofns are torn down.
1 parent af748d0 commit 5ff92c2

File tree

4 files changed

+121
-58
lines changed

4 files changed

+121
-58
lines changed

runners/google-cloud-dataflow-java/build.gradle

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,6 @@ def commonLegacyExcludeCategories = [
205205
'org.apache.beam.sdk.testing.UsesGaugeMetrics',
206206
'org.apache.beam.sdk.testing.UsesMultimapState',
207207
'org.apache.beam.sdk.testing.UsesTestStream',
208-
'org.apache.beam.sdk.testing.UsesParDoLifecycle', // doesn't support remote runner
209208
'org.apache.beam.sdk.testing.UsesMetricsPusher',
210209
'org.apache.beam.sdk.testing.UsesBundleFinalizer',
211210
'org.apache.beam.sdk.testing.UsesBoundedTrieMetrics', // Dataflow QM as of now does not support returning back BoundedTrie in metric result.
@@ -520,17 +519,6 @@ task validatesRunnerV2 {
520519
excludedTests: [
521520
'org.apache.beam.sdk.transforms.ReshuffleTest.testReshuffleWithTimestampsStreaming',
522521

523-
// TODO(https://github.com/apache/beam/issues/18592): respect ParDo lifecycle.
524-
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testFnCallSequenceStateful',
525-
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle',
526-
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful',
527-
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement',
528-
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElementStateful',
529-
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetup',
530-
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetupStateful',
531-
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundle',
532-
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundleStateful',
533-
534522
'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testCombiningAccumulatingProcessingTime',
535523
'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testLargeKeys100MB',
536524
'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testLargeKeys10MB',
@@ -563,16 +551,6 @@ task validatesRunnerV2Streaming {
563551
'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testAfterProcessingTimeContinuationTriggerUsingState',
564552
'org.apache.beam.sdk.transforms.GroupByKeyTest.testCombiningAccumulatingProcessingTime',
565553

566-
// TODO(https://github.com/apache/beam/issues/18592): respect ParDo lifecycle.
567-
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle',
568-
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful',
569-
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement',
570-
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElementStateful',
571-
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetup',
572-
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetupStateful',
573-
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundle',
574-
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundleStateful',
575-
576554
// TODO(https://github.com/apache/beam/issues/20931): Identify whether it's bug or a feature gap.
577555
'org.apache.beam.sdk.transforms.GroupByKeyTest$WindowTests.testRewindowWithTimestampCombiner',
578556

runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactory.java

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,32 @@ public DataflowMapTaskExecutor create(
105105
Networks.replaceDirectedNetworkNodes(
106106
network, createOutputReceiversTransform(stageName, counterSet));
107107

108-
// Swap out all the ParallelInstruction nodes with Operation nodes
109-
Networks.replaceDirectedNetworkNodes(
110-
network,
111-
createOperationTransformForParallelInstructionNodes(
112-
stageName, network, options, readerFactory, sinkFactory, executionContext));
108+
// Swap out all the ParallelInstruction nodes with Operation nodes. While updating the network,
109+
// we keep track of
110+
// the created Operations so that if an exception is encountered we can properly abort started
111+
// operations.
112+
ArrayList<Operation> createdOperations = new ArrayList<>();
113+
try {
114+
Networks.replaceDirectedNetworkNodes(
115+
network,
116+
createOperationTransformForParallelInstructionNodes(
117+
stageName,
118+
network,
119+
options,
120+
readerFactory,
121+
sinkFactory,
122+
executionContext,
123+
createdOperations));
124+
} catch (Exception exn) {
125+
for (Operation o : createdOperations) {
126+
try {
127+
o.abort();
128+
} catch (Exception exn2) {
129+
exn.addSuppressed(exn2);
130+
}
131+
}
132+
throw exn;
133+
}
113134

114135
// Collect all the operations within the network and attach all the operations as receivers
115136
// to preceding output receivers.
@@ -144,7 +165,8 @@ Function<Node, Node> createOperationTransformForParallelInstructionNodes(
144165
final PipelineOptions options,
145166
final ReaderFactory readerFactory,
146167
final SinkFactory sinkFactory,
147-
final DataflowExecutionContext<?> executionContext) {
168+
final DataflowExecutionContext<?> executionContext,
169+
final ArrayList<Operation> createdOperations) {
148170

149171
return new TypeSafeNodeFunction<ParallelInstructionNode>(ParallelInstructionNode.class) {
150172
@Override
@@ -156,27 +178,31 @@ public Node typedApply(ParallelInstructionNode node) {
156178
instruction.getOriginalName(),
157179
instruction.getSystemName(),
158180
instruction.getName());
181+
OperationNode result;
159182
try {
160183
DataflowOperationContext context = executionContext.createOperationContext(nameContext);
161184
if (instruction.getRead() != null) {
162-
return createReadOperation(
163-
network, node, options, readerFactory, executionContext, context);
185+
result =
186+
createReadOperation(
187+
network, node, options, readerFactory, executionContext, context);
164188
} else if (instruction.getWrite() != null) {
165-
return createWriteOperation(node, options, sinkFactory, executionContext, context);
189+
result = createWriteOperation(node, options, sinkFactory, executionContext, context);
166190
} else if (instruction.getParDo() != null) {
167-
return createParDoOperation(network, node, options, executionContext, context);
191+
result = createParDoOperation(network, node, options, executionContext, context);
168192
} else if (instruction.getPartialGroupByKey() != null) {
169-
return createPartialGroupByKeyOperation(
170-
network, node, options, executionContext, context);
193+
result =
194+
createPartialGroupByKeyOperation(network, node, options, executionContext, context);
171195
} else if (instruction.getFlatten() != null) {
172-
return createFlattenOperation(network, node, context);
196+
result = createFlattenOperation(network, node, context);
173197
} else {
174198
throw new IllegalArgumentException(
175199
String.format("Unexpected instruction: %s", instruction));
176200
}
177201
} catch (Exception e) {
178202
throw new RuntimeException(e);
179203
}
204+
createdOperations.add(result.getOperation());
205+
return result;
180206
}
181207
};
182208
}

runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
package org.apache.beam.runners.dataflow.worker.util.common.worker;
1919

2020
import java.io.Closeable;
21+
import java.util.ArrayList;
2122
import java.util.List;
22-
import java.util.ListIterator;
2323
import org.apache.beam.runners.core.metrics.ExecutionStateTracker;
2424
import org.apache.beam.runners.dataflow.worker.counters.CounterSet;
2525
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
@@ -36,7 +36,9 @@ public class MapTaskExecutor implements WorkExecutor {
3636
private static final Logger LOG = LoggerFactory.getLogger(MapTaskExecutor.class);
3737

3838
/** The operations in the map task, in execution order. */
39-
public final List<Operation> operations;
39+
public final ArrayList<Operation> operations;
40+
41+
private boolean closed = false;
4042

4143
private final ExecutionStateTracker executionStateTracker;
4244

@@ -54,7 +56,7 @@ public MapTaskExecutor(
5456
CounterSet counters,
5557
ExecutionStateTracker executionStateTracker) {
5658
this.counters = counters;
57-
this.operations = operations;
59+
this.operations = new ArrayList<>(operations);
5860
this.executionStateTracker = executionStateTracker;
5961
}
6062

@@ -63,6 +65,11 @@ public CounterSet getOutputCounters() {
6365
return counters;
6466
}
6567

68+
/**
69+
* May be reused if execute() returns without an exception being thrown.
70+
*
71+
* @throws Exception
72+
*/
6673
@Override
6774
public void execute() throws Exception {
6875
LOG.debug("Executing map task");
@@ -74,13 +81,11 @@ public void execute() throws Exception {
7481
// Starting a root operation such as a ReadOperation does the work
7582
// of processing the input dataset.
7683
LOG.debug("Starting operations");
77-
ListIterator<Operation> iterator = operations.listIterator(operations.size());
78-
while (iterator.hasPrevious()) {
84+
for (int i = operations.size() - 1; i >= 0; --i) {
7985
if (Thread.currentThread().isInterrupted()) {
8086
throw new InterruptedException("Worker aborted");
8187
}
82-
Operation op = iterator.previous();
83-
op.start();
88+
operations.get(i).start();
8489
}
8590

8691
// Finish operations, in forward-execution-order, so that a
@@ -94,16 +99,13 @@ public void execute() throws Exception {
9499
op.finish();
95100
}
96101
} catch (Exception | Error exn) {
97-
LOG.debug("Aborting operations", exn);
98-
for (Operation op : operations) {
99-
try {
100-
op.abort();
101-
} catch (Exception | Error exn2) {
102-
exn.addSuppressed(exn2);
103-
if (exn2 instanceof InterruptedException) {
104-
Thread.currentThread().interrupt();
105-
}
106-
}
102+
try {
103+
closeInternal();
104+
} catch (Exception closeExn) {
105+
exn.addSuppressed(closeExn);
106+
}
107+
if (exn instanceof InterruptedException) {
108+
Thread.currentThread().interrupt();
107109
}
108110
throw exn;
109111
}
@@ -164,6 +166,45 @@ public void abort() {
164166
}
165167
}
166168

169+
private void closeInternal() throws Exception {
170+
Preconditions.checkState(!closed);
171+
LOG.debug("Aborting operations");
172+
@Nullable Exception exn = null;
173+
for (Operation op : operations) {
174+
try {
175+
op.abort();
176+
} catch (Exception | Error exn2) {
177+
if (exn2 instanceof InterruptedException) {
178+
Thread.currentThread().interrupt();
179+
}
180+
if (exn == null) {
181+
if (exn2 instanceof Exception) {
182+
exn = (Exception) exn2;
183+
} else {
184+
exn = new RuntimeException(exn2);
185+
}
186+
} else {
187+
exn.addSuppressed(exn2);
188+
}
189+
}
190+
}
191+
closed = true;
192+
if (exn != null) {
193+
throw exn;
194+
}
195+
}
196+
197+
@Override
198+
public void close() {
199+
if (!closed) {
200+
try {
201+
closeInternal();
202+
} catch (Exception e) {
203+
LOG.error("Exception while closing MapTaskExecutor, ignoring", e);
204+
}
205+
}
206+
}
207+
167208
@Override
168209
public List<Integer> reportProducedEmptyOutput() {
169210
List<Integer> emptyOutputSinkIndexes = Lists.newArrayList();

runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import static org.apache.beam.sdk.util.SerializableUtils.serializeToByteArray;
2525
import static org.apache.beam.sdk.util.StringUtils.byteArrayToJsonString;
2626
import static org.hamcrest.MatcherAssert.assertThat;
27+
import static org.hamcrest.Matchers.contains;
2728
import static org.hamcrest.Matchers.hasItems;
2829
import static org.hamcrest.Matchers.instanceOf;
2930
import static org.junit.Assert.assertEquals;
@@ -330,6 +331,7 @@ public void testCreateReadOperation() throws Exception {
330331
PCOLLECTION_ID))));
331332
when(network.outDegree(instructionNode)).thenReturn(1);
332333

334+
ArrayList<Operation> createdOperations = new ArrayList<>();
333335
Node operationNode =
334336
mapTaskExecutorFactory
335337
.createOperationTransformForParallelInstructionNodes(
@@ -338,11 +340,13 @@ public void testCreateReadOperation() throws Exception {
338340
PipelineOptionsFactory.create(),
339341
readerRegistry,
340342
sinkRegistry,
341-
BatchModeExecutionContext.forTesting(options, counterSet, "testStage"))
343+
BatchModeExecutionContext.forTesting(options, counterSet, "testStage"),
344+
createdOperations)
342345
.apply(instructionNode);
343346
assertThat(operationNode, instanceOf(OperationNode.class));
344347
assertThat(((OperationNode) operationNode).getOperation(), instanceOf(ReadOperation.class));
345348
ReadOperation readOperation = (ReadOperation) ((OperationNode) operationNode).getOperation();
349+
assertThat(createdOperations, contains(readOperation));
346350

347351
assertEquals(1, readOperation.receivers.length);
348352
assertEquals(0, readOperation.receivers[0].getReceiverCount());
@@ -391,6 +395,7 @@ public void testCreateWriteOperation() throws Exception {
391395
ParallelInstructionNode.create(
392396
createWriteInstruction(producerIndex, producerOutputNum, "WriteOperation"),
393397
ExecutionLocation.UNKNOWN);
398+
ArrayList<Operation> createdOperations = new ArrayList<>();
394399
Node operationNode =
395400
mapTaskExecutorFactory
396401
.createOperationTransformForParallelInstructionNodes(
@@ -399,11 +404,13 @@ public void testCreateWriteOperation() throws Exception {
399404
options,
400405
readerRegistry,
401406
sinkRegistry,
402-
BatchModeExecutionContext.forTesting(options, counterSet, "testStage"))
407+
BatchModeExecutionContext.forTesting(options, counterSet, "testStage"),
408+
createdOperations)
403409
.apply(instructionNode);
404410
assertThat(operationNode, instanceOf(OperationNode.class));
405411
assertThat(((OperationNode) operationNode).getOperation(), instanceOf(WriteOperation.class));
406412
WriteOperation writeOperation = (WriteOperation) ((OperationNode) operationNode).getOperation();
413+
assertThat(createdOperations, contains(writeOperation));
407414

408415
assertEquals(0, writeOperation.receivers.length);
409416
assertEquals(Operation.InitializationState.UNSTARTED, writeOperation.initializationState);
@@ -541,14 +548,16 @@ public void testCreateParDoOperation() throws Exception {
541548
.getMultiOutputInfos()
542549
.get(0))));
543550

551+
ArrayList<Operation> createdOperations = new ArrayList<>();
544552
Node operationNode =
545553
mapTaskExecutorFactory
546554
.createOperationTransformForParallelInstructionNodes(
547-
STAGE, network, options, readerRegistry, sinkRegistry, context)
555+
STAGE, network, options, readerRegistry, sinkRegistry, context, createdOperations)
548556
.apply(instructionNode);
549557
assertThat(operationNode, instanceOf(OperationNode.class));
550558
assertThat(((OperationNode) operationNode).getOperation(), instanceOf(ParDoOperation.class));
551559
ParDoOperation parDoOperation = (ParDoOperation) ((OperationNode) operationNode).getOperation();
560+
assertThat(createdOperations, contains(parDoOperation));
552561

553562
assertEquals(1, parDoOperation.receivers.length);
554563
assertEquals(0, parDoOperation.receivers[0].getReceiverCount());
@@ -608,6 +617,7 @@ public void testCreatePartialGroupByKeyOperation() throws Exception {
608617
PCOLLECTION_ID))));
609618
when(network.outDegree(instructionNode)).thenReturn(1);
610619

620+
ArrayList<Operation> createdOperations = new ArrayList<>();
611621
Node operationNode =
612622
mapTaskExecutorFactory
613623
.createOperationTransformForParallelInstructionNodes(
@@ -616,11 +626,13 @@ public void testCreatePartialGroupByKeyOperation() throws Exception {
616626
PipelineOptionsFactory.create(),
617627
readerRegistry,
618628
sinkRegistry,
619-
BatchModeExecutionContext.forTesting(options, counterSet, "testStage"))
629+
BatchModeExecutionContext.forTesting(options, counterSet, "testStage"),
630+
createdOperations)
620631
.apply(instructionNode);
621632
assertThat(operationNode, instanceOf(OperationNode.class));
622633
assertThat(((OperationNode) operationNode).getOperation(), instanceOf(ParDoOperation.class));
623634
ParDoOperation pgbkOperation = (ParDoOperation) ((OperationNode) operationNode).getOperation();
635+
assertThat(createdOperations, contains(pgbkOperation));
624636

625637
assertEquals(1, pgbkOperation.receivers.length);
626638
assertEquals(0, pgbkOperation.receivers[0].getReceiverCount());
@@ -660,6 +672,7 @@ public void testCreatePartialGroupByKeyOperationWithCombine() throws Exception {
660672
PCOLLECTION_ID))));
661673
when(network.outDegree(instructionNode)).thenReturn(1);
662674

675+
ArrayList<Operation> createdOperations = new ArrayList<>();
663676
Node operationNode =
664677
mapTaskExecutorFactory
665678
.createOperationTransformForParallelInstructionNodes(
@@ -668,11 +681,13 @@ public void testCreatePartialGroupByKeyOperationWithCombine() throws Exception {
668681
options,
669682
readerRegistry,
670683
sinkRegistry,
671-
BatchModeExecutionContext.forTesting(options, counterSet, "testStage"))
684+
BatchModeExecutionContext.forTesting(options, counterSet, "testStage"),
685+
createdOperations)
672686
.apply(instructionNode);
673687
assertThat(operationNode, instanceOf(OperationNode.class));
674688
assertThat(((OperationNode) operationNode).getOperation(), instanceOf(ParDoOperation.class));
675689
ParDoOperation pgbkOperation = (ParDoOperation) ((OperationNode) operationNode).getOperation();
690+
assertThat(createdOperations, contains(pgbkOperation));
676691

677692
assertEquals(1, pgbkOperation.receivers.length);
678693
assertEquals(0, pgbkOperation.receivers[0].getReceiverCount());
@@ -738,6 +753,7 @@ public void testCreateFlattenOperation() throws Exception {
738753
PCOLLECTION_ID))));
739754
when(network.outDegree(instructionNode)).thenReturn(1);
740755

756+
ArrayList<Operation> createdOperations = new ArrayList<>();
741757
Node operationNode =
742758
mapTaskExecutorFactory
743759
.createOperationTransformForParallelInstructionNodes(
@@ -746,12 +762,14 @@ public void testCreateFlattenOperation() throws Exception {
746762
options,
747763
readerRegistry,
748764
sinkRegistry,
749-
BatchModeExecutionContext.forTesting(options, counterSet, "testStage"))
765+
BatchModeExecutionContext.forTesting(options, counterSet, "testStage"),
766+
createdOperations)
750767
.apply(instructionNode);
751768
assertThat(operationNode, instanceOf(OperationNode.class));
752769
assertThat(((OperationNode) operationNode).getOperation(), instanceOf(FlattenOperation.class));
753770
FlattenOperation flattenOperation =
754771
(FlattenOperation) ((OperationNode) operationNode).getOperation();
772+
assertThat(createdOperations, contains(flattenOperation));
755773

756774
assertEquals(1, flattenOperation.receivers.length);
757775
assertEquals(0, flattenOperation.receivers[0].getReceiverCount());

0 commit comments

Comments
 (0)