@@ -18,27 +18,29 @@ import (
1818 log "github.com/authzed/spicedb/internal/logging"
1919)
2020
21+ const DefaultSampleIntervalSeconds = 1
22+
2123var (
2224 // RejectedRequestsCounter tracks requests rejected due to memory pressure
2325 RejectedRequestsCounter = promauto .NewCounterVec (prometheus.CounterOpts {
2426 Namespace : "spicedb" ,
25- Subsystem : "admission " ,
26- Name : "memory_overload_rejected_requests_total " ,
27+ Subsystem : "memory_admission " ,
28+ Name : "rejected_requests_total " ,
2729 Help : "Total requests rejected due to memory pressure" ,
2830 }, []string {"endpoint" })
2931
3032 // MemoryUsageGauge tracks current memory usage percentage
3133 MemoryUsageGauge = promauto .NewGauge (prometheus.GaugeOpts {
3234 Namespace : "spicedb" ,
33- Subsystem : "admission " ,
35+ Subsystem : "memory_admission " ,
3436 Name : "memory_usage_percent" ,
3537 Help : "Current memory usage as percentage of GOMEMLIMIT" ,
3638 })
3739)
3840
3941// Config holds configuration for the memory protection middleware
4042type Config struct {
41- // ThresholdPercent is the memory usage threshold for requests (0-100)
43+ // ThresholdPercent is the memory usage threshold for requests. If zero or negative, this middleware has no effect
4244 ThresholdPercent int
4345 // SampleIntervalSeconds controls how often memory usage is sampled
4446 SampleIntervalSeconds int
@@ -60,60 +62,94 @@ func DefaultDispatchConfig() Config {
6062 }
6163}
6264
63- // AdmissionMiddleware implements memory-based admission control
64- type AdmissionMiddleware struct {
65- config Config
66- memoryLimit int64
67- lastMemoryUsage atomic.Int64
68- metricsSamples []metrics.Sample
69- ctx context.Context
65+ // MemoryLimitProvider gets and sets the limit of memory usage.
66+ // In production, use DefaultMemoryLimitProvider.
67+ // For testing, use HardCodedMemoryLimitProvider.
68+ type MemoryLimitProvider interface {
69+ Get () int64
70+ Set (int64 )
71+ }
72+
73+ var (
74+ _ MemoryLimitProvider = (* DefaultMemoryLimitProvider )(nil )
75+ _ MemoryLimitProvider = (* HardCodedMemoryLimitProvider )(nil )
76+ )
77+
78+ type DefaultMemoryLimitProvider struct {}
79+
80+ func (p * DefaultMemoryLimitProvider ) Get () int64 {
81+ // SetMemoryLimit returns the previously set memory limit.
82+ // A negative input does not adjust the limit, and allows for retrieval of the currently set memory limit
83+ return debug .SetMemoryLimit (- 1 )
84+ }
85+
86+ func (p * DefaultMemoryLimitProvider ) Set (limit int64 ) {
87+ debug .SetMemoryLimit (limit )
88+ }
89+
90+ type HardCodedMemoryLimitProvider struct {
91+ Hardcodedlimit int64
92+ }
93+
94+ func (p * HardCodedMemoryLimitProvider ) Get () int64 {
95+ return p .Hardcodedlimit
96+ }
97+
98+ func (p * HardCodedMemoryLimitProvider ) Set (limit int64 ) {
99+ p .Hardcodedlimit = limit
100+ }
101+
102+ type MemoryAdmissionMiddleware struct {
103+ config Config
104+ memoryLimit int64 // -1 means no limit
105+ metricsSamples []metrics.Sample
106+ ctx context.Context // to stop the background process
107+
108+ lastMemorySampleInBytes * atomic.Uint64 // atomic because it's written inside a goroutine but can be read from anywhere
109+ timestampLastMemorySample * atomic.Pointer [time.Time ]
70110}
71111
72- // New creates a new memory protection middleware with the given context
73- func New (ctx context.Context , config Config ) * AdmissionMiddleware {
74- // Use the provided context directly
75- mwCtx := ctx
112+ // New creates a new memory admission middleware with the given context.
113+ // Whe the context is cancelled, this middleware stops its background processing.
114+ func New (ctx context.Context , config Config , limitProvider MemoryLimitProvider , name string ) MemoryAdmissionMiddleware {
115+ am := MemoryAdmissionMiddleware {
116+ config : config ,
117+ lastMemorySampleInBytes : & atomic.Uint64 {},
118+ timestampLastMemorySample : & atomic.Pointer [time.Time ]{},
119+ memoryLimit : - 1 , // disabled initially
120+ ctx : ctx ,
121+ }
76122
77123 // Get the current GOMEMLIMIT
78- memoryLimit := debug . SetMemoryLimit ( - 1 )
124+ memoryLimit := limitProvider . Get ( )
79125 if memoryLimit < 0 {
80- // If no limit is set, we can't provide memory protection
81- log .Info ().Msg ("GOMEMLIMIT not set, memory protection disabled" )
82- return & AdmissionMiddleware {
83- config : config ,
84- memoryLimit : - 1 , // Disabled
85- ctx : mwCtx ,
86- }
126+ log .Info ().Str ("name" , name ).Msg ("GOMEMLIMIT not set, memory protection disabled" )
127+ return am
87128 }
88129
89- // Check if memory protection is disabled via config
90130 if config .ThresholdPercent <= 0 {
91- log .Info ().Msg ("memory protection disabled via configuration" )
92- return & AdmissionMiddleware {
93- config : config ,
94- memoryLimit : - 1 , // Disabled
95- ctx : mwCtx ,
96- }
131+ log .Info ().Str ("name" , name ).Msg ("threshold is non-positive; memory protection disabled" )
132+ return am
97133 }
98134
99- am := & AdmissionMiddleware {
100- config : config ,
101- memoryLimit : memoryLimit ,
102- metricsSamples : []metrics.Sample {
103- {Name : "/memory/classes/heap/objects:bytes" },
104- },
105- ctx : mwCtx ,
135+ if config .SampleIntervalSeconds <= 0 {
136+ log .Info ().Str ("name" , name ).Msgf ("memory protection sample interval cannot be zero or negative; using default value of %q seconds" , DefaultSampleIntervalSeconds )
137+ am .config .SampleIntervalSeconds = DefaultSampleIntervalSeconds
106138 }
107139
108- // Initialize with current memory usage
109- if err := am .sampleMemory (); err != nil {
110- log . Warn (). Err ( err ). Msg ( "failed to get initial memory sample" )
140+ am . memoryLimit = memoryLimit
141+ am .metricsSamples = []metrics. Sample {
142+ { Name : "/ memory/classes/heap/objects:bytes" },
111143 }
112144
145+ // Initialize with current memory usage
146+ am .sampleMemory ()
147+
113148 // Start background sampling with context
114149 am .startBackgroundSampling ()
115150
116151 log .Info ().
152+ Str ("name" , name ).
117153 Int64 ("memory_limit_bytes" , memoryLimit ).
118154 Int ("threshold_percent" , config .ThresholdPercent ).
119155 Int ("sample_interval_seconds" , config .SampleIntervalSeconds ).
@@ -122,33 +158,23 @@ func New(ctx context.Context, config Config) *AdmissionMiddleware {
122158 return am
123159}
124160
125- // UnaryServerInterceptor returns a unary server interceptor that implements admission control
126- func (am * AdmissionMiddleware ) UnaryServerInterceptor () grpc.UnaryServerInterceptor {
127- return func (ctx context.Context , req interface {}, info * grpc.UnaryServerInfo , handler grpc.UnaryHandler ) (interface {}, error ) {
128- if am .memoryLimit < 0 {
129- // Memory protection is disabled
130- return handler (ctx , req )
131- }
132-
133- if err := am .checkAdmission (info .FullMethod ); err != nil {
134- am .recordRejection (info .FullMethod )
161+ // UnaryServerInterceptor returns a unary server interceptor that rejects incoming requests is memory usage is too high
162+ func (am * MemoryAdmissionMiddleware ) UnaryServerInterceptor () grpc.UnaryServerInterceptor {
163+ return func (ctx context.Context , req any , info * grpc.UnaryServerInfo , handler grpc.UnaryHandler ) (any , error ) {
164+ if err := am .checkAdmission (); err != nil {
165+ _ = am .recordRejection (info .FullMethod )
135166 return nil , err
136167 }
137168
138169 return handler (ctx , req )
139170 }
140171}
141172
142- // StreamServerInterceptor returns a stream server interceptor that implements admission control
143- func (am * AdmissionMiddleware ) StreamServerInterceptor () grpc.StreamServerInterceptor {
144- return func (srv interface {}, stream grpc.ServerStream , info * grpc.StreamServerInfo , handler grpc.StreamHandler ) error {
145- if am .memoryLimit < 0 {
146- // Memory protection is disabled
147- return handler (srv , stream )
148- }
149-
150- if err := am .checkAdmission (info .FullMethod ); err != nil {
151- am .recordRejection (info .FullMethod )
173+ // StreamServerInterceptor returns a stream server interceptor that rejects incoming requests is memory usage is too high
174+ func (am * MemoryAdmissionMiddleware ) StreamServerInterceptor () grpc.StreamServerInterceptor {
175+ return func (srv any , stream grpc.ServerStream , info * grpc.StreamServerInfo , handler grpc.StreamHandler ) error {
176+ if err := am .checkAdmission (); err != nil {
177+ _ = am .recordRejection (info .FullMethod )
152178 return err
153179 }
154180
@@ -157,21 +183,20 @@ func (am *AdmissionMiddleware) StreamServerInterceptor() grpc.StreamServerInterc
157183 }
158184}
159185
160- // checkAdmission determines if a request should be admitted based on current memory usage
161- func (am * AdmissionMiddleware ) checkAdmission (fullMethod string ) error {
162- memoryUsage := am .getCurrentMemoryUsage ()
186+ // checkAdmission returns an error if the request should be denied because memory usage is too high.
187+ func (am * MemoryAdmissionMiddleware ) checkAdmission () error {
188+ if am .memoryLimit < 0 || am .config .ThresholdPercent <= 0 {
189+ // Memory protection is disabled
190+ return nil
191+ }
192+
193+ memoryUsage := am .getLastMemorySampleInBytes ()
163194
164195 usagePercent := float64 (memoryUsage ) / float64 (am .memoryLimit ) * 100
165196
166197 // Metrics gauge is updated in background sampling
167198
168199 if usagePercent > float64 (am .config .ThresholdPercent ) {
169- log .Warn ().
170- Float64 ("memory_usage_percent" , usagePercent ).
171- Int ("threshold_percent" , am .config .ThresholdPercent ).
172- Str ("method" , fullMethod ).
173- Msg ("rejecting request due to memory pressure" )
174-
175200 return status .Errorf (codes .ResourceExhausted ,
176201 "server is experiencing memory pressure (%.1f%% usage, threshold: %d%%)" ,
177202 usagePercent , am .config .ThresholdPercent )
@@ -181,24 +206,23 @@ func (am *AdmissionMiddleware) checkAdmission(fullMethod string) error {
181206}
182207
183208// startBackgroundSampling starts a background goroutine that samples memory usage periodically
184- func (am * AdmissionMiddleware ) startBackgroundSampling () {
209+ func (am * MemoryAdmissionMiddleware ) startBackgroundSampling () {
185210 interval := time .Duration (am .config .SampleIntervalSeconds ) * time .Second
186211 ticker := time .NewTicker (interval )
187212
188213 go func () {
189214 defer ticker .Stop ()
190- defer log .Debug ().Msg ("memory protection background sampling stopped" )
191-
192- log .Debug ().
193- Dur ("interval" , interval ).
194- Msg ("memory protection background sampling started" )
215+ // TODO this code might start running before the logger is setup, therefore we have a data race with the logger object :/
216+ //defer log.Debug().Msg("memory protection background sampling stopped")
217+ //
218+ //log.Debug().
219+ // Dur("interval", interval).
220+ // Msg("memory protection background sampling started")
195221
196222 for {
197223 select {
198224 case <- ticker .C :
199- if err := am .sampleMemory (); err != nil {
200- log .Warn ().Err (err ).Msg ("background memory sampling failed" )
201- }
225+ am .sampleMemory ()
202226 case <- am .ctx .Done ():
203227 return
204228 }
@@ -207,37 +231,41 @@ func (am *AdmissionMiddleware) startBackgroundSampling() {
207231}
208232
209233// sampleMemory samples the current memory usage and updates the cached value
210- func (am * AdmissionMiddleware ) sampleMemory () error {
234+ func (am * MemoryAdmissionMiddleware ) sampleMemory () {
211235 defer func () {
212236 if r := recover (); r != nil {
213237 log .Warn ().Interface ("panic" , r ).Msg ("memory sampling panicked" )
214238 }
215239 }()
216240
241+ now := time .Now ()
217242 metrics .Read (am .metricsSamples )
218- newUsage := int64 (am .metricsSamples [0 ].Value .Uint64 ())
219- am .lastMemoryUsage .Store (newUsage )
243+ newUsage := am .metricsSamples [0 ].Value .Uint64 ()
244+ am .lastMemorySampleInBytes .Store (newUsage )
245+ am .timestampLastMemorySample .Store (& now )
220246
221247 // Update metrics gauge
222248 if am .memoryLimit > 0 {
223249 usagePercent := float64 (newUsage ) / float64 (am .memoryLimit ) * 100
224250 MemoryUsageGauge .Set (usagePercent )
225251 }
252+ }
226253
227- return nil
254+ func (am * MemoryAdmissionMiddleware ) getLastMemorySampleInBytes () uint64 {
255+ return am .lastMemorySampleInBytes .Load ()
228256}
229257
230- // getCurrentMemoryUsage returns the cached memory usage in bytes
231- func (am * AdmissionMiddleware ) getCurrentMemoryUsage () int64 {
232- return am .lastMemoryUsage .Load ()
258+ func (am * MemoryAdmissionMiddleware ) getTimestampLastMemorySample () * time.Time {
259+ return am .timestampLastMemorySample .Load ()
233260}
234261
235262// recordRejection records metrics for rejected requests
236- func (am * AdmissionMiddleware ) recordRejection (fullMethod string ) {
263+ func (am * MemoryAdmissionMiddleware ) recordRejection (fullMethod string ) string {
237264 endpointType := "api"
238265 if strings .HasPrefix (fullMethod , "/dispatch.v1.DispatchService" ) {
239266 endpointType = "dispatch"
240267 }
241268
242269 RejectedRequestsCounter .WithLabelValues (endpointType ).Inc ()
270+ return endpointType
243271}
0 commit comments