diff --git a/README.md b/README.md index f5e8901..648d230 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ # SmolVM -**Secure runtime for AI agents and tools** +**Secure runtime for AI agents to execute untrusted code** [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/) diff --git a/src/smolvm/api.py b/src/smolvm/api.py index 1a3149d..023b1c9 100644 --- a/src/smolvm/api.py +++ b/src/smolvm/api.py @@ -217,6 +217,7 @@ def add_network_interface( iface_id: str, host_dev_name: str, guest_mac: str, + rate_limit_mbps: int | None = None, ) -> None: """Add a network interface. @@ -224,15 +225,33 @@ def add_network_interface( iface_id: Interface identifier. host_dev_name: TAP device name on host. guest_mac: MAC address for the guest. + rate_limit_mbps: Optional rate limit in Mbps. """ + payload: dict[str, Any] = { + "iface_id": iface_id, + "host_dev_name": host_dev_name, + "guest_mac": guest_mac, + } + + if rate_limit_mbps is not None and rate_limit_mbps > 0: + bytes_per_sec = rate_limit_mbps * 125000 + # 100ms refill time is standard for smooth throughput + size_per_100ms = max(bytes_per_sec // 10, 1) + + token_bucket = { + "bandwidth": { + "size": size_per_100ms, + "one_time_burst": bytes_per_sec, + "refill_time": 100, + } + } + payload["rx_rate_limiter"] = token_bucket + payload["tx_rate_limiter"] = token_bucket + self._request( "PUT", f"/network-interfaces/{iface_id}", - json={ - "iface_id": iface_id, - "host_dev_name": host_dev_name, - "guest_mac": guest_mac, - }, + json=payload, ) logger.debug("Network interface added: %s -> %s", iface_id, host_dev_name) diff --git a/src/smolvm/build.py b/src/smolvm/build.py index 0f70e8e..f512f92 100644 --- a/src/smolvm/build.py +++ b/src/smolvm/build.py @@ -17,13 +17,17 @@ Automatically builds VM images with SSH using Docker. """ +import hashlib +import json import logging import platform +import re import shlex import shutil import subprocess import tarfile import tempfile +import typing import urllib.error import urllib.request from pathlib import Path @@ -36,6 +40,11 @@ # Default boot args that include init=/init for our custom init script SSH_BOOT_ARGS = "console=ttyS0 reboot=k panic=1 pci=off root=/dev/vda rw init=/init" +# Boot args for OpenClaw VMs — 8250.nr_uarts=0 disables serial UART to avoid +# vCPU exits on /dev/ttyS0 writes, which become a measurable host CPU tax at +# 200+ VMs. No console= since we don't need serial output in production. +OPENCLAW_BOOT_ARGS = "reboot=k panic=1 pci=off init=/init 8250.nr_uarts=0" + # Firecracker-compatible uncompressed kernels. FIRECRACKER_KERNEL_URLS = { "x86_64": "https://s3.amazonaws.com/spec.ccfc.min/firecracker-ci/v1.6/x86_64/vmlinux-5.10.198", @@ -142,9 +151,18 @@ def build_alpine_ssh( kernel_path = image_dir / "vmlinux.bin" rootfs_path = image_dir / "rootfs.ext4" - # Return cached image if it exists - if kernel_path.exists() and rootfs_path.exists(): - logger.info("Image '%s' already exists at %s", name, image_dir) + # Check fingerprint cache + fingerprint_data = { + "rootfs_size_mb": rootfs_size_mb, + "kernel_url": kernel_url, + "ssh_password": ssh_password, + } + if ( + kernel_path.exists() + and rootfs_path.exists() + and self._check_fingerprint(image_dir, fingerprint_data) + ): + logger.info("Image '%s' already exists and fingerprint matches at %s", name, image_dir) return (kernel_path, rootfs_path) logger.info("Building Alpine SSH image '%s'...", name) @@ -153,9 +171,11 @@ def build_alpine_ssh( # The /init script runs as PID 1 inside the VM and brings up SSH. init_script = self._default_init_script() - dockerfile_content = f""" + dockerfile_content = """ FROM alpine:3.19 +ARG SSH_PASSWORD + # Install SSH and networking utilities RUN apk add --no-cache \\ openssh \\ @@ -167,7 +187,7 @@ def build_alpine_ssh( RUN ssh-keygen -A && \\ sed -i 's/#PermitRootLogin.*/PermitRootLogin yes/' /etc/ssh/sshd_config && \\ sed -i 's/#PasswordAuthentication.*/PasswordAuthentication yes/' /etc/ssh/sshd_config && \\ - echo 'root:{ssh_password}' | chpasswd + echo "root:${SSH_PASSWORD}" | chpasswd # Install our custom init script COPY init /init @@ -183,7 +203,9 @@ def build_alpine_ssh( kernel_path, rootfs_path, rootfs_size_mb, + build_args={"SSH_PASSWORD": ssh_password}, kernel_url=kernel_url, + fingerprint_data=fingerprint_data, ) except (subprocess.CalledProcessError, ImageError) as e: # Clean up partial build @@ -228,45 +250,23 @@ def build_alpine_ssh_key( kernel_path = image_dir / "vmlinux.bin" rootfs_path = image_dir / "rootfs.ext4" + fingerprint_data = { + "rootfs_size_mb": rootfs_size_mb, + "kernel_url": kernel_url, + "ssh_public_key": key_value, + } + if kernel_path.exists() and rootfs_path.exists(): - # Check if the image is stale (older than the provided key file) - is_stale = False - - # Resolve key path from input if possible - key_path_check: Path | None = None - if isinstance(ssh_public_key, Path): - key_path_check = ssh_public_key - elif isinstance(ssh_public_key, str): - try: - p = Path(ssh_public_key) - if p.exists(): - key_path_check = p - except OSError: - pass - - # If we found a key file, check its mtime - if key_path_check and key_path_check.exists(): - try: - key_mtime = key_path_check.stat().st_mtime - img_mtime = rootfs_path.stat().st_mtime - if key_mtime > img_mtime: - logger.info( - "SSH key '%s' is newer than cached image. Rebuilding...", - key_path_check.name, - ) - is_stale = True - except OSError: - pass - - if not is_stale: - logger.info("Image '%s' already exists at %s", name, image_dir) + if self._check_fingerprint(image_dir, fingerprint_data): + logger.info( + "Image '%s' already exists and fingerprint matches at %s", name, image_dir + ) return (kernel_path, rootfs_path) + logger.info("SSH key or config changed for image '%s'. Rebuilding...", name) # Remove stale files - if kernel_path.exists(): - kernel_path.unlink() - if rootfs_path.exists(): - rootfs_path.unlink() + kernel_path.unlink(missing_ok=True) + rootfs_path.unlink(missing_ok=True) logger.info("Building Alpine key-only SSH image '%s'...", name) image_dir.mkdir(parents=True, exist_ok=True) @@ -306,6 +306,7 @@ def build_alpine_ssh_key( rootfs_size_mb, extra_files={"authorized_keys": f"{key_value}\n"}, kernel_url=kernel_url, + fingerprint_data=fingerprint_data, ) except (subprocess.CalledProcessError, ImageError) as e: if rootfs_path.exists(): @@ -351,45 +352,24 @@ def build_debian_ssh_key( kernel_path = image_dir / "vmlinux.bin" rootfs_path = image_dir / "rootfs.ext4" + fingerprint_data = { + "rootfs_size_mb": rootfs_size_mb, + "kernel_url": kernel_url, + "ssh_public_key": key_value, + "base_image": base_image, + } + if kernel_path.exists() and rootfs_path.exists(): - # Check if the image is stale (older than the provided key file) - is_stale = False - - # Resolve key path from input if possible - key_path_check: Path | None = None - if isinstance(ssh_public_key, Path): - key_path_check = ssh_public_key - elif isinstance(ssh_public_key, str): - try: - p = Path(ssh_public_key) - if p.exists(): - key_path_check = p - except OSError: - pass - - # If we found a key file, check its mtime - if key_path_check and key_path_check.exists(): - try: - key_mtime = key_path_check.stat().st_mtime - img_mtime = rootfs_path.stat().st_mtime - if key_mtime > img_mtime: - logger.info( - "SSH key '%s' is newer than cached image. Rebuilding...", - key_path_check.name, - ) - is_stale = True - except OSError: - pass - - if not is_stale: - logger.info("Image '%s' already exists at %s", name, image_dir) + if self._check_fingerprint(image_dir, fingerprint_data): + logger.info( + "Image '%s' already exists and fingerprint matches at %s", name, image_dir + ) return (kernel_path, rootfs_path) + logger.info("Inputs changed for image '%s'. Rebuilding...", name) # Remove stale files - if kernel_path.exists(): - kernel_path.unlink() - if rootfs_path.exists(): - rootfs_path.unlink() + kernel_path.unlink(missing_ok=True) + rootfs_path.unlink(missing_ok=True) logger.info("Building Debian key-only SSH image '%s'...", name) image_dir.mkdir(parents=True, exist_ok=True) @@ -433,6 +413,244 @@ def build_debian_ssh_key( rootfs_size_mb, extra_files={"authorized_keys": f"{key_value}\n"}, kernel_url=kernel_url, + fingerprint_data=fingerprint_data, + ) + except (subprocess.CalledProcessError, ImageError) as e: + if rootfs_path.exists(): + rootfs_path.unlink() + if kernel_path.exists(): + kernel_path.unlink() + if isinstance(e, ImageError): + raise + raise ImageError(f"Image build failed: {e}") from e + + logger.info("Image '%s' built successfully at %s", name, image_dir) + return (kernel_path, rootfs_path) + + def build_openclaw_rootfs( + self, + name: str = "openclaw", + # Note: 'smolvm' is intentionally kept as the default for simplified local + # demos and testing fixtures. Production usages should override this value. + ssh_password: str = "smolvm", + ssh_public_key: str | Path | None = None, + rootfs_size_mb: int = 2048, + kernel_url: str | None = None, + extra_packages: list[str] | None = None, + ) -> tuple[Path, Path]: + """Build OpenClaw rootfs with Node.js, sidecars, and init wiring. + + The resulting image contains: + + - Node.js >= 22.12.0 (``node:22.12.0-bookworm-slim`` base) + - OpenClaw pre-installed at ``/opt/openclaw/`` (symlinked to ``/usr/local/bin/openclaw``) + - ``inotify-tools`` and the device-approver sidecar + - SSH server for ``vm.run()`` management commands + - Custom ``/init`` that boots networking, sshd, and the sidecar + - Custom system packages like `git` (for npm source dependencies) + + Boot the resulting VM with :data:`OPENCLAW_BOOT_ARGS`. + + Args: + name: Image name for caching. + ssh_password: Root password for SSH (default: smolvm). + ssh_public_key: Public key content or path to a public key file. + rootfs_size_mb: Size of rootfs in MB (default: 2048). + kernel_url: Optional kernel URL override. + extra_packages: List of apt packages to install (defaults to ['git']). + + Returns: + Tuple of (kernel_path, rootfs_path). + + Raises: + ImageError: If Docker is not available or build fails. + """ + if not self.check_docker(): + raise ImageError( + "Docker is required to build images. " + "Install Docker Desktop (macOS) or docker.io (Linux)." + ) + + if extra_packages is None: + extra_packages = ["git"] + + # Validate package names to prevent Dockerfile string-interpolation injection + valid_pkg_regex = re.compile(r"^[a-z0-9\.\+\-]+$") + for pkg in extra_packages: + if not valid_pkg_regex.match(pkg): + raise ImageError(f"Invalid package name requested for installation: '{pkg}'") + + if ssh_public_key is None: + key_path = Path.home() / ".smolvm" / "keys" / "id_ed25519.pub" + try: + key_value = key_path.read_text().strip() + except OSError: + key_value = "" + else: + key_value = self._resolve_public_key(ssh_public_key) + + image_dir = self.cache_dir / name + kernel_path = image_dir / "vmlinux.bin" + rootfs_path = image_dir / "rootfs.ext4" + + fingerprint_data = { + "rootfs_size_mb": rootfs_size_mb, + "kernel_url": kernel_url, + "ssh_password": ssh_password, + "ssh_public_key": key_value, + "extra_packages": extra_packages, + } + + if kernel_path.exists() and rootfs_path.exists(): + if self._check_fingerprint(image_dir, fingerprint_data): + logger.info( + "Image '%s' already exists and fingerprint matches at %s", name, image_dir + ) + return (kernel_path, rootfs_path) + + logger.info("Inputs changed for OpenClaw image '%s'. Rebuilding...", name) + kernel_path.unlink(missing_ok=True) + rootfs_path.unlink(missing_ok=True) + + packages_str = " ".join(extra_packages) + + logger.info("Building OpenClaw image '%s' with extra packages: %s...", name, packages_str) + image_dir.mkdir(parents=True, exist_ok=True) + + init_script = self._openclaw_init_script() + + # --- Sidecar scripts (TDD Decision 1.2.5) --- + device_approver_py = r"""#!/usr/bin/env python3 +import json, time + +BASE = "/home/node/.openclaw/devices" +PENDING = f"{BASE}/pending.json" +PAIRED = f"{BASE}/paired.json" + +def approve(): + try: + pending = json.loads(open(PENDING).read()) + except (FileNotFoundError, json.JSONDecodeError): + return + if not pending: + return + try: + paired = json.loads(open(PAIRED).read()) + except (FileNotFoundError, json.JSONDecodeError): + paired = {} + now_ms = int(time.time() * 1000) + for _, entry in pending.items(): + device_id = entry.get("deviceId") + if not device_id: + continue + paired[device_id] = {**entry, "pairedAt": now_ms} + open(PAIRED, "w").write(json.dumps(paired, indent=2)) + open(PENDING, "w").write(json.dumps({})) + +approve() +""" + + watch_devices_sh = r"""#!/bin/bash +# Watch DIRECTORY not the file — handles atomic rename writes +while inotifywait -e close_write,moved_to \ + /home/node/.openclaw/devices 2>/dev/null; do + python3 /usr/local/bin/device-approver.py +done +""" + + systemctl_proxy_sh = r"""#!/bin/bash +if [ "$1" = "start" ] && [ "$2" = "openclaw" ]; then + echo "Starting openclaw via dummy systemctl..." + # The reconciler provisions the config via SSH then calls `systemctl start openclaw`. + # We use --allow-unconfigured so the gateway starts even before pairing completes. + # + # `/dev/null || true + +# Prepare OpenClaw directories and workspace +RUN useradd -m -s /bin/bash node 2>/dev/null || true && \\ + mkdir -p /opt/openclaw /home/node/.openclaw/devices /workspace && \\ + chown -R node:node /opt/openclaw /home/node/.openclaw /workspace + +WORKDIR /opt/openclaw +RUN npm init -y && \\ + npm --prefix /opt/openclaw install -g openclaw && \\ + ln -sf /opt/openclaw/bin/openclaw /usr/local/bin/openclaw && \\ + touch /var/log/openclaw.log && \\ + chown node:node /var/log/openclaw.log + +# Sidecar and proxy scripts +COPY device-approver.py /usr/local/bin/device-approver.py +COPY watch-devices.sh /usr/local/bin/watch-devices.sh +COPY systemctl /usr/local/bin/systemctl +RUN chmod +x \\ + /usr/local/bin/device-approver.py \\ + /usr/local/bin/watch-devices.sh \\ + /usr/local/bin/systemctl + +# Init script +COPY init /init +RUN chmod +x /init +""" + + try: + self._do_build( + name, + dockerfile_content, + init_script, + image_dir, + kernel_path, + rootfs_path, + rootfs_size_mb, + extra_files={ + "device-approver.py": device_approver_py, + "watch-devices.sh": watch_devices_sh, + "systemctl": systemctl_proxy_sh, + "authorized_keys": f"{key_value}\n" if key_value else "", + }, + build_args={"SSH_PASSWORD": ssh_password}, + kernel_url=kernel_url, + fingerprint_data=fingerprint_data, ) except (subprocess.CalledProcessError, ImageError) as e: if rootfs_path.exists(): @@ -456,9 +674,14 @@ def _resolve_public_key(self, ssh_public_key: str | Path) -> str: raise ImageError("Invalid SSH public key format") return key_text - def _default_init_script(self) -> str: - """Default PID 1 init script used by SSH-capable images.""" - return r"""#!/bin/sh + def _base_init_script(self, custom_hostname: str = "smolvm", custom_commands: str = "") -> str: + """Base PID 1 init script used by SSH-capable images. + + Args: + custom_hostname: Hostname to set (default: smolvm). + custom_commands: Additional shell commands to inject before the PID 1 sleep loop. + """ + return f"""#!/bin/sh # SmolVM custom init - runs as PID 1 inside Firecracker VM # ── Signal handling ────────────────────────────────────────── @@ -467,29 +690,29 @@ def _default_init_script(self) -> str: # kernel_restart() which tries a hardware reboot (doesn't exist # in Firecracker, so the VM hangs). We disable CAD so the # kernel sends SIGINT to PID 1 instead, where we trap it. -shutdown() { +shutdown() {{ echo "SmolVM init: shutting down..." kill -TERM -1 2>/dev/null sleep 0.2 sync poweroff -f -} +}} trap shutdown INT TERM PWR # ── Timestamp helpers (for host-side startup profiling) ────── -ts_uptime() { +ts_uptime() {{ cut -d' ' -f1 /proc/uptime 2>/dev/null || echo "0.00" -} +}} # date +%s is widely supported by busybox/coreutils. -ts_epoch() { +ts_epoch() {{ date +%s 2>/dev/null || echo "0" -} +}} -log_ts() { +log_ts() {{ STAGE="$1" - echo "SMOLVM_TS stage=${STAGE} epoch_s=$(ts_epoch) uptime_s=$(ts_uptime)" -} + echo "SMOLVM_TS stage=${{STAGE}} epoch_s=$(ts_epoch) uptime_s=$(ts_uptime)" +}} log_ts "init-start" @@ -530,15 +753,14 @@ def _default_init_script(self) -> str: ip link set lo up ip link set eth0 up -ip addr add "${GUEST_IP}/24" dev eth0 2>/dev/null || true -ip route add default via "${GATEWAY}" dev eth0 2>/dev/null || true +ip addr add "${{GUEST_IP}}/24" dev eth0 2>/dev/null || true +ip route add default via "${{GATEWAY}}" dev eth0 2>/dev/null || true # DNS echo "nameserver 8.8.8.8" > /etc/resolv.conf echo "nameserver 8.8.4.4" >> /etc/resolv.conf -# Set hostname -hostname smolvm +hostname {custom_hostname} log_ts "net-ready" # ── SSH ────────────────────────────────────────────────────── @@ -553,9 +775,12 @@ def _default_init_script(self) -> str: /usr/sbin/sshd -e log_ts "sshd-invoked" -echo "SmolVM init complete: IP=${GUEST_IP}, SSH listening on port 22" +echo "SmolVM init complete: IP=${{GUEST_IP}}, SSH listening on port 22" log_ts "init-complete" +# ── Custom Injections ─────────────────────────────────────── +{custom_commands} + # ── Keep PID 1 alive ──────────────────────────────────────── # Use 'wait' so signals are delivered promptly (plain 'sleep' # in a while-loop prevents signal delivery until sleep exits). @@ -565,14 +790,51 @@ def _default_init_script(self) -> str: done """ + def _default_init_script(self) -> str: + """Default PID 1 init script used by SSH-capable images.""" + return self._base_init_script() + + def _openclaw_init_script(self) -> str: + """PID 1 init script for OpenClaw images. + + Extends the base init with: + - Device-approver sidecar launched as a background process + - ``/home/node/.openclaw/devices`` directory setup + - Hostname set to ``openclaw`` + """ + device_approver_block = r""" +# ── Device-Approver Sidecar ───────────────────────────────── +# Launched as a background process — no systemd required. +# watch-devices.sh uses inotifywait on the directory (not file) to +# handle atomic-rename writes from OpenClaw. +log_ts "device-approver-start" +mkdir -p /home/node/.openclaw/devices +chown -R 1000:1000 /home/node/.openclaw +/usr/local/bin/watch-devices.sh & +DEVICE_APPROVER_PID=$! +log_ts "device-approver-started" +echo "Device-approver running with PID=${DEVICE_APPROVER_PID}" +""" + return self._base_init_script( + custom_hostname="openclaw", custom_commands=device_approver_block + ) + def _loopfs_helper_path(self) -> Path | None: """Return installed privileged helper path if available.""" if LOOPFS_HELPER_PATH.is_file(): return LOOPFS_HELPER_PATH return None - def _run_loopfs(self, action: str, *args: Path) -> None: - """Run a privileged loopfs action through the scoped helper.""" + def _run_loopfs(self, action: str, *args: Path, timeout: int = 30) -> None: + """Run a privileged loopfs action through the scoped helper. + + Args: + action: One of ``mount``, ``extract``, ``umount``. + *args: Positional path arguments forwarded to the helper. + timeout: Command timeout in seconds. Mount/umount are fast + (default 30 s); callers should pass a larger value for + ``extract`` when working with large images. + """ helper = self._loopfs_helper_path() if helper is None: raise ImageError( @@ -583,7 +845,7 @@ def _run_loopfs(self, action: str, *args: Path) -> None: cmd = [str(helper), action, *(str(arg) for arg in args)] try: - run_command(cmd, use_sudo=True, check=True, capture_output=True) + run_command(cmd, use_sudo=True, check=True, capture_output=True, timeout=timeout) except SmolVMError as e: raise ImageError( "Image build loopfs operation failed.\n" @@ -612,12 +874,40 @@ def qemu_kernel_url_for_host(self) -> str: arch_key = self._host_arch_key() return QEMU_KERNEL_URLS[arch_key] + def _check_fingerprint(self, image_dir: Path, data: dict[str, typing.Any]) -> bool: + """Check if the cached image fingerprint matches the current build inputs.""" + fingerprint_file = image_dir / ".fingerprint" + if not fingerprint_file.exists(): + return False + + expected_hash = self._hash_fingerprint_data(data) + try: + stored_hash = fingerprint_file.read_text().strip() + return stored_hash == expected_hash + except OSError: + return False + + def _write_fingerprint(self, image_dir: Path, data: dict[str, typing.Any]) -> None: + """Write the build input fingerprint to the cache directory.""" + fingerprint_file = image_dir / ".fingerprint" + try: + fingerprint_file.write_text(self._hash_fingerprint_data(data)) + except OSError as e: + logger.warning("Failed to write image fingerprint cache: %s", e) + + def _hash_fingerprint_data(self, data: dict[str, typing.Any]) -> str: + """Compute SHA-256 hash of a JSON-serializable dictionary.""" + json_str = json.dumps(data, sort_keys=True) + return hashlib.sha256(json_str.encode("utf-8")).hexdigest() + def _download_kernel(self, url: str, dest: Path) -> None: """Download kernel image to *dest* without external wget dependency.""" try: - with urllib.request.urlopen(url, timeout=180) as response: - with open(dest, "wb") as out: - shutil.copyfileobj(response, out) + with ( + urllib.request.urlopen(url, timeout=180) as response, + open(dest, "wb") as out, + ): + shutil.copyfileobj(response, out) except (urllib.error.URLError, OSError) as e: raise ImageError(f"Failed to download kernel from {url}: {e}") from e @@ -646,14 +936,20 @@ def _create_ext4_with_loopfs( mount_dir = tmpdir / "mnt" mount_dir.mkdir() self._run_loopfs("mount", rootfs_path, mount_dir) + + # Scale extract timeout with image size: tar-extracting thousands of + # Node.js module files onto a loop-mounted ext4 is inode-bound, not + # throughput-bound. 30 s is sufficient for mount/umount but far too + # short for a 4 GB+ rootfs on a standard (non-SSD) disk. + extract_timeout = max(300, rootfs_size_mb // 8) tar_error: Exception | None = None try: - self._run_loopfs("extract", tar_path, mount_dir) + self._run_loopfs("extract", tar_path, mount_dir, timeout=extract_timeout) except Exception as e: tar_error = e finally: try: - self._run_loopfs("umount", mount_dir) + self._run_loopfs("umount", mount_dir, timeout=extract_timeout) except ImageError: if tar_error is None: raise @@ -748,7 +1044,9 @@ def _do_build( rootfs_path: Path, rootfs_size_mb: int, extra_files: dict[str, str] | None = None, + build_args: dict[str, str] | None = None, kernel_url: str | None = None, + fingerprint_data: dict[str, typing.Any] | None = None, ) -> None: """Execute the Docker build and image conversion.""" docker_tag = f"smolvm-{name}" @@ -765,8 +1063,14 @@ def _do_build( # 1. Build Docker image logger.info(" [1/4] Building Docker image...") + build_cmd = ["docker", "build", "-t", docker_tag] + if build_args: + for k, v in build_args.items(): + build_cmd.extend(["--build-arg", f"{k}={v}"]) + build_cmd.append(str(tmp_path)) + subprocess.run( - ["docker", "build", "-t", docker_tag, str(tmp_path)], + build_cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, @@ -807,3 +1111,7 @@ def _do_build( resolved_kernel_url, ) self._download_kernel(resolved_kernel_url, kernel_path) + + # 5. Write cache fingerprint if provided and successful + if fingerprint_data is not None: + self._write_fingerprint(image_dir, fingerprint_data) diff --git a/src/smolvm/doctor.py b/src/smolvm/doctor.py index 2a95377..77615d2 100644 --- a/src/smolvm/doctor.py +++ b/src/smolvm/doctor.py @@ -18,6 +18,7 @@ import json import platform +import re import subprocess from dataclasses import asdict, dataclass from pathlib import Path @@ -39,6 +40,7 @@ class DoctorCheck: name: str status: DoctorStatus detail: str + fix: str | None = None @dataclass(frozen=True) @@ -66,7 +68,8 @@ def _check_command(binary: str, package_hint: str) -> DoctorCheck: return DoctorCheck( name=f"command:{binary}", status="fail", - detail=f"'{binary}' not found (install {package_hint})", + detail="Not found", + fix=f"Install {package_hint}", ) return DoctorCheck( name=f"command:{binary}", @@ -115,6 +118,222 @@ def _check_nft_table(family: str, table: str) -> DoctorCheck: ) +class WorkerNodeSecurityError(SmolVMError): + """Raised when one or more host-level security checks fail. + + The reconciler must refuse to start when this is raised rather than + running in a degraded security posture (C2: Defence in Depth). + """ + + +# --------------------------------------------------------------------------- +# Worker-node security invariants (Decision 1.1.5) +# --------------------------------------------------------------------------- + +_KSM_RUN = Path("/sys/kernel/mm/ksm/run") +_THP_ENABLED = Path("/sys/kernel/mm/transparent_hugepage/enabled") +_KVM_NX_PARAM = Path("/sys/module/kvm/parameters/nx_huge_pages") +_KVM_DEV = Path("/dev/kvm") +_PROC_MEMINFO = Path("/proc/meminfo") +_FSTAB = Path("/etc/fstab") + + +def _check_swap_disabled() -> DoctorCheck: + """C1: swap off prevents guest memory pages (potentially containing secrets) from + being written to the host disk.""" + name = "worker:swap-disabled" + + # 1. Is swap inactive right now? + try: + meminfo = _PROC_MEMINFO.read_text() + except OSError as exc: + return DoctorCheck(name=name, status="fail", detail=f"Cannot read /proc/meminfo: {exc}") + + match = re.search(r"^SwapTotal:\s+(\d+)", meminfo, re.MULTILINE) + swap_total_kb = int(match.group(1)) if match else 0 + if swap_total_kb != 0: + return DoctorCheck( + name=name, + status="fail", + detail=f"Active ({swap_total_kb} kB)", + fix="sudo swapoff -a", + ) + + # 2. Will it stay off after a reboot (/etc/fstab)? + try: + fstab_text = _FSTAB.read_text() + except OSError: + fstab_text = "" + + swap_entries = [ + line + for line in fstab_text.splitlines() + if line.strip() and not line.strip().startswith("#") and "swap" in line + ] + if swap_entries: + return DoctorCheck( + name=name, + status="fail", + detail="Swap entries found in /etc/fstab", + fix="sudo sed -i '/\\bswap\\b/d' /etc/fstab", + ) + + return DoctorCheck(name=name, status="pass", detail="Inactive and absent from /etc/fstab") + + +def _check_ksm_disabled() -> DoctorCheck: + """KSM off: prevents cross-VM memory-page timing side-channels.""" + name = "worker:ksm-disabled" + if not _KSM_RUN.exists(): + # KSM not compiled in; that is fine — it cannot be used. + return DoctorCheck(name=name, status="pass", detail="Not compiled in kernel") + try: + value = _KSM_RUN.read_text().strip() + except OSError as exc: + return DoctorCheck(name=name, status="fail", detail=f"Cannot read {_KSM_RUN}: {exc}") + + if value == "0": + return DoctorCheck(name=name, status="pass", detail="Disabled") + return DoctorCheck( + name=name, + status="fail", + detail=f"Active (run={value})", + fix="sudo sh -c 'echo 0 > /sys/kernel/mm/ksm/run'", + ) + + +def _check_thp_disabled() -> DoctorCheck: + """THP=never: prevents latency spikes that could disrupt VM timing guarantees.""" + name = "worker:thp-disabled" + if not _THP_ENABLED.exists(): + return DoctorCheck(name=name, status="pass", detail="Not compiled in kernel") + try: + raw = _THP_ENABLED.read_text() + except OSError as exc: + return DoctorCheck(name=name, status="fail", detail=f"Cannot read {_THP_ENABLED}: {exc}") + + # File content looks like: "always madvise [never]" + bracket_match = re.search(r"\[(\w+)\]", raw) + active = bracket_match.group(1) if bracket_match else raw.split()[0] + + if active == "never": + return DoctorCheck(name=name, status="pass", detail="Disabled ('never')") + return DoctorCheck( + name=name, + status="fail", + detail=f"Active ('{active}')", + fix="sudo sh -c 'echo never > /sys/kernel/mm/transparent_hugepage/enabled'", + ) + + +def _check_kvm_nx_huge_pages() -> DoctorCheck: + """CVE-2021-3737 / KVM iTLB multihit: nx_huge_pages must be 'never'.""" + name = "worker:kvm-nx-huge-pages" + if not _KVM_NX_PARAM.exists(): + return DoctorCheck( + name=name, + status="warn", + detail="Module parameter absent (kvm module not loaded?)", + fix="sudo modprobe kvm nx_huge_pages=never", + ) + try: + value = _KVM_NX_PARAM.read_text().strip().lower() + except OSError as exc: + return DoctorCheck(name=name, status="fail", detail=f"Cannot read {_KVM_NX_PARAM}: {exc}") + + # Kernel exposes this as "never" or "N" depending on version. + if value in {"never", "n"}: + return DoctorCheck(name=name, status="pass", detail=f"Mitigated ('{value}')") + + return DoctorCheck( + name=name, + status="fail", + detail=f"nx_huge_pages='{value}'", + fix="sudo modprobe -r kvm_intel kvm && sudo modprobe kvm nx_huge_pages=never", + ) + + +def _check_kvm_permissions() -> DoctorCheck: + """Firecracker (jailer) needs /dev/kvm with 660 + kvm group ownership.""" + name = "worker:kvm-permissions" + if not _KVM_DEV.exists(): + return DoctorCheck(name=name, status="fail", detail="/dev/kvm not found") + + stat = _KVM_DEV.stat() + # Mode bits: 0o660 means rw-rw---- + current_mode = oct(stat.st_mode & 0o777) + import grp as _grp # stdlib; import locally to avoid top-level overhead + + try: + current_group = _grp.getgrgid(stat.st_gid).gr_name + except KeyError: + current_group = str(stat.st_gid) + + ok_perms = (stat.st_mode & 0o777) == 0o660 + ok_group = current_group == "kvm" + + if ok_perms and ok_group: + return DoctorCheck(name=name, status="pass", detail=f"{current_mode} group={current_group}") + + problems = [] + if not ok_perms: + problems.append(f"mode={current_mode}") + if not ok_group: + problems.append(f"group={current_group}") + + return DoctorCheck( + name=name, + status="fail", + detail="Incorrect permissions: " + " and ".join(problems), + fix="sudo chmod 660 /dev/kvm && sudo chgrp kvm /dev/kvm", + ) + + +def check_worker_node_security() -> list[DoctorCheck]: + """Run all host-level security invariants required before starting the reconciler. + + Returns a list of :class:`DoctorCheck` results. Raises + :class:`WorkerNodeSecurityError` if **any** check is non-passing. + + Design rationale (C2 — Defence in Depth): + These checks operate at the host kernel level. No amount of application + code can compensate for a wrong setting here. The reconciler must call + this function at startup and abort if it raises. + + Example usage in a reconciler entrypoint:: + + from smolvm.doctor import check_worker_node_security, WorkerNodeSecurityError + + try: + check_worker_node_security() + except WorkerNodeSecurityError as exc: + logger.critical("Worker node security check failed: %s", exc) + sys.exit(1) + """ + checks: list[DoctorCheck] = [ + _check_swap_disabled(), + _check_ksm_disabled(), + _check_thp_disabled(), + _check_kvm_nx_huge_pages(), + _check_kvm_permissions(), + ] + + non_passing = [c for c in checks if c.status != "pass"] + if non_passing: + msg_parts = [] + for c in non_passing: + part = f"{c.name} ({c.status}): {c.detail}" + if c.fix: + part += f" - Fix: {c.fix}" + msg_parts.append(part) + lines = " | ".join(msg_parts) + raise WorkerNodeSecurityError( + f"Worker node security checks failed ({len(non_passing)}/{len(checks)}): {lines}" + ) + + return checks + + def generate_doctor_report(backend: str | None = None) -> DoctorReport: """Collect diagnostics for the selected runtime backend.""" requested = (backend or BACKEND_AUTO).strip().lower() @@ -173,6 +392,20 @@ def generate_doctor_report(backend: str | None = None) -> DoctorReport: checks.append(_check_nft_table("ip", "smolvm_nat")) checks.append(_check_nft_table("inet", "smolvm_filter")) + # Worker-node host-level security invariants (Decision 1.1.5). + # These are included as informational checks in the doctor report; + # the reconciler startup guard calls check_worker_node_security() + # separately and refuses to start on failure. + checks.extend( + [ + _check_swap_disabled(), + _check_ksm_disabled(), + _check_thp_disabled(), + _check_kvm_nx_huge_pages(), + _check_kvm_permissions(), + ] + ) + elif resolved == BACKEND_QEMU: qemu = _find_qemu_binary() if qemu is None: @@ -181,8 +414,7 @@ def generate_doctor_report(backend: str | None = None) -> DoctorReport: name="qemu", status="fail", detail=( - "QEMU not found. Install one of: qemu-system-aarch64, " - "qemu-system-x86_64" + "QEMU not found. Install one of: qemu-system-aarch64, qemu-system-x86_64" ), ) ) @@ -252,29 +484,39 @@ def generate_doctor_report(backend: str | None = None) -> DoctorReport: def _print_human_report(report: DoctorReport, strict: bool) -> None: - print("SmolVM Doctor") - print(f"Backend: {report.backend_resolved} (requested: {report.backend_requested})") + print(f"SmolVM Doctor (Backend: {report.backend_resolved})") print(f"Platform: {report.system} {report.arch}") print("") markers = { - "pass": "PASS", - "warn": "WARN", - "fail": "FAIL", + "pass": "✓", + "warn": "!", + "fail": "✗", + } + + colors = { + "pass": "\033[92m", + "warn": "\033[93m", + "fail": "\033[91m", + "reset": "\033[0m", } for check in report.checks: - print(f"[{markers[check.status]}] {check.name}: {check.detail}") + color = colors[check.status] + reset = colors["reset"] + print(f" [{color}{markers[check.status]}{reset}] {check.name}: {check.detail}") + if check.fix and check.status != "pass": + print(f" {color}Fix: {check.fix}{reset}") print("") failures = len(report.failures) warnings = len(report.warnings) if failures == 0 and (warnings == 0 or not strict): - print("Doctor result: OK") + print("Doctor result: \033[92mOK\033[0m") elif strict and warnings and not failures: - print("Doctor result: FAIL (strict mode treats warnings as failures)") + print("Doctor result: \033[91mFAIL\033[0m (strict mode treats warnings as failures)") else: - print("Doctor result: FAIL") + print("Doctor result: \033[91mFAIL\033[0m") def run_doctor( diff --git a/src/smolvm/network.py b/src/smolvm/network.py index f54dc57..4aa6a6b 100644 --- a/src/smolvm/network.py +++ b/src/smolvm/network.py @@ -418,6 +418,33 @@ def _delete_nft_rules( if delete_lines: self._run_nft_script("\n".join(delete_lines) + "\n") + def _find_nft_delete_rule_lines( + self, + family: str, + table: str, + *, + comment: str | None = None, + comment_prefix: str | None = None, + ) -> list[str]: + """Return nft 'delete rule' lines for rules matching comment filters.""" + if comment is None and comment_prefix is None: + raise ValueError("comment or comment_prefix must be provided") + + table_output = self._nft_list_table(family, table, handles=True) + if not table_output: + return [] + + handles = self._extract_table_rule_handles(table_output) + delete_lines: list[str] = [] + for chain, rule_comment, handle in handles: + if comment is not None and rule_comment != comment: + continue + if comment_prefix is not None and not rule_comment.startswith(comment_prefix): + continue + delete_lines.append(f"delete rule {family} {table} {chain} handle {handle}") + + return delete_lines + # ------------------------------------------------------------------ # Public firewall/NAT API # ------------------------------------------------------------------ @@ -464,10 +491,7 @@ def setup_nat(self, tap_name: str) -> None: _NFT_FILTER_FAMILY, _NFT_FILTER_TABLE, "forward", - ( - f"iifname {self._quote('tap*')} " - f"oifname {self._quote('tap*')} counter drop" - ), + (f"iifname {self._quote('tap*')} oifname {self._quote('tap*')} counter drop"), "smolvm:global:forward:tap-isolation", ), ] @@ -669,6 +693,125 @@ def cleanup_nat_rules(self, tap_name: str) -> None: comment=comment, ) + def apply_egress_allowlist( + self, + tap_device: str, + allowed_ips: list[str], + ) -> None: + """Restrict outbound traffic from *tap_device* to *allowed_ips* only. + + Installs per-TAP rules in the SmolVM filter forward chain keyed by tap + name so they are isolated between tenants:: + + # pass matching return traffic (established sessions) + iifname ct state established,related counter accept + # allow the configured destination set + iifname ip daddr { , , ... } counter accept + # drop everything else going out from this tap + iifname ip daddr != { , , ... } counter drop + + The function is fail-closed and update-safe: it applies a single nft + transaction that stages new rules first, then removes stale rules and + any generic per-TAP NAT accept rule. If the transaction fails, old rules + remain in place unchanged. + + Args: + tap_device: TAP interface name (e.g., ``tap42``). + allowed_ips: CIDR or host addresses that the guest may reach. + Pass an empty list to deny *all* outbound IP traffic. + + Raises: + ValueError: If ``tap_device`` is empty. + NetworkError: If the nft call fails. + """ + if not tap_device: + raise ValueError("tap_device cannot be empty") + + logger.info( + "Applying egress allowlist for %s: %s", + tap_device, + allowed_ips or "", + ) + + self._ensure_nftables_base() + iface = self.outbound_interface + + comment_prefix = f"smolvm:egress:{tap_device}" + old_egress_delete_lines = self._find_nft_delete_rule_lines( + _NFT_FILTER_FAMILY, + _NFT_FILTER_TABLE, + comment_prefix=f"{comment_prefix}:", + ) + old_nat_accept_delete_lines = self._find_nft_delete_rule_lines( + _NFT_FILTER_FAMILY, + _NFT_FILTER_TABLE, + comment=f"smolvm:nat:tap:{tap_device}:to:{iface}", + ) + script_lines = [ + ( + f"add rule {_NFT_FILTER_FAMILY} {_NFT_FILTER_TABLE} forward " + f"iifname {self._quote(tap_device)} ct state established,related " + f"counter accept comment {self._quote(f'{comment_prefix}:established')}" + ), + ] + + if allowed_ips: + # nftables anonymous set: ip daddr != { a, b, c } + ip_set = ", ".join(allowed_ips) + script_lines.append( + ( + f"add rule {_NFT_FILTER_FAMILY} {_NFT_FILTER_TABLE} forward " + f"iifname {self._quote(tap_device)} ip daddr {{ {ip_set} }} " + f"counter accept comment {self._quote(f'{comment_prefix}:allow')}" + ), + ) + drop_expr = ( + f"iifname {self._quote(tap_device)} " + f"ip daddr != {{ {ip_set} }} counter drop" + ) + else: + # No IPs allowed — drop unconditionally. + drop_expr = f"iifname {self._quote(tap_device)} counter drop" + + script_lines.append( + ( + f"add rule {_NFT_FILTER_FAMILY} {_NFT_FILTER_TABLE} forward " + f"{drop_expr} comment {self._quote(f'{comment_prefix}:drop')}" + ) + ) + + script_lines.extend(old_egress_delete_lines) + script_lines.extend(old_nat_accept_delete_lines) + + self._run_nft_script("\n".join(script_lines) + "\n") + + def remove_egress_rules(self, tap_device: str) -> None: + """Remove all egress allowlist rules for *tap_device*. + + Must be called **before** ``vm.delete()`` to prevent a rule-table leak. + nftables rules survive VM termination; this cleans them up atomically + using a comment-prefix match. + + The call is best-effort: if the table no longer exists (e.g., host + reboot) the function returns silently. + + Args: + tap_device: TAP interface name used in :meth:`apply_egress_allowlist`. + + Raises: + ValueError: If ``tap_device`` is empty. + """ + if not tap_device: + raise ValueError("tap_device cannot be empty") + + logger.info("Removing egress rules for %s", tap_device) + + self._delete_nft_rules( + _NFT_FILTER_FAMILY, + _NFT_FILTER_TABLE, + comment_prefix=f"smolvm:egress:{tap_device}:", + ) + def generate_mac(self, vm_number: int) -> str: """Generate deterministic VM MAC address for vm_number in [0,255].""" if vm_number < 0 or vm_number > 255: diff --git a/src/smolvm/types.py b/src/smolvm/types.py index 203b945..1155d29 100644 --- a/src/smolvm/types.py +++ b/src/smolvm/types.py @@ -57,6 +57,7 @@ class VMConfig(BaseModel): mem_size_mib: Memory size in MiB (128-16384). kernel_path: Path to the kernel image. rootfs_path: Path to the root filesystem image. + extra_drives: Additional block-device image paths to attach at boot. boot_args: Kernel boot arguments. backend: Optional runtime backend override ("firecracker" or "qemu"). disk_mode: Disk lifecycle mode: @@ -73,18 +74,20 @@ class VMConfig(BaseModel): str, Field( default_factory=_generate_vm_id, - pattern=r"^[a-z0-9][a-z0-9-]{0,62}[a-z0-9]$|^[a-z0-9]$", + pattern=r"^[a-z0-9][a-z0-9_-]{0,62}[a-z0-9]$|^[a-z0-9]$", ), ] vcpu_count: Annotated[int, Field(ge=1, le=32)] = 2 mem_size_mib: Annotated[int, Field(ge=128, le=16384)] = 512 kernel_path: Path rootfs_path: Path + extra_drives: list[Path] = [] boot_args: str = "console=ttyS0 reboot=k panic=1 pci=off" backend: str | None = None disk_mode: Literal["isolated", "shared"] = "isolated" retain_disk_on_delete: bool = False env_vars: dict[str, str] = {} + network_rate_limit_mbps: Annotated[int, Field(ge=1)] | None = None vsock: VsockConfig = VsockConfig() @field_validator("vm_id", mode="before") @@ -99,6 +102,19 @@ def default_vm_id_when_none(cls, v: object) -> object: @classmethod def validate_path_exists(cls, v: Path) -> Path: """Ensure paths exist on the filesystem.""" + return cls._validate_file_path(v) + + @field_validator("extra_drives") + @classmethod + def validate_extra_drives(cls, v: list[Path]) -> list[Path]: + """Ensure all extra drive paths exist and are files.""" + for path in v: + cls._validate_file_path(path) + return v + + @staticmethod + def _validate_file_path(v: Path) -> Path: + """Validate a filesystem path points to an existing file.""" if not v.exists(): raise ValueError(f"Path does not exist: {v}") if not v.is_file(): diff --git a/src/smolvm/vm.py b/src/smolvm/vm.py index d72e308..9c86b4c 100644 --- a/src/smolvm/vm.py +++ b/src/smolvm/vm.py @@ -634,11 +634,20 @@ def start( is_root_device=True, is_read_only=False, ) + for index, drive_path in enumerate(vm_info.config.extra_drives): + drive_id = "data_drive" if index == 0 else f"data_drive_{index}" + client.add_drive( + drive_id, + drive_path, + is_root_device=False, + is_read_only=False, + ) assert vm_info.network is not None client.add_network_interface( "eth0", vm_info.network.tap_device, vm_info.network.guest_mac, + rate_limit_mbps=vm_info.config.network_rate_limit_mbps, ) if vm_info.config.vsock.enabled: @@ -1197,6 +1206,8 @@ def _cleanup_resources(self, vm_id: str) -> None: if backend == BACKEND_FIRECRACKER: # Cleanup Linux TAP/NAT only for Firecracker backend. + with suppress(Exception): + self.network.remove_egress_rules(tap_device) self.network.cleanup_nat_rules(tap_device) self.network.cleanup_tap(tap_device) diff --git a/tests/test_doctor.py b/tests/test_doctor.py index 6af0aac..4c06da0 100644 --- a/tests/test_doctor.py +++ b/tests/test_doctor.py @@ -17,13 +17,30 @@ from pathlib import Path from unittest.mock import MagicMock, patch -from smolvm.doctor import generate_doctor_report, run_doctor +import pytest + +from smolvm.doctor import ( + DoctorCheck, + WorkerNodeSecurityError, + check_worker_node_security, + generate_doctor_report, + run_doctor, +) from smolvm.exceptions import SmolVMError +def _pass(name: str) -> DoctorCheck: + return DoctorCheck(name=name, status="pass", detail="ok") + + class TestDoctorFirecracker: """Firecracker backend diagnostic tests.""" + @patch("smolvm.doctor._check_kvm_permissions", new=lambda: _pass("worker:kvm-permissions")) + @patch("smolvm.doctor._check_kvm_nx_huge_pages", new=lambda: _pass("worker:kvm-nx-huge-pages")) + @patch("smolvm.doctor._check_thp_disabled", new=lambda: _pass("worker:thp-disabled")) + @patch("smolvm.doctor._check_ksm_disabled", new=lambda: _pass("worker:ksm-disabled")) + @patch("smolvm.doctor._check_swap_disabled", new=lambda: _pass("worker:swap-disabled")) @patch("smolvm.doctor.run_command") @patch("smolvm.doctor.check_network_prerequisites", return_value=[]) @patch("smolvm.doctor.which") @@ -55,6 +72,11 @@ def _which_side_effect(binary: str) -> Path | None: assert report.failures == [] assert report.warnings == [] + @patch("smolvm.doctor._check_kvm_permissions", new=lambda: _pass("worker:kvm-permissions")) + @patch("smolvm.doctor._check_kvm_nx_huge_pages", new=lambda: _pass("worker:kvm-nx-huge-pages")) + @patch("smolvm.doctor._check_thp_disabled", new=lambda: _pass("worker:thp-disabled")) + @patch("smolvm.doctor._check_ksm_disabled", new=lambda: _pass("worker:ksm-disabled")) + @patch("smolvm.doctor._check_swap_disabled", new=lambda: _pass("worker:swap-disabled")) @patch("smolvm.doctor.run_command", side_effect=SmolVMError("No such file or directory")) @patch("smolvm.doctor.check_network_prerequisites", return_value=[]) @patch("smolvm.doctor.which") @@ -104,3 +126,24 @@ def test_generate_report_qemu_missing_binary( assert report.backend_resolved == "qemu" assert any(check.name == "qemu" and check.status == "fail" for check in report.checks) + + +class TestWorkerNodeSecurityChecks: + """Tests for strict worker-node startup guard behavior.""" + + @patch("smolvm.doctor._check_kvm_permissions", new=lambda: _pass("worker:kvm-permissions")) + @patch( + "smolvm.doctor._check_kvm_nx_huge_pages", + new=lambda: DoctorCheck( + name="worker:kvm-nx-huge-pages", + status="warn", + detail="kvm module not loaded", + ), + ) + @patch("smolvm.doctor._check_thp_disabled", new=lambda: _pass("worker:thp-disabled")) + @patch("smolvm.doctor._check_ksm_disabled", new=lambda: _pass("worker:ksm-disabled")) + @patch("smolvm.doctor._check_swap_disabled", new=lambda: _pass("worker:swap-disabled")) + def test_check_worker_node_security_raises_on_warn(self) -> None: + """Startup guard should reject non-pass security checks, including warnings.""" + with pytest.raises(WorkerNodeSecurityError, match=r"worker:kvm-nx-huge-pages \(warn\)"): + check_worker_node_security() diff --git a/tests/test_network.py b/tests/test_network.py index 7b7cd20..4ca5c5c 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -256,3 +256,58 @@ def test_check_network_prerequisites_checks_scoped_sudo_commands( assert ["ip", "link", "show"] in commands assert ["nft", "list", "tables"] in commands assert ["sysctl", "net.ipv4.ip_forward"] in commands + + +class TestEgressAllowlist: + """Tests for egress allowlist rule behavior.""" + + def test_apply_egress_allowlist_applies_atomic_add_then_delete_update( + self, + ) -> None: + """Allowlist updates should stage new rules before deleting old ones.""" + nm = NetworkManager() + nm._outbound_interface = "eth0" + nm._ensure_nftables_base = MagicMock() + nm._nft_list_table = MagicMock( + return_value=( + "table inet smolvm_filter {\n" + " chain forward {\n" + ' iifname "tap42" ct state established,related counter accept comment "smolvm:egress:tap42:established" # handle 41\n' + ' iifname "tap42" counter drop comment "smolvm:egress:tap42:drop" # handle 42\n' + ' iifname "tap42" oifname "eth0" counter accept comment "smolvm:nat:tap:tap42:to:eth0" # handle 43\n' + " }\n" + "}\n" + ) + ) + nm._run_nft_script = MagicMock() + + nm.apply_egress_allowlist("tap42", ["1.1.1.1", "8.8.8.8"]) + + assert nm._nft_list_table.call_count == 2 + script = nm._run_nft_script.call_args.args[0] + + add_established = ( + 'add rule inet smolvm_filter forward iifname "tap42" ct state established,related ' + 'counter accept comment "smolvm:egress:tap42:established"' + ) + add_allow = ( + 'add rule inet smolvm_filter forward iifname "tap42" ip daddr { 1.1.1.1, 8.8.8.8 } ' + 'counter accept comment "smolvm:egress:tap42:allow"' + ) + add_drop = ( + 'add rule inet smolvm_filter forward iifname "tap42" ip daddr != { 1.1.1.1, 8.8.8.8 } ' + 'counter drop comment "smolvm:egress:tap42:drop"' + ) + delete_old_established = "delete rule inet smolvm_filter forward handle 41" + delete_old_drop = "delete rule inet smolvm_filter forward handle 42" + delete_old_nat_accept = "delete rule inet smolvm_filter forward handle 43" + + assert add_established in script + assert add_allow in script + assert add_drop in script + assert delete_old_established in script + assert delete_old_drop in script + assert delete_old_nat_accept in script + + # Fail-closed sequencing: all adds are staged before old rules are deleted. + assert script.index(add_drop) < script.index(delete_old_established) diff --git a/tests/test_types.py b/tests/test_types.py index 49a34b7..0e24f23 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -58,7 +58,7 @@ def test_vm_id_auto_generated_when_omitted(self, tmp_path: Path) -> None: ) assert config.vm_id.startswith("vm-") - assert re.fullmatch(r"^[a-z0-9][a-z0-9-]{0,62}[a-z0-9]$|^[a-z0-9]$", config.vm_id) + assert re.fullmatch(r"^[a-z0-9][a-z0-9_-]{0,62}[a-z0-9]$|^[a-z0-9]$", config.vm_id) def test_vm_id_auto_generated_when_none(self, tmp_path: Path) -> None: """Test VM ID is generated when explicitly set to None.""" @@ -269,6 +269,42 @@ def test_disk_mode_defaults_to_isolated(self, tmp_path: Path) -> None: assert config.disk_mode == "isolated" assert config.retain_disk_on_delete is False + def test_extra_drives_default_empty(self, tmp_path: Path) -> None: + """Test extra_drives defaults to an empty list.""" + kernel = tmp_path / "vmlinux" + rootfs = tmp_path / "rootfs.ext4" + kernel.touch() + rootfs.touch() + + config = VMConfig(vm_id="vm001", kernel_path=kernel, rootfs_path=rootfs) + + assert config.extra_drives == [] + + def test_extra_drives_must_be_existing_files(self, tmp_path: Path) -> None: + """Test extra drive paths must exist and point to files.""" + kernel = tmp_path / "vmlinux" + rootfs = tmp_path / "rootfs.ext4" + data_drive = tmp_path / "data.ext4" + kernel.touch() + rootfs.touch() + data_drive.touch() + + config = VMConfig( + vm_id="vm001", + kernel_path=kernel, + rootfs_path=rootfs, + extra_drives=[data_drive], + ) + assert config.extra_drives == [data_drive] + + with pytest.raises(ValidationError, match="does not exist"): + VMConfig( + vm_id="vm002", + kernel_path=kernel, + rootfs_path=rootfs, + extra_drives=[tmp_path / "missing.ext4"], + ) + def test_invalid_disk_mode_rejected(self, tmp_path: Path) -> None: """Test unsupported disk_mode values are rejected.""" kernel = tmp_path / "vmlinux" diff --git a/tests/test_vm.py b/tests/test_vm.py index 0d8f4de..a56b7ea 100644 --- a/tests/test_vm.py +++ b/tests/test_vm.py @@ -17,7 +17,7 @@ import subprocess from pathlib import Path from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch import pytest @@ -363,6 +363,7 @@ def test_delete_cleans_local_forward_rules( smol_vm.delete("vm001") mock_network.cleanup_all_local_port_forwards.assert_called_once_with("vm001") + mock_network.remove_egress_rules.assert_called_once_with("tap2") class TestIPBasedTAPNaming: @@ -565,6 +566,47 @@ def test_start_preserves_existing_ip_boot_arg( boot_args = mock_client.set_boot_source.call_args[0][1] assert boot_args == config.boot_args + @patch("smolvm.vm.FirecrackerClient") + @patch.object(SmolVMManager, "_start_firecracker") + @patch("smolvm.vm.NetworkManager") + def test_start_attaches_extra_drives( + self, + mock_network_class: MagicMock, + mock_start_fc: MagicMock, + mock_client_cls: MagicMock, + smol_vm: SmolVMManager, + sample_config: VMConfig, + tmp_path: Path, + ) -> None: + """Test start() attaches configured extra drives via Firecracker drives API.""" + mock_network = MagicMock() + mock_network.host_ip = "172.16.0.1" + mock_network.generate_mac.return_value = "AA:FC:00:00:00:02" + mock_network_class.return_value = mock_network + smol_vm.network = mock_network + + mock_process = MagicMock() + mock_process.pid = 12345 + mock_start_fc.return_value = mock_process + + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + + data_drive = tmp_path / "data.ext4" + data_drive.touch() + config = sample_config.model_copy(update={"extra_drives": [data_drive]}) + + smol_vm.create(config) + smol_vm.start("vm001") + + assert mock_client.add_drive.call_count == 2 + assert mock_client.add_drive.call_args_list[1] == call( + "data_drive", + data_drive, + is_root_device=False, + is_read_only=False, + ) + @patch("smolvm.vm.NetworkManager") def test_get_ssh_commands_returns_private_and_forwarded( self,