@@ -18,27 +18,29 @@ import (
1818 log "github.com/authzed/spicedb/internal/logging"
1919)
2020
21+ const DefaultSampleIntervalSeconds = 1
22+
2123var (
2224 // RejectedRequestsCounter tracks requests rejected due to memory pressure
2325 RejectedRequestsCounter = promauto .NewCounterVec (prometheus.CounterOpts {
2426 Namespace : "spicedb" ,
25- Subsystem : "admission " ,
26- Name : "memory_overload_rejected_requests_total " ,
27+ Subsystem : "memory_admission " ,
28+ Name : "rejected_requests_total " ,
2729 Help : "Total requests rejected due to memory pressure" ,
2830 }, []string {"endpoint" })
2931
3032 // MemoryUsageGauge tracks current memory usage percentage
3133 MemoryUsageGauge = promauto .NewGauge (prometheus.GaugeOpts {
3234 Namespace : "spicedb" ,
33- Subsystem : "admission " ,
35+ Subsystem : "memory_admission " ,
3436 Name : "memory_usage_percent" ,
3537 Help : "Current memory usage as percentage of GOMEMLIMIT" ,
3638 })
3739)
3840
3941// Config holds configuration for the memory protection middleware
4042type Config struct {
41- // ThresholdPercent is the memory usage threshold for requests (0-100)
43+ // ThresholdPercent is the memory usage threshold for requests. If zero or negative, this middleware has no effect
4244 ThresholdPercent int
4345 // SampleIntervalSeconds controls how often memory usage is sampled
4446 SampleIntervalSeconds int
@@ -60,56 +62,90 @@ func DefaultDispatchConfig() Config {
6062 }
6163}
6264
63- // AdmissionMiddleware implements memory-based admission control
64- type AdmissionMiddleware struct {
65- config Config
66- memoryLimit int64
67- lastMemoryUsage atomic.Int64
68- metricsSamples []metrics.Sample
69- ctx context.Context
65+ // MemoryLimitProvider gets and sets the limit of memory usage.
66+ // In production, use DefaultMemoryLimitProvider.
67+ // For testing, use HardCodedMemoryLimitProvider.
68+ type MemoryLimitProvider interface {
69+ Get () int64
70+ Set (int64 )
71+ }
72+
73+ var (
74+ _ MemoryLimitProvider = (* DefaultMemoryLimitProvider )(nil )
75+ _ MemoryLimitProvider = (* HardCodedMemoryLimitProvider )(nil )
76+ )
77+
78+ type DefaultMemoryLimitProvider struct {}
79+
80+ func (p * DefaultMemoryLimitProvider ) Get () int64 {
81+ // SetMemoryLimit returns the previously set memory limit.
82+ // A negative input does not adjust the limit, and allows for retrieval of the currently set memory limit
83+ return debug .SetMemoryLimit (- 1 )
84+ }
85+
86+ func (p * DefaultMemoryLimitProvider ) Set (limit int64 ) {
87+ debug .SetMemoryLimit (limit )
88+ }
89+
90+ type HardCodedMemoryLimitProvider struct {
91+ Hardcodedlimit int64
7092}
7193
72- // New creates a new memory protection middleware with the given context
73- func New (ctx context.Context , config Config ) * AdmissionMiddleware {
74- // Use the provided context directly
75- mwCtx := ctx
94+ func (p * HardCodedMemoryLimitProvider ) Get () int64 {
95+ return p .Hardcodedlimit
96+ }
97+
98+ func (p * HardCodedMemoryLimitProvider ) Set (limit int64 ) {
99+ p .Hardcodedlimit = limit
100+ }
101+
102+ type MemoryAdmissionMiddleware struct {
103+ config Config
104+ memoryLimit int64 // -1 means no limit
105+ metricsSamples []metrics.Sample
106+ ctx context.Context // to stop the background process
107+
108+ lastMemorySampleInBytes * atomic.Uint64 // atomic because it's written inside a goroutine but can be read from anywhere
109+ timestampLastMemorySample * atomic.Pointer [time.Time ]
110+ }
111+
112+ // New creates a new memory admission middleware with the given context.
113+ // Whe the context is cancelled, this middleware stops its background processing.
114+ func New (ctx context.Context , config Config , limitProvider MemoryLimitProvider ) MemoryAdmissionMiddleware {
115+ am := MemoryAdmissionMiddleware {
116+ config : config ,
117+ lastMemorySampleInBytes : & atomic.Uint64 {},
118+ timestampLastMemorySample : & atomic.Pointer [time.Time ]{},
119+ memoryLimit : - 1 , // disabled initially
120+ ctx : ctx ,
121+ }
76122
77123 // Get the current GOMEMLIMIT
78- memoryLimit := debug . SetMemoryLimit ( - 1 )
124+ memoryLimit := limitProvider . Get ( )
79125 if memoryLimit < 0 {
80126 // If no limit is set, we can't provide memory protection
81127 log .Info ().Msg ("GOMEMLIMIT not set, memory protection disabled" )
82- return & AdmissionMiddleware {
83- config : config ,
84- memoryLimit : - 1 , // Disabled
85- ctx : mwCtx ,
86- }
128+ return am
87129 }
88130
89- // Check if memory protection is disabled via config
90131 if config .ThresholdPercent <= 0 {
91132 log .Info ().Msg ("memory protection disabled via configuration" )
92- return & AdmissionMiddleware {
93- config : config ,
94- memoryLimit : - 1 , // Disabled
95- ctx : mwCtx ,
96- }
133+ return am
97134 }
98135
99- am := & AdmissionMiddleware {
100- config : config ,
101- memoryLimit : memoryLimit ,
102- metricsSamples : []metrics.Sample {
103- {Name : "/memory/classes/heap/objects:bytes" },
104- },
105- ctx : mwCtx ,
136+ if config .SampleIntervalSeconds <= 0 {
137+ log .Info ().Msgf ("memory protection sample interval cannot be zero or negative; using default value of %q seconds" , DefaultSampleIntervalSeconds )
138+ am .config .SampleIntervalSeconds = DefaultSampleIntervalSeconds
106139 }
107140
108- // Initialize with current memory usage
109- if err := am .sampleMemory (); err != nil {
110- log . Warn (). Err ( err ). Msg ( "failed to get initial memory sample" )
141+ am . memoryLimit = memoryLimit
142+ am .metricsSamples = []metrics. Sample {
143+ { Name : "/ memory/classes/heap/objects:bytes" },
111144 }
112145
146+ // Initialize with current memory usage
147+ am .sampleMemory ()
148+
113149 // Start background sampling with context
114150 am .startBackgroundSampling ()
115151
@@ -122,9 +158,9 @@ func New(ctx context.Context, config Config) *AdmissionMiddleware {
122158 return am
123159}
124160
125- // UnaryServerInterceptor returns a unary server interceptor that implements admission control
126- func (am * AdmissionMiddleware ) UnaryServerInterceptor () grpc.UnaryServerInterceptor {
127- return func (ctx context.Context , req interface {} , info * grpc.UnaryServerInfo , handler grpc.UnaryHandler ) (interface {} , error ) {
161+ // UnaryServerInterceptor returns a unary server interceptor that rejects incoming requests is memory usage is too high
162+ func (am * MemoryAdmissionMiddleware ) UnaryServerInterceptor () grpc.UnaryServerInterceptor {
163+ return func (ctx context.Context , req any , info * grpc.UnaryServerInfo , handler grpc.UnaryHandler ) (any , error ) {
128164 if am .memoryLimit < 0 {
129165 // Memory protection is disabled
130166 return handler (ctx , req )
@@ -139,9 +175,9 @@ func (am *AdmissionMiddleware) UnaryServerInterceptor() grpc.UnaryServerIntercep
139175 }
140176}
141177
142- // StreamServerInterceptor returns a stream server interceptor that implements admission control
143- func (am * AdmissionMiddleware ) StreamServerInterceptor () grpc.StreamServerInterceptor {
144- return func (srv interface {} , stream grpc.ServerStream , info * grpc.StreamServerInfo , handler grpc.StreamHandler ) error {
178+ // StreamServerInterceptor returns a stream server interceptor that rejects incoming requests is memory usage is too high
179+ func (am * MemoryAdmissionMiddleware ) StreamServerInterceptor () grpc.StreamServerInterceptor {
180+ return func (srv any , stream grpc.ServerStream , info * grpc.StreamServerInfo , handler grpc.StreamHandler ) error {
145181 if am .memoryLimit < 0 {
146182 // Memory protection is disabled
147183 return handler (srv , stream )
@@ -158,8 +194,8 @@ func (am *AdmissionMiddleware) StreamServerInterceptor() grpc.StreamServerInterc
158194}
159195
160196// checkAdmission determines if a request should be admitted based on current memory usage
161- func (am * AdmissionMiddleware ) checkAdmission (fullMethod string ) error {
162- memoryUsage := am .getCurrentMemoryUsage ()
197+ func (am * MemoryAdmissionMiddleware ) checkAdmission (fullMethod string ) error {
198+ memoryUsage := am .getLastMemorySampleInBytes ()
163199
164200 usagePercent := float64 (memoryUsage ) / float64 (am .memoryLimit ) * 100
165201
@@ -181,24 +217,23 @@ func (am *AdmissionMiddleware) checkAdmission(fullMethod string) error {
181217}
182218
183219// startBackgroundSampling starts a background goroutine that samples memory usage periodically
184- func (am * AdmissionMiddleware ) startBackgroundSampling () {
220+ func (am * MemoryAdmissionMiddleware ) startBackgroundSampling () {
185221 interval := time .Duration (am .config .SampleIntervalSeconds ) * time .Second
186222 ticker := time .NewTicker (interval )
187223
188224 go func () {
189225 defer ticker .Stop ()
190- defer log .Debug ().Msg ("memory protection background sampling stopped" )
191-
192- log .Debug ().
193- Dur ("interval" , interval ).
194- Msg ("memory protection background sampling started" )
226+ // TODO this code might start running before the logger is setup, therefore we have a data race
227+ //defer log.Debug().Msg("memory protection background sampling stopped")
228+ //
229+ //log.Debug().
230+ // Dur("interval", interval).
231+ // Msg("memory protection background sampling started")
195232
196233 for {
197234 select {
198235 case <- ticker .C :
199- if err := am .sampleMemory (); err != nil {
200- log .Warn ().Err (err ).Msg ("background memory sampling failed" )
201- }
236+ am .sampleMemory ()
202237 case <- am .ctx .Done ():
203238 return
204239 }
@@ -207,33 +242,36 @@ func (am *AdmissionMiddleware) startBackgroundSampling() {
207242}
208243
209244// sampleMemory samples the current memory usage and updates the cached value
210- func (am * AdmissionMiddleware ) sampleMemory () error {
245+ func (am * MemoryAdmissionMiddleware ) sampleMemory () {
211246 defer func () {
212247 if r := recover (); r != nil {
213248 log .Warn ().Interface ("panic" , r ).Msg ("memory sampling panicked" )
214249 }
215250 }()
216251
252+ now := time .Now ()
217253 metrics .Read (am .metricsSamples )
218- newUsage := int64 (am .metricsSamples [0 ].Value .Uint64 ())
219- am .lastMemoryUsage .Store (newUsage )
254+ newUsage := am .metricsSamples [0 ].Value .Uint64 ()
255+ am .lastMemorySampleInBytes .Store (newUsage )
256+ am .timestampLastMemorySample .Store (& now )
220257
221258 // Update metrics gauge
222259 if am .memoryLimit > 0 {
223260 usagePercent := float64 (newUsage ) / float64 (am .memoryLimit ) * 100
224261 MemoryUsageGauge .Set (usagePercent )
225262 }
263+ }
226264
227- return nil
265+ func (am * MemoryAdmissionMiddleware ) getLastMemorySampleInBytes () uint64 {
266+ return am .lastMemorySampleInBytes .Load ()
228267}
229268
230- // getCurrentMemoryUsage returns the cached memory usage in bytes
231- func (am * AdmissionMiddleware ) getCurrentMemoryUsage () int64 {
232- return am .lastMemoryUsage .Load ()
269+ func (am * MemoryAdmissionMiddleware ) getTimestampLastMemorySample () * time.Time {
270+ return am .timestampLastMemorySample .Load ()
233271}
234272
235273// recordRejection records metrics for rejected requests
236- func (am * AdmissionMiddleware ) recordRejection (fullMethod string ) {
274+ func (am * MemoryAdmissionMiddleware ) recordRejection (fullMethod string ) {
237275 endpointType := "api"
238276 if strings .HasPrefix (fullMethod , "/dispatch.v1.DispatchService" ) {
239277 endpointType = "dispatch"
0 commit comments