Skip to content

Commit 7d7fb92

Browse files
committed
chore: improvements and fix building of dispatch middleware
1 parent 05bb075 commit 7d7fb92

File tree

15 files changed

+3182
-630
lines changed

15 files changed

+3182
-630
lines changed

.gitattributes

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
internal/mocks/*.go linguist-generated=true
2+
*.pb.go linguist-generated=true
3+
*.pb.*.go linguist-generated=true
4+
proto/internal/buf.lock linguist-generated=true

development/prometheus.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ global:
66
scrape_configs:
77
- job_name: "spicedb"
88
static_configs:
9-
- targets: ["spicedb:9090"]
9+
- targets: ["spicedb-1:9090"]
1010
labels:
11-
service: "spicedb"
11+
service: "spicedb-1"
12+
- targets: ["spicedb-2:9090"]
13+
labels:
14+
service: "spicedb-2"

go.mod

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ require (
112112
go.opentelemetry.io/otel/trace v1.38.0
113113
go.uber.org/atomic v1.11.0
114114
go.uber.org/goleak v1.3.0
115+
go.uber.org/mock v0.6.0
115116
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b
116117
golang.org/x/mod v0.28.0
117118
golang.org/x/sync v0.17.0
@@ -137,6 +138,8 @@ tool (
137138
github.com/golangci/golangci-lint/v2/cmd/golangci-lint
138139
// support running mage with go run mage.go
139140
github.com/magefile/mage/mage
141+
// mocks are generated with go:generate directives.
142+
go.uber.org/mock/mockgen
140143
// vulncheck always uses the current directory's go.mod.
141144
golang.org/x/vuln/cmd/govulncheck
142145
)

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2595,6 +2595,8 @@ go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwE
25952595
go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A=
25962596
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
25972597
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
2598+
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
2599+
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
25982600
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
25992601
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
26002602
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=

internal/dispatch/dispatch.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//go:generate go run go.uber.org/mock/mockgen -source dispatch.go -destination ../mocks/mock_dispatcher.go -package mocks Dispatcher
2+
13
package dispatch
24

35
import (

internal/middleware/memoryprotection/memoryprotection.go

Lines changed: 100 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,29 @@ import (
1818
log "github.com/authzed/spicedb/internal/logging"
1919
)
2020

21+
const DefaultSampleIntervalSeconds = 1
22+
2123
var (
2224
// RejectedRequestsCounter tracks requests rejected due to memory pressure
2325
RejectedRequestsCounter = promauto.NewCounterVec(prometheus.CounterOpts{
2426
Namespace: "spicedb",
25-
Subsystem: "admission",
26-
Name: "memory_overload_rejected_requests_total",
27+
Subsystem: "memory_admission",
28+
Name: "rejected_requests_total",
2729
Help: "Total requests rejected due to memory pressure",
2830
}, []string{"endpoint"})
2931

3032
// MemoryUsageGauge tracks current memory usage percentage
3133
MemoryUsageGauge = promauto.NewGauge(prometheus.GaugeOpts{
3234
Namespace: "spicedb",
33-
Subsystem: "admission",
35+
Subsystem: "memory_admission",
3436
Name: "memory_usage_percent",
3537
Help: "Current memory usage as percentage of GOMEMLIMIT",
3638
})
3739
)
3840

3941
// Config holds configuration for the memory protection middleware
4042
type Config struct {
41-
// ThresholdPercent is the memory usage threshold for requests (0-100)
43+
// ThresholdPercent is the memory usage threshold for requests. If zero or negative, this middleware has no effect
4244
ThresholdPercent int
4345
// SampleIntervalSeconds controls how often memory usage is sampled
4446
SampleIntervalSeconds int
@@ -60,56 +62,90 @@ func DefaultDispatchConfig() Config {
6062
}
6163
}
6264

63-
// AdmissionMiddleware implements memory-based admission control
64-
type AdmissionMiddleware struct {
65-
config Config
66-
memoryLimit int64
67-
lastMemoryUsage atomic.Int64
68-
metricsSamples []metrics.Sample
69-
ctx context.Context
65+
// MemoryLimitProvider gets and sets the limit of memory usage.
66+
// In production, use DefaultMemoryLimitProvider.
67+
// For testing, use HardCodedMemoryLimitProvider.
68+
type MemoryLimitProvider interface {
69+
Get() int64
70+
Set(int64)
71+
}
72+
73+
var (
74+
_ MemoryLimitProvider = (*DefaultMemoryLimitProvider)(nil)
75+
_ MemoryLimitProvider = (*HardCodedMemoryLimitProvider)(nil)
76+
)
77+
78+
type DefaultMemoryLimitProvider struct{}
79+
80+
func (p *DefaultMemoryLimitProvider) Get() int64 {
81+
// SetMemoryLimit returns the previously set memory limit.
82+
// A negative input does not adjust the limit, and allows for retrieval of the currently set memory limit
83+
return debug.SetMemoryLimit(-1)
84+
}
85+
86+
func (p *DefaultMemoryLimitProvider) Set(limit int64) {
87+
debug.SetMemoryLimit(limit)
88+
}
89+
90+
type HardCodedMemoryLimitProvider struct {
91+
Hardcodedlimit int64
7092
}
7193

72-
// New creates a new memory protection middleware with the given context
73-
func New(ctx context.Context, config Config) *AdmissionMiddleware {
74-
// Use the provided context directly
75-
mwCtx := ctx
94+
func (p *HardCodedMemoryLimitProvider) Get() int64 {
95+
return p.Hardcodedlimit
96+
}
97+
98+
func (p *HardCodedMemoryLimitProvider) Set(limit int64) {
99+
p.Hardcodedlimit = limit
100+
}
101+
102+
type MemoryAdmissionMiddleware struct {
103+
config Config
104+
memoryLimit int64 // -1 means no limit
105+
metricsSamples []metrics.Sample
106+
ctx context.Context // to stop the background process
107+
108+
lastMemorySampleInBytes *atomic.Uint64 // atomic because it's written inside a goroutine but can be read from anywhere
109+
timestampLastMemorySample *atomic.Pointer[time.Time]
110+
}
111+
112+
// New creates a new memory admission middleware with the given context.
113+
// Whe the context is cancelled, this middleware stops its background processing.
114+
func New(ctx context.Context, config Config, limitProvider MemoryLimitProvider) MemoryAdmissionMiddleware {
115+
am := MemoryAdmissionMiddleware{
116+
config: config,
117+
lastMemorySampleInBytes: &atomic.Uint64{},
118+
timestampLastMemorySample: &atomic.Pointer[time.Time]{},
119+
memoryLimit: -1, // disabled initially
120+
ctx: ctx,
121+
}
76122

77123
// Get the current GOMEMLIMIT
78-
memoryLimit := debug.SetMemoryLimit(-1)
124+
memoryLimit := limitProvider.Get()
79125
if memoryLimit < 0 {
80126
// If no limit is set, we can't provide memory protection
81127
log.Info().Msg("GOMEMLIMIT not set, memory protection disabled")
82-
return &AdmissionMiddleware{
83-
config: config,
84-
memoryLimit: -1, // Disabled
85-
ctx: mwCtx,
86-
}
128+
return am
87129
}
88130

89-
// Check if memory protection is disabled via config
90131
if config.ThresholdPercent <= 0 {
91132
log.Info().Msg("memory protection disabled via configuration")
92-
return &AdmissionMiddleware{
93-
config: config,
94-
memoryLimit: -1, // Disabled
95-
ctx: mwCtx,
96-
}
133+
return am
97134
}
98135

99-
am := &AdmissionMiddleware{
100-
config: config,
101-
memoryLimit: memoryLimit,
102-
metricsSamples: []metrics.Sample{
103-
{Name: "/memory/classes/heap/objects:bytes"},
104-
},
105-
ctx: mwCtx,
136+
if config.SampleIntervalSeconds <= 0 {
137+
log.Info().Msgf("memory protection sample interval cannot be zero or negative; using default value of %q seconds", DefaultSampleIntervalSeconds)
138+
am.config.SampleIntervalSeconds = DefaultSampleIntervalSeconds
106139
}
107140

108-
// Initialize with current memory usage
109-
if err := am.sampleMemory(); err != nil {
110-
log.Warn().Err(err).Msg("failed to get initial memory sample")
141+
am.memoryLimit = memoryLimit
142+
am.metricsSamples = []metrics.Sample{
143+
{Name: "/memory/classes/heap/objects:bytes"},
111144
}
112145

146+
// Initialize with current memory usage
147+
am.sampleMemory()
148+
113149
// Start background sampling with context
114150
am.startBackgroundSampling()
115151

@@ -122,9 +158,9 @@ func New(ctx context.Context, config Config) *AdmissionMiddleware {
122158
return am
123159
}
124160

125-
// UnaryServerInterceptor returns a unary server interceptor that implements admission control
126-
func (am *AdmissionMiddleware) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
127-
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
161+
// UnaryServerInterceptor returns a unary server interceptor that rejects incoming requests is memory usage is too high
162+
func (am *MemoryAdmissionMiddleware) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
163+
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
128164
if am.memoryLimit < 0 {
129165
// Memory protection is disabled
130166
return handler(ctx, req)
@@ -139,9 +175,9 @@ func (am *AdmissionMiddleware) UnaryServerInterceptor() grpc.UnaryServerIntercep
139175
}
140176
}
141177

142-
// StreamServerInterceptor returns a stream server interceptor that implements admission control
143-
func (am *AdmissionMiddleware) StreamServerInterceptor() grpc.StreamServerInterceptor {
144-
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
178+
// StreamServerInterceptor returns a stream server interceptor that rejects incoming requests is memory usage is too high
179+
func (am *MemoryAdmissionMiddleware) StreamServerInterceptor() grpc.StreamServerInterceptor {
180+
return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
145181
if am.memoryLimit < 0 {
146182
// Memory protection is disabled
147183
return handler(srv, stream)
@@ -158,8 +194,8 @@ func (am *AdmissionMiddleware) StreamServerInterceptor() grpc.StreamServerInterc
158194
}
159195

160196
// checkAdmission determines if a request should be admitted based on current memory usage
161-
func (am *AdmissionMiddleware) checkAdmission(fullMethod string) error {
162-
memoryUsage := am.getCurrentMemoryUsage()
197+
func (am *MemoryAdmissionMiddleware) checkAdmission(fullMethod string) error {
198+
memoryUsage := am.getLastMemorySampleInBytes()
163199

164200
usagePercent := float64(memoryUsage) / float64(am.memoryLimit) * 100
165201

@@ -181,24 +217,23 @@ func (am *AdmissionMiddleware) checkAdmission(fullMethod string) error {
181217
}
182218

183219
// startBackgroundSampling starts a background goroutine that samples memory usage periodically
184-
func (am *AdmissionMiddleware) startBackgroundSampling() {
220+
func (am *MemoryAdmissionMiddleware) startBackgroundSampling() {
185221
interval := time.Duration(am.config.SampleIntervalSeconds) * time.Second
186222
ticker := time.NewTicker(interval)
187223

188224
go func() {
189225
defer ticker.Stop()
190-
defer log.Debug().Msg("memory protection background sampling stopped")
191-
192-
log.Debug().
193-
Dur("interval", interval).
194-
Msg("memory protection background sampling started")
226+
// TODO this code might start running before the logger is setup, therefore we have a data race
227+
//defer log.Debug().Msg("memory protection background sampling stopped")
228+
//
229+
//log.Debug().
230+
// Dur("interval", interval).
231+
// Msg("memory protection background sampling started")
195232

196233
for {
197234
select {
198235
case <-ticker.C:
199-
if err := am.sampleMemory(); err != nil {
200-
log.Warn().Err(err).Msg("background memory sampling failed")
201-
}
236+
am.sampleMemory()
202237
case <-am.ctx.Done():
203238
return
204239
}
@@ -207,33 +242,36 @@ func (am *AdmissionMiddleware) startBackgroundSampling() {
207242
}
208243

209244
// sampleMemory samples the current memory usage and updates the cached value
210-
func (am *AdmissionMiddleware) sampleMemory() error {
245+
func (am *MemoryAdmissionMiddleware) sampleMemory() {
211246
defer func() {
212247
if r := recover(); r != nil {
213248
log.Warn().Interface("panic", r).Msg("memory sampling panicked")
214249
}
215250
}()
216251

252+
now := time.Now()
217253
metrics.Read(am.metricsSamples)
218-
newUsage := int64(am.metricsSamples[0].Value.Uint64())
219-
am.lastMemoryUsage.Store(newUsage)
254+
newUsage := am.metricsSamples[0].Value.Uint64()
255+
am.lastMemorySampleInBytes.Store(newUsage)
256+
am.timestampLastMemorySample.Store(&now)
220257

221258
// Update metrics gauge
222259
if am.memoryLimit > 0 {
223260
usagePercent := float64(newUsage) / float64(am.memoryLimit) * 100
224261
MemoryUsageGauge.Set(usagePercent)
225262
}
263+
}
226264

227-
return nil
265+
func (am *MemoryAdmissionMiddleware) getLastMemorySampleInBytes() uint64 {
266+
return am.lastMemorySampleInBytes.Load()
228267
}
229268

230-
// getCurrentMemoryUsage returns the cached memory usage in bytes
231-
func (am *AdmissionMiddleware) getCurrentMemoryUsage() int64 {
232-
return am.lastMemoryUsage.Load()
269+
func (am *MemoryAdmissionMiddleware) getTimestampLastMemorySample() *time.Time {
270+
return am.timestampLastMemorySample.Load()
233271
}
234272

235273
// recordRejection records metrics for rejected requests
236-
func (am *AdmissionMiddleware) recordRejection(fullMethod string) {
274+
func (am *MemoryAdmissionMiddleware) recordRejection(fullMethod string) {
237275
endpointType := "api"
238276
if strings.HasPrefix(fullMethod, "/dispatch.v1.DispatchService") {
239277
endpointType = "dispatch"

0 commit comments

Comments
 (0)