From 389a8e79829504fce493b3b4c1901954d6189683 Mon Sep 17 00:00:00 2001 From: deesh-code Date: Thu, 6 Nov 2025 18:06:19 +0530 Subject: [PATCH] Add proxy support for corporate networks Added support for routing traffic through SOCKS and HTTP proxies. This helps when running strix behind corporate firewalls or when you need to route different types of traffic through different proxies. New environment variables: - STRIX_PROXY_ALL for all traffic - STRIX_PROXY_TOOLS for just tool traffic - STRIX_PROXY_LLM for just LLM requests Works with corporate proxies, SSH tunnels, and WAF allowlists. --- README.md | 7 + pyproject.toml | 2 + strix/interface/main.py | 61 ++++++++ strix/proxy_config.py | 154 +++++++++++++++++++ strix/runtime/docker_runtime.py | 48 ++++-- strix/tools/executor.py | 62 ++++++-- strix/tools/proxy/proxy_manager.py | 2 + strix/tools/web_search/web_search_actions.py | 9 +- 8 files changed, 316 insertions(+), 29 deletions(-) create mode 100644 strix/proxy_config.py diff --git a/README.md b/README.md index d627611e..da07c459 100644 --- a/README.md +++ b/README.md @@ -145,8 +145,15 @@ export LLM_API_KEY="your-api-key" # Optional export LLM_API_BASE="your-api-base-url" # if using a local model, e.g. Ollama, LMStudio export PERPLEXITY_API_KEY="your-api-key" # for search capabilities + +# Proxy Configuration (optional) +export STRIX_PROXY_ALL="socks5://proxy.example.com:1080" # Proxy for all traffic +export STRIX_PROXY_TOOLS="http://proxy.example.com:8080" # Proxy for tool traffic only +export STRIX_PROXY_LLM="https://proxy.example.com:8080" # Proxy for LLM traffic only ``` +**Proxy Support**: Strix supports both HTTP and SOCKS5 proxies for routing traffic through corporate networks, WAF allow-lists, or SSH tunnels. Configure separate proxies for tool traffic and LLM requests, or use `STRIX_PROXY_ALL` for unified routing. + [📚 View supported AI models](https://docs.litellm.ai/docs/providers) ### 🤖 Headless Mode diff --git a/pyproject.toml b/pyproject.toml index 974087a0..1c3bd0e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,9 @@ textual = "^4.0.0" xmltodict = "^0.13.0" pyte = "^0.8.1" requests = "^2.32.0" +requests-socks = "^2.0.0" # SOCKS proxy support for requests libtmux = "^0.46.2" +httpx-socks = "^0.9.1" # SOCKS proxy support for httpx [tool.poetry.group.dev.dependencies] # Type checking and static analysis diff --git a/strix/interface/main.py b/strix/interface/main.py index 063dc10d..1e411112 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -32,6 +32,7 @@ process_pull_line, validate_llm_response, ) +from strix.proxy_config import configure_global_proxies from strix.runtime.docker_runtime import STRIX_IMAGE from strix.telemetry.tracer import get_global_tracer @@ -68,6 +69,34 @@ def validate_environment() -> None: # noqa: PLR0912, PLR0915 if not os.getenv("PERPLEXITY_API_KEY"): missing_optional_vars.append("PERPLEXITY_API_KEY") + # Validate proxy configuration + try: + configure_global_proxies() + except ValueError as e: + error_text = Text() + error_text.append("❌ ", style="bold red") + error_text.append("INVALID PROXY CONFIGURATION", style="bold red") + error_text.append("\n\n", style="white") + error_text.append(str(e), style="white") + error_text.append("\n\nSupported proxy formats:\n", style="white") + error_text.append("• http://proxy.example.com:8080\n", style="dim white") + error_text.append("• https://proxy.example.com:8080\n", style="dim white") + error_text.append("• socks5://proxy.example.com:1080\n", style="dim white") + error_text.append("• socks5h://proxy.example.com:1080\n", style="dim white") + + panel = Panel( + error_text, + title="[bold red]🛡️ STRIX CONFIGURATION ERROR", + title_align="center", + border_style="red", + padding=(1, 2), + ) + + console.print("\n") + console.print(panel) + console.print() + sys.exit(1) + if missing_required_vars: error_text = Text() error_text.append("❌ ", style="bold red") @@ -123,6 +152,25 @@ def validate_environment() -> None: # noqa: PLR0912, PLR0915 style="white", ) + # Add proxy configuration documentation + proxy_configured = any([ + os.getenv("STRIX_PROXY_ALL"), + os.getenv("STRIX_PROXY_TOOLS"), + os.getenv("STRIX_PROXY_LLM") + ]) + + if proxy_configured or missing_optional_vars: + error_text.append("\nProxy configuration (optional):\n", style="white") + error_text.append("• ", style="white") + error_text.append("STRIX_PROXY_ALL", style="bold cyan") + error_text.append(" - Proxy for all traffic (tools and LLM)\n", style="white") + error_text.append("• ", style="white") + error_text.append("STRIX_PROXY_TOOLS", style="bold cyan") + error_text.append(" - Proxy for tool traffic only\n", style="white") + error_text.append("• ", style="white") + error_text.append("STRIX_PROXY_LLM", style="bold cyan") + error_text.append(" - Proxy for LLM traffic only\n", style="white") + error_text.append("\nExample setup:\n", style="white") error_text.append("export STRIX_LLM='openai/gpt-5'\n", style="dim white") @@ -147,6 +195,19 @@ def validate_environment() -> None: # noqa: PLR0912, PLR0915 "export PERPLEXITY_API_KEY='your-perplexity-key-here'\n", style="dim white" ) + # Add proxy examples if any proxy is configured + proxy_configured = any([ + os.getenv("STRIX_PROXY_ALL"), + os.getenv("STRIX_PROXY_TOOLS"), + os.getenv("STRIX_PROXY_LLM") + ]) + + if proxy_configured: + error_text.append("\nProxy examples:\n", style="white") + error_text.append("export STRIX_PROXY_ALL='socks5://proxy.example.com:1080'\n", style="dim white") + error_text.append("export STRIX_PROXY_TOOLS='http://proxy.example.com:8080'\n", style="dim white") + error_text.append("export STRIX_PROXY_LLM='https://llm-proxy.example.com:8080'\n", style="dim white") + panel = Panel( error_text, title="[bold red]🛡️ STRIX CONFIGURATION ERROR", diff --git a/strix/proxy_config.py b/strix/proxy_config.py new file mode 100644 index 00000000..eb312f69 --- /dev/null +++ b/strix/proxy_config.py @@ -0,0 +1,154 @@ +""" +Proxy configuration module for Strix. + +This module handles upstream proxy configuration for both tool traffic and LLM traffic. +Supports both SOCKS5 and HTTP proxies as requested in: +https://github.com/usestrix/strix/issues/19 +""" + +import os +from dataclasses import dataclass +from typing import Any +from urllib.parse import urlparse + + +@dataclass +class ProxyConfig: + """Configuration for upstream proxies.""" + + tools_proxy: str | None = None + llm_proxy: str | None = None + all_proxy: str | None = None + + def __post_init__(self) -> None: + """Validate proxy configurations.""" + for proxy_name, proxy_url in [ + ("STRIX_PROXY_TOOLS", self.tools_proxy), + ("STRIX_PROXY_LLM", self.llm_proxy), + ("STRIX_PROXY_ALL", self.all_proxy), + ]: + if proxy_url: + self._validate_proxy_url(proxy_url, proxy_name) + + def _validate_proxy_url(self, proxy_url: str, env_var_name: str) -> None: + """Validate proxy URL format.""" + try: + parsed = urlparse(proxy_url) + if parsed.scheme not in ["http", "https", "socks5", "socks5h"]: + raise ValueError( + f"Invalid proxy scheme in {env_var_name}: {parsed.scheme}. " + "Supported schemes: http, https, socks5, socks5h" + ) + if not parsed.hostname: + raise ValueError(f"Missing hostname in {env_var_name}: {proxy_url}") + if not parsed.port: + raise ValueError(f"Missing port in {env_var_name}: {proxy_url}") + except Exception as e: + raise ValueError(f"Invalid proxy URL in {env_var_name}: {proxy_url}") from e + + def get_tools_proxy(self) -> str | None: + """Get proxy configuration for tools traffic.""" + return self.tools_proxy or self.all_proxy + + def get_llm_proxy(self) -> str | None: + """Get proxy configuration for LLM traffic.""" + return self.llm_proxy or self.all_proxy + + def get_requests_proxies(self, proxy_type: str = "tools") -> dict[str, str] | None: + """ + Get proxy configuration in requests library format. + + Args: + proxy_type: Either 'tools' or 'llm' to determine which proxy to use. + + Returns: + Dictionary with 'http' and 'https' keys, or None if no proxy configured. + """ + proxy_url = self.get_tools_proxy() if proxy_type == "tools" else self.get_llm_proxy() + if not proxy_url: + return None + + return {"http": proxy_url, "https": proxy_url} + + def get_httpx_proxies(self, proxy_type: str = "tools") -> dict[str, str] | None: + """ + Get proxy configuration in httpx library format. + + Args: + proxy_type: Either 'tools' or 'llm' to determine which proxy to use. + + Returns: + Dictionary with protocol keys, or None if no proxy configured. + + Note: + For SOCKS proxies with httpx, we need to use httpx-socks library + and create AsyncProxyTransport instead of simple URL strings. + """ + proxy_url = self.get_tools_proxy() if proxy_type == "tools" else self.get_llm_proxy() + if not proxy_url: + return None + + # For httpx, we can return the same format as requests for HTTP proxies + # SOCKS proxies need special handling with httpx-socks + parsed = urlparse(proxy_url) + if parsed.scheme in ["socks5", "socks5h"]: + # We'll handle SOCKS in the calling code using httpx-socks + return {"_socks_proxy": proxy_url} + else: + # HTTP/HTTPS proxies work the same as requests + return {"http://": proxy_url, "https://": proxy_url} + + def get_litellm_proxy_env(self) -> dict[str, str]: + """ + Get environment variables for litellm proxy configuration. + + Returns: + Dictionary of environment variables to set for litellm. + """ + env_vars = {} + llm_proxy = self.get_llm_proxy() + + if llm_proxy: + # litellm supports standard proxy environment variables + env_vars["HTTP_PROXY"] = llm_proxy + env_vars["HTTPS_PROXY"] = llm_proxy + + return env_vars + + +def load_proxy_config() -> ProxyConfig: + """Load proxy configuration from environment variables.""" + return ProxyConfig( + tools_proxy=os.getenv("STRIX_PROXY_TOOLS"), + llm_proxy=os.getenv("STRIX_PROXY_LLM"), + all_proxy=os.getenv("STRIX_PROXY_ALL"), + ) + + +def configure_global_proxies() -> ProxyConfig: + """ + Configure global proxy settings and return the configuration. + + This function should be called early in the application startup + to ensure proxy settings are applied globally. + """ + config = load_proxy_config() + + # Set environment variables for litellm if LLM proxy is configured + llm_proxy_env = config.get_litellm_proxy_env() + for key, value in llm_proxy_env.items(): + os.environ[key] = value + + return config + + +# Global proxy configuration instance +_global_proxy_config: ProxyConfig | None = None + + +def get_proxy_config() -> ProxyConfig: + """Get the global proxy configuration instance.""" + global _global_proxy_config # noqa: PLW0603 + if _global_proxy_config is None: + _global_proxy_config = configure_global_proxies() + return _global_proxy_config \ No newline at end of file diff --git a/strix/runtime/docker_runtime.py b/strix/runtime/docker_runtime.py index 32cc6252..c32f34ad 100644 --- a/strix/runtime/docker_runtime.py +++ b/strix/runtime/docker_runtime.py @@ -11,6 +11,8 @@ from docker.errors import DockerException, ImageNotFound, NotFound from docker.models.containers import Container +from strix.proxy_config import get_proxy_config + from .runtime import AbstractRuntime, SandboxInfo @@ -340,18 +342,40 @@ async def _register_agent_with_tool_server( ) -> None: import httpx - try: - async with httpx.AsyncClient(trust_env=False) as client: - response = await client.post( - f"{api_url}/register_agent", - params={"agent_id": agent_id}, - headers={"Authorization": f"Bearer {token}"}, - timeout=30, - ) - response.raise_for_status() - logger.info(f"Registered agent {agent_id} with tool server") - except (httpx.RequestError, httpx.HTTPStatusError) as e: - logger.warning(f"Failed to register agent {agent_id}: {e}") + proxy_config = get_proxy_config() + proxies = proxy_config.get_httpx_proxies("tools") + + # Handle SOCKS proxies with httpx-socks + if proxies and "_socks_proxy" in proxies: + from httpx_socks import AsyncProxyTransport + + socks_url = proxies["_socks_proxy"] + transport = AsyncProxyTransport.from_url(socks_url) + try: + async with httpx.AsyncClient(transport=transport, trust_env=False) as client: + response = await client.post( + f"{api_url}/register_agent", + params={"agent_id": agent_id}, + headers={"Authorization": f"Bearer {token}"}, + timeout=30, + ) + response.raise_for_status() + logger.info(f"Registered agent {agent_id} with tool server") + except (httpx.RequestError, httpx.HTTPStatusError) as e: + logger.warning(f"Failed to register agent {agent_id}: {e}") + else: + try: + async with httpx.AsyncClient(trust_env=False, proxies=proxies) as client: + response = await client.post( + f"{api_url}/register_agent", + params={"agent_id": agent_id}, + headers={"Authorization": f"Bearer {token}"}, + timeout=30, + ) + response.raise_for_status() + logger.info(f"Registered agent {agent_id} with tool server") + except (httpx.RequestError, httpx.HTTPStatusError) as e: + logger.warning(f"Failed to register agent {agent_id}: {e}") async def get_sandbox_url(self, container_id: str, port: int) -> str: try: diff --git a/strix/tools/executor.py b/strix/tools/executor.py index 6dd1b04d..e6e300dc 100644 --- a/strix/tools/executor.py +++ b/strix/tools/executor.py @@ -4,6 +4,8 @@ import httpx +from strix.proxy_config import get_proxy_config + if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false": from strix.runtime import get_runtime @@ -62,22 +64,50 @@ async def _execute_tool_in_sandbox(tool_name: str, agent_state: Any, **kwargs: A "Content-Type": "application/json", } - async with httpx.AsyncClient(trust_env=False) as client: - try: - response = await client.post( - request_url, json=request_data, headers=headers, timeout=None - ) - response.raise_for_status() - response_data = response.json() - if response_data.get("error"): - raise RuntimeError(f"Sandbox execution error: {response_data['error']}") - return response_data.get("result") - except httpx.HTTPStatusError as e: - if e.response.status_code == 401: - raise RuntimeError("Authentication failed: Invalid or missing sandbox token") from e - raise RuntimeError(f"HTTP error calling tool server: {e.response.status_code}") from e - except httpx.RequestError as e: - raise RuntimeError(f"Request error calling tool server: {e}") from e + proxy_config = get_proxy_config() + proxies = proxy_config.get_httpx_proxies("tools") + + # Handle SOCKS proxies with httpx-socks + if proxies and "_socks_proxy" in proxies: + from httpx_socks import AsyncProxyTransport + from urllib.parse import urlparse + + socks_url = proxies["_socks_proxy"] + parsed = urlparse(socks_url) + transport = AsyncProxyTransport.from_url(socks_url) + async with httpx.AsyncClient(transport=transport, trust_env=False) as client: + try: + response = await client.post( + request_url, json=request_data, headers=headers, timeout=None + ) + response.raise_for_status() + response_data = response.json() + if response_data.get("error"): + raise RuntimeError(f"Sandbox execution error: {response_data['error']}") + return response_data.get("result") + except httpx.HTTPStatusError as e: + if e.response.status_code == 401: + raise RuntimeError("Authentication failed: Invalid or missing sandbox token") from e + raise RuntimeError(f"HTTP error calling tool server: {e.response.status_code}") from e + except httpx.RequestError as e: + raise RuntimeError(f"Request error calling tool server: {e}") from e + else: + async with httpx.AsyncClient(trust_env=False, proxies=proxies) as client: + try: + response = await client.post( + request_url, json=request_data, headers=headers, timeout=None + ) + response.raise_for_status() + response_data = response.json() + if response_data.get("error"): + raise RuntimeError(f"Sandbox execution error: {response_data['error']}") + return response_data.get("result") + except httpx.HTTPStatusError as e: + if e.response.status_code == 401: + raise RuntimeError("Authentication failed: Invalid or missing sandbox token") from e + raise RuntimeError(f"HTTP error calling tool server: {e.response.status_code}") from e + except httpx.RequestError as e: + raise RuntimeError(f"Request error calling tool server: {e}") from e async def _execute_tool_locally(tool_name: str, agent_state: Any | None, **kwargs: Any) -> Any: diff --git a/strix/tools/proxy/proxy_manager.py b/strix/tools/proxy/proxy_manager.py index e02d85b7..1ca15b23 100644 --- a/strix/tools/proxy/proxy_manager.py +++ b/strix/tools/proxy/proxy_manager.py @@ -11,6 +11,8 @@ from gql.transport.requests import RequestsHTTPTransport from requests.exceptions import ProxyError, RequestException, Timeout +from strix.proxy_config import get_proxy_config + if TYPE_CHECKING: from collections.abc import Callable diff --git a/strix/tools/web_search/web_search_actions.py b/strix/tools/web_search/web_search_actions.py index 52f00a97..bc224641 100644 --- a/strix/tools/web_search/web_search_actions.py +++ b/strix/tools/web_search/web_search_actions.py @@ -3,6 +3,7 @@ import requests +from strix.proxy_config import get_proxy_config from strix.tools.registry import register_tool @@ -53,7 +54,13 @@ def web_search(query: str) -> dict[str, Any]: ], } - response = requests.post(url, headers=headers, json=payload, timeout=300) + response = requests.post( + url, + headers=headers, + json=payload, + timeout=300, + proxies=get_proxy_config().get_requests_proxies("tools") + ) response.raise_for_status() response_data = response.json()