Skip to content

Commit 4218569

Browse files
committed
add a thread safe context that can be used when UseInternalContext is enabled.
1 parent c0048f6 commit 4218569

File tree

4 files changed

+233
-14
lines changed

4 files changed

+233
-14
lines changed

context.go

Lines changed: 87 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package gin
66

77
import (
8+
"context"
89
"errors"
910
"fmt"
1011
"io"
@@ -92,6 +93,9 @@ type Context struct {
9293
// SameSite allows a server to define a cookie attribute making it impossible for
9394
// the browser to send this cookie along with cross-site requests.
9495
sameSite http.SameSite
96+
97+
internalContextMu sync.RWMutex
98+
internalContext context.Context
9599
}
96100

97101
/************************************/
@@ -113,6 +117,10 @@ func (c *Context) reset() {
113117
c.sameSite = 0
114118
*c.params = (*c.params)[:0]
115119
*c.skippedNodes = (*c.skippedNodes)[:0]
120+
121+
if c.useInternalContext() {
122+
c.WithInternalContext(context.Background())
123+
}
116124
}
117125

118126
// Copy returns a copy of the current context that can be safely used outside the request's scope.
@@ -1358,6 +1366,49 @@ func (c *Context) SetAccepted(formats ...string) {
13581366
/***** GOLANG.ORG/X/NET/CONTEXT *****/
13591367
/************************************/
13601368

1369+
// WithInternalContext replaces the internal context stored with the provided one in a thread safe manner.
1370+
// It's important that any context you pass in is not something the wraps *gin.Context,
1371+
// if you want to wrap a context and then provide it to WithInternalContext, use InternalContext().
1372+
// If you don't plan to provide the context back to WithInternalContext you can safely use *Context directly.
1373+
// Otherwise you'll end up with a stack overflow.
1374+
//
1375+
// For example:
1376+
// var c *Context // given a context
1377+
// // you can safely wrap it and pass it downstream
1378+
// myDownstreamFunction(context.WithValue(c, ...))
1379+
//
1380+
// // but when you want to call WithInternalContext you should do it like this
1381+
// c.WithInternalContext(context.WithValue(c.InternalContext(), ...))
1382+
func (c *Context) WithInternalContext(ctx context.Context) {
1383+
if !c.useInternalContext() {
1384+
panic("Can't use WithInternalContext when UseInternalContext is false")
1385+
}
1386+
1387+
c.internalContextMu.Lock()
1388+
defer c.internalContextMu.Unlock()
1389+
1390+
c.internalContext = ctx
1391+
}
1392+
1393+
// InternalContext provides the currently stored internal context in a thread safe manner.
1394+
// Use this if you want to wrap a context.Context which you'll end up providing to WithInternalContext.
1395+
// If you don't plan to provide the context back to WithInternalContext you can safely use *Context directly.
1396+
func (c *Context) InternalContext() context.Context {
1397+
if !c.useInternalContext() {
1398+
panic("Can't use InternalContext when UseInternalContext is false")
1399+
}
1400+
1401+
c.internalContextMu.RLock()
1402+
defer c.internalContextMu.RUnlock()
1403+
1404+
return c.internalContext
1405+
}
1406+
1407+
// hasRequestContext returns whether c.Request has Context and fallback.
1408+
func (c *Context) useInternalContext() bool {
1409+
return c.engine != nil && c.engine.UseInternalContext
1410+
}
1411+
13611412
// hasRequestContext returns whether c.Request has Context and fallback.
13621413
func (c *Context) hasRequestContext() bool {
13631414
hasFallback := c.engine != nil && c.engine.ContextWithFallback
@@ -1367,26 +1418,44 @@ func (c *Context) hasRequestContext() bool {
13671418

13681419
// Deadline returns that there is no deadline (ok==false) when c.Request has no Context.
13691420
func (c *Context) Deadline() (deadline time.Time, ok bool) {
1370-
if !c.hasRequestContext() {
1371-
return
1421+
if c.useInternalContext() {
1422+
c.internalContextMu.RLock()
1423+
defer c.internalContextMu.RUnlock()
1424+
1425+
return c.internalContext.Deadline()
1426+
} else if c.hasRequestContext() {
1427+
return c.Request.Context().Deadline()
13721428
}
1373-
return c.Request.Context().Deadline()
1429+
1430+
return
13741431
}
13751432

13761433
// Done returns nil (chan which will wait forever) when c.Request has no Context.
13771434
func (c *Context) Done() <-chan struct{} {
1378-
if !c.hasRequestContext() {
1379-
return nil
1435+
if c.useInternalContext() {
1436+
c.internalContextMu.RLock()
1437+
defer c.internalContextMu.RUnlock()
1438+
1439+
return c.internalContext.Done()
1440+
} else if c.hasRequestContext() {
1441+
return c.Request.Context().Done()
13801442
}
1381-
return c.Request.Context().Done()
1443+
1444+
return nil
13821445
}
13831446

13841447
// Err returns nil when c.Request has no Context.
13851448
func (c *Context) Err() error {
1386-
if !c.hasRequestContext() {
1387-
return nil
1449+
if c.useInternalContext() {
1450+
c.internalContextMu.RLock()
1451+
defer c.internalContextMu.RUnlock()
1452+
1453+
return c.internalContext.Err()
1454+
} else if c.hasRequestContext() {
1455+
return c.Request.Context().Err()
13881456
}
1389-
return c.Request.Context().Err()
1457+
1458+
return nil
13901459
}
13911460

13921461
// Value returns the value associated with this context for key, or nil
@@ -1404,8 +1473,14 @@ func (c *Context) Value(key any) any {
14041473
return val
14051474
}
14061475
}
1407-
if !c.hasRequestContext() {
1408-
return nil
1476+
if c.useInternalContext() {
1477+
c.internalContextMu.RLock()
1478+
defer c.internalContextMu.RUnlock()
1479+
1480+
return c.internalContext.Value(key)
1481+
} else if c.hasRequestContext() {
1482+
return c.Request.Context().Value(key)
14091483
}
1410-
return c.Request.Context().Value(key)
1484+
1485+
return nil
14111486
}

context_test.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3138,6 +3138,142 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
31383138
}
31393139
}
31403140

3141+
func TestContextUseInternalContextDeadline(t *testing.T) {
3142+
c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
3143+
// enable UseInternalContext feature flag
3144+
c.engine.UseInternalContext = true
3145+
})
3146+
3147+
deadline, ok := c.Deadline()
3148+
assert.Zero(t, deadline)
3149+
assert.False(t, ok)
3150+
3151+
c2, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
3152+
// enable UseInternalContext feature flag
3153+
c.engine.UseInternalContext = true
3154+
})
3155+
3156+
d := time.Now().Add(time.Second)
3157+
ctx, cancel := context.WithDeadline(context.Background(), d)
3158+
defer cancel()
3159+
c2.WithInternalContext(ctx)
3160+
deadline, ok = c2.Deadline()
3161+
assert.Equal(t, d, deadline)
3162+
assert.True(t, ok)
3163+
}
3164+
3165+
func TestContextUseInternalContextDone(t *testing.T) {
3166+
c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
3167+
// enable UseInternalContext feature flag
3168+
c.engine.UseInternalContext = true
3169+
})
3170+
3171+
assert.Nil(t, c.Done())
3172+
3173+
c2, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
3174+
// enable UseInternalContext feature flag
3175+
c.engine.UseInternalContext = true
3176+
})
3177+
3178+
ctx, cancel := context.WithCancel(context.Background())
3179+
c2.WithInternalContext(ctx)
3180+
cancel()
3181+
assert.NotNil(t, <-c2.Done())
3182+
}
3183+
3184+
func TestContextUseInternalContextErr(t *testing.T) {
3185+
c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
3186+
// enable UseInternalContext feature flag
3187+
c.engine.UseInternalContext = true
3188+
})
3189+
3190+
require.NoError(t, c.Err())
3191+
3192+
c2, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
3193+
// enable UseInternalContext feature flag
3194+
c.engine.UseInternalContext = true
3195+
})
3196+
3197+
ctx, cancel := context.WithCancel(context.Background())
3198+
c2.WithInternalContext(ctx)
3199+
cancel()
3200+
3201+
assert.EqualError(t, c2.Err(), context.Canceled.Error())
3202+
}
3203+
3204+
func TestContextUseInternalContextValue(t *testing.T) {
3205+
type contextKey string
3206+
3207+
tests := []struct {
3208+
name string
3209+
getContextAndKey func() (*Context, any)
3210+
value any
3211+
}{
3212+
{
3213+
name: "c with struct context key",
3214+
getContextAndKey: func() (*Context, any) {
3215+
type KeyStruct struct{} // https://staticcheck.dev/docs/checks/#SA1029
3216+
var key KeyStruct
3217+
c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
3218+
// enable UseInternalContext feature flag
3219+
c.engine.UseInternalContext = true
3220+
})
3221+
c.WithInternalContext(context.WithValue(context.TODO(), key, "value"))
3222+
return c, key
3223+
},
3224+
value: "value",
3225+
},
3226+
{
3227+
name: "c with struct context key and request context with different value",
3228+
getContextAndKey: func() (*Context, any) {
3229+
type KeyStruct struct{} // https://staticcheck.dev/docs/checks/#SA1029
3230+
var key KeyStruct
3231+
c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
3232+
// enable UseInternalContext feature flag
3233+
c.engine.UseInternalContext = true
3234+
// enable ContextWithFallback feature flag
3235+
c.engine.ContextWithFallback = true
3236+
c.Request, _ = http.NewRequest(http.MethodPost, "/", nil)
3237+
})
3238+
c.WithInternalContext(context.WithValue(context.TODO(), key, "value"))
3239+
c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "other value"))
3240+
return c, key
3241+
},
3242+
value: "value",
3243+
},
3244+
{
3245+
name: "c with string context key",
3246+
getContextAndKey: func() (*Context, any) {
3247+
c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
3248+
// enable UseInternalContext feature flag
3249+
c.engine.UseInternalContext = true
3250+
})
3251+
c.WithInternalContext(context.WithValue(context.TODO(), contextKey("key"), "value"))
3252+
return c, contextKey("key")
3253+
},
3254+
value: "value",
3255+
},
3256+
{
3257+
name: "c with background internal context",
3258+
getContextAndKey: func() (*Context, any) {
3259+
c, _ := CreateTestContext(httptest.NewRecorder(), func(c *Context) {
3260+
// enable UseInternalContext feature flag
3261+
c.engine.UseInternalContext = true
3262+
})
3263+
c.WithInternalContext(context.Background())
3264+
return c, "key"
3265+
},
3266+
value: nil,
3267+
},
3268+
}
3269+
for _, tt := range tests {
3270+
t.Run(tt.name, func(t *testing.T) {
3271+
c, key := tt.getContextAndKey()
3272+
assert.Equal(t, tt.value, c.Value(key))
3273+
})
3274+
}
3275+
}
3276+
31413277
func TestContextCopyShouldNotCancel(t *testing.T) {
31423278
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
31433279
w.WriteHeader(http.StatusOK)

gin.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,14 @@ type Engine struct {
161161
// UseH2C enable h2c support.
162162
UseH2C bool
163163

164-
// ContextWithFallback enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value() when Context.Request.Context() is not nil.
164+
// ContextWithFallback enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value()
165+
// through Context.Request when Context.Request.Context() is not nil.
165166
ContextWithFallback bool
166167

168+
// UseInternalContext enable fallback Context.Deadline(), Context.Done(), Context.Err()
169+
// through InternalContext and supersedes ContextWithFallback
170+
UseInternalContext bool
171+
167172
delims render.Delims
168173
secureJSONPrefix string
169174
HTMLRender render.HTMLRender

test_helpers.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@ import "net/http"
1010
// This is useful for tests that need to set up a new Gin engine instance
1111
// along with a context, for example, to test middleware that doesn't depend on
1212
// specific routes. The ResponseWriter `w` is used to initialize the context's writer.
13-
func CreateTestContext(w http.ResponseWriter) (c *Context, r *Engine) {
13+
func CreateTestContext(w http.ResponseWriter, opts ...func(c *Context)) (c *Context, r *Engine) {
1414
r = New()
1515
c = r.allocateContext(0)
16+
for _, opt := range opts {
17+
opt(c)
18+
}
1619
c.reset()
1720
c.writermem.reset(w)
1821
return

0 commit comments

Comments
 (0)