Skip to content

Commit d8d41e4

Browse files
committed
add __eq__ for HashAlgorithm and padding instances
1 parent d1bcb3e commit d8d41e4

File tree

10 files changed

+338
-7
lines changed

10 files changed

+338
-7
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ Changelog
6868
:class:`~cryptography.hazmat.primitives.ciphers.aead.AESSIV`, and
6969
:class:`~cryptography.hazmat.primitives.ciphers.aead.ChaCha20Poly1305` to
7070
allow encrypting directly into a pre-allocated buffer.
71+
* Builtin hash classes and instances of classes in
72+
:mod:`~cryptography.hazmat.primitives.asymmetric.padding` can now be compared with `==`
7173

7274
.. _v46-0-3:
7375

docs/hazmat/primitives/asymmetric/cloudhsm.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ if you only need a subset of functionality.
8888
... Maps the cryptography padding and algorithm to the corresponding KMS signing algorithm.
8989
... This is specific to your implementation.
9090
... """
91-
... if isinstance(padding, PKCS1v15) and isinstance(algorithm, hashes.SHA256):
91+
... if padding == PKCS1v15() and algorithm == hashes.SHA256():
9292
... return b"RSA_PKCS1_V1_5_SHA_256"
9393
... else:
9494
... raise NotImplementedError()

docs/x509/reference.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ Loading Certificate Revocation Lists
248248
>>> from cryptography import x509
249249
>>> from cryptography.hazmat.primitives import hashes
250250
>>> crl = x509.load_pem_x509_crl(pem_crl_data)
251-
>>> isinstance(crl.signature_hash_algorithm, hashes.SHA256)
251+
>>> crl.signature_hash_algorithm == hashes.SHA256()
252252
True
253253

254254
.. function:: load_der_x509_crl(data)
@@ -287,7 +287,7 @@ Loading Certificate Signing Requests
287287
>>> from cryptography import x509
288288
>>> from cryptography.hazmat.primitives import hashes
289289
>>> csr = x509.load_pem_x509_csr(pem_req_data)
290-
>>> isinstance(csr.signature_hash_algorithm, hashes.SHA256)
290+
>>> csr.signature_hash_algorithm == hashes.SHA256()
291291
True
292292

293293
.. function:: load_der_x509_csr(data)
@@ -477,7 +477,7 @@ X.509 Certificate Object
477477
.. doctest::
478478

479479
>>> from cryptography.hazmat.primitives import hashes
480-
>>> isinstance(cert.signature_hash_algorithm, hashes.SHA256)
480+
>>> cert.signature_hash_algorithm == hashes.SHA256()
481481
True
482482

483483
.. attribute:: signature_algorithm_oid
@@ -716,7 +716,7 @@ X.509 CRL (Certificate Revocation List) Object
716716
.. doctest::
717717

718718
>>> from cryptography.hazmat.primitives import hashes
719-
>>> isinstance(crl.signature_hash_algorithm, hashes.SHA256)
719+
>>> crl.signature_hash_algorithm == hashes.SHA256()
720720
True
721721

722722
.. attribute:: signature_algorithm_oid
@@ -1119,7 +1119,7 @@ X.509 CSR (Certificate Signing Request) Object
11191119
.. doctest::
11201120

11211121
>>> from cryptography.hazmat.primitives import hashes
1122-
>>> isinstance(csr.signature_hash_algorithm, hashes.SHA256)
1122+
>>> csr.signature_hash_algorithm == hashes.SHA256()
11231123
True
11241124

11251125
.. attribute:: signature_algorithm_oid

src/cryptography/hazmat/primitives/asymmetric/padding.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import abc
8+
import typing
89

910
from cryptography.hazmat.primitives import hashes
1011
from cryptography.hazmat.primitives._asymmetric import (
@@ -16,6 +17,9 @@
1617
class PKCS1v15(AsymmetricPadding):
1718
name = "EMSA-PKCS1-v1_5"
1819

20+
def __eq__(self, other: typing.Any) -> bool:
21+
return isinstance(other, PKCS1v15)
22+
1923

2024
class _MaxLength:
2125
"Sentinel value for `MAX_LENGTH`."
@@ -56,6 +60,17 @@ def __init__(
5660

5761
self._salt_length = salt_length
5862

63+
def __eq__(self, other: typing.Any) -> bool:
64+
if not isinstance(other, PSS):
65+
return False
66+
67+
if isinstance(self._salt_length, int):
68+
eq_salt_length = self._salt_length == other._salt_length
69+
else:
70+
eq_salt_length = self._salt_length is other._salt_length
71+
72+
return eq_salt_length and self._mgf == other._mgf
73+
5974
@property
6075
def mgf(self) -> MGF:
6176
return self._mgf
@@ -77,6 +92,14 @@ def __init__(
7792
self._algorithm = algorithm
7893
self._label = label
7994

95+
def __eq__(self, other: typing.Any) -> bool:
96+
return (
97+
isinstance(other, OAEP)
98+
and self._mgf == other._mgf
99+
and self._algorithm == other._algorithm
100+
and self._label == other._label
101+
)
102+
80103
@property
81104
def algorithm(self) -> hashes.HashAlgorithm:
82105
return self._algorithm
@@ -89,6 +112,13 @@ def mgf(self) -> MGF:
89112
class MGF(metaclass=abc.ABCMeta):
90113
_algorithm: hashes.HashAlgorithm
91114

115+
@abc.abstractmethod
116+
def __eq__(self, other: typing.Any) -> bool:
117+
"""
118+
Implement equality checking.
119+
"""
120+
...
121+
92122

93123
class MGF1(MGF):
94124
def __init__(self, algorithm: hashes.HashAlgorithm):
@@ -97,6 +127,9 @@ def __init__(self, algorithm: hashes.HashAlgorithm):
97127

98128
self._algorithm = algorithm
99129

130+
def __eq__(self, other: typing.Any) -> bool:
131+
return isinstance(other, MGF1) and self._algorithm == other._algorithm
132+
100133

101134
def calculate_max_pss_salt_length(
102135
key: rsa.RSAPrivateKey | rsa.RSAPublicKey,

src/cryptography/hazmat/primitives/hashes.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import abc
8+
import typing
89

910
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
1011
from cryptography.utils import Buffer
@@ -103,66 +104,99 @@ class SHA1(HashAlgorithm):
103104
digest_size = 20
104105
block_size = 64
105106

107+
def __eq__(self, other: typing.Any) -> bool:
108+
return isinstance(other, SHA1)
109+
106110

107111
class SHA512_224(HashAlgorithm): # noqa: N801
108112
name = "sha512-224"
109113
digest_size = 28
110114
block_size = 128
111115

116+
def __eq__(self, other: typing.Any) -> bool:
117+
return isinstance(other, SHA512_224)
118+
112119

113120
class SHA512_256(HashAlgorithm): # noqa: N801
114121
name = "sha512-256"
115122
digest_size = 32
116123
block_size = 128
117124

125+
def __eq__(self, other: typing.Any) -> bool:
126+
return isinstance(other, SHA512_256)
127+
118128

119129
class SHA224(HashAlgorithm):
120130
name = "sha224"
121131
digest_size = 28
122132
block_size = 64
123133

134+
def __eq__(self, other: typing.Any) -> bool:
135+
return isinstance(other, SHA224)
136+
124137

125138
class SHA256(HashAlgorithm):
126139
name = "sha256"
127140
digest_size = 32
128141
block_size = 64
129142

143+
def __eq__(self, other: typing.Any) -> bool:
144+
return isinstance(other, SHA256)
145+
130146

131147
class SHA384(HashAlgorithm):
132148
name = "sha384"
133149
digest_size = 48
134150
block_size = 128
135151

152+
def __eq__(self, other: typing.Any) -> bool:
153+
return isinstance(other, SHA384)
154+
136155

137156
class SHA512(HashAlgorithm):
138157
name = "sha512"
139158
digest_size = 64
140159
block_size = 128
141160

161+
def __eq__(self, other: typing.Any) -> bool:
162+
return isinstance(other, SHA512)
163+
142164

143165
class SHA3_224(HashAlgorithm): # noqa: N801
144166
name = "sha3-224"
145167
digest_size = 28
146168
block_size = None
147169

170+
def __eq__(self, other: typing.Any) -> bool:
171+
return isinstance(other, SHA3_224)
172+
148173

149174
class SHA3_256(HashAlgorithm): # noqa: N801
150175
name = "sha3-256"
151176
digest_size = 32
152177
block_size = None
153178

179+
def __eq__(self, other: typing.Any) -> bool:
180+
return isinstance(other, SHA3_256)
181+
154182

155183
class SHA3_384(HashAlgorithm): # noqa: N801
156184
name = "sha3-384"
157185
digest_size = 48
158186
block_size = None
159187

188+
def __eq__(self, other: typing.Any) -> bool:
189+
return isinstance(other, SHA3_384)
190+
160191

161192
class SHA3_512(HashAlgorithm): # noqa: N801
162193
name = "sha3-512"
163194
digest_size = 64
164195
block_size = None
165196

197+
def __eq__(self, other: typing.Any) -> bool:
198+
return isinstance(other, SHA3_512)
199+
166200

167201
class SHAKE128(HashAlgorithm, ExtendableOutputFunction):
168202
name = "shake128"
@@ -177,6 +211,12 @@ def __init__(self, digest_size: int):
177211

178212
self._digest_size = digest_size
179213

214+
def __eq__(self, other: typing.Any) -> bool:
215+
return (
216+
isinstance(other, SHAKE128)
217+
and self._digest_size == other._digest_size
218+
)
219+
180220
@property
181221
def digest_size(self) -> int:
182222
return self._digest_size
@@ -195,6 +235,12 @@ def __init__(self, digest_size: int):
195235

196236
self._digest_size = digest_size
197237

238+
def __eq__(self, other: typing.Any) -> bool:
239+
return (
240+
isinstance(other, SHAKE256)
241+
and self._digest_size == other._digest_size
242+
)
243+
198244
@property
199245
def digest_size(self) -> int:
200246
return self._digest_size
@@ -205,6 +251,9 @@ class MD5(HashAlgorithm):
205251
digest_size = 16
206252
block_size = 64
207253

254+
def __eq__(self, other: typing.Any) -> bool:
255+
return isinstance(other, MD5)
256+
208257

209258
class BLAKE2b(HashAlgorithm):
210259
name = "blake2b"
@@ -218,6 +267,12 @@ def __init__(self, digest_size: int):
218267

219268
self._digest_size = digest_size
220269

270+
def __eq__(self, other: typing.Any) -> bool:
271+
return (
272+
isinstance(other, BLAKE2b)
273+
and self._digest_size == other._digest_size
274+
)
275+
221276
@property
222277
def digest_size(self) -> int:
223278
return self._digest_size
@@ -235,6 +290,12 @@ def __init__(self, digest_size: int):
235290

236291
self._digest_size = digest_size
237292

293+
def __eq__(self, other: typing.Any) -> bool:
294+
return (
295+
isinstance(other, BLAKE2s)
296+
and self._digest_size == other._digest_size
297+
)
298+
238299
@property
239300
def digest_size(self) -> int:
240301
return self._digest_size
@@ -244,3 +305,6 @@ class SM3(HashAlgorithm):
244305
name = "sm3"
245306
digest_size = 32
246307
block_size = 64
308+
309+
def __eq__(self, other: typing.Any) -> bool:
310+
return isinstance(other, SM3)

tests/doubles.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
33
# for complete details.
44

5+
import typing
56

67
from cryptography.hazmat.primitives import hashes, serialization
78
from cryptography.hazmat.primitives.asymmetric import padding
@@ -40,6 +41,12 @@ class DummyHashAlgorithm(hashes.HashAlgorithm):
4041
def __init__(self, digest_size: int = 32) -> None:
4142
self._digest_size = digest_size
4243

44+
def __eq__(self, other: typing.Any) -> bool:
45+
return (
46+
isinstance(self, DummyHashAlgorithm)
47+
and self._digest_size == other._digest_size
48+
)
49+
4350
@property
4451
def digest_size(self) -> int:
4552
return self._digest_size

tests/hazmat/backends/test_openssl.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55

66
import itertools
7+
import typing
78

89
import pytest
910

@@ -32,6 +33,9 @@ class DummyMGF(padding.MGF):
3233
_salt_length = 0
3334
_algorithm = hashes.SHA1()
3435

36+
def __eq__(self, other: typing.Any) -> bool:
37+
return isinstance(other, DummyMGF)
38+
3539

3640
class TestOpenSSL:
3741
def test_backend_exists(self):
@@ -194,6 +198,11 @@ def test_rsa_padding_unsupported_mgf(self):
194198
is False
195199
)
196200

201+
def test_dummy_mgf_eq(self):
202+
"""This test just exists to fix code coverage for the dummy class."""
203+
assert DummyMGF() == DummyMGF()
204+
assert DummyMGF() != padding.MGF1(hashes.SHA256())
205+
197206
def test_unsupported_mgf1_hash_algorithm_md5_decrypt(self, rsa_key_2048):
198207
with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_PADDING):
199208
rsa_key_2048.decrypt(

0 commit comments

Comments
 (0)