Skip to content

Commit e48ecbc

Browse files
authored
Merge pull request #235 from PeterJCLaw/async-iterator
FakeAIOKafkaConsumer async iterator
2 parents aa501c0 + 94f761c commit e48ecbc

File tree

2 files changed

+73
-10
lines changed

2 files changed

+73
-10
lines changed

mockafka/aiokafka/aiokafka_consumer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,24 @@ async def getmany(
244244

245245
return dict(result)
246246

247+
def __aiter__(self):
248+
if self._is_closed:
249+
raise ConsumerStoppedError()
250+
return self
251+
252+
async def __anext__(self) -> ConsumerRecord[bytes, bytes]:
253+
while True:
254+
try:
255+
result = await self.getone()
256+
if result is None:
257+
# Follow the lead of `getone`, though note that we should
258+
# address this as part of any fix to
259+
# https://github.com/alm0ra/mockafka-py/issues/117
260+
raise StopAsyncIteration
261+
return result
262+
except ConsumerStoppedError:
263+
raise StopAsyncIteration from None
264+
247265
async def __aenter__(self) -> Self:
248266
await self.start()
249267
return self

tests/test_aiokafka/test_aiokafka_consumer.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import sys
34
import itertools
45
from unittest import IsolatedAsyncioTestCase
56

@@ -17,6 +18,13 @@
1718
)
1819
from mockafka.kafka_store import KafkaStore
1920

21+
if sys.version_info < (3, 10):
22+
def aiter(async_iterable): # noqa: A001
23+
return async_iterable.__aiter__()
24+
25+
async def anext(async_iterable): # noqa: A001
26+
return await async_iterable.__anext__()
27+
2028

2129
@pytest.mark.asyncio
2230
class TestAIOKAFKAFakeConsumer(IsolatedAsyncioTestCase):
@@ -40,7 +48,7 @@ def topic(self):
4048
def create_topic(self):
4149
self.kafka.create_partition(topic=self.test_topic, partitions=16)
4250

43-
async def produce_message(self):
51+
async def produce_two_messages(self):
4452
await self.producer.send(
4553
topic=self.test_topic, partition=0, key=b"test", value=b"test"
4654
)
@@ -51,6 +59,40 @@ async def produce_message(self):
5159
async def test_consume(self):
5260
await self.test_poll_with_commit()
5361

62+
async def test_async_iterator(self):
63+
self.create_topic()
64+
await self.produce_two_messages()
65+
self.consumer.subscribe(topics=[self.test_topic])
66+
await self.consumer.start()
67+
68+
iterator = aiter(self.consumer)
69+
message = await anext(iterator)
70+
self.assertEqual(message.value, b"test")
71+
72+
message = await anext(iterator)
73+
self.assertEqual(message.value, b"test1")
74+
75+
# Technically at this point aiokafka's consumer would block
76+
# indefinitely, however since that's not useful in tests we instead stop
77+
# iterating.
78+
with pytest.raises(StopAsyncIteration):
79+
await anext(iterator)
80+
81+
async def test_async_iterator_closed_early(self):
82+
self.create_topic()
83+
await self.produce_two_messages()
84+
self.consumer.subscribe(topics=[self.test_topic])
85+
await self.consumer.start()
86+
87+
iterator = aiter(self.consumer)
88+
message = await anext(iterator)
89+
self.assertEqual(message.value, b"test")
90+
91+
await self.consumer.stop()
92+
93+
with pytest.raises(StopAsyncIteration):
94+
await anext(iterator)
95+
5496
async def test_start(self):
5597
# check consumer store is empty
5698
await self.consumer.start()
@@ -69,7 +111,7 @@ async def test_start(self):
69111

70112
async def test_poll_without_commit(self):
71113
self.create_topic()
72-
await self.produce_message()
114+
await self.produce_two_messages()
73115
self.consumer.subscribe(topics=[self.test_topic])
74116
await self.consumer.start()
75117

@@ -83,7 +125,7 @@ async def test_poll_without_commit(self):
83125

84126
async def test_partition_specific_poll_without_commit(self):
85127
self.create_topic()
86-
await self.produce_message()
128+
await self.produce_two_messages()
87129
self.consumer.subscribe(topics=[self.test_topic])
88130
await self.consumer.start()
89131

@@ -99,7 +141,7 @@ async def test_partition_specific_poll_without_commit(self):
99141

100142
async def test_poll_with_commit(self):
101143
self.create_topic()
102-
await self.produce_message()
144+
await self.produce_two_messages()
103145
self.consumer.subscribe(topics=[self.test_topic])
104146
await self.consumer.start()
105147

@@ -116,7 +158,7 @@ async def test_poll_with_commit(self):
116158

117159
async def test_getmany_without_commit(self):
118160
self.create_topic()
119-
await self.produce_message()
161+
await self.produce_two_messages()
120162
await self.producer.send(
121163
topic=self.test_topic, partition=2, key=b"test2", value=b"test2"
122164
)
@@ -145,7 +187,7 @@ async def test_getmany_without_commit(self):
145187

146188
async def test_getmany_with_limit_without_commit(self):
147189
self.create_topic()
148-
await self.produce_message()
190+
await self.produce_two_messages()
149191
await self.producer.send(
150192
topic=self.test_topic, partition=0, key=b"test2", value=b"test2"
151193
)
@@ -182,7 +224,7 @@ async def test_getmany_with_limit_without_commit(self):
182224

183225
async def test_getmany_specific_poll_without_commit(self):
184226
self.create_topic()
185-
await self.produce_message()
227+
await self.produce_two_messages()
186228
await self.producer.send(
187229
topic=self.test_topic, partition=1, key=b"test2", value=b"test2"
188230
)
@@ -210,7 +252,7 @@ async def test_getmany_specific_poll_without_commit(self):
210252

211253
async def test_getmany_with_commit(self):
212254
self.create_topic()
213-
await self.produce_message()
255+
await self.produce_two_messages()
214256
await self.producer.send(
215257
topic=self.test_topic, partition=2, key=b"test2", value=b"test2"
216258
)
@@ -287,7 +329,7 @@ async def test_lifecycle(self):
287329

288330
self.assertEqual(self.consumer.subscribed_topic, topics)
289331

290-
await self.produce_message()
332+
await self.produce_two_messages()
291333

292334
messages = {
293335
tp: self.summarise(msgs)
@@ -336,7 +378,7 @@ async def test_context_manager(self):
336378

337379
async with self.consumer as consumer:
338380
self.assertEqual(self.consumer, consumer)
339-
await self.produce_message()
381+
await self.produce_two_messages()
340382

341383
messages = {
342384
tp: self.summarise(msgs)
@@ -373,3 +415,6 @@ async def test_consumer_is_stopped(self):
373415
self.consumer.subscribe(topics=topics)
374416
with self.assertRaises(ConsumerStoppedError):
375417
await self.consumer.getone()
418+
419+
with self.assertRaises(ConsumerStoppedError):
420+
aiter(self.consumer)

0 commit comments

Comments
 (0)