Skip to content

Commit 36ae008

Browse files
authored
Merge pull request #1 from HireBase-1/fix/ray-worker-async-client
Fix AsyncMongoClient SRV resolution in Ray worker contexts
2 parents 3da6e85 + 2ef4b85 commit 36ae008

1 file changed

Lines changed: 19 additions & 8 deletions

File tree

pymongo/asynchronous/srv_resolver.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from dns import resolver
2727

2828
_IS_SYNC = False
29-
29+
from dns import resolver
30+
from dns import asyncresolver
3031

3132
def _have_dnspython() -> bool:
3233
try:
@@ -48,15 +49,12 @@ def maybe_decode(text: Union[str, bytes]) -> str:
4849
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
4950
async def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
5051
if _IS_SYNC:
51-
from dns import resolver
52-
5352
if hasattr(resolver, "resolve"):
5453
# dnspython >= 2
5554
return resolver.resolve(*args, **kwargs)
5655
# dnspython 1.X
5756
return resolver.query(*args, **kwargs)
5857
else:
59-
from dns import asyncresolver
6058

6159
if hasattr(asyncresolver, "resolve"):
6260
# dnspython >= 2
@@ -84,6 +82,11 @@ def __init__(
8482
self.__srv = srv_service_name
8583
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
8684
self.__srv_max_hosts = srv_max_hosts or 0
85+
self.__aresolver = None
86+
if not _IS_SYNC:
87+
self.__aresolver = asyncresolver.Resolver(configure=False) # Clean slate, no /etc/resolv.conf inheritance
88+
self.__aresolver.lifetime = 0
89+
self.__aresolver.nameservers = ['8.8.8.8', '8.8.4.4']
8790
# Validate the fully qualified domain name.
8891
try:
8992
ipaddress.ip_address(fqdn)
@@ -102,7 +105,10 @@ async def get_options(self) -> Optional[str]:
102105
from dns import resolver
103106

104107
try:
105-
results = await _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout)
108+
if _IS_SYNC:
109+
results = await _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout) # Existing dispatcher for sync
110+
else:
111+
results = await self.__aresolver.resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout)
106112
except (resolver.NoAnswer, resolver.NXDOMAIN):
107113
# No TXT records
108114
return None
@@ -114,9 +120,14 @@ async def get_options(self) -> Optional[str]:
114120

115121
async def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer:
116122
try:
117-
results = await _resolve(
118-
"_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout
119-
)
123+
if _IS_SYNC:
124+
results = await _resolve(
125+
"_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout
126+
)
127+
else:
128+
results = await self.__aresolver.resolve(
129+
"_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout
130+
)
120131
except Exception as exc:
121132
if not encapsulate_errors:
122133
# Raise the original error.

0 commit comments

Comments
 (0)