@@ -126,6 +126,37 @@ def test_dht_single_node():
126126 node .shutdown ()
127127
128128
129+ @pytest .mark .forked
130+ @pytest .mark .asyncio
131+ async def test_negative_caching (n_peers = 10 ):
132+ dht_kwargs = {"cache_locally" : False }
133+
134+ peers = [hivemind .DHT (start = True , ** dht_kwargs )]
135+ initial_peers = peers [0 ].get_visible_maddrs ()
136+ peers += [hivemind .DHT (initial_peers = initial_peers , start = True , ** dht_kwargs ) for _ in range (n_peers - 1 )]
137+
138+ writer_peer = random .choice (peers )
139+ assert all (declare_experts (writer_peer , ["ffn.1.2.3" , "ffn.3.4.5" ], get_dht_time () + 30 ).values ())
140+
141+ neighbors = sum ([peer .get_visible_maddrs () for peer in random .sample (peers , min (3 , len (peers )))], [])
142+ neg_caching_peer = hivemind .DHT (initial_peers = neighbors , start = True , ** dht_kwargs )
143+ beam_search = MoEBeamSearcher (neg_caching_peer , uid_prefix = "ffn." , grid_size = (10 , 10 , 10 ), negative_caching = True )
144+ # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
145+ assert len (beam_search .get_initial_beam (scores = [0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 ], beam_size = 3 )) == 2
146+
147+ node = await DHTNode .create (initial_peers = neighbors )
148+ fetched = await asyncio .gather (* (node .get (f"ffn.{ i } ." ) for i in range (10 )))
149+ for i in range (6 ):
150+ assert fetched [i ] is not None , f"node should have cached ffn.{ i } ."
151+ for i in range (6 , len (fetched )):
152+ assert fetched [i ] is None , f"node shouldn't have cached ffn.{ i } ."
153+
154+ await node .shutdown ()
155+ neg_caching_peer .shutdown ()
156+ for peer in peers :
157+ peer .shutdown ()
158+
159+
129160def test_uid_patterns ():
130161 valid_experts = [
131162 "expert.1" ,
@@ -188,34 +219,3 @@ def test_uid_patterns():
188219 assert not is_valid_uid (uid ), f"UID { uid } is not valid, but was perceived as valid"
189220 for pfx in invalid_prefixes :
190221 assert not is_valid_prefix (pfx ), f"Prefix { pfx } is not valid, but was perceived as valid"
191-
192-
193- @pytest .mark .forked
194- @pytest .mark .asyncio
195- async def test_negative_caching (n_peers = 10 ):
196- dht_kwargs = {"cache_locally" : False }
197-
198- peers = [hivemind .DHT (start = True , ** dht_kwargs )]
199- initial_peers = peers [0 ].get_visible_maddrs ()
200- peers += [hivemind .DHT (initial_peers = initial_peers , start = True , ** dht_kwargs ) for _ in range (n_peers - 1 )]
201-
202- writer_peer = random .choice (peers )
203- assert all (declare_experts (writer_peer , ["ffn.1.2.3" , "ffn.3.4.5" ], get_dht_time () + 30 ).values ())
204-
205- neighbors = sum ([peer .get_visible_maddrs () for peer in random .sample (peers , min (3 , len (peers )))], [])
206- neg_caching_peer = hivemind .DHT (initial_peers = neighbors , start = True , ** dht_kwargs )
207- beam_search = MoEBeamSearcher (neg_caching_peer , uid_prefix = "ffn." , grid_size = (10 , 10 , 10 ), negative_caching = True )
208- # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
209- assert len (beam_search .get_initial_beam (scores = [0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 ], beam_size = 3 )) == 2
210-
211- node = await DHTNode .create (initial_peers = neighbors )
212- fetched = await asyncio .gather (* (node .get (f"ffn.{ i } ." ) for i in range (10 )))
213- for i in range (6 ):
214- assert fetched [i ] is not None , f"node should have cached ffn.{ i } ."
215- for i in range (6 , len (fetched )):
216- assert fetched [i ] is None , f"node shouldn't have cached ffn.{ i } ."
217-
218- await node .shutdown ()
219- neg_caching_peer .shutdown ()
220- for peer in peers :
221- peer .shutdown ()
0 commit comments