diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..c18dd8d8 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__/ diff --git a/pyproject.toml b/pyproject.toml index eede4096..b83011d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,13 +12,16 @@ pyzmq = ">=4.0" msgspec = { version = ">=0.18", optional = true } numpy = { version = ">=1.6", optional = true } orjson = { version = ">=3.0", optional = true } +pika = { version = ">=1.3", optional = true } pint = { version = ">=0.17", optional = true } [tool.poetry.extras] msgspec = ["msgspec"] numpy = ["numpy"] orjson = ["orjson"] +pika = ["pika"] pint = ["pint"] +rabbitmq = ["pika"] [tool.poetry.urls] repository = "https://github.com/KeckObservatory/mKTL" diff --git a/src/mktl/__init__.py b/src/mktl/__init__.py index 475e8e2a..c12e9892 100644 --- a/src/mktl/__init__.py +++ b/src/mktl/__init__.py @@ -24,6 +24,5 @@ from .item import Item from .store import Store from .daemon import Daemon -Payload = protocol.message.Payload # vim: set expandtab tabstop=8 softtabstop=4 shiftwidth=4 autoindent: diff --git a/src/mktl/begin.py b/src/mktl/begin.py index e0fb905f..30c46482 100644 --- a/src/mktl/begin.py +++ b/src/mktl/begin.py @@ -3,10 +3,11 @@ """ import threading -import zmq from . import config from . import protocol +from . import transport +from .transport import TransportError from .store import Store @@ -47,13 +48,13 @@ def discover(*targets): # Hacking the timeout for discovery, this is not expected to throw # errors with minimal delay. - old_timeout = protocol.request.Client.timeout - protocol.request.Client.timeout = 0.5 + old_timeout = transport.request.Client.timeout + transport.request.Client.timeout = 0.5 for address,port in brokers: request = protocol.message.Request('HASH') try: - payload = protocol.request.send(address, port, request) + payload = transport.request.send(address, port, request) except: continue @@ -61,7 +62,7 @@ def discover(*targets): for store in hashes.keys(): request = protocol.message.Request('CONFIG', store) - payload = protocol.request.send(address, port, request) + payload = transport.request.send(address, port, request) blocks = payload.value @@ -71,7 +72,7 @@ def discover(*targets): configuration.update(block) - protocol.request.Client.timeout = old_timeout + transport.request.Client.timeout = old_timeout @@ -140,7 +141,7 @@ def get(store, key=None): hostname,port = brokers[0] message = protocol.message.Request('CONFIG', store) - payload = protocol.request.send(hostname, port, message) + payload = transport.request.send(hostname, port, message) blocks = payload.value @@ -200,12 +201,12 @@ def refresh(configuration): hostname = stratum['hostname'] rep = stratum['rep'] - client = protocol.request.client(hostname, rep) + client = transport.request.client(hostname, rep) request = protocol.message.Request('HASH', store) try: client.send(request) - except zmq.ZMQError: + except TransportError: # No response from this daemon; move on to the next entry in # the provenance. If no daemons respond the client will have # to rely on the local disk cache. diff --git a/src/mktl/config.py b/src/mktl/config.py index 24fb8c7b..daa11d3b 100644 --- a/src/mktl/config.py +++ b/src/mktl/config.py @@ -5,7 +5,6 @@ import threading import time import uuid -import zmq # Importing pint is expensive, representing something like 30% of the # user runtime for a simple mKTL command. It will be imported on a @@ -19,6 +18,8 @@ from . import json from . import protocol +from . import transport +from .transport import TransportError _cache = dict() @@ -1195,8 +1196,8 @@ def announce(config, uuid, override=False): for address,port in brokers: try: - payload = protocol.request.send(address, port, message) - except zmq.error.ZMQError: + payload = transport.request.send(address, port, message) + except TransportError: continue error = payload.error diff --git a/src/mktl/daemon.py b/src/mktl/daemon.py index e50a5f4b..c24ee1e6 100644 --- a/src/mktl/daemon.py +++ b/src/mktl/daemon.py @@ -9,7 +9,6 @@ import sys import threading import time -import zmq from . import begin from . import config @@ -18,6 +17,8 @@ from . import poll from . import protocol from . import store +from . import transport +from .transport import TransportError, TransportPortError class Daemon: @@ -85,15 +86,15 @@ def __init__(self, store, alias, override=False, options=None): self._test_port(store, rep) try: - self.pub = protocol.publish.Server(port=pub, avoid=avoid) - except zmq.error.ZMQError: - self.pub = protocol.publish.Server(port=None, avoid=avoid) + self.pub = transport.publish.Server(port=pub, avoid=avoid) + except TransportPortError: + self.pub = transport.publish.Server(port=None, avoid=avoid) avoid = _used_ports() try: self.rep = RequestServer(self, port=rep, avoid=avoid) - except zmq.error.ZMQError: + except TransportPortError: self.rep = RequestServer(self, port=None, avoid=avoid) _save_port(store, self.uuid, self.rep.port, self.pub.port) @@ -420,8 +421,8 @@ def _test_port(self, store, port): request = protocol.message.Request('CONFIG', store) try: - payload = protocol.request.send(hostname, port, request) - except zmq.ZMQError: + payload = transport.request.send(hostname, port, request) + except TransportError: # Not running; perfect. return @@ -446,10 +447,10 @@ def _test_port(self, store, port): -class RequestServer(protocol.request.Server): +class RequestServer(transport.request.Server): def __init__(self, daemon, *args, **kwargs): - protocol.request.Server.__init__(self, *args, **kwargs) + transport.request.Server.__init__(self, *args, **kwargs) self.daemon = daemon diff --git a/src/mktl/item.py b/src/mktl/item.py index 6cf8119c..2a849c49 100644 --- a/src/mktl/item.py +++ b/src/mktl/item.py @@ -3,7 +3,6 @@ import threading import time import traceback -import zmq try: import numpy @@ -12,7 +11,9 @@ from . import protocol from . import poll +from . import transport from . import weakref +from .transport import TransportError class Item: @@ -95,8 +96,8 @@ def __init__(self, store, key, subscribe=True, authoritative=False, pub=None): # configuration that doesn't contain a provenance. raise RuntimeError('cannot find daemon for ' + self.full_key) - self.sub = protocol.publish.client(hostname, pub) - self.req = protocol.request.client(hostname, rep) + self.sub = transport.publish.client(hostname, pub) + self.req = transport.request.client(hostname, rep) try: settable = self.config['settable'] @@ -779,7 +780,7 @@ def subscribe(self, prime=True): if prime == True: try: self.get(refresh=True) - except (zmq.ZMQError, RuntimeError): + except (TransportError, RuntimeError): # Connection errors and remote errors on priming reads are # thrown away; an error here means the remote daemon is not # available to respond to requests, but despite that error diff --git a/src/mktl/transport/__init__.py b/src/mktl/transport/__init__.py index 6f40511f..dba18e4f 100644 --- a/src/mktl/transport/__init__.py +++ b/src/mktl/transport/__init__.py @@ -1,5 +1,27 @@ -"""Transport implementations. +"""Transport layer implementations.""" -Each transport is responsible for mapping protocol :class:`mktl.protocol.Message` -objects to/from a wire representation. -""" +import os + +from .base import ( + TransportError, + TransportTimeout, + TransportConnectionError, + TransportPortError, +) + +_BACKEND = os.environ.get("MKTL_TRANSPORT", "zmq") + +if _BACKEND == "zmq": + from .zmq import request + from .zmq import publish +elif _BACKEND == "rabbitmq": + try: + from .rabbitmq import request + from .rabbitmq import publish + except ImportError: + raise ImportError( + "MKTL_TRANSPORT='rabbitmq' requires pika: " + "pip install mKTL[rabbitmq]" + ) +else: + raise ImportError(f"unknown MKTL_TRANSPORT backend: {_BACKEND!r}") diff --git a/src/mktl/transport/base.py b/src/mktl/transport/base.py index 90b5532e..b993b112 100644 --- a/src/mktl/transport/base.py +++ b/src/mktl/transport/base.py @@ -7,15 +7,49 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import Optional -from ..protocol import Message +from ..protocol.message import Message + + +# Transport agnostic exceptions + +class TransportError(Exception): + """Base class for all transport-layer errors.""" + + +class TransportTimeout(TransportError): + """A request did not receive a timely response.""" + + +class TransportConnectionError(TransportError): + """The transport could not establish or maintain a connection.""" + + +class TransportPortError(TransportError): + """No suitable port could be bound or connected.""" class Transport(ABC): + """Minimal contract for a wire-level transport.""" + + @abstractmethod + def open(self) -> None: + """Establish the underlying connection/socket.""" + + @abstractmethod + def close(self) -> None: + """Tear down the underlying connection/socket.""" + @abstractmethod def send(self, msg: Message) -> None: - """Send a protocol message.""" + """Send a protocol Message.""" @abstractmethod - def recv(self) -> Message: - """Receive the next protocol message.""" + def recv(self, timeout: Optional[float] = None) -> Message: + """Receive the next protocol Message.""" + + @property + def is_open(self) -> bool: + """Whether the transport is currently connected.""" + return False diff --git a/src/mktl/transport/rabbitmq/__init__.py b/src/mktl/transport/rabbitmq/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/mktl/transport/rabbitmq/__init__.py @@ -0,0 +1 @@ + diff --git a/src/mktl/transport/rabbitmq/publish.py b/src/mktl/transport/rabbitmq/publish.py new file mode 100644 index 00000000..afa3ef26 --- /dev/null +++ b/src/mktl/transport/rabbitmq/publish.py @@ -0,0 +1,189 @@ +"""RabbitMQ publish/subscribe transport.""" + +from __future__ import annotations + +import atexit +import os +import queue +import threading +import traceback +from typing import Dict, Optional, Tuple + +import pika + +from ...protocol.message import Message +from ...protocol.wire import pack_frame, unpack_frame +from ..session import PublishSession, SubscribeSession + + +_EXCHANGE = "mktl.pub" +_BROKER_HOST = os.environ.get("MKTL_AMQP_HOST", "localhost") +_BROKER_PORT = int(os.environ.get("MKTL_AMQP_PORT", "5672")) + + +def _broker_params() -> pika.ConnectionParameters: + return pika.ConnectionParameters( + host=_BROKER_HOST, + port=_BROKER_PORT, + heartbeat=600, + blocked_connection_timeout=300, + ) + + +class Client(SubscribeSession): + """SUB client backed by a RabbitMQ topic exchange.""" + + def __init__(self, address: str, port: int): + self.address = address + self.port = int(port) + + self._inbox: queue.Queue = queue.Queue() + self._bindings: list = [] + self._has_wildcard = False + self._ready = threading.Event() + self._connection = None + + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + self._ready.wait(timeout=10) + + def subscribe(self, topic: str) -> None: + routing_key = topic + ".#" + self._bindings.append(routing_key) + if self._ready.is_set(): + self._connection.add_callback_threadsafe( + lambda rk=routing_key: self._bind_topic(rk) + ) + + def _bind_topic(self, routing_key: str) -> None: + """Add a topic binding, removing the default wildcard if present.""" + if self._has_wildcard: + self._channel.queue_unbind( + exchange=_EXCHANGE, + queue=self._queue_name, + routing_key="#", + ) + self._has_wildcard = False + self._channel.queue_bind( + exchange=_EXCHANGE, + queue=self._queue_name, + routing_key=routing_key, + ) + + def recv(self) -> Message: + body = self._inbox.get() + return unpack_frame(body) + + def _run(self) -> None: + self._connection = pika.BlockingConnection(_broker_params()) + self._channel = self._connection.channel() + + self._channel.exchange_declare( + exchange=_EXCHANGE, exchange_type="topic", durable=False + ) + + result = self._channel.queue_declare(queue="", exclusive=True) + self._queue_name = result.method.queue + + self._channel.queue_bind( + exchange=_EXCHANGE, + queue=self._queue_name, + routing_key="#", + ) + self._has_wildcard = True + + # Apply any bindings requested before the channel was ready. + for rk in self._bindings: + self._bind_topic(rk) + + self._channel.basic_consume( + queue=self._queue_name, + on_message_callback=self._on_message, + auto_ack=True, + ) + + self._ready.set() + self._channel.start_consuming() + + def _on_message(self, _ch, _method, _properties, body: bytes) -> None: + self._inbox.put(body) + + +class Server(PublishSession): + """PUB server backed by a RabbitMQ topic exchange.""" + + def __init__( + self, + port: Optional[int] = None, + avoid: Optional[set] = None, + ): + self.port = int(port) if port is not None else 0 + + self._queue: queue.Queue = queue.Queue() + self._ready = threading.Event() + self._connection = None + self.shutdown = False + + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + self._ready.wait(timeout=10) + + def send(self, msg: Message) -> None: + self._queue.put(msg) + self._connection.add_callback_threadsafe(self._flush) + + def _run(self) -> None: + self._connection = pika.BlockingConnection(_broker_params()) + self._channel = self._connection.channel() + + self._channel.exchange_declare( + exchange=_EXCHANGE, exchange_type="topic", durable=False + ) + + self._ready.set() + # Keep the connection alive by consuming from an unused queue + self._connection.process_data_events(time_limit=None) + + def _flush(self) -> None: + """Drain all queued outgoing messages (called on the connection + thread via add_callback_threadsafe).""" + while True: + try: + msg: Message = self._queue.get_nowait() + except queue.Empty: + break + + routing_key = (msg.env.key or "") + "." + try: + self._channel.basic_publish( + exchange=_EXCHANGE, + routing_key=routing_key, + body=pack_frame(msg), + ) + except Exception: + traceback.print_exc() + + +_client_cache: Dict[Tuple[str, int], Client] = {} +_client_lock = threading.Lock() + + +def client(address: str, port: int) -> Client: + key = (address, int(port)) + with _client_lock: + c = _client_cache.get(key) + if c is None: + c = Client(address, int(port)) + _client_cache[key] = c + return c + + +def _cleanup() -> None: + for c in _client_cache.values(): + try: + c._connection.close() + except Exception: + pass + + +atexit.register(_cleanup) diff --git a/src/mktl/transport/rabbitmq/request.py b/src/mktl/transport/rabbitmq/request.py new file mode 100644 index 00000000..b0bcbb58 --- /dev/null +++ b/src/mktl/transport/rabbitmq/request.py @@ -0,0 +1,224 @@ +"""RabbitMQ request/response transport.""" + +from __future__ import annotations + +import atexit +import concurrent.futures +import os +import queue +import socket as pysocket +import threading +from typing import Dict, Optional, Tuple + +import pika + +from ...protocol.message import Message +from ...protocol.wire import pack_frame, unpack_frame +from ...transport import TransportTimeout +from ..session import RequestSession, RequestServer, PendingRequest + + +_BROKER_HOST = os.environ.get("MKTL_AMQP_HOST", "localhost") +_BROKER_PORT = int(os.environ.get("MKTL_AMQP_PORT", "5672")) + + +def _broker_params() -> pika.ConnectionParameters: + return pika.ConnectionParameters( + host=_BROKER_HOST, + port=_BROKER_PORT, + heartbeat=600, + blocked_connection_timeout=300, + ) + + +class Client(RequestSession): + """Issue requests via RabbitMQ and receive responses on an exclusive + reply queue.""" + + timeout = 0.1 + + def __init__(self, address: str, port: int): + self.port = int(port) + self.address = address + + self._pending: Dict[str, PendingRequest] = {} + self._outbox: queue.Queue = queue.Queue() + self._ready = threading.Event() + self._connection = None + + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + self._ready.wait(timeout=10) + + def send(self, msg: Message) -> PendingRequest: + if self._connection is None or not self._ready.is_set(): + raise TransportTimeout( + f"not connected to AMQP broker at {_BROKER_HOST}:{_BROKER_PORT}" + ) + + pending = PendingRequest(msg) + self._outbox.put(pending) + self._connection.add_callback_threadsafe(self._flush_outbox) + + ack = pending.wait_ack(self.timeout) + if not ack: + raise TransportTimeout( + f"{msg.env.type} @ {self.address}:{self.port}: " + f"no ACK in {self.timeout:.2f} sec" + ) + return pending + + def _run(self) -> None: + self._connection = pika.BlockingConnection(_broker_params()) + self._channel = self._connection.channel() + + # Exclusive auto-delete reply queue + result = self._channel.queue_declare(queue="", exclusive=True) + self._reply_queue = result.method.queue + + self._channel.basic_consume( + queue=self._reply_queue, + on_message_callback=self._on_response, + auto_ack=True, + ) + + self._ready.set() + self._channel.start_consuming() + + def _on_response(self, _ch, _method, properties, body: bytes) -> None: + msg = unpack_frame(body) + self._handle_incoming(msg) + + def _flush_outbox(self) -> None: + """Drain all queued outgoing requests (called on the connection + thread via add_callback_threadsafe).""" + while True: + try: + pending: PendingRequest = self._outbox.get_nowait() + except queue.Empty: + break + + server_queue = _server_queue_name(self.address, self.port) + self._pending[pending.id] = pending + self._channel.basic_publish( + exchange="", + routing_key=server_queue, + properties=pika.BasicProperties( + reply_to=self._reply_queue, + correlation_id=pending.id, + ), + body=pack_frame(pending.req), + ) + + +class Server(RequestServer): + """Receive requests from a RabbitMQ queue, respond to them.""" + + port = None + hostname = None + + def __init__( + self, + address: Optional[str] = None, + port: Optional[int] = None, + avoid: Optional[set] = None, + ): + self.hostname = address or pysocket.getfqdn() + self.port = int(port) if port is not None else 0 + + self._queue_name = _server_queue_name(self.hostname, self.port) + + self._responses: queue.Queue = queue.Queue() + self._ready = threading.Event() + self._connection = None + self.shutdown = False + self.workers = concurrent.futures.ThreadPoolExecutor(max_workers=8) + + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + self._ready.wait(timeout=10) + + def send(self, response: Message) -> None: + self._responses.put(response) + self._connection.add_callback_threadsafe(self._flush_responses) + + def _run(self) -> None: + self._connection = pika.BlockingConnection(_broker_params()) + self._channel = self._connection.channel() + + self._channel.queue_declare(queue=self._queue_name, durable=False) + self._channel.basic_qos(prefetch_count=1) + self._channel.basic_consume( + queue=self._queue_name, + on_message_callback=self._on_request, + ) + + self._ready.set() + self._channel.start_consuming() + + def _on_request(self, ch, method, properties, body: bytes) -> None: + ch.basic_ack(delivery_tag=method.delivery_tag) + msg = unpack_frame(body) + msg.env.meta["_reply_to"] = properties.reply_to + msg.env.meta["_correlation_id"] = properties.correlation_id + self.workers.submit(self._req_incoming, msg) + + def _flush_responses(self) -> None: + """Drain all queued outgoing responses (called on the connection + thread via add_callback_threadsafe).""" + while True: + try: + response: Message = self._responses.get_nowait() + except queue.Empty: + break + + reply_to = response.env.meta.get("_reply_to", "") + correlation_id = response.env.meta.get( + "_correlation_id", response.env.transid + ) + self._channel.basic_publish( + exchange="", + routing_key=reply_to, + properties=pika.BasicProperties( + correlation_id=correlation_id, + ), + body=pack_frame(response), + ) + + +_client_cache: Dict[Tuple[str, int], Client] = {} +_client_lock = threading.Lock() + + +def client(address: str, port: int) -> Client: + key = (address, int(port)) + with _client_lock: + c = _client_cache.get(key) + if c is None: + c = Client(address, int(port)) + _client_cache[key] = c + return c + + +def send(address: str, port: int, message: Message) -> Message: + c = client(address, port) + pending = c.send(message) + response = pending.wait(timeout=60) + if response is None: + raise TransportTimeout("no response received") + return response + + +def _server_queue_name(address: str, port: int) -> str: + return f"mktl.req.{address}.{port}" + + +def _cleanup() -> None: + for c in _client_cache.values(): + try: + c._connection.close() + except Exception: + pass + + +atexit.register(_cleanup) diff --git a/src/mktl/transport/session.py b/src/mktl/transport/session.py new file mode 100644 index 00000000..85620f00 --- /dev/null +++ b/src/mktl/transport/session.py @@ -0,0 +1,178 @@ +"""Transport-agnostic session layer.""" + +from __future__ import annotations + +import sys +import threading +import traceback +from typing import Dict, Optional + +from ..protocol.message import Message, Envelope, MsgType +from .base import Transport, TransportTimeout + + +class PendingRequest: + """Client-side helper that provides ACK/REP synchronization.""" + + def __init__(self, msg: Message): + self.req = msg + self.response: Optional[Message] = None + self.ack_event = threading.Event() + self.rep_event = threading.Event() + + @property + def id(self) -> str: + return self.req.env.transid + + def wait_ack(self, timeout: Optional[float]) -> bool: + return self.ack_event.wait(timeout) + + def wait(self, timeout: Optional[float] = 60) -> Optional[Message]: + self.rep_event.wait(timeout) + return self.response + + def _complete_ack(self) -> None: + self.ack_event.set() + + def _complete(self, response: Message) -> None: + self.response = response + self.ack_event.set() + self.rep_event.set() + + +class RequestSession: + """Client-side request/response pattern logic.""" + + timeout = 0.1 + + def __init__(self, transport: Transport): + self.transport = transport + self._pending: Dict[str, PendingRequest] = {} + + def _handle_incoming(self, msg: Message) -> None: + """Correlate incoming ACK/REP to a PendingRequest.""" + pending = self._pending.get(msg.env.transid) + if pending is None: + return + + if msg.env.type == MsgType.ACK: + pending._complete_ack() + return + + # REP + pending._complete(msg) + self._pending.pop(msg.env.transid, None) + + def send(self, msg: Message) -> PendingRequest: + pending = PendingRequest(msg) + self._pending[pending.id] = pending + self.transport.send(msg) + + ack = pending.wait_ack(self.timeout) + if not ack: + self._pending.pop(pending.id, None) + raise TransportTimeout( + f"{msg.env.type}: no ACK in {self.timeout:.2f} sec" + ) + return pending + + +class RequestServer: + """Server-side request handler.""" + + node_id = "" + _on_receive = None + + def __init__(self, transport: Transport, node_id: str = ""): + self.transport = transport + self.node_id = node_id + + def req_handler(self, msg: Message) -> Optional[dict]: + """Override in subclasses. + + Return: + - dict -> will be wrapped into a REP + - None -> no immediate REP (handler is responsible) + """ + self.req_ack(msg) + return None + + def req_ack(self, msg: Message) -> None: + ack = Message( + env=Envelope( + type=MsgType.ACK, + sourceid=self.node_id, + transid=msg.env.transid, + destid=msg.env.sourceid, + key=msg.env.key, + meta=dict(msg.env.meta), + ), + ) + self.send(ack) + + def send(self, response: Message) -> None: + self.transport.send(response) + + def _req_incoming(self, msg: Message) -> None: + if self._on_receive is not None: + self._on_receive(msg) + return + + payload: Optional[dict] = None + error: Optional[dict] = None + + try: + payload = self.req_handler(msg) + except Exception: + e_class, e_instance, _tb = sys.exc_info() + error = { + "type": getattr(e_class, "__name__", "Exception"), + "text": str(e_instance), + "debug": traceback.format_exc(), + } + + if payload is None and error is None: + return + + rep_payload = payload if payload is not None else {} + if error is not None: + rep_payload["error"] = error + + rep = Message( + env=Envelope( + type=MsgType.REP, + sourceid=self.node_id, + transid=msg.env.transid, + destid=msg.env.sourceid, + key=msg.env.key, + payload=rep_payload, + meta=dict(msg.env.meta), + ), + ) + self.send(rep) + + +class PublishSession: + """Server-side publish pattern logic.""" + + port = None + + def __init__(self, transport: Transport): + self.transport = transport + + def send(self, msg: Message) -> None: + self.transport.send(msg) + + +class SubscribeSession: + """Client-side subscribe pattern logic.""" + + def __init__(self, transport: Transport): + self.transport = transport + + def subscribe(self, topic: str) -> None: + if hasattr(self.transport, 'subscribe'): + self.transport.subscribe(topic) + + def recv(self) -> Message: + return self.transport.recv() diff --git a/src/mktl/transport/zmq/__init__.py b/src/mktl/transport/zmq/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/mktl/transport/zmq/framing.py b/src/mktl/transport/zmq/framing.py deleted file mode 100644 index 24f28d6f..00000000 --- a/src/mktl/transport/zmq/framing.py +++ /dev/null @@ -1,118 +0,0 @@ -"""ZMQ multipart framing for protocol messages. - -Request/Response (DEALER<->ROUTER) - (optional routing prefix...), version, id, type, target, payload_json, bulk - -Publish (PUB/SUB) - topic_with_trailing_dot, version, payload_json, bulk -""" - -from __future__ import annotations - -from typing import Iterable, Optional, Sequence, Tuple - -from ...protocol import Message, Payload, PROTOCOL_VERSION -from .codec import encode_payload, decode_payload - - -_VERSION_BYTES = PROTOCOL_VERSION.encode() - - -def _as_bytes_id(msg_id: bytes) -> bytes: - # msg_id is bytes in the new protocol. Keep as-is. - return msg_id - - -def to_request_frames(msg: Message, *, include_prefix: bool = False) -> Tuple[bytes, ...]: - """Encode a protocol Message to ZMQ request/response multipart frames.""" - - prefix: Tuple[bytes, ...] = tuple(msg.meta.get("zmq_prefix", ())) - if prefix and not include_prefix: - prefix = () - - payload_bytes, bulk = encode_payload(msg.payload) - target = (msg.target or "").encode() - parts = ( - _VERSION_BYTES, - _as_bytes_id(msg.msg_id), - msg.msg_type.encode(), - target, - payload_bytes, - bulk, - ) - return prefix + parts - - -def from_request_frames(parts: Sequence[bytes]) -> Message: - """Decode ROUTER/DEALER parts into a protocol Message. - - If a ROUTER identity prefix is present, it is stored as msg.meta['zmq_prefix']. - """ - - if not parts: - raise ValueError("empty message") - - # ROUTER sockets prepend identity frames. We expect either: - # [version, id, type, target, payload, bulk] - # or - # [ident, version, id, type, target, payload, bulk] - if parts[0] == _VERSION_BYTES: - prefix: Tuple[bytes, ...] = () - start = 0 - else: - prefix = (parts[0],) - start = 1 - - their_version = parts[start] - if their_version != _VERSION_BYTES: - # Version mismatch: represent as an error payload. - err = { - "type": "RuntimeError", - "text": f"message is mKTL protocol {their_version!r}, recipient expects {_VERSION_BYTES!r}", - } - payload = Payload(value=None, error=err) - msg = Message(msg_type="REP", target="???", payload=payload, msg_id=parts[start + 1]) - if prefix: - msg.meta["zmq_prefix"] = prefix - return msg - - msg_id = parts[start + 1] - msg_type = parts[start + 2].decode() - target = parts[start + 3].decode() if parts[start + 3] not in (b"", None) else None - payload_bytes = parts[start + 4] - bulk_bytes = parts[start + 5] if len(parts) > start + 5 else b"" - - payload = decode_payload(payload_bytes, bulk_bytes) - msg = Message(msg_type=msg_type, target=target, payload=payload, msg_id=msg_id) - if prefix: - msg.meta["zmq_prefix"] = prefix - return msg - - -def to_pub_frames(msg: Message) -> Tuple[bytes, ...]: - """Encode a publish message for PUB/SUB sockets.""" - - topic = (msg.target or "") + "." # trailing dot to prevent prefix matches - topic_b = topic.encode() - payload_bytes, bulk = encode_payload(msg.payload) - return (topic_b, _VERSION_BYTES, payload_bytes, bulk) - - -def from_pub_frames(parts: Sequence[bytes]) -> Message: - if len(parts) < 4: - raise ValueError("invalid PUB message") - - topic = parts[0].decode() - if topic.endswith("."): - topic = topic[:-1] - their_version = parts[1] - if their_version != _VERSION_BYTES: - err = { - "type": "RuntimeError", - "text": f"message is mKTL protocol {their_version!r}, recipient expects {_VERSION_BYTES!r}", - } - payload = Payload(value=None, error=err) - return Message(msg_type="PUB", target=topic, payload=payload) - - payload = decode_payload(parts[2], parts[3]) - return Message(msg_type="PUB", target=topic, payload=payload) diff --git a/src/mktl/transport/zmq/message.py b/src/mktl/transport/zmq/message.py deleted file mode 100644 index 0d07c40d..00000000 --- a/src/mktl/transport/zmq/message.py +++ /dev/null @@ -1,410 +0,0 @@ -""" A class representation of an mKTL message, including subclasses for - specific messages. -""" - -import itertools -import threading -import time as timemodule - -from ... import json - - -# This is the version of the mKTL on-the-wire protocol implemented here. -# See the protocol specification for a full description; the version is -# identified by a single byte. - -version = b'a' - - -class Message: - """ The :class:`Message` provides a very thin encapsulation of what it - means to be a message in an mKTL context. This class will be used - to represent mKTL messages that do not result in a response. - - The fields are largely in order of how they are represented on the - wire: the message *type*, the key/store *target* for the request, - the *payload* of the message (a :class:`Payload` instance), - and an identification - number unique to this correspondence. The identification number is - the one field that is out-of-order compared to the multipart sequence - on the wire; this is because some message types (publish messages, - in particular) do not have an identification number, and it is - automatically generated for request messages. Rather than force the - caller to pass an explicit None, the id is left as the last field, - so that the arguments for all :class:`Message` instances can have - a similar structure. - - :ivar payload: The :class:`Payload`, if any, for the message. - :ivar valid_types: A set of valid strings for the message type. - :ivar timestamp: A UNIX epoch timestamp for the message send time. - """ - - valid_types = set(('ACK', 'REP')) - - def __init__(self, type, target=None, payload=None, id=None): - - if type in self.valid_types: - pass - else: - raise ValueError('invalid request type: ' + type) - - # There are some message types where the id is allowed to be None; - # for example, publish messages do not have or need an identification - # number or a prefix. - - self.id = id - self.type = type - self.payload = payload - self.prefix = None - self.target = target - self.timestamp = timemodule.time() - - self._parts = None - - - def __iter__(self): - self._finalize() - return iter(self._parts) - - - def __repr__(self): - self._finalize() - return repr(self._parts) - - - def _finalize(self): - """ Take the contents of this :class:`Message`, interpet them as - bytes, and prepare the tuple that will be used for the multipart - transmission on the wire. - """ - - if self._parts: - # Once finalized, always finalized. - return - - id = self.id - type = self.type - target = self.target - payload = self.payload - - # It is legal to create a Message with None as the id-- this happens - # all the time when a Message is used as a container-- but trying to - # send such a message is not permitted. - - if id is None: - raise RuntimeError('messages must have an id to be put on the wire') - - try: - id.decode - except AttributeError: - id = '%08x' % (id) - id = id.encode() - - type = type.encode() - - if target == None or target == '': - target = b'' - else: - try: - target = target.encode() - except AttributeError: - # Assume it is already bytes. - pass - - if payload is None or payload == '': - bulk = b'' - payload = b'' - else: - bulk = payload.bulk - if bulk is None: - bulk = b'' - payload = payload.encapsulate() - - if self.prefix: - parts = self.prefix + (version, id, type, target, payload, bulk) - else: - parts = (version, id, type, target, payload, bulk) - - self._parts = parts - - -# end of class Message - - - -class Broadcast(Message): - """ A :class:`Broadcast` is a minor variant of a :class:`Message`, - with a change to format the multipart tuple in a PUB/SUB specific - fashion. - """ - - valid_types = set(('PUB',)) - - def _finalize(self): - - if self._parts: - # Once finalized, always finalized. - return - - target = self.target - payload = self.payload - - # The PUB/SUB topic has a trailing dot to prevent leading - # substring matches from picking up extra keys. - - target = target + '.' - target = target.encode() - - if payload is None or payload == '': - bulk = b'' - payload = b'' - else: - bulk = payload.bulk - if bulk is None: - bulk = b'' - payload = payload.encapsulate() - - # The prefix is ignored for broadcast messages; it should not be set. - self._parts = (target, version, payload, bulk) - - -# end of class Broadcast - - - -class Request(Message): - """ A :class:`Request` has a little extra functionality, focusing on - local caching of response values and signaling that a request is - complete. This is the class that will be used on the client side - when a server is expected to provide a response, such as returning - a requested value, or signaling that a set operation is complete. - - :ivar response: The response (as a :class:`Message`) to this request - """ - - valid_types = set(('CONFIG', 'GET', 'HASH', 'SET')) - - def __init__(self, type, target=None, payload=None, id=None): - - # Requests are generally initiated without an id number, but they're - # required to have one. The expectation is that requests will have an - # id number that is locally unique, so that the request/response - # handler can correctly tie an incoming response to the request that - # generated it. - - # Long story short, for nearly all Request instances the id argument - # will be None, and we are expected to auto-generate a locally unique - # identification number. - - if id is None: - id = _id_next() - - Message.__init__(self, type, target, payload, id) - - self.response = None - - self.ack_event = threading.Event() - self.rep_event = threading.Event() - - - def __repr__(self): - self._finalize() - request = 'REQ: ' + repr(self._parts) - - if self.response is None: - response = 'REP: None' - else: - response = 'REP: ' + repr(tuple(self.response)) - - return request + ', ' + response - - - def _complete_ack(self): - """ The request, if any, has been acknowledged; signal any callers blocking via :func:`wait_ack` to proceed. - """ - - self.ack_event.set() - - - def _complete(self, response): - """ Locally store the response and signal any callers blocking via - :func:`wait` to proceed. - """ - - self.response = response - self.ack_event.set() - self.rep_event.set() - - - def poll(self): - """ Return True if the request is complete, otherwise return False. - """ - - return self.rep_event.is_set() - - - def wait_ack(self, timeout): - """ Block until the request has been acknowledged. This is a wrapper to - a :class:`threading.Event` instance; if the event has occurred it - will return True, otherwise it returns False after the requested - *timeout*. If the *timeout* argument is None it will block - indefinitely. - """ - - return self.ack_event.wait(timeout) - - - def wait(self, timeout=60): - """ Block until the request has been handled. The response to the - request is always returned; the response will be None if the - original request is still pending. - """ - - self.rep_event.wait(timeout) - return self.response - - -# end of class Request - - - -class Payload: - """ This is a lightweight class to properly encapsulate a Python-native - value for later inclusion in a :class:`Message` instance. All attributes - of this class, except for the :attr:`bulk` attribute, or any attrbiutes - set to None, will be added to as a JSON dictionary via - :func:`encapsulate`, which is called when a :class:`Message` instance - finalizes itself before generating its final on-the-wire representation. - - No interpretation of the :class:`Payload` contents is performed, any - interpretation (such as converting a numpy array to a form suitable - for enapsulation) must occur before the :class:`Payload` is - instantiated. - - :ivar bulk: A bulk data value, in bytes, or None - :ivar omit: A set of fields to omit from encapsulation - """ - - omit = set(('bulk', 'omit')) - - def __init__(self, value, time=None, error=None, bulk=None, shape=None, dtype=None, refresh=None, **kwargs): - """ Arbitrary keyword arguments are allowed when creating a - :class:`Payload` instance, beyond the canonical set; when - included, these additional keyword arguments will be assigned - directly as attributes for later encapsulation. Any values - assigned in this fashion must be serializable as JSON. - """ - - # The use of 'time' as a keyword argument is what's motivating the - # weird import of the time module in this file. We want the keyword - # arguments to be aligned with the fields in the JSON description - # of a payload: value, time, error, etc. - - if time is None: - time = timemodule.time() - - if refresh is False: - refresh = None - - # We expect the canonical arguments to all be set all the time, - # even if their value is None. That's why they're not rolled up - # into the kwargs catch-all. - - self.bulk = bulk - self.dtype = dtype - self.error = error - self.refresh = refresh - self.shape = shape - self.time = time - self.value = value - - if not kwargs: - # This is the average case. Faster to check this one condition - # and return than to drop out of the next two conditions. - return - - if 'omit' in kwargs: - raise ValueError("cannot assign 'omit' to a Payload") - - # Allow additional arbitrary fields in the payload. We are assuming - # the caller knows what they are doing, and that these additional - # fields can be serialized as JSON. - - for key,value in kwargs.items(): - setattr(self, key, value) - - - def __repr__(self): - return self.encapsulate().decode() - - - def encapsulate(self): - """ Add all non-omitted local attributes to a dictionary, and return - the JSON encoding of that dictionary. For example, if the .value - and .time attributes of a :class:`Payload` are assigned, the caller - will receive a value like:: - - b'{"value": 12, "time": 1761100609.234571}' - """ - - # The output from this method was initially cached, but there's never - # a situation where this method is called twice for a given Payload, - # so the caching was removed. - - payload = dict() - - # All local attributes get put into the encapsulated payload, - # except for those included in the omit set. - - for key,value in vars(self).items(): - if key in self.omit: - continue - - # Do not include attributes that are just 'None'. This may be - # premature optimization, but it seems silly to put a bunch of - # extra bytes on the wire when it conveys no additional information. - - # It's faster to check the key against 'value' repeatedly than - # to build a separate set of includes and only assign those - # key/value pairs to the payload. - - if value is not None or key == 'value': - payload[key] = value - - payload = json.dumps(payload) - return payload - - -# end of class Payload - - -_id_min = 0 -_id_max = 0xFFFFFFFF -_id_lock = threading.Lock() -_id_ticker = itertools.count(_id_min) - - -def _id_next(): - """ Return the next request identification number for subroutines to - use when constructing a message. - """ - - global _id_ticker - _id_lock.acquire() - id = next(_id_ticker) - - if id >= _id_max: - _id_ticker = itertools.count(_id_min) - - if id > _id_max: - # This shouldn't happen, but here we are... - id = next(_id_ticker) - - _id_lock.release() - - id = '%08x' % (id) - id = id.encode() - return id - - -# vim: set expandtab tabstop=8 softtabstop=4 shiftwidth=4 autoindent: diff --git a/src/mktl/transport/zmq/publish.py b/src/mktl/transport/zmq/publish.py index e806b152..0f988042 100644 --- a/src/mktl/transport/zmq/publish.py +++ b/src/mktl/transport/zmq/publish.py @@ -3,7 +3,6 @@ from __future__ import annotations import atexit -import itertools import queue import threading import traceback @@ -11,15 +10,17 @@ import zmq -from ...protocol import Message, Publish -from .framing import from_pub_frames, to_pub_frames +from ...protocol.message import Message +from ...protocol.wire import pack_frame, unpack_frame +from ...transport import TransportPortError +from ..session import PublishSession, SubscribeSession minimum_port = 10139 maximum_port = 13679 zmq_context = zmq.Context() -class Client: +class Client(SubscribeSession): """SUB client.""" def __init__(self, address: str, port: int): @@ -39,11 +40,11 @@ def subscribe(self, topic: str) -> None: self.socket.setsockopt(zmq.SUBSCRIBE, (topic + ".").encode()) def recv(self) -> Message: - parts = self.socket.recv_multipart() - return from_pub_frames(parts) + _topic, wire_bytes = self.socket.recv_multipart() + return unpack_frame(wire_bytes) -class Server: +class Server(PublishSession): """PUB server.""" def __init__(self, port: Optional[int] = None, avoid: Optional[set] = None): @@ -63,9 +64,16 @@ def __init__(self, port: Optional[int] = None, avoid: Optional[set] = None): except zmq.ZMQError: continue if self.port is None: - raise zmq.ZMQError("no PUB port available") + raise TransportPortError( + f"no ports available in range {minimum_port}:{maximum_port}" + ) else: - self.socket.bind(f"tcp://*:{self.port}") + try: + self.socket.bind(f"tcp://*:{self.port}") + except zmq.ZMQError as exc: + raise TransportPortError( + f"port already in use: {self.port}" + ) from exc # Internal queue for thread-safe sends try: @@ -83,15 +91,15 @@ def __init__(self, port: Optional[int] = None, avoid: Optional[set] = None): self.thread = threading.Thread(target=self.run, daemon=True) self.thread.start() - def send(self, msg: Publish) -> None: + def send(self, msg: Message) -> None: self._queue.put(msg) self._sig_tx.send(b"") def _send_one(self) -> None: self._sig_rx.recv(flags=zmq.NOBLOCK) msg = self._queue.get(block=False) - frames = to_pub_frames(msg) - self.socket.send_multipart(frames) + topic = ((msg.env.key or "") + ".").encode() + self.socket.send_multipart([topic, pack_frame(msg)]) def run(self) -> None: poller = zmq.Poller() diff --git a/src/mktl/transport/zmq/request.py b/src/mktl/transport/zmq/request.py index edf8b874..1f6dc215 100644 --- a/src/mktl/transport/zmq/request.py +++ b/src/mktl/transport/zmq/request.py @@ -13,20 +13,17 @@ import atexit import concurrent.futures -import itertools import queue import socket as pysocket -import sys import threading -import time -import traceback from typing import Dict, Optional, Tuple import zmq -from ...protocol import Message, Payload, Request -from ...protocol.fields import ACK, REP -from .framing import from_request_frames, to_request_frames +from ...protocol.message import Message +from ...protocol.wire import pack_frame, unpack_frame +from ...transport import TransportTimeout, TransportPortError +from ..session import RequestSession, RequestServer, PendingRequest minimum_port = 10079 @@ -34,36 +31,7 @@ zmq_context = zmq.Context() -class PendingRequest: - """Client-side helper that provides ACK/REP synchronization.""" - - def __init__(self, req: Request): - self.req = req - self.response: Optional[Message] = None - self.ack_event = threading.Event() - self.rep_event = threading.Event() - - @property - def id(self) -> bytes: - return self.req.msg_id - - def wait_ack(self, timeout: Optional[float]) -> bool: - return self.ack_event.wait(timeout) - - def wait(self, timeout: Optional[float] = 60) -> Optional[Message]: - self.rep_event.wait(timeout) - return self.response - - def _complete_ack(self) -> None: - self.ack_event.set() - - def _complete(self, response: Message) -> None: - self.response = response - self.ack_event.set() - self.rep_event.set() - - -class Client: +class Client(RequestSession): """Issue requests via a ZeroMQ DEALER socket and receive responses.""" timeout = 0.1 @@ -91,32 +59,17 @@ def __init__(self, address: str, port: int): self._signal_tx = zmq_context.socket(zmq.PAIR) self._signal_tx.connect(internal) - self._pending: Dict[bytes, PendingRequest] = {} + self._pending: Dict[str, PendingRequest] = {} self._thread = threading.Thread(target=self.run, daemon=True) self._thread.start() - def _handle_incoming(self, parts: Tuple[bytes, ...]) -> None: - msg = from_request_frames(parts) - pending = self._pending.get(msg.msg_id) - if pending is None: - return - - if msg.msg_type == ACK: - pending._complete_ack() - return - - # REP (or error REP on version mismatch) - pending._complete(msg) - self._pending.pop(msg.msg_id, None) - def _handle_outgoing(self) -> None: # Clear one signal and send one request. self._signal_rx.recv(flags=zmq.NOBLOCK) pending: PendingRequest = self._outbox.get(block=False) - frames = to_request_frames(pending.req) self._pending[pending.id] = pending - self.socket.send_multipart(frames) + self.socket.send(pack_frame(pending.req)) def run(self) -> None: poller = zmq.Poller() @@ -128,23 +81,23 @@ def run(self) -> None: if active == self._signal_rx: self._handle_outgoing() elif active == self.socket: - parts = tuple(self.socket.recv_multipart()) - self._handle_incoming(parts) + msg = unpack_frame(self.socket.recv()) + self._handle_incoming(msg) - def send(self, request: Request) -> PendingRequest: - pending = PendingRequest(request) + def send(self, msg: Message) -> PendingRequest: + pending = PendingRequest(msg) self._outbox.put(pending) self._signal_tx.send(b"") ack = pending.wait_ack(self.timeout) if not ack: - raise zmq.ZMQError( - f"{request.msg_type} @ {self.address}:{self.port}: no ACK in {self.timeout:.2f} sec" + raise TransportTimeout( + f"{msg.env.type} @ {self.address}:{self.port}: no ACK in {self.timeout:.2f} sec" ) return pending -class Server: +class Server(RequestServer): """Receive requests via a ZeroMQ ROUTER socket, respond to them.""" port = None # auto @@ -160,7 +113,12 @@ def __init__(self, address: Optional[str] = None, port: Optional[int] = None, av if self.port is None: self.port = self._bind_any() else: - self.socket.bind(f"tcp://{self.address}:{self.port}") + try: + self.socket.bind(f"tcp://{self.address}:{self.port}") + except zmq.ZMQError as exc: + raise TransportPortError( + f"port already in use: {self.port}" + ) from exc # Response queue for thread-safe sending try: @@ -188,73 +146,19 @@ def _bind_any(self) -> int: return port except zmq.ZMQError: continue - raise RuntimeError("no available ports") - - # --- request handling hooks --- - def req_handler(self, request: Request) -> Optional[Payload]: - """Override in subclasses. - - Return: - - Payload -> will be wrapped into a REP - - None -> no immediate REP (handler is responsible for later response) - """ - - # Default: no-op - self.req_ack(request) - return None - - def req_ack(self, request: Request) -> None: - ack = Message(msg_type=ACK, target=request.target, msg_id=request.msg_id) - ack.meta["zmq_prefix"] = request.meta.get("zmq_prefix", ()) - self.send(ack) + raise TransportPortError( + f"no ports available in range {minimum_port}:{maximum_port}" + ) def send(self, response: Message) -> None: self._responses.put(response) self._signal_tx.send(b"") - # --- internal --- def _rep_outgoing(self) -> None: self._signal_rx.recv(flags=zmq.NOBLOCK) response: Message = self._responses.get(block=False) - frames = to_request_frames(response, include_prefix=True) - self.socket.send_multipart(frames) - - def _req_incoming(self, parts: Tuple[bytes, ...]) -> None: - msg = from_request_frames(parts) - - # Convert generic Message to typed Request (validates op) - try: - req = Request(msg_type=msg.msg_type, target=msg.target, payload=msg.payload, msg_id=msg.msg_id) - except Exception: - # If invalid, treat as opaque request - req = Request(msg_type=msg.msg_type, target=msg.target, payload=msg.payload, msg_id=msg.msg_id) - - req.meta.update(msg.meta) - - payload: Optional[Payload] = None - error: Optional[dict] = None - - try: - payload = self.req_handler(req) - except Exception: - e_class, e_instance, _tb = sys.exc_info() - error = { - "type": getattr(e_class, "__name__", "Exception"), - "text": str(e_instance), - "debug": traceback.format_exc(), - } - - if payload is None and error is None: - return - - if payload is None: - payload = Payload(value=None) - if error is not None: - payload.error = error - - rep = Message(msg_type=REP, target=req.target, payload=payload, msg_id=req.msg_id) - rep.meta["zmq_prefix"] = req.meta.get("zmq_prefix", ()) - self.send(rep) + identity = response.env.meta.get("zmq_prefix", (b"",))[0] + self.socket.send_multipart([identity, pack_frame(response)]) def run(self) -> None: poller = zmq.Poller() @@ -266,11 +170,11 @@ def run(self) -> None: if active == self._signal_rx: self._rep_outgoing() elif active == self.socket: - parts = tuple(self.socket.recv_multipart()) - self.workers.submit(self._req_incoming, parts) - + identity, wire_bytes = self.socket.recv_multipart() + msg = unpack_frame(wire_bytes) + msg.env.meta["zmq_prefix"] = (identity,) + self.workers.submit(self._req_incoming, msg) -# --- convenience helpers (API-compatible-ish) --- _client_cache: Dict[Tuple[str, int], Client] = {} _client_lock = threading.Lock() @@ -286,15 +190,13 @@ def client(address: str, port: int) -> Client: return c -def send(address: str, port: int, message: Request) -> Payload: +def send(address: str, port: int, message: Message) -> Message: c = client(address, port) pending = c.send(message) response = pending.wait(timeout=60) if response is None: - raise zmq.ZMQError("no response received") - if response.payload is None: - return Payload(value=None) - return response.payload + raise TransportTimeout("no response received") + return response def _cleanup() -> None: