Skip to content

Commit d65a7c0

Browse files
committed
chore: improvements and fix building of dispatch middleware
1 parent 05bb075 commit d65a7c0

File tree

5 files changed

+311
-627
lines changed

5 files changed

+311
-627
lines changed

internal/middleware/memoryprotection/memoryprotection.go

Lines changed: 92 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ var (
2222
// RejectedRequestsCounter tracks requests rejected due to memory pressure
2323
RejectedRequestsCounter = promauto.NewCounterVec(prometheus.CounterOpts{
2424
Namespace: "spicedb",
25-
Subsystem: "admission",
26-
Name: "memory_overload_rejected_requests_total",
25+
Subsystem: "memory_admission",
26+
Name: "rejected_requests_total",
2727
Help: "Total requests rejected due to memory pressure",
2828
}, []string{"endpoint"})
2929

3030
// MemoryUsageGauge tracks current memory usage percentage
3131
MemoryUsageGauge = promauto.NewGauge(prometheus.GaugeOpts{
3232
Namespace: "spicedb",
33-
Subsystem: "admission",
33+
Subsystem: "memory_admission",
3434
Name: "memory_usage_percent",
3535
Help: "Current memory usage as percentage of GOMEMLIMIT",
3636
})
@@ -60,55 +60,84 @@ func DefaultDispatchConfig() Config {
6060
}
6161
}
6262

