Skip to content

Commit 045d532

Browse files
author
Spencer Miller
committed
add type hinting
1 parent 5ae962d commit 045d532

File tree

2 files changed

+38
-35
lines changed

2 files changed

+38
-35
lines changed

coti/crypto_utils.py

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from cryptography.hazmat.primitives import serialization
55
from cryptography.hazmat.primitives.asymmetric import padding
66
from cryptography.hazmat.primitives.asymmetric import rsa
7-
from eth_account import Account
87
from eth_keys import keys
98
from .types import ItString, ItUint
109

@@ -15,17 +14,17 @@
1514
key_size = 32
1615

1716

18-
def encrypt(key, plaintext):
17+
def encrypt(user_aes_key: bytes, plaintext: int):
1918
# Ensure plaintext is smaller than 128 bits (16 bytes)
2019
if len(plaintext) > block_size:
2120
raise ValueError("Plaintext size must be 128 bits or smaller.")
2221

2322
# Ensure key size is 128 bits (16 bytes)
24-
if len(key) != block_size:
23+
if len(user_aes_key) != block_size:
2524
raise ValueError("Key size must be 128 bits.")
2625

2726
# Create a new AES cipher block using the provided key
28-
cipher = AES.new(key, AES.MODE_ECB)
27+
cipher = AES.new(user_aes_key, AES.MODE_ECB)
2928

3029
# Generate a random value 'r' of the same length as the block size
3130
r = get_random_bytes(block_size)
@@ -42,20 +41,20 @@ def encrypt(key, plaintext):
4241
return ciphertext, r
4342

4443

45-
def decrypt(key, r, ciphertext):
44+
def decrypt(user_aes_key: bytes, r: bytes, ciphertext: bytes):
4645
if len(ciphertext) != block_size:
4746
raise ValueError("Ciphertext size must be 128 bits.")
4847

4948
# Ensure key size is 128 bits (16 bytes)
50-
if len(key) != block_size:
49+
if len(user_aes_key) != block_size:
5150
raise ValueError("Key size must be 128 bits.")
5251

5352
# Ensure random size is 128 bits (16 bytes)
5453
if len(r) != block_size:
5554
raise ValueError("Random size must be 128 bits.")
5655

5756
# Create a new AES cipher block using the provided key
58-
cipher = AES.new(key, AES.MODE_ECB)
57+
cipher = AES.new(user_aes_key, AES.MODE_ECB)
5958

6059
# Encrypt the random value 'r' using AES in ECB mode
6160
encrypted_r = cipher.encrypt(r)
@@ -73,14 +72,14 @@ def generate_aes_key():
7372
return key
7473

7574

76-
def sign_input_text(sender: Account, addr: str, function_selector: str, ct, key):
75+
def sign_input_text(sender_address: str, contract_address: str, function_selector: str, ct, key):
7776
function_selector_bytes = bytes.fromhex(function_selector[2:])
7877

7978
# Ensure all input sizes are the correct length
80-
if len(sender) != address_size:
81-
raise ValueError(f"Invalid sender address length: {len(sender)} bytes, must be {address_size} bytes")
82-
if len(addr) != address_size:
83-
raise ValueError(f"Invalid contract address length: {len(addr)} bytes, must be {address_size} bytes")
79+
if len(sender_address) != address_size:
80+
raise ValueError(f"Invalid sender address length: {len(sender_address)} bytes, must be {address_size} bytes")
81+
if len(contract_address) != address_size:
82+
raise ValueError(f"Invalid contract address length: {len(contract_address)} bytes, must be {address_size} bytes")
8483
if len(function_selector_bytes) != function_selector_size:
8584
raise ValueError(f"Invalid signature size: {len(function_selector_bytes)} bytes, must be {function_selector_size} bytes")
8685
if len(ct) != ct_size:
@@ -90,7 +89,7 @@ def sign_input_text(sender: Account, addr: str, function_selector: str, ct, key)
9089
raise ValueError(f"Invalid key length: {len(key)} bytes, must be {key_size} bytes")
9190

9291
# Create the message to be signed by appending all inputs
93-
message = sender + addr + function_selector_bytes + ct
92+
message = sender_address + contract_address + function_selector_bytes + ct
9493

9594
return sign(message, key)
9695

@@ -103,19 +102,16 @@ def sign(message, key):
103102
return signature
104103

105104

