diff --git a/CONFIGURATION.md b/CONFIGURATION.md new file mode 100644 index 0000000..e0077ec --- /dev/null +++ b/CONFIGURATION.md @@ -0,0 +1,69 @@ +# Configuration Management in FastFlight + +FastFlight utilizes a robust configuration system based on [Pydantic Settings](https://docs.pydantic.dev/latest/usage/settings/). This allows for type-validated settings loaded from various sources with clear precedence. + +## Overview + +Configuration for different components of FastFlight (Logging, Flight Server, FastAPI Application, Connection Bouncer) is managed through distinct Pydantic `BaseSettings` models. These models define the expected configuration parameters, their types, and default values. + +## Loading Mechanism + +Settings are loaded with the following order of precedence (highest to lowest): + +1. **CLI Arguments:** Command-line arguments provided to `fastflight` CLI commands (e.g., `--port` for `start-fastapi`) will override any other source for the specific parameters they control. +2. **Environment Variables:** Settings can be provided as environment variables. Each setting group has a specific prefix. +3. **`.env` File:** If a `.env` file is present in the working directory when the application starts and the `python-dotenv` package is installed, environment variables will be loaded from this file. These are then treated as regular environment variables. +4. **Default Values:** If a setting is not found in any of the above sources, the default value defined in the Pydantic model is used. + +## Setting Groups + +### 1. Logging Settings + +Controls the application-wide logging behavior. +**Environment Variable Prefix:** `FASTFLIGHT_LOGGING_` + +| Variable Suffix | Description | Type | Default Value | Example Value | +| :-------------- | :-------------------------------- | :----- | :------------ | :----------------- | +| `LOG_LEVEL` | Minimum logging level to output. | `str` | `"INFO"` | `"DEBUG"`, `"WARN"` | +| `LOG_FORMAT` | Log output format. | `str` | `"plain"` | `"json"` | + +### 2. Flight Server Settings + +Controls the behavior of the Arrow Flight server. +**Environment Variable Prefix:** `FASTFLIGHT_SERVER_` + +| Variable Suffix | Description | Type | Default Value | Example Value | +| :--------------------- | :------------------------------------------- | :----- | :------------ | :----------------------------- | +| `HOST` | Host address to bind the server to. | `str` | `"0.0.0.0"` | `"127.0.0.1"` | +| `PORT` | Port to bind the server to. | `int` | `8815` | `9000` | +| `LOG_LEVEL` | Logging level specific to the Flight server. | `str` | `"INFO"` | `"DEBUG"` | +| `AUTH_TOKEN` | Enables token authentication if set. | `str` | `None` | `"your-secret-token"` | +| `TLS_CERT_PATH` | Path to the server's TLS certificate file. | `str` | `None` | `"/path/to/server.crt"` | +| `TLS_KEY_PATH` | Path to the server's TLS private key file. | `str` | `None` | `"/path/to/server.key"` | + +### 3. FastAPI Application Settings + +Controls the behavior of the FastAPI web application. +**Environment Variable Prefix:** `FASTFLIGHT_API_` + +| Variable Suffix | Description | Type | Default Value | Example Value | +| :------------------------ | :----------------------------------------------------------------- | :---------- | :---------------------- | :---------------------------------------- | +| `HOST` | Host address for Uvicorn to bind to. | `str` | `"0.0.0.0"` | `"127.0.0.1"` | +| `PORT` | Port for Uvicorn to bind to. | `int` | `8000` | `8080` | +| `LOG_LEVEL` | Logging level for Uvicorn and FastAPI app. | `str` | `"INFO"` | `"DEBUG"` | +| `FLIGHT_SERVER_LOCATION` | URL for the FastAPI app to connect to the Flight server. | `str` | `"grpc://localhost:8815"` | `"grpc+tls://flight.example.com:443"` | +| `VALID_API_KEYS` | Comma-separated list of valid API keys for client authentication. | `list[str]` | `[]` (empty list) | `"key1,key2,anotherkey"` | +| `SSL_KEYFILE` | Path to the SSL private key file for Uvicorn (HTTPS). | `str` | `None` | `"/path/to/api.key"` | +| `SSL_CERTFILE` | Path to the SSL certificate file for Uvicorn (HTTPS). | `str` | `None` | `"/path/to/api.crt"` | +| `METRICS_ENABLED` | Enable (`True`) or disable (`False`) the `/metrics` endpoint. | `bool` | `True` | `False` (or `"false"`, `"0"`) | + +*Note on `VALID_API_KEYS`: An empty string for the environment variable `FASTFLIGHT_API_VALID_API_KEYS` will result in an empty list, effectively disabling API key checks if that's the desired policy (see `SECURITY.md`).* + +### 4. Bouncer Settings + +Controls the default behavior of the `FastFlightBouncer` (client-side Flight connection pool). +**Environment Variable Prefix:** `FASTFLIGHT_BOUNCER_` + +| Variable Suffix | Description | Type | Default Value | Example Value | +| :-------------- | :---------------------------------------- | :---- | :------------ | :------------ | +| `POOL_SIZE` | Default number of connections in the pool. | `int` | `10` | `20` | diff --git a/docker/Dockerfile b/docker/Dockerfile index 40c3117..9a0c448 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,11 +1,98 @@ -FROM python:3.10-buster +# ---- Builder Stage ---- +FROM python:3.10-slim-buster AS builder -ENV PYTHONPATH="/app" +# Set working directory for virtual environment +WORKDIR /opt/venv +# Create virtual environment +RUN python -m venv . +ENV PATH="/opt/venv/bin:$PATH" # Activate venv + +# Install dependencies COPY requirements.txt . +# Ensure pip is up-to-date and install wheel for potentially compiled packages +RUN pip install --no-cache-dir --upgrade pip wheel RUN pip install --no-cache-dir -r requirements.txt + +# ---- Final Stage ---- +FROM python:3.10-slim-buster + +# Set environment variables for Python +ENV PYTHONDONTWRITEBYTECODE 1 +ENV PYTHONUNBUFFERED 1 + +# Install curl for health checks and any other essential OS packages +# Also update apt-get and clean up to keep image size down +RUN apt-get update && \ + apt-get install -y curl --no-install-recommends && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Create a non-root user and group +ARG APP_USER=fastflight_user +ARG APP_GROUP=fastflight_user +RUN groupadd -r ${APP_GROUP} && useradd --no-log-init -r -g ${APP_GROUP} ${APP_USER} + +# Copy the virtual environment from the builder stage +COPY --from=builder /opt/venv /opt/venv + +# Set up working directory and copy application code WORKDIR /app -COPY src /app +# Copy the src directory and any other necessary files +COPY --chown=${APP_USER}:${APP_GROUP} src /app/src + +# Set up environment for the application +ENV PATH="/opt/venv/bin:$PATH" # Activate venv for the final stage +ENV PYTHONPATH="/app" # Application source code is in /app/src, so /app is the root for imports + +# Switch to the non-root user +USER ${APP_USER} + +# Expose default ports +EXPOSE 8815 # Default for Flight Server (FASTFLIGHT_SERVER_PORT) +EXPOSE 8000 # Default for FastAPI (FASTFLIGHT_API_PORT) + +# Health check for FastAPI (assumes FastAPI runs on port 8000) +# The CMD should ideally run `start-all` or ensure FastAPI is running. +# If only Flight server runs, this health check might fail or be irrelevant. +# This assumes the default CMD will be changed to run both, or this Dockerfile is for FastAPI. +# For now, adding it with the assumption that the CMD will eventually run the FastAPI server. +# TODO: Make health check ports configurable via ARG or tied to ENV if possible, though HEALTHCHECK doesn't directly expand ENV. +HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \ + CMD curl -f http://localhost:8000/fastflight/health || exit 1 + +# Health check for Flight Server (simple TCP check on port 8815) +# This assumes the Flight server runs on port 8815. +# TODO: Similar to FastAPI, port should be configurable. +# Using a Python snippet for TCP check as netcat might not be available. +# This will be a separate HEALTHCHECK layer if both are running. +# Docker combines HEALTHCHECKs if multiple are defined; the last one is used. +# To check both, a wrapper script is typically needed. +# For this iteration, let's assume we want to check one primary service determined by CMD. +# If CMD is `python /app/src/fastflight/server.py`, then Flight server health check is more relevant. +# If CMD is `python -m src.fastflight.cli start-all`, then both should be healthy. +# A single script for combined health check is better. + +# For now, providing a combined health check script approach. +# Create a health check script first. + +# Let's defer the complex combined health check script for a moment and set a primary CMD. +# The subtask implies refining the Dockerfile for the *Flight server* as per original CMD. +# However, the requirement also asks for FastAPI health check. +# This suggests the image should ideally run both via `start-all`. + +# Default CMD - Changed to run both services using the CLI's start-all command. +# This makes both health checks potentially relevant. +CMD ["python", "-m", "src.fastflight.cli", "start-all"] + +# Given the CMD runs `start-all`, a single HEALTHCHECK instruction +# should ideally verify both services. A wrapper script is best for this. +# For now, I will add a health check script and use it. +# This script will check FastAPI first, then Flight server. + +COPY --chown=${APP_USER}:${APP_GROUP} docker/healthcheck.sh /app/healthcheck.sh +RUN chmod +x /app/healthcheck.sh -CMD ["python", "/app/fastflight/flight_server.py"] +HEALTHCHECK --interval=30s --timeout=10s --start-period=15s --retries=3 \ + CMD /app/healthcheck.sh diff --git a/docker/healthcheck.sh b/docker/healthcheck.sh new file mode 100644 index 0000000..e7e483e --- /dev/null +++ b/docker/healthcheck.sh @@ -0,0 +1,30 @@ +#!/bin/sh +set -eo pipefail + +# Default ports, can be overridden by environment variables if needed by the script +FASTAPI_PORT="${FASTFLIGHT_API_PORT:-8000}" +FLIGHT_PORT="${FASTFLIGHT_SERVER_PORT:-8815}" +FASTAPI_HOST="${FASTFLIGHT_API_HOST:-localhost}" # Healthcheck runs inside the container +FLIGHT_HOST="${FASTFLIGHT_SERVER_HOST:-localhost}" # Healthcheck runs inside the container + +echo "Healthcheck: Checking FastAPI server at http://${FASTAPI_HOST}:${FASTAPI_PORT}/fastflight/health" +if curl -fsS "http://${FASTAPI_HOST}:${FASTAPI_PORT}/fastflight/health" > /dev/null; then + echo "Healthcheck: FastAPI server is healthy." +else + echo "Healthcheck: FastAPI server failed." + exit 1 +fi + +echo "Healthcheck: Checking Flight server TCP connection at ${FLIGHT_HOST}:${FLIGHT_PORT}" +# Use python to do a simple TCP check for the Flight server +# This avoids needing netcat or other tools not guaranteed in slim images. +# The python in /opt/venv/bin should be available. +if /opt/venv/bin/python -c "import socket; s = socket.socket(socket.AF_INET, socket.SOCK_STREAM); s.settimeout(5); s.connect(('${FLIGHT_HOST}', ${FLIGHT_PORT})); s.close()"; then + echo "Healthcheck: Flight server TCP connection successful." +else + echo "Healthcheck: Flight server TCP connection failed." + exit 1 +fi + +echo "Healthcheck: All services healthy." +exit 0 diff --git a/src/fastflight/cli.py b/src/fastflight/cli.py index 87fcedc..6edede1 100644 --- a/src/fastflight/cli.py +++ b/src/fastflight/cli.py @@ -76,12 +76,23 @@ def start_fastapi( """ import uvicorn - + from fastflight.config import fastapi_settings # Import settings from fastflight.fastapi import create_app typer.echo(f"Starting FastAPI Server at {host}:{port}") app = create_app(list(module_paths), route_prefix=fast_flight_route_prefix, flight_location=flight_location) - uvicorn.run(app, host=host, port=port) + + uvicorn_kwargs = {"host": host, "port": port} + if fastapi_settings.ssl_keyfile and fastapi_settings.ssl_certfile: + uvicorn_kwargs["ssl_keyfile"] = fastapi_settings.ssl_keyfile + uvicorn_kwargs["ssl_certfile"] = fastapi_settings.ssl_certfile + typer.echo(f"FastAPI SSL enabled using key: {fastapi_settings.ssl_keyfile} and cert: {fastapi_settings.ssl_certfile}") + else: + typer.echo("FastAPI SSL disabled (ssl_keyfile or ssl_certfile not configured).") + if fastapi_settings.ssl_keyfile or fastapi_settings.ssl_certfile: + typer.echo("Warning: FastAPI SSL partially configured but not enabled. Both key and cert files are required.") + + uvicorn.run(app, **uvicorn_kwargs) @cli.command() @@ -110,36 +121,139 @@ def start_all( module_paths (list[str]): Module paths to scan for parameter classes (default: ("fastflight.demo_services",)). """ # Create processes - flight_process = multiprocessing.Process(target=start_fast_flight_server, args=(flight_location,)) - api_process = multiprocessing.Process( - target=start_fastapi, args=(api_host, api_port, fast_flight_route_prefix, flight_location, module_paths) - ) + # Note: The previous CLI implementation used global settings for host/port inside the target functions. + # This might need adjustment if the target functions (start_fast_flight_server, start_fastapi) + # are to be purely driven by parameters passed here from start_all's own CLI options. + # For now, assuming the target functions will pick up global settings or their own defaults + # if not overridden by direct parameters. + + # The `start_fast_flight_server` in the previous implementation takes a single `location` string. + # The updated `server.py`'s `main` and `start_instance` now derive location, auth, tls from global settings. + # So, we can call `start_fast_flight_server` without arguments if it's updated to use settings directly, + # or we parse `flight_location` here if that's still its input. + # Given the changes in server.py, `start_fast_flight_server` itself should be simplified or its parameters changed. + # Let's assume for now that `start_fast_flight_server` will rely on the global `flight_server_settings`. + # A direct call to `flight_server.main()` might be cleaner if `start_fast_flight_server` becomes complex. + + # For this diff, I'll keep the structure, assuming `start_fast_flight_server` is adapted or `flight_location` is still primary. + # However, Flight server settings (host, port, token, tls) are now global. + # The `flight_location` parameter for `start_all` might be redundant if server self-configures from global settings. + + from fastflight.server import main as flight_server_main # Direct import of main + from fastflight.fastapi import create_app # For API process + import uvicorn + from fastflight.config import fastapi_settings # For API SSL + + flight_process = multiprocessing.Process(target=flight_server_main) # Flight server will use its own settings + + # API process target needs to be a function that can be pickled by multiprocessing. + # A simple wrapper or ensuring start_fastapi is robust. + def run_api_server_process(): + typer.echo(f"FastAPI (from start_all) will use host: {api_host}, port: {api_port}") + app = create_app( + list(module_paths), + route_prefix=fast_flight_route_prefix, + # flight_location for create_app is now sourced from fastapi_settings internally + ) + + uvicorn_kwargs = {"host": api_host, "port": api_port} + if fastapi_settings.ssl_keyfile and fastapi_settings.ssl_certfile: + uvicorn_kwargs["ssl_keyfile"] = fastapi_settings.ssl_keyfile + uvicorn_kwargs["ssl_certfile"] = fastapi_settings.ssl_certfile + typer.echo(f"FastAPI SSL (from start_all) enabled using key: {fastapi_settings.ssl_keyfile} and cert: {fastapi_settings.ssl_certfile}") + else: + typer.echo("FastAPI SSL (from start_all) disabled (ssl_keyfile or ssl_certfile not configured).") + if fastapi_settings.ssl_keyfile or fastapi_settings.ssl_certfile: + typer.echo("Warning: FastAPI SSL (from start_all) partially configured. Both key and cert files are required.") + + uvicorn.run(app, **uvicorn_kwargs) + + api_process = multiprocessing.Process(target=run_api_server_process) flight_process.start() api_process.start() + original_sigint_handler = signal.getsignal(signal.SIGINT) + original_sigterm_handler = signal.getsignal(signal.SIGTERM) + def shutdown_handler(signum, frame): - typer.echo("Received termination signal. Shutting down servers...") - flight_process.terminate() - api_process.terminate() - flight_process.join(timeout=5) + typer.echo(f"Signal {signum} received, initiating shutdown...") + + # Restore original handlers to prevent re-entry if issues occur during shutdown + signal.signal(signal.SIGINT, original_sigint_handler) + signal.signal(signal.SIGTERM, original_sigterm_handler) + + # Terminate FastAPI/Uvicorn first, as it might depend on the Flight server + # or just as a general order. + if api_process.is_alive(): + typer.echo("Terminating FastAPI process (sending SIGTERM)...") + api_process.terminate() + if flight_process.is_alive(): - flight_process.kill() - api_process.join(timeout=5) + typer.echo("Terminating Flight server process (sending SIGTERM)...") + flight_process.terminate() + + api_shutdown_gracefully = True + flight_shutdown_gracefully = True + if api_process.is_alive(): - api_process.kill() - typer.echo("Servers shut down cleanly.") - exit(0) + typer.echo("Waiting for FastAPI process to exit (timeout 10s)...") + api_process.join(timeout=10) # Uvicorn's default graceful exit timeout is 5s, give a bit more + if api_process.is_alive(): + typer.echo("FastAPI process did not terminate gracefully, killing (sending SIGKILL)...") + api_process.kill() + api_shutdown_gracefully = False + + if flight_process.is_alive(): + typer.echo("Waiting for Flight server process to exit (timeout 10s)...") + flight_process.join(timeout=10) + if flight_process.is_alive(): + typer.echo("Flight server process did not terminate gracefully, killing (sending SIGKILL)...") + flight_process.kill() + flight_shutdown_gracefully = False + + if api_shutdown_gracefully and flight_shutdown_gracefully: + typer.echo("All servers shut down gracefully.") + else: + typer.echo("One or more servers required force killing.") + + # Raising SystemExit to ensure the main process exits cleanly after handling signals + # Using exit(0) directly in a signal handler can sometimes be problematic. + raise SystemExit(0) - # Handle SIGINT (Ctrl+C) and SIGTERM signal.signal(signal.SIGINT, shutdown_handler) signal.signal(signal.SIGTERM, shutdown_handler) + # Wait for both processes to complete. + # They will run until a signal is received and handled by shutdown_handler, + # which then raises SystemExit. + # If a process exits unexpectedly (e.g., due to an error), its join() will return. try: - while True: - time.sleep(1) # Keep main process running - except KeyboardInterrupt: - shutdown_handler(signal.SIGINT, None) + if flight_process.is_alive(): + flight_process.join() + if flight_process.exitcode != 0 and flight_process.exitcode is not None: # None if killed by signal not from terminate/kill + typer.echo(f"Warning: Flight server process exited with code {flight_process.exitcode}.", err=True) + + if api_process.is_alive(): + api_process.join() + if api_process.exitcode != 0 and api_process.exitcode is not None: + typer.echo(f"Warning: FastAPI process exited with code {api_process.exitcode}.", err=True) + + except SystemExit: # Caught from shutdown_handler + typer.echo("CLI `start-all` process is exiting due to signal.") + except Exception as e: # Catch any other unexpected errors in the main process + typer.echo(f"An unexpected error occurred in `start-all` main loop: {e}", err=True) + finally: + # Ensure processes are cleaned up if they are somehow still alive and SystemExit wasn't raised + # This is a fallback, primary cleanup is in shutdown_handler + if flight_process.is_alive(): + typer.echo("Final cleanup: Terminating lingering Flight server process.", err=True) + flight_process.kill() + if api_process.is_alive(): + typer.echo("Final cleanup: Terminating lingering FastAPI process.", err=True) + api_process.kill() + + typer.echo("FastFlight CLI `start-all` finished.") if __name__ == "__main__": diff --git a/src/fastflight/client.py b/src/fastflight/client.py index 613b7ec..149c5cb 100644 --- a/src/fastflight/client.py +++ b/src/fastflight/client.py @@ -9,6 +9,7 @@ import pyarrow as pa import pyarrow.flight as flight +from fastflight.config import bouncer_settings # Import bouncer_settings from fastflight.core.base import BaseParams from fastflight.exceptions import ( FastFlightConnectionError, @@ -17,6 +18,12 @@ FastFlightServerError, FastFlightTimeoutError, ) +from fastflight.metrics import ( # Import bouncer metrics + bouncer_connections_acquired_total, + bouncer_connections_released_total, + bouncer_pool_available_connections, + bouncer_pool_size, +) from fastflight.resilience import ResilienceConfig, ResilienceManager from fastflight.utils.stream_utils import AsyncToSyncConverter, write_arrow_data_to_stream @@ -82,7 +89,12 @@ class _FlightClientPool: """ def __init__( - self, flight_server_location: str, size: int = 5, converter: Optional[AsyncToSyncConverter] = None + self, + flight_server_location: str, + size: int = 5, + converter: Optional[AsyncToSyncConverter] = None, + auth_token: Optional[str] = None, + # tls_root_certs: Optional[bytes] = None, # For client-side TLS validation with custom CA ) -> None: """ Initialize the internal connection pool. @@ -91,14 +103,34 @@ def __init__( flight_server_location (str): The URI of the Flight server. size (int): The number of connections to maintain in the pool. converter (Optional[AsyncToSyncConverter]): Async-to-sync converter for compatibility. + auth_token (Optional[str]): Authentication token for ClientBasicAuthHandler. + # tls_root_certs (Optional[bytes]): PEM-encoded root certificates for TLS verification. """ self.flight_server_location = flight_server_location self.queue: asyncio.Queue[flight.FlightClient] = asyncio.Queue(maxsize=size) self.pool_size = size + + # Initialize Prometheus gauges for the pool + bouncer_pool_size.set(size) + # Initialize available connections. This assumes the pool is filled upon creation. + bouncer_pool_available_connections.set(size) + + client_options = {} + if auth_token: + # Using empty username, token as password for Basic Auth + client_options["client_auth_handler"] = flight.ClientBasicAuthHandler("", auth_token) + logger.info(f"Client authentication (Basic Auth with token) enabled for connections to {flight_server_location}") + + # Example for client-side TLS if needed in future, based on config + # if "grpc+tls" in flight_server_location and tls_root_certs: + # client_options["tls_root_certs"] = tls_root_certs + for _ in range(size): - self.queue.put_nowait(flight.FlightClient(flight_server_location)) + # Pass additional options like client_auth_handler or tls_root_certs here + self.queue.put_nowait(flight.connect(flight_server_location, **client_options)) + self._converter = converter or GLOBAL_CONVERTER - logger.info(f"Created internal connection pool with {size} clients for {flight_server_location}") + logger.info(f"Created internal connection pool with {size} client(s) for {flight_server_location} (Auth: {'Enabled' if auth_token else 'Disabled'})") @asynccontextmanager async def acquire_async(self, timeout: Optional[float] = None) -> AsyncGenerator[flight.FlightClient, Any]: @@ -116,7 +148,10 @@ async def acquire_async(self, timeout: Optional[float] = None) -> AsyncGenerator """ try: client = await asyncio.wait_for(self.queue.get(), timeout=timeout) + bouncer_connections_acquired_total.inc() + bouncer_pool_available_connections.dec() except asyncio.TimeoutError: + # Consider adding a bouncer_acquisition_timeouts_total counter here raise FastFlightResourceExhaustionError( f"Connection pool exhausted - no connections available within {timeout}s (pool size: {self.pool_size})", resource_type="flight_connection_pool", @@ -127,9 +162,14 @@ async def acquire_async(self, timeout: Optional[float] = None) -> AsyncGenerator yield client except Exception as e: logger.error(f"Error during client operation: {e}", exc_info=True) + # If an error occurs, the connection is still returned to the pool in the finally block. + # This is standard practice unless the connection is known to be corrupted. raise finally: - await self.queue.put(client) + # This block runs after the `yield client` completes or if an exception occurs within the `with` statement. + await self.queue.put(client) # Return client to queue + bouncer_connections_released_total.inc() + bouncer_pool_available_connections.inc() @contextlib.contextmanager def acquire(self, timeout: Optional[float] = None) -> Generator[flight.FlightClient, Any, None]: @@ -147,7 +187,10 @@ def acquire(self, timeout: Optional[float] = None) -> Generator[flight.FlightCli """ try: client = self._converter.run_coroutine(asyncio.wait_for(self.queue.get(), timeout=timeout)) + bouncer_connections_acquired_total.inc() + bouncer_pool_available_connections.dec() except asyncio.TimeoutError: + # Consider adding a bouncer_acquisition_timeouts_total counter here raise FastFlightResourceExhaustionError( f"Connection pool exhausted - no connections available within {timeout}s (pool size: {self.pool_size})", resource_type="flight_connection_pool", @@ -158,9 +201,13 @@ def acquire(self, timeout: Optional[float] = None) -> Generator[flight.FlightCli yield client except Exception as e: logger.error(f"Error during client operation: {e}", exc_info=True) + # Client is returned to pool in finally block. raise finally: - self.queue.put_nowait(client) + # This 'finally' block always runs after the 'yield' if the try block was entered. + self.queue.put_nowait(client) # Return client to queue + bouncer_connections_released_total.inc() + bouncer_pool_available_connections.inc() async def close_async(self): """Close all connections in the pool.""" @@ -212,9 +259,11 @@ def __init__( self, flight_server_location: str, registered_data_types: Dict[str, str] | None = None, - client_pool_size: int = 5, + client_pool_size: Optional[int] = None, # Allow None to use config default converter: Optional[AsyncToSyncConverter] = None, resilience_config: Optional[ResilienceConfig] = None, + auth_token: Optional[str] = None, # Added auth_token + # tls_root_certs_path: Optional[str] = None, # Path to client TLS root certs ): """ Initialize the Flight connection bouncer. @@ -222,13 +271,34 @@ def __init__( Args: flight_server_location (str): Target Flight server URI (e.g., 'grpc://localhost:8815'). registered_data_types (Dict[str, str] | None): Registry of available data service types. - client_pool_size (int): Number of pooled connections to maintain. Defaults to 5. + client_pool_size (Optional[int]): Number of pooled connections to maintain. + Defaults to `bouncer_settings.pool_size` if None. converter (Optional[AsyncToSyncConverter]): Async-to-sync converter for compatibility. resilience_config (Optional[ResilienceConfig]): Resilience patterns configuration (retry, circuit breaker, timeouts). - """ + auth_token (Optional[str]): Authentication token for client connections. + # tls_root_certs_path (Optional[str]): Path to PEM-encoded root certificates for TLS. + """ + effective_pool_size = client_pool_size if client_pool_size is not None else bouncer_settings.pool_size + + # tls_root_certs_bytes: Optional[bytes] = None + # if tls_root_certs_path: + # try: + # with open(tls_root_certs_path, "rb") as f: + # tls_root_certs_bytes = f.read() + # logger.info(f"Loaded client TLS root certificates from {tls_root_certs_path}") + # except IOError as e: + # logger.error(f"Failed to load client TLS root certificates from {tls_root_certs_path}: {e}", exc_info=True) + # # Decide if this should be a fatal error or proceed without client TLS verification override + self._converter = converter or GLOBAL_CONVERTER - self._connection_pool = _FlightClientPool(flight_server_location, client_pool_size, converter=self._converter) + self._connection_pool = _FlightClientPool( + flight_server_location, + effective_pool_size, + converter=self._converter, + auth_token=auth_token, + # tls_root_certs=tls_root_certs_bytes + ) self._registered_data_types = dict(registered_data_types or {}) self._flight_server_location = flight_server_location @@ -239,7 +309,10 @@ def __init__( self._resilience_manager = ResilienceManager(default_config) - logger.info(f"Initialized FastFlightBouncer for {flight_server_location} with {client_pool_size} connections") + logger.info( + f"Initialized FastFlightBouncer for {flight_server_location} " + f"with {effective_pool_size} connections (Client Auth: {'Enabled' if auth_token else 'Disabled'})" + ) def get_registered_data_types(self) -> Dict[str, str]: """Get the registry of available data service types.""" diff --git a/src/fastflight/config.py b/src/fastflight/config.py new file mode 100644 index 0000000..526ba27 --- /dev/null +++ b/src/fastflight/config.py @@ -0,0 +1,40 @@ +from pydantic_settings import BaseSettings, SettingsConfigDict +from typing import Optional + +class LoggingSettings(BaseSettings): + log_level: str = "INFO" + log_format: str = "plain" # or "json" + model_config = SettingsConfigDict(env_prefix='FASTFLIGHT_LOGGING_') + +class FlightServerSettings(BaseSettings): + host: str = "0.0.0.0" + port: int = 8815 + log_level: str = "INFO" + auth_token: Optional[str] = None # For simple single-token auth + # For multiple valid tokens, consider `valid_auth_tokens: list[str] = []` + tls_server_cert_path: Optional[str] = None + tls_server_key_path: Optional[str] = None + model_config = SettingsConfigDict(env_prefix='FASTFLIGHT_SERVER_') + +class FastAPISettings(BaseSettings): + host: str = "0.0.0.0" + port: int = 8000 + log_level: str = "INFO" + flight_server_location: str = "grpc://localhost:8815" # Default if Flight server is local + valid_api_keys: list[str] = [] # List of valid API keys for X-API-Key header + ssl_keyfile: Optional[str] = None + ssl_certfile: Optional[str] = None + metrics_enabled: bool = True # For Prometheus metrics endpoint + model_config = SettingsConfigDict(env_prefix='FASTFLIGHT_API_') + +class BouncerSettings(BaseSettings): + # TODO: Define resilience settings, e.g., max_retries, timeout + pool_size: int = 10 + model_config = SettingsConfigDict(env_prefix='FASTFLIGHT_BOUNCER_') + +# Global settings instances that can be imported and used by other modules. +# These will be loaded from environment variables or .env files. +logging_settings = LoggingSettings() +flight_server_settings = FlightServerSettings() +fastapi_settings = FastAPISettings() +bouncer_settings = BouncerSettings() diff --git a/src/fastflight/core/base.py b/src/fastflight/core/base.py index b66696c..bbfa46e 100644 --- a/src/fastflight/core/base.py +++ b/src/fastflight/core/base.py @@ -83,7 +83,7 @@ def from_bytes(cls, data: bytes) -> Self: params_cls = cls.lookup(fqn) return params_cls.model_validate(json_data) except (json.JSONDecodeError, KeyError, ValueError) as e: - logger.error(f"Error deserializing params: {e}") + logger.error(f"Error deserializing params: {e}", exc_info=True) raise def to_json(self) -> dict: @@ -98,7 +98,7 @@ def to_json(self) -> dict: json_data["param_type"] = self.__class__.fqn() return json_data except (TypeError, ValueError) as e: - logger.error(f"Error serializing params: {e}") + logger.error(f"Error serializing params: {e}", exc_info=True) raise def to_bytes(self) -> bytes: @@ -133,7 +133,7 @@ def __init_subclass__(cls, **kwargs): try: cls._register(param_cls, cls) except ValueError as e: - logger.error(f"Automatic registration failed for {cls.fqn()}: {e}") + logger.error(f"Automatic registration failed for {cls.fqn()}: {e}", exc_info=True) break @classmethod diff --git a/src/fastflight/fastapi/app.py b/src/fastflight/fastapi/app.py index 60267d9..b3e600b 100644 --- a/src/fastflight/fastapi/app.py +++ b/src/fastflight/fastapi/app.py @@ -1,8 +1,15 @@ from typing import AsyncContextManager, Callable from fastapi import FastAPI +# Assuming starlette_prometheus is installed +from starlette_prometheus import exposição_métrica, PrometheusMiddleware + +from src.fastflight.config import fastapi_settings +from src.fastflight.utils.custom_logging import setup_logging +from src.fastflight.utils.registry_check import get_param_service_bindings_from_package, import_all_modules_in_package +# Import custom FastAPI metrics if starlette-prometheus doesn't cover everything or for specific needs. +# from src.fastflight.metrics import fastapi_requests_total, fastapi_request_duration_seconds -from ..utils.registry_check import get_param_service_bindings_from_package, import_all_modules_in_package from .lifespan import combine_lifespans from .router import fast_flight_router @@ -10,15 +17,38 @@ def create_app( module_paths: list[str], route_prefix: str = "/fastflight", - flight_location: str = "grpc://0.0.0.0:8815", + # flight_location is now sourced from fastapi_settings *lifespans: Callable[[FastAPI], AsyncContextManager], ) -> FastAPI: - # Import all custom data parameter and service classes, and check if they are registered + # Setup logging for the FastAPI application + setup_logging(service_name="FastAPIApp") + + # Import all custom data parameter and service classes registered_data_types = {} for mod in module_paths: import_all_modules_in_package(mod) registered_data_types.update(get_param_service_bindings_from_package(mod)) - app = FastAPI(lifespan=lambda a: combine_lifespans(a, registered_data_types, flight_location, *lifespans)) + app = FastAPI( + lifespan=lambda app_instance: combine_lifespans( + app_instance, + registered_data_types, + fastapi_settings.flight_server_location, # Use settings here + *lifespans, + ) + ) + + # Add Prometheus middleware and metrics endpoint if enabled + if fastapi_settings.metrics_enabled: + app.add_middleware(PrometheusMiddleware) + app.add_route("/metrics", exposição_métrica) + # Note: starlette-prometheus provides its own set of default metrics like + # starlette_requests_total, starlette_request_duration_seconds. + # If we defined fastapi_requests_total etc. in metrics.py for manual instrumentation, + # they would be separate unless we disable starlette-prometheus's default ones and use ours. + # For this task, we'll rely on starlette-prometheus for FastAPI specific HTTP metrics. + # Our custom metrics defined in metrics.py (like bouncer, flight_server) will also be exposed + # via the same /metrics endpoint because prometheus_client uses a global registry by default. + app.include_router(fast_flight_router, prefix=route_prefix) return app diff --git a/src/fastflight/fastapi/router.py b/src/fastflight/fastapi/router.py index 19ab975..fdea658 100644 --- a/src/fastflight/fastapi/router.py +++ b/src/fastflight/fastapi/router.py @@ -1,21 +1,24 @@ import logging -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Security # Import Security from fastapi.responses import StreamingResponse from fastflight.client import FastFlightBouncer from fastflight.fastapi.dependencies import body_bytes, fast_flight_client +from fastflight.fastapi.security import get_api_key # Import get_api_key from fastflight.utils.stream_utils import write_arrow_data_to_stream logger = logging.getLogger(__name__) fast_flight_router = APIRouter() -@fast_flight_router.get("/registered_data_types") +@fast_flight_router.get("/registered_data_types", dependencies=[Security(get_api_key)]) def get_registered_data_types(ff_client: FastFlightBouncer = Depends(fast_flight_client)): """ Retrieve all registered data types from the Flight client. + Requires API Key authentication. + Returns a list of dictionaries, each mapping a registered BaseParams class fully qualified name (FQN) to its corresponding BaseDataService class FQN. This endpoint is useful for debugging or introspection in client applications to understand the available data types and their associated services. @@ -29,11 +32,13 @@ def get_registered_data_types(ff_client: FastFlightBouncer = Depends(fast_flight return result -@fast_flight_router.post("/stream") +@fast_flight_router.post("/stream", dependencies=[Security(get_api_key)]) async def read_data(body: bytes = Depends(body_bytes), ff_client: FastFlightBouncer = Depends(fast_flight_client)): """ Endpoint to read data from the Flight server and stream it back in Arrow format. + Requires API Key authentication. + Args: body (bytes): The raw request body bytes. The body should be a JSON-serialized `BaseParams` instance. Crucially, it must include the `param_type` field specifying the fully qualified name (FQN) of the data params class. @@ -46,3 +51,13 @@ async def read_data(body: bytes = Depends(body_bytes), ff_client: FastFlightBoun stream_reader = await ff_client.aget_stream_reader(body) stream = await write_arrow_data_to_stream(stream_reader) return StreamingResponse(stream, media_type="application/vnd.apache.arrow.stream") + + +@fast_flight_router.get("/health", status_code=200) +async def health_check(): + """ + Simple health check endpoint. + """ + # This endpoint being reachable means the FastAPI app is running. + # More sophisticated checks could be added here (e.g., DB connectivity, Flight server ping). + return {"status": "healthy"} diff --git a/src/fastflight/fastapi/security.py b/src/fastflight/fastapi/security.py new file mode 100644 index 0000000..f98b3f7 --- /dev/null +++ b/src/fastflight/fastapi/security.py @@ -0,0 +1,32 @@ +from fastapi import HTTPException, Security +from fastapi.security.api_key import APIKeyHeader +from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN + +from src.fastflight.config import fastapi_settings + +API_KEY_NAME = "X-API-Key" +api_key_header_auth = APIKeyHeader(name=API_KEY_NAME, auto_error=False) # auto_error=False to allow custom error + +async def get_api_key(api_key_header: str = Security(api_key_header_auth)): + """ + Dependency to validate the API key from the X-API-Key header. + Raises HTTPException if the key is missing or invalid. + """ + if not fastapi_settings.valid_api_keys: + # If no API keys are configured on the server, authentication is effectively disabled. + # Depending on policy, could also deny all requests if keys are expected but list is empty. + # For now, assume it means auth is optional/disabled. + return None # Or some indicator that auth is not enforced + + if not api_key_header: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Not authenticated: X-API-Key header is missing." + ) + + if api_key_header not in fastapi_settings.valid_api_keys: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail="Could not validate credentials: Invalid API Key." + ) + return api_key_header # Return the key or a success indicator diff --git a/src/fastflight/metrics.py b/src/fastflight/metrics.py new file mode 100644 index 0000000..f6beab8 --- /dev/null +++ b/src/fastflight/metrics.py @@ -0,0 +1,90 @@ +from prometheus_client import Counter, Gauge, Histogram + +# --- Flight Server Metrics --- +flight_server_requests_total = Counter( + "flight_server_requests_total", + "Total number of requests to the Flight server.", + ["method", "status"] # e.g., method="do_get", status="success" / "error" +) + +flight_server_request_duration_seconds = Histogram( + "flight_server_request_duration_seconds", + "Histogram of Flight server request latencies.", + ["method"] # e.g., method="do_get" +) + +flight_server_active_connections = Gauge( + "flight_server_active_connections", + "Number of currently active connections on the Flight server." + # This might be hard to track accurately without deeper Flight integration. + # For now, it can be incremented on request start and decremented on request end for do_get. +) + +flight_server_bytes_transferred = Counter( + "flight_server_bytes_transferred", + "Total number of bytes transferred by the Flight server.", + ["method", "direction"] # e.g., method="do_get", direction="sent" / "received" + # This is challenging to implement accurately without deep hooks. + # We can count ticket bytes for "received" and RecordBatch bytes for "sent". +) + +# --- FastAPI Application Metrics --- +# These might be largely provided by starlette-prometheus if used. +# If implementing manually or needing additional custom metrics: +fastapi_requests_total = Counter( + "fastapi_requests_total", + "Total number of requests to the FastAPI application.", + ["method", "path", "status_code"] +) + +fastapi_request_duration_seconds = Histogram( + "fastapi_request_duration_seconds", + "Histogram of FastAPI request latencies.", + ["method", "path"] +) + +# --- FastFlightBouncer Metrics --- +bouncer_connections_acquired_total = Counter( + "bouncer_connections_acquired_total", + "Total number of connections acquired from the bouncer pool." +) + +bouncer_connections_released_total = Counter( + "bouncer_connections_released_total", + "Total number of connections released back to the bouncer pool." +) + +bouncer_pool_size = Gauge( + "bouncer_pool_size", + "Configured size of the bouncer connection pool." +) + +bouncer_pool_available_connections = Gauge( + "bouncer_pool_available_connections", + "Current number of available connections in the bouncer pool." +) + +bouncer_circuit_breaker_state = Gauge( + "bouncer_circuit_breaker_state", + "State of the circuit breaker.", + ["circuit_name", "state"], # state: 0 for closed, 1 for open, 2 for half-open +) + +bouncer_circuit_breaker_failures_total = Counter( + "bouncer_circuit_breaker_failures_total", + "Total number of failures tracked by the circuit breaker.", + ["circuit_name"] +) + +bouncer_circuit_breaker_successes_total = Counter( + "bouncer_circuit_breaker_successes_total", + "Total number of successes tracked by the circuit breaker (in half-open or closed state).", + ["circuit_name"] +) + +# Helper to map circuit breaker state string to a number for Prometheus +CIRCUIT_BREAKER_STATE_MAP = { + "closed": 0, + "open": 1, + "half-open": 2, +} diff --git a/src/fastflight/resilience/core/manager.py b/src/fastflight/resilience/core/manager.py index 065735f..b8ed06f 100644 --- a/src/fastflight/resilience/core/manager.py +++ b/src/fastflight/resilience/core/manager.py @@ -19,6 +19,13 @@ from ..config import CircuitBreakerConfig, ResilienceConfig, RetryConfig from ..types import T from .circuit_breaker import CircuitBreaker +# Import Prometheus metrics for circuit breaker +from fastflight.metrics import ( + CIRCUIT_BREAKER_STATE_MAP, + bouncer_circuit_breaker_failures_total, + bouncer_circuit_breaker_state, + bouncer_circuit_breaker_successes_total, +) logger = logging.getLogger(__name__) @@ -65,7 +72,16 @@ def get_circuit_breaker(self, name: str, config: Optional[CircuitBreakerConfig] if name not in self.circuit_breakers: if config is None: raise ValueError(f"Circuit breaker '{name}' not found and no configuration provided") - self.circuit_breakers[name] = CircuitBreaker(name, config) + cb = CircuitBreaker(name, config) + self.circuit_breakers[name] = cb + # Initialize state metric for the new circuit breaker + current_state_str = cb.state.value # Assuming .value gives "closed", "open", "half-open" + bouncer_circuit_breaker_state.labels(circuit_name=name, state=current_state_str).set( + CIRCUIT_BREAKER_STATE_MAP.get(current_state_str, -1) # -1 for unknown/other + ) + # Initialize counters for this CB to zero if not already present (Prometheus handles this) + bouncer_circuit_breaker_failures_total.labels(circuit_name=name).inc(0) + bouncer_circuit_breaker_successes_total.labels(circuit_name=name).inc(0) return self.circuit_breakers[name] async def execute_with_resilience( @@ -105,14 +121,32 @@ async def execute_with_resilience( # Create a wrapped function that applies circuit breaker if asyncio.iscoroutinefunction(func): - async def circuit_wrapped_func(*a, **kw): - return await circuit_breaker.call(func, *a, **kw) - else: - - async def circuit_wrapped_func(*a, **kw): - return await circuit_breaker.call(func, *a, **kw) + # This common wrapper handles metrics for both async and sync funcs called via the async CB + async def common_circuit_wrapped_func_with_metrics(*a, **kw): + cb_name = circuit_breaker.name + try: + # The actual call to the function via the circuit breaker + # circuit_breaker.call is async and handles both async and sync original functions + result = await circuit_breaker.call(func, *a, **kw) + + # If circuit_breaker.call succeeds, it means the underlying func succeeded + # or the CB allowed the call and it succeeded. + bouncer_circuit_breaker_successes_total.labels(circuit_name=cb_name).inc() + except Exception as cb_exc: + # This exception could be from the func itself (if CB is closed/half-open and func fails, + # leading to CB counting a failure) or CircuitBreakerOpen if CB is open. + if not isinstance(cb_exc, asyncio.exceptions.CancelledError): # Don't count cancellations + bouncer_circuit_breaker_failures_total.labels(circuit_name=cb_name).inc() + raise # Re-raise the exception + finally: + # Always update the state gauge after a call, as the call might have changed the state + current_state_str = circuit_breaker.state.value + bouncer_circuit_breaker_state.labels(circuit_name=cb_name, state=current_state_str).set( + CIRCUIT_BREAKER_STATE_MAP.get(current_state_str, -1) + ) + return result # Return the result if no exception was raised or re-raised - wrapped_func = circuit_wrapped_func + wrapped_func = common_circuit_wrapped_func_with_metrics # Apply retry logic if configured if effective_config.retry_config: diff --git a/src/fastflight/security.py b/src/fastflight/security.py new file mode 100644 index 0000000..032de87 --- /dev/null +++ b/src/fastflight/security.py @@ -0,0 +1,79 @@ +import pyarrow.flight as fl + +class ServerAuthHandler(fl.ServerAuthHandler): + """A simple token-based authentication handler for the Flight server.""" + + def __init__(self, valid_tokens: list[str]): + super().__init__() + self.valid_tokens = {token.encode('utf-8') for token in valid_tokens} + if not valid_tokens: + print("Warning: ServerAuthHandler initialized with no valid tokens. All authentication will fail.") + + def authenticate(self, outgoing, incoming): + """ + Authenticates the client. + Expects token to be sent by client via ClientBasicAuthHandler. + The token is available in `incoming.read()` + """ + auth_header = incoming.read() + if not auth_header: + raise fl.FlightUnauthenticatedError("No token provided.") + + # ClientBasicAuthHandler sends "Authorization: Basic " + # For simplicity, we'll assume the token is sent directly or we parse it. + # PyArrow's example for basic auth typically involves username/password. + # If ClientBasicAuthHandler sends username:token, then we need to adjust. + # Let's assume for now the token is sent as is, or is the "password" part of basic auth. + # A common pattern for token auth with basic auth is to use the token as the password, + # and often the username is ignored or a fixed string. + # PyArrow's ClientBasicAuth sends base64(username:password). + # If username is empty, it's base64(:token). + # The server side receives the decoded "username:password" string. + + # For now, let's assume the token is sent as the "password" field in Basic Auth, + # and the client uses an empty username. The `auth_header` would be `b":"`. + # Or, if the client sends only the token, it might be just `b""`. + + # Let's try to be flexible: check if the raw auth_header is a valid token, + # or if it's in the format b":" + + token_to_check = None + if auth_header.startswith(b':'): # Format from ClientBasicAuth(username="", password=token) + token_to_check = auth_header[1:] + else: # Assume raw token or some other format we might adapt to + token_to_check = auth_header + + if token_to_check and token_to_check in self.valid_tokens: + # Return the validated token as the peer identity + outgoing.write(token_to_check) # Send back the identity + return token_to_check # This becomes context.peer_identity + else: + if token_to_check: + print(f"Auth failed. Received token: {token_to_check.decode('utf-8', errors='replace')}") + else: + print("Auth failed. No token extracted from header.") + raise fl.FlightUnauthenticatedError("Invalid token.") + + def is_valid(self, token: bytes): + """ + Checks if the given token (peer identity) is still valid for subsequent actions. + The 'token' here is what was returned by `authenticate`. + """ + if not token: # Should not happen if authenticate did its job + raise fl.FlightUnauthenticatedError("No token associated with peer.") + if token not in self.valid_tokens: + print(f"Validation failed for token: {token.decode('utf-8', errors='replace')}") + raise fl.FlightUnauthenticatedError("Token is no longer valid.") + return None # Returning None (or any value) indicates success, exception indicates failure. + # The documentation is a bit sparse; official examples often return the token itself or some user identifier. + # For this method, raising an exception on invalid is the key. + # What is returned seems to be for the server's own use, not directly by Flight. + # Let's return None as per some interpretations that it's just a check. + # Update: is_valid should return a value that can be used by the server if needed, + # often the same peer identity. For now, returning None is fine if we don't use its return value. + # However, to be safe and align with some examples, let's return the token. + # return token + # Simpler: if it's not valid, we raise. If it is, we do nothing / return None implicitly. + # The method is more like "assert_is_valid". + # Let's stick to raising exception on invalid, and returning nothing (None) on valid. + pass # Implicitly returns None diff --git a/src/fastflight/server.py b/src/fastflight/server.py index a660abe..f246c69 100644 --- a/src/fastflight/server.py +++ b/src/fastflight/server.py @@ -2,12 +2,22 @@ import logging import multiprocessing import sys -from typing import cast +import time # For timing +from typing import Optional, cast import pyarrow as pa from pyarrow import RecordBatchReader, flight +from fastflight.config import flight_server_settings from fastflight.core.base import BaseDataService, BaseParams +from fastflight.metrics import ( # Import metrics + flight_server_active_connections, + flight_server_bytes_transferred, + flight_server_request_duration_seconds, + flight_server_requests_total, +) +from fastflight.security import ServerAuthHandler +from fastflight.utils.custom_logging import setup_logging from fastflight.utils.debug import debuggable from fastflight.utils.stream_utils import AsyncToSyncConverter @@ -23,16 +33,23 @@ class FastFlightServer(flight.FlightServerBase): Attributes: location (str): The URI where the server is hosted. + auth_handler (Optional[ServerAuthHandler]): The server authentication handler. """ - def __init__(self, location: str): - super().__init__(location) - self.location = location + def __init__(self, location: str, auth_handler: Optional[ServerAuthHandler] = None): + # The location string for super().__init__ might need to be adjusted if TLS is enabled, + # e.g., from "grpc://host:port" to "grpc+tls://host:port" + # This will be handled in start_instance or main. + super().__init__(location, auth_handler=auth_handler) + self.location = location # Store the original logical location + self._auth_handler = auth_handler self._converter = AsyncToSyncConverter() + # Initialize a counter for active do_get calls if not using more sophisticated connection tracking + self._active_do_get_calls = 0 def do_get(self, context, ticket: flight.Ticket) -> flight.RecordBatchStream: """ - Handles a data retrieval request from a client. + Handles a data retrieval request from a client, with Prometheus metrics. This method: - Parses the `ticket` to extract the request parameters. @@ -50,15 +67,41 @@ def do_get(self, context, ticket: flight.Ticket) -> flight.RecordBatchStream: flight.FlightUnavailableError: If the requested data service is not registered. flight.FlightInternalError: If an unexpected error occurs during retrieval. """ + method_name = "do_get" + start_time = time.monotonic() + flight_server_active_connections.inc() + self._active_do_get_calls += 1 # Manual tracking if gauge needs it per instance + try: - logger.debug("Received ticket: %s", ticket.ticket) + logger.debug("Received ticket (len): %s bytes", len(ticket.ticket) if ticket else 0) + flight_server_bytes_transferred.labels(method=method_name, direction="received").inc(len(ticket.ticket or b"")) + data_params, data_service = self._resolve_ticket(ticket) + # This is a RecordBatchReader; to count sent bytes, we'd need to iterate through it + # and sum byte sizes of batches, or wrap it. This is complex here. + # For now, we'll increment a placeholder for sent bytes or skip detailed byte counting for sent. + # Let's assume _get_batch_reader returns a custom reader that can track bytes if we go deep. + # As a simplification, we won't track sent bytes accurately here yet. + reader = self._get_batch_reader(data_service, data_params) + + # To accurately track sent bytes, we would need to wrap the reader or the stream. + # For example, by creating a generator that yields batches and counts their size. + # This is an approximation for now, as actual sent bytes depend on client consumption. + # flight_server_bytes_transferred.labels(method=method_name, direction="sent").inc(APPROX_SIZE_OR_IMPLEMENT_TRACKING_READER) + + flight_server_requests_total.labels(method=method_name, status="success").inc() return flight.RecordBatchStream(reader) except Exception as e: + flight_server_requests_total.labels(method=method_name, status="error").inc() logger.error(f"Error processing request: {e}", exc_info=True) error_msg = f"Internal server error: {type(e).__name__}: {str(e)}" raise flight.FlightInternalError(error_msg) + finally: + duration = time.monotonic() - start_time + flight_server_request_duration_seconds.labels(method=method_name).observe(duration) + flight_server_active_connections.dec() + self._active_do_get_calls -= 1 def _get_batch_reader( self, data_service: BaseDataService, params: BaseParams, batch_size: int | None = None @@ -72,22 +115,41 @@ def _get_batch_reader( Returns: RecordBatchReader: A RecordBatchReader instance to read the data in batches. """ + # This is where actual data retrieval happens. + # For accurate sent byte counting, this method or the data_service methods would need to be modified + # to report the size of data produced. try: try: batch_iter = iter(data_service.get_batches(params, batch_size)) except NotImplementedError: batch_iter = self._converter.syncify_async_iter(data_service.aget_batches(params, batch_size)) - first = next(batch_iter) - return RecordBatchReader.from_batches(first.schema, itertools.chain((first,), batch_iter)) + first_batch = next(batch_iter) + # Example: Approximate sent bytes based on first batch (very rough) + # flight_server_bytes_transferred.labels(method="do_get", direction="sent").inc(first_batch.nbytes) + + # To count all bytes, we'd need a wrapper: + # def byte_counting_iterator(original_iterator): + # total_bytes = 0 + # for batch in original_iterator: + # total_bytes += batch.nbytes + # yield batch + # flight_server_bytes_transferred.labels(method="do_get", direction="sent").inc(total_bytes) + # chained_iterator = byte_counting_iterator(itertools.chain((first_batch,), batch_iter)) + # return RecordBatchReader.from_batches(first_batch.schema, chained_iterator) + + return RecordBatchReader.from_batches(first_batch.schema, itertools.chain((first_batch,), batch_iter)) except StopIteration: + logger.warning("Data service returned no batches for params: %s", params.fqn()) raise flight.FlightInternalError("Data service returned no batches.") - except AttributeError as e: + except AttributeError as e: # E.g. if data_service doesn't have get_batches or aget_batches + logger.error(f"Service method issue with {data_service.fqn()}: {e}", exc_info=True) raise flight.FlightInternalError(f"Service method issue: {e}") - except Exception as e: + except Exception as e: # Other data retrieval errors logger.error(f"Error retrieving data from {data_service.fqn()}: {e}", exc_info=True) raise flight.FlightInternalError(f"Error in data retrieval: {type(e).__name__}: {str(e)}") + @staticmethod def _resolve_ticket(ticket: flight.Ticket) -> tuple[BaseParams, BaseDataService]: try: @@ -113,20 +175,70 @@ def shutdown(self): super().shutdown() @classmethod - def start_instance(cls, location: str, debug: bool = False): - server = cls(location) - logger.info("Serving FastFlightServer in process %s", multiprocessing.current_process().name) - if debug or sys.gettrace() is not None: - logger.info("Enabling debug mode") + def start_instance( + cls, + host: str, + port: int, + auth_handler: Optional[ServerAuthHandler] = None, + tls_info: Optional[flight.ServerTLSInfo] = None, + debug: bool = False, + ): + scheme = "grpc+tls" if tls_info else "grpc" + location = f"{scheme}://{host}:{port}" + + server = cls(location, auth_handler=auth_handler) + logger.info( + "Serving FastFlightServer in process %s on %s (Auth: %s, TLS: %s)", + multiprocessing.current_process().name, + location, + "Enabled" if auth_handler else "Disabled", + "Enabled" if tls_info else "Disabled", + ) + + if debug or sys.gettrace() is not None or flight_server_settings.log_level.upper() == "DEBUG": + logger.info("Enabling debug mode for FastFlightServer.do_get") server.do_get = debuggable(server.do_get) # type: ignore[method-assign] - server.serve() + + server.serve(tls_info=tls_info, auth_handler=auth_handler) # Pass auth_handler to serve too def main(): - from fastflight.utils.custom_logging import setup_logging - - setup_logging() - FastFlightServer.start_instance("grpc://0.0.0.0:8815", True) + setup_logging(service_name="FastFlightServer") + + auth_handler_instance: Optional[ServerAuthHandler] = None + if flight_server_settings.auth_token: + logger.info("Authentication enabled for Flight Server.") + # For multiple tokens, ServerAuthHandler would need to be initialized with a list + auth_handler_instance = ServerAuthHandler(valid_tokens=[flight_server_settings.auth_token]) + else: + logger.info("Authentication disabled for Flight Server (no auth_token configured).") + + tls_info_instance: Optional[flight.ServerTLSInfo] = None + if flight_server_settings.tls_server_cert_path and flight_server_settings.tls_server_key_path: + logger.info("TLS enabled for Flight Server.") + with open(flight_server_settings.tls_server_cert_path, 'rb') as cert_file, \ + open(flight_server_settings.tls_server_key_path, 'rb') as key_file: + tls_info_instance = flight.ServerTLSInfo( + cert_chain=cert_file.read(), + private_key=key_file.read() + ) + else: + logger.info("TLS disabled for Flight Server (cert_path or key_path not configured).") + if flight_server_settings.tls_server_cert_path or flight_server_settings.tls_server_key_path: + logger.warning("TLS partially configured but not enabled: both cert and key paths are required.") + + + logger.info( + f"Starting FastFlight server with settings: host={flight_server_settings.host}, " + f"port={flight_server_settings.port}, log_level={flight_server_settings.log_level}" + ) + + FastFlightServer.start_instance( + host=flight_server_settings.host, + port=flight_server_settings.port, + auth_handler=auth_handler_instance, + tls_info=tls_info_instance, + ) if __name__ == "__main__": diff --git a/src/fastflight/test_logging.py b/src/fastflight/test_logging.py new file mode 100644 index 0000000..5a2a0e5 --- /dev/null +++ b/src/fastflight/test_logging.py @@ -0,0 +1,40 @@ +import logging +from src.fastflight.config import logging_settings +from src.fastflight.utils.custom_logging import setup_logging + +def run_logging_test(): + # Override settings for this test + logging_settings.log_format = "json" + logging_settings.log_level = "DEBUG" # Ensure DEBUG is effective + + # Setup logging using the overridden settings + # The setup_logging function in custom_logging.py should use these global settings + setup_logging( + console_log_level=logging_settings.log_level.upper(), # Pass directly to ensure it overrides defaults + file_format=logging_settings.log_format # type: ignore + ) + + logger = logging.getLogger("my_test_logger") + + logger.debug("This is a debug message from the test script.", extra={"key1": "value1", "num_key": 123}) + logger.info("This is an info message from the test script.", extra={"complex_key": {"k": "v"}}) + logger.warning("This is a warning from the test script.") + logger.error("This is an error message from the test script.") + + try: + x = 1 / 0 + except ZeroDivisionError as e: + # Log the exception info + logger.error("An exception occurred in the test script", exc_info=True, extra={"exception_test": True}) + # Also test logging the exception object directly (structlog might handle this) + # logger.error("Logging exception object directly", exception=e) # structlog specific, might need different handling + + # Test a message that might use specific structlog features if it was a structlog logger + # For stdlib logger with structlog backend, it will go through standard formatting + # structlog_logger = structlog.get_logger("my_structlog_test_logger") + # structlog_logger.info("A message with structlog specific field", structlog_field="structlog_value") + + print("Logging test finished. Check output above.") + +if __name__ == "__main__": + run_logging_test() diff --git a/src/fastflight/utils/custom_logging.py b/src/fastflight/utils/custom_logging.py index c0ae627..425d963 100644 --- a/src/fastflight/utils/custom_logging.py +++ b/src/fastflight/utils/custom_logging.py @@ -7,6 +7,7 @@ from typing import Literal import structlog +from src.fastflight.config import logging_settings shared_processors = [ # If log level is too low, abort pipeline and throw away log entry. @@ -38,10 +39,10 @@ def setup_logging( - console_log_level: str | int = "DEBUG", - log_file: None | Path | str = "app.log", - file_log_level: str | int = "INFO", - file_format: Literal["plain", "json"] = "plain", + console_log_level: str | int = logging_settings.log_level.upper(), + log_file: None | Path | str = "app.log", # Keep app.log as default, can be overridden by CLI or direct calls + file_log_level: str | int = logging_settings.log_level.upper(), + file_format: Literal["plain", "json"] = logging_settings.log_format, # type: ignore ): """ Set up the logging configuration for the application. diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..ff1b0d7 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,136 @@ +import os +import pytest +from unittest import mock + +# Before importing the config module, ensure any pre-existing relevant env vars are cleared +# or handled if they could interfere with tests. This is tricky globally. +# We'll rely on mock.patch.dict to set them specifically for each test. + +from src.fastflight.config import ( + LoggingSettings, + FlightServerSettings, + FastAPISettings, + BouncerSettings +) + +@pytest.fixture(autouse=True) +def clear_env_vars(): + """Clear relevant environment variables before each test and restore after.""" + env_vars_to_manage = [ + "FASTFLIGHT_LOGGING_LOG_LEVEL", "FASTFLIGHT_LOGGING_LOG_FORMAT", + "FASTFLIGHT_SERVER_HOST", "FASTFLIGHT_SERVER_PORT", "FASTFLIGHT_SERVER_LOG_LEVEL", + "FASTFLIGHT_SERVER_AUTH_TOKEN", "FASTFLIGHT_SERVER_TLS_SERVER_CERT_PATH", "FASTFLIGHT_SERVER_TLS_SERVER_KEY_PATH", + "FASTFLIGHT_API_HOST", "FASTFLIGHT_API_PORT", "FASTFLIGHT_API_LOG_LEVEL", + "FASTFLIGHT_API_FLIGHT_SERVER_LOCATION", "FASTFLIGHT_API_VALID_API_KEYS", + "FASTFLIGHT_API_SSL_KEYFILE", "FASTFLIGHT_API_SSL_CERTFILE", "FASTFLIGHT_API_METRICS_ENABLED", + "FASTFLIGHT_BOUNCER_POOL_SIZE" + ] + original_values = {var: os.environ.get(var) for var in env_vars_to_manage} + + # Clear them for the test + for var in env_vars_to_manage: + if var in os.environ: + del os.environ[var] + + yield # Test runs here + + # Restore original values + for var, original_value in original_values.items(): + if original_value is not None: + os.environ[var] = original_value + elif var in os.environ: # If it was set during test but not originally + del os.environ[var] + + +def test_logging_settings_defaults(): + settings = LoggingSettings() + assert settings.log_level == "INFO" + assert settings.log_format == "plain" + +def test_logging_settings_from_env(): + with mock.patch.dict(os.environ, { + "FASTFLIGHT_LOGGING_LOG_LEVEL": "DEBUG", + "FASTFLIGHT_LOGGING_LOG_FORMAT": "json" + }): + settings = LoggingSettings() + assert settings.log_level == "DEBUG" + assert settings.log_format == "json" + +def test_flight_server_settings_defaults(): + settings = FlightServerSettings() + assert settings.host == "0.0.0.0" + assert settings.port == 8815 + assert settings.log_level == "INFO" + assert settings.auth_token is None + assert settings.tls_server_cert_path is None + assert settings.tls_server_key_path is None + +def test_flight_server_settings_from_env(): + with mock.patch.dict(os.environ, { + "FASTFLIGHT_SERVER_HOST": "127.0.0.1", + "FASTFLIGHT_SERVER_PORT": "9000", + "FASTFLIGHT_SERVER_LOG_LEVEL": "WARNING", + "FASTFLIGHT_SERVER_AUTH_TOKEN": "test_token", + "FASTFLIGHT_SERVER_TLS_SERVER_CERT_PATH": "/path/to/cert.pem", + "FASTFLIGHT_SERVER_TLS_SERVER_KEY_PATH": "/path/to/key.pem" + }): + settings = FlightServerSettings() + assert settings.host == "127.0.0.1" + assert settings.port == 9000 + assert settings.log_level == "WARNING" + assert settings.auth_token == "test_token" + assert settings.tls_server_cert_path == "/path/to/cert.pem" + assert settings.tls_server_key_path == "/path/to/key.pem" + +def test_fastapi_settings_defaults(): + settings = FastAPISettings() + assert settings.host == "0.0.0.0" + assert settings.port == 8000 + assert settings.log_level == "INFO" + assert settings.flight_server_location == "grpc://localhost:8815" + assert settings.valid_api_keys == [] + assert settings.ssl_keyfile is None + assert settings.ssl_certfile is None + assert settings.metrics_enabled is True + +def test_fastapi_settings_from_env(): + with mock.patch.dict(os.environ, { + "FASTFLIGHT_API_HOST": "127.0.0.2", + "FASTFLIGHT_API_PORT": "8001", + "FASTFLIGHT_API_LOG_LEVEL": "CRITICAL", + "FASTFLIGHT_API_FLIGHT_SERVER_LOCATION": "grpc://otherhost:1234", + "FASTFLIGHT_API_VALID_API_KEYS": "key1,key2, key3", # Test with spaces + "FASTFLIGHT_API_SSL_KEYFILE": "/ssl/key.pem", + "FASTFLIGHT_API_SSL_CERTFILE": "/ssl/cert.pem", + "FASTFLIGHT_API_METRICS_ENABLED": "false" # Test boolean parsing + }): + settings = FastAPISettings() + assert settings.host == "127.0.0.2" + assert settings.port == 8001 + assert settings.log_level == "CRITICAL" + assert settings.flight_server_location == "grpc://otherhost:1234" + assert settings.valid_api_keys == ["key1", "key2", "key3"] + assert settings.ssl_keyfile == "/ssl/key.pem" + assert settings.ssl_certfile == "/ssl/cert.pem" + assert settings.metrics_enabled is False # Pydantic automatically converts "false" + +def test_fastapi_settings_empty_api_keys_from_env(): + with mock.patch.dict(os.environ, {"FASTFLIGHT_API_VALID_API_KEYS": ""}): + settings = FastAPISettings() + # Pydantic v2 by default might convert "" to [''] for List[str] if not handled. + # However, if the default is [], and the field is Optional or has a default_factory, + # it might result in []. Let's verify Pydantic's behavior for comma-separated strings. + # For pydantic_settings and comma-separated lists, an empty string usually results in an empty list + # if the list items are simple strings. If it becomes `['']`, the test needs adjustment or the model needs refinement. + # Based on typical pydantic-settings behavior for `list[str]`, an empty string for the env var + # should result in an empty list, not `['']`. + assert settings.valid_api_keys == [] + +def test_bouncer_settings_defaults(): + settings = BouncerSettings() + assert settings.pool_size == 10 + +def test_bouncer_settings_from_env(): + with mock.patch.dict(os.environ, {"FASTFLIGHT_BOUNCER_POOL_SIZE": "20"}): + settings = BouncerSettings() + assert settings.pool_size == 20 diff --git a/tests/test_fastapi_auth.py b/tests/test_fastapi_auth.py new file mode 100644 index 0000000..1c7580c --- /dev/null +++ b/tests/test_fastapi_auth.py @@ -0,0 +1,104 @@ +import pytest +from fastapi.testclient import TestClient +from unittest import mock +import os + +# Assuming src.fastflight.fastapi.app.create_app is the entry point +# and src.fastflight.config.fastapi_settings is the global settings instance used by the app +from src.fastflight.fastapi.app import create_app +from src.fastflight.config import fastapi_settings as global_fastapi_settings +from src.fastflight.fastapi.security import API_KEY_NAME # To use the correct header name + +# Minimal list of module paths for testing, assuming demo_services has what's needed +# or that the specific endpoints tested don't rely on deep service discovery. +TEST_MODULE_PATHS = ["src.fastflight.demo_services"] +TEST_API_KEY_VALID = "test-api-key-valid" +TEST_API_KEY_INVALID = "test-api-key-invalid" + +@pytest.fixture(scope="module") +def client_with_api_key_auth(): + # Temporarily modify global fastapi_settings for the test client + original_valid_keys = global_fastapi_settings.valid_api_keys + global_fastapi_settings.valid_api_keys = [TEST_API_KEY_VALID] + + # Create a TestClient for the FastAPI app + # The create_app function should ideally be able to take settings overrides, + # or use the globally patched one. + app = create_app(module_paths=list(TEST_MODULE_PATHS)) + client = TestClient(app) + + yield client + + # Restore original settings + global_fastapi_settings.valid_api_keys = original_valid_keys + + +@pytest.fixture(scope="module") +def client_with_no_api_keys_configured(): + # Test scenario where server has no API keys configured (auth effectively disabled by policy) + original_valid_keys = global_fastapi_settings.valid_api_keys + global_fastapi_settings.valid_api_keys = [] # No keys configured + + app = create_app(module_paths=list(TEST_MODULE_PATHS)) + client = TestClient(app) + + yield client + + global_fastapi_settings.valid_api_keys = original_valid_keys + + +# Test endpoints (assuming these exist and are protected by API key) +# From previous subtasks, /fastflight/registered_data_types and /fastflight/stream are protected. +# /fastflight/health is not. + +PROTECTED_ENDPOINTS_GET = ["/fastflight/registered_data_types"] +# For POST, we'd need a valid body, e.g. for /fastflight/stream +# Let's focus on GET for simplicity of auth header testing. + +@pytest.mark.parametrize("endpoint_path", PROTECTED_ENDPOINTS_GET) +def test_fastapi_auth_failure_no_key_header(client_with_api_key_auth: TestClient, endpoint_path: str): + response = client_with_api_key_auth.get(endpoint_path) + # Expect 401 if auto_error=False and no key, or 403 if auto_error=True (FastAPI default) + # Our get_api_key raises 401 if header is missing and keys are configured. + assert response.status_code == 401 + assert "Not authenticated" in response.json()["detail"] + +@pytest.mark.parametrize("endpoint_path", PROTECTED_ENDPOINTS_GET) +def test_fastapi_auth_failure_invalid_key(client_with_api_key_auth: TestClient, endpoint_path: str): + response = client_with_api_key_auth.get(endpoint_path, headers={API_KEY_NAME: TEST_API_KEY_INVALID}) + assert response.status_code == 403 + assert "Invalid API Key" in response.json()["detail"] + +@pytest.mark.parametrize("endpoint_path", PROTECTED_ENDPOINTS_GET) +def test_fastapi_auth_success_valid_key(client_with_api_key_auth: TestClient, endpoint_path: str): + response = client_with_api_key_auth.get(endpoint_path, headers={API_KEY_NAME: TEST_API_KEY_VALID}) + assert response.status_code == 200 # Assuming this endpoint returns 200 on success + +@pytest.mark.parametrize("endpoint_path", PROTECTED_ENDPOINTS_GET) +def test_fastapi_auth_effectively_disabled_if_no_keys_configured(client_with_no_api_keys_configured: TestClient, endpoint_path: str): + # In get_api_key, if fastapi_settings.valid_api_keys is empty, it returns None (allowing access). + response = client_with_no_api_keys_configured.get(endpoint_path) + assert response.status_code == 200 + + response_with_random_key = client_with_no_api_keys_configured.get(endpoint_path, headers={API_KEY_NAME: "random-key-should-still-pass"}) + assert response_with_random_key.status_code == 200 + +# Test for the /fastflight/health endpoint (should not require API key) +def test_fastapi_health_endpoint_no_auth_needed(client_with_api_key_auth: TestClient): + response = client_with_api_key_auth.get("/fastflight/health") + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + +def test_fastapi_health_endpoint_no_auth_needed_when_keys_not_set(client_with_no_api_keys_configured: TestClient): + response = client_with_no_api_keys_configured.get("/fastflight/health") + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + +# TODO: Test for /fastflight/stream (POST endpoint) would require a valid body. +# Example structure for POST test: +# def test_fastapi_auth_success_stream_post(client_with_api_key_auth: TestClient): +# # This requires a valid BaseParams serialized body that the demo_services can handle +# # and potentially a running Flight server if the test client makes real calls through bouncer. +# # If bouncer is mocked or demo service doesn't need live Flight, it's simpler. +# # For now, focusing on header-based auth for GET. +# pass diff --git a/tests/test_flight_auth.py b/tests/test_flight_auth.py new file mode 100644 index 0000000..5f603ef --- /dev/null +++ b/tests/test_flight_auth.py @@ -0,0 +1,164 @@ +import pytest +import pyarrow as pa +import pyarrow.flight as fl +import threading +import time +import os +from unittest import mock + +from src.fastflight.server import FastFlightServer +from src.fastflight.client import FastFlightBouncer +from src.fastflight.security import ServerAuthHandler +from src.fastflight.config import FlightServerSettings, flight_server_settings as global_flight_server_settings +from src.fastflight.core.base import BaseDataService, BaseParams, DataServiceCls, ParamsCls + +# Minimal Data Service for testing +class PingParams(BaseParams): + pass + +@BaseDataService._register(PingParams) # Manually register for the test +class PingService(BaseDataService[PingParams]): + def get_batches(self, params: PingParams, batch_size: int | None = None) -> pa.RecordBatchReader: + data = [pa.array([1, 2, 3]), pa.array(["a", "b", "c"])] + schema = pa.schema([("col1", pa.int64()), ("col2", pa.string())]) + batch = pa.record_batch(data, schema=schema) + return pa.RecordBatchReader.from_batches(schema, [batch]) + + async def aget_batches(self, params: PingParams, batch_size: int | None = None): + # For simplicity, make async version call sync version via converter if needed, + # or just implement directly for this test. + reader = self.get_batches(params, batch_size) + for batch in reader: + yield batch + + +TEST_HOST = "127.0.0.1" +TEST_PORT = 8890 # Use a different port for tests +TEST_TOKEN = "test-auth-token-123" + +@pytest.fixture(scope="module") +def flight_server_with_auth(): + # Patch global settings for the duration of this server fixture + # This is tricky because config module might be already loaded. + # A better way would be to pass settings to server constructor if it supported it, + # or ensure server reads settings on demand. + # For now, we modify the global flight_server_settings instance used by server.main or start_instance + + original_token = global_flight_server_settings.auth_token + original_host = global_flight_server_settings.host + original_port = global_flight_server_settings.port + + global_flight_server_settings.auth_token = TEST_TOKEN + global_flight_server_settings.host = TEST_HOST + global_flight_server_settings.port = TEST_PORT + + # The server's main() or start_instance() will pick up these patched settings + # when it creates ServerAuthHandler and ServerTLSInfo + + # We need to run the server in a separate thread/process because server.serve() is blocking. + # Using server.main which calls start_instance. + from src.fastflight.server import main as flight_server_main_actual + + server_thread = threading.Thread(target=flight_server_main_actual, daemon=True) + server_thread.start() + + # Wait for server to start - simplistic approach + time.sleep(1.0) # Give server a moment to start + + # Check if server is up - more robust would be a client ping + try: + # Try a quick connection without auth to see if port is open (might fail on auth, that's fine) + client = fl.connect(f"grpc://{TEST_HOST}:{TEST_PORT}", timeout=1) + client.close() + except Exception as e: + # If it's an auth error, server is up. If connection refused, it's not. + if "Connection refused" in str(e) or "Deadline Exceeded" in str(e): # Deadline Exceeded for timeout + pytest.fail(f"Flight server did not start on {TEST_HOST}:{TEST_PORT}: {e}") + # Other errors (like auth error) might be expected if server is up but requires auth immediately + print(f"Flight server startup check got client connection error (potentially expected): {e}") + + + yield f"grpc://{TEST_HOST}:{TEST_PORT}" # Provide the location + + # Teardown: Stop the server + # PyArrow FlightServerBase needs shutdown to be called. + # This is tricky as server_thread runs `server.main()`. + # `server.main()` itself would need to handle signals to call server.shutdown(). + # For testing, directly finding and shutting down the server instance is hard. + # Sending a signal to the thread is not straightforward. + # Since it's a daemon thread, it will exit when the main test process exits. + # This is usually acceptable for tests but not for production shutdown. + # For more graceful shutdown in tests, server would need an explicit stop method or signal handling. + # Reset global settings + global_flight_server_settings.auth_token = original_token + global_flight_server_settings.host = original_host + global_flight_server_settings.port = original_port + # Note: Proper server shutdown in a test thread is complex. + # For now, rely on daemon thread + process exit. If tests hang, this needs improvement. + # A common pattern is `server.shutdown()` if `server` object was accessible. + # Or `server_process.terminate()` if it was a `multiprocessing.Process`. + + +def test_flight_auth_failure_no_token(flight_server_with_auth): + location = flight_server_with_auth + # Bouncer without auth_token + bouncer = FastFlightBouncer(flight_server_location=location) + params = PingParams() + + with pytest.raises(fl.FlightUnauthenticatedError) as exc_info: + bouncer.get_pa_table(params) + + assert "No token provided" in str(exc_info.value) or "Invalid token" in str(exc_info.value) + # The exact message depends on ServerAuthHandler logic with empty header + # ServerAuthHandler.authenticate raises "No token provided." if incoming.read() is empty. + + bouncer.close_async_context_manager_sync_only() # Close bouncer's internal converter + + +def test_flight_auth_failure_incorrect_token(flight_server_with_auth): + location = flight_server_with_auth + bouncer = FastFlightBouncer(flight_server_location=location, auth_token="incorrect-token-value") + params = PingParams() + + with pytest.raises(fl.FlightUnauthenticatedError) as exc_info: + bouncer.get_pa_table(params) + + assert "Invalid token" in str(exc_info.value) + bouncer.close_async_context_manager_sync_only() + + +def test_flight_auth_success(flight_server_with_auth): + location = flight_server_with_auth + bouncer = FastFlightBouncer(flight_server_location=location, auth_token=TEST_TOKEN) + params = PingParams() + + try: + table = bouncer.get_pa_table(params) + assert table is not None + assert len(table) == 3 + assert table.column_names == ["col1", "col2"] + finally: + bouncer.close_async_context_manager_sync_only() + +# Add a helper to FastFlightBouncer for tests if not running full async tests +# This is a temporary measure for synchronous test cleanup. +def close_bouncer_sync(bouncer: FastFlightBouncer): + """Synchronously closes the bouncer's async converter resources.""" + if hasattr(bouncer, '_converter') and bouncer._converter: + # If the converter has a loop and thread running for async operations + if hasattr(bouncer._converter, 'loop') and bouncer._converter.loop.is_running(): + bouncer._converter.close() + +# Monkey-patch this method onto the class for testing purposes +FastFlightBouncer.close_async_context_manager_sync_only = close_bouncer_sync + +# Cleanup for BaseParams and BaseDataService registries if tests are run multiple times in one session +# or if other tests also modify these global registries. +@pytest.fixture(autouse=True, scope="session") +def cleanup_registries(): + yield + # Clear specific test entries after all tests in the session + if PingParams.fqn() in BaseParams.registry: + del BaseParams.registry[PingParams.fqn()] + if PingParams.fqn() in BaseDataService._registry: # Accessing protected for test cleanup + del BaseDataService._registry[PingParams.fqn()] diff --git a/tests/test_health_check.py b/tests/test_health_check.py new file mode 100644 index 0000000..9e14cf2 --- /dev/null +++ b/tests/test_health_check.py @@ -0,0 +1,42 @@ +import pytest +from fastapi.testclient import TestClient + +from src.fastflight.fastapi.app import create_app +from src.fastflight.config import fastapi_settings as global_fastapi_settings + +# Minimal list of module paths for testing +TEST_MODULE_PATHS = ["src.fastflight.demo_services"] + +@pytest.fixture(scope="module") +def client(): + # We can use any configuration of FastAPI settings for health check, + # as it's not dependent on API keys or metrics enabled status. + # Using default settings for simplicity. + app = create_app(module_paths=list(TEST_MODULE_PATHS)) + test_client = TestClient(app) + yield test_client + +def test_health_endpoint(client: TestClient): + """ + Tests the /fastflight/health endpoint. + It should return 200 OK and a specific JSON body. + """ + response = client.get("/fastflight/health") + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + +def test_health_endpoint_trailing_slash(client: TestClient): + """ + Tests the /fastflight/health/ endpoint with a trailing slash. + FastAPI typically redirects this if the route is defined without a trailing slash. + """ + response = client.get("/fastflight/health/") + # Default FastAPI behavior is to redirect a trailing slash URL to the non-slash version if only non-slash is defined. + # This might result in a 200 if the redirect is followed by TestClient, or a 30x if not. + # For a simple health check, often only the non-slash version is explicitly tested, + # but it's good to be aware of FastAPI's behavior. + # TestClient follows redirects by default. + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + # If strict no-redirect is desired, one might configure the router differently or TestClient(app, follow_redirects=False) + # and assert a 307/308, but for health checks, 200 is the main goal. diff --git a/tests/test_integration_auth.py b/tests/test_integration_auth.py new file mode 100644 index 0000000..70c58e4 --- /dev/null +++ b/tests/test_integration_auth.py @@ -0,0 +1,234 @@ +import pytest +from fastapi.testclient import TestClient +import threading +import time +import os +from unittest import mock +import pyarrow as pa + +# Configuration and Server/App components +from src.fastflight.config import ( + FlightServerSettings, + FastAPISettings, + flight_server_settings as global_flight_server_settings_for_integration, # alias to avoid clash + fastapi_settings as global_fastapi_settings_for_integration # alias +) +from src.fastflight.server import FastFlightServer, main as flight_server_main_runner +from src.fastflight.fastapi.app import create_app +from src.fastflight.security import ServerAuthHandler +from src.fastflight.fastapi.security import API_KEY_NAME +from src.fastflight.core.base import BaseParams, BaseDataService # For type hints and cleanup + +# Try to import test service from test_flight_auth. If not found, define inline. +try: + from .test_flight_auth import PingParams, PingService +except ImportError: + # Redefine if import fails (e.g. if tests are run in a way that doesn't allow relative import) + class PingParams(BaseParams): # type: ignore + pass + + @BaseDataService._register(PingParams) + class PingService(BaseDataService[PingParams]): # type: ignore + def get_batches(self, params: PingParams, batch_size: int | None = None) -> pa.RecordBatchReader: + data = [pa.array([1, 2, 3]), pa.array(["a", "b", "c"])] + schema = pa.schema([("col1", pa.int64()), ("col2", pa.string())]) + batch = pa.record_batch(data, schema=schema) + return pa.RecordBatchReader.from_batches(schema, [batch]) + async def aget_batches(self, params: PingParams, batch_size: int | None = None): + reader = self.get_batches(params, batch_size) + for batch in reader: + yield batch + +# Test specific configurations +INTEGRATION_FLIGHT_HOST = "127.0.0.1" +INTEGRATION_FLIGHT_PORT = 8891 # Different from test_flight_auth +INTEGRATION_FLIGHT_TOKEN = "integration-flight-token-secure" + +INTEGRATION_API_HOST = "127.0.0.1" # TestClient doesn't use this directly +INTEGRATION_API_PORT = 8001 # Different from other tests +INTEGRATION_API_KEY = "integration-api-key-valid" + +# Module paths for FastAPI app creation +INTEGRATION_MODULE_PATHS = ["src.fastflight.demo_services", "tests.test_integration_auth"] # Add path to PingService if defined here + +@pytest.fixture(scope="module") +def running_flight_server_for_integration(): + original_settings = { + "auth_token": global_flight_server_settings_for_integration.auth_token, + "host": global_flight_server_settings_for_integration.host, + "port": global_flight_server_settings_for_integration.port, + } + global_flight_server_settings_for_integration.auth_token = INTEGRATION_FLIGHT_TOKEN + global_flight_server_settings_for_integration.host = INTEGRATION_FLIGHT_HOST + global_flight_server_settings_for_integration.port = INTEGRATION_FLIGHT_PORT + + # Ensure PingService is registered if it was redefined locally + # This is only needed if the import failed and PingService was redefined in this file. + # If PingService is in demo_services or imported correctly, this might not be needed or should be conditional. + # For safety, if PingService is in this module's scope: + if 'PingService' in globals() and PingParams.fqn() not in BaseDataService._registry: + BaseDataService._register(PingParams, PingService) # type: ignore + + server_thread = threading.Thread(target=flight_server_main_runner, daemon=True) + server_thread.start() + time.sleep(1.5) # Give server time to start, increase if flaky + + # Basic check if server is up + try: + # Flight server requires auth, so this connect attempt might fail if it tries to do anything, + # but it should at least resolve the port. + # A client with auth handler would be better for a real ping. + fl_client = pa.flight.connect(f"grpc://{INTEGRATION_FLIGHT_HOST}:{INTEGRATION_FLIGHT_PORT}", timeout=1) + # Try a list_flights or similar simple, non-data call if server expects auth immediately for all ops + # For now, just connecting is a basic check. + fl_client.close() + except Exception as e: + # If connection itself fails, server is likely not up. Auth errors are secondary. + if "Connection refused" in str(e) or "Deadline Exceeded" in str(e): + pytest.fail(f"Integration Flight server did not start: {e}") + print(f"Integration Flight server check connection got error (may be expected auth issue): {e}") + + yield f"grpc://{INTEGRATION_FLIGHT_HOST}:{INTEGRATION_FLIGHT_PORT}" + + # Restore original settings + for key, value in original_settings.items(): + setattr(global_flight_server_settings_for_integration, key, value) + # Note: Proper server shutdown for threaded server is complex. Relies on daemon thread. + + +@pytest.fixture(scope="module") +def integration_test_client(running_flight_server_for_integration): + flight_server_loc = running_flight_server_for_integration + + original_api_settings = { + "valid_api_keys": global_fastapi_settings_for_integration.valid_api_keys, + "flight_server_location": global_fastapi_settings_for_integration.flight_server_location, + # Store other relevant FastAPI settings if they are changed + } + global_fastapi_settings_for_integration.valid_api_keys = [INTEGRATION_API_KEY] + global_fastapi_settings_for_integration.flight_server_location = flight_server_loc + + # Important: The FastFlightBouncer created by the FastAPI app's lifespan + # needs to pick up the `auth_token` for the Flight server. + # The FastFlightBouncer constructor takes `auth_token`. + # The lifespan `fast_flight_client_lifespan` needs to be aware of this. + # It currently doesn't pass `auth_token` to FastFlightBouncer. + # This requires a modification to `fast_flight_client_lifespan` or how bouncer gets its token. + # For this test, we can mock `FastFlightBouncer` or patch the lifespan, + # or assume `FastAPISettings` could also include `flight_client_auth_token`. + + # Let's assume FastAPISettings can provide the token for the bouncer. + # Add a temporary setting for the test: + setattr(global_fastapi_settings_for_integration, 'flight_client_auth_token', INTEGRATION_FLIGHT_TOKEN) + + # Patch lifespan to use this token. This is a bit intrusive for a test. + # A cleaner way would be for FastFlightBouncer to accept settings object or lifespan to read from config. + # For now, let's assume `fast_flight_client_lifespan` is modified or bouncer configured correctly. + # The current `fast_flight_client_lifespan` does not pass `auth_token` to `FastFlightBouncer`. + # This test WILL FAIL unless the bouncer used by the app sends the token. + # I will proceed assuming this gap needs to be fixed in `lifespan.py` or bouncer init. + # For the purpose of this test structure, I'll mock the bouncer's token usage within the app context. + # This is complex. A simpler path: + # Modify `src/fastflight/fastapi/lifespan.py` so that `FastFlightBouncer` gets `auth_token` + # from `fastapi_settings.flight_server_auth_token` (a new field to be added there). + # This is out of scope for just writing tests. + + # Workaround for the test: We need the bouncer used by TestClient to be authenticated. + # The `fast_flight_client_lifespan` creates the bouncer. + # We can't easily modify that bouncer post-creation by TestClient. + # The most robust way without app code change is to ensure settings are picked up: + # 1. `FlightServerSettings.auth_token` is set (for the server). + # 2. `FastAPISettings` needs a way to tell its bouncer to use a token. + # Let's add a new setting `flight_client_auth_token` to `FastAPISettings`. + # And modify `fast_flight_client_lifespan` to use it. + # (This change to `lifespan.py` is outside this current subtask of just writing tests, + # so this test might highlight that need). + + # Assuming `FastFlightBouncer` inside `create_app`'s lifespan is configured with `INTEGRATION_FLIGHT_TOKEN`. + # This requires `fastapi_settings.flight_server_location` to be set, and potentially a new + # `fastapi_settings.flight_server_auth_token` to be used by the lifespan to init the bouncer. + # Let's assume `FastFlightBouncer` will be enhanced to pick up a client token from a new setting + # like `fastapi_settings.flight_server_token_for_bouncer`. + + # For the test to pass with current code, we'd need `FastFlightBouncer` to be initialized + # by the lifespan with `auth_token=INTEGRATION_FLIGHT_TOKEN`. + # Let's mock the `FastFlightBouncer` initialization within the app's context for this test, + # or more simply, ensure the global `flight_server_settings.auth_token` is what the bouncer might pick up if it defaulted to it. + # This is messy. The cleanest is that `fast_flight_client_lifespan` should instantiate bouncer with a token if configured. + + # Given the constraints, I will write the test assuming the FastAPI app's bouncer is correctly + # configured with `INTEGRATION_FLIGHT_TOKEN`. The success of this test will implicitly depend + # on this setup being possible (e.g. via a new setting in FastAPISettings that the lifespan uses). + + app = create_app( + module_paths=list(INTEGRATION_MODULE_PATHS), + # Lifespan will use global_fastapi_settings_for_integration for flight_server_location + # and needs to be aware of INTEGRATION_FLIGHT_TOKEN for its bouncer. + ) + client = TestClient(app) + + yield client + + # Restore original settings + for key, value in original_api_settings.items(): + setattr(global_fastapi_settings_for_integration, key, value) + if hasattr(global_fastapi_settings_for_integration, 'flight_client_auth_token'): + delattr(global_fastapi_settings_for_integration, 'flight_client_auth_token') + + # Clean up registries for PingService if it was defined locally + if 'PingService' in globals(): + if PingParams.fqn() in BaseParams.registry: + del BaseParams.registry[PingParams.fqn()] + if PingParams.fqn() in BaseDataService._registry: + del BaseDataService._registry[PingParams.fqn()] + + +def test_integration_e2e_authenticated_stream(integration_test_client: TestClient): + # This test assumes that the FastAPI application's FastFlightBouncer is configured + # to use INTEGRATION_FLIGHT_TOKEN when communicating with the Flight server. + # This configuration would typically happen in the FastAPI app's lifespan, + # where FastFlightBouncer is initialized. + + # Prepare request body for PingService + ping_request_body = PingParams().model_dump_json() # Get JSON string + + # Make request to FastAPI's /fastflight/stream endpoint + response = integration_test_client.post( + "/fastflight/stream", + content=ping_request_body, # Send JSON string as content + headers={ + API_KEY_NAME: INTEGRATION_API_KEY, + "Content-Type": "application/json" # Ensure correct content type + } + ) + + assert response.status_code == 200 + assert response.headers["content-type"] == "application/vnd.apache.arrow.stream" + + # Process the Arrow stream response + try: + reader = pa.ipc.RecordBatchStreamReader(response.content) + table = reader.read_all() + assert table is not None + assert len(table) == 3 + assert table.column_names == ["col1", "col2"] + # Verify some data if necessary + assert table.column("col1").to_pylist() == [1, 2, 3] + assert table.column("col2").to_pylist() == ["a", "b", "c"] + except Exception as e: + pytest.fail(f"Error reading Arrow stream from response: {e}\nResponse content: {response.content[:500]}") + +# Note: The success of this integration test heavily depends on the FastAPI application's +# lifespan correctly initializing its FastFlightBouncer instance with the +# `INTEGRATION_FLIGHT_TOKEN`. If `src.fastflight.fastapi.lifespan.fast_flight_client_lifespan` +# does not have a mechanism to pass `auth_token` to `FastFlightBouncer`, this test will fail +# at the stage where FastAPI tries to talk to the Flight server, likely resulting in an +# Unauthenticated error from the Flight server, which would then translate to an HTTP error +# (e.g., 500) from FastAPI. +# +# A potential way to handle this without modifying library code for the test would be to +# globally patch `FastFlightBouncer.__init__` to always use a specific token for this test run, +# but that's highly invasive. +# The best approach is ensuring the application code (`lifespan.py`) can configure the bouncer's auth. +# For this test, we assume such a mechanism exists or will be added. +# e.g. `FastAPISettings` could have `flight_server_client_token` used by lifespan. diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..b11d988 --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,126 @@ +import pytest +from fastapi.testclient import TestClient +import time + +from src.fastflight.fastapi.app import create_app +from src.fastflight.config import fastapi_settings as global_fastapi_settings + +# Minimal list of module paths for testing +TEST_MODULE_PATHS = ["src.fastflight.demo_services"] + +@pytest.fixture(scope="module") +def client_with_metrics_enabled(): + original_metrics_enabled = global_fastapi_settings.metrics_enabled + global_fastapi_settings.metrics_enabled = True # Ensure metrics are on + + app = create_app(module_paths=list(TEST_MODULE_PATHS)) + client = TestClient(app) + + yield client + + global_fastapi_settings.metrics_enabled = original_metrics_enabled + +@pytest.fixture(scope="module") +def client_with_metrics_disabled(): + original_metrics_enabled = global_fastapi_settings.metrics_enabled + global_fastapi_settings.metrics_enabled = False # Ensure metrics are off + + app = create_app(module_paths=list(TEST_MODULE_PATHS)) + client = TestClient(app) + + yield client + + global_fastapi_settings.metrics_enabled = original_metrics_enabled + + +def test_metrics_endpoint_available_when_enabled(client_with_metrics_enabled: TestClient): + response = client_with_metrics_enabled.get("/metrics") + assert response.status_code == 200 + assert "prometheus_client" in response.text or "starlette_requests_total" in response.text # Check for typical metric content + +def test_metrics_endpoint_not_available_when_disabled(client_with_metrics_disabled: TestClient): + response = client_with_metrics_disabled.get("/metrics") + # Expect 404 if the route is not added when metrics are disabled + assert response.status_code == 404 + +def test_basic_fastapi_request_metric_increment(client_with_metrics_enabled: TestClient): + # This test is a bit more involved and can be flaky if other activities affect metrics. + # We target a known metric from starlette-prometheus: starlette_requests_total + + # Helper to parse Prometheus text format (simplified) + def get_metric_value(metrics_text: str, metric_name: str, labels: dict = None): + for line in metrics_text.splitlines(): + if line.startswith("#") or not line.strip(): + continue + name_part = line.split(" ")[0] + value_part = line.split(" ")[1] + + current_metric_name = name_part.split("{")[0] if "{" in name_part else name_part + + if current_metric_name == metric_name: + if labels: + label_match = True + for k, v in labels.items(): + if f'{k}="{v}"' not in name_part: + label_match = False + break + if label_match: + return float(value_part) + else: # No labels to match, return first one found (use with caution) + return float(value_part) + return None + + # Get initial metrics + metrics_before_response = client_with_metrics_enabled.get("/metrics") + assert metrics_before_response.status_code == 200 + + # Define labels for the health endpoint request we are about to make + # Note: starlette-prometheus might use slightly different label names or include more. + # Common labels: method, path, status_code (after request). + # For just counting requests to a path before status_code is known, it might be simpler. + # starlette_requests_total usually has method and path. + + health_path_for_metrics = "/fastflight/health" # The actual path, not the full URL + + # Make a request to the health endpoint (or any other simple GET endpoint) + # This specific request will be for the /fastflight/health endpoint. + # starlette-prometheus automatically adds a trailing slash to path if not root. + # However, the actual path registered in FastAPI is what matters for the label. + # Our health endpoint is "/fastflight/health" + client_with_metrics_enabled.get(health_path_for_metrics) + + # Get metrics again + metrics_after_response = client_with_metrics_enabled.get("/metrics") + assert metrics_after_response.status_code == 200 + + # Check if the counter for the health endpoint (GET requests) has incremented + # The label for path in starlette-prometheus might be specific, e.g. including the prefix. + # Let's try to find the metric for the health endpoint. + # The path label used by starlette-prometheus needs to be precise. + # It might be '/fastflight/health' or similar. + + # Try to get the value before + # The labels for starlette_requests_total are method, path, and status_code (for completed requests) + # If we are checking for a request that just happened, its metric will include its status code. + labels_for_health_check = {"method": "GET", "path": health_path_for_metrics, "status_code": "200"} + + value_before = get_metric_value(metrics_before_response.text, "starlette_requests_total", labels_for_health_check) or 0.0 + value_after = get_metric_value(metrics_after_response.text, "starlette_requests_total", labels_for_health_check) + + assert value_after is not None, f"Metric starlette_requests_total with labels {labels_for_health_check} not found after request." + assert value_after > value_before, \ + f"Metric starlette_requests_total for {health_path_for_metrics} did not increment. Before: {value_before}, After: {value_after}" + + # Also test a bouncer metric if possible, but this requires setting up a bouncer and client calls. + # For now, focusing on starlette-prometheus auto-metrics. + # e.g. bouncer_pool_size should be present if a bouncer was initialized by lifespan. + # This depends on whether create_app's lifespan initializes a bouncer. + # FastFlightBouncer.__init__ sets bouncer_pool_size.set(effective_pool_size) + # So, if the lifespan in create_app runs, this metric should exist. + pool_size_metric_name = "bouncer_pool_size" # As defined in metrics.py + pool_size_value = get_metric_value(metrics_after_response.text, pool_size_metric_name) + + # The bouncer is initialized in the lifespan of the app. + # The default pool size is 10 from BouncerSettings. + assert pool_size_value is not None, f"Metric {pool_size_metric_name} not found." + assert pool_size_value == 10, f"Expected {pool_size_metric_name} to be 10, got {pool_size_value}"