63-
// AdmissionMiddleware implements memory-based admission control
64-
type AdmissionMiddleware struct {
65-
config Config
66-
memoryLimit int64
67-
lastMemoryUsage atomic.Int64
68-
metricsSamples []metrics.Sample
69-
ctx context.Context
63+
// MemoryLimitProvider gets and sets the limit of memory usage.
64+
// In production, use DefaultMemoryLimitProvider.
65+
// For testing, use HardCodedMemoryLimitProvider.
66+
type MemoryLimitProvider interface {
67+
Get() int64
68+
Set(int64)
7069
}
7170

72-
// New creates a new memory protection middleware with the given context
73-
func New(ctx context.Context, config Config) *AdmissionMiddleware {
74-
// Use the provided context directly
75-
mwCtx := ctx
71+
var (
72+
_ MemoryLimitProvider = (*DefaultMemoryLimitProvider)(nil)
73+
_ MemoryLimitProvider = (*HardCodedMemoryLimitProvider)(nil)
74+
)
75+
76+
type DefaultMemoryLimitProvider struct{}
77+
78+
func (p *DefaultMemoryLimitProvider) Get() int64 {
79+
// SetMemoryLimit returns the previously set memory limit.
80+
// A negative input does not adjust the limit, and allows for retrieval of the currently set memory limit
81+
return debug.SetMemoryLimit(-1)
82+
}
83+
84+
func (p *DefaultMemoryLimitProvider) Set(limit int64) {
85+
debug.SetMemoryLimit(limit)
86+
}
87+
88+
type HardCodedMemoryLimitProvider struct {
89+
Hardcodedlimit int64
90+
}
91+
92+
func (p *HardCodedMemoryLimitProvider) Get() int64 {
93+
return p.Hardcodedlimit
94+
}
95+
96+
func (p *HardCodedMemoryLimitProvider) Set(limit int64) {
97+
p.Hardcodedlimit = limit
98+
}
99+
100+
type MemoryAdmissionMiddleware struct {
101+
config Config
102+
memoryLimit int64 // -1 means no limit
103+
metricsSamples []metrics.Sample
104+
ctx context.Context // to stop the background process
105+
106+
lastMemorySampleInBytes *atomic.Uint64 // atomic because it's written inside a goroutine but can be read from anywhere
107+
timestampLastMemorySample *atomic.Pointer[time.Time]
108+
}
109+
110+
// New creates a new memory admission middleware with the given context.
111+
// Whe the context is cancelled, this middleware stops its background processing.
112+
func New(ctx context.Context, config Config, limitProvider MemoryLimitProvider) MemoryAdmissionMiddleware {
113+
am := MemoryAdmissionMiddleware{
114+
config: config,
115+
lastMemorySampleInBytes: &atomic.Uint64{},
116+
timestampLastMemorySample: &atomic.Pointer[time.Time]{},
117+
memoryLimit: -1, // disabled initially
118+
ctx: ctx,
119+
}
76120

77121
// Get the current GOMEMLIMIT
78-
memoryLimit := debug.SetMemoryLimit(-1)
122+
memoryLimit := limitProvider.Get()
79123
if memoryLimit < 0 {
80124
// If no limit is set, we can't provide memory protection
81125
log.Info().Msg("GOMEMLIMIT not set, memory protection disabled")
82-
return &AdmissionMiddleware{
83-
config: config,
84-
memoryLimit: -1, // Disabled
85-
ctx: mwCtx,
86-
}
126+
return am
87127
}
88128

89-
// Check if memory protection is disabled via config
90129
if config.ThresholdPercent <= 0 {
91130
log.Info().Msg("memory protection disabled via configuration")
92-
return &AdmissionMiddleware{
93-
config: config,
94-
memoryLimit: -1, // Disabled
95-
ctx: mwCtx,
96-
}
131+
return am
97132
}
98133

99-
am := &AdmissionMiddleware{
100-
config: config,
101-
memoryLimit: memoryLimit,
102-
metricsSamples: []metrics.Sample{
103-
{Name: "/memory/classes/heap/objects:bytes"},
104-
},
105-
ctx: mwCtx,
134+
am.memoryLimit = memoryLimit
135+
am.metricsSamples = []metrics.Sample{
136+
{Name: "/memory/classes/heap/objects:bytes"},
106137
}
107138

108139
// Initialize with current memory usage
109-
if err := am.sampleMemory(); err != nil {
110-
log.Warn().Err(err).Msg("failed to get initial memory sample")
111-
}
140+
am.sampleMemory()
112141

113142
// Start background sampling with context
114143
am.startBackgroundSampling()
@@ -122,9 +151,9 @@ func New(ctx context.Context, config Config) *AdmissionMiddleware {
122151
return am
123152
}
124153

125-
// UnaryServerInterceptor returns a unary server interceptor that implements admission control
126-
func (am *AdmissionMiddleware) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
127-
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
154+
// UnaryServerInterceptor returns a unary server interceptor that rejects incoming requests is memory usage is too high
155+
func (am *MemoryAdmissionMiddleware) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
156+
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
128157
if am.memoryLimit < 0 {
129158
// Memory protection is disabled
130159
return handler(ctx, req)
@@ -139,9 +168,9 @@ func (am *AdmissionMiddleware) UnaryServerInterceptor() grpc.UnaryServerIntercep
139168
}
140169
}
141170

142-
// StreamServerInterceptor returns a stream server interceptor that implements admission control
143-
func (am *AdmissionMiddleware) StreamServerInterceptor() grpc.StreamServerInterceptor {
144-
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
171+
// StreamServerInterceptor returns a stream server interceptor that rejects incoming requests is memory usage is too high
172+
func (am *MemoryAdmissionMiddleware) StreamServerInterceptor() grpc.StreamServerInterceptor {
173+
return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
145174
if am.memoryLimit < 0 {
146175
// Memory protection is disabled
147176
return handler(srv, stream)
@@ -158,8 +187,8 @@ func (am *AdmissionMiddleware) StreamServerInterceptor() grpc.StreamServerInterc
158187
}
159188

160189
// checkAdmission determines if a request should be admitted based on current memory usage
161-
func (am *AdmissionMiddleware) checkAdmission(fullMethod string) error {
162-
memoryUsage := am.getCurrentMemoryUsage()
190+
func (am *MemoryAdmissionMiddleware) checkAdmission(fullMethod string) error {
191+
memoryUsage := am.getLastMemorySampleInBytes()
163192

164193
usagePercent := float64(memoryUsage) / float64(am.memoryLimit) * 100
165194

@@ -181,24 +210,23 @@ func (am *AdmissionMiddleware) checkAdmission(fullMethod string) error {
181210
}
182211

183212
// startBackgroundSampling starts a background goroutine that samples memory usage periodically
184-
func (am *AdmissionMiddleware) startBackgroundSampling() {
213+
func (am *MemoryAdmissionMiddleware) startBackgroundSampling() {
185214
interval := time.Duration(am.config.SampleIntervalSeconds) * time.Second
186215
ticker := time.NewTicker(interval)
187216

188217
go func() {
189218
defer ticker.Stop()
190-
defer log.Debug().Msg("memory protection background sampling stopped")
191-
192-
log.Debug().
193-
Dur("interval", interval).
194-
Msg("memory protection background sampling started")
219+
// TODO this code might start running before the logger is setup, therefore we have a data race
220+
//defer log.Debug().Msg("memory protection background sampling stopped")
221+
//
222+
//log.Debug().
223+
// Dur("interval", interval).
224+
// Msg("memory protection background sampling started")
195225

196226
for {
197227
select {
198228
case <-ticker.C:
199-
if err := am.sampleMemory(); err != nil {
200-
log.Warn().Err(err).Msg("background memory sampling failed")
201-
}
229+
am.sampleMemory()
202230
case <-am.ctx.Done():
203231
return
204232
}
@@ -207,33 +235,36 @@ func (am *AdmissionMiddleware) startBackgroundSampling() {
207235
}
208236

209237
// sampleMemory samples the current memory usage and updates the cached value
210-
func (am *AdmissionMiddleware) sampleMemory() error {
238+
func (am *MemoryAdmissionMiddleware) sampleMemory() {
211239
defer func() {
212240
if r := recover(); r != nil {
213241
log.Warn().Interface("panic", r).Msg("memory sampling panicked")
214242
}
215243
}()
216244

245+
now := time.Now()
217246
metrics.Read(am.metricsSamples)
218-
newUsage := int64(am.metricsSamples[0].Value.Uint64())
219-
am.lastMemoryUsage.Store(newUsage)
247+
newUsage := am.metricsSamples[0].Value.Uint64()
248+
am.lastMemorySampleInBytes.Store(newUsage)
249+
am.timestampLastMemorySample.Store(&now)
220250

221251
// Update metrics gauge
222252
if am.memoryLimit > 0 {
223253
usagePercent := float64(newUsage) / float64(am.memoryLimit) * 100
224254
MemoryUsageGauge.Set(usagePercent)
225255
}
256+
}
226257

227-
return nil
258+
func (am *MemoryAdmissionMiddleware) getLastMemorySampleInBytes() uint64 {
259+
return am.lastMemorySampleInBytes.Load()
228260
}
229261

230-
// getCurrentMemoryUsage returns the cached memory usage in bytes
231-
func (am *AdmissionMiddleware) getCurrentMemoryUsage() int64 {
232-
return am.lastMemoryUsage.Load()
262+
func (am *MemoryAdmissionMiddleware) getTimestampLastMemorySample() *time.Time {
263+
return am.timestampLastMemorySample.Load()
233264
}
234265

235266
// recordRejection records metrics for rejected requests
236-
func (am *AdmissionMiddleware) recordRejection(fullMethod string) {
267+
func (am *MemoryAdmissionMiddleware) recordRejection(fullMethod string) {
237268
endpointType := "api"
238269
if strings.HasPrefix(fullMethod, "/dispatch.v1.DispatchService") {
239270
endpointType = "dispatch"

0 commit comments

Comments
 (0)