diff --git a/examples/client_rsa_mutual_auth.py b/examples/client_rsa_mutual_auth.py index 9206d6d..03c188d 100644 --- a/examples/client_rsa_mutual_auth.py +++ b/examples/client_rsa_mutual_auth.py @@ -19,25 +19,26 @@ from scapy.layers.ssl_tls import * from scapy.all import * +from scapy_ssl_tls import multidigest_pkcs1_15 as Sig_multi_PKCS1_v1_5 def do_tls_mutual_auth(host): - with open(os.path.join(basedir, "tests/integration/keys/scapy-tls-client.crt.der"), "rb") as f: + with open(os.path.join(basedir, "tests/integration/keys/rsa_clnt1.der"), "rb") as f: client_cert = f.read() certificate = TLSCertificate(data=client_cert) - tls_version = TLSVersion.TLS_1_2 + tls_version = TLSVersion.TLS_1_0 with TLSSocket(socket.socket(), client=True) as tls_socket: tls_socket.connect(host) tls_socket.tls_ctx.client_ctx.load_rsa_keys_from_file(os.path.join( - basedir, "tests/integration/keys/scapy-tls-client.key.pem")) + basedir, "tests/integration/keys/rsa_clnt1ky")) client_hello = TLSRecord(version=tls_version) / \ TLSHandshakes(handshakes=[TLSHandshake() / TLSClientHello(version=tls_version, - cipher_suites=[TLSCipherSuite.ECDHE_RSA_WITH_AES_128_CBC_SHA256])]) + cipher_suites=[0x0035])]) server_hello = tls_socket.do_round_trip(client_hello) - # server_hello.show() + server_hello.show() client_cert = TLSRecord(version=tls_version) / \ TLSHandshakes(handshakes=[TLSHandshake() / TLSCertificateList() / @@ -48,12 +49,18 @@ def do_tls_mutual_auth(host): p = TLS.from_records([client_cert, client_key_exchange]) tls_socket.do_round_trip(p, recv=False) - sig = tls_socket.tls_ctx.compute_client_cert_verify(digest=Cryptodome.Hash.SHA256) + sig = tls_socket.tls_ctx.compute_client_cert_verify(digest=Sig_multi_PKCS1_v1_5) #TLS1.0 + #sig = tls_socket.tls_ctx.compute_client_cert_verify(digest=Cryptodome.Hash.SHA256) #TLS1.2 + + #client_cert_verify = TLSRecord(version=tls_version) / \ + # TLSHandshakes(handshakes=[TLSHandshake() / + # TLS12CertificateVerify(alg=TLSSignatureScheme.RSA_PKCS1_SHA256, + # sig=sig)]) #TLS1.2 client_cert_verify = TLSRecord(version=tls_version) / \ TLSHandshakes(handshakes=[TLSHandshake() / - TLSCertificateVerify(alg=TLSSignatureScheme.RSA_PKCS1_SHA256, - sig=sig)]) + TLSCertificateVerify(sig=sig)]) #TLS1.0 + tls_socket.do_round_trip(client_cert_verify, recv=False) client_ccs = TLSRecord(version=tls_version) / TLSChangeCipherSpec() @@ -71,5 +78,5 @@ def do_tls_mutual_auth(host): if len(sys.argv) > 2: server = (sys.argv[1], int(sys.argv[2])) else: - server = ("127.0.0.1", 8443) + server = ("10.102.59.251", 443) do_tls_mutual_auth(server) diff --git a/examples/dtlshandshake_with_cke.py b/examples/dtlshandshake_with_cke.py new file mode 100644 index 0000000..0bdca77 --- /dev/null +++ b/examples/dtlshandshake_with_cke.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import with_statement +from __future__ import print_function +try: + # This import works from the project directory + from scapy_ssl_tls.ssl_tls import * +except ImportError: + # If you installed this package via pip, you just need to execute this + from scapy.layers.ssl_tls import * + +basedir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) + +tls_version = TLSVersion.DTLS_1_0 +#ciphers = [TLSCipherSuite.RSA_WITH_AES_256_CBC_SHA] +ciphers = [0x0035] +tls_server_names = "abc.com" +tls_session_ticket = "myticket" +extensions=[ + TLSExtension() / + TLSExtServerNameIndication(server_names=TLSServerName(data=tls_server_names)), + TLSExtension() / + TLSExtSessionTicketTLS(data=tls_session_ticket), + ] + + +def dtls_client(server): + with open(os.path.join(basedir, "tests/integration/keys/scapy-tls-client.crt.der"), "rb") as f: + client_cert = f.read() + certificate = TLSCertificate(data=client_cert) + + sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + with TLSSocket(sockfd, client=True) as tls_socket: + try: + tls_socket.connect(server) + print("Connected to server: %s" % (server,)) + except socket.timeout: + print("Failed to open connection to server: %s" % (server,), file=sys.stderr) + else: + try: + server_hello, server_kex = tls_socket.do_handshake(tls_version, ciphers, extensions) + server_kex.show() + server_hello.show() + except TLSProtocolError as tpe: + print("Got TLS error: %s" % tpe, file=sys.stderr) + tpe.response.show() + else: + app_data = DTLSRecord(version=tls_version, sequence=1, epoch=1) / TLSPlaintext(data="GET / HTTP/1.1\r\nHOST: 10.102.59.251\r\n\r\n") + tls_socket.sendall(app_data) + resp = tls_socket.recvall() + print("Got response from server") + resp.show() + finally: + print(tls_socket.tls_ctx) + + +if __name__ == "__main__": + if len(sys.argv) > 2: + server = (sys.argv[1], int(sys.argv[2])) + else: + server = ("10.102.59.251", 4433) + dtls_client(server) diff --git a/examples/dtlshandshake_with_clientauth.py b/examples/dtlshandshake_with_clientauth.py new file mode 100644 index 0000000..6efa58a --- /dev/null +++ b/examples/dtlshandshake_with_clientauth.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import with_statement +from __future__ import print_function +try: + # This import works from the project directory + from scapy_ssl_tls.ssl_tls import * +except ImportError: + # If you installed this package via pip, you just need to execute this + from scapy.layers.ssl_tls import * + +from scapy_ssl_tls import multidigest_pkcs1_15 as Sig_multi_PKCS1_v1_5 + +basedir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) + +version = TLSVersion.DTLS_1_0 +#ciphers = [TLSCipherSuite.RSA_WITH_AES_256_CBC_SHA] +ciphers = [0x0035] +tls_server_names = "abc.com" +tls_session_ticket = "myticket" +extensions=[ + TLSExtension() / + TLSExtServerNameIndication(server_names=TLSServerName(data=tls_server_names)), + TLSExtension() / + TLSExtSessionTicketTLS(data=tls_session_ticket), + ] + + +def dtls_client(server): + with open(os.path.join(basedir, "tests/integration/keys/rsa_clnt1.der"), "rb") as f: + client_cert = f.read() + certificate = TLSCertificate(data=client_cert) + + sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + with TLSSocket(sockfd, client=True) as tls_socket: + try: + tls_socket.connect(server) + print("Connected to server: %s" % (server,)) + tls_socket.tls_ctx.client_ctx.load_rsa_keys_from_file(os.path.join( + basedir, "tests/integration/keys/rsa_clnt1ky")) + except socket.timeout: + print("Failed to open connection to server: %s" % (server,), file=sys.stderr) + else: + try: + client_hello = DTLSRecord(version=version, sequence=0) / \ + DTLSHandshake(fragment_offset=0) / \ + DTLSClientHello(version=version, + compression_methods=TLSCompressionMethod.NULL, + cipher_suites=ciphers, + extensions=extensions) + resp1 = tls_do_round_trip(tls_socket, client_hello) + resp1.show() + client_cert = DTLSRecord(version=version, sequence=2) / \ + DTLSHandshake(fragment_offset=0, sequence=1) / \ + DTLSCertificateList() / \ + TLS10Certificate(certificates=certificate) + + client_key_exchange = DTLSRecord(version=version, sequence=3) / \ + DTLSHandshake(fragment_offset=0, sequence=2) / DTLSClientKeyExchange() / \ + tls_socket.tls_ctx.get_client_kex_data() + + p = TLS.from_records([client_cert, client_key_exchange]) + tls_socket.sendall(p) + + sig = tls_socket.tls_ctx.compute_client_cert_verify(digest=Sig_multi_PKCS1_v1_5) + + client_cert_verify = DTLSRecord(version=version, sequence=4) / \ + DTLSHandshake(fragment_offset=0, sequence=3) / \ + DTLSCertificateVerify(sig=sig) + + tls_socket.sendall(client_cert_verify) + + client_ccs = DTLSRecord(version=version, sequence=5) / DTLSChangeCipherSpec() + + tls_socket.sendall(client_ccs) + + client_finished = DTLSRecord(version=version, sequence=0, epoch=1) / \ + DTLSHandshake(fragment_offset=0, sequence=4) / \ + DTLSFinished(data=tls_socket.tls_ctx.get_verify_data()) + + #resp2 = tls_socket.do_round_trip([client_ccs, client_finished], False) + resp2 = tls_do_round_trip(tls_socket, client_finished) + resp2.show() + + except TLSProtocolError as tpe: + print("Got TLS error: %s" % tpe, file=sys.stderr) + tpe.response.show() + else: + app_data = DTLSRecord(version=version, sequence=1, epoch=1) / TLSPlaintext(data="GET / HTTP/1.1\r\nHOST: 10.102.59.251\r\n\r\n") + tls_socket.sendall(app_data) + resp = tls_socket.recvall() + print("Got response from server") + resp.show() + finally: + print(tls_socket.tls_ctx) + + +if __name__ == "__main__": + if len(sys.argv) > 2: + server = (sys.argv[1], int(sys.argv[2])) + else: + server = ("10.102.57.144", 4433) + dtls_client(server) diff --git a/examples/psec1438_dtls.py b/examples/psec1438_dtls.py new file mode 100644 index 0000000..12d7b7c --- /dev/null +++ b/examples/psec1438_dtls.py @@ -0,0 +1,343 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import with_statement +from __future__ import print_function +import sys +import binascii +from struct import * + +try: + # This import works from the project directory + from scapy_ssl_tls.ssl_tls import * + from scapy_ssl_tls.ssl_tls_crypto import * +except ImportError: + # If you installed this package via pip, you just need to execute this + from scapy.layers.ssl_tls import * + from scapy.layers.ssl_tls_crypto import * + +''' +To run this test you need to +- install scapy +- install scapy-ssl_tls (https://github.com/tintinweb/scapy-ssl_tls) +- move this file to [path to scapy-ssl_tls]/scapy-ssl_tls/examples +''' + +indexpad = 0 +indexmac = 0 +incpad = 0 +incpadlgth = 0 +incmac = 0 +verbose = 5 + + +def modify_padding(crypto_container): + ''' + function modifying the crypto. padding + modify padding byte #indexpad with byte incpad + ''' + padding = crypto_container.padding + if verbose > 10: + print('old pad', binascii.hexlify(bytearray(crypto_container.padding))) + crypto_container.padding = ("%s" + chr(incpad) + "%s") % (padding[:indexpad], padding[indexpad + 1:]) + x = bytearray(crypto_container.padding) + if verbose > 10: + print('iv', binascii.hexlify(bytearray(crypto_container.explicit_iv))) + print('mac', binascii.hexlify(bytearray(crypto_container.mac))) + print('pad', binascii.hexlify(bytearray(crypto_container.padding))) + print('pln', binascii.hexlify(bytearray(crypto_container.padding_len))) + return crypto_container + + +def modify_mac(crypto_container): + ''' + function modifying the crypto. mac + modify mac byte #indexmac with byte #incmac + ''' + # print("--- modify_mac") + mac = crypto_container.mac + if verbose > 10: + print('old mac', binascii.hexlify(bytearray(crypto_container.mac))) + crypto_container.mac = ("%s" + chr(incmac) + "%s") % (mac[:indexmac], mac[indexmac + 1:]) + + if verbose > 10: + print('iv', binascii.hexlify(bytearray(crypto_container.explicit_iv))) + print('mac', binascii.hexlify(bytearray(crypto_container.mac))) + print('pad', binascii.hexlify(bytearray(crypto_container.padding))) + print('pln', binascii.hexlify(bytearray(crypto_container.padding_len))) + return crypto_container + + +def modify_macpad(crypto_container): + ''' + function modifying the crypto. padding and mac + modify padding byte #indexpad with byte #incpad + modify mac byte #indexmac with byte incmac + ''' + # print("--- modify_mac") + padding = crypto_container.padding + mac = crypto_container.mac + if verbose > 10: + print('old pad', binascii.hexlify(bytearray(crypto_container.padding))) + print('old mac', binascii.hexlify(bytearray(crypto_container.mac))) + crypto_container.padding = ("%s" + chr(incpad) + "%s") % (padding[:indexpad], padding[indexpad + 1:]) + crypto_container.mac = ("%s" + chr(incmac) + "%s") % (mac[:indexmac], mac[indexmac + 1:]) + x = bytearray(crypto_container.mac) + if verbose > 10: + print('iv', binascii.hexlify(bytearray(crypto_container.explicit_iv))) + print('mac', binascii.hexlify(bytearray(crypto_container.mac))) + print('pad', binascii.hexlify(bytearray(crypto_container.padding))) + print('pln', binascii.hexlify(bytearray(crypto_container.padding_len))) + return crypto_container + + +def modify_padding_length(crypto_container): + ''' + function modifying the crypto. padding length + modify padding length byte with byte incpadlgth + this test should return the same result as an invalid mac + ''' + # print("--- modify_padding_length") + l = crypto_container.padding_len + if verbose > 10: + print('old pad len', binascii.hexlify(bytearray(crypto_container.padding_len))) + crypto_container.padding_len = chr(incpadlgth) + + if verbose > 10: + print('iv', binascii.hexlify(bytearray(crypto_container.explicit_iv))) + print('mac', binascii.hexlify(bytearray(crypto_container.mac))) + print('pad', binascii.hexlify(bytearray(crypto_container.padding))) + print('pln', binascii.hexlify(bytearray(crypto_container.padding_len))) + return crypto_container + + +def send_application_data(server, cipher_suite, data, hook): + # print("--- send_application_data") + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(server) + tls_socket = TLSSocket(s, client=True) + version = TLSVersion.DTLS_1_0 # TLS_1_0 + dtls_do_handshake(tls_socket, version, cipher_suite) + tls_socket.pre_encrypt_hook = hook + tls_socket.sendall(DTLSRecord(version=version, sequence=1, epoch=1) / TLSPlaintext(data=data)) + + resp = [] + tls_socket._s.settimeout(1) + try: + # we expect to see here the timeout or the RST + data = tls_socket._s.recv(8192) + resp.append(data) + except Exception as e: + # print( "first response", e) + # we expect to get here the TLS alert + data = tls_socket._s.recv(8192) + if data: + resp.append(data) + + # we decode the packet + record = TLS("".join(resp), ctx=tls_socket.tls_ctx, _origin=tls_socket._get_pkt_origin('in')) + + if verbose > 10: + # we show the full packet + record.show() + + return record + + +def align_data_on_block_boundary(data, cipher_suite, pad_char="a"): + ''' + function takin as input the raw data, a padding character and the cipher suite used + and outputing the data padded to fit the cipher suite specifications + ''' + # print("--- align_data_on_block_boundary") + data_len = len(data) + block_len = TLSSecurityParameters.crypto_params[cipher_suite]["cipher"]["type"].block_size + mac_len = TLSSecurityParameters.crypto_params[cipher_suite]["hash"]["type"].digest_size + junk_len = block_len - ((data_len + mac_len) % block_len) + return "%s%s" % (data, pad_char * junk_len) + + +def test_all_field_bytes(server, cipher_suite, block_aligned_request, modify_padding): + ''' + function describing in human language the result of the test + ''' + # print("--- test_all_field_bytes") + + error_msg = "" + try: + resp = send_application_data(server, cipher_suite, block_aligned_request, modify_padding) + if len(resp.records) == 0: + error_msg = "Server is not vulnerable, but implementation does not send a BAD_RECORD_MAC alert" # most likely a RST + elif resp.haslayer(DTLSRecord) and resp[DTLSRecord].content_type == TLSAlertDescription.BAD_RECORD_MAC: + error_msg = "bad_mac" # badmac + elif resp.haslayer(TLSAlert) and resp[TLSAlert].description == TLSAlertDescription.DECRYPT_ERROR: + error_msg = "decrypt_error" + else: + #print(binascii.hexlify(bytearray(resp[Raw].load))) + error_msg = "Server is probably vulnerable\n" # different response, could be sign of a vulnerability or that the packet is correct and the page queried was not found + error_msg += "If application data was displayed above, server is definitely vulnerable" + #resp.show() + if verbose > 10: + resp.show() # show dubious packet + except Exception as e: + error_msg = e # Timeouts will appear here + return error_msg + + +if __name__ == "__main__": + ''' + main function to launch all of the tests + by default we test: + - with and without app data + -- incorrect padding byte + -- incorrect mac byte + -- incorrect padding byte and mac byte + the incorrect values are by default \x00, this can be changed at the top of the file + to test all incorrect bytes, add loops to modify incmac, incpad and incpadlgth from 0 to 255 + ''' + + server = "" + host = None + result = None + if len(sys.argv) == 3: + server = (sys.argv[1], int(sys.argv[2])) + host = sys.argv[1] + elif len(sys.argv) == 4: + server = (sys.argv[1], int(sys.argv[2])) + verbose = int(sys.argv[3]) + host = sys.argv[1] + else: + server = ("10.102.59.251", 4433) + cipher_suite = TLSCipherSuite.RSA_WITH_AES_256_CBC_SHA + # cipher_suite = TLSCipherSuite.RSA_WITH_AES_128_CBC_SHA + #resfile = 'result_' + host + '.txt' + # sys.stdout = open(resfile,'w') + + print("TLS Poodle: testing host", server, end='\t') + + ## Case with App Data + print("\n\n\nTests with APPDATA ----------------------------------------------", end='\t') + request = "GET / HTTP/1.1\r\nHOST: %s\r\n\r\n" % server[0] + block_aligned_request = align_data_on_block_boundary(request, cipher_suite) + ''' + print("\n\n\nTesting correct case", end =" ") + errmsg = test_all_field_bytes(server, cipher_suite, block_aligned_request, None) + result = '\tPASSED\t' + print(errmsg, result, end = '\t') + ''' + indexpad = 0 + print("\n\n\nTesting all padding bytes", end=" ") + # Perform poodle 2 check + for _ in range(0, TLSSecurityParameters.crypto_params[cipher_suite]["cipher"]["type"].block_size - 1): + print("\nModifying padding byte %d" % indexpad, end='\t') + try: + errmsg = test_all_field_bytes(server, cipher_suite, block_aligned_request, modify_padding) + if (errmsg == 'Server is not vulnerable, but implementation does not send a BAD_RECORD_MAC alert'): + result = '\tPASSED\t' + else: + result = '\tFAILED\t' + print(errmsg, result, end='\t') + except Exception as e: + print(e) + indexpad += 1 + + # incpadlgth = 0 + # print("\n\n\nTesting all padding length") + # for i in range(0,256): + # print("\nModifying padding length with byte %d" % indexpadlgth) # modify_padding_length + # print(test_all_field_bytes(server, cipher_suite, block_aligned_request, modify_padding_length)) + # incpadlgth += 1 + + indexmac = 0 + print("\n\n\nTesting all mac bytes", end='\t') + # Perform mac check + for _ in range(0, TLSSecurityParameters.crypto_params[cipher_suite]["hash"]["type"].digest_size - 1): + print("\nModifying mac byte %d" % indexmac, end='\t') + errmsg = test_all_field_bytes(server, cipher_suite, block_aligned_request, modify_mac) + if (errmsg == 'Server is not vulnerable, but implementation does not send a BAD_RECORD_MAC alert'): + result = '\tPASSED\t' + else: + result = '\tFAILED\t' + print(errmsg, result, end='\t') + indexmac += 1 + + indexmac = 0 + indexpad = 0 + print("\n\n\nTesting bad mac bad padding", end='\t') + for _ in range(0, TLSSecurityParameters.crypto_params[cipher_suite]["cipher"]["type"].block_size - 1): + indexmac = 0 + for _ in range(0, TLSSecurityParameters.crypto_params[cipher_suite]["hash"]["type"].digest_size - 1): + print("\n\nModifying pad index %d" % indexpad, " and mac index %d" % indexmac, end='\t') + errmsg = test_all_field_bytes(server, cipher_suite, block_aligned_request, modify_macpad) + if (errmsg == 'Server is not vulnerable, but implementation does not send a BAD_RECORD_MAC alert'): + result = '\tPASSED\t' + else: + result = '\tFAILED\t' + print(errmsg, result, end='\t') + indexmac += 1 + indexpad += 1 + + ## Case without App Data + print("\n\n\nTests WITHOUT APPDATA----------------------------------------------", end='\t') + request = "" + block_aligned_request = request + ''' + print("\n\n\nTesting correct case", end = '\t') + errmsg = test_all_field_bytes(server, cipher_suite, block_aligned_request, None) + result = '\tPASSED\t' + print(errmsg, result, end = '\t') + ''' + indexpad = 0 + print("\n\n\nTesting all padding bytes", end='\t') + # Perform poodle 2 check + for _ in range(0, TLSSecurityParameters.crypto_params[cipher_suite]["cipher"]["type"].block_size - 1): + print("\nModifying padding byte %d" % indexpad, end='\t') + try: + errmsg = test_all_field_bytes(server, cipher_suite, block_aligned_request, modify_padding) + if (errmsg == 'Server is not vulnerable, but implementation does not send a BAD_RECORD_MAC alert'): + result = '\tPASSED\t' + else: + result = '\tFAILED\t' + print(errmsg, result, end='\t') + except Exception as e: + print(e) + indexpad += 1 + + # incpadlgth = 0 + # print("\n\n\nTesting all padding length") + # for i in range(0,256): + # print("\nModifying padding length %d" % indexpadlgth) # modify_padding_length + # print(test_all_field_bytes(server, cipher_suite, block_aligned_request, modify_padding_length)) + # incpadlgth += 1 + + indexmac = 0 + print("\n\n\nTesting all mac bytes", end='\t') + # Perform mac check + for _ in range(0, TLSSecurityParameters.crypto_params[cipher_suite]["hash"]["type"].digest_size - 1): + print("\nModifying mac byte %d" % indexmac, end='\t') + errmsg = test_all_field_bytes(server, cipher_suite, block_aligned_request, modify_mac) + if (errmsg == 'Server is not vulnerable, but implementation does not send a BAD_RECORD_MAC alert'): + result = '\tPASSED\t' + else: + result = '\tFAILED\t' + print(errmsg, result, end='\t') + indexmac += 1 + + indexmac = 0 + indexpad = 0 + print("\n\n\nTesting bad mac bad padding", end='\t') + for _ in range(0, TLSSecurityParameters.crypto_params[cipher_suite]["cipher"]["type"].block_size - 1): + indexmac = 0 + for _ in range(0, TLSSecurityParameters.crypto_params[cipher_suite]["hash"]["type"].digest_size - 1): + print("\n\nModifying pad index %d" % indexpad, " and mac index %d" % indexmac, end='\t') + errmsg = test_all_field_bytes(server, cipher_suite, block_aligned_request, modify_macpad) + if (errmsg == 'Server is not vulnerable, but implementation does not send a BAD_RECORD_MAC alert'): + result = '\tPASSED\t' + else: + result = '\tFAILED\t' + print(errmsg, result, end='\t') + indexmac += 1 + indexpad += 1 + + print("\n\nTest complete", end='\t') + # fp.close() diff --git a/scapy_ssl_tls/multidigest_pkcs1_15.py b/scapy_ssl_tls/multidigest_pkcs1_15.py new file mode 100644 index 0000000..9e43e9b --- /dev/null +++ b/scapy_ssl_tls/multidigest_pkcs1_15.py @@ -0,0 +1,21 @@ +import Cryptodome.Util.number +from Cryptodome.Util.number import ceil_div, bytes_to_long, long_to_bytes + +class multidigest_pkcs1_15: + + def __init__(self, rsa_key): + self._key = rsa_key + + def sign(self, msg_digest): + modBits = Cryptodome.Util.number.size(self._key.n) + k = ceil_div(modBits, 8) # Convert from bits to bytes + + ps = b'\xFF' * (k - len(msg_digest) - 3) + em = b'\x00\x01' + ps + b'\x00' + msg_digest + em_int = bytes_to_long(em) + m_int = self._key._decrypt(em_int) + signature = long_to_bytes(m_int, k) + return signature + +def new(rsa_key): + return multidigest_pkcs1_15(rsa_key) \ No newline at end of file diff --git a/scapy_ssl_tls/ssl_tls.py b/scapy_ssl_tls/ssl_tls.py index abed1f4..651c1c6 100644 --- a/scapy_ssl_tls/ssl_tls.py +++ b/scapy_ssl_tls/ssl_tls.py @@ -775,7 +775,7 @@ def guess_payload_class(self, payload): pkt = self.underlayer # If our underlayer is a handshake, use the tls_ctx to determine # wheat KEX we are currently using - if pkt is not None and pkt.haslayer(TLSHandshake) and hasattr(pkt, "tls_ctx"): + if pkt is not None and (pkt.haslayer(TLSHandshake) or pkt.haslayer(DTLSHandshake)) and hasattr(pkt, "tls_ctx"): if pkt.tls_ctx is not None: kex = pkt.tls_ctx.negotiated.key_exchange return self.kex_payload_table.get(kex, Raw) @@ -918,12 +918,16 @@ def guess_payload_class(self, payload): return TLS10Certificate -class TLSCertificateVerify(PacketNoPayload): +class TLS12CertificateVerify(PacketNoPayload): name = "TLS Certificate Verify" fields_desc = [ShortEnumField("alg", TLSSignatureScheme.RSA_PKCS1_SHA256, TLS_SIGNATURE_SCHEMES), XFieldLenField("sig_length", None, length_of="sig", fmt="H"), # ASN.1 signature element StrLenField("sig", "", length_from=lambda x:x.sig_length)] +class TLSCertificateVerify(PacketNoPayload): + name = "TLS Certificate Verify" + fields_desc = [XFieldLenField("sig_length", None, length_of="sig", fmt="H"), # ASN.1 signature element + StrLenField("sig", "", length_from=lambda x:x.sig_length)] class TLSCertificateType(PacketNoPayload): name = "TLS Certificate Type" @@ -1101,6 +1105,26 @@ class TLSCiphertext(Packet): name = "TLS Ciphertext" fields_desc = [StrField("data", None, fmt="H")] +class DTLSServerDHParams(PacketNoPayload): + name = "DTLS Diffie-Hellman Server Params" + fields_desc = [XFieldLenField("p_length", None, length_of="p", fmt="!H"), + StrLenField("p", '', length_from=lambda x:x.p_length), + XFieldLenField("g_length", None, length_of="g", fmt="!H"), + StrLenField("g", '', length_from=lambda x:x.g_length), + XFieldLenField("ys_length", None, length_of="y_s", fmt="!H"), + StrLenField("y_s", "", length_from=lambda x:x.ys_length), + XFieldLenField("sig_length", None, length_of="sig", fmt="!H"), + StrLenField("sig", '', length_from=lambda x:x.sig_length)] + + +class DTLSServerECDHParams(PacketNoPayload): + name = "DTLS EC Diffie-Hellman Server Params" + fields_desc = [ByteEnumField("curve_type", TLSECCurveTypes.NAMED_CURVE, TLS_EC_CURVE_TYPES), + ShortEnumField("curve_name", TLSSupportedGroup.SECP256R1, TLS_SUPPORTED_GROUPS), + XFieldLenField("p_length", None, length_of="p", fmt="!B"), + StrLenField("p", '', length_from=lambda x:x.p_length), + XFieldLenField("sig_length", None, length_of="sig", fmt="!H"), + StrLenField("sig", '', length_from=lambda x:x.sig_length)] class DTLSRecord(PacketLengthFieldPayload): name = "DTLS Record" @@ -1141,16 +1165,105 @@ class DTLSClientHello(PacketNoPayload): else False), PacketListField("extensions", None, TLSExtension, length_from=lambda x:x.extensions_length)] -SSLv2_CERTIFICATE_TYPES = {0x01: 'x509'} -SSLv2CertificateType = EnumStruct(SSLv2_CERTIFICATE_TYPES) - - class DTLSHelloVerify(PacketNoPayload): name = "DTLS Hello Verify" fields_desc = [XShortEnumField("version", TLSVersion.DTLS_1_0, TLS_VERSIONS), XFieldLenField("cookie_length", None, length_of="cookie", fmt="B"), StrLenField("cookie", '', length_from=lambda x:x.cookie_length)] +class DTLSServerHello(PacketNoPayload): + name = "DTLS Server Hello" + fields_desc = [XShortEnumField("version", TLSVersion.DTLS_1_0, TLS_VERSIONS), + IntField("gmt_unix_time", int(time.time())), + StrFixedLenField("random_bytes", os.urandom(28), 28), + XFieldLenField("session_id_length", None, length_of="session_id", fmt="B"), + StrLenField("session_id", os.urandom(20), length_from=lambda x:x.session_id_length), + XShortEnumField("cipher_suite", TLSCipherSuite.RSA_WITH_AES_128_CBC_SHA, TLS_CIPHER_SUITES), + ByteEnumField("compression_method", TLSCompressionMethod.NULL, TLS_COMPRESSION_METHODS), + StrConditionalField(XFieldLenField("extensions_length", None, length_of="extensions", fmt="H"), + lambda pkt, s, val: True if val or pkt.extensions or (s and struct.unpack("!H", s[:2])[0] == len(s) - 2) else False), + TypedPacketListField("extensions", None, TLSExtension, length_from=lambda x:x.extensions_length, type_="DTLSServerHello")] + +class DTLSCertificateList(Packet): + name = "DTLS Certificate List" + fields_desc = [] + + def guess_payload_class(self, payload): + tls13_cert = TLS13Certificate(payload) + tls10_cert = TLS10Certificate(payload) + certs_len = lambda certs: len(b"".join([str(cert) for cert in certs.certificates])) + if tls13_cert.request_context_length == len(tls13_cert.request_context) and tls13_cert.length == certs_len(tls13_cert): + return TLS13Certificate + elif tls10_cert.length == certs_len(tls10_cert): + return TLS10Certificate + else: + pkt = self.underlayer + # If our underlayer is a handshake, use the tls_ctx to determine + # whether we are using a tls 1.3 cert or an older version + if pkt is not None and pkt.haslayer(TLSHandshake): + if pkt.tls_ctx is not None: + if pkt.tls_ctx.negotiated.version >= TLSVersion.TLS_1_3: + return TLS13Certificate + return TLS10Certificate + +class DTLSServerKeyExchange(TLSKeyExchange): + name = "DTLS Server Key Exchange" + kex_payload_table = {TLSKexNames.DHE: DTLSServerDHParams, + TLSKexNames.ECDHE: DTLSServerECDHParams} + + def guess_payload_class(self, payload): + dh_params = DTLSServerDHParams(payload) + ecdh_params = DTLSServerECDHParams(payload) + # Try to figure out what is the next Key Exchange layer + if dh_params.p_length == len(dh_params.p) and dh_params.g_length == len(dh_params.g) and \ + dh_params.ys_length == len(dh_params.y_s) and dh_params.sig_length == len(dh_params.sig): + return DTLSServerDHParams + elif ecdh_params.p_length == len(ecdh_params.p) and ecdh_params.sig_length == len(ecdh_params.sig): + return DTLSServerECDHParams + # If we don't have a match, fallback to the standard mechanism + else: + return TLSKeyExchange.guess_payload_class(payload) + + +class DTLSServerHelloDone(PacketNoPayload): + name = "DTLS Server Hello Done" + fields_desc = [] + +class DTLSClientKeyExchange(TLSKeyExchange): + name = "DTLS Client Key Exchange" + kex_payload_table = {TLSKexNames.RSA: TLSClientRSAParams, + TLSKexNames.DHE: TLSClientDHParams, + TLSKexNames.ECDHE: TLSClientECDHParams} + + def guess_payload_class(self, payload): + ecdh_params = TLSClientECDHParams(payload) + # Try to figure out what is the next Key Exchange layer. Can only do this for ECDHE, + # since RSA and DHE parse in exactly the same way. + if ecdh_params.length == len(ecdh_params.data) and ecdh_params.data.startswith(b"\x04"): + return TLSClientECDHParams + else: + return TLSKeyExchange.guess_payload_class(self, payload) + +class DTLSCertificateVerify(PacketNoPayload): + name = "DTLS Certificate Verify " + fields_desc = [XFieldLenField("sig_length", None, length_of="sig", fmt="H"), # ASN.1 signature element + StrLenField("sig", "", length_from=lambda x:x.sig_length)] + +class DTLSFinished(PacketNoPayload): + name = "DTLS Handshake Finished" + fields_desc = [StrLenField("data", "", length_from=lambda x:x.underlayer.length)] + +class DTLSChangeCipherSpec(TLSDecryptablePacket): + name = "DTLS ChangeCipherSpec" + fields_desc = [StrField("message", '\x01', fmt="H")] + +class DTLSHelloRequest(Packet): + name = "DTLS Hello Request" + fields_desc = [] + + +SSLv2_CERTIFICATE_TYPES = {0x01: 'x509'} +SSLv2CertificateType = EnumStruct(SSLv2_CERTIFICATE_TYPES) SSLv2_MESSAGE_TYPES = { 0x01: 'client_hello', @@ -1306,7 +1419,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.close() def do_handshake(self, version, ciphers, extensions=[]): - return tls_do_handshake(self, version, ciphers, extensions) + if version == TLSVersion.DTLS_1_0: + return dtls_do_handshake(self, version, ciphers, extensions) + else: + return tls_do_handshake(self, version, ciphers, extensions) def do_round_trip(self, pkt, recv=True): return tls_do_round_trip(self, pkt, recv) @@ -1335,7 +1451,7 @@ def from_records(cls, records, ctx=None): def pre_dissect(self, raw_bytes): # figure out if we're UDP or TCP - if self.underlayer is not None and self.underlayer.haslayer(UDP): + if ord(raw_bytes[1]) == 0xfe and ord(raw_bytes[2]) == 0xff: self.guessed_next_layer = DTLSRecord elif ord(raw_bytes[0]) & 0x80: self.guessed_next_layer = SSLv2Record @@ -1355,7 +1471,10 @@ def do_dissect(self, raw_bytes): while pos < len(raw_bytes) - record_header_len: payload_len = record(raw_bytes[pos:pos + record_header_len]).length if self.tls_ctx is not None: - payload = record(raw_bytes[pos:pos + record_header_len + payload_len], ctx=self.tls_ctx) + if record == DTLSRecord: + payload = record(raw_bytes[pos:pos + record_header_len + payload_len]) + else: + payload = record(raw_bytes[pos:pos + record_header_len + payload_len], ctx=self.tls_ctx) # Perform inline decryption if required payload = self.do_decrypt_payload(payload) self.tls_ctx.insert(payload, origin=self._origin) @@ -1372,7 +1491,7 @@ def do_dissect(self, raw_bytes): def do_decrypt_payload(self, record): content_type = None encrypted_payload, layer = self._get_encrypted_payload(record) - if encrypted_payload is not None or self.tls_ctx.negotiated.version >= TLSVersion.TLS_1_3: + if encrypted_payload is not None or (self.tls_ctx.negotiated.version >= TLSVersion.TLS_1_3 and self.tls_ctx.negotiated.version != TLSVersion.DTLS_1_0): try: if self.tls_ctx.client: cleartext = self.tls_ctx.server_ctx.crypto_ctx.decrypt(encrypted_payload, @@ -1430,7 +1549,9 @@ def find_padding_start(payload, padding_byte=b"\x00"): cleartext_handler = {TLSPlaintext: lambda pkt, tls_ctx: (TLSContentType.APPLICATION_DATA, pkt[TLSPlaintext].data), TLSChangeCipherSpec: lambda pkt, tls_ctx: (TLSContentType.CHANGE_CIPHER_SPEC, str(pkt[TLSChangeCipherSpec])), TLSAlert: lambda pkt, tls_ctx: (TLSContentType.ALERT, str(pkt[TLSAlert])), #} - TLSHandshakes: lambda pkt, tls_ctx: (TLSContentType.HANDSHAKE, str(pkt[TLSHandshakes]))} + TLSHandshakes: lambda pkt, tls_ctx: (TLSContentType.HANDSHAKE, str(pkt[TLSHandshakes])), + DTLSChangeCipherSpec: lambda pkt, tls_ctx: (TLSContentType.CHANGE_CIPHER_SPEC, str(pkt[DTLSChangeCipherSpec])), + DTLSHandshake: lambda pkt, tls_ctx: (TLSContentType.HANDSHAKE, str(pkt[DTLSHandshake]))} def to_raw(pkt, tls_ctx, include_record=True, compress_hook=None, pre_encrypt_hook=None, encrypt_hook=None): @@ -1468,8 +1589,11 @@ def to_raw(pkt, tls_ctx, include_record=True, compress_hook=None, pre_encrypt_ho ciphertext = ctx.crypto_ctx.encrypt(crypto_container) if include_record: - if tls_ctx.negotiated.version >= TLSVersion.TLS_1_3: + if tls_ctx.negotiated.version >= TLSVersion.TLS_1_3 and tls_ctx.negotiated.version != TLSVersion.DTLS_1_0: tls_ciphertext = TLSRecord(content_type=TLSContentType.APPLICATION_DATA) / ciphertext + elif tls_ctx.negotiated.version == TLSVersion.DTLS_1_0: + tls_ciphertext = DTLSRecord(version=tls_ctx.negotiated.version, content_type=content_type, sequence=0, + epoch=1) / ciphertext else: tls_ciphertext = TLSRecord(version=tls_ctx.negotiated.version, content_type=content_type) / ciphertext else: @@ -1534,6 +1658,32 @@ def tls_do_handshake(tls_socket, version, ciphers, extensions=[]): raise NotImplementedError("Do handshake not implemented for TLS 1.3") +def dtls_do_handshake(tls_socket, version, ciphers, extensions=[]): + if version == TLSVersion.DTLS_1_0: + client_hello = DTLSRecord(version=version, sequence=0) / \ + DTLSHandshake(fragment_offset=0) / \ + DTLSClientHello(version=version, + compression_methods=TLSCompressionMethod.NULL, + cipher_suites=ciphers, + extensions=extensions) + resp1 = tls_do_round_trip(tls_socket, client_hello) + client_key_exchange = DTLSRecord(version=version, sequence=1) / \ + DTLSHandshake(fragment_offset=0, sequence=1) / DTLSClientKeyExchange() / \ + tls_socket.tls_ctx.get_client_kex_data() + + client_ccs = DTLSRecord(version=version, sequence=2) / DTLSChangeCipherSpec() + tls_do_round_trip(tls_socket, TLS.from_records([client_key_exchange, client_ccs]), False) + + client_finished = DTLSRecord(version=version, sequence=0, epoch=1) / \ + DTLSHandshake(fragment_offset=0, sequence=2) / \ + DTLSFinished(data=tls_socket.tls_ctx.get_verify_data()) + + resp2 = tls_do_round_trip(tls_socket, client_finished) + return resp1, resp2 + else: + raise NotImplementedError("Invalid DTLS Version") + + def tls_fragment_payload(pkt, record=None, size=2**14): if size <= 0: raise ValueError("Fragment size must be strictly positive") @@ -1585,6 +1735,7 @@ def tls_draft_version(draft_version): bind_layers(TLSHandshake, TLSSessionTicket, {'type': TLSHandshakeType.NEWSESSIONTICKET}) bind_layers(TLSHandshake, TLSCertificateRequest, {"type": TLSHandshakeType.CERTIFICATE_REQUEST}) bind_layers(TLSHandshake, TLSCertificateVerify, {"type": TLSHandshakeType.CERTIFICATE_VERIFY}) +bind_layers(TLSHandshake, TLS12CertificateVerify, {"type": TLSHandshakeType.CERTIFICATE_VERIFY}) bind_layers(TLSHandshake, TLSEncryptedExtensions, {"type": TLSHandshakeType.ENCRYPTED_EXTENSIONS}) # <--- @@ -1610,7 +1761,16 @@ def tls_draft_version(draft_version): # DTLSRecord bind_layers(DTLSRecord, DTLSHandshake, {'content_type': TLSContentType.HANDSHAKE}) +bind_layers(DTLSHandshake, DTLSHelloRequest, {'type': TLSHandshakeType.HELLO_REQUEST}) bind_layers(DTLSHandshake, DTLSClientHello, {'type': TLSHandshakeType.CLIENT_HELLO}) +bind_layers(DTLSHandshake, DTLSServerHello, {'type': TLSHandshakeType.SERVER_HELLO}) +bind_layers(DTLSHandshake, DTLSServerKeyExchange, {'type': TLSHandshakeType.SERVER_KEY_EXCHANGE}) +bind_layers(DTLSHandshake, DTLSServerHelloDone, {'type': TLSHandshakeType.SERVER_HELLO_DONE}) +bind_layers(DTLSHandshake, DTLSCertificateList, {'type': TLSHandshakeType.CERTIFICATE}) +bind_layers(DTLSHandshake, DTLSClientKeyExchange, {'type': TLSHandshakeType.CLIENT_KEY_EXCHANGE}) +bind_layers(DTLSRecord, DTLSChangeCipherSpec, {'content_type': TLSContentType.CHANGE_CIPHER_SPEC}) +bind_layers(DTLSHandshake, DTLSCertificateVerify, {"type": TLSHandshakeType.CERTIFICATE_VERIFY}) +bind_layers(DTLSHandshake, DTLSFinished, {'type': TLSHandshakeType.FINISHED}) # SSLv2 bind_layers(SSLv2Record, SSLv2ServerHello, {'content_type': SSLv2MessageType.SERVER_HELLO}) diff --git a/scapy_ssl_tls/ssl_tls_crypto.py b/scapy_ssl_tls/ssl_tls_crypto.py index be1b919..626b337 100644 --- a/scapy_ssl_tls/ssl_tls_crypto.py +++ b/scapy_ssl_tls/ssl_tls_crypto.py @@ -20,10 +20,12 @@ import tinyec.registry as ec_reg from collections import namedtuple + from Cryptodome.Cipher import AES, ARC2, ARC4, DES, DES3, PKCS1_v1_5 from Cryptodome.Hash import HMAC, MD5, SHA, SHA256, SHA384 from Cryptodome.PublicKey import DSA, RSA from Cryptodome.Signature import PKCS1_v1_5 as Sig_PKCS1_v1_5 +from scapy_ssl_tls import multidigest_pkcs1_15 as Sig_multi_PKCS1_v1_5 from scapy.packet import Raw # Added this to get all certificate dissection to work OK, without the need to import this in the client script @@ -326,9 +328,9 @@ def __handle_server_hello(self, server_hello): tls.TLS_CIPHER_SUITES.get(self.negotiated.ciphersuite, "UNKNOWN"))) self.negotiated.encryption = (self.cipher_properties["cipher"]["name"], self.cipher_properties["cipher"]["key_len"], self.cipher_properties["cipher"]["mode_name"]) - self.requires_iv = True if tls.TLSVersion.TLS_1_0 < self.negotiated.version < tls.TLSVersion.TLS_1_3 else False + self.requires_iv = True if ((tls.TLSVersion.TLS_1_0 < self.negotiated.version < tls.TLSVersion.TLS_1_3) or self.negotiated.version == tls.TLSVersion.DTLS_1_0) else False - if self.negotiated.version < tls.TLSVersion.TLS_1_3: + if self.negotiated.version < tls.TLSVersion.TLS_1_3 or self.negotiated.version == tls.TLSVersion.DTLS_1_0: self.__handle_tls12_server_hello(server_hello) # TlS 1.3 case. Extract KEX data from KeyShare extension else: @@ -383,6 +385,37 @@ def __handle_server_kex(self, server_kex): else: warnings.warn("Unknown server key exchange") + def __handle_dtls_server_kex(self, server_kex): + # DHE case + if server_kex.haslayer(tls.DTLSServerDHParams): + if isinstance(self.server_ctx.kex_keystore, tlsk.EmptyKexKeystore): + p = tlsk.str_to_int(server_kex[tls.DTLSServerDHParams].p) + g = tlsk.str_to_int(server_kex[tls.DTLSServerDHParams].g) + public = tlsk.str_to_int(server_kex[tls.DTLSServerDHParams].y_s) + self.server_ctx.kex_keystore = tlsk.DHKeyStore(g, p, public) + elif server_kex.haslayer(tls.DTLSServerECDHParams): + if isinstance(self.server_ctx.kex_keystore, tlsk.EmptyKexKeystore): + try: + curve_id = server_kex[tls.DTLSServerECDHParams].curve_name + # TODO: DO NOT assume uncompressed EC points! + point = tlsk.ansi_str_to_point(server_kex[tls.DTLSServerECDHParams].p) + curve_name = tls.TLS_SUPPORTED_GROUPS[curve_id] + # Unknown curve case. Just record raw values, but do nothing with them + except KeyError: + self.server_ctx.kex_keystore = tlsk.ECDHKeyStore(None, point) + warnings.warn("Unknown elliptic curve id: %d. Client KEX calculation is up to you" % curve_id) + # We are on a known curve + else: + try: + curve = ec_reg.get_curve(curve_name) + self.server_ctx.kex_keystore = tlsk.ECDHKeyStore(curve, ec.Point(curve, *point)) + except ValueError: + self.server_ctx.kex_keystore = tlsk.ECDHKeyStore(None, point) + warnings.warn("Unsupported elliptic curve: %s" % curve_name) + else: + warnings.warn("Unknown server key exchange") + + def __handle_client_kex(self, client_kex): # Walk around a bug where tls_ctx is not defined, thus prevents correct parsing # of the TLSKeyExchange by the upper layer. Dodgy, but I don't see anyway around it @@ -467,7 +500,7 @@ def __handle_ccs(self, ccs, origin): self.__ccs_count += 1 def __handle_finished(self, finished): - if self.negotiated.version >= tls.TLSVersion.TLS_1_3: + if self.negotiated.version >= tls.TLSVersion.TLS_1_3 and self.negotiated.version != tls.TLSVersion.DTLS_1_0: ctx = self.client_ctx verify_data = self.derive_client_finished() # This is the first finished in the connection, coming from the server. Transition to traffic secrets @@ -513,23 +546,37 @@ def _process(self, pkt, origin=None): """ fill context """ - if pkt.haslayer(tls.TLSHandshake): + if pkt.haslayer(tls.TLSHandshake) or pkt.haslayer(tls.DTLSHandshake): # requires handshake messages if pkt.haslayer(tls.TLSClientHello): self.__handle_client_hello(pkt[tls.TLSClientHello]) + if pkt.haslayer(tls.DTLSClientHello): + self.__handle_client_hello(pkt[tls.DTLSClientHello]) if pkt.haslayer(tls.TLSServerHello): self.__handle_server_hello(pkt[tls.TLSServerHello]) + if pkt.haslayer(tls.DTLSServerHello): + self.__handle_server_hello(pkt[tls.DTLSServerHello]) if pkt.haslayer(tls.TLSCertificateList): self.__handle_cert_list(pkt[tls.TLSCertificateList]) + if pkt.haslayer(tls.DTLSCertificateList): + self.__handle_cert_list(pkt[tls.DTLSCertificateList]) if pkt.haslayer(tls.TLSServerKeyExchange): self.__handle_server_kex(pkt[tls.TLSServerKeyExchange]) + if pkt.haslayer(tls.DTLSServerKeyExchange): + self.__handle_dtls_server_kex(pkt[tls.DTLSServerKeyExchange]) if pkt.haslayer(tls.TLSClientKeyExchange): self.__handle_client_kex(pkt[tls.TLSClientKeyExchange]) + if pkt.haslayer(tls.DTLSClientKeyExchange): + self.__handle_client_kex(pkt[tls.DTLSClientKeyExchange]) if pkt.haslayer(tls.TLSFinished): self.__handle_finished(pkt[tls.TLSFinished]) + if pkt.haslayer(tls.DTLSFinished): + self.__handle_finished(pkt[tls.DTLSFinished]) self.__handle_session_ticket(pkt) if pkt.haslayer(tls.TLSChangeCipherSpec): self.__handle_ccs(pkt[tls.TLSChangeCipherSpec], origin=origin) + if pkt.haslayer(tls.DTLSChangeCipherSpec): + self.__handle_ccs(pkt[tls.DTLSChangeCipherSpec], origin=origin) def _generate_random_pms(self, version): return "%s%s" % (struct.pack("!H", version), os.urandom(46)) @@ -572,11 +619,20 @@ def get_client_ecdh_pubkey(self, private=None): def get_client_kex_data(self, val=None): if self.negotiated.key_exchange == tls.TLSKexNames.RSA: - return tls.TLSClientKeyExchange() / tls.TLSClientRSAParams(data=self.get_encrypted_pms(val)) + if self.negotiated.version == tls.TLSVersion.DTLS_1_0: + return tls.DTLSClientKeyExchange() / tls.TLSClientRSAParams(data=self.get_encrypted_pms(val)) + else: + return tls.TLSClientKeyExchange() / tls.TLSClientRSAParams(data=self.get_encrypted_pms(val)) elif self.negotiated.key_exchange == tls.TLSKexNames.DHE: - return tls.TLSClientKeyExchange() / tls.TLSClientDHParams(data=self.get_client_dh_pubkey(val)) + if self.negotiated.version == tls.TLSVersion.DTLS_1_0: + return tls.DTLSClientKeyExchange() / tls.TLSClientDHParams(data=self.get_client_dh_pubkey(val)) + else: + return tls.TLSClientKeyExchange() / tls.TLSClientDHParams(data=self.get_client_dh_pubkey(val)) elif self.negotiated.key_exchange == tls.TLSKexNames.ECDHE: - return tls.TLSClientKeyExchange() / tls.TLSClientECDHParams(data=self.get_client_ecdh_pubkey(val)) + if self.negotiated.version == tls.TLSVersion.DTLS_1_0: + return tls.DTLSClientKeyExchange() / tls.TLSClientECDHParams(data=self.get_client_ecdh_pubkey(val)) + else: + return tls.TLSClientKeyExchange() / tls.TLSClientECDHParams(data=self.get_client_ecdh_pubkey(val)) else: raise NotImplementedError("Key exchange unknown or currently not supported") @@ -600,6 +656,10 @@ def _walk_handshake_msgs(self): for handshake in pkt[tls.TLSHandshakes].handshakes: if not handshake.haslayer(tls.TLSHelloRequest): yield handshake + if pkt.haslayer(tls.DTLSHandshake): + for handshake in pkt[tls.DTLSHandshake]: + if not handshake.haslayer(tls.DTLSHelloRequest): + yield handshake def _derive_finished(self, secret, hash_): return HMAC.new(secret, hash_, digestmod=self.prf.digest).digest() @@ -615,7 +675,7 @@ def derive_client_finished(self): return self._derive_finished(self.client_ctx.finished_secret, self.get_handshake_hash(self.prf.digest, tls.TLSFinished, True)) def get_verify_data(self, data=None): - if self.negotiated.version >= tls.TLSVersion.TLS_1_3: + if self.negotiated.version >= tls.TLSVersion.TLS_1_3 and self.negotiated.version != tls.TLSVersion.DTLS_1_0: if self.client: prf_verify_data = self.derive_client_finished() else: @@ -632,6 +692,10 @@ def get_verify_data(self, data=None): # Special case of encrypted handshake. Remove crypto material to compute verify_data verify_data.append("%s%s%s" % (chr(handshake.type), struct.pack(">I", handshake.length)[1:], handshake[tls.TLSFinished].data)) + elif handshake.haslayer(tls.DTLSFinished): + # Special case of encrypted handshake. Remove crypto material to compute verify_data + verify_data.append("%s%s%s" % (chr(handshake.type), struct.pack(">I", handshake.length)[1:], + handshake[tls.DTLSFinished].data)) else: verify_data.append(str(handshake)) else: @@ -667,8 +731,13 @@ def get_client_signed_handshake_hash(self, hash_=SHA256.new(), pre_sign_hook=lam """Legacy way to get the certificate verify hash. Added sig as last parameter to preserve prior use""" if self.client_ctx.asym_keystore.private is None: raise RuntimeError("Missing client private key. Can't sign") - msg_hash = self.get_handshake_digest(hash_) - msg_hash = pre_sign_hook(msg_hash) + + if self.negotiated.version == tls.TLSVersion.TLS_1_2: + msg_hash = self.get_handshake_digest(hash_) + msg_hash = pre_sign_hook(msg_hash) + else: + msg_hash = self.get_handshake_digest(MD5.new()).digest() + self.get_handshake_digest(SHA.new()).digest() + # Will throw exception if we can't sign or if data is larger the modulus return sig.new(self.client_ctx.asym_keystore.private).sign(msg_hash) @@ -687,15 +756,18 @@ def compute_server_cert_verify(self, sig=Sig_PKCS1_v1_5, digest=SHA256, pre_sign return self._compute_cert_verify(self.server_ctx, hash_, sig_label, sig, digest, pre_sign_hook) def compute_client_cert_verify(self, sig=Sig_PKCS1_v1_5, digest=SHA256, pre_sign_hook=lambda x: x): - if self.negotiated.version >= tls.TLSVersion.TLS_1_3: + '''For TLS1.0, TLS1.1 and DTLS1.0, the certificate verify message digest is calculated by joining the MD5 and SHA1 digest of the handshake messages''' + if self.negotiated.version >= tls.TLSVersion.TLS_1_3 and self.negotiated.version != tls.TLSVersion.DTLS_1_0: sig_label = b"TLS 1.3, client CertificateVerify" if self.prf is None: raise RuntimeError("PRF must be initialized prior to computing TLS 1.3 signature") # TODO: calculate handshake hash properly until the second tls.TLSCertificateList for client based certs hash_ = self.get_handshake_hash(self.prf.digest, tls.TLSCertificateList) return self._compute_cert_verify(self.server_ctx, hash_, sig_label, sig, digest, pre_sign_hook) - else: + elif self.negotiated.version == tls.TLSVersion.TLS_1_2: return self.get_client_signed_handshake_hash(digest.new(), pre_sign_hook, sig) + else: + return self.get_client_signed_handshake_hash(None, pre_sign_hook, Sig_multi_PKCS1_v1_5) def set_mode(self, client=None, server=None): self.client = client if client else not server @@ -728,7 +800,7 @@ def __init__(self, tls_version, digest=None): self.digest = digest def get_bytes(self, key, label, random, num_bytes): - if self.tls_version >= tls.TLSVersion.TLS_1_2: + if self.tls_version >= tls.TLSVersion.TLS_1_2 and self.tls_version != tls.TLSVersion.DTLS_1_0: bytes_ = self._get_bytes(self.digest, key, label, random, num_bytes) else: key_len = (len(key) + 1) // 2 @@ -1220,7 +1292,11 @@ def from_data(cls, tls_ctx, ctx, data): return CBCCryptoContainer.from_context(tls_ctx, ctx, crypto_data) def __mac(self): - sequence_ = struct.pack("!Q", self.crypto_data.sequence) + if self.crypto_data.version == tls.TLSVersion.DTLS_1_0: + epoch_ = 1 + sequence_ = struct.pack("!H6B", epoch_,0, 0, 0, 0, 0, 0) + else: + sequence_ = struct.pack("!Q", self.crypto_data.sequence) content_type_ = struct.pack("!B", self.crypto_data.content_type) version_ = struct.pack("!H", self.crypto_data.version) len_ = struct.pack("!H", self.crypto_data.data_len)