diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8e0ef76 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +# IDEs +.idea +.vscode \ No newline at end of file diff --git a/jwe/jwe_config.go b/jwe/jwe_config.go index 7250856..bd18741 100644 --- a/jwe/jwe_config.go +++ b/jwe/jwe_config.go @@ -6,15 +6,16 @@ import ( ) type JWEConfig struct { - decryptionKey *rsa.PrivateKey - certificate *x509.Certificate - encryptedValueFieldName string - encryptionKey *rsa.PublicKey - cek []byte - iv []byte - encryptionPaths map[string]string - decryptionPaths map[string]string - encryptionKeyFingerprint string + decryptionKey *rsa.PrivateKey + certificate *x509.Certificate + encryptedValueFieldName string + encryptionKey *rsa.PublicKey + cek []byte + iv []byte + encryptionPaths map[string]string + decryptionPaths map[string]string + encryptionKeyFingerprint string + enableAuthTagVerification bool } func (config *JWEConfig) GetDecryptionKey() *rsa.PrivateKey { diff --git a/jwe/jwe_config_builder.go b/jwe/jwe_config_builder.go index 0aedf9a..dee537d 100644 --- a/jwe/jwe_config_builder.go +++ b/jwe/jwe_config_builder.go @@ -8,15 +8,16 @@ import ( ) type JWEConfigBuilder struct { - decryptionKey *rsa.PrivateKey - certificate *x509.Certificate - encryptedValueFieldName string - encryptionKey *rsa.PublicKey - cek []byte - iv []byte - encryptionPaths map[string]string - decryptionPaths map[string]string - encryptionKeyFingerprint string + decryptionKey *rsa.PrivateKey + certificate *x509.Certificate + encryptedValueFieldName string + encryptionKey *rsa.PublicKey + cek []byte + iv []byte + encryptionPaths map[string]string + decryptionPaths map[string]string + encryptionKeyFingerprint string + enableAuthTagVerification bool } func NewJWEConfigBuilder() *JWEConfigBuilder { @@ -74,6 +75,11 @@ func (cb *JWEConfigBuilder) WithEncryptedValueFieldName(encryptedValueFieldName return cb } +func (cb *JWEConfigBuilder) WithAuthTagVerificationEnabled(enableAuthTagVerification bool) *JWEConfigBuilder { + cb.enableAuthTagVerification = enableAuthTagVerification + return cb +} + func (cb *JWEConfigBuilder) computeKeyFingerprint() { derEncoded, err := x509.MarshalPKIXPublicKey(cb.encryptionKey) if err != nil { @@ -91,14 +97,15 @@ func (cb *JWEConfigBuilder) Build() *JWEConfig { } return &JWEConfig{ - decryptionKey: cb.decryptionKey, - certificate: cb.certificate, - encryptedValueFieldName: cb.encryptedValueFieldName, - encryptionKey: cb.encryptionKey, - cek: cb.cek, - iv: cb.iv, - encryptionPaths: cb.encryptionPaths, - decryptionPaths: cb.decryptionPaths, - encryptionKeyFingerprint: cb.encryptionKeyFingerprint, + decryptionKey: cb.decryptionKey, + certificate: cb.certificate, + encryptedValueFieldName: cb.encryptedValueFieldName, + encryptionKey: cb.encryptionKey, + cek: cb.cek, + iv: cb.iv, + encryptionPaths: cb.encryptionPaths, + decryptionPaths: cb.decryptionPaths, + encryptionKeyFingerprint: cb.encryptionKeyFingerprint, + enableAuthTagVerification: cb.enableAuthTagVerification, } } diff --git a/jwe/jwe_object.go b/jwe/jwe_object.go index 4f7662f..fea2f7b 100644 --- a/jwe/jwe_object.go +++ b/jwe/jwe_object.go @@ -2,9 +2,12 @@ package jwe import ( "crypto" + "crypto/hmac" "crypto/rand" "crypto/rsa" "crypto/sha256" + "crypto/subtle" + "encoding/binary" "errors" "strings" @@ -15,7 +18,7 @@ import ( const ( A128CBC_HS256 = "A128CBC-HS256" A256GCM = "A256GCM" - A128GCM = "A128GCM" + A128GCM = "A128GCM" A192GCM = "A192GCM" ) @@ -84,6 +87,11 @@ func (jweObject JWEObject) Decrypt(config JWEConfig) (string, error) { } return string(plainText), nil case A128CBC_HS256: + if config.enableAuthTagVerification { + if err := verifyCbcAuthTag(cek, nonce, cipherText, authTag, aad); err != nil { + return "", err + } + } plainText, err := aes_encryption.AesCbcDecrypt(cipherText, cek[16:], nonce, authTag) if err != nil { return "", err @@ -95,6 +103,35 @@ func (jweObject JWEObject) Decrypt(config JWEConfig) (string, error) { } } +func verifyCbcAuthTag(cek, nonce, cipherText, authTag, aad []byte) error { + if len(cek) != 32 { + return errors.New("invalid cek length for A128CBC-HS256") + } + macKeyLen := len(cek) / 2 + macKey := cek[:macKeyLen] + + if len(authTag) != 16 { + return errors.New("invalid authentication tag length") + } + + al := make([]byte, 8) + binary.BigEndian.PutUint64(al, uint64(len(aad)*8)) + + mac := hmac.New(sha256.New, macKey) + mac.Write(aad) + mac.Write(nonce) + mac.Write(cipherText) + mac.Write(al) + expected := mac.Sum(nil) + + expectedAuthTag := expected[:len(authTag)] + if subtle.ConstantTimeCompare(authTag, expectedAuthTag) != 1 { + return errors.New("invalid authentication tag") + } + + return nil +} + func (jweObject JWEObject) Serialize() string { return strings.Join([]string{jweObject.Aad, jweObject.EncryptedKey, jweObject.Iv, jweObject.CipherText, jweObject.AuthTag}, ".") } diff --git a/jwe/jwe_object_test.go b/jwe/jwe_object_test.go index 5a28efc..2814346 100644 --- a/jwe/jwe_object_test.go +++ b/jwe/jwe_object_test.go @@ -1,6 +1,9 @@ package jwe_test import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha256" "testing" "github.com/mastercard/client-encryption-go/jwe" @@ -12,7 +15,7 @@ const ( encryptedPayload256Gcm = "eyJraWQiOiI3NjFiMDAzYzFlYWRlM2E1NDkwZTUwMDBkMzc4ODdiYWE1ZTZlYzBlMjI2YzA3NzA2ZTU5OTQ1MWZjMDMyYTc5IiwiY3R5IjoiYXBwbGljYXRpb25cL2pzb24iLCJlbmMiOiJBMjU2R0NNIiwiYWxnIjoiUlNBLU9BRVAtMjU2In0.8c6vxeZOUBS8A9SXYUSrRnfl1ht9xxciB7TAEv84etZhQQ2civQKso-htpa2DWFBSUm-UYlxb6XtXNXZxuWu-A0WXjwi1K5ZAACc8KUoYnqPldEtC9Q2bhbQgc_qZF_GxeKrOZfuXc9oi45xfVysF_db4RZ6VkLvY2YpPeDGEMX_nLEjzqKaDz_2m0Ae_nknr0p_Nu0m5UJgMzZGR4Sk1DJWa9x-WJLEyo4w_nRDThOjHJshOHaOU6qR5rdEAZr_dwqnTHrjX9Qm9N9gflPGMaJNVa4mvpsjz6LJzjaW3nJ2yCoirbaeJyCrful6cCiwMWMaDMuiBDPKa2ovVTy0Sw.w0Nkjxl0T9HHNu4R.suRZaYu6Ui05Z3-vsw.akknMr3Dl4L0VVTGPUszcA" encryptedPayload128Gcm = "eyJlbmMiOiJBMTI4R0NNIiwiYWxnIjoiUlNBLU9BRVAtMjU2In0.WtvYljbsjdEv-Ttxx1p6PgyIrOsLpj1FMF9NQNhJUAHlKchAo5QImgEgIdgJE7HC2KfpNcHiQVqKKZq_y201FVzpicDkNzlPJr5kIH4Lq-oC5iP0agWeou9yK5vIxFRP__F_B8HSuojBJ3gDYT_KdYffUIHkm_UysNj4PW2RIRlafJ6RKYanVzk74EoKZRG7MIr3pTU6LIkeQUW41qYG8hz6DbGBOh79Nkmq7Oceg0ZwCn1_MruerP-b15SGFkuvOshStT5JJp7OOq82gNAOkMl4fylEj2-vADjP7VSK8GlqrA7u9Tn-a4Q28oy0GOKr1Z-HJgn_CElknwkUTYsWbg.PKl6_kvZ4_4MjmjW.AH6pGFkn7J49hBQcwg.zdyD73TcuveImOy4CRnVpw" encryptedPayload192Gcm = "eyJlbmMiOiJBMTkyR0NNIiwiYWxnIjoiUlNBLU9BRVAtMjU2In0.FWC8PVaZoR2TRKwKO4syhSJReezVIvtkxU_yKh4qODNvlVr8t8ttvySJ-AjM8xdI6vNyIg9jBMWASG4cE49jT9FYuQ72fP4R-Td4vX8wpB8GonQj40yLqZyfRLDrMgPR20RcQDW2ThzLXsgI55B5l5fpwQ9Nhmx8irGifrFWOcJ_k1dUSBdlsHsYxkjRKMENu5x4H6h12gGZ21aZSPtwAj9msMYnKLdiUbdGmGG_P8a6gPzc9ih20McxZk8fHzXKujjukr_1p5OO4o1N4d3qa-YI8Sns2fPtf7xPHnwi1wipmCC6ThFLU80r3173RXcpyZkF8Y3UacOS9y1f8eUfVQ.JRE7kZLN4Im1Rtdb.eW_lJ-U330n0QHqZnQ._r5xYVvMCrvICwLz4chjdw" - encryptedPayloadCbc = "eyJraWQiOiI3NjFiMDAzYzFlYWRlM2E1NDkwZTUwMDBkMzc4ODdiYWE1ZTZlYzBlMjI2YzA3NzA2ZTU5OTQ1MWZjMDMyYTc5IiwiY3R5IjoiYXBwbGljYXRpb25cL2pzb24iLCJlbmMiOiJBMTI4Q0JDLUhTMjU2IiwiYWxnIjoiUlNBLU9BRVAtMjU2In0.5bsamlChk0HR3Nqg2UPJ2Fw4Y0MvC2pwWzNv84jYGkOXyqp1iwQSgETGaplIa7JyLg1ZWOqwNHEx3N7gsN4nzwAnVgz0eta6SsoQUE9YQ-5jek0COslUkoqIQjlQYJnYur7pqttDibj87fcw13G2agle5fL99j1QgFPjNPYqH88DMv481XGFa8O3VfJhW93m73KD2gvE5GasOPOkFK9wjKXc9lMGSgSArp3Awbc_oS2Cho_SbsvuEQwkhnQc2JKT3IaSWu8yK7edNGwD6OZJLhMJzWJlY30dUt2Eqe1r6kMT0IDRl7jHJnVIr2Qpe56CyeZ9V0aC5RH1mI5dYk4kHg.yI0CS3NdBrz9CCW2jwBSDw.6zr2pOSmAGdlJG0gbH53Eg.UFgf3-P9UjgMocEu7QA_vQ" + encryptedPayloadCbc = "eyJraWQiOiI3NjFiMDAzYzFlYWRlM2E1NDkwZTUwMDBkMzc4ODdiYWE1ZTZlYzBlMjI2YzA3NzA2ZTU5OTQ1MWZjMDMyYTc5IiwiY3R5IjoiYXBwbGljYXRpb25cL2pzb24iLCJlbmMiOiJBMTI4Q0JDLUhTMjU2IiwiYWxnIjoiUlNBLU9BRVAtMjU2In0.5bsamlChk0HR3Nqg2UPJ2Fw4Y0MvC2pwWzNv84jYGkOXyqp1iwQSgETGaplIa7JyLg1ZWOqwNHEx3N7gsN4nzwAnVgz0eta6SsoQUE9YQ-5jek0COslUkoqIQjlQYJnYur7pqttDibj87fcw13G2agle5fL99j1QgFPjNPYqH88DMv481XGFa8O3VfJhW93m73KD2gvE5GasOPOkFK9wjKXc9lMGSgSArp3Awbc_oS2Cho_SbsvuEQwkhnQc2JKT3IaSWu8yK7edNGwD6OZJLhMJzWJlY30dUt2Eqe1r6kMT0IDRl7jHJnVIr2Qpe56CyeZ9V0aC5RH1mI5dYk4kHg.yI0CS3NdBrz9CCW2jwBSDw.6zr2pOSmAGdlJG0gbH53Eg.UFgf3-P9UjgMocEu7QA_vQ" ) func TestJWEObject(t *testing.T) { @@ -132,3 +135,110 @@ func TestDecrypt_ShouldReturnDecryptedPayload_WhenPayloadIsCbcEncrypted(t *testi assert.Nil(t, err) assert.Equal(t, "bar", decryptedPayload) } + +func TestDecrypt_ShouldReturnError_WhenCekLengthIsInvalidAndVerificationEnabled(t *testing.T) { + jweObject, err := jwe.ParseJWEObject(encryptedPayloadCbc) + assert.Nil(t, err) + + decryptionKeyPath := "../testdata/keys/pkcs8/test_key_pkcs8-2048.der" + certificatePath := "../testdata/certificates/test_certificate-2048.der" + + decryptionKey, err := utils.LoadUnencryptedDecryptionKey(decryptionKeyPath) + assert.Nil(t, err) + certificate, err := utils.LoadEncryptionCertificate(certificatePath) + assert.Nil(t, err) + + // Replace the encrypted key with one that decrypts to a short CEK (16 bytes) + shortCek := make([]byte, 16) + encryptedShortCek, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, certificate.PublicKey.(*rsa.PublicKey), shortCek, nil) + assert.Nil(t, err) + jweObject.EncryptedKey = utils.Base64UrlEncode(encryptedShortCek) + + cb := jwe.NewJWEConfigBuilder() + jweConfig := cb.WithDecryptionKey(decryptionKey). + WithCertificate(certificate). + WithAuthTagVerificationEnabled(true). + Build() + + decryptedPayload, err := jweObject.Decrypt(*jweConfig) + assert.Empty(t, decryptedPayload) + assert.NotNil(t, err) + assert.EqualError(t, err, "invalid cek length for A128CBC-HS256") +} + +func TestDecrypt_ShouldReturnError_WhenAuthTagLengthIsInvalidAndVerificationEnabled(t *testing.T) { + jweObject, err := jwe.ParseJWEObject(encryptedPayloadCbc) + assert.Nil(t, err) + + // Make the auth tag decode to an invalid length (not 16 bytes) + jweObject.AuthTag = utils.Base64UrlEncode([]byte("short-tag")) + + decryptionKeyPath := "../testdata/keys/pkcs8/test_key_pkcs8-2048.der" + certificatePath := "../testdata/certificates/test_certificate-2048.der" + + decryptionKey, err := utils.LoadUnencryptedDecryptionKey(decryptionKeyPath) + assert.Nil(t, err) + certificate, err := utils.LoadEncryptionCertificate(certificatePath) + assert.Nil(t, err) + + cb := jwe.NewJWEConfigBuilder() + jweConfig := cb.WithDecryptionKey(decryptionKey). + WithCertificate(certificate). + WithAuthTagVerificationEnabled(true). + Build() + + decryptedPayload, err := jweObject.Decrypt(*jweConfig) + assert.Empty(t, decryptedPayload) + assert.NotNil(t, err) + assert.EqualError(t, err, "invalid authentication tag length") +} + +func TestDecrypt_ShouldReturnError_WhenAuthTagMacIsInvalidAndVerificationEnabled(t *testing.T) { + jweObject, err := jwe.ParseJWEObject(encryptedPayloadCbc) + assert.Nil(t, err) + + // Preserve tag length (16 bytes decoded) but change value so MAC verification fails + jweObject.AuthTag = utils.Base64UrlEncode(make([]byte, 16)) + + decryptionKeyPath := "../testdata/keys/pkcs8/test_key_pkcs8-2048.der" + certificatePath := "../testdata/certificates/test_certificate-2048.der" + + decryptionKey, err := utils.LoadUnencryptedDecryptionKey(decryptionKeyPath) + assert.Nil(t, err) + certificate, err := utils.LoadEncryptionCertificate(certificatePath) + assert.Nil(t, err) + + cb := jwe.NewJWEConfigBuilder() + jweConfig := cb.WithDecryptionKey(decryptionKey). + WithCertificate(certificate). + WithAuthTagVerificationEnabled(true). + Build() + + decryptedPayload, err := jweObject.Decrypt(*jweConfig) + assert.Empty(t, decryptedPayload) + assert.NotNil(t, err) + assert.EqualError(t, err, "invalid authentication tag") +} + +func TestDecrypt_ShouldReturnDecryptedPayload_WhenVerificationEnabledAndAuthTagIsValid(t *testing.T) { + jweObject, err := jwe.ParseJWEObject(encryptedPayloadCbc) + assert.Nil(t, err) + + decryptionKeyPath := "../testdata/keys/pkcs8/test_key_pkcs8-2048.der" + certificatePath := "../testdata/certificates/test_certificate-2048.der" + + decryptionKey, err := utils.LoadUnencryptedDecryptionKey(decryptionKeyPath) + assert.Nil(t, err) + certificate, err := utils.LoadEncryptionCertificate(certificatePath) + assert.Nil(t, err) + + cb := jwe.NewJWEConfigBuilder() + jweConfig := cb.WithDecryptionKey(decryptionKey). + WithCertificate(certificate). + WithAuthTagVerificationEnabled(true). + Build() + + decryptedPayload, err := jweObject.Decrypt(*jweConfig) + assert.Nil(t, err) + assert.Equal(t, "bar", decryptedPayload) +}