|
| 1 | +package memoryprotection |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "runtime/debug" |
| 6 | + "runtime/metrics" |
| 7 | + "strings" |
| 8 | + "sync/atomic" |
| 9 | + "time" |
| 10 | + |
| 11 | + middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2" |
| 12 | + "github.com/prometheus/client_golang/prometheus" |
| 13 | + "github.com/prometheus/client_golang/prometheus/promauto" |
| 14 | + "google.golang.org/grpc" |
| 15 | + "google.golang.org/grpc/codes" |
| 16 | + "google.golang.org/grpc/status" |
| 17 | + |
| 18 | + log "github.com/authzed/spicedb/internal/logging" |
| 19 | +) |
| 20 | + |
| 21 | +var ( |
| 22 | + // RejectedRequestsCounter tracks requests rejected due to memory pressure |
| 23 | + RejectedRequestsCounter = promauto.NewCounterVec(prometheus.CounterOpts{ |
| 24 | + Namespace: "spicedb", |
| 25 | + Subsystem: "admission", |
| 26 | + Name: "memory_overload_rejected_requests_total", |
| 27 | + Help: "Total requests rejected due to memory pressure", |
| 28 | + }, []string{"endpoint"}) |
| 29 | + |
| 30 | + // MemoryUsageGauge tracks current memory usage percentage |
| 31 | + MemoryUsageGauge = promauto.NewGauge(prometheus.GaugeOpts{ |
| 32 | + Namespace: "spicedb", |
| 33 | + Subsystem: "admission", |
| 34 | + Name: "memory_usage_percent", |
| 35 | + Help: "Current memory usage as percentage of GOMEMLIMIT", |
| 36 | + }) |
| 37 | +) |
| 38 | + |
| 39 | +// Config holds configuration for the memory protection middleware |
| 40 | +type Config struct { |
| 41 | + // ThresholdPercent is the memory usage threshold for requests (0-100) |
| 42 | + ThresholdPercent int |
| 43 | + // SampleIntervalSeconds controls how often memory usage is sampled |
| 44 | + SampleIntervalSeconds int |
| 45 | +} |
| 46 | + |
| 47 | +// DefaultConfig returns reasonable default configuration for API requests |
| 48 | +func DefaultConfig() Config { |
| 49 | + return Config{ |
| 50 | + ThresholdPercent: 90, |
| 51 | + SampleIntervalSeconds: 1, |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | +// DefaultDispatchConfig returns reasonable default configuration for dispatch requests |
| 56 | +func DefaultDispatchConfig() Config { |
| 57 | + return Config{ |
| 58 | + ThresholdPercent: 95, |
| 59 | + SampleIntervalSeconds: 1, |
| 60 | + } |
| 61 | +} |
| 62 | + |
| 63 | +// AdmissionMiddleware implements memory-based admission control |
| 64 | +type AdmissionMiddleware struct { |
| 65 | + config Config |
| 66 | + memoryLimit int64 |
| 67 | + lastMemoryUsage atomic.Int64 |
| 68 | + metricsSamples []metrics.Sample |
| 69 | + ctx context.Context |
| 70 | +} |
| 71 | + |
| 72 | +// New creates a new memory protection middleware with the given context |
| 73 | +func New(ctx context.Context, config Config) *AdmissionMiddleware { |
| 74 | + // Use the provided context directly |
| 75 | + mwCtx := ctx |
| 76 | + |
| 77 | + // Get the current GOMEMLIMIT |
| 78 | + memoryLimit := debug.SetMemoryLimit(-1) |
| 79 | + if memoryLimit < 0 { |
| 80 | + // If no limit is set, we can't provide memory protection |
| 81 | + log.Info().Msg("GOMEMLIMIT not set, memory protection disabled") |
| 82 | + return &AdmissionMiddleware{ |
| 83 | + config: config, |
| 84 | + memoryLimit: -1, // Disabled |
| 85 | + ctx: mwCtx, |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | + // Check if memory protection is disabled via config |
| 90 | + if config.ThresholdPercent <= 0 { |
| 91 | + log.Info().Msg("memory protection disabled via configuration") |
| 92 | + return &AdmissionMiddleware{ |
| 93 | + config: config, |
| 94 | + memoryLimit: -1, // Disabled |
| 95 | + ctx: mwCtx, |
| 96 | + } |
| 97 | + } |
| 98 | + |
| 99 | + am := &AdmissionMiddleware{ |
| 100 | + config: config, |
| 101 | + memoryLimit: memoryLimit, |
| 102 | + metricsSamples: []metrics.Sample{ |
| 103 | + {Name: "/memory/classes/heap/objects:bytes"}, |
| 104 | + }, |
| 105 | + ctx: mwCtx, |
| 106 | + } |
| 107 | + |
| 108 | + // Initialize with current memory usage |
| 109 | + if err := am.sampleMemory(); err != nil { |
| 110 | + log.Warn().Err(err).Msg("failed to get initial memory sample") |
| 111 | + } |
| 112 | + |
| 113 | + // Start background sampling with context |
| 114 | + am.startBackgroundSampling() |
| 115 | + |
| 116 | + log.Info(). |
| 117 | + Int64("memory_limit_bytes", memoryLimit). |
| 118 | + Int("threshold_percent", config.ThresholdPercent). |
| 119 | + Int("sample_interval_seconds", config.SampleIntervalSeconds). |
| 120 | + Msg("memory protection middleware initialized with background sampling") |
| 121 | + |
| 122 | + return am |
| 123 | +} |
| 124 | + |
| 125 | +// UnaryServerInterceptor returns a unary server interceptor that implements admission control |
| 126 | +func (am *AdmissionMiddleware) UnaryServerInterceptor() grpc.UnaryServerInterceptor { |
| 127 | + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { |
| 128 | + if am.memoryLimit < 0 { |
| 129 | + // Memory protection is disabled |
| 130 | + return handler(ctx, req) |
| 131 | + } |
| 132 | + |
| 133 | + if err := am.checkAdmission(info.FullMethod); err != nil { |
| 134 | + am.recordRejection(info.FullMethod) |
| 135 | + return nil, err |
| 136 | + } |
| 137 | + |
| 138 | + return handler(ctx, req) |
| 139 | + } |
| 140 | +} |
| 141 | + |
| 142 | +// StreamServerInterceptor returns a stream server interceptor that implements admission control |
| 143 | +func (am *AdmissionMiddleware) StreamServerInterceptor() grpc.StreamServerInterceptor { |
| 144 | + return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { |
| 145 | + if am.memoryLimit < 0 { |
| 146 | + // Memory protection is disabled |
| 147 | + return handler(srv, stream) |
| 148 | + } |
| 149 | + |
| 150 | + if err := am.checkAdmission(info.FullMethod); err != nil { |
| 151 | + am.recordRejection(info.FullMethod) |
| 152 | + return err |
| 153 | + } |
| 154 | + |
| 155 | + wrapped := middleware.WrapServerStream(stream) |
| 156 | + return handler(srv, wrapped) |
| 157 | + } |
| 158 | +} |
| 159 | + |
| 160 | +// checkAdmission determines if a request should be admitted based on current memory usage |
| 161 | +func (am *AdmissionMiddleware) checkAdmission(fullMethod string) error { |
| 162 | + memoryUsage := am.getCurrentMemoryUsage() |
| 163 | + |
| 164 | + usagePercent := float64(memoryUsage) / float64(am.memoryLimit) * 100 |
| 165 | + |
| 166 | + // Metrics gauge is updated in background sampling |
| 167 | + |
| 168 | + if usagePercent > float64(am.config.ThresholdPercent) { |
| 169 | + log.Warn(). |
| 170 | + Float64("memory_usage_percent", usagePercent). |
| 171 | + Int("threshold_percent", am.config.ThresholdPercent). |
| 172 | + Str("method", fullMethod). |
| 173 | + Msg("rejecting request due to memory pressure") |
| 174 | + |
| 175 | + return status.Errorf(codes.ResourceExhausted, |
| 176 | + "server is experiencing memory pressure (%.1f%% usage, threshold: %d%%)", |
| 177 | + usagePercent, am.config.ThresholdPercent) |
| 178 | + } |
| 179 | + |
| 180 | + return nil |
| 181 | +} |
| 182 | + |
| 183 | +// startBackgroundSampling starts a background goroutine that samples memory usage periodically |
| 184 | +func (am *AdmissionMiddleware) startBackgroundSampling() { |
| 185 | + interval := time.Duration(am.config.SampleIntervalSeconds) * time.Second |
| 186 | + ticker := time.NewTicker(interval) |
| 187 | + |
| 188 | + go func() { |
| 189 | + defer ticker.Stop() |
| 190 | + defer log.Debug().Msg("memory protection background sampling stopped") |
| 191 | + |
| 192 | + log.Debug(). |
| 193 | + Dur("interval", interval). |
| 194 | + Msg("memory protection background sampling started") |
| 195 | + |
| 196 | + for { |
| 197 | + select { |
| 198 | + case <-ticker.C: |
| 199 | + if err := am.sampleMemory(); err != nil { |
| 200 | + log.Warn().Err(err).Msg("background memory sampling failed") |
| 201 | + } |
| 202 | + case <-am.ctx.Done(): |
| 203 | + return |
| 204 | + } |
| 205 | + } |
| 206 | + }() |
| 207 | +} |
| 208 | + |
| 209 | +// sampleMemory samples the current memory usage and updates the cached value |
| 210 | +func (am *AdmissionMiddleware) sampleMemory() error { |
| 211 | + defer func() { |
| 212 | + if r := recover(); r != nil { |
| 213 | + log.Warn().Interface("panic", r).Msg("memory sampling panicked") |
| 214 | + } |
| 215 | + }() |
| 216 | + |
| 217 | + metrics.Read(am.metricsSamples) |
| 218 | + newUsage := int64(am.metricsSamples[0].Value.Uint64()) |
| 219 | + am.lastMemoryUsage.Store(newUsage) |
| 220 | + |
| 221 | + // Update metrics gauge |
| 222 | + if am.memoryLimit > 0 { |
| 223 | + usagePercent := float64(newUsage) / float64(am.memoryLimit) * 100 |
| 224 | + MemoryUsageGauge.Set(usagePercent) |
| 225 | + } |
| 226 | + |
| 227 | + return nil |
| 228 | +} |
| 229 | + |
| 230 | +// getCurrentMemoryUsage returns the cached memory usage in bytes |
| 231 | +func (am *AdmissionMiddleware) getCurrentMemoryUsage() int64 { |
| 232 | + return am.lastMemoryUsage.Load() |
| 233 | +} |
| 234 | + |
| 235 | +// recordRejection records metrics for rejected requests |
| 236 | +func (am *AdmissionMiddleware) recordRejection(fullMethod string) { |
| 237 | + endpointType := "api" |
| 238 | + if strings.HasPrefix(fullMethod, "/dispatch.v1.DispatchService") { |
| 239 | + endpointType = "dispatch" |
| 240 | + } |
| 241 | + |
| 242 | + RejectedRequestsCounter.WithLabelValues(endpointType).Inc() |
| 243 | +} |
0 commit comments