Skip to content
Open
3 changes: 3 additions & 0 deletions src/aioice/stun.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,9 @@ async def run(self) -> tuple[Message, tuple[str, int]]:
self.__timeout_handle.cancel()

def __retry(self) -> None:
if self.__future.done():
return

if self.__tries >= self.__tries_max:
self.__future.set_exception(TransactionTimeout())
return
Expand Down
38 changes: 38 additions & 0 deletions tests/test_stun.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import asyncio
import unittest
from binascii import unhexlify
from collections import OrderedDict
from typing import Optional

from utils import detect_exceptions_in_loop

from aioice import stun

Expand Down Expand Up @@ -260,3 +264,37 @@ def send_stun(
message_method=stun.Method.BINDING, message_class=stun.Class.RESPONSE
)
transaction.response_received(response, ("127.0.0.1", 1234))

@asynctest
async def test_message_resolved_before_timeout(self) -> None:
socket_address = ("127.0.0.1", 1234)
request = stun.Message(
message_method=stun.Method.BINDING, message_class=stun.Class.REQUEST
)
expected_response = stun.Message(
message_method=stun.Method.BINDING,
message_class=stun.Class.RESPONSE,
)

class RespondImmediatelyProtocol:
_transaction: Optional[stun.Transaction] = None

def set_transaction(self, new_transaction: stun.Transaction) -> None:
self._transaction = new_transaction

def send_stun(
self, message: stun.Message, address: tuple[str, int]
) -> None:
asyncio.get_running_loop().call_soon(
self._transaction.response_received, expected_response, address
)

with detect_exceptions_in_loop():
protocol = RespondImmediatelyProtocol()
transaction = stun.Transaction(request, socket_address, protocol, 0)
protocol.set_transaction(transaction)

response_message, response_address = await transaction.run()

self.assertEqual(response_message, expected_response)
self.assertEqual(response_address, socket_address)
46 changes: 44 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import logging
import os
import sys
from asyncio import AbstractEventLoop
from collections.abc import Callable, Coroutine
from contextlib import contextmanager

if sys.version_info >= (3, 10):
from typing import ParamSpec
from typing import ParamSpec, Any, Iterator
else:
from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, Any, Iterator

from aioice import ice

Expand Down Expand Up @@ -51,3 +53,43 @@ def read_message(name: str) -> bytes:

if os.environ.get("AIOICE_DEBUG"):
logging.basicConfig(level=logging.DEBUG)


class CollectExceptionsHandler:
_exceptions: list[Exception] = []

def handle_exception(
self, _loop: AbstractEventLoop, context: dict[str, Any]
) -> None:
exception = context.get("exception")

if exception and isinstance(exception, Exception):
self._exceptions.append(exception)

@property
def exceptions(self) -> list[Exception]:
return self._exceptions


@contextmanager
def new_collect_exceptions_handler() -> Iterator[CollectExceptionsHandler]:
handler = CollectExceptionsHandler()
loop = asyncio.get_event_loop()
original_handler = loop.get_exception_handler()
loop.set_exception_handler(handler.handle_exception)

try:
yield handler
finally:
loop.set_exception_handler(original_handler)


@contextmanager
def detect_exceptions_in_loop() -> Iterator[None]:
with new_collect_exceptions_handler() as handler:
yield None

if handler.exceptions:
raise Exception(
f"Found {len(handler.exceptions)} exceptions on loop."
) from handler.exceptions[0]