Skip to content

Commit 31d1947

Browse files
committed
feat: add Memory Protection Middleware
The commit introduces a new memory protection middleware to help prevent out-of-memory conditions in SpiceDB by implementing admission control based on current memory usage. This is not a perfect solution (doesn't prevent non-traffic-related sources of OOM) and is meant to support other future improvements to resource sharing in a single SpiceDB node. The middleware is installed both in the main api and in dispatch, but at different thresholds. Memory usage is polled in the background, and if in-flight memory rises above the threshold, backpressure is placed on incoming requests. The API threshold is higher than the dispatch threshold to preserve already admitted traffic as much as possible.
1 parent 9638aca commit 31d1947

File tree

11 files changed

+1056
-31
lines changed

11 files changed

+1056
-31
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ require (
309309
github.com/klauspost/cpuid/v2 v2.2.5 // indirect
310310
github.com/kulti/thelper v0.7.1 // indirect
311311
github.com/kunwardeep/paralleltest v1.0.14 // indirect
312+
github.com/kylelemons/godebug v1.1.0 // indirect
312313
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 // indirect
313314
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 // indirect
314315
github.com/lasiar/canonicalheader v1.1.2 // indirect
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
package memoryprotection
2+
3+
import (
4+
"context"
5+
"runtime/debug"
6+
"runtime/metrics"
7+
"strings"
8+
"sync/atomic"
9+
"time"
10+
11+
middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
12+
"github.com/prometheus/client_golang/prometheus"
13+
"github.com/prometheus/client_golang/prometheus/promauto"
14+
"google.golang.org/grpc"
15+
"google.golang.org/grpc/codes"
16+
"google.golang.org/grpc/status"
17+
18+
log "github.com/authzed/spicedb/internal/logging"
19+
)
20+
21+
var (
22+
// RejectedRequestsCounter tracks requests rejected due to memory pressure
23+
RejectedRequestsCounter = promauto.NewCounterVec(prometheus.CounterOpts{
24+
Namespace: "spicedb",
25+
Subsystem: "admission",
26+
Name: "memory_overload_rejected_requests_total",
27+
Help: "Total requests rejected due to memory pressure",
28+
}, []string{"endpoint"})
29+
30+
// MemoryUsageGauge tracks current memory usage percentage
31+
MemoryUsageGauge = promauto.NewGauge(prometheus.GaugeOpts{
32+
Namespace: "spicedb",
33+
Subsystem: "admission",
34+
Name: "memory_usage_percent",
35+
Help: "Current memory usage as percentage of GOMEMLIMIT",
36+
})
37+
)
38+
39+
// Config holds configuration for the memory protection middleware
40+
type Config struct {
41+
// ThresholdPercent is the memory usage threshold for requests (0-100)
42+
ThresholdPercent int
43+
// SampleIntervalSeconds controls how often memory usage is sampled
44+
SampleIntervalSeconds int
45+
}
46+
47+
// DefaultConfig returns reasonable default configuration for API requests
48+
func DefaultConfig() Config {
49+
return Config{
50+
ThresholdPercent: 90,
51+
SampleIntervalSeconds: 1,
52+
}
53+
}
54+
55+
// DefaultDispatchConfig returns reasonable default configuration for dispatch requests
56+
func DefaultDispatchConfig() Config {
57+
return Config{
58+
ThresholdPercent: 95,
59+
SampleIntervalSeconds: 1,
60+
}
61+
}
62+
63+
// AdmissionMiddleware implements memory-based admission control
64+
type AdmissionMiddleware struct {
65+
config Config
66+
memoryLimit int64
67+
lastMemoryUsage atomic.Int64
68+
metricsSamples []metrics.Sample
69+
ctx context.Context
70+
}
71+
72+
// New creates a new memory protection middleware with the given context
73+
func New(ctx context.Context, config Config) *AdmissionMiddleware {
74+
// Use the provided context directly
75+
mwCtx := ctx
76+
77+
// Get the current GOMEMLIMIT
78+
memoryLimit := debug.SetMemoryLimit(-1)
79+
if memoryLimit < 0 {
80+
// If no limit is set, we can't provide memory protection
81+
log.Info().Msg("GOMEMLIMIT not set, memory protection disabled")
82+
return &AdmissionMiddleware{
83+
config: config,
84+
memoryLimit: -1, // Disabled
85+
ctx: mwCtx,
86+
}
87+
}
88+
89+
// Check if memory protection is disabled via config
90+
if config.ThresholdPercent <= 0 {
91+
log.Info().Msg("memory protection disabled via configuration")
92+
return &AdmissionMiddleware{
93+
config: config,
94+
memoryLimit: -1, // Disabled
95+
ctx: mwCtx,
96+
}
97+
}
98+
99+
am := &AdmissionMiddleware{
100+
config: config,
101+
memoryLimit: memoryLimit,
102+
metricsSamples: []metrics.Sample{
103+
{Name: "/memory/classes/heap/objects:bytes"},
104+
},
105+
ctx: mwCtx,
106+
}
107+
108+
// Initialize with current memory usage
109+
if err := am.sampleMemory(); err != nil {
110+
log.Warn().Err(err).Msg("failed to get initial memory sample")
111+
}
112+
113+
// Start background sampling with context
114+
am.startBackgroundSampling()
115+
116+
log.Info().
117+
Int64("memory_limit_bytes", memoryLimit).
118+
Int("threshold_percent", config.ThresholdPercent).
119+
Int("sample_interval_seconds", config.SampleIntervalSeconds).
120+
Msg("memory protection middleware initialized with background sampling")
121+
122+
return am
123+
}
124+
125+
// UnaryServerInterceptor returns a unary server interceptor that implements admission control
126+
func (am *AdmissionMiddleware) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
127+
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
128+
if am.memoryLimit < 0 {
129+
// Memory protection is disabled
130+
return handler(ctx, req)
131+
}
132+
133+
if err := am.checkAdmission(info.FullMethod); err != nil {
134+
am.recordRejection(info.FullMethod)
135+
return nil, err
136+
}
137+
138+
return handler(ctx, req)
139+
}
140+
}
141+
142+
// StreamServerInterceptor returns a stream server interceptor that implements admission control
143+
func (am *AdmissionMiddleware) StreamServerInterceptor() grpc.StreamServerInterceptor {
144+
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
145+
if am.memoryLimit < 0 {
146+
// Memory protection is disabled
147+
return handler(srv, stream)
148+
}
149+
150+
if err := am.checkAdmission(info.FullMethod); err != nil {
151+
am.recordRejection(info.FullMethod)
152+
return err
153+
}
154+
155+
wrapped := middleware.WrapServerStream(stream)
156+
return handler(srv, wrapped)
157+
}
158+
}
159+
160+
// checkAdmission determines if a request should be admitted based on current memory usage
161+
func (am *AdmissionMiddleware) checkAdmission(fullMethod string) error {
162+
memoryUsage := am.getCurrentMemoryUsage()
163+
164+
usagePercent := float64(memoryUsage) / float64(am.memoryLimit) * 100
165+
166+
// Metrics gauge is updated in background sampling
167+
168+
if usagePercent > float64(am.config.ThresholdPercent) {
169+
log.Warn().
170+
Float64("memory_usage_percent", usagePercent).
171+
Int("threshold_percent", am.config.ThresholdPercent).
172+
Str("method", fullMethod).
173+
Msg("rejecting request due to memory pressure")
174+
175+
return status.Errorf(codes.ResourceExhausted,
176+
"server is experiencing memory pressure (%.1f%% usage, threshold: %d%%)",
177+
usagePercent, am.config.ThresholdPercent)
178+
}
179+
180+
return nil
181+
}
182+
183+
// startBackgroundSampling starts a background goroutine that samples memory usage periodically
184+
func (am *AdmissionMiddleware) startBackgroundSampling() {
185+
interval := time.Duration(am.config.SampleIntervalSeconds) * time.Second
186+
ticker := time.NewTicker(interval)
187+
188+
go func() {
189+
defer ticker.Stop()
190+
defer log.Debug().Msg("memory protection background sampling stopped")
191+
192+
log.Debug().
193+
Dur("interval", interval).
194+
Msg("memory protection background sampling started")
195+
196+
for {
197+
select {
198+
case <-ticker.C:
199+
if err := am.sampleMemory(); err != nil {
200+
log.Warn().Err(err).Msg("background memory sampling failed")
201+
}
202+
case <-am.ctx.Done():
203+
return
204+
}
205+
}
206+
}()
207+
}
208+
209+
// sampleMemory samples the current memory usage and updates the cached value
210+
func (am *AdmissionMiddleware) sampleMemory() error {
211+
defer func() {
212+
if r := recover(); r != nil {
213+
log.Warn().Interface("panic", r).Msg("memory sampling panicked")
214+
}
215+
}()
216+
217+
metrics.Read(am.metricsSamples)
218+
newUsage := int64(am.metricsSamples[0].Value.Uint64())
219+
am.lastMemoryUsage.Store(newUsage)
220+
221+
// Update metrics gauge
222+
if am.memoryLimit > 0 {
223+
usagePercent := float64(newUsage) / float64(am.memoryLimit) * 100
224+
MemoryUsageGauge.Set(usagePercent)
225+
}
226+
227+
return nil
228+
}
229+
230+
// getCurrentMemoryUsage returns the cached memory usage in bytes
231+
func (am *AdmissionMiddleware) getCurrentMemoryUsage() int64 {
232+
return am.lastMemoryUsage.Load()
233+
}
234+
235+
// recordRejection records metrics for rejected requests
236+
func (am *AdmissionMiddleware) recordRejection(fullMethod string) {
237+
endpointType := "api"
238+
if strings.HasPrefix(fullMethod, "/dispatch.v1.DispatchService") {
239+
endpointType = "dispatch"
240+
}
241+
242+
RejectedRequestsCounter.WithLabelValues(endpointType).Inc()
243+
}

0 commit comments

Comments
 (0)