Skip to content

Commit 630783f

Browse files
committed
refactor: adopt context-aware token operations throughout codebase
- Refactor token store and middleware methods to accept context for improved cancellation and timeout handling - Update all usages of token store methods to pass context, including tests and examples - Revise TokenGenerator and TokenGeneratorWithRevocation signatures to require context - Update documentation and code comments to reflect context usage in token operations Signed-off-by: Bo-Yi Wu <[email protected]>
1 parent 144cf55 commit 630783f

File tree

12 files changed

+154
-129
lines changed

12 files changed

+154
-129
lines changed

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ The `TokenGenerator` functionality allows you to create JWT tokens directly with
331331
package main
332332

333333
import (
334+
"context"
334335
"fmt"
335336
"log"
336337
"time"
@@ -356,9 +357,12 @@ func main() {
356357
log.Fatal("JWT Error:" + err.Error())
357358
}
358359

360+
// Create context for token operations
361+
ctx := context.Background()
362+
359363
// Generate a complete token pair (access + refresh tokens)
360364
userData := "user123"
361-
tokenPair, err := authMiddleware.TokenGenerator(userData)
365+
tokenPair, err := authMiddleware.TokenGenerator(ctx, userData)
362366
if err != nil {
363367
log.Fatal("Failed to generate token pair:", err)
364368
}
@@ -392,7 +396,7 @@ Use `TokenGeneratorWithRevocation` to refresh tokens and automatically revoke ol
392396

393397
```go
394398
// Refresh with automatic revocation of old token
395-
newTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(userData, oldRefreshToken)
399+
newTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(ctx, userData, oldRefreshToken)
396400
if err != nil {
397401
log.Fatal("Failed to refresh token:", err)
398402
}

README.zh-CN.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ import "github.com/appleboy/gin-jwt/v3"
166166
package main
167167

168168
import (
169+
"context"
169170
"fmt"
170171
"log"
171172
"time"
@@ -191,9 +192,12 @@ func main() {
191192
log.Fatal("JWT Error:" + err.Error())
192193
}
193194

195+
// 创建 Token 操作的 context
196+
ctx := context.Background()
197+
194198
// 生成完整的 Token 组(访问 + 刷新 Token)
195199
userData := "user123"
196-
tokenPair, err := authMiddleware.TokenGenerator(userData)
200+
tokenPair, err := authMiddleware.TokenGenerator(ctx, userData)
197201
if err != nil {
198202
log.Fatal("Failed to generate token pair:", err)
199203
}

README.zh-TW.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ import "github.com/appleboy/gin-jwt/v3"
166166
package main
167167

168168
import (
169+
"context"
169170
"fmt"
170171
"log"
171172
"time"
@@ -191,9 +192,12 @@ func main() {
191192
log.Fatal("JWT Error:" + err.Error())
192193
}
193194

195+
// 建立 Token 操作的 context
196+
ctx := context.Background()
197+
194198
// 產生完整的 Token 組(存取 + 刷新 Token)
195199
userData := "user123"
196-
tokenPair, err := authMiddleware.TokenGenerator(userData)
200+
tokenPair, err := authMiddleware.TokenGenerator(ctx, userData)
197201
if err != nil {
198202
log.Fatal("Failed to generate token pair:", err)
199203
}
@@ -227,7 +231,7 @@ func (t *Token) ExpiresIn() int64 // 回傳到期前的秒數
227231

228232
```go
229233
// 刷新並自動撤銷舊 Token
230-
newTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(userData, oldRefreshToken)
234+
newTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(ctx, userData, oldRefreshToken)
231235
if err != nil {
232236
log.Fatal("Failed to refresh token:", err)
233237
}

_example/token_generator/main.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
package main
33

44
import (
5+
"context"
56
"fmt"
67
"log"
78
"time"
@@ -30,9 +31,12 @@ func main() {
3031
// Example user data
3132
userData := "user123"
3233

34+
// Create context for token operations
35+
ctx := context.Background()
36+
3337
// Generate a complete token pair (access + refresh tokens)
3438
fmt.Println("=== Generating Token Pair ===")
35-
tokenPair, err := authMiddleware.TokenGenerator(userData)
39+
tokenPair, err := authMiddleware.TokenGenerator(ctx, userData)
3640
if err != nil {
3741
log.Fatal("Failed to generate token pair:", err)
3842
}
@@ -46,7 +50,7 @@ func main() {
4650

4751
// Simulate refresh token usage
4852
fmt.Println("\n=== Refreshing Token Pair ===")
49-
newTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(userData, tokenPair.RefreshToken)
53+
newTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(ctx, userData, tokenPair.RefreshToken)
5054
if err != nil {
5155
log.Fatal("Failed to refresh token pair:", err)
5256
}
@@ -57,7 +61,7 @@ func main() {
5761

5862
// Verify old refresh token is invalid
5963
fmt.Println("\n=== Verifying Old Token Revocation ===")
60-
_, err = authMiddleware.TokenGeneratorWithRevocation(userData, tokenPair.RefreshToken)
64+
_, err = authMiddleware.TokenGeneratorWithRevocation(ctx, userData, tokenPair.RefreshToken)
6165
if err != nil {
6266
fmt.Printf("Old refresh token correctly rejected: %s\n", err)
6367
}

auth_jwt.go

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package jwt
22

33
import (
4+
"context"
45
"crypto/rand"
56
"crypto/rsa"
67
"encoding/base64"
@@ -568,7 +569,7 @@ func (mw *GinJWTMiddleware) LoginHandler(c *gin.Context) {
568569
}
569570

570571
// Generate complete token pair
571-
tokenPair, err := mw.TokenGenerator(data)
572+
tokenPair, err := mw.TokenGenerator(c.Request.Context(), data)
572573
if err != nil {
573574
mw.unauthorized(c, http.StatusInternalServerError, mw.HTTPStatusMessageFunc(c,ErrFailedTokenCreation))
574575
return
@@ -612,7 +613,7 @@ func (mw *GinJWTMiddleware) LogoutHandler(c *gin.Context) {
612613
// Handle refresh token revocation (RFC 6749 compliant)
613614
refreshToken := mw.extractRefreshToken(c)
614615
if refreshToken != "" {
615-
if err := mw.revokeRefreshToken(refreshToken); err != nil {
616+
if err := mw.revokeRefreshToken(c.Request.Context(), refreshToken); err != nil {
616617
log.Printf("Failed to revoke refresh token on logout: %v", err)
617618
}
618619
}
@@ -658,14 +659,14 @@ func (mw *GinJWTMiddleware) generateRefreshToken() (string, error) {
658659
}
659660

660661
// storeRefreshToken stores a refresh token with user data
661-
func (mw *GinJWTMiddleware) storeRefreshToken(token string, userData any) error {
662+
func (mw *GinJWTMiddleware) storeRefreshToken(ctx context.Context, token string, userData any) error {
662663
expiry := mw.TimeFunc().Add(mw.RefreshTokenTimeout)
663-
return mw.RefreshTokenStore.Set(token, userData, expiry)
664+
return mw.RefreshTokenStore.Set(ctx, token, userData, expiry)
664665
}
665666

666667
// validateRefreshToken validates a refresh token and returns associated user data
667-
func (mw *GinJWTMiddleware) validateRefreshToken(token string) (any, error) {
668-
userData, err := mw.RefreshTokenStore.Get(token)
668+
func (mw *GinJWTMiddleware) validateRefreshToken(ctx context.Context, token string) (any, error) {
669+
userData, err := mw.RefreshTokenStore.Get(ctx, token)
669670
if err != nil {
670671
if err == core.ErrRefreshTokenNotFound {
671672
return nil, ErrInvalidRefreshToken
@@ -676,8 +677,8 @@ func (mw *GinJWTMiddleware) validateRefreshToken(token string) (any, error) {
676677
}
677678

678679
// revokeRefreshToken removes a refresh token from storage
679-
func (mw *GinJWTMiddleware) revokeRefreshToken(token string) error {
680-
return mw.RefreshTokenStore.Delete(token)
680+
func (mw *GinJWTMiddleware) revokeRefreshToken(ctx context.Context, token string) error {
681+
return mw.RefreshTokenStore.Delete(ctx, token)
681682
}
682683

683684
// RefreshHandler can be used to refresh a token using RFC 6749 compliant refresh tokens.
@@ -692,14 +693,14 @@ func (mw *GinJWTMiddleware) RefreshHandler(c *gin.Context) {
692693
}
693694

694695
// Validate refresh token
695-
userData, err := mw.validateRefreshToken(refreshToken)
696+
userData, err := mw.validateRefreshToken(c.Request.Context(), refreshToken)
696697
if err != nil {
697698
mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(c,err))
698699
return
699700
}
700701

701702
// Generate new token pair and revoke old refresh token
702-
tokenPair, err := mw.TokenGeneratorWithRevocation(userData, refreshToken)
703+
tokenPair, err := mw.TokenGeneratorWithRevocation(c.Request.Context(), userData, refreshToken)
703704
if err != nil {
704705
mw.unauthorized(c, http.StatusInternalServerError, mw.HTTPStatusMessageFunc(c,err))
705706
return
@@ -795,7 +796,7 @@ func (mw *GinJWTMiddleware) generateAccessToken(data any) (string, time.Time, er
795796
}
796797

797798
// TokenGenerator generates a complete token pair (access + refresh) with RFC 6749 compliance
798-
func (mw *GinJWTMiddleware) TokenGenerator(data any) (*core.Token, error) {
799+
func (mw *GinJWTMiddleware) TokenGenerator(ctx context.Context, data any) (*core.Token, error) {
799800
// Generate access token
800801
accessToken, expire, err := mw.generateAccessToken(data)
801802
if err != nil {
@@ -809,7 +810,7 @@ func (mw *GinJWTMiddleware) TokenGenerator(data any) (*core.Token, error) {
809810
}
810811

811812
// Store refresh token
812-
if err := mw.storeRefreshToken(refreshToken, data); err != nil {
813+
if err := mw.storeRefreshToken(ctx, refreshToken, data); err != nil {
813814
return nil, err
814815
}
815816

@@ -824,15 +825,15 @@ func (mw *GinJWTMiddleware) TokenGenerator(data any) (*core.Token, error) {
824825
}
825826

826827
// TokenGeneratorWithRevocation generates a new token pair and revokes the old refresh token
827-
func (mw *GinJWTMiddleware) TokenGeneratorWithRevocation(data any, oldRefreshToken string) (*core.Token, error) {
828+
func (mw *GinJWTMiddleware) TokenGeneratorWithRevocation(ctx context.Context, data any, oldRefreshToken string) (*core.Token, error) {
828829
// Generate new token pair
829-
tokenPair, err := mw.TokenGenerator(data)
830+
tokenPair, err := mw.TokenGenerator(ctx, data)
830831
if err != nil {
831832
return nil, err
832833
}
833834

834835
// Revoke old refresh token, ignore if token already doesn't exist
835-
if err := mw.revokeRefreshToken(oldRefreshToken); err != nil && !errors.Is(err, core.ErrRefreshTokenNotFound) {
836+
if err := mw.revokeRefreshToken(ctx, oldRefreshToken); err != nil && !errors.Is(err, core.ErrRefreshTokenNotFound) {
836837
return nil, err
837838
}
838839

auth_jwt_redis_test.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -401,33 +401,34 @@ func testRedisStoreOperations(t *testing.T, middleware *GinJWTMiddleware) {
401401
require.True(t, ok, "should be using Redis store")
402402

403403
// Test store operations directly
404+
ctx := context.Background()
404405
testToken := "direct-test-token"
405406
testData := map[string]any{"test": "data"}
406407
expiry := time.Now().Add(time.Hour)
407408

408409
// Test Set
409-
err := redisStore.Set(testToken, testData, expiry)
410+
err := redisStore.Set(ctx, testToken, testData, expiry)
410411
assert.NoError(t, err, "direct set should succeed")
411412

412413
// Test Get
413-
retrievedData, err := redisStore.Get(testToken)
414+
retrievedData, err := redisStore.Get(ctx, testToken)
414415
assert.NoError(t, err, "direct get should succeed")
415416
assert.Equal(t, testData, retrievedData, "retrieved data should match")
416417

417418
// Test Count
418-
count, err := redisStore.Count()
419+
count, err := redisStore.Count(ctx)
419420
assert.NoError(t, err, "count should succeed")
420421
assert.GreaterOrEqual(t, count, 1, "count should include our test token")
421422

422423
// Test Delete
423-
err = redisStore.Delete(testToken)
424+
err = redisStore.Delete(ctx, testToken)
424425
assert.NoError(t, err, "direct delete should succeed")
425426

426427
// Verify deletion - wait for cache TTL to expire
427428
time.Sleep(100 * time.Millisecond)
428429

429430
// The Get method should return an error for deleted tokens
430-
_, err = redisStore.Get(testToken)
431+
_, err = redisStore.Get(ctx, testToken)
431432
assert.Error(t, err, "token should not exist after deletion")
432433
}
433434

auth_jwt_test.go

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package jwt
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"log"
@@ -1470,7 +1471,8 @@ func TestTokenGenerator(t *testing.T) {
14701471
assert.NoError(t, err)
14711472

14721473
userData := "admin"
1473-
tokenPair, err := authMiddleware.TokenGenerator(userData)
1474+
ctx := context.Background()
1475+
tokenPair, err := authMiddleware.TokenGenerator(ctx, userData)
14741476

14751477
assert.NoError(t, err)
14761478
assert.NotNil(t, tokenPair)
@@ -1510,40 +1512,41 @@ func TestTokenGeneratorWithRevocation(t *testing.T) {
15101512
assert.NoError(t, err)
15111513

15121514
userData := "admin"
1515+
ctx := context.Background()
15131516

15141517
// Generate first token pair
1515-
oldTokenPair, err := authMiddleware.TokenGenerator(userData)
1518+
oldTokenPair, err := authMiddleware.TokenGenerator(ctx, userData)
15161519
assert.NoError(t, err)
15171520

15181521
// Verify old refresh token exists in store
1519-
storedData, err := authMiddleware.validateRefreshToken(oldTokenPair.RefreshToken)
1522+
storedData, err := authMiddleware.validateRefreshToken(ctx, oldTokenPair.RefreshToken)
15201523
assert.NoError(t, err)
15211524
assert.Equal(t, userData, storedData)
15221525

15231526
// Generate new token pair with revocation
1524-
newTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(userData, oldTokenPair.RefreshToken)
1527+
newTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(ctx, userData, oldTokenPair.RefreshToken)
15251528
assert.NoError(t, err)
15261529
assert.NotNil(t, newTokenPair)
15271530

15281531
// Verify refresh tokens are different (access tokens might be the same if generated in same second)
15291532
assert.NotEqual(t, oldTokenPair.RefreshToken, newTokenPair.RefreshToken)
15301533

15311534
// Verify old refresh token is revoked
1532-
_, err = authMiddleware.validateRefreshToken(oldTokenPair.RefreshToken)
1535+
_, err = authMiddleware.validateRefreshToken(ctx, oldTokenPair.RefreshToken)
15331536
assert.Error(t, err)
15341537

15351538
// Verify new refresh token works
1536-
storedData, err = authMiddleware.validateRefreshToken(newTokenPair.RefreshToken)
1539+
storedData, err = authMiddleware.validateRefreshToken(ctx, newTokenPair.RefreshToken)
15371540
assert.NoError(t, err)
15381541
assert.Equal(t, userData, storedData)
15391542

15401543
// Test revoking already revoked token (should not fail)
1541-
anotherTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(userData, oldTokenPair.RefreshToken)
1544+
anotherTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(ctx, userData, oldTokenPair.RefreshToken)
15421545
assert.NoError(t, err)
15431546
assert.NotNil(t, anotherTokenPair)
15441547

15451548
// Test revoking non-existent token (should not fail)
1546-
finalTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(userData, "non_existent_token")
1549+
finalTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(ctx, userData, "non_existent_token")
15471550
assert.NoError(t, err)
15481551
assert.NotNil(t, finalTokenPair)
15491552
}
@@ -1562,7 +1565,8 @@ func TestTokenStruct(t *testing.T) {
15621565
assert.NoError(t, err)
15631566

15641567
userData := "admin"
1565-
tokenPair, err := authMiddleware.TokenGenerator(userData)
1568+
ctx := context.Background()
1569+
tokenPair, err := authMiddleware.TokenGenerator(ctx, userData)
15661570
assert.NoError(t, err)
15671571

15681572
// Test ExpiresIn method

core/store.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
package core
33

44
import (
5+
"context"
56
"errors"
67
"time"
78
)
@@ -18,23 +19,23 @@ var (
1819
type TokenStore interface {
1920
// Set stores a refresh token with associated user data and expiration
2021
// Returns an error if the operation fails
21-
Set(token string, userData any, expiry time.Time) error
22+
Set(ctx context.Context, token string, userData any, expiry time.Time) error
2223

2324
// Get retrieves user data associated with a refresh token
2425
// Returns ErrRefreshTokenNotFound if token doesn't exist or is expired
25-
Get(token string) (any, error)
26+
Get(ctx context.Context, token string) (any, error)
2627

2728
// Delete removes a refresh token from storage
2829
// Returns an error if the operation fails, but should not error if token doesn't exist
29-
Delete(token string) error
30+
Delete(ctx context.Context, token string) error
3031

3132
// Cleanup removes expired tokens (optional, for cleanup routines)
3233
// Returns the number of tokens cleaned up and any error encountered
33-
Cleanup() (int, error)
34+
Cleanup(ctx context.Context) (int, error)
3435

3536
// Count returns the total number of active refresh tokens
3637
// Useful for monitoring and debugging
37-
Count() (int, error)
38+
Count(ctx context.Context) (int, error)
3839
}
3940

4041
// RefreshTokenData holds the data stored with each refresh token

0 commit comments

Comments
 (0)