diff --git a/authcode.go b/authcode.go index 83456b8..8ef47c5 100644 --- a/authcode.go +++ b/authcode.go @@ -176,7 +176,7 @@ func (f FederationLeaf) GetAuthorizationURL( if f.RequestURIGenerator != nil && opMetadata.RequestURIParameterSupported && len(resolved.TrustChain) > 0 { ownResolved, err := DefaultMetadataResolver.ResolveResponsePayload( apimodel.ResolveRequest{ - Subject: f.EntityID, + Subject: f.EntityID(), TrustAnchor: []string{resolved.TrustAnchor}, EntityTypes: []string{oidfedconst.EntityTypeOpenIDRelyingParty}, }, @@ -209,7 +209,7 @@ func (f FederationLeaf) GetAuthorizationURL( } else { q.Set("request", string(requestObject)) } - q.Set("client_id", f.EntityID) + q.Set("client_id", f.EntityID()) q.Set("response_type", "code") q.Set("redirect_uri", redirectURI) q.Set("scope", scope) @@ -233,7 +233,7 @@ func (f FederationLeaf) CodeExchange( params.Set("grant_type", "authorization_code") params.Set("code", code) params.Set("redirect_uri", redirectURI) - params.Set("client_id", f.EntityID) + params.Set("client_id", f.EntityID()) clientAssertion, err := f.oidcROProducer.ClientAssertion( opMetadata.TokenEndpoint, diff --git a/cache/cache.go b/cache/cache.go index 4c1bdff..28fd4ac 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -2,7 +2,6 @@ package cache import ( "encoding/base64" - "log" "strings" "time" @@ -31,7 +30,7 @@ type cacheWrapper struct { func newCacheWrapper(defaultExpiration time.Duration) cacheWrapper { c := gocache.NewCache().WithDefaultTTL(defaultExpiration) if err := c.StartJanitor(); err != nil { - log.Fatal(err) // skipcq: RVV-A0003 + internal.WithError(err).Error("Cache: failed to start janitor; proceeding without background cleanup") } return cacheWrapper{ c, diff --git a/explicit.go b/explicit.go index 45ae082..dc45681 100644 --- a/explicit.go +++ b/explicit.go @@ -128,14 +128,17 @@ func (f FederationLeaf) DoExplicitClientRegistration(op string) ( if opMetadata == nil || opMetadata.FederationRegistrationEndpoint == "" { return nil, nil, errors.New("op does not have a federation registration endpoint") } - entityConfigurationData := f.EntityConfigurationPayload() + entityConfigurationData, err := f.EntityConfigurationPayload() + if err != nil { + return nil, nil, errors.Wrap(err, "could not get entity configuration payload") + } AdjustRPMetadataToOP(entityConfigurationData.Metadata.RelyingParty, opMetadata) entityConfigurationData.Audience = op var headers jws.Headers if len(resolved.TrustChain) > 0 { ownResolved, err := DefaultMetadataResolver.ResolveResponsePayload( apimodel.ResolveRequest{ - Subject: f.EntityID, + Subject: f.EntityID(), TrustAnchor: []string{resolved.TrustAnchor}, EntityTypes: []string{oidfedconst.EntityTypeOpenIDRelyingParty}, }, @@ -147,7 +150,7 @@ func (f FederationLeaf) DoExplicitClientRegistration(op string) ( _ = headers.Set("peer_trust_chain", resolved.TrustChain) } } - entityConfiguration, err := f.EntityStatementSigner.JWTWithHeaders(entityConfigurationData, headers) + entityConfiguration, err := f.SignEntityStatementWithHeaders(*entityConfigurationData, headers) if err != nil { return nil, nil, err } @@ -180,7 +183,7 @@ func (f FederationLeaf) DoExplicitClientRegistration(op string) ( if err != nil { return nil, nil, errors.Wrap(err, "could not parse explicit registration response") } - if res.Audience != f.EntityID { + if res.Audience != f.EntityID() { return nil, nil, errors.New("explicit client registration: OP returned unexpected audience") } diff --git a/federation.go b/federation.go index b9db5ed..574d5b9 100644 --- a/federation.go +++ b/federation.go @@ -3,6 +3,7 @@ package oidfed import ( "time" + "github.com/lestrrat-go/jwx/v3/jws" "github.com/pkg/errors" "github.com/go-oidfed/lib/apimodel" @@ -12,20 +13,18 @@ import ( "github.com/go-oidfed/lib/unixtime" ) -// FederationEntity is a type for an entity participating in federations. -// It holds all relevant information about the federation entity and can be used to create -// an EntityConfiguration about it -type FederationEntity struct { - EntityID string - Metadata *Metadata - MetadataUpdater func(*Metadata) - AuthorityHints []string - ConfigurationLifetime time.Duration - *jwx.EntityStatementSigner - TrustMarks []*EntityConfigurationTrustMarkConfig - TrustMarkIssuers AllowedTrustMarkIssuers - TrustMarkOwners TrustMarkOwners - Extra map[string]any +// FederationEntity defines the common behavior for federation entities, +// implemented by both StaticFederationEntity and DynamicFederationEntity. +type FederationEntity interface { + EntityID() string + // EntityConfigurationPayload returns the payload for the entity configuration + EntityConfigurationPayload() (*EntityStatementPayload, error) + // EntityConfigurationJWT returns the signed entity configuration as a JWT + EntityConfigurationJWT() ([]byte, error) + // SignEntityStatement signs the provided entity configuration payload + SignEntityStatement(payload EntityStatementPayload) ([]byte, error) + // SignEntityStatementWithHeaders signs the provided entity configuration payload and adds the passed jws.Headers + SignEntityStatementWithHeaders(payload EntityStatementPayload, headers jws.Headers) ([]byte, error) } // RequestURIGenerator is a function that takes a request object and returns a request_uri at which the passed @@ -41,16 +40,183 @@ type FederationLeaf struct { RequestURIGenerator RequestURIGenerator } -// NewFederationEntity creates a new FederationEntity with the passed properties +// DynamicFederationEntity mirrors FederationEntity but exposes all properties +// (except EntityID) as functions of time, enabling time-dependent values. +type DynamicFederationEntity struct { + ID string + Metadata func() (*Metadata, error) + AuthorityHints func() ([]string, error) + ConfigurationLifetime func() (time.Duration, error) + EntityStatementSigner func() (*jwx.EntityStatementSigner, error) + TrustMarks func() ([]*EntityConfigurationTrustMarkConfig, error) + TrustMarkIssuers func() (AllowedTrustMarkIssuers, error) + TrustMarkOwners func() (TrustMarkOwners, error) + Extra func() (map[string]any, []string, error) +} + +// EntityID returns the entity ID of the DynamicFederationEntity +func (f DynamicFederationEntity) EntityID() string { + return f.ID +} + +// EntityConfigurationPayload returns an EntityStatementPayload for this DynamicFederationEntity +// resolving all dynamic properties at time.Now(). +func (f DynamicFederationEntity) EntityConfigurationPayload() (*EntityStatementPayload, error) { + now := time.Now() + + var err error + // Resolve dynamic fields + metadata := (*Metadata)(nil) + if f.Metadata != nil { + metadata, err = f.Metadata() + if err != nil { + return nil, err + } + } + + var authorityHints []string + if f.AuthorityHints != nil { + authorityHints, err = f.AuthorityHints() + if err != nil { + return nil, err + } + } + + lifetime := time.Duration(0) + if f.ConfigurationLifetime != nil { + lifetime, err = f.ConfigurationLifetime() + if err != nil { + return nil, err + } + } + if lifetime <= 0 { + lifetime = defaultEntityConfigurationLifetime + } + + signer := (*jwx.EntityStatementSigner)(nil) + if f.EntityStatementSigner != nil { + signer, err = f.EntityStatementSigner() + if err != nil { + return nil, err + } + } + + var tms []TrustMarkInfo + if f.TrustMarks != nil { + trustMarkConfigs, err := f.TrustMarks() + if err != nil { + return nil, err + } + tms = make([]TrustMarkInfo, 0, len(trustMarkConfigs)) + for _, tmc := range trustMarkConfigs { + tm, err := tmc.TrustMarkJWT() + if err != nil { + internal.Log(err.Error()) + continue + } + tms = append( + tms, TrustMarkInfo{ + TrustMarkType: tmc.TrustMarkType, + TrustMarkJWT: tm, + }, + ) + } + } + + var trustMarkIssuers AllowedTrustMarkIssuers + if f.TrustMarkIssuers != nil { + trustMarkIssuers, err = f.TrustMarkIssuers() + if err != nil { + return nil, err + } + } + + var trustMarkOwners TrustMarkOwners + if f.TrustMarkOwners != nil { + trustMarkOwners, err = f.TrustMarkOwners() + if err != nil { + return nil, err + } + } + + var extra map[string]any + var crits []string + if f.Extra != nil { + extra, crits, err = f.Extra() + if err != nil { + return nil, err + } + } + + if metadata != nil { + metadata.ApplyInformationalClaimsToFederationEntity() + } + + var jwks jwx.JWKS + if signer != nil { + jwks, err = signer.JWKS() + if err != nil { + return nil, err + } + } + + return &EntityStatementPayload{ + Issuer: f.ID, + Subject: f.ID, + IssuedAt: unixtime.Unixtime{Time: now}, + ExpiresAt: unixtime.Unixtime{Time: now.Add(lifetime)}, + JWKS: jwks, + AuthorityHints: authorityHints, + Metadata: metadata, + TrustMarks: tms, + TrustMarkIssuers: trustMarkIssuers, + TrustMarkOwners: trustMarkOwners, + CriticalExtensions: crits, + Extra: extra, + }, nil +} + +// EntityConfigurationJWT creates and returns the signed jwt for the dynamic entity configuration +func (f DynamicFederationEntity) EntityConfigurationJWT() ([]byte, error) { + payload, err := f.EntityConfigurationPayload() + if err != nil { + return nil, err + } + return f.SignEntityStatement(*payload) +} + +// SignEntityStatement creates a signed JWT for the given EntityStatementPayload +func (f DynamicFederationEntity) SignEntityStatement(payload EntityStatementPayload) ([]byte, error) { + return f.SignEntityStatementWithHeaders(payload, nil) +} + +// SignEntityStatementWithHeaders creates a signed JWT for the given EntityStatementPayload and jws.Headers +func (f DynamicFederationEntity) SignEntityStatementWithHeaders( + payload EntityStatementPayload, headers jws.Headers, +) ([]byte, error) { + if f.EntityStatementSigner == nil { + return nil, errors.New("no signer function configured") + } + signer, err := f.EntityStatementSigner() + if signer == nil { + return nil, errors.New("no signer available at current time") + } + if err != nil { + return nil, err + } + return signer.JWTWithHeaders(payload, headers) +} + +// NewFederationEntity creates a new StaticFederationEntity with the passed properties func NewFederationEntity( entityID string, authorityHints []string, metadata *Metadata, signer *jwx.EntityStatementSigner, configurationLifetime time.Duration, extra map[string]any, -) (*FederationEntity, error) { +) (*StaticFederationEntity, error) { if configurationLifetime <= 0 { configurationLifetime = defaultEntityConfigurationLifetime } - return &FederationEntity{ - EntityID: entityID, + return &StaticFederationEntity{ + ID: entityID, Metadata: metadata, AuthorityHints: authorityHints, EntityStatementSigner: signer, @@ -78,54 +244,78 @@ func NewFederationLeaf( }, nil } -// EntityConfigurationPayload returns an EntityStatementPayload for this FederationEntity -func (f FederationEntity) EntityConfigurationPayload() *EntityStatementPayload { - now := time.Now() - var tms []TrustMarkInfo - for _, tmc := range f.TrustMarks { - tm, err := tmc.TrustMarkJWT() - if err != nil { - internal.Log(err.Error()) - continue - } - tms = append( - tms, TrustMarkInfo{ - TrustMarkType: tmc.TrustMarkType, - TrustMarkJWT: tm, - }, - ) - } - if f.MetadataUpdater != nil { - f.MetadataUpdater(f.Metadata) - } - f.Metadata.ApplyInformationalClaimsToFederationEntity() - return &EntityStatementPayload{ - Issuer: f.EntityID, - Subject: f.EntityID, - IssuedAt: unixtime.Unixtime{Time: now}, - ExpiresAt: unixtime.Unixtime{Time: now.Add(f.ConfigurationLifetime)}, - JWKS: f.EntityStatementSigner.JWKS(), - AuthorityHints: f.AuthorityHints, - Metadata: f.Metadata, - TrustMarks: tms, - TrustMarkIssuers: f.TrustMarkIssuers, - TrustMarkOwners: f.TrustMarkOwners, - Extra: f.Extra, - } +// StaticFederationEntity is a type for an entity participating in federations. +// It holds all relevant information about the federation entity and can be used to create +// an EntityConfiguration about it +type StaticFederationEntity struct { + ID string + Metadata *Metadata + AuthorityHints []string + ConfigurationLifetime time.Duration + *jwx.EntityStatementSigner + TrustMarks []*EntityConfigurationTrustMarkConfig + TrustMarkIssuers AllowedTrustMarkIssuers + TrustMarkOwners TrustMarkOwners + Extra map[string]any + CriticalClaims []string +} + +// EntityID returns the entity ID of the StaticFederationEntity +func (f StaticFederationEntity) EntityID() string { + return f.ID +} + +// EntityConfigurationPayload returns an EntityStatementPayload for this +// StaticFederationEntity +func (f StaticFederationEntity) EntityConfigurationPayload() (*EntityStatementPayload, error) { + return DynamicFederationEntity{ + ID: f.ID, + Metadata: func() (*Metadata, error) { + return f.Metadata, nil + }, + AuthorityHints: func() ([]string, error) { + return f.AuthorityHints, nil + }, + ConfigurationLifetime: func() (time.Duration, error) { return f.ConfigurationLifetime, nil }, + EntityStatementSigner: func() (*jwx.EntityStatementSigner, error) { + return f.EntityStatementSigner, nil + }, + TrustMarks: func() ([]*EntityConfigurationTrustMarkConfig, error) { + return f.TrustMarks, nil + }, + TrustMarkIssuers: func() (AllowedTrustMarkIssuers, error) { return f.TrustMarkIssuers, nil }, + TrustMarkOwners: func() (TrustMarkOwners, error) { + return f.TrustMarkOwners, nil + }, + Extra: func() (map[string]any, []string, error) { return f.Extra, f.CriticalClaims, nil }, + }.EntityConfigurationPayload() } // EntityConfigurationJWT creates and returns the signed jwt as a []byte for // the entity's entity configuration -func (f FederationEntity) EntityConfigurationJWT() ([]byte, error) { - return f.EntityStatementSigner.JWT(f.EntityConfigurationPayload()) +func (f StaticFederationEntity) EntityConfigurationJWT() ([]byte, error) { + payload, err := f.EntityConfigurationPayload() + if err != nil { + return nil, err + } + return f.SignEntityStatement(*payload) } // SignEntityStatement creates a signed JWT for the given EntityStatementPayload; this function is intended to be // used on TA/IA -func (f FederationEntity) SignEntityStatement(payload EntityStatementPayload) ([]byte, error) { +func (f StaticFederationEntity) SignEntityStatement(payload EntityStatementPayload) ([]byte, error) { return f.EntityStatementSigner.JWT(payload) } +// SignEntityStatementWithHeaders creates a signed JWT for the given +// EntityStatementPayload; this function is intended to be +// used on TA/IA +func (f StaticFederationEntity) SignEntityStatementWithHeaders( + payload EntityStatementPayload, headers jws.Headers, +) ([]byte, error) { + return f.EntityStatementSigner.JWTWithHeaders(payload, headers) +} + // RequestObjectProducer returns the entity's RequestObjectProducer func (f FederationLeaf) RequestObjectProducer() *RequestObjectProducer { return f.oidcROProducer diff --git a/go.mod b/go.mod index 0aee90f..8d413b0 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/go-oidfed/lib go 1.25.1 require ( + github.com/ThalesGroup/crypto11 v1.2.6 github.com/TwiN/gocache/v2 v2.4.0 github.com/adam-hanna/arrayOperations v1.0.1 github.com/coreos/go-oidc/v3 v3.17.0 @@ -49,10 +50,12 @@ require ( github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/miekg/pkcs11 v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/segmentio/asm v1.2.1 // indirect + github.com/thales-e-security/pool v0.0.2 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect github.com/valyala/fastjson v1.6.4 // indirect diff --git a/go.sum b/go.sum index 29b2261..c0bd745 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/ThalesGroup/crypto11 v1.2.6 h1:KixeJpVw3Y9gLSsz393XHh/Pez7q+KBXit4TQebmOz4= +github.com/ThalesGroup/crypto11 v1.2.6/go.mod h1:Grol7G+6zQdI94hGq+j702L1QFHSlJA5lBLl8uWAhG0= github.com/TwiN/gocache/v2 v2.4.0 h1:BZ/TqvhipDQE23MFFTjC0MiI1qZ7GEVtSdOFVVXyr18= github.com/TwiN/gocache/v2 v2.4.0/go.mod h1:Cl1c0qNlQlXzJhTpAARVqpQDSuGDM5RhtzPYAM1x17g= github.com/adam-hanna/arrayOperations v1.0.1 h1:iAot3I2p4yKrFk8eRhEkuHj0ttOrfFJMWAo7Is/rHwk= @@ -79,6 +81,8 @@ github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6T github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/maxatome/go-testdeep v1.14.0 h1:rRlLv1+kI8eOI3OaBXZwb3O7xY3exRzdW5QyX48g9wI= github.com/maxatome/go-testdeep v1.14.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM= +github.com/miekg/pkcs11 v1.1.1 h1:Ugu9pdy6vAYku5DEpVWVFPYnzV+bxB+iRdbuFSu7TvU= +github.com/miekg/pkcs11 v1.1.1/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -103,6 +107,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/thales-e-security/pool v0.0.2 h1:RAPs4q2EbWsTit6tpzuvTFlgFRJ3S8Evf5gtvVDbmPg= +github.com/thales-e-security/pool v0.0.2/go.mod h1:qtpMm2+thHtqhLzTwgDBj/OuNnMpupY8mv0Phz0gjhU= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= diff --git a/internal/logging.go b/internal/logging.go index 915b38f..d4b0a3c 100644 --- a/internal/logging.go +++ b/internal/logging.go @@ -107,3 +107,10 @@ func WithField(key string, value any) *logrus.Entry { // skipcq: RVV-B0001 func WithFields(fields logrus.Fields) *logrus.Entry { // skipcq: RVV-B0001 return logger.WithFields(fields) } + +// Fields is an alias exported to allow callers to construct structured fields +// without importing logrus directly when using the internal logger helpers. +// This keeps compatibility with existing call sites like +// `log.WithFields(log.Fields{...})` even after switching imports to this +// package. +type Fields = logrus.Fields diff --git a/jwx/fileio.go b/jwx/fileio.go index 0cb0c28..1bf6676 100644 --- a/jwx/fileio.go +++ b/jwx/fileio.go @@ -12,8 +12,8 @@ import ( "github.com/zachmann/go-utils/fileutils" ) -// readSignerFromFile loads the private key from the passed keyfile -func readSignerFromFile(keyfile string, alg jwa.SignatureAlgorithm) (crypto.Signer, error) { +// ReadSignerFromFile loads the private key from the passed keyfile +func ReadSignerFromFile(keyfile string, alg jwa.SignatureAlgorithm) (crypto.Signer, error) { keyFileContent, err := fileutils.ReadFile(keyfile) if err != nil { return nil, err @@ -44,7 +44,7 @@ func readSignerFromFile(keyfile string, alg jwa.SignatureAlgorithm) (crypto.Sign return sk, nil } -func writeSignerToFile(sk crypto.Signer, filePath string) error { +func WriteSignerToFile(sk crypto.Signer, filePath string) error { pemData := exportPrivateKeyAsPem(sk) err := errors.WithStack(os.WriteFile(filePath, pemData, 0600)) return err diff --git a/jwx/jwk.go b/jwx/jwk.go index 7a806dd..e392fc6 100644 --- a/jwx/jwk.go +++ b/jwx/jwk.go @@ -9,13 +9,10 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" - "time" "github.com/lestrrat-go/jwx/v3/jwa" "github.com/lestrrat-go/jwx/v3/jwk" "github.com/pkg/errors" - - "github.com/go-oidfed/lib/unixtime" ) // generatePrivateKey generates a cryptographic private key with the passed properties @@ -94,58 +91,35 @@ func exportEDDSAPrivateKeyAsPem(privkey ed25519.PrivateKey) []byte { return privkeyPem } -type keyLifetimeConf struct { - NowIssued bool - Expires bool - Lifetime time.Duration - Nbf *unixtime.Unixtime -} - -func signerToPublicJWK(sk crypto.Signer, alg jwa.SignatureAlgorithm, lifetimeConf keyLifetimeConf) (jwk.Key, error) { - pk, err := jwk.PublicKeyOf(sk.Public()) +func SignerToPublicJWK(sk crypto.Signer, alg jwa.SignatureAlgorithm) ( + pk jwk.Key, kid string, err error, +) { + pk, err = jwk.PublicKeyOf(sk.Public()) if err != nil { - return nil, err - } - if err = jwk.AssignKeyID(pk); err != nil { - return nil, errors.WithStack(err) - } - if err = pk.Set(jwk.KeyUsageKey, jwk.ForSignature); err != nil { - return nil, errors.WithStack(err) + return } - if err = pk.Set(jwk.AlgorithmKey, alg); err != nil { - return nil, errors.WithStack(err) + if err = errors.WithStack(jwk.AssignKeyID(pk)); err != nil { + return } - now := unixtime.Now() - if lifetimeConf.NowIssued { - if err = pk.Set("iat", now); err != nil { - return nil, errors.WithStack(err) - } + kid, _ = pk.KeyID() + if err = errors.WithStack(pk.Set(jwk.KeyUsageKey, jwk.ForSignature)); err != nil { + return } - if lifetimeConf.Expires { - exp := unixtime.Unixtime{Time: now.Add(lifetimeConf.Lifetime)} - if lifetimeConf.Nbf != nil && lifetimeConf.Nbf.After(now.Time) { - if err = errors.WithStack(pk.Set("nbf", lifetimeConf.Nbf)); err != nil { - return nil, err - } - exp = unixtime.Unixtime{Time: lifetimeConf.Nbf.Add(lifetimeConf.Lifetime)} - } - if err = errors.WithStack(pk.Set("exp", exp)); err != nil { - return nil, err - } + if err = errors.WithStack(pk.Set(jwk.AlgorithmKey, alg)); err != nil { + return } - return pk, nil + return } -// generatePrivateKey generates a cryptographic private key with the passed +// GenerateKeyPair generates a cryptographic private key with the passed // properties and returns the corresponding public key as a jwk.Key -func generateKeyPair(alg jwa.SignatureAlgorithm, rsaKeyLen int, lifetimeConf keyLifetimeConf) ( - sk crypto.Signer, pk jwk.Key, err error, +func GenerateKeyPair(alg jwa.SignatureAlgorithm, rsaKeyLen int) ( + sk crypto.Signer, pk jwk.Key, kid string, err error, ) { sk, err = generatePrivateKey(alg, rsaKeyLen) if err != nil { return } - lifetimeConf.NowIssued = true - pk, err = signerToPublicJWK(sk, alg, lifetimeConf) + pk, kid, err = SignerToPublicJWK(sk, alg) return } diff --git a/jwx/jwtsigning.go b/jwx/jwtsigning.go index 9a3c445..7f553d0 100644 --- a/jwx/jwtsigning.go +++ b/jwx/jwtsigning.go @@ -28,14 +28,14 @@ type VersatileSigner interface { // DefaultSigner returns a crypto.Signer and the corresponding jwa.SignatureAlgorithm DefaultSigner() (crypto.Signer, jwa.SignatureAlgorithm) // JWKS returns the jwks.JWKS containing all public keys of this VersatileSigner - JWKS() JWKS + JWKS() (JWKS, error) } // JWTSigner is an interface that can give signed jwts type JWTSigner interface { - JWT(i any) (jwt []byte, err error) - JWTWithHeaders(i any, headers jws.Headers) (jwt []byte, err error) - JWKS() JWKS + JWT(i any, alg ...jwa.SignatureAlgorithm) (jwt []byte, err error) + JWTWithHeaders(i any, headers jws.Headers, alg ...jwa.SignatureAlgorithm) (jwt []byte, err error) + JWKS() (JWKS, error) } // GeneralJWTSigner is a general jwt signer with no specific typ @@ -85,7 +85,7 @@ func (s GeneralJWTSigner) JWTWithHeaders(i any, headers jws.Headers, headerType } // JWKS returns the jwks.JWKS used with this signer -func (s *GeneralJWTSigner) JWKS() JWKS { +func (s *GeneralJWTSigner) JWKS() (JWKS, error) { return s.signer.JWKS() } @@ -244,18 +244,24 @@ func SignPayload(payload []byte, signingAlg jwa.SignatureAlgorithm, key crypto.S []byte, error, ) { - k, err := jwk.Import(key) - if err != nil { - return nil, err - } - if err = jwk.AssignKeyID(k); err != nil { - return nil, err + type signerWithKID interface{ KID() string } + var keyID string + if s, ok := key.(signerWithKID); ok { + keyID = s.KID() + } else { + k, err := jwk.PublicKeyOf(key.Public()) + if err != nil { + return nil, err + } + if err = jwk.AssignKeyID(k); err != nil { + return nil, err + } + keyID, _ = k.KeyID() } if headers == nil { headers = jws.NewHeaders() } - keyID, _ := k.KeyID() - if err = headers.Set(jws.KeyIDKey, keyID); err != nil { + if err := headers.Set(jws.KeyIDKey, keyID); err != nil { return nil, err } return jws.Sign(payload, jws.WithKey(signingAlg, key, jws.WithProtectedHeaders(headers))) diff --git a/jwx/keymanagement/kms/file.go b/jwx/keymanagement/kms/file.go new file mode 100644 index 0000000..c8683b1 --- /dev/null +++ b/jwx/keymanagement/kms/file.go @@ -0,0 +1,54 @@ +package kms + +import ( + "crypto" + + "github.com/lestrrat-go/jwx/v3/jwa" + + "github.com/go-oidfed/lib/jwx" +) + +// SingleSigningKeyFile implements BasicKeyManagementSystem for a single +// configured signing key stored in a PEM file on disk. It exposes a signer +// and its algorithm without any rotation behavior. +type SingleSigningKeyFile struct { + // Alg is the single configured signature algorithm for this signer. + Alg jwa.SignatureAlgorithm + // Path is the filesystem path to the PEM-encoded private key. + Path string + + signer crypto.Signer +} + +// Load loads the signer from the configured file path. +func (s *SingleSigningKeyFile) Load() error { + signer, err := jwx.ReadSignerFromFile(s.Path, s.Alg) + if err != nil { + return err + } + s.signer = signer + return nil +} + +// GetDefault returns the configured signer and its algorithm. +func (s *SingleSigningKeyFile) GetDefault() (crypto.Signer, jwa.SignatureAlgorithm) { + if s.signer == nil { + return nil, jwa.SignatureAlgorithm{} + } + return s.signer, s.Alg +} + +// GetForAlgs returns the signer if the requested algorithms include the +// configured algorithm; otherwise nil. +func (s *SingleSigningKeyFile) GetForAlgs(algs ...string) (crypto.Signer, jwa.SignatureAlgorithm) { + if s.signer == nil { + return nil, jwa.SignatureAlgorithm{} + } + want := s.Alg.String() + for _, a := range algs { + if a == want { + return s.signer, s.Alg + } + } + return nil, jwa.SignatureAlgorithm{} +} diff --git a/jwx/keymanagement/kms/filesystem.go b/jwx/keymanagement/kms/filesystem.go new file mode 100644 index 0000000..27d1214 --- /dev/null +++ b/jwx/keymanagement/kms/filesystem.go @@ -0,0 +1,592 @@ +package kms + +import ( + "cmp" + "crypto" + "fmt" + "slices" + "sync" + "time" + + "github.com/lestrrat-go/jwx/v3/jwa" + "github.com/pkg/errors" + + log "github.com/go-oidfed/lib/internal" + + "github.com/go-oidfed/lib/jwx" + "github.com/go-oidfed/lib/jwx/keymanagement/public" + "github.com/go-oidfed/lib/unixtime" +) + +// NewSingleAlgFilesystemKMS constructs a FilesystemKMS configured for a single +// signature algorithm, sharing the given PublicKeyStorage. +func NewSingleAlgFilesystemKMS( + alg jwa.SignatureAlgorithm, + config FilesystemKMSConfig, pks public.PublicKeyStorage, +) KeyManagementSystem { + config.Algs = []jwa.SignatureAlgorithm{alg} + config.DefaultAlg = alg + return &FilesystemKMS{ + FilesystemKMSConfig: config, + PKs: pks, + } +} + +// NewFilesystemKMSAndPublicKeyStorage creates a new FilesystemKMS and PublicKeyStorage +// backed by the same directory. +func NewFilesystemKMSAndPublicKeyStorage(config FilesystemKMSConfig) (KeyManagementSystem, error) { + pks := &public.FilesystemPublicKeyStorage{ + Dir: config.Dir, + TypeID: config.TypeID, + } + if err := pks.Load(); err != nil { + return nil, err + } + return &FilesystemKMS{ + FilesystemKMSConfig: config, + PKs: pks, + }, nil +} + +// FilesystemKMSConfig is the configuration for a FilesystemKMS. +type FilesystemKMSConfig struct { + KMSConfig + Dir string + TypeID string +} + +// FilesystemKMS implements KeyManagementSystem using PEM files for private keys +// on disk and a PublicKeyStorage for public key metadata. +type FilesystemKMS struct { + FilesystemKMSConfig + + // signers is a map of all loaded signers, where the key is the kid + signers map[string]crypto.Signer + + PKs public.PublicKeyStorage + + // automatic rotation control + rotationStop chan struct{} + rotationWG sync.WaitGroup +} + +func (kms *FilesystemKMS) keyFilePath(kid string) string { + return fmt.Sprintf("%s/%s.pem", kms.Dir, kid) +} + +// GetDefault returns a crypto.Signer and the corresponding jwa.SignatureAlgorithm +func (kms *FilesystemKMS) GetDefault() (crypto.Signer, jwa.SignatureAlgorithm) { + if len(kms.Algs) == 0 { + return nil, jwa.SignatureAlgorithm{} + } + var algs []string + if kms.DefaultAlg.String() != "" { + algs = []string{kms.DefaultAlg.String()} + } + for _, a := range kms.Algs { + algs = append(algs, a.String()) + } + return kms.GetForAlgs(algs...) +} + +// GetForAlgs takes a list of acceptable signature algorithms and returns a +// usable crypto.Signer or nil as well as the corresponding +// jwa.SignatureAlgorithm +func (kms *FilesystemKMS) GetForAlgs(algs ...string) ( + crypto.Signer, + jwa.SignatureAlgorithm, +) { + activePKs, err := kms.PKs.GetActive() + if err != nil { + log.WithError(err).Error("FilesystemKMS: failed to get active public keys") + return nil, jwa.SignatureAlgorithm{} + } + pksByAlg := activePKs.ByAlg() + for _, alg := range kms.Algs { + if !slices.Contains(algs, alg.String()) { + continue + } + algPKs, ok := pksByAlg[alg] + if !ok || len(algPKs) == 0 { + continue + } + pk := algPKs[0] + if len(algPKs) > 1 { + maxExp := unixtime.Now() + maxExpWithNbf := maxExp + noExpIndex := -1 + maxExpIndex := -1 + maxExpWithNbfIndex := -1 + nbfTreshold := time.Now().Add(-kms.KeyRotation.Overlap.Duration() / 2) + for i, it := range algPKs { + if it.ExpiresAt == nil { + noExpIndex = i + continue + } + if it.NotBefore != nil && it.NotBefore.Before(nbfTreshold) && it.ExpiresAt.After( + maxExpWithNbf.Time, + ) { + maxExpWithNbf = *it.ExpiresAt + maxExpWithNbfIndex = i + + } else if maxExpIndex == -1 && it.ExpiresAt.After(maxExp.Time) { + maxExp = *it.ExpiresAt + maxExpIndex = i + } + } + if maxExpWithNbfIndex != -1 { + pk = algPKs[maxExpWithNbfIndex] + } else if maxExpIndex != -1 { + pk = algPKs[maxExpIndex] + } else { + pk = algPKs[noExpIndex] + } + } + signer, ok := kms.signers[pk.KID] + if !ok { + continue + } + return signer, alg + } + return nil, jwa.SignatureAlgorithm{} +} + +// Load loads the private keys from disk and generates missing keys when +// configured to do so. +func (kms *FilesystemKMS) Load() error { + log.Debugf("FilesystemKMS: loading keys from '%s'", kms.Dir) + if kms.signers == nil { + kms.signers = make(map[string]crypto.Signer) + } + + log.Debug("FilesystemKMS: loading active pks") + activePKs, err := kms.PKs.GetActive() + if err != nil { + return err + } + log.Debugf("FilesystemKMS: found %d active keys in pk storage", len(activePKs)) + // initialize map before use + var loadedAlgs map[jwa.SignatureAlgorithm]struct{} + loadedAlgs = make(map[jwa.SignatureAlgorithm]struct{}) + for _, activePK := range activePKs { + kid := activePK.KID + kalg, _ := activePK.Key.Algorithm() + alg := kalg.(jwa.SignatureAlgorithm) + signer, err := jwx.ReadSignerFromFile(kms.keyFilePath(kid), alg) + if err != nil { + log.WithError(err).WithField("kid", kid).Warn("FilesystemKMS: could not load signing key") + } else { + kms.signers[kid] = signer + loadedAlgs[alg] = struct{}{} + } + } + log.Debugf("FilesystemKMS: loaded %d active keys", len(kms.signers)) + + log.Debug("FilesystemKMS: Checking that all signing algs have a valid key") + for _, alg := range kms.Algs { + if _, ok := loadedAlgs[alg]; ok { + log.WithField("alg", alg.String()).Debug("FilesystemKMS: key for alg already found") + continue + } + log.WithField("alg", alg.String()).Debug("FilesystemKMS: key for alg is missing") + if !kms.GenerateKeys { + log.Info("FilesystemKMS: key generation disabled") + return errors.Errorf( + "no existing signing key for alg '%s'. Assure the file exists and has the correct format or enable key generation.", + alg, + ) + } + log.Info("FilesystemKMS: generating new signing key") + if _, err = kms.generateNewSigner(alg, nbfModeNow); err != nil { + return err + } + } + return nil +} + +// NewFilesystemKMSFromBasic creates a new FilesystemKMS initialized from an existing +// BasicKeyManagementSystem and persists private keys for the configured algorithms +// into the filesystem at the configured directory. +func NewFilesystemKMSFromBasic( + src BasicKeyManagementSystem, + config FilesystemKMSConfig, + pks public.PublicKeyStorage, +) (KeyManagementSystem, error) { + kms := &FilesystemKMS{ + FilesystemKMSConfig: config, + PKs: pks, + signers: make(map[string]crypto.Signer), + } + + // Ensure target PK storage is loaded + if err := pks.Load(); err != nil { + return nil, err + } + + // For each configured algorithm, obtain a signer from the source and persist it + for _, alg := range config.Algs { + signer, usedAlg := src.GetForAlgs(alg.String()) + if signer == nil || usedAlg.String() == "" { + continue + } + pk, kid, err := jwx.SignerToPublicJWK(signer, usedAlg) + if err != nil { + return nil, err + } + // Write private key to new location + if err = jwx.WriteSignerToFile(signer, kms.keyFilePath(kid)); err != nil { + return nil, err + } + // Register signer locally + kms.signers[kid] = signer + // Add public key metadata to PK storage if missing + existing, err := pks.Get(kid) + if err != nil { + return nil, err + } + if existing == nil { + now := unixtime.Now() + pke := public.PublicKeyEntry{ + KID: kid, + Key: public.JWKKey{Key: pk}, + IssuedAt: &now, + NotBefore: &now, + } + if err = pks.Add(pke); err != nil { + return nil, err + } + } + } + + // Finalize by loading any remaining keys normally + if err := kms.Load(); err != nil { + log.WithError(err).Warn("NewFilesystemKMSFromBasic: Load encountered issues after migration") + } + return kms, nil +} + +func (kms *FilesystemKMS) generateNewSigner( + alg jwa.SignatureAlgorithm, + mode nbfMode, +) (*public.PublicKeyEntry, error) { + sk, pk, kid, err := jwx.GenerateKeyPair(alg, kms.RSAKeyLen) + if err != nil { + return nil, err + } + now := unixtime.Now() + var nbf *unixtime.Unixtime + switch mode { + case nbfModeNow: + nbf = &now + case nbfModeNext: + lifetime, err := kms.KeyRotation.EntityConfigurationLifetimeFunc() + if err != nil { + return nil, errors.Wrap(err, "failed to get entity configuration lifetime") + } + nbf = &unixtime.Unixtime{Time: now.Add(lifetime)} + default: + return nil, errors.New("invalid nbf mode") + } + var exp *unixtime.Unixtime + if kms.KeyRotation.Enabled { + exp = &unixtime.Unixtime{Time: nbf.Add(kms.KeyRotation.Interval.Duration())} + } + pke := public.PublicKeyEntry{ + KID: kid, + Key: public.JWKKey{Key: pk}, + IssuedAt: &now, + NotBefore: nbf, + UpdateablePublicKeyMetadata: public.UpdateablePublicKeyMetadata{ + ExpiresAt: exp, + }, + } + if err = kms.PKs.Add(pke); err != nil { + return nil, err + } + if err = jwx.WriteSignerToFile(sk, kms.keyFilePath(kid)); err != nil { + return nil, err + } + kms.signers[kid] = sk + return &pke, nil +} + +func (kms *FilesystemKMS) rotateKeys(kids []string, revoked bool, reason string) error { + log.WithFields( + log.Fields{ + "kids": kids, + "revoked": revoked, + }, + ).Info("FilesystemKMS: rotation: start") + ks := make([]*public.PublicKeyEntry, len(kids)) + var signingAlg jwa.SignatureAlgorithm + // Track latest expiration across keys to decide nbf mode for new key + latestExp := time.Time{} + for i, kid := range kids { + k, err := kms.PKs.Get(kid) + if err != nil { + return err + } + alg, _ := k.Key.Algorithm() + if signingAlg.String() == "" { + signingAlg = alg.(jwa.SignatureAlgorithm) + } else { + if signingAlg.String() != alg.String() { + return errors.New("all keys must be of the same algorithm") + } + } + ks[i] = k + if k.ExpiresAt != nil && !k.ExpiresAt.IsZero() && (latestExp.IsZero() || k.ExpiresAt.After(latestExp)) { + latestExp = k.ExpiresAt.Time + } + } + mode := nbfModeNext + if revoked { + mode = nbfModeNow + } + if mode == nbfModeNext { + if lifetime, err := kms.KeyRotation.EntityConfigurationLifetimeFunc(); err == nil { + if time.Now().Add(lifetime).After(latestExp) { + mode = nbfModeNow + } + } + } + pk, err := kms.generateNewSigner(signingAlg, mode) + if err != nil { + return err + } + log.WithFields( + log.Fields{ + "alg": signingAlg.String(), + "mode": fmt.Sprintf("%v", mode), + "new_kid": pk.KID, + }, + ).Info("FilesystemKMS: rotation: generated new key") + newExpForOldKey := &unixtime.Unixtime{Time: pk.NotBefore.Add(kms.KeyRotation.Overlap.Duration())} + for _, k := range ks { + if revoked { + now := unixtime.Now() + k.RevokedAt = &now + k.Reason = reason + } + if k.ExpiresAt != nil && k.ExpiresAt.IsZero() || newExpForOldKey.Before(k.ExpiresAt.Time) || newExpForOldKey.After(k.ExpiresAt.Time) { + k.ExpiresAt = newExpForOldKey + } + if err = kms.PKs.Update(k.KID, k.UpdateablePublicKeyMetadata); err != nil { + return err + } + } + log.WithFields( + log.Fields{ + "alg": signingAlg.String(), + "new_kid": pk.KID, + }, + ).Info("FilesystemKMS: rotation: completed") + return nil +} + +// RotateKey rotates a single key, optionally marking it revoked and recording a reason. +func (kms *FilesystemKMS) RotateKey(kid string, revoked bool, reason string) error { + log.WithFields( + log.Fields{ + "kid": kid, + "revoked": revoked, + }, + ).Info("FilesystemKMS: rotate key") + return kms.rotateKeys([]string{kid}, revoked, reason) +} + +// RotateAllKeys rotates all active keys per configured algorithm, optionally +// marking them revoked and recording a reason. +func (kms *FilesystemKMS) RotateAllKeys(revoked bool, reason string) error { + // Get all currently active public keys + activePKs, err := kms.PKs.GetActive() + if err != nil { + return err + } + + // Group active keys by algorithm + pksByAlg := activePKs.ByAlg() + + // Iterate over configured algorithms only + for _, alg := range kms.Algs { + algPKs, ok := pksByAlg[alg] + if !ok || len(algPKs) == 0 { + // Nothing to rotate for this algorithm; create a new key + if _, err = kms.generateNewSigner(alg, nbfModeNow); err != nil { + return err + } + log.WithField( + "alg", alg.String(), + ).Info("FilesystemKMS: rotation: seeded new key for alg with no active keys") + } + + kids := make([]string, len(algPKs)) + for i, pk := range algPKs { + kids[i] = pk.KID + } + log.WithField("alg", alg.String()).Info("FilesystemKMS: rotation: processing alg") + if err = kms.rotateKeys(kids, revoked, reason); err != nil { + return err + } + } + return nil +} + +// StartAutomaticRotation starts a background loop that monitors key expiration +// thresholds and rotates keys ahead of time based on the configured overlap. +func (kms *FilesystemKMS) StartAutomaticRotation() error { + if !kms.KeyRotation.Enabled { + return nil + } + // ensure only one rotation loop runs + if kms.rotationStop != nil { + return nil + } + log.Info("FilesystemKMS: automatic rotation: starting") + kms.rotationStop = make(chan struct{}) + kms.rotationWG.Add(1) + go func() { + defer kms.rotationWG.Done() + for { + nextSleep, didRotate := kms.rotationStep(time.Now()) + if didRotate { + select { + case <-kms.rotationStop: + return + default: + } + continue + } + if nextSleep <= 0 { + nextSleep = time.Second + } + timer := time.NewTimer(nextSleep) + select { + case <-kms.rotationStop: + if !timer.Stop() { + <-timer.C + } + return + case <-timer.C: + // loop + } + } + }() + return nil +} + +// rotationStep performs one evaluation/rotation cycle and returns the next sleep +// interval and whether any rotation or seeding occurred (didRotate). +func (kms *FilesystemKMS) rotationStep(now time.Time) (time.Duration, bool) { + nextSleep := kms.KeyRotation.Overlap.Duration() / 2 + const minSleep = time.Second + if nextSleep <= 0 { + nextSleep = minSleep + } + didRotate := false + + activePKs, err := kms.PKs.GetActive() + if err != nil { + log.WithError(err).Error("FilesystemKMS: automatic rotation: failed to get active public keys") + return nextSleep, false + } + pksByAlg := activePKs.ByAlg() + for _, alg := range kms.Algs { + sleepCandidate, rotated := kms.rotationEvaluationForAlg(pksByAlg, alg, now, minSleep) + if rotated { + didRotate = true + } + if sleepCandidate > 0 && sleepCandidate < nextSleep { + nextSleep = sleepCandidate + } + } + return nextSleep, didRotate +} + +// rotationEvaluationForAlg evaluates rotation needs for a single algorithm. +// It returns a candidate sleep duration until the next action point and whether +// any rotation or seeding occurred. +func (kms *FilesystemKMS) rotationEvaluationForAlg( + pksByAlg map[jwa.SignatureAlgorithm]public.PublicKeyEntryList, + alg jwa.SignatureAlgorithm, + now time.Time, + minSleep time.Duration, +) (time.Duration, bool) { + algPKs, ok := pksByAlg[alg] + if !ok || len(algPKs) == 0 { + earliestNbf, hasFuture, vErr := earliestFutureNbfForAlg(kms.PKs, alg, now) + if vErr != nil { + log.WithError(vErr).Error("FilesystemKMS: automatic rotation: failed to get valid public keys for future check") + return 0, false + } + if hasFuture { + wait := time.Until(earliestNbf) + if wait < minSleep { + wait = minSleep + } + return wait, false + } + if _, err := kms.generateNewSigner(alg, nbfModeNow); err != nil { + log.WithError(err).Error("FilesystemKMS: automatic rotation: failed to seed key for alg") + return minSleep, false + } + return 0, true + } + + current := slices.MaxFunc( + algPKs, func(a, b public.PublicKeyEntry) int { + return cmp.Compare(a.ExpiresAt.Unix(), b.ExpiresAt.Unix()) + }, + ) + + lifetime := time.Duration(0) + if kms.KeyRotation.EntityConfigurationLifetimeFunc != nil { + if lt, lerr := kms.KeyRotation.EntityConfigurationLifetimeFunc(); lerr == nil { + lifetime = lt + } else { + log.WithError(lerr).Warn("FilesystemKMS: automatic rotation: failed to get lifetime; using 0") + } + } + threshold := current.ExpiresAt.Time.Add(-kms.KeyRotation.Overlap.Duration()).Add(-lifetime) + if !threshold.After(now) { + kids := make([]string, len(algPKs)) + for i, pk := range algPKs { + kids[i] = pk.KID + } + if earliestNbf, hasFuture, vErr := earliestFutureNbfForAlg(kms.PKs, alg, now); vErr == nil && hasFuture { + shortenExpirationUntilFuture( + kms.PKs, algPKs, earliestNbf, kms.KeyRotation.Overlap.Duration(), "FilesystemKMS", + ) + wait := time.Until(earliestNbf) + if wait < minSleep { + wait = minSleep + } + return wait, false + } + if err := kms.rotateKeys(kids, false, ""); err != nil { + log.WithError(err).Error("FilesystemKMS: automatic rotation: rotate failed") + return minSleep, false + } + return 0, true + } + wait := time.Until(threshold) + if wait < minSleep { + wait = minSleep + } + return wait, false +} + +// earliestFutureNbfForAlg returns the earliest NotBefore among valid, non-revoked +// keys for the given algorithm, that are in the future relative to now. +// Removed local earliestFutureNbfForAlg and shortenExpirationUntilFuture in favor of shared helpers. + +// StopAutomaticRotation stops the background rotation loop and waits for it to exit. +func (kms *FilesystemKMS) StopAutomaticRotation() { + if kms.rotationStop == nil { + return + } + close(kms.rotationStop) + kms.rotationWG.Wait() + log.Info("FilesystemKMS: automatic rotation: stopped") + kms.rotationStop = nil +} diff --git a/jwx/keymanagement/kms/kms.go b/jwx/keymanagement/kms/kms.go new file mode 100644 index 0000000..21282b9 --- /dev/null +++ b/jwx/keymanagement/kms/kms.go @@ -0,0 +1,152 @@ +package kms + +import ( + "crypto" + "time" + + "github.com/lestrrat-go/jwx/v3/jwa" + log "github.com/sirupsen/logrus" + "github.com/zachmann/go-utils/duration" + + "github.com/go-oidfed/lib/jwx" + "github.com/go-oidfed/lib/jwx/keymanagement/public" + "github.com/go-oidfed/lib/unixtime" +) + +type nbfMode int + +const ( + nbfModeNow nbfMode = iota + nbfModeNext +) + +// BasicKeyManagementSystem provides methods to load keys and retrieve +// a default or algorithm-specific signer. +type BasicKeyManagementSystem interface { + Load() error + GetForAlgs(algs ...string) (crypto.Signer, jwa.SignatureAlgorithm) + GetDefault() (crypto.Signer, jwa.SignatureAlgorithm) +} + +// KeyManagementSystem extends BasicKeyManagementSystem with rotation and +// automatic rotation controls. +type KeyManagementSystem interface { + BasicKeyManagementSystem + RotateKey(kid string, revoked bool, reason string) error + RotateAllKeys(revoked bool, reason string) error + StartAutomaticRotation() error + StopAutomaticRotation() +} + +// KMSConfig contains base configuration for a KeyManagementSystem, including +// algorithms, key length and rotation behavior. +type KMSConfig struct { + GenerateKeys bool + Algs []jwa.SignatureAlgorithm + DefaultAlg jwa.SignatureAlgorithm + RSAKeyLen int + KeyRotation KeyRotationConfig +} + +// KeyRotationConfig is a type holding configuration for key rollover / key rotation +type KeyRotationConfig struct { + Enabled bool `yaml:"enabled"` + Interval duration.DurationOption `yaml:"interval"` + Overlap duration.DurationOption `yaml:"overlap"` + EntityConfigurationLifetimeFunc func() (time.Duration, error) `yaml:"-"` +} + +type kmsAsVersatileSigner struct { + kms BasicKeyManagementSystem + jwksFnc func() (jwx.JWKS, error) +} + +func (k kmsAsVersatileSigner) Signer(algs ...string) (crypto.Signer, jwa.SignatureAlgorithm) { + return k.kms.GetForAlgs(algs...) +} + +func (k kmsAsVersatileSigner) DefaultSigner() (crypto.Signer, jwa.SignatureAlgorithm) { + return k.kms.GetDefault() +} + +func (k kmsAsVersatileSigner) JWKS() (jwx.JWKS, error) { + return k.jwksFnc() +} + +// KMSToVersatileSignerWithJWKSFunc returns a VersatileSigner that uses the passed +// BasicKeyManagementSystem to load keys and returns the JWKS from the passed function. +func KMSToVersatileSignerWithJWKSFunc( + kms BasicKeyManagementSystem, jwksFnc func() (jwx.JWKS, error), +) jwx.VersatileSigner { + return kmsAsVersatileSigner{ + kms: kms, + jwksFnc: jwksFnc, + } +} + +// KMSToVersatileSignerWithPKStorage returns a VersatileSigner that uses the passed +// BasicKeyManagementSystem to load keys and returns the JWKS from the passed public.PublicKeyStorage. +func KMSToVersatileSignerWithPKStorage( + kms BasicKeyManagementSystem, pkStorage public.PublicKeyStorage, +) jwx.VersatileSigner { + return kmsAsVersatileSigner{ + kms: kms, + jwksFnc: func() (jwx.JWKS, error) { + list, err := pkStorage.GetValid() + if err != nil { + return jwx.JWKS{}, err + } + return list.JWKS() + }, + } +} + +// Shared rotation helpers to reduce duplication across KMS implementations. + +// earliestFutureNbfForAlg returns the earliest NotBefore among valid, non-revoked +// keys for the given algorithm, that are in the future relative to now. +func earliestFutureNbfForAlg(pkStorage public.PublicKeyStorage, alg jwa.SignatureAlgorithm, now time.Time) ( + time.Time, bool, error, +) { + validPKs, vErr := pkStorage.GetValid() + if vErr != nil { + return time.Time{}, false, vErr + } + earliestNbf := time.Time{} + for _, pk := range validPKs { + algI, set := pk.Key.Algorithm() + if !set { + continue + } + a, ok := algI.(jwa.SignatureAlgorithm) + if !ok || a.String() != alg.String() { + continue + } + if pk.RevokedAt != nil && !pk.RevokedAt.IsZero() && pk.RevokedAt.Before(now) { + continue + } + if pk.NotBefore != nil && !pk.NotBefore.IsZero() && pk.NotBefore.After(now) { + if earliestNbf.IsZero() || pk.NotBefore.Before(earliestNbf) { + earliestNbf = pk.NotBefore.Time + } + } + } + return earliestNbf, !earliestNbf.IsZero(), nil +} + +// shortenExpirationUntilFuture updates the expiration of current active keys so that +// they extend until the future key's NotBefore plus overlap. +func shortenExpirationUntilFuture( + pkStorage public.PublicKeyStorage, algPKs []public.PublicKeyEntry, earliestNbf time.Time, overlap time.Duration, + logPrefix string, +) { + newExpForOldKey := &unixtime.Unixtime{Time: earliestNbf.Add(overlap)} + for _, k := range algPKs { + if k.ExpiresAt != nil && k.ExpiresAt.IsZero() || newExpForOldKey.Before(k.ExpiresAt.Time) { + k.ExpiresAt = newExpForOldKey + if uErr := pkStorage.Update(k.KID, k.UpdateablePublicKeyMetadata); uErr != nil { + log.WithError(uErr).Error(logPrefix + ": automatic rotation: failed to update old key exp") + } + } + } +} diff --git a/jwx/keymanagement/kms/legacy.go b/jwx/keymanagement/kms/legacy.go new file mode 100644 index 0000000..9003c98 --- /dev/null +++ b/jwx/keymanagement/kms/legacy.go @@ -0,0 +1,59 @@ +package kms + +import ( + "crypto" + "fmt" + "slices" + + "github.com/lestrrat-go/jwx/v3/jwa" + + "github.com/go-oidfed/lib/jwx" +) + +// LegacyFilesystemKMS provides a read-only BasicKeyManagementSystem backed by legacy +// key file layout (_.pem). It enables migration by exposing GetForAlgs +// based on the legacy files, without supporting rotation. +type LegacyFilesystemKMS struct { + Dir string + TypeID string + Algs []jwa.SignatureAlgorithm + signers map[string]crypto.Signer // kid -> signer +} + +func (l *LegacyFilesystemKMS) legacyKeyFilePath(alg jwa.SignatureAlgorithm) string { + return fmt.Sprintf("%s/%s_%s.pem", l.Dir, l.TypeID, alg.String()) +} + +func (l *LegacyFilesystemKMS) Load() error { + l.signers = make(map[string]crypto.Signer) + for _, alg := range l.Algs { + signer, err := jwx.ReadSignerFromFile(l.legacyKeyFilePath(alg), alg) + if err != nil { + continue + } + _, kid, err := jwx.SignerToPublicJWK(signer, alg) + if err != nil { + continue + } + l.signers[kid] = signer + } + return nil +} + +func (l *LegacyFilesystemKMS) GetForAlgs(algs ...string) (crypto.Signer, jwa.SignatureAlgorithm) { + for _, alg := range l.Algs { + for _, signer := range l.signers { + if slices.Contains(algs, alg.String()) { + return signer, alg + } + } + } + return nil, jwa.SignatureAlgorithm{} +} + +func (l *LegacyFilesystemKMS) GetDefault() (crypto.Signer, jwa.SignatureAlgorithm) { + if len(l.Algs) == 0 { + return nil, jwa.SignatureAlgorithm{} + } + return l.GetForAlgs(l.Algs[0].String()) +} diff --git a/jwx/keymanagement/kms/pkcs11.go b/jwx/keymanagement/kms/pkcs11.go new file mode 100644 index 0000000..d31ff94 --- /dev/null +++ b/jwx/keymanagement/kms/pkcs11.go @@ -0,0 +1,764 @@ +package kms + +import ( + "cmp" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rsa" + "fmt" + "io" + "slices" + "sync" + "time" + + "github.com/ThalesGroup/crypto11" + "github.com/google/uuid" + "github.com/lestrrat-go/jwx/v3/jwa" + "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/pkg/errors" + + log "github.com/go-oidfed/lib/internal" + + "github.com/go-oidfed/lib/jwx" + "github.com/go-oidfed/lib/jwx/keymanagement/public" + "github.com/go-oidfed/lib/unixtime" +) + +// NewSingleAlgPKCS11KMS constructs a PKCS#11-backed KMS for a single algorithm. +func NewSingleAlgPKCS11KMS( + alg jwa.SignatureAlgorithm, + config PKCS11KMSConfig, pks public.PublicKeyStorage, +) KeyManagementSystem { + config.Algs = []jwa.SignatureAlgorithm{alg} + config.DefaultAlg = alg + return &PKCS11KMS{ + PKCS11KMSConfig: config, + PKs: pks, + } +} + +// PKCS11KMSConfig contains configuration for the PKCS#11 KMS. +// Keys are created and looked up inside the HSM using labels derived from the KID. +// If LabelPrefix is set, labels are LabelPrefix+"_"+KID; else if TypeID is set, TypeID+"_"+KID; otherwise KID. +type PKCS11KMSConfig struct { + KMSConfig + + // TypeID is a logical namespace for this KMS (used in labels if LabelPrefix is empty) + TypeID string + + // ModulePath is the path to the PKCS#11 module (crypto11.Config.Path) + ModulePath string + // TokenLabel selects the token by label (crypto11.Config.TokenLabel) + TokenLabel string + // TokenSerial selects the token by serial (crypto11.Config.TokenSerial) + TokenSerial string + // Pin is the user PIN for the token (crypto11.Config.Pin) + Pin string + + // Optional prefix for object labels inside HSM + LabelPrefix string + + // ExtraLabels are HSM object labels to load into this KMS even if + // they are not present yet in the PublicKeyStorage. + ExtraLabels []string +} + +// PKCS11KMS implements KeyManagementSystem using a PKCS#11 HSM. +type PKCS11KMS struct { + PKCS11KMSConfig + + ctx *crypto11.Context + + // signers is a map of all loaded signers, keyed by kid + signers map[string]crypto.Signer + + PKs public.PublicKeyStorage + + // automatic rotation control + rotationStop chan struct{} + rotationWG sync.WaitGroup +} + +// labeledSigner wraps a crypto.Signer and carries a stable KID (e.g., HSM label). +type labeledSigner struct { + s crypto.Signer + kid string +} + +// Public returns the public key associated with this signer. +func (l *labeledSigner) Public() crypto.PublicKey { return l.s.Public() } + +// Sign signs digest with the private key associated with this signer. +func (l *labeledSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + return l.s.Sign(rand, digest, opts) +} + +// KID returns the key ID associated with this signer. +func (l *labeledSigner) KID() string { return l.kid } + +// keyLabel constructs the HSM object label from kid and configured prefixes. +func (kms *PKCS11KMS) keyLabel(kid string) string { + prefix := kms.LabelPrefix + if prefix == "" { + prefix = kms.TypeID + } + if prefix == "" { + return kid + } + return fmt.Sprintf("%s_%s", prefix, kid) +} + +// GetDefault returns a crypto.Signer and the corresponding jwa.SignatureAlgorithm +func (kms *PKCS11KMS) GetDefault() (crypto.Signer, jwa.SignatureAlgorithm) { + if len(kms.Algs) == 0 { + return nil, jwa.SignatureAlgorithm{} + } + var algs []string + if kms.DefaultAlg.String() != "" { + algs = []string{kms.DefaultAlg.String()} + } + for _, a := range kms.Algs { + algs = append(algs, a.String()) + } + return kms.GetForAlgs(algs...) +} + +// GetForAlgs returns a signer for the first acceptable algorithm found among active keys. +func (kms *PKCS11KMS) GetForAlgs(algs ...string) ( + crypto.Signer, + jwa.SignatureAlgorithm, +) { + activePKs, err := kms.PKs.GetActive() + if err != nil { + log.WithError(err).Error("pkcs#11 KMS: failed to get active public keys") + return nil, jwa.SignatureAlgorithm{} + } + pksByAlg := activePKs.ByAlg() + for _, alg := range kms.Algs { + if !slices.Contains(algs, alg.String()) { + continue + } + algPKs, ok := pksByAlg[alg] + if !ok || len(algPKs) == 0 { + continue + } + pk := algPKs[0] + if len(algPKs) > 1 { + maxExp := unixtime.Now() + maxExpWithNbf := maxExp + maxExpIndex := -1 + maxExpWithNbfIndex := -1 + noExpIndex := -1 + nbfTreshold := time.Now().Add(-kms.KeyRotation.Overlap.Duration() / 2) + for i, it := range algPKs { + if it.ExpiresAt == nil { + noExpIndex = i + continue + } + if it.NotBefore != nil && it.NotBefore.Before(nbfTreshold) && it.ExpiresAt.After( + maxExpWithNbf.Time, + ) { + maxExpWithNbf = *it.ExpiresAt + maxExpWithNbfIndex = i + + } else if maxExpIndex == -1 && it.ExpiresAt.After(maxExp.Time) { + maxExp = *it.ExpiresAt + maxExpIndex = i + } + } + if maxExpWithNbfIndex != -1 { + pk = algPKs[maxExpWithNbfIndex] + } else if maxExpIndex != -1 { + pk = algPKs[maxExpIndex] + } else { + pk = algPKs[noExpIndex] + } + } + signer, ok := kms.signers[pk.KID] + if !ok { + continue + } + return signer, alg + } + return nil, jwa.SignatureAlgorithm{} +} + +// Load initializes the PKCS#11 context, loads HSM-resident signers for active keys, +// and generates any missing keys if enabled. +func (kms *PKCS11KMS) Load() error { + if kms.signers == nil { + kms.signers = make(map[string]crypto.Signer) + } + if kms.ctx == nil { + cfg := &crypto11.Config{ + Path: kms.ModulePath, + TokenLabel: kms.TokenLabel, + TokenSerial: kms.TokenSerial, + Pin: kms.Pin, + } + ctx, err := crypto11.Configure(cfg) + if err != nil { + return errors.Wrap(err, "pkcs11 kms: configure crypto11") + } + kms.ctx = ctx + } + + activePKs, err := kms.PKs.GetActive() + if err != nil { + return err + } + + loadedAlgs := make(map[jwa.SignatureAlgorithm]struct{}) + for _, activePK := range activePKs { + kid := activePK.KID + kalg, _ := activePK.Key.Algorithm() + alg, ok := kalg.(jwa.SignatureAlgorithm) + if !ok { + continue + } + if !kms.algorithmSupported(alg) { + continue + } + // Load signer by label (kid-derived) + signer, err := kms.findKeyByKID(kid) + if err != nil { + return err + } + if signer == nil { + continue + } + kms.signers[kid] = signer + // Wrap with label so downstream signing can set kid header from HSM label + kms.signers[kid] = &labeledSigner{ + s: signer, + kid: kid, + } + loadedAlgs[alg] = struct{}{} + } + + for _, alg := range kms.Algs { + if _, ok := loadedAlgs[alg]; ok { + continue + } + // Not available; create if allowed + if !kms.GenerateKeys { + return errors.Errorf( + "no existing HSM signing key for alg '%s'. Enable key generation or provision keys", + alg, + ) + } + if _, err = kms.generateNewSigner(alg, nbfModeNow); err != nil { + return err + } + } + // Load extra labels explicitly requested via config + for _, label := range kms.ExtraLabels { + signer, ferr := kms.ctx.FindKeyPair(nil, []byte(label)) + if ferr != nil { + log.WithError(ferr).WithField("label", label).Warn("pkcs#11 KMS: failed to find extra label") + continue + } + if signer == nil { + continue + } + alg, algErr := kms.algForSigner(signer) + if algErr != nil { + log.WithError(algErr).WithField( + "label", label, + ).Warn("pkcs#11 KMS: could not determine algorithm for extra label") + continue + } + pk, _, pkErr := jwx.SignerToPublicJWK(signer, alg) + if pkErr != nil { + log.WithError(pkErr).WithField( + "label", label, + ).Warn("pkcs#11 KMS: failed to derive public JWK for extra label") + continue + } + _ = pk.Set(jwk.KeyIDKey, label) + existing, gerr := kms.PKs.Get(label) + if gerr != nil { + log.WithError(gerr).WithField( + "label", label, + ).Warn("pkcs#11 KMS: failed to query public storage for extra label") + continue + } + if existing == nil { + now := unixtime.Now() + var exp *unixtime.Unixtime + if kms.KeyRotation.Enabled { + exp = &unixtime.Unixtime{Time: now.Add(kms.KeyRotation.Interval.Duration())} + } + pke := public.PublicKeyEntry{ + KID: label, + Key: public.JWKKey{Key: pk}, + IssuedAt: &now, + NotBefore: &now, + UpdateablePublicKeyMetadata: public.UpdateablePublicKeyMetadata{ + ExpiresAt: exp, + }, + } + if aerr := kms.PKs.Add(pke); aerr != nil { + log.WithError(aerr).WithField( + "label", label, + ).Warn("pkcs#11 KMS: failed to add extra label to public storage") + continue + } + } + // Wrap with label so downstream signing can set kid header from HSM label + kms.signers[label] = &labeledSigner{ + s: signer, + kid: label, + } + } + return nil +} + +// algorithmSupported reports whether the given algorithm is among the configured ones. +func (kms *PKCS11KMS) algorithmSupported(alg jwa.SignatureAlgorithm) bool { + for _, a := range kms.Algs { + if a.String() == alg.String() { + return true + } + } + return false +} + +// algForSigner determines a configured jwa.SignatureAlgorithm suitable for the signer’s key type. +func (kms *PKCS11KMS) algForSigner(signer crypto.Signer) (jwa.SignatureAlgorithm, error) { + pub := signer.Public() + switch t := pub.(type) { + case *rsa.PublicKey: + for _, a := range kms.Algs { + switch a { + case jwa.RS256(), jwa.RS384(), jwa.RS512(), jwa.PS256(), jwa.PS384(), jwa.PS512(): + return a, nil + } + } + return jwa.SignatureAlgorithm{}, errors.New("no RSA algorithms configured for loaded HSM key") + case *ecdsa.PublicKey: + var want jwa.SignatureAlgorithm + switch t.Curve { + case elliptic.P256(): + want = jwa.ES256() + case elliptic.P384(): + want = jwa.ES384() + case elliptic.P521(): + want = jwa.ES512() + default: + return jwa.SignatureAlgorithm{}, errors.New("unsupported ECDSA curve") + } + for _, a := range kms.Algs { + if a.String() == want.String() { + return a, nil + } + } + return jwa.SignatureAlgorithm{}, errors.Errorf("algorithm %s not configured for loaded ECDSA HSM key", want) + case ed25519.PublicKey: + for _, a := range kms.Algs { + if a.String() == jwa.EdDSA().String() { + return a, nil + } + } + return jwa.SignatureAlgorithm{}, errors.New("EdDSA not configured for loaded HSM key") + default: + return jwa.SignatureAlgorithm{}, errors.New("unknown HSM key type") + } +} + +// findKeyByKID locates a key pair by label derived from the kid (with optional prefix). +func (kms *PKCS11KMS) findKeyByKID(kid string) (crypto.Signer, error) { + if kms.ctx == nil { + return nil, errors.New("pkcs11 kms: context not initialized") + } + // First try exact kid as label, then fallback to prefixed label to support legacy configs + signer, err := kms.ctx.FindKeyPair(nil, []byte(kid)) + if err != nil { + return nil, errors.Wrap(err, "pkcs11 kms: find key by label") + } + if signer == nil { + signer, err = kms.ctx.FindKeyPair(nil, []byte(kms.keyLabel(kid))) + if err != nil { + return nil, errors.Wrap(err, "pkcs11 kms: fallback find key by prefixed label") + } + } + return signer, nil +} + +// generateNewSigner creates a new key pair inside the HSM for the given algorithm +// and registers its public part in PublicKeyStorage. +func (kms *PKCS11KMS) generateNewSigner( + alg jwa.SignatureAlgorithm, + mode nbfMode, +) (*public.PublicKeyEntry, error) { + if kms.ctx == nil { + return nil, errors.New("pkcs11 kms: context not initialized") + } + + u, err := uuid.NewV7() + if err != nil { + return nil, errors.Wrap(err, "could not generate uuid") + } + kid := u.String() + label := kms.keyLabel(kid) + signer, err := kms.generateKeyInHSM(alg, kid, label) + if err != nil { + return nil, err + } + + pk, _, err := jwx.SignerToPublicJWK(signer, alg) + if err != nil { + return nil, err + } + // Override the kid to match the HSM label + _ = pk.Set(jwk.KeyIDKey, label) + + now := unixtime.Now() + var nbf *unixtime.Unixtime + switch mode { + case nbfModeNow: + nbf = &now + case nbfModeNext: + lifetime, err := kms.KeyRotation.EntityConfigurationLifetimeFunc() + if err != nil { + return nil, errors.Wrap(err, "failed to get entity configuration lifetime") + } + nbf = &unixtime.Unixtime{Time: now.Add(lifetime)} + default: + return nil, errors.New("invalid nbf mode") + } + var exp *unixtime.Unixtime + if kms.KeyRotation.Enabled { + exp = &unixtime.Unixtime{Time: nbf.Add(kms.KeyRotation.Interval.Duration())} + } + pke := public.PublicKeyEntry{ + KID: label, + Key: public.JWKKey{Key: pk}, + IssuedAt: &now, + NotBefore: nbf, + UpdateablePublicKeyMetadata: public.UpdateablePublicKeyMetadata{ + ExpiresAt: exp, + }, + } + if err = kms.PKs.Add(pke); err != nil { + return nil, err + } + // Wrap with label so downstream signing can set kid header from HSM label + kms.signers[label] = &labeledSigner{ + s: signer, + kid: label, + } + return &pke, nil +} + +func (kms *PKCS11KMS) generateKeyInHSM(alg jwa.SignatureAlgorithm, kid, label string) (crypto.Signer, error) { + var signer crypto.Signer + var err error + switch alg { + case jwa.RS256(), jwa.RS384(), jwa.RS512(), jwa.PS256(), jwa.PS384(), jwa.PS512(): + signer, err = kms.ctx.GenerateRSAKeyPairWithLabel([]byte(kid), []byte(label), kms.RSAKeyLen) + case jwa.ES256(): + signer, err = kms.ctx.GenerateECDSAKeyPairWithLabel([]byte(kid), []byte(label), elliptic.P256()) + case jwa.ES384(): + signer, err = kms.ctx.GenerateECDSAKeyPairWithLabel([]byte(kid), []byte(label), elliptic.P384()) + case jwa.ES512(): + signer, err = kms.ctx.GenerateECDSAKeyPairWithLabel([]byte(kid), []byte(label), elliptic.P521()) + default: + return nil, errors.New("unknown signing algorithm: " + alg.String()) + } + if err != nil { + return nil, errors.WithStack(err) + } + return signer, nil +} + +// Rotation functions mirror FilesystemKMS logic, operating purely on public key metadata +// and leveraging HSM-backed key generation. +func (kms *PKCS11KMS) rotateKeys(kids []string, revoked bool, reason string) error { + log.WithFields( + log.Fields{ + "kids": kids, + "revoked": revoked, + }, + ).Info("pkcs#11 KMS: rotation: start") + ks := make([]*public.PublicKeyEntry, len(kids)) + var signingAlg jwa.SignatureAlgorithm + latestExp := time.Time{} + for i, kid := range kids { + k, err := kms.PKs.Get(kid) + if err != nil { + return err + } + alg, _ := k.Key.Algorithm() + if signingAlg.String() == "" { + signingAlg = alg.(jwa.SignatureAlgorithm) + } else { + if signingAlg.String() != alg.String() { + return errors.New("all keys must be of the same algorithm") + } + } + ks[i] = k + if k.ExpiresAt != nil && !k.ExpiresAt.IsZero() && (latestExp.IsZero() || k.ExpiresAt.After(latestExp)) { + latestExp = k.ExpiresAt.Time + } + } + mode := nbfModeNext + if revoked { + mode = nbfModeNow + } + // Avoid gaps: if the computed future NotBefore would be after latest current expiration, + // activate the new key immediately. + if mode == nbfModeNext { + if lifetime, err := kms.KeyRotation.EntityConfigurationLifetimeFunc(); err == nil { + if time.Now().Add(lifetime).After(latestExp) { + mode = nbfModeNow + } + } + } + pk, err := kms.generateNewSigner(signingAlg, mode) + if err != nil { + return err + } + log.WithFields( + log.Fields{ + "alg": signingAlg.String(), + "mode": fmt.Sprintf("%v", mode), + "new_kid": pk.KID, + }, + ).Info("pkcs#11 KMS: rotation: generated new key") + newExpForOldKey := &unixtime.Unixtime{Time: pk.NotBefore.Add(kms.KeyRotation.Overlap.Duration())} + for _, k := range ks { + if revoked { + now := unixtime.Now() + k.RevokedAt = &now + k.Reason = reason + } + // Ensure continuous coverage by setting old expiration to new.nbf + overlap + if k.ExpiresAt != nil && k.ExpiresAt.IsZero() || newExpForOldKey.Before(k.ExpiresAt.Time) || newExpForOldKey.After(k.ExpiresAt.Time) { + k.ExpiresAt = newExpForOldKey + } + if err = kms.PKs.Update(k.KID, k.UpdateablePublicKeyMetadata); err != nil { + return err + } + } + log.WithFields( + log.Fields{ + "alg": signingAlg.String(), + "new_kid": pk.KID, + }, + ).Info("pkcs#11 KMS: rotation: completed") + return nil +} + +// RotateKey rotates a single key, optionally marking it revoked and recording a reason. +func (kms *PKCS11KMS) RotateKey(kid string, revoked bool, reason string) error { + log.WithFields( + log.Fields{ + "kid": kid, + "revoked": revoked, + }, + ).Info("pkcs#11 KMS: rotate key") + return kms.rotateKeys([]string{kid}, revoked, reason) +} + +// RotateAllKeys rotates all active keys per configured algorithm, optionally revoking them. +func (kms *PKCS11KMS) RotateAllKeys(revoked bool, reason string) error { + activePKs, err := kms.PKs.GetActive() + if err != nil { + return err + } + pksByAlg := activePKs.ByAlg() + for _, alg := range kms.Algs { + algPKs, ok := pksByAlg[alg] + if !ok || len(algPKs) == 0 { + if _, err = kms.generateNewSigner(alg, nbfModeNow); err != nil { + return err + } + log.WithField("alg", alg.String()).Info("pkcs#11 KMS: rotation: seeded new key for alg with no active keys") + } + kids := make([]string, len(algPKs)) + for i, pk := range algPKs { + kids[i] = pk.KID + } + log.WithField("alg", alg.String()).Info("pkcs#11 KMS: rotation: processing alg") + if err = kms.rotateKeys(kids, revoked, reason); err != nil { + return err + } + } + return nil +} + +// StartAutomaticRotation launches a background loop to rotate keys ahead of expiration. +func (kms *PKCS11KMS) StartAutomaticRotation() error { + if !kms.KeyRotation.Enabled { + return nil + } + // ensure only one rotation loop runs + if kms.rotationStop != nil { + return nil + } + log.Info("pkcs#11 KMS: automatic rotation: starting") + kms.rotationStop = make(chan struct{}) + kms.rotationWG.Add(1) + go func() { + defer kms.rotationWG.Done() + for { + nextSleep, didRotate := kms.rotationStep(time.Now()) + // If we rotated, loop again immediately unless asked to stop. + if didRotate { + select { + case <-kms.rotationStop: + return + default: + } + continue + } + // Sleep until the next threshold or future key activation. + if nextSleep <= 0 { + nextSleep = time.Second + } + timer := time.NewTimer(nextSleep) + select { + case <-kms.rotationStop: + if !timer.Stop() { + <-timer.C + } + return + case <-timer.C: + // loop + } + } + }() + return nil +} + +// rotationStep performs one evaluation/rotation cycle and returns the next sleep +// interval and whether any rotation or seeding occurred (didRotate). +func (kms *PKCS11KMS) rotationStep(now time.Time) (time.Duration, bool) { + // default sleep if we cannot compute anything meaningful + nextSleep := kms.KeyRotation.Overlap.Duration() / 2 + // clamp to a minimum to avoid busy loops when overlap is zero + const minSleep = time.Second + if nextSleep <= 0 { + nextSleep = minSleep + } + didRotate := false + + activePKs, err := kms.PKs.GetActive() + if err != nil { + log.WithError(err).Error("pkcs#11 KMS: automatic rotation: failed to get active public keys") + return nextSleep, false + } + pksByAlg := activePKs.ByAlg() + // iterate only configured algorithms + for _, alg := range kms.Algs { + sleepCandidate, rotated := kms.rotationEvaluationForAlg(pksByAlg, alg, now, minSleep) + if rotated { + didRotate = true + } + if sleepCandidate > 0 && sleepCandidate < nextSleep { + nextSleep = sleepCandidate + } + } + + return nextSleep, didRotate +} + +// rotationEvaluationForAlg evaluates rotation needs for a single algorithm. +// It returns a candidate sleep duration until the next action point and whether +// any rotation or seeding occurred. +func (kms *PKCS11KMS) rotationEvaluationForAlg( + pksByAlg map[jwa.SignatureAlgorithm]public.PublicKeyEntryList, + alg jwa.SignatureAlgorithm, + now time.Time, + minSleep time.Duration, +) (time.Duration, bool) { + algPKs, ok := pksByAlg[alg] + if !ok || len(algPKs) == 0 { + // No active keys: before seeding, check if a valid future key already exists. + earliestNbf, hasFuture, vErr := earliestFutureNbfForAlg(kms.PKs, alg, now) + if vErr != nil { + log.WithError(vErr).Error("pkcs#11 KMS: automatic rotation: failed to get valid public keys for future check") + return 0, false + } + if hasFuture { + // sleep until future key becomes active + wait := time.Until(earliestNbf) + if wait < minSleep { + wait = minSleep + } + return wait, false + } + // no active and no future key; seed immediately + if _, err := kms.generateNewSigner(alg, nbfModeNow); err != nil { + log.WithError(err).Error("pkcs#11 KMS: automatic rotation: failed to seed key for alg") + // ensure we don't spin; retry soon-ish + return minSleep, false + } + // re-evaluate immediately to include the new key in active set + return 0, true + } + + // pick the key with latest expiration as the current signer for this alg + current := slices.MaxFunc( + algPKs, func(a, b public.PublicKeyEntry) int { + return cmp.Compare(a.ExpiresAt.Unix(), b.ExpiresAt.Unix()) + }, + ) + + // Trigger early enough to accommodate nbf = now + lifetime + lifetime := time.Duration(0) + if kms.KeyRotation.EntityConfigurationLifetimeFunc != nil { + if lt, lerr := kms.KeyRotation.EntityConfigurationLifetimeFunc(); lerr == nil { + lifetime = lt + } else { + log.WithError(lerr).Warn("pkcs#11 KMS: automatic rotation: failed to get lifetime; using 0") + } + } + threshold := current.ExpiresAt.Time.Add(-kms.KeyRotation.Overlap.Duration()).Add(-lifetime) + if !threshold.After(now) { + kids := make([]string, len(algPKs)) + for i, pk := range algPKs { + kids[i] = pk.KID + } + // If there is already a future key, do not generate another; only shorten old exp + if earliestNbf, hasFuture, vErr := earliestFutureNbfForAlg(kms.PKs, alg, now); vErr == nil && hasFuture { + shortenExpirationUntilFuture( + kms.PKs, algPKs, earliestNbf, kms.KeyRotation.Overlap.Duration(), "pkcs#11 KMS", + ) + wait := time.Until(earliestNbf) + if wait < minSleep { + wait = minSleep + } + return wait, false + } + if err := kms.rotateKeys(kids, false, ""); err != nil { + log.WithError(err).Error("pkcs#11 KMS: automatic rotation: rotate failed") + return minSleep, false + } + return 0, true + } + // schedule rotation when threshold is reached + wait := time.Until(threshold) + if wait < minSleep { + wait = minSleep + } + return wait, false +} + +// earliestFutureNbfForAlg returns the earliest NotBefore among valid, non-revoked +// keys for the given algorithm, that are in the future relative to now. +// Removed local earliestFutureNbfForAlg and shortenExpirationUntilFuture in favor of shared helpers. + +// StopAutomaticRotation stops the background rotation loop. +func (kms *PKCS11KMS) StopAutomaticRotation() { + if kms.rotationStop == nil { + return + } + close(kms.rotationStop) + kms.rotationWG.Wait() + log.Info("pkcs#11 KMS: automatic rotation: stopped") + kms.rotationStop = nil +} diff --git a/jwx/keymanagement/public/filesystem.go b/jwx/keymanagement/public/filesystem.go new file mode 100644 index 0000000..d04879b --- /dev/null +++ b/jwx/keymanagement/public/filesystem.go @@ -0,0 +1,382 @@ +package public + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + "time" + + "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/pkg/errors" + + log "github.com/go-oidfed/lib/internal" + + "github.com/go-oidfed/lib/unixtime" +) + +// FilesystemPublicKeyStorage implements PublicKeyStorage backed by a JSON file on disk. +// It persists a collection of PublicKeyEntry records under a type-specific path. +type FilesystemPublicKeyStorage struct { + Dir string + TypeID string + + mu sync.RWMutex + entries map[string]PublicKeyEntry // keyed by KID +} + +func (fs *FilesystemPublicKeyStorage) storageFilePath() string { + return filepath.Join(fs.Dir, fs.TypeID+"_public.json") +} + +// Load loads public keys from disk. If no native storage file exists, it attempts +// to import from a legacy LegacyPKCollection persisted via keys.jwks and history files. +func (fs *FilesystemPublicKeyStorage) Load() error { + fs.mu.Lock() + defer fs.mu.Unlock() + if fs.entries == nil { + fs.entries = make(map[string]PublicKeyEntry) + } + + data, err := os.ReadFile(fs.storageFilePath()) + if err != nil { + log.WithError(err).WithField( + "filepath", fs.storageFilePath(), + ).Warn("FilesystemPublicKeyStorage: could not read storage file") + return nil + } + if len(data) == 0 { + return nil + } + var disk map[string]PublicKeyEntry + if err = json.Unmarshal(data, &disk); err != nil { + return errors.WithStack(err) + } + fs.entries = disk + return nil +} + +// GetAll returns all keys in the storage, including revoked and expired keys. +func (fs *FilesystemPublicKeyStorage) GetAll() (PublicKeyEntryList, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + return fs.snapshot(), nil +} + +// GetRevoked returns all revoked keys in the storage. +func (fs *FilesystemPublicKeyStorage) GetRevoked() (PublicKeyEntryList, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + now := time.Now() + var out PublicKeyEntryList + for _, e := range fs.entries { + if e.RevokedAt != nil && !e.RevokedAt.IsZero() && e.RevokedAt.Before(now) { + out = append(out, e) + } + } + return out, nil +} + +// GetExpired returns all expired keys in the storage. +func (fs *FilesystemPublicKeyStorage) GetExpired() (PublicKeyEntryList, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + now := time.Now() + var out PublicKeyEntryList + for _, e := range fs.entries { + if e.ExpiresAt != nil && !e.ExpiresAt.IsZero() && e.ExpiresAt.Before(now) { + out = append(out, e) + } + } + return out, nil +} + +// GetHistorical returns revoked and expired keys. +func (fs *FilesystemPublicKeyStorage) GetHistorical() (PublicKeyEntryList, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + now := time.Now() + var out PublicKeyEntryList + for _, e := range fs.entries { + if e.ExpiresAt != nil && !e.ExpiresAt.IsZero() && e.ExpiresAt.Before(now) { + out = append(out, e) + } else if e.RevokedAt != nil && !e.RevokedAt.IsZero() && e.RevokedAt.Before(now) { + out = append(out, e) + } + } + return out, nil + +} + +// GetActive returns keys that are currently usable (not revoked and within nbf/exp window). +func (fs *FilesystemPublicKeyStorage) GetActive() (PublicKeyEntryList, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + now := time.Now() + var out PublicKeyEntryList + for _, e := range fs.entries { + if e.RevokedAt != nil && !e.RevokedAt.IsZero() && e.RevokedAt.Before(now) { + continue + } + if e.NotBefore != nil && !e.NotBefore.IsZero() && now.Before(e.NotBefore.Time) { + continue + } + if e.ExpiresAt != nil && !e.ExpiresAt.IsZero() && e.ExpiresAt.Before(now) { + continue + } + out = append(out, e) + } + return out, nil +} + +// GetValid returns keys that are valid now or in the future (not revoked or expired). +func (fs *FilesystemPublicKeyStorage) GetValid() (PublicKeyEntryList, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + now := time.Now() + var out PublicKeyEntryList + for _, e := range fs.entries { + if e.RevokedAt != nil && !e.RevokedAt.IsZero() && e.RevokedAt.Before(now) { + continue + } + if e.ExpiresAt != nil && !e.ExpiresAt.IsZero() && e.ExpiresAt.Before(now) { + continue + } + out = append(out, e) + } + return out, nil +} + +// Add adds a new key to the storage if the KID is unused. +func (fs *FilesystemPublicKeyStorage) Add(key PublicKeyEntry) error { + fs.mu.Lock() + defer fs.mu.Unlock() + if key.KID == "" { + // Attempt to derive KID from key if missing + var kid string + _ = key.Key.Get("kid", &kid) + key.KID = kid + } + if key.KID == "" { + return errors.New("missing kid for public key entry") + } + if _, exists := fs.entries[key.KID]; exists { + return nil + } + fs.entries[key.KID] = key + return fs.persist() +} + +// AddAll adds multiple keys to the storage. +func (fs *FilesystemPublicKeyStorage) AddAll(keys []PublicKeyEntry) error { + fs.mu.Lock() + defer fs.mu.Unlock() + for _, key := range keys { + if key.KID == "" { + var kid string + _ = key.Key.Get("kid", &kid) + key.KID = kid + } + if key.KID == "" { + continue + } + if _, exists := fs.entries[key.KID]; exists { + continue + } + fs.entries[key.KID] = key + } + return fs.persist() +} + +// Update updates the editable metadata for a given key. +func (fs *FilesystemPublicKeyStorage) Update(kid string, data UpdateablePublicKeyMetadata) error { + fs.mu.Lock() + defer fs.mu.Unlock() + e, ok := fs.entries[kid] + if !ok { + return errors.Errorf("unknown kid '%s'", kid) + } + e.UpdateablePublicKeyMetadata = data + fs.entries[kid] = e + return fs.persist() +} + +// Clear removes all keys from the storage. +func (fs *FilesystemPublicKeyStorage) Clear() error { + fs.mu.Lock() + defer fs.mu.Unlock() + fs.entries = make(map[string]PublicKeyEntry) + return fs.persist() +} + +// Delete removes a key by kid. +func (fs *FilesystemPublicKeyStorage) Delete(kid string) error { + fs.mu.Lock() + defer fs.mu.Unlock() + delete(fs.entries, kid) + return fs.persist() +} + +// Revoke marks the given key as revoked with the passed reason. +func (fs *FilesystemPublicKeyStorage) Revoke(kid, reason string) error { + k, err := fs.Get(kid) + if err != nil { + return err + } + if k == nil { + return nil + } + now := unixtime.Now() + k.RevokedAt = &now + k.Reason = reason + return fs.Update(kid, k.UpdateablePublicKeyMetadata) + +} + +// Get fetches a key by kid. +func (fs *FilesystemPublicKeyStorage) Get(kid string) (*PublicKeyEntry, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + e, ok := fs.entries[kid] + if !ok { + return nil, nil + } + // Clone the jwk.Key to avoid external mutation + if e.Key.Key != nil { + if k, err := e.Key.Clone(); err == nil { + e.Key.Key = k + } + } + return &e, nil +} + +// snapshot returns a copy of entries as a list. +func (fs *FilesystemPublicKeyStorage) snapshot() PublicKeyEntryList { + out := make(PublicKeyEntryList, 0, len(fs.entries)) + for _, e := range fs.entries { + out = append(out, e) + } + return out +} + +// persist writes the storage to disk. +func (fs *FilesystemPublicKeyStorage) persist() error { + if err := os.MkdirAll(fs.Dir, 0o700); err != nil { + return errors.WithStack(err) + } + data, err := json.Marshal(fs.entries) + if err != nil { + return errors.WithStack(err) + } + return errors.WithStack(os.WriteFile(fs.storageFilePath(), data, 0o600)) +} + +// NewFilesystemPublicKeyStorageFromLegacy creates a new FilesystemPublicKeyStorage and +// populates it from a LegacyPKCollection stored in the legacy files within the given dir. +// This helper is intended to aid migration and does not retain legacy aggregation behavior. +func NewFilesystemPublicKeyStorageFromLegacy(dir, typeID string) (*FilesystemPublicKeyStorage, error) { + fs := &FilesystemPublicKeyStorage{ + Dir: dir, + TypeID: typeID, + } + if err := fs.Load(); err != nil { + return nil, err + } + // Load legacy aggregated storage and import entries for the given typeID + var agg aggregatedPublicKeyStorage + if err := agg.Load(dir); err != nil { + return nil, err + } + coll, ok := agg[typeID] + if !ok || coll == nil { + return fs, nil + } + // Import all JWKS sets and history + importSet := func(set jwk.Set) { + if set == nil || set.Len() == 0 { + return + } + for i := range set.Len() { + k, _ := set.Key(i) + var kid string + _ = k.Get("kid", &kid) + if kid == "" { + continue + } + cloned, cerr := k.Clone() + if cerr != nil { + continue + } + // Extract timing metadata if present + var iat, nbf, exp unixtime.Unixtime + _ = k.Get("iat", &iat) + _ = k.Get("nbf", &nbf) + _ = k.Get("exp", &exp) + entry := PublicKeyEntry{ + KID: kid, + Key: JWKKey{cloned}, + } + if !iat.IsZero() && iat.Unix() != 0 { + entry.IssuedAt = &iat + } + if !nbf.IsZero() && nbf.Unix() != 0 { + entry.NotBefore = &nbf + } + if !exp.IsZero() && exp.Unix() != 0 { + entry.ExpiresAt = &exp + } + // Last write wins for duplicate kids + fs.entries[kid] = entry + } + } + for _, set := range coll.jwks { + importSet(set.Set) + } + if coll.history.Set != nil { + importSet(coll.history.Set) + } + if err := fs.persist(); err != nil { + return nil, err + } + return fs, nil +} + +// NewFilesystemPublicKeyStorageFromStorage creates a new FilesystemPublicKeyStorage +// and populates it from the passed PublicKeyStorage implementation. +func NewFilesystemPublicKeyStorageFromStorage(dir, typeID string, src PublicKeyStorage) ( + *FilesystemPublicKeyStorage, error, +) { + fs := &FilesystemPublicKeyStorage{ + Dir: dir, + TypeID: typeID, + entries: make(map[string]PublicKeyEntry), + } + // Load source if necessary + if err := src.Load(); err != nil { + return nil, err + } + list, err := src.GetAll() + if err != nil { + return nil, err + } + for _, e := range list { + if e.KID == "" && e.Key.Key != nil { + var kid string + _ = e.Key.Get("kid", &kid) + e.KID = kid + } + if e.KID == "" || e.Key.Key == nil { + continue + } + if k, cerr := e.Key.Clone(); cerr == nil { + e.Key.Key = k + } else { + continue + } + fs.entries[e.KID] = e + } + if err = fs.persist(); err != nil { + return nil, err + } + return fs, nil +} diff --git a/jwx/keymanagement/public/legacy.go b/jwx/keymanagement/public/legacy.go new file mode 100644 index 0000000..ed48310 --- /dev/null +++ b/jwx/keymanagement/public/legacy.go @@ -0,0 +1,367 @@ +package public + +import ( + "encoding/json" + "fmt" + "math" + "os" + "path/filepath" + "slices" + "time" + + "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/pkg/errors" + "github.com/zachmann/go-utils/fileutils" + + log "github.com/go-oidfed/lib/internal" + + "github.com/go-oidfed/lib/jwx" + "github.com/go-oidfed/lib/unixtime" +) + +type jwksSlice []jwx.JWKS + +var zeroJWKS jwx.JWKS + +// LegacyPKCollection is a collection of public keys, used for signing. +// Deprecated: Only provided for backwards compatibility to provide an easy +// upgrade path from this old implementation to the new one. +type LegacyPKCollection struct { + // jwksSlice stores the public key JWKS; the order matters! + // [0] the current JWKS (currently used for signing) + // [1] the next JWKS (will be used next for signing) + // [2...n] previous JWKS, where n is the oldest + jwks jwksSlice + NumberOfOldKeysKeptInJWKS int + KeepHistory bool + history jwx.JWKS +} + +// LegacyPublicKeyStorage wraps LegacyPKCollection and exposes the PublicKeyStorage +// interface to enable migration to the new FilesystemPublicKeyStorage. +// This should only be used for migrations +type LegacyPublicKeyStorage struct { + Dir string + TypeID string + coll *LegacyPKCollection +} + +func (l *LegacyPublicKeyStorage) Load() error { + var agg aggregatedPublicKeyStorage + if err := agg.Load(l.Dir); err != nil { + return err + } + c, ok := agg[l.TypeID] + if !ok { + l.coll = &LegacyPKCollection{} + return nil + } + l.coll = c + return nil +} + +func (l *LegacyPublicKeyStorage) GetAll() (PublicKeyEntryList, error) { + return l.collectAll(true, true), nil +} + +func (l *LegacyPublicKeyStorage) GetRevoked() (PublicKeyEntryList, error) { + // Legacy format does not track revoked explicitly; return empty + return PublicKeyEntryList{}, nil +} + +func (l *LegacyPublicKeyStorage) GetExpired() (PublicKeyEntryList, error) { + now := time.Now() + out := l.collectAll(false, true) + var expired PublicKeyEntryList + for _, e := range out { + if e.ExpiresAt != nil && !e.ExpiresAt.IsZero() && e.ExpiresAt.Before(now) { + expired = append(expired, e) + } + } + return expired, nil +} + +func (l *LegacyPublicKeyStorage) GetHistorical() (PublicKeyEntryList, error) { + now := time.Now() + out := l.collectAll(false, true) + var hist PublicKeyEntryList + for _, e := range out { + if e.ExpiresAt != nil && !e.ExpiresAt.IsZero() && e.ExpiresAt.Before(now) { + hist = append(hist, e) + } + } + return hist, nil +} + +func (l *LegacyPublicKeyStorage) GetActive() (PublicKeyEntryList, error) { + now := time.Now() + out := l.collectAll(false, false) + var active PublicKeyEntryList + for _, e := range out { + if e.NotBefore != nil && !e.NotBefore.IsZero() && now.Before(e.NotBefore.Time) { + continue + } + if e.ExpiresAt != nil && !e.ExpiresAt.IsZero() && e.ExpiresAt.Before(now) { + continue + } + active = append(active, e) + } + return active, nil +} + +func (l *LegacyPublicKeyStorage) GetValid() (PublicKeyEntryList, error) { + // In legacy, valid = all keys not expired + now := time.Now() + list := l.collectAll(true, false) + var valid PublicKeyEntryList + for _, e := range list { + if e.ExpiresAt != nil && !e.ExpiresAt.IsZero() && e.ExpiresAt.Before(now) { + continue + } + valid = append(valid, e) + } + return valid, nil +} + +func (l *LegacyPublicKeyStorage) Add(key PublicKeyEntry) error { return errors.New("unsupported") } +func (l *LegacyPublicKeyStorage) AddAll(keys []PublicKeyEntry) error { + return errors.New("unsupported") +} +func (l *LegacyPublicKeyStorage) Update(kid string, data UpdateablePublicKeyMetadata) error { + return errors.New("unsupported") +} +func (l *LegacyPublicKeyStorage) Delete(kid string) error { return errors.New("unsupported") } +func (l *LegacyPublicKeyStorage) Revoke(kid, reason string) error { return errors.New("unsupported") } + +func (l *LegacyPublicKeyStorage) Get(kid string) (*PublicKeyEntry, error) { + list := l.collectAll(true, true) + for _, e := range list { + if e.KID == kid { + return &e, nil + } + } + return nil, nil +} + +// collectAll flattens legacy JWKS (current, next, olds) and history into PublicKeyEntryList +func (l *LegacyPublicKeyStorage) collectAll(includeNext, includeOlds bool) PublicKeyEntryList { + if l.coll == nil { + return PublicKeyEntryList{} + } + var sets []jwx.JWKS + if len(l.coll.jwks) > 0 { + // current + sets = append(sets, l.coll.jwks[0]) + } + if includeNext && len(l.coll.jwks) > 1 { + sets = append(sets, l.coll.jwks[1]) + } + if includeOlds && len(l.coll.jwks) > 2 { + sets = append(sets, l.coll.jwks[2:]...) + } + if l.coll.history.Set != nil { + sets = append(sets, l.coll.history) + } + var out PublicKeyEntryList + for _, s := range sets { + for i := range s.Len() { + k, _ := s.Key(i) + var kid string + _ = k.Get("kid", &kid) + if kid == "" { + continue + } + cloned, cerr := k.Clone() + if cerr != nil { + continue + } + var iatF, nbfF, expF float64 + _ = k.Get("iat", &iatF) + _ = k.Get("nbf", &nbfF) + _ = k.Get("exp", &expF) + var iat, nbf, exp *unixtime.Unixtime + if iatF != 0 { + sec, dec := math.Modf(iatF) + iat = &unixtime.Unixtime{Time: time.Unix(int64(sec), int64(dec*(1e9)))} + } + if nbfF != 0 { + sec, dec := math.Modf(nbfF) + nbf = &unixtime.Unixtime{Time: time.Unix(int64(sec), int64(dec*(1e9)))} + } + if expF != 0 { + sec, dec := math.Modf(expF) + exp = &unixtime.Unixtime{Time: time.Unix(int64(sec), int64(dec*(1e9)))} + } + out = append( + out, PublicKeyEntry{ + KID: kid, + Key: JWKKey{cloned}, + IssuedAt: iat, + NotBefore: nbf, + UpdateablePublicKeyMetadata: UpdateablePublicKeyMetadata{ExpiresAt: exp}, + }, + ) + } + } + return out +} + +// MarshalJSON implements the json.Marshaler interface +func (pks LegacyPKCollection) MarshalJSON() ([]byte, error) { + return json.Marshal(pks.jwks) +} + +// UnmarshalJSON implements the json.Unmarshaler interface +func (pks *LegacyPKCollection) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &pks.jwks) +} + +func (pks *LegacyPKCollection) setCurrentJWKS(current jwx.JWKS) { + if len(pks.jwks) == 0 { + pks.jwks = append(pks.jwks, current) + return + } + pks.jwks[0] = current +} + +func (pks *LegacyPKCollection) addCurrentJWK(current jwk.Key) { + if len(pks.jwks) == 0 { + set := jwx.NewJWKS() + _ = set.AddKey(current) + pks.jwks = jwksSlice{set} + return + } + _ = pks.jwks[0].AddKey(current) +} + +func (pks *LegacyPKCollection) setNextJWKS(next jwx.JWKS) { + if len(pks.jwks) == 0 { + log.Error("error setting next JWKS in LegacyPKCollection: no current JWKS set") + pks.jwks = append(pks.jwks, next) + } + if len(pks.jwks) == 1 { + pks.jwks = append(pks.jwks, next) + return + } + pks.jwks[1] = next +} + +func (pks *LegacyPKCollection) addNextJWK(next jwk.Key) { + if len(pks.jwks) == 0 { + log.Error("error setting next JWKS in LegacyPKCollection: no current JWKS set") + set := jwx.NewJWKS() + _ = set.AddKey(next) + pks.jwks = jwksSlice{set} + } + if len(pks.jwks) == 1 { + set := jwx.NewJWKS() + _ = set.AddKey(next) + pks.jwks = append(pks.jwks, set) + return + } + _ = pks.jwks[1].AddKey(next) +} + +func (pks *LegacyPKCollection) pushOldJWKS(old jwx.JWKS) jwx.JWKS { + l := len(pks.jwks) + if l < 2 { + pks.jwks = append(pks.jwks, old) + return zeroJWKS + } + if l == 2 { + pks.jwks = append(pks.jwks, old) + } else { + pks.jwks = slices.Insert(pks.jwks, 2, old) + } + if l-2 >= pks.NumberOfOldKeysKeptInJWKS { + poped := pks.jwks[len(pks.jwks)-1] + pks.jwks = pks.jwks[:len(pks.jwks)-1] + if pks.KeepHistory { + if pks.history.Set == nil { + pks.history = poped + } else { + for i := range poped.Len() { + k, _ := poped.Key(i) + _ = pks.history.AddKey(k) + } + } + } + return poped + } + return zeroJWKS +} + +// rotate rotates the JWKS, the passed JWKS will be set as the next JWKS, +// the previously next JWKS becomes the current JWKS, the previous current JWKS becomes the first old JWKS, +// and all old JWKS are shifted, while the oldest JWKS ( +// if it exceeds the number of old JWKS kept) is removed from the collection and returned. +func (pks *LegacyPKCollection) rotate(next jwx.JWKS) jwx.JWKS { + if len(pks.jwks) == 0 { + pks.jwks = append(pks.jwks, next) + return zeroJWKS + } + previouslyCurrent := pks.jwks[0] + old := pks.pushOldJWKS(previouslyCurrent) + previouslyNext := pks.jwks[1] + pks.setCurrentJWKS(previouslyNext) + pks.setNextJWKS(next) + return old +} + +type aggregatedPublicKeyStorage map[string]*LegacyPKCollection + +// Load loads the public keys from disk +func (pks *aggregatedPublicKeyStorage) Load(dir string) error { + data, err := fileutils.ReadFile(jwksFilePath(dir)) + if err != nil { + log.Warn(err.Error()) + return nil + } + if len(data) == 0 { + return nil + } + if err = errors.WithStack(json.Unmarshal(data, pks)); err != nil { + return err + } + for typeID, collection := range *pks { + data, err = fileutils.ReadFile(jwksHistoryFilePath(dir, typeID)) + if err != nil { + continue + } + if err = errors.WithStack(json.Unmarshal(data, &collection.history)); err != nil { + return err + } + } + return nil +} + +// Save saves the public keys to disk +func (pks aggregatedPublicKeyStorage) Save(dir string) error { + data, err := json.Marshal(pks) + if err != nil { + return errors.WithStack(err) + } + if err = os.WriteFile(jwksFilePath(dir), data, 0600); err != nil { + return errors.WithStack(err) + } + for typeID, collection := range pks { + if collection.history.Set == nil || collection.history.Len() == 0 { + continue + } + data, err = json.Marshal(collection.history) + if err != nil { + return errors.WithStack(err) + } + if err = os.WriteFile(jwksHistoryFilePath(dir, typeID), data, 0600); err != nil { + return errors.WithStack(err) + } + } + return nil +} + +func jwksFilePath(dir string) string { + return filepath.Join(dir, "keys.jwks") +} +func jwksHistoryFilePath(dir, typeID string) string { + return filepath.Join(dir, fmt.Sprintf("%s_history.jwks", typeID)) +} diff --git a/jwx/keymanagement/public/publicKeyStorage.go b/jwx/keymanagement/public/publicKeyStorage.go new file mode 100644 index 0000000..ebbc745 --- /dev/null +++ b/jwx/keymanagement/public/publicKeyStorage.go @@ -0,0 +1,185 @@ +package public + +import ( + "bytes" + "encoding/json" + + "github.com/lestrrat-go/jwx/v3/jwa" + "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/pkg/errors" + + "github.com/go-oidfed/lib/jwx" + "github.com/go-oidfed/lib/unixtime" +) + +// PublicKeyStorage defines operations for storing and retrieving public keys +// and their associated validity and revocation metadata. +type PublicKeyStorage interface { + // Load initializes the PublicKeyStorage and loads the public keys (if necessary) + Load() error + // GetAll returns all keys in the storage, including revoked and expired keys. + GetAll() (PublicKeyEntryList, error) + // GetRevoked returns all revoked keys in the storage. + GetRevoked() (PublicKeyEntryList, error) + // GetExpired returns all expired keys in the storage. + GetExpired() (PublicKeyEntryList, error) + // GetHistorical returns all keys in the storage that can no longer be + // used, i.e. revoked and expired keys. + GetHistorical() (PublicKeyEntryList, error) + // GetActive returns all active keys in the storage, + // i.e., keys that can be used currently, + // i.e., keys that are not revoked and where the current time is between + // nbf and exp. + GetActive() (PublicKeyEntryList, error) + // GetValid returns all valid keys in the storage, + // i.e., keys that can be used currently or in the future, i.e., keys that are not expired or revoked. + GetValid() (PublicKeyEntryList, error) + // Add adds a new key to the storage; the keyID is used to identify the + // key; MUST only add the key if the keyID is not already in use, + // in that case do nothing and MUST NOT return an error. + Add(key PublicKeyEntry) error + // AddAll adds all the passed PublicKeyEntry to the storage + AddAll(key []PublicKeyEntry) error + // Update updates an existing PublicKeyEntry + Update(kid string, data UpdateablePublicKeyMetadata) error + // Delete deletes a PublicKeyEntry + Delete(kid string) error + // Revoke revokes an PublicKeyEntry with the passed reason + Revoke(kid, reason string) error + // Get returns a PublicKeyEntry + Get(kid string) (*PublicKeyEntry, error) +} + +// PublicKeyEntryList is a list of PublicKeyEntry +type PublicKeyEntryList []PublicKeyEntry + +// JWKS converts the list into a JWKS, cloning each JWK and annotating it with +// standard fields (iat, nbf, exp) and optional revocation information. +func (pks PublicKeyEntryList) JWKS() (jwx.JWKS, error) { + jwks := jwx.NewJWKS() + for _, pk := range pks { + k, err := pk.JWK() + if err != nil { + return jwx.JWKS{}, err + } + _ = jwks.AddKey(k) + } + return jwks, nil +} + +// Filter returns a new list containing entries for which the provided filter +// function returns true. +func (pks PublicKeyEntryList) Filter(filter func(entry PublicKeyEntry) bool) ( + filtered PublicKeyEntryList, +) { + for _, pk := range pks { + if filter(pk) { + filtered = append(filtered, pk) + } + } + return +} + +// ByAlg groups entries by their JWK signature algorithm and returns a map +// keyed by jwa.SignatureAlgorithm. +func (pks PublicKeyEntryList) ByAlg() map[jwa.SignatureAlgorithm]PublicKeyEntryList { + m := make(map[jwa.SignatureAlgorithm]PublicKeyEntryList) + for _, pk := range pks { + alg, set := pk.Key.Algorithm() + if !set { + continue + } + signatureAlg, ok := alg.(jwa.SignatureAlgorithm) + if !ok { + continue + } + if _, ok = m[signatureAlg]; !ok { + m[signatureAlg] = make(PublicKeyEntryList, 0) + } + m[signatureAlg] = append(m[signatureAlg], pk) + } + return m +} + +// JWKKey is a wrapper around jwk.Key that implements the json.Unmarshaler +// interface, so that it can be used as a field in a PublicKeyEntry. +type JWKKey struct { + jwk.Key +} + +// MarshalJSON implements the json.Marshaler interface +func (k JWKKey) MarshalJSON() ([]byte, error) { + return json.Marshal((k.Key)) +} + +// UnmarshalJSON implements the json.Unmarshaler interface +func (k *JWKKey) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, []byte("null")) { + return nil + } + key, err := jwk.ParseKey(data) + if err != nil { + return errors.Wrap(err, "failed to parse jwk") + } + k.Key = key + return nil +} + +// PublicKeyEntry holds a public JWK alongside issuance, validity and revocation +// metadata used to determine whether the key is usable. +type PublicKeyEntry struct { + KID string `gorm:"primaryKey;column:kid" json:"kid"` + Key JWKKey `gorm:"serializer:json" json:"key"` + IssuedAt *unixtime.Unixtime `json:"iat,omitempty"` + NotBefore *unixtime.Unixtime `json:"nbf,omitempty"` + UpdateablePublicKeyMetadata +} + +// UpdateablePublicKeyMetadata contains fields that can be updated after +// creation, such as expiration and revocation information. +type UpdateablePublicKeyMetadata struct { + ExpiresAt *unixtime.Unixtime `json:"exp,omitempty"` + RevokedAt *unixtime.Unixtime `json:"revoked_at,omitempty"` + Reason string `json:"reason,omitempty"` +} + +// JWK returns a cloned jwk.Key annotated with standard JWT fields (iat, nbf, +// exp) and optional revocation information. +func (pk PublicKeyEntry) JWK() (jwk.Key, error) { + key, err := pk.Key.Clone() + if err != nil { + return nil, errors.Wrap(err, "failed to clone key") + } + if iat := pk.IssuedAt; iat != nil && !iat.IsZero() { + err = key.Set("iat", iat) + if err != nil { + return nil, errors.Wrap(err, "failed to set iat") + } + } + if nbf := pk.NotBefore; nbf != nil && !nbf.IsZero() { + err = key.Set("nbf", nbf) + if err != nil { + return nil, errors.Wrap(err, "failed to set nbf") + } + } + if exp := pk.ExpiresAt; exp != nil && !exp.IsZero() { + err = key.Set("exp", exp) + if err != nil { + return nil, errors.Wrap(err, "failed to set exp") + } + } + if rvk := pk.RevokedAt; rvk != nil && !rvk.IsZero() && rvk.Unix() != 0 { + revoked := struct { + RevokedAt unixtime.Unixtime `json:"revoked_at"` + Reason string `json:"reason,omitempty"` + }{ + RevokedAt: *rvk, + Reason: pk.Reason, + } + err = key.Set("revoked", revoked) + if err != nil { + return nil, errors.Wrap(err, "failed to set revoked") + } + } + return key, nil +} diff --git a/jwx/keystorage.go b/jwx/keystorage.go deleted file mode 100644 index e41ba47..0000000 --- a/jwx/keystorage.go +++ /dev/null @@ -1,242 +0,0 @@ -package jwx - -import ( - "crypto" - "sync" - - "github.com/go-oidfed/lib/internal" - "github.com/lestrrat-go/jwx/v3/jwa" - "github.com/pkg/errors" - "github.com/zachmann/go-utils/duration" -) - -// Constants for TypeIDs used in a KeyStorage -const ( - KeyStorageTypeFederation = "federation" - KeyStorageTypeOIDC = "oidc" -) - -// NewKeyStorage creates a new KeyStorage for the passed KeyStorageConfigs at the passed directory -func NewKeyStorage(keyDir string, conf map[string]KeyStorageConfig) (*KeyStorage, error) { - ks := &KeyStorage{ - public: make(aggregatedPublicKeyStorage), - private: make(map[string]privateKeyStorage), - keyDir: keyDir, - } - for t, cfg := range conf { - var sks privateKeyStorage - if cfg.Algorithm != "" { - // single alg - alg, ok := jwa.LookupSignatureAlgorithm(cfg.Algorithm) - if !ok { - return nil, errors.Errorf("unknown algorithm '%s'", cfg.Algorithm) - } - sks = &privateKeyStorageSingleAlg{ - typeID: t, - alg: alg, - rollover: cfg.RolloverConf, - rsaKeyLen: cfg.RSAKeyLen, - keyDir: keyDir, - } - } else { - // multi alg - var algs []jwa.SignatureAlgorithm - if len(algs) == 0 { - algs = supportedAlgs - } else { - for _, a := range cfg.Algorithms { - alg, ok := jwa.LookupSignatureAlgorithm(a) - if !ok { - return nil, errors.Errorf("unknown algorithm '%s'", a) - } - algs = append(algs, alg) - } - } - sksma := &privateKeyStorageMultiAlg{ - typeID: t, - algs: algs, - rollover: cfg.RolloverConf, - rsaKeyLen: cfg.RSAKeyLen, - keyDir: keyDir, - } - if a := cfg.DefaultAlgorithm; a != "" { - defaultAlg, ok := jwa.LookupSignatureAlgorithm(a) - if !ok { - return nil, errors.Errorf("unknown default algorithm '%s'", a) - } - sksma.defaultAlg = defaultAlg - } - sks = sksma - } - ks.private[t] = sks - ks.public[t] = &pkCollection{ - NumberOfOldKeysKeptInJWKS: cfg.RolloverConf.NumberOfOldKeysKeptInJWKS, - KeepHistory: cfg.RolloverConf.KeepHistory, - } - } - return ks, nil -} - -// KeyStorage manages public and private signing keys for multiple typeIds (e.g. federation and oidc), -// it handles loading and writing keys to disk and can also handle key rotation. -type KeyStorage struct { - public aggregatedPublicKeyStorage - private map[string]privateKeyStorage - keyDir string -} - -// KeyStorageConfig is a type holding the configuration for keys for a protocol. -// If Algorithm is set, this implies that a single singing algorithm is supported, -// otherwise multiple algorithms are supported, even if Algorithms is not set ( -// since in that case all supported algorithms should be supported) -type KeyStorageConfig struct { - Algorithm string `yaml:"alg"` - Algorithms []string `yaml:"algs"` - DefaultAlgorithm string `yaml:"default_alg"` - RSAKeyLen int `yaml:"rsa_key_len"` - RolloverConf RolloverConf `yaml:"automatic_key_rollover"` -} - -// RolloverConf is a type holding configuration for key rollover / key rotation -type RolloverConf struct { - Enabled bool `yaml:"enabled"` - Interval duration.DurationOption `yaml:"interval"` - NumberOfOldKeysKeptInJWKS int `yaml:"old_keys_kept_in_jwks"` - KeepHistory bool `yaml:"keep_history"` -} - -// JWKS returns the jwks.JWKS containing all public keys for the passed storageType -func (ks KeyStorage) JWKS(storageType string) JWKS { - sets, ok := ks.public[storageType] - if !ok || sets == nil || len(sets.jwks) == 0 { - return JWKS{} - } - final := NewJWKS() - for _, set := range sets.jwks { - for i := range set.Len() { - k, _ := set.Key(i) - if err := final.AddKey(k); err != nil { - internal.Error(err.Error()) - } - } - } - return final -} - -// History returns the jwks history for the passed storageType -func (ks KeyStorage) History(storageType string) JWKS { - pks, ok := ks.public[storageType] - if !ok { - return zeroJWKS - } - return pks.history -} - -// Signer takes a list of acceptable signature algorithms and returns a -// usable crypto.Signer or nil as well as the corresponding -// jwa.SignatureAlgorithm for the passed storageType -func (ks KeyStorage) Signer(storageType string, algs ...string) (crypto.Signer, jwa.SignatureAlgorithm) { - sks, ok := ks.private[storageType] - if !ok { - return nil, jwa.SignatureAlgorithm{} - } - return sks.GetForAlgs(algs...) -} - -// DefaultSigner returns a crypto.Signer and the corresponding jwa.SignatureAlgorithm for the passed storageType -func (ks KeyStorage) DefaultSigner(storageType string) (crypto.Signer, jwa.SignatureAlgorithm) { - sks, ok := ks.private[storageType] - if !ok { - return nil, jwa.SignatureAlgorithm{} - } - return sks.GetDefault() -} - -// FederationJWKS returns the jwks.JWKS containing all public keys for the KeyStorageTypeFederation storageTypeID -func (ks KeyStorage) FederationJWKS() JWKS { - return ks.JWKS(KeyStorageTypeFederation) -} - -// OIDCJWKS returns the jwks.JWKS containing all public keys for the KeyStorageTypeOIDC storageTypeID -func (ks KeyStorage) OIDCJWKS() JWKS { - return ks.JWKS(KeyStorageTypeOIDC) -} - -// FederationSigner returns the crypto.Signer and the corresponding jwa.SignatureAlgorithm -// for the KeyStorageTypeFederation storageTypeID -func (ks KeyStorage) FederationSigner() (crypto.Signer, jwa.SignatureAlgorithm) { - return ks.DefaultSigner(KeyStorageTypeFederation) -} - -// Load loads the KeyStorage from disk and if enabled schedules key rotation. -func (ks *KeyStorage) Load() error { - if err := ks.public.Load(ks.keyDir); err != nil { - return err - } - - var mutex sync.Mutex - - for typeID, sks := range ks.private { - pks, found := ks.public[typeID] - if !found { - pks = &pkCollection{} - ks.public[typeID] = pks - } - if err := sks.Load( - pks, func() error { - mutex.Lock() - defer mutex.Unlock() - return ks.Save() - }, - ); err != nil { - return err - } - } - return nil -} - -// Save saves the KeyStorage to disk -func (ks KeyStorage) Save() error { - return ks.public.Save(ks.keyDir) -} - -// SubStorage returns a VersatileSigner for the passed storageTypeID -func (ks *KeyStorage) SubStorage(typeID string) VersatileSigner { - return substorage{ - ks: ks, - typeID: typeID, - } -} - -// Federation returns a VersatileSigner for the KeyStorageTypeFederation -func (ks *KeyStorage) Federation() VersatileSigner { - return ks.SubStorage(KeyStorageTypeFederation) -} - -// OIDC returns a VersatileSigner for the KeyStorageTypeOIDC -func (ks *KeyStorage) OIDC() VersatileSigner { - return ks.SubStorage(KeyStorageTypeOIDC) -} - -// substorage is a type related to a KeyStorage and implements the VersatileSigner interface for a storageTypeID -type substorage struct { - ks *KeyStorage - typeID string -} - -// Signer takes a list of acceptable signature algorithms and returns a -// usable crypto.Signer or nil as well as the corresponding -// jwa.SignatureAlgorithm -func (s substorage) Signer(algs ...string) (crypto.Signer, jwa.SignatureAlgorithm) { - return s.ks.Signer(s.typeID, algs...) -} - -// DefaultSigner returns a crypto.Signer and the corresponding jwa.SignatureAlgorithm -func (s substorage) DefaultSigner() (crypto.Signer, jwa.SignatureAlgorithm) { - return s.ks.DefaultSigner(s.typeID) -} - -// JWKS returns the jwks.JWKS containing all public keys of this VersatileSigner -func (s substorage) JWKS() JWKS { - return s.ks.JWKS(s.typeID) -} diff --git a/jwx/privateKeyStorage.go b/jwx/privateKeyStorage.go deleted file mode 100644 index b17a7dc..0000000 --- a/jwx/privateKeyStorage.go +++ /dev/null @@ -1,38 +0,0 @@ -package jwx - -import ( - "crypto" - - "github.com/lestrrat-go/jwx/v3/jwa" -) - -type privateKeyStorage interface { - Load(pks *pkCollection, pksOnChange func() error) error - GetForAlgs(algs ...string) (crypto.Signer, jwa.SignatureAlgorithm) - GetDefault() (crypto.Signer, jwa.SignatureAlgorithm) - GenerateNewKeys(pks *pkCollection, pksOnChange func() error) error - initKeyRotation(pks *pkCollection, pksOnChange func() error) -} - -func generateStoreAndSetNextPrivateKey( - pks *pkCollection, alg jwa.SignatureAlgorithm, rsaKeyLen int, lifetimeConf keyLifetimeConf, filePath string, - newPKSet bool, -) (crypto.Signer, error) { - skFuture, pkFuture, err := generateKeyPair( - alg, rsaKeyLen, lifetimeConf, - ) - if err != nil { - return nil, err - } - if err = writeSignerToFile(skFuture, filePath); err != nil { - return nil, err - } - if newPKSet { - pkSet := NewJWKS() - _ = pkSet.AddKey(pkFuture) - pks.setNextJWKS(pkSet) - } else { - pks.addNextJWK(pkFuture) - } - return skFuture, nil -} diff --git a/jwx/privateKeyStorageMultiAlg.go b/jwx/privateKeyStorageMultiAlg.go deleted file mode 100644 index e16df32..0000000 --- a/jwx/privateKeyStorageMultiAlg.go +++ /dev/null @@ -1,206 +0,0 @@ -package jwx - -import ( - "crypto" - "fmt" - "os" - "slices" - "time" - - "github.com/go-oidfed/lib/internal" - "github.com/lestrrat-go/jwx/v3/jwa" - "github.com/pkg/errors" - "github.com/zachmann/go-utils/fileutils" - - "github.com/go-oidfed/lib/unixtime" -) - -type privateKeyStorageMultiAlg struct { - typeID string - signers map[jwa.SignatureAlgorithm]crypto.Signer - algs []jwa.SignatureAlgorithm - defaultAlg jwa.SignatureAlgorithm - keyDir string - rsaKeyLen int - rollover RolloverConf -} - -func (sks privateKeyStorageMultiAlg) keyFilePath(alg jwa.SignatureAlgorithm, future bool) string { - var f string - if future { - f = "f" - } - return fmt.Sprintf("%s/%s_%s%s.pem", sks.keyDir, sks.typeID, alg.String(), f) -} - -// GetDefault returns a crypto.Signer and the corresponding jwa.SignatureAlgorithm -func (sks privateKeyStorageMultiAlg) GetDefault() (crypto.Signer, jwa.SignatureAlgorithm) { - if len(sks.algs) == 0 { - return nil, jwa.SignatureAlgorithm{} - } - defaultAlg := sks.defaultAlg - if defaultAlg.String() == "" { - defaultAlg = sks.algs[0] - } - return sks.signers[defaultAlg], defaultAlg -} - -// GetForAlgs takes a list of acceptable signature algorithms and returns a -// usable crypto.Signer or nil as well as the corresponding -// jwa.SignatureAlgorithm -func (sks privateKeyStorageMultiAlg) GetForAlgs(algs ...string) (crypto.Signer, jwa.SignatureAlgorithm) { - for _, alg := range sks.algs { - if slices.Contains(algs, alg.String()) { - return sks.signers[alg], alg - } - } - return nil, jwa.SignatureAlgorithm{} -} - -func (sks *privateKeyStorageMultiAlg) initKeyRotation(pks *pkCollection, pksOnChange func() error) { - if !sks.rollover.Enabled { - return - } - go func() { - for { - sleepDuration := time.Until(pks.jwks[0].MinimalExpirationTime().Time.Add(-5 * time.Second)) - if sleepDuration > 0 { - time.Sleep(sleepDuration) - } - if err := sks.GenerateNewKeys(pks, pksOnChange); err != nil { - internal.Error(err) - } - } - }() -} - -// Load loads the private keys from disk and if necessary generates missing keys -func (sks *privateKeyStorageMultiAlg) Load(pks *pkCollection, pksOnChange func() error) error { - addPublicKeysToJWKS := false - if sks.signers == nil { - sks.signers = make(map[jwa.SignatureAlgorithm]crypto.Signer) - } - if len(pks.jwks) == 0 { - pks.jwks = []JWKS{NewJWKS()} - addPublicKeysToJWKS = true - } - pksChanged := false - - for _, alg := range sks.algs { - signer, changed, err := sks.loadOrGenerateSigner(alg, pks, addPublicKeysToJWKS) - if err != nil { - return err - } - pksChanged = pksChanged || changed - sks.signers[alg] = signer - - // Ensure the next key file exists for rollover - if !fileutils.FileExists(sks.keyFilePath(alg, true)) { - _, err = generateStoreAndSetNextPrivateKey( - pks, alg, sks.rsaKeyLen, keyLifetimeConf{ - Expires: sks.rollover.Enabled, - Lifetime: sks.rollover.Interval.Duration(), - Nbf: &unixtime.Unixtime{Time: pks.jwks[0].MinimalExpirationTime().Add(-10 * time.Second)}, - }, sks.keyFilePath(alg, true), false, - ) - pksChanged = true - if err != nil { - return err - } - } - } - - if addPublicKeysToJWKS || pksChanged { - if err := pksOnChange(); err != nil { - return err - } - } - sks.initKeyRotation(pks, pksOnChange) - return nil -} - -// loadOrGenerateSigner loads a signer from disk or generates a new one if it doesn't exist. -// If addPublicKeysToJWKS is true, it also adds the public key to the pkCollection. -func (sks *privateKeyStorageMultiAlg) loadOrGenerateSigner( - alg jwa.SignatureAlgorithm, pks *pkCollection, addPublicKeysToJWKS bool, -) (crypto.Signer, bool, error) { - filePath := sks.keyFilePath(alg, false) - signer, err := readSignerFromFile(filePath, alg) - if err != nil { - // Could not load key, generating a new one for this alg - sk, pk, err := generateKeyPair( - alg, - sks.rsaKeyLen, - keyLifetimeConf{ - NowIssued: true, - Expires: sks.rollover.Enabled, - Lifetime: sks.rollover.Interval.Duration(), - }, - ) - if err != nil { - return nil, false, err - } - if err = writeSignerToFile(sk, filePath); err != nil { - return nil, false, err - } - pks.addCurrentJWK(pk) - return sk, true, nil - } - if addPublicKeysToJWKS { - pk, err := signerToPublicJWK( - signer, alg, keyLifetimeConf{ - NowIssued: false, - Expires: sks.rollover.Enabled, - Lifetime: sks.rollover.Interval.Duration(), - }, - ) - if err != nil { - return nil, false, err - } - pks.addCurrentJWK(pk) - } - return signer, addPublicKeysToJWKS, nil -} - -// GenerateNewKeys generates a new set of keys -func (sks *privateKeyStorageMultiAlg) GenerateNewKeys(pks *pkCollection, pksOnChange func() error) error { - futureKeys := NewJWKS() - for _, alg := range sks.algs { - skNext, err := readSignerFromFile(sks.keyFilePath(alg, true), alg) - if err != nil { - // if the next sk file does not yet exist, generate it - skNext, err = generateStoreAndSetNextPrivateKey( - pks, alg, sks.rsaKeyLen, keyLifetimeConf{ - Expires: sks.rollover.Enabled, - Lifetime: sks.rollover.Interval.Duration(), - }, sks.keyFilePath(alg, true), false, - ) - if err != nil { - return err - } - } - sks.signers[alg] = skNext - if err = errors.WithStack(os.Rename(sks.keyFilePath(alg, true), sks.keyFilePath(alg, false))); err != nil { - return err - } - - sk, pk, err := generateKeyPair( - alg, sks.rsaKeyLen, keyLifetimeConf{ - Expires: sks.rollover.Enabled, - Lifetime: sks.rollover.Interval.Duration(), - Nbf: &unixtime.Unixtime{Time: pks.jwks[1].MinimalExpirationTime().Add(-10 * time.Second)}, - }, - ) - if err != nil { - return err - } - if err = writeSignerToFile(sk, sks.keyFilePath(alg, false)); err != nil { - return err - } - if err = futureKeys.AddKey(pk); err != nil { - return errors.WithStack(err) - } - } - pks.rotate(futureKeys) - return pksOnChange() -} diff --git a/jwx/privateKeyStorageSingleAlg.go b/jwx/privateKeyStorageSingleAlg.go deleted file mode 100644 index 4e27d32..0000000 --- a/jwx/privateKeyStorageSingleAlg.go +++ /dev/null @@ -1,163 +0,0 @@ -package jwx - -import ( - "crypto" - "fmt" - "os" - "slices" - "time" - - "github.com/go-oidfed/lib/internal" - "github.com/lestrrat-go/jwx/v3/jwa" - "github.com/pkg/errors" - "github.com/zachmann/go-utils/fileutils" - - "github.com/go-oidfed/lib/unixtime" -) - -type privateKeyStorageSingleAlg struct { - typeID string - signer crypto.Signer - alg jwa.SignatureAlgorithm - keyDir string - rsaKeyLen int - rollover RolloverConf -} - -func (sks privateKeyStorageSingleAlg) keyFilePath(future bool) string { - var f string - if future { - f = "f" - } - return fmt.Sprintf("%s/%s_%s%s.pem", sks.keyDir, sks.typeID, sks.alg.String(), f) -} - -// GetDefault returns a crypto.Signer and the corresponding jwa.SignatureAlgorithm -func (sks privateKeyStorageSingleAlg) GetDefault() (crypto.Signer, jwa.SignatureAlgorithm) { - return sks.signer, sks.alg -} - -// GetForAlgs takes a list of acceptable signature algorithms and returns a -// usable crypto.Signer or nil as well as the corresponding -// jwa.SignatureAlgorithm -func (sks privateKeyStorageSingleAlg) GetForAlgs(algs ...string) (crypto.Signer, jwa.SignatureAlgorithm) { - if slices.Contains(algs, sks.alg.String()) { - return sks.GetDefault() - } - return nil, jwa.SignatureAlgorithm{} - -} - -func (sks *privateKeyStorageSingleAlg) initKeyRotation(pks *pkCollection, pksOnChange func() error) { - if !sks.rollover.Enabled { - return - } - go func() { - for { - sleepDuration := time.Until(pks.jwks[0].MinimalExpirationTime().Time.Add(-5 * time.Second)) - if sleepDuration > 0 { - time.Sleep(sleepDuration) - } - if err := sks.GenerateNewKeys(pks, pksOnChange); err != nil { - internal.Error(err) - } - } - }() -} - -// Load loads the key from disk or generates a new one if the key does not exist on disk -func (sks *privateKeyStorageSingleAlg) Load(pks *pkCollection, pksOnChange func() error) error { - signer, err := readSignerFromFile(sks.keyFilePath(false), sks.alg) - if err != nil { - internal.Warn(err) - if err = sks.GenerateNewKeys(pks, pksOnChange); err != nil { - return err - } - sks.initKeyRotation(pks, pksOnChange) - return nil - } - sks.signer = signer - - if len(pks.jwks) == 0 { - // This is only for the case that there is no keys.jwks yet, - // but there was a private key file. - set := NewJWKS() - pk, err := signerToPublicJWK( - signer, sks.alg, keyLifetimeConf{ - NowIssued: false, - Expires: sks.rollover.Enabled, - Lifetime: sks.rollover.Interval.Duration(), - }, - ) - if err != nil { - return err - } - if err = set.AddKey(pk); err != nil { - return errors.WithStack(err) - } - pks.jwks = []JWKS{set} - if err = pksOnChange(); err != nil { - return err - } - } - - if !fileutils.FileExists(sks.keyFilePath(true)) { - _, err = generateStoreAndSetNextPrivateKey( - pks, sks.alg, sks.rsaKeyLen, keyLifetimeConf{ - Expires: sks.rollover.Enabled, - Lifetime: sks.rollover.Interval.Duration(), - Nbf: &unixtime.Unixtime{Time: pks.jwks[0].MinimalExpirationTime().Add(-10 * time.Second)}, - }, sks.keyFilePath(true), true, - ) - if err != nil { - return err - } - if err = pksOnChange(); err != nil { - return err - } - } - - sks.initKeyRotation(pks, pksOnChange) - return nil -} - -// GenerateNewKeys generates a new key -func (sks *privateKeyStorageSingleAlg) GenerateNewKeys(pks *pkCollection, pksOnChange func() error) error { - skNext, err := readSignerFromFile(sks.keyFilePath(true), sks.alg) - if err != nil { - skNext, err = generateStoreAndSetNextPrivateKey( - pks, sks.alg, sks.rsaKeyLen, keyLifetimeConf{ - Expires: sks.rollover.Enabled, - Lifetime: sks.rollover.Interval.Duration(), - }, sks.keyFilePath(true), true, - ) - if err != nil { - return err - } - } - - sks.signer = skNext - if err = errors.WithStack(os.Rename(sks.keyFilePath(true), sks.keyFilePath(false))); err != nil { - return err - } - - newKeys := NewJWKS() - sk, pk, err := generateKeyPair( - sks.alg, sks.rsaKeyLen, keyLifetimeConf{ - Expires: sks.rollover.Enabled, - Lifetime: sks.rollover.Interval.Duration(), - Nbf: &unixtime.Unixtime{Time: pks.jwks[1].MinimalExpirationTime().Add(-10 * time.Second)}, - }, - ) - if err != nil { - return err - } - if err = writeSignerToFile(sk, sks.keyFilePath(true)); err != nil { - return err - } - if err = newKeys.AddKey(pk); err != nil { - return errors.WithStack(err) - } - pks.rotate(newKeys) - return pksOnChange() -} diff --git a/jwx/publicKeyStorage.go b/jwx/publicKeyStorage.go deleted file mode 100644 index c151aea..0000000 --- a/jwx/publicKeyStorage.go +++ /dev/null @@ -1,187 +0,0 @@ -package jwx - -import ( - "encoding/json" - "fmt" - "os" - "path/filepath" - "slices" - - "github.com/go-oidfed/lib/internal" - "github.com/lestrrat-go/jwx/v3/jwk" - "github.com/pkg/errors" - "github.com/zachmann/go-utils/fileutils" -) - -type jwksSlice []JWKS - -type pkCollection struct { - // jwksSlice stores the public key JWKS; the order matters! - // [0] the current JWKS (currently used for signing) - // [1] the next JWKS (will be used next for signing) - // [2...n] previous JWKS, where n is the oldest - jwks jwksSlice - NumberOfOldKeysKeptInJWKS int - KeepHistory bool - history JWKS -} - -// MarshalJSON implements the json.Marshaler interface -func (pks pkCollection) MarshalJSON() ([]byte, error) { - return json.Marshal(pks.jwks) -} - -// UnmarshalJSON implements the json.Unmarshaler interface -func (pks *pkCollection) UnmarshalJSON(data []byte) error { - return json.Unmarshal(data, &pks.jwks) -} - -func (pks *pkCollection) setCurrentJWKS(current JWKS) { - if len(pks.jwks) == 0 { - pks.jwks = append(pks.jwks, current) - return - } - pks.jwks[0] = current -} - -func (pks *pkCollection) addCurrentJWK(current jwk.Key) { - if len(pks.jwks) == 0 { - set := NewJWKS() - _ = set.AddKey(current) - pks.jwks = jwksSlice{set} - return - } - _ = pks.jwks[0].AddKey(current) -} - -func (pks *pkCollection) setNextJWKS(next JWKS) { - if len(pks.jwks) == 0 { - internal.Error("jwx: error setting next JWKS in pkCollection: no current JWKS set") - pks.jwks = append(pks.jwks, next) - } - if len(pks.jwks) == 1 { - pks.jwks = append(pks.jwks, next) - return - } - pks.jwks[1] = next -} - -func (pks *pkCollection) addNextJWK(next jwk.Key) { - if len(pks.jwks) == 0 { - internal.Error("jwx: error setting next JWKS in pkCollection: no current JWKS set") - set := NewJWKS() - _ = set.AddKey(next) - pks.jwks = jwksSlice{set} - } - if len(pks.jwks) == 1 { - set := NewJWKS() - _ = set.AddKey(next) - pks.jwks = append(pks.jwks, set) - return - } - _ = pks.jwks[1].AddKey(next) -} - -func (pks *pkCollection) pushOldJWKS(old JWKS) JWKS { - l := len(pks.jwks) - if l < 2 { - pks.jwks = append(pks.jwks, old) - return zeroJWKS - } - if l == 2 { - pks.jwks = append(pks.jwks, old) - } else { - pks.jwks = slices.Insert(pks.jwks, 2, old) - } - if l-2 >= pks.NumberOfOldKeysKeptInJWKS { - poped := pks.jwks[len(pks.jwks)-1] - pks.jwks = pks.jwks[:len(pks.jwks)-1] - if pks.KeepHistory { - if pks.history.Set == nil { - pks.history = poped - } else { - for i := range poped.Len() { - k, _ := poped.Key(i) - _ = pks.history.AddKey(k) - } - } - } - return poped - } - return zeroJWKS -} - -// rotate rotates the JWKS, the passed JWKS will be set as the next JWKS, -// the previously next JWKS becomes the current JWKS, the previous current JWKS becomes the first old JWKS, -// and all old JWKS are shifted, while the oldest JWKS ( -// if it exceeds the number of old JWKS kept) is removed from the collection and returned. -func (pks *pkCollection) rotate(next JWKS) JWKS { - if len(pks.jwks) == 0 { - pks.jwks = append(pks.jwks, next) - return zeroJWKS - } - previouslyCurrent := pks.jwks[0] - old := pks.pushOldJWKS(previouslyCurrent) - previouslyNext := pks.jwks[1] - pks.setCurrentJWKS(previouslyNext) - pks.setNextJWKS(next) - return old -} - -type aggregatedPublicKeyStorage map[string]*pkCollection - -// Load loads the public keys from disk -func (pks *aggregatedPublicKeyStorage) Load(dir string) error { - data, err := fileutils.ReadFile(jwksFilePath(dir)) - if err != nil { - internal.Warn(err.Error()) - return nil - } - if len(data) == 0 { - return nil - } - if err = errors.WithStack(json.Unmarshal(data, pks)); err != nil { - return err - } - for typeID, collection := range *pks { - data, err = fileutils.ReadFile(jwksHistoryFilePath(dir, typeID)) - if err != nil { - continue - } - if err = errors.WithStack(json.Unmarshal(data, &collection.history)); err != nil { - return err - } - } - return nil -} - -// Save saves the public keys to disk -func (pks aggregatedPublicKeyStorage) Save(dir string) error { - data, err := json.Marshal(pks) - if err != nil { - return errors.WithStack(err) - } - if err = os.WriteFile(jwksFilePath(dir), data, 0600); err != nil { - return errors.WithStack(err) - } - for typeID, collection := range pks { - if collection.history.Set == nil || collection.history.Len() == 0 { - continue - } - data, err = json.Marshal(collection.history) - if err != nil { - return errors.WithStack(err) - } - if err = os.WriteFile(jwksHistoryFilePath(dir, typeID), data, 0600); err != nil { - return errors.WithStack(err) - } - } - return nil -} - -func jwksFilePath(dir string) string { - return filepath.Join(dir, "keys.jwks") -} -func jwksHistoryFilePath(dir, typeID string) string { - return filepath.Join(dir, fmt.Sprintf("%s_history.jwks", typeID)) -} diff --git a/jwx/singleKey.go b/jwx/singleKey.go index ec7a719..186d778 100644 --- a/jwx/singleKey.go +++ b/jwx/singleKey.go @@ -7,16 +7,16 @@ import ( "github.com/lestrrat-go/jwx/v3/jwa" ) -// SingleKeyStorage is a type implementing the oidfed.VersatileSigner interface but only +// SingleKeySigner is a type implementing the oidfed.VersatileSigner interface but only // uses a single key / alg -type SingleKeyStorage struct { +type SingleKeySigner struct { sk crypto.Signer alg jwa.SignatureAlgorithm } -// NewSingleKeyVersatileSigner creates a new SingleKeyStorage -func NewSingleKeyVersatileSigner(sk crypto.Signer, alg jwa.SignatureAlgorithm) SingleKeyStorage { - return SingleKeyStorage{ +// NewSingleKeyVersatileSigner creates a new SingleKeySigner +func NewSingleKeyVersatileSigner(sk crypto.Signer, alg jwa.SignatureAlgorithm) SingleKeySigner { + return SingleKeySigner{ sk: sk, alg: alg, } @@ -25,7 +25,7 @@ func NewSingleKeyVersatileSigner(sk crypto.Signer, alg jwa.SignatureAlgorithm) S // Signer takes a list of acceptable signature algorithms and returns a // usable crypto.Signer or nil as well as the corresponding // jwa.SignatureAlgorithm -func (s SingleKeyStorage) Signer(algs ...string) (crypto.Signer, jwa.SignatureAlgorithm) { +func (s SingleKeySigner) Signer(algs ...string) (crypto.Signer, jwa.SignatureAlgorithm) { if slices.Contains(algs, s.alg.String()) { return s.sk, s.alg } @@ -33,12 +33,11 @@ func (s SingleKeyStorage) Signer(algs ...string) (crypto.Signer, jwa.SignatureAl } // DefaultSigner returns a crypto.Signer and the corresponding jwa.SignatureAlgorithm -func (s SingleKeyStorage) DefaultSigner() (crypto.Signer, jwa.SignatureAlgorithm) { +func (s SingleKeySigner) DefaultSigner() (crypto.Signer, jwa.SignatureAlgorithm) { return s.sk, s.alg } // JWKS returns the jwks.JWKS containing all public keys of this VersatileSigner -func (s SingleKeyStorage) JWKS() JWKS { - jwks, _ := KeyToJWKS(s.sk.Public(), jwa.ES512()) - return jwks +func (s SingleKeySigner) JWKS() (JWKS, error) { + return KeyToJWKS(s.sk.Public(), jwa.ES512()) } diff --git a/mock_tm.go b/mock_tm.go index f8fc67e..0c543c1 100644 --- a/mock_tm.go +++ b/mock_tm.go @@ -70,9 +70,13 @@ func newMockTrustMarkIssuer(entityID string, trustMarkSpecs []TrustMarkSpec) *mo jwx.NewSingleKeyVersatileSigner(sk, jwa.ES512()), ), trustMarkSpecs, ) + jwks, err := tmi.JWKS() + if err != nil { + panic(err) + } mock := &mockTMI{ TrustMarkIssuer: *tmi, - jwks: tmi.JWKS(), + jwks: jwks, } mockEntityConfiguration(mock.EntityID, mock) return mock diff --git a/periodic_entity_collector.go b/periodic_entity_collector.go index 9a8f91d..e468b20 100644 --- a/periodic_entity_collector.go +++ b/periodic_entity_collector.go @@ -196,7 +196,7 @@ func (p *PeriodicEntityCollector) CollectEntities(req apimodel.EntityCollectionR NextEntityID: nextEntityID, } if err = cache.Set(cacheRequestKey, res, p.Interval); err != nil { - internal.Errorf("PeriodicEntityCollector cache set error: %v", err) + internal.WithError(err).Error("PeriodicEntityCollector cache set error") } return &res, nil } @@ -227,7 +227,7 @@ func preparePaginatedResponses( } cacheRequestKey := cache.Key(periodicCacheSubsystem, cacheSubSubSystemRequests, reqHash) if err = cache.Set(cacheRequestKey, res, interval); err != nil { - internal.Errorf("PeriodicEntityCollector cache set error: %v", err) + internal.WithError(err).Error("PeriodicEntityCollector cache set error") } entities = others } @@ -251,7 +251,7 @@ func (p *PeriodicEntityCollector) runOnce() { defer p.cacheMutex.Unlock() if err := cache.Clear(periodicCacheSubsystem); err != nil { - internal.Errorf("PeriodicEntityCollector cache clear error: %v", err) + internal.WithError(err).Error("PeriodicEntityCollector cache clear error") } // Worker pool pattern with a buffered semaphore channel. @@ -287,7 +287,7 @@ func (p *PeriodicEntityCollector) runOnce() { }, p.Interval, ); err != nil { - internal.Errorf("PeriodicEntityCollector cache set error: %v", err) + internal.WithError(err).Error("PeriodicEntityCollector cache set error") } // Notify handler for proactive resolve generation. diff --git a/trustmark_test.go b/trustmark_test.go index f907115..3c9f490 100644 --- a/trustmark_test.go +++ b/trustmark_test.go @@ -84,16 +84,18 @@ var tmo = newMockTrustMarkOwner( }, ) +var tmoJWKS, _ = tmo.JWKS() + var taWithTmo = newMockAuthority( "https://trustmark.ta.com", EntityStatementPayload{ TrustMarkOwners: map[string]TrustMarkOwnerSpec{ "https://trustmarks.org/tm-delegated": { ID: "https://tmo.example.eu", - JWKS: tmo.JWKS(), + JWKS: tmoJWKS, }, "https://trustmarks.org/test": { ID: "https://tmo.example.eu", - JWKS: tmo.JWKS(), + JWKS: tmoJWKS, }, "https://trustmarks.org/other": { ID: "https://other.owner.org", @@ -223,7 +225,12 @@ func TestTrustMarkOwner_DelegationJWT(t *testing.T) { return } } - if err = delegation.VerifyExternal(tmo.JWKS()); err != nil { + jwks, err := tmo.JWKS() + if err != nil { + t.Errorf("could not get JWKS: %v", err) + return + } + if err = delegation.VerifyExternal(jwks); err != nil { t.Errorf("error verifying issued delegation jwt: %v", err) return } @@ -238,9 +245,9 @@ func TestDelegationJWT_VerifyExternal(t *testing.T) { []byte(`{"keys":[{"alg":"ES512","crv":"P-521","kid":"bjQ4ZO1kfWr-cxi-_tU9bKTWwG6XoUwnSW6M5food_U","kty":"EC","use":"sig","x":"AKj5_1MgsEFKCSNN4UyDqQP2wanr9ZD1Q1eBUGJ1BJej8MTQnRkDPRY_35Ctae8bxoj2fxZMufXnWAuVxERelwzL","y":"AObqfUE1k0YIlO1qe-5D8CcTWxZn6OIXC3s_cPrug69sM580aCtug7vEdaBcfNY8RGTwUV1hMxqvOTsQsROrrXG2"}]}`), &correctJWKS, ); err != nil { - t.Error(err) + panic(err) } - wrongKey := tmo.JWKS() + wrongKey := tmoJWKS tests := []struct { name string jwks jwx.JWKS diff --git a/unixtime/unixtime.go b/unixtime/unixtime.go index 03fae07..b87cf72 100644 --- a/unixtime/unixtime.go +++ b/unixtime/unixtime.go @@ -1,6 +1,7 @@ package unixtime import ( + "database/sql/driver" "encoding/json" "math" "time" @@ -14,6 +15,29 @@ type Unixtime struct { time.Time } +// Scan implements the sql.Scanner interface. +func (u *Unixtime) Scan(src any) error { + if src == nil { + u.Time = time.Time{} + return nil + } + t, ok := src.(time.Time) + if !ok { + return errors.Errorf("cannot scan Unixtime from %T (expected time.Time)", src) + } + u.Time = t + return nil +} + +// Value implements the driver.Valuer interface. +func (u Unixtime) Value() (driver.Value, error) { + // Delegate to time.Time driver handling; use NULL for zero value. + if u.IsZero() { + return nil, nil + } + return u.Time, nil +} + // UnmarshalJSON implements the json.Unmarshaler interface. func (u *Unixtime) UnmarshalJSON(src []byte) error { var f float64