Skip to content

Commit 27dc8f3

Browse files
winoroszz-jason
authored andcommitted
expression/agg: add ResetContext() for stream aggregation. (#6097)
1 parent 4241d70 commit 27dc8f3

File tree

10 files changed

+132
-2
lines changed

10 files changed

+132
-2
lines changed

executor/aggregate.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,12 +347,12 @@ func (e *StreamAggExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Ch
347347
}
348348

349349
// appendResult2Chunk appends result of all the aggregation functions to the
350-
// result chunk, and realloc the evaluation context for each aggregation.
350+
// result chunk, and reset the evaluation context for each aggregation.
351351
func (e *StreamAggExec) appendResult2Chunk(chk *chunk.Chunk) {
352352
e.rowBuffer = e.rowBuffer[:0]
353353
for i, af := range e.AggFuncs {
354354
e.rowBuffer = append(e.rowBuffer, af.GetResult(e.aggCtxs[i]))
355-
e.aggCtxs[i] = af.CreateContext(e.ctx.GetSessionVars().StmtCtx)
355+
af.ResetContext(e.ctx.GetSessionVars().StmtCtx, e.aggCtxs[i])
356356
}
357357
e.mutableRow.SetDatums(e.rowBuffer...)
358358
chk.AppendRow(e.mutableRow.ToRow())

expression/aggregation/aggregation.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ type Aggregation interface {
3838

3939
// Create a new AggEvaluateContext for the aggregation function.
4040
CreateContext(sc *stmtctx.StatementContext) *AggEvaluateContext
41+
42+
// Reset the content of the evaluate context.
43+
ResetContext(sc *stmtctx.StatementContext, evalCtx *AggEvaluateContext)
4144
}
4245

4346
// NewDistAggFunc creates new Aggregate function for mock tikv.
@@ -115,6 +118,13 @@ func (af *aggFunction) CreateContext(sc *stmtctx.StatementContext) *AggEvaluateC
115118
return evalCtx
116119
}
117120

121+
func (af *aggFunction) ResetContext(sc *stmtctx.StatementContext, evalCtx *AggEvaluateContext) {
122+
if af.HasDistinct {
123+
evalCtx.DistinctChecker = createDistinctChecker(sc)
124+
}
125+
evalCtx.Value.SetNull()
126+
}
127+
118128
func (af *aggFunction) updateSum(sc *stmtctx.StatementContext, evalCtx *AggEvaluateContext, row types.Row) error {
119129
a := af.Args[0]
120130
value, err := a.Eval(row)

expression/aggregation/avg.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ func (af *avgFunction) updateAvg(sc *stmtctx.StatementContext, evalCtx *AggEvalu
4646
return nil
4747
}
4848

49+
func (af *avgFunction) ResetContext(sc *stmtctx.StatementContext, evalCtx *AggEvaluateContext) {
50+
if af.HasDistinct {
51+
evalCtx.DistinctChecker = createDistinctChecker(sc)
52+
}
53+
evalCtx.Value.SetNull()
54+
evalCtx.Count = 0
55+
}
56+
4957
// Update implements Aggregation interface.
5058
func (af *avgFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.StatementContext, row types.Row) error {
5159
if af.Mode == FinalMode {
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright 2018 PingCAP, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
package aggregation
15+
16+
import (
17+
"testing"
18+
19+
"github.com/pingcap/tidb/ast"
20+
"github.com/pingcap/tidb/expression"
21+
"github.com/pingcap/tidb/mysql"
22+
"github.com/pingcap/tidb/types"
23+
"github.com/pingcap/tidb/util/mock"
24+
)
25+
26+
func BenchmarkCreateContext(b *testing.B) {
27+
col := &expression.Column{
28+
Index: 0,
29+
RetType: types.NewFieldType(mysql.TypeLonglong),
30+
}
31+
ctx := mock.NewContext()
32+
fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, false).GetAggFunc()
33+
b.StartTimer()
34+
for i := 0; i < b.N; i++ {
35+
fun.CreateContext(ctx.GetSessionVars().StmtCtx)
36+
}
37+
b.ReportAllocs()
38+
}
39+
40+
func BenchmarkResetContext(b *testing.B) {
41+
col := &expression.Column{
42+
Index: 0,
43+
RetType: types.NewFieldType(mysql.TypeLonglong),
44+
}
45+
ctx := mock.NewContext()
46+
fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, false).GetAggFunc()
47+
evalCtx := fun.CreateContext(ctx.GetSessionVars().StmtCtx)
48+
b.StartTimer()
49+
for i := 0; i < b.N; i++ {
50+
fun.ResetContext(ctx.GetSessionVars().StmtCtx, evalCtx)
51+
}
52+
b.ReportAllocs()
53+
}
54+
55+
func BenchmarkCreateDistinctContext(b *testing.B) {
56+
col := &expression.Column{
57+
Index: 0,
58+
RetType: types.NewFieldType(mysql.TypeLonglong),
59+
}
60+
ctx := mock.NewContext()
61+
fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true).GetAggFunc()
62+
b.StartTimer()
63+
for i := 0; i < b.N; i++ {
64+
fun.CreateContext(ctx.GetSessionVars().StmtCtx)
65+
}
66+
b.ReportAllocs()
67+
}
68+
69+
func BenchmarkResetDistinctContext(b *testing.B) {
70+
col := &expression.Column{
71+
Index: 0,
72+
RetType: types.NewFieldType(mysql.TypeLonglong),
73+
}
74+
ctx := mock.NewContext()
75+
fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true).GetAggFunc()
76+
evalCtx := fun.CreateContext(ctx.GetSessionVars().StmtCtx)
77+
b.StartTimer()
78+
for i := 0; i < b.N; i++ {
79+
fun.ResetContext(ctx.GetSessionVars().StmtCtx, evalCtx)
80+
}
81+
b.ReportAllocs()
82+
}

