Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions daisy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@


class Server(ServerObservee):
def __init__(self, stop_event=None):
def __init__(self, stop_event=None, host=None):
super().__init__()

if stop_event is None:
self.stop_event = Event()
else:
self.stop_event = stop_event

self.tcp_server = TCPServer()
self.tcp_server = TCPServer(host=host)
self.hostname, self.port = self.tcp_server.address

logger.debug("Started server listening at %s:%s", self.hostname, self.port)
Expand Down
77 changes: 71 additions & 6 deletions daisy/tcp/tcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class TCPServer(tornado.tcpserver.TCPServer, IOLooper):
How many times to try to find an empty random port.
"""

def __init__(self, max_port_tries=1000):
def __init__(self, host=None, max_port_tries=1000):
tornado.tcpserver.TCPServer.__init__(self)
IOLooper.__init__(self)

Expand All @@ -41,7 +41,7 @@ def __init__(self, max_port_tries=1000):
% self.max_port_tries
)

self.address = self._get_address()
self.address = self._get_address(host)

def get_message(self, timeout=None):
"""Get a message that was sent to this server.
Expand Down Expand Up @@ -125,20 +125,85 @@ def _check_for_errors(self):
except queue.Empty:
return

def _get_address(self):
"""Get the host and port of the tcp server"""
def _get_address(self, host=None):
"""Get the host and port of the tcp server.

Args:

host (str, optional):

If given, use this as the server host address. If not given,
auto-detect the host by finding the default route IP, and
validate that TCP connections to it work. Falls back to
``127.0.0.1`` if auto-detection fails or the detected address
is not connectable (e.g., blocked by a macOS firewall).
"""

sock = self._sockets[list(self._sockets.keys())[0]]
port = sock.getsockname()[1]

if host is not None:
return (host, port)

# Auto-detect: find the IP of the default route interface
ip = None
outside_sock = None
try:
outside_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
outside_sock.connect(("8.8.8.8", port))
ip = outside_sock.getsockname()[0]
except Exception:
logger.error("Could not detect own IP address, returning bogus IP")
return "8.8.8.8"
pass
finally:
if outside_sock:
outside_sock.close()

# Validate the detected IP with a loopback TCP test. On macOS the
# application firewall blocks incoming connections on non-loopback
# interfaces by default, causing recv on accepted sockets to fail
# with ENOTCONN even though the connection appeared to succeed.
if ip is not None and ip != "127.0.0.1":
if not self._validate_address(ip):
logger.warning(
"Auto-detected address %s failed connectivity check, "
"falling back to 127.0.0.1",
ip,
)
ip = None

if ip is None:
ip = "127.0.0.1"

return (ip, port)

@staticmethod
def _validate_address(ip):
"""Validate that a TCP server on the given IP can accept connections
and recv data. Returns True if the address is usable."""

server_sock = None
client_sock = None
conn = None
try:
server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_sock.settimeout(1)
server_sock.bind((ip, 0))
server_sock.listen(1)
test_port = server_sock.getsockname()[1]

client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client_sock.settimeout(1)
client_sock.connect((ip, test_port))
client_sock.sendall(b"test")

conn, _ = server_sock.accept()
conn.settimeout(1)
data = conn.recv(4)
return data == b"test"
except OSError:
return False
finally:
for s in (conn, client_sock, server_sock):
if s is not None:
s.close()
Loading