Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions backend/app_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

from __future__ import annotations

import base64
import hmac
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING

from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.responses import Response as StarletteResponse

from _routes._errors import HTTPError
from _routes.generation import router as generation_router
Expand Down Expand Up @@ -45,6 +49,7 @@ def create_app(
handler: "AppHandler",
allowed_origins: list[str] | None = None,
title: str = "LTX-2 Video Generation Server",
auth_token: str = "",
) -> FastAPI:
"""Create a configured FastAPI app bound to the provided handler."""
init_state_service(handler)
Expand All @@ -57,6 +62,38 @@ def create_app(
allow_headers=["*"],
)

@app.middleware("http")
async def _auth_middleware( # pyright: ignore[reportUnusedFunction]
request: Request,
call_next: Callable[[Request], Awaitable[StarletteResponse]],
) -> StarletteResponse:
if not auth_token:
return await call_next(request)
if request.method == "OPTIONS":
return await call_next(request)

def _token_matches(candidate: str) -> bool:
return hmac.compare_digest(candidate, auth_token)

# WebSocket: check query param
if request.headers.get("upgrade", "").lower() == "websocket":
if _token_matches(request.query_params.get("token", "")):
return await call_next(request)
return JSONResponse(status_code=401, content={"error": "Unauthorized"})
# HTTP: Bearer or Basic auth
auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer ") and _token_matches(auth_header[7:]):
return await call_next(request)
if auth_header.startswith("Basic "):
try:
decoded = base64.b64decode(auth_header[6:]).decode()
_, _, password = decoded.partition(":")
if _token_matches(password):
return await call_next(request)
except Exception:
pass
return JSONResponse(status_code=401, content={"error": "Unauthorized"})

async def _route_http_error_handler(request: Request, exc: Exception) -> JSONResponse:
if isinstance(exc, HTTPError):
log_http_error(request, exc)
Expand Down
34 changes: 28 additions & 6 deletions backend/ltx2_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def patched_sdpa(
# Constants & Paths
# ============================================================

PORT = 8000
PORT = 0 # 0 = pick a free port; Electron parses "Server running on ..." for the actual URL


def _get_device() -> torch.device:
Expand Down Expand Up @@ -219,7 +219,8 @@ def _resolve_force_api_generations() -> bool:
)

handler = build_initial_state(runtime_config, DEFAULT_APP_SETTINGS)
app = create_app(handler=handler, allowed_origins=DEFAULT_ALLOWED_ORIGINS)
auth_token = os.environ.get("LTX_AUTH_TOKEN", "")
app = create_app(handler=handler, allowed_origins=DEFAULT_ALLOWED_ORIGINS, auth_token=auth_token)


def precache_model_files(model_dir: Path) -> int:
Expand Down Expand Up @@ -267,9 +268,11 @@ def log_hardware_info() -> None:


if __name__ == "__main__":
import asyncio
import socket as _socket
import uvicorn

port = int(os.environ.get("LTX_PORT", PORT))
port = int(os.environ.get("LTX_PORT", "") or PORT)
logger.info("=" * 60)
logger.info("LTX-2 Video Generation Server (FastAPI + Uvicorn)")
log_hardware_info()
Expand All @@ -281,8 +284,13 @@ def log_hardware_info() -> None:
queue_thread = threading.Thread(target=queue_worker_loop, daemon=True)
queue_thread.start()

# Use our root logging config so uvicorn logs go to stdout (not its
# default stderr), letting Electron tag them correctly as INFO.
# Bind the socket ourselves so we know the actual port before uvicorn starts.
# Electron parses the "Server running on ..." line to get the backend URL.
sock = _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM)
sock.setsockopt(_socket.SOL_SOCKET, _socket.SO_REUSEADDR, 1)
sock.bind(("127.0.0.1", port))
actual_port = int(sock.getsockname()[1])

log_config: dict[str, object] = {
"version": 1,
"disable_existing_loggers": False,
Expand All @@ -298,4 +306,18 @@ def log_hardware_info() -> None:
"uvicorn.access": {"handlers": ["default"], "level": "INFO", "propagate": False},
},
}
uvicorn.run(app, host="127.0.0.1", port=port, log_level="info", access_log=False, log_config=log_config)
config = uvicorn.Config(
app, host="127.0.0.1", port=actual_port, log_level="info", access_log=False, log_config=log_config
)
server = uvicorn.Server(config)

_orig_startup = server.startup

async def _startup_with_ready_msg(sockets: list[_socket.socket] | None = None) -> None:
await _orig_startup(sockets=sockets)
if server.started:
print(f"Server running on http://127.0.0.1:{actual_port}", flush=True)

server.startup = _startup_with_ready_msg # type: ignore[assignment]

