diff --git a/Makefile b/Makefile index bec0b10a..968e8b18 100644 --- a/Makefile +++ b/Makefile @@ -45,7 +45,7 @@ race: ######################################### fmt: - $Q goimports -l -w $(SRC) + $Q goimports --local go.step.sm/crypto -l -w $(SRC) lint: golint govulncheck diff --git a/go.mod b/go.mod index c2d6ceaa..ba1d5000 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module go.step.sm/crypto -go 1.24.0 +go 1.25.0 require ( cloud.google.com/go/kms v1.25.0 @@ -24,7 +24,7 @@ require ( github.com/stretchr/testify v1.11.1 go.uber.org/mock v0.6.0 golang.org/x/crypto v0.48.0 - golang.org/x/net v0.50.0 + golang.org/x/net v0.51.0 golang.org/x/sys v0.41.0 golang.org/x/term v0.40.0 google.golang.org/api v0.264.0 diff --git a/go.sum b/go.sum index 9fa9c9ba..c7d98e9f 100644 --- a/go.sum +++ b/go.sum @@ -1033,8 +1033,8 @@ golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= -golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= +golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= +golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181106182150-f42d05182288/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= diff --git a/internal/templates/funcmap.go b/internal/templates/funcmap.go index 45c13759..76b9989e 100644 --- a/internal/templates/funcmap.go +++ b/internal/templates/funcmap.go @@ -7,6 +7,7 @@ import ( "time" "github.com/Masterminds/sprig/v3" + "go.step.sm/crypto/jose" ) diff --git a/jose/encrypt.go b/jose/encrypt.go index 9b61a5f4..c412291e 100644 --- a/jose/encrypt.go +++ b/jose/encrypt.go @@ -4,6 +4,7 @@ import ( "encoding/json" "github.com/pkg/errors" + "go.step.sm/crypto/randutil" ) diff --git a/jose/generate.go b/jose/generate.go index 4bdc6c44..73329b0f 100644 --- a/jose/generate.go +++ b/jose/generate.go @@ -9,6 +9,7 @@ import ( "encoding/base64" "github.com/pkg/errors" + "go.step.sm/crypto/keyutil" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x25519" diff --git a/jose/parse.go b/jose/parse.go index 1236cfaf..b3832393 100644 --- a/jose/parse.go +++ b/jose/parse.go @@ -17,6 +17,7 @@ import ( "time" "github.com/pkg/errors" + "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x25519" ) diff --git a/jose/types.go b/jose/types.go index 0ff48092..b83749d4 100644 --- a/jose/types.go +++ b/jose/types.go @@ -11,6 +11,7 @@ import ( jose "github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3/cryptosigner" "github.com/go-jose/go-jose/v3/jwt" + "go.step.sm/crypto/x25519" ) diff --git a/jose/types_test.go b/jose/types_test.go index 2c06cde6..961ee9ab 100644 --- a/jose/types_test.go +++ b/jose/types_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/pkg/errors" + "go.step.sm/crypto/x25519" ) diff --git a/jose/validate.go b/jose/validate.go index 6a904167..b8bf99a9 100644 --- a/jose/validate.go +++ b/jose/validate.go @@ -12,8 +12,9 @@ import ( "os" "github.com/pkg/errors" - "go.step.sm/crypto/keyutil" "golang.org/x/crypto/ssh" + + "go.step.sm/crypto/keyutil" ) // ValidateSSHPOP validates the given SSH certificate and key for use in an diff --git a/jose/x25519.go b/jose/x25519.go index 25e90e8a..b63d47a8 100644 --- a/jose/x25519.go +++ b/jose/x25519.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/pkg/errors" + "go.step.sm/crypto/x25519" ) diff --git a/keyutil/key.go b/keyutil/key.go index 171cdf3f..a8ec53d8 100644 --- a/keyutil/key.go +++ b/keyutil/key.go @@ -14,8 +14,9 @@ import ( "sync/atomic" "github.com/pkg/errors" - "go.step.sm/crypto/x25519" "golang.org/x/crypto/ssh" + + "go.step.sm/crypto/x25519" ) var ( diff --git a/kms/apiv1/options.go b/kms/apiv1/options.go index 3b50b942..5733fb4c 100644 --- a/kms/apiv1/options.go +++ b/kms/apiv1/options.go @@ -18,6 +18,17 @@ type KeyManager interface { Close() error } +// KeyDeleter is an optional interface for KMS implementations that support +// deleting keys. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a later +// release. +type KeyDeleter interface { + DeleteKey(req *DeleteKeyRequest) error +} + // SearchableKeyManager is an optional interface for KMS implementations // that support searching for keys based on certain attributes. // @@ -54,6 +65,17 @@ type CertificateChainManager interface { StoreCertificateChain(req *StoreCertificateChainRequest) error } +// CertificateDeleter is an optional interface for KMS implementations that +// support deleting certificates. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a later +// release. +type CertificateDeleter interface { + DeleteCertificate(req *DeleteCertificateRequest) error +} + // NameValidator is an interface that KeyManager can implement to validate a // given name or URI. type NameValidator interface { @@ -151,6 +173,9 @@ const ( TPMKMS Type = "tpmkms" // MacKMS is the KMS implementation using macOS Keychain and Secure Enclave. MacKMS Type = "mackms" + // PlatformKMS is the KMS implementation that uses TPMKMS on Windows and + // Linux and MacKMS on macOS. + PlatformKMS Type = "kms" ) // TypeOf returns the type of of the given uri. @@ -181,7 +206,7 @@ func (t Type) Validate() error { return nil case YubiKey, PKCS11, TPMKMS: // Hardware based kms. return nil - case SSHAgentKMS, CAPIKMS, MacKMS: // Others + case SSHAgentKMS, CAPIKMS, MacKMS, PlatformKMS: // Others return nil } diff --git a/kms/apiv1/requests.go b/kms/apiv1/requests.go index 3955f6c0..6fb5b00f 100644 --- a/kms/apiv1/requests.go +++ b/kms/apiv1/requests.go @@ -265,6 +265,29 @@ type AttestationClient interface { Attest(context.Context) ([]*x509.Certificate, error) } +type attestSignerCtx struct{} + +// NewAttestSignerContext creates a new context with the given signer. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a later +// release. +func NewAttestSignerContext(ctx context.Context, signer crypto.Signer) context.Context { + return context.WithValue(ctx, attestSignerCtx{}, signer) +} + +// AttestSignerFromContext returns the signer from the context. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a later +// release. +func AttestSignerFromContext(ctx context.Context) (crypto.Signer, bool) { + signer, ok := ctx.Value(attestSignerCtx{}).(crypto.Signer) + return signer, ok +} + // CertificationParameters encapsulates the inputs for certifying an application key. // Only TPM 2.0 is supported at this point. // diff --git a/kms/apiv1/requests_test.go b/kms/apiv1/requests_test.go index b378e631..199713ac 100644 --- a/kms/apiv1/requests_test.go +++ b/kms/apiv1/requests_test.go @@ -1,6 +1,13 @@ package apiv1 -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.step.sm/crypto/keyutil" +) func TestProtectionLevel_String(t *testing.T) { tests := []struct { @@ -49,3 +56,17 @@ func TestSignatureAlgorithm_String(t *testing.T) { }) } } + +func TestNewAttestSignerContext(t *testing.T) { + signer, err := keyutil.GenerateDefaultSigner() + require.NoError(t, err) + + ctx := NewAttestSignerContext(t.Context(), signer) + got, ok := AttestSignerFromContext(ctx) + assert.Equal(t, signer, got) + assert.True(t, ok) + + got, ok = AttestSignerFromContext(t.Context()) + assert.Nil(t, got) + assert.False(t, ok) +} diff --git a/kms/awskms/awskms.go b/kms/awskms/awskms.go index 5ea34caa..782dea3b 100644 --- a/kms/awskms/awskms.go +++ b/kms/awskms/awskms.go @@ -13,6 +13,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/aws/aws-sdk-go-v2/service/kms/types" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" "go.step.sm/crypto/pemutil" diff --git a/kms/awskms/awskms_test.go b/kms/awskms/awskms_test.go index b3763a86..fc411a98 100644 --- a/kms/awskms/awskms_test.go +++ b/kms/awskms/awskms_test.go @@ -12,6 +12,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kms/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/pemutil" ) diff --git a/kms/awskms/no_awskms.go b/kms/awskms/no_awskms.go index eff7215b..d2db445a 100644 --- a/kms/awskms/no_awskms.go +++ b/kms/awskms/no_awskms.go @@ -8,6 +8,7 @@ import ( "path/filepath" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/awskms/signer.go b/kms/awskms/signer.go index d24af3da..3ec8935f 100644 --- a/kms/awskms/signer.go +++ b/kms/awskms/signer.go @@ -11,6 +11,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/aws/aws-sdk-go-v2/service/kms/types" "github.com/pkg/errors" + "go.step.sm/crypto/pemutil" ) diff --git a/kms/awskms/signer_test.go b/kms/awskms/signer_test.go index 6d7edea0..bc002762 100644 --- a/kms/awskms/signer_test.go +++ b/kms/awskms/signer_test.go @@ -13,6 +13,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/aws/aws-sdk-go-v2/service/kms/types" + "go.step.sm/crypto/pemutil" ) diff --git a/kms/azurekms/key_vault.go b/kms/azurekms/key_vault.go index 5e718751..384a4b41 100644 --- a/kms/azurekms/key_vault.go +++ b/kms/azurekms/key_vault.go @@ -14,6 +14,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" ) diff --git a/kms/azurekms/key_vault_test.go b/kms/azurekms/key_vault_test.go index 77db69a5..aa51b169 100644 --- a/kms/azurekms/key_vault_test.go +++ b/kms/azurekms/key_vault_test.go @@ -15,10 +15,11 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" "github.com/go-jose/go-jose/v3" + "go.uber.org/mock/gomock" + "go.step.sm/crypto/keyutil" "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/azurekms/internal/mock" - "go.uber.org/mock/gomock" ) var errTest = fmt.Errorf("test error") diff --git a/kms/azurekms/no_azurekms.go b/kms/azurekms/no_azurekms.go index 7e3eea6c..b1403804 100644 --- a/kms/azurekms/no_azurekms.go +++ b/kms/azurekms/no_azurekms.go @@ -8,6 +8,7 @@ import ( "path/filepath" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/azurekms/signer_test.go b/kms/azurekms/signer_test.go index 4d894230..28a1bc6c 100644 --- a/kms/azurekms/signer_test.go +++ b/kms/azurekms/signer_test.go @@ -12,11 +12,12 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" - "go.step.sm/crypto/keyutil" - "go.step.sm/crypto/kms/apiv1" "go.uber.org/mock/gomock" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" + + "go.step.sm/crypto/keyutil" + "go.step.sm/crypto/kms/apiv1" ) type FuncMatcher func(x interface{}) bool diff --git a/kms/azurekms/utils.go b/kms/azurekms/utils.go index 9de8f05c..293f6c06 100644 --- a/kms/azurekms/utils.go +++ b/kms/azurekms/utils.go @@ -17,6 +17,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" ) diff --git a/kms/azurekms/utils_test.go b/kms/azurekms/utils_test.go index 63953646..6df6629a 100644 --- a/kms/azurekms/utils_test.go +++ b/kms/azurekms/utils_test.go @@ -13,6 +13,7 @@ import ( "testing" "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/capi/capi.go b/kms/capi/capi.go index 3a7ea50d..22555195 100644 --- a/kms/capi/capi.go +++ b/kms/capi/capi.go @@ -32,6 +32,7 @@ import ( "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" "go.step.sm/crypto/randutil" + "go.step.sm/crypto/x509util" ) // Scheme is the scheme used in uris, the string "capi". @@ -88,7 +89,7 @@ type uriAttributes struct { intermediateStoreName string keyID []byte subjectCN string - serialNumber string + serialNumber *big.Int issuerName string keySpec string skipFindCertificateKey bool @@ -115,6 +116,11 @@ func parseURI(rawuri string) (*uriAttributes, error) { } } + serialNumber, err := u.GetBigInt(SerialNumberArg) + if err != nil { + return nil, fmt.Errorf("failed getting %s from URI %q: %w", SerialNumberArg, rawuri, err) + } + return &uriAttributes{ containerName: u.Get(ContainerNameArg), hash: hashValue, @@ -124,7 +130,7 @@ func parseURI(rawuri string) (*uriAttributes, error) { intermediateStoreName: cmp.Or(u.Get(IntermediateStoreNameArg), CAStore), keyID: keyIDValue, subjectCN: u.Get(SubjectCNArg), - serialNumber: u.Get(SerialNumberArg), + serialNumber: serialNumber, issuerName: u.Get(IssuerNameArg), keySpec: u.Get(KeySpec), skipFindCertificateKey: u.GetBool(SkipFindCertificateKey), @@ -414,25 +420,10 @@ func (k *CAPIKMS) getCertContext(u *uriAttributes) (*windows.CertContext, error) return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%x not found", HashArg, u.hash)} } case len(u.keyID) > 0: - searchData := CERT_ID_KEYIDORHASH{ - idChoice: CERT_ID_KEY_IDENTIFIER, - KeyIDOrHash: CRYPTOAPI_BLOB{ - len: uint32(len(u.keyID)), - data: uintptr(unsafe.Pointer(&u.keyID[0])), - }, - } - handle, err = findCertificateInStore(st, - encodingX509ASN|encodingPKCS7, - 0, - findCertID, - uintptr(unsafe.Pointer(&searchData)), nil) - if err != nil { - return nil, fmt.Errorf("findCertificateInStore failed: %w", err) - } - if handle == nil { - return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%x not found", KeyIDArg, u.keyID)} + if handle, err = findCertificateBySubjectKeyID(st, u.keyID); err != nil { + return nil, err } - case u.issuerName != "" && (u.serialNumber != "" || u.subjectCN != ""): + case u.issuerName != "" && (u.serialNumber != nil || u.subjectCN != ""): var prevCert *windows.CertContext for { handle, err = findCertificateInStore(st, @@ -454,27 +445,11 @@ func (k *CAPIKMS) getCertContext(u *uriAttributes) (*windows.CertContext, error) } switch { - case len(u.serialNumber) > 0: + case u.serialNumber != nil: // TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number - var bi *big.Int - if strings.HasPrefix(u.serialNumber, "0x") { - serialBytes, err := hex.DecodeString(strings.TrimPrefix(u.serialNumber, "0x")) - if err != nil { - return nil, fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err) - } - - bi = new(big.Int).SetBytes(serialBytes) - } else { - bi := new(big.Int) - bi, ok := bi.SetString(u.serialNumber, 10) - if !ok { - return nil, fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg) - } - } - - if x509Cert.SerialNumber.Cmp(bi) == 0 { + if x509Cert.SerialNumber.Cmp(u.serialNumber) == 0 { return handle, nil } case len(u.subjectCN) > 0: @@ -485,6 +460,22 @@ func (k *CAPIKMS) getCertContext(u *uriAttributes) (*windows.CertContext, error) prevCert = handle } + case u.containerName != "": + key, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{ + Name: uri.New(Scheme, url.Values{ + ContainerNameArg: []string{u.containerName}, + }).String(), + }) + if err != nil { + return nil, err + } + keyID, err := x509util.GenerateSubjectKeyID(key) + if err != nil { + return nil, fmt.Errorf("error generating SubjectKeyID: %w", err) + } + if handle, err = findCertificateBySubjectKeyID(st, keyID); err != nil { + return nil, err + } default: return nil, fmt.Errorf("%q, %q, or %q and one of %q or %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg, SubjectCNArg) } @@ -964,50 +955,22 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { } return nil case len(u.keyID) > 0: - searchData := CERT_ID_KEYIDORHASH{ - idChoice: CERT_ID_KEY_IDENTIFIER, - KeyIDOrHash: CRYPTOAPI_BLOB{ - len: uint32(len(u.keyID)), - data: uintptr(unsafe.Pointer(&u.keyID[0])), - }, - } - certHandle, err = findCertificateInStore(st, - encodingX509ASN|encodingPKCS7, - 0, - findCertID, - uintptr(unsafe.Pointer(&searchData)), nil) + certHandle, err = findCertificateBySubjectKeyID(st, u.keyID) if err != nil { - return fmt.Errorf("findCertificateInStore failed: %w", err) - } - if certHandle == nil { - return nil + return err } - if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil { return fmt.Errorf("failed removing certificate: %w", err) } return nil - case u.issuerName != "" && u.serialNumber != "": + case u.issuerName != "" && u.serialNumber != nil: // TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number - var serialBytes []byte - if strings.HasPrefix(u.serialNumber, "0x") { - u.serialNumber = strings.TrimPrefix(u.serialNumber, "0x") - u.serialNumber = strings.TrimPrefix(u.serialNumber, "00") // Comparison fails if leading 00 is not removed - serialBytes, err = hex.DecodeString(u.serialNumber) - if err != nil { - return fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err) - } - } else { - bi := new(big.Int) - bi, ok := bi.SetString(u.serialNumber, 10) - if !ok { - return fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg) - } - serialBytes = bi.Bytes() - } - var prevCert *windows.CertContext + var ( + prevCert *windows.CertContext + serialBytes = u.serialNumber.Bytes() + ) for { certHandle, err = findCertificateInStore(st, encodingX509ASN|encodingPKCS7, @@ -1036,6 +999,27 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { } prevCert = certHandle } + case u.containerName != "": + key, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{ + Name: uri.New(Scheme, url.Values{ + ContainerNameArg: []string{u.containerName}, + }).String(), + }) + if err != nil { + return err + } + keyID, err := x509util.GenerateSubjectKeyID(key) + if err != nil { + return fmt.Errorf("error generating SubjectKeyID: %w", err) + } + certHandle, err = findCertificateBySubjectKeyID(st, keyID) + if err != nil { + return err + } + if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil { + return fmt.Errorf("failed removing certificate: %w", err) + } + return nil default: return fmt.Errorf("%q, %q, or %q and %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg) } diff --git a/kms/capi/capi_no_windows.go b/kms/capi/capi_no_windows.go index be7e896b..b32563c2 100644 --- a/kms/capi/capi_no_windows.go +++ b/kms/capi/capi_no_windows.go @@ -6,6 +6,7 @@ import ( "context" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/capi/ncrypt_windows.go b/kms/capi/ncrypt_windows.go index 9366f08b..0126d99c 100644 --- a/kms/capi/ncrypt_windows.go +++ b/kms/capi/ncrypt_windows.go @@ -13,6 +13,8 @@ import ( "unsafe" "golang.org/x/sys/windows" + + "go.step.sm/crypto/kms/apiv1" ) const ( @@ -518,6 +520,31 @@ func nCryptExportKey(kh uintptr, blobType string) ([]byte, error) { return buf, nil } +func findCertificateBySubjectKeyID(store windows.Handle, keyID []byte) (*windows.CertContext, error) { + searchData := CERT_ID_KEYIDORHASH{ + idChoice: CERT_ID_KEY_IDENTIFIER, + KeyIDOrHash: CRYPTOAPI_BLOB{ + len: uint32(len(keyID)), + data: uintptr(unsafe.Pointer(&keyID[0])), + }, + } + + handle, err := findCertificateInStore(store, + encodingX509ASN|encodingPKCS7, + 0, + findCertID, + uintptr(unsafe.Pointer(&searchData)), nil) + if err != nil { + return nil, fmt.Errorf("findCertificateInStore failed: %w", err) + } + + if handle == nil { + return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%x not found", KeyIDArg, keyID)} + } + + return handle, nil +} + func findCertificateInStore(store windows.Handle, enc, findFlags, findType uint32, para uintptr, prev *windows.CertContext) (*windows.CertContext, error) { h, _, err := procCertFindCertificateInStore.Call( uintptr(store), diff --git a/kms/capi/no_capi.go b/kms/capi/no_capi.go index 1eec688f..16d6468d 100644 --- a/kms/capi/no_capi.go +++ b/kms/capi/no_capi.go @@ -8,6 +8,7 @@ import ( "path/filepath" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/cloudkms/cloudkms.go b/kms/cloudkms/cloudkms.go index 0092fed9..adae96f5 100644 --- a/kms/cloudkms/cloudkms.go +++ b/kms/cloudkms/cloudkms.go @@ -18,10 +18,11 @@ import ( "cloud.google.com/go/kms/apiv1/kmspb" gax "github.com/googleapis/gax-go/v2" "github.com/pkg/errors" + "google.golang.org/api/option" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" "go.step.sm/crypto/pemutil" - "google.golang.org/api/option" ) // Scheme is the scheme used in uris, the string "cloudkms". diff --git a/kms/cloudkms/cloudkms_test.go b/kms/cloudkms/cloudkms_test.go index e32edb02..b935c1b9 100644 --- a/kms/cloudkms/cloudkms_test.go +++ b/kms/cloudkms/cloudkms_test.go @@ -14,13 +14,14 @@ import ( gax "github.com/googleapis/gax-go/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.step.sm/crypto/kms/apiv1" - "go.step.sm/crypto/kms/uri" - "go.step.sm/crypto/pemutil" "google.golang.org/api/option" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/durationpb" + + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/uri" + "go.step.sm/crypto/pemutil" ) func TestParent(t *testing.T) { diff --git a/kms/cloudkms/decrypter_test.go b/kms/cloudkms/decrypter_test.go index a756cbae..815698fd 100644 --- a/kms/cloudkms/decrypter_test.go +++ b/kms/cloudkms/decrypter_test.go @@ -14,9 +14,10 @@ import ( gax "github.com/googleapis/gax-go/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/wrapperspb" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/pemutil" - "google.golang.org/protobuf/types/known/wrapperspb" ) func TestCloudKMS_CreateDecrypter(t *testing.T) { diff --git a/kms/cloudkms/no_cloudkms.go b/kms/cloudkms/no_cloudkms.go index 31dcc57c..ac4df003 100644 --- a/kms/cloudkms/no_cloudkms.go +++ b/kms/cloudkms/no_cloudkms.go @@ -8,6 +8,7 @@ import ( "path/filepath" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/cloudkms/signer.go b/kms/cloudkms/signer.go index b6f666b1..16fb712d 100644 --- a/kms/cloudkms/signer.go +++ b/kms/cloudkms/signer.go @@ -9,6 +9,7 @@ import ( "cloud.google.com/go/kms/apiv1/kmspb" "github.com/pkg/errors" + "go.step.sm/crypto/pemutil" ) diff --git a/kms/cloudkms/signer_test.go b/kms/cloudkms/signer_test.go index 0b391b04..15664b4c 100644 --- a/kms/cloudkms/signer_test.go +++ b/kms/cloudkms/signer_test.go @@ -15,6 +15,7 @@ import ( gax "github.com/googleapis/gax-go/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.step.sm/crypto/pemutil" ) diff --git a/kms/kms.go b/kms/kms.go index c5ad7042..b530a1be 100644 --- a/kms/kms.go +++ b/kms/kms.go @@ -4,6 +4,7 @@ import ( "context" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" // Enable default implementation diff --git a/kms/mackms/mackms.go b/kms/mackms/mackms.go index 703d9d20..c2c28003 100644 --- a/kms/mackms/mackms.go +++ b/kms/mackms/mackms.go @@ -1276,17 +1276,15 @@ func parseCertURI(rawuri string, useDataProtectionKeychain, requireValue bool) ( } // With regular values, uris look like this: - // mackms:label=my-cert;serial=01020A0B... + // mackms:label=my-cert;serial=0x01020A0B... label := u.Get("label") keychain := u.Get("keychain") - serial := u.GetEncoded("serial") - if requireValue && label == "" && len(serial) == 0 { - return nil, fmt.Errorf("error parsing %q: label or serial are required", rawuri) + serialNumber, err := u.GetBigInt("serial") + if err != nil { + return nil, fmt.Errorf("error parsing %q: %w", rawuri, err) } - - var serialNumber *big.Int - if len(serial) > 0 { - serialNumber = new(big.Int).SetBytes(serial) + if requireValue && label == "" && serialNumber == nil { + return nil, fmt.Errorf("error parsing %q: label or serial is required", rawuri) } return &certAttributes{ diff --git a/kms/mackms/mackms_test.go b/kms/mackms/mackms_test.go index dd627982..adca47ec 100644 --- a/kms/mackms/mackms_test.go +++ b/kms/mackms/mackms_test.go @@ -40,6 +40,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + cf "go.step.sm/crypto/internal/darwin/corefoundation" "go.step.sm/crypto/internal/darwin/security" "go.step.sm/crypto/kms/apiv1" @@ -849,6 +850,7 @@ func TestMacKMS_LoadCertificate(t *testing.T) { {"fail uri", &MacKMS{}, args{&apiv1.LoadCertificateRequest{Name: "mackms:"}}, nil, assert.Error}, {"fail missing label", &MacKMS{}, args{&apiv1.LoadCertificateRequest{Name: "mackms:label=missing-" + suffix}}, nil, assert.Error}, {"fail missing serial", &MacKMS{}, args{&apiv1.LoadCertificateRequest{Name: "mackms:serial=010a020b030c"}}, nil, assert.Error}, + {"fail bad serial", &MacKMS{}, args{&apiv1.LoadCertificateRequest{Name: "mackms:serial=010a020b030z"}}, nil, assert.Error}, {"fail with keychain", &MacKMS{}, args{&apiv1.LoadCertificateRequest{ Name: "mackms:keychain=dataProtection;label=" + label, }}, nil, assert.Error}, diff --git a/kms/mackms/signer_test.go b/kms/mackms/signer_test.go index 863ba346..c7966c87 100644 --- a/kms/mackms/signer_test.go +++ b/kms/mackms/signer_test.go @@ -30,6 +30,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/pkcs11/pkcs11_no_cgo.go b/kms/pkcs11/pkcs11_no_cgo.go index f1899f47..c063cf20 100644 --- a/kms/pkcs11/pkcs11_no_cgo.go +++ b/kms/pkcs11/pkcs11_no_cgo.go @@ -9,6 +9,7 @@ import ( "path/filepath" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/pkcs11/pkcs11_test.go b/kms/pkcs11/pkcs11_test.go index 1596eff1..d2d9b379 100644 --- a/kms/pkcs11/pkcs11_test.go +++ b/kms/pkcs11/pkcs11_test.go @@ -22,9 +22,10 @@ import ( "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.step.sm/crypto/kms/apiv1" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" + + "go.step.sm/crypto/kms/apiv1" ) func TestNew(t *testing.T) { @@ -213,6 +214,7 @@ func TestNew_config(t *testing.T) { }) } } + func TestPKCS11_GetPublicKey(t *testing.T) { k := setupPKCS11(t) type args struct { diff --git a/kms/pkcs11/setup_test.go b/kms/pkcs11/setup_test.go index 0f437c4a..6075ba7a 100644 --- a/kms/pkcs11/setup_test.go +++ b/kms/pkcs11/setup_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/platform/kms.go b/kms/platform/kms.go new file mode 100644 index 00000000..33698b32 --- /dev/null +++ b/kms/platform/kms.go @@ -0,0 +1,306 @@ +package platform + +import ( + "context" + "crypto" + "crypto/x509" + "errors" + "net/url" + "strings" + + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/uri" +) + +const Scheme = "kms" + +func init() { + apiv1.Register(apiv1.PlatformKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) { + return New(ctx, opts) + }) +} + +const ( + backendKey = "backend" + nameKey = "name" + hwKey = "hw" +) + +type kmsURI struct { + uri *uri.URI + backend apiv1.Type + name string + hw bool + extraValues url.Values +} + +func isDefaultKey(k string) bool { + return k == nameKey || + k == hwKey || + k == backendKey +} + +func parseURI(rawuri string) (*kmsURI, error) { + u, err := uri.ParseWithScheme(Scheme, rawuri) + if err != nil { + return nil, err + } + + extraValues := make(url.Values) + for k, v := range uri.Values(u) { + if !isDefaultKey(k) { + extraValues[k] = v + } + } + + return &kmsURI{ + uri: u, + backend: apiv1.Type(strings.ToLower(u.Get(backendKey))), + name: u.Get(nameKey), + hw: u.GetBool(hwKey), + extraValues: extraValues, + }, nil +} + +type extendedKeyManager interface { + apiv1.KeyManager + apiv1.KeyDeleter + apiv1.CertificateManager + apiv1.CertificateChainManager + apiv1.CertificateDeleter +} + +var _ apiv1.KeyManager = (*KMS)(nil) +var _ apiv1.CertificateManager = (*KMS)(nil) +var _ apiv1.CertificateChainManager = (*KMS)(nil) + +type KMS struct { + typ apiv1.Type + backend extendedKeyManager + transformToURI func(string) (string, error) + transformFromURI func(string) (string, error) +} + +func New(ctx context.Context, opts apiv1.Options) (*KMS, error) { + return newKMS(ctx, opts) +} + +func (k *KMS) Type() apiv1.Type { + return k.typ +} + +func (k *KMS) Close() error { + return k.backend.Close() +} + +func (k *KMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) { + name, err := k.transformToURI(req.Name) + if err != nil { + return nil, err + } + + return k.backend.GetPublicKey(&apiv1.GetPublicKeyRequest{ + Name: name, + }) +} + +func (k *KMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) { + name, err := k.transformToURI(req.Name) + if err != nil { + return nil, err + } + + r := clone(req) + r.Name = name + resp, err := k.backend.CreateKey(r) + if err != nil { + return nil, err + } + + return k.patchCreateKeyResponse(resp) +} + +func (k *KMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) { + if req.Signer != nil { + return req.Signer, nil + } + + signingKey, err := k.transformToURI(req.SigningKey) + if err != nil { + return nil, err + } + + r := clone(req) + r.SigningKey = signingKey + return k.backend.CreateSigner(r) +} + +func (k *KMS) DeleteKey(req *apiv1.DeleteKeyRequest) error { + name, err := k.transformToURI(req.Name) + if err != nil { + return err + } + + r := clone(req) + r.Name = name + return k.backend.DeleteKey(r) +} + +func (k *KMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certificate, error) { + name, err := k.transformToURI(req.Name) + if err != nil { + return nil, err + } + + r := clone(req) + r.Name = name + return k.backend.LoadCertificate(r) +} + +func (k *KMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { + name, err := k.transformToURI(req.Name) + if err != nil { + return err + } + + r := clone(req) + r.Name = name + return k.backend.StoreCertificate(r) +} + +func (k *KMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([]*x509.Certificate, error) { + name, err := k.transformToURI(req.Name) + if err != nil { + return nil, err + } + + r := clone(req) + r.Name = name + return k.backend.LoadCertificateChain(r) +} + +func (k *KMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) error { + name, err := k.transformToURI(req.Name) + if err != nil { + return err + } + + r := clone(req) + r.Name = name + return k.backend.StoreCertificateChain(r) +} + +func (k *KMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { + name, err := k.transformToURI(req.Name) + if err != nil { + return err + } + + r := clone(req) + r.Name = name + return k.backend.DeleteCertificate(r) +} + +func (k *KMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.CreateAttestationResponse, error) { + if req.Name == "" { + return nil, errors.New("createAttestationRequest 'name' cannot be empty") + } + + name, err := k.transformToURI(req.Name) + if err != nil { + return nil, err + } + + // Attestation implemented by the backend + if km, ok := k.backend.(apiv1.Attester); ok { + r := clone(req) + r.Name = name + + return km.CreateAttestation(r) + } + + // Attestation using a custom attestation client. + if req.AttestationClient != nil { + signer, err := k.backend.CreateSigner(&apiv1.CreateSignerRequest{ + SigningKey: name, + }) + if err != nil { + return nil, err + } + + ctx := apiv1.NewAttestSignerContext(context.Background(), signer) + chain, err := req.AttestationClient.Attest(ctx) + if err != nil { + return nil, err + } + + var permanentIdentifier string + if len(chain[0].URIs) > 0 { + permanentIdentifier = chain[0].URIs[0].String() + } + + return &apiv1.CreateAttestationResponse{ + Certificate: chain[0], + CertificateChain: chain, + PublicKey: signer.Public(), + PermanentIdentifier: permanentIdentifier, + }, nil + } + + return nil, apiv1.NotImplementedError{} +} + +func (k *KMS) SearchKeys(req *apiv1.SearchKeysRequest) (*apiv1.SearchKeysResponse, error) { + if km, ok := k.backend.(apiv1.SearchableKeyManager); ok { + query, err := k.transformToURI(req.Query) + if err != nil { + return nil, err + } + + r := clone(req) + r.Query = query + resp, err := km.SearchKeys(r) + if err != nil { + return nil, err + } + + return k.patchSearchKeysResponse(resp) + } + + return nil, apiv1.NotImplementedError{} +} + +func (k *KMS) patchCreateKeyResponse(resp *apiv1.CreateKeyResponse) (*apiv1.CreateKeyResponse, error) { + name, err := k.transformFromURI(resp.Name) + if err != nil { + return nil, err + } + + resp.Name = name + if resp.CreateSignerRequest.SigningKey != "" { + resp.CreateSignerRequest.SigningKey = name + } + + return resp, nil +} + +func (k *KMS) patchSearchKeysResponse(resp *apiv1.SearchKeysResponse) (*apiv1.SearchKeysResponse, error) { + for i := range resp.Results { + name, err := k.transformFromURI(resp.Results[i].Name) + if err != nil { + return nil, err + } + + resp.Results[i].Name = name + if resp.Results[i].CreateSignerRequest.SigningKey != "" { + resp.Results[i].CreateSignerRequest.SigningKey = name + } + } + + return resp, nil +} + +func clone[T any](v *T) *T { + c := *v + return &c +} diff --git a/kms/platform/kms_darwin.go b/kms/platform/kms_darwin.go new file mode 100644 index 00000000..01bd6744 --- /dev/null +++ b/kms/platform/kms_darwin.go @@ -0,0 +1,107 @@ +package platform + +import ( + "context" + "fmt" + "maps" + "net/url" + "strings" + + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/mackms" + "go.step.sm/crypto/kms/uri" +) + +var _ apiv1.SearchableKeyManager = (*KMS)(nil) + +func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { + if opts.URI == "" { + return newMacKMS(ctx, opts) + } + + u, err := parseURI(opts.URI) + if err != nil { + return nil, err + } + + switch u.backend { + case apiv1.TPMKMS: + return newTPMKMS(ctx, opts) + case apiv1.SoftKMS: + return newSoftKMS(ctx, opts) + case apiv1.DefaultKMS, apiv1.MacKMS: + return newMacKMS(ctx, opts) + default: + return nil, fmt.Errorf("failed parsing %q: unsupported backend %q", opts.URI, u.backend) + } +} + +func newMacKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { + if opts.URI != "" { + u, err := transformToMacKMS(opts.URI) + if err != nil { + return nil, fmt.Errorf("error parsing uri: %w", err) + } + opts.URI = u + } + + km, err := mackms.New(ctx, opts) + if err != nil { + return nil, err + } + + return &KMS{ + typ: apiv1.MacKMS, + backend: km, + transformToURI: transformToMacKMS, + transformFromURI: transformFromMacKMS, + }, nil +} + +func transformToMacKMS(rawuri string) (string, error) { + u, err := parseURI(rawuri) + if err != nil { + return "", err + } + + uv := url.Values{} + if u.name != "" { + uv.Set("label", u.name) + } + if u.hw { + uv.Set("se", "true") + if !u.uri.Has("keychain") { + uv.Set("keychain", "dataProtection") + } + } else if strings.EqualFold(u.uri.Get("hw"), "false") { + uv.Set("se", "false") + } + + // Add custom extra values that might be mackms specific. + maps.Copy(uv, u.extraValues) + + return uri.New(mackms.Scheme, uv).String(), nil +} + +func transformFromMacKMS(rawuri string) (string, error) { + u, err := uri.ParseWithScheme(mackms.Scheme, rawuri) + if err != nil { + return "", err + } + + uv := url.Values{} + if u.Has("label") { + uv.Set("name", u.Get("label")) + } + if u.GetBool("se") { + uv.Set("hw", "true") + } + + for k, v := range uri.Values(u) { + if k != "label" && k != "se" { + uv[k] = v + } + } + + return uri.New(Scheme, uv).String(), nil +} diff --git a/kms/platform/kms_darwin_test.go b/kms/platform/kms_darwin_test.go new file mode 100644 index 00000000..b4443e3b --- /dev/null +++ b/kms/platform/kms_darwin_test.go @@ -0,0 +1,70 @@ +package platform + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func mustPlatformKMS(t *testing.T) *KMS { + t.Helper() + + return mustKMS(t, "kms:") +} + +// SkipTest is a method implemented on tests that allow skipping the test on +// this platform. +func (k *KMS) SkipTests() bool { + return false +} + +func Test_transformToMacKMS(t *testing.T) { + tests := []struct { + name string + rawuri string + want string + assertion assert.ErrorAssertionFunc + }{ + {"scheme", "kms:", "mackms:", assert.NoError}, + {"with name", "kms:name=foo", "mackms:label=foo", assert.NoError}, + {"with hw", "kms:name=foo;hw=true", "mackms:keychain=dataProtection;label=foo;se=true", assert.NoError}, + {"with hw false", "kms:name=foo;hw=false", "mackms:label=foo;se=false", assert.NoError}, + {"with hw on query", "kms:name=foo?hw=true", "mackms:keychain=dataProtection;label=foo;se=true", assert.NoError}, + {"with hw and keychain", "kms:name=foo;hw=true;keychain=my", "mackms:keychain=my;label=foo;se=true", assert.NoError}, + {"with hw other", "kms:name=foo;hw=other", "mackms:label=foo", assert.NoError}, + {"with extrasValues", "kms:name=foo;keychain=my?foo=bar&baz=qux", "mackms:baz=qux;foo=bar;keychain=my;label=foo", assert.NoError}, + {"fail parse", "softkms:name=foo", "", assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := transformToMacKMS(tt.rawuri) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_transformFromMacKMS(t *testing.T) { + tests := []struct { + name string + rawuri string + want string + assertion assert.ErrorAssertionFunc + }{ + {"scheme", "mackms:", "kms:", assert.NoError}, + {"with label", "mackms:label=foo", "kms:name=foo", assert.NoError}, + {"with se", "mackms:label=foo;se=true", "kms:hw=true;name=foo", assert.NoError}, + {"with se on query", "mackms:label=foo?se=true", "kms:hw=true;name=foo", assert.NoError}, + {"with keychain", "mackms:label=foo;se=true;keychain=dataProtection", "kms:hw=true;keychain=dataProtection;name=foo", assert.NoError}, + {"with keychain on query", "mackms:label=foo?keychain=dataProtection&foo=bar", "kms:foo=bar;keychain=dataProtection;name=foo", assert.NoError}, + {"fail empty", "", "", assert.Error}, + {"fail scheme", "kms:", "", assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := transformFromMacKMS(tt.rawuri) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/kms/platform/kms_other.go b/kms/platform/kms_other.go new file mode 100644 index 00000000..9901783f --- /dev/null +++ b/kms/platform/kms_other.go @@ -0,0 +1,30 @@ +//go:build !darwin && !windows + +package platform + +import ( + "context" + "fmt" + + "go.step.sm/crypto/kms/apiv1" +) + +func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { + if opts.URI == "" { + return newTPMKMS(ctx, opts) + } + + u, err := parseURI(opts.URI) + if err != nil { + return nil, err + } + + switch u.backend { + case apiv1.SoftKMS: + return newSoftKMS(ctx, opts) + case apiv1.DefaultKMS, apiv1.TPMKMS: + return newTPMKMS(ctx, opts) + default: + return nil, fmt.Errorf("failed parsing %q: unsupported backend %q", opts.URI, u.backend) + } +} diff --git a/kms/platform/kms_other_test.go b/kms/platform/kms_other_test.go new file mode 100644 index 00000000..1c83ddb1 --- /dev/null +++ b/kms/platform/kms_other_test.go @@ -0,0 +1,28 @@ +//go:build !darwin && !windows + +package platform + +import ( + "net/url" + "testing" + + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/uri" + "go.step.sm/crypto/tpm/available" +) + +func mustPlatformKMS(t *testing.T) *KMS { + if available.Check() != nil { + return &KMS{} + } + + return mustKMS(t, uri.New(Scheme, url.Values{ + "storage-directory": []string{t.TempDir()}, + }).String()) +} + +// SkipTest is a method implemented on tests that allow skipping the test on +// this platform. +func (k *KMS) SkipTests() bool { + return k.Type() == apiv1.DefaultKMS +} diff --git a/kms/platform/kms_softkms.go b/kms/platform/kms_softkms.go new file mode 100644 index 00000000..f3a76a7f --- /dev/null +++ b/kms/platform/kms_softkms.go @@ -0,0 +1,133 @@ +package platform + +import ( + "bytes" + "context" + "encoding/pem" + "fmt" + "net/url" + "os" + + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/softkms" + "go.step.sm/crypto/kms/uri" + "go.step.sm/crypto/pemutil" +) + +func newSoftKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { + km, err := softkms.New(ctx, opts) + if err != nil { + return nil, err + } + + return &KMS{ + typ: apiv1.SoftKMS, + backend: &softKMS{SoftKMS: km}, + transformToURI: transformToSoftKMS, + transformFromURI: transformFromSoftKMS, + }, nil +} + +type softKMS struct { + *softkms.SoftKMS +} + +func (k *softKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) { + if req.Name == "" { + return nil, fmt.Errorf("createKeyRequest 'name' cannot be empty") + } + + resp, err := k.SoftKMS.CreateKey(req) + if err != nil { + return nil, err + } + + if _, err := pemutil.Serialize(resp.PrivateKey, pemutil.ToFile(req.Name, 0o600)); err != nil { + return nil, err + } + + return resp, nil +} + +func (k *softKMS) DeleteKey(req *apiv1.DeleteKeyRequest) error { + if req.Name == "" { + return fmt.Errorf("deleteKeyRequest 'name' cannot be empty") + } + + return os.Remove(req.Name) +} + +func (k *softKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { + switch { + case req.Name == "": + return fmt.Errorf("storeCertificateRequest 'name' cannot be empty") + case req.Certificate == nil: + return fmt.Errorf("storeCertificateRequest 'certificate' cannot be empty") + } + + return os.WriteFile(req.Name, pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: req.Certificate.Raw, + }), 0o600) +} + +func (k *softKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) error { + switch { + case req.Name == "": + return fmt.Errorf("storeCertificateChainRequest 'name' cannot be empty") + case len(req.CertificateChain) == 0: + return fmt.Errorf("storeCertificateChainRequest 'certificateChain' cannot be empty") + } + + var buf bytes.Buffer + for _, crt := range req.CertificateChain { + if err := pem.Encode(&buf, &pem.Block{ + Type: "CERTIFICATE", + Bytes: crt.Raw, + }); err != nil { + return err + } + } + + return os.WriteFile(req.Name, buf.Bytes(), 0o600) +} + +func (k *softKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { + if req.Name == "" { + return fmt.Errorf("deleteCertificateRequest 'name' cannot be empty") + } + + return os.Remove(req.Name) +} + +func transformToSoftKMS(rawuri string) (string, error) { + u, err := parseURI(rawuri) + if err != nil { + return "", err + } + + if u.hw { + return "", fmt.Errorf("error parsing uri: hw is not supported") + } + + switch { + case u.uri.Has("name"): + return u.name, nil + case u.uri.Has("path"): + return u.uri.Get("path"), nil + case u.uri.Path != "": + return u.uri.Path, nil + case u.uri.Opaque != "": + return u.uri.Opaque, nil + default: + return "", nil + } +} + +func transformFromSoftKMS(path string) (string, error) { + uv := url.Values{} + if path != "" { + uv.Set(nameKey, path) + } + return uri.New(Scheme, uv).String(), nil +} diff --git a/kms/platform/kms_softkms_test.go b/kms/platform/kms_softkms_test.go new file mode 100644 index 00000000..8d47537b --- /dev/null +++ b/kms/platform/kms_softkms_test.go @@ -0,0 +1,50 @@ +package platform + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_transformToSoftKMS(t *testing.T) { + tests := []struct { + name string + rawuri string + want string + assertion assert.ErrorAssertionFunc + }{ + {"scheme", "kms:", "", assert.NoError}, + {"with name", "kms:name=path/to/file.crt", "path/to/file.crt", assert.NoError}, + {"with encoded", "kms:name=%2Fpath%2Fto%2Ffile.key", "/path/to/file.key", assert.NoError}, + {"with path", "kms:path=/path/to/file.key", "/path/to/file.key", assert.NoError}, + {"with opaque", "kms:path/to/file.key", "path/to/file.key", assert.NoError}, + {"fail parse", "mackms:", "", assert.Error}, + {"fail hw", "kms:name=file.key;hw=true", "", assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := transformToSoftKMS(tt.rawuri) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_transformFromSoftKMS(t *testing.T) { + tests := []struct { + name string + path string + want string + assertion assert.ErrorAssertionFunc + }{ + {"scheme", "", "kms:", assert.NoError}, + {"with path", "/path/to/file", "kms:name=%2Fpath%2Fto%2Ffile", assert.NoError}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := transformFromSoftKMS(tt.path) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/kms/platform/kms_test.go b/kms/platform/kms_test.go new file mode 100644 index 00000000..98aa0304 --- /dev/null +++ b/kms/platform/kms_test.go @@ -0,0 +1,1221 @@ +package platform + +import ( + "bytes" + "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/hex" + "encoding/pem" + "errors" + "fmt" + "net/url" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.step.sm/crypto/keyutil" + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/uri" + "go.step.sm/crypto/minica" + "go.step.sm/crypto/pemutil" + "go.step.sm/crypto/randutil" +) + +var ( + platformKeyName string + platformCertName string + platformMissingName string +) + +func TestMain(m *testing.M) { + suffix, err := randutil.Alphanumeric(8) + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } + + platformKeyName = "kms:name=test-" + suffix + platformCertName = "kms:name=test-" + suffix + platformMissingName = "kms:name=test-missing-" + suffix + + if runtime.GOOS == "darwin" { + platformKeyName += ";tag=com.smallstep.test." + suffix + } + + os.Exit(m.Run()) +} + +func shouldSkipNow(t *testing.T, km *KMS) { + t.Helper() + + if km.Type() != apiv1.SoftKMS && km.SkipTests() { + t.SkipNow() + } +} + +func mustKMS(t *testing.T, rawuri string) *KMS { + t.Helper() + + km, err := New(t.Context(), apiv1.Options{ + URI: rawuri, + }) + require.NoError(t, err) + + t.Cleanup(func() { + assert.NoError(t, km.Close()) + }) + return km +} + +func mustSigner(t *testing.T, path string) crypto.Signer { + t.Helper() + + signer, err := keyutil.GenerateDefaultSigner() + require.NoError(t, err) + + _, err = pemutil.Serialize(signer, pemutil.ToFile(path, 0o600)) + require.NoError(t, err) + + return signer +} + +func mustReadSigner(t *testing.T, path string) crypto.Signer { + t.Helper() + + k, err := pemutil.Read(path) + require.NoError(t, err) + + signer, ok := k.(crypto.Signer) + require.True(t, ok) + + return signer +} + +func mustCertificate(t *testing.T, path string) []*x509.Certificate { + t.Helper() + + ca, err := minica.New() + require.NoError(t, err) + + signer, err := keyutil.GenerateDefaultSigner() + require.NoError(t, err) + + cert, err := ca.Sign(&x509.Certificate{ + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + PublicKey: signer.Public(), + DNSNames: []string{"example.com"}, + }) + require.NoError(t, err) + + if path != "" { + var buf bytes.Buffer + require.NoError(t, pem.Encode(&buf, &pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + })) + require.NoError(t, pem.Encode(&buf, &pem.Block{ + Type: "CERTIFICATE", + Bytes: ca.Intermediate.Raw, + })) + + require.NoError(t, os.WriteFile(path, buf.Bytes(), 0o600)) + } + + return []*x509.Certificate{ + cert, ca.Intermediate, + } +} + +func mustCertificateWithKey(t *testing.T, key crypto.PublicKey) []*x509.Certificate { + t.Helper() + + ca, err := minica.New() + require.NoError(t, err) + + // skipped platform + if key == nil { + return []*x509.Certificate{ca.Intermediate} + } + + cert, err := ca.Sign(&x509.Certificate{ + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + PublicKey: key, + DNSNames: []string{"example.com"}, + }) + require.NoError(t, err) + + return []*x509.Certificate{ + cert, ca.Intermediate, + } +} + +func mustSuffix(t *testing.T) string { + t.Helper() + suffix, err := randutil.Alphanumeric(8) + require.NoError(t, err) + return suffix +} + +type createOptions struct { + name string + noCleanup bool + noCleanupCertificate bool +} + +type createFuncOption func(*createOptions) + +func withName(s string) createFuncOption { + return func(co *createOptions) { + co.name = s + } +} + +func withNoCleanup() createFuncOption { + return func(co *createOptions) { + co.noCleanup = true + co.noCleanupCertificate = true + } +} + +func withNoCleanupCertificate() createFuncOption { + return func(co *createOptions) { + co.noCleanupCertificate = true + } +} + +func mustCreatePlatformKey(t *testing.T, km *KMS, opts ...createFuncOption) *apiv1.CreateKeyResponse { + t.Helper() + + o := new(createOptions) + o.name = platformKeyName + for _, fn := range opts { + fn(o) + } + + if km.SkipTests() { + return &apiv1.CreateKeyResponse{} + } + + resp, err := km.CreateKey(&apiv1.CreateKeyRequest{ + Name: o.name, + }) + require.NoError(t, err) + + if !o.noCleanup { + t.Cleanup(func() { + assert.NoError(t, km.DeleteKey(&apiv1.DeleteKeyRequest{ + Name: resp.Name, + })) + }) + } + + return resp +} + +func mustCreatePlatformCertificate(t *testing.T, km *KMS, opts ...createFuncOption) []*x509.Certificate { + t.Helper() + + o := new(createOptions) + o.name = platformCertName + for _, fn := range opts { + fn(o) + } + + ca, err := minica.New() + require.NoError(t, err) + + if km.SkipTests() { + return []*x509.Certificate{ + ca.Intermediate, + } + } + + key := mustCreatePlatformKey(t, km, opts...) + cert, err := ca.Sign(&x509.Certificate{ + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + PublicKey: key.PublicKey, + DNSNames: []string{"example.com"}, + }) + require.NoError(t, err) + + require.NoError(t, km.StoreCertificateChain(&apiv1.StoreCertificateChainRequest{ + Name: o.name, + CertificateChain: []*x509.Certificate{ + cert, ca.Intermediate, + }, + })) + if !o.noCleanupCertificate { + t.Cleanup(func() { + assert.NoError(t, km.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: o.name, + })) + }) + } + + // Always delete the intermediate on macOS + if typ := km.Type(); typ == apiv1.MacKMS { + t.Cleanup(func() { + assert.NoError(t, km.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: uri.New(Scheme, url.Values{ + "serial": []string{ca.Intermediate.SerialNumber.String()}, + }).String(), + })) + }) + } + + return []*x509.Certificate{ + cert, ca.Intermediate, + } +} + +func mustPermanentIdentifier(t *testing.T, pub crypto.PublicKey) *url.URL { + t.Helper() + + b, err := x509.MarshalPKIXPublicKey(pub) + require.NoError(t, err) + + keyID := sha256.Sum256(b) + return &url.URL{ + Scheme: "urn", + Opaque: "ek:sha256:" + base64.StdEncoding.EncodeToString(keyID[:]), + } +} + +type attestationClient struct { + chain []*x509.Certificate + err error +} + +func mustAttestationClient(chain []*x509.Certificate, err error) *attestationClient { + return &attestationClient{ + chain: chain, + err: err, + } +} + +func (c *attestationClient) Attest(ctx context.Context) ([]*x509.Certificate, error) { + if _, ok := apiv1.AttestSignerFromContext(ctx); !ok { + return nil, fmt.Errorf("signer is not in context") + } + return c.chain, c.err +} + +func TestKMS_Type(t *testing.T) { + softKMS := mustKMS(t, "kms:backend=softkms") + assert.Equal(t, apiv1.SoftKMS, softKMS.Type()) +} + +func TestKMS_Close(t *testing.T) { + softKMS, err := New(t.Context(), apiv1.Options{ + URI: "kms:backend=softkms", + }) + require.NoError(t, err) + assert.NoError(t, softKMS.Close()) +} + +func TestKMS_GetPublicKey(t *testing.T) { + dir := t.TempDir() + privateKeyPath := filepath.Join(dir, "private.key") + signer := mustSigner(t, privateKeyPath) + softKMS := mustKMS(t, "kms:backend=softkms") + + platformKMS := mustPlatformKMS(t) + platformKey := mustCreatePlatformKey(t, platformKMS) + + type args struct { + req *apiv1.GetPublicKeyRequest + } + tests := []struct { + name string + kms *KMS + args args + want crypto.PublicKey + assertion assert.ErrorAssertionFunc + }{ + // Platform KMS + {"ok platform", platformKMS, args{&apiv1.GetPublicKeyRequest{ + Name: platformKey.Name, + }}, platformKey.PublicKey, assert.NoError}, + {"fail platform missing", platformKMS, args{&apiv1.GetPublicKeyRequest{ + Name: platformMissingName, + }}, nil, assert.Error}, + {"fail platform name", platformKMS, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:something=test", + }}, nil, assert.Error}, + + // SoftKMS + {"ok SoftKMS", softKMS, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:name=" + privateKeyPath, + }}, signer.Public(), assert.NoError}, + {"ok SoftKMS escape", softKMS, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:name=" + url.QueryEscape(privateKeyPath), + }}, signer.Public(), assert.NoError}, + {"ok SoftKMS path", softKMS, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:" + privateKeyPath, + }}, signer.Public(), assert.NoError}, + {"fail empty name", softKMS, args{&apiv1.GetPublicKeyRequest{ + Name: "", + }}, nil, assert.Error}, + {"fail SoftKMS missing", softKMS, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:" + filepath.Join(dir, "notfound.key"), + }}, nil, assert.Error}, + {"fail transform", softKMS, args{&apiv1.GetPublicKeyRequest{ + Name: "softkms:" + privateKeyPath, + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + + got, err := tt.kms.GetPublicKey(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_CreateKey(t *testing.T) { + dir := t.TempDir() + privateKeyPath := filepath.Join(dir, "private.key") + softKMS := mustKMS(t, "kms:backend=softkms") + + suffix := mustSuffix(t) + platformKMS := mustPlatformKMS(t) + + type args struct { + req *apiv1.CreateKeyRequest + } + tests := []struct { + name string + kms *KMS + args args + equal func(t *testing.T, got *apiv1.CreateKeyResponse) + assertion assert.ErrorAssertionFunc + }{ + // Platform KMS + {"ok platform", platformKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=test1-" + suffix, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + require.NotNil(t, got) + + t.Cleanup(func() { + assert.NoError(t, platformKMS.DeleteKey(&apiv1.DeleteKeyRequest{ + Name: "kms:name=test1-" + suffix, + })) + }) + + assert.Regexp(t, "^kms:.*name=test1-.*$", got.Name) + assert.Equal(t, got.Name, got.CreateSignerRequest.SigningKey) + + if platformKMS.Type() == apiv1.TPMKMS && assert.IsType(t, &rsa.PublicKey{}, got.PublicKey) { + assert.Equal(t, 256, got.PublicKey.(*rsa.PublicKey).Size()) + } else if assert.IsType(t, &ecdsa.PublicKey{}, got.PublicKey) { + assert.Equal(t, elliptic.P256(), got.PublicKey.(*ecdsa.PublicKey).Curve) + } + }, assert.NoError}, + {"ok platform ECDSA", platformKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=test2-" + suffix, + SignatureAlgorithm: apiv1.ECDSAWithSHA384, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + require.NotNil(t, got) + + t.Cleanup(func() { + assert.NoError(t, platformKMS.DeleteKey(&apiv1.DeleteKeyRequest{ + Name: "kms:name=test2-" + suffix, + })) + }) + + assert.Regexp(t, "^kms:.*name=test2-.*$", got.Name) + assert.Equal(t, got.Name, got.CreateSignerRequest.SigningKey) + + if assert.IsType(t, &ecdsa.PublicKey{}, got.PublicKey) { + assert.Equal(t, elliptic.P384(), got.PublicKey.(*ecdsa.PublicKey).Curve) + } + }, assert.NoError}, + {"ok platform RSA", platformKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=test3-" + suffix, + SignatureAlgorithm: apiv1.SHA256WithRSA, + Bits: 2048, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + require.NotNil(t, got) + + t.Cleanup(func() { + assert.NoError(t, platformKMS.DeleteKey(&apiv1.DeleteKeyRequest{ + Name: "kms:name=test3-" + suffix, + })) + }) + + assert.Regexp(t, "^kms:.*name=test3-.*$", got.Name) + assert.Equal(t, got.Name, got.CreateSignerRequest.SigningKey) + if assert.IsType(t, &rsa.PublicKey{}, got.PublicKey) { + assert.Equal(t, 256, got.PublicKey.(*rsa.PublicKey).Size()) + } + }, assert.NoError}, + {"fail platform algorithm", platformKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:test4-" + suffix, + SignatureAlgorithm: apiv1.SignatureAlgorithm(100), + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, + + // SoftKMS + {"ok softKMS", softKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=" + privateKeyPath, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + signer := mustReadSigner(t, privateKeyPath) + assert.IsType(t, &ecdsa.PrivateKey{}, signer) + name := "kms:name=" + url.QueryEscape(privateKeyPath) + assert.Equal(t, got, &apiv1.CreateKeyResponse{ + Name: name, + PublicKey: signer.Public(), + PrivateKey: signer, + CreateSignerRequest: apiv1.CreateSignerRequest{ + Signer: signer, + SigningKey: name, + }, + }) + }, assert.NoError}, + {"ok softKMS escape", softKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=" + url.QueryEscape(privateKeyPath), + SignatureAlgorithm: apiv1.SHA256WithRSA, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + signer := mustReadSigner(t, privateKeyPath) + if assert.IsType(t, &rsa.PrivateKey{}, signer) { + assert.Equal(t, 3072/8, signer.(*rsa.PrivateKey).Size()) + } + name := "kms:name=" + url.QueryEscape(privateKeyPath) + assert.Equal(t, got, &apiv1.CreateKeyResponse{ + Name: name, + PublicKey: signer.Public(), + PrivateKey: signer, + CreateSignerRequest: apiv1.CreateSignerRequest{ + Signer: signer, + SigningKey: name, + }, + }) + }, assert.NoError}, + {"ok softKMS path", softKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:" + privateKeyPath, + SignatureAlgorithm: apiv1.SHA256WithRSA, + Bits: 2048, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + signer := mustReadSigner(t, privateKeyPath) + name := "kms:name=" + url.QueryEscape(privateKeyPath) + if assert.IsType(t, &rsa.PrivateKey{}, signer) { + assert.Equal(t, 2048/8, signer.(*rsa.PrivateKey).Size()) + } + assert.Equal(t, got, &apiv1.CreateKeyResponse{ + Name: name, + PublicKey: signer.Public(), + PrivateKey: signer, + CreateSignerRequest: apiv1.CreateSignerRequest{ + Signer: signer, + SigningKey: name, + }, + }) + }, assert.NoError}, + {"fail softKMS createKey", softKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:" + privateKeyPath, + SignatureAlgorithm: apiv1.SignatureAlgorithm(100), + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, + {"fail parseURI", softKMS, args{&apiv1.CreateKeyRequest{ + Name: "softkms:" + privateKeyPath, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, + {"fail empty name", softKMS, args{&apiv1.CreateKeyRequest{ + Name: "", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, + {"fail empty uri", softKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, + {"fail empty uri name", softKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, + {"fail empty uri path", softKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:path=", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + + got, err := tt.kms.CreateKey(tt.args.req) + tt.assertion(t, err) + tt.equal(t, got) + }) + } +} + +func TestKMS_CreateSigner(t *testing.T) { + dir := t.TempDir() + softKMS := mustKMS(t, "kms:backend=softkms") + privateKeyPath := filepath.Join(dir, "private.key") + resp, err := softKMS.CreateKey(&apiv1.CreateKeyRequest{ + Name: "kms:name=" + url.QueryEscape(privateKeyPath), + }) + require.NoError(t, err) + signer := mustReadSigner(t, privateKeyPath) + + platformKMS := mustPlatformKMS(t) + platformKey := mustCreatePlatformKey(t, platformKMS) + + assertNil := func(t *testing.T, got crypto.Signer) { + t.Helper() + assert.Nil(t, got) + } + + type args struct { + req *apiv1.CreateSignerRequest + } + tests := []struct { + name string + kms *KMS + args args + equal func(*testing.T, crypto.Signer) + assertion assert.ErrorAssertionFunc + }{ + // PlatformKMS + {"ok platform", platformKMS, args{&apiv1.CreateSignerRequest{ + SigningKey: platformKeyName, + }}, func(t *testing.T, s crypto.Signer) { + require.NotNil(t, s) + assert.Equal(t, platformKey.PublicKey, s.Public()) + }, assert.NoError}, + {"fail platform missing", platformKMS, args{&apiv1.CreateSignerRequest{ + SigningKey: platformMissingName, + }}, assertNil, assert.Error}, + + // SoftKMS + {"ok softKMS", softKMS, args{&apiv1.CreateSignerRequest{ + SigningKey: "kms:name=" + url.QueryEscape(privateKeyPath), + }}, func(t *testing.T, s crypto.Signer) { + assert.Equal(t, signer, s) + }, assert.NoError}, + {"ok softKMS with signer", softKMS, args{&apiv1.CreateSignerRequest{ + Signer: resp.CreateSignerRequest.Signer, + SigningKey: resp.CreateSignerRequest.SigningKey, + }}, func(t *testing.T, s crypto.Signer) { + assert.Equal(t, signer, s) + }, assert.NoError}, + {"fail missing", softKMS, args{&apiv1.CreateSignerRequest{ + SigningKey: "kms:name=" + url.QueryEscape(filepath.Join(dir, "missing.key")), + }}, assertNil, assert.Error}, + {"fail parseURI", softKMS, args{&apiv1.CreateSignerRequest{ + SigningKey: privateKeyPath, + }}, assertNil, assert.Error}, + {"fail signingKey", softKMS, args{&apiv1.CreateSignerRequest{ + SigningKey: "", + }}, assertNil, assert.Error}, + {"fail empty uri", softKMS, args{&apiv1.CreateSignerRequest{ + SigningKey: "kms:", + }}, assertNil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + + got, err := tt.kms.CreateSigner(tt.args.req) + tt.assertion(t, err) + tt.equal(t, got) + }) + } +} + +func TestKMS_DeleteKey(t *testing.T) { + dir := t.TempDir() + softKMS := mustKMS(t, "kms:backend=softkms") + keyPath1 := filepath.Join(dir, "key1.key") + _, err := softKMS.CreateKey(&apiv1.CreateKeyRequest{ + Name: "kms:name=" + url.QueryEscape(keyPath1), + }) + require.NoError(t, err) + + keyPath2 := filepath.Join(dir, "key2.key") + _, err = softKMS.CreateKey(&apiv1.CreateKeyRequest{ + Name: "kms:name=" + keyPath2, + }) + require.NoError(t, err) + + platformKMS := mustPlatformKMS(t) + platformKey := mustCreatePlatformKey(t, platformKMS, withNoCleanup()) + + type args struct { + req *apiv1.DeleteKeyRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + // Platform KMS + {"ok platform", platformKMS, args{&apiv1.DeleteKeyRequest{ + Name: platformKey.Name, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + _, getErr := platformKMS.GetPublicKey(&apiv1.GetPublicKeyRequest{ + Name: platformKey.Name, + }) + return assert.NoError(t, err) && assert.Error(t, getErr) + }}, + {"fail platform deleted", platformKMS, args{&apiv1.DeleteKeyRequest{ + Name: platformKey.Name, + }}, assert.Error}, + {"fail platform missing", platformKMS, args{&apiv1.DeleteKeyRequest{ + Name: platformMissingName, + }}, assert.Error}, + + // SoftKMS + {"ok softKMS", softKMS, args{&apiv1.DeleteKeyRequest{ + Name: "kms:name=" + url.QueryEscape(keyPath1), + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) && + assert.NoFileExists(t, keyPath1) + }}, + {"fail missing", softKMS, args{&apiv1.DeleteKeyRequest{ + Name: "kms:name=" + url.QueryEscape(filepath.Join(dir, "missing.key")), + }}, assert.Error}, + {"fail parseURI", softKMS, args{&apiv1.DeleteKeyRequest{ + Name: keyPath2, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.Error(t, err) && + assert.FileExists(t, keyPath2) + }}, + {"fail name", softKMS, args{&apiv1.DeleteKeyRequest{ + Name: "", + }}, assert.Error}, + {"fail empty uri", softKMS, args{&apiv1.DeleteKeyRequest{ + Name: "kms:", + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + + tt.assertion(t, tt.kms.DeleteKey(tt.args.req)) + }) + } +} + +func TestKMS_LoadCertificate(t *testing.T) { + dir := t.TempDir() + + chainPath := filepath.Join(dir, "chain.crt") + chain := mustCertificate(t, chainPath) + + certPath := filepath.Join(dir, "certificate.crt") + softKMS := mustKMS(t, "kms:backend=softkms") + require.NoError(t, softKMS.StoreCertificate(&apiv1.StoreCertificateRequest{ + Name: "kms:name=" + url.QueryEscape(certPath), + Certificate: chain[0], + })) + + platformKMS := mustPlatformKMS(t) + platformChain := mustCreatePlatformCertificate(t, platformKMS) + + type args struct { + req *apiv1.LoadCertificateRequest + } + tests := []struct { + name string + kms *KMS + args args + want *x509.Certificate + assertion assert.ErrorAssertionFunc + }{ + // Platform KMS + {"ok platform", platformKMS, args{&apiv1.LoadCertificateRequest{ + Name: platformCertName, + }}, platformChain[0], assert.NoError}, + {"fail platform missing", platformKMS, args{&apiv1.LoadCertificateRequest{ + Name: platformMissingName, + }}, nil, assert.Error}, + + // SoftKMS + {"ok softKMS", softKMS, args{&apiv1.LoadCertificateRequest{ + Name: "kms:" + certPath, + }}, chain[0], assert.NoError}, + {"ok softKMS from chain", softKMS, args{&apiv1.LoadCertificateRequest{ + Name: "kms:name=" + url.QueryEscape(chainPath), + }}, chain[0], assert.NoError}, + {"fail missing", softKMS, args{&apiv1.LoadCertificateRequest{ + Name: "kms:name=" + filepath.Join(dir, "missing.crt"), + }}, nil, assert.Error}, + {"fail parseURI", softKMS, args{&apiv1.LoadCertificateRequest{ + Name: "foo:name=" + certPath, + }}, nil, assert.Error}, + {"fail name", softKMS, args{&apiv1.LoadCertificateRequest{ + Name: "", + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + + got, err := tt.kms.LoadCertificate(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_StoreCertificate(t *testing.T) { + dir := t.TempDir() + chain := mustCertificate(t, "") + softKMS := mustKMS(t, "kms:backend=softkms") + + platformKMS := mustPlatformKMS(t) + platformKey := mustCreatePlatformKey(t, platformKMS) + platformChain := mustCertificateWithKey(t, platformKey.PublicKey) + + type args struct { + req *apiv1.StoreCertificateRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + // Platform KMS + {"ok platform", platformKMS, args{&apiv1.StoreCertificateRequest{ + Name: platformCertName, + Certificate: platformChain[0], + }}, assert.NoError}, + {"ok platform no key", platformKMS, args{&apiv1.StoreCertificateRequest{ + Name: platformCertName + "-other", + Certificate: chain[0], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + // Storing a certificate with no key is not supported on TPMKMS. + if platformKMS.Type() == apiv1.TPMKMS && runtime.GOOS != "windows" { + return assert.Error(t, err) + } + + t.Cleanup(func() { + assert.NoError(t, platformKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: platformCertName, + })) + }) + return assert.NoError(t, err) + }}, + {"fail platform no certificate", platformKMS, args{&apiv1.StoreCertificateRequest{ + Name: platformCertName, + Certificate: nil, + }}, assert.Error}, + + // SoftKMS + {"ok softKMS", softKMS, args{&apiv1.StoreCertificateRequest{ + Name: "kms:name=" + filepath.Join(dir, "cert.crt"), + Certificate: chain[0], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) && + assert.FileExists(t, filepath.Join(dir, "cert.crt")) + }}, + {"ok softKMS simple", softKMS, args{&apiv1.StoreCertificateRequest{ + Name: "kms:" + filepath.Join(dir, "intermediate.crt"), + Certificate: chain[1], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) && + assert.FileExists(t, filepath.Join(dir, "intermediate.crt")) + }}, + {"ok softKMS overwrite", softKMS, args{&apiv1.StoreCertificateRequest{ + Name: "kms:" + filepath.Join(dir, "cert.crt"), + Certificate: chain[0], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) && + assert.FileExists(t, filepath.Join(dir, "cert.crt")) + }}, + {"fail parseURI", softKMS, args{&apiv1.StoreCertificateRequest{ + Name: "foo:" + filepath.Join(dir, "fail.crt"), + Certificate: chain[0], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.Error(t, err) && + assert.NoFileExists(t, filepath.Join(dir, "fail.crt")) + }}, + {"fail name", softKMS, args{&apiv1.StoreCertificateRequest{ + Certificate: chain[0], + }}, assert.Error}, + {"fail empty uri", softKMS, args{&apiv1.StoreCertificateRequest{ + Name: "kms:", + Certificate: chain[0], + }}, assert.Error}, + {"fail certificate", softKMS, args{&apiv1.StoreCertificateRequest{ + Name: "kms:name=" + filepath.Join(dir, "cert.crt"), + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + tt.assertion(t, tt.kms.StoreCertificate(tt.args.req)) + }) + } +} + +func TestKMS_LoadCertificateChain(t *testing.T) { + dir := t.TempDir() + chainPath := filepath.Join(dir, "chain.crt") + chain := mustCertificate(t, chainPath) + softKMS := mustKMS(t, "kms:backend=softkms") + + platformKMS := mustPlatformKMS(t) + platformChain := mustCreatePlatformCertificate(t, platformKMS) + + type args struct { + req *apiv1.LoadCertificateChainRequest + } + tests := []struct { + name string + kms *KMS + args args + want []*x509.Certificate + assertion assert.ErrorAssertionFunc + }{ + // Platform KMS + {"ok platform", platformKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: platformCertName, + }}, platformChain, assert.NoError}, + {"fail platform missing", platformKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: platformMissingName, + }}, nil, assert.Error}, + + // SoftKMS + {"ok softKMS", softKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: "kms:name=" + chainPath, + }}, chain, assert.NoError}, + {"fail parseURI", softKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: "foo:name=" + chainPath, + }}, nil, assert.Error}, + {"fail missing", softKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: "kms:name=" + filepath.Join(dir, "missing.crt"), + }}, nil, assert.Error}, + {"fail parseuri", softKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: "softkms:name=" + chainPath, + }}, nil, assert.Error}, + {"fail name", softKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: "", + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + + got, err := tt.kms.LoadCertificateChain(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_StoreCertificateChain(t *testing.T) { + dir := t.TempDir() + chain := mustCertificate(t, "") + softKMS := mustKMS(t, "kms:backend=softkms") + + platformKMS := mustPlatformKMS(t) + platformKey := mustCreatePlatformKey(t, platformKMS) + platformChain := mustCertificateWithKey(t, platformKey.PublicKey) + + type args struct { + req *apiv1.StoreCertificateChainRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + // Platform KMS + {"ok platform", platformKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: platformCertName, + CertificateChain: platformChain, + }}, assert.NoError}, + {"ok platform no key", platformKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: platformCertName + "-other", + CertificateChain: chain, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + // Storing a certificate with no key is not supported on TPMKMS. + if platformKMS.Type() == apiv1.TPMKMS && runtime.GOOS != "windows" { + return assert.Error(t, err) + } + + t.Cleanup(func() { + assert.NoError(t, platformKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: platformCertName, + })) + + if typ := platformKMS.Type(); typ == apiv1.MacKMS || (typ == apiv1.TPMKMS && runtime.GOOS == "windows") { + assert.NoError(t, platformKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: uri.New(Scheme, url.Values{ + "issuer": []string{platformChain[1].Issuer.CommonName}, // for windows only + "serial": []string{hex.EncodeToString(platformChain[1].SerialNumber.Bytes())}, + }).String(), + })) + } + }) + return assert.NoError(t, err) + }}, + {"fail platform bad chain", platformKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: platformCertName, + CertificateChain: []*x509.Certificate{}, + }}, assert.Error}, + + // SoftKMS + {"ok softKMS", softKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=" + filepath.Join(dir, "chain.crt"), + CertificateChain: chain, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) && assert.FileExists(t, filepath.Join(dir, "chain.crt")) + }}, + {"ok softKMS escape", softKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=" + url.QueryEscape(filepath.Join(dir, "leaf.crt")), + CertificateChain: chain[:1], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) && assert.FileExists(t, filepath.Join(dir, "leaf.crt")) + }}, + {"fail parseURI", softKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: "foo:name=" + filepath.Join(dir, "other.crt"), + CertificateChain: chain, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.Error(t, err) && assert.NoFileExists(t, filepath.Join(dir, "other.crt")) + }}, + {"fail name", softKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: "", + CertificateChain: chain, + }}, assert.Error}, + {"fail empty uri", softKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:", + CertificateChain: chain, + }}, assert.Error}, + {"fail certificateChain", softKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=" + filepath.Join(dir, "other.crt"), + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + tt.assertion(t, tt.kms.StoreCertificateChain(tt.args.req)) + }) + } +} + +func TestKMS_DeleteCertificate(t *testing.T) { + dir := t.TempDir() + _ = mustCertificate(t, filepath.Join(dir, "chain.crt")) + softKMS := mustKMS(t, "kms:backend=softkms") + + platformKMS := mustPlatformKMS(t) + _ = mustCreatePlatformCertificate(t, platformKMS, withNoCleanupCertificate()) + + type args struct { + req *apiv1.DeleteCertificateRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok platform", platformKMS, args{&apiv1.DeleteCertificateRequest{ + Name: platformCertName, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + _, loadErr := platformKMS.LoadCertificate(&apiv1.LoadCertificateRequest{ + Name: platformCertName, + }) + return assert.NoError(t, err) && assert.Error(t, loadErr) + }}, + {"fail platform missing", platformKMS, args{&apiv1.DeleteCertificateRequest{ + Name: platformMissingName, + }}, assert.Error}, + + // SoftKMS + {"ok softKMS", softKMS, args{&apiv1.DeleteCertificateRequest{ + Name: "kms:name=" + url.QueryEscape(filepath.Join(dir, "chain.crt")), + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) && + assert.NoFileExists(t, filepath.Join(dir, "chain.crt")) + }}, + {"fail missing", softKMS, args{&apiv1.DeleteCertificateRequest{ + Name: "kms:name=" + url.QueryEscape(filepath.Join(dir, "chain.crt")), + }}, assert.Error}, + {"fail parseURI", softKMS, args{&apiv1.DeleteCertificateRequest{ + Name: "foo", + }}, assert.Error}, + {"fail name", softKMS, args{&apiv1.DeleteCertificateRequest{ + Name: "", + }}, assert.Error}, + {"fail empty uri", softKMS, args{&apiv1.DeleteCertificateRequest{ + Name: "kms:", + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + + tt.assertion(t, tt.kms.DeleteCertificate(tt.args.req)) + }) + } +} + +func TestKMS_CreateAttestation(t *testing.T) { + dir := t.TempDir() + privateKeyPath := filepath.Join(dir, "private.key") + signer := mustSigner(t, privateKeyPath) + attester := mustSigner(t, filepath.Join(dir, "attester.key")) + permanentIdentifier := mustPermanentIdentifier(t, attester.Public()) + + ca, err := minica.New() + require.NoError(t, err) + cert, err := ca.Sign(&x509.Certificate{ + Subject: pkix.Name{ + CommonName: "attestation certificate", + }, + URIs: []*url.URL{permanentIdentifier}, + PublicKey: signer.Public(), + }) + require.NoError(t, err) + + softKMS := mustKMS(t, "kms:backend=softkms") + okClient := mustAttestationClient([]*x509.Certificate{cert, ca.Intermediate}, nil) + failClient := mustAttestationClient(nil, errors.New("attestation failed")) + + type args struct { + req *apiv1.CreateAttestationRequest + } + tests := []struct { + name string + kms *KMS + args args + want *apiv1.CreateAttestationResponse + assertion assert.ErrorAssertionFunc + }{ + {"ok custom attestation", softKMS, args{&apiv1.CreateAttestationRequest{ + Name: "kms:" + privateKeyPath, + AttestationClient: okClient, + }}, &apiv1.CreateAttestationResponse{ + Certificate: cert, + CertificateChain: []*x509.Certificate{cert, ca.Intermediate}, + PublicKey: signer.Public(), + PermanentIdentifier: permanentIdentifier.String(), + }, assert.NoError}, + {"fail missing key", softKMS, args{&apiv1.CreateAttestationRequest{ + Name: "kms:" + platformMissingName, + AttestationClient: okClient, + }}, nil, assert.Error}, + {"fail custom attestation", softKMS, args{&apiv1.CreateAttestationRequest{ + Name: "kms:" + privateKeyPath, + AttestationClient: failClient, + }}, nil, assert.Error}, + {"fail no client", softKMS, args{&apiv1.CreateAttestationRequest{ + Name: "kms:" + privateKeyPath, + }}, nil, assert.Error}, + {"fail no name", softKMS, args{&apiv1.CreateAttestationRequest{}}, nil, assert.Error}, + {"fail parse", softKMS, args{&apiv1.CreateAttestationRequest{ + Name: "tpmkms:", + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + + got, err := tt.kms.CreateAttestation(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_SearchKeys(t *testing.T) { + dir := t.TempDir() + softKMS := mustKMS(t, "kms:backend=softkms") + + suffix := mustSuffix(t) + platformKMS := mustPlatformKMS(t) + + platformKeys := make([]*apiv1.CreateKeyResponse, 4) + for i := range platformKeys { + name := fmt.Sprintf("kms:name=search-test-%d-%s", i, suffix) + if runtime.GOOS == "darwin" { + name += ";tag=com.smallstep.test." + suffix + } + platformKeys[i] = mustCreatePlatformKey(t, platformKMS, withName(name)) + } + + makeResult := func(r *apiv1.CreateKeyResponse) apiv1.SearchKeyResult { + return apiv1.SearchKeyResult{ + Name: r.Name, + PublicKey: r.PublicKey, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: r.Name, + }, + } + } + + type args struct { + req *apiv1.SearchKeysRequest + } + tests := []struct { + name string + kms *KMS + args args + want *apiv1.SearchKeysResponse + assertion assert.ErrorAssertionFunc + }{ + // PlatformKMS + {"ok platform", platformKMS, args{&apiv1.SearchKeysRequest{ + Query: "kms:tag=com.smallstep.test." + suffix, + }}, &apiv1.SearchKeysResponse{ + Results: []apiv1.SearchKeyResult{ + makeResult(platformKeys[0]), makeResult(platformKeys[1]), + makeResult(platformKeys[2]), makeResult(platformKeys[3]), + }, + }, assert.NoError}, + {"ok platform with name", platformKMS, args{&apiv1.SearchKeysRequest{ + Query: fmt.Sprintf("kms:name=search-test-%d-%s;tag=com.smallstep.test.%s", 2, suffix, suffix), + }}, &apiv1.SearchKeysResponse{ + Results: []apiv1.SearchKeyResult{ + makeResult(platformKeys[2]), + }, + }, assert.NoError}, + {"fail parse", platformKMS, args{&apiv1.SearchKeysRequest{ + Query: "name=", + }}, nil, assert.Error}, + + // SoftKMS + {"fail softKMS", softKMS, args{&apiv1.SearchKeysRequest{ + Query: "kms:name=" + url.QueryEscape(dir), + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + + got, err := tt.kms.SearchKeys(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/kms/platform/kms_tpm.go b/kms/platform/kms_tpm.go new file mode 100644 index 00000000..658d3ba3 --- /dev/null +++ b/kms/platform/kms_tpm.go @@ -0,0 +1,95 @@ +package platform + +import ( + "context" + "maps" + "net/url" + "runtime" + + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/tpmkms" + "go.step.sm/crypto/kms/uri" + "go.step.sm/crypto/tpm" +) + +var _ apiv1.Attester = (*KMS)(nil) + +func newTPMKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { + if opts.URI != "" { + u, err := transformToTPMKMS(opts.URI) + if err != nil { + return nil, err + } + opts.URI = u + } + + km, err := tpmkms.New(ctx, opts) + if err != nil { + return nil, err + } + + return &KMS{ + typ: apiv1.TPMKMS, + backend: km, + transformToURI: transformToTPMKMS, + transformFromURI: transformFromTPMKMS, + }, nil +} + +func NewWithTPM(ctx context.Context, t *tpm.TPM, opts ...tpmkms.Option) (*KMS, error) { + km, err := tpmkms.NewWithTPM(ctx, t, opts...) + if err != nil { + return nil, err + } + + return &KMS{ + typ: apiv1.TPMKMS, + backend: km, + transformToURI: transformToTPMKMS, + transformFromURI: transformFromTPMKMS, + }, nil +} + +func transformToTPMKMS(rawuri string) (string, error) { + u, err := parseURI(rawuri) + if err != nil { + return "", err + } + + uv := url.Values{} + if u.name != "" { + uv.Set("name", u.name) + } + + // When storing a certificate on windows, skip key validation. This avoid a + // prompt looking for an SmartCard. + if runtime.GOOS == "windows" && !u.extraValues.Has("skip-find-certificate-key") { + uv.Set("skip-find-certificate-key", "true") + } + + // Add custom extra values that might be tpmkms specific. + // There is not need to set "hw". + maps.Copy(uv, u.extraValues) + + return uri.New(tpmkms.Scheme, uv).String(), nil +} + +func transformFromTPMKMS(rawuri string) (string, error) { + u, err := uri.ParseWithScheme(tpmkms.Scheme, rawuri) + if err != nil { + return "", err + } + + uv := url.Values{} + if u.Has("name") { + uv.Set(nameKey, u.Get("name")) + } + + for k, v := range uri.Values(u) { + if k != nameKey { + uv[k] = v + } + } + + return uri.New(Scheme, uv).String(), nil +} diff --git a/kms/platform/kms_tpm_test.go b/kms/platform/kms_tpm_test.go new file mode 100644 index 00000000..ae47d84e --- /dev/null +++ b/kms/platform/kms_tpm_test.go @@ -0,0 +1,70 @@ +package platform + +import ( + "runtime" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/kms/uri" +) + +func Test_transformToTPMKMS(t *testing.T) { + fixWindowsURI := func(s string) string { + if runtime.GOOS != "windows" { + return s + } + + u, err := uri.Parse(s) + require.NoError(t, err) + u.Set("skip-find-certificate-key", "true") + return u.String() + } + + tests := []struct { + name string + rawuri string + want string + assertion assert.ErrorAssertionFunc + }{ + {"scheme", "kms:", fixWindowsURI("tpmkms:"), assert.NoError}, + {"with name", "kms:name=foo", fixWindowsURI("tpmkms:name=foo"), assert.NoError}, + {"with ak", "kms:name=foo;ak=true", fixWindowsURI("tpmkms:ak=true;name=foo"), assert.NoError}, + {"with ak in query", "kms:name=foo?ak=true", fixWindowsURI("tpmkms:ak=true;name=foo"), assert.NoError}, + {"with ak false", "kms:ak=false", fixWindowsURI("tpmkms:ak=false"), assert.NoError}, + {"with extrasValues", "kms:name=foo;foo=bar?baz=qux", fixWindowsURI("tpmkms:baz=qux;foo=bar;name=foo"), assert.NoError}, + {"without hw", "kms:name=foo;hw=true", fixWindowsURI("tpmkms:name=foo"), assert.NoError}, + {"fail parse", "mackms:name=foo", "", assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := transformToTPMKMS(tt.rawuri) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_transformFromTPMKMS(t *testing.T) { + tests := []struct { + name string + rawuri string + want string + assertion assert.ErrorAssertionFunc + }{ + {"scheme", "tpmkms:", "kms:", assert.NoError}, + {"with label", "tpmkms:name=foo", "kms:name=foo", assert.NoError}, + {"with ak", "tpmkms:name=foo;ak=true", "kms:ak=true;name=foo", assert.NoError}, + {"with ak on query", "tpmkms:name=foo?ak=true", "kms:ak=true;name=foo", assert.NoError}, + {"with ak false", "tpmkms:ak=false;name=foo", "kms:ak=false;name=foo", assert.NoError}, + {"fail empty", "", "", assert.Error}, + {"fail scheme", "kms:", "", assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := transformFromTPMKMS(tt.rawuri) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/kms/platform/kms_tpmsimulator_test.go b/kms/platform/kms_tpmsimulator_test.go new file mode 100644 index 00000000..f599a345 --- /dev/null +++ b/kms/platform/kms_tpmsimulator_test.go @@ -0,0 +1,905 @@ +//go:build tpmsimulator + +package platform + +import ( + "crypto" + "crypto/x509" + "crypto/x509/pkix" + "net" + "net/url" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/uri" + "go.step.sm/crypto/minica" + "go.step.sm/crypto/tpm" + "go.step.sm/crypto/tpm/simulator" + "go.step.sm/crypto/tpm/storage" +) + +func mustTPM(t *testing.T) *tpm.TPM { + t.Helper() + + sim, err := simulator.New() + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, sim.Close()) + }) + require.NoError(t, sim.Open()) + + dir := t.TempDir() + + stpm, err := tpm.New(tpm.WithSimulator(sim), tpm.WithStore(storage.NewDirstore(dir))) + require.NoError(t, err) + + return stpm +} + +func mustTPMDevice(t *testing.T) (*tpm.TPM, string, string) { + t.Helper() + + sim, err := simulator.New() + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, sim.Close()) + }) + require.NoError(t, sim.Open()) + + dir := t.TempDir() + stpm, err := tpm.New(tpm.WithSimulator(sim), tpm.WithStore(storage.NewDirstore(dir))) + require.NoError(t, err) + + listener := &net.ListenConfig{} + socket := filepath.Join(dir, "tpm.sock") + ln, err := listener.Listen(t.Context(), "unix", socket) + require.NoError(t, err) + + go func() { + for { + conn, err := ln.Accept() + require.NoError(t, err) + + go func(conn net.Conn) { + defer conn.Close() + + readBuf := make([]byte, 4096) + n, err := conn.Read(readBuf) + require.NoError(t, err) + + _, err = sim.Write(readBuf[:n]) + require.NoError(t, err) + + writeBuf := make([]byte, 4096) + nr, err := sim.Read(writeBuf) + require.NoError(t, err) + + _, err = conn.Write(writeBuf[:nr]) + require.NoError(t, err) + }(conn) + } + }() + + return stpm, socket, dir +} + +func mustTPMKMS(t *testing.T) (*KMS, *tpm.TPM) { + t.Helper() + + stpm, sock, dir := mustTPMDevice(t) + km := mustKMS(t, uri.New(Scheme, url.Values{ + "backend": []string{"tpmkms"}, + "device": []string{sock}, + "storage-directory": []string{dir}, + }).String()) + + return km, stpm +} + +func TestKMS_Type_tpm(t *testing.T) { + kms1, stpm := mustTPMKMS(t) + assert.Equal(t, apiv1.TPMKMS, kms1.Type()) + + kms2, err := NewWithTPM(t.Context(), stpm) + require.NoError(t, err) + assert.Equal(t, apiv1.TPMKMS, kms2.Type()) + +} + +func TestKMS_Close_tpm(t *testing.T) { + kms1, stpm := mustTPMKMS(t) + assert.NoError(t, kms1.Close()) + + kms2, err := NewWithTPM(t.Context(), stpm) + require.NoError(t, err) + assert.NoError(t, kms2.Close()) +} + +func TestKMS_GetPublicKey_tpm(t *testing.T) { + ctx := t.Context() + kms1, stpm := mustTPMKMS(t) + kms2, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + key, err := stpm.CreateKey(ctx, "key-1", tpm.CreateKeyConfig{ + Algorithm: "RSA", + Size: 2048, + }) + require.NoError(t, err) + + keySigner, err := key.Signer(ctx) + require.NoError(t, err) + + ak, err := stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + type args struct { + req *apiv1.GetPublicKeyRequest + } + tests := []struct { + name string + kms *KMS + args args + want crypto.PublicKey + assertion assert.ErrorAssertionFunc + }{ + {"ok key", kms1, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:name=key-1", + }}, keySigner.Public(), assert.NoError}, + {"ok ak", kms1, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:name=ak-1;ak=true", + }}, ak.Public(), assert.NoError}, + {"ok key with tpm", kms2, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:name=key-1", + }}, keySigner.Public(), assert.NoError}, + {"ok ak with tpm", kms2, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:name=ak-1;ak=true", + }}, ak.Public(), assert.NoError}, + {"fail missing key", kms1, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:name=key-2", + }}, nil, assert.Error}, + {"fail missing ak", kms2, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:name=ak-2;ak=true", + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.GetPublicKey(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_CreateKey_tpm(t *testing.T) { + ctx := t.Context() + kms1, stpm := mustTPMKMS(t) + kms2, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + type args struct { + req *apiv1.CreateKeyRequest + } + tests := []struct { + name string + kms *KMS + args args + equal func(t *testing.T, got *apiv1.CreateKeyResponse) + assertion assert.ErrorAssertionFunc + }{ + {"ok", kms1, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=key-1", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + key, err := stpm.GetKey(ctx, "key-1") + require.NoError(t, err) + signer, err := key.Signer(ctx) + require.NoError(t, err) + + require.NotNil(t, got) + require.NotNil(t, got.CreateSignerRequest.Signer) + + assert.Equal(t, signer.Public(), got.CreateSignerRequest.Signer.Public()) + got.CreateSignerRequest.Signer = signer + + assert.Equal(t, got, &apiv1.CreateKeyResponse{ + Name: "kms:name=key-1", + PublicKey: signer.Public(), + CreateSignerRequest: apiv1.CreateSignerRequest{ + Signer: signer, + SigningKey: "kms:name=key-1", + }, + }) + }, assert.NoError}, + {"ok ak", kms1, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=ak-1;ak=true", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + key, err := stpm.GetAK(ctx, "ak-1") + require.NoError(t, err) + + assert.Equal(t, got, &apiv1.CreateKeyResponse{ + Name: "kms:ak=true;name=ak-1", + PublicKey: key.Public(), + }) + }, assert.NoError}, + {"ok with tpm", kms2, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=key-2", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + key, err := stpm.GetKey(ctx, "key-2") + require.NoError(t, err) + signer, err := key.Signer(ctx) + require.NoError(t, err) + + require.NotNil(t, got) + require.NotNil(t, got.CreateSignerRequest.Signer) + + assert.Equal(t, signer.Public(), got.CreateSignerRequest.Signer.Public()) + got.CreateSignerRequest.Signer = signer + + assert.Equal(t, got, &apiv1.CreateKeyResponse{ + Name: "kms:name=key-2", + PublicKey: signer.Public(), + CreateSignerRequest: apiv1.CreateSignerRequest{ + Signer: signer, + SigningKey: "kms:name=key-2", + }, + }) + }, assert.NoError}, + {"ok ak with tpm", kms2, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=ak-2;ak=true", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + key, err := stpm.GetAK(ctx, "ak-2") + require.NoError(t, err) + + assert.Equal(t, got, &apiv1.CreateKeyResponse{ + Name: "kms:ak=true;name=ak-2", + PublicKey: key.Public(), + }) + }, assert.NoError}, + {"fail key already exists", kms1, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=key-2", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, + {"fail ak already exists", kms2, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=ak-1;ak=true", + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.CreateKey(tt.args.req) + tt.assertion(t, err) + tt.equal(t, got) + }) + } +} + +func TestKMS_CreateSigner_tpm(t *testing.T) { + ctx := t.Context() + stpm := mustTPM(t) + km, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + key, err := stpm.CreateKey(ctx, "key-1", tpm.CreateKeyConfig{ + Algorithm: "RSA", + Size: 2048, + }) + require.NoError(t, err) + + signer, err := key.Signer(ctx) + require.NoError(t, err) + + _, err = stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + type args struct { + req *apiv1.CreateSignerRequest + } + tests := []struct { + name string + kms *KMS + args args + equal func(*testing.T, crypto.Signer) + assertion assert.ErrorAssertionFunc + }{ + {"ok key", km, args{&apiv1.CreateSignerRequest{ + SigningKey: "kms:name=key-1", + }}, func(t *testing.T, got crypto.Signer) { + require.NotNil(t, got) + assert.Equal(t, signer.Public(), got.Public()) + }, assert.NoError}, + {"ok key with signer", km, args{&apiv1.CreateSignerRequest{ + Signer: signer, + SigningKey: "kms:name=key1", + }}, func(t *testing.T, got crypto.Signer) { + assert.Equal(t, signer, got) + }, assert.NoError}, + {"fail missing", km, args{&apiv1.CreateSignerRequest{ + SigningKey: "kms:name=key-2", + }}, func(t *testing.T, got crypto.Signer) { + assert.Nil(t, got) + }, assert.Error}, + {"fail with ak", km, args{&apiv1.CreateSignerRequest{ + SigningKey: "kms:name=ak-1;ak=true", + }}, func(t *testing.T, got crypto.Signer) { + assert.Nil(t, got) + }, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.CreateSigner(tt.args.req) + tt.assertion(t, err) + tt.equal(t, got) + }) + } +} + +func TestKMS_DeleteKey_tpm(t *testing.T) { + ctx := t.Context() + stpm := mustTPM(t) + km, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + _, err = stpm.CreateKey(ctx, "key-1", tpm.CreateKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + }) + require.NoError(t, err) + + _, err = stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + type args struct { + req *apiv1.DeleteKeyRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok key", km, args{&apiv1.DeleteKeyRequest{ + Name: "kms:name=key-1", + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + _, keyErr := stpm.GetKey(ctx, "key-1") + return assert.NoError(t, err) && assert.Error(t, keyErr) + }}, + {"ok ak", km, args{&apiv1.DeleteKeyRequest{ + Name: "kms:name=ak-1;ak=true", + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + _, akErr := stpm.GetAK(ctx, "ak-1") + return assert.NoError(t, err) && assert.Error(t, akErr) + }}, + {"fail missing key", km, args{&apiv1.DeleteKeyRequest{ + Name: "kms:name=key-2", + }}, assert.Error}, + {"fail missing ak", km, args{&apiv1.DeleteKeyRequest{ + Name: "kms:name=ak-2;ak=true", + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, tt.kms.DeleteKey(tt.args.req)) + }) + } +} + +func TestKMS_LoadCertificate_tpm(t *testing.T) { + ctx := t.Context() + stpm := mustTPM(t) + km, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + key, err := stpm.CreateKey(ctx, "key-1", tpm.CreateKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + }) + require.NoError(t, err) + + ak, err := stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + _, err = stpm.CreateKey(ctx, "key-2", tpm.CreateKeyConfig{ + Algorithm: "RSA", + Size: 2048, + }) + require.NoError(t, err) + + _, err = stpm.CreateAK(ctx, "ak-2") + require.NoError(t, err) + + keyChain := mustCertificateWithKey(t, key.Public()) + require.NoError(t, key.SetCertificateChain(ctx, keyChain)) + + akChain := mustCertificateWithKey(t, ak.Public()) + require.NoError(t, ak.SetCertificateChain(ctx, akChain)) + + type args struct { + req *apiv1.LoadCertificateRequest + } + tests := []struct { + name string + kms *KMS + args args + want *x509.Certificate + assertion assert.ErrorAssertionFunc + }{ + {"ok", km, args{&apiv1.LoadCertificateRequest{ + Name: "kms:name=key-1", + }}, keyChain[0], assert.NoError}, + {"ok ak", km, args{&apiv1.LoadCertificateRequest{ + Name: "kms:name=ak-1;ak=true", + }}, akChain[0], assert.NoError}, + {"fail no certificate", km, args{&apiv1.LoadCertificateRequest{ + Name: "kms:name=key-2", + }}, nil, assert.Error}, + {"fail no ak certificate", km, args{&apiv1.LoadCertificateRequest{ + Name: "kms:name=ak-2;ak=true", + }}, nil, assert.Error}, + {"fail missing", km, args{&apiv1.LoadCertificateRequest{ + Name: "kms:name=missing-key", + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.LoadCertificate(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_StoreCertificate_tpm(t *testing.T) { + ctx := t.Context() + stpm := mustTPM(t) + km, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + key, err := stpm.CreateKey(ctx, "key-1", tpm.CreateKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + }) + require.NoError(t, err) + + ak, err := stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + keyChain1 := mustCertificateWithKey(t, key.Public()) + keyChain2 := mustCertificateWithKey(t, key.Public()) + akChain1 := mustCertificateWithKey(t, ak.Public()) + akChain2 := mustCertificateWithKey(t, ak.Public()) + + type args struct { + req *apiv1.StoreCertificateRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok", km, args{&apiv1.StoreCertificateRequest{ + Name: "kms:name=key-1", + Certificate: keyChain1[0], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetKey(ctx, "key-1") + require.NoError(t, err) + return assert.Equal(t, keyChain1[0], k.Certificate()) + }}, + {"ok overwrite", km, args{&apiv1.StoreCertificateRequest{ + Name: "kms:name=key-1", + Certificate: keyChain2[0], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetKey(ctx, "key-1") + require.NoError(t, err) + return assert.Equal(t, keyChain2[0], k.Certificate()) + }}, + {"ok ak", km, args{&apiv1.StoreCertificateRequest{ + Name: "kms:name=ak-1;ak=true", + Certificate: akChain1[0], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetAK(ctx, "ak-1") + require.NoError(t, err) + return assert.Equal(t, akChain1[0], k.Certificate()) + }}, + {"ok ak overwrite", km, args{&apiv1.StoreCertificateRequest{ + Name: "kms:name=ak-1;ak=true", + Certificate: akChain2[0], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetAK(ctx, "ak-1") + require.NoError(t, err) + return assert.Equal(t, akChain2[0], k.Certificate()) + }}, + {"fail missing", km, args{&apiv1.StoreCertificateRequest{ + Name: "kms:name=missing-key", + Certificate: keyChain1[0], + }}, assert.Error}, + {"fail key not match", km, args{&apiv1.StoreCertificateRequest{ + Name: "kms:name=key-1", + Certificate: akChain1[0], + }}, assert.Error}, + {"fail ak key not match", km, args{&apiv1.StoreCertificateRequest{ + Name: "kms:name=ak-1;ak=true", + Certificate: keyChain1[0], + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, tt.kms.StoreCertificate(tt.args.req)) + }) + } +} + +func TestKMS_LoadCertificateChain_tpm(t *testing.T) { + ctx := t.Context() + stpm := mustTPM(t) + km, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + key, err := stpm.CreateKey(ctx, "key-1", tpm.CreateKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + }) + require.NoError(t, err) + + ak, err := stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + _, err = stpm.CreateKey(ctx, "key-2", tpm.CreateKeyConfig{ + Algorithm: "RSA", + Size: 2048, + }) + require.NoError(t, err) + + _, err = stpm.CreateAK(ctx, "ak-2") + require.NoError(t, err) + + keyChain := mustCertificateWithKey(t, key.Public()) + require.NoError(t, key.SetCertificateChain(ctx, keyChain)) + + akChain := mustCertificateWithKey(t, ak.Public()) + require.NoError(t, ak.SetCertificateChain(ctx, akChain)) + + type args struct { + req *apiv1.LoadCertificateChainRequest + } + tests := []struct { + name string + kms *KMS + args args + want []*x509.Certificate + assertion assert.ErrorAssertionFunc + }{ + {"ok", km, args{&apiv1.LoadCertificateChainRequest{ + Name: "kms:name=key-1", + }}, keyChain, assert.NoError}, + {"ok ak", km, args{&apiv1.LoadCertificateChainRequest{ + Name: "kms:name=ak-1;ak=true", + }}, akChain, assert.NoError}, + {"fail no chain", km, args{&apiv1.LoadCertificateChainRequest{ + Name: "kms:name=key-2", + }}, nil, assert.Error}, + {"fail no ak chain", km, args{&apiv1.LoadCertificateChainRequest{ + Name: "kms:name=ak-2;ak=true", + }}, nil, assert.Error}, + {"fail missing", km, args{&apiv1.LoadCertificateChainRequest{ + Name: "kms:name=missing-key", + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.LoadCertificateChain(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_StoreCertificateChain_tpm(t *testing.T) { + ctx := t.Context() + stpm := mustTPM(t) + km, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + key, err := stpm.CreateKey(ctx, "key-1", tpm.CreateKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + }) + require.NoError(t, err) + + ak, err := stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + keyChain1 := mustCertificateWithKey(t, key.Public()) + keyChain2 := mustCertificateWithKey(t, key.Public()) + akChain1 := mustCertificateWithKey(t, ak.Public()) + akChain2 := mustCertificateWithKey(t, ak.Public()) + + type args struct { + req *apiv1.StoreCertificateChainRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok", km, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=key-1", + CertificateChain: keyChain1, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetKey(ctx, "key-1") + require.NoError(t, err) + return assert.Equal(t, keyChain1[0], k.Certificate()) && + assert.Equal(t, keyChain1, k.CertificateChain()) + }}, + {"ok overwrite", km, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=key-1", + CertificateChain: keyChain2, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetKey(ctx, "key-1") + require.NoError(t, err) + return assert.Equal(t, keyChain2[0], k.Certificate()) && + assert.Equal(t, keyChain2, k.CertificateChain()) + }}, + {"ok ak", km, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=ak-1;ak=true", + CertificateChain: akChain1, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetAK(ctx, "ak-1") + require.NoError(t, err) + return assert.Equal(t, akChain1[0], k.Certificate()) && + assert.Equal(t, akChain1, k.CertificateChain()) + }}, + {"ok ak overwrite", km, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=ak-1;ak=true", + CertificateChain: akChain2, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetAK(ctx, "ak-1") + require.NoError(t, err) + return assert.Equal(t, akChain2[0], k.Certificate()) && + assert.Equal(t, akChain2, k.CertificateChain()) + }}, + {"fail missing", km, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=missing-key", + CertificateChain: keyChain1, + }}, assert.Error}, + {"fail key not match", km, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=key-1", + CertificateChain: akChain1, + }}, assert.Error}, + {"fail ak key not match", km, args{&apiv1.StoreCertificateChainRequest{ + Name: "kms:name=ak-1;ak=true", + CertificateChain: keyChain1, + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, tt.kms.StoreCertificateChain(tt.args.req)) + }) + } +} + +func TestKMS_DeleteCertificate_tpm(t *testing.T) { + ctx := t.Context() + stpm := mustTPM(t) + km, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + key, err := stpm.CreateKey(ctx, "key-1", tpm.CreateKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + }) + require.NoError(t, err) + + ak, err := stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + keyChain := mustCertificateWithKey(t, key.Public()) + require.NoError(t, key.SetCertificateChain(ctx, keyChain)) + + akChain := mustCertificateWithKey(t, ak.Public()) + require.NoError(t, ak.SetCertificateChain(ctx, akChain)) + + type args struct { + req *apiv1.DeleteCertificateRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok", km, args{&apiv1.DeleteCertificateRequest{ + Name: "kms:name=key-1", + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetKey(ctx, "key-1") + require.NoError(t, err) + return assert.Nil(t, k.Certificate()) && assert.Nil(t, k.CertificateChain()) + }}, + {"ok ak", km, args{&apiv1.DeleteCertificateRequest{ + Name: "kms:name=ak-1;ak=true", + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + k, err := stpm.GetAK(ctx, "ak-1") + require.NoError(t, err) + return assert.Nil(t, k.Certificate()) && assert.Nil(t, k.CertificateChain()) + }}, + {"ok delete again", km, args{&apiv1.DeleteCertificateRequest{ + Name: "kms:name=key-1", + }}, assert.NoError}, + {"ok delete again ak", km, args{&apiv1.DeleteCertificateRequest{ + Name: "kms:name=ak-1;ak=true", + }}, assert.NoError}, + {"fail missing", km, args{&apiv1.DeleteCertificateRequest{ + Name: "kms:name=missing-ak;ak=true", + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, tt.kms.DeleteCertificate(tt.args.req)) + }) + } +} + +func TestKMS_CreateAttestation_tpm(t *testing.T) { + ctx := t.Context() + stpm := mustTPM(t) + km, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + eks, err := stpm.GetEKs(ctx) + require.NoError(t, err) + require.NotEmpty(t, eks) + ekKeyURL := mustPermanentIdentifier(t, eks[0].Public()) + + ak, err := stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + key, err := stpm.AttestKey(ctx, "ak-1", "key-1", tpm.AttestKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + QualifyingData: []byte{1, 2, 3, 4}, + }) + require.NoError(t, err) + keyParams, err := key.CertificationParameters(ctx) + require.NoError(t, err) + keySigner, err := key.Signer(ctx) + require.NoError(t, err) + + _, err = stpm.CreateKey(ctx, "key-2", tpm.CreateKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + }) + require.NoError(t, err) + + ca, err := minica.New() + require.NoError(t, err) + + akCert, err := ca.Sign(&x509.Certificate{ + Subject: pkix.Name{ + CommonName: "ak-1", + }, + URIs: []*url.URL{ekKeyURL}, + PublicKey: ak.Public(), + }) + require.NoError(t, err) + require.NoError(t, ak.SetCertificateChain(ctx, []*x509.Certificate{ + akCert, ca.Intermediate, + })) + + type args struct { + req *apiv1.CreateAttestationRequest + } + tests := []struct { + name string + kms *KMS + args args + want *apiv1.CreateAttestationResponse + assertion assert.ErrorAssertionFunc + }{ + {"ok", km, args{&apiv1.CreateAttestationRequest{ + Name: "kms:name=key-1", + }}, &apiv1.CreateAttestationResponse{ + Certificate: akCert, + CertificateChain: []*x509.Certificate{akCert, ca.Intermediate}, + PublicKey: keySigner.Public(), + CertificationParameters: &apiv1.CertificationParameters{ + Public: keyParams.Public, + CreateData: keyParams.CreateData, + CreateAttestation: keyParams.CreateAttestation, + CreateSignature: keyParams.CreateSignature, + }, + PermanentIdentifier: ekKeyURL.String(), + }, assert.NoError}, + {"fail not attested key", km, args{&apiv1.CreateAttestationRequest{ + Name: "kms:name=key-2", + }}, nil, assert.Error}, + {"fail missing key", km, args{&apiv1.CreateAttestationRequest{ + Name: "kms:name=key-3", + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkipNow(t, tt.kms) + + got, err := tt.kms.CreateAttestation(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_SearchKeys_tpm(t *testing.T) { + ctx := t.Context() + stpm := mustTPM(t) + km, err := NewWithTPM(ctx, stpm) + require.NoError(t, err) + + key1, err := stpm.CreateKey(ctx, "key-1", tpm.CreateKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + }) + require.NoError(t, err) + + key2, err := stpm.CreateKey(ctx, "key-2", tpm.CreateKeyConfig{ + Algorithm: "ECDSA", + Size: 256, + }) + require.NoError(t, err) + + ak1, err := stpm.CreateAK(ctx, "ak-1") + require.NoError(t, err) + + ak2, err := stpm.CreateAK(ctx, "ak-2") + require.NoError(t, err) + + type args struct { + req *apiv1.SearchKeysRequest + } + tests := []struct { + name string + kms *KMS + args args + want *apiv1.SearchKeysResponse + assertion assert.ErrorAssertionFunc + }{ + {"ok", km, args{&apiv1.SearchKeysRequest{ + Query: "kms:", + }}, &apiv1.SearchKeysResponse{ + Results: []apiv1.SearchKeyResult{ + {Name: "kms:ak=true;name=ak-1", PublicKey: ak1.Public()}, + {Name: "kms:ak=true;name=ak-2", PublicKey: ak2.Public()}, + {Name: "kms:name=key-1", PublicKey: key1.Public(), CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: "kms:name=key-1"}}, + {Name: "kms:name=key-2", PublicKey: key2.Public(), CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: "kms:name=key-2"}}, + }, + }, assert.NoError}, + {"ok keys", km, args{&apiv1.SearchKeysRequest{ + Query: "kms:ak=false", + }}, &apiv1.SearchKeysResponse{ + Results: []apiv1.SearchKeyResult{ + {Name: "kms:name=key-1", PublicKey: key1.Public(), CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: "kms:name=key-1"}}, + {Name: "kms:name=key-2", PublicKey: key2.Public(), CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: "kms:name=key-2"}}, + }, + }, assert.NoError}, + {"ok aks", km, args{&apiv1.SearchKeysRequest{ + Query: "kms:ak=true", + }}, &apiv1.SearchKeysResponse{ + Results: []apiv1.SearchKeyResult{ + {Name: "kms:ak=true;name=ak-1", PublicKey: ak1.Public()}, + {Name: "kms:ak=true;name=ak-2", PublicKey: ak2.Public()}, + }, + }, assert.NoError}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.SearchKeys(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/kms/platform/kms_windows.go b/kms/platform/kms_windows.go new file mode 100644 index 00000000..05b99548 --- /dev/null +++ b/kms/platform/kms_windows.go @@ -0,0 +1,123 @@ +//go:build windows + +package platform + +import ( + "context" + "fmt" + "maps" + "net/url" + + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/capi" + "go.step.sm/crypto/kms/uri" +) + +const tpmProvider = "Microsoft Platform Crypto Provider" + +func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { + if opts.URI == "" { + opts.URI = withEnableCNG(nil) + return newTPMKMS(ctx, opts) + } + + u, err := parseURI(opts.URI) + if err != nil { + return nil, err + } + + switch u.backend { + case apiv1.CAPIKMS: + return newCAPIKMS(ctx, opts) + case apiv1.SoftKMS: + return newSoftKMS(ctx, opts) + case apiv1.DefaultKMS, apiv1.TPMKMS: + opts.URI = withEnableCNG(u.uri) + return newTPMKMS(ctx, opts) + default: + return nil, fmt.Errorf("failed parsing %q: unsupported backend %q", opts.URI, u.backend) + } +} + +func newCAPIKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) { + if opts.URI != "" { + u, err := transformToCAPIKMS(opts.URI) + if err != nil { + return nil, err + } + opts.URI = u + } + + km, err := capi.New(ctx, opts) + if err != nil { + return nil, err + } + + return &KMS{ + typ: apiv1.CAPIKMS, + backend: km, + transformToURI: transformToCAPIKMS, + transformFromURI: transformFromCAPIKMS, + }, nil +} + +func withEnableCNG(u *uri.URI) string { + if u == nil { + return "kms:enable-cng=true" + } + if !u.Has("enable-cng") { + u.Set("enable-cng", "true") + } + return u.String() +} + +func transformToCAPIKMS(rawuri string) (string, error) { + u, err := parseURI(rawuri) + if err != nil { + return "", err + } + + uv := url.Values{} + if u.name != "" { + uv.Set("key", u.name) + } + + // When storing certificate skip key validation. + // This avoid a prompt looking for an SmartCard. + if !u.extraValues.Has("skip-find-certificate-key") { + uv.Set("skip-find-certificate-key", "true") + } + + // Set provider "Microsoft Platform Crypto Provider" to use the TPM. + if u.hw && !u.extraValues.Has("provider") { + uv.Set("provider", tpmProvider) + } + + // Add custom extra values that might be CAPI specific. + maps.Copy(uv, u.extraValues) + + return uri.New(capi.Scheme, uv).String(), nil +} + +func transformFromCAPIKMS(rawuri string) (string, error) { + u, err := uri.ParseWithScheme(capi.Scheme, rawuri) + if err != nil { + return "", err + } + + uv := url.Values{} + if u.Has("key") { + uv.Set("name", u.Get("key")) + } + if u.Get("provider") == tpmProvider { + uv.Set("hw", "true") + } + + for k, v := range uri.Values(u) { + if k != "key" { + uv[k] = v + } + } + + return uri.New(Scheme, uv).String(), nil +} diff --git a/kms/platform/kms_windows_test.go b/kms/platform/kms_windows_test.go new file mode 100644 index 00000000..79b159ae --- /dev/null +++ b/kms/platform/kms_windows_test.go @@ -0,0 +1,550 @@ +//go:build windows + +package platform + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/x509" + "encoding/hex" + "fmt" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/uri" + "go.step.sm/crypto/tpm/available" +) + +func mustPlatformKMS(t *testing.T) *KMS { + t.Helper() + + if available.Check() != nil { + return &KMS{} + } + + return mustKMS(t, uri.New(Scheme, url.Values{ + "storage-directory": []string{t.TempDir()}, + }).String()) +} + +// SkipTest is a method implemented on tests that allow skipping the test on +// this platform. +func (k *KMS) SkipTests() bool { + return k.Type() == apiv1.DefaultKMS +} + +func mustCAPIKMS(t *testing.T) *KMS { + return mustKMS(t, "kms:backend=capi") +} + +func TestKMS_Type_capi(t *testing.T) { + km := mustCAPIKMS(t) + assert.Equal(t, apiv1.CAPIKMS, km.Type()) +} + +func TestKMS_GetPublicKey_capi(t *testing.T) { + capiKMS := mustCAPIKMS(t) + capiKey := mustCreatePlatformKey(t, capiKMS) + + type args struct { + req *apiv1.GetPublicKeyRequest + } + tests := []struct { + name string + kms *KMS + args args + want crypto.PublicKey + assertion assert.ErrorAssertionFunc + }{ + {"ok capi", capiKMS, args{&apiv1.GetPublicKeyRequest{ + Name: capiKey.Name, + }}, capiKey.PublicKey, assert.NoError}, + {"fail capi missing", capiKMS, args{&apiv1.GetPublicKeyRequest{ + Name: platformMissingName, + }}, nil, assert.Error}, + {"fail capi name", capiKMS, args{&apiv1.GetPublicKeyRequest{ + Name: "kms:something=test", + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.GetPublicKey(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_CreateKey_capi(t *testing.T) { + suffix := mustSuffix(t) + capiKMS := mustCAPIKMS(t) + + type args struct { + req *apiv1.CreateKeyRequest + } + tests := []struct { + name string + kms *KMS + args args + equal func(t *testing.T, got *apiv1.CreateKeyResponse) + assertion assert.ErrorAssertionFunc + }{ + {"ok capi", capiKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=test1-" + suffix, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + require.NotNil(t, got) + + t.Cleanup(func() { + assert.NoError(t, capiKMS.DeleteKey(&apiv1.DeleteKeyRequest{ + Name: "kms:name=test1-" + suffix, + })) + }) + + assert.Regexp(t, "^kms:.*name=.*;provider=.*$", got.Name) + assert.Equal(t, got.Name, got.CreateSignerRequest.SigningKey) + + if capiKMS.Type() == apiv1.TPMKMS && assert.IsType(t, &rsa.PublicKey{}, got.PublicKey) { + assert.Equal(t, 256, got.PublicKey.(*rsa.PublicKey).Size()) + } else if assert.IsType(t, &ecdsa.PublicKey{}, got.PublicKey) { + assert.Equal(t, elliptic.P256(), got.PublicKey.(*ecdsa.PublicKey).Curve) + } + }, assert.NoError}, + {"ok capi ECDSA", capiKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=test2-" + suffix, + SignatureAlgorithm: apiv1.ECDSAWithSHA384, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + require.NotNil(t, got) + + t.Cleanup(func() { + assert.NoError(t, capiKMS.DeleteKey(&apiv1.DeleteKeyRequest{ + Name: "kms:name=test2-" + suffix, + })) + }) + + assert.Regexp(t, "^kms:.*name=.*;provider=.*$", got.Name) + assert.Equal(t, got.Name, got.CreateSignerRequest.SigningKey) + + if assert.IsType(t, &ecdsa.PublicKey{}, got.PublicKey) { + assert.Equal(t, elliptic.P384(), got.PublicKey.(*ecdsa.PublicKey).Curve) + } + }, assert.NoError}, + {"ok capi RSA", capiKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:name=test3-" + suffix, + SignatureAlgorithm: apiv1.SHA256WithRSA, + Bits: 2048, + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + require.NotNil(t, got) + + t.Cleanup(func() { + assert.NoError(t, capiKMS.DeleteKey(&apiv1.DeleteKeyRequest{ + Name: "kms:name=test3-" + suffix, + })) + }) + + assert.Regexp(t, "^kms:.*name=.*;provider=.*$", got.Name) + assert.Equal(t, got.Name, got.CreateSignerRequest.SigningKey) + if assert.IsType(t, &rsa.PublicKey{}, got.PublicKey) { + assert.Equal(t, 256, got.PublicKey.(*rsa.PublicKey).Size()) + } + }, assert.NoError}, + {"fail capi algorithm", capiKMS, args{&apiv1.CreateKeyRequest{ + Name: "kms:test4-" + suffix, + SignatureAlgorithm: apiv1.SignatureAlgorithm(100), + }}, func(t *testing.T, got *apiv1.CreateKeyResponse) { + assert.Nil(t, got) + }, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.CreateKey(tt.args.req) + tt.assertion(t, err) + tt.equal(t, got) + }) + } +} + +func TestKMS_CreateSigner_capi(t *testing.T) { + capiKMS := mustCAPIKMS(t) + capiKey := mustCreatePlatformKey(t, capiKMS) + + assertNil := func(t *testing.T, got crypto.Signer) { + t.Helper() + assert.Nil(t, got) + } + + type args struct { + req *apiv1.CreateSignerRequest + } + tests := []struct { + name string + kms *KMS + args args + equal func(*testing.T, crypto.Signer) + assertion assert.ErrorAssertionFunc + }{ + {"ok capi", capiKMS, args{&apiv1.CreateSignerRequest{ + SigningKey: platformKeyName, + }}, func(t *testing.T, s crypto.Signer) { + require.NotNil(t, s) + assert.Equal(t, capiKey.PublicKey, s.Public()) + }, assert.NoError}, + {"fail capi missing", capiKMS, args{&apiv1.CreateSignerRequest{ + SigningKey: platformMissingName, + }}, assertNil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.CreateSigner(tt.args.req) + tt.assertion(t, err) + tt.equal(t, got) + }) + } +} + +func TestKMS_DeleteKey_capi(t *testing.T) { + capiKMS := mustCAPIKMS(t) + capiKey := mustCreatePlatformKey(t, capiKMS, withNoCleanup()) + + type args struct { + req *apiv1.DeleteKeyRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok capi", capiKMS, args{&apiv1.DeleteKeyRequest{ + Name: capiKey.Name, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + _, getErr := capiKMS.GetPublicKey(&apiv1.GetPublicKeyRequest{ + Name: capiKey.Name, + }) + return assert.NoError(t, err) && assert.Error(t, getErr) + }}, + {"fail capi deleted", capiKMS, args{&apiv1.DeleteKeyRequest{ + Name: capiKey.Name, + }}, assert.Error}, + {"fail capi missing", capiKMS, args{&apiv1.DeleteKeyRequest{ + Name: platformMissingName, + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, tt.kms.DeleteKey(tt.args.req)) + }) + } +} + +func TestKMS_LoadCertificate_capi(t *testing.T) { + capiKMS := mustCAPIKMS(t) + capiKey := mustCreatePlatformKey(t, capiKMS) + capiChain := mustCertificateWithKey(t, capiKey.PublicKey) + require.NoError(t, capiKMS.StoreCertificateChain(&apiv1.StoreCertificateChainRequest{ + Name: platformCertName, + CertificateChain: capiChain, + })) + t.Cleanup(func() { + assert.NoError(t, capiKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: platformCertName, + })) + assert.NoError(t, capiKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: uri.New(Scheme, url.Values{ + "issuer": []string{capiChain[1].Issuer.CommonName}, + "serial": []string{capiChain[1].SerialNumber.String()}, + }).String(), + })) + }) + + type args struct { + req *apiv1.LoadCertificateRequest + } + tests := []struct { + name string + kms *KMS + args args + want *x509.Certificate + assertion assert.ErrorAssertionFunc + }{ + {"ok capi", capiKMS, args{&apiv1.LoadCertificateRequest{ + Name: platformCertName, + }}, capiChain[0], assert.NoError}, + {"ok capi issuer and serial", capiKMS, args{&apiv1.LoadCertificateRequest{ + Name: uri.New(Scheme, url.Values{ + "issuer": []string{capiChain[0].Issuer.CommonName}, + "serial": []string{capiChain[0].SerialNumber.String()}, + }).String(), + }}, capiChain[0], assert.NoError}, + {"fail capi missing", capiKMS, args{&apiv1.LoadCertificateRequest{ + Name: platformMissingName, + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.LoadCertificate(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_StoreCertificate_capi(t *testing.T) { + capiKMS := mustCAPIKMS(t) + capiKey := mustCreatePlatformKey(t, capiKMS) + capiChain := mustCertificateWithKey(t, capiKey.PublicKey) + chainNoKey := mustCertificate(t, "") + + type args struct { + req *apiv1.StoreCertificateRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok capi", capiKMS, args{&apiv1.StoreCertificateRequest{ + Name: platformCertName, + Certificate: capiChain[0], + }}, assert.NoError}, + {"ok capi no key", capiKMS, args{&apiv1.StoreCertificateRequest{ + Name: platformCertName + "-other", + Certificate: chainNoKey[0], + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + t.Cleanup(func() { + assert.NoError(t, capiKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: platformCertName, + })) + }) + return assert.NoError(t, err) + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, tt.kms.StoreCertificate(tt.args.req)) + }) + } +} + +func TestKMS_LoadCertificateChain_capi(t *testing.T) { + capiKMS := mustCAPIKMS(t) + capiKey := mustCreatePlatformKey(t, capiKMS) + capiChain := mustCertificateWithKey(t, capiKey.PublicKey) + require.NoError(t, capiKMS.StoreCertificateChain(&apiv1.StoreCertificateChainRequest{ + Name: platformCertName, + CertificateChain: capiChain, + })) + t.Cleanup(func() { + assert.NoError(t, capiKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: platformCertName, + })) + assert.NoError(t, capiKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: uri.New(Scheme, url.Values{ + "issuer": []string{capiChain[1].Issuer.CommonName}, + "serial": []string{capiChain[1].SerialNumber.String()}, + }).String(), + })) + }) + + type args struct { + req *apiv1.LoadCertificateChainRequest + } + tests := []struct { + name string + kms *KMS + args args + want []*x509.Certificate + assertion assert.ErrorAssertionFunc + }{ + {"ok capi", capiKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: platformCertName, + }}, capiChain, assert.NoError}, + {"ok capi issuer and serial", capiKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: uri.New(Scheme, url.Values{ + "issuer": []string{capiChain[0].Issuer.CommonName}, + "serial": []string{capiChain[0].SerialNumber.String()}, + }).String(), + }}, capiChain, assert.NoError}, + {"fail capi missing", capiKMS, args{&apiv1.LoadCertificateChainRequest{ + Name: platformMissingName, + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.LoadCertificateChain(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestKMS_StoreCertificateChain_capi(t *testing.T) { + capiKMS := mustCAPIKMS(t) + capiKey := mustCreatePlatformKey(t, capiKMS) + capiChain := mustCertificateWithKey(t, capiKey.PublicKey) + chainNoKey := mustCertificate(t, "") + + type args struct { + req *apiv1.StoreCertificateChainRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok capi", capiKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: platformCertName, + CertificateChain: capiChain, + }}, assert.NoError}, + {"ok capi no key", capiKMS, args{&apiv1.StoreCertificateChainRequest{ + Name: platformCertName + "-other", + CertificateChain: chainNoKey, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + // Storing a certificate with no key is not supported on TPMKMS. + if capiKMS.Type() == apiv1.TPMKMS { + return assert.Error(t, err) + } + + t.Cleanup(func() { + assert.NoError(t, capiKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: platformCertName, + })) + + if capiKMS.Type() == apiv1.MacKMS { + assert.NoError(t, capiKMS.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: uri.New(Scheme, url.Values{ + "serial": []string{hex.EncodeToString(capiChain[1].SerialNumber.Bytes())}, + }).String(), + })) + } + }) + return assert.NoError(t, err) + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, tt.kms.StoreCertificateChain(tt.args.req)) + }) + } +} + +func TestKMS_DeleteCertificate_capi(t *testing.T) { + capiKMS := mustCAPIKMS(t) + _ = mustCreatePlatformCertificate(t, capiKMS, withNoCleanupCertificate()) + + type args struct { + req *apiv1.DeleteCertificateRequest + } + tests := []struct { + name string + kms *KMS + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok capi", capiKMS, args{&apiv1.DeleteCertificateRequest{ + Name: platformCertName, + }}, func(tt assert.TestingT, err error, i ...interface{}) bool { + _, loadErr := capiKMS.LoadCertificate(&apiv1.LoadCertificateRequest{ + Name: platformCertName, + }) + return assert.NoError(t, err) && assert.Error(t, loadErr) + }}, + {"fail platform missing", capiKMS, args{&apiv1.DeleteCertificateRequest{ + Name: platformMissingName, + }}, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, tt.kms.DeleteCertificate(tt.args.req)) + }) + } +} + +func TestKMS_SearchKeys_capi(t *testing.T) { + suffix := mustSuffix(t) + capiKMS := mustCAPIKMS(t) + + platformKeys := make([]*apiv1.CreateKeyResponse, 4) + for i := range platformKeys { + name := fmt.Sprintf("kms:name=search-test-%d-%s", i, suffix) + platformKeys[i] = mustCreatePlatformKey(t, capiKMS, withName(name)) + } + + type args struct { + req *apiv1.SearchKeysRequest + } + tests := []struct { + name string + kms *KMS + args args + want *apiv1.SearchKeysResponse + assertion assert.ErrorAssertionFunc + }{ + {"fail capi", capiKMS, args{&apiv1.SearchKeysRequest{ + Query: "kms:", + }}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kms.SearchKeys(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_transformToCAPIKMS(t *testing.T) { + tests := []struct { + name string + rawuri string + want string + assertion assert.ErrorAssertionFunc + }{ + {"scheme", "kms:", "capi:skip-find-certificate-key=true", assert.NoError}, + {"with name", "kms:name=foo", "capi:key=foo;skip-find-certificate-key=true", assert.NoError}, + {"with hw", "kms:name=foo;hw=true", "capi:key=foo;provider=Microsoft+Platform+Crypto+Provider;skip-find-certificate-key=true", assert.NoError}, + {"with hw on query", "kms:name=foo?hw=true", "capi:key=foo;provider=Microsoft+Platform+Crypto+Provider;skip-find-certificate-key=true", assert.NoError}, + {"with skip-find-certificate-key", "kms:name=foo;skip-find-certificate-key=false", "capi:key=foo;skip-find-certificate-key=false", assert.NoError}, + {"with provider", "kms:name=foo;hw=true;provider=my", "capi:key=foo;provider=my;skip-find-certificate-key=true", assert.NoError}, + {"with extrasValues", "kms:name=foo;foo=bar?baz=qux", "capi:baz=qux;foo=bar;key=foo;skip-find-certificate-key=true", assert.NoError}, + {"fail parse", "capikms:name=foo", "", assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := transformToCAPIKMS(tt.rawuri) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_transformFromCAPIKMS(t *testing.T) { + tests := []struct { + name string + rawuri string + want string + assertion assert.ErrorAssertionFunc + }{ + {"scheme", "capi:", "kms:", assert.NoError}, + {"with key", "capi:key=foo", "kms:name=foo", assert.NoError}, + {"with provider", "capi:key=foo;provider=Microsoft+Platform+Crypto+Provider", "kms:hw=true;name=foo;provider=Microsoft+Platform+Crypto+Provider", assert.NoError}, + {"with provider on query", "capi:key=foo?provider=my", "kms:name=foo;provider=my", assert.NoError}, + {"with others", "capi:key=foo;serial=1234;issuer=My+CA", "kms:issuer=My+CA;name=foo;serial=1234", assert.NoError}, + {"fail empty", "", "", assert.Error}, + {"fail scheme", "kms:", "", assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := transformFromCAPIKMS(tt.rawuri) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/kms/softkms/softkms.go b/kms/softkms/softkms.go index 46a73ebe..997c7838 100644 --- a/kms/softkms/softkms.go +++ b/kms/softkms/softkms.go @@ -7,8 +7,10 @@ import ( "crypto/ed25519" "crypto/rsa" "crypto/x509" + "fmt" "github.com/pkg/errors" + "go.step.sm/crypto/keyutil" "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" @@ -190,6 +192,31 @@ func (k *SoftKMS) CreateDecrypter(req *apiv1.CreateDecrypterRequest) (crypto.Dec } } +// LoadCertificate returns a x509.Certificate from the file passed in the +// request name. +func (k *SoftKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certificate, error) { + if req.Name == "" { + return nil, fmt.Errorf("loadCertificateRequest 'name' cannot be empty") + } + + bundle, err := pemutil.ReadCertificateBundle(filename(req.Name)) + if err != nil { + return nil, err + } + + return bundle[0], nil +} + +// LoadCertificateChain returns a slice of x509.Certificate from the file passed +// in the request name. +func (k *SoftKMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([]*x509.Certificate, error) { + if req.Name == "" { + return nil, fmt.Errorf("loadCertificateChainRequest 'name' cannot be empty") + } + + return pemutil.ReadCertificateBundle(filename(req.Name)) +} + func filename(s string) string { if u, err := uri.ParseWithScheme(Scheme, s); err == nil { if f := u.Get("path"); f != "" { diff --git a/kms/softkms/softkms_test.go b/kms/softkms/softkms_test.go index 11aa4ddd..6e5f8299 100644 --- a/kms/softkms/softkms_test.go +++ b/kms/softkms/softkms_test.go @@ -17,6 +17,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/pemutil" @@ -439,3 +440,69 @@ func Test_filename(t *testing.T) { }) } } + +func TestSoftKMS_LoadCertificate(t *testing.T) { + cert, err := pemutil.ReadCertificate("testdata/cert.crt") + require.NoError(t, err) + + type args struct { + req *apiv1.LoadCertificateRequest + } + tests := []struct { + name string + k *SoftKMS + args args + want *x509.Certificate + assertion assert.ErrorAssertionFunc + }{ + {"ok", &SoftKMS{}, args{&apiv1.LoadCertificateRequest{Name: "testdata/cert.crt"}}, cert, assert.NoError}, + {"ok uri", &SoftKMS{}, args{&apiv1.LoadCertificateRequest{Name: "testdata/cert.crt"}}, cert, assert.NoError}, + {"ok uri with path", &SoftKMS{}, args{&apiv1.LoadCertificateRequest{Name: "softkms:path=testdata/cert.crt"}}, cert, assert.NoError}, + {"fail empty", &SoftKMS{}, args{&apiv1.LoadCertificateRequest{}}, nil, assert.Error}, + {"fail missing", &SoftKMS{}, args{&apiv1.LoadCertificateRequest{Name: "testdata/missing.crt"}}, nil, assert.Error}, + {"fail not a certificate", &SoftKMS{}, args{&apiv1.LoadCertificateRequest{Name: "testdata/cert.key"}}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &SoftKMS{} + got, err := k.LoadCertificate(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestSoftKMS_LoadCertificateChain(t *testing.T) { + chain, err := pemutil.ReadCertificateBundle("testdata/chain.crt") + require.NoError(t, err) + + cert, err := pemutil.ReadCertificate("testdata/cert.crt") + require.NoError(t, err) + + type args struct { + req *apiv1.LoadCertificateChainRequest + } + tests := []struct { + name string + k *SoftKMS + args args + want []*x509.Certificate + assertion assert.ErrorAssertionFunc + }{ + {"ok", &SoftKMS{}, args{&apiv1.LoadCertificateChainRequest{Name: "testdata/chain.crt"}}, chain, assert.NoError}, + {"ok uri", &SoftKMS{}, args{&apiv1.LoadCertificateChainRequest{Name: "testdata/chain.crt"}}, chain, assert.NoError}, + {"ok uri with path", &SoftKMS{}, args{&apiv1.LoadCertificateChainRequest{Name: "softkms:path=testdata/chain.crt"}}, chain, assert.NoError}, + {"ok cert", &SoftKMS{}, args{&apiv1.LoadCertificateChainRequest{Name: "softkms:testdata/cert.crt"}}, []*x509.Certificate{cert}, assert.NoError}, + {"fail empty", &SoftKMS{}, args{&apiv1.LoadCertificateChainRequest{}}, nil, assert.Error}, + {"fail missing", &SoftKMS{}, args{&apiv1.LoadCertificateChainRequest{Name: "testdata/missing.crt"}}, nil, assert.Error}, + {"fail not a certificate", &SoftKMS{}, args{&apiv1.LoadCertificateChainRequest{Name: "testdata/cert.key"}}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &SoftKMS{} + got, err := k.LoadCertificateChain(tt.args.req) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/kms/softkms/testdata/chain.crt b/kms/softkms/testdata/chain.crt new file mode 100644 index 00000000..4289f3e3 --- /dev/null +++ b/kms/softkms/testdata/chain.crt @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIB0TCCAXegAwIBAgIQLnp4754BEk4JSoK7md7K4DAKBggqhkjOPQQDAjAkMSIw +IAYDVQQDExlTbWFsbHN0ZXAgSW50ZXJtZWRpYXRlIENBMB4XDTI2MDEzMDAxMzM0 +M1oXDTI2MDEzMTAxMzM0M1owHTEbMBkGA1UEAxMSdGVzdC5zbWFsbHN0ZXAuY29t +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEIPeCsTFXBmkeJlX4dC6jo+2oe3US +m3Yt1LfJRrV2cBNB5t+OQqKZNajirkBCv/IqUBFJILtrNAZ8tSNwTtCTeKOBkTCB +jjAOBgNVHQ8BAf8EBAMCB4AwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMC +MB0GA1UdDgQWBBRv+K/kt+Ho6CDVpfvl3YRnbqxBjzAfBgNVHSMEGDAWgBQOlOB3 +dUUKfp7vXBdzBsMJREGiZTAdBgNVHREEFjAUghJ0ZXN0LnNtYWxsc3RlcC5jb20w +CgYIKoZIzj0EAwIDSAAwRQIhAKzVQCITpTCITRhQ/VWrUHqhYuD+Q/a7yUrM1gix +ZJV3AiAgRBrPNWYAIK24hjVWE21OPaJnSZ7Q7VRJNE0/vQzrZQ== +-----END CERTIFICATE----- diff --git a/kms/softkms/testdata/chain.key b/kms/softkms/testdata/chain.key new file mode 100644 index 00000000..b9b8e11c --- /dev/null +++ b/kms/softkms/testdata/chain.key @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIKDTRUVuCrQVHjAlUjaDkmfEayOiIkLk1SZPU7MPMhyxoAoGCCqGSM49 +AwEHoUQDQgAEIPeCsTFXBmkeJlX4dC6jo+2oe3USm3Yt1LfJRrV2cBNB5t+OQqKZ +NajirkBCv/IqUBFJILtrNAZ8tSNwTtCTeA== +-----END EC PRIVATE KEY----- diff --git a/kms/sshagentkms/no_sshagentkms.go b/kms/sshagentkms/no_sshagentkms.go index adb94c66..cc3e29f2 100644 --- a/kms/sshagentkms/no_sshagentkms.go +++ b/kms/sshagentkms/no_sshagentkms.go @@ -8,6 +8,7 @@ import ( "path/filepath" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/sshagentkms/sshagentkms.go b/kms/sshagentkms/sshagentkms.go index 87645c37..acd11dea 100644 --- a/kms/sshagentkms/sshagentkms.go +++ b/kms/sshagentkms/sshagentkms.go @@ -19,6 +19,7 @@ import ( "golang.org/x/crypto/ssh/agent" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/sshutil" diff --git a/kms/sshagentkms/sshagentkms_test.go b/kms/sshagentkms/sshagentkms_test.go index 96621613..9d1656e7 100644 --- a/kms/sshagentkms/sshagentkms_test.go +++ b/kms/sshagentkms/sshagentkms_test.go @@ -19,11 +19,12 @@ import ( "strings" "testing" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" - "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/agent" ) // Some helpers with inspiration from crypto/ssh/agent/client_test.go diff --git a/kms/tpmkms/no_tpmkms.go b/kms/tpmkms/no_tpmkms.go index a08640b0..f7eaa5e5 100644 --- a/kms/tpmkms/no_tpmkms.go +++ b/kms/tpmkms/no_tpmkms.go @@ -8,6 +8,7 @@ import ( "path/filepath" "github.com/pkg/errors" + "go.step.sm/crypto/kms/apiv1" ) diff --git a/kms/tpmkms/tpmkms.go b/kms/tpmkms/tpmkms.go index 370a2848..f766e4e0 100644 --- a/kms/tpmkms/tpmkms.go +++ b/kms/tpmkms/tpmkms.go @@ -29,6 +29,7 @@ import ( "go.step.sm/crypto/tpm/attestation" "go.step.sm/crypto/tpm/storage" "go.step.sm/crypto/tpm/tss2" + "go.step.sm/crypto/x509util" ) func init() { @@ -1059,8 +1060,14 @@ func (k *TPMKMS) deleteCertificateFromWindowsCertificateStore(req *apiv1.DeleteC uv.Set("key-id", o.keyID) case o.sha1 != "": uv.Set("sha1", o.sha1) + case o.name != "": + keyID, err := k.getSubjectKeyID(req.Name) + if err != nil { + return fmt.Errorf("error getting key-id: %w", err) + } + uv.Set("key-id", hex.EncodeToString(keyID)) default: - return errors.New(`at least one of "serial", "key-id" or "sha1" is expected to be set`) + return errors.New(`at least one of "serial", "key-id", "sha1" or "name" is expected to be set`) } dk, ok := k.windowsCertificateManager.(deletingCertificateManager) @@ -1077,6 +1084,16 @@ func (k *TPMKMS) deleteCertificateFromWindowsCertificateStore(req *apiv1.DeleteC return nil } +func (k *TPMKMS) getSubjectKeyID(name string) ([]byte, error) { + key, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{ + Name: name, + }) + if err != nil { + return nil, err + } + return x509util.GenerateSubjectKeyID(key) +} + // attestationClient is a wrapper for [attestation.Client], containing // all of the required references to perform attestation against the // Smallstep Attestation CA. diff --git a/kms/tpmkms/tpmkms_test.go b/kms/tpmkms/tpmkms_test.go index 55c44e93..449bc79f 100644 --- a/kms/tpmkms/tpmkms_test.go +++ b/kms/tpmkms/tpmkms_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/tpm" "go.step.sm/crypto/tpm/tss2" diff --git a/kms/uri/uri.go b/kms/uri/uri.go index 43ff2efe..8dafca4b 100644 --- a/kms/uri/uri.go +++ b/kms/uri/uri.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/hex" "fmt" + "maps" + "math/big" "net/url" "os" "strconv" @@ -11,6 +13,7 @@ import ( "unicode" "github.com/pkg/errors" + "go.step.sm/crypto/internal/termutil" ) @@ -152,6 +155,38 @@ func (u *URI) GetInt(key string) *int64 { return nil } +// GetBigInt returns the first [*big.Int] value in the URI with the given key. +// It returns nil if the field is not present. It parses as a hexadecimal +// string if the value starts with 0x (0x12), 0X (0X1A), contains a colon +// (00:01), or contains only hex characters with at least one letter A-F +// (e.g. "0A01"); otherwise it parses as a base-10 number. +func (u *URI) GetBigInt(key string) (*big.Int, error) { + v := u.Get(key) + if v == "" { + return nil, nil //nolint:nilnil // return nil value + } + + if hx, ok, isHex := hexString(v); ok && isHex { + if hx == "" { + return nil, fmt.Errorf("value %q is not a valid hexadecimal number", v) + } + + b, err := hex.DecodeString(hx) + if err != nil { + return nil, err + } + + return new(big.Int).SetBytes(b), nil + } + + bi, ok := new(big.Int).SetString(v, 10) + if !ok { + return nil, fmt.Errorf("value %q is not a valid number", v) + } + + return bi, nil +} + // GetEncoded returns the first value in the uri with the given key, it will // return empty nil if that field is not present or is empty. If the return // value is hex encoded it will decode it and return it. @@ -160,8 +195,8 @@ func (u *URI) GetEncoded(key string) []byte { if v == "" { return nil } - if len(v)%2 == 0 { - if b, err := hex.DecodeString(strings.TrimPrefix(v, "0x")); err == nil { + if hx, ok, _ := hexString(v); ok { + if b, err := hex.DecodeString(hx); err == nil { return b } } @@ -177,12 +212,17 @@ func (u *URI) GetHexEncoded(key string) ([]byte, error) { return nil, nil } - b, err := hex.DecodeString(strings.TrimPrefix(v, "0x")) - if err != nil { - return nil, fmt.Errorf("failed decoding %q: %w", v, err) + hx, ok, _ := hexString(v) + if !ok || hx == "" { + return nil, fmt.Errorf("value %q is not a valid hexadecimal number", v) } - return b, nil + return hex.DecodeString(hx) +} + +// Set sets the key to value. It replaces any existing values. +func (u *URI) Set(key, value string) { + u.Values.Set(key, value) } // Pin returns the pin encoded in the url. It will read the pin from the @@ -218,6 +258,66 @@ func (u *URI) Read(key string) ([]byte, error) { return readFile(path) } +// Values returns the [url.Values] merging the values in the opaque and query +// string of the given [*URI]. +func Values(u *URI) url.Values { + uv := url.Values{} + maps.Copy(uv, u.Values) + for k, v := range u.URL.Query() { + if !uv.Has(k) { + uv[k] = v + continue + } + for _, s := range v { + uv.Add(k, s) + } + } + return uv +} + +// hexString returns a clean hexadecimal string, a boolean indicating if s is a +// valid hexadecimal string, and a boolean indicating if s was explicitly +// identifiable as hexadecimal. If s starts with 0x (0x12), 0X (0X1A), or +// contains colons (01:1A) it will remove them. The third boolean is true if s +// had a 0x/0X prefix, contained colons, or contained at least one letter A-F +// (010A). The string is prefixed with 0 if its length is odd. +func hexString(s string) (string, bool, bool) { + hx := strings.TrimPrefix(s, "0x") + hx = strings.TrimPrefix(hx, "0X") + hx = strings.ReplaceAll(hx, ":", "") + changed := (len(s) != len(hx)) + + if len(hx)%2 != 0 { + hx = "0" + hx + } + + valid, hasLetter := isValidHexString(hx) + if !valid { + return "", false, false + } + + return hx, valid, (changed || hasLetter) +} + +// isValidHexString returns two booleans, the first indicating s contains only +// hexadecimal characters and the second if at least one letter (a-f or A-F), +// indicating it should be treated as hex. +func isValidHexString(s string) (bool, bool) { + var hasLetter bool + for _, c := range s { + switch { + case c >= '0' && c <= '9': + case c >= 'a' && c <= 'f': + hasLetter = true + case c >= 'A' && c <= 'F': + hasLetter = true + default: + return false, false + } + } + return true, hasLetter +} + func readFile(path string) ([]byte, error) { u, err := url.Parse(path) if err == nil && (u.Scheme == "" || u.Scheme == "file") { diff --git a/kms/uri/uri_test.go b/kms/uri/uri_test.go index 03542303..9b63bdd6 100644 --- a/kms/uri/uri_test.go +++ b/kms/uri/uri_test.go @@ -2,6 +2,7 @@ package uri import ( "errors" + "math/big" "net/url" "os" "path/filepath" @@ -263,7 +264,7 @@ func TestURI_GetEncoded(t *testing.T) { {"ok in query percent", mustParse(t, "yubikey:slot-id=9a?foo=%9a"), args{"foo"}, []byte{0x9a}}, {"ok missing", mustParse(t, "yubikey:slot-id=9a"), args{"foo"}, nil}, {"ok missing query", mustParse(t, "yubikey:slot-id=9a?bar=zar"), args{"foo"}, nil}, - {"ok no hex", mustParse(t, "yubikey:slot-id=09a?bar=zar"), args{"slot-id"}, []byte{'0', '9', 'a'}}, + {"ok no hex", mustParse(t, "yubikey:slot-id=09z?bar=zar"), args{"slot-id"}, []byte{'0', '9', 'z'}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -369,6 +370,39 @@ func TestURI_GetInt(t *testing.T) { } } +func TestURI_GetBigInt(t *testing.T) { + type args struct { + key string + } + tests := []struct { + name string + uri *URI + args args + want *big.Int + assertion assert.ErrorAssertionFunc + }{ + {"ok empty", mustParse(t, "mackms:serial="), args{"serial"}, nil, assert.NoError}, + {"ok missing", mustParse(t, "mackms:label=123456"), args{"serial"}, nil, assert.NoError}, + {"ok number", mustParse(t, "mackms:serial=123456"), args{"serial"}, big.NewInt(123456), assert.NoError}, + {"ok hex with 0x", mustParse(t, "mackms:serial=0x123456"), args{"serial"}, big.NewInt(1193046), assert.NoError}, + {"ok hex with 0X", mustParse(t, "mackms:serial=0X123456"), args{"serial"}, big.NewInt(1193046), assert.NoError}, + {"ok hex with colon", mustParse(t, "mackms:serial=12:34:56"), args{"serial"}, big.NewInt(1193046), assert.NoError}, + {"ok hex odd length", mustParse(t, "mackms:serial=0x1"), args{"serial"}, big.NewInt(1), assert.NoError}, + {"fail hex empty", mustParse(t, "mackms:serial=0x"), args{"serial"}, nil, assert.Error}, + {"ok hex with letters", mustParse(t, "mackms:serial=0A01"), args{"serial"}, big.NewInt(0x0A01), assert.NoError}, + {"ok hex with letters only", mustParse(t, "mackms:serial=12345a"), args{"serial"}, big.NewInt(0x12345a), assert.NoError}, + {"fail hex", mustParse(t, "mackms:serial=0x12345g"), args{"serial"}, nil, assert.Error}, + {"fail number", mustParse(t, "mackms:serial=12345G"), args{"serial"}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.uri.GetBigInt(tt.args.key) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + func TestURI_GetHexEncoded(t *testing.T) { type args struct { key string @@ -384,7 +418,7 @@ func TestURI_GetHexEncoded(t *testing.T) { {"ok first", mustParse(t, "capi:sha1=9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}, false}, {"ok prefix", mustParse(t, "capi:sha1=0x9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}, false}, {"ok missing", mustParse(t, "capi:foo=9a"), args{"sha1"}, nil, false}, - {"fail odd hex", mustParse(t, "capi:sha1=09a?bar=zar"), args{"sha1"}, nil, true}, + {"ok odd hex", mustParse(t, "capi:sha1=09a?bar=zar"), args{"sha1"}, []byte{0x00, 0x9a}, false}, {"fail invalid hex", mustParse(t, "capi:sha1=9z?bar=zar"), args{"sha1"}, nil, true}, } for _, tt := range tests { @@ -401,6 +435,19 @@ func TestURI_GetHexEncoded(t *testing.T) { } } +func TestURI_Set(t *testing.T) { + u := mustParse(t, "kms:name=foo") + assert.Equal(t, "", u.Get("key")) + + u.Set("key", "bar") + assert.Equal(t, "bar", u.Get("key")) + assert.Equal(t, "kms:key=bar;name=foo", u.String()) + + u.Set("key", "zar") + assert.Equal(t, "zar", u.Get("key")) + assert.Equal(t, "kms:key=zar;name=foo", u.String()) +} + func TestURI_Read(t *testing.T) { // Read does not trim the contents of the file expected := []byte("trim-this-pin \n") @@ -444,3 +491,28 @@ func TestURI_Read(t *testing.T) { }) } } + +func TestValues(t *testing.T) { + tests := []struct { + name string + rawuri string + want url.Values + }{ + {"empty", "kms:", url.Values{}}, + {"with opaque values", "kms:foo=bar;baz=qux;foo=zar", url.Values{ + "foo": []string{"bar", "zar"}, "baz": []string{"qux"}, + }}, + {"with query values", "kms:name=value?foo=bar&baz=qux&foo=zar", url.Values{ + "name": []string{"value"}, "foo": []string{"bar", "zar"}, "baz": []string{"qux"}, + }}, + {"with mixed values", "kms:name=value;foo=bar?baz=qux&foo=zar", url.Values{ + "name": []string{"value"}, "foo": []string{"bar", "zar"}, "baz": []string{"qux"}, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u := mustParse(t, tt.rawuri) + assert.Equal(t, tt.want, Values(u)) + }) + } +} diff --git a/nssdb/keys.go b/nssdb/keys.go index d08133fe..bb5864a9 100644 --- a/nssdb/keys.go +++ b/nssdb/keys.go @@ -10,8 +10,9 @@ import ( "fmt" "math/big" - asn1utils "go.step.sm/crypto/internal/utils/asn1" "golang.org/x/crypto/cryptobyte" + + asn1utils "go.step.sm/crypto/internal/utils/asn1" ) // ASN.1 encoded OID for secp256r1 (1.2.840.10045.3.1.7), the only supported curve. diff --git a/nssdb/keys_test.go b/nssdb/keys_test.go index b507dabd..1c45d849 100644 --- a/nssdb/keys_test.go +++ b/nssdb/keys_test.go @@ -11,8 +11,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.step.sm/crypto/pemutil" "golang.org/x/crypto/cryptobyte" + + "go.step.sm/crypto/pemutil" ) func TestEcParams(t *testing.T) { diff --git a/pemutil/ssh.go b/pemutil/ssh.go index 00698dae..d8c56be7 100644 --- a/pemutil/ssh.go +++ b/pemutil/ssh.go @@ -17,9 +17,10 @@ import ( "math/big" "github.com/pkg/errors" + "golang.org/x/crypto/ssh" + bcryptpbkdf "go.step.sm/crypto/internal/bcrypt_pbkdf" "go.step.sm/crypto/randutil" - "golang.org/x/crypto/ssh" ) const ( diff --git a/sshutil/fingerprint_test.go b/sshutil/fingerprint_test.go index 1edf1630..f573bdf9 100644 --- a/sshutil/fingerprint_test.go +++ b/sshutil/fingerprint_test.go @@ -12,8 +12,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.step.sm/crypto/internal/emoji" "golang.org/x/crypto/ssh" + + "go.step.sm/crypto/internal/emoji" ) func generateCertificate(t *testing.T) ssh.PublicKey { diff --git a/sshutil/sshutil_test.go b/sshutil/sshutil_test.go index dc40c03f..64350214 100644 --- a/sshutil/sshutil_test.go +++ b/sshutil/sshutil_test.go @@ -8,9 +8,10 @@ import ( "testing" "github.com/stretchr/testify/require" - "go.step.sm/crypto/keyutil" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" + + "go.step.sm/crypto/keyutil" ) type skKey struct { diff --git a/tpm/ak_test.go b/tpm/ak_test.go index 71370dae..ff862918 100644 --- a/tpm/ak_test.go +++ b/tpm/ak_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/stretchr/testify/require" + "go.step.sm/crypto/keyutil" "go.step.sm/crypto/minica" "go.step.sm/crypto/x509util" diff --git a/tpm/caps.go b/tpm/caps.go index 9d6ca7da..539a9dc8 100644 --- a/tpm/caps.go +++ b/tpm/caps.go @@ -6,6 +6,7 @@ import ( "slices" "github.com/google/go-tpm/legacy/tpm2" + "go.step.sm/crypto/tpm/algorithm" ) diff --git a/tpm/info_test.go b/tpm/info_test.go index 6af4545b..2eac1c74 100644 --- a/tpm/info_test.go +++ b/tpm/info_test.go @@ -6,6 +6,7 @@ import ( "github.com/smallstep/go-attestation/attest" "github.com/stretchr/testify/require" + "go.step.sm/crypto/tpm/manufacturer" ) diff --git a/tpm/key_test.go b/tpm/key_test.go index 3139e3d6..9635eb1e 100644 --- a/tpm/key_test.go +++ b/tpm/key_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/stretchr/testify/require" + "go.step.sm/crypto/keyutil" "go.step.sm/crypto/minica" "go.step.sm/crypto/x509util" diff --git a/tpm/rand/rand_simulator_test.go b/tpm/rand/rand_simulator_test.go index cf79becf..5286e0fa 100644 --- a/tpm/rand/rand_simulator_test.go +++ b/tpm/rand/rand_simulator_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.step.sm/crypto/tpm" "go.step.sm/crypto/tpm/simulator" ) diff --git a/tpm/tss2/simulator_test.go b/tpm/tss2/simulator_test.go index 6028934b..b40ec920 100644 --- a/tpm/tss2/simulator_test.go +++ b/tpm/tss2/simulator_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/stretchr/testify/require" + "go.step.sm/crypto/tpm/simulator" ) diff --git a/x509util/certificate.go b/x509util/certificate.go index fa80b06f..a2456c6a 100644 --- a/x509util/certificate.go +++ b/x509util/certificate.go @@ -216,7 +216,7 @@ func CreateCertificate(template, parent *x509.Certificate, pub crypto.PublicKey, } } if template.SubjectKeyId == nil { - if template.SubjectKeyId, err = generateSubjectKeyID(pub); err != nil { + if template.SubjectKeyId, err = GenerateSubjectKeyID(pub); err != nil { return nil, err } } diff --git a/x509util/certificate_test.go b/x509util/certificate_test.go index d49b8d48..a7df6960 100644 --- a/x509util/certificate_test.go +++ b/x509util/certificate_test.go @@ -81,7 +81,7 @@ func createIssuerCertificate(t *testing.T, commonName string) (*x509.Certificate if err != nil { t.Fatal(err) } - subjectKeyID, err := generateSubjectKeyID(pub) + subjectKeyID, err := GenerateSubjectKeyID(pub) if err != nil { t.Fatal(err) } @@ -814,7 +814,7 @@ func TestCreateCertificate(t *testing.T) { return sn } mustSubjectKeyID := func(pub crypto.PublicKey) []byte { - b, err := generateSubjectKeyID(pub) + b, err := GenerateSubjectKeyID(pub) if err != nil { t.Fatal(err) } diff --git a/x509util/extensions.go b/x509util/extensions.go index a89a7238..95f37c88 100644 --- a/x509util/extensions.go +++ b/x509util/extensions.go @@ -16,6 +16,7 @@ import ( "time" "github.com/pkg/errors" + asn1utils "go.step.sm/crypto/internal/utils/asn1" ) diff --git a/x509util/utils.go b/x509util/utils.go index 69ddbebf..825f8d7a 100644 --- a/x509util/utils.go +++ b/x509util/utils.go @@ -110,7 +110,7 @@ type subjectPublicKeyInfo struct { SubjectPublicKey asn1.BitString } -// generateSubjectKeyID generates the key identifier according the the RFC 5280 +// GenerateSubjectKeyID generates the key identifier according the the RFC 5280 // section 4.2.1.2. // // The keyIdentifier is composed of the 160-bit SHA-1 hash of the value of the @@ -119,7 +119,7 @@ type subjectPublicKeyInfo struct { // // If FIPS 140-3 mode is enabled, instead of SHA-1, it will use the leftmost // 160-bits of the SHA-256 hash according to RFC 7093 section 2. -func generateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) { +func GenerateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) { b, err := x509.MarshalPKIXPublicKey(pub) if err != nil { return nil, errors.Wrap(err, "error marshaling public key") diff --git a/x509util/utils_test.go b/x509util/utils_test.go index a5a1c1bd..ca8a5ca3 100644 --- a/x509util/utils_test.go +++ b/x509util/utils_test.go @@ -172,7 +172,7 @@ func Test_generateSubjectKeyID(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := generateSubjectKeyID(tt.args.pub) + got, err := GenerateSubjectKeyID(tt.args.pub) if (err != nil) != tt.wantErr { t.Errorf("generateSubjectKeyID() error = %v, wantErr %v", err, tt.wantErr) return @@ -204,7 +204,7 @@ func Test_generateSubjectKeyID_fips(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := generateSubjectKeyID(tt.args.pub) + got, err := GenerateSubjectKeyID(tt.args.pub) tt.assertion(t, err) assert.Equal(t, tt.want, got) })