Skip to content

Commit ab5e159

Browse files
committed
add WithNamespace transit option
1 parent 3e6bef8 commit ab5e159

File tree

3 files changed

+101
-19
lines changed

3 files changed

+101
-19
lines changed

vault/client.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ func (c *defaultClient) Close(revoke bool) {
489489
// This will ensure that the auth token is periodically renewed.
490490
// If the Client's token is not renewable an error will be returned.
491491
func (c *defaultClient) startLifetimeWatcher(ctx context.Context) error {
492+
fmt.Printf("skip renewal: %v\n", c.skipRenewal)
492493
if c.skipRenewal {
493494
return nil
494495
}
@@ -517,12 +518,16 @@ func (c *defaultClient) startLifetimeWatcher(ctx context.Context) error {
517518
watcherID := uuid.NewString()
518519
wg := sync.WaitGroup{}
519520
wg.Add(1)
521+
fmt.Println("ready to go go go")
520522
go func(ctx context.Context, c *defaultClient, watcher *api.LifetimeWatcher) {
523+
fmt.Println("go go go")
521524
logger := log.FromContext(ctx).WithName("lifetimeWatcher").WithValues(
522525
"id", watcherID, "entityID", c.authSecret.Auth.EntityID,
523526
"clientID", c.id, "cacheKey", cacheKey)
524527
logger.Info("Starting")
528+
fmt.Println("starting")
525529
defer func() {
530+
fmt.Println("stopping")
526531
logger.Info("Stopping")
527532
watcher.Stop()
528533
}()
@@ -531,12 +536,15 @@ func (c *defaultClient) startLifetimeWatcher(ctx context.Context) error {
531536
c.watcher = watcher
532537
wg.Done()
533538
logger.V(consts.LogLevelDebug).Info("Started")
539+
fmt.Println("started")
534540
for {
535541
select {
536542
case result := <-ctx.Done():
543+
fmt.Println("context done")
537544
logger.V(consts.LogLevelTrace).Info("Context done", "result", result)
538545
return
539546
case err := <-watcher.DoneCh():
547+
fmt.Println("watcher done ch")
540548
if err != nil {
541549
logger.Error(err, "LifetimeWatcher completed with an error")
542550
c.lastWatcherErr = err
@@ -559,6 +567,7 @@ func (c *defaultClient) startLifetimeWatcher(ctx context.Context) error {
559567

560568
return
561569
case renewal := <-watcher.RenewCh():
570+
fmt.Println("renew")
562571
logger.V(consts.LogLevelDebug).Info("Successfully renewed the client")
563572

564573
c.authSecret = renewal.Secret
@@ -639,8 +648,11 @@ func (c *defaultClient) Login(ctx context.Context, client ctrlclient.Client) err
639648

640649
c.id = id
641650

651+
fmt.Println("hii")
652+
fmt.Println(resp.Secret().Auth.Renewable)
642653
if resp.Secret().Auth.Renewable {
643654
if err := c.startLifetimeWatcher(ctx); err != nil {
655+
fmt.Printf("failed to atcher: %v\n", err)
644656
errs = err
645657
return errs
646658
}

vault/transit.go

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"context"
88
"encoding/base64"
99
"fmt"
10+
"net/http"
1011

1112
"k8s.io/apimachinery/pkg/util/json"
1213
)
@@ -22,25 +23,32 @@ type (
2223
Plaintext string `json:"plaintext"`
2324
}
2425

25-
// TransitOption modifies parameters for a Transit Encrypt/Decrypt request.
26-
// Individual options may only apply to certain operations.
27-
TransitOption func(m map[string]any)
26+
TransitRequestOptions struct {
27+
Params map[string]any
28+
Headers http.Header
29+
}
30+
31+
// TransitOption modifies parameters and/or headers for a Transit request.
32+
TransitOption func(*TransitRequestOptions)
2833
)
2934

3035
// EncryptWithTransit encrypts data using Vault Transit.
3136
func EncryptWithTransit(ctx context.Context, vaultClient Client, mount, key string, data []byte, opts ...TransitOption) ([]byte, error) {
3237
path := fmt.Sprintf("%s/encrypt/%s", mount, key)
3338

34-
params := map[string]any{
35-
"name": key,
36-
"plaintext": base64.StdEncoding.EncodeToString(data),
39+
req := &TransitRequestOptions{
40+
Params: map[string]any{
41+
"name": key,
42+
"plaintext": base64.StdEncoding.EncodeToString(data),
43+
},
44+
Headers: make(http.Header),
3745
}
3846

3947
for _, opt := range opts {
40-
opt(params)
48+
opt(req)
4149
}
4250

43-
resp, err := vaultClient.Write(ctx, NewWriteRequest(path, params, nil))
51+
resp, err := vaultClient.Write(ctx, NewWriteRequest(path, req.Params, req.Headers))
4452
if err != nil {
4553
return nil, err
4654
}
@@ -63,14 +71,21 @@ func DecryptWithTransit(ctx context.Context, vaultClient Client, mount, key stri
6371
}
6472

6573
// DecryptCiphertextWithTransit decrypts a ciphertext value using Vault Transit.
66-
func DecryptCiphertextWithTransit(ctx context.Context, vaultClient Client, mount, key, ciphertext string) ([]byte, error) {
74+
func DecryptCiphertextWithTransit(ctx context.Context, vaultClient Client, mount, key, ciphertext string, opts ...TransitOption) ([]byte, error) {
6775
path := fmt.Sprintf("%s/decrypt/%s", mount, key)
68-
params := map[string]interface{}{
69-
"name": key,
70-
"ciphertext": ciphertext,
76+
req := &TransitRequestOptions{
77+
Params: map[string]interface{}{
78+
"name": key,
79+
"ciphertext": ciphertext,
80+
},
81+
Headers: make(http.Header),
82+
}
83+
84+
for _, opt := range opts {
85+
opt(req)
7186
}
7287

73-
resp, err := vaultClient.Write(ctx, NewWriteRequest(path, params, nil))
88+
resp, err := vaultClient.Write(ctx, NewWriteRequest(path, req.Params, req.Headers))
7489
if err != nil {
7590
return nil, err
7691
}
@@ -95,5 +110,16 @@ func DecryptCiphertextWithTransit(ctx context.Context, vaultClient Client, mount
95110
// WithKeyVersion sets the key version for EncryptWithTransit.
96111
// It is ignored when passed to DecryptWithTransit.
97112
func WithKeyVersion(v uint) TransitOption {
98-
return func(m map[string]any) { m["key_version"] = v }
113+
return func(opt *TransitRequestOptions) {
114+
opt.Params["key_version"] = v
115+
}
116+
}
117+
118+
func WithNamespace(namespace string) TransitOption {
119+
return func(opt *TransitRequestOptions) {
120+
if namespace == "" {
121+
return
122+
}
123+
opt.Headers.Set("X-Vault-Namespace", namespace)
124+
}
99125
}

vault/transit_test.go

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package vault
22

33
import (
4+
"net/http"
45
"testing"
56

67
"github.com/stretchr/testify/assert"
@@ -28,13 +29,56 @@ func TestWithKeyVersion(t *testing.T) {
2829

2930
for _, tt := range tests {
3031
t.Run(tt.name, func(t *testing.T) {
31-
opts := WithKeyVersion(tt.v)
32-
m := make(map[string]any)
32+
opt := WithKeyVersion(tt.v)
3333

34-
opts(m)
34+
req := &TransitRequestOptions{
35+
Params: make(map[string]any),
36+
Headers: make(http.Header),
37+
}
3538

36-
require.Contains(t, m, "key_version")
37-
assert.Equal(t, tt.v, m["key_version"])
39+
opt(req)
40+
41+
require.Contains(t, req.Params, "key_version")
42+
assert.Equal(t, tt.v, req.Params["key_version"])
43+
})
44+
}
45+
}
46+
47+
func TestWithNamespace(t *testing.T) {
48+
tests := []struct {
49+
name string
50+
namespace string
51+
wantSet bool
52+
}{
53+
{
54+
name: "sets namespace header",
55+
namespace: "foo/namespace",
56+
wantSet: true,
57+
},
58+
{
59+
name: "empty namespace does nothing",
60+
namespace: "",
61+
wantSet: false,
62+
},
63+
}
64+
65+
for _, tt := range tests {
66+
t.Run(tt.name, func(t *testing.T) {
67+
opt := WithNamespace(tt.namespace)
68+
69+
req := &TransitRequestOptions{
70+
Params: make(map[string]any),
71+
Headers: make(http.Header),
72+
}
73+
74+
opt(req)
75+
76+
got := req.Headers.Get("X-Vault-Namespace")
77+
if tt.wantSet {
78+
require.Equal(t, tt.namespace, got)
79+
} else {
80+
require.Equal(t, "", got)
81+
}
3882
})
3983
}
4084
}

0 commit comments

Comments
 (0)