Skip to content

Commit 65effb8

Browse files
authored
Add use and middleware (#50)
* Add use and middleware * Add runtime checks to prevent wrong handling of constructor
1 parent 77e01ba commit 65effb8

File tree

2 files changed

+123
-0
lines changed

2 files changed

+123
-0
lines changed

xhttp/xhttp.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package xhttp
2+
3+
import (
4+
"errors"
5+
"net/http"
6+
)
7+
8+
type Middleware func(m http.RoundTripper) http.RoundTripper
9+
10+
// Use is wrapping up a RoundTripper with a set of middleware
11+
func Use(client *http.Client, middlewares ...Middleware) *http.Client {
12+
if client == nil {
13+
panic(errors.New("client is nil"))
14+
}
15+
if len(middlewares) == 0 {
16+
panic(errors.New("middlewares is nil"))
17+
}
18+
if client.Transport == nil {
19+
client.Transport = http.DefaultTransport
20+
}
21+
current := client.Transport
22+
for _, middleware := range middlewares {
23+
if middleware == nil {
24+
panic(errors.New("middleware is nil"))
25+
}
26+
current = middleware(current)
27+
}
28+
client.Transport = current
29+
return client
30+
}

xhttp/xhttp_test.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package xhttp_test
2+
3+
import (
4+
"github.com/donutloop/toolkit/xhttp"
5+
"log"
6+
"net/http"
7+
"net/http/httptest"
8+
"testing"
9+
)
10+
11+
type TestMiddleware struct {
12+
roundtripper http.RoundTripper
13+
Log func(v ...interface{})
14+
ID int
15+
}
16+
17+
func (m *TestMiddleware) RoundTrip(req *http.Request) (*http.Response, error) {
18+
m.Log("hitted middleware ", m.ID)
19+
resp, err := m.roundtripper.RoundTrip(req)
20+
if err != nil {
21+
return resp, nil
22+
}
23+
return resp, nil
24+
}
25+
26+
func TestInjectMiddlware(t *testing.T) {
27+
28+
handler := func(w http.ResponseWriter, r *http.Request) {
29+
log.Println("hitted handler")
30+
}
31+
s := httptest.NewServer(http.HandlerFunc(handler))
32+
defer s.Close()
33+
34+
m1 := func(m http.RoundTripper) http.RoundTripper {
35+
return &TestMiddleware{m, log.Println, 1}
36+
}
37+
38+
m2 := func(m http.RoundTripper) http.RoundTripper {
39+
return &TestMiddleware{m, log.Println, 2}
40+
}
41+
42+
m3 := func(m http.RoundTripper) http.RoundTripper {
43+
return &TestMiddleware{m, log.Println, 3}
44+
}
45+
46+
httpClient := new(http.Client)
47+
httpClient = xhttp.Use(httpClient, m1, m2)
48+
httpClient = xhttp.Use(httpClient, m3)
49+
resp, err := httpClient.Get(s.URL)
50+
if err != nil {
51+
log.Fatal(err)
52+
}
53+
54+
if resp.StatusCode != http.StatusOK {
55+
log.Fatal(err)
56+
}
57+
}
58+
59+
func TestPanicNilClient(t *testing.T) {
60+
defer func() {
61+
v := recover()
62+
err := v.(error)
63+
if err.Error() != "client is nil" {
64+
t.Errorf("error message is bad (%v)", v)
65+
}
66+
}()
67+
68+
xhttp.Use(nil, nil)
69+
}
70+
71+
func TestPanicNilMiddleware(t *testing.T) {
72+
defer func() {
73+
v := recover()
74+
err := v.(error)
75+
if err.Error() != "middleware is nil" {
76+
t.Errorf("error message is bad (%v)", v)
77+
}
78+
}()
79+
80+
xhttp.Use(new(http.Client), nil)
81+
}
82+
83+
func TestPanicNilMiddlewares(t *testing.T) {
84+
defer func() {
85+
v := recover()
86+
err := v.(error)
87+
if err.Error() != "middlewares is nil" {
88+
t.Errorf("error message is bad (%v)", v)
89+
}
90+
}()
91+
92+
xhttp.Use(new(http.Client))
93+
}

0 commit comments

Comments
 (0)