diff --git a/vault/transit.go b/vault/transit.go index 88fce061..d6be5f75 100644 --- a/vault/transit.go +++ b/vault/transit.go @@ -7,6 +7,7 @@ import ( "context" "encoding/base64" "fmt" + "net/http" "k8s.io/apimachinery/pkg/util/json" ) @@ -22,25 +23,32 @@ type ( Plaintext string `json:"plaintext"` } - // TransitOption modifies parameters for a Transit Encrypt/Decrypt request. - // Individual options may only apply to certain operations. - TransitOption func(m map[string]any) + TransitRequestOptions struct { + Params map[string]any + Headers http.Header + } + + // TransitOption modifies parameters and/or headers for a Transit request. + TransitOption func(*TransitRequestOptions) ) // EncryptWithTransit encrypts data using Vault Transit. func EncryptWithTransit(ctx context.Context, vaultClient Client, mount, key string, data []byte, opts ...TransitOption) ([]byte, error) { path := fmt.Sprintf("%s/encrypt/%s", mount, key) - params := map[string]any{ - "name": key, - "plaintext": base64.StdEncoding.EncodeToString(data), + req := &TransitRequestOptions{ + Params: map[string]any{ + "name": key, + "plaintext": base64.StdEncoding.EncodeToString(data), + }, + Headers: make(http.Header), } for _, opt := range opts { - opt(params) + opt(req) } - resp, err := vaultClient.Write(ctx, NewWriteRequest(path, params, nil)) + resp, err := vaultClient.Write(ctx, NewWriteRequest(path, req.Params, req.Headers)) if err != nil { return nil, err } @@ -63,14 +71,21 @@ func DecryptWithTransit(ctx context.Context, vaultClient Client, mount, key stri } // DecryptCiphertextWithTransit decrypts a ciphertext value using Vault Transit. -func DecryptCiphertextWithTransit(ctx context.Context, vaultClient Client, mount, key, ciphertext string) ([]byte, error) { +func DecryptCiphertextWithTransit(ctx context.Context, vaultClient Client, mount, key, ciphertext string, opts ...TransitOption) ([]byte, error) { path := fmt.Sprintf("%s/decrypt/%s", mount, key) - params := map[string]interface{}{ - "name": key, - "ciphertext": ciphertext, + req := &TransitRequestOptions{ + Params: map[string]interface{}{ + "name": key, + "ciphertext": ciphertext, + }, + Headers: make(http.Header), + } + + for _, opt := range opts { + opt(req) } - resp, err := vaultClient.Write(ctx, NewWriteRequest(path, params, nil)) + resp, err := vaultClient.Write(ctx, NewWriteRequest(path, req.Params, req.Headers)) if err != nil { return nil, err } @@ -95,5 +110,16 @@ func DecryptCiphertextWithTransit(ctx context.Context, vaultClient Client, mount // WithKeyVersion sets the key version for EncryptWithTransit. // It is ignored when passed to DecryptWithTransit. func WithKeyVersion(v uint) TransitOption { - return func(m map[string]any) { m["key_version"] = v } + return func(opt *TransitRequestOptions) { + opt.Params["key_version"] = v + } +} + +func WithNamespace(namespace string) TransitOption { + return func(opt *TransitRequestOptions) { + if namespace == "" { + return + } + opt.Headers.Set("X-Vault-Namespace", namespace) + } } diff --git a/vault/transit_test.go b/vault/transit_test.go index a90db0ff..ffa15bdf 100644 --- a/vault/transit_test.go +++ b/vault/transit_test.go @@ -1,6 +1,7 @@ package vault import ( + "net/http" "testing" "github.com/stretchr/testify/assert" @@ -28,13 +29,56 @@ func TestWithKeyVersion(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - opts := WithKeyVersion(tt.v) - m := make(map[string]any) + opt := WithKeyVersion(tt.v) - opts(m) + req := &TransitRequestOptions{ + Params: make(map[string]any), + Headers: make(http.Header), + } - require.Contains(t, m, "key_version") - assert.Equal(t, tt.v, m["key_version"]) + opt(req) + + require.Contains(t, req.Params, "key_version") + assert.Equal(t, tt.v, req.Params["key_version"]) + }) + } +} + +func TestWithNamespace(t *testing.T) { + tests := []struct { + name string + namespace string + wantSet bool + }{ + { + name: "sets namespace header", + namespace: "foo/namespace", + wantSet: true, + }, + { + name: "empty namespace does nothing", + namespace: "", + wantSet: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opt := WithNamespace(tt.namespace) + + req := &TransitRequestOptions{ + Params: make(map[string]any), + Headers: make(http.Header), + } + + opt(req) + + got := req.Headers.Get("X-Vault-Namespace") + if tt.wantSet { + require.Equal(t, tt.namespace, got) + } else { + require.Equal(t, "", got) + } }) } }