expression/aggregation/bit_and.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ func (bf *bitAndFunction) CreateContext(sc *stmtctx.StatementContext) *AggEvalua
3131
return evalCtx
3232
}
3333

34+
func (bf bitAndFunction) ResetContext(sc *stmtctx.StatementContext, evalCtx *AggEvaluateContext) {
35+
evalCtx.Value.SetUint64(math.MaxUint64)
36+
}
37+
3438
// Update implements Aggregation interface.
3539
func (bf *bitAndFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.StatementContext, row types.Row) error {
3640
a := bf.Args[0]

expression/aggregation/bit_or.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ func (bf *bitOrFunction) CreateContext(sc *stmtctx.StatementContext) *AggEvaluat
2929
return evalCtx
3030
}
3131

32+
func (bf *bitOrFunction) ResetContext(sc *stmtctx.StatementContext, evalCtx *AggEvaluateContext) {
33+
evalCtx.Value.SetUint64(0)
34+
}
35+
3236
// Update implements Aggregation interface.
3337
func (bf *bitOrFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.StatementContext, row types.Row) error {
3438
a := bf.Args[0]

expression/aggregation/bit_xor.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ func (bf *bitXorFunction) CreateContext(sc *stmtctx.StatementContext) *AggEvalua
2929
return evalCtx
3030
}
3131

32+
func (bf *bitXorFunction) ResetContext(sc *stmtctx.StatementContext, evalCtx *AggEvaluateContext) {
33+
evalCtx.Value.SetUint64(0)
34+
}
35+
3236
// Update implements Aggregation interface.
3337
func (bf *bitXorFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.StatementContext, row types.Row) error {
3438
a := bf.Args[0]

expression/aggregation/concat.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,13 @@ func (cf *concatFunction) GetResult(evalCtx *AggEvaluateContext) (d types.Datum)
102102
return d
103103
}
104104

105+
func (cf *concatFunction) ResetContext(sc *stmtctx.StatementContext, evalCtx *AggEvaluateContext) {
106+
if cf.HasDistinct {
107+
evalCtx.DistinctChecker = createDistinctChecker(sc)
108+
}
109+
evalCtx.Buffer = nil
110+
}
111+
105112
// GetPartialResult implements Aggregation interface.
106113
func (cf *concatFunction) GetPartialResult(evalCtx *AggEvaluateContext) []types.Datum {
107114
return []types.Datum{cf.GetResult(evalCtx)}

expression/aggregation/count.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ func (cf *countFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.Stateme
5959
return nil
6060
}
6161

62+
func (cf *countFunction) ResetContext(sc *stmtctx.StatementContext, evalCtx *AggEvaluateContext) {
63+
if cf.HasDistinct {
64+
evalCtx.DistinctChecker = createDistinctChecker(sc)
65+
}
66+
evalCtx.Count = 0
67+
}
68+
6269
// GetResult implements Aggregation interface.
6370
func (cf *countFunction) GetResult(evalCtx *AggEvaluateContext) (d types.Datum) {
6471
d.SetInt64(evalCtx.Count)

expression/aggregation/first_row.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ func (ff *firstRowFunction) GetResult(evalCtx *AggEvaluateContext) types.Datum {
4545
return evalCtx.Value
4646
}
4747

48+
func (ff *firstRowFunction) ResetContext(_ *stmtctx.StatementContext, evalCtx *AggEvaluateContext) {
49+
evalCtx.GotFirstRow = false
50+
}
51+
4852
// GetPartialResult implements Aggregation interface.
4953
func (ff *firstRowFunction) GetPartialResult(evalCtx *AggEvaluateContext) []types.Datum {
5054
return []types.Datum{ff.GetResult(evalCtx)}

0 commit comments

Comments
 (0)