From af22e226531a62d17f15bdff76c2cb241b1902cc Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 12 Dec 2025 16:33:45 -0800 Subject: [PATCH 01/27] Platform KMS (WIP) --- kms/apiv1/options.go | 25 ++++++++++++ kms/platform/kms.go | 76 +++++++++++++++++++++++++++++++++++++ kms/platform/kms_darwin.go | 33 ++++++++++++++++ kms/platform/kms_other.go | 13 +++++++ kms/platform/kms_tpm.go | 41 ++++++++++++++++++++ kms/platform/kms_windows.go | 38 +++++++++++++++++++ 6 files changed, 226 insertions(+) create mode 100644 kms/platform/kms.go create mode 100644 kms/platform/kms_darwin.go create mode 100644 kms/platform/kms_other.go create mode 100644 kms/platform/kms_tpm.go create mode 100644 kms/platform/kms_windows.go diff --git a/kms/apiv1/options.go b/kms/apiv1/options.go index 3b50b942..667d2fb5 100644 --- a/kms/apiv1/options.go +++ b/kms/apiv1/options.go @@ -18,6 +18,17 @@ type KeyManager interface { Close() error } +// KeyDeleter is an optional interface for KMS implementations that support +// deleting keys. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a later +// release. +type KeyDeleter interface { + DeleteKey(req *DeleteKeyRequest) error +} + // SearchableKeyManager is an optional interface for KMS implementations // that support searching for keys based on certain attributes. // @@ -54,6 +65,17 @@ type CertificateChainManager interface { StoreCertificateChain(req *StoreCertificateChainRequest) error } +// CertificateDeleter is an optional interface for KMS implementations that +// support deleting certificates. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a later +// release. +type CertificateDeleter interface { + DeleteCertificate(req *DeleteCertificateRequest) error +} + // NameValidator is an interface that KeyManager can implement to validate a // given name or URI. type NameValidator interface { @@ -151,6 +173,9 @@ const ( TPMKMS Type = "tpmkms" // MacKMS is the KMS implementation using macOS Keychain and Secure Enclave. MacKMS Type = "mackms" + // PlatformKMS is the KMS implementation that uses TPMKMS on Windows and + // Linux and MacKMS on macOS.. + PlatformKMS Type = "kms" ) // TypeOf returns the type of of the given uri. diff --git a/kms/platform/kms.go b/kms/platform/kms.go new file mode 100644 index 00000000..3015d09a --- /dev/null +++ b/kms/platform/kms.go @@ -0,0 +1,76 @@ +package platform + +import ( + "context" + "crypto" + "crypto/x509" + + "go.step.sm/crypto/kms/apiv1" +) + +const Scheme = "kms" + +func init() { + apiv1.Register(apiv1.PlatformKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) { + return New(ctx, opts) + }) +} + +type extendedKeyManager interface { + apiv1.KeyManager + apiv1.KeyDeleter + apiv1.CertificateManager + apiv1.CertificateChainManager +} + +var _ apiv1.KeyManager = (*KMS)(nil) +var _ apiv1.CertificateManager = (*KMS)(nil) +var _ apiv1.CertificateChainManager = (*KMS)(nil) + +type KMS struct { + backend extendedKeyManager +} + +func New(ctx context.Context, opts apiv1.Options) (*KMS, error) { + return newKMS(ctx, opts) +} + +func (k *KMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) { + return k.backend.GetPublicKey(req) +} + +func (k *KMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) { + return k.backend.CreateKey(req) +} + +func (k *KMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) { + return k.backend.CreateSigner(req) +} + +func (k *KMS) Close() error { + return k.backend.Close() +} + +func (k *KMS) DeleteKey(req *apiv1.DeleteKeyRequest) error { + return k.backend.DeleteKey(req) +} + +func (k *KMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certificate, error) { + return k.backend.LoadCertificate(req) +} + +func (k *KMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { + return k.backend.StoreCertificate(req) +} + +func (k *KMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([]*x509.Certificate, error) { + return k.backend.LoadCertificateChain(req) +} + +func (k *KMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) error { + if km, ok := k.backend.(apiv1.CertificateChainManager); ok { + return km.StoreCertificateChain(req) + } + + return apiv1.NotImplementedError{} +} diff --git a/kms/platform/kms_darwin.go b/kms/platform/kms_darwin.go new file mode 100644 index 00000000..cda75366 --- /dev/null +++ b/kms/platform/kms_darwin.go @@ -0,0 +1,33 @@ +package platform + +import ( + "context" + + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/mackms" +) + +var _ apiv1.SearchableKeyManager = (*KMS)(nil) + +func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { + if opts.Type == apiv1.TPMKMS { + return newTPMKMS(ctx, opts) + } + + km, err := mackms.New(ctx, opts) + if err != nil { + return nil, err + } + + return &KMS{ + backend: km, + }, nil +} + +func (k *KMS) SearchKeys(req *apiv1.SearchKeysRequest) (*apiv1.SearchKeysResponse, error) { + if km, ok := k.backend.(apiv1.SearchableKeyManager); ok { + return km.SearchKeys(req) + } + + return nil, apiv1.NotImplementedError{} +} diff --git a/kms/platform/kms_other.go b/kms/platform/kms_other.go new file mode 100644 index 00000000..fb844bad --- /dev/null +++ b/kms/platform/kms_other.go @@ -0,0 +1,13 @@ +//go:build !darwin && !windows + +package platform + +import ( + "context" + + "go.step.sm/crypto/kms/apiv1" +) + +func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { + return newTPMKMS(ctx, opts) +} diff --git a/kms/platform/kms_tpm.go b/kms/platform/kms_tpm.go new file mode 100644 index 00000000..bca38f2f --- /dev/null +++ b/kms/platform/kms_tpm.go @@ -0,0 +1,41 @@ +package platform + +import ( + "context" + + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/tpmkms" + "go.step.sm/crypto/tpm" +) + +var _ apiv1.Attester = (*KMS)(nil) + +func newTPMKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { + km, err := tpmkms.New(ctx, opts) + if err != nil { + return nil, err + } + + return &KMS{ + backend: km, + }, nil +} + +func NewWithTPM(ctx context.Context, t *tpm.TPM, opts ...tpmkms.Option) (*KMS, error) { + km, err := tpmkms.NewWithTPM(ctx, t, opts...) + if err != nil { + return nil, err + } + + return &KMS{ + backend: km, + }, nil +} + +func (k *KMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.CreateAttestationResponse, error) { + if km, ok := k.backend.(apiv1.Attester); ok { + return km.CreateAttestation(req) + } + + return nil, apiv1.NotImplementedError{} +} diff --git a/kms/platform/kms_windows.go b/kms/platform/kms_windows.go new file mode 100644 index 00000000..3de31245 --- /dev/null +++ b/kms/platform/kms_windows.go @@ -0,0 +1,38 @@ +//go:build windows + +package platform + +import ( + "context" + + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/capi" + "go.step.sm/crypto/kms/uri" +) + +func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { + if opts.Type == apiv1.CAPIKMS { + km, err := capi.New(ctx, opts) + if err != nil { + return nil, err + } + + return &KMS{ + backend: km, + }, nil + } + + if opts.URI != "" { + u, err := uri.Parse(opts.URI) + if err != nil { + return nil, err + } + + if !u.Has("enable-cng") { + u.Values.Set("enable-cng", "true") + } + opts.URI = u.String() + } + + return newTPMKMS(ctx, opts) +} From d0d6ca3ac9ec9175e7d7b9559af4936fe73b77e2 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 8 Jan 2026 13:05:04 -0800 Subject: [PATCH 02/27] wip --- kms/apiv1/options.go | 2 +- kms/platform/kms.go | 130 ++++++++++++++++++++++++++++++++++-- kms/platform/kms_darwin.go | 42 ++++++++++-- kms/platform/kms_tpm.go | 34 +++++++++- kms/platform/kms_windows.go | 59 ++++++++++------ 5 files changed, 234 insertions(+), 33 deletions(-) diff --git a/kms/apiv1/options.go b/kms/apiv1/options.go index 667d2fb5..b557b812 100644 --- a/kms/apiv1/options.go +++ b/kms/apiv1/options.go @@ -206,7 +206,7 @@ func (t Type) Validate() error { return nil case YubiKey, PKCS11, TPMKMS: // Hardware based kms. return nil - case SSHAgentKMS, CAPIKMS, MacKMS: // Others + case SSHAgentKMS, CAPIKMS, MacKMS, PlatformKMS: // Others return nil } diff --git a/kms/platform/kms.go b/kms/platform/kms.go index 3015d09a..ef7237e5 100644 --- a/kms/platform/kms.go +++ b/kms/platform/kms.go @@ -4,8 +4,11 @@ import ( "context" "crypto" "crypto/x509" + "net/url" + "strings" "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/uri" ) const Scheme = "kms" @@ -16,6 +19,42 @@ func init() { }) } +const ( + backendKey = "backend" + nameKey = "name" + hwKey = "hw" +) + +type kmsURI struct { + uri *uri.URI + backend apiv1.Type + name string + hw bool + extraValues url.Values +} + +func parseURI(rawuri string) (*kmsURI, error) { + u, err := uri.ParseWithScheme(Scheme, rawuri) + if err != nil { + return nil, err + } + + extraValues := make(url.Values) + for k, v := range u.Values { + if k != nameKey && k != hwKey && k != backendKey { + extraValues[k] = v + } + } + + return &kmsURI{ + uri: u, + backend: apiv1.Type(strings.ToLower(u.Get(backendKey))), + name: u.Get(nameKey), + hw: u.GetBool(hwKey), + extraValues: extraValues, + }, nil +} + type extendedKeyManager interface { apiv1.KeyManager apiv1.KeyDeleter @@ -28,7 +67,8 @@ var _ apiv1.CertificateManager = (*KMS)(nil) var _ apiv1.CertificateChainManager = (*KMS)(nil) type KMS struct { - backend extendedKeyManager + backend extendedKeyManager + transformURI func(*kmsURI) string } func New(ctx context.Context, opts apiv1.Options) (*KMS, error) { @@ -36,14 +76,34 @@ func New(ctx context.Context, opts apiv1.Options) (*KMS, error) { } func (k *KMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) { - return k.backend.GetPublicKey(req) + name, err := k.transform(req.Name) + if err != nil { + return nil, err + } + return k.backend.GetPublicKey(&apiv1.GetPublicKeyRequest{ + Name: name, + }) } func (k *KMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) { + name, err := k.transform(req.Name) + if err != nil { + return nil, err + } + + req = clone(req) + req.Name = name return k.backend.CreateKey(req) } func (k *KMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) { + signingKey, err := k.transform(req.SigningKey) + if err != nil { + return nil, err + } + + req = clone(req) + req.SigningKey = signingKey return k.backend.CreateSigner(req) } @@ -52,25 +112,85 @@ func (k *KMS) Close() error { } func (k *KMS) DeleteKey(req *apiv1.DeleteKeyRequest) error { + name, err := k.transform(req.Name) + if err != nil { + return err + } + + req = clone(req) + req.Name = name return k.backend.DeleteKey(req) } func (k *KMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certificate, error) { + name, err := k.transform(req.Name) + if err != nil { + return nil, err + } + + req = clone(req) + req.Name = name return k.backend.LoadCertificate(req) } func (k *KMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { + name, err := k.transform(req.Name) + if err != nil { + return err + } + + req = clone(req) + req.Name = name return k.backend.StoreCertificate(req) } func (k *KMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([]*x509.Certificate, error) { + name, err := k.transform(req.Name) + if err != nil { + return nil, err + } + + req = clone(req) + req.Name = name return k.backend.LoadCertificateChain(req) } func (k *KMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) error { - if km, ok := k.backend.(apiv1.CertificateChainManager); ok { - return km.StoreCertificateChain(req) + name, err := k.transform(req.Name) + if err != nil { + return err } - return apiv1.NotImplementedError{} + req = clone(req) + req.Name = name + return k.backend.StoreCertificateChain(req) +} + +func (k *KMS) SearchKeys(req *apiv1.SearchKeysRequest) (*apiv1.SearchKeysResponse, error) { + if km, ok := k.backend.(apiv1.SearchableKeyManager); ok { + query, err := k.transform(req.Query) + if err != nil { + return nil, err + } + + req = clone(req) + req.Query = query + return km.SearchKeys(req) + } + + return nil, apiv1.NotImplementedError{} +} + +func (k *KMS) transform(rawuri string) (string, error) { + u, err := parseURI(rawuri) + if err != nil { + return "", err + } + + return k.transformURI(u), nil +} + +func clone[T any](v *T) *T { + c := *v + return &c } diff --git a/kms/platform/kms_darwin.go b/kms/platform/kms_darwin.go index cda75366..fb747af1 100644 --- a/kms/platform/kms_darwin.go +++ b/kms/platform/kms_darwin.go @@ -2,32 +2,62 @@ package platform import ( "context" + "fmt" + "net/url" "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/mackms" + "go.step.sm/crypto/kms/uri" ) var _ apiv1.SearchableKeyManager = (*KMS)(nil) func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { - if opts.Type == apiv1.TPMKMS { + if opts.URI == "" { + return newMacKMS(ctx, opts) + } + + u, err := parseURI(opts.URI) + if err != nil { + return nil, err + } + + switch u.backend { + case apiv1.TPMKMS: return newTPMKMS(ctx, opts) + case apiv1.DefaultKMS, apiv1.MacKMS: + opts.URI = transformToMacKMS(u) + return newMacKMS(ctx, opts) + default: + return nil, fmt.Errorf("failed parsing %q: unsupported backend %q", opts.URI, u.backend) } +} +func newMacKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { km, err := mackms.New(ctx, opts) if err != nil { return nil, err } return &KMS{ - backend: km, + backend: km, + transformURI: transformToMacKMS, }, nil } -func (k *KMS) SearchKeys(req *apiv1.SearchKeysRequest) (*apiv1.SearchKeysResponse, error) { - if km, ok := k.backend.(apiv1.SearchableKeyManager); ok { - return km.SearchKeys(req) +func transformToMacKMS(u *kmsURI) string { + uv := url.Values{ + "label": []string{u.name}, + } + if u.hw { + uv.Set("se", "true") + uv.Set("keychain", "dataProtection") + } + + // Add custom extra values that might be mackms specific. + for k, v := range u.extraValues { + uv[k] = v } - return nil, apiv1.NotImplementedError{} + return uri.New(mackms.Scheme, uv).String() } diff --git a/kms/platform/kms_tpm.go b/kms/platform/kms_tpm.go index bca38f2f..54a3b374 100644 --- a/kms/platform/kms_tpm.go +++ b/kms/platform/kms_tpm.go @@ -2,22 +2,35 @@ package platform import ( "context" + "net/url" "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/tpmkms" + "go.step.sm/crypto/kms/uri" "go.step.sm/crypto/tpm" ) var _ apiv1.Attester = (*KMS)(nil) func newTPMKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { + if opts.URI == "" { + return newTPMKMS(ctx, opts) + } + + u, err := parseURI(opts.URI) + if err != nil { + return nil, err + } + + opts.URI = transformToTPMKMS(u) km, err := tpmkms.New(ctx, opts) if err != nil { return nil, err } return &KMS{ - backend: km, + backend: km, + transformURI: transformToTPMKMS, }, nil } @@ -28,7 +41,8 @@ func NewWithTPM(ctx context.Context, t *tpm.TPM, opts ...tpmkms.Option) (*KMS, e } return &KMS{ - backend: km, + backend: km, + transformURI: transformToTPMKMS, }, nil } @@ -39,3 +53,19 @@ func (k *KMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.Cre return nil, apiv1.NotImplementedError{} } + +func transformToTPMKMS(u *kmsURI) string { + uv := url.Values{ + "name": []string{u.name}, + } + if u.hw { + uv.Set("ak", "true") + } + + // Add custom extra values that might be tpmkms specific. + for k, v := range u.extraValues { + uv[k] = v + } + + return uri.New(tpmkms.Scheme, uv).String() +} diff --git a/kms/platform/kms_windows.go b/kms/platform/kms_windows.go index 3de31245..82854d2e 100644 --- a/kms/platform/kms_windows.go +++ b/kms/platform/kms_windows.go @@ -4,6 +4,8 @@ package platform import ( "context" + "fmt" + "net/url" "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/capi" @@ -11,28 +13,47 @@ import ( ) func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { - if opts.Type == apiv1.CAPIKMS { - km, err := capi.New(ctx, opts) - if err != nil { - return nil, err - } - - return &KMS{ - backend: km, - }, nil + if opts.URI == "" { + return newTPMKMS(ctx, opts) } - if opts.URI != "" { - u, err := uri.Parse(opts.URI) - if err != nil { - return nil, err - } + u, err := parseURI(opts.URI) + if err != nil { + return nil, err + } + + switch u.backend { + case apiv1.CAPIKMS: + opts.URI = transformToCapiKMS(u) + return newCAPIKMS(ctx, opts) + case apiv1.DefaultKMS, apiv1.TPMKMS: + return newTPMKMS(ctx, opts) + default: + return nil, fmt.Errorf("failed parsing %q: unsupported backend %q", opts.URI, u.backend) + } +} + +func newCAPIKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { + km, err := capi.New(ctx, opts) + if err != nil { + return nil, err + } + + return &KMS{ + backend: km, + transformURI: transformToCapiKMS, + }, nil +} + +func transformToCapiKMS(u *kmsURI) string { + uv := url.Values{ + "key": []string{u.name}, + } - if !u.Has("enable-cng") { - u.Values.Set("enable-cng", "true") - } - opts.URI = u.String() + // Add custom extra values that might be tpmkms specific. + for k, v := range u.extraValues { + uv[k] = v } - return newTPMKMS(ctx, opts) + return uri.New(capi.Scheme, uv).String() } From c453374f211444fcba17495b42ad27a0bfef6246 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 21 Jan 2026 15:27:11 -0800 Subject: [PATCH 03/27] Fix transformation of URIs for search methods --- kms/platform/kms_darwin.go | 10 +++++----- kms/platform/kms_tpm.go | 10 +++++----- kms/platform/kms_windows.go | 12 ++++++------ 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/kms/platform/kms_darwin.go b/kms/platform/kms_darwin.go index fb747af1..c10887d8 100644 --- a/kms/platform/kms_darwin.go +++ b/kms/platform/kms_darwin.go @@ -3,6 +3,7 @@ package platform import ( "context" "fmt" + "maps" "net/url" "go.step.sm/crypto/kms/apiv1" @@ -46,8 +47,9 @@ func newMacKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { } func transformToMacKMS(u *kmsURI) string { - uv := url.Values{ - "label": []string{u.name}, + uv := url.Values{} + if u.name != "" { + uv.Set("label", u.name) } if u.hw { uv.Set("se", "true") @@ -55,9 +57,7 @@ func transformToMacKMS(u *kmsURI) string { } // Add custom extra values that might be mackms specific. - for k, v := range u.extraValues { - uv[k] = v - } + maps.Copy(uv, u.extraValues) return uri.New(mackms.Scheme, uv).String() } diff --git a/kms/platform/kms_tpm.go b/kms/platform/kms_tpm.go index 54a3b374..bfdf6f7a 100644 --- a/kms/platform/kms_tpm.go +++ b/kms/platform/kms_tpm.go @@ -2,6 +2,7 @@ package platform import ( "context" + "maps" "net/url" "go.step.sm/crypto/kms/apiv1" @@ -55,17 +56,16 @@ func (k *KMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.Cre } func transformToTPMKMS(u *kmsURI) string { - uv := url.Values{ - "name": []string{u.name}, + uv := url.Values{} + if u.name != "" { + uv.Set("name", u.name) } if u.hw { uv.Set("ak", "true") } // Add custom extra values that might be tpmkms specific. - for k, v := range u.extraValues { - uv[k] = v - } + maps.Copy(uv, u.extraValues) return uri.New(tpmkms.Scheme, uv).String() } diff --git a/kms/platform/kms_windows.go b/kms/platform/kms_windows.go index 82854d2e..2c6beed6 100644 --- a/kms/platform/kms_windows.go +++ b/kms/platform/kms_windows.go @@ -5,6 +5,7 @@ package platform import ( "context" "fmt" + "maps" "net/url" "go.step.sm/crypto/kms/apiv1" @@ -46,14 +47,13 @@ func newCAPIKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { } func transformToCapiKMS(u *kmsURI) string { - uv := url.Values{ - "key": []string{u.name}, + uv := url.Values{} + if u.name != "" { + uv.Set("key", u.name) } - // Add custom extra values that might be tpmkms specific. - for k, v := range u.extraValues { - uv[k] = v - } + // Add custom extra values that might be CAPI specific. + maps.Copy(uv, u.extraValues) return uri.New(capi.Scheme, uv).String() } From fc79c123a8cd9c5f0c5f271c9bbb7ec7ba843228 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 29 Jan 2026 17:37:54 -0800 Subject: [PATCH 04/27] Add LoadCertificate and LoadCertificateChain to softkms --- kms/softkms/softkms.go | 26 +++++++++++++ kms/softkms/softkms_test.go | 68 +++++++++++++++++++++++++++++++++- kms/softkms/testdata/chain.crt | 12 ++++++ kms/softkms/testdata/chain.key | 5 +++ 4 files changed, 110 insertions(+), 1 deletion(-) create mode 100644 kms/softkms/testdata/chain.crt create mode 100644 kms/softkms/testdata/chain.key diff --git a/kms/softkms/softkms.go b/kms/softkms/softkms.go index 46a73ebe..fb8ad78a 100644 --- a/kms/softkms/softkms.go +++ b/kms/softkms/softkms.go @@ -7,6 +7,7 @@ import ( "crypto/ed25519" "crypto/rsa" "crypto/x509" + "fmt" "github.com/pkg/errors" "go.step.sm/crypto/keyutil" @@ -190,6 +191,31 @@ func (k *SoftKMS) CreateDecrypter(req *apiv1.CreateDecrypterRequest) (crypto.Dec } } +// LoadCertificate returns a x509.Certificate from the file passed in the +// request name. +func (k *SoftKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certificate, error) { + if req.Name == "" { + return nil, fmt.Errorf("loadCertificateRequest 'name' cannot be empty") + } + + bundle, err := pemutil.ReadCertificateBundle(filename(req.Name)) + if err != nil { + return nil, err + } + + return bundle[0], nil +} + +// LoadCertificateChain returns a slice of x509.Certificate from the file passed +// in the request name. +func (k *SoftKMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([]*x509.Certificate, error) { + if req.Name == "" { + return nil, fmt.Errorf("loadCertificateChainRequest 'name' cannot be empty") + } + + return pemutil.ReadCertificateBundle(filename(req.Name)) +} + func filename(s string) string { if u, err := uri.ParseWithScheme(Scheme, s); err == nil { if f := u.Get("path"); f != "" { diff --git a/kms/softkms/softkms_test.go b/kms/softkms/softkms_test.go index 11aa4ddd..dc86ee67 100644 --- a/kms/softkms/softkms_test.go +++ b/kms/softkms/softkms_test.go @@ -17,7 +17,7 @@ import ( "testing" "github.com/stretchr/testify/assert" - + "github.com/stretchr/testify/require" "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x25519" @@ -439,3 +439,69 @@ func Test_filename(t *testing.T) { }) } } + +func TestSoftKMS_LoadCertificate(t *testing.T) { + cert, err := pemutil.ReadCertificate("testdata/cert.crt") + require.NoError(t, err) + + type args struct { + req *apiv1.LoadCertificateRequest + } + tests := []struct { + name string + k *SoftKMS + args args + want *x509.Certificate + assertion assert.ErrorAssertionFunc + }{ + {"ok", &SoftKMS{}, args{&apiv1.LoadCertificateRequest{Name: "testdata/cert.crt"}}, cert, assert.NoError}, + {"ok uri", &SoftKMS{}, args{&apiv1.LoadCertificateRequest{Name: "testdata/cert.crt"}}, cert, assert.NoError}, + {"ok uri with path", &SoftKMS{}, args{&apiv1.LoadCertificateRequest{Name: "softkms:path=testdata/cert.crt"}}, cert, assert.NoError}, + {"fail empty", &SoftKMS{}, args{&apiv1.LoadCertificateRequest{}}, nil, assert.Error}, + {"fail missing", &SoftKMS{}, args{&apiv1.LoadCertificateRequest{Name: "testdata/missing.crt"}}, nil, assert.Error}, + {"fail not a certificate", &SoftKMS{}, args{&apiv1.LoadCertificateRequest{Name: "testdata/cert.key"}}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &SoftKMS{} + got, err := k.LoadCertificate(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestSoftKMS_LoadCertificateChain(t *testing.T) { + chain, err := pemutil.ReadCertificateBundle("testdata/chain.crt") + require.NoError(t, err) + + cert, err := pemutil.ReadCertificate("testdata/cert.crt") + require.NoError(t, err) + + type args struct { + req *apiv1.LoadCertificateChainRequest + } + tests := []struct { + name string + k *SoftKMS + args args + want []*x509.Certificate + assertion assert.ErrorAssertionFunc + }{ + {"ok", &SoftKMS{}, args{&apiv1.LoadCertificateChainRequest{Name: "testdata/chain.crt"}}, chain, assert.NoError}, + {"ok uri", &SoftKMS{}, args{&apiv1.LoadCertificateChainRequest{Name: "testdata/chain.crt"}}, chain, assert.NoError}, + {"ok uri with path", &SoftKMS{}, args{&apiv1.LoadCertificateChainRequest{Name: "softkms:path=testdata/chain.crt"}}, chain, assert.NoError}, + {"ok cert", &SoftKMS{}, args{&apiv1.LoadCertificateChainRequest{Name: "softkms:testdata/cert.crt"}}, []*x509.Certificate{cert}, assert.NoError}, + {"fail empty", &SoftKMS{}, args{&apiv1.LoadCertificateChainRequest{}}, nil, assert.Error}, + {"fail missing", &SoftKMS{}, args{&apiv1.LoadCertificateChainRequest{Name: "testdata/missing.crt"}}, nil, assert.Error}, + {"fail not a certificate", &SoftKMS{}, args{&apiv1.LoadCertificateChainRequest{Name: "testdata/cert.key"}}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &SoftKMS{} + got, err := k.LoadCertificateChain(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/kms/softkms/testdata/chain.crt b/kms/softkms/testdata/chain.crt new file mode 100644 index 00000000..4289f3e3 --- /dev/null +++ b/kms/softkms/testdata/chain.crt @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIB0TCCAXegAwIBAgIQLnp4754BEk4JSoK7md7K4DAKBggqhkjOPQQDAjAkMSIw +IAYDVQQDExlTbWFsbHN0ZXAgSW50ZXJtZWRpYXRlIENBMB4XDTI2MDEzMDAxMzM0 +M1oXDTI2MDEzMTAxMzM0M1owHTEbMBkGA1UEAxMSdGVzdC5zbWFsbHN0ZXAuY29t +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEIPeCsTFXBmkeJlX4dC6jo+2oe3US +m3Yt1LfJRrV2cBNB5t+OQqKZNajirkBCv/IqUBFJILtrNAZ8tSNwTtCTeKOBkTCB +jjAOBgNVHQ8BAf8EBAMCB4AwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMC +MB0GA1UdDgQWBBRv+K/kt+Ho6CDVpfvl3YRnbqxBjzAfBgNVHSMEGDAWgBQOlOB3 +dUUKfp7vXBdzBsMJREGiZTAdBgNVHREEFjAUghJ0ZXN0LnNtYWxsc3RlcC5jb20w +CgYIKoZIzj0EAwIDSAAwRQIhAKzVQCITpTCITRhQ/VWrUHqhYuD+Q/a7yUrM1gix +ZJV3AiAgRBrPNWYAIK24hjVWE21OPaJnSZ7Q7VRJNE0/vQzrZQ== +-----END CERTIFICATE----- diff --git a/kms/softkms/testdata/chain.key b/kms/softkms/testdata/chain.key new file mode 100644 index 00000000..b9b8e11c --- /dev/null +++ b/kms/softkms/testdata/chain.key @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIKDTRUVuCrQVHjAlUjaDkmfEayOiIkLk1SZPU7MPMhyxoAoGCCqGSM49 +AwEHoUQDQgAEIPeCsTFXBmkeJlX4dC6jo+2oe3USm3Yt1LfJRrV2cBNB5t+OQqKZ +NajirkBCv/IqUBFJILtrNAZ8tSNwTtCTeA== +-----END EC PRIVATE KEY----- From 5c565357e83a8ff963e4d60911f49528faed13b2 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 29 Jan 2026 17:39:05 -0800 Subject: [PATCH 05/27] Add softkms as a backend for platformkms --- kms/platform/kms.go | 12 ++++ kms/platform/kms_darwin.go | 2 + kms/platform/kms_other.go | 19 +++++- kms/platform/kms_softkms.go | 113 ++++++++++++++++++++++++++++++++++++ kms/platform/kms_windows.go | 2 + 5 files changed, 147 insertions(+), 1 deletion(-) create mode 100644 kms/platform/kms_softkms.go diff --git a/kms/platform/kms.go b/kms/platform/kms.go index ef7237e5..a6b209fb 100644 --- a/kms/platform/kms.go +++ b/kms/platform/kms.go @@ -60,6 +60,7 @@ type extendedKeyManager interface { apiv1.KeyDeleter apiv1.CertificateManager apiv1.CertificateChainManager + apiv1.CertificateDeleter } var _ apiv1.KeyManager = (*KMS)(nil) @@ -166,6 +167,17 @@ func (k *KMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) err return k.backend.StoreCertificateChain(req) } +func (k *KMS) DeleteCertificater(req *apiv1.DeleteCertificateRequest) error { + name, err := k.transform(req.Name) + if err != nil { + return err + } + + req = clone(req) + req.Name = name + return k.backend.DeleteCertificate(req) +} + func (k *KMS) SearchKeys(req *apiv1.SearchKeysRequest) (*apiv1.SearchKeysResponse, error) { if km, ok := k.backend.(apiv1.SearchableKeyManager); ok { query, err := k.transform(req.Query) diff --git a/kms/platform/kms_darwin.go b/kms/platform/kms_darwin.go index c10887d8..32fdb981 100644 --- a/kms/platform/kms_darwin.go +++ b/kms/platform/kms_darwin.go @@ -26,6 +26,8 @@ func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { switch u.backend { case apiv1.TPMKMS: return newTPMKMS(ctx, opts) + case apiv1.SoftKMS: + return newSoftKMS(ctx, opts) case apiv1.DefaultKMS, apiv1.MacKMS: opts.URI = transformToMacKMS(u) return newMacKMS(ctx, opts) diff --git a/kms/platform/kms_other.go b/kms/platform/kms_other.go index fb844bad..9901783f 100644 --- a/kms/platform/kms_other.go +++ b/kms/platform/kms_other.go @@ -4,10 +4,27 @@ package platform import ( "context" + "fmt" "go.step.sm/crypto/kms/apiv1" ) func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { - return newTPMKMS(ctx, opts) + if opts.URI == "" { + return newTPMKMS(ctx, opts) + } + + u, err := parseURI(opts.URI) + if err != nil { + return nil, err + } + + switch u.backend { + case apiv1.SoftKMS: + return newSoftKMS(ctx, opts) + case apiv1.DefaultKMS, apiv1.TPMKMS: + return newTPMKMS(ctx, opts) + default: + return nil, fmt.Errorf("failed parsing %q: unsupported backend %q", opts.URI, u.backend) + } } diff --git a/kms/platform/kms_softkms.go b/kms/platform/kms_softkms.go new file mode 100644 index 00000000..0d6e4975 --- /dev/null +++ b/kms/platform/kms_softkms.go @@ -0,0 +1,113 @@ +package platform + +import ( + "bytes" + "context" + "encoding/pem" + "fmt" + "os" + + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/softkms" + "go.step.sm/crypto/kms/uri" + "go.step.sm/crypto/pemutil" +) + +func newSoftKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { + km, err := softkms.New(ctx, opts) + if err != nil { + return nil, err + } + + return &KMS{ + backend: &softKMS{SoftKMS: km}, + transformURI: transformToSoftKMS, + }, nil +} + +type softKMS struct { + *softkms.SoftKMS +} + +func (k *softKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) { + resp, err := k.SoftKMS.CreateKey(req) + if err != nil { + return nil, err + } + + if _, err := pemutil.Serialize(resp.PrivateKey, pemutil.ToFile(resp.Name, 0o600)); err != nil { + return nil, err + } + + return resp, nil +} + +func (k *softKMS) DeleteKey(req *apiv1.DeleteKeyRequest) error { + if req.Name == "" { + return fmt.Errorf("deleteKeyRequest 'name' cannot be empty") + } + + return os.Remove(filename(req.Name)) +} + +func (k *softKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { + switch { + case req.Name == "": + return fmt.Errorf("storeCertificateRequest 'name' cannot be empty") + case req.Certificate == nil: + return fmt.Errorf("storeCertificateRequest 'certificate' cannot be empty") + } + + return os.WriteFile(filename(req.Name), pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: req.Certificate.Raw, + }), 0o600) +} + +func (k *softKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) error { + switch { + case req.Name == "": + return fmt.Errorf("storeCertificateChainRequest 'name' cannot be empty") + case len(req.CertificateChain) == 0: + return fmt.Errorf("storeCertificateChainRequest 'certificateChain' cannot be empty") + } + + var buf bytes.Buffer + for _, crt := range req.CertificateChain { + if err := pem.Encode(&buf, &pem.Block{ + Type: "CERTIFICATE", + Bytes: crt.Raw, + }); err != nil { + return err + } + } + + return os.WriteFile(filename(req.Name), buf.Bytes(), 0o600) +} + +func (k *softKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { + if req.Name == "" { + return fmt.Errorf("deleteCertificateRequest 'name' cannot be empty") + } + + return os.Remove(filename(req.Name)) +} + +func filename(s string) string { + if u, err := uri.ParseWithScheme(Scheme, s); err == nil { + if f := u.Get("path"); f != "" { + return f + } + switch { + case u.Path != "": + return u.Path + default: + return u.Opaque + } + } + return s +} + +func transformToSoftKMS(u *kmsURI) string { + return uri.NewOpaque(softkms.Scheme, u.name).String() +} diff --git a/kms/platform/kms_windows.go b/kms/platform/kms_windows.go index 2c6beed6..30260f18 100644 --- a/kms/platform/kms_windows.go +++ b/kms/platform/kms_windows.go @@ -27,6 +27,8 @@ func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { case apiv1.CAPIKMS: opts.URI = transformToCapiKMS(u) return newCAPIKMS(ctx, opts) + case apiv1.SoftKMS: + return newSoftKMS(ctx, opts) case apiv1.DefaultKMS, apiv1.TPMKMS: return newTPMKMS(ctx, opts) default: From 741a53fcd168e299ff063479824f9edc9e75e07d Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 11 Feb 2026 11:53:28 -0800 Subject: [PATCH 06/27] Fix typo --- kms/platform/kms.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kms/platform/kms.go b/kms/platform/kms.go index a6b209fb..a88162b3 100644 --- a/kms/platform/kms.go +++ b/kms/platform/kms.go @@ -167,7 +167,7 @@ func (k *KMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) err return k.backend.StoreCertificateChain(req) } -func (k *KMS) DeleteCertificater(req *apiv1.DeleteCertificateRequest) error { +func (k *KMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { name, err := k.transform(req.Name) if err != nil { return err From 6d85f02a267a08555e88a4dfb529332888e92f25 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 17 Feb 2026 17:18:11 -0800 Subject: [PATCH 07/27] Add methods to convert URIs --- kms/platform/kms.go | 68 ++++++- kms/platform/kms_darwin.go | 28 ++- kms/platform/kms_softkms.go | 22 ++- kms/platform/kms_test.go | 356 ++++++++++++++++++++++++++++++++++++ kms/platform/kms_tpm.go | 32 +++- kms/platform/kms_windows.go | 25 ++- 6 files changed, 509 insertions(+), 22 deletions(-) create mode 100644 kms/platform/kms_test.go diff --git a/kms/platform/kms.go b/kms/platform/kms.go index a88162b3..68a8e538 100644 --- a/kms/platform/kms.go +++ b/kms/platform/kms.go @@ -68,14 +68,24 @@ var _ apiv1.CertificateManager = (*KMS)(nil) var _ apiv1.CertificateChainManager = (*KMS)(nil) type KMS struct { - backend extendedKeyManager - transformURI func(*kmsURI) string + typ apiv1.Type + backend extendedKeyManager + transformToURI func(*kmsURI) string + transformFromURI func(string) (string, error) } func New(ctx context.Context, opts apiv1.Options) (*KMS, error) { return newKMS(ctx, opts) } +func (k *KMS) Type() apiv1.Type { + return k.typ +} + +func (k *KMS) Close() error { + return k.backend.Close() +} + func (k *KMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) { name, err := k.transform(req.Name) if err != nil { @@ -94,10 +104,19 @@ func (k *KMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, req = clone(req) req.Name = name - return k.backend.CreateKey(req) + resp, err := k.backend.CreateKey(req) + if err != nil { + return nil, err + } + + return k.patchCreateKeyResponse(resp) } func (k *KMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) { + if req.Signer != nil { + return req.Signer, nil + } + signingKey, err := k.transform(req.SigningKey) if err != nil { return nil, err @@ -108,10 +127,6 @@ func (k *KMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error return k.backend.CreateSigner(req) } -func (k *KMS) Close() error { - return k.backend.Close() -} - func (k *KMS) DeleteKey(req *apiv1.DeleteKeyRequest) error { name, err := k.transform(req.Name) if err != nil { @@ -187,7 +202,12 @@ func (k *KMS) SearchKeys(req *apiv1.SearchKeysRequest) (*apiv1.SearchKeysRespons req = clone(req) req.Query = query - return km.SearchKeys(req) + resp, err := km.SearchKeys(req) + if err != nil { + return nil, err + } + + return k.patchSearchKeysResponse(resp) } return nil, apiv1.NotImplementedError{} @@ -199,7 +219,37 @@ func (k *KMS) transform(rawuri string) (string, error) { return "", err } - return k.transformURI(u), nil + return k.transformToURI(u), nil +} + +func (k *KMS) patchCreateKeyResponse(resp *apiv1.CreateKeyResponse) (*apiv1.CreateKeyResponse, error) { + name, err := k.transformFromURI(resp.Name) + if err != nil { + return nil, err + } + + resp.Name = name + if resp.CreateSignerRequest.SigningKey != "" { + resp.CreateSignerRequest.SigningKey = name + } + + return resp, nil +} + +func (k *KMS) patchSearchKeysResponse(resp *apiv1.SearchKeysResponse) (*apiv1.SearchKeysResponse, error) { + for i := range resp.Results { + name, err := k.transformFromURI(resp.Results[i].Name) + if err != nil { + return nil, err + } + + resp.Results[i].Name = name + if resp.Results[i].CreateSignerRequest.SigningKey != "" { + resp.Results[i].CreateSignerRequest.SigningKey = name + } + } + + return resp, nil } func clone[T any](v *T) *T { diff --git a/kms/platform/kms_darwin.go b/kms/platform/kms_darwin.go index 32fdb981..4ef36bc8 100644 --- a/kms/platform/kms_darwin.go +++ b/kms/platform/kms_darwin.go @@ -43,8 +43,10 @@ func newMacKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { } return &KMS{ - backend: km, - transformURI: transformToMacKMS, + typ: apiv1.MacKMS, + backend: km, + transformToURI: transformToMacKMS, + transformFromURI: transformFromMacKMS, }, nil } @@ -63,3 +65,25 @@ func transformToMacKMS(u *kmsURI) string { return uri.New(mackms.Scheme, uv).String() } + +func transformFromMacKMS(rawuri string) (string, error) { + u, err := uri.ParseWithScheme(mackms.Scheme, rawuri) + if err != nil { + return "", err + } + + uv := url.Values{ + "name": []string{u.Get("label")}, + } + if u.GetBool("se") { + uv.Set("hw", "true") + } + + for k, v := range u.Values { + if k != "label" && k != "se" { + uv[k] = v + } + } + + return uri.New(Scheme, uv).String(), nil +} diff --git a/kms/platform/kms_softkms.go b/kms/platform/kms_softkms.go index 0d6e4975..9dbe8483 100644 --- a/kms/platform/kms_softkms.go +++ b/kms/platform/kms_softkms.go @@ -5,6 +5,7 @@ import ( "context" "encoding/pem" "fmt" + "net/url" "os" "go.step.sm/crypto/kms/apiv1" @@ -20,8 +21,10 @@ func newSoftKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { } return &KMS{ - backend: &softKMS{SoftKMS: km}, - transformURI: transformToSoftKMS, + typ: apiv1.SoftKMS, + backend: &softKMS{SoftKMS: km}, + transformToURI: transformToSoftKMS, + transformFromURI: transformFromSoftKMS, }, nil } @@ -57,7 +60,7 @@ func (k *softKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { case req.Certificate == nil: return fmt.Errorf("storeCertificateRequest 'certificate' cannot be empty") } - + fmt.Println(filename(req.Name)) return os.WriteFile(filename(req.Name), pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: req.Certificate.Raw, @@ -94,7 +97,7 @@ func (k *softKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { } func filename(s string) string { - if u, err := uri.ParseWithScheme(Scheme, s); err == nil { + if u, err := uri.ParseWithScheme(softkms.Scheme, s); err == nil { if f := u.Get("path"); f != "" { return f } @@ -109,5 +112,14 @@ func filename(s string) string { } func transformToSoftKMS(u *kmsURI) string { - return uri.NewOpaque(softkms.Scheme, u.name).String() + if u.name != "" { + return uri.NewOpaque(softkms.Scheme, u.name).String() + } + return uri.NewOpaque(softkms.Scheme, u.uri.Path).String() +} + +func transformFromSoftKMS(rawuri string) (string, error) { + return uri.New(Scheme, url.Values{ + "name": []string{rawuri}, + }).String(), nil } diff --git a/kms/platform/kms_test.go b/kms/platform/kms_test.go new file mode 100644 index 00000000..fe363b6c --- /dev/null +++ b/kms/platform/kms_test.go @@ -0,0 +1,356 @@ +package platform + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/keyutil" + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/minica" + "go.step.sm/crypto/pemutil" +) + +func mustKMS(t *testing.T, rawuri string) *KMS { + t.Helper() + + km, err := New(t.Context(), apiv1.Options{ + URI: rawuri, + }) + require.NoError(t, err) + return km +} + +func mustSigner(t *testing.T, path string) crypto.Signer { + t.Helper() + + signer, err := keyutil.GenerateDefaultSigner() + require.NoError(t, err) + + _, err = pemutil.Serialize(signer, pemutil.ToFile(path, 0o600)) + require.NoError(t, err) + + return signer +} + +func mustReadSigner(t *testing.T, path string) crypto.Signer { + t.Helper() + + k, err := pemutil.Read(path) + require.NoError(t, err) + + signer, ok := k.(crypto.Signer) + require.True(t, ok) + + return signer +} + +func mustCertificate(t *testing.T, path string) []*x509.Certificate { + t.Helper() + + ca, err := minica.New() + require.NoError(t, err) + + signer, err := keyutil.GenerateDefaultSigner() + require.NoError(t, err) + + cert, err := ca.Sign(&x509.Certificate{ + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + PublicKey: signer.Public(), + DNSNames: []string{"example.com"}, + }) + require.NoError(t, err) + + var buf bytes.Buffer + require.NoError(t, pem.Encode(&buf, &pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + })) + require.NoError(t, pem.Encode(&buf, &pem.Block{ + Type: "CERTIFICATE", + Bytes: ca.Intermediate.Raw, + })) + require.NoError(t, os.WriteFile(path, buf.Bytes(), 0o600)) + + return []*x509.Certificate{ + cert, ca.Intermediate, + } +} + +func TestKMS_Type(t *testing.T) { + softKMS := mustKMS(t, "kms:backend=softkms") + assert.Equal(t, apiv1.SoftKMS, softKMS.Type()) +} + +func TestKMS_Close(t *testing.T) { + softKMS := mustKMS(t, "kms:backend=softkms") + assert.NoError(t, softKMS.Close()) +} + +func TestKMS_GetPublicKey(t *testing.T) { + dir := t.TempDir() + privateKeyPath := filepath.Join(dir, "private.key") + signer := mustSigner(t, privateKeyPath) + softKMS := mustKMS(t, "kms:backend=softkms") + + type args struct { + req *apiv1.GetPublicKeyRequest + } + tests := []struct { + name string + kms *KMS + args args + want crypto.PublicKey + assertion assert.ErrorAssertionFunc + }{ + {"ok SoftKMS", softKMS, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:name=" + privateKeyPath, + }}, signer.Public(), assert.NoError}, + {"ok SoftKMS escape", softKMS, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:name=" + url.QueryEscape(privateKeyPath), + }}, signer.Public(), assert.NoError}, + {"ok SoftKMS path", softKMS, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:" + privateKeyPath, + }}, signer.Public(), assert.NoError}, + {"fail empty name", softKMS, args{&apiv1.GetPublicKeyRequest{ + Name: "", + }}, nil, assert.Error}, + {"fail SoftKMS missing", softKMS, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:" + filepath.Join(dir, "notfound.key"), + }}, nil, assert.Error}, + {"fail parseURI", softKMS, args{&apiv1.GetPublicKeyRequest{ + Name: "softkms:" + privateKeyPath, + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.GetPublicKey(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_CreateKey(t *testing.T) { + dir := t.TempDir() + privateKeyPath := filepath.Join(dir, "private.key") + softKMS := mustKMS(t, "kms:backend=softkms") + + type args struct { + req *apiv1.CreateKeyRequest + } + tests := []struct { + name string + kms *KMS + args args + equal func(t *testing.T, got *apiv1.CreateKeyResponse) + assertion assert.ErrorAssertionFunc + }{ + {"ok softKMS", softKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=" + privateKeyPath, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + signer := mustReadSigner(t, privateKeyPath) + assert.IsType(t, &ecdsa.PrivateKey{}, signer) + name := "kms:name=" + url.QueryEscape(privateKeyPath) + assert.Equal(t, got, &apiv1.CreateKeyResponse{ + Name: name, + PublicKey: signer.Public(), + PrivateKey: signer, + CreateSignerRequest: apiv1.CreateSignerRequest{ + Signer: signer, + SigningKey: name, + }, + }) + }, assert.NoError}, + {"ok softKMS escape", softKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=" + url.QueryEscape(privateKeyPath), + SignatureAlgorithm: apiv1.SHA256WithRSA, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + signer := mustReadSigner(t, privateKeyPath) + if assert.IsType(t, &rsa.PrivateKey{}, signer) { + assert.Equal(t, 3072/8, signer.(*rsa.PrivateKey).Size()) + } + name := "kms:name=" + url.QueryEscape(privateKeyPath) + assert.Equal(t, got, &apiv1.CreateKeyResponse{ + Name: name, + PublicKey: signer.Public(), + PrivateKey: signer, + CreateSignerRequest: apiv1.CreateSignerRequest{ + Signer: signer, + SigningKey: name, + }, + }) + }, assert.NoError}, + {"ok softKMS path", softKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:" + privateKeyPath, + SignatureAlgorithm: apiv1.SHA256WithRSA, + Bits: 2048, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + signer := mustReadSigner(t, privateKeyPath) + name := "kms:name=" + url.QueryEscape(privateKeyPath) + if assert.IsType(t, &rsa.PrivateKey{}, signer) { + assert.Equal(t, 2048/8, signer.(*rsa.PrivateKey).Size()) + } + assert.Equal(t, got, &apiv1.CreateKeyResponse{ + Name: name, + PublicKey: signer.Public(), + PrivateKey: signer, + CreateSignerRequest: apiv1.CreateSignerRequest{ + Signer: signer, + SigningKey: name, + }, + }) + }, assert.NoError}, + {"fail parseURI", softKMS, args{&apiv1.CreateKeyRequest{ + Name: "softkms:" + privateKeyPath, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.CreateKey(tt.args.req) + tt.assertion(t, err) + tt.equal(t, got) + }) + } +} + +func TestKMS_CreateSigner(t *testing.T) { + dir := t.TempDir() + softKMS := mustKMS(t, "kms:backend=softkms") + privateKeyPath := filepath.Join(dir, "private.key") + resp, err := softKMS.CreateKey(&apiv1.CreateKeyRequest{ + Name: "kms:name=" + url.QueryEscape(privateKeyPath), + }) + require.NoError(t, err) + signer := mustReadSigner(t, privateKeyPath) + + type args struct { + req *apiv1.CreateSignerRequest + } + tests := []struct { + name string + kms *KMS + args args + want crypto.Signer + assertion assert.ErrorAssertionFunc + }{ + {"ok softKMS", softKMS, args{&apiv1.CreateSignerRequest{ + SigningKey: "kms:name=" + url.QueryEscape(privateKeyPath), + }}, signer, assert.NoError}, + {"ok softKMS with signer", softKMS, args{&apiv1.CreateSignerRequest{ + Signer: resp.CreateSignerRequest.Signer, + SigningKey: resp.CreateSignerRequest.SigningKey, + }}, signer, assert.NoError}, + {"fail missing", softKMS, args{&apiv1.CreateSignerRequest{ + SigningKey: "kms:name=" + url.QueryEscape(filepath.Join(dir, "missing.key")), + }}, nil, assert.Error}, + {"fail parseURI", softKMS, args{&apiv1.CreateSignerRequest{ + SigningKey: privateKeyPath, + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.CreateSigner(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_DeleteKey(t *testing.T) { + dir := t.TempDir() + softKMS := mustKMS(t, "kms:backend=softkms") + keyPath1 := filepath.Join(dir, "key1.key") + _, err := softKMS.CreateKey(&apiv1.CreateKeyRequest{ + Name: "kms:name=" + url.QueryEscape(keyPath1), + }) + require.NoError(t, err) + + keyPath2 := filepath.Join(dir, "key2.key") + _, err = softKMS.CreateKey(&apiv1.CreateKeyRequest{ + Name: "kms:name=" + keyPath2, + }) + require.NoError(t, err) + + type args struct { + req *apiv1.DeleteKeyRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok softKMS", softKMS, args{&apiv1.DeleteKeyRequest{ + Name: "kms:name=" + url.QueryEscape(keyPath1), + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) && + assert.NoFileExists(t, keyPath1) + }}, + {"fail missing", softKMS, args{&apiv1.DeleteKeyRequest{ + Name: "kms:name=" + url.QueryEscape(filepath.Join(dir, "missing.key")), + }}, assert.Error}, + {"fail parseURI", softKMS, args{&apiv1.DeleteKeyRequest{ + Name: keyPath2, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.Error(t, err) && + assert.FileExists(t, keyPath2) + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, tt.kms.DeleteKey(tt.args.req)) + }) + } +} + +func TestKMS_LoadCertificate(t *testing.T) { + dir := t.TempDir() + + chainPath := filepath.Join(dir, "chain.crt") + chain := mustCertificate(t, chainPath) + + certPath := filepath.Join(dir, "certificate.crt") + softKMS := mustKMS(t, "kms:backend=softkms") + require.NoError(t, softKMS.StoreCertificate(&apiv1.StoreCertificateRequest{ + Name: "kms:name=" + url.QueryEscape(certPath), + Certificate: chain[0], + })) + + type args struct { + req *apiv1.LoadCertificateRequest + } + tests := []struct { + name string + kms *KMS + args args + want *x509.Certificate + assertion assert.ErrorAssertionFunc + }{ + {"ok softKMS", softKMS, args{&apiv1.LoadCertificateRequest{ + Name: "kms:" + certPath, + }}, chain[0], assert.NoError}, + {"ok softKMS from chain", softKMS, args{&apiv1.LoadCertificateRequest{ + Name: "kms:name=" + chainPath, + }}, chain[0], assert.NoError}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.LoadCertificate(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/kms/platform/kms_tpm.go b/kms/platform/kms_tpm.go index bfdf6f7a..47c6e1c9 100644 --- a/kms/platform/kms_tpm.go +++ b/kms/platform/kms_tpm.go @@ -30,8 +30,10 @@ func newTPMKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { } return &KMS{ - backend: km, - transformURI: transformToTPMKMS, + typ: apiv1.TPMKMS, + backend: km, + transformToURI: transformToTPMKMS, + transformFromURI: transformFromTPMKMS, }, nil } @@ -42,8 +44,8 @@ func NewWithTPM(ctx context.Context, t *tpm.TPM, opts ...tpmkms.Option) (*KMS, e } return &KMS{ - backend: km, - transformURI: transformToTPMKMS, + backend: km, + transformToURI: transformToTPMKMS, }, nil } @@ -69,3 +71,25 @@ func transformToTPMKMS(u *kmsURI) string { return uri.New(tpmkms.Scheme, uv).String() } + +func transformFromTPMKMS(rawuri string) (string, error) { + u, err := uri.ParseWithScheme(tpmkms.Scheme, rawuri) + if err != nil { + return "", err + } + + uv := url.Values{ + "name": []string{u.Get("name")}, + } + if u.GetBool("ak") { + uv.Set("hw", "true") + } + + for k, v := range u.Values { + if k != "name" && k != "ak" { + uv[k] = v + } + } + + return uri.New(Scheme, uv).String(), nil +} diff --git a/kms/platform/kms_windows.go b/kms/platform/kms_windows.go index 30260f18..cc73bdbc 100644 --- a/kms/platform/kms_windows.go +++ b/kms/platform/kms_windows.go @@ -43,8 +43,10 @@ func newCAPIKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { } return &KMS{ - backend: km, - transformURI: transformToCapiKMS, + typ: apiv1.CAPIKMS, + backend: km, + transformToURI: transformToCapiKMS, + transformFromURI: transformFromCapiKMS, }, nil } @@ -59,3 +61,22 @@ func transformToCapiKMS(u *kmsURI) string { return uri.New(capi.Scheme, uv).String() } + +func transformFromCapiKMS(rawuri string) (string, error) { + u, err := uri.ParseWithScheme(capi.Scheme, rawuri) + if err != nil { + return "", err + } + + uv := url.Values{ + "name": []string{u.Get("key")}, + } + + for k, v := range u.Values { + if k != "name" { + uv[k] = v + } + } + + return uri.New(Scheme, uv).String(), nil +} From 1d7ca518a3133f0d75fbe557abb3b28a67db1ec7 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 19 Feb 2026 18:58:33 -0800 Subject: [PATCH 08/27] Add some tests to platform kms --- kms/platform/kms_darwin.go | 2 + kms/platform/kms_darwin_test.go | 15 + kms/platform/kms_other_test.go | 17 + kms/platform/kms_softkms.go | 29 +- kms/platform/kms_test.go | 388 +++++++++++- kms/platform/kms_tpm.go | 8 +- kms/platform/kms_tpmsimulator_test.go | 816 ++++++++++++++++++++++++++ kms/platform/kms_windows_test.go | 17 + tpm/key.go | 20 + 9 files changed, 1288 insertions(+), 24 deletions(-) create mode 100644 kms/platform/kms_darwin_test.go create mode 100644 kms/platform/kms_other_test.go create mode 100644 kms/platform/kms_tpmsimulator_test.go create mode 100644 kms/platform/kms_windows_test.go diff --git a/kms/platform/kms_darwin.go b/kms/platform/kms_darwin.go index 4ef36bc8..469dbd77 100644 --- a/kms/platform/kms_darwin.go +++ b/kms/platform/kms_darwin.go @@ -58,6 +58,8 @@ func transformToMacKMS(u *kmsURI) string { if u.hw { uv.Set("se", "true") uv.Set("keychain", "dataProtection") + } else if u.uri.Has("hw") { + uv.Set("se", "false") } // Add custom extra values that might be mackms specific. diff --git a/kms/platform/kms_darwin_test.go b/kms/platform/kms_darwin_test.go new file mode 100644 index 00000000..e9cb7681 --- /dev/null +++ b/kms/platform/kms_darwin_test.go @@ -0,0 +1,15 @@ +package platform + +import "testing" + +func mustPlatformKMS(t *testing.T) *KMS { + t.Helper() + + return mustKMS(t, "kms:") +} + +// SkipTest is a method implemented on tests that allow skipping the test on +// this platform. +func (k *KMS) SkipTests() bool { + return false +} diff --git a/kms/platform/kms_other_test.go b/kms/platform/kms_other_test.go new file mode 100644 index 00000000..48eddbca --- /dev/null +++ b/kms/platform/kms_other_test.go @@ -0,0 +1,17 @@ +//go:build !darwin && !windows + +package platform + +import ( + "testing" +) + +func mustPlatformKMS(t *testing.T) *KMS { + return &KMS{} +} + +// SkipTest is a method implemented on tests that allow skipping the test on +// this platform. +func (k *KMS) SkipTests() bool { + return true +} diff --git a/kms/platform/kms_softkms.go b/kms/platform/kms_softkms.go index 9dbe8483..4e6af8c8 100644 --- a/kms/platform/kms_softkms.go +++ b/kms/platform/kms_softkms.go @@ -33,12 +33,17 @@ type softKMS struct { } func (k *softKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) { + name := filename(req.Name) + if name == "" { + return nil, fmt.Errorf("createKeyRequest 'name' cannot be empty") + } + resp, err := k.SoftKMS.CreateKey(req) if err != nil { return nil, err } - if _, err := pemutil.Serialize(resp.PrivateKey, pemutil.ToFile(resp.Name, 0o600)); err != nil { + if _, err := pemutil.Serialize(resp.PrivateKey, pemutil.ToFile(name, 0o600)); err != nil { return nil, err } @@ -46,30 +51,33 @@ func (k *softKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespon } func (k *softKMS) DeleteKey(req *apiv1.DeleteKeyRequest) error { - if req.Name == "" { + name := filename(req.Name) + if name == "" { return fmt.Errorf("deleteKeyRequest 'name' cannot be empty") } - return os.Remove(filename(req.Name)) + return os.Remove(name) } func (k *softKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { + name := filename(req.Name) switch { - case req.Name == "": + case name == "": return fmt.Errorf("storeCertificateRequest 'name' cannot be empty") case req.Certificate == nil: return fmt.Errorf("storeCertificateRequest 'certificate' cannot be empty") } - fmt.Println(filename(req.Name)) - return os.WriteFile(filename(req.Name), pem.EncodeToMemory(&pem.Block{ + + return os.WriteFile(name, pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: req.Certificate.Raw, }), 0o600) } func (k *softKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) error { + name := filename(req.Name) switch { - case req.Name == "": + case name == "": return fmt.Errorf("storeCertificateChainRequest 'name' cannot be empty") case len(req.CertificateChain) == 0: return fmt.Errorf("storeCertificateChainRequest 'certificateChain' cannot be empty") @@ -85,15 +93,16 @@ func (k *softKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) } } - return os.WriteFile(filename(req.Name), buf.Bytes(), 0o600) + return os.WriteFile(name, buf.Bytes(), 0o600) } func (k *softKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { - if req.Name == "" { + name := filename(req.Name) + if name == "" { return fmt.Errorf("deleteCertificateRequest 'name' cannot be empty") } - return os.Remove(filename(req.Name)) + return os.Remove(name) } func filename(s string) string { diff --git a/kms/platform/kms_test.go b/kms/platform/kms_test.go index fe363b6c..6dcba2c3 100644 --- a/kms/platform/kms_test.go +++ b/kms/platform/kms_test.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto" "crypto/ecdsa" + "crypto/elliptic" "crypto/rsa" "crypto/x509" "encoding/pem" @@ -18,8 +19,17 @@ import ( "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/minica" "go.step.sm/crypto/pemutil" + "go.step.sm/crypto/randutil" ) +func shouldSkipNow(t *testing.T, km *KMS) { + t.Helper() + + if km.Type() != apiv1.SoftKMS && km.SkipTests() { + t.SkipNow() + } +} + func mustKMS(t *testing.T, rawuri string) *KMS { t.Helper() @@ -27,6 +37,10 @@ func mustKMS(t *testing.T, rawuri string) *KMS { URI: rawuri, }) require.NoError(t, err) + + t.Cleanup(func() { + assert.NoError(t, km.Close()) + }) return km } @@ -71,16 +85,38 @@ func mustCertificate(t *testing.T, path string) []*x509.Certificate { }) require.NoError(t, err) - var buf bytes.Buffer - require.NoError(t, pem.Encode(&buf, &pem.Block{ - Type: "CERTIFICATE", - Bytes: cert.Raw, - })) - require.NoError(t, pem.Encode(&buf, &pem.Block{ - Type: "CERTIFICATE", - Bytes: ca.Intermediate.Raw, - })) - require.NoError(t, os.WriteFile(path, buf.Bytes(), 0o600)) + if path != "" { + var buf bytes.Buffer + require.NoError(t, pem.Encode(&buf, &pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + })) + require.NoError(t, pem.Encode(&buf, &pem.Block{ + Type: "CERTIFICATE", + Bytes: ca.Intermediate.Raw, + })) + + require.NoError(t, os.WriteFile(path, buf.Bytes(), 0o600)) + } + + return []*x509.Certificate{ + cert, ca.Intermediate, + } +} + +func mustCertificateWithKey(t *testing.T, key crypto.PublicKey) []*x509.Certificate { + t.Helper() + + ca, err := minica.New() + require.NoError(t, err) + + cert, err := ca.Sign(&x509.Certificate{ + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + PublicKey: key, + DNSNames: []string{"example.com"}, + }) + require.NoError(t, err) return []*x509.Certificate{ cert, ca.Intermediate, @@ -93,7 +129,10 @@ func TestKMS_Type(t *testing.T) { } func TestKMS_Close(t *testing.T) { - softKMS := mustKMS(t, "kms:backend=softkms") + softKMS, err := New(t.Context(), apiv1.Options{ + URI: "kms:backend=softkms", + }) + require.NoError(t, err) assert.NoError(t, softKMS.Close()) } @@ -134,6 +173,8 @@ func TestKMS_GetPublicKey(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + got, err := tt.kms.GetPublicKey(tt.args.req) tt.assertion(t, err) assert.Equal(t, tt.want, got) @@ -144,8 +185,12 @@ func TestKMS_GetPublicKey(t *testing.T) { func TestKMS_CreateKey(t *testing.T) { dir := t.TempDir() privateKeyPath := filepath.Join(dir, "private.key") + platformKMS := mustPlatformKMS(t) softKMS := mustKMS(t, "kms:backend=softkms") + suffix, err := randutil.Alphanumeric(8) + require.NoError(t, err) + type args struct { req *apiv1.CreateKeyRequest } @@ -156,6 +201,42 @@ func TestKMS_CreateKey(t *testing.T) { equal func(t *testing.T, got *apiv1.CreateKeyResponse) assertion assert.ErrorAssertionFunc }{ + {"ok default", platformKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=test1-" + suffix, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + require.NotNil(t, got) + + t.Cleanup(func() { + assert.NoError(t, platformKMS.DeleteKey(&apiv1.DeleteKeyRequest{ + Name: "kms:name=test1-" + suffix, + })) + }) + + assert.Regexp(t, "^kms:.*name=test1-.*$", got.Name) + assert.Equal(t, got.Name, got.CreateSignerRequest.SigningKey) + if assert.IsType(t, &ecdsa.PublicKey{}, got.PublicKey) { + assert.Equal(t, elliptic.P256(), got.PublicKey.(*ecdsa.PublicKey).Curve) + } + }, assert.NoError}, + {"ok rsa", platformKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=test2-" + suffix, + SignatureAlgorithm: apiv1.SHA256WithRSA, + Bits: 2048, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + require.NotNil(t, got) + + t.Cleanup(func() { + assert.NoError(t, platformKMS.DeleteKey(&apiv1.DeleteKeyRequest{ + Name: "kms:name=test2-" + suffix, + })) + }) + + assert.Regexp(t, "^kms:.*name=test2-.*$", got.Name) + assert.Equal(t, got.Name, got.CreateSignerRequest.SigningKey) + if assert.IsType(t, &rsa.PublicKey{}, got.PublicKey) { + assert.Equal(t, 256, got.PublicKey.(*rsa.PublicKey).Size()) + } + }, assert.NoError}, {"ok softKMS", softKMS, args{&apiv1.CreateKeyRequest{ Name: "kms:name=" + privateKeyPath, }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { @@ -211,14 +292,42 @@ func TestKMS_CreateKey(t *testing.T) { }, }) }, assert.NoError}, + {"fail createKey", softKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:" + privateKeyPath, + SignatureAlgorithm: apiv1.SignatureAlgorithm(100), + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, {"fail parseURI", softKMS, args{&apiv1.CreateKeyRequest{ Name: "softkms:" + privateKeyPath, }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { assert.Nil(t, got) }, assert.Error}, + {"fail empty name", softKMS, args{&apiv1.CreateKeyRequest{ + Name: "", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, + {"fail empty uri", softKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, + {"fail empty uri name", softKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, + {"fail empty uri path", softKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:path=", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + got, err := tt.kms.CreateKey(tt.args.req) tt.assertion(t, err) tt.equal(t, got) @@ -259,9 +368,17 @@ func TestKMS_CreateSigner(t *testing.T) { {"fail parseURI", softKMS, args{&apiv1.CreateSignerRequest{ SigningKey: privateKeyPath, }}, nil, assert.Error}, + {"fail signingKey", softKMS, args{&apiv1.CreateSignerRequest{ + SigningKey: "", + }}, nil, assert.Error}, + {"fail empty uri", softKMS, args{&apiv1.CreateSignerRequest{ + SigningKey: "kms:", + }}, nil, assert.Error}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + got, err := tt.kms.CreateSigner(tt.args.req) tt.assertion(t, err) assert.Equal(t, tt.want, got) @@ -308,9 +425,17 @@ func TestKMS_DeleteKey(t *testing.T) { return assert.Error(t, err) && assert.FileExists(t, keyPath2) }}, + {"fail name", softKMS, args{&apiv1.DeleteKeyRequest{ + Name: "", + }}, assert.Error}, + {"fail empty uri", softKMS, args{&apiv1.DeleteKeyRequest{ + Name: "kms:", + }}, assert.Error}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + tt.assertion(t, tt.kms.DeleteKey(tt.args.req)) }) } @@ -343,14 +468,253 @@ func TestKMS_LoadCertificate(t *testing.T) { Name: "kms:" + certPath, }}, chain[0], assert.NoError}, {"ok softKMS from chain", softKMS, args{&apiv1.LoadCertificateRequest{ - Name: "kms:name=" + chainPath, + Name: "kms:name=" + url.QueryEscape(chainPath), }}, chain[0], assert.NoError}, + {"fail missing", softKMS, args{&apiv1.LoadCertificateRequest{ + Name: "kms:name=" + filepath.Join(dir, "missing.crt"), + }}, nil, assert.Error}, + {"fail parseURI", softKMS, args{&apiv1.LoadCertificateRequest{ + Name: "foo:name=" + certPath, + }}, nil, assert.Error}, + {"fail name", softKMS, args{&apiv1.LoadCertificateRequest{ + Name: "", + }}, nil, assert.Error}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + got, err := tt.kms.LoadCertificate(tt.args.req) tt.assertion(t, err) assert.Equal(t, tt.want, got) }) } } + +func TestKMS_StoreCertificate(t *testing.T) { + dir := t.TempDir() + chain := mustCertificate(t, "") + softKMS := mustKMS(t, "kms:backend=softkms") + + type args struct { + req *apiv1.StoreCertificateRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok softKMS", softKMS, args{&apiv1.StoreCertificateRequest{ + Name: "kms:name=" + filepath.Join(dir, "cert.crt"), + Certificate: chain[0], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) && + assert.FileExists(t, filepath.Join(dir, "cert.crt")) + }}, + {"ok softKMS simple", softKMS, args{&apiv1.StoreCertificateRequest{ + Name: "kms:" + filepath.Join(dir, "intermediate.crt"), + Certificate: chain[1], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) && + assert.FileExists(t, filepath.Join(dir, "intermediate.crt")) + }}, + {"ok softKMS overwrite", softKMS, args{&apiv1.StoreCertificateRequest{ + Name: "kms:" + filepath.Join(dir, "cert.crt"), + Certificate: chain[0], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) && + assert.FileExists(t, filepath.Join(dir, "cert.crt")) + }}, + {"fail parseURI", softKMS, args{&apiv1.StoreCertificateRequest{ + Name: "foo:" + filepath.Join(dir, "fail.crt"), + Certificate: chain[0], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.Error(t, err) && + assert.NoFileExists(t, filepath.Join(dir, "fail.crt")) + }}, + {"fail name", softKMS, args{&apiv1.StoreCertificateRequest{ + Certificate: chain[0], + }}, assert.Error}, + {"fail empty uri", softKMS, args{&apiv1.StoreCertificateRequest{ + Name: "kms:", + Certificate: chain[0], + }}, assert.Error}, + {"fail certificate", softKMS, args{&apiv1.StoreCertificateRequest{ + Name: "kms:name=" + filepath.Join(dir, "cert.crt"), + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + tt.assertion(t, tt.kms.StoreCertificate(tt.args.req)) + }) + } +} + +func TestKMS_LoadCertificateChain(t *testing.T) { + dir := t.TempDir() + chainPath := filepath.Join(dir, "chain.crt") + chain := mustCertificate(t, chainPath) + softKMS := mustKMS(t, "kms:backend=softkms") + + type args struct { + req *apiv1.LoadCertificateChainRequest + } + tests := []struct { + name string + kms *KMS + args args + want []*x509.Certificate + assertion assert.ErrorAssertionFunc + }{ + {"ok softKMS", softKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: "kms:name=" + chainPath, + }}, chain, assert.NoError}, + {"fail parseURI", softKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: "foo:name=" + chainPath, + }}, nil, assert.Error}, + {"fail missing", softKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: "kms:name=" + filepath.Join(dir, "missing.crt"), + }}, nil, assert.Error}, + {"fail parseuri", softKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: "softkms:name=" + chainPath, + }}, nil, assert.Error}, + {"fail name", softKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: "", + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + + got, err := tt.kms.LoadCertificateChain(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_StoreCertificateChain(t *testing.T) { + dir := t.TempDir() + chain := mustCertificate(t, "") + softKMS := mustKMS(t, "kms:backend=softkms") + + type args struct { + req *apiv1.StoreCertificateChainRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok softKMS", softKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=" + filepath.Join(dir, "chain.crt"), + CertificateChain: chain, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) && assert.FileExists(t, filepath.Join(dir, "chain.crt")) + }}, + {"ok softKMS escape", softKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=" + url.QueryEscape(filepath.Join(dir, "leaf.crt")), + CertificateChain: chain[:1], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) && assert.FileExists(t, filepath.Join(dir, "leaf.crt")) + }}, + {"fail parseURI", softKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: "foo:name=" + filepath.Join(dir, "other.crt"), + CertificateChain: chain, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.Error(t, err) && assert.NoFileExists(t, filepath.Join(dir, "other.crt")) + }}, + {"fail name", softKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: "", + CertificateChain: chain, + }}, assert.Error}, + {"fail empty uri", softKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:", + CertificateChain: chain, + }}, assert.Error}, + {"fail certificateChain", softKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=" + filepath.Join(dir, "other.crt"), + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + + tt.assertion(t, tt.kms.StoreCertificateChain(tt.args.req)) + }) + } +} + +func TestKMS_DeleteCertificate(t *testing.T) { + dir := t.TempDir() + _ = mustCertificate(t, filepath.Join(dir, "chain.crt")) + softKMS := mustKMS(t, "kms:backend=softkms") + + type args struct { + req *apiv1.DeleteCertificateRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok softKMS", softKMS, args{&apiv1.DeleteCertificateRequest{ + Name: "kms:name=" + url.QueryEscape(filepath.Join(dir, "chain.crt")), + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) && + assert.NoFileExists(t, filepath.Join(dir, "chain.crt")) + }}, + {"fail missing", softKMS, args{&apiv1.DeleteCertificateRequest{ + Name: "kms:name=" + url.QueryEscape(filepath.Join(dir, "chain.crt")), + }}, assert.Error}, + {"fail parseURI", softKMS, args{&apiv1.DeleteCertificateRequest{ + Name: "foo", + }}, assert.Error}, + {"fail name", softKMS, args{&apiv1.DeleteCertificateRequest{ + Name: "", + }}, assert.Error}, + {"fail empty uri", softKMS, args{&apiv1.DeleteCertificateRequest{ + Name: "kms:", + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + + tt.assertion(t, tt.kms.DeleteCertificate(tt.args.req)) + }) + } +} + +func TestKMS_SearchKeys(t *testing.T) { + dir := t.TempDir() + softKMS := mustKMS(t, "kms:backend=softkms") + + type args struct { + req *apiv1.SearchKeysRequest + } + tests := []struct { + name string + kms *KMS + args args + want *apiv1.SearchKeysResponse + assertion assert.ErrorAssertionFunc + }{ + {"fail softKMS", softKMS, args{&apiv1.SearchKeysRequest{ + Query: "kms:name=" + url.QueryEscape(dir), + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + + got, err := tt.kms.SearchKeys(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/kms/platform/kms_tpm.go b/kms/platform/kms_tpm.go index 47c6e1c9..c8580662 100644 --- a/kms/platform/kms_tpm.go +++ b/kms/platform/kms_tpm.go @@ -44,8 +44,10 @@ func NewWithTPM(ctx context.Context, t *tpm.TPM, opts ...tpmkms.Option) (*KMS, e } return &KMS{ - backend: km, - transformToURI: transformToTPMKMS, + typ: apiv1.TPMKMS, + backend: km, + transformToURI: transformToTPMKMS, + transformFromURI: transformFromTPMKMS, }, nil } @@ -64,6 +66,8 @@ func transformToTPMKMS(u *kmsURI) string { } if u.hw { uv.Set("ak", "true") + } else if u.uri.Has("hw") { + uv.Set("ak", "false") } // Add custom extra values that might be tpmkms specific. diff --git a/kms/platform/kms_tpmsimulator_test.go b/kms/platform/kms_tpmsimulator_test.go new file mode 100644 index 00000000..d7acab3e --- /dev/null +++ b/kms/platform/kms_tpmsimulator_test.go @@ -0,0 +1,816 @@ +//go:build tpmsimulator + +package platform + +import ( + "context" + "crypto" + "crypto/x509" + "net" + "net/url" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/uri" + "go.step.sm/crypto/tpm" + "go.step.sm/crypto/tpm/simulator" + "go.step.sm/crypto/tpm/storage" +) + +func mustTPM(t *testing.T) *tpm.TPM { + t.Helper() + + sim, err := simulator.New() + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, sim.Close()) + }) + require.NoError(t, sim.Open()) + + dir := t.TempDir() + + stpm, err := tpm.New(tpm.WithSimulator(sim), tpm.WithStore(storage.NewDirstore(dir))) + require.NoError(t, err) + + return stpm +} + +func mustTPMDevice(t *testing.T) (*tpm.TPM, string, string) { + t.Helper() + + sim, err := simulator.New() + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, sim.Close()) + }) + require.NoError(t, sim.Open()) + + dir := t.TempDir() + stpm, err := tpm.New(tpm.WithSimulator(sim), tpm.WithStore(storage.NewDirstore(dir))) + require.NoError(t, err) + + listener := &net.ListenConfig{} + socket := filepath.Join(dir, "tpm.sock") + ln, err := listener.Listen(context.TODO(), "unix", socket) + require.NoError(t, err) + + go func() { + for { + conn, err := ln.Accept() + require.NoError(t, err) + + go func(conn net.Conn) { + defer conn.Close() + + readBuf := make([]byte, 4096) + n, err := conn.Read(readBuf) + require.NoError(t, err) + + _, err = sim.Write(readBuf[:n]) + require.NoError(t, err) + + writeBuf := make([]byte, 4096) + nr, err := sim.Read(writeBuf) + require.NoError(t, err) + + _, err = conn.Write(writeBuf[:nr]) + require.NoError(t, err) + }(conn) + } + }() + + return stpm, socket, dir +} + +func mustTPMKMS(t *testing.T) (*KMS, *tpm.TPM) { + t.Helper() + + stpm, sock, dir := mustTPMDevice(t) + km := mustKMS(t, uri.New(Scheme, url.Values{ + "backend": []string{"tpmkms"}, + "device": []string{sock}, + "storage-directory": []string{dir}, + }).String()) + + return km, stpm +} + +func TestKMS_Type_tpm(t *testing.T) { + kms1, stpm := mustTPMKMS(t) + assert.Equal(t, apiv1.TPMKMS, kms1.Type()) + + kms2, err := NewWithTPM(t.Context(), stpm) + require.NoError(t, err) + assert.Equal(t, apiv1.TPMKMS, kms2.Type()) + +} + +func TestKMS_Close_tpm(t *testing.T) { + kms1, stpm := mustTPMKMS(t) + assert.NoError(t, kms1.Close()) + + kms2, err := NewWithTPM(t.Context(), stpm) + require.NoError(t, err) + assert.NoError(t, kms2.Close()) +} + +func TestKMS_GetPublicKey_tpm(t *testing.T) { + ctx := t.Context() + kms1, stpm := mustTPMKMS(t) + kms2, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + key, err := stpm.CreateKey(ctx, "key-1", tpm.CreateKeyConfig{ + Algorithm: "RSA", + Size: 2048, + }) + require.NoError(t, err) + + keySigner, err := key.Signer(ctx) + require.NoError(t, err) + + ak, err := stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + type args struct { + req *apiv1.GetPublicKeyRequest + } + tests := []struct { + name string + kms *KMS + args args + want crypto.PublicKey + assertion assert.ErrorAssertionFunc + }{ + {"ok key", kms1, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:name=key-1", + }}, keySigner.Public(), assert.NoError}, + {"ok ak", kms1, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:name=ak-1;hw=true", + }}, ak.Public(), assert.NoError}, + {"ok key with tpm", kms2, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:name=key-1", + }}, keySigner.Public(), assert.NoError}, + {"ok ak with tpm", kms2, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:name=ak-1;hw=true", + }}, ak.Public(), assert.NoError}, + {"fail missing key", kms1, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:name=key-2", + }}, nil, assert.Error}, + {"fail missing ak", kms2, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:name=ak-2;hw=true", + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.GetPublicKey(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_CreateKey_tpm(t *testing.T) { + ctx := t.Context() + kms1, stpm := mustTPMKMS(t) + kms2, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + type args struct { + req *apiv1.CreateKeyRequest + } + tests := []struct { + name string + kms *KMS + args args + equal func(t *testing.T, got *apiv1.CreateKeyResponse) + assertion assert.ErrorAssertionFunc + }{ + {"ok", kms1, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=key-1", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + key, err := stpm.GetKey(ctx, "key-1") + require.NoError(t, err) + signer, err := key.Signer(ctx) + require.NoError(t, err) + + require.NotNil(t, got) + require.NotNil(t, got.CreateSignerRequest.Signer) + + assert.Equal(t, signer.Public(), got.CreateSignerRequest.Signer.Public()) + got.CreateSignerRequest.Signer = signer + + assert.Equal(t, got, &apiv1.CreateKeyResponse{ + Name: "kms:name=key-1", + PublicKey: signer.Public(), + CreateSignerRequest: apiv1.CreateSignerRequest{ + Signer: signer, + SigningKey: "kms:name=key-1", + }, + }) + }, assert.NoError}, + {"ok ak", kms1, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=ak-1;hw=true", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + key, err := stpm.GetAK(ctx, "ak-1") + require.NoError(t, err) + + assert.Equal(t, got, &apiv1.CreateKeyResponse{ + Name: "kms:hw=true;name=ak-1", + PublicKey: key.Public(), + }) + }, assert.NoError}, + {"ok with tpm", kms2, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=key-2", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + key, err := stpm.GetKey(ctx, "key-2") + require.NoError(t, err) + signer, err := key.Signer(ctx) + require.NoError(t, err) + + require.NotNil(t, got) + require.NotNil(t, got.CreateSignerRequest.Signer) + + assert.Equal(t, signer.Public(), got.CreateSignerRequest.Signer.Public()) + got.CreateSignerRequest.Signer = signer + + assert.Equal(t, got, &apiv1.CreateKeyResponse{ + Name: "kms:name=key-2", + PublicKey: signer.Public(), + CreateSignerRequest: apiv1.CreateSignerRequest{ + Signer: signer, + SigningKey: "kms:name=key-2", + }, + }) + }, assert.NoError}, + {"ok ak with tpm", kms2, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=ak-2;hw=true", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + key, err := stpm.GetAK(ctx, "ak-2") + require.NoError(t, err) + + assert.Equal(t, got, &apiv1.CreateKeyResponse{ + Name: "kms:hw=true;name=ak-2", + PublicKey: key.Public(), + }) + }, assert.NoError}, + {"fail key already exists", kms1, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=key-2", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, + {"fail ak already exists", kms2, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=ak-1;hw=true", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.CreateKey(tt.args.req) + tt.assertion(t, err) + tt.equal(t, got) + }) + } +} + +func TestKMS_CreateSigner_tpm(t *testing.T) { + ctx := t.Context() + stpm := mustTPM(t) + km, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + key, err := stpm.CreateKey(ctx, "key-1", tpm.CreateKeyConfig{ + Algorithm: "RSA", + Size: 2048, + }) + require.NoError(t, err) + + signer, err := key.Signer(ctx) + require.NoError(t, err) + + _, err = stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + type args struct { + req *apiv1.CreateSignerRequest + } + tests := []struct { + name string + kms *KMS + args args + equal func(*testing.T, crypto.Signer) + assertion assert.ErrorAssertionFunc + }{ + {"ok key", km, args{&apiv1.CreateSignerRequest{ + SigningKey: "kms:name=key-1", + }}, func(t *testing.T, got crypto.Signer) { + require.NotNil(t, got) + assert.Equal(t, signer.Public(), got.Public()) + }, assert.NoError}, + {"ok key with signer", km, args{&apiv1.CreateSignerRequest{ + Signer: signer, + SigningKey: "kms:name=key1", + }}, func(t *testing.T, got crypto.Signer) { + assert.Equal(t, signer, got) + }, assert.NoError}, + {"fail missing", km, args{&apiv1.CreateSignerRequest{ + SigningKey: "kms:name=key-2", + }}, func(t *testing.T, got crypto.Signer) { + assert.Nil(t, got) + }, assert.Error}, + {"fail with ak", km, args{&apiv1.CreateSignerRequest{ + SigningKey: "kms:name=ak-1;hw=true", + }}, func(t *testing.T, got crypto.Signer) { + assert.Nil(t, got) + }, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.CreateSigner(tt.args.req) + tt.assertion(t, err) + tt.equal(t, got) + }) + } +} + +func TestKMS_DeleteKey_tpm(t *testing.T) { + ctx := t.Context() + stpm := mustTPM(t) + km, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + _, err = stpm.CreateKey(ctx, "key-1", tpm.CreateKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + }) + require.NoError(t, err) + + _, err = stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + type args struct { + req *apiv1.DeleteKeyRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok key", km, args{&apiv1.DeleteKeyRequest{ + Name: "kms:name=key-1", + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + _, keyErr := stpm.GetKey(ctx, "key-1") + return assert.NoError(t, err) && assert.Error(t, keyErr) + }}, + {"ok ak", km, args{&apiv1.DeleteKeyRequest{ + Name: "kms:name=ak-1;hw=true", + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + _, akErr := stpm.GetAK(ctx, "ak-1") + return assert.NoError(t, err) && assert.Error(t, akErr) + }}, + {"fail missing key", km, args{&apiv1.DeleteKeyRequest{ + Name: "kms:name=key-2", + }}, assert.Error}, + {"fail missing ak", km, args{&apiv1.DeleteKeyRequest{ + Name: "kms:name=ak-2;hw=true", + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, tt.kms.DeleteKey(tt.args.req)) + }) + } +} + +func TestKMS_LoadCertificate_tpm(t *testing.T) { + ctx := t.Context() + stpm := mustTPM(t) + km, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + key, err := stpm.CreateKey(ctx, "key-1", tpm.CreateKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + }) + require.NoError(t, err) + + ak, err := stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + _, err = stpm.CreateKey(ctx, "key-2", tpm.CreateKeyConfig{ + Algorithm: "RSA", + Size: 2048, + }) + require.NoError(t, err) + + _, err = stpm.CreateAK(ctx, "ak-2") + require.NoError(t, err) + + keyChain := mustCertificateWithKey(t, key.Public()) + require.NoError(t, key.SetCertificateChain(ctx, keyChain)) + + akChain := mustCertificateWithKey(t, ak.Public()) + require.NoError(t, ak.SetCertificateChain(ctx, akChain)) + + type args struct { + req *apiv1.LoadCertificateRequest + } + tests := []struct { + name string + kms *KMS + args args + want *x509.Certificate + assertion assert.ErrorAssertionFunc + }{ + {"ok", km, args{&apiv1.LoadCertificateRequest{ + Name: "kms:name=key-1", + }}, keyChain[0], assert.NoError}, + {"ok ak", km, args{&apiv1.LoadCertificateRequest{ + Name: "kms:name=ak-1;hw=true", + }}, akChain[0], assert.NoError}, + {"fail no certificate", km, args{&apiv1.LoadCertificateRequest{ + Name: "kms:name=key-2", + }}, nil, assert.Error}, + {"fail no ak certificate", km, args{&apiv1.LoadCertificateRequest{ + Name: "kms:name=ak-2;hw=true", + }}, nil, assert.Error}, + {"fail missing", km, args{&apiv1.LoadCertificateRequest{ + Name: "kms:name=missing-key", + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.LoadCertificate(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_StoreCertificate_tpm(t *testing.T) { + ctx := t.Context() + stpm := mustTPM(t) + km, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + key, err := stpm.CreateKey(ctx, "key-1", tpm.CreateKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + }) + require.NoError(t, err) + + ak, err := stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + keyChain1 := mustCertificateWithKey(t, key.Public()) + keyChain2 := mustCertificateWithKey(t, key.Public()) + akChain1 := mustCertificateWithKey(t, ak.Public()) + akChain2 := mustCertificateWithKey(t, ak.Public()) + + type args struct { + req *apiv1.StoreCertificateRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok", km, args{&apiv1.StoreCertificateRequest{ + Name: "kms:name=key-1", + Certificate: keyChain1[0], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetKey(ctx, "key-1") + require.NoError(t, err) + return assert.Equal(t, keyChain1[0], k.Certificate()) + }}, + {"ok overwrite", km, args{&apiv1.StoreCertificateRequest{ + Name: "kms:name=key-1", + Certificate: keyChain2[0], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetKey(ctx, "key-1") + require.NoError(t, err) + return assert.Equal(t, keyChain2[0], k.Certificate()) + }}, + {"ok ak", km, args{&apiv1.StoreCertificateRequest{ + Name: "kms:name=ak-1;hw=true", + Certificate: akChain1[0], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetAK(ctx, "ak-1") + require.NoError(t, err) + return assert.Equal(t, akChain1[0], k.Certificate()) + }}, + {"ok ak overwrite", km, args{&apiv1.StoreCertificateRequest{ + Name: "kms:name=ak-1;hw=true", + Certificate: akChain2[0], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetAK(ctx, "ak-1") + require.NoError(t, err) + return assert.Equal(t, akChain2[0], k.Certificate()) + }}, + {"fail missing", km, args{&apiv1.StoreCertificateRequest{ + Name: "kms:name=missing-key", + Certificate: keyChain1[0], + }}, assert.Error}, + {"fail key not match", km, args{&apiv1.StoreCertificateRequest{ + Name: "kms:name=key-1", + Certificate: akChain1[0], + }}, assert.Error}, + {"fail ak key not match", km, args{&apiv1.StoreCertificateRequest{ + Name: "kms:name=ak-1;hw=true", + Certificate: keyChain1[0], + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, tt.kms.StoreCertificate(tt.args.req)) + }) + } +} + +func TestKMS_LoadCertificateChain_tpm(t *testing.T) { + ctx := t.Context() + stpm := mustTPM(t) + km, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + key, err := stpm.CreateKey(ctx, "key-1", tpm.CreateKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + }) + require.NoError(t, err) + + ak, err := stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + _, err = stpm.CreateKey(ctx, "key-2", tpm.CreateKeyConfig{ + Algorithm: "RSA", + Size: 2048, + }) + require.NoError(t, err) + + _, err = stpm.CreateAK(ctx, "ak-2") + require.NoError(t, err) + + keyChain := mustCertificateWithKey(t, key.Public()) + require.NoError(t, key.SetCertificateChain(ctx, keyChain)) + + akChain := mustCertificateWithKey(t, ak.Public()) + require.NoError(t, ak.SetCertificateChain(ctx, akChain)) + + type args struct { + req *apiv1.LoadCertificateChainRequest + } + tests := []struct { + name string + kms *KMS + args args + want []*x509.Certificate + assertion assert.ErrorAssertionFunc + }{ + {"ok", km, args{&apiv1.LoadCertificateChainRequest{ + Name: "kms:name=key-1", + }}, keyChain, assert.NoError}, + {"ok ak", km, args{&apiv1.LoadCertificateChainRequest{ + Name: "kms:name=ak-1;hw=true", + }}, akChain, assert.NoError}, + {"fail no chain", km, args{&apiv1.LoadCertificateChainRequest{ + Name: "kms:name=key-2", + }}, nil, assert.Error}, + {"fail no ak chain", km, args{&apiv1.LoadCertificateChainRequest{ + Name: "kms:name=ak-2;hw=true", + }}, nil, assert.Error}, + {"fail missing", km, args{&apiv1.LoadCertificateChainRequest{ + Name: "kms:name=missing-key", + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.LoadCertificateChain(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_StoreCertificateChain_tpm(t *testing.T) { + ctx := t.Context() + stpm := mustTPM(t) + km, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + key, err := stpm.CreateKey(ctx, "key-1", tpm.CreateKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + }) + require.NoError(t, err) + + ak, err := stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + keyChain1 := mustCertificateWithKey(t, key.Public()) + keyChain2 := mustCertificateWithKey(t, key.Public()) + akChain1 := mustCertificateWithKey(t, ak.Public()) + akChain2 := mustCertificateWithKey(t, ak.Public()) + + type args struct { + req *apiv1.StoreCertificateChainRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok", km, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=key-1", + CertificateChain: keyChain1, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetKey(ctx, "key-1") + require.NoError(t, err) + return assert.Equal(t, keyChain1[0], k.Certificate()) && + assert.Equal(t, keyChain1, k.CertificateChain()) + }}, + {"ok overwrite", km, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=key-1", + CertificateChain: keyChain2, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetKey(ctx, "key-1") + require.NoError(t, err) + return assert.Equal(t, keyChain2[0], k.Certificate()) && + assert.Equal(t, keyChain2, k.CertificateChain()) + }}, + {"ok ak", km, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=ak-1;hw=true", + CertificateChain: akChain1, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetAK(ctx, "ak-1") + require.NoError(t, err) + return assert.Equal(t, akChain1[0], k.Certificate()) && + assert.Equal(t, akChain1, k.CertificateChain()) + }}, + {"ok ak overwrite", km, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=ak-1;hw=true", + CertificateChain: akChain2, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetAK(ctx, "ak-1") + require.NoError(t, err) + return assert.Equal(t, akChain2[0], k.Certificate()) && + assert.Equal(t, akChain2, k.CertificateChain()) + }}, + {"fail missing", km, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=missing-key", + CertificateChain: keyChain1, + }}, assert.Error}, + {"fail key not match", km, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=key-1", + CertificateChain: akChain1, + }}, assert.Error}, + {"fail ak key not match", km, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=ak-1;hw=true", + CertificateChain: keyChain1, + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, tt.kms.StoreCertificateChain(tt.args.req)) + }) + } +} + +func TestKMS_DeleteCertificate_tpm(t *testing.T) { + ctx := t.Context() + stpm := mustTPM(t) + km, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + key, err := stpm.CreateKey(ctx, "key-1", tpm.CreateKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + }) + require.NoError(t, err) + + ak, err := stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + keyChain := mustCertificateWithKey(t, key.Public()) + require.NoError(t, key.SetCertificateChain(ctx, keyChain)) + + akChain := mustCertificateWithKey(t, ak.Public()) + require.NoError(t, ak.SetCertificateChain(ctx, akChain)) + + type args struct { + req *apiv1.DeleteCertificateRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok", km, args{&apiv1.DeleteCertificateRequest{ + Name: "kms:name=key-1", + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetKey(ctx, "key-1") + require.NoError(t, err) + return assert.Nil(t, k.Certificate()) && assert.Nil(t, k.CertificateChain()) + }}, + {"ok ak", km, args{&apiv1.DeleteCertificateRequest{ + Name: "kms:name=ak-1;hw=true", + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetAK(ctx, "ak-1") + require.NoError(t, err) + return assert.Nil(t, k.Certificate()) && assert.Nil(t, k.CertificateChain()) + }}, + {"ok delete again", km, args{&apiv1.DeleteCertificateRequest{ + Name: "kms:name=key-1", + }}, assert.NoError}, + {"ok delete again ak", km, args{&apiv1.DeleteCertificateRequest{ + Name: "kms:name=ak-1;hw=true", + }}, assert.NoError}, + {"fail missing", km, args{&apiv1.DeleteCertificateRequest{ + Name: "kms:name=missing-ak;hw=true", + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, tt.kms.DeleteCertificate(tt.args.req)) + }) + } +} + +func TestKMS_SearchKeys_tpm(t *testing.T) { + ctx := t.Context() + stpm := mustTPM(t) + km, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + key1, err := stpm.CreateKey(ctx, "key-1", tpm.CreateKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + }) + require.NoError(t, err) + + key2, err := stpm.CreateKey(ctx, "key-2", tpm.CreateKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + }) + require.NoError(t, err) + + ak1, err := stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + ak2, err := stpm.CreateAK(ctx, "ak-2") + require.NoError(t, err) + + type args struct { + req *apiv1.SearchKeysRequest + } + tests := []struct { + name string + kms *KMS + args args + want *apiv1.SearchKeysResponse + assertion assert.ErrorAssertionFunc + }{ + {"ok", km, args{&apiv1.SearchKeysRequest{ + Query: "kms:", + }}, &apiv1.SearchKeysResponse{ + Results: []apiv1.SearchKeyResult{ + {Name: "kms:hw=true;name=ak-1", PublicKey: ak1.Public()}, + {Name: "kms:hw=true;name=ak-2", PublicKey: ak2.Public()}, + {Name: "kms:name=key-1", PublicKey: key1.Public(), CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: "kms:name=key-1"}}, + {Name: "kms:name=key-2", PublicKey: key2.Public(), CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: "kms:name=key-2"}}, + }, + }, assert.NoError}, + {"ok keys", km, args{&apiv1.SearchKeysRequest{ + Query: "kms:hw=false", + }}, &apiv1.SearchKeysResponse{ + Results: []apiv1.SearchKeyResult{ + {Name: "kms:name=key-1", PublicKey: key1.Public(), CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: "kms:name=key-1"}}, + {Name: "kms:name=key-2", PublicKey: key2.Public(), CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: "kms:name=key-2"}}, + }, + }, assert.NoError}, + {"ok aks", km, args{&apiv1.SearchKeysRequest{ + Query: "kms:hw=true", + }}, &apiv1.SearchKeysResponse{ + Results: []apiv1.SearchKeyResult{ + {Name: "kms:hw=true;name=ak-1", PublicKey: ak1.Public()}, + {Name: "kms:hw=true;name=ak-2", PublicKey: ak2.Public()}, + }, + }, assert.NoError}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.SearchKeys(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/kms/platform/kms_windows_test.go b/kms/platform/kms_windows_test.go new file mode 100644 index 00000000..884208fe --- /dev/null +++ b/kms/platform/kms_windows_test.go @@ -0,0 +1,17 @@ +//go:build windows + +package platform + +import ( + "testing" +) + +func mustPlatformKMS(t *testing.T) *KMS { + return &KMS{} +} + +// SkipTest is a method implemented on tests that allow skipping the test on +// this platform. +func (k *KMS) SkipTests() bool { + return true +} diff --git a/tpm/key.go b/tpm/key.go index 8cf61aa6..1b61ae2f 100644 --- a/tpm/key.go +++ b/tpm/key.go @@ -9,6 +9,7 @@ import ( "fmt" "time" + "github.com/google/go-tpm/legacy/tpm2" "github.com/smallstep/go-attestation/attest" internalkey "go.step.sm/crypto/tpm/internal/key" @@ -63,6 +64,25 @@ func (k *Key) WasAttestedBy(ak *AK) bool { return k.attestedBy == ak.name } +// Public returns the Key public key. This is backed +// by a call to the TPM, so it can fail. If it fails, +// nil is returned. +func (k *Key) Public() crypto.PublicKey { + blobs, err := k.Blobs(context.Background()) + if err != nil { + return nil + } + pub, err := tpm2.DecodePublic(blobs.public) + if err != nil { + return nil + } + publicKey, err := pub.Key() + if err != nil { + return nil + } + return publicKey +} + // Certificate returns the certificate for the Key, if set. // Will return nil in case no AK certificate is available. func (k *Key) Certificate() *x509.Certificate { From 9a2198fe04dd54e39d9247406590db1383a0baea Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 20 Feb 2026 15:46:44 -0800 Subject: [PATCH 09/27] Add platform tests on macos --- kms/platform/kms_test.go | 381 +++++++++++++++++++++++++++++++++++++-- tpm/key.go | 20 +- 2 files changed, 375 insertions(+), 26 deletions(-) diff --git a/kms/platform/kms_test.go b/kms/platform/kms_test.go index 6dcba2c3..d7c9e8bc 100644 --- a/kms/platform/kms_test.go +++ b/kms/platform/kms_test.go @@ -8,9 +8,11 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "fmt" "net/url" "os" "path/filepath" + "runtime" "testing" "github.com/stretchr/testify/assert" @@ -22,6 +24,30 @@ import ( "go.step.sm/crypto/randutil" ) +var ( + platformKeyName string + platformCertName string + platformMissingName string +) + +func TestMain(m *testing.M) { + suffix, err := randutil.Alphanumeric(8) + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } + + platformKeyName = "kms:name=test-" + suffix + platformCertName = "kms:name=test-" + suffix + platformMissingName = "kms:name=test-missing-" + suffix + + if runtime.GOOS == "darwin" { + platformKeyName += ";tag=com.smallstep.test." + suffix + } + + os.Exit(m.Run()) +} + func shouldSkipNow(t *testing.T, km *KMS) { t.Helper() @@ -110,6 +136,11 @@ func mustCertificateWithKey(t *testing.T, key crypto.PublicKey) []*x509.Certific ca, err := minica.New() require.NoError(t, err) + // skipped platform + if key == nil { + return []*x509.Certificate{ca.Intermediate} + } + cert, err := ca.Sign(&x509.Certificate{ KeyUsage: x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, @@ -123,6 +154,107 @@ func mustCertificateWithKey(t *testing.T, key crypto.PublicKey) []*x509.Certific } } +func mustSuffix(t *testing.T) string { + t.Helper() + suffix, err := randutil.Alphanumeric(8) + require.NoError(t, err) + return suffix +} + +type createOptions struct { + name string + noCleanup bool +} + +type createFuncOption func(*createOptions) + +func withName(s string) createFuncOption { + return func(co *createOptions) { + co.name = s + } +} + +func withNoCleanup() createFuncOption { + return func(co *createOptions) { + co.noCleanup = true + } +} + +func mustCreatePlatformKey(t *testing.T, km *KMS, opts ...createFuncOption) *apiv1.CreateKeyResponse { + t.Helper() + + o := new(createOptions) + o.name = platformKeyName + for _, fn := range opts { + fn(o) + } + + if km.SkipTests() { + return &apiv1.CreateKeyResponse{} + } + + resp, err := km.CreateKey(&apiv1.CreateKeyRequest{ + Name: o.name, + }) + require.NoError(t, err) + + if !o.noCleanup { + t.Cleanup(func() { + assert.NoError(t, km.DeleteKey(&apiv1.DeleteKeyRequest{ + Name: resp.Name, + })) + }) + } + + return resp +} + +func mustCreatePlatformCertificate(t *testing.T, km *KMS, opts ...createFuncOption) []*x509.Certificate { + t.Helper() + + o := new(createOptions) + o.name = platformCertName + for _, fn := range opts { + fn(o) + } + + ca, err := minica.New() + require.NoError(t, err) + + if km.SkipTests() { + return []*x509.Certificate{ + ca.Intermediate, + } + } + + key := mustCreatePlatformKey(t, km, opts...) + cert, err := ca.Sign(&x509.Certificate{ + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + PublicKey: key.PublicKey, + DNSNames: []string{"example.com"}, + }) + require.NoError(t, err) + + require.NoError(t, km.StoreCertificateChain(&apiv1.StoreCertificateChainRequest{ + Name: o.name, + CertificateChain: []*x509.Certificate{ + cert, ca.Intermediate, + }, + })) + if !o.noCleanup { + t.Cleanup(func() { + assert.NoError(t, km.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: o.name, + })) + }) + } + + return []*x509.Certificate{ + cert, ca.Intermediate, + } +} + func TestKMS_Type(t *testing.T) { softKMS := mustKMS(t, "kms:backend=softkms") assert.Equal(t, apiv1.SoftKMS, softKMS.Type()) @@ -142,6 +274,9 @@ func TestKMS_GetPublicKey(t *testing.T) { signer := mustSigner(t, privateKeyPath) softKMS := mustKMS(t, "kms:backend=softkms") + platformKMS := mustPlatformKMS(t) + platformKey := mustCreatePlatformKey(t, platformKMS) + type args struct { req *apiv1.GetPublicKeyRequest } @@ -152,6 +287,18 @@ func TestKMS_GetPublicKey(t *testing.T) { want crypto.PublicKey assertion assert.ErrorAssertionFunc }{ + // Platform KMS + {"ok platform", platformKMS, args{&apiv1.GetPublicKeyRequest{ + Name: platformKey.Name, + }}, platformKey.PublicKey, assert.NoError}, + {"fail platform missing", platformKMS, args{&apiv1.GetPublicKeyRequest{ + Name: platformMissingName, + }}, nil, assert.Error}, + {"fail platform name", platformKMS, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:something=test", + }}, nil, assert.Error}, + + // SoftKMS {"ok SoftKMS", softKMS, args{&apiv1.GetPublicKeyRequest{ Name: "kms:name=" + privateKeyPath, }}, signer.Public(), assert.NoError}, @@ -167,7 +314,7 @@ func TestKMS_GetPublicKey(t *testing.T) { {"fail SoftKMS missing", softKMS, args{&apiv1.GetPublicKeyRequest{ Name: "kms:" + filepath.Join(dir, "notfound.key"), }}, nil, assert.Error}, - {"fail parseURI", softKMS, args{&apiv1.GetPublicKeyRequest{ + {"fail transform", softKMS, args{&apiv1.GetPublicKeyRequest{ Name: "softkms:" + privateKeyPath, }}, nil, assert.Error}, } @@ -185,11 +332,10 @@ func TestKMS_GetPublicKey(t *testing.T) { func TestKMS_CreateKey(t *testing.T) { dir := t.TempDir() privateKeyPath := filepath.Join(dir, "private.key") - platformKMS := mustPlatformKMS(t) softKMS := mustKMS(t, "kms:backend=softkms") - suffix, err := randutil.Alphanumeric(8) - require.NoError(t, err) + suffix := mustSuffix(t) + platformKMS := mustPlatformKMS(t) type args struct { req *apiv1.CreateKeyRequest @@ -201,7 +347,8 @@ func TestKMS_CreateKey(t *testing.T) { equal func(t *testing.T, got *apiv1.CreateKeyResponse) assertion assert.ErrorAssertionFunc }{ - {"ok default", platformKMS, args{&apiv1.CreateKeyRequest{ + // Platform KMS + {"ok platform", platformKMS, args{&apiv1.CreateKeyRequest{ Name: "kms:name=test1-" + suffix, }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { require.NotNil(t, got) @@ -218,10 +365,9 @@ func TestKMS_CreateKey(t *testing.T) { assert.Equal(t, elliptic.P256(), got.PublicKey.(*ecdsa.PublicKey).Curve) } }, assert.NoError}, - {"ok rsa", platformKMS, args{&apiv1.CreateKeyRequest{ + {"ok platform ECDSA", platformKMS, args{&apiv1.CreateKeyRequest{ Name: "kms:name=test2-" + suffix, - SignatureAlgorithm: apiv1.SHA256WithRSA, - Bits: 2048, + SignatureAlgorithm: apiv1.ECDSAWithSHA384, }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { require.NotNil(t, got) @@ -233,10 +379,37 @@ func TestKMS_CreateKey(t *testing.T) { assert.Regexp(t, "^kms:.*name=test2-.*$", got.Name) assert.Equal(t, got.Name, got.CreateSignerRequest.SigningKey) + if assert.IsType(t, &ecdsa.PublicKey{}, got.PublicKey) { + assert.Equal(t, elliptic.P384(), got.PublicKey.(*ecdsa.PublicKey).Curve) + } + }, assert.NoError}, + {"ok platform RSA", platformKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=test3-" + suffix, + SignatureAlgorithm: apiv1.SHA256WithRSA, + Bits: 2048, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + require.NotNil(t, got) + + t.Cleanup(func() { + assert.NoError(t, platformKMS.DeleteKey(&apiv1.DeleteKeyRequest{ + Name: "kms:name=test3-" + suffix, + })) + }) + + assert.Regexp(t, "^kms:.*name=test3-.*$", got.Name) + assert.Equal(t, got.Name, got.CreateSignerRequest.SigningKey) if assert.IsType(t, &rsa.PublicKey{}, got.PublicKey) { assert.Equal(t, 256, got.PublicKey.(*rsa.PublicKey).Size()) } }, assert.NoError}, + {"fail platform algorithm", platformKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:test4-" + suffix, + SignatureAlgorithm: apiv1.SignatureAlgorithm(100), + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, + + // SoftKMS {"ok softKMS", softKMS, args{&apiv1.CreateKeyRequest{ Name: "kms:name=" + privateKeyPath, }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { @@ -292,7 +465,7 @@ func TestKMS_CreateKey(t *testing.T) { }, }) }, assert.NoError}, - {"fail createKey", softKMS, args{&apiv1.CreateKeyRequest{ + {"fail softKMS createKey", softKMS, args{&apiv1.CreateKeyRequest{ Name: "kms:" + privateKeyPath, SignatureAlgorithm: apiv1.SignatureAlgorithm(100), }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { @@ -345,6 +518,14 @@ func TestKMS_CreateSigner(t *testing.T) { require.NoError(t, err) signer := mustReadSigner(t, privateKeyPath) + platformKMS := mustPlatformKMS(t) + platformKey := mustCreatePlatformKey(t, platformKMS) + + assertNil := func(t *testing.T, got crypto.Signer) { + t.Helper() + assert.Nil(t, got) + } + type args struct { req *apiv1.CreateSignerRequest } @@ -352,28 +533,44 @@ func TestKMS_CreateSigner(t *testing.T) { name string kms *KMS args args - want crypto.Signer + equal func(*testing.T, crypto.Signer) assertion assert.ErrorAssertionFunc }{ + // PlatformKMS + {"ok platform", platformKMS, args{&apiv1.CreateSignerRequest{ + SigningKey: platformKey.Name, + }}, func(t *testing.T, s crypto.Signer) { + require.NotNil(t, s) + assert.Equal(t, platformKey.PublicKey, s.Public()) + }, assert.NoError}, + {"fail platform missing", platformKMS, args{&apiv1.CreateSignerRequest{ + SigningKey: platformMissingName, + }}, assertNil, assert.Error}, + + // SoftKMS {"ok softKMS", softKMS, args{&apiv1.CreateSignerRequest{ SigningKey: "kms:name=" + url.QueryEscape(privateKeyPath), - }}, signer, assert.NoError}, + }}, func(t *testing.T, s crypto.Signer) { + assert.Equal(t, signer, s) + }, assert.NoError}, {"ok softKMS with signer", softKMS, args{&apiv1.CreateSignerRequest{ Signer: resp.CreateSignerRequest.Signer, SigningKey: resp.CreateSignerRequest.SigningKey, - }}, signer, assert.NoError}, + }}, func(t *testing.T, s crypto.Signer) { + assert.Equal(t, signer, s) + }, assert.NoError}, {"fail missing", softKMS, args{&apiv1.CreateSignerRequest{ SigningKey: "kms:name=" + url.QueryEscape(filepath.Join(dir, "missing.key")), - }}, nil, assert.Error}, + }}, assertNil, assert.Error}, {"fail parseURI", softKMS, args{&apiv1.CreateSignerRequest{ SigningKey: privateKeyPath, - }}, nil, assert.Error}, + }}, assertNil, assert.Error}, {"fail signingKey", softKMS, args{&apiv1.CreateSignerRequest{ SigningKey: "", - }}, nil, assert.Error}, + }}, assertNil, assert.Error}, {"fail empty uri", softKMS, args{&apiv1.CreateSignerRequest{ SigningKey: "kms:", - }}, nil, assert.Error}, + }}, assertNil, assert.Error}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -381,7 +578,7 @@ func TestKMS_CreateSigner(t *testing.T) { got, err := tt.kms.CreateSigner(tt.args.req) tt.assertion(t, err) - assert.Equal(t, tt.want, got) + tt.equal(t, got) }) } } @@ -401,6 +598,9 @@ func TestKMS_DeleteKey(t *testing.T) { }) require.NoError(t, err) + platformKMS := mustPlatformKMS(t) + platformKey := mustCreatePlatformKey(t, platformKMS, withNoCleanup()) + type args struct { req *apiv1.DeleteKeyRequest } @@ -410,6 +610,23 @@ func TestKMS_DeleteKey(t *testing.T) { args args assertion assert.ErrorAssertionFunc }{ + // Platform KMS + {"ok platform", platformKMS, args{&apiv1.DeleteKeyRequest{ + Name: platformKey.Name, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + _, getErr := platformKMS.GetPublicKey(&apiv1.GetPublicKeyRequest{ + Name: platformKey.Name, + }) + return assert.NoError(t, err) && assert.Error(t, getErr) + }}, + {"fail platform deleted", softKMS, args{&apiv1.DeleteKeyRequest{ + Name: platformKey.Name, + }}, assert.Error}, + {"fail platform missing", softKMS, args{&apiv1.DeleteKeyRequest{ + Name: platformMissingName, + }}, assert.Error}, + + // SoftKMS {"ok softKMS", softKMS, args{&apiv1.DeleteKeyRequest{ Name: "kms:name=" + url.QueryEscape(keyPath1), }}, func(tt assert.TestingT, err error, i ...interface{}) bool { @@ -454,6 +671,9 @@ func TestKMS_LoadCertificate(t *testing.T) { Certificate: chain[0], })) + platformKMS := mustPlatformKMS(t) + platformChain := mustCreatePlatformCertificate(t, platformKMS) + type args struct { req *apiv1.LoadCertificateRequest } @@ -464,6 +684,15 @@ func TestKMS_LoadCertificate(t *testing.T) { want *x509.Certificate assertion assert.ErrorAssertionFunc }{ + // Platform KMS + {"ok platform", platformKMS, args{&apiv1.LoadCertificateRequest{ + Name: platformCertName, + }}, platformChain[0], assert.NoError}, + {"fail platform missing", platformKMS, args{&apiv1.LoadCertificateRequest{ + Name: platformMissingName, + }}, nil, assert.Error}, + + // SoftKMS {"ok softKMS", softKMS, args{&apiv1.LoadCertificateRequest{ Name: "kms:" + certPath, }}, chain[0], assert.NoError}, @@ -496,6 +725,10 @@ func TestKMS_StoreCertificate(t *testing.T) { chain := mustCertificate(t, "") softKMS := mustKMS(t, "kms:backend=softkms") + platformKMS := mustPlatformKMS(t) + platformKey := mustCreatePlatformKey(t, platformKMS) + platformChain := mustCertificateWithKey(t, platformKey.PublicKey) + type args struct { req *apiv1.StoreCertificateRequest } @@ -505,6 +738,28 @@ func TestKMS_StoreCertificate(t *testing.T) { args args assertion assert.ErrorAssertionFunc }{ + // Platform KMS + {"ok platform", platformKMS, args{&apiv1.StoreCertificateRequest{ + Name: platformCertName, + Certificate: platformChain[0], + }}, assert.NoError}, + {"ok platform no key", platformKMS, args{&apiv1.StoreCertificateRequest{ + Name: platformCertName + "-other", + Certificate: chain[0], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + t.Cleanup(func() { + assert.NoError(t, platformKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: platformCertName, + })) + }) + return assert.NoError(t, err) + }}, + {"fail platform bad certificate", platformKMS, args{&apiv1.StoreCertificateRequest{ + Name: platformCertName, + Certificate: &x509.Certificate{}, + }}, assert.Error}, + + // SoftKMS {"ok softKMS", softKMS, args{&apiv1.StoreCertificateRequest{ Name: "kms:name=" + filepath.Join(dir, "cert.crt"), Certificate: chain[0], @@ -558,6 +813,9 @@ func TestKMS_LoadCertificateChain(t *testing.T) { chain := mustCertificate(t, chainPath) softKMS := mustKMS(t, "kms:backend=softkms") + platformKMS := mustPlatformKMS(t) + platformChain := mustCreatePlatformCertificate(t, platformKMS) + type args struct { req *apiv1.LoadCertificateChainRequest } @@ -568,6 +826,15 @@ func TestKMS_LoadCertificateChain(t *testing.T) { want []*x509.Certificate assertion assert.ErrorAssertionFunc }{ + // Platform KMS + {"ok platform", platformKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: platformCertName, + }}, platformChain, assert.NoError}, + {"fail platform missing", platformKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: platformMissingName, + }}, nil, assert.Error}, + + // SoftKMS {"ok softKMS", softKMS, args{&apiv1.LoadCertificateChainRequest{ Name: "kms:name=" + chainPath, }}, chain, assert.NoError}, @@ -600,6 +867,10 @@ func TestKMS_StoreCertificateChain(t *testing.T) { chain := mustCertificate(t, "") softKMS := mustKMS(t, "kms:backend=softkms") + platformKMS := mustPlatformKMS(t) + platformKey := mustCreatePlatformKey(t, platformKMS) + platformChain := mustCertificateWithKey(t, platformKey.PublicKey) + type args struct { req *apiv1.StoreCertificateChainRequest } @@ -609,6 +880,28 @@ func TestKMS_StoreCertificateChain(t *testing.T) { args args assertion assert.ErrorAssertionFunc }{ + // Platform KMS + {"ok platform", platformKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: platformCertName, + CertificateChain: platformChain, + }}, assert.NoError}, + {"ok platform no key", platformKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: platformCertName + "-other", + CertificateChain: chain, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + t.Cleanup(func() { + assert.NoError(t, platformKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: platformCertName, + })) + }) + return assert.NoError(t, err) + }}, + {"fail platform bad chain", platformKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: platformCertName, + CertificateChain: []*x509.Certificate{}, + }}, assert.Error}, + + // SoftKMS {"ok softKMS", softKMS, args{&apiv1.StoreCertificateChainRequest{ Name: "kms:name=" + filepath.Join(dir, "chain.crt"), CertificateChain: chain, @@ -653,6 +946,9 @@ func TestKMS_DeleteCertificate(t *testing.T) { _ = mustCertificate(t, filepath.Join(dir, "chain.crt")) softKMS := mustKMS(t, "kms:backend=softkms") + platformKMS := mustPlatformKMS(t) + _ = mustCreatePlatformCertificate(t, platformKMS, withNoCleanup()) + type args struct { req *apiv1.DeleteCertificateRequest } @@ -662,6 +958,19 @@ func TestKMS_DeleteCertificate(t *testing.T) { args args assertion assert.ErrorAssertionFunc }{ + {"ok platform", platformKMS, args{&apiv1.DeleteCertificateRequest{ + Name: platformCertName, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + _, loadErr := platformKMS.LoadCertificate(&apiv1.LoadCertificateRequest{ + Name: platformCertName, + }) + return assert.NoError(t, err) && assert.Error(t, loadErr) + }}, + {"fail platform missing", platformKMS, args{&apiv1.DeleteCertificateRequest{ + Name: platformMissingName, + }}, assert.Error}, + + // SoftKMS {"ok softKMS", softKMS, args{&apiv1.DeleteCertificateRequest{ Name: "kms:name=" + url.QueryEscape(filepath.Join(dir, "chain.crt")), }}, func(tt assert.TestingT, err error, i ...interface{}) bool { @@ -694,6 +1003,26 @@ func TestKMS_SearchKeys(t *testing.T) { dir := t.TempDir() softKMS := mustKMS(t, "kms:backend=softkms") + suffix := mustSuffix(t) + platformKMS := mustPlatformKMS(t) + + platformKeys := make([]*apiv1.CreateKeyResponse, 4) + for i := range platformKeys { + name := fmt.Sprintf("kms:name=search-test-%d-%s", i, suffix) + if runtime.GOOS == "darwin" { + name += ";tag=com.smallstep.test." + suffix + } + platformKeys[i] = mustCreatePlatformKey(t, platformKMS, withName(name)) + } + + makeResult := func(r *apiv1.CreateKeyResponse) apiv1.SearchKeyResult { + return apiv1.SearchKeyResult{ + Name: r.Name, + PublicKey: r.PublicKey, + CreateSignerRequest: r.CreateSignerRequest, + } + } + type args struct { req *apiv1.SearchKeysRequest } @@ -704,6 +1033,24 @@ func TestKMS_SearchKeys(t *testing.T) { want *apiv1.SearchKeysResponse assertion assert.ErrorAssertionFunc }{ + // PlatformKMS + {"ok platform", platformKMS, args{&apiv1.SearchKeysRequest{ + Query: "kms:tag=com.smallstep.test." + suffix, + }}, &apiv1.SearchKeysResponse{ + Results: []apiv1.SearchKeyResult{ + makeResult(platformKeys[0]), makeResult(platformKeys[1]), + makeResult(platformKeys[2]), makeResult(platformKeys[3]), + }, + }, assert.NoError}, + {"ok platform with name", platformKMS, args{&apiv1.SearchKeysRequest{ + Query: fmt.Sprintf("kms:name=search-test-%d-%s;tag=com.smallstep.test.%s", 2, suffix, suffix), + }}, &apiv1.SearchKeysResponse{ + Results: []apiv1.SearchKeyResult{ + makeResult(platformKeys[2]), + }, + }, assert.NoError}, + + // SoftKMS {"fail softKMS", softKMS, args{&apiv1.SearchKeysRequest{ Query: "kms:name=" + url.QueryEscape(dir), }}, nil, assert.Error}, diff --git a/tpm/key.go b/tpm/key.go index 1b61ae2f..b8c267b3 100644 --- a/tpm/key.go +++ b/tpm/key.go @@ -9,7 +9,6 @@ import ( "fmt" "time" - "github.com/google/go-tpm/legacy/tpm2" "github.com/smallstep/go-attestation/attest" internalkey "go.step.sm/crypto/tpm/internal/key" @@ -68,19 +67,22 @@ func (k *Key) WasAttestedBy(ak *AK) bool { // by a call to the TPM, so it can fail. If it fails, // nil is returned. func (k *Key) Public() crypto.PublicKey { - blobs, err := k.Blobs(context.Background()) - if err != nil { - return nil - } - pub, err := tpm2.DecodePublic(blobs.public) - if err != nil { + var ( + err error + ctx = context.Background() + ) + if err = k.tpm.open(ctx); err != nil { return nil } - publicKey, err := pub.Key() + defer closeTPM(context.Background(), k.tpm, &err) + + key, err := k.tpm.attestTPM.LoadKey(k.data) if err != nil { return nil } - return publicKey + defer key.Close() + + return key.Public() } // Certificate returns the certificate for the Key, if set. From 1ee9695042f674d5e94915e329828706f958c6b4 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 25 Feb 2026 19:52:00 -0800 Subject: [PATCH 10/27] Add method to get a big.Int from the URI. This commit adds the GetBigInt method to the URI. This method can be useful to decode serial numbers. --- kms/uri/uri.go | 84 +++++++++++++++++++++++++++++++++++++++++---- kms/uri/uri_test.go | 38 ++++++++++++++++++-- 2 files changed, 114 insertions(+), 8 deletions(-) diff --git a/kms/uri/uri.go b/kms/uri/uri.go index 43ff2efe..860c4b0e 100644 --- a/kms/uri/uri.go +++ b/kms/uri/uri.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/hex" "fmt" + "math/big" "net/url" "os" "strconv" @@ -152,6 +153,38 @@ func (u *URI) GetInt(key string) *int64 { return nil } +// GetBigInt returns the first [*big.Int] value in the URI with the given key. +// It returns nil if the field is not present. It parses as a hexadecimal +// string if the value starts with 0x (0x12), 0X (0X1A), contains a colon +// (00:01), or contains only hex characters with at least one letter A-F +// (e.g. "0A01"); otherwise it parses as a base-10 number. +func (u *URI) GetBigInt(key string) (*big.Int, error) { + v := u.Get(key) + if v == "" { + return nil, nil + } + + if hx, ok := hexString(v); ok { + if hx == "" { + return nil, fmt.Errorf("value %q is not a valid hexadecimal number", v) + } + + b, err := hex.DecodeString(hx) + if err != nil { + return nil, err + } + + return new(big.Int).SetBytes(b), nil + } + + bi, ok := new(big.Int).SetString(v, 10) + if !ok { + return nil, fmt.Errorf("value %q is not a valid number", v) + } + + return bi, nil +} + // GetEncoded returns the first value in the uri with the given key, it will // return empty nil if that field is not present or is empty. If the return // value is hex encoded it will decode it and return it. @@ -160,8 +193,8 @@ func (u *URI) GetEncoded(key string) []byte { if v == "" { return nil } - if len(v)%2 == 0 { - if b, err := hex.DecodeString(strings.TrimPrefix(v, "0x")); err == nil { + if hx, ok := hexString(v); ok { + if b, err := hex.DecodeString(hx); err == nil { return b } } @@ -177,12 +210,12 @@ func (u *URI) GetHexEncoded(key string) ([]byte, error) { return nil, nil } - b, err := hex.DecodeString(strings.TrimPrefix(v, "0x")) - if err != nil { - return nil, fmt.Errorf("failed decoding %q: %w", v, err) + hx, ok := hexString(v) + if !ok || hx == "" { + return nil, fmt.Errorf("value %q is not a valid hexadecimal number", v) } - return b, nil + return hex.DecodeString(hx) } // Pin returns the pin encoded in the url. It will read the pin from the @@ -218,6 +251,45 @@ func (u *URI) Read(key string) ([]byte, error) { return readFile(path) } +// hexString returns a clean hexadecimal string and a boolean indicating if s +// can be an hexadecimal string. If s starts with 0x (0x12), 0X (0X1A), or +// contains colons (01:1A) it will remove them and return true if it only +// contains valid hexadecimal characters. It will also true if the string +// contains at least one letter A-F (010A). It will also prefix the string with +// 0 if the length is an odd number. +func hexString(s string) (string, bool) { + hx := strings.TrimPrefix(s, "0x") + hx = strings.TrimPrefix(hx, "0X") + hx = strings.ReplaceAll(hx, ":", "") + changed := (len(s) != len(hx)) + + if len(hx)%2 != 0 { + hx = "0" + hx + } + + valid, hasLetter := isValidHexString(hx) + return hx, valid && (changed || hasLetter) +} + +// isValidHexString returns two booleans, the first indicating s contains only +// hexadecimal characters and the second if at least one letter (a-f or A-F), +// indicating it should be treated as hex. +func isValidHexString(s string) (bool, bool) { + var hasLetter bool + for _, c := range s { + switch { + case c >= '0' && c <= '9': + case c >= 'a' && c <= 'f': + hasLetter = true + case c >= 'A' && c <= 'F': + hasLetter = true + default: + return false, false + } + } + return true, hasLetter +} + func readFile(path string) ([]byte, error) { u, err := url.Parse(path) if err == nil && (u.Scheme == "" || u.Scheme == "file") { diff --git a/kms/uri/uri_test.go b/kms/uri/uri_test.go index 03542303..44de960d 100644 --- a/kms/uri/uri_test.go +++ b/kms/uri/uri_test.go @@ -2,6 +2,7 @@ package uri import ( "errors" + "math/big" "net/url" "os" "path/filepath" @@ -263,7 +264,7 @@ func TestURI_GetEncoded(t *testing.T) { {"ok in query percent", mustParse(t, "yubikey:slot-id=9a?foo=%9a"), args{"foo"}, []byte{0x9a}}, {"ok missing", mustParse(t, "yubikey:slot-id=9a"), args{"foo"}, nil}, {"ok missing query", mustParse(t, "yubikey:slot-id=9a?bar=zar"), args{"foo"}, nil}, - {"ok no hex", mustParse(t, "yubikey:slot-id=09a?bar=zar"), args{"slot-id"}, []byte{'0', '9', 'a'}}, + {"ok no hex", mustParse(t, "yubikey:slot-id=09z?bar=zar"), args{"slot-id"}, []byte{'0', '9', 'z'}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -369,6 +370,39 @@ func TestURI_GetInt(t *testing.T) { } } +func TestURI_GetBigInt(t *testing.T) { + type args struct { + key string + } + tests := []struct { + name string + uri *URI + args args + want *big.Int + assertion assert.ErrorAssertionFunc + }{ + {"ok empty", mustParse(t, "mackms:serial="), args{"serial"}, nil, assert.NoError}, + {"ok missing", mustParse(t, "mackms:label=123456"), args{"serial"}, nil, assert.NoError}, + {"ok number", mustParse(t, "mackms:serial=123456"), args{"serial"}, big.NewInt(123456), assert.NoError}, + {"ok hex with 0x", mustParse(t, "mackms:serial=0x123456"), args{"serial"}, big.NewInt(1193046), assert.NoError}, + {"ok hex with 0X", mustParse(t, "mackms:serial=0X123456"), args{"serial"}, big.NewInt(1193046), assert.NoError}, + {"ok hex with colon", mustParse(t, "mackms:serial=12:34:56"), args{"serial"}, big.NewInt(1193046), assert.NoError}, + {"ok hex odd length", mustParse(t, "mackms:serial=0x1"), args{"serial"}, big.NewInt(1), assert.NoError}, + {"fail hex empty", mustParse(t, "mackms:serial=0x"), args{"serial"}, nil, assert.Error}, + {"ok hex with letters", mustParse(t, "mackms:serial=0A01"), args{"serial"}, big.NewInt(0x0A01), assert.NoError}, + {"ok hex with letters only", mustParse(t, "mackms:serial=12345a"), args{"serial"}, big.NewInt(0x12345a), assert.NoError}, + {"fail hex", mustParse(t, "mackms:serial=0x12345g"), args{"serial"}, nil, assert.Error}, + {"fail number", mustParse(t, "mackms:serial=12345G"), args{"serial"}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.uri.GetBigInt(tt.args.key) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + func TestURI_GetHexEncoded(t *testing.T) { type args struct { key string @@ -384,7 +418,7 @@ func TestURI_GetHexEncoded(t *testing.T) { {"ok first", mustParse(t, "capi:sha1=9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}, false}, {"ok prefix", mustParse(t, "capi:sha1=0x9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}, false}, {"ok missing", mustParse(t, "capi:foo=9a"), args{"sha1"}, nil, false}, - {"fail odd hex", mustParse(t, "capi:sha1=09a?bar=zar"), args{"sha1"}, nil, true}, + {"ok odd hex", mustParse(t, "capi:sha1=09a?bar=zar"), args{"sha1"}, []byte{0x00, 0x9a}, false}, {"fail invalid hex", mustParse(t, "capi:sha1=9z?bar=zar"), args{"sha1"}, nil, true}, } for _, tt := range tests { From ee22d8d838da33c2deae742d8465092da9553b3a Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 25 Feb 2026 21:02:22 -0800 Subject: [PATCH 11/27] Add option to load a certificate by the key name --- kms/capi/capi.go | 130 ++++++++++++++++--------------------- kms/capi/ncrypt_windows.go | 26 ++++++++ 2 files changed, 83 insertions(+), 73 deletions(-) diff --git a/kms/capi/capi.go b/kms/capi/capi.go index 3a7ea50d..22555195 100644 --- a/kms/capi/capi.go +++ b/kms/capi/capi.go @@ -32,6 +32,7 @@ import ( "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" "go.step.sm/crypto/randutil" + "go.step.sm/crypto/x509util" ) // Scheme is the scheme used in uris, the string "capi". @@ -88,7 +89,7 @@ type uriAttributes struct { intermediateStoreName string keyID []byte subjectCN string - serialNumber string + serialNumber *big.Int issuerName string keySpec string skipFindCertificateKey bool @@ -115,6 +116,11 @@ func parseURI(rawuri string) (*uriAttributes, error) { } } + serialNumber, err := u.GetBigInt(SerialNumberArg) + if err != nil { + return nil, fmt.Errorf("failed getting %s from URI %q: %w", SerialNumberArg, rawuri, err) + } + return &uriAttributes{ containerName: u.Get(ContainerNameArg), hash: hashValue, @@ -124,7 +130,7 @@ func parseURI(rawuri string) (*uriAttributes, error) { intermediateStoreName: cmp.Or(u.Get(IntermediateStoreNameArg), CAStore), keyID: keyIDValue, subjectCN: u.Get(SubjectCNArg), - serialNumber: u.Get(SerialNumberArg), + serialNumber: serialNumber, issuerName: u.Get(IssuerNameArg), keySpec: u.Get(KeySpec), skipFindCertificateKey: u.GetBool(SkipFindCertificateKey), @@ -414,25 +420,10 @@ func (k *CAPIKMS) getCertContext(u *uriAttributes) (*windows.CertContext, error) return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%x not found", HashArg, u.hash)} } case len(u.keyID) > 0: - searchData := CERT_ID_KEYIDORHASH{ - idChoice: CERT_ID_KEY_IDENTIFIER, - KeyIDOrHash: CRYPTOAPI_BLOB{ - len: uint32(len(u.keyID)), - data: uintptr(unsafe.Pointer(&u.keyID[0])), - }, - } - handle, err = findCertificateInStore(st, - encodingX509ASN|encodingPKCS7, - 0, - findCertID, - uintptr(unsafe.Pointer(&searchData)), nil) - if err != nil { - return nil, fmt.Errorf("findCertificateInStore failed: %w", err) - } - if handle == nil { - return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%x not found", KeyIDArg, u.keyID)} + if handle, err = findCertificateBySubjectKeyID(st, u.keyID); err != nil { + return nil, err } - case u.issuerName != "" && (u.serialNumber != "" || u.subjectCN != ""): + case u.issuerName != "" && (u.serialNumber != nil || u.subjectCN != ""): var prevCert *windows.CertContext for { handle, err = findCertificateInStore(st, @@ -454,27 +445,11 @@ func (k *CAPIKMS) getCertContext(u *uriAttributes) (*windows.CertContext, error) } switch { - case len(u.serialNumber) > 0: + case u.serialNumber != nil: // TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number - var bi *big.Int - if strings.HasPrefix(u.serialNumber, "0x") { - serialBytes, err := hex.DecodeString(strings.TrimPrefix(u.serialNumber, "0x")) - if err != nil { - return nil, fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err) - } - - bi = new(big.Int).SetBytes(serialBytes) - } else { - bi := new(big.Int) - bi, ok := bi.SetString(u.serialNumber, 10) - if !ok { - return nil, fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg) - } - } - - if x509Cert.SerialNumber.Cmp(bi) == 0 { + if x509Cert.SerialNumber.Cmp(u.serialNumber) == 0 { return handle, nil } case len(u.subjectCN) > 0: @@ -485,6 +460,22 @@ func (k *CAPIKMS) getCertContext(u *uriAttributes) (*windows.CertContext, error) prevCert = handle } + case u.containerName != "": + key, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{ + Name: uri.New(Scheme, url.Values{ + ContainerNameArg: []string{u.containerName}, + }).String(), + }) + if err != nil { + return nil, err + } + keyID, err := x509util.GenerateSubjectKeyID(key) + if err != nil { + return nil, fmt.Errorf("error generating SubjectKeyID: %w", err) + } + if handle, err = findCertificateBySubjectKeyID(st, keyID); err != nil { + return nil, err + } default: return nil, fmt.Errorf("%q, %q, or %q and one of %q or %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg, SubjectCNArg) } @@ -964,50 +955,22 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { } return nil case len(u.keyID) > 0: - searchData := CERT_ID_KEYIDORHASH{ - idChoice: CERT_ID_KEY_IDENTIFIER, - KeyIDOrHash: CRYPTOAPI_BLOB{ - len: uint32(len(u.keyID)), - data: uintptr(unsafe.Pointer(&u.keyID[0])), - }, - } - certHandle, err = findCertificateInStore(st, - encodingX509ASN|encodingPKCS7, - 0, - findCertID, - uintptr(unsafe.Pointer(&searchData)), nil) + certHandle, err = findCertificateBySubjectKeyID(st, u.keyID) if err != nil { - return fmt.Errorf("findCertificateInStore failed: %w", err) - } - if certHandle == nil { - return nil + return err } - if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil { return fmt.Errorf("failed removing certificate: %w", err) } return nil - case u.issuerName != "" && u.serialNumber != "": + case u.issuerName != "" && u.serialNumber != nil: // TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number - var serialBytes []byte - if strings.HasPrefix(u.serialNumber, "0x") { - u.serialNumber = strings.TrimPrefix(u.serialNumber, "0x") - u.serialNumber = strings.TrimPrefix(u.serialNumber, "00") // Comparison fails if leading 00 is not removed - serialBytes, err = hex.DecodeString(u.serialNumber) - if err != nil { - return fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err) - } - } else { - bi := new(big.Int) - bi, ok := bi.SetString(u.serialNumber, 10) - if !ok { - return fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg) - } - serialBytes = bi.Bytes() - } - var prevCert *windows.CertContext + var ( + prevCert *windows.CertContext + serialBytes = u.serialNumber.Bytes() + ) for { certHandle, err = findCertificateInStore(st, encodingX509ASN|encodingPKCS7, @@ -1036,6 +999,27 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { } prevCert = certHandle } + case u.containerName != "": + key, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{ + Name: uri.New(Scheme, url.Values{ + ContainerNameArg: []string{u.containerName}, + }).String(), + }) + if err != nil { + return err + } + keyID, err := x509util.GenerateSubjectKeyID(key) + if err != nil { + return fmt.Errorf("error generating SubjectKeyID: %w", err) + } + certHandle, err = findCertificateBySubjectKeyID(st, keyID) + if err != nil { + return err + } + if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil { + return fmt.Errorf("failed removing certificate: %w", err) + } + return nil default: return fmt.Errorf("%q, %q, or %q and %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg) } diff --git a/kms/capi/ncrypt_windows.go b/kms/capi/ncrypt_windows.go index 9366f08b..714f3ca2 100644 --- a/kms/capi/ncrypt_windows.go +++ b/kms/capi/ncrypt_windows.go @@ -12,6 +12,7 @@ import ( "fmt" "unsafe" + "go.step.sm/crypto/kms/apiv1" "golang.org/x/sys/windows" ) @@ -518,6 +519,31 @@ func nCryptExportKey(kh uintptr, blobType string) ([]byte, error) { return buf, nil } +func findCertificateBySubjectKeyID(store windows.Handle, keyID []byte) (*windows.CertContext, error) { + searchData := CERT_ID_KEYIDORHASH{ + idChoice: CERT_ID_KEY_IDENTIFIER, + KeyIDOrHash: CRYPTOAPI_BLOB{ + len: uint32(len(keyID)), + data: uintptr(unsafe.Pointer(&keyID[0])), + }, + } + + handle, err := findCertificateInStore(store, + encodingX509ASN|encodingPKCS7, + 0, + findCertID, + uintptr(unsafe.Pointer(&searchData)), nil) + if err != nil { + return nil, fmt.Errorf("findCertificateInStore failed: %w", err) + } + + if handle == nil { + return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%x not found", KeyIDArg, keyID)} + } + + return handle, nil +} + func findCertificateInStore(store windows.Handle, enc, findFlags, findType uint32, para uintptr, prev *windows.CertContext) (*windows.CertContext, error) { h, _, err := procCertFindCertificateInStore.Call( uintptr(store), From 3513f8cd3244dea1e859a418a6cccf7cdd549239 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 25 Feb 2026 21:03:13 -0800 Subject: [PATCH 12/27] Use new GetBigInt to parse the serial --- kms/mackms/mackms.go | 14 ++++++-------- kms/mackms/mackms_test.go | 1 + 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/kms/mackms/mackms.go b/kms/mackms/mackms.go index 703d9d20..90570dfc 100644 --- a/kms/mackms/mackms.go +++ b/kms/mackms/mackms.go @@ -1276,17 +1276,15 @@ func parseCertURI(rawuri string, useDataProtectionKeychain, requireValue bool) ( } // With regular values, uris look like this: - // mackms:label=my-cert;serial=01020A0B... + // mackms:label=my-cert;serial=0x01020A0B... label := u.Get("label") keychain := u.Get("keychain") - serial := u.GetEncoded("serial") - if requireValue && label == "" && len(serial) == 0 { - return nil, fmt.Errorf("error parsing %q: label or serial are required", rawuri) + serialNumber, err := u.GetBigInt("serial") + if err != nil { + return nil, fmt.Errorf("error parsing %q: %w", rawuri, err) } - - var serialNumber *big.Int - if len(serial) > 0 { - serialNumber = new(big.Int).SetBytes(serial) + if requireValue && label == "" && serialNumber == nil { + return nil, fmt.Errorf("error parsing %q: label or serial are required", rawuri) } return &certAttributes{ diff --git a/kms/mackms/mackms_test.go b/kms/mackms/mackms_test.go index dd627982..0959e2fd 100644 --- a/kms/mackms/mackms_test.go +++ b/kms/mackms/mackms_test.go @@ -849,6 +849,7 @@ func TestMacKMS_LoadCertificate(t *testing.T) { {"fail uri", &MacKMS{}, args{&apiv1.LoadCertificateRequest{Name: "mackms:"}}, nil, assert.Error}, {"fail missing label", &MacKMS{}, args{&apiv1.LoadCertificateRequest{Name: "mackms:label=missing-" + suffix}}, nil, assert.Error}, {"fail missing serial", &MacKMS{}, args{&apiv1.LoadCertificateRequest{Name: "mackms:serial=010a020b030c"}}, nil, assert.Error}, + {"fail bad serial", &MacKMS{}, args{&apiv1.LoadCertificateRequest{Name: "mackms:serial=010a020b030z"}}, nil, assert.Error}, {"fail with keychain", &MacKMS{}, args{&apiv1.LoadCertificateRequest{ Name: "mackms:keychain=dataProtection;label=" + label, }}, nil, assert.Error}, From b30491a5dfe27293339c5260e0ae7c9704ddc89e Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 25 Feb 2026 21:04:17 -0800 Subject: [PATCH 13/27] Expose GenerateSubjectKeyID --- x509util/certificate.go | 2 +- x509util/certificate_test.go | 4 ++-- x509util/utils.go | 4 ++-- x509util/utils_test.go | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/x509util/certificate.go b/x509util/certificate.go index fa80b06f..a2456c6a 100644 --- a/x509util/certificate.go +++ b/x509util/certificate.go @@ -216,7 +216,7 @@ func CreateCertificate(template, parent *x509.Certificate, pub crypto.PublicKey, } } if template.SubjectKeyId == nil { - if template.SubjectKeyId, err = generateSubjectKeyID(pub); err != nil { + if template.SubjectKeyId, err = GenerateSubjectKeyID(pub); err != nil { return nil, err } } diff --git a/x509util/certificate_test.go b/x509util/certificate_test.go index d49b8d48..a7df6960 100644 --- a/x509util/certificate_test.go +++ b/x509util/certificate_test.go @@ -81,7 +81,7 @@ func createIssuerCertificate(t *testing.T, commonName string) (*x509.Certificate if err != nil { t.Fatal(err) } - subjectKeyID, err := generateSubjectKeyID(pub) + subjectKeyID, err := GenerateSubjectKeyID(pub) if err != nil { t.Fatal(err) } @@ -814,7 +814,7 @@ func TestCreateCertificate(t *testing.T) { return sn } mustSubjectKeyID := func(pub crypto.PublicKey) []byte { - b, err := generateSubjectKeyID(pub) + b, err := GenerateSubjectKeyID(pub) if err != nil { t.Fatal(err) } diff --git a/x509util/utils.go b/x509util/utils.go index 69ddbebf..825f8d7a 100644 --- a/x509util/utils.go +++ b/x509util/utils.go @@ -110,7 +110,7 @@ type subjectPublicKeyInfo struct { SubjectPublicKey asn1.BitString } -// generateSubjectKeyID generates the key identifier according the the RFC 5280 +// GenerateSubjectKeyID generates the key identifier according the the RFC 5280 // section 4.2.1.2. // // The keyIdentifier is composed of the 160-bit SHA-1 hash of the value of the @@ -119,7 +119,7 @@ type subjectPublicKeyInfo struct { // // If FIPS 140-3 mode is enabled, instead of SHA-1, it will use the leftmost // 160-bits of the SHA-256 hash according to RFC 7093 section 2. -func generateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) { +func GenerateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) { b, err := x509.MarshalPKIXPublicKey(pub) if err != nil { return nil, errors.Wrap(err, "error marshaling public key") diff --git a/x509util/utils_test.go b/x509util/utils_test.go index a5a1c1bd..ca8a5ca3 100644 --- a/x509util/utils_test.go +++ b/x509util/utils_test.go @@ -172,7 +172,7 @@ func Test_generateSubjectKeyID(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := generateSubjectKeyID(tt.args.pub) + got, err := GenerateSubjectKeyID(tt.args.pub) if (err != nil) != tt.wantErr { t.Errorf("generateSubjectKeyID() error = %v, wantErr %v", err, tt.wantErr) return @@ -204,7 +204,7 @@ func Test_generateSubjectKeyID_fips(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := generateSubjectKeyID(tt.args.pub) + got, err := GenerateSubjectKeyID(tt.args.pub) tt.assertion(t, err) assert.Equal(t, tt.want, got) }) From ead241cea33c0a1b3dd9ff13d861d79d3a8932c9 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 25 Feb 2026 21:05:02 -0800 Subject: [PATCH 14/27] Add tests for windows --- kms/platform/kms_other_test.go | 14 +- kms/platform/kms_softkms.go | 15 +- kms/platform/kms_test.go | 76 ++++- kms/platform/kms_windows.go | 6 +- kms/platform/kms_windows_test.go | 486 ++++++++++++++++++++++++++++++- 5 files changed, 576 insertions(+), 21 deletions(-) diff --git a/kms/platform/kms_other_test.go b/kms/platform/kms_other_test.go index 48eddbca..b3e703b3 100644 --- a/kms/platform/kms_other_test.go +++ b/kms/platform/kms_other_test.go @@ -3,15 +3,25 @@ package platform import ( + "net/url" "testing" + + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/uri" ) func mustPlatformKMS(t *testing.T) *KMS { - return &KMS{} + if !isTPMAvailable() { + return &KMS{} + } + + return mustKMS(t, uri.New(Scheme, url.Values{ + "storage-directory": []string{t.TempDir()}, + }).String()) } // SkipTest is a method implemented on tests that allow skipping the test on // this platform. func (k *KMS) SkipTests() bool { - return true + return k.Type() == apiv1.DefaultKMS } diff --git a/kms/platform/kms_softkms.go b/kms/platform/kms_softkms.go index 4e6af8c8..4af3639c 100644 --- a/kms/platform/kms_softkms.go +++ b/kms/platform/kms_softkms.go @@ -107,9 +107,6 @@ func (k *softKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { func filename(s string) string { if u, err := uri.ParseWithScheme(softkms.Scheme, s); err == nil { - if f := u.Get("path"); f != "" { - return f - } switch { case u.Path != "": return u.Path @@ -121,10 +118,18 @@ func filename(s string) string { } func transformToSoftKMS(u *kmsURI) string { - if u.name != "" { + switch { + case u.uri.Has("name"): return uri.NewOpaque(softkms.Scheme, u.name).String() + case u.uri.Has("path"): + return uri.NewOpaque(softkms.Scheme, u.uri.Get("path")).String() + case u.uri.Path != "": + return uri.NewOpaque(softkms.Scheme, u.uri.Path).String() + case u.uri.Opaque != "": + return uri.NewOpaque(softkms.Scheme, u.uri.Opaque).String() + default: + return uri.NewOpaque(softkms.Scheme, "").String() } - return uri.NewOpaque(softkms.Scheme, u.uri.Path).String() } func transformFromSoftKMS(rawuri string) (string, error) { diff --git a/kms/platform/kms_test.go b/kms/platform/kms_test.go index d7c9e8bc..a067024f 100644 --- a/kms/platform/kms_test.go +++ b/kms/platform/kms_test.go @@ -7,6 +7,7 @@ import ( "crypto/elliptic" "crypto/rsa" "crypto/x509" + "encoding/hex" "encoding/pem" "fmt" "net/url" @@ -19,9 +20,11 @@ import ( "github.com/stretchr/testify/require" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/uri" "go.step.sm/crypto/minica" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" + "go.step.sm/crypto/tpm" ) var ( @@ -48,6 +51,14 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } +func isTPMAvailable() bool { + t, err := tpm.New() + if err != nil { + return false + } + return t.Available() == nil +} + func shouldSkipNow(t *testing.T, km *KMS) { t.Helper() @@ -162,8 +173,9 @@ func mustSuffix(t *testing.T) string { } type createOptions struct { - name string - noCleanup bool + name string + noCleanup bool + noCleanupCertificate bool } type createFuncOption func(*createOptions) @@ -177,6 +189,13 @@ func withName(s string) createFuncOption { func withNoCleanup() createFuncOption { return func(co *createOptions) { co.noCleanup = true + co.noCleanupCertificate = true + } +} + +func withNoCleanupCertificate() createFuncOption { + return func(co *createOptions) { + co.noCleanupCertificate = true } } @@ -242,7 +261,7 @@ func mustCreatePlatformCertificate(t *testing.T, km *KMS, opts ...createFuncOpti cert, ca.Intermediate, }, })) - if !o.noCleanup { + if !o.noCleanupCertificate { t.Cleanup(func() { assert.NoError(t, km.DeleteCertificate(&apiv1.DeleteCertificateRequest{ Name: o.name, @@ -250,6 +269,17 @@ func mustCreatePlatformCertificate(t *testing.T, km *KMS, opts ...createFuncOpti }) } + // Always delete the intermediate on macOS + if typ := km.Type(); typ == apiv1.MacKMS { + t.Cleanup(func() { + assert.NoError(t, km.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: uri.New(Scheme, url.Values{ + "serial": []string{ca.Intermediate.SerialNumber.String()}, + }).String(), + })) + }) + } + return []*x509.Certificate{ cert, ca.Intermediate, } @@ -361,7 +391,10 @@ func TestKMS_CreateKey(t *testing.T) { assert.Regexp(t, "^kms:.*name=test1-.*$", got.Name) assert.Equal(t, got.Name, got.CreateSignerRequest.SigningKey) - if assert.IsType(t, &ecdsa.PublicKey{}, got.PublicKey) { + + if platformKMS.Type() == apiv1.TPMKMS && assert.IsType(t, &rsa.PublicKey{}, got.PublicKey) { + assert.Equal(t, 256, got.PublicKey.(*rsa.PublicKey).Size()) + } else if assert.IsType(t, &ecdsa.PublicKey{}, got.PublicKey) { assert.Equal(t, elliptic.P256(), got.PublicKey.(*ecdsa.PublicKey).Curve) } }, assert.NoError}, @@ -379,6 +412,7 @@ func TestKMS_CreateKey(t *testing.T) { assert.Regexp(t, "^kms:.*name=test2-.*$", got.Name) assert.Equal(t, got.Name, got.CreateSignerRequest.SigningKey) + if assert.IsType(t, &ecdsa.PublicKey{}, got.PublicKey) { assert.Equal(t, elliptic.P384(), got.PublicKey.(*ecdsa.PublicKey).Curve) } @@ -538,7 +572,7 @@ func TestKMS_CreateSigner(t *testing.T) { }{ // PlatformKMS {"ok platform", platformKMS, args{&apiv1.CreateSignerRequest{ - SigningKey: platformKey.Name, + SigningKey: platformKeyName, }}, func(t *testing.T, s crypto.Signer) { require.NotNil(t, s) assert.Equal(t, platformKey.PublicKey, s.Public()) @@ -619,10 +653,10 @@ func TestKMS_DeleteKey(t *testing.T) { }) return assert.NoError(t, err) && assert.Error(t, getErr) }}, - {"fail platform deleted", softKMS, args{&apiv1.DeleteKeyRequest{ + {"fail platform deleted", platformKMS, args{&apiv1.DeleteKeyRequest{ Name: platformKey.Name, }}, assert.Error}, - {"fail platform missing", softKMS, args{&apiv1.DeleteKeyRequest{ + {"fail platform missing", platformKMS, args{&apiv1.DeleteKeyRequest{ Name: platformMissingName, }}, assert.Error}, @@ -747,6 +781,11 @@ func TestKMS_StoreCertificate(t *testing.T) { Name: platformCertName + "-other", Certificate: chain[0], }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + // Storing a certificate with no key is not supported on TPMKMS. + if platformKMS.Type() == apiv1.TPMKMS { + return assert.Error(t, err) + } + t.Cleanup(func() { assert.NoError(t, platformKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ Name: platformCertName, @@ -889,10 +928,23 @@ func TestKMS_StoreCertificateChain(t *testing.T) { Name: platformCertName + "-other", CertificateChain: chain, }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + // Storing a certificate with no key is not supported on TPMKMS. + if platformKMS.Type() == apiv1.TPMKMS { + return assert.Error(t, err) + } + t.Cleanup(func() { assert.NoError(t, platformKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ Name: platformCertName, })) + + if typ := platformKMS.Type(); typ == apiv1.MacKMS { + assert.NoError(t, platformKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: uri.New(Scheme, url.Values{ + "serial": []string{hex.EncodeToString(platformChain[1].SerialNumber.Bytes())}, + }).String(), + })) + } }) return assert.NoError(t, err) }}, @@ -947,7 +999,7 @@ func TestKMS_DeleteCertificate(t *testing.T) { softKMS := mustKMS(t, "kms:backend=softkms") platformKMS := mustPlatformKMS(t) - _ = mustCreatePlatformCertificate(t, platformKMS, withNoCleanup()) + _ = mustCreatePlatformCertificate(t, platformKMS, withNoCleanupCertificate()) type args struct { req *apiv1.DeleteCertificateRequest @@ -1017,9 +1069,11 @@ func TestKMS_SearchKeys(t *testing.T) { makeResult := func(r *apiv1.CreateKeyResponse) apiv1.SearchKeyResult { return apiv1.SearchKeyResult{ - Name: r.Name, - PublicKey: r.PublicKey, - CreateSignerRequest: r.CreateSignerRequest, + Name: r.Name, + PublicKey: r.PublicKey, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: r.Name, + }, } } diff --git a/kms/platform/kms_windows.go b/kms/platform/kms_windows.go index cc73bdbc..e25dd83b 100644 --- a/kms/platform/kms_windows.go +++ b/kms/platform/kms_windows.go @@ -56,6 +56,10 @@ func transformToCapiKMS(u *kmsURI) string { uv.Set("key", u.name) } + // When storing certificate skip key validation. + // This avoid a prompt looking for an SmartCard. + uv.Set("skip-find-certificate-key", "true") + // Add custom extra values that might be CAPI specific. maps.Copy(uv, u.extraValues) @@ -73,7 +77,7 @@ func transformFromCapiKMS(rawuri string) (string, error) { } for k, v := range u.Values { - if k != "name" { + if k != "key" { uv[k] = v } } diff --git a/kms/platform/kms_windows_test.go b/kms/platform/kms_windows_test.go index 884208fe..e8ee5050 100644 --- a/kms/platform/kms_windows_test.go +++ b/kms/platform/kms_windows_test.go @@ -3,15 +3,497 @@ package platform import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/x509" + "encoding/hex" + "fmt" + "net/url" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/uri" ) func mustPlatformKMS(t *testing.T) *KMS { - return &KMS{} + t.Helper() + + if !isTPMAvailable() { + return &KMS{} + } + + return mustKMS(t, uri.New(Scheme, url.Values{ + "storage-directory": []string{t.TempDir()}, + }).String()) } // SkipTest is a method implemented on tests that allow skipping the test on // this platform. func (k *KMS) SkipTests() bool { - return true + return k.Type() == apiv1.DefaultKMS +} + +func mustCAPIKMS(t *testing.T) *KMS { + return mustKMS(t, "kms:backend=capi") +} + +func TestKMS_Type_capi(t *testing.T) { + km := mustCAPIKMS(t) + assert.Equal(t, apiv1.CAPIKMS, km.Type()) +} + +func TestKMS_GetPublicKey_capi(t *testing.T) { + capiKMS := mustCAPIKMS(t) + capiKey := mustCreatePlatformKey(t, capiKMS) + + type args struct { + req *apiv1.GetPublicKeyRequest + } + tests := []struct { + name string + kms *KMS + args args + want crypto.PublicKey + assertion assert.ErrorAssertionFunc + }{ + {"ok capi", capiKMS, args{&apiv1.GetPublicKeyRequest{ + Name: capiKey.Name, + }}, capiKey.PublicKey, assert.NoError}, + {"fail capi missing", capiKMS, args{&apiv1.GetPublicKeyRequest{ + Name: platformMissingName, + }}, nil, assert.Error}, + {"fail capi name", capiKMS, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:something=test", + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.GetPublicKey(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_CreateKey_capi(t *testing.T) { + suffix := mustSuffix(t) + capiKMS := mustCAPIKMS(t) + + type args struct { + req *apiv1.CreateKeyRequest + } + tests := []struct { + name string + kms *KMS + args args + equal func(t *testing.T, got *apiv1.CreateKeyResponse) + assertion assert.ErrorAssertionFunc + }{ + {"ok capi", capiKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=test1-" + suffix, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + require.NotNil(t, got) + + t.Cleanup(func() { + assert.NoError(t, capiKMS.DeleteKey(&apiv1.DeleteKeyRequest{ + Name: "kms:name=test1-" + suffix, + })) + }) + + assert.Regexp(t, "^kms:.*name=.*;provider=.*$", got.Name) + assert.Equal(t, got.Name, got.CreateSignerRequest.SigningKey) + + if capiKMS.Type() == apiv1.TPMKMS && assert.IsType(t, &rsa.PublicKey{}, got.PublicKey) { + assert.Equal(t, 256, got.PublicKey.(*rsa.PublicKey).Size()) + } else if assert.IsType(t, &ecdsa.PublicKey{}, got.PublicKey) { + assert.Equal(t, elliptic.P256(), got.PublicKey.(*ecdsa.PublicKey).Curve) + } + }, assert.NoError}, + {"ok capi ECDSA", capiKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=test2-" + suffix, + SignatureAlgorithm: apiv1.ECDSAWithSHA384, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + require.NotNil(t, got) + + t.Cleanup(func() { + assert.NoError(t, capiKMS.DeleteKey(&apiv1.DeleteKeyRequest{ + Name: "kms:name=test2-" + suffix, + })) + }) + + assert.Regexp(t, "^kms:.*name=.*;provider=.*$", got.Name) + assert.Equal(t, got.Name, got.CreateSignerRequest.SigningKey) + + if assert.IsType(t, &ecdsa.PublicKey{}, got.PublicKey) { + assert.Equal(t, elliptic.P384(), got.PublicKey.(*ecdsa.PublicKey).Curve) + } + }, assert.NoError}, + {"ok capi RSA", capiKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=test3-" + suffix, + SignatureAlgorithm: apiv1.SHA256WithRSA, + Bits: 2048, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + require.NotNil(t, got) + + t.Cleanup(func() { + assert.NoError(t, capiKMS.DeleteKey(&apiv1.DeleteKeyRequest{ + Name: "kms:name=test3-" + suffix, + })) + }) + + assert.Regexp(t, "^kms:.*name=.*;provider=.*$", got.Name) + assert.Equal(t, got.Name, got.CreateSignerRequest.SigningKey) + if assert.IsType(t, &rsa.PublicKey{}, got.PublicKey) { + assert.Equal(t, 256, got.PublicKey.(*rsa.PublicKey).Size()) + } + }, assert.NoError}, + {"fail capi algorithm", capiKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:test4-" + suffix, + SignatureAlgorithm: apiv1.SignatureAlgorithm(100), + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.CreateKey(tt.args.req) + tt.assertion(t, err) + tt.equal(t, got) + }) + } +} + +func TestKMS_CreateSigner_capi(t *testing.T) { + capiKMS := mustCAPIKMS(t) + capiKey := mustCreatePlatformKey(t, capiKMS) + + assertNil := func(t *testing.T, got crypto.Signer) { + t.Helper() + assert.Nil(t, got) + } + + type args struct { + req *apiv1.CreateSignerRequest + } + tests := []struct { + name string + kms *KMS + args args + equal func(*testing.T, crypto.Signer) + assertion assert.ErrorAssertionFunc + }{ + {"ok capi", capiKMS, args{&apiv1.CreateSignerRequest{ + SigningKey: platformKeyName, + }}, func(t *testing.T, s crypto.Signer) { + require.NotNil(t, s) + assert.Equal(t, capiKey.PublicKey, s.Public()) + }, assert.NoError}, + {"fail capi missing", capiKMS, args{&apiv1.CreateSignerRequest{ + SigningKey: platformMissingName, + }}, assertNil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.CreateSigner(tt.args.req) + tt.assertion(t, err) + tt.equal(t, got) + }) + } +} + +func TestKMS_DeleteKey_capi(t *testing.T) { + capiKMS := mustCAPIKMS(t) + capiKey := mustCreatePlatformKey(t, capiKMS, withNoCleanup()) + + type args struct { + req *apiv1.DeleteKeyRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok capi", capiKMS, args{&apiv1.DeleteKeyRequest{ + Name: capiKey.Name, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + _, getErr := capiKMS.GetPublicKey(&apiv1.GetPublicKeyRequest{ + Name: capiKey.Name, + }) + return assert.NoError(t, err) && assert.Error(t, getErr) + }}, + {"fail capi deleted", capiKMS, args{&apiv1.DeleteKeyRequest{ + Name: capiKey.Name, + }}, assert.Error}, + {"fail capi missing", capiKMS, args{&apiv1.DeleteKeyRequest{ + Name: platformMissingName, + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, tt.kms.DeleteKey(tt.args.req)) + }) + } +} + +func TestKMS_LoadCertificate_capi(t *testing.T) { + capiKMS := mustCAPIKMS(t) + capiKey := mustCreatePlatformKey(t, capiKMS) + capiChain := mustCertificateWithKey(t, capiKey.PublicKey) + require.NoError(t, capiKMS.StoreCertificateChain(&apiv1.StoreCertificateChainRequest{ + Name: platformCertName, + CertificateChain: capiChain, + })) + t.Cleanup(func() { + assert.NoError(t, capiKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: platformCertName, + })) + assert.NoError(t, capiKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: uri.New(Scheme, url.Values{ + "issuer": []string{capiChain[1].Issuer.CommonName}, + "serial": []string{capiChain[1].SerialNumber.String()}, + }).String(), + })) + }) + + type args struct { + req *apiv1.LoadCertificateRequest + } + tests := []struct { + name string + kms *KMS + args args + want *x509.Certificate + assertion assert.ErrorAssertionFunc + }{ + {"ok capi", capiKMS, args{&apiv1.LoadCertificateRequest{ + Name: platformCertName, + }}, capiChain[0], assert.NoError}, + {"ok capi issuer and serial", capiKMS, args{&apiv1.LoadCertificateRequest{ + Name: uri.New(Scheme, url.Values{ + "issuer": []string{capiChain[0].Issuer.CommonName}, + "serial": []string{capiChain[0].SerialNumber.String()}, + }).String(), + }}, capiChain[0], assert.NoError}, + {"fail capi missing", capiKMS, args{&apiv1.LoadCertificateRequest{ + Name: platformMissingName, + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.LoadCertificate(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_StoreCertificate_capi(t *testing.T) { + capiKMS := mustCAPIKMS(t) + capiKey := mustCreatePlatformKey(t, capiKMS) + capiChain := mustCertificateWithKey(t, capiKey.PublicKey) + chainNoKey := mustCertificate(t, "") + + type args struct { + req *apiv1.StoreCertificateRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok capi", capiKMS, args{&apiv1.StoreCertificateRequest{ + Name: platformCertName, + Certificate: capiChain[0], + }}, assert.NoError}, + {"ok capi no key", capiKMS, args{&apiv1.StoreCertificateRequest{ + Name: platformCertName + "-other", + Certificate: chainNoKey[0], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + t.Cleanup(func() { + assert.NoError(t, capiKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: platformCertName, + })) + }) + return assert.NoError(t, err) + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, tt.kms.StoreCertificate(tt.args.req)) + }) + } +} + +func TestKMS_LoadCertificateChain_capi(t *testing.T) { + capiKMS := mustCAPIKMS(t) + capiKey := mustCreatePlatformKey(t, capiKMS) + capiChain := mustCertificateWithKey(t, capiKey.PublicKey) + require.NoError(t, capiKMS.StoreCertificateChain(&apiv1.StoreCertificateChainRequest{ + Name: platformCertName, + CertificateChain: capiChain, + })) + t.Cleanup(func() { + assert.NoError(t, capiKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: platformCertName, + })) + assert.NoError(t, capiKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: uri.New(Scheme, url.Values{ + "issuer": []string{capiChain[1].Issuer.CommonName}, + "serial": []string{capiChain[1].SerialNumber.String()}, + }).String(), + })) + }) + + type args struct { + req *apiv1.LoadCertificateChainRequest + } + tests := []struct { + name string + kms *KMS + args args + want []*x509.Certificate + assertion assert.ErrorAssertionFunc + }{ + {"ok capi", capiKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: platformCertName, + }}, capiChain, assert.NoError}, + {"ok capi issuer and serial", capiKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: uri.New(Scheme, url.Values{ + "issuer": []string{capiChain[0].Issuer.CommonName}, + "serial": []string{capiChain[0].SerialNumber.String()}, + }).String(), + }}, capiChain, assert.NoError}, + {"fail capi missing", capiKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: platformMissingName, + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.LoadCertificateChain(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_StoreCertificateChain_capi(t *testing.T) { + capiKMS := mustCAPIKMS(t) + capiKey := mustCreatePlatformKey(t, capiKMS) + capiChain := mustCertificateWithKey(t, capiKey.PublicKey) + chainNoKey := mustCertificate(t, "") + + type args struct { + req *apiv1.StoreCertificateChainRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok capi", capiKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: platformCertName, + CertificateChain: capiChain, + }}, assert.NoError}, + {"ok capi no key", capiKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: platformCertName + "-other", + CertificateChain: chainNoKey, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + // Storing a certificate with no key is not supported on TPMKMS. + if capiKMS.Type() == apiv1.TPMKMS { + return assert.Error(t, err) + } + + t.Cleanup(func() { + assert.NoError(t, capiKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: platformCertName, + })) + + if capiKMS.Type() == apiv1.MacKMS { + assert.NoError(t, capiKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: uri.New(Scheme, url.Values{ + "serial": []string{hex.EncodeToString(capiChain[1].SerialNumber.Bytes())}, + }).String(), + })) + } + }) + return assert.NoError(t, err) + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, tt.kms.StoreCertificateChain(tt.args.req)) + }) + } +} + +func TestKMS_DeleteCertificate_capi(t *testing.T) { + capiKMS := mustCAPIKMS(t) + _ = mustCreatePlatformCertificate(t, capiKMS, withNoCleanupCertificate()) + + type args struct { + req *apiv1.DeleteCertificateRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok capi", capiKMS, args{&apiv1.DeleteCertificateRequest{ + Name: platformCertName, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + _, loadErr := capiKMS.LoadCertificate(&apiv1.LoadCertificateRequest{ + Name: platformCertName, + }) + return assert.NoError(t, err) && assert.Error(t, loadErr) + }}, + {"fail platform missing", capiKMS, args{&apiv1.DeleteCertificateRequest{ + Name: platformMissingName, + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, tt.kms.DeleteCertificate(tt.args.req)) + }) + } +} + +func TestKMS_SearchKeys_capi(t *testing.T) { + suffix := mustSuffix(t) + capiKMS := mustCAPIKMS(t) + + platformKeys := make([]*apiv1.CreateKeyResponse, 4) + for i := range platformKeys { + name := fmt.Sprintf("kms:name=search-test-%d-%s", i, suffix) + platformKeys[i] = mustCreatePlatformKey(t, capiKMS, withName(name)) + } + + type args struct { + req *apiv1.SearchKeysRequest + } + tests := []struct { + name string + kms *KMS + args args + want *apiv1.SearchKeysResponse + assertion assert.ErrorAssertionFunc + }{ + {"fail capi", capiKMS, args{&apiv1.SearchKeysRequest{ + Query: "kms:", + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.SearchKeys(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } } From 3b94424e9f6c6da5b36f1eddfa83dc154f3120ea Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 26 Feb 2026 11:24:31 -0800 Subject: [PATCH 15/27] Format imports --- Makefile | 2 +- internal/templates/funcmap.go | 1 + jose/encrypt.go | 1 + jose/generate.go | 1 + jose/parse.go | 1 + jose/types.go | 1 + jose/types_test.go | 1 + jose/validate.go | 3 ++- jose/x25519.go | 1 + keyutil/key.go | 3 ++- kms/awskms/awskms.go | 1 + kms/awskms/awskms_test.go | 1 + kms/awskms/no_awskms.go | 1 + kms/awskms/signer.go | 1 + kms/awskms/signer_test.go | 1 + kms/azurekms/key_vault.go | 1 + kms/azurekms/key_vault_test.go | 3 ++- kms/azurekms/no_azurekms.go | 1 + kms/azurekms/signer_test.go | 5 +++-- kms/azurekms/utils.go | 1 + kms/azurekms/utils_test.go | 1 + kms/capi/capi_no_windows.go | 1 + kms/capi/ncrypt_windows.go | 3 ++- kms/capi/no_capi.go | 1 + kms/cloudkms/cloudkms.go | 3 ++- kms/cloudkms/cloudkms_test.go | 7 ++++--- kms/cloudkms/decrypter_test.go | 3 ++- kms/cloudkms/no_cloudkms.go | 1 + kms/cloudkms/signer.go | 1 + kms/cloudkms/signer_test.go | 1 + kms/kms.go | 1 + kms/mackms/mackms_test.go | 1 + kms/mackms/signer_test.go | 1 + kms/pkcs11/pkcs11_no_cgo.go | 1 + kms/pkcs11/pkcs11_test.go | 3 ++- kms/pkcs11/setup_test.go | 1 + kms/platform/kms_test.go | 1 + kms/platform/kms_windows_test.go | 1 + kms/softkms/softkms.go | 1 + kms/softkms/softkms_test.go | 1 + kms/sshagentkms/no_sshagentkms.go | 1 + kms/sshagentkms/sshagentkms.go | 1 + kms/sshagentkms/sshagentkms_test.go | 5 +++-- kms/tpmkms/no_tpmkms.go | 1 + kms/tpmkms/tpmkms_test.go | 1 + kms/uri/uri.go | 1 + nssdb/keys.go | 3 ++- nssdb/keys_test.go | 3 ++- pemutil/ssh.go | 3 ++- sshutil/fingerprint_test.go | 3 ++- sshutil/sshutil_test.go | 3 ++- tpm/ak_test.go | 1 + tpm/caps.go | 1 + tpm/info_test.go | 1 + tpm/key_test.go | 1 + tpm/rand/rand_simulator_test.go | 1 + tpm/tss2/simulator_test.go | 1 + x509util/extensions.go | 1 + 58 files changed, 77 insertions(+), 20 deletions(-) diff --git a/Makefile b/Makefile index bec0b10a..968e8b18 100644 --- a/Makefile +++ b/Makefile @@ -45,7 +45,7 @@ race: ######################################### fmt: - $Q goimports -l -w $(SRC) + $Q goimports --local go.step.sm/crypto -l -w $(SRC) lint: golint govulncheck diff --git a/internal/templates/funcmap.go b/internal/templates/funcmap.go index 45c13759..76b9989e 100644 --- a/internal/templates/funcmap.go +++ b/internal/templates/funcmap.go @@ -7,6 +7,7 @@ import ( "time" "github.com/Masterminds/sprig/v3" + "go.step.sm/crypto/jose" ) diff --git a/jose/encrypt.go b/jose/encrypt.go index 9b61a5f4..c412291e 100644 --- a/jose/encrypt.go +++ b/jose/encrypt.go @@ -4,6 +4,7 @@ import ( "encoding/json" "github.com/pkg/errors" + "go.step.sm/crypto/randutil" ) diff --git a/jose/generate.go b/jose/generate.go index 4bdc6c44..73329b0f 100644 --- a/jose/generate.go +++ b/jose/generate.go @@ -9,6 +9,7 @@ import ( "encoding/base64" "github.com/pkg/errors" + "go.step.sm/crypto/keyutil" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x25519" diff --git a/jose/parse.go b/jose/parse.go index 1236cfaf..b3832393 100644 --- a/jose/parse.go +++ b/jose/parse.go @@ -17,6 +17,7 @@ import ( "time" "github.com/pkg/errors" + "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x25519" ) diff --git a/jose/types.go b/jose/types.go index 0ff48092..b83749d4 100644 --- a/jose/types.go +++ b/jose/types.go @@ -11,6 +11,7 @@ import ( jose "github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3/cryptosigner" "github.com/go-jose/go-jose/v3/jwt" + "go.step.sm/crypto/x25519" ) diff --git a/jose/types_test.go b/jose/types_test.go index 2c06cde6..961ee9ab 100644 --- a/jose/types_test.go +++ b/jose/types_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/pkg/errors" + "go.step.sm/crypto/x25519" ) diff --git a/jose/validate.go b/jose/validate.go index 6a904167..b8bf99a9 100644 --- a/jose/validate.go +++ b/jose/validate.go @@ -12,8 +12,9 @@ import ( "os" "github.com/pkg/errors" - "go.step.sm/crypto/keyutil" "golang.org/x/crypto/ssh" + + "go.step.sm/crypto/keyutil" ) // ValidateSSHPOP validates the given SSH certificate and key for use in an diff --git a/jose/x25519.go b/jose/x25519.go index 25e90e8a..b63d47a8 100644 --- a/jose/x25519.go +++ b/jose/x25519.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/pkg/errors" + "go.step.sm/crypto/x25519" ) diff --git a/keyutil/key.go b/keyutil/key.go index 171cdf3f..a8ec53d8 100644 --- a/keyutil/key.go +++ b/keyutil/key.go @@ -14,8 +14,9 @@ import ( "sync/atomic" "github.com/pkg/errors" - "go.step.sm/crypto/x25519" "golang.org/x/crypto/ssh" + + "go.step.sm/crypto/x25519" ) var ( diff --git a/kms/awskms/awskms.go b/kms/awskms/awskms.go index 5ea34caa..782dea3b 100644 --- a/kms/awskms/awskms.go +++ b/kms/awskms/awskms.go @@ -13,6 +13,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/aws/aws-sdk-go-v2/service/kms/types" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" "go.step.sm/crypto/pemutil" diff --git a/kms/awskms/awskms_test.go b/kms/awskms/awskms_test.go index b3763a86..fc411a98 100644 --- a/kms/awskms/awskms_test.go +++ b/kms/awskms/awskms_test.go @@ -12,6 +12,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kms/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/pemutil" ) diff --git a/kms/awskms/no_awskms.go b/kms/awskms/no_awskms.go index eff7215b..d2db445a 100644 --- a/kms/awskms/no_awskms.go +++ b/kms/awskms/no_awskms.go @@ -8,6 +8,7 @@ import ( "path/filepath" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/awskms/signer.go b/kms/awskms/signer.go index d24af3da..3ec8935f 100644 --- a/kms/awskms/signer.go +++ b/kms/awskms/signer.go @@ -11,6 +11,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/aws/aws-sdk-go-v2/service/kms/types" "github.com/pkg/errors" + "go.step.sm/crypto/pemutil" ) diff --git a/kms/awskms/signer_test.go b/kms/awskms/signer_test.go index 6d7edea0..bc002762 100644 --- a/kms/awskms/signer_test.go +++ b/kms/awskms/signer_test.go @@ -13,6 +13,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/aws/aws-sdk-go-v2/service/kms/types" + "go.step.sm/crypto/pemutil" ) diff --git a/kms/azurekms/key_vault.go b/kms/azurekms/key_vault.go index 5e718751..384a4b41 100644 --- a/kms/azurekms/key_vault.go +++ b/kms/azurekms/key_vault.go @@ -14,6 +14,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" ) diff --git a/kms/azurekms/key_vault_test.go b/kms/azurekms/key_vault_test.go index 77db69a5..aa51b169 100644 --- a/kms/azurekms/key_vault_test.go +++ b/kms/azurekms/key_vault_test.go @@ -15,10 +15,11 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" "github.com/go-jose/go-jose/v3" + "go.uber.org/mock/gomock" + "go.step.sm/crypto/keyutil" "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/azurekms/internal/mock" - "go.uber.org/mock/gomock" ) var errTest = fmt.Errorf("test error") diff --git a/kms/azurekms/no_azurekms.go b/kms/azurekms/no_azurekms.go index 7e3eea6c..b1403804 100644 --- a/kms/azurekms/no_azurekms.go +++ b/kms/azurekms/no_azurekms.go @@ -8,6 +8,7 @@ import ( "path/filepath" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/azurekms/signer_test.go b/kms/azurekms/signer_test.go index 4d894230..28a1bc6c 100644 --- a/kms/azurekms/signer_test.go +++ b/kms/azurekms/signer_test.go @@ -12,11 +12,12 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" - "go.step.sm/crypto/keyutil" - "go.step.sm/crypto/kms/apiv1" "go.uber.org/mock/gomock" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" + + "go.step.sm/crypto/keyutil" + "go.step.sm/crypto/kms/apiv1" ) type FuncMatcher func(x interface{}) bool diff --git a/kms/azurekms/utils.go b/kms/azurekms/utils.go index 9de8f05c..293f6c06 100644 --- a/kms/azurekms/utils.go +++ b/kms/azurekms/utils.go @@ -17,6 +17,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" ) diff --git a/kms/azurekms/utils_test.go b/kms/azurekms/utils_test.go index 63953646..6df6629a 100644 --- a/kms/azurekms/utils_test.go +++ b/kms/azurekms/utils_test.go @@ -13,6 +13,7 @@ import ( "testing" "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/capi/capi_no_windows.go b/kms/capi/capi_no_windows.go index be7e896b..b32563c2 100644 --- a/kms/capi/capi_no_windows.go +++ b/kms/capi/capi_no_windows.go @@ -6,6 +6,7 @@ import ( "context" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/capi/ncrypt_windows.go b/kms/capi/ncrypt_windows.go index 714f3ca2..0126d99c 100644 --- a/kms/capi/ncrypt_windows.go +++ b/kms/capi/ncrypt_windows.go @@ -12,8 +12,9 @@ import ( "fmt" "unsafe" - "go.step.sm/crypto/kms/apiv1" "golang.org/x/sys/windows" + + "go.step.sm/crypto/kms/apiv1" ) const ( diff --git a/kms/capi/no_capi.go b/kms/capi/no_capi.go index 1eec688f..16d6468d 100644 --- a/kms/capi/no_capi.go +++ b/kms/capi/no_capi.go @@ -8,6 +8,7 @@ import ( "path/filepath" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/cloudkms/cloudkms.go b/kms/cloudkms/cloudkms.go index 0092fed9..adae96f5 100644 --- a/kms/cloudkms/cloudkms.go +++ b/kms/cloudkms/cloudkms.go @@ -18,10 +18,11 @@ import ( "cloud.google.com/go/kms/apiv1/kmspb" gax "github.com/googleapis/gax-go/v2" "github.com/pkg/errors" + "google.golang.org/api/option" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" "go.step.sm/crypto/pemutil" - "google.golang.org/api/option" ) // Scheme is the scheme used in uris, the string "cloudkms". diff --git a/kms/cloudkms/cloudkms_test.go b/kms/cloudkms/cloudkms_test.go index e32edb02..b935c1b9 100644 --- a/kms/cloudkms/cloudkms_test.go +++ b/kms/cloudkms/cloudkms_test.go @@ -14,13 +14,14 @@ import ( gax "github.com/googleapis/gax-go/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.step.sm/crypto/kms/apiv1" - "go.step.sm/crypto/kms/uri" - "go.step.sm/crypto/pemutil" "google.golang.org/api/option" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/durationpb" + + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/uri" + "go.step.sm/crypto/pemutil" ) func TestParent(t *testing.T) { diff --git a/kms/cloudkms/decrypter_test.go b/kms/cloudkms/decrypter_test.go index a756cbae..815698fd 100644 --- a/kms/cloudkms/decrypter_test.go +++ b/kms/cloudkms/decrypter_test.go @@ -14,9 +14,10 @@ import ( gax "github.com/googleapis/gax-go/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/wrapperspb" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/pemutil" - "google.golang.org/protobuf/types/known/wrapperspb" ) func TestCloudKMS_CreateDecrypter(t *testing.T) { diff --git a/kms/cloudkms/no_cloudkms.go b/kms/cloudkms/no_cloudkms.go index 31dcc57c..ac4df003 100644 --- a/kms/cloudkms/no_cloudkms.go +++ b/kms/cloudkms/no_cloudkms.go @@ -8,6 +8,7 @@ import ( "path/filepath" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/cloudkms/signer.go b/kms/cloudkms/signer.go index b6f666b1..16fb712d 100644 --- a/kms/cloudkms/signer.go +++ b/kms/cloudkms/signer.go @@ -9,6 +9,7 @@ import ( "cloud.google.com/go/kms/apiv1/kmspb" "github.com/pkg/errors" + "go.step.sm/crypto/pemutil" ) diff --git a/kms/cloudkms/signer_test.go b/kms/cloudkms/signer_test.go index 0b391b04..15664b4c 100644 --- a/kms/cloudkms/signer_test.go +++ b/kms/cloudkms/signer_test.go @@ -15,6 +15,7 @@ import ( gax "github.com/googleapis/gax-go/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.step.sm/crypto/pemutil" ) diff --git a/kms/kms.go b/kms/kms.go index c5ad7042..b530a1be 100644 --- a/kms/kms.go +++ b/kms/kms.go @@ -4,6 +4,7 @@ import ( "context" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" // Enable default implementation diff --git a/kms/mackms/mackms_test.go b/kms/mackms/mackms_test.go index 0959e2fd..adca47ec 100644 --- a/kms/mackms/mackms_test.go +++ b/kms/mackms/mackms_test.go @@ -40,6 +40,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + cf "go.step.sm/crypto/internal/darwin/corefoundation" "go.step.sm/crypto/internal/darwin/security" "go.step.sm/crypto/kms/apiv1" diff --git a/kms/mackms/signer_test.go b/kms/mackms/signer_test.go index 863ba346..c7966c87 100644 --- a/kms/mackms/signer_test.go +++ b/kms/mackms/signer_test.go @@ -30,6 +30,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/pkcs11/pkcs11_no_cgo.go b/kms/pkcs11/pkcs11_no_cgo.go index f1899f47..c063cf20 100644 --- a/kms/pkcs11/pkcs11_no_cgo.go +++ b/kms/pkcs11/pkcs11_no_cgo.go @@ -9,6 +9,7 @@ import ( "path/filepath" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/pkcs11/pkcs11_test.go b/kms/pkcs11/pkcs11_test.go index 1596eff1..608dd7c6 100644 --- a/kms/pkcs11/pkcs11_test.go +++ b/kms/pkcs11/pkcs11_test.go @@ -22,9 +22,10 @@ import ( "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.step.sm/crypto/kms/apiv1" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" + + "go.step.sm/crypto/kms/apiv1" ) func TestNew(t *testing.T) { diff --git a/kms/pkcs11/setup_test.go b/kms/pkcs11/setup_test.go index 0f437c4a..6075ba7a 100644 --- a/kms/pkcs11/setup_test.go +++ b/kms/pkcs11/setup_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/platform/kms_test.go b/kms/platform/kms_test.go index a067024f..158da2b8 100644 --- a/kms/platform/kms_test.go +++ b/kms/platform/kms_test.go @@ -18,6 +18,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.step.sm/crypto/keyutil" "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" diff --git a/kms/platform/kms_windows_test.go b/kms/platform/kms_windows_test.go index e8ee5050..76c76056 100644 --- a/kms/platform/kms_windows_test.go +++ b/kms/platform/kms_windows_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" ) diff --git a/kms/softkms/softkms.go b/kms/softkms/softkms.go index fb8ad78a..997c7838 100644 --- a/kms/softkms/softkms.go +++ b/kms/softkms/softkms.go @@ -10,6 +10,7 @@ import ( "fmt" "github.com/pkg/errors" + "go.step.sm/crypto/keyutil" "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" diff --git a/kms/softkms/softkms_test.go b/kms/softkms/softkms_test.go index dc86ee67..6e5f8299 100644 --- a/kms/softkms/softkms_test.go +++ b/kms/softkms/softkms_test.go @@ -18,6 +18,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x25519" diff --git a/kms/sshagentkms/no_sshagentkms.go b/kms/sshagentkms/no_sshagentkms.go index adb94c66..cc3e29f2 100644 --- a/kms/sshagentkms/no_sshagentkms.go +++ b/kms/sshagentkms/no_sshagentkms.go @@ -8,6 +8,7 @@ import ( "path/filepath" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/sshagentkms/sshagentkms.go b/kms/sshagentkms/sshagentkms.go index 87645c37..acd11dea 100644 --- a/kms/sshagentkms/sshagentkms.go +++ b/kms/sshagentkms/sshagentkms.go @@ -19,6 +19,7 @@ import ( "golang.org/x/crypto/ssh/agent" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/sshutil" diff --git a/kms/sshagentkms/sshagentkms_test.go b/kms/sshagentkms/sshagentkms_test.go index 96621613..9d1656e7 100644 --- a/kms/sshagentkms/sshagentkms_test.go +++ b/kms/sshagentkms/sshagentkms_test.go @@ -19,11 +19,12 @@ import ( "strings" "testing" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" - "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/agent" ) // Some helpers with inspiration from crypto/ssh/agent/client_test.go diff --git a/kms/tpmkms/no_tpmkms.go b/kms/tpmkms/no_tpmkms.go index a08640b0..f7eaa5e5 100644 --- a/kms/tpmkms/no_tpmkms.go +++ b/kms/tpmkms/no_tpmkms.go @@ -8,6 +8,7 @@ import ( "path/filepath" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/tpmkms/tpmkms_test.go b/kms/tpmkms/tpmkms_test.go index 55c44e93..449bc79f 100644 --- a/kms/tpmkms/tpmkms_test.go +++ b/kms/tpmkms/tpmkms_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/tpm" "go.step.sm/crypto/tpm/tss2" diff --git a/kms/uri/uri.go b/kms/uri/uri.go index 860c4b0e..f8922765 100644 --- a/kms/uri/uri.go +++ b/kms/uri/uri.go @@ -12,6 +12,7 @@ import ( "unicode" "github.com/pkg/errors" + "go.step.sm/crypto/internal/termutil" ) diff --git a/nssdb/keys.go b/nssdb/keys.go index d08133fe..bb5864a9 100644 --- a/nssdb/keys.go +++ b/nssdb/keys.go @@ -10,8 +10,9 @@ import ( "fmt" "math/big" - asn1utils "go.step.sm/crypto/internal/utils/asn1" "golang.org/x/crypto/cryptobyte" + + asn1utils "go.step.sm/crypto/internal/utils/asn1" ) // ASN.1 encoded OID for secp256r1 (1.2.840.10045.3.1.7), the only supported curve. diff --git a/nssdb/keys_test.go b/nssdb/keys_test.go index b507dabd..1c45d849 100644 --- a/nssdb/keys_test.go +++ b/nssdb/keys_test.go @@ -11,8 +11,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.step.sm/crypto/pemutil" "golang.org/x/crypto/cryptobyte" + + "go.step.sm/crypto/pemutil" ) func TestEcParams(t *testing.T) { diff --git a/pemutil/ssh.go b/pemutil/ssh.go index 00698dae..d8c56be7 100644 --- a/pemutil/ssh.go +++ b/pemutil/ssh.go @@ -17,9 +17,10 @@ import ( "math/big" "github.com/pkg/errors" + "golang.org/x/crypto/ssh" + bcryptpbkdf "go.step.sm/crypto/internal/bcrypt_pbkdf" "go.step.sm/crypto/randutil" - "golang.org/x/crypto/ssh" ) const ( diff --git a/sshutil/fingerprint_test.go b/sshutil/fingerprint_test.go index 1edf1630..f573bdf9 100644 --- a/sshutil/fingerprint_test.go +++ b/sshutil/fingerprint_test.go @@ -12,8 +12,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.step.sm/crypto/internal/emoji" "golang.org/x/crypto/ssh" + + "go.step.sm/crypto/internal/emoji" ) func generateCertificate(t *testing.T) ssh.PublicKey { diff --git a/sshutil/sshutil_test.go b/sshutil/sshutil_test.go index dc40c03f..64350214 100644 --- a/sshutil/sshutil_test.go +++ b/sshutil/sshutil_test.go @@ -8,9 +8,10 @@ import ( "testing" "github.com/stretchr/testify/require" - "go.step.sm/crypto/keyutil" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" + + "go.step.sm/crypto/keyutil" ) type skKey struct { diff --git a/tpm/ak_test.go b/tpm/ak_test.go index 71370dae..ff862918 100644 --- a/tpm/ak_test.go +++ b/tpm/ak_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/stretchr/testify/require" + "go.step.sm/crypto/keyutil" "go.step.sm/crypto/minica" "go.step.sm/crypto/x509util" diff --git a/tpm/caps.go b/tpm/caps.go index 9d6ca7da..539a9dc8 100644 --- a/tpm/caps.go +++ b/tpm/caps.go @@ -6,6 +6,7 @@ import ( "slices" "github.com/google/go-tpm/legacy/tpm2" + "go.step.sm/crypto/tpm/algorithm" ) diff --git a/tpm/info_test.go b/tpm/info_test.go index 6af4545b..2eac1c74 100644 --- a/tpm/info_test.go +++ b/tpm/info_test.go @@ -6,6 +6,7 @@ import ( "github.com/smallstep/go-attestation/attest" "github.com/stretchr/testify/require" + "go.step.sm/crypto/tpm/manufacturer" ) diff --git a/tpm/key_test.go b/tpm/key_test.go index 3139e3d6..9635eb1e 100644 --- a/tpm/key_test.go +++ b/tpm/key_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/stretchr/testify/require" + "go.step.sm/crypto/keyutil" "go.step.sm/crypto/minica" "go.step.sm/crypto/x509util" diff --git a/tpm/rand/rand_simulator_test.go b/tpm/rand/rand_simulator_test.go index cf79becf..5286e0fa 100644 --- a/tpm/rand/rand_simulator_test.go +++ b/tpm/rand/rand_simulator_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.step.sm/crypto/tpm" "go.step.sm/crypto/tpm/simulator" ) diff --git a/tpm/tss2/simulator_test.go b/tpm/tss2/simulator_test.go index 6028934b..b40ec920 100644 --- a/tpm/tss2/simulator_test.go +++ b/tpm/tss2/simulator_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/stretchr/testify/require" + "go.step.sm/crypto/tpm/simulator" ) diff --git a/x509util/extensions.go b/x509util/extensions.go index a89a7238..95f37c88 100644 --- a/x509util/extensions.go +++ b/x509util/extensions.go @@ -16,6 +16,7 @@ import ( "time" "github.com/pkg/errors" + asn1utils "go.step.sm/crypto/internal/utils/asn1" ) From 49f23310771bb98b415f235e74251da1494830c3 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 26 Feb 2026 14:34:58 -0800 Subject: [PATCH 16/27] Add suggestions from code review --- kms/apiv1/options.go | 2 +- kms/mackms/mackms.go | 2 +- kms/platform/kms.go | 54 ++++++++++++++++---------------- kms/platform/kms_other_test.go | 3 +- kms/platform/kms_test.go | 9 ------ kms/platform/kms_windows.go | 10 +++--- kms/platform/kms_windows_test.go | 3 +- kms/uri/uri.go | 6 ++-- 8 files changed, 41 insertions(+), 48 deletions(-) diff --git a/kms/apiv1/options.go b/kms/apiv1/options.go index b557b812..5733fb4c 100644 --- a/kms/apiv1/options.go +++ b/kms/apiv1/options.go @@ -174,7 +174,7 @@ const ( // MacKMS is the KMS implementation using macOS Keychain and Secure Enclave. MacKMS Type = "mackms" // PlatformKMS is the KMS implementation that uses TPMKMS on Windows and - // Linux and MacKMS on macOS.. + // Linux and MacKMS on macOS. PlatformKMS Type = "kms" ) diff --git a/kms/mackms/mackms.go b/kms/mackms/mackms.go index 90570dfc..c2c28003 100644 --- a/kms/mackms/mackms.go +++ b/kms/mackms/mackms.go @@ -1284,7 +1284,7 @@ func parseCertURI(rawuri string, useDataProtectionKeychain, requireValue bool) ( return nil, fmt.Errorf("error parsing %q: %w", rawuri, err) } if requireValue && label == "" && serialNumber == nil { - return nil, fmt.Errorf("error parsing %q: label or serial are required", rawuri) + return nil, fmt.Errorf("error parsing %q: label or serial is required", rawuri) } return &certAttributes{ diff --git a/kms/platform/kms.go b/kms/platform/kms.go index 68a8e538..aa8ec3ec 100644 --- a/kms/platform/kms.go +++ b/kms/platform/kms.go @@ -102,9 +102,9 @@ func (k *KMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, return nil, err } - req = clone(req) - req.Name = name - resp, err := k.backend.CreateKey(req) + r := clone(req) + r.Name = name + resp, err := k.backend.CreateKey(r) if err != nil { return nil, err } @@ -122,9 +122,9 @@ func (k *KMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error return nil, err } - req = clone(req) - req.SigningKey = signingKey - return k.backend.CreateSigner(req) + r := clone(req) + r.SigningKey = signingKey + return k.backend.CreateSigner(r) } func (k *KMS) DeleteKey(req *apiv1.DeleteKeyRequest) error { @@ -133,9 +133,9 @@ func (k *KMS) DeleteKey(req *apiv1.DeleteKeyRequest) error { return err } - req = clone(req) - req.Name = name - return k.backend.DeleteKey(req) + r := clone(req) + r.Name = name + return k.backend.DeleteKey(r) } func (k *KMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certificate, error) { @@ -144,9 +144,9 @@ func (k *KMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certific return nil, err } - req = clone(req) - req.Name = name - return k.backend.LoadCertificate(req) + r := clone(req) + r.Name = name + return k.backend.LoadCertificate(r) } func (k *KMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { @@ -155,9 +155,9 @@ func (k *KMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { return err } - req = clone(req) - req.Name = name - return k.backend.StoreCertificate(req) + r := clone(req) + r.Name = name + return k.backend.StoreCertificate(r) } func (k *KMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([]*x509.Certificate, error) { @@ -166,9 +166,9 @@ func (k *KMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([]*x return nil, err } - req = clone(req) - req.Name = name - return k.backend.LoadCertificateChain(req) + r := clone(req) + r.Name = name + return k.backend.LoadCertificateChain(r) } func (k *KMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) error { @@ -177,9 +177,9 @@ func (k *KMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) err return err } - req = clone(req) - req.Name = name - return k.backend.StoreCertificateChain(req) + r := clone(req) + r.Name = name + return k.backend.StoreCertificateChain(r) } func (k *KMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { @@ -188,9 +188,9 @@ func (k *KMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { return err } - req = clone(req) - req.Name = name - return k.backend.DeleteCertificate(req) + r := clone(req) + r.Name = name + return k.backend.DeleteCertificate(r) } func (k *KMS) SearchKeys(req *apiv1.SearchKeysRequest) (*apiv1.SearchKeysResponse, error) { @@ -200,9 +200,9 @@ func (k *KMS) SearchKeys(req *apiv1.SearchKeysRequest) (*apiv1.SearchKeysRespons return nil, err } - req = clone(req) - req.Query = query - resp, err := km.SearchKeys(req) + r := clone(req) + r.Query = query + resp, err := km.SearchKeys(r) if err != nil { return nil, err } diff --git a/kms/platform/kms_other_test.go b/kms/platform/kms_other_test.go index b3e703b3..1c83ddb1 100644 --- a/kms/platform/kms_other_test.go +++ b/kms/platform/kms_other_test.go @@ -8,10 +8,11 @@ import ( "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" + "go.step.sm/crypto/tpm/available" ) func mustPlatformKMS(t *testing.T) *KMS { - if !isTPMAvailable() { + if available.Check() != nil { return &KMS{} } diff --git a/kms/platform/kms_test.go b/kms/platform/kms_test.go index 158da2b8..a3d64999 100644 --- a/kms/platform/kms_test.go +++ b/kms/platform/kms_test.go @@ -25,7 +25,6 @@ import ( "go.step.sm/crypto/minica" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" - "go.step.sm/crypto/tpm" ) var ( @@ -52,14 +51,6 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -func isTPMAvailable() bool { - t, err := tpm.New() - if err != nil { - return false - } - return t.Available() == nil -} - func shouldSkipNow(t *testing.T, km *KMS) { t.Helper() diff --git a/kms/platform/kms_windows.go b/kms/platform/kms_windows.go index e25dd83b..9d2dbf44 100644 --- a/kms/platform/kms_windows.go +++ b/kms/platform/kms_windows.go @@ -25,7 +25,7 @@ func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { switch u.backend { case apiv1.CAPIKMS: - opts.URI = transformToCapiKMS(u) + opts.URI = transformToCAPIKMS(u) return newCAPIKMS(ctx, opts) case apiv1.SoftKMS: return newSoftKMS(ctx, opts) @@ -45,12 +45,12 @@ func newCAPIKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { return &KMS{ typ: apiv1.CAPIKMS, backend: km, - transformToURI: transformToCapiKMS, - transformFromURI: transformFromCapiKMS, + transformToURI: transformToCAPIKMS, + transformFromURI: transformFromCAPIKMS, }, nil } -func transformToCapiKMS(u *kmsURI) string { +func transformToCAPIKMS(u *kmsURI) string { uv := url.Values{} if u.name != "" { uv.Set("key", u.name) @@ -66,7 +66,7 @@ func transformToCapiKMS(u *kmsURI) string { return uri.New(capi.Scheme, uv).String() } -func transformFromCapiKMS(rawuri string) (string, error) { +func transformFromCAPIKMS(rawuri string) (string, error) { u, err := uri.ParseWithScheme(capi.Scheme, rawuri) if err != nil { return "", err diff --git a/kms/platform/kms_windows_test.go b/kms/platform/kms_windows_test.go index 76c76056..340d6851 100644 --- a/kms/platform/kms_windows_test.go +++ b/kms/platform/kms_windows_test.go @@ -18,12 +18,13 @@ import ( "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" + "go.step.sm/crypto/tpm/available" ) func mustPlatformKMS(t *testing.T) *KMS { t.Helper() - if !isTPMAvailable() { + if available.Check() != nil { return &KMS{} } diff --git a/kms/uri/uri.go b/kms/uri/uri.go index f8922765..2361880b 100644 --- a/kms/uri/uri.go +++ b/kms/uri/uri.go @@ -255,9 +255,9 @@ func (u *URI) Read(key string) ([]byte, error) { // hexString returns a clean hexadecimal string and a boolean indicating if s // can be an hexadecimal string. If s starts with 0x (0x12), 0X (0X1A), or // contains colons (01:1A) it will remove them and return true if it only -// contains valid hexadecimal characters. It will also true if the string -// contains at least one letter A-F (010A). It will also prefix the string with -// 0 if the length is an odd number. +// contains valid hexadecimal characters. It will also return true if the string +// contains at least one letter A-F (010A). It will prefix the string with 0 if +// the length is an odd number. func hexString(s string) (string, bool) { hx := strings.TrimPrefix(s, "0x") hx = strings.TrimPrefix(hx, "0X") From 498ebf383e77039176b9629665f52da95c3c45b2 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 26 Feb 2026 15:10:29 -0800 Subject: [PATCH 17/27] Fix linter errors --- go.mod | 12 ++++++------ go.sum | 20 ++++++++++---------- kms/uri/uri.go | 2 +- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/go.mod b/go.mod index 78e1d101..1a3579a2 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module go.step.sm/crypto -go 1.24.0 +go 1.25.0 require ( cloud.google.com/go/kms v1.25.0 @@ -23,8 +23,8 @@ require ( github.com/smallstep/go-attestation v0.4.4-0.20241119153605-2306d5b464ca github.com/stretchr/testify v1.11.1 go.uber.org/mock v0.6.0 - golang.org/x/crypto v0.47.0 - golang.org/x/net v0.49.0 + golang.org/x/crypto v0.48.0 + golang.org/x/net v0.51.0 golang.org/x/sys v0.41.0 golang.org/x/term v0.40.0 google.golang.org/api v0.264.0 @@ -92,12 +92,12 @@ require ( go.opentelemetry.io/otel/metric v1.39.0 // indirect go.opentelemetry.io/otel/trace v1.39.0 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect - golang.org/x/mod v0.31.0 // indirect + golang.org/x/mod v0.32.0 // indirect golang.org/x/oauth2 v0.34.0 // indirect golang.org/x/sync v0.19.0 // indirect - golang.org/x/text v0.33.0 // indirect + golang.org/x/text v0.34.0 // indirect golang.org/x/time v0.14.0 // indirect - golang.org/x/tools v0.40.0 // indirect + golang.org/x/tools v0.41.0 // indirect google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 // indirect diff --git a/go.sum b/go.sum index 1591e832..797dd66e 100644 --- a/go.sum +++ b/go.sum @@ -934,8 +934,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= -golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -976,8 +976,8 @@ golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= -golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= +golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= +golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1033,8 +1033,8 @@ golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= +golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181106182150-f42d05182288/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1173,8 +1173,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= -golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1256,8 +1256,8 @@ golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= -golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= +golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= +golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/kms/uri/uri.go b/kms/uri/uri.go index 2361880b..8bd72db3 100644 --- a/kms/uri/uri.go +++ b/kms/uri/uri.go @@ -162,7 +162,7 @@ func (u *URI) GetInt(key string) *int64 { func (u *URI) GetBigInt(key string) (*big.Int, error) { v := u.Get(key) if v == "" { - return nil, nil + return nil, nil //nolint:nilnil // return nil value } if hx, ok := hexString(v); ok { From af42cc2cd4a08b3e3d7e686eb58762e0975ce850 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 26 Feb 2026 16:18:14 -0800 Subject: [PATCH 18/27] fix GetEncoded logic --- kms/pkcs11/pkcs11_test.go | 1 + kms/uri/uri.go | 26 +++++++++++++++----------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/kms/pkcs11/pkcs11_test.go b/kms/pkcs11/pkcs11_test.go index 608dd7c6..d2d9b379 100644 --- a/kms/pkcs11/pkcs11_test.go +++ b/kms/pkcs11/pkcs11_test.go @@ -214,6 +214,7 @@ func TestNew_config(t *testing.T) { }) } } + func TestPKCS11_GetPublicKey(t *testing.T) { k := setupPKCS11(t) type args struct { diff --git a/kms/uri/uri.go b/kms/uri/uri.go index 8bd72db3..5cd571e2 100644 --- a/kms/uri/uri.go +++ b/kms/uri/uri.go @@ -165,7 +165,7 @@ func (u *URI) GetBigInt(key string) (*big.Int, error) { return nil, nil //nolint:nilnil // return nil value } - if hx, ok := hexString(v); ok { + if hx, ok, isHex := hexString(v); ok && isHex { if hx == "" { return nil, fmt.Errorf("value %q is not a valid hexadecimal number", v) } @@ -194,7 +194,7 @@ func (u *URI) GetEncoded(key string) []byte { if v == "" { return nil } - if hx, ok := hexString(v); ok { + if hx, ok, _ := hexString(v); ok { if b, err := hex.DecodeString(hx); err == nil { return b } @@ -211,7 +211,7 @@ func (u *URI) GetHexEncoded(key string) ([]byte, error) { return nil, nil } - hx, ok := hexString(v) + hx, ok, _ := hexString(v) if !ok || hx == "" { return nil, fmt.Errorf("value %q is not a valid hexadecimal number", v) } @@ -252,13 +252,13 @@ func (u *URI) Read(key string) ([]byte, error) { return readFile(path) } -// hexString returns a clean hexadecimal string and a boolean indicating if s -// can be an hexadecimal string. If s starts with 0x (0x12), 0X (0X1A), or -// contains colons (01:1A) it will remove them and return true if it only -// contains valid hexadecimal characters. It will also return true if the string -// contains at least one letter A-F (010A). It will prefix the string with 0 if -// the length is an odd number. -func hexString(s string) (string, bool) { +// hexString returns a clean hexadecimal string, a boolean indicating if s is a +// valid hexadecimal string, and a boolean indicating if s was explicitly +// identifiable as hexadecimal. If s starts with 0x (0x12), 0X (0X1A), or +// contains colons (01:1A) it will remove them. The third boolean is true if s +// had a 0x/0X prefix, contained colons, or contained at least one letter A-F +// (010A). The string is prefixed with 0 if its length is odd. +func hexString(s string) (string, bool, bool) { hx := strings.TrimPrefix(s, "0x") hx = strings.TrimPrefix(hx, "0X") hx = strings.ReplaceAll(hx, ":", "") @@ -269,7 +269,11 @@ func hexString(s string) (string, bool) { } valid, hasLetter := isValidHexString(hx) - return hx, valid && (changed || hasLetter) + if !valid { + return "", false, false + } + + return hx, valid, (changed || hasLetter) } // isValidHexString returns two booleans, the first indicating s contains only From f1076d3315ab6caf8bafc783cb5937211a3db690 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 26 Feb 2026 19:02:23 -0800 Subject: [PATCH 19/27] Add custom attestation using the attestation client This commit adds support for custom attestation using the attestation client and adds CreateAttestation tests. --- kms/apiv1/requests.go | 13 ++++ kms/apiv1/requests_test.go | 23 ++++++- kms/platform/kms.go | 50 ++++++++++++++ kms/platform/kms_test.go | 97 +++++++++++++++++++++++++++ kms/platform/kms_tpm.go | 8 --- kms/platform/kms_tpmsimulator_test.go | 90 +++++++++++++++++++++++++ 6 files changed, 272 insertions(+), 9 deletions(-) diff --git a/kms/apiv1/requests.go b/kms/apiv1/requests.go index 3955f6c0..f4bf1162 100644 --- a/kms/apiv1/requests.go +++ b/kms/apiv1/requests.go @@ -265,6 +265,19 @@ type AttestationClient interface { Attest(context.Context) ([]*x509.Certificate, error) } +type attestSignerCtx struct{} + +// NewAttestSignerContext creates a new context with the given signer. +func NewAttestSignerContext(ctx context.Context, signer crypto.Signer) context.Context { + return context.WithValue(ctx, attestSignerCtx{}, signer) +} + +// AttestSignerFromContext returns the signer from the context. +func AttestSignerFromContext(ctx context.Context) (crypto.Signer, bool) { + signer, ok := ctx.Value(attestSignerCtx{}).(crypto.Signer) + return signer, ok +} + // CertificationParameters encapsulates the inputs for certifying an application key. // Only TPM 2.0 is supported at this point. // diff --git a/kms/apiv1/requests_test.go b/kms/apiv1/requests_test.go index b378e631..199713ac 100644 --- a/kms/apiv1/requests_test.go +++ b/kms/apiv1/requests_test.go @@ -1,6 +1,13 @@ package apiv1 -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.step.sm/crypto/keyutil" +) func TestProtectionLevel_String(t *testing.T) { tests := []struct { @@ -49,3 +56,17 @@ func TestSignatureAlgorithm_String(t *testing.T) { }) } } + +func TestNewAttestSignerContext(t *testing.T) { + signer, err := keyutil.GenerateDefaultSigner() + require.NoError(t, err) + + ctx := NewAttestSignerContext(t.Context(), signer) + got, ok := AttestSignerFromContext(ctx) + assert.Equal(t, signer, got) + assert.True(t, ok) + + got, ok = AttestSignerFromContext(t.Context()) + assert.Nil(t, got) + assert.False(t, ok) +} diff --git a/kms/platform/kms.go b/kms/platform/kms.go index aa8ec3ec..124b0757 100644 --- a/kms/platform/kms.go +++ b/kms/platform/kms.go @@ -4,6 +4,7 @@ import ( "context" "crypto" "crypto/x509" + "errors" "net/url" "strings" @@ -193,6 +194,55 @@ func (k *KMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { return k.backend.DeleteCertificate(r) } +func (k *KMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.CreateAttestationResponse, error) { + if req.Name == "" { + return nil, errors.New("createAttestationRequest 'name' cannot be empty") + } + + name, err := k.transform(req.Name) + if err != nil { + return nil, err + } + + // Attestation implemented by the backend + if km, ok := k.backend.(apiv1.Attester); ok { + r := clone(req) + r.Name = name + + return km.CreateAttestation(r) + } + + // Attestation using a custom attestation client. + if req.AttestationClient != nil { + signer, err := k.backend.CreateSigner(&apiv1.CreateSignerRequest{ + SigningKey: name, + }) + if err != nil { + return nil, err + } + + ctx := apiv1.NewAttestSignerContext(context.Background(), signer) + chain, err := req.AttestationClient.Attest(ctx) + if err != nil { + return nil, err + } + + var permanentIdentifier string + if len(chain[0].URIs) > 0 { + permanentIdentifier = chain[0].URIs[0].String() + } + + return &apiv1.CreateAttestationResponse{ + Certificate: chain[0], + CertificateChain: chain, + PublicKey: signer.Public(), + PermanentIdentifier: permanentIdentifier, + }, nil + } + + return nil, apiv1.NotImplementedError{} +} + func (k *KMS) SearchKeys(req *apiv1.SearchKeysRequest) (*apiv1.SearchKeysResponse, error) { if km, ok := k.backend.(apiv1.SearchableKeyManager); ok { query, err := k.transform(req.Query) diff --git a/kms/platform/kms_test.go b/kms/platform/kms_test.go index a3d64999..5bed1477 100644 --- a/kms/platform/kms_test.go +++ b/kms/platform/kms_test.go @@ -2,13 +2,18 @@ package platform import ( "bytes" + "context" "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rsa" + "crypto/sha256" "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" "encoding/hex" "encoding/pem" + "errors" "fmt" "net/url" "os" @@ -277,6 +282,38 @@ func mustCreatePlatformCertificate(t *testing.T, km *KMS, opts ...createFuncOpti } } +func mustPermanentIdentifier(t *testing.T, pub crypto.PublicKey) *url.URL { + t.Helper() + + b, err := x509.MarshalPKIXPublicKey(pub) + require.NoError(t, err) + + keyID := sha256.Sum256(b) + return &url.URL{ + Scheme: "urn", + Opaque: "ek:sha256:" + base64.StdEncoding.EncodeToString(keyID[:]), + } +} + +type attestationClient struct { + chain []*x509.Certificate + err error +} + +func mustAttestationClient(chain []*x509.Certificate, err error) *attestationClient { + return &attestationClient{ + chain: chain, + err: err, + } +} + +func (c *attestationClient) Attest(ctx context.Context) ([]*x509.Certificate, error) { + if _, ok := apiv1.AttestSignerFromContext(ctx); !ok { + return nil, fmt.Errorf("signer is not in context") + } + return c.chain, c.err +} + func TestKMS_Type(t *testing.T) { softKMS := mustKMS(t, "kms:backend=softkms") assert.Equal(t, apiv1.SoftKMS, softKMS.Type()) @@ -1043,6 +1080,66 @@ func TestKMS_DeleteCertificate(t *testing.T) { } } +func TestKMS_CreateAttestation(t *testing.T) { + dir := t.TempDir() + privateKeyPath := filepath.Join(dir, "private.key") + signer := mustSigner(t, privateKeyPath) + attester := mustSigner(t, "attester.key") + permanentIdentifier := mustPermanentIdentifier(t, attester.Public()) + + ca, err := minica.New() + require.NoError(t, err) + cert, err := ca.Sign(&x509.Certificate{ + Subject: pkix.Name{ + CommonName: "attestation certificate", + }, + URIs: []*url.URL{permanentIdentifier}, + PublicKey: signer.Public(), + }) + require.NoError(t, err) + + softKMS := mustKMS(t, "kms:backend=softkms") + okClient := mustAttestationClient([]*x509.Certificate{cert, ca.Intermediate}, nil) + failClient := mustAttestationClient(nil, errors.New("attestation failed")) + + type args struct { + req *apiv1.CreateAttestationRequest + } + tests := []struct { + name string + kms *KMS + args args + want *apiv1.CreateAttestationResponse + assertion assert.ErrorAssertionFunc + }{ + {"ok custom attestation", softKMS, args{&apiv1.CreateAttestationRequest{ + Name: "kms:" + privateKeyPath, + AttestationClient: okClient, + }}, &apiv1.CreateAttestationResponse{ + Certificate: cert, + CertificateChain: []*x509.Certificate{cert, ca.Intermediate}, + PublicKey: signer.Public(), + PermanentIdentifier: permanentIdentifier.String(), + }, assert.NoError}, + {"fail custom attestation", softKMS, args{&apiv1.CreateAttestationRequest{ + Name: "kms:" + privateKeyPath, + AttestationClient: failClient, + }}, nil, assert.Error}, + {"fail no client", softKMS, args{&apiv1.CreateAttestationRequest{ + Name: "kms:" + privateKeyPath, + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + + got, err := tt.kms.CreateAttestation(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + func TestKMS_SearchKeys(t *testing.T) { dir := t.TempDir() softKMS := mustKMS(t, "kms:backend=softkms") diff --git a/kms/platform/kms_tpm.go b/kms/platform/kms_tpm.go index c8580662..abc6c7a7 100644 --- a/kms/platform/kms_tpm.go +++ b/kms/platform/kms_tpm.go @@ -51,14 +51,6 @@ func NewWithTPM(ctx context.Context, t *tpm.TPM, opts ...tpmkms.Option) (*KMS, e }, nil } -func (k *KMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.CreateAttestationResponse, error) { - if km, ok := k.backend.(apiv1.Attester); ok { - return km.CreateAttestation(req) - } - - return nil, apiv1.NotImplementedError{} -} - func transformToTPMKMS(u *kmsURI) string { uv := url.Values{} if u.name != "" { diff --git a/kms/platform/kms_tpmsimulator_test.go b/kms/platform/kms_tpmsimulator_test.go index d7acab3e..ead78c7e 100644 --- a/kms/platform/kms_tpmsimulator_test.go +++ b/kms/platform/kms_tpmsimulator_test.go @@ -6,6 +6,7 @@ import ( "context" "crypto" "crypto/x509" + "crypto/x509/pkix" "net" "net/url" "path/filepath" @@ -16,6 +17,7 @@ import ( "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" + "go.step.sm/crypto/minica" "go.step.sm/crypto/tpm" "go.step.sm/crypto/tpm/simulator" "go.step.sm/crypto/tpm/storage" @@ -745,6 +747,94 @@ func TestKMS_DeleteCertificate_tpm(t *testing.T) { } } +func TestKMS_CreateAttestation_tpm(t *testing.T) { + ctx := t.Context() + stpm := mustTPM(t) + km, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + eks, err := stpm.GetEKs(ctx) + require.NoError(t, err) + require.NotEmpty(t, eks) + ekKeyURL := mustPermanentIdentifier(t, eks[0].Public()) + + ak, err := stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + key, err := stpm.AttestKey(ctx, "ak-1", "key-1", tpm.AttestKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + QualifyingData: []byte{1, 2, 3, 4}, + }) + require.NoError(t, err) + keyParams, err := key.CertificationParameters(ctx) + require.NoError(t, err) + keySigner, err := key.Signer(ctx) + require.NoError(t, err) + + _, err = stpm.CreateKey(ctx, "key-2", tpm.CreateKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + }) + require.NoError(t, err) + + ca, err := minica.New() + require.NoError(t, err) + + akCert, err := ca.Sign(&x509.Certificate{ + Subject: pkix.Name{ + CommonName: "ak-1", + }, + URIs: []*url.URL{ekKeyURL}, + PublicKey: ak.Public(), + }) + require.NoError(t, err) + require.NoError(t, ak.SetCertificateChain(ctx, []*x509.Certificate{ + akCert, ca.Intermediate, + })) + + type args struct { + req *apiv1.CreateAttestationRequest + } + tests := []struct { + name string + kms *KMS + args args + want *apiv1.CreateAttestationResponse + assertion assert.ErrorAssertionFunc + }{ + {"ok", km, args{&apiv1.CreateAttestationRequest{ + Name: "kms:name=key-1", + }}, &apiv1.CreateAttestationResponse{ + Certificate: akCert, + CertificateChain: []*x509.Certificate{akCert, ca.Intermediate}, + PublicKey: keySigner.Public(), + CertificationParameters: &apiv1.CertificationParameters{ + Public: keyParams.Public, + CreateData: keyParams.CreateData, + CreateAttestation: keyParams.CreateAttestation, + CreateSignature: keyParams.CreateSignature, + }, + PermanentIdentifier: ekKeyURL.String(), + }, assert.NoError}, + {"fail not attested key", km, args{&apiv1.CreateAttestationRequest{ + Name: "kms:name=key-2", + }}, nil, assert.Error}, + {"fail missing key", km, args{&apiv1.CreateAttestationRequest{ + Name: "kms:name=key-3", + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + + got, err := tt.kms.CreateAttestation(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + func TestKMS_SearchKeys_tpm(t *testing.T) { ctx := t.Context() stpm := mustTPM(t) From efdd0bf6a26a1c0a6f04ab9e8d9e4109e18c741a Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 27 Feb 2026 15:39:26 -0800 Subject: [PATCH 20/27] Add helper uri.Values This commit adds the helper uri.Values which returns url.Values merging the ones in the opaque and query strings. --- kms/uri/uri.go | 18 ++++++++++++++++++ kms/uri/uri_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/kms/uri/uri.go b/kms/uri/uri.go index 5cd571e2..899adbd6 100644 --- a/kms/uri/uri.go +++ b/kms/uri/uri.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/hex" "fmt" + "maps" "math/big" "net/url" "os" @@ -252,6 +253,23 @@ func (u *URI) Read(key string) ([]byte, error) { return readFile(path) } +// Values returns the [url.Values] merging the values in the opaque and query +// string of the given [*URI]. +func Values(u *URI) url.Values { + uv := url.Values{} + maps.Copy(uv, u.Values) + for k, v := range u.URL.Query() { + if !uv.Has(k) { + uv[k] = v + continue + } + for _, s := range v { + uv.Add(k, s) + } + } + return uv +} + // hexString returns a clean hexadecimal string, a boolean indicating if s is a // valid hexadecimal string, and a boolean indicating if s was explicitly // identifiable as hexadecimal. If s starts with 0x (0x12), 0X (0X1A), or diff --git a/kms/uri/uri_test.go b/kms/uri/uri_test.go index 44de960d..394c5914 100644 --- a/kms/uri/uri_test.go +++ b/kms/uri/uri_test.go @@ -478,3 +478,28 @@ func TestURI_Read(t *testing.T) { }) } } + +func TestValues(t *testing.T) { + tests := []struct { + name string + rawuri string + want url.Values + }{ + {"empty", "kms:", url.Values{}}, + {"with opaque values", "kms:foo=bar;baz=qux;foo=zar", url.Values{ + "foo": []string{"bar", "zar"}, "baz": []string{"qux"}, + }}, + {"with query values", "kms:name=value?foo=bar&baz=qux&foo=zar", url.Values{ + "name": []string{"value"}, "foo": []string{"bar", "zar"}, "baz": []string{"qux"}, + }}, + {"with mixed values", "kms:name=value;foo=bar?baz=qux&foo=zar", url.Values{ + "name": []string{"value"}, "foo": []string{"bar", "zar"}, "baz": []string{"qux"}, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u := mustParse(t, tt.rawuri) + assert.Equal(t, tt.want, Values(u)) + }) + } +} From 3131fc69ae301b4545b928bd025aa855d1cb67fa Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 27 Feb 2026 16:06:34 -0800 Subject: [PATCH 21/27] Remove hw to ak on tpmkms and add tests --- kms/platform/kms.go | 10 ++++- kms/platform/kms_darwin.go | 11 +++-- kms/platform/kms_darwin_test.go | 58 +++++++++++++++++++++++++- kms/platform/kms_softkms.go | 57 ++++++++++--------------- kms/platform/kms_softkms_test.go | 49 ++++++++++++++++++++++ kms/platform/kms_test.go | 8 +++- kms/platform/kms_tpm.go | 18 +++----- kms/platform/kms_tpm_test.go | 59 ++++++++++++++++++++++++++ kms/platform/kms_tpmsimulator_test.go | 60 +++++++++++++-------------- kms/platform/kms_windows.go | 21 ++++++++-- kms/platform/kms_windows_test.go | 52 +++++++++++++++++++++++ 11 files changed, 313 insertions(+), 90 deletions(-) create mode 100644 kms/platform/kms_softkms_test.go create mode 100644 kms/platform/kms_tpm_test.go diff --git a/kms/platform/kms.go b/kms/platform/kms.go index 124b0757..c2cc3990 100644 --- a/kms/platform/kms.go +++ b/kms/platform/kms.go @@ -34,6 +34,12 @@ type kmsURI struct { extraValues url.Values } +func isDefaultKey(k string) bool { + return k == nameKey || + k == hwKey || + k == backendKey +} + func parseURI(rawuri string) (*kmsURI, error) { u, err := uri.ParseWithScheme(Scheme, rawuri) if err != nil { @@ -41,8 +47,8 @@ func parseURI(rawuri string) (*kmsURI, error) { } extraValues := make(url.Values) - for k, v := range u.Values { - if k != nameKey && k != hwKey && k != backendKey { + for k, v := range uri.Values(u) { + if !isDefaultKey(k) { extraValues[k] = v } } diff --git a/kms/platform/kms_darwin.go b/kms/platform/kms_darwin.go index 469dbd77..bc966213 100644 --- a/kms/platform/kms_darwin.go +++ b/kms/platform/kms_darwin.go @@ -57,7 +57,9 @@ func transformToMacKMS(u *kmsURI) string { } if u.hw { uv.Set("se", "true") - uv.Set("keychain", "dataProtection") + if !u.uri.Has("keychain") { + uv.Set("keychain", "dataProtection") + } } else if u.uri.Has("hw") { uv.Set("se", "false") } @@ -74,14 +76,15 @@ func transformFromMacKMS(rawuri string) (string, error) { return "", err } - uv := url.Values{ - "name": []string{u.Get("label")}, + uv := url.Values{} + if u.Has("label") { + uv.Set("name", u.Get("label")) } if u.GetBool("se") { uv.Set("hw", "true") } - for k, v := range u.Values { + for k, v := range uri.Values(u) { if k != "label" && k != "se" { uv[k] = v } diff --git a/kms/platform/kms_darwin_test.go b/kms/platform/kms_darwin_test.go index e9cb7681..111f342a 100644 --- a/kms/platform/kms_darwin_test.go +++ b/kms/platform/kms_darwin_test.go @@ -1,6 +1,10 @@ package platform -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/assert" +) func mustPlatformKMS(t *testing.T) *KMS { t.Helper() @@ -13,3 +17,55 @@ func mustPlatformKMS(t *testing.T) *KMS { func (k *KMS) SkipTests() bool { return false } + +func Test_transformToMacKMS(t *testing.T) { + type args struct { + u *kmsURI + } + tests := []struct { + name string + rawuri string + want string + }{ + {"scheme", "kms:", "mackms:"}, + {"with name", "kms:name=foo", "mackms:label=foo"}, + {"with hw", "kms:name=foo;hw=true", "mackms:keychain=dataProtection;label=foo;se=true"}, + {"with hw on query", "kms:name=foo?hw=true", "mackms:keychain=dataProtection;label=foo;se=true"}, + {"with hw and keychain", "kms:name=foo;hw=true;keychain=my", "mackms:keychain=my;label=foo;se=true"}, + {"with extrasValues", "kms:name=foo;keychain=my?foo=bar&baz=qux", "mackms:baz=qux;foo=bar;keychain=my;label=foo"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u := mustParseURI(t, tt.rawuri) + assert.Equal(t, tt.want, transformToMacKMS(u)) + }) + } +} + +func Test_transformFromMacKMS(t *testing.T) { + type args struct { + rawuri string + } + tests := []struct { + name string + rawuri string + want string + assertion assert.ErrorAssertionFunc + }{ + {"scheme", "mackms:", "kms:", assert.NoError}, + {"with label", "mackms:label=foo", "kms:name=foo", assert.NoError}, + {"with se", "mackms:label=foo;se=true", "kms:hw=true;name=foo", assert.NoError}, + {"with se on query", "mackms:label=foo?se=true", "kms:hw=true;name=foo", assert.NoError}, + {"with keychain", "mackms:label=foo;se=true;keychain=dataProtection", "kms:hw=true;keychain=dataProtection;name=foo", assert.NoError}, + {"with keychain on query", "mackms:label=foo?keychain=dataProtection&foo=bar", "kms:foo=bar;keychain=dataProtection;name=foo", assert.NoError}, + {"fail empty", "", "", assert.Error}, + {"fail scheme", "kms:", "", assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := transformFromMacKMS(tt.rawuri) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/kms/platform/kms_softkms.go b/kms/platform/kms_softkms.go index 4af3639c..f6178c2c 100644 --- a/kms/platform/kms_softkms.go +++ b/kms/platform/kms_softkms.go @@ -33,8 +33,7 @@ type softKMS struct { } func (k *softKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) { - name := filename(req.Name) - if name == "" { + if req.Name == "" { return nil, fmt.Errorf("createKeyRequest 'name' cannot be empty") } @@ -43,7 +42,7 @@ func (k *softKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespon return nil, err } - if _, err := pemutil.Serialize(resp.PrivateKey, pemutil.ToFile(name, 0o600)); err != nil { + if _, err := pemutil.Serialize(resp.PrivateKey, pemutil.ToFile(req.Name, 0o600)); err != nil { return nil, err } @@ -51,33 +50,30 @@ func (k *softKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespon } func (k *softKMS) DeleteKey(req *apiv1.DeleteKeyRequest) error { - name := filename(req.Name) - if name == "" { + if req.Name == "" { return fmt.Errorf("deleteKeyRequest 'name' cannot be empty") } - return os.Remove(name) + return os.Remove(req.Name) } func (k *softKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { - name := filename(req.Name) switch { - case name == "": + case req.Name == "": return fmt.Errorf("storeCertificateRequest 'name' cannot be empty") case req.Certificate == nil: return fmt.Errorf("storeCertificateRequest 'certificate' cannot be empty") } - return os.WriteFile(name, pem.EncodeToMemory(&pem.Block{ + return os.WriteFile(req.Name, pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: req.Certificate.Raw, }), 0o600) } func (k *softKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) error { - name := filename(req.Name) switch { - case name == "": + case req.Name == "": return fmt.Errorf("storeCertificateChainRequest 'name' cannot be empty") case len(req.CertificateChain) == 0: return fmt.Errorf("storeCertificateChainRequest 'certificateChain' cannot be empty") @@ -93,47 +89,36 @@ func (k *softKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) } } - return os.WriteFile(name, buf.Bytes(), 0o600) + return os.WriteFile(req.Name, buf.Bytes(), 0o600) } func (k *softKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { - name := filename(req.Name) - if name == "" { + if req.Name == "" { return fmt.Errorf("deleteCertificateRequest 'name' cannot be empty") } - return os.Remove(name) -} - -func filename(s string) string { - if u, err := uri.ParseWithScheme(softkms.Scheme, s); err == nil { - switch { - case u.Path != "": - return u.Path - default: - return u.Opaque - } - } - return s + return os.Remove(req.Name) } func transformToSoftKMS(u *kmsURI) string { switch { case u.uri.Has("name"): - return uri.NewOpaque(softkms.Scheme, u.name).String() + return u.name case u.uri.Has("path"): - return uri.NewOpaque(softkms.Scheme, u.uri.Get("path")).String() + return u.uri.Get("path") case u.uri.Path != "": - return uri.NewOpaque(softkms.Scheme, u.uri.Path).String() + return u.uri.Path case u.uri.Opaque != "": - return uri.NewOpaque(softkms.Scheme, u.uri.Opaque).String() + return u.uri.Opaque default: - return uri.NewOpaque(softkms.Scheme, "").String() + return "" } } -func transformFromSoftKMS(rawuri string) (string, error) { - return uri.New(Scheme, url.Values{ - "name": []string{rawuri}, - }).String(), nil +func transformFromSoftKMS(path string) (string, error) { + uv := url.Values{} + if path != "" { + uv.Set(nameKey, path) + } + return uri.New(Scheme, uv).String(), nil } diff --git a/kms/platform/kms_softkms_test.go b/kms/platform/kms_softkms_test.go new file mode 100644 index 00000000..e53facb8 --- /dev/null +++ b/kms/platform/kms_softkms_test.go @@ -0,0 +1,49 @@ +package platform + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_transformToSoftKMS(t *testing.T) { + type args struct { + u *kmsURI + } + tests := []struct { + name string + rawuri string + want string + }{ + {"scheme", "kms:", ""}, + {"with name", "kms:name=path/to/file.crt", "path/to/file.crt"}, + {"with encoded", "kms:name=%2Fpath%2Fto%2Ffile.key", "/path/to/file.key"}, + {"with path", "kms:path=/path/to/file.key", "/path/to/file.key"}, + {"with opaque", "kms:path/to/file.key", "path/to/file.key"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u := mustParseURI(t, tt.rawuri) + assert.Equal(t, tt.want, transformToSoftKMS(u)) + }) + } +} + +func Test_transformFromSoftKMS(t *testing.T) { + tests := []struct { + name string + path string + want string + assertion assert.ErrorAssertionFunc + }{ + {"scheme", "", "kms:", assert.NoError}, + {"with path", "/path/to/file", "kms:name=%2Fpath%2Fto%2Ffile", assert.NoError}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := transformFromSoftKMS(tt.path) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/kms/platform/kms_test.go b/kms/platform/kms_test.go index 5bed1477..f10d1b92 100644 --- a/kms/platform/kms_test.go +++ b/kms/platform/kms_test.go @@ -64,6 +64,12 @@ func shouldSkipNow(t *testing.T, km *KMS) { } } +func mustParseURI(t *testing.T, rawuri string) *kmsURI { + u, err := parseURI(rawuri) + require.NoError(t, err) + return u +} + func mustKMS(t *testing.T, rawuri string) *KMS { t.Helper() @@ -1084,7 +1090,7 @@ func TestKMS_CreateAttestation(t *testing.T) { dir := t.TempDir() privateKeyPath := filepath.Join(dir, "private.key") signer := mustSigner(t, privateKeyPath) - attester := mustSigner(t, "attester.key") + attester := mustSigner(t, filepath.Join(dir, "attester.key")) permanentIdentifier := mustPermanentIdentifier(t, attester.Public()) ca, err := minica.New() diff --git a/kms/platform/kms_tpm.go b/kms/platform/kms_tpm.go index abc6c7a7..520fb36f 100644 --- a/kms/platform/kms_tpm.go +++ b/kms/platform/kms_tpm.go @@ -56,13 +56,9 @@ func transformToTPMKMS(u *kmsURI) string { if u.name != "" { uv.Set("name", u.name) } - if u.hw { - uv.Set("ak", "true") - } else if u.uri.Has("hw") { - uv.Set("ak", "false") - } // Add custom extra values that might be tpmkms specific. + // There is not need to set "hw". maps.Copy(uv, u.extraValues) return uri.New(tpmkms.Scheme, uv).String() @@ -74,15 +70,13 @@ func transformFromTPMKMS(rawuri string) (string, error) { return "", err } - uv := url.Values{ - "name": []string{u.Get("name")}, - } - if u.GetBool("ak") { - uv.Set("hw", "true") + uv := url.Values{} + if u.Has("name") { + uv.Set(nameKey, u.Get("name")) } - for k, v := range u.Values { - if k != "name" && k != "ak" { + for k, v := range uri.Values(u) { + if k != nameKey { uv[k] = v } } diff --git a/kms/platform/kms_tpm_test.go b/kms/platform/kms_tpm_test.go new file mode 100644 index 00000000..58e6e287 --- /dev/null +++ b/kms/platform/kms_tpm_test.go @@ -0,0 +1,59 @@ +package platform + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_transformToTPMKMS(t *testing.T) { + type args struct { + u *kmsURI + } + tests := []struct { + name string + rawuri string + want string + }{ + {"scheme", "kms:", "tpmkms:"}, + {"with name", "kms:name=foo", "tpmkms:name=foo"}, + {"with ak", "kms:name=foo;ak=true", "tpmkms:ak=true;name=foo"}, + {"with ak in query", "kms:name=foo?ak=true", "tpmkms:ak=true;name=foo"}, + {"with ak false", "kms:ak=false", "tpmkms:ak=false"}, + {"with extrasValues", "kms:name=foo;foo=bar?baz=qux", "tpmkms:baz=qux;foo=bar;name=foo"}, + {"without hw", "kms:name=foo;hw=true", "tpmkms:name=foo"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u := mustParseURI(t, tt.rawuri) + assert.Equal(t, tt.want, transformToTPMKMS(u)) + }) + } +} + +func Test_transformFromTPMKMS(t *testing.T) { + type args struct { + rawuri string + } + tests := []struct { + name string + rawuri string + want string + assertion assert.ErrorAssertionFunc + }{ + {"scheme", "tpmkms:", "kms:", assert.NoError}, + {"with label", "tpmkms:name=foo", "kms:name=foo", assert.NoError}, + {"with ak", "tpmkms:name=foo;ak=true", "kms:ak=true;name=foo", assert.NoError}, + {"with ak on query", "tpmkms:name=foo?ak=true", "kms:ak=true;name=foo", assert.NoError}, + {"with ak false", "tpmkms:ak=false;name=foo", "kms:ak=false;name=foo", assert.NoError}, + {"fail empty", "", "", assert.Error}, + {"fail scheme", "kms:", "", assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := transformFromTPMKMS(tt.rawuri) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/kms/platform/kms_tpmsimulator_test.go b/kms/platform/kms_tpmsimulator_test.go index ead78c7e..878936a8 100644 --- a/kms/platform/kms_tpmsimulator_test.go +++ b/kms/platform/kms_tpmsimulator_test.go @@ -152,19 +152,19 @@ func TestKMS_GetPublicKey_tpm(t *testing.T) { Name: "kms:name=key-1", }}, keySigner.Public(), assert.NoError}, {"ok ak", kms1, args{&apiv1.GetPublicKeyRequest{ - Name: "kms:name=ak-1;hw=true", + Name: "kms:name=ak-1;ak=true", }}, ak.Public(), assert.NoError}, {"ok key with tpm", kms2, args{&apiv1.GetPublicKeyRequest{ Name: "kms:name=key-1", }}, keySigner.Public(), assert.NoError}, {"ok ak with tpm", kms2, args{&apiv1.GetPublicKeyRequest{ - Name: "kms:name=ak-1;hw=true", + Name: "kms:name=ak-1;ak=true", }}, ak.Public(), assert.NoError}, {"fail missing key", kms1, args{&apiv1.GetPublicKeyRequest{ Name: "kms:name=key-2", }}, nil, assert.Error}, {"fail missing ak", kms2, args{&apiv1.GetPublicKeyRequest{ - Name: "kms:name=ak-2;hw=true", + Name: "kms:name=ak-2;ak=true", }}, nil, assert.Error}, } for _, tt := range tests { @@ -216,13 +216,13 @@ func TestKMS_CreateKey_tpm(t *testing.T) { }) }, assert.NoError}, {"ok ak", kms1, args{&apiv1.CreateKeyRequest{ - Name: "kms:name=ak-1;hw=true", + Name: "kms:name=ak-1;ak=true", }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { key, err := stpm.GetAK(ctx, "ak-1") require.NoError(t, err) assert.Equal(t, got, &apiv1.CreateKeyResponse{ - Name: "kms:hw=true;name=ak-1", + Name: "kms:ak=true;name=ak-1", PublicKey: key.Public(), }) }, assert.NoError}, @@ -250,13 +250,13 @@ func TestKMS_CreateKey_tpm(t *testing.T) { }) }, assert.NoError}, {"ok ak with tpm", kms2, args{&apiv1.CreateKeyRequest{ - Name: "kms:name=ak-2;hw=true", + Name: "kms:name=ak-2;ak=true", }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { key, err := stpm.GetAK(ctx, "ak-2") require.NoError(t, err) assert.Equal(t, got, &apiv1.CreateKeyResponse{ - Name: "kms:hw=true;name=ak-2", + Name: "kms:ak=true;name=ak-2", PublicKey: key.Public(), }) }, assert.NoError}, @@ -266,7 +266,7 @@ func TestKMS_CreateKey_tpm(t *testing.T) { assert.Nil(t, got) }, assert.Error}, {"fail ak already exists", kms2, args{&apiv1.CreateKeyRequest{ - Name: "kms:name=ak-1;hw=true", + Name: "kms:name=ak-1;ak=true", }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { assert.Nil(t, got) }, assert.Error}, @@ -326,7 +326,7 @@ func TestKMS_CreateSigner_tpm(t *testing.T) { assert.Nil(t, got) }, assert.Error}, {"fail with ak", km, args{&apiv1.CreateSignerRequest{ - SigningKey: "kms:name=ak-1;hw=true", + SigningKey: "kms:name=ak-1;ak=true", }}, func(t *testing.T, got crypto.Signer) { assert.Nil(t, got) }, assert.Error}, @@ -371,7 +371,7 @@ func TestKMS_DeleteKey_tpm(t *testing.T) { return assert.NoError(t, err) && assert.Error(t, keyErr) }}, {"ok ak", km, args{&apiv1.DeleteKeyRequest{ - Name: "kms:name=ak-1;hw=true", + Name: "kms:name=ak-1;ak=true", }}, func(tt assert.TestingT, err error, i ...interface{}) bool { _, akErr := stpm.GetAK(ctx, "ak-1") return assert.NoError(t, err) && assert.Error(t, akErr) @@ -380,7 +380,7 @@ func TestKMS_DeleteKey_tpm(t *testing.T) { Name: "kms:name=key-2", }}, assert.Error}, {"fail missing ak", km, args{&apiv1.DeleteKeyRequest{ - Name: "kms:name=ak-2;hw=true", + Name: "kms:name=ak-2;ak=true", }}, assert.Error}, } for _, tt := range tests { @@ -434,13 +434,13 @@ func TestKMS_LoadCertificate_tpm(t *testing.T) { Name: "kms:name=key-1", }}, keyChain[0], assert.NoError}, {"ok ak", km, args{&apiv1.LoadCertificateRequest{ - Name: "kms:name=ak-1;hw=true", + Name: "kms:name=ak-1;ak=true", }}, akChain[0], assert.NoError}, {"fail no certificate", km, args{&apiv1.LoadCertificateRequest{ Name: "kms:name=key-2", }}, nil, assert.Error}, {"fail no ak certificate", km, args{&apiv1.LoadCertificateRequest{ - Name: "kms:name=ak-2;hw=true", + Name: "kms:name=ak-2;ak=true", }}, nil, assert.Error}, {"fail missing", km, args{&apiv1.LoadCertificateRequest{ Name: "kms:name=missing-key", @@ -501,7 +501,7 @@ func TestKMS_StoreCertificate_tpm(t *testing.T) { return assert.Equal(t, keyChain2[0], k.Certificate()) }}, {"ok ak", km, args{&apiv1.StoreCertificateRequest{ - Name: "kms:name=ak-1;hw=true", + Name: "kms:name=ak-1;ak=true", Certificate: akChain1[0], }}, func(tt assert.TestingT, err error, i ...interface{}) bool { k, err := stpm.GetAK(ctx, "ak-1") @@ -509,7 +509,7 @@ func TestKMS_StoreCertificate_tpm(t *testing.T) { return assert.Equal(t, akChain1[0], k.Certificate()) }}, {"ok ak overwrite", km, args{&apiv1.StoreCertificateRequest{ - Name: "kms:name=ak-1;hw=true", + Name: "kms:name=ak-1;ak=true", Certificate: akChain2[0], }}, func(tt assert.TestingT, err error, i ...interface{}) bool { k, err := stpm.GetAK(ctx, "ak-1") @@ -525,7 +525,7 @@ func TestKMS_StoreCertificate_tpm(t *testing.T) { Certificate: akChain1[0], }}, assert.Error}, {"fail ak key not match", km, args{&apiv1.StoreCertificateRequest{ - Name: "kms:name=ak-1;hw=true", + Name: "kms:name=ak-1;ak=true", Certificate: keyChain1[0], }}, assert.Error}, } @@ -580,13 +580,13 @@ func TestKMS_LoadCertificateChain_tpm(t *testing.T) { Name: "kms:name=key-1", }}, keyChain, assert.NoError}, {"ok ak", km, args{&apiv1.LoadCertificateChainRequest{ - Name: "kms:name=ak-1;hw=true", + Name: "kms:name=ak-1;ak=true", }}, akChain, assert.NoError}, {"fail no chain", km, args{&apiv1.LoadCertificateChainRequest{ Name: "kms:name=key-2", }}, nil, assert.Error}, {"fail no ak chain", km, args{&apiv1.LoadCertificateChainRequest{ - Name: "kms:name=ak-2;hw=true", + Name: "kms:name=ak-2;ak=true", }}, nil, assert.Error}, {"fail missing", km, args{&apiv1.LoadCertificateChainRequest{ Name: "kms:name=missing-key", @@ -649,7 +649,7 @@ func TestKMS_StoreCertificateChain_tpm(t *testing.T) { assert.Equal(t, keyChain2, k.CertificateChain()) }}, {"ok ak", km, args{&apiv1.StoreCertificateChainRequest{ - Name: "kms:name=ak-1;hw=true", + Name: "kms:name=ak-1;ak=true", CertificateChain: akChain1, }}, func(tt assert.TestingT, err error, i ...interface{}) bool { k, err := stpm.GetAK(ctx, "ak-1") @@ -658,7 +658,7 @@ func TestKMS_StoreCertificateChain_tpm(t *testing.T) { assert.Equal(t, akChain1, k.CertificateChain()) }}, {"ok ak overwrite", km, args{&apiv1.StoreCertificateChainRequest{ - Name: "kms:name=ak-1;hw=true", + Name: "kms:name=ak-1;ak=true", CertificateChain: akChain2, }}, func(tt assert.TestingT, err error, i ...interface{}) bool { k, err := stpm.GetAK(ctx, "ak-1") @@ -675,7 +675,7 @@ func TestKMS_StoreCertificateChain_tpm(t *testing.T) { CertificateChain: akChain1, }}, assert.Error}, {"fail ak key not match", km, args{&apiv1.StoreCertificateChainRequest{ - Name: "kms:name=ak-1;hw=true", + Name: "kms:name=ak-1;ak=true", CertificateChain: keyChain1, }}, assert.Error}, } @@ -724,7 +724,7 @@ func TestKMS_DeleteCertificate_tpm(t *testing.T) { return assert.Nil(t, k.Certificate()) && assert.Nil(t, k.CertificateChain()) }}, {"ok ak", km, args{&apiv1.DeleteCertificateRequest{ - Name: "kms:name=ak-1;hw=true", + Name: "kms:name=ak-1;ak=true", }}, func(tt assert.TestingT, err error, i ...interface{}) bool { k, err := stpm.GetAK(ctx, "ak-1") require.NoError(t, err) @@ -734,10 +734,10 @@ func TestKMS_DeleteCertificate_tpm(t *testing.T) { Name: "kms:name=key-1", }}, assert.NoError}, {"ok delete again ak", km, args{&apiv1.DeleteCertificateRequest{ - Name: "kms:name=ak-1;hw=true", + Name: "kms:name=ak-1;ak=true", }}, assert.NoError}, {"fail missing", km, args{&apiv1.DeleteCertificateRequest{ - Name: "kms:name=missing-ak;hw=true", + Name: "kms:name=missing-ak;ak=true", }}, assert.Error}, } for _, tt := range tests { @@ -873,14 +873,14 @@ func TestKMS_SearchKeys_tpm(t *testing.T) { Query: "kms:", }}, &apiv1.SearchKeysResponse{ Results: []apiv1.SearchKeyResult{ - {Name: "kms:hw=true;name=ak-1", PublicKey: ak1.Public()}, - {Name: "kms:hw=true;name=ak-2", PublicKey: ak2.Public()}, + {Name: "kms:ak=true;name=ak-1", PublicKey: ak1.Public()}, + {Name: "kms:ak=true;name=ak-2", PublicKey: ak2.Public()}, {Name: "kms:name=key-1", PublicKey: key1.Public(), CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: "kms:name=key-1"}}, {Name: "kms:name=key-2", PublicKey: key2.Public(), CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: "kms:name=key-2"}}, }, }, assert.NoError}, {"ok keys", km, args{&apiv1.SearchKeysRequest{ - Query: "kms:hw=false", + Query: "kms:ak=false", }}, &apiv1.SearchKeysResponse{ Results: []apiv1.SearchKeyResult{ {Name: "kms:name=key-1", PublicKey: key1.Public(), CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: "kms:name=key-1"}}, @@ -888,11 +888,11 @@ func TestKMS_SearchKeys_tpm(t *testing.T) { }, }, assert.NoError}, {"ok aks", km, args{&apiv1.SearchKeysRequest{ - Query: "kms:hw=true", + Query: "kms:ak=true", }}, &apiv1.SearchKeysResponse{ Results: []apiv1.SearchKeyResult{ - {Name: "kms:hw=true;name=ak-1", PublicKey: ak1.Public()}, - {Name: "kms:hw=true;name=ak-2", PublicKey: ak2.Public()}, + {Name: "kms:ak=true;name=ak-1", PublicKey: ak1.Public()}, + {Name: "kms:ak=true;name=ak-2", PublicKey: ak2.Public()}, }, }, assert.NoError}, } diff --git a/kms/platform/kms_windows.go b/kms/platform/kms_windows.go index 9d2dbf44..5958dc4c 100644 --- a/kms/platform/kms_windows.go +++ b/kms/platform/kms_windows.go @@ -13,6 +13,8 @@ import ( "go.step.sm/crypto/kms/uri" ) +const tpmProvider = "Microsoft Platform Crypto Provider" + func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { if opts.URI == "" { return newTPMKMS(ctx, opts) @@ -58,7 +60,14 @@ func transformToCAPIKMS(u *kmsURI) string { // When storing certificate skip key validation. // This avoid a prompt looking for an SmartCard. - uv.Set("skip-find-certificate-key", "true") + if !u.extraValues.Has("skip-find-certificate-key") { + uv.Set("skip-find-certificate-key", "true") + } + + // Set provider "Microsoft Platform Crypto Provider" to use the TPM. + if u.hw && !u.extraValues.Has("provider") { + uv.Set("provider", tpmProvider) + } // Add custom extra values that might be CAPI specific. maps.Copy(uv, u.extraValues) @@ -72,11 +81,15 @@ func transformFromCAPIKMS(rawuri string) (string, error) { return "", err } - uv := url.Values{ - "name": []string{u.Get("key")}, + uv := url.Values{} + if u.Has("key") { + uv.Set("name", u.Get("key")) + } + if u.Get("provider") == tpmProvider { + uv.Set("hw", "true") } - for k, v := range u.Values { + for k, v := range uri.Values(u) { if k != "key" { uv[k] = v } diff --git a/kms/platform/kms_windows_test.go b/kms/platform/kms_windows_test.go index 340d6851..63baae3f 100644 --- a/kms/platform/kms_windows_test.go +++ b/kms/platform/kms_windows_test.go @@ -499,3 +499,55 @@ func TestKMS_SearchKeys_capi(t *testing.T) { }) } } + +func Test_transformToCAPIKMS(t *testing.T) { + type args struct { + u *kmsURI + } + tests := []struct { + name string + rawuri string + want string + }{ + {"scheme", "kms:", "capi:skip-find-certificate-key=true"}, + {"with name", "kms:name=foo", "capi:key=foo;skip-find-certificate-key=true"}, + {"with hw", "kms:name=foo;hw=true", "capi:key=foo;provider=Microsoft+Platform+Crypto+Provider;skip-find-certificate-key=true"}, + {"with hw on query", "kms:name=foo?hw=true", "capi:key=foo;provider=Microsoft+Platform+Crypto+Provider;skip-find-certificate-key=true"}, + {"with skip-find-certificate-key", "kms:name=foo;skip-find-certificate-key=false", "capi:key=foo;skip-find-certificate-key=false"}, + {"with provider", "kms:name=foo;hw=true;provider=my", "capi:key=foo;provider=my;skip-find-certificate-key=true"}, + {"with extrasValues", "kms:name=foo;foo=bar?baz=qux", "capi:baz=qux;foo=bar;key=foo;skip-find-certificate-key=true"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u := mustParseURI(t, tt.rawuri) + assert.Equal(t, tt.want, transformToCAPIKMS(u)) + }) + } +} + +func Test_transformFromCAPIKMS(t *testing.T) { + type args struct { + rawuri string + } + tests := []struct { + name string + rawuri string + want string + assertion assert.ErrorAssertionFunc + }{ + {"scheme", "capi:", "kms:", assert.NoError}, + {"with key", "capi:key=foo", "kms:name=foo", assert.NoError}, + {"with provider", "capi:key=foo;provider=Microsoft+Platform+Crypto+Provider", "kms:hw=true;name=foo;provider=Microsoft+Platform+Crypto+Provider", assert.NoError}, + {"with provider on query", "capi:key=foo?provider=my", "kms:name=foo;provider=my", assert.NoError}, + {"with others", "capi:key=foo;serial=1234;issuer=My+CA", "kms:issuer=My+CA;name=foo;serial=1234", assert.NoError}, + {"fail empty", "", "", assert.Error}, + {"fail scheme", "kms:", "", assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := transformFromCAPIKMS(tt.rawuri) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} From ef3550ddfd5ee9c93cffada78137067b85681339 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 27 Feb 2026 16:11:54 -0800 Subject: [PATCH 22/27] fix linter errors --- kms/platform/kms_darwin_test.go | 6 ------ kms/platform/kms_softkms_test.go | 3 --- kms/platform/kms_tpm_test.go | 6 ------ kms/platform/kms_windows_test.go | 6 ------ 4 files changed, 21 deletions(-) diff --git a/kms/platform/kms_darwin_test.go b/kms/platform/kms_darwin_test.go index 111f342a..fff6672c 100644 --- a/kms/platform/kms_darwin_test.go +++ b/kms/platform/kms_darwin_test.go @@ -19,9 +19,6 @@ func (k *KMS) SkipTests() bool { } func Test_transformToMacKMS(t *testing.T) { - type args struct { - u *kmsURI - } tests := []struct { name string rawuri string @@ -43,9 +40,6 @@ func Test_transformToMacKMS(t *testing.T) { } func Test_transformFromMacKMS(t *testing.T) { - type args struct { - rawuri string - } tests := []struct { name string rawuri string diff --git a/kms/platform/kms_softkms_test.go b/kms/platform/kms_softkms_test.go index e53facb8..e491e35b 100644 --- a/kms/platform/kms_softkms_test.go +++ b/kms/platform/kms_softkms_test.go @@ -7,9 +7,6 @@ import ( ) func Test_transformToSoftKMS(t *testing.T) { - type args struct { - u *kmsURI - } tests := []struct { name string rawuri string diff --git a/kms/platform/kms_tpm_test.go b/kms/platform/kms_tpm_test.go index 58e6e287..31a2304e 100644 --- a/kms/platform/kms_tpm_test.go +++ b/kms/platform/kms_tpm_test.go @@ -7,9 +7,6 @@ import ( ) func Test_transformToTPMKMS(t *testing.T) { - type args struct { - u *kmsURI - } tests := []struct { name string rawuri string @@ -32,9 +29,6 @@ func Test_transformToTPMKMS(t *testing.T) { } func Test_transformFromTPMKMS(t *testing.T) { - type args struct { - rawuri string - } tests := []struct { name string rawuri string diff --git a/kms/platform/kms_windows_test.go b/kms/platform/kms_windows_test.go index 63baae3f..1b508fdf 100644 --- a/kms/platform/kms_windows_test.go +++ b/kms/platform/kms_windows_test.go @@ -501,9 +501,6 @@ func TestKMS_SearchKeys_capi(t *testing.T) { } func Test_transformToCAPIKMS(t *testing.T) { - type args struct { - u *kmsURI - } tests := []struct { name string rawuri string @@ -526,9 +523,6 @@ func Test_transformToCAPIKMS(t *testing.T) { } func Test_transformFromCAPIKMS(t *testing.T) { - type args struct { - rawuri string - } tests := []struct { name string rawuri string From 83c660feeac82c04b1986a6ca39cb35b498bf619 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 2 Mar 2026 12:12:29 -0800 Subject: [PATCH 23/27] Fail hw with softkms --- kms/platform/kms.go | 34 ++++++++++++-------------------- kms/platform/kms_darwin.go | 18 ++++++++++++++--- kms/platform/kms_darwin_test.go | 25 ++++++++++++----------- kms/platform/kms_softkms.go | 21 ++++++++++++++------ kms/platform/kms_softkms_test.go | 24 ++++++++++++---------- kms/platform/kms_test.go | 11 +++++++++++ kms/platform/kms_tpm.go | 23 +++++++++++---------- kms/platform/kms_tpm_test.go | 27 ++++++++++++++----------- kms/platform/kms_windows.go | 18 ++++++++++++++--- kms/platform/kms_windows_test.go | 27 ++++++++++++++----------- 10 files changed, 140 insertions(+), 88 deletions(-) diff --git a/kms/platform/kms.go b/kms/platform/kms.go index c2cc3990..33698b32 100644 --- a/kms/platform/kms.go +++ b/kms/platform/kms.go @@ -77,7 +77,7 @@ var _ apiv1.CertificateChainManager = (*KMS)(nil) type KMS struct { typ apiv1.Type backend extendedKeyManager - transformToURI func(*kmsURI) string + transformToURI func(string) (string, error) transformFromURI func(string) (string, error) } @@ -94,17 +94,18 @@ func (k *KMS) Close() error { } func (k *KMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) { - name, err := k.transform(req.Name) + name, err := k.transformToURI(req.Name) if err != nil { return nil, err } + return k.backend.GetPublicKey(&apiv1.GetPublicKeyRequest{ Name: name, }) } func (k *KMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) { - name, err := k.transform(req.Name) + name, err := k.transformToURI(req.Name) if err != nil { return nil, err } @@ -124,7 +125,7 @@ func (k *KMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error return req.Signer, nil } - signingKey, err := k.transform(req.SigningKey) + signingKey, err := k.transformToURI(req.SigningKey) if err != nil { return nil, err } @@ -135,7 +136,7 @@ func (k *KMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error } func (k *KMS) DeleteKey(req *apiv1.DeleteKeyRequest) error { - name, err := k.transform(req.Name) + name, err := k.transformToURI(req.Name) if err != nil { return err } @@ -146,7 +147,7 @@ func (k *KMS) DeleteKey(req *apiv1.DeleteKeyRequest) error { } func (k *KMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certificate, error) { - name, err := k.transform(req.Name) + name, err := k.transformToURI(req.Name) if err != nil { return nil, err } @@ -157,7 +158,7 @@ func (k *KMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certific } func (k *KMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { - name, err := k.transform(req.Name) + name, err := k.transformToURI(req.Name) if err != nil { return err } @@ -168,7 +169,7 @@ func (k *KMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { } func (k *KMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([]*x509.Certificate, error) { - name, err := k.transform(req.Name) + name, err := k.transformToURI(req.Name) if err != nil { return nil, err } @@ -179,7 +180,7 @@ func (k *KMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([]*x } func (k *KMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) error { - name, err := k.transform(req.Name) + name, err := k.transformToURI(req.Name) if err != nil { return err } @@ -190,7 +191,7 @@ func (k *KMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) err } func (k *KMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { - name, err := k.transform(req.Name) + name, err := k.transformToURI(req.Name) if err != nil { return err } @@ -205,7 +206,7 @@ func (k *KMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.Cre return nil, errors.New("createAttestationRequest 'name' cannot be empty") } - name, err := k.transform(req.Name) + name, err := k.transformToURI(req.Name) if err != nil { return nil, err } @@ -251,7 +252,7 @@ func (k *KMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.Cre func (k *KMS) SearchKeys(req *apiv1.SearchKeysRequest) (*apiv1.SearchKeysResponse, error) { if km, ok := k.backend.(apiv1.SearchableKeyManager); ok { - query, err := k.transform(req.Query) + query, err := k.transformToURI(req.Query) if err != nil { return nil, err } @@ -269,15 +270,6 @@ func (k *KMS) SearchKeys(req *apiv1.SearchKeysRequest) (*apiv1.SearchKeysRespons return nil, apiv1.NotImplementedError{} } -func (k *KMS) transform(rawuri string) (string, error) { - u, err := parseURI(rawuri) - if err != nil { - return "", err - } - - return k.transformToURI(u), nil -} - func (k *KMS) patchCreateKeyResponse(resp *apiv1.CreateKeyResponse) (*apiv1.CreateKeyResponse, error) { name, err := k.transformFromURI(resp.Name) if err != nil { diff --git a/kms/platform/kms_darwin.go b/kms/platform/kms_darwin.go index bc966213..e43864cb 100644 --- a/kms/platform/kms_darwin.go +++ b/kms/platform/kms_darwin.go @@ -29,7 +29,6 @@ func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { case apiv1.SoftKMS: return newSoftKMS(ctx, opts) case apiv1.DefaultKMS, apiv1.MacKMS: - opts.URI = transformToMacKMS(u) return newMacKMS(ctx, opts) default: return nil, fmt.Errorf("failed parsing %q: unsupported backend %q", opts.URI, u.backend) @@ -37,6 +36,14 @@ func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { } func newMacKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { + if opts.URI != "" { + u, err := transformToMacKMS(opts.URI) + if err != nil { + return nil, fmt.Errorf("error parsing uri: %w", err) + } + opts.URI = u + } + km, err := mackms.New(ctx, opts) if err != nil { return nil, err @@ -50,7 +57,12 @@ func newMacKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { }, nil } -func transformToMacKMS(u *kmsURI) string { +func transformToMacKMS(rawuri string) (string, error) { + u, err := parseURI(rawuri) + if err != nil { + return "", err + } + uv := url.Values{} if u.name != "" { uv.Set("label", u.name) @@ -67,7 +79,7 @@ func transformToMacKMS(u *kmsURI) string { // Add custom extra values that might be mackms specific. maps.Copy(uv, u.extraValues) - return uri.New(mackms.Scheme, uv).String() + return uri.New(mackms.Scheme, uv).String(), nil } func transformFromMacKMS(rawuri string) (string, error) { diff --git a/kms/platform/kms_darwin_test.go b/kms/platform/kms_darwin_test.go index fff6672c..2f491470 100644 --- a/kms/platform/kms_darwin_test.go +++ b/kms/platform/kms_darwin_test.go @@ -20,21 +20,24 @@ func (k *KMS) SkipTests() bool { func Test_transformToMacKMS(t *testing.T) { tests := []struct { - name string - rawuri string - want string + name string + rawuri string + want string + assertion assert.ErrorAssertionFunc }{ - {"scheme", "kms:", "mackms:"}, - {"with name", "kms:name=foo", "mackms:label=foo"}, - {"with hw", "kms:name=foo;hw=true", "mackms:keychain=dataProtection;label=foo;se=true"}, - {"with hw on query", "kms:name=foo?hw=true", "mackms:keychain=dataProtection;label=foo;se=true"}, - {"with hw and keychain", "kms:name=foo;hw=true;keychain=my", "mackms:keychain=my;label=foo;se=true"}, - {"with extrasValues", "kms:name=foo;keychain=my?foo=bar&baz=qux", "mackms:baz=qux;foo=bar;keychain=my;label=foo"}, + {"scheme", "kms:", "mackms:", assert.NoError}, + {"with name", "kms:name=foo", "mackms:label=foo", assert.NoError}, + {"with hw", "kms:name=foo;hw=true", "mackms:keychain=dataProtection;label=foo;se=true", assert.NoError}, + {"with hw on query", "kms:name=foo?hw=true", "mackms:keychain=dataProtection;label=foo;se=true", assert.NoError}, + {"with hw and keychain", "kms:name=foo;hw=true;keychain=my", "mackms:keychain=my;label=foo;se=true", assert.NoError}, + {"with extrasValues", "kms:name=foo;keychain=my?foo=bar&baz=qux", "mackms:baz=qux;foo=bar;keychain=my;label=foo", assert.NoError}, + {"fail parse", "softkms:name=foo", "", assert.Error}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - u := mustParseURI(t, tt.rawuri) - assert.Equal(t, tt.want, transformToMacKMS(u)) + got, err := transformToMacKMS(tt.rawuri) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) }) } } diff --git a/kms/platform/kms_softkms.go b/kms/platform/kms_softkms.go index f6178c2c..f3a76a7f 100644 --- a/kms/platform/kms_softkms.go +++ b/kms/platform/kms_softkms.go @@ -100,18 +100,27 @@ func (k *softKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { return os.Remove(req.Name) } -func transformToSoftKMS(u *kmsURI) string { +func transformToSoftKMS(rawuri string) (string, error) { + u, err := parseURI(rawuri) + if err != nil { + return "", err + } + + if u.hw { + return "", fmt.Errorf("error parsing uri: hw is not supported") + } + switch { case u.uri.Has("name"): - return u.name + return u.name, nil case u.uri.Has("path"): - return u.uri.Get("path") + return u.uri.Get("path"), nil case u.uri.Path != "": - return u.uri.Path + return u.uri.Path, nil case u.uri.Opaque != "": - return u.uri.Opaque + return u.uri.Opaque, nil default: - return "" + return "", nil } } diff --git a/kms/platform/kms_softkms_test.go b/kms/platform/kms_softkms_test.go index e491e35b..8d47537b 100644 --- a/kms/platform/kms_softkms_test.go +++ b/kms/platform/kms_softkms_test.go @@ -8,20 +8,24 @@ import ( func Test_transformToSoftKMS(t *testing.T) { tests := []struct { - name string - rawuri string - want string + name string + rawuri string + want string + assertion assert.ErrorAssertionFunc }{ - {"scheme", "kms:", ""}, - {"with name", "kms:name=path/to/file.crt", "path/to/file.crt"}, - {"with encoded", "kms:name=%2Fpath%2Fto%2Ffile.key", "/path/to/file.key"}, - {"with path", "kms:path=/path/to/file.key", "/path/to/file.key"}, - {"with opaque", "kms:path/to/file.key", "path/to/file.key"}, + {"scheme", "kms:", "", assert.NoError}, + {"with name", "kms:name=path/to/file.crt", "path/to/file.crt", assert.NoError}, + {"with encoded", "kms:name=%2Fpath%2Fto%2Ffile.key", "/path/to/file.key", assert.NoError}, + {"with path", "kms:path=/path/to/file.key", "/path/to/file.key", assert.NoError}, + {"with opaque", "kms:path/to/file.key", "path/to/file.key", assert.NoError}, + {"fail parse", "mackms:", "", assert.Error}, + {"fail hw", "kms:name=file.key;hw=true", "", assert.Error}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - u := mustParseURI(t, tt.rawuri) - assert.Equal(t, tt.want, transformToSoftKMS(u)) + got, err := transformToSoftKMS(tt.rawuri) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) }) } } diff --git a/kms/platform/kms_test.go b/kms/platform/kms_test.go index f10d1b92..9f09b01a 100644 --- a/kms/platform/kms_test.go +++ b/kms/platform/kms_test.go @@ -1127,6 +1127,10 @@ func TestKMS_CreateAttestation(t *testing.T) { PublicKey: signer.Public(), PermanentIdentifier: permanentIdentifier.String(), }, assert.NoError}, + {"fail missing key", softKMS, args{&apiv1.CreateAttestationRequest{ + Name: "kms:" + platformMissingName, + AttestationClient: okClient, + }}, nil, assert.Error}, {"fail custom attestation", softKMS, args{&apiv1.CreateAttestationRequest{ Name: "kms:" + privateKeyPath, AttestationClient: failClient, @@ -1134,6 +1138,10 @@ func TestKMS_CreateAttestation(t *testing.T) { {"fail no client", softKMS, args{&apiv1.CreateAttestationRequest{ Name: "kms:" + privateKeyPath, }}, nil, assert.Error}, + {"fail no name", softKMS, args{&apiv1.CreateAttestationRequest{}}, nil, assert.Error}, + {"fail parse", softKMS, args{&apiv1.CreateAttestationRequest{ + Name: "tpmkms:", + }}, nil, assert.Error}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -1198,6 +1206,9 @@ func TestKMS_SearchKeys(t *testing.T) { makeResult(platformKeys[2]), }, }, assert.NoError}, + {"fail parse", platformKMS, args{&apiv1.SearchKeysRequest{ + Query: "name=", + }}, nil, assert.Error}, // SoftKMS {"fail softKMS", softKMS, args{&apiv1.SearchKeysRequest{ diff --git a/kms/platform/kms_tpm.go b/kms/platform/kms_tpm.go index 520fb36f..ff7b3ff7 100644 --- a/kms/platform/kms_tpm.go +++ b/kms/platform/kms_tpm.go @@ -14,16 +14,14 @@ import ( var _ apiv1.Attester = (*KMS)(nil) func newTPMKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { - if opts.URI == "" { - return newTPMKMS(ctx, opts) - } - - u, err := parseURI(opts.URI) - if err != nil { - return nil, err + if opts.URI != "" { + u, err := transformToTPMKMS(opts.URI) + if err != nil { + return nil, err + } + opts.URI = u } - opts.URI = transformToTPMKMS(u) km, err := tpmkms.New(ctx, opts) if err != nil { return nil, err @@ -51,7 +49,12 @@ func NewWithTPM(ctx context.Context, t *tpm.TPM, opts ...tpmkms.Option) (*KMS, e }, nil } -func transformToTPMKMS(u *kmsURI) string { +func transformToTPMKMS(rawuri string) (string, error) { + u, err := parseURI(rawuri) + if err != nil { + return "", err + } + uv := url.Values{} if u.name != "" { uv.Set("name", u.name) @@ -61,7 +64,7 @@ func transformToTPMKMS(u *kmsURI) string { // There is not need to set "hw". maps.Copy(uv, u.extraValues) - return uri.New(tpmkms.Scheme, uv).String() + return uri.New(tpmkms.Scheme, uv).String(), nil } func transformFromTPMKMS(rawuri string) (string, error) { diff --git a/kms/platform/kms_tpm_test.go b/kms/platform/kms_tpm_test.go index 31a2304e..b631823e 100644 --- a/kms/platform/kms_tpm_test.go +++ b/kms/platform/kms_tpm_test.go @@ -8,22 +8,25 @@ import ( func Test_transformToTPMKMS(t *testing.T) { tests := []struct { - name string - rawuri string - want string + name string + rawuri string + want string + assertion assert.ErrorAssertionFunc }{ - {"scheme", "kms:", "tpmkms:"}, - {"with name", "kms:name=foo", "tpmkms:name=foo"}, - {"with ak", "kms:name=foo;ak=true", "tpmkms:ak=true;name=foo"}, - {"with ak in query", "kms:name=foo?ak=true", "tpmkms:ak=true;name=foo"}, - {"with ak false", "kms:ak=false", "tpmkms:ak=false"}, - {"with extrasValues", "kms:name=foo;foo=bar?baz=qux", "tpmkms:baz=qux;foo=bar;name=foo"}, - {"without hw", "kms:name=foo;hw=true", "tpmkms:name=foo"}, + {"scheme", "kms:", "tpmkms:", assert.NoError}, + {"with name", "kms:name=foo", "tpmkms:name=foo", assert.NoError}, + {"with ak", "kms:name=foo;ak=true", "tpmkms:ak=true;name=foo", assert.NoError}, + {"with ak in query", "kms:name=foo?ak=true", "tpmkms:ak=true;name=foo", assert.NoError}, + {"with ak false", "kms:ak=false", "tpmkms:ak=false", assert.NoError}, + {"with extrasValues", "kms:name=foo;foo=bar?baz=qux", "tpmkms:baz=qux;foo=bar;name=foo", assert.NoError}, + {"without hw", "kms:name=foo;hw=true", "tpmkms:name=foo", assert.NoError}, + {"fail parse", "mackms:name=foo", "", assert.Error}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - u := mustParseURI(t, tt.rawuri) - assert.Equal(t, tt.want, transformToTPMKMS(u)) + got, err := transformToTPMKMS(tt.rawuri) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) }) } } diff --git a/kms/platform/kms_windows.go b/kms/platform/kms_windows.go index 5958dc4c..ecf942aa 100644 --- a/kms/platform/kms_windows.go +++ b/kms/platform/kms_windows.go @@ -27,7 +27,6 @@ func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { switch u.backend { case apiv1.CAPIKMS: - opts.URI = transformToCAPIKMS(u) return newCAPIKMS(ctx, opts) case apiv1.SoftKMS: return newSoftKMS(ctx, opts) @@ -39,6 +38,14 @@ func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { } func newCAPIKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { + if opts.URI != "" { + u, err := transformToCAPIKMS(opts.URI) + if err != nil { + return nil, err + } + opts.URI = u + } + km, err := capi.New(ctx, opts) if err != nil { return nil, err @@ -52,7 +59,12 @@ func newCAPIKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { }, nil } -func transformToCAPIKMS(u *kmsURI) string { +func transformToCAPIKMS(rawuri string) (string, error) { + u, err := parseURI(rawuri) + if err != nil { + return "", err + } + uv := url.Values{} if u.name != "" { uv.Set("key", u.name) @@ -72,7 +84,7 @@ func transformToCAPIKMS(u *kmsURI) string { // Add custom extra values that might be CAPI specific. maps.Copy(uv, u.extraValues) - return uri.New(capi.Scheme, uv).String() + return uri.New(capi.Scheme, uv).String(), nil } func transformFromCAPIKMS(rawuri string) (string, error) { diff --git a/kms/platform/kms_windows_test.go b/kms/platform/kms_windows_test.go index 1b508fdf..79b159ae 100644 --- a/kms/platform/kms_windows_test.go +++ b/kms/platform/kms_windows_test.go @@ -502,22 +502,25 @@ func TestKMS_SearchKeys_capi(t *testing.T) { func Test_transformToCAPIKMS(t *testing.T) { tests := []struct { - name string - rawuri string - want string + name string + rawuri string + want string + assertion assert.ErrorAssertionFunc }{ - {"scheme", "kms:", "capi:skip-find-certificate-key=true"}, - {"with name", "kms:name=foo", "capi:key=foo;skip-find-certificate-key=true"}, - {"with hw", "kms:name=foo;hw=true", "capi:key=foo;provider=Microsoft+Platform+Crypto+Provider;skip-find-certificate-key=true"}, - {"with hw on query", "kms:name=foo?hw=true", "capi:key=foo;provider=Microsoft+Platform+Crypto+Provider;skip-find-certificate-key=true"}, - {"with skip-find-certificate-key", "kms:name=foo;skip-find-certificate-key=false", "capi:key=foo;skip-find-certificate-key=false"}, - {"with provider", "kms:name=foo;hw=true;provider=my", "capi:key=foo;provider=my;skip-find-certificate-key=true"}, - {"with extrasValues", "kms:name=foo;foo=bar?baz=qux", "capi:baz=qux;foo=bar;key=foo;skip-find-certificate-key=true"}, + {"scheme", "kms:", "capi:skip-find-certificate-key=true", assert.NoError}, + {"with name", "kms:name=foo", "capi:key=foo;skip-find-certificate-key=true", assert.NoError}, + {"with hw", "kms:name=foo;hw=true", "capi:key=foo;provider=Microsoft+Platform+Crypto+Provider;skip-find-certificate-key=true", assert.NoError}, + {"with hw on query", "kms:name=foo?hw=true", "capi:key=foo;provider=Microsoft+Platform+Crypto+Provider;skip-find-certificate-key=true", assert.NoError}, + {"with skip-find-certificate-key", "kms:name=foo;skip-find-certificate-key=false", "capi:key=foo;skip-find-certificate-key=false", assert.NoError}, + {"with provider", "kms:name=foo;hw=true;provider=my", "capi:key=foo;provider=my;skip-find-certificate-key=true", assert.NoError}, + {"with extrasValues", "kms:name=foo;foo=bar?baz=qux", "capi:baz=qux;foo=bar;key=foo;skip-find-certificate-key=true", assert.NoError}, + {"fail parse", "capikms:name=foo", "", assert.Error}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - u := mustParseURI(t, tt.rawuri) - assert.Equal(t, tt.want, transformToCAPIKMS(u)) + got, err := transformToCAPIKMS(tt.rawuri) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) }) } } From d330b6c984ce886fa969119d0a0afc7dce3f1f53 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 2 Mar 2026 12:19:11 -0800 Subject: [PATCH 24/27] Fix linter error --- kms/platform/kms_test.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/kms/platform/kms_test.go b/kms/platform/kms_test.go index 9f09b01a..3c8064a8 100644 --- a/kms/platform/kms_test.go +++ b/kms/platform/kms_test.go @@ -64,12 +64,6 @@ func shouldSkipNow(t *testing.T, km *KMS) { } } -func mustParseURI(t *testing.T, rawuri string) *kmsURI { - u, err := parseURI(rawuri) - require.NoError(t, err) - return u -} - func mustKMS(t *testing.T, rawuri string) *KMS { t.Helper() From a44abd75aa6b216e2e693a6c1b358b548fc98ee5 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 2 Mar 2026 18:35:08 -0800 Subject: [PATCH 25/27] Add uri Set method --- kms/uri/uri.go | 5 +++++ kms/uri/uri_test.go | 13 +++++++++++++ 2 files changed, 18 insertions(+) diff --git a/kms/uri/uri.go b/kms/uri/uri.go index 899adbd6..8dafca4b 100644 --- a/kms/uri/uri.go +++ b/kms/uri/uri.go @@ -220,6 +220,11 @@ func (u *URI) GetHexEncoded(key string) ([]byte, error) { return hex.DecodeString(hx) } +// Set sets the key to value. It replaces any existing values. +func (u *URI) Set(key, value string) { + u.Values.Set(key, value) +} + // Pin returns the pin encoded in the url. It will read the pin from the // pin-value or the pin-source attributes. func (u *URI) Pin() string { diff --git a/kms/uri/uri_test.go b/kms/uri/uri_test.go index 394c5914..9b63bdd6 100644 --- a/kms/uri/uri_test.go +++ b/kms/uri/uri_test.go @@ -435,6 +435,19 @@ func TestURI_GetHexEncoded(t *testing.T) { } } +func TestURI_Set(t *testing.T) { + u := mustParse(t, "kms:name=foo") + assert.Equal(t, "", u.Get("key")) + + u.Set("key", "bar") + assert.Equal(t, "bar", u.Get("key")) + assert.Equal(t, "kms:key=bar;name=foo", u.String()) + + u.Set("key", "zar") + assert.Equal(t, "zar", u.Get("key")) + assert.Equal(t, "kms:key=zar;name=foo", u.String()) +} + func TestURI_Read(t *testing.T) { // Read does not trim the contents of the file expected := []byte("trim-this-pin \n") From 9395e7eec268b21327f511333df7b0d8f283704e Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 2 Mar 2026 18:52:37 -0800 Subject: [PATCH 26/27] Enable CNG by default --- kms/platform/kms_test.go | 12 ++++++------ kms/platform/kms_tpm.go | 7 +++++++ kms/platform/kms_tpm_test.go | 28 +++++++++++++++++++++------- kms/platform/kms_windows.go | 12 ++++++++++++ kms/tpmkms/tpmkms.go | 20 +++++++++++++++++++- 5 files changed, 65 insertions(+), 14 deletions(-) diff --git a/kms/platform/kms_test.go b/kms/platform/kms_test.go index 3c8064a8..98aa0304 100644 --- a/kms/platform/kms_test.go +++ b/kms/platform/kms_test.go @@ -811,7 +811,7 @@ func TestKMS_StoreCertificate(t *testing.T) { Certificate: chain[0], }}, func(tt assert.TestingT, err error, i ...interface{}) bool { // Storing a certificate with no key is not supported on TPMKMS. - if platformKMS.Type() == apiv1.TPMKMS { + if platformKMS.Type() == apiv1.TPMKMS && runtime.GOOS != "windows" { return assert.Error(t, err) } @@ -822,9 +822,9 @@ func TestKMS_StoreCertificate(t *testing.T) { }) return assert.NoError(t, err) }}, - {"fail platform bad certificate", platformKMS, args{&apiv1.StoreCertificateRequest{ + {"fail platform no certificate", platformKMS, args{&apiv1.StoreCertificateRequest{ Name: platformCertName, - Certificate: &x509.Certificate{}, + Certificate: nil, }}, assert.Error}, // SoftKMS @@ -958,7 +958,7 @@ func TestKMS_StoreCertificateChain(t *testing.T) { CertificateChain: chain, }}, func(tt assert.TestingT, err error, i ...interface{}) bool { // Storing a certificate with no key is not supported on TPMKMS. - if platformKMS.Type() == apiv1.TPMKMS { + if platformKMS.Type() == apiv1.TPMKMS && runtime.GOOS != "windows" { return assert.Error(t, err) } @@ -967,9 +967,10 @@ func TestKMS_StoreCertificateChain(t *testing.T) { Name: platformCertName, })) - if typ := platformKMS.Type(); typ == apiv1.MacKMS { + if typ := platformKMS.Type(); typ == apiv1.MacKMS || (typ == apiv1.TPMKMS && runtime.GOOS == "windows") { assert.NoError(t, platformKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ Name: uri.New(Scheme, url.Values{ + "issuer": []string{platformChain[1].Issuer.CommonName}, // for windows only "serial": []string{hex.EncodeToString(platformChain[1].SerialNumber.Bytes())}, }).String(), })) @@ -1016,7 +1017,6 @@ func TestKMS_StoreCertificateChain(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { shouldSkipNow(t, tt.kms) - tt.assertion(t, tt.kms.StoreCertificateChain(tt.args.req)) }) } diff --git a/kms/platform/kms_tpm.go b/kms/platform/kms_tpm.go index ff7b3ff7..658d3ba3 100644 --- a/kms/platform/kms_tpm.go +++ b/kms/platform/kms_tpm.go @@ -4,6 +4,7 @@ import ( "context" "maps" "net/url" + "runtime" "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/tpmkms" @@ -60,6 +61,12 @@ func transformToTPMKMS(rawuri string) (string, error) { uv.Set("name", u.name) } + // When storing a certificate on windows, skip key validation. This avoid a + // prompt looking for an SmartCard. + if runtime.GOOS == "windows" && !u.extraValues.Has("skip-find-certificate-key") { + uv.Set("skip-find-certificate-key", "true") + } + // Add custom extra values that might be tpmkms specific. // There is not need to set "hw". maps.Copy(uv, u.extraValues) diff --git a/kms/platform/kms_tpm_test.go b/kms/platform/kms_tpm_test.go index b631823e..ae47d84e 100644 --- a/kms/platform/kms_tpm_test.go +++ b/kms/platform/kms_tpm_test.go @@ -1,25 +1,39 @@ package platform import ( + "runtime" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/kms/uri" ) func Test_transformToTPMKMS(t *testing.T) { + fixWindowsURI := func(s string) string { + if runtime.GOOS != "windows" { + return s + } + + u, err := uri.Parse(s) + require.NoError(t, err) + u.Set("skip-find-certificate-key", "true") + return u.String() + } + tests := []struct { name string rawuri string want string assertion assert.ErrorAssertionFunc }{ - {"scheme", "kms:", "tpmkms:", assert.NoError}, - {"with name", "kms:name=foo", "tpmkms:name=foo", assert.NoError}, - {"with ak", "kms:name=foo;ak=true", "tpmkms:ak=true;name=foo", assert.NoError}, - {"with ak in query", "kms:name=foo?ak=true", "tpmkms:ak=true;name=foo", assert.NoError}, - {"with ak false", "kms:ak=false", "tpmkms:ak=false", assert.NoError}, - {"with extrasValues", "kms:name=foo;foo=bar?baz=qux", "tpmkms:baz=qux;foo=bar;name=foo", assert.NoError}, - {"without hw", "kms:name=foo;hw=true", "tpmkms:name=foo", assert.NoError}, + {"scheme", "kms:", fixWindowsURI("tpmkms:"), assert.NoError}, + {"with name", "kms:name=foo", fixWindowsURI("tpmkms:name=foo"), assert.NoError}, + {"with ak", "kms:name=foo;ak=true", fixWindowsURI("tpmkms:ak=true;name=foo"), assert.NoError}, + {"with ak in query", "kms:name=foo?ak=true", fixWindowsURI("tpmkms:ak=true;name=foo"), assert.NoError}, + {"with ak false", "kms:ak=false", fixWindowsURI("tpmkms:ak=false"), assert.NoError}, + {"with extrasValues", "kms:name=foo;foo=bar?baz=qux", fixWindowsURI("tpmkms:baz=qux;foo=bar;name=foo"), assert.NoError}, + {"without hw", "kms:name=foo;hw=true", fixWindowsURI("tpmkms:name=foo"), assert.NoError}, {"fail parse", "mackms:name=foo", "", assert.Error}, } for _, tt := range tests { diff --git a/kms/platform/kms_windows.go b/kms/platform/kms_windows.go index ecf942aa..05b99548 100644 --- a/kms/platform/kms_windows.go +++ b/kms/platform/kms_windows.go @@ -17,6 +17,7 @@ const tpmProvider = "Microsoft Platform Crypto Provider" func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { if opts.URI == "" { + opts.URI = withEnableCNG(nil) return newTPMKMS(ctx, opts) } @@ -31,6 +32,7 @@ func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { case apiv1.SoftKMS: return newSoftKMS(ctx, opts) case apiv1.DefaultKMS, apiv1.TPMKMS: + opts.URI = withEnableCNG(u.uri) return newTPMKMS(ctx, opts) default: return nil, fmt.Errorf("failed parsing %q: unsupported backend %q", opts.URI, u.backend) @@ -59,6 +61,16 @@ func newCAPIKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { }, nil } +func withEnableCNG(u *uri.URI) string { + if u == nil { + return "kms:enable-cng=true" + } + if !u.Has("enable-cng") { + u.Set("enable-cng", "true") + } + return u.String() +} + func transformToCAPIKMS(rawuri string) (string, error) { u, err := parseURI(rawuri) if err != nil { diff --git a/kms/tpmkms/tpmkms.go b/kms/tpmkms/tpmkms.go index 370a2848..d3644749 100644 --- a/kms/tpmkms/tpmkms.go +++ b/kms/tpmkms/tpmkms.go @@ -29,6 +29,7 @@ import ( "go.step.sm/crypto/tpm/attestation" "go.step.sm/crypto/tpm/storage" "go.step.sm/crypto/tpm/tss2" + "go.step.sm/crypto/x509util" ) func init() { @@ -1059,8 +1060,14 @@ func (k *TPMKMS) deleteCertificateFromWindowsCertificateStore(req *apiv1.DeleteC uv.Set("key-id", o.keyID) case o.sha1 != "": uv.Set("sha1", o.sha1) + case o.name != "": + keyID, err := k.getSubjectKeyID(req.Name) + if err != nil { + return fmt.Errorf("error getting key-id: %w", err) + } + uv.Set("key-id", hex.EncodeToString(keyID)) default: - return errors.New(`at least one of "serial", "key-id" or "sha1" is expected to be set`) + return errors.New(`at least one of "serial", "key-id", "sha1" or "name" is expected to be set`) } dk, ok := k.windowsCertificateManager.(deletingCertificateManager) @@ -1071,12 +1078,23 @@ func (k *TPMKMS) deleteCertificateFromWindowsCertificateStore(req *apiv1.DeleteC if err := dk.DeleteCertificate(&apiv1.DeleteCertificateRequest{ Name: uri.New("capi", uv).String(), }); err != nil { + fmt.Println(uri.New("capi", uv).String()) return fmt.Errorf("failed deleting certificate using Windows platform cryptography provider: %w", err) } return nil } +func (k *TPMKMS) getSubjectKeyID(name string) ([]byte, error) { + key, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{ + Name: name, + }) + if err != nil { + return nil, err + } + return x509util.GenerateSubjectKeyID(key) +} + // attestationClient is a wrapper for [attestation.Client], containing // all of the required references to perform attestation against the // Smallstep Attestation CA. From cef679f5ae652aca235c8b34dd6fd513543f7818 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 3 Mar 2026 12:25:25 -0800 Subject: [PATCH 27/27] Address comments on code review --- kms/apiv1/requests.go | 10 ++++++++++ kms/platform/kms_darwin.go | 3 ++- kms/platform/kms_darwin_test.go | 2 ++ kms/platform/kms_tpmsimulator_test.go | 3 +-- kms/tpmkms/tpmkms.go | 1 - 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/kms/apiv1/requests.go b/kms/apiv1/requests.go index f4bf1162..6fb5b00f 100644 --- a/kms/apiv1/requests.go +++ b/kms/apiv1/requests.go @@ -268,11 +268,21 @@ type AttestationClient interface { type attestSignerCtx struct{} // NewAttestSignerContext creates a new context with the given signer. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a later +// release. func NewAttestSignerContext(ctx context.Context, signer crypto.Signer) context.Context { return context.WithValue(ctx, attestSignerCtx{}, signer) } // AttestSignerFromContext returns the signer from the context. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a later +// release. func AttestSignerFromContext(ctx context.Context) (crypto.Signer, bool) { signer, ok := ctx.Value(attestSignerCtx{}).(crypto.Signer) return signer, ok diff --git a/kms/platform/kms_darwin.go b/kms/platform/kms_darwin.go index e43864cb..01bd6744 100644 --- a/kms/platform/kms_darwin.go +++ b/kms/platform/kms_darwin.go @@ -5,6 +5,7 @@ import ( "fmt" "maps" "net/url" + "strings" "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/mackms" @@ -72,7 +73,7 @@ func transformToMacKMS(rawuri string) (string, error) { if !u.uri.Has("keychain") { uv.Set("keychain", "dataProtection") } - } else if u.uri.Has("hw") { + } else if strings.EqualFold(u.uri.Get("hw"), "false") { uv.Set("se", "false") } diff --git a/kms/platform/kms_darwin_test.go b/kms/platform/kms_darwin_test.go index 2f491470..b4443e3b 100644 --- a/kms/platform/kms_darwin_test.go +++ b/kms/platform/kms_darwin_test.go @@ -28,8 +28,10 @@ func Test_transformToMacKMS(t *testing.T) { {"scheme", "kms:", "mackms:", assert.NoError}, {"with name", "kms:name=foo", "mackms:label=foo", assert.NoError}, {"with hw", "kms:name=foo;hw=true", "mackms:keychain=dataProtection;label=foo;se=true", assert.NoError}, + {"with hw false", "kms:name=foo;hw=false", "mackms:label=foo;se=false", assert.NoError}, {"with hw on query", "kms:name=foo?hw=true", "mackms:keychain=dataProtection;label=foo;se=true", assert.NoError}, {"with hw and keychain", "kms:name=foo;hw=true;keychain=my", "mackms:keychain=my;label=foo;se=true", assert.NoError}, + {"with hw other", "kms:name=foo;hw=other", "mackms:label=foo", assert.NoError}, {"with extrasValues", "kms:name=foo;keychain=my?foo=bar&baz=qux", "mackms:baz=qux;foo=bar;keychain=my;label=foo", assert.NoError}, {"fail parse", "softkms:name=foo", "", assert.Error}, } diff --git a/kms/platform/kms_tpmsimulator_test.go b/kms/platform/kms_tpmsimulator_test.go index 878936a8..f599a345 100644 --- a/kms/platform/kms_tpmsimulator_test.go +++ b/kms/platform/kms_tpmsimulator_test.go @@ -3,7 +3,6 @@ package platform import ( - "context" "crypto" "crypto/x509" "crypto/x509/pkix" @@ -57,7 +56,7 @@ func mustTPMDevice(t *testing.T) (*tpm.TPM, string, string) { listener := &net.ListenConfig{} socket := filepath.Join(dir, "tpm.sock") - ln, err := listener.Listen(context.TODO(), "unix", socket) + ln, err := listener.Listen(t.Context(), "unix", socket) require.NoError(t, err) go func() { diff --git a/kms/tpmkms/tpmkms.go b/kms/tpmkms/tpmkms.go index d3644749..f766e4e0 100644 --- a/kms/tpmkms/tpmkms.go +++ b/kms/tpmkms/tpmkms.go @@ -1078,7 +1078,6 @@ func (k *TPMKMS) deleteCertificateFromWindowsCertificateStore(req *apiv1.DeleteC if err := dk.DeleteCertificate(&apiv1.DeleteCertificateRequest{ Name: uri.New("capi", uv).String(), }); err != nil { - fmt.Println(uri.New("capi", uv).String()) return fmt.Errorf("failed deleting certificate using Windows platform cryptography provider: %w", err) }