From 64af21fe8ecea8ab48f24a3dc827f2059d0c423e Mon Sep 17 00:00:00 2001 From: Michael Langmayr Date: Wed, 11 Feb 2026 18:44:32 +0100 Subject: [PATCH 1/8] Replace zmq transport exceptions with transport agnostic errors --- src/mktl/begin.py | 4 ++-- src/mktl/config.py | 4 ++-- src/mktl/daemon.py | 8 ++++---- src/mktl/item.py | 4 ++-- src/mktl/transport/__init__.py | 7 +++++++ src/mktl/transport/base.py | 18 ++++++++++++++++++ src/mktl/transport/zmq/publish.py | 12 ++++++++++-- src/mktl/transport/zmq/request.py | 16 ++++++++++++---- 8 files changed, 57 insertions(+), 16 deletions(-) diff --git a/src/mktl/begin.py b/src/mktl/begin.py index e0fb905f..95d5b0e3 100644 --- a/src/mktl/begin.py +++ b/src/mktl/begin.py @@ -3,10 +3,10 @@ """ import threading -import zmq from . import config from . import protocol +from .transport import TransportError from .store import Store @@ -205,7 +205,7 @@ def refresh(configuration): try: client.send(request) - except zmq.ZMQError: + except TransportError: # No response from this daemon; move on to the next entry in # the provenance. If no daemons respond the client will have # to rely on the local disk cache. diff --git a/src/mktl/config.py b/src/mktl/config.py index 24fb8c7b..0c425826 100644 --- a/src/mktl/config.py +++ b/src/mktl/config.py @@ -5,7 +5,6 @@ import threading import time import uuid -import zmq # Importing pint is expensive, representing something like 30% of the # user runtime for a simple mKTL command. It will be imported on a @@ -19,6 +18,7 @@ from . import json from . import protocol +from .transport import TransportError _cache = dict() @@ -1196,7 +1196,7 @@ def announce(config, uuid, override=False): for address,port in brokers: try: payload = protocol.request.send(address, port, message) - except zmq.error.ZMQError: + except TransportError: continue error = payload.error diff --git a/src/mktl/daemon.py b/src/mktl/daemon.py index e50a5f4b..2122319b 100644 --- a/src/mktl/daemon.py +++ b/src/mktl/daemon.py @@ -9,7 +9,6 @@ import sys import threading import time -import zmq from . import begin from . import config @@ -18,6 +17,7 @@ from . import poll from . import protocol from . import store +from .transport import TransportError, TransportPortError class Daemon: @@ -86,14 +86,14 @@ def __init__(self, store, alias, override=False, options=None): try: self.pub = protocol.publish.Server(port=pub, avoid=avoid) - except zmq.error.ZMQError: + except TransportPortError: self.pub = protocol.publish.Server(port=None, avoid=avoid) avoid = _used_ports() try: self.rep = RequestServer(self, port=rep, avoid=avoid) - except zmq.error.ZMQError: + except TransportPortError: self.rep = RequestServer(self, port=None, avoid=avoid) _save_port(store, self.uuid, self.rep.port, self.pub.port) @@ -421,7 +421,7 @@ def _test_port(self, store, port): try: payload = protocol.request.send(hostname, port, request) - except zmq.ZMQError: + except TransportError: # Not running; perfect. return diff --git a/src/mktl/item.py b/src/mktl/item.py index 6cf8119c..93b8de25 100644 --- a/src/mktl/item.py +++ b/src/mktl/item.py @@ -3,7 +3,6 @@ import threading import time import traceback -import zmq try: import numpy @@ -13,6 +12,7 @@ from . import protocol from . import poll from . import weakref +from .transport import TransportError class Item: @@ -779,7 +779,7 @@ def subscribe(self, prime=True): if prime == True: try: self.get(refresh=True) - except (zmq.ZMQError, RuntimeError): + except (TransportError, RuntimeError): # Connection errors and remote errors on priming reads are # thrown away; an error here means the remote daemon is not # available to respond to requests, but despite that error diff --git a/src/mktl/transport/__init__.py b/src/mktl/transport/__init__.py index 6f40511f..f024f924 100644 --- a/src/mktl/transport/__init__.py +++ b/src/mktl/transport/__init__.py @@ -3,3 +3,10 @@ Each transport is responsible for mapping protocol :class:`mktl.protocol.Message` objects to/from a wire representation. """ + +from .base import ( + TransportError, + TransportTimeout, + TransportConnectionError, + TransportPortError, +) diff --git a/src/mktl/transport/base.py b/src/mktl/transport/base.py index 90b5532e..455493d3 100644 --- a/src/mktl/transport/base.py +++ b/src/mktl/transport/base.py @@ -11,6 +11,24 @@ from ..protocol import Message +# Transport agnostic exceptions + +class TransportError(Exception): + """Base class for all transport-layer errors.""" + + +class TransportTimeout(TransportError): + """A request did not receive a timely response.""" + + +class TransportConnectionError(TransportError): + """The transport could not establish or maintain a connection.""" + + +class TransportPortError(TransportError): + """No suitable port could be bound or connected.""" + + class Transport(ABC): @abstractmethod def send(self, msg: Message) -> None: diff --git a/src/mktl/transport/zmq/publish.py b/src/mktl/transport/zmq/publish.py index e806b152..d310ccb1 100644 --- a/src/mktl/transport/zmq/publish.py +++ b/src/mktl/transport/zmq/publish.py @@ -12,6 +12,7 @@ import zmq from ...protocol import Message, Publish +from ...transport import TransportPortError from .framing import from_pub_frames, to_pub_frames minimum_port = 10139 @@ -63,9 +64,16 @@ def __init__(self, port: Optional[int] = None, avoid: Optional[set] = None): except zmq.ZMQError: continue if self.port is None: - raise zmq.ZMQError("no PUB port available") + raise TransportPortError( + f"no ports available in range {minimum_port}:{maximum_port}" + ) else: - self.socket.bind(f"tcp://*:{self.port}") + try: + self.socket.bind(f"tcp://*:{self.port}") + except zmq.ZMQError as exc: + raise TransportPortError( + f"port already in use: {self.port}" + ) from exc # Internal queue for thread-safe sends try: diff --git a/src/mktl/transport/zmq/request.py b/src/mktl/transport/zmq/request.py index edf8b874..d0c77488 100644 --- a/src/mktl/transport/zmq/request.py +++ b/src/mktl/transport/zmq/request.py @@ -26,6 +26,7 @@ from ...protocol import Message, Payload, Request from ...protocol.fields import ACK, REP +from ...transport import TransportTimeout, TransportPortError from .framing import from_request_frames, to_request_frames @@ -138,7 +139,7 @@ def send(self, request: Request) -> PendingRequest: ack = pending.wait_ack(self.timeout) if not ack: - raise zmq.ZMQError( + raise TransportTimeout( f"{request.msg_type} @ {self.address}:{self.port}: no ACK in {self.timeout:.2f} sec" ) return pending @@ -160,7 +161,12 @@ def __init__(self, address: Optional[str] = None, port: Optional[int] = None, av if self.port is None: self.port = self._bind_any() else: - self.socket.bind(f"tcp://{self.address}:{self.port}") + try: + self.socket.bind(f"tcp://{self.address}:{self.port}") + except zmq.ZMQError as exc: + raise TransportPortError( + f"port already in use: {self.port}" + ) from exc # Response queue for thread-safe sending try: @@ -188,7 +194,9 @@ def _bind_any(self) -> int: return port except zmq.ZMQError: continue - raise RuntimeError("no available ports") + raise TransportPortError( + f"no ports available in range {minimum_port}:{maximum_port}" + ) # --- request handling hooks --- def req_handler(self, request: Request) -> Optional[Payload]: @@ -291,7 +299,7 @@ def send(address: str, port: int, message: Request) -> Payload: pending = c.send(message) response = pending.wait(timeout=60) if response is None: - raise zmq.ZMQError("no response received") + raise TransportTimeout("no response received") if response.payload is None: return Payload(value=None) return response.payload From 2b15cf786ebf76ee53c28b73642e9be90b8c4d77 Mon Sep 17 00:00:00 2001 From: Michael Langmayr Date: Wed, 11 Feb 2026 20:02:09 +0100 Subject: [PATCH 2/8] lift shared codec.py to transport folder and fix protocol imports --- src/mktl/transport/base.py | 2 +- src/mktl/transport/codec.py | 33 +++++++++++++++++++++++++++++++ src/mktl/transport/zmq/codec.py | 2 +- src/mktl/transport/zmq/framing.py | 4 ++-- src/mktl/transport/zmq/publish.py | 3 ++- src/mktl/transport/zmq/request.py | 3 ++- 6 files changed, 41 insertions(+), 6 deletions(-) create mode 100644 src/mktl/transport/codec.py diff --git a/src/mktl/transport/base.py b/src/mktl/transport/base.py index 455493d3..5c0e16a9 100644 --- a/src/mktl/transport/base.py +++ b/src/mktl/transport/base.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod -from ..protocol import Message +from ..protocol.message import Message # Transport agnostic exceptions diff --git a/src/mktl/transport/codec.py b/src/mktl/transport/codec.py new file mode 100644 index 00000000..c195f743 --- /dev/null +++ b/src/mktl/transport/codec.py @@ -0,0 +1,33 @@ +"""Transport codec for protocol Payload.""" + +from __future__ import annotations + +from typing import Optional, Tuple + +from .. import json +from ..protocol.message import Payload + + +def encode_payload(payload: Optional[Payload]) -> Tuple[bytes, bytes]: + """Return (json_bytes, bulk_bytes).""" + + if payload is None: + return b"", b"" + + bulk = payload.bulk or b"" + j = json.dumps(payload.to_dict()) + return j, bulk + + +def decode_payload(payload_bytes: bytes, bulk_bytes: bytes) -> Optional[Payload]: + if payload_bytes in (b"", None): + return None + + d = json.loads(payload_bytes) + bulk = bulk_bytes if bulk_bytes not in (b"", None) else None + try: + return Payload.from_dict(d, bulk=bulk) + except Exception: + # Preserve non-conforming payloads for advanced users. + # Return as raw dict-like payload embedded in Payload.value. + return Payload(value=d, bulk=bulk) diff --git a/src/mktl/transport/zmq/codec.py b/src/mktl/transport/zmq/codec.py index 5412ade9..82988cbd 100644 --- a/src/mktl/transport/zmq/codec.py +++ b/src/mktl/transport/zmq/codec.py @@ -8,7 +8,7 @@ from typing import Optional, Tuple from ... import json -from ...protocol import Payload +from ...protocol.message import Payload def encode_payload(payload: Optional[Payload]) -> Tuple[bytes, bytes]: diff --git a/src/mktl/transport/zmq/framing.py b/src/mktl/transport/zmq/framing.py index 24f28d6f..2bad10c6 100644 --- a/src/mktl/transport/zmq/framing.py +++ b/src/mktl/transport/zmq/framing.py @@ -11,8 +11,8 @@ from typing import Iterable, Optional, Sequence, Tuple -from ...protocol import Message, Payload, PROTOCOL_VERSION -from .codec import encode_payload, decode_payload +from ...protocol.message import Message, Payload, PROTOCOL_VERSION +from ..codec import encode_payload, decode_payload _VERSION_BYTES = PROTOCOL_VERSION.encode() diff --git a/src/mktl/transport/zmq/publish.py b/src/mktl/transport/zmq/publish.py index d310ccb1..7aa04a96 100644 --- a/src/mktl/transport/zmq/publish.py +++ b/src/mktl/transport/zmq/publish.py @@ -11,7 +11,8 @@ import zmq -from ...protocol import Message, Publish +from ...protocol.message import Message +from ...protocol.publish import Publish from ...transport import TransportPortError from .framing import from_pub_frames, to_pub_frames diff --git a/src/mktl/transport/zmq/request.py b/src/mktl/transport/zmq/request.py index d0c77488..965aef5f 100644 --- a/src/mktl/transport/zmq/request.py +++ b/src/mktl/transport/zmq/request.py @@ -24,7 +24,8 @@ import zmq -from ...protocol import Message, Payload, Request +from ...protocol.message import Message, Payload +from ...protocol.request import Request from ...protocol.fields import ACK, REP from ...transport import TransportTimeout, TransportPortError from .framing import from_request_frames, to_request_frames From 4d3228dff390a1f511e391a2b74c910b72119ee4 Mon Sep 17 00:00:00 2001 From: Michael Langmayr Date: Wed, 11 Feb 2026 23:23:15 +0100 Subject: [PATCH 3/8] add open/close/send/recv to to base transport --- src/mktl/transport/base.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/mktl/transport/base.py b/src/mktl/transport/base.py index 5c0e16a9..b993b112 100644 --- a/src/mktl/transport/base.py +++ b/src/mktl/transport/base.py @@ -7,6 +7,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import Optional from ..protocol.message import Message @@ -30,10 +31,25 @@ class TransportPortError(TransportError): class Transport(ABC): + """Minimal contract for a wire-level transport.""" + + @abstractmethod + def open(self) -> None: + """Establish the underlying connection/socket.""" + + @abstractmethod + def close(self) -> None: + """Tear down the underlying connection/socket.""" + @abstractmethod def send(self, msg: Message) -> None: - """Send a protocol message.""" + """Send a protocol Message.""" @abstractmethod - def recv(self) -> Message: - """Receive the next protocol message.""" + def recv(self, timeout: Optional[float] = None) -> Message: + """Receive the next protocol Message.""" + + @property + def is_open(self) -> bool: + """Whether the transport is currently connected.""" + return False From 22665e60290e648b9dcfbd213dd126271ef58819 Mon Sep 17 00:00:00 2001 From: Michael Langmayr Date: Thu, 12 Feb 2026 02:32:55 +0100 Subject: [PATCH 4/8] plug in session class with zmq transport and upper layer --- src/mktl/begin.py | 15 +- src/mktl/config.py | 3 +- src/mktl/daemon.py | 11 +- src/mktl/item.py | 5 +- src/mktl/protocol/factory.py | 11 +- src/mktl/transport/__init__.py | 17 +- src/mktl/transport/session.py | 169 ++++++++++++ src/mktl/transport/zmq/__init__.py | 0 src/mktl/transport/zmq/codec.py | 36 --- src/mktl/transport/zmq/message.py | 410 ----------------------------- src/mktl/transport/zmq/publish.py | 6 +- src/mktl/transport/zmq/request.py | 114 +------- 12 files changed, 220 insertions(+), 577 deletions(-) create mode 100644 src/mktl/transport/session.py create mode 100644 src/mktl/transport/zmq/__init__.py delete mode 100644 src/mktl/transport/zmq/codec.py delete mode 100644 src/mktl/transport/zmq/message.py diff --git a/src/mktl/begin.py b/src/mktl/begin.py index 95d5b0e3..30c46482 100644 --- a/src/mktl/begin.py +++ b/src/mktl/begin.py @@ -6,6 +6,7 @@ from . import config from . import protocol +from . import transport from .transport import TransportError from .store import Store @@ -47,13 +48,13 @@ def discover(*targets): # Hacking the timeout for discovery, this is not expected to throw # errors with minimal delay. - old_timeout = protocol.request.Client.timeout - protocol.request.Client.timeout = 0.5 + old_timeout = transport.request.Client.timeout + transport.request.Client.timeout = 0.5 for address,port in brokers: request = protocol.message.Request('HASH') try: - payload = protocol.request.send(address, port, request) + payload = transport.request.send(address, port, request) except: continue @@ -61,7 +62,7 @@ def discover(*targets): for store in hashes.keys(): request = protocol.message.Request('CONFIG', store) - payload = protocol.request.send(address, port, request) + payload = transport.request.send(address, port, request) blocks = payload.value @@ -71,7 +72,7 @@ def discover(*targets): configuration.update(block) - protocol.request.Client.timeout = old_timeout + transport.request.Client.timeout = old_timeout @@ -140,7 +141,7 @@ def get(store, key=None): hostname,port = brokers[0] message = protocol.message.Request('CONFIG', store) - payload = protocol.request.send(hostname, port, message) + payload = transport.request.send(hostname, port, message) blocks = payload.value @@ -200,7 +201,7 @@ def refresh(configuration): hostname = stratum['hostname'] rep = stratum['rep'] - client = protocol.request.client(hostname, rep) + client = transport.request.client(hostname, rep) request = protocol.message.Request('HASH', store) try: diff --git a/src/mktl/config.py b/src/mktl/config.py index 0c425826..daa11d3b 100644 --- a/src/mktl/config.py +++ b/src/mktl/config.py @@ -18,6 +18,7 @@ from . import json from . import protocol +from . import transport from .transport import TransportError @@ -1195,7 +1196,7 @@ def announce(config, uuid, override=False): for address,port in brokers: try: - payload = protocol.request.send(address, port, message) + payload = transport.request.send(address, port, message) except TransportError: continue diff --git a/src/mktl/daemon.py b/src/mktl/daemon.py index 2122319b..c24ee1e6 100644 --- a/src/mktl/daemon.py +++ b/src/mktl/daemon.py @@ -17,6 +17,7 @@ from . import poll from . import protocol from . import store +from . import transport from .transport import TransportError, TransportPortError @@ -85,9 +86,9 @@ def __init__(self, store, alias, override=False, options=None): self._test_port(store, rep) try: - self.pub = protocol.publish.Server(port=pub, avoid=avoid) + self.pub = transport.publish.Server(port=pub, avoid=avoid) except TransportPortError: - self.pub = protocol.publish.Server(port=None, avoid=avoid) + self.pub = transport.publish.Server(port=None, avoid=avoid) avoid = _used_ports() @@ -420,7 +421,7 @@ def _test_port(self, store, port): request = protocol.message.Request('CONFIG', store) try: - payload = protocol.request.send(hostname, port, request) + payload = transport.request.send(hostname, port, request) except TransportError: # Not running; perfect. return @@ -446,10 +447,10 @@ def _test_port(self, store, port): -class RequestServer(protocol.request.Server): +class RequestServer(transport.request.Server): def __init__(self, daemon, *args, **kwargs): - protocol.request.Server.__init__(self, *args, **kwargs) + transport.request.Server.__init__(self, *args, **kwargs) self.daemon = daemon diff --git a/src/mktl/item.py b/src/mktl/item.py index 93b8de25..2a849c49 100644 --- a/src/mktl/item.py +++ b/src/mktl/item.py @@ -11,6 +11,7 @@ from . import protocol from . import poll +from . import transport from . import weakref from .transport import TransportError @@ -95,8 +96,8 @@ def __init__(self, store, key, subscribe=True, authoritative=False, pub=None): # configuration that doesn't contain a provenance. raise RuntimeError('cannot find daemon for ' + self.full_key) - self.sub = protocol.publish.client(hostname, pub) - self.req = protocol.request.client(hostname, rep) + self.sub = transport.publish.client(hostname, pub) + self.req = transport.request.client(hostname, rep) try: settable = self.config['settable'] diff --git a/src/mktl/protocol/factory.py b/src/mktl/protocol/factory.py index 3925ac1a..148fba4c 100644 --- a/src/mktl/protocol/factory.py +++ b/src/mktl/protocol/factory.py @@ -7,8 +7,8 @@ from .discover import Discover from .publish import Publish from .request import Request -from .fields import GET -from .message import Payload +from .fields import ACK, GET, REP +from .message import Message, Payload def payload(value: Any, **kwargs) -> Payload: @@ -31,3 +31,10 @@ def discover(value: Any = None, *, payload: Optional[Payload] = None, **payload_ if payload is None and (value is not None or payload_kwargs): payload = Payload(value=value, **payload_kwargs) return Discover(payload=payload) + + +def fast_ack(msg: Message) -> Message: + """Create an ACK confirming receipt of a message.""" + ack = Message(msg_type=ACK, target=msg.target, msg_id=msg.msg_id) + ack.meta.update(msg.meta) + return ack diff --git a/src/mktl/transport/__init__.py b/src/mktl/transport/__init__.py index f024f924..4c533539 100644 --- a/src/mktl/transport/__init__.py +++ b/src/mktl/transport/__init__.py @@ -1,8 +1,6 @@ -"""Transport implementations. +"""Transport layer implementations.""" -Each transport is responsible for mapping protocol :class:`mktl.protocol.Message` -objects to/from a wire representation. -""" +import os from .base import ( TransportError, @@ -10,3 +8,14 @@ TransportConnectionError, TransportPortError, ) + +_BACKEND = os.environ.get("MKTL_TRANSPORT", "zmq") + +if _BACKEND == "zmq": + from .zmq import request + from .zmq import publish +# elif _BACKEND == "rabbitmq": +# from .rabbitmq import request +# from .rabbitmq import publish +else: + raise ImportError(f"unknown MKTL_TRANSPORT backend: {_BACKEND!r}") diff --git a/src/mktl/transport/session.py b/src/mktl/transport/session.py new file mode 100644 index 00000000..587d0e9d --- /dev/null +++ b/src/mktl/transport/session.py @@ -0,0 +1,169 @@ +"""Transport-agnostic session layer.""" + +from __future__ import annotations + +import sys +import threading +import traceback +from typing import Dict, Optional + +from ..protocol.factory import fast_ack +from ..protocol.fields import ACK, REP +from ..protocol.message import Message, Payload +from ..protocol.request import Request +from .base import Transport, TransportTimeout + + +class PendingRequest: + """Client-side helper that provides ACK/REP synchronization.""" + + def __init__(self, req: Request): + self.req = req + self.response: Optional[Message] = None + self.ack_event = threading.Event() + self.rep_event = threading.Event() + + @property + def id(self) -> bytes: + return self.req.msg_id + + def wait_ack(self, timeout: Optional[float]) -> bool: + return self.ack_event.wait(timeout) + + def wait(self, timeout: Optional[float] = 60) -> Optional[Message]: + self.rep_event.wait(timeout) + return self.response + + def _complete_ack(self) -> None: + self.ack_event.set() + + def _complete(self, response: Message) -> None: + self.response = response + self.ack_event.set() + self.rep_event.set() + + +class RequestSession: + """Client-side request/response pattern logic.""" + + timeout = 0.1 + + def __init__(self, transport: Transport): + self.transport = transport + self._pending: Dict[bytes, PendingRequest] = {} + + def _handle_incoming(self, msg: Message) -> None: + """Correlate incoming ACK/REP to a PendingRequest.""" + pending = self._pending.get(msg.msg_id) + if pending is None: + return + + if msg.msg_type == ACK: + pending._complete_ack() + return + + # REP (or error REP on version mismatch) + pending._complete(msg) + self._pending.pop(msg.msg_id, None) + + def send(self, request: Request) -> PendingRequest: + pending = PendingRequest(request) + self._pending[pending.id] = pending + self.transport.send(request) + + ack = pending.wait_ack(self.timeout) + if not ack: + self._pending.pop(pending.id, None) + raise TransportTimeout( + f"{request.msg_type}: no ACK in {self.timeout:.2f} sec" + ) + return pending + + +class RequestServer: + """Server-side request handler.""" + + def __init__(self, transport: Transport): + self.transport = transport + + # --- request handling hooks --- + def req_handler(self, request: Request) -> Optional[Payload]: + """Override in subclasses. + + Return: + - Payload -> will be wrapped into a REP + - None -> no immediate REP (handler is responsible for later response) + """ + + # Default: no-op + self.req_ack(request) + return None + + def req_ack(self, request: Request) -> None: + self.send(fast_ack(request)) + + def send(self, response: Message) -> None: + self.transport.send(response) + + # --- internal --- + def _req_incoming(self, msg: Message) -> None: + """Dispatch to req_handler, build REP, send.""" + # Convert generic Message to typed Request (validates op) + try: + req = Request(msg_type=msg.msg_type, target=msg.target, payload=msg.payload, msg_id=msg.msg_id) + except Exception: + # If invalid, treat as opaque request + req = Request(msg_type=msg.msg_type, target=msg.target, payload=msg.payload, msg_id=msg.msg_id) + + req.meta.update(msg.meta) + + payload: Optional[Payload] = None + error: Optional[dict] = None + + try: + payload = self.req_handler(req) + except Exception: + e_class, e_instance, _tb = sys.exc_info() + error = { + "type": getattr(e_class, "__name__", "Exception"), + "text": str(e_instance), + "debug": traceback.format_exc(), + } + + if payload is None and error is None: + return + + if payload is None: + payload = Payload(value=None) + if error is not None: + payload.error = error + + rep = Message(msg_type=REP, target=req.target, payload=payload, msg_id=req.msg_id) + rep.meta.update(req.meta) + self.send(rep) + + +class PublishSession: + """Server-side publish pattern logic.""" + + port = None + + def __init__(self, transport: Transport): + self.transport = transport + + def send(self, msg: Message) -> None: + self.transport.send(msg) + + +class SubscribeSession: + """Client-side subscribe pattern logic.""" + + def __init__(self, transport: Transport): + self.transport = transport + + def subscribe(self, topic: str) -> None: + if hasattr(self.transport, 'subscribe'): + self.transport.subscribe(topic) + + def recv(self) -> Message: + return self.transport.recv() diff --git a/src/mktl/transport/zmq/__init__.py b/src/mktl/transport/zmq/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/mktl/transport/zmq/codec.py b/src/mktl/transport/zmq/codec.py deleted file mode 100644 index 82988cbd..00000000 --- a/src/mktl/transport/zmq/codec.py +++ /dev/null @@ -1,36 +0,0 @@ -"""ZMQ transport codec for protocol Payload. - -This is a JSON codec compatible with the historical mKTL wire format. -""" - -from __future__ import annotations - -from typing import Optional, Tuple - -from ... import json -from ...protocol.message import Payload - - -def encode_payload(payload: Optional[Payload]) -> Tuple[bytes, bytes]: - """Return (json_bytes, bulk_bytes).""" - - if payload is None: - return b"", b"" - - bulk = payload.bulk or b"" - j = json.dumps(payload.to_dict()) - return j, bulk - - -def decode_payload(payload_bytes: bytes, bulk_bytes: bytes) -> Optional[Payload]: - if payload_bytes in (b"", None): - return None - - d = json.loads(payload_bytes) - bulk = bulk_bytes if bulk_bytes not in (b"", None) else None - try: - return Payload.from_dict(d, bulk=bulk) - except Exception: - # Preserve non-conforming payloads for advanced users. - # Return as raw dict-like payload embedded in Payload.value. - return Payload(value=d, bulk=bulk) diff --git a/src/mktl/transport/zmq/message.py b/src/mktl/transport/zmq/message.py deleted file mode 100644 index 0d07c40d..00000000 --- a/src/mktl/transport/zmq/message.py +++ /dev/null @@ -1,410 +0,0 @@ -""" A class representation of an mKTL message, including subclasses for - specific messages. -""" - -import itertools -import threading -import time as timemodule - -from ... import json - - -# This is the version of the mKTL on-the-wire protocol implemented here. -# See the protocol specification for a full description; the version is -# identified by a single byte. - -version = b'a' - - -class Message: - """ The :class:`Message` provides a very thin encapsulation of what it - means to be a message in an mKTL context. This class will be used - to represent mKTL messages that do not result in a response. - - The fields are largely in order of how they are represented on the - wire: the message *type*, the key/store *target* for the request, - the *payload* of the message (a :class:`Payload` instance), - and an identification - number unique to this correspondence. The identification number is - the one field that is out-of-order compared to the multipart sequence - on the wire; this is because some message types (publish messages, - in particular) do not have an identification number, and it is - automatically generated for request messages. Rather than force the - caller to pass an explicit None, the id is left as the last field, - so that the arguments for all :class:`Message` instances can have - a similar structure. - - :ivar payload: The :class:`Payload`, if any, for the message. - :ivar valid_types: A set of valid strings for the message type. - :ivar timestamp: A UNIX epoch timestamp for the message send time. - """ - - valid_types = set(('ACK', 'REP')) - - def __init__(self, type, target=None, payload=None, id=None): - - if type in self.valid_types: - pass - else: - raise ValueError('invalid request type: ' + type) - - # There are some message types where the id is allowed to be None; - # for example, publish messages do not have or need an identification - # number or a prefix. - - self.id = id - self.type = type - self.payload = payload - self.prefix = None - self.target = target - self.timestamp = timemodule.time() - - self._parts = None - - - def __iter__(self): - self._finalize() - return iter(self._parts) - - - def __repr__(self): - self._finalize() - return repr(self._parts) - - - def _finalize(self): - """ Take the contents of this :class:`Message`, interpet them as - bytes, and prepare the tuple that will be used for the multipart - transmission on the wire. - """ - - if self._parts: - # Once finalized, always finalized. - return - - id = self.id - type = self.type - target = self.target - payload = self.payload - - # It is legal to create a Message with None as the id-- this happens - # all the time when a Message is used as a container-- but trying to - # send such a message is not permitted. - - if id is None: - raise RuntimeError('messages must have an id to be put on the wire') - - try: - id.decode - except AttributeError: - id = '%08x' % (id) - id = id.encode() - - type = type.encode() - - if target == None or target == '': - target = b'' - else: - try: - target = target.encode() - except AttributeError: - # Assume it is already bytes. - pass - - if payload is None or payload == '': - bulk = b'' - payload = b'' - else: - bulk = payload.bulk - if bulk is None: - bulk = b'' - payload = payload.encapsulate() - - if self.prefix: - parts = self.prefix + (version, id, type, target, payload, bulk) - else: - parts = (version, id, type, target, payload, bulk) - - self._parts = parts - - -# end of class Message - - - -class Broadcast(Message): - """ A :class:`Broadcast` is a minor variant of a :class:`Message`, - with a change to format the multipart tuple in a PUB/SUB specific - fashion. - """ - - valid_types = set(('PUB',)) - - def _finalize(self): - - if self._parts: - # Once finalized, always finalized. - return - - target = self.target - payload = self.payload - - # The PUB/SUB topic has a trailing dot to prevent leading - # substring matches from picking up extra keys. - - target = target + '.' - target = target.encode() - - if payload is None or payload == '': - bulk = b'' - payload = b'' - else: - bulk = payload.bulk - if bulk is None: - bulk = b'' - payload = payload.encapsulate() - - # The prefix is ignored for broadcast messages; it should not be set. - self._parts = (target, version, payload, bulk) - - -# end of class Broadcast - - - -class Request(Message): - """ A :class:`Request` has a little extra functionality, focusing on - local caching of response values and signaling that a request is - complete. This is the class that will be used on the client side - when a server is expected to provide a response, such as returning - a requested value, or signaling that a set operation is complete. - - :ivar response: The response (as a :class:`Message`) to this request - """ - - valid_types = set(('CONFIG', 'GET', 'HASH', 'SET')) - - def __init__(self, type, target=None, payload=None, id=None): - - # Requests are generally initiated without an id number, but they're - # required to have one. The expectation is that requests will have an - # id number that is locally unique, so that the request/response - # handler can correctly tie an incoming response to the request that - # generated it. - - # Long story short, for nearly all Request instances the id argument - # will be None, and we are expected to auto-generate a locally unique - # identification number. - - if id is None: - id = _id_next() - - Message.__init__(self, type, target, payload, id) - - self.response = None - - self.ack_event = threading.Event() - self.rep_event = threading.Event() - - - def __repr__(self): - self._finalize() - request = 'REQ: ' + repr(self._parts) - - if self.response is None: - response = 'REP: None' - else: - response = 'REP: ' + repr(tuple(self.response)) - - return request + ', ' + response - - - def _complete_ack(self): - """ The request, if any, has been acknowledged; signal any callers blocking via :func:`wait_ack` to proceed. - """ - - self.ack_event.set() - - - def _complete(self, response): - """ Locally store the response and signal any callers blocking via - :func:`wait` to proceed. - """ - - self.response = response - self.ack_event.set() - self.rep_event.set() - - - def poll(self): - """ Return True if the request is complete, otherwise return False. - """ - - return self.rep_event.is_set() - - - def wait_ack(self, timeout): - """ Block until the request has been acknowledged. This is a wrapper to - a :class:`threading.Event` instance; if the event has occurred it - will return True, otherwise it returns False after the requested - *timeout*. If the *timeout* argument is None it will block - indefinitely. - """ - - return self.ack_event.wait(timeout) - - - def wait(self, timeout=60): - """ Block until the request has been handled. The response to the - request is always returned; the response will be None if the - original request is still pending. - """ - - self.rep_event.wait(timeout) - return self.response - - -# end of class Request - - - -class Payload: - """ This is a lightweight class to properly encapsulate a Python-native - value for later inclusion in a :class:`Message` instance. All attributes - of this class, except for the :attr:`bulk` attribute, or any attrbiutes - set to None, will be added to as a JSON dictionary via - :func:`encapsulate`, which is called when a :class:`Message` instance - finalizes itself before generating its final on-the-wire representation. - - No interpretation of the :class:`Payload` contents is performed, any - interpretation (such as converting a numpy array to a form suitable - for enapsulation) must occur before the :class:`Payload` is - instantiated. - - :ivar bulk: A bulk data value, in bytes, or None - :ivar omit: A set of fields to omit from encapsulation - """ - - omit = set(('bulk', 'omit')) - - def __init__(self, value, time=None, error=None, bulk=None, shape=None, dtype=None, refresh=None, **kwargs): - """ Arbitrary keyword arguments are allowed when creating a - :class:`Payload` instance, beyond the canonical set; when - included, these additional keyword arguments will be assigned - directly as attributes for later encapsulation. Any values - assigned in this fashion must be serializable as JSON. - """ - - # The use of 'time' as a keyword argument is what's motivating the - # weird import of the time module in this file. We want the keyword - # arguments to be aligned with the fields in the JSON description - # of a payload: value, time, error, etc. - - if time is None: - time = timemodule.time() - - if refresh is False: - refresh = None - - # We expect the canonical arguments to all be set all the time, - # even if their value is None. That's why they're not rolled up - # into the kwargs catch-all. - - self.bulk = bulk - self.dtype = dtype - self.error = error - self.refresh = refresh - self.shape = shape - self.time = time - self.value = value - - if not kwargs: - # This is the average case. Faster to check this one condition - # and return than to drop out of the next two conditions. - return - - if 'omit' in kwargs: - raise ValueError("cannot assign 'omit' to a Payload") - - # Allow additional arbitrary fields in the payload. We are assuming - # the caller knows what they are doing, and that these additional - # fields can be serialized as JSON. - - for key,value in kwargs.items(): - setattr(self, key, value) - - - def __repr__(self): - return self.encapsulate().decode() - - - def encapsulate(self): - """ Add all non-omitted local attributes to a dictionary, and return - the JSON encoding of that dictionary. For example, if the .value - and .time attributes of a :class:`Payload` are assigned, the caller - will receive a value like:: - - b'{"value": 12, "time": 1761100609.234571}' - """ - - # The output from this method was initially cached, but there's never - # a situation where this method is called twice for a given Payload, - # so the caching was removed. - - payload = dict() - - # All local attributes get put into the encapsulated payload, - # except for those included in the omit set. - - for key,value in vars(self).items(): - if key in self.omit: - continue - - # Do not include attributes that are just 'None'. This may be - # premature optimization, but it seems silly to put a bunch of - # extra bytes on the wire when it conveys no additional information. - - # It's faster to check the key against 'value' repeatedly than - # to build a separate set of includes and only assign those - # key/value pairs to the payload. - - if value is not None or key == 'value': - payload[key] = value - - payload = json.dumps(payload) - return payload - - -# end of class Payload - - -_id_min = 0 -_id_max = 0xFFFFFFFF -_id_lock = threading.Lock() -_id_ticker = itertools.count(_id_min) - - -def _id_next(): - """ Return the next request identification number for subroutines to - use when constructing a message. - """ - - global _id_ticker - _id_lock.acquire() - id = next(_id_ticker) - - if id >= _id_max: - _id_ticker = itertools.count(_id_min) - - if id > _id_max: - # This shouldn't happen, but here we are... - id = next(_id_ticker) - - _id_lock.release() - - id = '%08x' % (id) - id = id.encode() - return id - - -# vim: set expandtab tabstop=8 softtabstop=4 shiftwidth=4 autoindent: diff --git a/src/mktl/transport/zmq/publish.py b/src/mktl/transport/zmq/publish.py index 7aa04a96..97c5aa11 100644 --- a/src/mktl/transport/zmq/publish.py +++ b/src/mktl/transport/zmq/publish.py @@ -3,7 +3,6 @@ from __future__ import annotations import atexit -import itertools import queue import threading import traceback @@ -14,6 +13,7 @@ from ...protocol.message import Message from ...protocol.publish import Publish from ...transport import TransportPortError +from ..session import PublishSession, SubscribeSession from .framing import from_pub_frames, to_pub_frames minimum_port = 10139 @@ -21,7 +21,7 @@ zmq_context = zmq.Context() -class Client: +class Client(SubscribeSession): """SUB client.""" def __init__(self, address: str, port: int): @@ -45,7 +45,7 @@ def recv(self) -> Message: return from_pub_frames(parts) -class Server: +class Server(PublishSession): """PUB server.""" def __init__(self, port: Optional[int] = None, avoid: Optional[set] = None): diff --git a/src/mktl/transport/zmq/request.py b/src/mktl/transport/zmq/request.py index 965aef5f..4708a11f 100644 --- a/src/mktl/transport/zmq/request.py +++ b/src/mktl/transport/zmq/request.py @@ -13,21 +13,17 @@ import atexit import concurrent.futures -import itertools import queue import socket as pysocket -import sys import threading -import time -import traceback from typing import Dict, Optional, Tuple import zmq from ...protocol.message import Message, Payload from ...protocol.request import Request -from ...protocol.fields import ACK, REP from ...transport import TransportTimeout, TransportPortError +from ..session import RequestSession, RequestServer, PendingRequest from .framing import from_request_frames, to_request_frames @@ -36,36 +32,7 @@ zmq_context = zmq.Context() -class PendingRequest: - """Client-side helper that provides ACK/REP synchronization.""" - - def __init__(self, req: Request): - self.req = req - self.response: Optional[Message] = None - self.ack_event = threading.Event() - self.rep_event = threading.Event() - - @property - def id(self) -> bytes: - return self.req.msg_id - - def wait_ack(self, timeout: Optional[float]) -> bool: - return self.ack_event.wait(timeout) - - def wait(self, timeout: Optional[float] = 60) -> Optional[Message]: - self.rep_event.wait(timeout) - return self.response - - def _complete_ack(self) -> None: - self.ack_event.set() - - def _complete(self, response: Message) -> None: - self.response = response - self.ack_event.set() - self.rep_event.set() - - -class Client: +class Client(RequestSession): """Issue requests via a ZeroMQ DEALER socket and receive responses.""" timeout = 0.1 @@ -97,20 +64,6 @@ def __init__(self, address: str, port: int): self._thread = threading.Thread(target=self.run, daemon=True) self._thread.start() - def _handle_incoming(self, parts: Tuple[bytes, ...]) -> None: - msg = from_request_frames(parts) - pending = self._pending.get(msg.msg_id) - if pending is None: - return - - if msg.msg_type == ACK: - pending._complete_ack() - return - - # REP (or error REP on version mismatch) - pending._complete(msg) - self._pending.pop(msg.msg_id, None) - def _handle_outgoing(self) -> None: # Clear one signal and send one request. self._signal_rx.recv(flags=zmq.NOBLOCK) @@ -131,7 +84,8 @@ def run(self) -> None: self._handle_outgoing() elif active == self.socket: parts = tuple(self.socket.recv_multipart()) - self._handle_incoming(parts) + msg = from_request_frames(parts) + self._handle_incoming(msg) def send(self, request: Request) -> PendingRequest: pending = PendingRequest(request) @@ -146,7 +100,7 @@ def send(self, request: Request) -> PendingRequest: return pending -class Server: +class Server(RequestServer): """Receive requests via a ZeroMQ ROUTER socket, respond to them.""" port = None # auto @@ -199,24 +153,6 @@ def _bind_any(self) -> int: f"no ports available in range {minimum_port}:{maximum_port}" ) - # --- request handling hooks --- - def req_handler(self, request: Request) -> Optional[Payload]: - """Override in subclasses. - - Return: - - Payload -> will be wrapped into a REP - - None -> no immediate REP (handler is responsible for later response) - """ - - # Default: no-op - self.req_ack(request) - return None - - def req_ack(self, request: Request) -> None: - ack = Message(msg_type=ACK, target=request.target, msg_id=request.msg_id) - ack.meta["zmq_prefix"] = request.meta.get("zmq_prefix", ()) - self.send(ack) - def send(self, response: Message) -> None: self._responses.put(response) self._signal_tx.send(b"") @@ -228,43 +164,6 @@ def _rep_outgoing(self) -> None: frames = to_request_frames(response, include_prefix=True) self.socket.send_multipart(frames) - def _req_incoming(self, parts: Tuple[bytes, ...]) -> None: - msg = from_request_frames(parts) - - # Convert generic Message to typed Request (validates op) - try: - req = Request(msg_type=msg.msg_type, target=msg.target, payload=msg.payload, msg_id=msg.msg_id) - except Exception: - # If invalid, treat as opaque request - req = Request(msg_type=msg.msg_type, target=msg.target, payload=msg.payload, msg_id=msg.msg_id) - - req.meta.update(msg.meta) - - payload: Optional[Payload] = None - error: Optional[dict] = None - - try: - payload = self.req_handler(req) - except Exception: - e_class, e_instance, _tb = sys.exc_info() - error = { - "type": getattr(e_class, "__name__", "Exception"), - "text": str(e_instance), - "debug": traceback.format_exc(), - } - - if payload is None and error is None: - return - - if payload is None: - payload = Payload(value=None) - if error is not None: - payload.error = error - - rep = Message(msg_type=REP, target=req.target, payload=payload, msg_id=req.msg_id) - rep.meta["zmq_prefix"] = req.meta.get("zmq_prefix", ()) - self.send(rep) - def run(self) -> None: poller = zmq.Poller() poller.register(self.socket, zmq.POLLIN) @@ -276,7 +175,8 @@ def run(self) -> None: self._rep_outgoing() elif active == self.socket: parts = tuple(self.socket.recv_multipart()) - self.workers.submit(self._req_incoming, parts) + msg = from_request_frames(parts) + self.workers.submit(self._req_incoming, msg) # --- convenience helpers (API-compatible-ish) --- From 31747f4dea163879309bf850cafd0ed80e0e08fd Mon Sep 17 00:00:00 2001 From: Michael Langmayr Date: Thu, 12 Feb 2026 17:59:16 +0100 Subject: [PATCH 5/8] use protocol.wire instead of zmq framing / codec --- src/mktl/transport/codec.py | 19 ---- src/mktl/transport/zmq/framing.py | 144 ------------------------------ src/mktl/transport/zmq/publish.py | 10 +-- src/mktl/transport/zmq/request.py | 17 ++-- 4 files changed, 13 insertions(+), 177 deletions(-) delete mode 100644 src/mktl/transport/codec.py delete mode 100644 src/mktl/transport/zmq/framing.py diff --git a/src/mktl/transport/codec.py b/src/mktl/transport/codec.py deleted file mode 100644 index 4f31e8b7..00000000 --- a/src/mktl/transport/codec.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Transport codec for protocol payloads.""" - -from __future__ import annotations - -from typing import Dict, Any, Optional - -from .. import json - - -def encode_payload(payload: Optional[Dict[str, Any]]) -> bytes: - if not payload: - return b"" - return json.dumps(payload) - - -def decode_payload(payload_bytes: bytes) -> Dict[str, Any]: - if payload_bytes in (b"", None): - return {} - return json.loads(payload_bytes) diff --git a/src/mktl/transport/zmq/framing.py b/src/mktl/transport/zmq/framing.py deleted file mode 100644 index b0497850..00000000 --- a/src/mktl/transport/zmq/framing.py +++ /dev/null @@ -1,144 +0,0 @@ -"""ZMQ multipart framing for protocol messages. - -Request/Response (DEALER<->ROUTER) - (optional routing prefix...), version, transid, type, key, payload_json, binary - -Publish (PUB/SUB) - topic_with_trailing_dot, version, payload_json, binary -""" - -from __future__ import annotations - -from typing import Sequence, Tuple - -from ...protocol.message import Message, Envelope, MsgType -from ..codec import encode_payload, decode_payload - - -_VERSION = "a" -_VERSION_BYTES = _VERSION.encode() - - -def to_request_frames(msg: Message, *, include_prefix: bool = False) -> Tuple[bytes, ...]: - - env = msg.env - prefix: Tuple[bytes, ...] = tuple(env.meta.get("zmq_prefix", ())) - if not include_prefix: - prefix = () - - payload_bytes = encode_payload(env.payload) - binary = msg.binary or b"" - transid = env.transid.encode() - key = (env.key or "").encode() - - parts = ( - _VERSION_BYTES, - transid, - env.type.value.encode(), - key, - payload_bytes, - binary, - ) - return prefix + parts - - -def from_request_frames(parts: Sequence[bytes]) -> Message: - """Decode ROUTER/DEALER parts into a protocol Message. - - If a ROUTER identity prefix is present, it is stored in env.meta['zmq_prefix']. - """ - - if not parts: - raise ValueError("empty message") - - # ROUTER sockets prepend identity frames. We expect either: - # [version, transid, type, key, payload, binary] - # or - # [ident, version, transid, type, key, payload, binary] - if parts[0] == _VERSION_BYTES: - prefix: Tuple[bytes, ...] = () - start = 0 - else: - prefix = (parts[0],) - start = 1 - - their_version = parts[start] - if their_version != _VERSION_BYTES: - meta = {"zmq_prefix": prefix} if prefix else {} - env = Envelope( - type=MsgType.REP, - sourceid="", - transid=parts[start + 1].decode(), - payload={"error": { - "type": "RuntimeError", - "text": f"mKTL protocol {their_version!r}, expected {_VERSION_BYTES!r}", - }}, - meta=meta, - ) - return Message(env=env) - - transid = parts[start + 1].decode() - msg_type = parts[start + 2].decode() - key_bytes = parts[start + 3] - key = key_bytes.decode() if key_bytes not in (b"", None) else None - payload_bytes = parts[start + 4] - binary_bytes = parts[start + 5] if len(parts) > start + 5 else b"" - - payload = decode_payload(payload_bytes) - meta = {"zmq_prefix": prefix} if prefix else {} - - env = Envelope( - type=MsgType(msg_type), - sourceid=prefix[0].decode() if prefix else "", - transid=transid, - key=key, - payload=payload, - meta=meta, - ) - - return Message(env=env, binary=binary_bytes if binary_bytes else None) - - -def to_pub_frames(msg: Message) -> Tuple[bytes, ...]: - """Encode a publish message for PUB/SUB sockets.""" - - env = msg.env - topic = (env.key or "") + "." # trailing dot to prevent prefix matches - topic_b = topic.encode() - payload_bytes = encode_payload(env.payload) - binary = msg.binary or b"" - return (topic_b, _VERSION_BYTES, payload_bytes, binary) - - -def from_pub_frames(parts: Sequence[bytes]) -> Message: - if len(parts) < 4: - raise ValueError("invalid PUB message") - - topic = parts[0].decode() - if topic.endswith("."): - topic = topic[:-1] - - their_version = parts[1] - if their_version != _VERSION_BYTES: - env = Envelope( - type=MsgType.PUB, - sourceid="", - key=topic, - payload={"error": { - "type": "RuntimeError", - "text": f"mKTL protocol {their_version!r}, expected {_VERSION_BYTES!r}", - }}, - ) - return Message(env=env) - - payload = decode_payload(parts[2]) - binary_bytes = parts[3] if len(parts) > 3 else b"" - - env = Envelope( - type=MsgType.PUB, - sourceid="", - key=topic, - payload=payload, - ) - - return Message(env=env, binary=binary_bytes if binary_bytes else None) diff --git a/src/mktl/transport/zmq/publish.py b/src/mktl/transport/zmq/publish.py index 74014194..0f988042 100644 --- a/src/mktl/transport/zmq/publish.py +++ b/src/mktl/transport/zmq/publish.py @@ -11,9 +11,9 @@ import zmq from ...protocol.message import Message +from ...protocol.wire import pack_frame, unpack_frame from ...transport import TransportPortError from ..session import PublishSession, SubscribeSession -from .framing import from_pub_frames, to_pub_frames minimum_port = 10139 maximum_port = 13679 @@ -40,8 +40,8 @@ def subscribe(self, topic: str) -> None: self.socket.setsockopt(zmq.SUBSCRIBE, (topic + ".").encode()) def recv(self) -> Message: - parts = self.socket.recv_multipart() - return from_pub_frames(parts) + _topic, wire_bytes = self.socket.recv_multipart() + return unpack_frame(wire_bytes) class Server(PublishSession): @@ -98,8 +98,8 @@ def send(self, msg: Message) -> None: def _send_one(self) -> None: self._sig_rx.recv(flags=zmq.NOBLOCK) msg = self._queue.get(block=False) - frames = to_pub_frames(msg) - self.socket.send_multipart(frames) + topic = ((msg.env.key or "") + ".").encode() + self.socket.send_multipart([topic, pack_frame(msg)]) def run(self) -> None: poller = zmq.Poller() diff --git a/src/mktl/transport/zmq/request.py b/src/mktl/transport/zmq/request.py index 5020bfa3..7bbbb234 100644 --- a/src/mktl/transport/zmq/request.py +++ b/src/mktl/transport/zmq/request.py @@ -21,9 +21,9 @@ import zmq from ...protocol.message import Message +from ...protocol.wire import pack_frame, unpack_frame from ...transport import TransportTimeout, TransportPortError from ..session import RequestSession, RequestServer, PendingRequest -from .framing import from_request_frames, to_request_frames minimum_port = 10079 @@ -68,9 +68,8 @@ def _handle_outgoing(self) -> None: self._signal_rx.recv(flags=zmq.NOBLOCK) pending: PendingRequest = self._outbox.get(block=False) - frames = to_request_frames(pending.req) self._pending[pending.id] = pending - self.socket.send_multipart(frames) + self.socket.send(pack_frame(pending.req)) def run(self) -> None: poller = zmq.Poller() @@ -82,8 +81,7 @@ def run(self) -> None: if active == self._signal_rx: self._handle_outgoing() elif active == self.socket: - parts = tuple(self.socket.recv_multipart()) - msg = from_request_frames(parts) + msg = unpack_frame(self.socket.recv()) self._handle_incoming(msg) def send(self, msg: Message) -> PendingRequest: @@ -160,8 +158,8 @@ def send(self, response: Message) -> None: def _rep_outgoing(self) -> None: self._signal_rx.recv(flags=zmq.NOBLOCK) response: Message = self._responses.get(block=False) - frames = to_request_frames(response, include_prefix=True) - self.socket.send_multipart(frames) + identity = response.env.meta.get("zmq_prefix", (b"",))[0] + self.socket.send_multipart([identity, pack_frame(response)]) def run(self) -> None: poller = zmq.Poller() @@ -173,8 +171,9 @@ def run(self) -> None: if active == self._signal_rx: self._rep_outgoing() elif active == self.socket: - parts = tuple(self.socket.recv_multipart()) - msg = from_request_frames(parts) + identity, wire_bytes = self.socket.recv_multipart() + msg = unpack_frame(wire_bytes) + msg.env.meta["zmq_prefix"] = (identity,) self.workers.submit(self._req_incoming, msg) From 076992b1a4b37142adaea49bf1f9bc51b67aacba Mon Sep 17 00:00:00 2001 From: Michael Langmayr Date: Thu, 12 Feb 2026 18:28:54 +0100 Subject: [PATCH 6/8] add rabbit transport layer --- pyproject.toml | 3 + src/mktl/transport/__init__.py | 12 +- src/mktl/transport/rabbitmq/__init__.py | 1 + src/mktl/transport/rabbitmq/publish.py | 181 +++++++++++++++++++ src/mktl/transport/rabbitmq/request.py | 224 ++++++++++++++++++++++++ src/mktl/transport/session.py | 2 - src/mktl/transport/zmq/request.py | 3 - 7 files changed, 418 insertions(+), 8 deletions(-) create mode 100644 src/mktl/transport/rabbitmq/__init__.py create mode 100644 src/mktl/transport/rabbitmq/publish.py create mode 100644 src/mktl/transport/rabbitmq/request.py diff --git a/pyproject.toml b/pyproject.toml index eede4096..b83011d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,13 +12,16 @@ pyzmq = ">=4.0" msgspec = { version = ">=0.18", optional = true } numpy = { version = ">=1.6", optional = true } orjson = { version = ">=3.0", optional = true } +pika = { version = ">=1.3", optional = true } pint = { version = ">=0.17", optional = true } [tool.poetry.extras] msgspec = ["msgspec"] numpy = ["numpy"] orjson = ["orjson"] +pika = ["pika"] pint = ["pint"] +rabbitmq = ["pika"] [tool.poetry.urls] repository = "https://github.com/KeckObservatory/mKTL" diff --git a/src/mktl/transport/__init__.py b/src/mktl/transport/__init__.py index 4c533539..dba18e4f 100644 --- a/src/mktl/transport/__init__.py +++ b/src/mktl/transport/__init__.py @@ -14,8 +14,14 @@ if _BACKEND == "zmq": from .zmq import request from .zmq import publish -# elif _BACKEND == "rabbitmq": -# from .rabbitmq import request -# from .rabbitmq import publish +elif _BACKEND == "rabbitmq": + try: + from .rabbitmq import request + from .rabbitmq import publish + except ImportError: + raise ImportError( + "MKTL_TRANSPORT='rabbitmq' requires pika: " + "pip install mKTL[rabbitmq]" + ) else: raise ImportError(f"unknown MKTL_TRANSPORT backend: {_BACKEND!r}") diff --git a/src/mktl/transport/rabbitmq/__init__.py b/src/mktl/transport/rabbitmq/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/mktl/transport/rabbitmq/__init__.py @@ -0,0 +1 @@ + diff --git a/src/mktl/transport/rabbitmq/publish.py b/src/mktl/transport/rabbitmq/publish.py new file mode 100644 index 00000000..c709c18a --- /dev/null +++ b/src/mktl/transport/rabbitmq/publish.py @@ -0,0 +1,181 @@ +"""RabbitMQ publish/subscribe transport.""" + +from __future__ import annotations + +import atexit +import os +import queue +import threading +import traceback +from typing import Dict, Optional, Tuple + +import pika + +from ...protocol.message import Message +from ...protocol.wire import pack_frame, unpack_frame +from ..session import PublishSession, SubscribeSession + + +_EXCHANGE = "mktl.pub" +_BROKER_HOST = os.environ.get("MKTL_AMQP_HOST", "localhost") +_BROKER_PORT = int(os.environ.get("MKTL_AMQP_PORT", "5672")) + + +def _broker_params() -> pika.ConnectionParameters: + return pika.ConnectionParameters( + host=_BROKER_HOST, + port=_BROKER_PORT, + heartbeat=600, + blocked_connection_timeout=300, + ) + + +class Client(SubscribeSession): + """SUB client backed by a RabbitMQ topic exchange.""" + + def __init__(self, address: str, port: int): + self.address = address + self.port = int(port) + + self._inbox: queue.Queue = queue.Queue() + self._bindings: list = [] + self._ready = threading.Event() + self._connection = None + + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + self._ready.wait(timeout=10) + + def subscribe(self, topic: str) -> None: + routing_key = topic + ".#" + self._bindings.append(routing_key) + if self._ready.is_set(): + self._connection.add_callback_threadsafe( + lambda rk=routing_key: self._channel.queue_bind( + exchange=_EXCHANGE, + queue=self._queue_name, + routing_key=rk, + ) + ) + + def recv(self) -> Message: + body = self._inbox.get() + return unpack_frame(body) + + def _run(self) -> None: + self._connection = pika.BlockingConnection(_broker_params()) + self._channel = self._connection.channel() + + self._channel.exchange_declare( + exchange=_EXCHANGE, exchange_type="topic", durable=False + ) + + result = self._channel.queue_declare(queue="", exclusive=True) + self._queue_name = result.method.queue + + # Subscribe to everything by default (same as ZMQ SUB with b"") + self._channel.queue_bind( + exchange=_EXCHANGE, + queue=self._queue_name, + routing_key="#", + ) + + # Apply any bindings requested before the channel was ready + for rk in self._bindings: + self._channel.queue_bind( + exchange=_EXCHANGE, + queue=self._queue_name, + routing_key=rk, + ) + + self._channel.basic_consume( + queue=self._queue_name, + on_message_callback=self._on_message, + auto_ack=True, + ) + + self._ready.set() + self._channel.start_consuming() + + def _on_message(self, _ch, _method, _properties, body: bytes) -> None: + self._inbox.put(body) + + +class Server(PublishSession): + """PUB server backed by a RabbitMQ topic exchange.""" + + def __init__( + self, + port: Optional[int] = None, + avoid: Optional[set] = None, + ): + self.port = int(port) if port is not None else 0 + + self._queue: queue.Queue = queue.Queue() + self._ready = threading.Event() + self._connection = None + self.shutdown = False + + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + self._ready.wait(timeout=10) + + def send(self, msg: Message) -> None: + self._queue.put(msg) + self._connection.add_callback_threadsafe(self._flush) + + def _run(self) -> None: + self._connection = pika.BlockingConnection(_broker_params()) + self._channel = self._connection.channel() + + self._channel.exchange_declare( + exchange=_EXCHANGE, exchange_type="topic", durable=False + ) + + self._ready.set() + # Keep the connection alive by consuming from an unused queue + self._connection.process_data_events(time_limit=None) + + def _flush(self) -> None: + """Drain all queued outgoing messages (called on the connection + thread via add_callback_threadsafe).""" + while True: + try: + msg: Message = self._queue.get_nowait() + except queue.Empty: + break + + routing_key = (msg.env.key or "") + "." + try: + self._channel.basic_publish( + exchange=_EXCHANGE, + routing_key=routing_key, + body=pack_frame(msg), + ) + except Exception: + traceback.print_exc() + + +_client_cache: Dict[Tuple[str, int], Client] = {} +_client_lock = threading.Lock() + + +def client(address: str, port: int) -> Client: + key = (address, int(port)) + with _client_lock: + c = _client_cache.get(key) + if c is None: + c = Client(address, int(port)) + _client_cache[key] = c + return c + + +def _cleanup() -> None: + for c in _client_cache.values(): + try: + c._connection.close() + except Exception: + pass + + +atexit.register(_cleanup) diff --git a/src/mktl/transport/rabbitmq/request.py b/src/mktl/transport/rabbitmq/request.py new file mode 100644 index 00000000..b0bcbb58 --- /dev/null +++ b/src/mktl/transport/rabbitmq/request.py @@ -0,0 +1,224 @@ +"""RabbitMQ request/response transport.""" + +from __future__ import annotations + +import atexit +import concurrent.futures +import os +import queue +import socket as pysocket +import threading +from typing import Dict, Optional, Tuple + +import pika + +from ...protocol.message import Message +from ...protocol.wire import pack_frame, unpack_frame +from ...transport import TransportTimeout +from ..session import RequestSession, RequestServer, PendingRequest + + +_BROKER_HOST = os.environ.get("MKTL_AMQP_HOST", "localhost") +_BROKER_PORT = int(os.environ.get("MKTL_AMQP_PORT", "5672")) + + +def _broker_params() -> pika.ConnectionParameters: + return pika.ConnectionParameters( + host=_BROKER_HOST, + port=_BROKER_PORT, + heartbeat=600, + blocked_connection_timeout=300, + ) + + +class Client(RequestSession): + """Issue requests via RabbitMQ and receive responses on an exclusive + reply queue.""" + + timeout = 0.1 + + def __init__(self, address: str, port: int): + self.port = int(port) + self.address = address + + self._pending: Dict[str, PendingRequest] = {} + self._outbox: queue.Queue = queue.Queue() + self._ready = threading.Event() + self._connection = None + + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + self._ready.wait(timeout=10) + + def send(self, msg: Message) -> PendingRequest: + if self._connection is None or not self._ready.is_set(): + raise TransportTimeout( + f"not connected to AMQP broker at {_BROKER_HOST}:{_BROKER_PORT}" + ) + + pending = PendingRequest(msg) + self._outbox.put(pending) + self._connection.add_callback_threadsafe(self._flush_outbox) + + ack = pending.wait_ack(self.timeout) + if not ack: + raise TransportTimeout( + f"{msg.env.type} @ {self.address}:{self.port}: " + f"no ACK in {self.timeout:.2f} sec" + ) + return pending + + def _run(self) -> None: + self._connection = pika.BlockingConnection(_broker_params()) + self._channel = self._connection.channel() + + # Exclusive auto-delete reply queue + result = self._channel.queue_declare(queue="", exclusive=True) + self._reply_queue = result.method.queue + + self._channel.basic_consume( + queue=self._reply_queue, + on_message_callback=self._on_response, + auto_ack=True, + ) + + self._ready.set() + self._channel.start_consuming() + + def _on_response(self, _ch, _method, properties, body: bytes) -> None: + msg = unpack_frame(body) + self._handle_incoming(msg) + + def _flush_outbox(self) -> None: + """Drain all queued outgoing requests (called on the connection + thread via add_callback_threadsafe).""" + while True: + try: + pending: PendingRequest = self._outbox.get_nowait() + except queue.Empty: + break + + server_queue = _server_queue_name(self.address, self.port) + self._pending[pending.id] = pending + self._channel.basic_publish( + exchange="", + routing_key=server_queue, + properties=pika.BasicProperties( + reply_to=self._reply_queue, + correlation_id=pending.id, + ), + body=pack_frame(pending.req), + ) + + +class Server(RequestServer): + """Receive requests from a RabbitMQ queue, respond to them.""" + + port = None + hostname = None + + def __init__( + self, + address: Optional[str] = None, + port: Optional[int] = None, + avoid: Optional[set] = None, + ): + self.hostname = address or pysocket.getfqdn() + self.port = int(port) if port is not None else 0 + + self._queue_name = _server_queue_name(self.hostname, self.port) + + self._responses: queue.Queue = queue.Queue() + self._ready = threading.Event() + self._connection = None + self.shutdown = False + self.workers = concurrent.futures.ThreadPoolExecutor(max_workers=8) + + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + self._ready.wait(timeout=10) + + def send(self, response: Message) -> None: + self._responses.put(response) + self._connection.add_callback_threadsafe(self._flush_responses) + + def _run(self) -> None: + self._connection = pika.BlockingConnection(_broker_params()) + self._channel = self._connection.channel() + + self._channel.queue_declare(queue=self._queue_name, durable=False) + self._channel.basic_qos(prefetch_count=1) + self._channel.basic_consume( + queue=self._queue_name, + on_message_callback=self._on_request, + ) + + self._ready.set() + self._channel.start_consuming() + + def _on_request(self, ch, method, properties, body: bytes) -> None: + ch.basic_ack(delivery_tag=method.delivery_tag) + msg = unpack_frame(body) + msg.env.meta["_reply_to"] = properties.reply_to + msg.env.meta["_correlation_id"] = properties.correlation_id + self.workers.submit(self._req_incoming, msg) + + def _flush_responses(self) -> None: + """Drain all queued outgoing responses (called on the connection + thread via add_callback_threadsafe).""" + while True: + try: + response: Message = self._responses.get_nowait() + except queue.Empty: + break + + reply_to = response.env.meta.get("_reply_to", "") + correlation_id = response.env.meta.get( + "_correlation_id", response.env.transid + ) + self._channel.basic_publish( + exchange="", + routing_key=reply_to, + properties=pika.BasicProperties( + correlation_id=correlation_id, + ), + body=pack_frame(response), + ) + + +_client_cache: Dict[Tuple[str, int], Client] = {} +_client_lock = threading.Lock() + + +def client(address: str, port: int) -> Client: + key = (address, int(port)) + with _client_lock: + c = _client_cache.get(key) + if c is None: + c = Client(address, int(port)) + _client_cache[key] = c + return c + + +def send(address: str, port: int, message: Message) -> Message: + c = client(address, port) + pending = c.send(message) + response = pending.wait(timeout=60) + if response is None: + raise TransportTimeout("no response received") + return response + + +def _server_queue_name(address: str, port: int) -> str: + return f"mktl.req.{address}.{port}" + + +def _cleanup() -> None: + for c in _client_cache.values(): + try: + c._connection.close() + except Exception: + pass + + +atexit.register(_cleanup) diff --git a/src/mktl/transport/session.py b/src/mktl/transport/session.py index 2f36d7bb..85620f00 100644 --- a/src/mktl/transport/session.py +++ b/src/mktl/transport/session.py @@ -87,7 +87,6 @@ def __init__(self, transport: Transport, node_id: str = ""): self.transport = transport self.node_id = node_id - # --- request handling hooks --- def req_handler(self, msg: Message) -> Optional[dict]: """Override in subclasses. @@ -114,7 +113,6 @@ def req_ack(self, msg: Message) -> None: def send(self, response: Message) -> None: self.transport.send(response) - # --- internal --- def _req_incoming(self, msg: Message) -> None: if self._on_receive is not None: self._on_receive(msg) diff --git a/src/mktl/transport/zmq/request.py b/src/mktl/transport/zmq/request.py index 7bbbb234..1f6dc215 100644 --- a/src/mktl/transport/zmq/request.py +++ b/src/mktl/transport/zmq/request.py @@ -154,7 +154,6 @@ def send(self, response: Message) -> None: self._responses.put(response) self._signal_tx.send(b"") - # --- internal --- def _rep_outgoing(self) -> None: self._signal_rx.recv(flags=zmq.NOBLOCK) response: Message = self._responses.get(block=False) @@ -177,8 +176,6 @@ def run(self) -> None: self.workers.submit(self._req_incoming, msg) -# --- convenience helpers (API-compatible-ish) --- - _client_cache: Dict[Tuple[str, int], Client] = {} _client_lock = threading.Lock() From 41b9f8e8a0870681623edf7c4fd5b7a4792702cf Mon Sep 17 00:00:00 2001 From: Michael Langmayr Date: Thu, 12 Feb 2026 21:35:46 +0100 Subject: [PATCH 7/8] fix pub/sub subribe to all --- src/mktl/transport/rabbitmq/publish.py | 32 ++++++++++++++++---------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/src/mktl/transport/rabbitmq/publish.py b/src/mktl/transport/rabbitmq/publish.py index c709c18a..afa3ef26 100644 --- a/src/mktl/transport/rabbitmq/publish.py +++ b/src/mktl/transport/rabbitmq/publish.py @@ -39,6 +39,7 @@ def __init__(self, address: str, port: int): self._inbox: queue.Queue = queue.Queue() self._bindings: list = [] + self._has_wildcard = False self._ready = threading.Event() self._connection = None @@ -51,13 +52,24 @@ def subscribe(self, topic: str) -> None: self._bindings.append(routing_key) if self._ready.is_set(): self._connection.add_callback_threadsafe( - lambda rk=routing_key: self._channel.queue_bind( - exchange=_EXCHANGE, - queue=self._queue_name, - routing_key=rk, - ) + lambda rk=routing_key: self._bind_topic(rk) ) + def _bind_topic(self, routing_key: str) -> None: + """Add a topic binding, removing the default wildcard if present.""" + if self._has_wildcard: + self._channel.queue_unbind( + exchange=_EXCHANGE, + queue=self._queue_name, + routing_key="#", + ) + self._has_wildcard = False + self._channel.queue_bind( + exchange=_EXCHANGE, + queue=self._queue_name, + routing_key=routing_key, + ) + def recv(self) -> Message: body = self._inbox.get() return unpack_frame(body) @@ -73,20 +85,16 @@ def _run(self) -> None: result = self._channel.queue_declare(queue="", exclusive=True) self._queue_name = result.method.queue - # Subscribe to everything by default (same as ZMQ SUB with b"") self._channel.queue_bind( exchange=_EXCHANGE, queue=self._queue_name, routing_key="#", ) + self._has_wildcard = True - # Apply any bindings requested before the channel was ready + # Apply any bindings requested before the channel was ready. for rk in self._bindings: - self._channel.queue_bind( - exchange=_EXCHANGE, - queue=self._queue_name, - routing_key=rk, - ) + self._bind_topic(rk) self._channel.basic_consume( queue=self._queue_name, From fa6b1859ba02760e99505328105e9747db77c5fc Mon Sep 17 00:00:00 2001 From: Michael Langmayr Date: Thu, 12 Feb 2026 22:20:24 +0100 Subject: [PATCH 8/8] cleanup --- src/mktl/protocol/protocol.py | 63 ++++++++--------------------------- 1 file changed, 14 insertions(+), 49 deletions(-) diff --git a/src/mktl/protocol/protocol.py b/src/mktl/protocol/protocol.py index a2e9337d..f742ec73 100644 --- a/src/mktl/protocol/protocol.py +++ b/src/mktl/protocol/protocol.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Optional, Callable, Dict, Any +from typing import Optional, Callable, Any from .builder import MessageBuilder from .message import Message @@ -13,11 +13,7 @@ class Protocol: def __init__(self, node_id: str, session): self.builder = MessageBuilder(node_id) self.session = session - self._req_handlers: Dict[str, Callable[[Message], Optional[dict]]] = {} - self._evt_handlers: Dict[str, Callable[[Message], None]] = {} - - if hasattr(session, '_on_receive'): - session._on_receive = self.dispatch + self._handlers = {} # Request APIs @@ -69,28 +65,18 @@ def publish(self, topic: str, payload: Any): self.session.send(msg) - def on( - self, - key: str, - handler: Callable[[Message], Optional[dict]], - ) -> None: - self._req_handlers[key] = handler - - def listen( - self, - topic: str, - handler: Callable[[Message], None], - ) -> None: - self._evt_handlers[topic] = handler - def serve( self, key: str, callback: Callable[[dict, dict], Optional[dict]], ) -> None: - """callback(payload, ctx) -> dict | None""" + """ + Register a key handler. - def _wrap(msg: Message) -> Optional[dict]: + callback(payload, ctx) -> dict | None + """ + + def _wrap(msg: Message): payload = msg.env.payload @@ -104,40 +90,19 @@ def _wrap(msg: Message) -> Optional[dict]: } try: - return callback(payload, ctx) + result = callback(payload, ctx) except Exception as ex: - return {"ok": False, "error": str(ex)} - - self.on(key, _wrap) - - - def dispatch(self, msg: Message) -> None: - env = msg.env - - if request.is_request(msg): - handler = self._req_handlers.get(env.key) - if handler is None: - return - - if hasattr(self.session, 'req_ack'): - self.session.req_ack(msg) - - result = handler(msg) + result = {"ok": False, "error": str(ex)} + # Build response automatically if result is not None: resp = ( self.builder - .rep(env.transid) - .to(env.sourceid) + .rep(msg.env.transid) + .to(msg.env.sourceid) .payload(result) - .meta(dict(env.meta)) .build() ) self.session.send(resp) - return - if publish.is_publish(msg): - handler = self._evt_handlers.get(env.key) - if handler is not None: - handler(msg) - return + self._handlers[key] = _wrap