Skip to content

Commit 9ff3096

Browse files
committed
chore: improvements and fix building of dispatch middleware
1 parent 05bb075 commit 9ff3096

File tree

16 files changed

+3209
-699
lines changed

16 files changed

+3209
-699
lines changed

.gitattributes

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
internal/mocks/*.go linguist-generated=true
2+
*.pb.go linguist-generated=true
3+
*.pb.*.go linguist-generated=true
4+
proto/internal/buf.lock linguist-generated=true

development/prometheus.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ global:
66
scrape_configs:
77
- job_name: "spicedb"
88
static_configs:
9-
- targets: ["spicedb:9090"]
9+
- targets: ["spicedb-1:9090"]
1010
labels:
11-
service: "spicedb"
11+
service: "spicedb-1"
12+
- targets: ["spicedb-2:9090"]
13+
labels:
14+
service: "spicedb-2"

go.mod

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ require (
112112
go.opentelemetry.io/otel/trace v1.38.0
113113
go.uber.org/atomic v1.11.0
114114
go.uber.org/goleak v1.3.0
115+
go.uber.org/mock v0.6.0
115116
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b
116117
golang.org/x/mod v0.28.0
117118
golang.org/x/sync v0.17.0
@@ -137,6 +138,8 @@ tool (
137138
github.com/golangci/golangci-lint/v2/cmd/golangci-lint
138139
// support running mage with go run mage.go
139140
github.com/magefile/mage/mage
141+
// mocks are generated with go:generate directives.
142+
go.uber.org/mock/mockgen
140143
// vulncheck always uses the current directory's go.mod.
141144
golang.org/x/vuln/cmd/govulncheck
142145
)

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2595,6 +2595,8 @@ go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwE
25952595
go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A=
25962596
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
25972597
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
2598+
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
2599+
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
25982600
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
25992601
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
26002602
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=

internal/dispatch/dispatch.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//go:generate go run go.uber.org/mock/mockgen -source dispatch.go -destination ../mocks/mock_dispatcher.go -package mocks Dispatcher
2+
13
package dispatch
24

35
import (

internal/middleware/memoryprotection/memoryprotection.go

Lines changed: 114 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,29 @@ import (
1818
log "github.com/authzed/spicedb/internal/logging"
1919
)
2020

21+
const DefaultSampleIntervalSeconds = 1
22+
2123
var (
2224
// RejectedRequestsCounter tracks requests rejected due to memory pressure
2325
RejectedRequestsCounter = promauto.NewCounterVec(prometheus.CounterOpts{
2426
Namespace: "spicedb",
25-
Subsystem: "admission",
26-
Name: "memory_overload_rejected_requests_total",
27+
Subsystem: "memory_admission",
28+
Name: "rejected_requests_total",
2729
Help: "Total requests rejected due to memory pressure",
2830
}, []string{"endpoint"})
2931

3032
// MemoryUsageGauge tracks current memory usage percentage
3133
MemoryUsageGauge = promauto.NewGauge(prometheus.GaugeOpts{
3234
Namespace: "spicedb",
33-
Subsystem: "admission",
35+
Subsystem: "memory_admission",
3436
Name: "memory_usage_percent",
3537
Help: "Current memory usage as percentage of GOMEMLIMIT",
3638
})
3739
)
3840

3941
// Config holds configuration for the memory protection middleware
4042
type Config struct {
41-
// ThresholdPercent is the memory usage threshold for requests (0-100)
43+
// ThresholdPercent is the memory usage threshold for requests. If zero or negative, this middleware has no effect
4244
ThresholdPercent int
4345
// SampleIntervalSeconds controls how often memory usage is sampled
4446
SampleIntervalSeconds int
@@ -60,60 +62,94 @@ func DefaultDispatchConfig() Config {
6062
}
6163
}
6264

63-
// AdmissionMiddleware implements memory-based admission control
64-
type AdmissionMiddleware struct {
65-
config Config
66-
memoryLimit int64
67-
lastMemoryUsage atomic.Int64
68-
metricsSamples []metrics.Sample
69-
ctx context.Context
65+
// MemoryLimitProvider gets and sets the limit of memory usage.
66+
// In production, use DefaultMemoryLimitProvider.
67+
// For testing, use HardCodedMemoryLimitProvider.
68+
type MemoryLimitProvider interface {
69+
Get() int64
70+
Set(int64)
71+
}
72+
73+
var (
74+
_ MemoryLimitProvider = (*DefaultMemoryLimitProvider)(nil)
75+
_ MemoryLimitProvider = (*HardCodedMemoryLimitProvider)(nil)
76+
)
77+
78+
type DefaultMemoryLimitProvider struct{}
79+
80+
func (p *DefaultMemoryLimitProvider) Get() int64 {
81+
// SetMemoryLimit returns the previously set memory limit.
82+
// A negative input does not adjust the limit, and allows for retrieval of the currently set memory limit
83+
return debug.SetMemoryLimit(-1)
84+
}
85+
86+
func (p *DefaultMemoryLimitProvider) Set(limit int64) {
87+
debug.SetMemoryLimit(limit)
88+
}
89+
90+
type HardCodedMemoryLimitProvider struct {
91+
Hardcodedlimit int64
92+
}
93+
94+
func (p *HardCodedMemoryLimitProvider) Get() int64 {
95+
return p.Hardcodedlimit
96+
}
97+
98+
func (p *HardCodedMemoryLimitProvider) Set(limit int64) {
99+
p.Hardcodedlimit = limit
100+
}
101+
102+
type MemoryAdmissionMiddleware struct {
103+
config Config
104+
memoryLimit int64 // -1 means no limit
105+
metricsSamples []metrics.Sample
106+
ctx context.Context // to stop the background process
107+
108+
lastMemorySampleInBytes *atomic.Uint64 // atomic because it's written inside a goroutine but can be read from anywhere
109+
timestampLastMemorySample *atomic.Pointer[time.Time]
70110
}
71111

72-
// New creates a new memory protection middleware with the given context
73-
func New(ctx context.Context, config Config) *AdmissionMiddleware {
74-
// Use the provided context directly
75-
mwCtx := ctx
112+
// New creates a new memory admission middleware with the given context.
113+
// Whe the context is cancelled, this middleware stops its background processing.
114+
func New(ctx context.Context, config Config, limitProvider MemoryLimitProvider, name string) MemoryAdmissionMiddleware {
115+
am := MemoryAdmissionMiddleware{
116+
config: config,
117+
lastMemorySampleInBytes: &atomic.Uint64{},
118+
timestampLastMemorySample: &atomic.Pointer[time.Time]{},
119+
memoryLimit: -1, // disabled initially
120+
ctx: ctx,
121+
}
76122

77123
// Get the current GOMEMLIMIT
78-
memoryLimit := debug.SetMemoryLimit(-1)
124+
memoryLimit := limitProvider.Get()
79125
if memoryLimit < 0 {
80-
// If no limit is set, we can't provide memory protection
81-
log.Info().Msg("GOMEMLIMIT not set, memory protection disabled")
82-
return &AdmissionMiddleware{
83-
config: config,
84-
memoryLimit: -1, // Disabled
85-
ctx: mwCtx,
86-
}
126+
log.Info().Str("name", name).Msg("GOMEMLIMIT not set, memory protection disabled")
127+
return am
87128
}
88129

89-
// Check if memory protection is disabled via config
90130
if config.ThresholdPercent <= 0 {
91-
log.Info().Msg("memory protection disabled via configuration")
92-
return &AdmissionMiddleware{
93-
config: config,
94-
memoryLimit: -1, // Disabled
95-
ctx: mwCtx,
96-
}
131+
log.Info().Str("name", name).Msg("threshold is non-positive; memory protection disabled")
132+
return am
97133
}
98134

99-
am := &AdmissionMiddleware{
100-
config: config,
101-
memoryLimit: memoryLimit,
102-
metricsSamples: []metrics.Sample{
103-
{Name: "/memory/classes/heap/objects:bytes"},
104-
},
105-
ctx: mwCtx,
135+
if config.SampleIntervalSeconds <= 0 {
136+
log.Info().Str("name", name).Msgf("memory protection sample interval cannot be zero or negative; using default value of %q seconds", DefaultSampleIntervalSeconds)
137+
am.config.SampleIntervalSeconds = DefaultSampleIntervalSeconds
106138
}
107139

108-
// Initialize with current memory usage
109-
if err := am.sampleMemory(); err != nil {
110-
log.Warn().Err(err).Msg("failed to get initial memory sample")
140+
am.memoryLimit = memoryLimit
141+
am.metricsSamples = []metrics.Sample{
142+
{Name: "/memory/classes/heap/objects:bytes"},
111143
}
112144

145+
// Initialize with current memory usage
146+
am.sampleMemory()
147+
113148
// Start background sampling with context
114149
am.startBackgroundSampling()
115150

116151
log.Info().
152+
Str("name", name).
117153
Int64("memory_limit_bytes", memoryLimit).
118154
Int("threshold_percent", config.ThresholdPercent).
119155
Int("sample_interval_seconds", config.SampleIntervalSeconds).
@@ -122,33 +158,23 @@ func New(ctx context.Context, config Config) *AdmissionMiddleware {
122158
return am
123159
}
124160

125-
// UnaryServerInterceptor returns a unary server interceptor that implements admission control
126-
func (am *AdmissionMiddleware) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
127-
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
128-
if am.memoryLimit < 0 {
129-
// Memory protection is disabled
130-
return handler(ctx, req)
131-
}
132-
133-
if err := am.checkAdmission(info.FullMethod); err != nil {
134-
am.recordRejection(info.FullMethod)
161+
// UnaryServerInterceptor returns a unary server interceptor that rejects incoming requests is memory usage is too high
162+
func (am *MemoryAdmissionMiddleware) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
163+
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
164+
if err := am.checkAdmission(); err != nil {
165+
_ = am.recordRejection(info.FullMethod)
135166
return nil, err
136167
}
137168

138169
return handler(ctx, req)
139170
}
140171
}
141172

142-
// StreamServerInterceptor returns a stream server interceptor that implements admission control
143-
func (am *AdmissionMiddleware) StreamServerInterceptor() grpc.StreamServerInterceptor {
144-
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
145-
if am.memoryLimit < 0 {
146-
// Memory protection is disabled
147-
return handler(srv, stream)
148-
}
149-
150-
if err := am.checkAdmission(info.FullMethod); err != nil {
151-
am.recordRejection(info.FullMethod)
173+
// StreamServerInterceptor returns a stream server interceptor that rejects incoming requests is memory usage is too high
174+
func (am *MemoryAdmissionMiddleware) StreamServerInterceptor() grpc.StreamServerInterceptor {
175+
return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
176+
if err := am.checkAdmission(); err != nil {
177+
_ = am.recordRejection(info.FullMethod)
152178
return err
153179
}
154180

@@ -157,21 +183,20 @@ func (am *AdmissionMiddleware) StreamServerInterceptor() grpc.StreamServerInterc
157183
}
158184
}
159185

160-
// checkAdmission determines if a request should be admitted based on current memory usage
161-
func (am *AdmissionMiddleware) checkAdmission(fullMethod string) error {
162-
memoryUsage := am.getCurrentMemoryUsage()
186+
// checkAdmission returns an error if the request should be denied because memory usage is too high.
187+
func (am *MemoryAdmissionMiddleware) checkAdmission() error {
188+
if am.memoryLimit < 0 || am.config.ThresholdPercent <= 0 {
189+
// Memory protection is disabled
190+
return nil
191+
}
192+
193+
memoryUsage := am.getLastMemorySampleInBytes()
163194

164195
usagePercent := float64(memoryUsage) / float64(am.memoryLimit) * 100
165196

166197
// Metrics gauge is updated in background sampling
167198

168199
if usagePercent > float64(am.config.ThresholdPercent) {
169-
log.Warn().
170-
Float64("memory_usage_percent", usagePercent).
171-
Int("threshold_percent", am.config.ThresholdPercent).
172-
Str("method", fullMethod).
173-
Msg("rejecting request due to memory pressure")
174-
175200
return status.Errorf(codes.ResourceExhausted,
176201
"server is experiencing memory pressure (%.1f%% usage, threshold: %d%%)",
177202
usagePercent, am.config.ThresholdPercent)
@@ -181,24 +206,23 @@ func (am *AdmissionMiddleware) checkAdmission(fullMethod string) error {
181206
}
182207

183208
// startBackgroundSampling starts a background goroutine that samples memory usage periodically
184-
func (am *AdmissionMiddleware) startBackgroundSampling() {
209+
func (am *MemoryAdmissionMiddleware) startBackgroundSampling() {
185210
interval := time.Duration(am.config.SampleIntervalSeconds) * time.Second
186211
ticker := time.NewTicker(interval)
187212

188213
go func() {
189214
defer ticker.Stop()
190-
defer log.Debug().Msg("memory protection background sampling stopped")
191-
192-
log.Debug().
193-
Dur("interval", interval).
194-
Msg("memory protection background sampling started")
215+
// TODO this code might start running before the logger is setup, therefore we have a data race with the logger object :/
216+
//defer log.Debug().Msg("memory protection background sampling stopped")
217+
//
218+
//log.Debug().
219+
// Dur("interval", interval).
220+
// Msg("memory protection background sampling started")
195221

196222
for {
197223
select {
198224
case <-ticker.C:
199-
if err := am.sampleMemory(); err != nil {
200-
log.Warn().Err(err).Msg("background memory sampling failed")
201-
}
225+
am.sampleMemory()
202226
case <-am.ctx.Done():
203227
return
204228
}
@@ -207,37 +231,41 @@ func (am *AdmissionMiddleware) startBackgroundSampling() {
207231
}
208232

209233
// sampleMemory samples the current memory usage and updates the cached value
210-
func (am *AdmissionMiddleware) sampleMemory() error {
234+
func (am *MemoryAdmissionMiddleware) sampleMemory() {
211235
defer func() {
212236
if r := recover(); r != nil {
213237
log.Warn().Interface("panic", r).Msg("memory sampling panicked")
214238
}
215239
}()
216240

241+
now := time.Now()
217242
metrics.Read(am.metricsSamples)
218-
newUsage := int64(am.metricsSamples[0].Value.Uint64())
219-
am.lastMemoryUsage.Store(newUsage)
243+
newUsage := am.metricsSamples[0].Value.Uint64()
244+
am.lastMemorySampleInBytes.Store(newUsage)
245+
am.timestampLastMemorySample.Store(&now)
220246

221247
// Update metrics gauge
222248
if am.memoryLimit > 0 {
223249
usagePercent := float64(newUsage) / float64(am.memoryLimit) * 100
224250
MemoryUsageGauge.Set(usagePercent)
225251
}
252+
}
226253

227-
return nil
254+
func (am *MemoryAdmissionMiddleware) getLastMemorySampleInBytes() uint64 {
255+
return am.lastMemorySampleInBytes.Load()
228256
}
229257

230-
// getCurrentMemoryUsage returns the cached memory usage in bytes
231-
func (am *AdmissionMiddleware) getCurrentMemoryUsage() int64 {
232-
return am.lastMemoryUsage.Load()
258+
func (am *MemoryAdmissionMiddleware) getTimestampLastMemorySample() *time.Time {
259+
return am.timestampLastMemorySample.Load()
233260
}
234261

235262
// recordRejection records metrics for rejected requests
236-
func (am *AdmissionMiddleware) recordRejection(fullMethod string) {
263+
func (am *MemoryAdmissionMiddleware) recordRejection(fullMethod string) string {
237264
endpointType := "api"
238265
if strings.HasPrefix(fullMethod, "/dispatch.v1.DispatchService") {
239266
endpointType = "dispatch"
240267
}
241268

242269
RejectedRequestsCounter.WithLabelValues(endpointType).Inc()
270+
return endpointType
243271
}

0 commit comments

Comments
 (0)