106-
def build_input_text(plaintext: int, user_aes_key, sender: Account, contract, function_selector: str, signing_key) -> ItUint:
107-
sender_address_bytes = bytes.fromhex(sender.address[2:])
108-
contract_address_bytes = bytes.fromhex(contract.address[2:])
109-
105+
def build_input_text(plaintext: int, user_aes_key: str, sender_address: str, contract_address: str, function_selector: str, signing_key: str) -> ItUint:
110106
# Convert the integer to a byte slice with size aligned to 8.
111107
plaintext_bytes = plaintext.to_bytes((plaintext.bit_length() + 7) // 8, 'big')
112108

113109
# Encrypt the plaintext with the user's AES key
114-
ciphertext, r = encrypt(user_aes_key, plaintext_bytes)
110+
ciphertext, r = encrypt(bytes.fromhex(user_aes_key), plaintext_bytes)
115111
ct = ciphertext + r
116112

117113
# Sign the message
118-
signature = sign_input_text(sender_address_bytes, contract_address_bytes, function_selector, ct, signing_key)
114+
signature = sign_input_text(bytes.fromhex(sender_address[2:]), bytes.fromhex(contract_address[2:]), function_selector, ct, signing_key)
119115

120116
# Convert the ct to an integer
121117
int_cipher_text = int.from_bytes(ct, byteorder='big')
@@ -126,7 +122,7 @@ def build_input_text(plaintext: int, user_aes_key, sender: Account, contract, fu
126122
}
127123

128124

129-
def build_string_input_text(plaintext, user_aes_key, sender, contract, function_selector, signing_key) -> ItString:
125+
def build_string_input_text(plaintext: int, user_aes_key: str, sender_address: str, contract_address: str, function_selector: str, signing_key: str) -> ItString:
130126
input_text = {
131127
'ciphertext': {
132128
'value': []
@@ -144,8 +140,8 @@ def build_string_input_text(plaintext, user_aes_key, sender, contract, function_
144140
it_int = build_input_text(
145141
int.from_bytes(byte_arr, 'big'),
146142
user_aes_key,
147-
sender,
148-
contract,
143+
sender_address,
144+
contract_address,
149145
function_selector,
150146
signing_key
151147
)
@@ -156,7 +152,7 @@ def build_string_input_text(plaintext, user_aes_key, sender, contract, function_
156152
return input_text
157153

158154

159-
def decrypt_uint(ciphertext, user_key) -> int:
155+
def decrypt_uint(ciphertext: int, user_aes_key: str) -> int:
160156
# Convert ct to bytes (big-endian)
161157
byte_array = ciphertext.to_bytes(32, byteorder='big')
162158

@@ -165,15 +161,15 @@ def decrypt_uint(ciphertext, user_key) -> int:
165161
r = byte_array[block_size:]
166162

167163
# Decrypt the cipher
168-
decrypted_message = decrypt(user_key, r, cipher)
164+
decrypted_message = decrypt(bytes.fromhex(user_aes_key), r, cipher)
169165

170166
# Print the decrypted cipher
171167
decrypted_uint = int.from_bytes(decrypted_message, 'big')
172168

173169
return decrypted_uint
174170

175171

176-
def decrypt_string(ciphertext, user_key) -> str:
172+
def decrypt_string(ciphertext: int, user_aes_key: str) -> str:
177173
if 'value' in ciphertext or hasattr(ciphertext, 'value'): # format when reading ciphertext from an event
178174
__ciphertext = ciphertext['value']
179175
elif isinstance(ciphertext, tuple): # format when reading ciphertext from state variable
@@ -184,7 +180,7 @@ def decrypt_string(ciphertext, user_key) -> str:
184180
decrypted_string = ""
185181

186182
for value in __ciphertext:
187-
decrypted = decrypt_uint(value, user_key)
183+
decrypted = decrypt_uint(value, user_aes_key)
188184

189185
byte_length = (decrypted.bit_length() + 7) // 8 # calculate the byte length
190186

@@ -221,7 +217,7 @@ def generate_rsa_keypair():
221217
return private_key_bytes, public_key_bytes
222218

223219

224-
def decrypt_rsa(private_key_bytes, ciphertext):
220+
def decrypt_rsa(private_key_bytes: bytes, ciphertext: bytes):
225221
# Load private key
226222
private_key = serialization.load_der_private_key(private_key_bytes, password=None)
227223
# Decrypt ciphertext
@@ -237,7 +233,7 @@ def decrypt_rsa(private_key_bytes, ciphertext):
237233

238234
#This function recovers a user's key by decrypting two encrypted key shares with the given private key,
239235
#and then XORing the two key shares together.
240-
def recover_user_key(private_key_bytes, encrypted_key_share0, encrypted_key_share1):
236+
def recover_user_key(private_key_bytes: bytes, encrypted_key_share0: bytes, encrypted_key_share1: bytes):
241237
key_share0 = decrypt_rsa(private_key_bytes, encrypted_key_share0)
242238
key_share1 = decrypt_rsa(private_key_bytes, encrypted_key_share1)
243239

coti/types.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
1-
from typing import Dict, List, TypeAlias, Union
1+
from typing import List, TypeAlias, TypedDict
22

3-
ItBool: TypeAlias = Dict[str, Union[int, bytes]] # { 'ciphertext': int, 'signature': bytes }
3+
CtBool: TypeAlias = int
44

5-
ItUint: TypeAlias = Dict[str, Union[int, bytes]] # { 'ciphertext': int, 'signature': bytes }
5+
CtUint: TypeAlias = int
66

7-
ItString: TypeAlias = Dict[str, Union[Dict[str, List[int]], List[bytes]]] # { 'ciphertext': { 'value': List[int] }, 'signature': List[bytes] }
7+
class CtString(TypedDict):
8+
value: List[int]
89

9-
CtBool: TypeAlias = int
10+
class ItBool(TypedDict):
11+
ciphertext: int
12+
signature: bytes
1013

11-
CtUint: TypeAlias = int
14+
class ItUint(TypedDict):
15+
ciphertext: int
16+
signature: bytes
1217

13-
CtString: TypeAlias = Dict[str, List[int]] # { 'value': List[int] }
18+
class ItString(TypedDict):
19+
ciphertext: CtString
20+
signature: List[bytes]

0 commit comments

Comments
 (0)