Skip to content

Commit 54da7bb

Browse files
authored
Merge pull request #1314 from nsakharenko/feat/loadhed-add-custom-onshed-handler
feat(): add possibility to define custom OnShed handler;
2 parents 5362f24 + 9a53f2d commit 54da7bb

File tree

4 files changed

+239
-33
lines changed

4 files changed

+239
-33
lines changed

go.work

Lines changed: 0 additions & 23 deletions
This file was deleted.

loadshed/README.md

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,45 @@ loadshed.New(config ...loadshed.Config) fiber.Handler
3131

3232
To use the LoadShed middleware in your Fiber application, import it and apply it to your Fiber app. Here's an example:
3333

34+
### Basic
35+
36+
```go
37+
package main
38+
39+
import (
40+
"time"
41+
"github.com/gofiber/fiber/v2"
42+
loadshed "github.com/gofiber/contrib/loadshed"
43+
)
44+
45+
func main() {
46+
app := fiber.New()
47+
48+
// Configure and use LoadShed middleware
49+
app.Use(loadshed.New(loadshed.Config{
50+
Criteria: &loadshed.CPULoadCriteria{
51+
LowerThreshold: 0.75, // Set your own lower threshold
52+
UpperThreshold: 0.90, // Set your own upper threshold
53+
Interval: 10 * time.Second,
54+
Getter: &loadshed.DefaultCPUPercentGetter{},
55+
},
56+
}))
57+
58+
app.Get("/", func(c *fiber.Ctx) error {
59+
return c.SendString("Welcome!")
60+
})
61+
62+
app.Listen(":3000")
63+
}
64+
```
65+
66+
### With a custom rejection handler
67+
3468
```go
3569
package main
3670

3771
import (
72+
"time"
3873
"github.com/gofiber/fiber/v2"
3974
loadshed "github.com/gofiber/contrib/loadshed"
4075
)
@@ -50,6 +85,19 @@ func main() {
5085
Interval: 10 * time.Second,
5186
Getter: &loadshed.DefaultCPUPercentGetter{},
5287
},
88+
OnShed: func(ctx *fiber.Ctx) error {
89+
if ctx.Method() == fiber.MethodGet {
90+
return ctx.
91+
Status(fiber.StatusTooManyRequests).
92+
Send([]byte{})
93+
}
94+
95+
return ctx.
96+
Status(fiber.StatusTooManyRequests).
97+
JSON(fiber.Map{
98+
"error": "Keep calm",
99+
})
100+
},
53101
}))
54102

55103
app.Get("/", func(c *fiber.Ctx) error {
@@ -64,10 +112,11 @@ func main() {
64112

65113
The LoadShed middleware in Fiber offers various configuration options to tailor the load shedding behavior according to the needs of your application.
66114

67-
| Property | Type | Description | Default |
68-
| :------- | :---------------------- | :--------------------------------------------------- | :---------------------- |
69-
| Next | `func(*fiber.Ctx) bool` | Function to skip this middleware when returned true. | `nil` |
70-
| Criteria | `LoadCriteria` | Interface for defining load shedding criteria. | `&CPULoadCriteria{...}` |
115+
| Property | Type | Description | Default |
116+
|:---------|:---------------------------|:--------------------------------------------------------|:------------------------|
117+
| Next | `func(*fiber.Ctx) bool` | Function to skip this middleware when returned true. | `nil` |
118+
| Criteria | `LoadCriteria` | Interface for defining load shedding criteria. | `&CPULoadCriteria{...}` |
119+
| OnShed | `func(c *fiber.Ctx) error` | Function to be executed if a request should be declined | `nil` |
71120

72121
## LoadCriteria
73122

@@ -80,7 +129,7 @@ LoadCriteria is an interface in the LoadShed middleware that defines the criteri
80129
#### Properties
81130

82131
| Property | Type | Description |
83-
| :------------- | :----------------- | :------------------------------------------------------------------------------------------------------------------------------------ |
132+
|:---------------|:-------------------|:--------------------------------------------------------------------------------------------------------------------------------------|
84133
| LowerThreshold | `float64` | The lower CPU usage threshold as a fraction (0.0 to 1.0). Requests are considered for shedding when CPU usage exceeds this threshold. |
85134
| UpperThreshold | `float64` | The upper CPU usage threshold as a fraction (0.0 to 1.0). All requests are shed when CPU usage exceeds this threshold. |
86135
| Interval | `time.Duration` | The time interval over which the CPU usage is averaged for decision making. |
@@ -110,10 +159,11 @@ This is the default configuration for `LoadCriteria` in the LoadShed middleware.
110159
var ConfigDefault = Config{
111160
Next: nil,
112161
Criteria: &CPULoadCriteria{
113-
LowerThreshold: 0.90, // 90% CPU usage as the start point for considering shedding
114-
UpperThreshold: 0.95, // 95% CPU usage as the point where all requests are shed
115-
Interval: 10 * time.Second, // CPU usage is averaged over 10 seconds
116-
Getter: &DefaultCPUPercentGetter{}, // Default method for getting CPU usage
117-
},
162+
LowerThreshold: 0.90, // 90% CPU usage as the start point for considering shedding
163+
UpperThreshold: 0.95, // 95% CPU usage as the point where all requests are shed
164+
Interval: 10 * time.Second, // CPU usage is averaged over 10 seconds
165+
Getter: &DefaultCPUPercentGetter{}, // Default method for getting CPU usage
166+
},
167+
OnShed: nil,
118168
}
119169
```

