@@ -20,7 +20,6 @@ import (
2020 "fmt"
2121 "runtime/trace"
2222 "strconv"
23- "sync"
2423 "sync/atomic"
2524 "time"
2625
@@ -67,6 +66,8 @@ type HashJoinExec struct {
6766
6867 // closeCh add a lock for closing executor.
6968 closeCh chan struct {}
69+ worker util.WaitGroupWrapper
70+ waiter util.WaitGroupWrapper
7071 joinType plannercore.JoinType
7172 requiredRows int64
7273
@@ -89,9 +90,7 @@ type HashJoinExec struct {
8990 prepared bool
9091 isOuterJoin bool
9192
92- // joinWorkerWaitGroup is for sync multiple join workers.
93- joinWorkerWaitGroup sync.WaitGroup
94- finished atomic.Value
93+ finished atomic.Value
9594
9695 stats * hashJoinRuntimeStats
9796}
@@ -146,6 +145,7 @@ func (e *HashJoinExec) Close() error {
146145 e .probeChkResourceCh = nil
147146 e .joinChkResourceCh = nil
148147 terror .Call (e .rowContainer .Close )
148+ e .waiter .Wait ()
149149 }
150150 e .outerMatchedStatus = e .outerMatchedStatus [:0 ]
151151
@@ -168,9 +168,10 @@ func (e *HashJoinExec) Open(ctx context.Context) error {
168168 e .diskTracker = disk .NewTracker (e .id , - 1 )
169169 e .diskTracker .AttachTo (e .ctx .GetSessionVars ().StmtCtx .DiskTracker )
170170
171+ e .worker = util.WaitGroupWrapper {}
172+ e .waiter = util.WaitGroupWrapper {}
171173 e .closeCh = make (chan struct {})
172174 e .finished .Store (false )
173- e .joinWorkerWaitGroup = sync.WaitGroup {}
174175
175176 if e .probeTypes == nil {
176177 e .probeTypes = retTypes (e .probeSideExec )
@@ -264,13 +265,13 @@ func (e *HashJoinExec) wait4BuildSide() (emptyBuild bool, err error) {
264265
265266// fetchBuildSideRows fetches all rows from build side executor, and append them
266267// to e.buildSideResult.
267- func (e * HashJoinExec ) fetchBuildSideRows (ctx context.Context , chkCh chan <- * chunk.Chunk , doneCh <- chan struct {}) {
268+ func (e * HashJoinExec ) fetchBuildSideRows (ctx context.Context , chkCh chan <- * chunk.Chunk , errCh chan <- error , doneCh <- chan struct {}) {
268269 defer close (chkCh )
269270 var err error
270271 failpoint .Inject ("issue30289" , func (val failpoint.Value ) {
271272 if val .(bool ) {
272273 err = errors .Errorf ("issue30289 build return error" )
273- e . buildFinished <- errors .Trace (err )
274+ errCh <- errors .Trace (err )
274275 return
275276 }
276277 })
@@ -281,7 +282,7 @@ func (e *HashJoinExec) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chu
281282 chk := chunk .NewChunkWithCapacity (e .buildSideExec .base ().retFieldTypes , e .ctx .GetSessionVars ().MaxChunkSize )
282283 err = Next (ctx , e .buildSideExec , chk )
283284 if err != nil {
284- e . buildFinished <- errors .Trace (err )
285+ errCh <- errors .Trace (err )
285286 return
286287 }
287288 failpoint .Inject ("errorFetchBuildSideRowsMockOOMPanic" , nil )
@@ -332,8 +333,7 @@ func (e *HashJoinExec) initializeForProbe() {
332333
333334func (e * HashJoinExec ) fetchAndProbeHashTable (ctx context.Context ) {
334335 e .initializeForProbe ()
335- e .joinWorkerWaitGroup .Add (1 )
336- go util .WithRecovery (func () {
336+ e .worker .RunWithRecover (func () {
337337 defer trace .StartRegion (ctx , "HashJoinProbeSideFetcher" ).End ()
338338 e .fetchProbeSideChunks (ctx )
339339 }, e .handleProbeSideFetcherPanic )
@@ -344,14 +344,13 @@ func (e *HashJoinExec) fetchAndProbeHashTable(ctx context.Context) {
344344 }
345345
346346 for i := uint (0 ); i < e .concurrency ; i ++ {
347- e .joinWorkerWaitGroup .Add (1 )
348347 workID := i
349- go util . WithRecovery (func () {
348+ e . worker . RunWithRecover (func () {
350349 defer trace .StartRegion (ctx , "HashJoinWorker" ).End ()
351350 e .runJoinWorker (workID , probeKeyColIdx )
352351 }, e .handleJoinWorkerPanic )
353352 }
354- go util . WithRecovery (e .waitJoinWorkersAndCloseResultChan , nil )
353+ e . waiter . RunWithRecover (e .waitJoinWorkersAndCloseResultChan , nil )
355354}
356355
357356func (e * HashJoinExec ) handleProbeSideFetcherPanic (r interface {}) {
@@ -361,14 +360,12 @@ func (e *HashJoinExec) handleProbeSideFetcherPanic(r interface{}) {
361360 if r != nil {
362361 e .joinResultCh <- & hashjoinWorkerResult {err : errors .Errorf ("%v" , r )}
363362 }
364- e .joinWorkerWaitGroup .Done ()
365363}
366364
367365func (e * HashJoinExec ) handleJoinWorkerPanic (r interface {}) {
368366 if r != nil {
369367 e .joinResultCh <- & hashjoinWorkerResult {err : errors .Errorf ("%v" , r )}
370368 }
371- e .joinWorkerWaitGroup .Done ()
372369}
373370
374371// Concurrently handling unmatched rows from the hash table
@@ -408,15 +405,14 @@ func (e *HashJoinExec) handleUnmatchedRowsFromHashTable(workerID uint) {
408405}
409406
410407func (e * HashJoinExec ) waitJoinWorkersAndCloseResultChan () {
411- e .joinWorkerWaitGroup .Wait ()
408+ e .worker .Wait ()
412409 if e .useOuterToBuild {
413410 // Concurrently handling unmatched rows from the hash table at the tail
414411 for i := uint (0 ); i < e .concurrency ; i ++ {
415412 var workerID = i
416- e .joinWorkerWaitGroup .Add (1 )
417- go util .WithRecovery (func () { e .handleUnmatchedRowsFromHashTable (workerID ) }, e .handleJoinWorkerPanic )
413+ e .worker .RunWithRecover (func () { e .handleUnmatchedRowsFromHashTable (workerID ) }, e .handleJoinWorkerPanic )
418414 }
419- e .joinWorkerWaitGroup .Wait ()
415+ e .worker .Wait ()
420416 }
421417 close (e .joinResultCh )
422418}
@@ -682,7 +678,7 @@ func (e *HashJoinExec) Next(ctx context.Context, req *chunk.Chunk) (err error) {
682678 e .rowContainerForProbe [i ] = e .rowContainer .ShallowCopy ()
683679 }
684680 }
685- go util . WithRecovery (func () {
681+ e . worker . RunWithRecover (func () {
686682 defer trace .StartRegion (ctx , "HashJoinHashTableBuilder" ).End ()
687683 e .fetchAndBuildHashTable (ctx )
688684 }, e .handleFetchAndBuildHashTablePanic )
@@ -725,10 +721,10 @@ func (e *HashJoinExec) fetchAndBuildHashTable(ctx context.Context) {
725721 buildSideResultCh := make (chan * chunk.Chunk , 1 )
726722 doneCh := make (chan struct {})
727723 fetchBuildSideRowsOk := make (chan error , 1 )
728- go util . WithRecovery (
724+ e . worker . RunWithRecover (
729725 func () {
730726 defer trace .StartRegion (ctx , "HashJoinBuildSideFetcher" ).End ()
731- e .fetchBuildSideRows (ctx , buildSideResultCh , doneCh )
727+ e .fetchBuildSideRows (ctx , buildSideResultCh , fetchBuildSideRowsOk , doneCh )
732728 },
733729 func (r interface {}) {
734730 if r != nil {
0 commit comments