Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions openwpm/socket_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,44 @@ def send(self, msg):
raise RuntimeError("socket connection broken")
totalsent = totalsent + sent

def receive(self, timeout: float = 5.0) -> Any:
"""Receive a single message from the server.

Uses the same wire format as send() (4-byte length + 1-byte
serialization type + payload). Returns the deserialized message.

Parameters
----------
timeout : float
Socket timeout in seconds. Returns None if no data arrives
within the timeout.
"""
old_timeout = self.sock.gettimeout()
self.sock.settimeout(timeout)
try:
header = self._recv_exactly(5)
if header is None:
return None
msglen, serialization = struct.unpack(">Lc", header)
payload = self._recv_exactly(msglen)
if payload is None:
return None
return _parse(serialization, payload)
except socket.timeout:
return None
finally:
self.sock.settimeout(old_timeout)

def _recv_exactly(self, n: int) -> Any:
"""Receive exactly n bytes from the socket."""
data = b""
while len(data) < n:
chunk = self.sock.recv(n - len(data))
if not chunk:
return None
data += chunk
return data

def close(self):
self.sock.close()

Expand All @@ -184,6 +222,18 @@ async def get_message_from_reader(reader: asyncio.StreamReader) -> Any:
return _parse(serialization, msg)


async def send_to_writer(writer: asyncio.StreamWriter, msg: Any) -> None:
"""Send a JSON-serialized message to an asyncio StreamWriter.

Uses the same wire format as ClientSocket.send() so that
ClientSocket can receive responses using the same protocol.
"""
encoded = json.dumps(msg).encode("utf-8")
header = struct.pack(">Lc", len(encoded), b"j")
writer.write(header + encoded)
await writer.drain()


def _parse(serialization: bytes, msg: bytes) -> Any:
if serialization == b"n":
return msg
Expand Down
62 changes: 57 additions & 5 deletions openwpm/storage/storage_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from openwpm.utilities.multiprocess_utils import Process

from ..config import BrowserParamsInternal, ManagerParamsInternal
from ..socket_interface import ClientSocket, get_message_from_reader
from ..socket_interface import ClientSocket, get_message_from_reader, send_to_writer
from ..types import BrowserId, VisitId
from .storage_providers import (
StructuredStorageProvider,
Expand Down Expand Up @@ -101,15 +101,15 @@ async def _handler(
await writer.wait_closed()

async def handler(
self, reader: asyncio.StreamReader, _: asyncio.StreamWriter
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
"""Created for every new connection to the Server"""
client_name = await get_message_from_reader(reader)
self.logger.info(f"Initializing new handler for {client_name}")
while True:
try:
record: Tuple[str, Any] = await get_message_from_reader(reader)
except IncompleteReadError:
except (IncompleteReadError, OSError):
self.logger.info(
f"Terminating handler for {client_name}, because the underlying socket closed"
)
Expand Down Expand Up @@ -153,7 +153,7 @@ async def handler(
visit_id = VisitId(data["visit_id"])

if record_type == RECORD_TYPE_META:
await self._handle_meta(visit_id, data)
await self._handle_meta(visit_id, data, writer)
continue

table_name = TableName(record_type)
Expand All @@ -174,7 +174,12 @@ async def store_record(
)
)

async def _handle_meta(self, visit_id: VisitId, data: Dict[str, Any]) -> None:
async def _handle_meta(
self,
visit_id: VisitId,
data: Dict[str, Any],
writer: asyncio.StreamWriter,
) -> None:
"""
Messages for the table RECORD_TYPE_SPECIAL are meta information
communicated to the storage controller
Expand All @@ -192,6 +197,19 @@ async def _handle_meta(self, visit_id: VisitId, data: Dict[str, Any]) -> None:
success: bool = data["success"]
completion_token = await self.finalize_visit_id(visit_id, success)
self.finalize_tasks.append((visit_id, completion_token, success))
# Send ack back only if the client requested it.
# Writing to a closed connection poisons the asyncio transport,
# preventing any further reads on the same connection.
if data.get("want_ack"):
try:
await send_to_writer(
writer,
{"action": "finalize_ack", "visit_id": visit_id},
)
except Exception:
self.logger.debug(
"Failed to send finalize ack for visit_id %d", visit_id
)
else:
raise ValueError("Unexpected action: %s", action)

Expand Down Expand Up @@ -413,6 +431,40 @@ def finalize_visit_id(self, visit_id: VisitId, success: bool) -> None:
)
)

def finalize_visit_id_with_ack(
self, visit_id: VisitId, success: bool, timeout: float = 10.0
) -> bool:
"""Send finalize and wait for acknowledgment from StorageController.

Returns True if ack was received, False on timeout.
Falls back gracefully - the finalize is still sent even if
the ack is not received.
"""
self.socket.send(
(
RECORD_TYPE_META,
{
"action": ACTION_TYPE_FINALIZE,
"visit_id": visit_id,
"success": success,
"want_ack": True,
},
)
)
ack = self.socket.receive(timeout=timeout)
if (
ack is not None
and isinstance(ack, dict)
and ack.get("action") == "finalize_ack"
):
return True
self.logger.debug(
"Did not receive finalize ack for visit_id %d (got: %r)",
visit_id,
ack,
)
return False

def close(self) -> None:
self.socket.close()

Expand Down
8 changes: 6 additions & 2 deletions openwpm/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,12 @@ def store_record(
self.sock.store_record(table, visit_id, data)

def finalize_visit_id(self, visit_id: VisitId, success: bool) -> None:
"""Signal that all data for a visit_id has been sent."""
self.sock.finalize_visit_id(visit_id, success)
"""Signal that all data for a visit_id has been sent.

Waits for acknowledgment from StorageController to confirm
the data has been processed. Falls back gracefully on timeout.
"""
self.sock.finalize_visit_id_with_ack(visit_id, success)

def _check_failure_status(self) -> None:
"""Check the status of command failures. Raise exceptions as necessary
Expand Down
Loading