@@ -305,7 +305,7 @@ class Pool:
305305    """ 
306306
307307    __slots__  =  (
308-         '_queue' , '_loop' , '_minsize' , '_maxsize' ,
308+         '_queue' , '_loop' , '_minsize' , '_maxsize' ,  '_middlewares' , 
309309        '_init' , '_connect_args' , '_connect_kwargs' ,
310310        '_working_addr' , '_working_config' , '_working_params' ,
311311        '_holders' , '_initialized' , '_initializing' , '_closing' ,
@@ -320,6 +320,7 @@ def __init__(self, *connect_args,
320320                 max_inactive_connection_lifetime ,
321321                 setup ,
322322                 init ,
323+                  middlewares ,
323324                 loop ,
324325                 connection_class ,
325326                 ** connect_kwargs ):
@@ -377,6 +378,7 @@ def __init__(self, *connect_args,
377378        self ._closed  =  False 
378379        self ._generation  =  0 
379380        self ._init  =  init 
381+         self ._middlewares  =  middlewares 
380382        self ._connect_args  =  connect_args 
381383        self ._connect_kwargs  =  connect_kwargs 
382384
@@ -469,6 +471,7 @@ async def _get_new_connection(self):
469471                * self ._connect_args ,
470472                loop = self ._loop ,
471473                connection_class = self ._connection_class ,
474+                 middlewares = self ._middlewares ,
472475                ** self ._connect_kwargs )
473476
474477            self ._working_addr  =  con ._addr 
@@ -483,6 +486,7 @@ async def _get_new_connection(self):
483486                addr = self ._working_addr ,
484487                timeout = self ._working_params .connect_timeout ,
485488                config = self ._working_config ,
489+                 middlewares = self ._middlewares ,
486490                params = self ._working_params ,
487491                connection_class = self ._connection_class )
488492
@@ -784,13 +788,35 @@ def __await__(self):
784788        return  self .pool ._acquire (self .timeout ).__await__ ()
785789
786790
791+ def  middleware (f ):
792+     """Decorator for adding a middleware 
793+      
794+     Can be used like such 
795+      
796+     .. code-block:: python 
797+ 
798+         @pool.middleware 
799+         async def my_middleware(query, args, limit, timeout, return_status, *, handler, conn): 
800+             print('do something before') 
801+             result, stmt = await handler(query, args, limit, timeout, return_status) 
802+             print('do something after') 
803+             return result, stmt 
804+              
805+         my_pool = await pool.create_pool(middlewares=[my_middleware]) 
806+     """ 
807+     async  def  middleware_factory (connection , handler ):
808+         return  functools .partial (f , connection = connection , handler = handler )
809+     return  middleware_factory 
810+ 
811+     
787812def  create_pool (dsn = None , * ,
788813                min_size = 10 ,
789814                max_size = 10 ,
790815                max_queries = 50000 ,
791816                max_inactive_connection_lifetime = 300.0 ,
792817                setup = None ,
793818                init = None ,
819+                 middlewares = None ,
794820                loop = None ,
795821                connection_class = connection .Connection ,
796822                ** connect_kwargs ):
@@ -866,6 +892,19 @@ def create_pool(dsn=None, *,
866892        or :meth:`Connection.set_type_codec() <\ 
867893        asyncpg.connection.Connection.set_type_codec>`. 
868894
895+     :param middlewares: 
896+         A list of middleware functions to be middleware just 
897+         before a connection excecutes a statement. 
898+         Syntax of a middleware is as follows: 
899+         async def middleware_factory(connection, handler): 
900+             async def middleware(query, args, limit, timeout, return_status): 
901+                 print('do something before') 
902+                 result, stmt = await handler(query, args, limit, 
903+                                              timeout, return_status) 
904+                 print('do something after') 
905+                 return result, stmt 
906+             return middleware 
907+ 
869908    :param loop: 
870909        An asyncio event loop instance.  If ``None``, the default 
871910        event loop will be used. 
@@ -893,6 +932,8 @@ def create_pool(dsn=None, *,
893932        dsn ,
894933        connection_class = connection_class ,
895934        min_size = min_size , max_size = max_size ,
896-         max_queries = max_queries , loop = loop , setup = setup , init = init ,
935+         max_queries = max_queries , loop = loop , setup = setup ,
936+         middlewares = middlewares , init = init ,
897937        max_inactive_connection_lifetime = max_inactive_connection_lifetime ,
898938        ** connect_kwargs )
939+ 
0 commit comments