Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/flyte/_internal/controllers/remote/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,34 @@ class ControllerClient:
A client for the Controller API.
"""

def __init__(self, channel: grpc.aio.Channel):
def __init__(self, channel: grpc.aio.Channel, endpoint: str | None = None):
self._channel = channel
self._endpoint = endpoint
self._state_service = state_service_pb2_grpc.StateServiceStub(channel=channel)
self._queue_service = queue_service_pb2_grpc.QueueServiceStub(channel=channel)

@classmethod
async def for_endpoint(cls, endpoint: str, insecure: bool = False, **kwargs) -> ControllerClient:
return cls(await create_channel(endpoint, None, insecure=insecure, **kwargs))
return cls(await create_channel(endpoint, None, insecure=insecure, **kwargs), endpoint=endpoint)

@classmethod
async def for_api_key(cls, api_key: str, insecure: bool = False, **kwargs) -> ControllerClient:
return cls(await create_channel(None, api_key, insecure=insecure, **kwargs))

@property
def endpoint(self) -> str | None:
"""
The endpoint this client is connected to.
"""
return self._endpoint

@property
def channel(self) -> grpc.aio.Channel:
"""
The underlying gRPC channel.
"""
return self._channel

@property
def state_service(self) -> StateService:
"""
Expand Down
27 changes: 22 additions & 5 deletions src/flyte/_internal/controllers/remote/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,12 @@ async def _bg_worker_pool(self):
self._running = True
logger.debug("Waiting for Service Client to be ready")
client_set = await self._client_coro
self._client_set = client_set
self._state_service: StateService = client_set.state_service
self._queue_service: QueueService = client_set.queue_service
# Log the endpoint for network debugging
endpoint = getattr(client_set, 'endpoint', None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could expose this on the protocol (python protocol) instead of the getattr

logger.info(f"Controller connected to endpoint: {endpoint}")
self._resource_log_task = asyncio.create_task(self._bg_log_stats())
# We will wait for this to signal that the thread is ready
# Signal the main thread that we're ready
Expand Down Expand Up @@ -297,14 +301,19 @@ async def _bg_cancel_action(self, action: Action):
action.mark_cancelled()
if started:
async with self._rate_limiter:
logger.info(f"Cancelling action: {action.name}")
endpoint = getattr(self._client_set, 'endpoint', 'unknown')
logger.info(f"Cancelling action: {action.name}, endpoint={endpoint}")
try:
await self._queue_service.AbortQueuedAction(
queue_service_pb2.AbortQueuedActionRequest(action_id=action.action_id),
wait_for_ready=True,
)
logger.info(f"Successfully cancelled action: {action.name}")
logger.info(f"Successfully cancelled action: {action.name}, endpoint={endpoint}")
except grpc.aio.AioRpcError as e:
logger.debug(
f"AbortQueuedAction RPC error: code={e.code()}, details={e.details()}, "
f"endpoint={endpoint}, action={action.name}"
)
if e.code() in [
grpc.StatusCode.NOT_FOUND,
grpc.StatusCode.FAILED_PRECONDITION,
Expand Down Expand Up @@ -352,7 +361,11 @@ async def _bg_launch(self, action: Action):
elif action.type == "trace":
trace = action.trace

logger.debug(f"Attempting to launch action: {action.name}")
endpoint = getattr(self._client_set, 'endpoint', 'unknown')
logger.debug(
f"Attempting to launch action: {action.name}, endpoint={endpoint}, "
f"timeout={self._enqueue_timeout}s"
)
try:
await self._queue_service.EnqueueAction(
queue_service_pb2.EnqueueActionRequest(
Expand All @@ -368,8 +381,12 @@ async def _bg_launch(self, action: Action):
wait_for_ready=True,
timeout=self._enqueue_timeout,
)
logger.info(f"Successfully launched action: {action.name}")
logger.info(f"Successfully launched action: {action.name}, endpoint={endpoint}")
except grpc.aio.AioRpcError as e:
logger.debug(
f"EnqueueAction RPC error: code={e.code()}, details={e.details()}, "
f"endpoint={endpoint}, action={action.name}, trailing_metadata={e.trailing_metadata()}"
)
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
logger.info(f"Action {action.name} already exists, continuing to monitor.")
return
Expand All @@ -384,7 +401,7 @@ async def _bg_launch(self, action: Action):
# For all other errors, we will retry with backoff
logger.error(
f"Failed to launch action: {action.name}, Code: {e.code()}, "
f"Details {e.details()} backing off..."
f"Details: {e.details()}, endpoint={endpoint}, backing off..."
)
logger.debug(f"Action details: {action}")
raise flyte.errors.SlowDownError(f"Failed to launch action: {e.details()}") from e
Expand Down
14 changes: 14 additions & 0 deletions src/flyte/remote/_client/auth/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ async def create_channel(
insecure_kwargs["options"] = kw_opts
if compression:
insecure_kwargs["compression"] = compression
logger.debug(f"Creating insecure gRPC channel to endpoint: {endpoint}")
unauthenticated_channel = grpc.aio.insecure_channel(endpoint, **insecure_kwargs)
else:
# Only create SSL credentials if not provided and also only when using secure channel.
Expand All @@ -145,6 +146,11 @@ async def create_channel(
ssl_credentials = grpc.ssl_channel_credentials(st_cert)
else:
ssl_credentials = grpc.ssl_channel_credentials()
logger.debug(
f"Creating secure gRPC channel to endpoint: {endpoint}, "
f"insecure_skip_verify={insecure_skip_verify}, "
f"ca_cert_file_path={ca_cert_file_path}"
)
unauthenticated_channel = grpc.aio.secure_channel(
target=endpoint,
credentials=ssl_credentials,
Expand Down Expand Up @@ -202,8 +208,16 @@ async def create_channel(
insecure_kwargs["options"] = kw_opts
if compression:
insecure_kwargs["compression"] = compression
logger.info(
f"gRPC channel ready: endpoint={endpoint}, secure=False, "
f"interceptors={len(interceptors)}"
)
return grpc.aio.insecure_channel(endpoint, interceptors=interceptors, **insecure_kwargs)

logger.info(
f"gRPC channel ready: endpoint={endpoint}, secure=True, "
f"interceptors={len(interceptors)}, grpc_options={grpc_options}"
)
return grpc.aio.secure_channel(
target=endpoint,
credentials=ssl_credentials,
Expand Down
Loading