Skip to content

Commit 5ccc10b

Browse files
authored
executor: fix hashjoin goleak (#39023) (#39054)
close #39026
1 parent 33261f2 commit 5ccc10b

File tree

2 files changed

+71
-22
lines changed

2 files changed

+71
-22
lines changed

executor/join.go

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020
"fmt"
2121
"runtime/trace"
2222
"strconv"
23-
"sync"
2423
"sync/atomic"
2524
"time"
2625

@@ -67,6 +66,8 @@ type HashJoinExec struct {
6766

6867
// closeCh add a lock for closing executor.
6968
closeCh chan struct{}
69+
worker util.WaitGroupWrapper
70+
waiter util.WaitGroupWrapper
7071
joinType plannercore.JoinType
7172
requiredRows int64
7273

@@ -89,9 +90,7 @@ type HashJoinExec struct {
8990
prepared bool
9091
isOuterJoin bool
9192

92-
// joinWorkerWaitGroup is for sync multiple join workers.
93-
joinWorkerWaitGroup sync.WaitGroup
94-
finished atomic.Value
93+
finished atomic.Value
9594

9695
stats *hashJoinRuntimeStats
9796
}
@@ -146,6 +145,7 @@ func (e *HashJoinExec) Close() error {
146145
e.probeChkResourceCh = nil
147146
e.joinChkResourceCh = nil
148147
terror.Call(e.rowContainer.Close)
148+
e.waiter.Wait()
149149
}
150150
e.outerMatchedStatus = e.outerMatchedStatus[:0]
151151

@@ -168,9 +168,10 @@ func (e *HashJoinExec) Open(ctx context.Context) error {
168168
e.diskTracker = disk.NewTracker(e.id, -1)
169169
e.diskTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.DiskTracker)
170170

171+
e.worker = util.WaitGroupWrapper{}
172+
e.waiter = util.WaitGroupWrapper{}
171173
e.closeCh = make(chan struct{})
172174
e.finished.Store(false)
173-
e.joinWorkerWaitGroup = sync.WaitGroup{}
174175

175176
if e.probeTypes == nil {
176177
e.probeTypes = retTypes(e.probeSideExec)
@@ -264,13 +265,13 @@ func (e *HashJoinExec) wait4BuildSide() (emptyBuild bool, err error) {
264265

265266
// fetchBuildSideRows fetches all rows from build side executor, and append them
266267
// to e.buildSideResult.
267-
func (e *HashJoinExec) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chunk.Chunk, doneCh <-chan struct{}) {
268+
func (e *HashJoinExec) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chunk.Chunk, errCh chan<- error, doneCh <-chan struct{}) {
268269
defer close(chkCh)
269270
var err error
270271
failpoint.Inject("issue30289", func(val failpoint.Value) {
271272
if val.(bool) {
272273
err = errors.Errorf("issue30289 build return error")
273-
e.buildFinished <- errors.Trace(err)
274+
errCh <- errors.Trace(err)
274275
return
275276
}
276277
})
@@ -281,7 +282,7 @@ func (e *HashJoinExec) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chu
281282
chk := chunk.NewChunkWithCapacity(e.buildSideExec.base().retFieldTypes, e.ctx.GetSessionVars().MaxChunkSize)
282283
err = Next(ctx, e.buildSideExec, chk)
283284
if err != nil {
284-
e.buildFinished <- errors.Trace(err)
285+
errCh <- errors.Trace(err)
285286
return
286287
}
287288
failpoint.Inject("errorFetchBuildSideRowsMockOOMPanic", nil)
@@ -332,8 +333,7 @@ func (e *HashJoinExec) initializeForProbe() {
332333

333334
func (e *HashJoinExec) fetchAndProbeHashTable(ctx context.Context) {
334335
e.initializeForProbe()
335-
e.joinWorkerWaitGroup.Add(1)
336-
go util.WithRecovery(func() {
336+
e.worker.RunWithRecover(func() {
337337
defer trace.StartRegion(ctx, "HashJoinProbeSideFetcher").End()
338338
e.fetchProbeSideChunks(ctx)
339339
}, e.handleProbeSideFetcherPanic)
@@ -344,14 +344,13 @@ func (e *HashJoinExec) fetchAndProbeHashTable(ctx context.Context) {
344344
}
345345

346346
for i := uint(0); i < e.concurrency; i++ {
347-
e.joinWorkerWaitGroup.Add(1)
348347
workID := i
349-
go util.WithRecovery(func() {
348+
e.worker.RunWithRecover(func() {
350349
defer trace.StartRegion(ctx, "HashJoinWorker").End()
351350
e.runJoinWorker(workID, probeKeyColIdx)
352351
}, e.handleJoinWorkerPanic)
353352
}
354-
go util.WithRecovery(e.waitJoinWorkersAndCloseResultChan, nil)
353+
e.waiter.RunWithRecover(e.waitJoinWorkersAndCloseResultChan, nil)
355354
}
356355

357356
func (e *HashJoinExec) handleProbeSideFetcherPanic(r interface{}) {
@@ -361,14 +360,12 @@ func (e *HashJoinExec) handleProbeSideFetcherPanic(r interface{}) {
361360
if r != nil {
362361
e.joinResultCh <- &hashjoinWorkerResult{err: errors.Errorf("%v", r)}
363362
}
364-
e.joinWorkerWaitGroup.Done()
365363
}
366364

367365
func (e *HashJoinExec) handleJoinWorkerPanic(r interface{}) {
368366
if r != nil {
369367
e.joinResultCh <- &hashjoinWorkerResult{err: errors.Errorf("%v", r)}
370368
}
371-
e.joinWorkerWaitGroup.Done()
372369
}
373370

374371
// Concurrently handling unmatched rows from the hash table
@@ -408,15 +405,14 @@ func (e *HashJoinExec) handleUnmatchedRowsFromHashTable(workerID uint) {
408405
}
409406

410407
func (e *HashJoinExec) waitJoinWorkersAndCloseResultChan() {
411-
e.joinWorkerWaitGroup.Wait()
408+
e.worker.Wait()
412409
if e.useOuterToBuild {
413410
// Concurrently handling unmatched rows from the hash table at the tail
414411
for i := uint(0); i < e.concurrency; i++ {
415412
var workerID = i
416-
e.joinWorkerWaitGroup.Add(1)
417-
go util.WithRecovery(func() { e.handleUnmatchedRowsFromHashTable(workerID) }, e.handleJoinWorkerPanic)
413+
e.worker.RunWithRecover(func() { e.handleUnmatchedRowsFromHashTable(workerID) }, e.handleJoinWorkerPanic)
418414
}
419-
e.joinWorkerWaitGroup.Wait()
415+
e.worker.Wait()
420416
}
421417
close(e.joinResultCh)
422418
}
@@ -682,7 +678,7 @@ func (e *HashJoinExec) Next(ctx context.Context, req *chunk.Chunk) (err error) {
682678
e.rowContainerForProbe[i] = e.rowContainer.ShallowCopy()
683679
}
684680
}
685-
go util.WithRecovery(func() {
681+
e.worker.RunWithRecover(func() {
686682
defer trace.StartRegion(ctx, "HashJoinHashTableBuilder").End()
687683
e.fetchAndBuildHashTable(ctx)
688684
}, e.handleFetchAndBuildHashTablePanic)
@@ -725,10 +721,10 @@ func (e *HashJoinExec) fetchAndBuildHashTable(ctx context.Context) {
725721
buildSideResultCh := make(chan *chunk.Chunk, 1)
726722
doneCh := make(chan struct{})
727723
fetchBuildSideRowsOk := make(chan error, 1)
728-
go util.WithRecovery(
724+
e.worker.RunWithRecover(
729725
func() {
730726
defer trace.StartRegion(ctx, "HashJoinBuildSideFetcher").End()
731-
e.fetchBuildSideRows(ctx, buildSideResultCh, doneCh)
727+
e.fetchBuildSideRows(ctx, buildSideResultCh, fetchBuildSideRowsOk, doneCh)
732728
},
733729
func(r interface{}) {
734730
if r != nil {

util/wait_group_wrapper.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright 2021 PingCAP, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package util
16+
17+
import (
18+
"sync"
19+
)
20+
21+
// WaitGroupWrapper is a wrapper for sync.WaitGroup
22+
type WaitGroupWrapper struct {
23+
sync.WaitGroup
24+
}
25+
26+
// Run runs a function in a goroutine, adds 1 to WaitGroup
27+
// and calls done when function returns. Please DO NOT use panic
28+
// in the cb function.
29+
func (w *WaitGroupWrapper) Run(exec func()) {
30+
w.Add(1)
31+
go func() {
32+
defer w.Done()
33+
exec()
34+
}()
35+
}
36+
37+
// RunWithRecover wraps goroutine startup call with force recovery, add 1 to WaitGroup
38+
// and call done when function return. it will dump current goroutine stack into log if catch any recover result.
39+
// exec is that execute logic function. recoverFn is that handler will be called after recover and before dump stack,
40+
// passing `nil` means noop.
41+
func (w *WaitGroupWrapper) RunWithRecover(exec func(), recoverFn func(r interface{})) {
42+
w.Add(1)
43+
go func() {
44+
defer func() {
45+
r := recover()
46+
if recoverFn != nil {
47+
recoverFn(r)
48+
}
49+
w.Done()
50+
}()
51+
exec()
52+
}()
53+
}

0 commit comments

Comments
 (0)