loadshed/loadshed.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ type Config struct {
1212

1313
// Criteria defines the criteria to be used for load shedding.
1414
Criteria LoadCriteria
15+
16+
// OnShed defines a custom handler that will be executed if a request should
17+
// be rejected.
18+
//
19+
// Returning `nil` without writing to the response context allows the
20+
// request to proceed to the next handler
21+
OnShed func(c *fiber.Ctx) error
1522
}
1623

1724
var ConfigDefault = Config{
@@ -45,6 +52,11 @@ func New(config ...Config) fiber.Handler {
4552

4653
// Shed load if the criteria's ShouldShed method returns true
4754
if cfg.Criteria.ShouldShed(metric) {
55+
// Call the custom OnShed function
56+
if cfg.OnShed != nil {
57+
return cfg.OnShed(c)
58+
}
59+
4860
return fiber.NewError(fiber.StatusServiceUnavailable)
4961
}
5062

loadshed/loadshed_test.go

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package loadshed
22

33
import (
44
"context"
5+
"io"
56
"net/http/httptest"
67
"testing"
78
"time"
@@ -98,3 +99,169 @@ func Test_Loadshed_UpperThreshold(t *testing.T) {
9899
utils.AssertEqual(t, nil, err)
99100
utils.AssertEqual(t, fiber.StatusServiceUnavailable, resp.StatusCode)
100101
}
102+
103+
func Test_Loadshed_CustomOnShed(t *testing.T) {
104+
app := fiber.New()
105+
106+
mockGetter := &MockCPUPercentGetter{MockedPercentage: []float64{96.0}}
107+
var cfg Config
108+
cfg.Criteria = &CPULoadCriteria{
109+
LowerThreshold: 0.90,
110+
UpperThreshold: 0.95,
111+
Interval: time.Second,
112+
Getter: mockGetter,
113+
}
114+
cfg.OnShed = func(c *fiber.Ctx) error {
115+
return c.Status(fiber.StatusTooManyRequests).Send([]byte{})
116+
}
117+
118+
app.Use(New(cfg))
119+
app.Get("/", ReturnOK)
120+
121+
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
122+
utils.AssertEqual(t, nil, err)
123+
utils.AssertEqual(t, fiber.StatusTooManyRequests, resp.StatusCode)
124+
}
125+
126+
func Test_Loadshed_CustomOnShedWithResponse(t *testing.T) {
127+
app := fiber.New()
128+
129+
mockGetter := &MockCPUPercentGetter{MockedPercentage: []float64{96.0}}
130+
var cfg Config
131+
cfg.Criteria = &CPULoadCriteria{
132+
LowerThreshold: 0.90,
133+
UpperThreshold: 0.95,
134+
Interval: time.Second,
135+
Getter: mockGetter,
136+
}
137+
138+
// This OnShed directly sets a response without returning it
139+
cfg.OnShed = func(c *fiber.Ctx) error {
140+
c.Status(fiber.StatusTooManyRequests)
141+
return nil
142+
}
143+
144+
app.Use(New(cfg))
145+
app.Get("/", ReturnOK)
146+
147+
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
148+
utils.AssertEqual(t, nil, err)
149+
utils.AssertEqual(t, fiber.StatusTooManyRequests, resp.StatusCode)
150+
}
151+
152+
func Test_Loadshed_CustomOnShedWithNilReturn(t *testing.T) {
153+
app := fiber.New()
154+
155+
mockGetter := &MockCPUPercentGetter{MockedPercentage: []float64{96.0}}
156+
var cfg Config
157+
cfg.Criteria = &CPULoadCriteria{
158+
LowerThreshold: 0.90,
159+
UpperThreshold: 0.95,
160+
Interval: time.Second,
161+
Getter: mockGetter,
162+
}
163+
164+
// OnShed returns nil without setting a response
165+
cfg.OnShed = func(c *fiber.Ctx) error {
166+
return nil
167+
}
168+
169+
app.Use(New(cfg))
170+
app.Get("/", ReturnOK)
171+
172+
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
173+
utils.AssertEqual(t, nil, err)
174+
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
175+
}
176+
177+
func Test_Loadshed_CustomOnShedWithCustomError(t *testing.T) {
178+
app := fiber.New()
179+
180+
mockGetter := &MockCPUPercentGetter{MockedPercentage: []float64{96.0}}
181+
var cfg Config
182+
cfg.Criteria = &CPULoadCriteria{
183+
LowerThreshold: 0.90,
184+
UpperThreshold: 0.95,
185+
Interval: time.Second,
186+
Getter: mockGetter,
187+
}
188+
189+
// OnShed returns a custom error
190+
cfg.OnShed = func(c *fiber.Ctx) error {
191+
return fiber.NewError(fiber.StatusForbidden, "Custom error message")
192+
}
193+
194+
app.Use(New(cfg))
195+
app.Get("/", ReturnOK)
196+
197+
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
198+
utils.AssertEqual(t, nil, err)
199+
utils.AssertEqual(t, fiber.StatusForbidden, resp.StatusCode)
200+
}
201+
202+
func Test_Loadshed_CustomOnShedWithResponseAndCustomError(t *testing.T) {
203+
app := fiber.New()
204+
205+
mockGetter := &MockCPUPercentGetter{MockedPercentage: []float64{96.0}}
206+
var cfg Config
207+
cfg.Criteria = &CPULoadCriteria{
208+
LowerThreshold: 0.90,
209+
UpperThreshold: 0.95,
210+
Interval: time.Second,
211+
Getter: mockGetter,
212+
}
213+
214+
// OnShed sets a response and returns a different error
215+
// The NewError have higher priority since executed last
216+
cfg.OnShed = func(c *fiber.Ctx) error {
217+
c.
218+
Status(fiber.StatusTooManyRequests).
219+
SendString("Too many requests")
220+
221+
return fiber.NewError(
222+
fiber.StatusInternalServerError,
223+
"Shed happened",
224+
)
225+
}
226+
227+
app.Use(New(cfg))
228+
app.Get("/", ReturnOK)
229+
230+
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
231+
payload, readErr := io.ReadAll(resp.Body)
232+
defer resp.Body.Close()
233+
234+
utils.AssertEqual(t, string(payload), "Shed happened")
235+
utils.AssertEqual(t, nil, err)
236+
utils.AssertEqual(t, nil, readErr)
237+
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
238+
}
239+
240+
func Test_Loadshed_CustomOnShedWithJSON(t *testing.T) {
241+
app := fiber.New()
242+
243+
mockGetter := &MockCPUPercentGetter{MockedPercentage: []float64{96.0}}
244+
var cfg Config
245+
cfg.Criteria = &CPULoadCriteria{
246+
LowerThreshold: 0.90,
247+
UpperThreshold: 0.95,
248+
Interval: time.Second,
249+
Getter: mockGetter,
250+
}
251+
252+
// OnShed returns JSON response
253+
cfg.OnShed = func(c *fiber.Ctx) error {
254+
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{
255+
"error": "Service is currently unavailable due to high load",
256+
"retry_after": 30,
257+
})
258+
}
259+
260+
app.Use(New(cfg))
261+
app.Get("/", ReturnOK)
262+
263+
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
264+
utils.AssertEqual(t, nil, err)
265+
utils.AssertEqual(t, fiber.StatusServiceUnavailable, resp.StatusCode)
266+
utils.AssertEqual(t, "application/json", resp.Header.Get("Content-Type"))
267+
}

0 commit comments

Comments
 (0)