From 29d67e13c2bf0e7b23d8659ad9fb1c926db4c2e6 Mon Sep 17 00:00:00 2001 From: Jeff Abrahamson Date: Sun, 7 Dec 2025 20:17:01 +0100 Subject: [PATCH 1/3] Add sway network status tool Periodically pings a set of hosts once per second, reads speedtest data, and writes a file that i3status will use to integrate into swaybar. Align speedtest network detection so that they use the same network interface names. In particular, use `ip route` to find the network associated with the first default route, then, if that's wifi, look up the SSID name. --- bin/bin/speedtest | 46 +- i3/i3/i3status.conf | 4 +- python/bandwidth_tool/__init__.py | 54 ++ sway/sway/bin/sway-network-status | 1144 +++++++++++++++++++++++++++++ sway/sway/config_base | 1 + tests/test_bandwidth.py | 70 +- 6 files changed, 1308 insertions(+), 11 deletions(-) create mode 100755 sway/sway/bin/sway-network-status diff --git a/bin/bin/speedtest b/bin/bin/speedtest index 4f40b9d..7ed9585 100755 --- a/bin/bin/speedtest +++ b/bin/bin/speedtest @@ -25,6 +25,30 @@ tmp_json=$(mktemp -t speedtest_json_XXXXXX) cleanup() { rm -f "$tmp_json"; } trap cleanup EXIT +detect_default_interface() { + local output line + if ! output=$(ip route show default 2>/dev/null); then + return + fi + while IFS= read -r line; do + [[ $line == default* ]] || continue + read -ra parts <<<"$line" + for idx in "${!parts[@]}"; do + if [[ ${parts[$idx]} == "dev" ]] && (( idx + 1 < ${#parts[@]} )); then + echo "${parts[$((idx + 1))]}" + return + fi + done + done <<<"$output" +} + +detect_wifi_ssid() { + local interface=$1 + if command -v iwgetid >/dev/null 2>&1; then + iwgetid "$interface" --raw 2>>/tmp/iwgetid.log | head -n1 + fi +} + # Run speedtest -> JSON if ! "$SPEEDTEST_BIN" -f json >"$tmp_json" 2>/dev/null; then # Fall back to invocation time only if we must @@ -74,14 +98,22 @@ record_value "$epoch_ts" "$ping_ms" ping record_value "$epoch_ts" "$dl_MiBps" download record_value "$epoch_ts" "$ul_MiBps" upload -# SSID (if on Wi-Fi; ignore errors) -if command -v iwgetid >/dev/null 2>&1; then - ssid=$(iwgetid -r 2>>/tmp/iwgetid.log || true) - ssid=${ssid:-unknown-ssid} - record_value "$epoch_ts" "$ssid" ssid -else - record_value "$epoch_ts" "$interface_name" ssid +# Determine current network identifier (SSID if wifi, else interface name) +network_id="unknown" +default_iface=$(detect_default_interface) +if [[ -n "${default_iface:-}" ]]; then + ssid=$(detect_wifi_ssid "$default_iface" || true) + if [[ -n "${ssid:-}" ]]; then + network_id="$ssid" + else + network_id="$default_iface" + fi +elif [[ -n "${interface_name:-}" ]]; then + # Fall back to interface reported by speedtest JSON + network_id="$interface_name" fi +record_value "$epoch_ts" "$network_id" ssid + exit 0 diff --git a/i3/i3/i3status.conf b/i3/i3/i3status.conf index bb0f683..624ba72 100644 --- a/i3/i3/i3status.conf +++ b/i3/i3/i3status.conf @@ -8,7 +8,7 @@ general { colors = true - interval = 5 + interval = 1 } #order += "ipv6" @@ -61,7 +61,7 @@ read_file network_status { format = "%content" #format_bad = "%title - %errno: %error" format_bad = "" - path = "/home/jeff/.uptime-status" + path = "/home/jeff/.network-status" # Max_characters = 255 } diff --git a/python/bandwidth_tool/__init__.py b/python/bandwidth_tool/__init__.py index a55ea76..e821d2a 100644 --- a/python/bandwidth_tool/__init__.py +++ b/python/bandwidth_tool/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations import socket +import subprocess from dataclasses import dataclass from datetime import datetime from pathlib import Path @@ -17,6 +18,56 @@ SSID_FILE = "speedtest-ssid" +def detect_default_interface() -> Optional[str]: + """Return the interface used for the default route, or None.""" + + try: + output = subprocess.check_output( + ["ip", "route", "show", "default"], + stderr=subprocess.DEVNULL, + text=True, + ) + except Exception: + return None + + for line in output.splitlines(): + line = line.strip() + if not line.startswith("default "): + continue + parts = line.split() + if "dev" in parts: + idx = parts.index("dev") + if idx + 1 < len(parts): + return parts[idx + 1] + return None + + +def detect_wifi_ssid(interface: str) -> Optional[str]: + """Return SSID for a wifi interface, or None if unavailable.""" + + try: + output = subprocess.check_output( + ["iwgetid", interface, "--raw"], + stderr=subprocess.DEVNULL, + text=True, + ).strip() + except Exception: + return None + return output or None + + +def current_network_identifier() -> str: + """Return SSID if wifi, otherwise default interface name or "unknown".""" + + interface = detect_default_interface() + if not interface: + return "unknown" + ssid = detect_wifi_ssid(interface) + if ssid: + return ssid + return interface + + @dataclass class Measurement: """Represents a single measurement row.""" @@ -301,4 +352,7 @@ def render_stats(measurements: Sequence[Measurement], *, text: bool, bins: int = "render_stats_graphical", "render_stats_text", "render_table", + "detect_default_interface", + "detect_wifi_ssid", + "current_network_identifier", ] diff --git a/sway/sway/bin/sway-network-status b/sway/sway/bin/sway-network-status new file mode 100755 index 0000000..b2bf10c --- /dev/null +++ b/sway/sway/bin/sway-network-status @@ -0,0 +1,1144 @@ +#!/usr/bin/env python3 +""" +sway-network-monitor + +Periodically: +- Pings a set of hosts once per second. +- Maintains short and long sparklines of "up" status across hosts. +- Computes EWMA of ping latencies. +- Reads speedtest summary files and computes bandwidth estimates. +- Writes a single status line either to stdout or to ~/.network-status + (via an atomic ~/.network-status.new -> ~/.network-status rename). + +Intended to be used by i3status' read_file module under sway. +""" + +import argparse +import dataclasses +import logging +import math +import os +import socket +import subprocess +import sys +import time +from typing import Dict, List, Optional, Tuple + +try: + import tomllib # Python 3.11+ +except ImportError: # pragma: no cover - older Python + tomllib = None + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +HOME = os.path.expanduser("~") +HOSTNAME_SHORT = socket.gethostname().split(".", 1)[0] + +DEFAULT_CONFIG_PATH = os.path.join( + HOME, ".config", "sway-network-monitor.toml" +) +DEFAULT_LOG_PATH = "/tmp/sway-network-monitor.log" + +DEFAULT_PING_HOSTS = ["1.1.1.1", "8.8.8.8"] +DEFAULT_LONG_INTERVAL = 15 # seconds +DEFAULT_SPEEDTEST_MARGIN_MINUTES = 40 # minutes +DEFAULT_SPEEDTEST_AGGREGATION = "ewma" # or "mean" +DEFAULT_EWMA_ALPHA = 0.3 +DEFAULT_ONE_SHOT_COLLECT_TIME = 10 # seconds + +NO_DATA_LENGTH = 60 # initial filler length +NETWORK_STATUS_PATH = os.path.join(HOME, ".network-status") +NETWORK_STATUS_TMP_PATH = NETWORK_STATUS_PATH + ".new" + +# Default speedtest directory: ${HOME}/data/hosts/$(hostname -s)/ +DEFAULT_SPEEDTEST_DIR = os.path.join(HOME, "data", "hosts", HOSTNAME_SHORT) + +SPEEDTEST_SSID_FILE = "speedtest-ssid" +SPEEDTEST_DOWNLOAD_FILE = "speedtest-download" +SPEEDTEST_UPLOAD_FILE = "speedtest-upload" +SPEEDTEST_PING_FILE = "speedtest-ping" +SPEEDTEST_FAILURE_FILE = "speedtest-failure" + +# The real sparkline chars are " ▁▂▃▄▅▆▇█", but it's really noisy to +# see the full rectangle when all is well, so lighten that to "|". +SPARKLINE_CHARS = ( + " ▁▂▃▄▅▆▇|" # index 0 = "no signal", 1..7 show increasing height +) + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class PingSample: + """Per-second aggregate over all ping_hosts.""" + + timestamp: int + up_fraction: float # fraction of hosts that responded + latencies_ms: List[float] # latencies for successful pings at this second + + +@dataclasses.dataclass +class SpeedtestSample: + """Single successful speedtest result.""" + + timestamp: int + ssid: str + download_mbps: float + upload_mbps: float + ping_ms: float + + +# --------------------------------------------------------------------------- +# Logging +# --------------------------------------------------------------------------- + + +def setup_logging(log_path: str) -> None: + """Configure file-based logging with required time format.""" + os.makedirs(os.path.dirname(log_path), exist_ok=True) + logging.basicConfig( + level=logging.INFO, + filename=log_path, + filemode="a", + format="%(asctime)s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + +# --------------------------------------------------------------------------- +# Config handling (TOML) +# --------------------------------------------------------------------------- + + +def load_config( + config_path: str, + default_ping_hosts: List[str], + default_speedtest_dir: str, +) -> Dict[str, object]: + """Load configuration from TOML file, applying defaults. + + Expected keys: + ping_hosts: list of hostnames/IPs + speedtest_dir: directory path + + On config error (bad types, empty required values), prints an + error and exits. + + """ + config: Dict[str, object] = { + "ping_hosts": list(default_ping_hosts), + "speedtest_dir": default_speedtest_dir, + } + + if not tomllib: + # tomllib missing; config file must be ignored, but that + # itself isn't a fatal error. + logging.info( + "tomllib not available; skipping config file parsing, using defaults." + ) + return config + + try: + with open(config_path, "rb") as f: + parsed = tomllib.load(f) + logging.info("Loaded config file: %s", config_path) + except FileNotFoundError: + logging.info("Config file not found: %s; using defaults.", config_path) + return config + except Exception as exc: + msg = f"Config error reading {config_path}: {exc}" + logging.error(msg) + print(msg, file=sys.stderr) + sys.exit(1) + + # ping_hosts + if "ping_hosts" in parsed: + value = parsed["ping_hosts"] + if ( + not isinstance(value, list) + or not value + or any(not isinstance(x, str) or not x.strip() for x in value) + ): + msg = "Config error: ping_hosts must be a non-empty list of non-empty strings." + logging.error(msg) + print(msg, file=sys.stderr) + sys.exit(1) + config["ping_hosts"] = value + + # speedtest_dir + if "speedtest_dir" in parsed: + value = parsed["speedtest_dir"] + if not isinstance(value, str) or not value.strip(): + msg = "Config error: speedtest_dir must be a non-empty string." + logging.error(msg) + print(msg, file=sys.stderr) + sys.exit(1) + # Allow ~ expansion etc. + config["speedtest_dir"] = os.path.expanduser(value) + + logging.info( + "Config values after applying defaults and overrides: ping_hosts=%s, speedtest_dir=%s", + config["ping_hosts"], + config["speedtest_dir"], + ) + + return config + + +# --------------------------------------------------------------------------- +# Utility helpers +# --------------------------------------------------------------------------- + + +def ewma(values: List[float], alpha: float) -> Optional[float]: + """Compute EWMA of a list of values; returns None if list is empty.""" + if not values: + return None + ema = None + for v in values: + if ema is None: + ema = v + else: + ema = alpha * v + (1.0 - alpha) * ema + return ema + + +def sparkline_from_values(values: List[Optional[float]]) -> str: + """ + Convert a list of values in [0,1] (or None) to a sparkline string. + + None -> ' ' (gap) + 0..1 -> mapped linearly onto SPARKLINE_CHARS[1..]. + """ + chars = [] + n_levels = ( + len(SPARKLINE_CHARS) - 1 + ) # exclude index 0 which is not used for real data + for v in values: + if v is None: + chars.append(" ") + else: + vv = min(max(v, 0.0), 1.0) + idx = int(round(vv * n_levels)) + idx = max(1, min(n_levels, idx)) + chars.append(SPARKLINE_CHARS[idx]) + return "".join(chars) + + +def truncate_status_file_initial() -> None: + """On startup, write NO_DATA_LENGTH spaces to the status file.""" + try: + with open(NETWORK_STATUS_PATH, "w", encoding="utf-8") as f: + f.write(" " * NO_DATA_LENGTH) + except Exception as exc: + logging.error("Failed to initialize %s: %s", NETWORK_STATUS_PATH, exc) + + +def write_status_atomically(line: str) -> None: + """Write a single line atomically to NETWORK_STATUS_PATH.""" + try: + with open(NETWORK_STATUS_TMP_PATH, "w", encoding="utf-8") as f: + f.write(line) + os.replace(NETWORK_STATUS_TMP_PATH, NETWORK_STATUS_PATH) + except Exception as exc: + logging.error("Failed to write network status file: %s", exc) + + +# --------------------------------------------------------------------------- +# Network identity detection +# --------------------------------------------------------------------------- + + +def detect_default_interface() -> Optional[str]: + """ + Determine the interface used for the default route using `ip route`. + Returns interface name or None on failure. + """ + try: + output = subprocess.check_output( + ["ip", "route", "show", "default"], + stderr=subprocess.DEVNULL, + text=True, + ) + except Exception as exc: + logging.error("Failed to run 'ip route show default': %s", exc) + return None + + for line in output.splitlines(): + line = line.strip() + if not line or not line.startswith("default "): + continue + parts = line.split() + # Example: default via 192.168.1.1 dev wlp3s0 proto ... + if "dev" in parts: + idx = parts.index("dev") + if idx + 1 < len(parts): + return parts[idx + 1] + return None + + +def detect_wifi_ssid(interface: str) -> Optional[str]: + """ + Try to obtain SSID for a wifi interface using iwgetid. + Returns SSID string or None. + """ + try: + output = subprocess.check_output( + ["iwgetid", interface, "--raw"], + stderr=subprocess.DEVNULL, + text=True, + ).strip() + if output: + return output + except Exception: + pass + return None + + +def get_current_network_identifier() -> str: + """ + Determine the current logical network identifier: + + - If default route is via wifi (SSID available): use SSID. + - Otherwise: use interface name (e.g., enp2s0). + + If detection fails, returns "unknown". + """ + iface = detect_default_interface() + if not iface: + logging.error( + "Unable to detect default interface; using 'unknown' network id." + ) + return "unknown" + + ssid = detect_wifi_ssid(iface) + if ssid: + return ssid + + # Fall back to interface name + return iface + + +# --------------------------------------------------------------------------- +# Ping handling +# --------------------------------------------------------------------------- + + +def ping_host_once( + host: str, timeout: float = 1.0 +) -> Tuple[int, Optional[float]]: + """ + Ping a host once using the system 'ping' command. + + Returns (up, latency_ms): + up: 1 if ping succeeded, 0 otherwise + latency_ms: float latency in ms if available, else None + """ + # Using '-c 1' for a single packet, '-n' for numeric, '-w timeout' for timeout in seconds. + try: + # Note: for IPv6 you'd use 'ping6' or 'ping -6', but we stay with 'ping' as per spec. + result = subprocess.run( + [ + "ping", + "-n", + "-c", + "1", + "-w", + str(int(math.ceil(timeout))), + host, + ], + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + text=True, + ) + except Exception as exc: + logging.error("Failed to execute ping for host %s: %s", host, exc) + return 0, None + + if result.returncode != 0: + return 0, None + + # Parse latency from typical ping output 'time=XYZ ms' + latency_ms = None + for line in result.stdout.splitlines(): + line = line.strip() + if "time=" in line and " ms" in line: + # Example: '64 bytes from ...: icmp_seq=1 ttl=63 time=10.3 ms' + try: + # Find substring between 'time=' and ' ms' + t_part = line.split("time=", 1)[1].split(" ", 1)[0] + latency_ms = float(t_part) + except Exception: + latency_ms = None + break + + return 1, latency_ms + + +def collect_ping_sample(ping_hosts: List[str], timestamp: int) -> PingSample: + """ + Ping all hosts once (in sequence) and aggregate into a PingSample. + """ + ups: List[int] = [] + latencies: List[float] = [] + + for host in ping_hosts: + up, latency = ping_host_once(host) + ups.append(up) + if latency is not None: + latencies.append(latency) + + if not ups: + up_fraction = 0.0 + else: + up_fraction = sum(ups) / float(len(ups)) + + return PingSample( + timestamp=timestamp, up_fraction=up_fraction, latencies_ms=latencies + ) + + +def prune_ping_history( + history: Dict[int, PingSample], + now_ts: int, + max_age_seconds: int, +) -> None: + """ + Remove ping samples older than now_ts - max_age_seconds. + """ + cutoff = now_ts - max_age_seconds + for ts in list(history.keys()): + if ts < cutoff: + del history[ts] + + +def collect_latencies_in_window( + history: Dict[int, PingSample], start_ts: int, end_ts: int +) -> List[float]: + """Collect all latencies from ping samples between start_ts and end_ts inclusive.""" + result: List[float] = [] + for ts, sample in history.items(): + if start_ts <= ts <= end_ts: + result.extend(sample.latencies_ms) + return sorted( + result, key=float + ) # sorted just to have deterministic EWMA iteration order + + +def build_ping_sparkline_and_latency( + history: Dict[int, PingSample], + now_ts: int, + bucket_seconds: int, + num_buckets: int, + ewma_alpha: float, +) -> Tuple[str, Optional[float]]: + """ + Build sparkline and EWMA latency for a given window. + + For sparkline: + - Window length = bucket_seconds * num_buckets. + - The oldest bucket maps to left-most character; newest to right-most. + - Each bucket aggregates the up_fraction across its per-second samples. + + For latency: + - EWMA over all latencies in the whole window. + """ + window_length = bucket_seconds * num_buckets + window_start = now_ts - window_length + 1 + bucket_values: List[Optional[float]] = [] + + for i in range(num_buckets): + bucket_start = window_start + i * bucket_seconds + bucket_end = bucket_start + bucket_seconds - 1 + # Collect per-second up_fraction values within this bucket + ups: List[float] = [ + s.up_fraction + for ts, s in history.items() + if bucket_start <= ts <= bucket_end + ] + if ups: + bucket_values.append(sum(ups) / float(len(ups))) + else: + bucket_values.append(None) + + spark = sparkline_from_values(bucket_values) + + # Latency EWMA over entire window across all hosts + latencies = collect_latencies_in_window(history, window_start, now_ts) + latency_ewma = ewma(latencies, ewma_alpha) if latencies else None + + return spark, latency_ewma + + +# --------------------------------------------------------------------------- +# Speedtest handling +# --------------------------------------------------------------------------- + + +class SpeedtestReader: + """ + Handles reading and caching speedtest data from a directory. + + - Uses os.stat to detect file changes. + - Reads only the last line of each file when needed. + - Maintains in-memory list of successful speedtests (SpeedtestSample). + """ + + def __init__(self, base_dir: str) -> None: + self.base_dir = base_dir + self.samples: List[SpeedtestSample] = [] + self.last_success_timestamp: Optional[int] = None + self.last_mtimes: Dict[str, float] = {} + self.bad_read_streak: int = 0 + self.suppressed: bool = False # true after 3 consecutive bad reads + + def _path(self, fname: str) -> str: + return os.path.join(self.base_dir, fname) + + @staticmethod + def _read_last_line(path: str) -> Optional[str]: + """ + Efficiently read the last line of a text file using seek. + + Returns the line as a decoded string (without trailing newline), + or None if the file is empty or cannot be read. + """ + try: + with open(path, "rb") as f: + f.seek(0, os.SEEK_END) + size = f.tell() + if size == 0: + return None + # Move backward until we find a newline or reach start + pos = size - 1 + while pos >= 0: + f.seek(pos) + ch = f.read(1) + if ch == b"\n" and pos != size - 1: + # read remainder of file + break + pos -= 1 + if pos < 0: + f.seek(0) + line = f.readline().decode("utf-8", errors="replace") + return line.rstrip("\r\n") + except FileNotFoundError: + logging.info("Speedtest file not found: %s", path) + return None + except Exception as exc: + logging.error("Error reading last line from %s: %s", path, exc) + return None + + def _stat_files(self) -> Optional[Dict[str, os.stat_result]]: + """ + Stat the primary speedtest files (ssid, download, upload, ping). + + Returns dict name->stat_result or None if any required file is missing. + """ + files = [ + SPEEDTEST_SSID_FILE, + SPEEDTEST_DOWNLOAD_FILE, + SPEEDTEST_UPLOAD_FILE, + SPEEDTEST_PING_FILE, + ] + stats: Dict[str, os.stat_result] = {} + for fname in files: + path = self._path(fname) + try: + stats[fname] = os.stat(path) + except FileNotFoundError: + logging.info("Speedtest file missing: %s", path) + return None + except Exception as exc: + logging.error( + "Error stat'ing speedtest file %s: %s", path, exc + ) + return None + return stats + + def maybe_update(self) -> None: + """ + If the speedtest files have changed, read and append any new successful + SpeedtestSample. + + Sets self.suppressed to True if 3 consecutive changes cannot be parsed + or are inconsistent. Clears suppression when successfully updated. + """ + stats = self._stat_files() + if stats is None: + # Required files missing; treat as no-update, but potentially as a bad read. + self.bad_read_streak += 1 + if self.bad_read_streak >= 3: + self.suppressed = True + return + + # Check if any file mtime changed + mtimes = {name: st.st_mtime for name, st in stats.items()} + changed = any( + self.last_mtimes.get(name) != st.st_mtime + for name, st in stats.items() + ) + + if not changed: + # Nothing new + self.bad_read_streak = 0 + return + + # Ensure mtimes are within 1 second to avoid reading while being rewritten + mtime_values = list(mtimes.values()) + if max(mtime_values) - min(mtime_values) > 1.0: + logging.warning( + "Speedtest files mtimes not aligned (likely being rewritten); skipping this iteration." + ) + self.bad_read_streak += 1 + if self.bad_read_streak >= 3: + self.suppressed = True + return + + # Read last lines + path_ssid = self._path(SPEEDTEST_SSID_FILE) + path_down = self._path(SPEEDTEST_DOWNLOAD_FILE) + path_up = self._path(SPEEDTEST_UPLOAD_FILE) + path_ping = self._path(SPEEDTEST_PING_FILE) + + line_ssid = self._read_last_line(path_ssid) + line_down = self._read_last_line(path_down) + line_up = self._read_last_line(path_up) + line_ping = self._read_last_line(path_ping) + + if not all([line_ssid, line_down, line_up, line_ping]): + logging.warning( + "Speedtest files have empty or unreadable last lines; skipping." + ) + self.bad_read_streak += 1 + if self.bad_read_streak >= 3: + self.suppressed = True + return + + try: + # SSID line: " " + parts_ssid = line_ssid.strip().split(maxsplit=1) + ts_ssid = int(parts_ssid[0]) + ssid = parts_ssid[1] if len(parts_ssid) > 1 else "" + + ts_down_str, down_str = line_down.strip().split(maxsplit=1) + ts_up_str, up_str = line_up.strip().split(maxsplit=1) + ts_ping_str, ping_str = line_ping.strip().split(maxsplit=1) + + ts_down = int(ts_down_str) + ts_up = int(ts_up_str) + ts_ping = int(ts_ping_str) + + # Ensure timestamps match + if not (ts_ssid == ts_down == ts_up == ts_ping): + logging.warning( + "Speedtest timestamps mismatch: ssid=%d, down=%d, up=%d, ping=%d; skipping.", + ts_ssid, + ts_down, + ts_up, + ts_ping, + ) + self.bad_read_streak += 1 + if self.bad_read_streak >= 3: + self.suppressed = True + return + + download_mbps = float(down_str) + upload_mbps = float(up_str) + ping_ms = float(ping_str) + + except Exception as exc: + logging.error("Failed parsing speedtest lines: %s", exc) + self.bad_read_streak += 1 + if self.bad_read_streak >= 3: + self.suppressed = True + return + + # Check if this is a new sample + ts = ts_ssid + if ( + self.last_success_timestamp is not None + and ts <= self.last_success_timestamp + ): + # Not newer; maybe files were touched without new data + self.bad_read_streak = 0 + self.last_mtimes = mtimes + return + + # New successful sample + self.samples.append( + SpeedtestSample( + timestamp=ts, + ssid=ssid, + download_mbps=download_mbps, + upload_mbps=upload_mbps, + ping_ms=ping_ms, + ) + ) + self.last_success_timestamp = ts + self.last_mtimes = mtimes + self.bad_read_streak = 0 + self.suppressed = False + logging.info( + "Recorded speedtest sample at %d for SSID '%s': down=%.2f Mbps, up=%.2f Mbps, ping=%.2f ms", + ts, + ssid, + download_mbps, + upload_mbps, + ping_ms, + ) + + def drop_all_samples(self) -> None: + """Clear all cached speedtest samples (e.g., on network change).""" + self.samples.clear() + self.last_success_timestamp = None + # Do not reset last_mtimes; we want to detect changes when files are updated again. + + def compute_bandwidth_summary( + self, + current_network_id: str, + now_ts: int, + margin_minutes: int, + aggregation: str, + ewma_alpha: float, + ) -> Tuple[str, bool]: + """ + Compute the bandwidth summary string for the current network. + + Returns (summary_string, has_data_for_network). + + If self.suppressed is True or there is no qualifying data for current_network_id + in the time window, the string is "- down / - up / - ms". + """ + if self.suppressed: + logging.warning( + "Speedtest data suppressed due to repeated read errors; writing empty bandwidth section." + ) + return "- down / - up / - ms", False + + margin_seconds = margin_minutes * 60 + cutoff = now_ts - margin_seconds + recent_samples = [s for s in self.samples if s.timestamp >= cutoff] + + # Filter by current network identifier (SSID or interface name) + matching = [s for s in recent_samples if s.ssid == current_network_id] + + if not matching: + # No qualifying data + # Log what we were looking for and what we found + all_ssids = sorted({s.ssid for s in recent_samples}) + logging.info( + "No speedtest data for network '%s' in last %d minutes (cutoff=%d). " + "Available SSIDs/networks in that period: %s", + current_network_id, + margin_minutes, + cutoff, + ", ".join(all_ssids) if all_ssids else "(none)", + ) + return "- down / - up / - ms", False + + # Extract metrics + downs = [s.download_mbps for s in matching] + ups = [s.upload_mbps for s in matching] + pings = [s.ping_ms for s in matching] + + if aggregation == "mean": + + def _mean(vals: List[float]) -> float: + return sum(vals) / float(len(vals)) if vals else 0.0 + + down_val = _mean(downs) + up_val = _mean(ups) + ping_val = _mean(pings) + else: # ewma + down_val = ewma(downs, ewma_alpha) or 0.0 + up_val = ewma(ups, ewma_alpha) or 0.0 + ping_val = ewma(pings, ewma_alpha) or 0.0 + + summary = f"{down_val:.1f} down / {up_val:.1f} up / {ping_val:.1f} ms" + return summary, True + + +# --------------------------------------------------------------------------- +# Status line assembly +# --------------------------------------------------------------------------- + + +def format_latency_ms(latency: Optional[float]) -> str: + """Format latency as string suitable for '(X.Y ms)'; '-' if None.""" + if latency is None: + return "-" + return f"{latency:.1f}" + + +def build_status_line( + ping_history: Dict[int, PingSample], + now_ts: int, + long_interval: int, + ewma_alpha: float, + speedtest_reader: SpeedtestReader, + current_network_id: str, + speedtest_margin_minutes: int, + speedtest_aggregation: str, +) -> str: + """ + Build the full status line: + + SHORT_SPARK(E) (latency ms) | LONG_SPARK(E) (latency ms) | BW_STRING + """ + # Short section: last 10 seconds, 1 second per bucket + short_spark, short_latency = build_ping_sparkline_and_latency( + ping_history, + now_ts, + bucket_seconds=1, + num_buckets=10, + ewma_alpha=ewma_alpha, + ) + short_latency_str = format_latency_ms(short_latency) + short_section = f"{short_spark:10s} ({short_latency_str} ms)" + + # Long section: last 10 * long_interval seconds, non-overlapping buckets + long_spark, long_latency = build_ping_sparkline_and_latency( + ping_history, + now_ts, + bucket_seconds=long_interval, + num_buckets=10, + ewma_alpha=ewma_alpha, + ) + long_latency_str = format_latency_ms(long_latency) + long_section = f"{long_spark:10s} ({long_latency_str} ms)" + + # Bandwidth section + bw_string, _has_bw_data = speedtest_reader.compute_bandwidth_summary( + current_network_id=current_network_id, + now_ts=now_ts, + margin_minutes=speedtest_margin_minutes, + aggregation=speedtest_aggregation, + ewma_alpha=ewma_alpha, + ) + + # return f"{short_section} | {long_section} | {bw_string}" + return f"{short_section} {long_section} {bw_string}" + + +# --------------------------------------------------------------------------- +# Main loop +# --------------------------------------------------------------------------- + + +def parse_args(argv: Optional[List[str]] = None) -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Monitor network pings and recent bandwidth for sway/i3status." + ) + parser.add_argument( + "--stdout", + action="store_true", + help=( + "Write status line to stdout instead of file; " + "if set, ~/.network-status is not modified." + ), + ) + parser.add_argument( + "--long", + type=int, + default=DEFAULT_LONG_INTERVAL, + help="Number of seconds per interval in the long section (default: 15).", + ) + parser.add_argument( + "--speedtest-margin", + type=int, + default=DEFAULT_SPEEDTEST_MARGIN_MINUTES, + help="Number of minutes of speedtest data to consider (default: 40).", + ) + parser.add_argument( + "--speedtest-aggregation", + choices=["mean", "ewma"], + default=DEFAULT_SPEEDTEST_AGGREGATION, + help="Aggregation method for speedtest values: mean or ewma (default: ewma).", + ) + parser.add_argument( + "--ewma-alpha", + type=float, + default=DEFAULT_EWMA_ALPHA, + help="Alpha parameter for EWMA computations (default: 0.3).", + ) + parser.add_argument( + "--one-shot", + action="store_true", + help=( + "Run only once and quit (implies --stdout). " + "Collect data for a limited time and then output a single status line." + ), + ) + parser.add_argument( + "--one-shot-collect-time", + type=int, + default=DEFAULT_ONE_SHOT_COLLECT_TIME, + help=( + "Seconds to accumulate data in one-shot mode (default: 10, " + "capped at 10 * --long)." + ), + ) + parser.add_argument( + "--log", + type=str, + default=DEFAULT_LOG_PATH, + help=f"Path to log file (default: {DEFAULT_LOG_PATH}).", + ) + parser.add_argument( + "--config", + type=str, + default=DEFAULT_CONFIG_PATH, + help=f"Path to config file (default: {DEFAULT_CONFIG_PATH}).", + ) + return parser.parse_args(argv) + + +def main(argv: Optional[List[str]] = None) -> None: + args = parse_args(argv) + + # Initialize logging first, so we can log subsequent decisions. + setup_logging(args.log) + logging.info("Starting sway-network-monitor.") + logging.info("Command-line arguments: %s", vars(args)) + + # Handle one-shot -> implies stdout + effective_stdout = bool(args.stdout or args.one_shot) + + # Cap one-shot collect time if needed (this must be logged) + if args.one_shot: + max_collect = 10 * args.long + collect_time = args.one_shot_collect_time + if collect_time > max_collect: + logging.info( + "one-shot-collect-time (%d) > 10 * long (%d); capping to %d seconds.", + collect_time, + args.long, + max_collect, + ) + collect_time = max_collect + args.one_shot_collect_time = collect_time + logging.info( + "One-shot mode: collecting data for %d seconds.", collect_time + ) + + # Load config (ping_hosts, speedtest_dir) + config = load_config( + config_path=args.config, + default_ping_hosts=DEFAULT_PING_HOSTS, + default_speedtest_dir=DEFAULT_SPEEDTEST_DIR, + ) + + ping_hosts: List[str] = config["ping_hosts"] # type: ignore[assignment] + speedtest_dir: str = config["speedtest_dir"] # type: ignore[assignment] + + logging.info("Using ping_hosts=%s", ping_hosts) + logging.info("Using speedtest_dir=%s", speedtest_dir) + logging.info("Using ewma_alpha=%.3f", args.ewma_alpha) + logging.info("Using speedtest_margin=%d minutes", args.speedtest_margin) + logging.info("Using speedtest_aggregation=%s", args.speedtest_aggregation) + logging.info("Using long_interval=%d seconds", args.long) + + # Initialize status file with spaces (unless writing only to stdout) + if not effective_stdout: + truncate_status_file_initial() + + # Initialize ping history & speedtest reader + ping_history: Dict[int, PingSample] = {} + max_ping_window = args.long * 10 # seconds for long section + speedtest_reader = SpeedtestReader(speedtest_dir) + + # Detect initial network identifier + current_network_id = get_current_network_identifier() + logging.info("Initial network identifier: %s", current_network_id) + + # Main logic split: one-shot vs continuous + if args.one_shot: + run_one_shot( + ping_hosts=ping_hosts, + long_interval=args.long, + ewma_alpha=args.ewma_alpha, + speedtest_reader=speedtest_reader, + speedtest_margin_minutes=args.speedtest_margin, + speedtest_aggregation=args.speedtest_aggregation, + collect_time=args.one_shot_collect_time, + ping_history=ping_history, + max_ping_window=max_ping_window, + ) + return + + run_continuous( + ping_hosts=ping_hosts, + long_interval=args.long, + ewma_alpha=args.ewma_alpha, + speedtest_reader=speedtest_reader, + speedtest_margin_minutes=args.speedtest_margin, + speedtest_aggregation=args.speedtest_aggregation, + ping_history=ping_history, + max_ping_window=max_ping_window, + effective_stdout=effective_stdout, + ) + + +def run_one_shot( + ping_hosts: List[str], + long_interval: int, + ewma_alpha: float, + speedtest_reader: SpeedtestReader, + speedtest_margin_minutes: int, + speedtest_aggregation: str, + collect_time: int, + ping_history: Dict[int, PingSample], + max_ping_window: int, +) -> None: + """ + One-shot mode: collect data for a specified number of seconds, then + emit a single status line to stdout and exit. + """ + start_time = time.time() + end_time = start_time + collect_time + + logging.info("Entering one-shot data collection loop.") + + while True: + loop_start = time.time() + now_ts = int(loop_start) + + # Detect current network each iteration + current_network_id = get_current_network_identifier() + + # If network changed -> drop history + # In one-shot mode, this simply restarts accumulation + # (previous data would be from a different network). + # For a short collect window this is acceptable. + # Note: we track last_network_id locally. + if not ping_history: + last_network_id = current_network_id + else: + last_network_id = ( + current_network_id # effectively ignore old network + ) + + # Ping all hosts + sample = collect_ping_sample(ping_hosts, now_ts) + ping_history[now_ts] = sample + prune_ping_history(ping_history, now_ts, max_ping_window) + + # Update speedtest reader (if new data) + speedtest_reader.maybe_update() + + # Sleep until next second or end of collection + elapsed = time.time() - loop_start + remaining = 1.0 - elapsed + if remaining > 0: + time.sleep(remaining) + + if time.time() >= end_time: + break + + # Final status line after collection + final_now_ts = int(time.time()) + current_network_id = get_current_network_identifier() + status_line = build_status_line( + ping_history=ping_history, + now_ts=final_now_ts, + long_interval=long_interval, + ewma_alpha=ewma_alpha, + speedtest_reader=speedtest_reader, + current_network_id=current_network_id, + speedtest_margin_minutes=speedtest_margin_minutes, + speedtest_aggregation=speedtest_aggregation, + ) + # Only output is this single line + print(status_line) + + +def run_continuous( + ping_hosts: List[str], + long_interval: int, + ewma_alpha: float, + speedtest_reader: SpeedtestReader, + speedtest_margin_minutes: int, + speedtest_aggregation: str, + ping_history: Dict[int, PingSample], + max_ping_window: int, + effective_stdout: bool, +) -> None: + """ + Continuous mode: loop forever, updating ping and speedtest data + approximately once per second and writing status lines. + """ + logging.info("Entering continuous monitoring loop.") + last_network_id: Optional[str] = None + + while True: + loop_start = time.time() + now_ts = int(loop_start) + + # Detect current network + current_network_id = get_current_network_identifier() + if current_network_id != last_network_id: + if last_network_id is not None: + logging.info( + "Network identifier changed from '%s' to '%s'; dropping cached data.", + last_network_id, + current_network_id, + ) + else: + logging.info( + "Network identifier detected: '%s'", current_network_id + ) + last_network_id = current_network_id + + # Drop ping history & speedtest samples when network changes + ping_history.clear() + speedtest_reader.drop_all_samples() + + # Collect ping data + sample = collect_ping_sample(ping_hosts, now_ts) + ping_history[now_ts] = sample + prune_ping_history(ping_history, now_ts, max_ping_window) + + # Update speedtest data (if files changed) + speedtest_reader.maybe_update() + + # Build status line + status_line = build_status_line( + ping_history=ping_history, + now_ts=now_ts, + long_interval=long_interval, + ewma_alpha=ewma_alpha, + speedtest_reader=speedtest_reader, + current_network_id=current_network_id, + speedtest_margin_minutes=speedtest_margin_minutes, + speedtest_aggregation=speedtest_aggregation, + ) + + # Write out + if effective_stdout: + print(status_line, flush=True) + else: + write_status_atomically(status_line) + + # Sleep to maintain ~1-second period + elapsed = time.time() - loop_start + remaining = 1.0 - elapsed + if remaining > 0: + time.sleep(remaining) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + logging.info("Received KeyboardInterrupt; exiting.") + except Exception as exc: + logging.exception( + "Unhandled exception in sway-network-monitor: %s", exc + ) + # Let systemd or other supervisor restart if needed. + sys.exit(1) diff --git a/sway/sway/config_base b/sway/sway/config_base index b51b309..084ebfc 100644 --- a/sway/sway/config_base +++ b/sway/sway/config_base @@ -39,6 +39,7 @@ include /etc/sway/config-vars.d/* # Run my personal daemons. exec --no-startup-id sway-env-stats exec --no-startup-id sway-power-monitor 30 daemon +exec --no-startup-id sway-network-status # Pick a random image from my desktop image files. # The '/*' is necessary to force a full path on output. diff --git a/tests/test_bandwidth.py b/tests/test_bandwidth.py index 79a73a6..2e9a950 100644 --- a/tests/test_bandwidth.py +++ b/tests/test_bandwidth.py @@ -8,13 +8,17 @@ import pytest -# Make the repository's python utilities importable when tests execute directly. +# Make repo python utilities importable when tests run directly. PROJECT_ROOT = Path(__file__).resolve().parents[1] PYTHON_DIR = PROJECT_ROOT / "python" if str(PYTHON_DIR) not in sys.path: sys.path.insert(0, str(PYTHON_DIR)) -from bandwidth_tool import ( +import bandwidth_tool # noqa: E402 +from bandwidth_tool import ( # noqa: E402 + current_network_identifier, + detect_default_interface, + detect_wifi_ssid, limit_measurements, load_measurements, render_stats, @@ -54,6 +58,68 @@ def write_file(name: str, values: list[str]) -> None: return tmp_path +def test_detect_default_interface_parses_output( + monkeypatch: pytest.MonkeyPatch, +) -> None: + def fake_check_output( + cmd: list[str], *, stderr: object, text: bool + ) -> str: + assert cmd == ["ip", "route", "show", "default"] + return "default via 192.168.1.1 dev wlp3s0 proto dhcp metric 600\n" + + monkeypatch.setattr(subprocess, "check_output", fake_check_output) + + assert detect_default_interface() == "wlp3s0" + + +def test_detect_wifi_ssid(monkeypatch: pytest.MonkeyPatch) -> None: + def fake_check_output( + cmd: list[str], *, stderr: object, text: bool + ) -> str: + assert cmd == ["iwgetid", "wlp3s0", "--raw"] + return "OfficeWifi\n" + + monkeypatch.setattr(subprocess, "check_output", fake_check_output) + + assert detect_wifi_ssid("wlp3s0") == "OfficeWifi" + + +def test_current_network_identifier_prefers_ssid( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr( + bandwidth_tool, "detect_default_interface", lambda: "wlp3s0" + ) + monkeypatch.setattr( + bandwidth_tool, "detect_wifi_ssid", lambda interface: "OfficeWifi" + ) + + assert current_network_identifier() == "OfficeWifi" + + +def test_current_network_identifier_falls_back_to_interface( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr( + bandwidth_tool, "detect_default_interface", lambda: "enp0s1" + ) + monkeypatch.setattr( + bandwidth_tool, "detect_wifi_ssid", lambda interface: None + ) + + assert current_network_identifier() == "enp0s1" + + +def test_current_network_identifier_unknown( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr( + bandwidth_tool, "detect_default_interface", lambda: None + ) + + assert current_network_identifier() == "unknown" + + def test_render_table_with_limit(sample_data: Path) -> None: measurements = load_measurements(sample_data) limited = limit_measurements(measurements, 2) From 4818f0746c9439b35d78bede5bf4172dcb942ac7 Mon Sep 17 00:00:00 2001 From: Jeff Abrahamson Date: Sun, 7 Dec 2025 21:01:28 +0100 Subject: [PATCH 2/3] Fix logging path handling and one-shot network resets * Avoid creating logging directories when the log path omits a parent folder. * Reset one-shot collection state when the network identifier changes to avoid mixing samples. --- sway/sway/bin/sway-network-status | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/sway/sway/bin/sway-network-status b/sway/sway/bin/sway-network-status index b2bf10c..ddda5bb 100755 --- a/sway/sway/bin/sway-network-status +++ b/sway/sway/bin/sway-network-status @@ -99,7 +99,9 @@ class SpeedtestSample: def setup_logging(log_path: str) -> None: """Configure file-based logging with required time format.""" - os.makedirs(os.path.dirname(log_path), exist_ok=True) + log_dir = os.path.dirname(log_path) + if log_dir: + os.makedirs(log_dir, exist_ok=True) logging.basicConfig( level=logging.INFO, filename=log_path, @@ -1003,6 +1005,7 @@ def run_one_shot( end_time = start_time + collect_time logging.info("Entering one-shot data collection loop.") + last_network_id: Optional[str] = None while True: loop_start = time.time() @@ -1014,14 +1017,17 @@ def run_one_shot( # If network changed -> drop history # In one-shot mode, this simply restarts accumulation # (previous data would be from a different network). - # For a short collect window this is acceptable. - # Note: we track last_network_id locally. - if not ping_history: - last_network_id = current_network_id - else: - last_network_id = ( - current_network_id # effectively ignore old network + if last_network_id is None: + logging.info("Network identifier detected: '%s'", current_network_id) + elif current_network_id != last_network_id: + logging.info( + "Network identifier changed from '%s' to '%s'; dropping cached data.", + last_network_id, + current_network_id, ) + ping_history.clear() + speedtest_reader.drop_all_samples() + last_network_id = current_network_id # Ping all hosts sample = collect_ping_sample(ping_hosts, now_ts) From 8d8f84c3e96daf8be6703726107cd6a1ae2773f5 Mon Sep 17 00:00:00 2001 From: Jeff Abrahamson Date: Sun, 7 Dec 2025 21:14:39 +0100 Subject: [PATCH 3/3] Add CI workflow for linting and tests And fix formatting so everything passes. --- .github/workflows/ci.yml | 60 ++++++++++++++++++++++ python/bandwidth_tool/__init__.py | 85 ++++++++++++++++++++++++------- python/bandwidth_tool/tabular.py | 20 ++++++-- 3 files changed, 143 insertions(+), 22 deletions(-) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..7b2d387 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,60 @@ +name: CI + +on: + push: + branches: + - main + pull_request: + workflow_dispatch: + +jobs: + lint: + name: Lint Python + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install black flake8 + + - name: Check formatting with black + run: black --check --verbose --line-length 79 bin/bin/tsd-plot.py src/tsd_plot tests python + + - name: Lint with flake8 + run: | + flake8 --tee --output-file flake8.report bin/bin/tsd-plot.py src/tsd_plot tests python + + - name: Upload flake8 report + if: always() + uses: actions/upload-artifact@v4 + with: + name: flake8-report + path: flake8.report + + tests: + name: Run Tests + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest + + - name: Run pytest + run: pytest diff --git a/python/bandwidth_tool/__init__.py b/python/bandwidth_tool/__init__.py index e821d2a..32a317e 100644 --- a/python/bandwidth_tool/__init__.py +++ b/python/bandwidth_tool/__init__.py @@ -1,4 +1,5 @@ """Utilities for inspecting recorded bandwidth measurements.""" + from __future__ import annotations import socket @@ -79,7 +80,9 @@ class Measurement: ssid: Optional[str] = None def as_row(self) -> Sequence[str]: - dt = datetime.fromtimestamp(self.timestamp).strftime("%Y-%m-%d %H:%M:%S") + dt = datetime.fromtimestamp(self.timestamp).strftime( + "%Y-%m-%d %H:%M:%S" + ) return ( dt, format_number(self.upload), @@ -149,7 +152,9 @@ def load_measurements(directory: Path) -> List[Measurement]: pings = _parse_numeric_file(paths["ping"]) ssids = _parse_text_file(paths["ssid"]) - timestamps = sorted({*uploads.keys(), *downloads.keys(), *pings.keys(), *ssids.keys()}) + timestamps = sorted( + {*uploads.keys(), *downloads.keys(), *pings.keys(), *ssids.keys()} + ) measurements: List[Measurement] = [] for ts in timestamps: measurements.append( @@ -164,7 +169,9 @@ def load_measurements(directory: Path) -> List[Measurement]: return measurements -def limit_measurements(measurements: Sequence[Measurement], limit: int) -> List[Measurement]: +def limit_measurements( + measurements: Sequence[Measurement], limit: int +) -> List[Measurement]: if limit <= 0: return [] return list(measurements[-limit:]) @@ -173,10 +180,14 @@ def limit_measurements(measurements: Sequence[Measurement], limit: int) -> List[ def render_table(measurements: Iterable[Measurement]) -> str: rows = [measurement.as_row() for measurement in measurements] headers = ("datetime", "upload MiBps", "download MiBps", "ping ms", "ssid") - return format_table(headers, rows, colalign=("left", "right", "right", "right", "left")) + return format_table( + headers, rows, colalign=("left", "right", "right", "right", "left") + ) -def _collect_values(measurements: Iterable[Measurement], attribute: str) -> List[float]: +def _collect_values( + measurements: Iterable[Measurement], attribute: str +) -> List[float]: values: List[float] = [] for measurement in measurements: value = getattr(measurement, attribute) @@ -202,7 +213,9 @@ def _determine_edges(values: Sequence[float], bins: int = 10) -> List[float]: return edges -def _histogram_from_edges(values: Sequence[float], edges: Sequence[float]) -> List[int]: +def _histogram_from_edges( + values: Sequence[float], edges: Sequence[float] +) -> List[int]: counts = [0 for _ in range(len(edges) - 1)] if not values: return counts @@ -231,7 +244,9 @@ def _format_range(start: float, end: float, is_last: bool) -> str: return f"[{start:7.2f}, {end:7.2f}{right}" -def _bar(count: int, max_count: int, width: int, *, reverse: bool = False) -> str: +def _bar( + count: int, max_count: int, width: int, *, reverse: bool = False +) -> str: if max_count <= 0 or count <= 0: bar = "" else: @@ -242,7 +257,11 @@ def _bar(count: int, max_count: int, width: int, *, reverse: bool = False) -> st return bar.ljust(width) -def _render_violin_text(edges: Sequence[float], upload_counts: Sequence[int], download_counts: Sequence[int]) -> List[str]: +def _render_violin_text( + edges: Sequence[float], + upload_counts: Sequence[int], + download_counts: Sequence[int], +) -> List[str]: width = 16 lines = ["Upload/Download speeds (MiBps)"] lines.append("upload".rjust(width) + " │ " + "download".ljust(width)) @@ -256,7 +275,9 @@ def _render_violin_text(edges: Sequence[float], upload_counts: Sequence[int], do return lines -def _render_ping_text(edges: Sequence[float], counts: Sequence[int]) -> List[str]: +def _render_ping_text( + edges: Sequence[float], counts: Sequence[int] +) -> List[str]: width = 32 lines = ["Ping times (ms)"] max_count = max([*counts, 0]) @@ -268,7 +289,9 @@ def _render_ping_text(edges: Sequence[float], counts: Sequence[int]) -> List[str return lines -def render_stats_text(measurements: Sequence[Measurement], bins: int = 10) -> str: +def render_stats_text( + measurements: Sequence[Measurement], bins: int = 10 +) -> str: uploads = _collect_values(measurements, "upload") downloads = _collect_values(measurements, "download") pings = _collect_values(measurements, "ping") @@ -279,7 +302,9 @@ def render_stats_text(measurements: Sequence[Measurement], bins: int = 10) -> st edges = _determine_edges([*uploads, *downloads], bins=bins) upload_counts = _histogram_from_edges(uploads, edges) download_counts = _histogram_from_edges(downloads, edges) - lines.extend(_render_violin_text(edges, upload_counts, download_counts)) + lines.extend( + _render_violin_text(edges, upload_counts, download_counts) + ) else: lines.append("No upload/download data available.") @@ -295,11 +320,17 @@ def render_stats_text(measurements: Sequence[Measurement], bins: int = 10) -> st return "\n".join(lines) -def render_stats_graphical(measurements: Sequence[Measurement], bins: int = 10) -> None: +def render_stats_graphical( + measurements: Sequence[Measurement], bins: int = 10 +) -> None: try: import matplotlib.pyplot as plt # type: ignore - except ImportError as exc: # pragma: no cover - depends on optional dependency - raise RuntimeError("Matplotlib is required for graphical statistics") from exc + except ( + ImportError + ) as exc: # pragma: no cover - depends on optional dependency + raise RuntimeError( + "Matplotlib is required for graphical statistics" + ) from exc uploads = _collect_values(measurements, "upload") downloads = _collect_values(measurements, "download") @@ -313,8 +344,22 @@ def render_stats_graphical(measurements: Sequence[Measurement], bins: int = 10) centers = _bin_centers(edges) heights = [edges[i + 1] - edges[i] for i in range(len(edges) - 1)] - ax_speed.barh(centers, upload_counts, height=heights, align="center", color="tab:blue", label="Upload") - ax_speed.barh(centers, [-count for count in download_counts], height=heights, align="center", color="tab:orange", label="Download") + ax_speed.barh( + centers, + upload_counts, + height=heights, + align="center", + color="tab:blue", + label="Upload", + ) + ax_speed.barh( + centers, + [-count for count in download_counts], + height=heights, + align="center", + color="tab:orange", + label="Download", + ) ax_speed.axvline(0, color="black", linewidth=0.8) ax_speed.set_xlabel("Sample count") ax_speed.set_ylabel("MiBps") @@ -324,7 +369,9 @@ def render_stats_graphical(measurements: Sequence[Measurement], bins: int = 10) ping_edges = _determine_edges(pings, bins=bins) ping_counts = _histogram_from_edges(pings, ping_edges) ping_centers = _bin_centers(ping_edges) - widths = [ping_edges[i + 1] - ping_edges[i] for i in range(len(ping_edges) - 1)] + widths = [ + ping_edges[i + 1] - ping_edges[i] for i in range(len(ping_edges) - 1) + ] ax_ping.bar(ping_centers, ping_counts, width=widths, color="tab:green") ax_ping.set_xlabel("Ping (ms)") @@ -335,7 +382,9 @@ def render_stats_graphical(measurements: Sequence[Measurement], bins: int = 10) plt.show() -def render_stats(measurements: Sequence[Measurement], *, text: bool, bins: int = 10) -> Optional[str]: +def render_stats( + measurements: Sequence[Measurement], *, text: bool, bins: int = 10 +) -> Optional[str]: if text: return render_stats_text(measurements, bins=bins) render_stats_graphical(measurements, bins=bins) diff --git a/python/bandwidth_tool/tabular.py b/python/bandwidth_tool/tabular.py index 8a58d39..c4c6bfb 100644 --- a/python/bandwidth_tool/tabular.py +++ b/python/bandwidth_tool/tabular.py @@ -5,6 +5,7 @@ or test time. The :func:`format_table` helper aligns text in a way that is compatible with traditional command line utilities such as ``column``. """ + from __future__ import annotations from typing import Iterable, List, Sequence @@ -22,7 +23,12 @@ def _align_cell(text: str, width: int, align: str) -> str: return text.ljust(width) -def format_table(headers: Sequence[str], rows: Iterable[Row], *, colalign: Alignment | None = None) -> str: +def format_table( + headers: Sequence[str], + rows: Iterable[Row], + *, + colalign: Alignment | None = None, +) -> str: """Return a string representing a table with aligned columns. Parameters @@ -49,17 +55,23 @@ def format_table(headers: Sequence[str], rows: Iterable[Row], *, colalign: Align widths = [len(headers[i]) for i in range(num_cols)] for row in row_list: if len(row) != num_cols: - raise ValueError("row has different number of columns than headers") + raise ValueError( + "row has different number of columns than headers" + ) for i, cell in enumerate(row): widths[i] = max(widths[i], len(cell)) header_line = " ".join( - _align_cell(headers[i], widths[i], colalign[i]) for i in range(num_cols) + _align_cell(headers[i], widths[i], colalign[i]) + for i in range(num_cols) ) divider = " ".join("-" * widths[i] for i in range(num_cols)) body_lines = [ - " ".join(_align_cell(cell, widths[i], colalign[i]) for i, cell in enumerate(row)) + " ".join( + _align_cell(cell, widths[i], colalign[i]) + for i, cell in enumerate(row) + ) for row in row_list ]