@@ -22,15 +22,15 @@ var (
2222 // RejectedRequestsCounter tracks requests rejected due to memory pressure
2323 RejectedRequestsCounter = promauto .NewCounterVec (prometheus.CounterOpts {
2424 Namespace : "spicedb" ,
25- Subsystem : "admission " ,
26- Name : "memory_overload_rejected_requests_total " ,
25+ Subsystem : "memory_admission " ,
26+ Name : "rejected_requests_total " ,
2727 Help : "Total requests rejected due to memory pressure" ,
2828 }, []string {"endpoint" })
2929
3030 // MemoryUsageGauge tracks current memory usage percentage
3131 MemoryUsageGauge = promauto .NewGauge (prometheus.GaugeOpts {
3232 Namespace : "spicedb" ,
33- Subsystem : "admission " ,
33+ Subsystem : "memory_admission " ,
3434 Name : "memory_usage_percent" ,
3535 Help : "Current memory usage as percentage of GOMEMLIMIT" ,
3636 })
@@ -60,55 +60,84 @@ func DefaultDispatchConfig() Config {
6060 }
6161}
6262
63- // AdmissionMiddleware implements memory-based admission control
64- type AdmissionMiddleware struct {
65- config Config
66- memoryLimit int64
67- lastMemoryUsage atomic.Int64
68- metricsSamples []metrics.Sample
69- ctx context.Context
63+ // MemoryLimitProvider gets and sets the limit of memory usage.
64+ // In production, use DefaultMemoryLimitProvider.
65+ // For testing, use HardCodedMemoryLimitProvider.
66+ type MemoryLimitProvider interface {
67+ Get () int64
68+ Set (int64 )
7069}
7170
72- // New creates a new memory protection middleware with the given context
73- func New (ctx context.Context , config Config ) * AdmissionMiddleware {
74- // Use the provided context directly
75- mwCtx := ctx
71+ var (
72+ _ MemoryLimitProvider = (* DefaultMemoryLimitProvider )(nil )
73+ _ MemoryLimitProvider = (* HardCodedMemoryLimitProvider )(nil )
74+ )
75+
76+ type DefaultMemoryLimitProvider struct {}
77+
78+ func (p * DefaultMemoryLimitProvider ) Get () int64 {
79+ // SetMemoryLimit returns the previously set memory limit.
80+ // A negative input does not adjust the limit, and allows for retrieval of the currently set memory limit
81+ return debug .SetMemoryLimit (- 1 )
82+ }
83+
84+ func (p * DefaultMemoryLimitProvider ) Set (limit int64 ) {
85+ debug .SetMemoryLimit (limit )
86+ }
87+
88+ type HardCodedMemoryLimitProvider struct {
89+ Hardcodedlimit int64
90+ }
91+
92+ func (p * HardCodedMemoryLimitProvider ) Get () int64 {
93+ return p .Hardcodedlimit
94+ }
95+
96+ func (p * HardCodedMemoryLimitProvider ) Set (limit int64 ) {
97+ p .Hardcodedlimit = limit
98+ }
99+
100+ type MemoryAdmissionMiddleware struct {
101+ config Config
102+ memoryLimit int64 // -1 means no limit
103+ metricsSamples []metrics.Sample
104+ ctx context.Context // to stop the background process
105+
106+ lastMemorySampleInBytes * atomic.Uint64 // atomic because it's written inside a goroutine but can be read from anywhere
107+ timestampLastMemorySample * atomic.Pointer [time.Time ]
108+ }
109+
110+ // New creates a new memory admission middleware with the given context.
111+ // Whe the context is cancelled, this middleware stops its background processing.
112+ func New (ctx context.Context , config Config , limitProvider MemoryLimitProvider ) MemoryAdmissionMiddleware {
113+ am := MemoryAdmissionMiddleware {
114+ config : config ,
115+ lastMemorySampleInBytes : & atomic.Uint64 {},
116+ timestampLastMemorySample : & atomic.Pointer [time.Time ]{},
117+ memoryLimit : - 1 , // disabled initially
118+ ctx : ctx ,
119+ }
76120
77121 // Get the current GOMEMLIMIT
78- memoryLimit := debug . SetMemoryLimit ( - 1 )
122+ memoryLimit := limitProvider . Get ( )
79123 if memoryLimit < 0 {
80124 // If no limit is set, we can't provide memory protection
81125 log .Info ().Msg ("GOMEMLIMIT not set, memory protection disabled" )
82- return & AdmissionMiddleware {
83- config : config ,
84- memoryLimit : - 1 , // Disabled
85- ctx : mwCtx ,
86- }
126+ return am
87127 }
88128
89- // Check if memory protection is disabled via config
90129 if config .ThresholdPercent <= 0 {
91130 log .Info ().Msg ("memory protection disabled via configuration" )
92- return & AdmissionMiddleware {
93- config : config ,
94- memoryLimit : - 1 , // Disabled
95- ctx : mwCtx ,
96- }
131+ return am
97132 }
98133
99- am := & AdmissionMiddleware {
100- config : config ,
101- memoryLimit : memoryLimit ,
102- metricsSamples : []metrics.Sample {
103- {Name : "/memory/classes/heap/objects:bytes" },
104- },
105- ctx : mwCtx ,
134+ am .memoryLimit = memoryLimit
135+ am .metricsSamples = []metrics.Sample {
136+ {Name : "/memory/classes/heap/objects:bytes" },
106137 }
107138
108139 // Initialize with current memory usage
109- if err := am .sampleMemory (); err != nil {
110- log .Warn ().Err (err ).Msg ("failed to get initial memory sample" )
111- }
140+ am .sampleMemory ()
112141
113142 // Start background sampling with context
114143 am .startBackgroundSampling ()
@@ -122,9 +151,9 @@ func New(ctx context.Context, config Config) *AdmissionMiddleware {
122151 return am
123152}
124153
125- // UnaryServerInterceptor returns a unary server interceptor that implements admission control
126- func (am * AdmissionMiddleware ) UnaryServerInterceptor () grpc.UnaryServerInterceptor {
127- return func (ctx context.Context , req interface {} , info * grpc.UnaryServerInfo , handler grpc.UnaryHandler ) (interface {} , error ) {
154+ // UnaryServerInterceptor returns a unary server interceptor that rejects incoming requests is memory usage is too high
155+ func (am * MemoryAdmissionMiddleware ) UnaryServerInterceptor () grpc.UnaryServerInterceptor {
156+ return func (ctx context.Context , req any , info * grpc.UnaryServerInfo , handler grpc.UnaryHandler ) (any , error ) {
128157 if am .memoryLimit < 0 {
129158 // Memory protection is disabled
130159 return handler (ctx , req )
@@ -139,9 +168,9 @@ func (am *AdmissionMiddleware) UnaryServerInterceptor() grpc.UnaryServerIntercep
139168 }
140169}
141170
142- // StreamServerInterceptor returns a stream server interceptor that implements admission control
143- func (am * AdmissionMiddleware ) StreamServerInterceptor () grpc.StreamServerInterceptor {
144- return func (srv interface {} , stream grpc.ServerStream , info * grpc.StreamServerInfo , handler grpc.StreamHandler ) error {
171+ // StreamServerInterceptor returns a stream server interceptor that rejects incoming requests is memory usage is too high
172+ func (am * MemoryAdmissionMiddleware ) StreamServerInterceptor () grpc.StreamServerInterceptor {
173+ return func (srv any , stream grpc.ServerStream , info * grpc.StreamServerInfo , handler grpc.StreamHandler ) error {
145174 if am .memoryLimit < 0 {
146175 // Memory protection is disabled
147176 return handler (srv , stream )
@@ -158,8 +187,8 @@ func (am *AdmissionMiddleware) StreamServerInterceptor() grpc.StreamServerInterc
158187}
159188
160189// checkAdmission determines if a request should be admitted based on current memory usage
161- func (am * AdmissionMiddleware ) checkAdmission (fullMethod string ) error {
162- memoryUsage := am .getCurrentMemoryUsage ()
190+ func (am * MemoryAdmissionMiddleware ) checkAdmission (fullMethod string ) error {
191+ memoryUsage := am .getLastMemorySampleInBytes ()
163192
164193 usagePercent := float64 (memoryUsage ) / float64 (am .memoryLimit ) * 100
165194
@@ -181,24 +210,23 @@ func (am *AdmissionMiddleware) checkAdmission(fullMethod string) error {
181210}
182211
183212// startBackgroundSampling starts a background goroutine that samples memory usage periodically
184- func (am * AdmissionMiddleware ) startBackgroundSampling () {
213+ func (am * MemoryAdmissionMiddleware ) startBackgroundSampling () {
185214 interval := time .Duration (am .config .SampleIntervalSeconds ) * time .Second
186215 ticker := time .NewTicker (interval )
187216
188217 go func () {
189218 defer ticker .Stop ()
190- defer log .Debug ().Msg ("memory protection background sampling stopped" )
191-
192- log .Debug ().
193- Dur ("interval" , interval ).
194- Msg ("memory protection background sampling started" )
219+ // TODO this code might start running before the logger is setup, therefore we have a data race
220+ //defer log.Debug().Msg("memory protection background sampling stopped")
221+ //
222+ //log.Debug().
223+ // Dur("interval", interval).
224+ // Msg("memory protection background sampling started")
195225
196226 for {
197227 select {
198228 case <- ticker .C :
199- if err := am .sampleMemory (); err != nil {
200- log .Warn ().Err (err ).Msg ("background memory sampling failed" )
201- }
229+ am .sampleMemory ()
202230 case <- am .ctx .Done ():
203231 return
204232 }
@@ -207,33 +235,36 @@ func (am *AdmissionMiddleware) startBackgroundSampling() {
207235}
208236
209237// sampleMemory samples the current memory usage and updates the cached value
210- func (am * AdmissionMiddleware ) sampleMemory () error {
238+ func (am * MemoryAdmissionMiddleware ) sampleMemory () {
211239 defer func () {
212240 if r := recover (); r != nil {
213241 log .Warn ().Interface ("panic" , r ).Msg ("memory sampling panicked" )
214242 }
215243 }()
216244
245+ now := time .Now ()
217246 metrics .Read (am .metricsSamples )
218- newUsage := int64 (am .metricsSamples [0 ].Value .Uint64 ())
219- am .lastMemoryUsage .Store (newUsage )
247+ newUsage := am .metricsSamples [0 ].Value .Uint64 ()
248+ am .lastMemorySampleInBytes .Store (newUsage )
249+ am .timestampLastMemorySample .Store (& now )
220250
221251 // Update metrics gauge
222252 if am .memoryLimit > 0 {
223253 usagePercent := float64 (newUsage ) / float64 (am .memoryLimit ) * 100
224254 MemoryUsageGauge .Set (usagePercent )
225255 }
256+ }
226257
227- return nil
258+ func (am * MemoryAdmissionMiddleware ) getLastMemorySampleInBytes () uint64 {
259+ return am .lastMemorySampleInBytes .Load ()
228260}
229261
230- // getCurrentMemoryUsage returns the cached memory usage in bytes
231- func (am * AdmissionMiddleware ) getCurrentMemoryUsage () int64 {
232- return am .lastMemoryUsage .Load ()
262+ func (am * MemoryAdmissionMiddleware ) getTimestampLastMemorySample () * time.Time {
263+ return am .timestampLastMemorySample .Load ()
233264}
234265
235266// recordRejection records metrics for rejected requests
236- func (am * AdmissionMiddleware ) recordRejection (fullMethod string ) {
267+ func (am * MemoryAdmissionMiddleware ) recordRejection (fullMethod string ) {
237268 endpointType := "api"
238269 if strings .HasPrefix (fullMethod , "/dispatch.v1.DispatchService" ) {
239270 endpointType = "dispatch"
0 commit comments