asyncio.run(server.serve(sockets=[sock]))
74 changes: 74 additions & 0 deletions backend/tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Tests for shared-secret authentication middleware."""

from __future__ import annotations

import base64

from starlette.testclient import TestClient

from app_factory import create_app


def test_request_without_token_returns_401(test_state): # noqa: ANN001
app = create_app(handler=test_state, auth_token="test-secret")
with TestClient(app) as client:
response = client.get("/health")
assert response.status_code == 401
assert response.json() == {"error": "Unauthorized"}


def test_request_with_correct_bearer_token(test_state): # noqa: ANN001
app = create_app(handler=test_state, auth_token="test-secret")
with TestClient(app) as client:
response = client.get("/health", headers={"Authorization": "Bearer test-secret"})
assert response.status_code == 200


def test_request_with_correct_basic_auth(test_state): # noqa: ANN001
app = create_app(handler=test_state, auth_token="test-secret")
credentials = base64.b64encode(b":test-secret").decode()
with TestClient(app) as client:
response = client.get("/health", headers={"Authorization": f"Basic {credentials}"})
assert response.status_code == 200


def test_request_with_wrong_token_returns_401(test_state): # noqa: ANN001
app = create_app(handler=test_state, auth_token="test-secret")
with TestClient(app) as client:
response = client.get("/health", headers={"Authorization": "Bearer wrong-token"})
assert response.status_code == 401


def test_health_without_token_returns_401(test_state): # noqa: ANN001
"""Health endpoint is NOT exempt from auth."""
app = create_app(handler=test_state, auth_token="test-secret")
with TestClient(app) as client:
response = client.get("/health")
assert response.status_code == 401


def test_no_auth_token_disables_middleware(test_state): # noqa: ANN001
"""When auth_token is empty string, auth is disabled (dev/test mode)."""
app = create_app(handler=test_state, auth_token="")
with TestClient(app) as client:
response = client.get("/health")
assert response.status_code == 200


def test_websocket_with_token_query_param(test_state): # noqa: ANN001
app = create_app(handler=test_state, auth_token="test-secret")
with TestClient(app) as client:
# WebSocket upgrade without token should fail with 401
response = client.get(
"/ws/download/test",
headers={"upgrade": "websocket", "connection": "upgrade"},
)
assert response.status_code == 401

# WebSocket upgrade with correct token query param
response = client.get(
"/ws/download/test?token=test-secret",
headers={"upgrade": "websocket", "connection": "upgrade"},
)
# The route may not exist, but auth should pass (not 401)
assert response.status_code != 401
2 changes: 0 additions & 2 deletions electron/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ import { app } from 'electron'
import path from 'path'
import os from 'os'

export const PYTHON_PORT = 8000
export const BACKEND_BASE_URL = `http://localhost:${PYTHON_PORT}`
export const isDev = !app.isPackaged

// Get directory - works in both CJS and ESM contexts
Expand Down
13 changes: 8 additions & 5 deletions electron/gpu.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import { execSync } from 'child_process'
import { BACKEND_BASE_URL } from './config'
import { logger } from './logger'
import { getPythonPath } from './python-backend'
import { getAuthToken, getBackendUrl, getPythonPath } from './python-backend'

// Check if NVIDIA GPU is available
export async function checkGPU(): Promise<{ available: boolean; name?: string; vram?: number }> {
try {
// Try to get GPU info from the backend API first (more reliable)
const response = await fetch(`${BACKEND_BASE_URL}/api/gpu-info`, {
const url = getBackendUrl()
if (!url) throw new Error('Backend URL not available yet')
const headers: Record<string, string> = { 'Content-Type': 'application/json' }
const token = getAuthToken()
if (token) headers['Authorization'] = `Bearer ${token}`
const response = await fetch(`${url}/api/gpu-info`, {
method: 'GET',
headers: { 'Content-Type': 'application/json' },
headers,
})

if (response.ok) {
Expand Down
7 changes: 3 additions & 4 deletions electron/ipc/app-handlers.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import { app, ipcMain } from 'electron'
import path from 'path'
import fs from 'fs'
import { BACKEND_BASE_URL } from '../config'
import { checkGPU } from '../gpu'
import { isPythonReady, downloadPythonEmbed } from '../python-setup'
import { getBackendHealthStatus, startPythonBackend } from '../python-backend'
import { getBackendHealthStatus, getBackendUrl, getAuthToken, startPythonBackend } from '../python-backend'
import { getMainWindow } from '../window'
import { getAnalyticsState, setAnalyticsEnabled, sendAnalyticsEvent } from '../analytics'

Expand Down Expand Up @@ -68,8 +67,8 @@ function markLicenseAccepted(settingsPath: string): void {
}

export function registerAppHandlers(): void {
ipcMain.handle('get-backend-url', () => {
return BACKEND_BASE_URL
ipcMain.handle('get-backend', () => {
return { url: getBackendUrl() ?? '', token: getAuthToken() ?? '' }
})

ipcMain.handle('get-models-path', () => {
Expand Down
6 changes: 3 additions & 3 deletions electron/preload.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ const { contextBridge, ipcRenderer } = require('electron')

// Expose protected methods to the renderer process
contextBridge.exposeInMainWorld('electronAPI', {
// Get the backend URL
getBackendUrl: (): Promise<string> => ipcRenderer.invoke('get-backend-url'),
// Get the backend URL and auth token (empty token when auth disabled)
getBackend: (): Promise<{ url: string; token: string }> => ipcRenderer.invoke('get-backend'),

// Get the path where models are stored
getModelsPath: (): Promise<string> => ipcRenderer.invoke('get-models-path'),
Expand Down Expand Up @@ -141,7 +141,7 @@ interface BackendHealthStatus {
declare global {
interface Window {
electronAPI: {
getBackendUrl: () => Promise<string>
getBackend: () => Promise<{ url: string; token: string }>
getModelsPath: () => Promise<string>
readLocalFile: (filePath: string) => Promise<{ data: string; mimeType: string }>
checkGpu: () => Promise<{ available: boolean; name?: string; vram?: number }>
Expand Down
Loading