44from cryptography .hazmat .primitives import serialization
55from cryptography .hazmat .primitives .asymmetric import padding
66from cryptography .hazmat .primitives .asymmetric import rsa
7- from eth_account import Account
87from eth_keys import keys
98from .types import ItString , ItUint
109
1514key_size = 32
1615
1716
18- def encrypt (key , plaintext ):
17+ def encrypt (user_aes_key : bytes , plaintext : int ):
1918 # Ensure plaintext is smaller than 128 bits (16 bytes)
2019 if len (plaintext ) > block_size :
2120 raise ValueError ("Plaintext size must be 128 bits or smaller." )
2221
2322 # Ensure key size is 128 bits (16 bytes)
24- if len (key ) != block_size :
23+ if len (user_aes_key ) != block_size :
2524 raise ValueError ("Key size must be 128 bits." )
2625
2726 # Create a new AES cipher block using the provided key
28- cipher = AES .new (key , AES .MODE_ECB )
27+ cipher = AES .new (user_aes_key , AES .MODE_ECB )
2928
3029 # Generate a random value 'r' of the same length as the block size
3130 r = get_random_bytes (block_size )
@@ -42,20 +41,20 @@ def encrypt(key, plaintext):
4241 return ciphertext , r
4342
4443
45- def decrypt (key , r , ciphertext ):
44+ def decrypt (user_aes_key : bytes , r : bytes , ciphertext : bytes ):
4645 if len (ciphertext ) != block_size :
4746 raise ValueError ("Ciphertext size must be 128 bits." )
4847
4948 # Ensure key size is 128 bits (16 bytes)
50- if len (key ) != block_size :
49+ if len (user_aes_key ) != block_size :
5150 raise ValueError ("Key size must be 128 bits." )
5251
5352 # Ensure random size is 128 bits (16 bytes)
5453 if len (r ) != block_size :
5554 raise ValueError ("Random size must be 128 bits." )
5655
5756 # Create a new AES cipher block using the provided key
58- cipher = AES .new (key , AES .MODE_ECB )
57+ cipher = AES .new (user_aes_key , AES .MODE_ECB )
5958
6059 # Encrypt the random value 'r' using AES in ECB mode
6160 encrypted_r = cipher .encrypt (r )
@@ -73,14 +72,14 @@ def generate_aes_key():
7372 return key
7473
7574
76- def sign_input_text (sender : Account , addr : str , function_selector : str , ct , key ):
75+ def sign_input_text (sender_address : str , contract_address : str , function_selector : str , ct , key ):
7776 function_selector_bytes = bytes .fromhex (function_selector [2 :])
7877
7978 # Ensure all input sizes are the correct length
80- if len (sender ) != address_size :
81- raise ValueError (f"Invalid sender address length: { len (sender )} bytes, must be { address_size } bytes" )
82- if len (addr ) != address_size :
83- raise ValueError (f"Invalid contract address length: { len (addr )} bytes, must be { address_size } bytes" )
79+ if len (sender_address ) != address_size :
80+ raise ValueError (f"Invalid sender address length: { len (sender_address )} bytes, must be { address_size } bytes" )
81+ if len (contract_address ) != address_size :
82+ raise ValueError (f"Invalid contract address length: { len (contract_address )} bytes, must be { address_size } bytes" )
8483 if len (function_selector_bytes ) != function_selector_size :
8584 raise ValueError (f"Invalid signature size: { len (function_selector_bytes )} bytes, must be { function_selector_size } bytes" )
8685 if len (ct ) != ct_size :
@@ -90,7 +89,7 @@ def sign_input_text(sender: Account, addr: str, function_selector: str, ct, key)
9089 raise ValueError (f"Invalid key length: { len (key )} bytes, must be { key_size } bytes" )
9190
9291 # Create the message to be signed by appending all inputs
93- message = sender + addr + function_selector_bytes + ct
92+ message = sender_address + contract_address + function_selector_bytes + ct
9493
9594 return sign (message , key )
9695
@@ -103,19 +102,16 @@ def sign(message, key):
103102 return signature
104103
105104
106- def build_input_text (plaintext : int , user_aes_key , sender : Account , contract , function_selector : str , signing_key ) -> ItUint :
107- sender_address_bytes = bytes .fromhex (sender .address [2 :])
108- contract_address_bytes = bytes .fromhex (contract .address [2 :])
109-
105+ def build_input_text (plaintext : int , user_aes_key : str , sender_address : str , contract_address : str , function_selector : str , signing_key : str ) -> ItUint :
110106 # Convert the integer to a byte slice with size aligned to 8.
111107 plaintext_bytes = plaintext .to_bytes ((plaintext .bit_length () + 7 ) // 8 , 'big' )
112108
113109 # Encrypt the plaintext with the user's AES key
114- ciphertext , r = encrypt (user_aes_key , plaintext_bytes )
110+ ciphertext , r = encrypt (bytes . fromhex ( user_aes_key ) , plaintext_bytes )
115111 ct = ciphertext + r
116112
117113 # Sign the message
118- signature = sign_input_text (sender_address_bytes , contract_address_bytes , function_selector , ct , signing_key )
114+ signature = sign_input_text (bytes . fromhex ( sender_address [ 2 :]), bytes . fromhex ( contract_address [ 2 :]) , function_selector , ct , signing_key )
119115
120116 # Convert the ct to an integer
121117 int_cipher_text = int .from_bytes (ct , byteorder = 'big' )
@@ -126,7 +122,7 @@ def build_input_text(plaintext: int, user_aes_key, sender: Account, contract, fu
126122 }
127123
128124
129- def build_string_input_text (plaintext , user_aes_key , sender , contract , function_selector , signing_key ) -> ItString :
125+ def build_string_input_text (plaintext : int , user_aes_key : str , sender_address : str , contract_address : str , function_selector : str , signing_key : str ) -> ItString :
130126 input_text = {
131127 'ciphertext' : {
132128 'value' : []
@@ -144,8 +140,8 @@ def build_string_input_text(plaintext, user_aes_key, sender, contract, function_
144140 it_int = build_input_text (
145141 int .from_bytes (byte_arr , 'big' ),
146142 user_aes_key ,
147- sender ,
148- contract ,
143+ sender_address ,
144+ contract_address ,
149145 function_selector ,
150146 signing_key
151147 )
@@ -156,7 +152,7 @@ def build_string_input_text(plaintext, user_aes_key, sender, contract, function_
156152 return input_text
157153
158154
159- def decrypt_uint (ciphertext , user_key ) -> int :
155+ def decrypt_uint (ciphertext : int , user_aes_key : str ) -> int :
160156 # Convert ct to bytes (big-endian)
161157 byte_array = ciphertext .to_bytes (32 , byteorder = 'big' )
162158
@@ -165,15 +161,15 @@ def decrypt_uint(ciphertext, user_key) -> int:
165161 r = byte_array [block_size :]
166162
167163 # Decrypt the cipher
168- decrypted_message = decrypt (user_key , r , cipher )
164+ decrypted_message = decrypt (bytes . fromhex ( user_aes_key ) , r , cipher )
169165
170166 # Print the decrypted cipher
171167 decrypted_uint = int .from_bytes (decrypted_message , 'big' )
172168
173169 return decrypted_uint
174170
175171
176- def decrypt_string (ciphertext , user_key ) -> str :
172+ def decrypt_string (ciphertext : int , user_aes_key : str ) -> str :
177173 if 'value' in ciphertext or hasattr (ciphertext , 'value' ): # format when reading ciphertext from an event
178174 __ciphertext = ciphertext ['value' ]
179175 elif isinstance (ciphertext , tuple ): # format when reading ciphertext from state variable
@@ -184,7 +180,7 @@ def decrypt_string(ciphertext, user_key) -> str:
184180 decrypted_string = ""
185181
186182 for value in __ciphertext :
187- decrypted = decrypt_uint (value , user_key )
183+ decrypted = decrypt_uint (value , user_aes_key )
188184
189185 byte_length = (decrypted .bit_length () + 7 ) // 8 # calculate the byte length
190186
@@ -221,7 +217,7 @@ def generate_rsa_keypair():
221217 return private_key_bytes , public_key_bytes
222218
223219
224- def decrypt_rsa (private_key_bytes , ciphertext ):
220+ def decrypt_rsa (private_key_bytes : bytes , ciphertext : bytes ):
225221 # Load private key
226222 private_key = serialization .load_der_private_key (private_key_bytes , password = None )
227223 # Decrypt ciphertext
@@ -237,7 +233,7 @@ def decrypt_rsa(private_key_bytes, ciphertext):
237233
238234#This function recovers a user's key by decrypting two encrypted key shares with the given private key,
239235#and then XORing the two key shares together.
240- def recover_user_key (private_key_bytes , encrypted_key_share0 , encrypted_key_share1 ):
236+ def recover_user_key (private_key_bytes : bytes , encrypted_key_share0 : bytes , encrypted_key_share1 : bytes ):
241237 key_share0 = decrypt_rsa (private_key_bytes , encrypted_key_share0 )
242238 key_share1 = decrypt_rsa (private_key_bytes , encrypted_key_share1 )
243239
0 commit comments