@@ -42,7 +42,7 @@ class Connection(metaclass=ConnectionMeta):
4242 """
4343
4444 __slots__ = ('_protocol' , '_transport' , '_loop' ,
45- '_top_xact' , '_aborted' ,
45+ '_top_xact' , '_aborted' , '_middlewares'
4646 '_pool_release_ctr' , '_stmt_cache' , '_stmts_to_close' ,
4747 '_listeners' , '_server_version' , '_server_caps' ,
4848 '_intro_query' , '_reset_query' , '_proxy' ,
@@ -53,7 +53,8 @@ class Connection(metaclass=ConnectionMeta):
5353 def __init__ (self , protocol , transport , loop ,
5454 addr : (str , int ) or str ,
5555 config : connect_utils ._ClientConfiguration ,
56- params : connect_utils ._ConnectionParameters ):
56+ params : connect_utils ._ConnectionParameters ,
57+ middlewares = None ):
5758 self ._protocol = protocol
5859 self ._transport = transport
5960 self ._loop = loop
@@ -92,7 +93,7 @@ def __init__(self, protocol, transport, loop,
9293
9394 self ._reset_query = None
9495 self ._proxy = None
95-
96+ self . _middlewares = _middlewares
9697 # Used to serialize operations that might involve anonymous
9798 # statements. Specifically, we want to make the following
9899 # operation atomic:
@@ -1410,8 +1411,12 @@ async def reload_schema_state(self):
14101411
14111412 async def _execute (self , query , args , limit , timeout , return_status = False ):
14121413 with self ._stmt_exclusive_section :
1413- result , _ = await self .__execute (
1414- query , args , limit , timeout , return_status = return_status )
1414+ wrapped = self .__execute
1415+ if self ._middlewares :
1416+ for m in reversed (self ._middlewares ):
1417+ wrapped = await m (self , wrapped )
1418+
1419+ result , _ = await wrapped (query , args , limit , timeout , return_status = return_status )
14151420 return result
14161421
14171422 async def __execute (self , query , args , limit , timeout ,
@@ -1502,6 +1507,7 @@ async def connect(dsn=None, *,
15021507 max_cacheable_statement_size = 1024 * 15 ,
15031508 command_timeout = None ,
15041509 ssl = None ,
1510+ middlewares = None ,
15051511 connection_class = Connection ,
15061512 server_settings = None ):
15071513 r"""A coroutine to establish a connection to a PostgreSQL server.
@@ -1618,6 +1624,10 @@ async def connect(dsn=None, *,
16181624 PostgreSQL documentation for
16191625 a `list of supported options <server settings>`_.
16201626
1627+ :param middlewares:
1628+ An optional list of middleware functions. Refer to documentation
1629+ on create_pool.
1630+
16211631 :param Connection connection_class:
16221632 Class of the returned connection object. Must be a subclass of
16231633 :class:`~asyncpg.connection.Connection`.
@@ -1683,6 +1693,7 @@ async def connect(dsn=None, *,
16831693 ssl = ssl , database = database ,
16841694 server_settings = server_settings ,
16851695 command_timeout = command_timeout ,
1696+ middlewares = middlewares ,
16861697 statement_cache_size = statement_cache_size ,
16871698 max_cached_statement_lifetime = max_cached_statement_lifetime ,
16881699 max_cacheable_statement_size = max_cacheable_statement_size )
0 commit comments