11from __future__ import annotations
22
3+ import sys
34import itertools
45from unittest import IsolatedAsyncioTestCase
56
1718)
1819from mockafka .kafka_store import KafkaStore
1920
21+ if sys .version_info < (3 , 10 ):
22+ def aiter (async_iterable ): # noqa: A001
23+ return async_iterable .__aiter__ ()
24+
25+ async def anext (async_iterable ): # noqa: A001
26+ return await async_iterable .__anext__ ()
27+
2028
2129@pytest .mark .asyncio
2230class TestAIOKAFKAFakeConsumer (IsolatedAsyncioTestCase ):
@@ -40,7 +48,7 @@ def topic(self):
4048 def create_topic (self ):
4149 self .kafka .create_partition (topic = self .test_topic , partitions = 16 )
4250
43- async def produce_message (self ):
51+ async def produce_two_messages (self ):
4452 await self .producer .send (
4553 topic = self .test_topic , partition = 0 , key = b"test" , value = b"test"
4654 )
@@ -51,6 +59,40 @@ async def produce_message(self):
5159 async def test_consume (self ):
5260 await self .test_poll_with_commit ()
5361
62+ async def test_async_iterator (self ):
63+ self .create_topic ()
64+ await self .produce_two_messages ()
65+ self .consumer .subscribe (topics = [self .test_topic ])
66+ await self .consumer .start ()
67+
68+ iterator = aiter (self .consumer )
69+ message = await anext (iterator )
70+ self .assertEqual (message .value , b"test" )
71+
72+ message = await anext (iterator )
73+ self .assertEqual (message .value , b"test1" )
74+
75+ # Technically at this point aiokafka's consumer would block
76+ # indefinitely, however since that's not useful in tests we instead stop
77+ # iterating.
78+ with pytest .raises (StopAsyncIteration ):
79+ await anext (iterator )
80+
81+ async def test_async_iterator_closed_early (self ):
82+ self .create_topic ()
83+ await self .produce_two_messages ()
84+ self .consumer .subscribe (topics = [self .test_topic ])
85+ await self .consumer .start ()
86+
87+ iterator = aiter (self .consumer )
88+ message = await anext (iterator )
89+ self .assertEqual (message .value , b"test" )
90+
91+ await self .consumer .stop ()
92+
93+ with pytest .raises (StopAsyncIteration ):
94+ await anext (iterator )
95+
5496 async def test_start (self ):
5597 # check consumer store is empty
5698 await self .consumer .start ()
@@ -69,7 +111,7 @@ async def test_start(self):
69111
70112 async def test_poll_without_commit (self ):
71113 self .create_topic ()
72- await self .produce_message ()
114+ await self .produce_two_messages ()
73115 self .consumer .subscribe (topics = [self .test_topic ])
74116 await self .consumer .start ()
75117
@@ -83,7 +125,7 @@ async def test_poll_without_commit(self):
83125
84126 async def test_partition_specific_poll_without_commit (self ):
85127 self .create_topic ()
86- await self .produce_message ()
128+ await self .produce_two_messages ()
87129 self .consumer .subscribe (topics = [self .test_topic ])
88130 await self .consumer .start ()
89131
@@ -99,7 +141,7 @@ async def test_partition_specific_poll_without_commit(self):
99141
100142 async def test_poll_with_commit (self ):
101143 self .create_topic ()
102- await self .produce_message ()
144+ await self .produce_two_messages ()
103145 self .consumer .subscribe (topics = [self .test_topic ])
104146 await self .consumer .start ()
105147
@@ -116,7 +158,7 @@ async def test_poll_with_commit(self):
116158
117159 async def test_getmany_without_commit (self ):
118160 self .create_topic ()
119- await self .produce_message ()
161+ await self .produce_two_messages ()
120162 await self .producer .send (
121163 topic = self .test_topic , partition = 2 , key = b"test2" , value = b"test2"
122164 )
@@ -145,7 +187,7 @@ async def test_getmany_without_commit(self):
145187
146188 async def test_getmany_with_limit_without_commit (self ):
147189 self .create_topic ()
148- await self .produce_message ()
190+ await self .produce_two_messages ()
149191 await self .producer .send (
150192 topic = self .test_topic , partition = 0 , key = b"test2" , value = b"test2"
151193 )
@@ -182,7 +224,7 @@ async def test_getmany_with_limit_without_commit(self):
182224
183225 async def test_getmany_specific_poll_without_commit (self ):
184226 self .create_topic ()
185- await self .produce_message ()
227+ await self .produce_two_messages ()
186228 await self .producer .send (
187229 topic = self .test_topic , partition = 1 , key = b"test2" , value = b"test2"
188230 )
@@ -210,7 +252,7 @@ async def test_getmany_specific_poll_without_commit(self):
210252
211253 async def test_getmany_with_commit (self ):
212254 self .create_topic ()
213- await self .produce_message ()
255+ await self .produce_two_messages ()
214256 await self .producer .send (
215257 topic = self .test_topic , partition = 2 , key = b"test2" , value = b"test2"
216258 )
@@ -287,7 +329,7 @@ async def test_lifecycle(self):
287329
288330 self .assertEqual (self .consumer .subscribed_topic , topics )
289331
290- await self .produce_message ()
332+ await self .produce_two_messages ()
291333
292334 messages = {
293335 tp : self .summarise (msgs )
@@ -336,7 +378,7 @@ async def test_context_manager(self):
336378
337379 async with self .consumer as consumer :
338380 self .assertEqual (self .consumer , consumer )
339- await self .produce_message ()
381+ await self .produce_two_messages ()
340382
341383 messages = {
342384 tp : self .summarise (msgs )
@@ -373,3 +415,6 @@ async def test_consumer_is_stopped(self):
373415 self .consumer .subscribe (topics = topics )
374416 with self .assertRaises (ConsumerStoppedError ):
375417 await self .consumer .getone ()
418+
419+ with self .assertRaises (ConsumerStoppedError ):
420+ aiter (self .consumer )
0 commit comments