Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions vault/transit.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"encoding/base64"
"fmt"
"net/http"

"k8s.io/apimachinery/pkg/util/json"
)
Expand All @@ -22,25 +23,32 @@ type (
Plaintext string `json:"plaintext"`
}

// TransitOption modifies parameters for a Transit Encrypt/Decrypt request.
// Individual options may only apply to certain operations.
TransitOption func(m map[string]any)
TransitRequestOptions struct {
Params map[string]any
Headers http.Header
}

// TransitOption modifies parameters and/or headers for a Transit request.
TransitOption func(*TransitRequestOptions)
)

// EncryptWithTransit encrypts data using Vault Transit.
func EncryptWithTransit(ctx context.Context, vaultClient Client, mount, key string, data []byte, opts ...TransitOption) ([]byte, error) {
path := fmt.Sprintf("%s/encrypt/%s", mount, key)

params := map[string]any{
"name": key,
"plaintext": base64.StdEncoding.EncodeToString(data),
req := &TransitRequestOptions{
Params: map[string]any{
"name": key,
"plaintext": base64.StdEncoding.EncodeToString(data),
},
Headers: make(http.Header),
}

for _, opt := range opts {
opt(params)
opt(req)
}

resp, err := vaultClient.Write(ctx, NewWriteRequest(path, params, nil))
resp, err := vaultClient.Write(ctx, NewWriteRequest(path, req.Params, req.Headers))
if err != nil {
return nil, err
}
Expand All @@ -63,14 +71,21 @@ func DecryptWithTransit(ctx context.Context, vaultClient Client, mount, key stri
}

// DecryptCiphertextWithTransit decrypts a ciphertext value using Vault Transit.
func DecryptCiphertextWithTransit(ctx context.Context, vaultClient Client, mount, key, ciphertext string) ([]byte, error) {
func DecryptCiphertextWithTransit(ctx context.Context, vaultClient Client, mount, key, ciphertext string, opts ...TransitOption) ([]byte, error) {
path := fmt.Sprintf("%s/decrypt/%s", mount, key)
params := map[string]interface{}{
"name": key,
"ciphertext": ciphertext,
req := &TransitRequestOptions{
Params: map[string]interface{}{
"name": key,
"ciphertext": ciphertext,
},
Headers: make(http.Header),
}

for _, opt := range opts {
opt(req)
}

resp, err := vaultClient.Write(ctx, NewWriteRequest(path, params, nil))
resp, err := vaultClient.Write(ctx, NewWriteRequest(path, req.Params, req.Headers))
if err != nil {
return nil, err
}
Expand All @@ -95,5 +110,16 @@ func DecryptCiphertextWithTransit(ctx context.Context, vaultClient Client, mount
// WithKeyVersion sets the key version for EncryptWithTransit.
// It is ignored when passed to DecryptWithTransit.
func WithKeyVersion(v uint) TransitOption {
return func(m map[string]any) { m["key_version"] = v }
return func(opt *TransitRequestOptions) {
opt.Params["key_version"] = v
}
}

func WithNamespace(namespace string) TransitOption {
return func(opt *TransitRequestOptions) {
if namespace == "" {
return
}
opt.Headers.Set("X-Vault-Namespace", namespace)
}
}
54 changes: 49 additions & 5 deletions vault/transit_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package vault

import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -28,13 +29,56 @@ func TestWithKeyVersion(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := WithKeyVersion(tt.v)
m := make(map[string]any)
opt := WithKeyVersion(tt.v)

opts(m)
req := &TransitRequestOptions{
Params: make(map[string]any),
Headers: make(http.Header),
}

require.Contains(t, m, "key_version")
assert.Equal(t, tt.v, m["key_version"])
opt(req)

require.Contains(t, req.Params, "key_version")
assert.Equal(t, tt.v, req.Params["key_version"])
})
}
}

func TestWithNamespace(t *testing.T) {
tests := []struct {
name string
namespace string
wantSet bool
}{
{
name: "sets namespace header",
namespace: "foo/namespace",
wantSet: true,
},
{
name: "empty namespace does nothing",
namespace: "",
wantSet: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opt := WithNamespace(tt.namespace)

req := &TransitRequestOptions{
Params: make(map[string]any),
Headers: make(http.Header),
}

opt(req)

got := req.Headers.Get("X-Vault-Namespace")
if tt.wantSet {
require.Equal(t, tt.namespace, got)
} else {
require.Equal(t, "", got)
}
})
}
}
Loading