2424import static org .apache .beam .sdk .util .SerializableUtils .serializeToByteArray ;
2525import static org .apache .beam .sdk .util .StringUtils .byteArrayToJsonString ;
2626import static org .hamcrest .MatcherAssert .assertThat ;
27+ import static org .hamcrest .Matchers .contains ;
2728import static org .hamcrest .Matchers .hasItems ;
2829import static org .hamcrest .Matchers .instanceOf ;
2930import static org .junit .Assert .assertEquals ;
@@ -330,6 +331,7 @@ public void testCreateReadOperation() throws Exception {
330331 PCOLLECTION_ID ))));
331332 when (network .outDegree (instructionNode )).thenReturn (1 );
332333
334+ ArrayList <Operation > createdOperations = new ArrayList <>();
333335 Node operationNode =
334336 mapTaskExecutorFactory
335337 .createOperationTransformForParallelInstructionNodes (
@@ -338,11 +340,13 @@ public void testCreateReadOperation() throws Exception {
338340 PipelineOptionsFactory .create (),
339341 readerRegistry ,
340342 sinkRegistry ,
341- BatchModeExecutionContext .forTesting (options , counterSet , "testStage" ))
343+ BatchModeExecutionContext .forTesting (options , counterSet , "testStage" ),
344+ createdOperations )
342345 .apply (instructionNode );
343346 assertThat (operationNode , instanceOf (OperationNode .class ));
344347 assertThat (((OperationNode ) operationNode ).getOperation (), instanceOf (ReadOperation .class ));
345348 ReadOperation readOperation = (ReadOperation ) ((OperationNode ) operationNode ).getOperation ();
349+ assertThat (createdOperations , contains (readOperation ));
346350
347351 assertEquals (1 , readOperation .receivers .length );
348352 assertEquals (0 , readOperation .receivers [0 ].getReceiverCount ());
@@ -391,6 +395,7 @@ public void testCreateWriteOperation() throws Exception {
391395 ParallelInstructionNode .create (
392396 createWriteInstruction (producerIndex , producerOutputNum , "WriteOperation" ),
393397 ExecutionLocation .UNKNOWN );
398+ ArrayList <Operation > createdOperations = new ArrayList <>();
394399 Node operationNode =
395400 mapTaskExecutorFactory
396401 .createOperationTransformForParallelInstructionNodes (
@@ -399,11 +404,13 @@ public void testCreateWriteOperation() throws Exception {
399404 options ,
400405 readerRegistry ,
401406 sinkRegistry ,
402- BatchModeExecutionContext .forTesting (options , counterSet , "testStage" ))
407+ BatchModeExecutionContext .forTesting (options , counterSet , "testStage" ),
408+ createdOperations )
403409 .apply (instructionNode );
404410 assertThat (operationNode , instanceOf (OperationNode .class ));
405411 assertThat (((OperationNode ) operationNode ).getOperation (), instanceOf (WriteOperation .class ));
406412 WriteOperation writeOperation = (WriteOperation ) ((OperationNode ) operationNode ).getOperation ();
413+ assertThat (createdOperations , contains (writeOperation ));
407414
408415 assertEquals (0 , writeOperation .receivers .length );
409416 assertEquals (Operation .InitializationState .UNSTARTED , writeOperation .initializationState );
@@ -541,14 +548,16 @@ public void testCreateParDoOperation() throws Exception {
541548 .getMultiOutputInfos ()
542549 .get (0 ))));
543550
551+ ArrayList <Operation > createdOperations = new ArrayList <>();
544552 Node operationNode =
545553 mapTaskExecutorFactory
546554 .createOperationTransformForParallelInstructionNodes (
547- STAGE , network , options , readerRegistry , sinkRegistry , context )
555+ STAGE , network , options , readerRegistry , sinkRegistry , context , createdOperations )
548556 .apply (instructionNode );
549557 assertThat (operationNode , instanceOf (OperationNode .class ));
550558 assertThat (((OperationNode ) operationNode ).getOperation (), instanceOf (ParDoOperation .class ));
551559 ParDoOperation parDoOperation = (ParDoOperation ) ((OperationNode ) operationNode ).getOperation ();
560+ assertThat (createdOperations , contains (parDoOperation ));
552561
553562 assertEquals (1 , parDoOperation .receivers .length );
554563 assertEquals (0 , parDoOperation .receivers [0 ].getReceiverCount ());
@@ -608,6 +617,7 @@ public void testCreatePartialGroupByKeyOperation() throws Exception {
608617 PCOLLECTION_ID ))));
609618 when (network .outDegree (instructionNode )).thenReturn (1 );
610619
620+ ArrayList <Operation > createdOperations = new ArrayList <>();
611621 Node operationNode =
612622 mapTaskExecutorFactory
613623 .createOperationTransformForParallelInstructionNodes (
@@ -616,11 +626,13 @@ public void testCreatePartialGroupByKeyOperation() throws Exception {
616626 PipelineOptionsFactory .create (),
617627 readerRegistry ,
618628 sinkRegistry ,
619- BatchModeExecutionContext .forTesting (options , counterSet , "testStage" ))
629+ BatchModeExecutionContext .forTesting (options , counterSet , "testStage" ),
630+ createdOperations )
620631 .apply (instructionNode );
621632 assertThat (operationNode , instanceOf (OperationNode .class ));
622633 assertThat (((OperationNode ) operationNode ).getOperation (), instanceOf (ParDoOperation .class ));
623634 ParDoOperation pgbkOperation = (ParDoOperation ) ((OperationNode ) operationNode ).getOperation ();
635+ assertThat (createdOperations , contains (pgbkOperation ));
624636
625637 assertEquals (1 , pgbkOperation .receivers .length );
626638 assertEquals (0 , pgbkOperation .receivers [0 ].getReceiverCount ());
@@ -660,6 +672,7 @@ public void testCreatePartialGroupByKeyOperationWithCombine() throws Exception {
660672 PCOLLECTION_ID ))));
661673 when (network .outDegree (instructionNode )).thenReturn (1 );
662674
675+ ArrayList <Operation > createdOperations = new ArrayList <>();
663676 Node operationNode =
664677 mapTaskExecutorFactory
665678 .createOperationTransformForParallelInstructionNodes (
@@ -668,11 +681,13 @@ public void testCreatePartialGroupByKeyOperationWithCombine() throws Exception {
668681 options ,
669682 readerRegistry ,
670683 sinkRegistry ,
671- BatchModeExecutionContext .forTesting (options , counterSet , "testStage" ))
684+ BatchModeExecutionContext .forTesting (options , counterSet , "testStage" ),
685+ createdOperations )
672686 .apply (instructionNode );
673687 assertThat (operationNode , instanceOf (OperationNode .class ));
674688 assertThat (((OperationNode ) operationNode ).getOperation (), instanceOf (ParDoOperation .class ));
675689 ParDoOperation pgbkOperation = (ParDoOperation ) ((OperationNode ) operationNode ).getOperation ();
690+ assertThat (createdOperations , contains (pgbkOperation ));
676691
677692 assertEquals (1 , pgbkOperation .receivers .length );
678693 assertEquals (0 , pgbkOperation .receivers [0 ].getReceiverCount ());
@@ -738,6 +753,7 @@ public void testCreateFlattenOperation() throws Exception {
738753 PCOLLECTION_ID ))));
739754 when (network .outDegree (instructionNode )).thenReturn (1 );
740755
756+ ArrayList <Operation > createdOperations = new ArrayList <>();
741757 Node operationNode =
742758 mapTaskExecutorFactory
743759 .createOperationTransformForParallelInstructionNodes (
@@ -746,12 +762,14 @@ public void testCreateFlattenOperation() throws Exception {
746762 options ,
747763 readerRegistry ,
748764 sinkRegistry ,
749- BatchModeExecutionContext .forTesting (options , counterSet , "testStage" ))
765+ BatchModeExecutionContext .forTesting (options , counterSet , "testStage" ),
766+ createdOperations )
750767 .apply (instructionNode );
751768 assertThat (operationNode , instanceOf (OperationNode .class ));
752769 assertThat (((OperationNode ) operationNode ).getOperation (), instanceOf (FlattenOperation .class ));
753770 FlattenOperation flattenOperation =
754771 (FlattenOperation ) ((OperationNode ) operationNode ).getOperation ();
772+ assertThat (createdOperations , contains (flattenOperation ));
755773
756774 assertEquals (1 , flattenOperation .receivers .length );
757775 assertEquals (0 , flattenOperation .receivers [0 ].getReceiverCount ());
0 commit comments