55
66from deprecation import deprecated
77from httpx import AsyncClient , Headers , QueryParams , Timeout
8+ from yarl import URL
89
910from ..base_client import BasePostgrestClient
1011from ..constants import (
1314)
1415from ..types import CountMethod
1516from ..version import __version__
16- from .request_builder import AsyncRequestBuilder , AsyncRPCFilterRequestBuilder
17+ from .request_builder import (
18+ AsyncRequestBuilder ,
19+ AsyncRPCFilterRequestBuilder ,
20+ RequestConfig ,
21+ )
1722
1823
1924class AsyncPostgrestClient (BasePostgrestClient ):
@@ -59,52 +64,32 @@ def __init__(
5964 else DEFAULT_POSTGREST_CLIENT_TIMEOUT
6065 )
6166 )
62-
6367 BasePostgrestClient .__init__ (
6468 self ,
65- base_url ,
69+ URL ( base_url ) ,
6670 schema = schema ,
6771 headers = headers ,
6872 timeout = self .timeout ,
6973 verify = self .verify ,
7074 proxy = proxy ,
71- http_client = http_client ,
7275 )
73- self .session : AsyncClient = self .session
74-
75- def create_session (
76- self ,
77- base_url : str ,
78- headers : Dict [str , str ],
79- timeout : Union [int , float , Timeout ],
80- verify : bool = True ,
81- proxy : Optional [str ] = None ,
82- ) -> AsyncClient :
83- http_client = None
84- if isinstance (self .http_client , AsyncClient ):
85- http_client = self .http_client
8676
87- if http_client is not None :
88- http_client .base_url = base_url
89- http_client .headers .update ({** headers })
90- return http_client
91-
92- return AsyncClient (
77+ self .session = http_client or AsyncClient (
9378 base_url = base_url ,
94- headers = headers ,
79+ headers = self . headers ,
9580 timeout = timeout ,
96- verify = verify ,
81+ verify = self . verify ,
9782 proxy = proxy ,
9883 follow_redirects = True ,
9984 http2 = True ,
10085 )
10186
102- def schema (self , schema : str ):
87+ def schema (self , schema : str ) -> AsyncPostgrestClient :
10388 """Switch to another schema."""
10489 return AsyncPostgrestClient (
105- base_url = self .base_url ,
90+ base_url = str ( self .base_url ) ,
10691 schema = schema ,
107- headers = self .headers ,
92+ headers = dict ( self .headers ) ,
10893 timeout = self .timeout ,
10994 verify = self .verify ,
11095 proxy = self .proxy ,
@@ -128,7 +113,9 @@ def from_(self, table: str) -> AsyncRequestBuilder:
128113 Returns:
129114 :class:`AsyncRequestBuilder`
130115 """
131- return AsyncRequestBuilder (self .session , f"/{ table } " )
116+ return AsyncRequestBuilder (
117+ self .session , self .base_url .joinpath (table ), self .headers , self .basic_auth
118+ )
132119
133120 def table (self , table : str ) -> AsyncRequestBuilder :
134121 """Alias to :meth:`from_`."""
@@ -142,7 +129,7 @@ def from_table(self, table: str) -> AsyncRequestBuilder:
142129 def rpc (
143130 self ,
144131 func : str ,
145- params : dict ,
132+ params : dict [ str , str ] ,
146133 count : Optional [CountMethod ] = None ,
147134 head : bool = False ,
148135 get : bool = False ,
@@ -171,17 +158,20 @@ def rpc(
171158 method = "HEAD" if head else "GET" if get else "POST"
172159
173160 headers = Headers ({"Prefer" : f"count={ count } " }) if count else Headers ()
174-
175- if method in ("HEAD" , "GET" ):
176- return AsyncRPCFilterRequestBuilder (
177- self .session ,
178- f"/rpc/{ func } " ,
179- method ,
180- headers ,
181- QueryParams (params ),
182- json = {},
183- )
161+ headers .update (self .headers )
184162 # the params here are params to be sent to the RPC and not the queryparams!
185- return AsyncRPCFilterRequestBuilder (
186- self .session , f"/rpc/{ func } " , method , headers , QueryParams (), json = params
163+ json , http_params = (
164+ ({}, QueryParams (params ))
165+ if method in ("HEAD" , "GET" )
166+ else (params , QueryParams ())
167+ )
168+ request = RequestConfig (
169+ self .session ,
170+ self .base_url .joinpath ("rpc" , func ),
171+ method ,
172+ headers ,
173+ http_params ,
174+ self .basic_auth ,
175+ json ,
187176 )
177+ return AsyncRPCFilterRequestBuilder (request )
0 commit comments