From 0630c0f760666100494cb4af2d3714db9b74b38e Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Tue, 8 Jul 2025 16:49:31 +0800 Subject: [PATCH] Support message router --- pulsar/__init__.py | 13 ++++++++++++- src/config.cc | 19 ++++++++++++++++++- tests/pulsar_test.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 2 deletions(-) diff --git a/pulsar/__init__.py b/pulsar/__init__.py index 8802493..7dda4cd 100644 --- a/pulsar/__init__.py +++ b/pulsar/__init__.py @@ -43,7 +43,7 @@ """ import logging -from typing import List, Tuple, Optional, Union +from typing import Callable, List, Tuple, Optional, Union import _pulsar @@ -54,6 +54,7 @@ from pulsar.__about__ import __version__ from pulsar.exceptions import * +from pulsar.schema.schema import BytesSchema from pulsar.tableview import TableView from pulsar.functions.function import Function @@ -246,6 +247,7 @@ def schema_version(self): @staticmethod def _wrap(_message): self = Message() + self._schema = BytesSchema() self._message = _message return self @@ -696,6 +698,7 @@ def create_producer(self, topic, encryption_key=None, crypto_key_reader: Union[None, CryptoKeyReader] = None, access_mode: ProducerAccessMode = ProducerAccessMode.Shared, + message_router: Callable[[Message, int], int]=None, ): """ Create a new producer on a given topic. @@ -811,6 +814,10 @@ def create_producer(self, topic, * WaitForExclusive: Producer creation is pending until it can acquire exclusive access. * ExclusiveWithFencing: Acquire exclusive access for the producer. Any existing producer will be removed and invalidated immediately. + message_router: optional + A custom message router function that takes a `Message` and the number of partitions + and returns the partition index to which the message should be routed. If not provided, + the default routing policy defined by `message_routing_mode` will be used. """ _check_type(str, topic, 'topic') _check_type_or_none(str, producer_name, 'producer_name') @@ -848,6 +855,10 @@ def create_producer(self, topic, conf.chunking_enabled(chunking_enabled) conf.lazy_start_partitioned_producers(lazy_start_partitioned_producers) conf.access_mode(access_mode) + if message_router is not None: + underlying_router = lambda msg, num_partitions: int(message_router(Message._wrap(msg), num_partitions)) + conf.message_router(underlying_router) + if producer_name: conf.producer_name(producer_name) if initial_sequence_id: diff --git a/src/config.cc b/src/config.cc index 06822b4..a83c0c0 100644 --- a/src/config.cc +++ b/src/config.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include namespace py = pybind11; @@ -104,6 +105,19 @@ class HIDDEN LoggerWrapperFactory : public LoggerFactory, public CaptivePythonOb } }; +using MessageRouterFunc = std::function; +class HIDDEN MessageRouter : public pulsar::MessageRoutingPolicy { + public: + explicit MessageRouter(MessageRouterFunc func) : func_(std::move(func)) {} + + int getPartition(const Message& msg, const TopicMetadata& topicMetadata) final { + return func_(msg, topicMetadata.getNumPartitions()); + } + + private: + MessageRouterFunc func_; +}; + static ClientConfiguration& ClientConfiguration_setLogger(ClientConfiguration& conf, py::object logger) { conf.setLogger(new LoggerWrapperFactory(logger)); return conf; @@ -235,7 +249,10 @@ void export_config(py::module_& m) { .def("encryption_key", &ProducerConfiguration::addEncryptionKey, return_value_policy::reference) .def("crypto_key_reader", &ProducerConfiguration::setCryptoKeyReader, return_value_policy::reference) .def("access_mode", &ProducerConfiguration::setAccessMode, return_value_policy::reference) - .def("access_mode", &ProducerConfiguration::getAccessMode, return_value_policy::copy); + .def("access_mode", &ProducerConfiguration::getAccessMode, return_value_policy::copy) + .def("message_router", [](ProducerConfiguration& config, MessageRouterFunc func) { + config.setMessageRouter(std::make_shared(std::move(func))); + }); class_(m, "BatchReceivePolicy") .def(init()) diff --git a/tests/pulsar_test.py b/tests/pulsar_test.py index 4e1c5fb..d24faf3 100755 --- a/tests/pulsar_test.py +++ b/tests/pulsar_test.py @@ -2019,5 +2019,37 @@ def test_deserialize_msg_id_with_topic(self): self.assertEqual(msg.value(), b'msg-3') client.close() + def test_message_router(self): + topic_name = "public/default/test_message_router" + str(time.time()) + url1 = self.adminUrl + "/admin/v2/persistent/" + topic_name + "/partitions" + doHttpPut(url1, "5") + client = Client(self.serviceUrl) + def router(msg: pulsar.Message, num_partitions: int): + s = msg.value().decode('utf-8') + if s.startswith("hello-"): + return 10 % num_partitions + else: + return 11 % num_partitions + producer = client.create_producer(topic_name, message_router=router) + producer.send(b"hello-0") + producer.send(b"hello-1") + producer.send(b"world-0") + producer.send(b"world-1") + consumer = client.subscribe(topic_name, 'sub', + initial_position=InitialPosition.Earliest) + partition_to_values = dict() + for _ in range(4): + msg = consumer.receive(TM) + partition = msg.message_id().partition() + if partition in partition_to_values: + partition_to_values[partition].append(msg.value().decode('utf-8')) + else: + partition_to_values[partition] = [msg.value().decode('utf-8')] + self.assertEqual(partition_to_values[0], ["hello-0", "hello-1"]) + self.assertEqual(partition_to_values[1], ["world-0", "world-1"]) + + client.close() + + if __name__ == "__main__": main()