Skip to content

Commit 13a23e4

Browse files
committed
Optionally restrict the range of ephemeral ports
1 parent 3cd4bcc commit 13a23e4

File tree

3 files changed

+99
-6
lines changed

3 files changed

+99
-6
lines changed

src/aioice/ice.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
import socket
99
import threading
1010
from itertools import count
11-
from typing import Dict, List, Optional, Set, Text, Tuple, Union, cast
11+
from typing import Dict, Iterable, List, Optional, Set, Text, Tuple, Union, cast
1212

1313
import netifaces
1414

1515
from . import mdns, stun, turn
1616
from .candidate import Candidate, candidate_foundation, candidate_priority
17-
from .utils import random_string
17+
from .utils import create_datagram_endpoint, random_string
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -297,6 +297,7 @@ class Connection:
297297
:param use_ipv4: Whether to use IPv4 candidates.
298298
:param use_ipv6: Whether to use IPv6 candidates.
299299
:param transport_policy: Transport policy.
300+
:param ephemeral_ports: Set of allowed ephemeral local ports to bind to.
300301
"""
301302

302303
def __init__(
@@ -312,6 +313,7 @@ def __init__(
312313
use_ipv4: bool = True,
313314
use_ipv6: bool = True,
314315
transport_policy: TransportPolicy = TransportPolicy.ALL,
316+
ephemeral_ports: Optional[Iterable[int]] = None,
315317
) -> None:
316318
self.ice_controlling = ice_controlling
317319
#: Local username, automatically set to a random value.
@@ -357,6 +359,7 @@ def __init__(
357359
self._tie_breaker = secrets.randbits(64)
358360
self._use_ipv4 = use_ipv4
359361
self._use_ipv6 = use_ipv6
362+
self._ephemeral_ports = ephemeral_ports
360363

361364
if (
362365
stun_server is None
@@ -876,16 +879,14 @@ async def get_component_candidates(
876879
self, component: int, addresses: List[str], timeout: int = 5
877880
) -> List[Candidate]:
878881
candidates = []
879-
loop = asyncio.get_event_loop()
880882

881883
# gather host candidates
882884
host_protocols = []
883885
for address in addresses:
884886
# create transport
885887
try:
886-
transport, protocol = await loop.create_datagram_endpoint(
887-
lambda: StunProtocol(self), local_addr=(address, 0)
888-
)
888+
transport, protocol = await create_datagram_endpoint(
889+
lambda: StunProtocol(self), local_address=address, local_ports=self._ephemeral_ports)
889890
sock = transport.get_extra_info("socket")
890891
if sock is not None:
891892
sock.setsockopt(

src/aioice/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import asyncio
12
import os
3+
import random
24
import secrets
35
import string
6+
from typing import Iterable, Optional, Tuple
47

58

69
def random_string(length: int) -> str:
@@ -10,3 +13,35 @@ def random_string(length: int) -> str:
1013

1114
def random_transaction_id() -> bytes:
1215
return os.urandom(12)
16+
17+
18+
async def create_datagram_endpoint(protocol_factory,
19+
remote_addr: Tuple[str, int] = None,
20+
local_address: str = None,
21+
local_ports: Optional[Iterable[int]] = None,
22+
):
23+
"""
24+
Asynchronousley create a datagram endpoint.
25+
26+
:param protocol_factory: Callable returning a protocol instance.
27+
:param remote_addr: Remote address and port.
28+
:param local_address: Local address to bind to.
29+
:param local_ports: Set of allowed local ports to bind to.
30+
"""
31+
if local_ports is not None:
32+
ports = list(local_ports)
33+
random.shuffle(ports)
34+
else:
35+
ports = (0,)
36+
loop = asyncio.get_event_loop()
37+
for port in ports:
38+
try:
39+
transport, protocol = await loop.create_datagram_endpoint(
40+
protocol_factory, remote_addr=remote_addr, local_addr=(local_address, port)
41+
)
42+
return transport, protocol
43+
except OSError as exc:
44+
if port == ports[-1]:
45+
# this was the last port, give up
46+
raise exc
47+
raise ValueError("local_ports must not be empty")

tests/test_ice.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import functools
33
import os
4+
import random
45
import socket
56
import unittest
67
from unittest import mock
@@ -1249,6 +1250,62 @@ async def test_repr(self):
12491250
conn._id = 1
12501251
self.assertEqual(repr(conn), "Connection(1)")
12511252

1253+
@asynctest
1254+
async def test_connection_ephemeral_ports(self):
1255+
addresses = ["127.0.0.1"]
1256+
1257+
# Let the OS pick a random port - should always yield a candidate
1258+
conn1 = ice.Connection(ice_controlling=True)
1259+
c = await conn1.get_component_candidates(0, addresses)
1260+
self.assertTrue(c[0].port >= 1 and c[0].port <= 65535)
1261+
1262+
# Try opening a new connection with the same port - should never yield candidates
1263+
conn2 = ice.Connection(ice_controlling=True, ephemeral_ports=[c[0].port])
1264+
c = await conn2.get_component_candidates(0, addresses)
1265+
self.assertEqual(len(c), 0) # port already in use, no candidates
1266+
await conn1.close()
1267+
1268+
# Empty set of ports - illegal argument
1269+
conn3 = ice.Connection(ice_controlling=True, ephemeral_ports=[])
1270+
with self.assertRaises(ValueError):
1271+
await conn3.get_component_candidates(0, addresses)
1272+
1273+
# Range of 100 ports
1274+
lower = random.randint(1024, 65536 - 100)
1275+
upper = lower + 100
1276+
ports = set(range(lower, upper)) - set([5353])
1277+
1278+
# Exhaust the range of ports - should always yield candidates
1279+
conns = []
1280+
for i in range(0, len(ports)):
1281+
conn = ice.Connection(ice_controlling=True, ephemeral_ports=ports)
1282+
c = await conn.get_component_candidates(i, addresses)
1283+
if c:
1284+
self.assertTrue(c[0].port >= lower and c[0].port < upper)
1285+
conns.append(conn)
1286+
self.assertGreaterEqual(len(conns), len(ports) - 1) # account for at most 1 port in use by another process
1287+
1288+
# Open one more connection from the same range - should never yield candidates
1289+
conn = ice.Connection(ice_controlling=True, ephemeral_ports=ports)
1290+
c = await conn.get_component_candidates(0, addresses)
1291+
self.assertEqual(len(c), 0) # all ports are exhausted, no candidates
1292+
1293+
# Close one connection and try again - should always yield a candidate
1294+
await conns.pop().close()
1295+
conn = ice.Connection(ice_controlling=True, ephemeral_ports=ports)
1296+
c = await conn.get_component_candidates(0, addresses)
1297+
self.assertTrue(c[0].port >= lower and c[0].port < upper)
1298+
await conn.close()
1299+
1300+
# cleanup
1301+
for conn in conns:
1302+
await conn.close()
1303+
1304+
# Bind to wildcard local address - should always yield a candidate
1305+
conn = ice.Connection(ice_controlling=True)
1306+
c = await conn.get_component_candidates(0, [None])
1307+
self.assertTrue(c[0].port >= 1 and c[0].port <= 65535)
1308+
await conn.close()
12521309

12531310
class StunProtocolTest(unittest.TestCase):
12541311
@asynctest

0 commit comments

Comments
 (0)