From 6c2745337dc129db96d831467511f09620096444 Mon Sep 17 00:00:00 2001 From: Mayk Thewessen Date: Mon, 6 Apr 2026 16:00:42 +0200 Subject: [PATCH 01/17] feat: add custom analysis pipelines and performance optimizations Port Script_MEE-style dashboard analyses to pypsa-app as a new analysis service layer alongside the existing PyPSA built-in statistics. Adds 7 analysis types (dispatch_area, line_loading_histogram/timeseries, price_duration_curve/timeseries, cross_border_flows, capacity_mix) with canonical carrier colors, dispatch ordering, and country color theming. Performance: orjson for faster JSON serialization in cache and serializer, 128KB file hash chunks, NaN/Inf sanitization in DataFrame serialization, optional Polars for vectorized aggregations on large networks. Co-Authored-By: Claude Opus 4.6 (1M context) --- pyproject.toml | 3 + src/pypsa_app/backend/api/routes/analysis.py | 39 ++ src/pypsa_app/backend/cache.py | 8 +- src/pypsa_app/backend/main.py | 6 + src/pypsa_app/backend/schemas/analysis.py | 42 ++ src/pypsa_app/backend/services/analysis.py | 557 ++++++++++++++++++ src/pypsa_app/backend/services/network.py | 4 +- src/pypsa_app/backend/tasks.py | 8 + src/pypsa_app/backend/utils/allowlists.py | 14 +- src/pypsa_app/backend/utils/carrier_colors.py | 78 +++ src/pypsa_app/backend/utils/serializers.py | 46 +- 11 files changed, 797 insertions(+), 8 deletions(-) create mode 100644 src/pypsa_app/backend/api/routes/analysis.py create mode 100644 src/pypsa_app/backend/schemas/analysis.py create mode 100644 src/pypsa_app/backend/services/analysis.py create mode 100644 src/pypsa_app/backend/utils/carrier_colors.py diff --git a/pyproject.toml b/pyproject.toml index 680cfd5..4d8ae28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,8 @@ dependencies = [ "httpx", "python-multipart", "alembic", + "orjson", + "plotly", ] [project.optional-dependencies] @@ -37,6 +39,7 @@ full = [ "psycopg2-binary", "redis", "celery[redis]", + "polars", ] dev = [ diff --git a/src/pypsa_app/backend/api/routes/analysis.py b/src/pypsa_app/backend/api/routes/analysis.py new file mode 100644 index 0000000..a2722aa --- /dev/null +++ b/src/pypsa_app/backend/api/routes/analysis.py @@ -0,0 +1,39 @@ +"""API routes for custom analysis endpoints (dispatch, line loading, prices, etc.)""" + +import logging + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from pypsa_app.backend.api.deps import get_db, get_networks, require_permission +from pypsa_app.backend.api.utils.task_utils import queue_task +from pypsa_app.backend.models import Permission, User +from pypsa_app.backend.schemas.analysis import AnalysisRequest +from pypsa_app.backend.schemas.task import TaskQueuedResponse +from pypsa_app.backend.tasks import run_analysis_task + +router = APIRouter() +logger = logging.getLogger(__name__) + + +@router.post("/", response_model=TaskQueuedResponse) +def create_analysis( + request: AnalysisRequest, + db: Session = Depends(get_db), + user: User = Depends(require_permission(Permission.NETWORKS_VIEW)), +) -> dict: + """Run a custom analysis on one or more networks. + + Available analysis types: dispatch_area, line_loading_histogram, + line_loading_timeseries, price_duration_curve, price_timeseries, + cross_border_flows, capacity_mix. + """ + networks = get_networks(db, request.network_ids, user) + file_paths = [net.file_path for net in networks] + + return queue_task( + run_analysis_task, + file_paths=file_paths, + analysis_type=request.analysis_type, + parameters=request.parameters, + ) diff --git a/src/pypsa_app/backend/cache.py b/src/pypsa_app/backend/cache.py index 5ba3d6f..2d95a0e 100644 --- a/src/pypsa_app/backend/cache.py +++ b/src/pypsa_app/backend/cache.py @@ -2,13 +2,13 @@ import hashlib import inspect -import json import logging from collections.abc import Callable from functools import wraps from typing import Any from pypsa_app.backend.settings import settings +from pypsa_app.backend.utils.serializers import fast_json_dumps, fast_json_loads try: import redis @@ -50,7 +50,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: # Generate cache key by hashing all parameters cache_hash = hashlib.md5( # noqa: S324 - json.dumps(serializable, sort_keys=True).encode() + fast_json_dumps(serializable).encode() ).hexdigest()[:12] cache_key = f"{key_template.split(':', maxsplit=1)[0]}:{cache_hash}" @@ -90,12 +90,12 @@ def get(self, key: str) -> dict | None: """Get cached data by key""" cached_data = self.redis_client.get(key) if cached_data: - return json.loads(cached_data) + return fast_json_loads(cached_data) return None def set(self, key: str, value: dict, ttl: int) -> bool: """Set cached data with TTL""" - serialized = json.dumps(value) + serialized = fast_json_dumps(value) size_bytes = len(serialized) size_mb = size_bytes / (1024 * 1024) diff --git a/src/pypsa_app/backend/main.py b/src/pypsa_app/backend/main.py index 446034b..9e84cc2 100644 --- a/src/pypsa_app/backend/main.py +++ b/src/pypsa_app/backend/main.py @@ -14,6 +14,7 @@ from pypsa_app.backend.__version__ import __description__, __version__ from pypsa_app.backend.api.routes import ( admin, + analysis, api_keys, auth, cache, @@ -291,6 +292,11 @@ async def global_exception_handler(request: Request, exc: Exception) -> JSONResp networks.router, prefix=f"{API_V1_PREFIX}/networks", tags=["networks"] ) app.include_router(plots.router, prefix=f"{API_V1_PREFIX}/plots", tags=["plots"]) +app.include_router( + analysis.router, + prefix=f"{API_V1_PREFIX}/analysis", + tags=["analysis"], +) app.include_router( statistics.router, prefix=f"{API_V1_PREFIX}/statistics", diff --git a/src/pypsa_app/backend/schemas/analysis.py b/src/pypsa_app/backend/schemas/analysis.py new file mode 100644 index 0000000..b318292 --- /dev/null +++ b/src/pypsa_app/backend/schemas/analysis.py @@ -0,0 +1,42 @@ +"""Request/response schemas for custom analysis endpoints.""" + +from typing import Any + +from pydantic import BaseModel, Field, field_validator + +from pypsa_app.backend.utils.allowlists import ALLOWED_ANALYSIS_TYPES + + +class AnalysisRequest(BaseModel): + """Request for custom analysis (dispatch, line loading, prices, etc.)""" + + network_ids: list[str] = Field( + ..., description="List of network UUIDs" + ) + analysis_type: str = Field( + ..., + description="Analysis type (e.g., 'dispatch_area')", + ) + parameters: dict[str, Any] = Field( + default_factory=dict, + description="Analysis-specific parameters (country, resample, top_n, etc.)", + ) + + @field_validator("network_ids") + @classmethod + def validate_network_ids(cls, v: list[str]) -> list[str]: + if not v: + msg = "At least one network ID is required" + raise ValueError(msg) + return v + + @field_validator("analysis_type") + @classmethod + def validate_analysis_type(cls, v: str) -> str: + if v not in ALLOWED_ANALYSIS_TYPES: + msg = ( + f"Invalid analysis_type '{v}'. " + f"Allowed: {sorted(ALLOWED_ANALYSIS_TYPES)}" + ) + raise ValueError(msg) + return v diff --git a/src/pypsa_app/backend/services/analysis.py b/src/pypsa_app/backend/services/analysis.py new file mode 100644 index 0000000..777c341 --- /dev/null +++ b/src/pypsa_app/backend/services/analysis.py @@ -0,0 +1,557 @@ +"""Custom analysis pipelines for PyPSA networks. + +Goes beyond PyPSA's built-in statistics — implements Script_MEE-style dashboard +analyses: dispatch profiles, line loading, price duration curves, and +cross-border flows. + +All functions accept a loaded PyPSA network and return Plotly figure JSON dicts, +ready for caching and frontend rendering. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +import numpy as np +import pandas as pd +import plotly.graph_objects as go +from plotly.subplots import make_subplots + +from pypsa_app.backend.services.network import load_service +from pypsa_app.backend.utils.carrier_colors import ( + COUNTRY_COLORS, + PLOTLY_TEMPLATE, + get_carrier_color, + sort_carriers_by_dispatch_order, +) + +logger = logging.getLogger(__name__) + +# Try polars for fast aggregations, fall back to pandas-only +try: + import polars as pl + + POLARS_AVAILABLE = True +except ImportError: + POLARS_AVAILABLE = False + + +# ── Dispatch / Energy Balance ──────────────────────────────────────────────── + + +def dispatch_area( + file_paths: list[str], + *, + country: str | None = None, + carrier_filter: list[str] | None = None, + resample: str | None = None, +) -> dict[str, Any]: + """Stacked area chart of generation dispatch by carrier over time. + + Mirrors Script_MEE's dispatch stacking with canonical carrier colors + and dispatch order. + """ + service = load_service(file_paths, use_cache=True) + n = service.n + + # Get generator dispatch (p) — rows=snapshots, cols=generators + gen_p = n.generators_t.p + if gen_p.empty: + return _empty_figure("No generator dispatch data available") + + # Map generators to carriers + carrier_map = n.generators.carrier + dispatch = gen_p.T.groupby(carrier_map).sum().T + + # Optional country filter + if country: + country_buses = n.buses.index[n.buses.get("country", "") == country] + country_gens = n.generators.index[n.generators.bus.isin(country_buses)] + gen_p_filtered = gen_p[gen_p.columns.intersection(country_gens)] + carrier_map_filtered = n.generators.carrier.loc[gen_p_filtered.columns] + dispatch = gen_p_filtered.T.groupby(carrier_map_filtered).sum().T + + # Optional carrier filter + if carrier_filter: + dispatch = dispatch[[c for c in carrier_filter if c in dispatch.columns]] + + # Optional resampling (e.g., "D" for daily, "W" for weekly) + if resample and len(dispatch) > 1: + dispatch = dispatch.resample(resample).mean() + + # Sort columns by dispatch order + ordered_carriers = sort_carriers_by_dispatch_order(list(dispatch.columns)) + dispatch = dispatch[ordered_carriers] + + # Build stacked area plot + fig = go.Figure() + for carrier in ordered_carriers: + fig.add_trace( + go.Scatter( + x=dispatch.index.astype(str), + y=dispatch[carrier].values, + name=carrier, + mode="lines", + stackgroup="dispatch", + fillcolor=get_carrier_color(carrier), + line={"width": 0.5, "color": get_carrier_color(carrier)}, + ) + ) + + fig.update_layout( + template=PLOTLY_TEMPLATE, + title="Generation Dispatch" + (f" — {country}" if country else ""), + xaxis_title="Time", + yaxis_title="Power (MW)", + hovermode="x unified", + legend={"traceorder": "normal"}, + ) + + return json.loads(fig.to_json()) + + +# ── Line Loading ───────────────────────────────────────────────────────────── + + +def line_loading_histogram( + file_paths: list[str], + *, + security_margin_pct: float = 70.0, + n1_limit_pct: float = 100.0, + top_n: int = 20, +) -> dict[str, Any]: + """Line loading distribution histogram + top-N congested lines bar chart. + + Two subplots: + 1. Histogram of max loading % across all lines + 2. Horizontal bar chart of top-N most loaded lines + """ + service = load_service(file_paths, use_cache=True) + n = service.n + + if not len(n.lines) or n.lines_t.p0.empty: + return _empty_figure("No line loading data available") + + # Compute loading percentage: |p0| / s_nom + s_nom = n.lines.s_nom + s_nom_safe = s_nom.replace(0, np.nan) + max_loading = (n.lines_t.p0.abs().max() / s_nom_safe * 100).dropna() + + if POLARS_AVAILABLE: + # Vectorized with polars for speed on large networks + df = pl.DataFrame({ + "line": max_loading.index, "loading_pct": max_loading.values, + }) + top_lines = df.sort("loading_pct", descending=True).head(top_n) + top_names = top_lines["line"].to_list() + top_values = top_lines["loading_pct"].to_list() + hist_values = df["loading_pct"].to_list() + else: + hist_values = max_loading.values + top = max_loading.nlargest(top_n) + top_names = top.index.tolist() + top_values = top.values.tolist() + + fig = make_subplots( + rows=1, + cols=2, + subplot_titles=["Loading Distribution", f"Top {top_n} Congested Lines"], + column_widths=[0.5, 0.5], + ) + + # Histogram + fig.add_trace( + go.Histogram( + x=hist_values, + nbinsx=50, + marker_color="steelblue", + opacity=0.7, + name="Lines", + ), + row=1, + col=1, + ) + + # Threshold lines + for pct, color, label in [ + (security_margin_pct, "orange", f"Security ({security_margin_pct:.0f}%)"), + (n1_limit_pct, "red", f"N-1 Limit ({n1_limit_pct:.0f}%)"), + ]: + fig.add_vline( + x=pct, line_dash="dash", line_color=color, + annotation_text=label, row=1, col=1, + ) + + # Top-N bar chart with color coding + bar_colors = [ + "red" if v > n1_limit_pct else "orange" if v > security_margin_pct else "green" + for v in top_values + ] + fig.add_trace( + go.Bar( + y=top_names, + x=top_values, + orientation="h", + marker_color=bar_colors, + name="Max Loading %", + ), + row=1, + col=2, + ) + + fig.update_layout( + template=PLOTLY_TEMPLATE, + title="Line Loading Analysis", + showlegend=False, + height=500, + ) + fig.update_xaxes(title_text="Max Loading (%)", row=1, col=1) + fig.update_xaxes(title_text="Max Loading (%)", row=1, col=2) + + return json.loads(fig.to_json()) + + +def line_loading_timeseries( + file_paths: list[str], + *, + line_names: list[str] | None = None, + top_n: int = 5, +) -> dict[str, Any]: + """Time series of line loading for selected or top-N lines.""" + service = load_service(file_paths, use_cache=True) + n = service.n + + if not len(n.lines) or n.lines_t.p0.empty: + return _empty_figure("No line loading data available") + + s_nom = n.lines.s_nom + s_nom_safe = s_nom.replace(0, np.nan) + loading_pct = n.lines_t.p0.abs().div(s_nom_safe) * 100 + + # Select lines to plot + if line_names: + selected = [ln for ln in line_names if ln in loading_pct.columns] + else: + max_loading = loading_pct.max() + selected = max_loading.nlargest(top_n).index.tolist() + + if not selected: + return _empty_figure("No matching lines found") + + fig = go.Figure() + for line_name in selected: + fig.add_trace( + go.Scatter( + x=loading_pct.index.astype(str), + y=loading_pct[line_name].values, + name=line_name, + mode="lines", + ) + ) + + fig.add_hline( + y=100, line_dash="dash", line_color="red", annotation_text="N-1 Limit", + ) + fig.update_layout( + template=PLOTLY_TEMPLATE, + title="Line Loading Time Series", + xaxis_title="Time", + yaxis_title="Loading (%)", + hovermode="x unified", + ) + + return json.loads(fig.to_json()) + + +# ── Price Analysis ─────────────────────────────────────────────────────────── + + +def price_duration_curve( + file_paths: list[str], + *, + country: str | None = None, + bus_names: list[str] | None = None, +) -> dict[str, Any]: + """Price duration curve — sorted marginal prices from high to low. + + Shows the full-year price profile in a single view, essential for + understanding price volatility and baseload/peak economics. + """ + service = load_service(file_paths, use_cache=True) + n = service.n + + if not hasattr(n, "buses_t") or n.buses_t.marginal_price.empty: + return _empty_figure("No marginal price data available") + + prices = n.buses_t.marginal_price + + if bus_names: + cols = [b for b in bus_names if b in prices.columns] + if not cols: + return _empty_figure("No matching buses found") + prices = prices[cols] + elif country: + country_buses = n.buses.index[n.buses.get("country", "") == country] + matching = prices.columns.intersection(country_buses) + if matching.empty: + return _empty_figure(f"No buses found for country {country}") + prices = prices[matching] + + fig = go.Figure() + + max_individual_curves = 10 + if prices.shape[1] <= max_individual_curves: + # Plot individual bus duration curves + for col in prices.columns: + sorted_vals = np.sort(prices[col].dropna().values)[::-1] + hours = np.arange(1, len(sorted_vals) + 1) + fig.add_trace( + go.Scatter(x=hours, y=sorted_vals, name=col, mode="lines") + ) + else: + # Aggregate: mean price across all selected buses + mean_prices = prices.mean(axis=1) + sorted_vals = np.sort(mean_prices.dropna().values)[::-1] + hours = np.arange(1, len(sorted_vals) + 1) + label = country or "Average" + fig.add_trace( + go.Scatter(x=hours, y=sorted_vals, name=label, mode="lines", + line={"color": COUNTRY_COLORS.get(country, "#2563eb")}) + ) + + fig.update_layout( + template=PLOTLY_TEMPLATE, + title="Price Duration Curve" + (f" — {country}" if country else ""), + xaxis_title="Hours (sorted)", + yaxis_title="Marginal Price (EUR/MWh)", + hovermode="x unified", + ) + + return json.loads(fig.to_json()) + + +def price_timeseries( + file_paths: list[str], + *, + country: str | None = None, + resample: str | None = None, +) -> dict[str, Any]: + """Price time series — marginal prices over time.""" + service = load_service(file_paths, use_cache=True) + n = service.n + + if not hasattr(n, "buses_t") or n.buses_t.marginal_price.empty: + return _empty_figure("No marginal price data available") + + prices = n.buses_t.marginal_price + + if country: + country_buses = n.buses.index[n.buses.get("country", "") == country] + matching = prices.columns.intersection(country_buses) + if matching.empty: + return _empty_figure(f"No buses found for country {country}") + prices = prices[matching].mean(axis=1).to_frame(name=country) + else: + prices = prices.mean(axis=1).to_frame(name="Average") + + if resample and len(prices) > 1: + prices = prices.resample(resample).mean() + + fig = go.Figure() + for col in prices.columns: + color = COUNTRY_COLORS.get(col, "#2563eb") + fig.add_trace( + go.Scatter( + x=prices.index.astype(str), + y=prices[col].values, + name=col, + mode="lines", + line={"color": color}, + ) + ) + + fig.update_layout( + template=PLOTLY_TEMPLATE, + title="Marginal Prices" + (f" — {country}" if country else ""), + xaxis_title="Time", + yaxis_title="Price (EUR/MWh)", + hovermode="x unified", + ) + + return json.loads(fig.to_json()) + + +# ── Cross-Border Flows ─────────────────────────────────────────────────────── + + +def cross_border_flows( + file_paths: list[str], + *, + country: str = "NL", + resample: str | None = "D", +) -> dict[str, Any]: + """Cross-border flow analysis for a given country. + + Shows net imports/exports per interconnection over time. + """ + service = load_service(file_paths, use_cache=True) + n = service.n + + if n.links_t.p0.empty and n.lines_t.p0.empty: + return _empty_figure("No flow data available") + + country_buses = set(n.buses.index[n.buses.get("country", "") == country]) + if not country_buses: + return _empty_figure(f"No buses found for country {country}") + + flows = {} + + # Check links (HVDC interconnectors) + for link_name in n.links.index: + bus0 = n.links.at[link_name, "bus0"] + bus1 = n.links.at[link_name, "bus1"] + bus0_country = n.buses.get("country", "").get(bus0, "") + bus1_country = n.buses.get("country", "").get(bus1, "") + + if bus0 in country_buses and bus1 not in country_buses: + # Export from country + label = f"{country} → {bus1_country}" + if link_name in n.links_t.p0.columns: + flows.setdefault(label, []).append(n.links_t.p0[link_name]) + elif bus1 in country_buses and bus0 not in country_buses: + # Import to country + label = f"{bus0_country} → {country}" + if link_name in n.links_t.p0.columns: + flows.setdefault(label, []).append(-n.links_t.p0[link_name]) + + if not flows: + return _empty_figure(f"No cross-border links found for {country}") + + fig = go.Figure() + for label, series_list in flows.items(): + combined = pd.concat(series_list, axis=1).sum(axis=1) + if resample and len(combined) > 1: + combined = combined.resample(resample).mean() + + # Color by partner country + partner = label.replace(f"{country} → ", "").replace(f" → {country}", "") + color = COUNTRY_COLORS.get(partner, "#888888") + + fig.add_trace( + go.Scatter( + x=combined.index.astype(str), + y=combined.values, + name=label, + mode="lines", + line={"color": color}, + ) + ) + + fig.add_hline(y=0, line_dash="dot", line_color="grey") + fig.update_layout( + template=PLOTLY_TEMPLATE, + title=f"Cross-Border Flows — {country}", + xaxis_title="Time", + yaxis_title="Flow (MW, positive = export)", + hovermode="x unified", + ) + + return json.loads(fig.to_json()) + + +# ── Capacity Mix ───────────────────────────────────────────────────────────── + + +def capacity_mix( + file_paths: list[str], + *, + country: str | None = None, +) -> dict[str, Any]: + """Installed capacity bar chart by carrier.""" + service = load_service(file_paths, use_cache=True) + n = service.n + + gens = n.generators.copy() + if country: + country_buses = n.buses.index[n.buses.get("country", "") == country] + gens = gens[gens.bus.isin(country_buses)] + + if gens.empty: + return _empty_figure("No generator data available") + + capacity = gens.groupby("carrier")["p_nom"].sum() + ordered = sort_carriers_by_dispatch_order(list(capacity.index)) + capacity = capacity.reindex(ordered) + colors = [get_carrier_color(c) for c in ordered] + + fig = go.Figure( + go.Bar( + x=ordered, + y=capacity.values, + marker_color=colors, + ) + ) + fig.update_layout( + template=PLOTLY_TEMPLATE, + title="Installed Capacity" + (f" — {country}" if country else ""), + xaxis_title="Carrier", + yaxis_title="Capacity (MW)", + ) + + return json.loads(fig.to_json()) + + +# ── Helpers ────────────────────────────────────────────────────────────────── + + +def _empty_figure(message: str) -> dict[str, Any]: + """Return an empty Plotly figure with a centered message.""" + fig = go.Figure() + fig.add_annotation( + text=message, + xref="paper", yref="paper", + x=0.5, y=0.5, showarrow=False, + font={"size": 16, "color": "grey"}, + ) + fig.update_layout(template=PLOTLY_TEMPLATE) + return json.loads(fig.to_json()) + + +# ── Public dispatch table ──────────────────────────────────────────────────── + +ANALYSIS_TYPES: dict[str, callable] = { + "dispatch_area": dispatch_area, + "line_loading_histogram": line_loading_histogram, + "line_loading_timeseries": line_loading_timeseries, + "price_duration_curve": price_duration_curve, + "price_timeseries": price_timeseries, + "cross_border_flows": cross_border_flows, + "capacity_mix": capacity_mix, +} + + +def run_analysis( + file_paths: list[str], + analysis_type: str, + parameters: dict[str, Any], +) -> dict[str, Any]: + """Run a named analysis pipeline and return Plotly figure JSON.""" + if analysis_type not in ANALYSIS_TYPES: + msg = ( + f"Unknown analysis type '{analysis_type}'. " + f"Available: {sorted(ANALYSIS_TYPES)}" + ) + raise ValueError(msg) + + func = ANALYSIS_TYPES[analysis_type] + logger.debug( + "Running analysis", + extra={ + "analysis_type": analysis_type, + "num_networks": len(file_paths), + "parameters": parameters, + }, + ) + return func(file_paths, **parameters) diff --git a/src/pypsa_app/backend/services/network.py b/src/pypsa_app/backend/services/network.py index 1d936e2..05e6162 100644 --- a/src/pypsa_app/backend/services/network.py +++ b/src/pypsa_app/backend/services/network.py @@ -262,10 +262,10 @@ def _generate_unique_names_from_paths(file_paths: list[Path]) -> list[str]: def _calculate_file_hash(file_path: Path) -> str: - """Calculate SHA256 hash of a file.""" + """Calculate SHA256 hash of a file using 128KB chunks for throughput.""" sha256_hash = hashlib.sha256() with file_path.open("rb") as f: - for byte_block in iter(lambda: f.read(4096), b""): + for byte_block in iter(lambda: f.read(131072), b""): sha256_hash.update(byte_block) return sha256_hash.hexdigest() diff --git a/src/pypsa_app/backend/tasks.py b/src/pypsa_app/backend/tasks.py index 644fef5..6385105 100644 --- a/src/pypsa_app/backend/tasks.py +++ b/src/pypsa_app/backend/tasks.py @@ -12,6 +12,7 @@ from pypsa_app.backend.database import SessionLocal from pypsa_app.backend.models import Run, RunStatus, SnakedispatchBackend from pypsa_app.backend.schemas.task import TaskResultResponse +from pypsa_app.backend.services.analysis import run_analysis as run_analysis_service from pypsa_app.backend.services.callback import fire_callback_sync from pypsa_app.backend.services.network import import_network_file from pypsa_app.backend.services.run import SnakedispatchClient @@ -69,6 +70,13 @@ def get_plot_task(self: Any, **kwargs: Any) -> dict[str, Any]: return _execute_task(self, "Plot generation", func, **kwargs) +@task_app.task(bind=True, name="tasks.run_analysis") +def run_analysis_task(self: Any, **kwargs: Any) -> dict[str, Any]: + """Background task for custom analysis generation""" + func = cache("analysis", ttl=settings.plot_cache_ttl)(run_analysis_service) + return _execute_task(self, "Analysis generation", func, **kwargs) + + @task_app.task(bind=True, name="tasks.import_run_outputs") def import_run_outputs_task(self: Any, job_id: str) -> None: # noqa: PLR0915 """Download .nc outputs from a completed run and import as networks.""" diff --git a/src/pypsa_app/backend/utils/allowlists.py b/src/pypsa_app/backend/utils/allowlists.py index a2499b1..0805078 100644 --- a/src/pypsa_app/backend/utils/allowlists.py +++ b/src/pypsa_app/backend/utils/allowlists.py @@ -37,4 +37,16 @@ } ) -__all__ = ["ALLOWED_STATISTICS", "ALLOWED_CHART_TYPES"] +ALLOWED_ANALYSIS_TYPES: Final[frozenset[str]] = frozenset( + { + "dispatch_area", + "line_loading_histogram", + "line_loading_timeseries", + "price_duration_curve", + "price_timeseries", + "cross_border_flows", + "capacity_mix", + } +) + +__all__ = ["ALLOWED_STATISTICS", "ALLOWED_CHART_TYPES", "ALLOWED_ANALYSIS_TYPES"] diff --git a/src/pypsa_app/backend/utils/carrier_colors.py b/src/pypsa_app/backend/utils/carrier_colors.py new file mode 100644 index 0000000..963b18d --- /dev/null +++ b/src/pypsa_app/backend/utils/carrier_colors.py @@ -0,0 +1,78 @@ +"""Canonical carrier colors and dispatch order for energy system visualizations. + +Ported from PSA_MEE analysis.dashboard_utils — single source of truth for all +plot styling in pypsa-app. +""" + +from typing import Final + +# Technology-specific colors for energy dispatch charts +CARRIER_COLORS: Final[dict[str, str]] = { + "Nuclear": "#d62728", + "Gas": "#8c564b", + "CHP": "#006400", + "Coal": "#2c2c2c", + "Lignite": "#654321", + "Wind": "#1f77b4", + "Wind_Onshore": "#1f77b4", + "Wind_Offshore": "#00008b", + "Solar": "#ffdd44", + "Hydro": "#87CEEB", + "Biomass": "#98df8a", + "Waste": "#bc5090", + "Oil": "#e377c2", + "Other": "#aaa5a0", + "Slack": "#d62728", + "Battery": "#800080", + "PHS": "#7f7f7f", + "DemandResponse": "#bcbd22", + "Load": "#708090", + "HVDC_Losses": "#e74c3c", +} + +# Dispatch stacking order: baseload at bottom, variable renewables on top +DISPATCH_ORDER: Final[list[str]] = [ + "Nuclear", + "Hydro", + "Coal", + "Lignite", + "Biomass", + "Waste", + "CHP", + "Gas", + "Oil", + "Other", + "DemandResponse", + "Battery", + "PHS", + "Wind_Offshore", + "Wind_Onshore", + "Wind", + "Solar", + "Slack", +] + +# Country/zone colors for cross-border and price visualizations +COUNTRY_COLORS: Final[dict[str, str]] = { + "NL": "#FF6B00", + "DE": "#000000", + "BE": "#FFD700", + "FR": "#002395", + "GB": "#00247D", + "DK": "#C8102E", + "NO": "#003DA5", + "LU": "#00A3E0", +} + +PLOTLY_TEMPLATE: Final[str] = "simple_white" + + +def get_carrier_color(carrier: str) -> str: + """Get color for a carrier, falling back to grey for unknown carriers.""" + return CARRIER_COLORS.get(carrier, "#aaa5a0") + + +def sort_carriers_by_dispatch_order(carriers: list[str]) -> list[str]: + """Sort carriers according to canonical dispatch stacking order.""" + order_map = {name: i for i, name in enumerate(DISPATCH_ORDER)} + return sorted(carriers, key=lambda c: order_map.get(c, len(DISPATCH_ORDER))) diff --git a/src/pypsa_app/backend/utils/serializers.py b/src/pypsa_app/backend/utils/serializers.py index 8065167..b1b4b47 100644 --- a/src/pypsa_app/backend/utils/serializers.py +++ b/src/pypsa_app/backend/utils/serializers.py @@ -3,8 +3,31 @@ import math from typing import Any +import numpy as np import pandas as pd +# Use orjson for faster JSON serialization when available +try: + import orjson + + def fast_json_dumps(obj: Any) -> str: + """Serialize to JSON string using orjson (3-10x faster than stdlib).""" + return orjson.dumps( + obj, option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NON_STR_KEYS + ).decode() + + def fast_json_loads(data: str | bytes) -> Any: + """Deserialize JSON using orjson.""" + return orjson.loads(data) + + ORJSON_AVAILABLE = True +except ImportError: + import json + + fast_json_dumps = json.dumps # type: ignore[assignment] + fast_json_loads = json.loads # type: ignore[assignment] + ORJSON_AVAILABLE = False + def serialize_df(data: pd.DataFrame | pd.Series) -> dict: """Convert pandas DataFrame or Series to JSON-serializable dict""" @@ -19,14 +42,35 @@ def serialize_df(data: pd.DataFrame | pd.Series) -> dict: result["columns"][0] if result["columns"] else None, tuple ): result["columns"] = [str(col) for col in result["columns"]] + # Replace NaN/Inf in data values for JSON compatibility + if result.get("data"): + result["data"] = _sanitize_nested_floats(result["data"]) return result elif isinstance(data, pd.Series): - return {str(k): v for k, v in data.to_dict().items()} + return {str(k): _sanitize_float(v) for k, v in data.to_dict().items()} else: msg = f"Expected DataFrame or Series, got {type(data)}" raise TypeError(msg) +def _sanitize_float(v: Any) -> Any: + """Sanitize a single float value.""" + if isinstance(v, float) and (math.isinf(v) or math.isnan(v)): + return None + if isinstance(v, np.floating): + v = float(v) + if math.isinf(v) or math.isnan(v): + return None + return v + + +def _sanitize_nested_floats(data: Any) -> Any: + """Recursively sanitize float values in nested lists (for DataFrame data).""" + if isinstance(data, list): + return [_sanitize_nested_floats(item) for item in data] + return _sanitize_float(data) + + def sanitize_metadata(data: Any) -> Any: """Recursively sanitize metadata to be JSON-compatible (removes inf/nan)""" if isinstance(data, dict): From ec79c6ffff59bf76ba2c70a6d716b59bcaf05664 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Apr 2026 14:01:45 +0000 Subject: [PATCH 02/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pypsa_app/backend/schemas/analysis.py | 4 +-- src/pypsa_app/backend/services/analysis.py | 42 +++++++++++++++------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/src/pypsa_app/backend/schemas/analysis.py b/src/pypsa_app/backend/schemas/analysis.py index b318292..fcac652 100644 --- a/src/pypsa_app/backend/schemas/analysis.py +++ b/src/pypsa_app/backend/schemas/analysis.py @@ -10,9 +10,7 @@ class AnalysisRequest(BaseModel): """Request for custom analysis (dispatch, line loading, prices, etc.)""" - network_ids: list[str] = Field( - ..., description="List of network UUIDs" - ) + network_ids: list[str] = Field(..., description="List of network UUIDs") analysis_type: str = Field( ..., description="Analysis type (e.g., 'dispatch_area')", diff --git a/src/pypsa_app/backend/services/analysis.py b/src/pypsa_app/backend/services/analysis.py index 777c341..66a0fa6 100644 --- a/src/pypsa_app/backend/services/analysis.py +++ b/src/pypsa_app/backend/services/analysis.py @@ -141,9 +141,12 @@ def line_loading_histogram( if POLARS_AVAILABLE: # Vectorized with polars for speed on large networks - df = pl.DataFrame({ - "line": max_loading.index, "loading_pct": max_loading.values, - }) + df = pl.DataFrame( + { + "line": max_loading.index, + "loading_pct": max_loading.values, + } + ) top_lines = df.sort("loading_pct", descending=True).head(top_n) top_names = top_lines["line"].to_list() top_values = top_lines["loading_pct"].to_list() @@ -180,8 +183,12 @@ def line_loading_histogram( (n1_limit_pct, "red", f"N-1 Limit ({n1_limit_pct:.0f}%)"), ]: fig.add_vline( - x=pct, line_dash="dash", line_color=color, - annotation_text=label, row=1, col=1, + x=pct, + line_dash="dash", + line_color=color, + annotation_text=label, + row=1, + col=1, ) # Top-N bar chart with color coding @@ -252,7 +259,10 @@ def line_loading_timeseries( ) fig.add_hline( - y=100, line_dash="dash", line_color="red", annotation_text="N-1 Limit", + y=100, + line_dash="dash", + line_color="red", + annotation_text="N-1 Limit", ) fig.update_layout( template=PLOTLY_TEMPLATE, @@ -307,9 +317,7 @@ def price_duration_curve( for col in prices.columns: sorted_vals = np.sort(prices[col].dropna().values)[::-1] hours = np.arange(1, len(sorted_vals) + 1) - fig.add_trace( - go.Scatter(x=hours, y=sorted_vals, name=col, mode="lines") - ) + fig.add_trace(go.Scatter(x=hours, y=sorted_vals, name=col, mode="lines")) else: # Aggregate: mean price across all selected buses mean_prices = prices.mean(axis=1) @@ -317,8 +325,13 @@ def price_duration_curve( hours = np.arange(1, len(sorted_vals) + 1) label = country or "Average" fig.add_trace( - go.Scatter(x=hours, y=sorted_vals, name=label, mode="lines", - line={"color": COUNTRY_COLORS.get(country, "#2563eb")}) + go.Scatter( + x=hours, + y=sorted_vals, + name=label, + mode="lines", + line={"color": COUNTRY_COLORS.get(country, "#2563eb")}, + ) ) fig.update_layout( @@ -511,8 +524,11 @@ def _empty_figure(message: str) -> dict[str, Any]: fig = go.Figure() fig.add_annotation( text=message, - xref="paper", yref="paper", - x=0.5, y=0.5, showarrow=False, + xref="paper", + yref="paper", + x=0.5, + y=0.5, + showarrow=False, font={"size": 16, "color": "grey"}, ) fig.update_layout(template=PLOTLY_TEMPLATE) From 2d756c5e0e5e5d5ab71f531c1ed0ce80219c1b5d Mon Sep 17 00:00:00 2001 From: Mayk Thewessen Date: Mon, 6 Apr 2026 18:53:10 +0200 Subject: [PATCH 03/17] fix: safe bus country lookups and cross-border AC line support Fix critical bugs in analysis service: - Replace unsafe n.buses.get("country", "") pattern that crashes on networks without a "country" column (ValueError on pandas >= 2.0) - Extract _get_bus_countries/_buses_in_country/_bus_country helpers for safe, consistent country lookups across all analysis functions - Add AC line cross-border flow detection (previously only HVDC links) - Handle NaN country values gracefully - Add .fillna(0) to capacity_mix reindex to prevent NaN bars - Refactor cross_border_flows to reduce cyclomatic complexity (PLR0912) Co-Authored-By: Claude Opus 4.6 (1M context) --- src/pypsa_app/backend/services/analysis.py | 127 +++++++++++++++------ 1 file changed, 91 insertions(+), 36 deletions(-) diff --git a/src/pypsa_app/backend/services/analysis.py b/src/pypsa_app/backend/services/analysis.py index 66a0fa6..7b0e07c 100644 --- a/src/pypsa_app/backend/services/analysis.py +++ b/src/pypsa_app/backend/services/analysis.py @@ -38,6 +38,31 @@ POLARS_AVAILABLE = False +# ── Bus country helpers ────────────────────────────────────────────────────── + + +def _get_bus_countries(n: Any) -> pd.Series: + """Safely extract bus country column, returning empty Series if absent.""" + if "country" in n.buses.columns: + return n.buses["country"] + return pd.Series(dtype=str, index=n.buses.index) + + +def _buses_in_country(n: Any, country: str) -> pd.Index: + """Return bus indices belonging to a given country.""" + countries = _get_bus_countries(n) + return countries.index[countries == country] + + +def _bus_country(n: Any, bus: str) -> str: + """Return the country of a single bus, or empty string if unknown.""" + countries = _get_bus_countries(n) + if bus in countries.index: + val = countries.at[bus] + return str(val) if pd.notna(val) else "" + return "" + + # ── Dispatch / Energy Balance ──────────────────────────────────────────────── @@ -65,13 +90,15 @@ def dispatch_area( carrier_map = n.generators.carrier dispatch = gen_p.T.groupby(carrier_map).sum().T - # Optional country filter + # Optional country filter — applied before groupby if country: - country_buses = n.buses.index[n.buses.get("country", "") == country] + country_buses = _buses_in_country(n, country) country_gens = n.generators.index[n.generators.bus.isin(country_buses)] - gen_p_filtered = gen_p[gen_p.columns.intersection(country_gens)] - carrier_map_filtered = n.generators.carrier.loc[gen_p_filtered.columns] - dispatch = gen_p_filtered.T.groupby(carrier_map_filtered).sum().T + gen_p = gen_p[gen_p.columns.intersection(country_gens)] + carrier_map = carrier_map.loc[gen_p.columns] + if gen_p.empty: + return _empty_figure(f"No generators found for {country}") + dispatch = gen_p.T.groupby(carrier_map).sum().T # Optional carrier filter if carrier_filter: @@ -303,7 +330,7 @@ def price_duration_curve( return _empty_figure("No matching buses found") prices = prices[cols] elif country: - country_buses = n.buses.index[n.buses.get("country", "") == country] + country_buses = _buses_in_country(n, country) matching = prices.columns.intersection(country_buses) if matching.empty: return _empty_figure(f"No buses found for country {country}") @@ -317,20 +344,26 @@ def price_duration_curve( for col in prices.columns: sorted_vals = np.sort(prices[col].dropna().values)[::-1] hours = np.arange(1, len(sorted_vals) + 1) - fig.add_trace(go.Scatter(x=hours, y=sorted_vals, name=col, mode="lines")) + fig.add_trace( + go.Scatter( + x=hours, y=sorted_vals, name=col, mode="lines", + ) + ) else: # Aggregate: mean price across all selected buses mean_prices = prices.mean(axis=1) sorted_vals = np.sort(mean_prices.dropna().values)[::-1] hours = np.arange(1, len(sorted_vals) + 1) label = country or "Average" + color = ( + COUNTRY_COLORS.get(country, "#2563eb") + if country + else "#2563eb" + ) fig.add_trace( go.Scatter( - x=hours, - y=sorted_vals, - name=label, - mode="lines", - line={"color": COUNTRY_COLORS.get(country, "#2563eb")}, + x=hours, y=sorted_vals, name=label, + mode="lines", line={"color": color}, ) ) @@ -361,7 +394,7 @@ def price_timeseries( prices = n.buses_t.marginal_price if country: - country_buses = n.buses.index[n.buses.get("country", "") == country] + country_buses = _buses_in_country(n, country) matching = prices.columns.intersection(country_buses) if matching.empty: return _empty_figure(f"No buses found for country {country}") @@ -399,6 +432,32 @@ def price_timeseries( # ── Cross-Border Flows ─────────────────────────────────────────────────────── +def _collect_cross_border_flows( + n: Any, + component_df: pd.DataFrame, + component_t_p0: pd.DataFrame, + country: str, + country_buses: set[str], + flows: dict[str, list[pd.Series]], +) -> None: + """Collect cross-border flows from a component (lines or links).""" + for name in component_df.index: + bus0 = component_df.at[name, "bus0"] + bus1 = component_df.at[name, "bus1"] + bus0_c = _bus_country(n, bus0) + bus1_c = _bus_country(n, bus1) + + if bus0_c == bus1_c or name not in component_t_p0.columns: + continue + + if bus0 in country_buses and bus1 not in country_buses: + label = f"{country} → {bus1_c or '?'}" + flows.setdefault(label, []).append(component_t_p0[name]) + elif bus1 in country_buses and bus0 not in country_buses: + label = f"{bus0_c or '?'} → {country}" + flows.setdefault(label, []).append(-component_t_p0[name]) + + def cross_border_flows( file_paths: list[str], *, @@ -408,39 +467,35 @@ def cross_border_flows( """Cross-border flow analysis for a given country. Shows net imports/exports per interconnection over time. + Checks both links (HVDC) and AC lines crossing borders. """ service = load_service(file_paths, use_cache=True) n = service.n - if n.links_t.p0.empty and n.lines_t.p0.empty: + has_links = not n.links_t.p0.empty + has_lines = not n.lines_t.p0.empty + if not has_links and not has_lines: return _empty_figure("No flow data available") - country_buses = set(n.buses.index[n.buses.get("country", "") == country]) + country_buses = set(_buses_in_country(n, country)) if not country_buses: return _empty_figure(f"No buses found for country {country}") - flows = {} - - # Check links (HVDC interconnectors) - for link_name in n.links.index: - bus0 = n.links.at[link_name, "bus0"] - bus1 = n.links.at[link_name, "bus1"] - bus0_country = n.buses.get("country", "").get(bus0, "") - bus1_country = n.buses.get("country", "").get(bus1, "") + flows: dict[str, list[pd.Series]] = {} - if bus0 in country_buses and bus1 not in country_buses: - # Export from country - label = f"{country} → {bus1_country}" - if link_name in n.links_t.p0.columns: - flows.setdefault(label, []).append(n.links_t.p0[link_name]) - elif bus1 in country_buses and bus0 not in country_buses: - # Import to country - label = f"{bus0_country} → {country}" - if link_name in n.links_t.p0.columns: - flows.setdefault(label, []).append(-n.links_t.p0[link_name]) + if has_links: + _collect_cross_border_flows( + n, n.links, n.links_t.p0, country, country_buses, flows, + ) + if has_lines: + _collect_cross_border_flows( + n, n.lines, n.lines_t.p0, country, country_buses, flows, + ) if not flows: - return _empty_figure(f"No cross-border links found for {country}") + return _empty_figure( + f"No cross-border connections found for {country}" + ) fig = go.Figure() for label, series_list in flows.items(): @@ -488,7 +543,7 @@ def capacity_mix( gens = n.generators.copy() if country: - country_buses = n.buses.index[n.buses.get("country", "") == country] + country_buses = _buses_in_country(n, country) gens = gens[gens.bus.isin(country_buses)] if gens.empty: @@ -496,7 +551,7 @@ def capacity_mix( capacity = gens.groupby("carrier")["p_nom"].sum() ordered = sort_carriers_by_dispatch_order(list(capacity.index)) - capacity = capacity.reindex(ordered) + capacity = capacity.reindex(ordered).fillna(0) colors = [get_carrier_color(c) for c in ordered] fig = go.Figure( From a3400f664e1880fee48a6f0fc10f316f831cf564 Mon Sep 17 00:00:00 2001 From: Mayk Thewessen Date: Mon, 6 Apr 2026 18:54:37 +0200 Subject: [PATCH 04/17] feat: add nodal_balance and line_flow_snapshot analysis types (closes #6) Add two new analysis types addressing GitHub issue #6: - nodal_balance: per-bus generation vs load at a given snapshot, showing net injection as overlay bar chart (top-N by absolute injection) - line_flow_snapshot: per-line loading % and flow direction at a given snapshot, with color-coded congestion thresholds Both support snapshot_idx parameter for time-step navigation, enabling the frontend to build a "grid playback" tool. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/pypsa_app/backend/services/analysis.py | 170 +++++++++++++++++++++ src/pypsa_app/backend/utils/allowlists.py | 2 + 2 files changed, 172 insertions(+) diff --git a/src/pypsa_app/backend/services/analysis.py b/src/pypsa_app/backend/services/analysis.py index 7b0e07c..cfc4431 100644 --- a/src/pypsa_app/backend/services/analysis.py +++ b/src/pypsa_app/backend/services/analysis.py @@ -571,6 +571,174 @@ def capacity_mix( return json.loads(fig.to_json()) +# ── Nodal Balance (GitHub issue #6) ────────────────────────────────────────── + + +def nodal_balance( + file_paths: list[str], + *, + snapshot_idx: int = 0, + country: str | None = None, + top_n: int = 30, +) -> dict[str, Any]: + """Per-node generation, load, and net injection at a given snapshot. + + Returns a bar chart of top-N buses by absolute net injection, + showing generation vs load balance per node. + """ + service = load_service(file_paths, use_cache=True) + n = service.n + + if n.generators_t.p.empty: + return _empty_figure("No generator dispatch data available") + + # Clamp snapshot index + snapshot_idx = max(0, min(snapshot_idx, len(n.snapshots) - 1)) + + # Generation per bus at this snapshot + gen_p = n.generators_t.p.iloc[snapshot_idx] + gen_by_bus = gen_p.groupby(n.generators.bus).sum() + + # Load per bus at this snapshot + load_by_bus = pd.Series(dtype=float) + if not n.loads_t.p_set.empty: + load_p = n.loads_t.p_set.iloc[snapshot_idx] + load_by_bus = load_p.groupby(n.loads.bus).sum() + + # Combine into a single DataFrame + all_buses = sorted(set(gen_by_bus.index) | set(load_by_bus.index)) + balance = pd.DataFrame( + { + "generation": gen_by_bus.reindex(all_buses, fill_value=0), + "load": load_by_bus.reindex(all_buses, fill_value=0), + }, + index=all_buses, + ) + balance["net_injection"] = balance["generation"] - balance["load"] + + # Optional country filter + if country: + country_buses = _buses_in_country(n, country) + balance = balance.loc[balance.index.isin(country_buses)] + + if balance.empty: + return _empty_figure("No nodal data available") + + # Select top-N by absolute net injection + balance = balance.reindex( + balance["net_injection"].abs().nlargest(top_n).index + ) + balance = balance.sort_values("net_injection", ascending=True) + + snapshot_label = str(n.snapshots[snapshot_idx]) + + fig = go.Figure() + fig.add_trace( + go.Bar( + y=balance.index, + x=balance["generation"].values, + name="Generation", + orientation="h", + marker_color="#2ca02c", + ) + ) + fig.add_trace( + go.Bar( + y=balance.index, + x=-balance["load"].values, + name="Load", + orientation="h", + marker_color="#d62728", + ) + ) + + fig.update_layout( + template=PLOTLY_TEMPLATE, + title=f"Nodal Balance — {snapshot_label}" + + (f" ({country})" if country else ""), + xaxis_title="Power (MW, + = generation, - = load)", + barmode="overlay", + height=max(400, top_n * 20), + hovermode="y unified", + ) + + return json.loads(fig.to_json()) + + +def line_flow_snapshot( + file_paths: list[str], + *, + snapshot_idx: int = 0, + top_n: int = 30, +) -> dict[str, Any]: + """Per-line power flow and loading % at a given snapshot. + + Returns a horizontal bar chart of top-N lines by loading %, + with flow magnitude and direction indicated. + """ + service = load_service(file_paths, use_cache=True) + n = service.n + + if not len(n.lines) or n.lines_t.p0.empty: + return _empty_figure("No line flow data available") + + snapshot_idx = max(0, min(snapshot_idx, len(n.snapshots) - 1)) + snapshot_label = str(n.snapshots[snapshot_idx]) + + # Flow at this snapshot + flow = n.lines_t.p0.iloc[snapshot_idx] + s_nom = n.lines.s_nom + s_nom_safe = s_nom.replace(0, np.nan) + loading = (flow.abs() / s_nom_safe * 100).dropna() + + # Top-N by loading + top = loading.nlargest(top_n) + top_flow = flow.loc[top.index] + + # Build labels with direction + labels = [] + for line_name in top.index: + bus0 = n.lines.at[line_name, "bus0"] + bus1 = n.lines.at[line_name, "bus1"] + direction = "→" if top_flow[line_name] >= 0 else "←" + labels.append(f"{bus0} {direction} {bus1}") + + n1_threshold = 100 + warning_threshold = 70 + bar_colors = [ + "red" + if v > n1_threshold + else "orange" + if v > warning_threshold + else "green" + for v in top.values + ] + + fig = go.Figure() + fig.add_trace( + go.Bar( + y=labels[::-1], + x=top.values[::-1], + orientation="h", + marker_color=bar_colors[::-1], + text=[ + f"{top_flow[ln]:.0f} MW" + for ln in reversed(top.index) + ], + textposition="outside", + ) + ) + + fig.update_layout( + template=PLOTLY_TEMPLATE, + title=f"Line Loading — {snapshot_label}", + xaxis_title="Loading (%)", + height=max(400, top_n * 22), + ) + + return json.loads(fig.to_json()) + + # ── Helpers ────────────────────────────────────────────────────────────────── @@ -600,6 +768,8 @@ def _empty_figure(message: str) -> dict[str, Any]: "price_timeseries": price_timeseries, "cross_border_flows": cross_border_flows, "capacity_mix": capacity_mix, + "nodal_balance": nodal_balance, + "line_flow_snapshot": line_flow_snapshot, } diff --git a/src/pypsa_app/backend/utils/allowlists.py b/src/pypsa_app/backend/utils/allowlists.py index 0805078..81d83e3 100644 --- a/src/pypsa_app/backend/utils/allowlists.py +++ b/src/pypsa_app/backend/utils/allowlists.py @@ -46,6 +46,8 @@ "price_timeseries", "cross_border_flows", "capacity_mix", + "nodal_balance", + "line_flow_snapshot", } ) From 2fc73e427e6eba2039672d67b15af5ba0b6a9004 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Apr 2026 16:54:53 +0000 Subject: [PATCH 05/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pypsa_app/backend/services/analysis.py | 51 +++++++++++----------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/src/pypsa_app/backend/services/analysis.py b/src/pypsa_app/backend/services/analysis.py index cfc4431..13decec 100644 --- a/src/pypsa_app/backend/services/analysis.py +++ b/src/pypsa_app/backend/services/analysis.py @@ -346,7 +346,10 @@ def price_duration_curve( hours = np.arange(1, len(sorted_vals) + 1) fig.add_trace( go.Scatter( - x=hours, y=sorted_vals, name=col, mode="lines", + x=hours, + y=sorted_vals, + name=col, + mode="lines", ) ) else: @@ -355,15 +358,14 @@ def price_duration_curve( sorted_vals = np.sort(mean_prices.dropna().values)[::-1] hours = np.arange(1, len(sorted_vals) + 1) label = country or "Average" - color = ( - COUNTRY_COLORS.get(country, "#2563eb") - if country - else "#2563eb" - ) + color = COUNTRY_COLORS.get(country, "#2563eb") if country else "#2563eb" fig.add_trace( go.Scatter( - x=hours, y=sorted_vals, name=label, - mode="lines", line={"color": color}, + x=hours, + y=sorted_vals, + name=label, + mode="lines", + line={"color": color}, ) ) @@ -485,17 +487,25 @@ def cross_border_flows( if has_links: _collect_cross_border_flows( - n, n.links, n.links_t.p0, country, country_buses, flows, + n, + n.links, + n.links_t.p0, + country, + country_buses, + flows, ) if has_lines: _collect_cross_border_flows( - n, n.lines, n.lines_t.p0, country, country_buses, flows, + n, + n.lines, + n.lines_t.p0, + country, + country_buses, + flows, ) if not flows: - return _empty_figure( - f"No cross-border connections found for {country}" - ) + return _empty_figure(f"No cross-border connections found for {country}") fig = go.Figure() for label, series_list in flows.items(): @@ -625,9 +635,7 @@ def nodal_balance( return _empty_figure("No nodal data available") # Select top-N by absolute net injection - balance = balance.reindex( - balance["net_injection"].abs().nlargest(top_n).index - ) + balance = balance.reindex(balance["net_injection"].abs().nlargest(top_n).index) balance = balance.sort_values("net_injection", ascending=True) snapshot_label = str(n.snapshots[snapshot_idx]) @@ -706,11 +714,7 @@ def line_flow_snapshot( n1_threshold = 100 warning_threshold = 70 bar_colors = [ - "red" - if v > n1_threshold - else "orange" - if v > warning_threshold - else "green" + "red" if v > n1_threshold else "orange" if v > warning_threshold else "green" for v in top.values ] @@ -721,10 +725,7 @@ def line_flow_snapshot( x=top.values[::-1], orientation="h", marker_color=bar_colors[::-1], - text=[ - f"{top_flow[ln]:.0f} MW" - for ln in reversed(top.index) - ], + text=[f"{top_flow[ln]:.0f} MW" for ln in reversed(top.index)], textposition="outside", ) ) From dcf44de61798ceb8d53964601259a3943752efec Mon Sep 17 00:00:00 2001 From: Mayk Thewessen Date: Mon, 6 Apr 2026 19:33:13 +0200 Subject: [PATCH 06/17] feat: add network component data browser with paginated API Add full-stack component browser for viewing PyPSA network component data (buses, generators, lines, etc.) with pagination, sorting, search, and time-series metadata. Backend includes PATCH endpoint for editing component data with safe cache invalidation and path validation. Co-Authored-By: Claude Opus 4.6 (1M context) --- frontend/app/src/lib/api/client.ts | 50 +++ frontend/app/src/lib/types.ts | 39 ++ .../src/routes/database/network/+page.svelte | 9 + .../components/ComponentBrowser.svelte | 349 ++++++++++++++++++ .../backend/api/routes/components.py | 269 ++++++++++++++ src/pypsa_app/backend/main.py | 4 + src/pypsa_app/backend/schemas/components.py | 57 +++ 7 files changed, 777 insertions(+) create mode 100644 frontend/app/src/routes/database/network/components/ComponentBrowser.svelte create mode 100644 src/pypsa_app/backend/api/routes/components.py create mode 100644 src/pypsa_app/backend/schemas/components.py diff --git a/frontend/app/src/lib/api/client.ts b/frontend/app/src/lib/api/client.ts index e92ee92..b4a6107 100644 --- a/frontend/app/src/lib/api/client.ts +++ b/frontend/app/src/lib/api/client.ts @@ -18,6 +18,9 @@ import type { Visibility, PaginatedResponse, Workflow, + ComponentListResponse, + ComponentDataResponse, + ComponentTimeseriesResponse, } from "$lib/types.js"; const API_BASE = '/api/v1'; @@ -132,6 +135,53 @@ export const networks = { method: 'PATCH', body: JSON.stringify({ visibility }) }); + }, + async getComponents(id: string): Promise { + return request(`/networks/${id}/components`, {}, `components-${id}`); + }, + async getComponentData( + id: string, + componentName: string, + params: { skip?: number; limit?: number; sort_by?: string; sort_desc?: boolean; search?: string } = {} + ): Promise { + const searchParams = new URLSearchParams(); + if (params.skip !== undefined) searchParams.set('skip', String(params.skip)); + if (params.limit !== undefined) searchParams.set('limit', String(params.limit)); + if (params.sort_by) searchParams.set('sort_by', params.sort_by); + if (params.sort_desc) searchParams.set('sort_desc', 'true'); + if (params.search) searchParams.set('search', params.search); + const qs = searchParams.toString(); + return request( + `/networks/${id}/components/${componentName}${qs ? '?' + qs : ''}`, + {}, + `component-data-${id}-${componentName}` + ); + }, + async getComponentTimeseries( + id: string, + componentName: string, + attr: string, + params: { skip?: number; limit?: number } = {} + ): Promise { + const searchParams = new URLSearchParams(); + if (params.skip !== undefined) searchParams.set('skip', String(params.skip)); + if (params.limit !== undefined) searchParams.set('limit', String(params.limit)); + const qs = searchParams.toString(); + return request( + `/networks/${id}/components/${componentName}/timeseries/${attr}${qs ? '?' + qs : ''}`, + {}, + `component-ts-${id}-${componentName}-${attr}` + ); + }, + async updateComponentData( + id: string, + componentName: string, + updates: Record> + ): Promise<{ message: string }> { + return request<{ message: string }>(`/networks/${id}/components/${componentName}`, { + method: 'PATCH', + body: JSON.stringify({ updates }) + }); } }; diff --git a/frontend/app/src/lib/types.ts b/frontend/app/src/lib/types.ts index a1c2467..e5b28bb 100644 --- a/frontend/app/src/lib/types.ts +++ b/frontend/app/src/lib/types.ts @@ -267,6 +267,45 @@ export interface Workflow { errors: WorkflowError[]; } +// Component types + +export interface ComponentSummary { + name: string; + list_name: string; + count: number; + category: string | null; + attrs: string[]; + has_dynamic: boolean; + dynamic_attrs: string[]; +} + +export interface ComponentListResponse { + components: ComponentSummary[]; + total_components: number; +} + +export interface ComponentDataResponse { + component: string; + columns: string[]; + index: string[]; + data: (string | number | boolean | null)[][]; + dtypes: Record; + total: number; + skip: number; + limit: number; +} + +export interface ComponentTimeseriesResponse { + component: string; + attr: string; + columns: string[]; + index: string[]; + data: (number | null)[][]; + total_snapshots: number; + skip: number; + limit: number; +} + // API error type export interface ApiError extends Error { diff --git a/frontend/app/src/routes/database/network/+page.svelte b/frontend/app/src/routes/database/network/+page.svelte index d1fba1e..351ebc6 100644 --- a/frontend/app/src/routes/database/network/+page.svelte +++ b/frontend/app/src/routes/database/network/+page.svelte @@ -5,6 +5,7 @@ import { goto } from '$app/navigation'; import { networks, plots } from '$lib/api/client.js'; import type { Network as NetworkType, PlotData, PlotResponse, ApiError } from '$lib/types.js'; + import ComponentBrowser from './components/ComponentBrowser.svelte'; import { formatFileSize, formatDate, formatRelativeTime, formatNumber, getDirectoryPath, getTagType, getTagColor } from '$lib/utils.js'; import { Network, AlertCircle, FolderOpen, Clock, CalendarRange, Waypoints, ChevronLeft, ChevronRight, SlidersHorizontal, PanelRight } from 'lucide-svelte'; import { toast } from 'svelte-sonner'; @@ -1457,6 +1458,14 @@ async function loadPlot(statistic: string, plotType: string, parameters: Record< + + {#if networkId} +
+

Component Data

+ +
+ {/if} +

Plots & Statistics

diff --git a/frontend/app/src/routes/database/network/components/ComponentBrowser.svelte b/frontend/app/src/routes/database/network/components/ComponentBrowser.svelte new file mode 100644 index 0000000..d754d79 --- /dev/null +++ b/frontend/app/src/routes/database/network/components/ComponentBrowser.svelte @@ -0,0 +1,349 @@ + + +
+
+ +
+
+

+ + Components +

+
+ + {#if loadingComponents} +
+ {#each Array(5) as _} +
+ {/each} +
+ {:else} +
+ {#each components as comp} + + {/each} +
+ {/if} +
+ + +
+ {#if selectedComponent} + +
+
+

{selectedComponent.name}

+ + {selectedComponent.count} rows + + {#if selectedComponent.category} + + {getCategoryLabel(selectedComponent.category)} + + {/if} + {#if selectedComponent.has_dynamic} + + Time Series + + {/if} +
+ + +
+ + +
+
+ + +
+ {#if loadingData} +
+ +
+ {:else if componentData && componentData.data.length > 0} + + + + + + {#each componentData.columns as col} + + {/each} + + + + {#each componentData.data as row, rowIdx} + + + {#each row as cell, colIdx} + {@const formatted = formatCellValue(cell, componentData.dtypes[componentData.columns[colIdx]] || '')} + + {/each} + + {/each} + +
+ + + +
+ {componentData.index[rowIdx]} + + {formatted} +
+ {:else if componentData} +
+ {searchQuery ? 'No matching rows found' : 'No data available'} +
+ {/if} +
+ + + {#if componentData && componentData.total > pageSize} +
+
+ Showing {skip + 1}-{Math.min(skip + pageSize, componentData.total)} of {componentData.total.toLocaleString()} +
+
+ + + + {currentPage} / {totalPages} + + + +
+
+ {/if} + {:else} +
+ Select a component type to view its data +
+ {/if} +
+
+
diff --git a/src/pypsa_app/backend/api/routes/components.py b/src/pypsa_app/backend/api/routes/components.py new file mode 100644 index 0000000..f12c5ca --- /dev/null +++ b/src/pypsa_app/backend/api/routes/components.py @@ -0,0 +1,269 @@ +"""API routes for browsing and editing network component data.""" + +import logging +from pathlib import Path + +import pandas as pd +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.orm import Session + +from pypsa_app.backend.api.deps import ( + Authorized, + get_db, + require_network, +) +from pypsa_app.backend.models import Network +from pypsa_app.backend.schemas.components import ( + ComponentDataResponse, + ComponentListResponse, + ComponentSummary, + ComponentTimeseriesResponse, + ComponentUpdateRequest, +) +from pypsa_app.backend.services.network import NetworkService, _network_cache +from pypsa_app.backend.utils.path_validation import validate_path +from pypsa_app.backend.utils.serializers import _sanitize_float + +router = APIRouter() +logger = logging.getLogger(__name__) + +EXCLUDED_SUFFIXES = ("Type",) + + +def _load_network(network: Network, *, use_cache: bool = True) -> NetworkService: + """Load a PyPSA network from its file path.""" + return NetworkService(network.file_path, use_cache=use_cache) + + +def _find_component(n, component_name: str): + """Find a component by name or list_name. Raises 404 if not found.""" + for c in n.components: + if c.name == component_name or c.list_name == component_name: + return c + raise HTTPException(404, f"Component '{component_name}' not found in network") + + +def _get_dynamic_attrs(n, list_name: str) -> list[str]: + """Get non-empty time-varying attribute names for a component.""" + dynamic_attr = f"{list_name}_t" + if not hasattr(n, dynamic_attr): + return [] + + dynamic_store = getattr(n, dynamic_attr) + attrs = [] + # Use pandas-based iteration over known DataFrame attributes + # rather than dir() which can expose internal Python attributes + for attr_name in dynamic_store: + try: + attr_val = getattr(dynamic_store, attr_name, None) + if isinstance(attr_val, pd.DataFrame) and len(attr_val) > 0: + attrs.append(attr_name) + except Exception: + continue + return sorted(attrs) + + +@router.get("/{network_id}/components", response_model=ComponentListResponse) +def list_components( + auth: Authorized[Network] = Depends(require_network("read")), +) -> ComponentListResponse: + """List all component types in a network with their counts and attributes.""" + service = _load_network(auth.model) + n = service.n + + components = [] + for c in n.components: + if c.name.endswith(EXCLUDED_SUFFIXES): + continue + if len(c) == 0: + continue + + static_df = getattr(n, c.list_name) + dynamic_attrs = _get_dynamic_attrs(n, c.list_name) + + components.append( + ComponentSummary( + name=c.name, + list_name=c.list_name, + count=len(c), + category=getattr(c, "category", None) or None, + attrs=list(static_df.columns), + has_dynamic=len(dynamic_attrs) > 0, + dynamic_attrs=dynamic_attrs, + ) + ) + + # Sort by count descending + components.sort(key=lambda c: c.count, reverse=True) + + return ComponentListResponse( + components=components, + total_components=len(components), + ) + + +@router.get( + "/{network_id}/components/{component_name}", + response_model=ComponentDataResponse, +) +def get_component_data( + component_name: str, + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + sort_by: str | None = Query(None, description="Column to sort by"), + sort_desc: bool = Query(False, description="Sort descending"), + search: str | None = Query(None, description="Filter rows by index substring"), + auth: Authorized[Network] = Depends(require_network("read")), +) -> ComponentDataResponse: + """Get paginated static data for a specific component type.""" + service = _load_network(auth.model) + n = service.n + + component = _find_component(n, component_name) + df = getattr(n, component.list_name).copy() + + # Apply search filter on index (literal substring match, not regex) + if search: + mask = df.index.astype(str).str.contains( + search, case=False, na=False, regex=False + ) + df = df[mask] + + total = len(df) + + # Apply sorting + if sort_by and sort_by in df.columns: + df = df.sort_values(sort_by, ascending=not sort_desc, na_position="last") + + # Paginate + df_page = df.iloc[skip : skip + limit] + + # Sanitize values for JSON + data = [] + for _, row in df_page.iterrows(): + data.append([_sanitize_float(v) for v in row.tolist()]) + + dtypes = {col: str(df[col].dtype) for col in df.columns} + + return ComponentDataResponse( + component=component.name, + columns=list(df.columns), + index=[str(idx) for idx in df_page.index], + data=data, + dtypes=dtypes, + total=total, + skip=skip, + limit=limit, + ) + + +@router.get( + "/{network_id}/components/{component_name}/timeseries/{attr}", + response_model=ComponentTimeseriesResponse, +) +def get_component_timeseries( + component_name: str, + attr: str, + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=5000), + auth: Authorized[Network] = Depends(require_network("read")), +) -> ComponentTimeseriesResponse: + """Get time-varying data for a specific component attribute.""" + service = _load_network(auth.model) + n = service.n + + component = _find_component(n, component_name) + + dynamic_attr = f"{component.list_name}_t" + if not hasattr(n, dynamic_attr): + raise HTTPException(404, f"No time-varying data for '{component_name}'") + + dynamic_store = getattr(n, dynamic_attr) + if not hasattr(dynamic_store, attr): + raise HTTPException( + 404, f"Attribute '{attr}' not found in {component_name} time series" + ) + + ts_df = getattr(dynamic_store, attr) + if not isinstance(ts_df, pd.DataFrame) or len(ts_df) == 0: + raise HTTPException( + 404, f"No data for '{attr}' in {component_name} time series" + ) + + total_snapshots = len(ts_df) + ts_page = ts_df.iloc[skip : skip + limit] + + data = [] + for _, row in ts_page.iterrows(): + data.append([_sanitize_float(v) for v in row.tolist()]) + + return ComponentTimeseriesResponse( + component=component.name, + attr=attr, + columns=[str(c) for c in ts_page.columns], + index=[str(idx) for idx in ts_page.index], + data=data, + total_snapshots=total_snapshots, + skip=skip, + limit=limit, + ) + + +@router.patch("/{network_id}/components/{component_name}") +def update_component_data( + component_name: str, + body: ComponentUpdateRequest, + auth: Authorized[Network] = Depends(require_network("modify")), + db: Session = Depends(get_db), +) -> dict: + """Update static data for specific component rows. + + Loads a fresh (uncached) copy, applies changes, exports to disk, + then invalidates only the affected cache entry. + """ + # Load fresh copy to avoid mutating shared cached objects + service = _load_network(auth.model, use_cache=False) + n = service.n + + component = _find_component(n, component_name) + df = getattr(n, component.list_name) + + # Validate all indices exist + missing = [idx for idx in body.updates if idx not in df.index] + if missing: + raise HTTPException(404, f"Component indices not found: {missing}") + + # Validate all columns exist + all_columns = set() + for changes in body.updates.values(): + all_columns.update(changes.keys()) + invalid_cols = all_columns - set(df.columns) + if invalid_cols: + raise HTTPException(400, f"Invalid columns: {list(invalid_cols)}") + + # Apply updates + changes_count = 0 + for idx, changes in body.updates.items(): + for col, value in changes.items(): + df.at[idx, col] = value + changes_count += 1 + + # Validate and export to file + safe_path = validate_path(Path(auth.model.file_path), must_exist=True) + n.export(safe_path) + + # Invalidate only the affected network in cache + with _network_cache._lock: + _network_cache.cache.pop(str(safe_path), None) + + logger.info( + "Component data updated", + extra={ + "network_id": str(auth.model.id), + "component": component_name, + "changes_count": changes_count, + "updated_by": auth.user.username, + }, + ) + + return {"message": f"Updated {changes_count} values in {component_name}"} diff --git a/src/pypsa_app/backend/main.py b/src/pypsa_app/backend/main.py index 9e84cc2..f5e4dc3 100644 --- a/src/pypsa_app/backend/main.py +++ b/src/pypsa_app/backend/main.py @@ -18,6 +18,7 @@ api_keys, auth, cache, + components, networks, plots, runs, @@ -291,6 +292,9 @@ async def global_exception_handler(request: Request, exc: Exception) -> JSONResp app.include_router( networks.router, prefix=f"{API_V1_PREFIX}/networks", tags=["networks"] ) +app.include_router( + components.router, prefix=f"{API_V1_PREFIX}/networks", tags=["components"] +) app.include_router(plots.router, prefix=f"{API_V1_PREFIX}/plots", tags=["plots"]) app.include_router( analysis.router, diff --git a/src/pypsa_app/backend/schemas/components.py b/src/pypsa_app/backend/schemas/components.py new file mode 100644 index 0000000..0ca7e83 --- /dev/null +++ b/src/pypsa_app/backend/schemas/components.py @@ -0,0 +1,57 @@ +"""Schemas for network component data browsing and editing.""" + +from typing import Any + +from pydantic import BaseModel + + +class ComponentSummary(BaseModel): + """Summary of a single network component type.""" + + name: str + list_name: str + count: int + category: str | None = None + attrs: list[str] = [] + has_dynamic: bool = False + dynamic_attrs: list[str] = [] + + +class ComponentListResponse(BaseModel): + """List of all components in a network.""" + + components: list[ComponentSummary] + total_components: int + + +class ComponentDataResponse(BaseModel): + """Paginated component data (DataFrame rows).""" + + component: str + columns: list[str] + index: list[str] + data: list[list[Any]] + dtypes: dict[str, str] + total: int + skip: int + limit: int + + +class ComponentTimeseriesResponse(BaseModel): + """Time-varying data for a component attribute.""" + + component: str + attr: str + columns: list[str] + index: list[str] + data: list[list[Any]] + total_snapshots: int + skip: int + limit: int + + +class ComponentUpdateRequest(BaseModel): + """Request to update component rows.""" + + updates: dict[str, dict[str, Any]] + """Mapping of index (component name) -> {column: new_value}""" From 3f6fd6c0f89d840349fd42075b8247ebc3dec30a Mon Sep 17 00:00:00 2001 From: Mayk Thewessen Date: Mon, 6 Apr 2026 19:35:05 +0200 Subject: [PATCH 07/17] feat: add inline editing for network component data Edit mode with pencil toggle, batch save with yellow highlights for modified cells, confirm-on-discard for unsaved changes, and pagination/ sort disabled during editing to prevent data loss. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../components/ComponentBrowser.svelte | 238 +++++++++++++++--- 1 file changed, 205 insertions(+), 33 deletions(-) diff --git a/frontend/app/src/routes/database/network/components/ComponentBrowser.svelte b/frontend/app/src/routes/database/network/components/ComponentBrowser.svelte index d754d79..3292306 100644 --- a/frontend/app/src/routes/database/network/components/ComponentBrowser.svelte +++ b/frontend/app/src/routes/database/network/components/ComponentBrowser.svelte @@ -2,9 +2,10 @@ import { onDestroy } from 'svelte'; import { networks } from '$lib/api/client.js'; import type { ComponentSummary, ComponentDataResponse, ApiError } from '$lib/types.js'; - import { Database, Search, ChevronLeft, ChevronRight, ChevronsLeft, ChevronsRight, ArrowUpDown, ArrowUp, ArrowDown, Loader2 } from 'lucide-svelte'; + import { Database, Search, ChevronLeft, ChevronRight, ChevronsLeft, ChevronsRight, ArrowUpDown, ArrowUp, ArrowDown, Loader2, Pencil, Save, X } from 'lucide-svelte'; import Badge from '$lib/components/ui/badge/badge.svelte'; import Button from '$lib/components/ui/button/button.svelte'; + import { toast } from 'svelte-sonner'; let { networkId }: { networkId: string } = $props(); @@ -28,9 +29,20 @@ let searchQuery = $state(''); let searchTimeout: ReturnType; + // Edit mode + let editMode = $state(false); + let pendingEdits = $state>>(new Map()); + let saving = $state(false); + // Derived let totalPages = $derived(componentData ? Math.ceil(componentData.total / pageSize) : 0); let skip = $derived((currentPage - 1) * pageSize); + let hasEdits = $derived(pendingEdits.size > 0); + let editCount = $derived(() => { + let count = 0; + for (const cols of pendingEdits.values()) count += cols.size; + return count; + }); // Cleanup on unmount onDestroy(() => { @@ -47,7 +59,6 @@ // Reload data when pagination/sort/search changes (single reactive source of truth) $effect(() => { if (selectedComponent && networkId) { - // Track all reactive values that should trigger a reload const _page = currentPage; const _sort = sortBy; const _sortDir = sortDesc; @@ -58,14 +69,14 @@ async function loadComponents() { loadingComponents = true; - // Reset stale state from previous network selectedComponent = null; componentData = null; + editMode = false; + pendingEdits = new Map(); error = null; try { const response = await networks.getComponents(networkId); components = response.components; - // Auto-select first component (the $effect above will trigger data load) if (components.length > 0) { selectedComponent = components[0]; currentPage = 1; @@ -82,12 +93,16 @@ } function selectComponent(comp: ComponentSummary) { + if (editMode && hasEdits) { + if (!confirm('You have unsaved changes. Discard them?')) return; + } selectedComponent = comp; currentPage = 1; sortBy = null; sortDesc = false; searchQuery = ''; - // The $effect watching selectedComponent will trigger loadComponentData + editMode = false; + pendingEdits = new Map(); } async function loadComponentData() { @@ -110,9 +125,9 @@ } function handleSort(column: string) { + if (editMode) return; // Disable sort while editing if (sortBy === column) { if (sortDesc) { - // Third click: clear sort sortBy = null; sortDesc = false; } else { @@ -123,7 +138,6 @@ sortDesc = false; } currentPage = 1; - // The $effect will trigger loadComponentData } function handleSearch(e: Event) { @@ -132,7 +146,6 @@ searchTimeout = setTimeout(() => { searchQuery = value; currentPage = 1; - // The $effect will trigger loadComponentData }, 300); } @@ -142,7 +155,96 @@ } } - // Category display helpers + // Edit mode functions + function toggleEditMode() { + if (editMode && hasEdits) { + if (!confirm('You have unsaved changes. Discard them?')) return; + } + editMode = !editMode; + pendingEdits = new Map(); + } + + function handleCellEdit(rowIndex: string, column: string, value: string, dtype: string) { + const parsed = parseValue(value, dtype); + const rowEdits = pendingEdits.get(rowIndex) ?? new Map(); + + // Check if the new value matches the original + if (componentData) { + const rowIdx = componentData.index.indexOf(rowIndex); + const colIdx = componentData.columns.indexOf(column); + if (rowIdx >= 0 && colIdx >= 0) { + const original = componentData.data[rowIdx][colIdx]; + if (parsed === original || (parsed === null && original === null)) { + rowEdits.delete(column); + if (rowEdits.size === 0) { + pendingEdits.delete(rowIndex); + } + pendingEdits = new Map(pendingEdits); // trigger reactivity + return; + } + } + } + + rowEdits.set(column, parsed); + pendingEdits.set(rowIndex, rowEdits); + pendingEdits = new Map(pendingEdits); // trigger reactivity + } + + function parseValue(value: string, dtype: string): unknown { + if (value === '' || value === '-') return null; + if (dtype.startsWith('float')) { + const num = parseFloat(value); + return isNaN(num) ? value : num; + } + if (dtype.startsWith('int') || dtype.startsWith('uint')) { + const num = parseInt(value, 10); + return isNaN(num) ? value : num; + } + if (dtype === 'bool') return value === 'true' || value === '1'; + return value; + } + + function getEditedValue(rowIndex: string, column: string): unknown | undefined { + return pendingEdits.get(rowIndex)?.get(column); + } + + function isEdited(rowIndex: string, column: string): boolean { + return pendingEdits.get(rowIndex)?.has(column) ?? false; + } + + async function saveEdits() { + if (!selectedComponent || !hasEdits) return; + saving = true; + + // Convert Map to plain object for API + const updates: Record> = {}; + for (const [rowIndex, colEdits] of pendingEdits) { + updates[rowIndex] = {}; + for (const [col, val] of colEdits) { + updates[rowIndex][col] = val; + } + } + + try { + const result = await networks.updateComponentData(networkId, selectedComponent.name, updates); + toast.success(result.message); + pendingEdits = new Map(); + editMode = false; + // Reload data to reflect saved changes + await loadComponentData(); + } catch (err: unknown) { + toast.error((err as Error).message || 'Failed to save changes'); + } finally { + saving = false; + } + } + + function cancelEdits() { + pendingEdits = new Map(); + editMode = false; + } + + // Display helpers function getCategoryLabel(category: string | null): string { if (!category) return 'Core'; return category.split('_').map(w => w.charAt(0).toUpperCase() + w.slice(1)).join(' '); @@ -168,6 +270,28 @@ if (typeof value === 'boolean') return value ? 'true' : 'false'; return String(value); } + + function getCellDisplayValue(rowIdx: number, colIdx: number): string { + if (!componentData) return '-'; + const rowIndex = componentData.index[rowIdx]; + const column = componentData.columns[colIdx]; + const dtype = componentData.dtypes[column] || ''; + const edited = getEditedValue(rowIndex, column); + if (edited !== undefined) { + return formatCellValue(edited, dtype); + } + return formatCellValue(componentData.data[rowIdx][colIdx], dtype); + } + + function getCellRawValue(rowIdx: number, colIdx: number): string { + if (!componentData) return ''; + const rowIndex = componentData.index[rowIdx]; + const column = componentData.columns[colIdx]; + const edited = getEditedValue(rowIndex, column); + const value = edited !== undefined ? edited : componentData.data[rowIdx][colIdx]; + if (value === null || value === undefined) return ''; + return String(value); + }
@@ -229,16 +353,46 @@ {/if}
- -
- - +
+ + {#if !editMode} +
+ + +
+ {/if} + + + {#if editMode} + {#if hasEdits} + + {editCount()} change{editCount() !== 1 ? 's' : ''} + + {/if} + + + {:else} + + {/if}
@@ -252,11 +406,11 @@ - {#each componentData.data as row, rowIdx} + {@const rowIndex = componentData.index[rowIdx]} {#each row as cell, colIdx} - {@const formatted = formatCellValue(cell, componentData.dtypes[componentData.columns[colIdx]] || '')} - + {@const column = componentData.columns[colIdx]} + {@const dtype = componentData.dtypes[column] || ''} + {@const edited = isEdited(rowIndex, column)} + {#if editMode} + + {:else} + {@const formatted = getCellDisplayValue(rowIdx, colIdx)} + + {/if} {/each} {/each} @@ -321,19 +493,19 @@ Showing {skip + 1}-{Math.min(skip + pageSize, componentData.total)} of {componentData.total.toLocaleString()}
- - {currentPage} / {totalPages} - -
From 8875398a32a885d4ed88f09fcca92f9e35cf7fb8 Mon Sep 17 00:00:00 2001 From: Mayk Thewessen Date: Mon, 6 Apr 2026 20:10:43 +0200 Subject: [PATCH 08/17] feat: add saved dashboard views with CRUD API and frontend UI Users can save the current dashboard configuration (active tab, carrier/ country filters, individual plot mode) as named views, load them later, and share public views. Includes SavedView model, Alembic migration, CRUD API endpoints, and Save/Load UI in the network detail header. Co-Authored-By: Claude Opus 4.6 (1M context) --- frontend/app/src/lib/api/client.ts | 35 ++++ frontend/app/src/lib/types.ts | 34 ++++ .../src/routes/database/network/+page.svelte | 55 +++++- .../network/components/SaveViewDialog.svelte | 160 ++++++++++++++++ .../network/components/ViewSelector.svelte | 141 ++++++++++++++ .../alembic/versions/0003_add_saved_views.py | 68 +++++++ src/pypsa_app/backend/api/routes/views.py | 176 ++++++++++++++++++ src/pypsa_app/backend/main.py | 2 + src/pypsa_app/backend/models.py | 32 ++++ src/pypsa_app/backend/schemas/views.py | 80 ++++++++ 10 files changed, 778 insertions(+), 5 deletions(-) create mode 100644 frontend/app/src/routes/database/network/components/SaveViewDialog.svelte create mode 100644 frontend/app/src/routes/database/network/components/ViewSelector.svelte create mode 100644 src/pypsa_app/backend/alembic/versions/0003_add_saved_views.py create mode 100644 src/pypsa_app/backend/api/routes/views.py create mode 100644 src/pypsa_app/backend/schemas/views.py diff --git a/frontend/app/src/lib/api/client.ts b/frontend/app/src/lib/api/client.ts index b4a6107..dba926f 100644 --- a/frontend/app/src/lib/api/client.ts +++ b/frontend/app/src/lib/api/client.ts @@ -21,6 +21,9 @@ import type { ComponentListResponse, ComponentDataResponse, ComponentTimeseriesResponse, + SavedView, + SavedViewListResponse, + ViewConfig, } from "$lib/types.js"; const API_BASE = '/api/v1'; @@ -349,6 +352,38 @@ export const runs = { } }; +// Views API +export const savedViews = { + async list(networkId?: string, skip = 0, limit = 50): Promise { + const params = new URLSearchParams({ skip: String(skip), limit: String(limit) }); + if (networkId) params.set('network_id', networkId); + return request(`/views/?${params}`); + }, + async get(id: string): Promise { + return request(`/views/${id}`); + }, + async create(body: { + name: string; + description?: string; + network_id?: string; + visibility?: string; + config: ViewConfig; + }): Promise { + return request('/views/', { method: 'POST', body: JSON.stringify(body) }); + }, + async update(id: string, body: { + name?: string; + description?: string; + visibility?: string; + config?: ViewConfig; + }): Promise { + return request(`/views/${id}`, { method: 'PATCH', body: JSON.stringify(body) }); + }, + async delete(id: string): Promise { + return request(`/views/${id}`, { method: 'DELETE' }); + } +}; + // Cache API export const cache = { async clearNetwork(networkId: string): Promise { diff --git a/frontend/app/src/lib/types.ts b/frontend/app/src/lib/types.ts index e5b28bb..70ed830 100644 --- a/frontend/app/src/lib/types.ts +++ b/frontend/app/src/lib/types.ts @@ -306,6 +306,40 @@ export interface ComponentTimeseriesResponse { limit: number; } +// Saved view types + +export interface ViewConfig { + active_tab?: string; + statistic?: string; + plot_type?: string; + selected_carriers: string[]; + selected_countries: string[]; + individual_plots: boolean; + analysis_type?: string; + analysis_parameters: Record; + selected_component?: string; + component_columns?: string[]; + compare_network_ids: string[]; + extra: Record; +} + +export interface SavedView { + id: string; + name: string; + description?: string; + network_id?: string; + visibility: Visibility; + config: ViewConfig; + owner: User; + created_at?: string; + updated_at?: string; +} + +export interface SavedViewListResponse { + data: SavedView[]; + total: number; +} + // API error type export interface ApiError extends Error { diff --git a/frontend/app/src/routes/database/network/+page.svelte b/frontend/app/src/routes/database/network/+page.svelte index 351ebc6..060229d 100644 --- a/frontend/app/src/routes/database/network/+page.svelte +++ b/frontend/app/src/routes/database/network/+page.svelte @@ -5,7 +5,10 @@ import { goto } from '$app/navigation'; import { networks, plots } from '$lib/api/client.js'; import type { Network as NetworkType, PlotData, PlotResponse, ApiError } from '$lib/types.js'; + import type { ViewConfig } from '$lib/types.js'; import ComponentBrowser from './components/ComponentBrowser.svelte'; + import SaveViewDialog from './components/SaveViewDialog.svelte'; + import ViewSelector from './components/ViewSelector.svelte'; import { formatFileSize, formatDate, formatRelativeTime, formatNumber, getDirectoryPath, getTagType, getTagColor } from '$lib/utils.js'; import { Network, AlertCircle, FolderOpen, Clock, CalendarRange, Waypoints, ChevronLeft, ChevronRight, SlidersHorizontal, PanelRight } from 'lucide-svelte'; import { toast } from 'svelte-sonner'; @@ -824,6 +827,42 @@ async function loadPlot(statistic: string, plotType: string, parameters: Record< } } + function buildCurrentViewConfig(): ViewConfig { + return { + active_tab: activeTab, + statistic: tabs.find(t => t.id === activeTab)?.statistic, + plot_type: tabs.find(t => t.id === activeTab)?.plotType, + selected_carriers: selectedCarriers, + selected_countries: selectedCountries, + individual_plots: individualPlots, + analysis_type: undefined, + analysis_parameters: {}, + selected_component: undefined, + component_columns: undefined, + compare_network_ids: compareMode ? networkIds : [], + extra: {}, + }; + } + + function handleLoadView(config: ViewConfig) { + // Apply carriers + if (config.selected_carriers?.length > 0) { + selectedCarriersStore.set(new Set(config.selected_carriers)); + } + // Apply countries + if (config.selected_countries?.length > 0) { + selectedCountriesStore.set(new Set(config.selected_countries)); + } + // Apply individual plots + if (config.individual_plots !== undefined) { + showIndividualPlots.set(config.individual_plots); + } + // Apply tab + if (config.active_tab) { + activeTab = config.active_tab; + } + } + function buildFilterParameters(tabConfig: TabConfig, carriers: string[]) { const params: Record = { ...tabConfig.parameters @@ -1228,11 +1267,17 @@ async function loadPlot(statistic: string, plotType: string, parameters: Record<
-
-

{network.filename}

- {#if network.name} -

{network.name}

- {/if} +
+
+

{network.filename}

+ {#if network.name} +

{network.name}

+ {/if} +
+
+ + +
diff --git a/frontend/app/src/routes/database/network/components/SaveViewDialog.svelte b/frontend/app/src/routes/database/network/components/SaveViewDialog.svelte new file mode 100644 index 0000000..88337e1 --- /dev/null +++ b/frontend/app/src/routes/database/network/components/SaveViewDialog.svelte @@ -0,0 +1,160 @@ + + + + +{#if open} + + +{/if} diff --git a/frontend/app/src/routes/database/network/components/ViewSelector.svelte b/frontend/app/src/routes/database/network/components/ViewSelector.svelte new file mode 100644 index 0000000..a186a67 --- /dev/null +++ b/frontend/app/src/routes/database/network/components/ViewSelector.svelte @@ -0,0 +1,141 @@ + + + + +
+ + + {#if open} +
+ {#if loading} +
+ +
+ {:else if views.length === 0} +
+ No saved views yet +
+ {:else} +
+ {#each views as view} +
handleSelect(view)} + onkeydown={(e) => { if (e.key === 'Enter') handleSelect(view); }} + role="button" + tabindex="0" + class="w-full text-left px-3 py-2.5 hover:bg-muted/50 transition-colors border-b border-border/30 last:border-b-0 group cursor-pointer" + > +
+
+
+ {#if view.visibility === 'public'} + + {:else} + + {/if} + {view.name} +
+ {#if view.description} +
{view.description}
+ {/if} +
+ {view.owner.username} — {formatDate(view.updated_at || view.created_at)} +
+
+ +
+
+ {/each} +
+ {/if} +
+ {/if} +
diff --git a/src/pypsa_app/backend/alembic/versions/0003_add_saved_views.py b/src/pypsa_app/backend/alembic/versions/0003_add_saved_views.py new file mode 100644 index 0000000..39e4edf --- /dev/null +++ b/src/pypsa_app/backend/alembic/versions/0003_add_saved_views.py @@ -0,0 +1,68 @@ +"""Add saved_views table for custom dashboard configurations. + +Revision ID: 0003 +Revises: 0002 +Create Date: 2026-04-06 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "0003" +down_revision: str | None = "0002" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "saved_views", + sa.Column("id", sa.Uuid(), primary_key=True), + sa.Column( + "user_id", + sa.Uuid(), + sa.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), + sa.Column( + "network_id", + sa.Uuid(), + sa.ForeignKey("networks.id", ondelete="CASCADE"), + nullable=True, + index=True, + ), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column( + "visibility", + sa.Enum( + "public", + "private", + name="visibility", + native_enum=True, + create_type=False, + ), + nullable=False, + server_default="private", + ), + sa.Column("config", sa.JSON(), nullable=False), + sa.Column( + "created_at", + sa.TIMESTAMP(), + server_default=sa.func.now(), + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(), + server_default=sa.func.now(), + ), + ) + + +def downgrade() -> None: + op.drop_table("saved_views") diff --git a/src/pypsa_app/backend/api/routes/views.py b/src/pypsa_app/backend/api/routes/views.py new file mode 100644 index 0000000..50fc5d1 --- /dev/null +++ b/src/pypsa_app/backend/api/routes/views.py @@ -0,0 +1,176 @@ +"""API routes for saved dashboard views.""" + +import logging +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy import or_ +from sqlalchemy.orm import Session, joinedload + +from pypsa_app.backend.api.deps import get_db, require_permission +from pypsa_app.backend.models import Permission, SavedView, User, Visibility +from pypsa_app.backend.permissions import has_permission +from pypsa_app.backend.schemas.views import ( + SavedViewCreate, + SavedViewListResponse, + SavedViewResponse, + SavedViewUpdate, +) + +router = APIRouter() +logger = logging.getLogger(__name__) + + +@router.post("/", response_model=SavedViewResponse, status_code=201) +def create_view( + body: SavedViewCreate, + db: Session = Depends(get_db), + user: User = Depends(require_permission(Permission.NETWORKS_VIEW)), +) -> SavedView: + """Create a new saved dashboard view.""" + view = SavedView( + user_id=user.id, + network_id=body.network_id, + name=body.name, + description=body.description, + visibility=body.visibility, + config=body.config.model_dump(), + ) + db.add(view) + db.commit() + db.refresh(view) + + logger.info( + "Saved view created", + extra={ + "view_id": str(view.id), + "view_name": view.name, + "user": user.username, + }, + ) + return view + + +@router.get("/", response_model=SavedViewListResponse) +def list_views( + network_id: UUID | None = Query(None, description="Filter by network ID"), + skip: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=200), + db: Session = Depends(get_db), + user: User = Depends(require_permission(Permission.NETWORKS_VIEW)), +) -> SavedViewListResponse: + """List saved views accessible to the current user.""" + query = db.query(SavedView).options(joinedload(SavedView.owner)) + + # Users see their own views + public views + if not has_permission(user, Permission.NETWORKS_MANAGE_ALL): + query = query.filter( + or_( + SavedView.user_id == user.id, + SavedView.visibility == Visibility.PUBLIC, + ) + ) + + if network_id: + query = query.filter( + or_( + SavedView.network_id == network_id, + SavedView.network_id.is_(None), # Global views apply to any network + ) + ) + + total = query.count() + views = ( + query.order_by(SavedView.updated_at.desc()) + .offset(skip) + .limit(limit) + .all() + ) + + return SavedViewListResponse(data=views, total=total) + + +@router.get("/{view_id}", response_model=SavedViewResponse) +def get_view( + view_id: UUID, + db: Session = Depends(get_db), + user: User = Depends(require_permission(Permission.NETWORKS_VIEW)), +) -> SavedView: + """Get a saved view by ID.""" + view = ( + db.query(SavedView) + .options(joinedload(SavedView.owner)) + .filter(SavedView.id == view_id) + .first() + ) + if not view: + raise HTTPException(404, "View not found") + + # Check access: own views or public + if view.user_id != user.id and view.visibility != Visibility.PUBLIC: + if not has_permission(user, Permission.NETWORKS_MANAGE_ALL): + raise HTTPException(404, "View not found") + + return view + + +@router.patch("/{view_id}", response_model=SavedViewResponse) +def update_view( + view_id: UUID, + body: SavedViewUpdate, + db: Session = Depends(get_db), + user: User = Depends(require_permission(Permission.NETWORKS_VIEW)), +) -> SavedView: + """Update a saved view. Only the owner can update.""" + view = ( + db.query(SavedView) + .options(joinedload(SavedView.owner)) + .filter(SavedView.id == view_id) + .first() + ) + if not view: + raise HTTPException(404, "View not found") + + if view.user_id != user.id and not has_permission( + user, Permission.NETWORKS_MANAGE_ALL + ): + raise HTTPException(403, "You can only update your own views") + + if body.name is not None: + view.name = body.name + if body.description is not None: + view.description = body.description + if body.visibility is not None: + view.visibility = body.visibility + if body.config is not None: + view.config = body.config.model_dump() + + db.commit() + db.refresh(view) + return view + + +@router.delete("/{view_id}") +def delete_view( + view_id: UUID, + db: Session = Depends(get_db), + user: User = Depends(require_permission(Permission.NETWORKS_VIEW)), +) -> dict: + """Delete a saved view. Only the owner can delete.""" + view = db.query(SavedView).filter(SavedView.id == view_id).first() + if not view: + raise HTTPException(404, "View not found") + + if view.user_id != user.id and not has_permission( + user, Permission.NETWORKS_MANAGE_ALL + ): + raise HTTPException(403, "You can only delete your own views") + + db.delete(view) + db.commit() + + logger.info( + "Saved view deleted", + extra={"view_id": str(view_id), "user": user.username}, + ) + return {"message": "View deleted"} diff --git a/src/pypsa_app/backend/main.py b/src/pypsa_app/backend/main.py index f5e4dc3..85c2d0f 100644 --- a/src/pypsa_app/backend/main.py +++ b/src/pypsa_app/backend/main.py @@ -25,6 +25,7 @@ statistics, tasks, version, + views, ) from pypsa_app.backend.auth.authenticate import set_auth_disabled_user from pypsa_app.backend.cache import cache_service @@ -306,6 +307,7 @@ async def global_exception_handler(request: Request, exc: Exception) -> JSONResp prefix=f"{API_V1_PREFIX}/statistics", tags=["statistics"], ) +app.include_router(views.router, prefix=f"{API_V1_PREFIX}/views", tags=["views"]) app.include_router(cache.router, prefix=f"{API_V1_PREFIX}/cache", tags=["cache"]) app.include_router(version.router, prefix=f"{API_V1_PREFIX}/version", tags=["version"]) app.include_router(tasks.router, prefix=f"{API_V1_PREFIX}/tasks", tags=["tasks"]) diff --git a/src/pypsa_app/backend/models.py b/src/pypsa_app/backend/models.py index 08a095c..352e05e 100644 --- a/src/pypsa_app/backend/models.py +++ b/src/pypsa_app/backend/models.py @@ -248,6 +248,38 @@ def tags(self) -> list | None: return tags if isinstance(tags, list) else None +class SavedView(Base): + """A saved dashboard view configuration for a network.""" + + __tablename__ = "saved_views" + + id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4) + user_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("users.id", ondelete="CASCADE"), + index=True, + ) + owner: Mapped["User"] = relationship(foreign_keys=[user_id]) + network_id: Mapped[uuid.UUID | None] = mapped_column( + ForeignKey("networks.id", ondelete="CASCADE"), + index=True, + ) + network: Mapped["Network | None"] = relationship(foreign_keys=[network_id]) + name: Mapped[str] = mapped_column(String(255)) + description: Mapped[str | None] = mapped_column(Text) + visibility: Mapped[Visibility] = mapped_column( + str_enum(Visibility, "visibility"), + default=Visibility.PRIVATE, + nullable=False, + ) + config: Mapped[Any] = mapped_column(JSON, nullable=False) + created_at: Mapped[datetime | None] = mapped_column( + TIMESTAMP, server_default=func.now() + ) + updated_at: Mapped[datetime | None] = mapped_column( + TIMESTAMP, server_default=func.now(), onupdate=func.now() + ) + + class RunStatus(enum.StrEnum): """Run status, mirrors Snakedispatch's JobStatus.""" diff --git a/src/pypsa_app/backend/schemas/views.py b/src/pypsa_app/backend/schemas/views.py new file mode 100644 index 0000000..0c67c40 --- /dev/null +++ b/src/pypsa_app/backend/schemas/views.py @@ -0,0 +1,80 @@ +"""Schemas for saved dashboard views.""" + +from datetime import datetime +from typing import Any +from uuid import UUID + +from pydantic import BaseModel, ConfigDict + +from pypsa_app.backend.models import Visibility +from pypsa_app.backend.schemas.auth import UserPublicResponse + + +class ViewConfig(BaseModel): + """Configuration for a saved view.""" + + # Active tab/plot + active_tab: str | None = None + statistic: str | None = None + plot_type: str | None = None + + # Filters + selected_carriers: list[str] = [] + selected_countries: list[str] = [] + individual_plots: bool = False + + # Analysis settings + analysis_type: str | None = None + analysis_parameters: dict[str, Any] = {} + + # Component browser state + selected_component: str | None = None + component_columns: list[str] | None = None + + # Compare mode + compare_network_ids: list[str] = [] + + # Custom parameters + extra: dict[str, Any] = {} + + +class SavedViewCreate(BaseModel): + """Request to create a saved view.""" + + name: str + description: str | None = None + network_id: UUID | None = None + visibility: Visibility = Visibility.PRIVATE + config: ViewConfig + + +class SavedViewUpdate(BaseModel): + """Request to update a saved view.""" + + name: str | None = None + description: str | None = None + visibility: Visibility | None = None + config: ViewConfig | None = None + + +class SavedViewResponse(BaseModel): + """Response for a saved view.""" + + id: UUID + name: str + description: str | None = None + network_id: UUID | None = None + visibility: Visibility + config: dict[str, Any] + owner: UserPublicResponse + created_at: datetime | None = None + updated_at: datetime | None = None + + model_config = ConfigDict(from_attributes=True) + + +class SavedViewListResponse(BaseModel): + """Paginated list of saved views.""" + + data: list[SavedViewResponse] + total: int From 2b87e68a7163f1cf31b540f6426ccba083893d57 Mon Sep 17 00:00:00 2001 From: Mayk Thewessen Date: Mon, 6 Apr 2026 23:02:01 +0200 Subject: [PATCH 09/17] feat: add interactive Leaflet map showing network buses and branches Lightweight map visualization using Leaflet + OpenStreetMap tiles. Shows buses as circle markers and lines/links as colored polylines with capacity-based width. Popups show component details. Map auto-fits to network bounds and responds to carrier filter selection. Co-Authored-By: Claude Opus 4.6 (1M context) --- frontend/app/package-lock.json | 25 +++ frontend/app/package.json | 2 + frontend/app/src/lib/api/client.ts | 13 ++ frontend/app/src/lib/types.ts | 24 +++ .../src/routes/database/network/+page.svelte | 10 + .../network/components/NetworkMap.svelte | 193 ++++++++++++++++++ src/pypsa_app/backend/api/routes/map.py | 141 +++++++++++++ src/pypsa_app/backend/main.py | 4 + 8 files changed, 412 insertions(+) create mode 100644 frontend/app/src/routes/database/network/components/NetworkMap.svelte create mode 100644 src/pypsa_app/backend/api/routes/map.py diff --git a/frontend/app/package-lock.json b/frontend/app/package-lock.json index efb2377..febdbe0 100644 --- a/frontend/app/package-lock.json +++ b/frontend/app/package-lock.json @@ -11,6 +11,7 @@ "@tanstack/svelte-table": "^9.0.0-alpha.10", "class-variance-authority": "^0.7.1", "elkjs": "^0.11.1", + "leaflet": "^1.9.4", "lucide-react": "^0.575.0", "lucide-svelte": "^0.553.0", "mode-watcher": "^1.1.0", @@ -27,6 +28,7 @@ "@sveltejs/vite-plugin-svelte": "^6.2.4", "@tailwindcss/vite": "^4.2.1", "@tanstack/table-core": "^8.21.3", + "@types/leaflet": "^1.9.21", "bits-ui": "^2.16.2", "clsx": "^2.1.1", "svelte": "^5.53.6", @@ -1440,6 +1442,23 @@ "integrity": "sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==", "license": "MIT" }, + "node_modules/@types/geojson": { + "version": "7946.0.16", + "resolved": "https://registry.npmjs.org/@types/geojson/-/geojson-7946.0.16.tgz", + "integrity": "sha512-6C8nqWur3j98U6+lXDfTUWIfgvZU+EumvpHKcYjujKH7woYyLj2sUmff0tRhrqM7BohUw7Pz3ZB1jj2gW9Fvmg==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/leaflet": { + "version": "1.9.21", + "resolved": "https://registry.npmjs.org/@types/leaflet/-/leaflet-1.9.21.tgz", + "integrity": "sha512-TbAd9DaPGSnzp6QvtYngntMZgcRk+igFELwR2N99XZn7RXUdKgsXMR+28bUO0rPsWp8MIu/f47luLIQuSLYv/w==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/geojson": "*" + } + }, "node_modules/@types/trusted-types": { "version": "2.0.7", "resolved": "https://registry.npmjs.org/@types/trusted-types/-/trusted-types-2.0.7.tgz", @@ -1736,6 +1755,12 @@ "node": ">=6" } }, + "node_modules/leaflet": { + "version": "1.9.4", + "resolved": "https://registry.npmjs.org/leaflet/-/leaflet-1.9.4.tgz", + "integrity": "sha512-nxS1ynzJOmOlHp+iL3FyWqK89GtNL8U8rvlMOsQdTTssxZwCXh8N2NB3GDQOL+YR3XnWyZAxwQixURb+FA74PA==", + "license": "BSD-2-Clause" + }, "node_modules/lightningcss": { "version": "1.31.1", "resolved": "https://registry.npmjs.org/lightningcss/-/lightningcss-1.31.1.tgz", diff --git a/frontend/app/package.json b/frontend/app/package.json index f41f382..01313f6 100644 --- a/frontend/app/package.json +++ b/frontend/app/package.json @@ -20,6 +20,7 @@ "@sveltejs/vite-plugin-svelte": "^6.2.4", "@tailwindcss/vite": "^4.2.1", "@tanstack/table-core": "^8.21.3", + "@types/leaflet": "^1.9.21", "bits-ui": "^2.16.2", "clsx": "^2.1.1", "svelte": "^5.53.6", @@ -35,6 +36,7 @@ "@tanstack/svelte-table": "^9.0.0-alpha.10", "class-variance-authority": "^0.7.1", "elkjs": "^0.11.1", + "leaflet": "^1.9.4", "lucide-react": "^0.575.0", "lucide-svelte": "^0.553.0", "mode-watcher": "^1.1.0", diff --git a/frontend/app/src/lib/api/client.ts b/frontend/app/src/lib/api/client.ts index dba926f..d0b4994 100644 --- a/frontend/app/src/lib/api/client.ts +++ b/frontend/app/src/lib/api/client.ts @@ -24,6 +24,7 @@ import type { SavedView, SavedViewListResponse, ViewConfig, + MapDataResponse, } from "$lib/types.js"; const API_BASE = '/api/v1'; @@ -176,6 +177,18 @@ export const networks = { `component-ts-${id}-${componentName}-${attr}` ); }, + async getMapData(id: string, carriers?: string[]): Promise { + const params = new URLSearchParams(); + if (carriers && carriers.length > 0) { + carriers.forEach(c => params.append('carriers', c)); + } + const qs = params.toString(); + return request( + `/networks/${id}/map${qs ? '?' + qs : ''}`, + {}, + `map-${id}` + ); + }, async updateComponentData( id: string, componentName: string, diff --git a/frontend/app/src/lib/types.ts b/frontend/app/src/lib/types.ts index 70ed830..0616c87 100644 --- a/frontend/app/src/lib/types.ts +++ b/frontend/app/src/lib/types.ts @@ -306,6 +306,30 @@ export interface ComponentTimeseriesResponse { limit: number; } +// Map types + +export interface GeoJSONFeature { + type: 'Feature'; + geometry: { + type: 'Point' | 'LineString'; + coordinates: number[] | number[][]; + }; + properties: Record; +} + +export interface GeoJSONFeatureCollection { + type: 'FeatureCollection'; + features: GeoJSONFeature[]; +} + +export interface MapDataResponse { + buses: GeoJSONFeatureCollection; + branches: GeoJSONFeatureCollection; + bounds: { southwest: [number, number]; northeast: [number, number] } | null; + total_buses: number; + total_branches: number; +} + // Saved view types export interface ViewConfig { diff --git a/frontend/app/src/routes/database/network/+page.svelte b/frontend/app/src/routes/database/network/+page.svelte index 060229d..c6fe914 100644 --- a/frontend/app/src/routes/database/network/+page.svelte +++ b/frontend/app/src/routes/database/network/+page.svelte @@ -7,6 +7,7 @@ import type { Network as NetworkType, PlotData, PlotResponse, ApiError } from '$lib/types.js'; import type { ViewConfig } from '$lib/types.js'; import ComponentBrowser from './components/ComponentBrowser.svelte'; + import NetworkMap from './components/NetworkMap.svelte'; import SaveViewDialog from './components/SaveViewDialog.svelte'; import ViewSelector from './components/ViewSelector.svelte'; import { formatFileSize, formatDate, formatRelativeTime, formatNumber, getDirectoryPath, getTagType, getTagColor } from '$lib/utils.js'; @@ -1501,6 +1502,15 @@ async function loadPlot(statistic: string, plotType: string, parameters: Record<
+ + {#if networkId} +
+
+ +
+
+ {/if} + diff --git a/frontend/app/src/routes/database/network/components/NetworkMap.svelte b/frontend/app/src/routes/database/network/components/NetworkMap.svelte new file mode 100644 index 0000000..b2db207 --- /dev/null +++ b/frontend/app/src/routes/database/network/components/NetworkMap.svelte @@ -0,0 +1,193 @@ + + +
+
+ + {#if loading} +
+
+ + Loading map... +
+
+ {/if} + + {#if error} +
+
{error}
+
+ {/if} + + {#if mapData && !loading} +
+ + {mapData.total_buses} buses, {mapData.total_branches} branches +
+ {/if} +
+ + diff --git a/src/pypsa_app/backend/api/routes/map.py b/src/pypsa_app/backend/api/routes/map.py new file mode 100644 index 0000000..7537424 --- /dev/null +++ b/src/pypsa_app/backend/api/routes/map.py @@ -0,0 +1,141 @@ +"""API routes for network geographic/map data.""" + +import logging +from typing import Any + +import pandas as pd +from fastapi import APIRouter, Depends, Query + +from pypsa_app.backend.api.deps import ( + Authorized, + require_network, +) +from pypsa_app.backend.models import Network +from pypsa_app.backend.services.network import NetworkService + +router = APIRouter() +logger = logging.getLogger(__name__) + + +def _load_network(network: Network) -> NetworkService: + return NetworkService(network.file_path, use_cache=True) + + +def _extract_bus_features(n, carriers: list[str] | None = None) -> list[dict]: + """Extract bus positions as GeoJSON Point features.""" + buses = n.buses + if carriers: + buses = buses[buses.carrier.isin(carriers)] + + features = [] + for idx, row in buses.iterrows(): + x, y = row.get("x"), row.get("y") + if pd.isna(x) or pd.isna(y) or (x == 0 and y == 0): + continue + props: dict[str, Any] = {"name": str(idx)} + for col in ("carrier", "v_nom", "country"): + if col in row.index: + val = row[col] + if pd.notna(val): + props[col] = float(val) if isinstance(val, (int, float)) else str(val) + features.append( + { + "type": "Feature", + "geometry": {"type": "Point", "coordinates": [float(x), float(y)]}, + "properties": props, + } + ) + return features + + +def _extract_branch_features( + n, component: str, buses_df: pd.DataFrame +) -> list[dict]: + """Extract branch (Line/Link) data as GeoJSON LineString features.""" + df = getattr(n, component, None) + if df is None or len(df) == 0: + return [] + + features = [] + for idx, row in df.iterrows(): + bus0_name, bus1_name = row.get("bus0"), row.get("bus1") + if bus0_name not in buses_df.index or bus1_name not in buses_df.index: + continue + + bus0 = buses_df.loc[bus0_name] + bus1 = buses_df.loc[bus1_name] + + coords = [] + for bus in (bus0, bus1): + x, y = bus.get("x"), bus.get("y") + if pd.isna(x) or pd.isna(y): + break + coords.append([float(x), float(y)]) + + if len(coords) != 2: + continue + + props: dict[str, Any] = { + "name": str(idx), + "bus0": str(bus0_name), + "bus1": str(bus1_name), + "type": component.rstrip("s"), # "lines" -> "line", "links" -> "link" + } + + # Add capacity info + if "s_nom" in row.index and pd.notna(row["s_nom"]): + props["capacity"] = float(row["s_nom"]) + elif "p_nom" in row.index and pd.notna(row["p_nom"]): + props["capacity"] = float(row["p_nom"]) + + features.append( + { + "type": "Feature", + "geometry": {"type": "LineString", "coordinates": coords}, + "properties": props, + } + ) + return features + + +@router.get("/{network_id}/map") +def get_map_data( + carriers: list[str] | None = Query(None, description="Filter buses by carrier"), + auth: Authorized[Network] = Depends(require_network("read")), +) -> dict: + """Get GeoJSON data for network visualization on a map. + + Returns buses as Points and lines/links as LineStrings. + """ + service = _load_network(auth.model) + n = service.n + + buses = n.buses + bus_features = _extract_bus_features(n, carriers) + line_features = _extract_branch_features(n, "lines", buses) + link_features = _extract_branch_features(n, "links", buses) + + # Compute bounds for map centering + xs = [f["geometry"]["coordinates"][0] for f in bus_features] + ys = [f["geometry"]["coordinates"][1] for f in bus_features] + + bounds = None + if xs and ys: + bounds = { + "southwest": [min(ys), min(xs)], + "northeast": [max(ys), max(xs)], + } + + return { + "buses": { + "type": "FeatureCollection", + "features": bus_features, + }, + "branches": { + "type": "FeatureCollection", + "features": line_features + link_features, + }, + "bounds": bounds, + "total_buses": len(bus_features), + "total_branches": len(line_features) + len(link_features), + } diff --git a/src/pypsa_app/backend/main.py b/src/pypsa_app/backend/main.py index 85c2d0f..fd3c020 100644 --- a/src/pypsa_app/backend/main.py +++ b/src/pypsa_app/backend/main.py @@ -19,6 +19,7 @@ auth, cache, components, + map as map_routes, networks, plots, runs, @@ -296,6 +297,9 @@ async def global_exception_handler(request: Request, exc: Exception) -> JSONResp app.include_router( components.router, prefix=f"{API_V1_PREFIX}/networks", tags=["components"] ) +app.include_router( + map_routes.router, prefix=f"{API_V1_PREFIX}/networks", tags=["map"] +) app.include_router(plots.router, prefix=f"{API_V1_PREFIX}/plots", tags=["plots"]) app.include_router( analysis.router, From 23db4ff3bfdde198e2f679c2bd71f20da0a1c9f3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Apr 2026 21:03:15 +0000 Subject: [PATCH 10/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pypsa_app/backend/api/routes/map.py | 8 ++++---- src/pypsa_app/backend/api/routes/views.py | 7 +------ src/pypsa_app/backend/main.py | 8 ++++---- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/src/pypsa_app/backend/api/routes/map.py b/src/pypsa_app/backend/api/routes/map.py index 7537424..c210c03 100644 --- a/src/pypsa_app/backend/api/routes/map.py +++ b/src/pypsa_app/backend/api/routes/map.py @@ -37,7 +37,9 @@ def _extract_bus_features(n, carriers: list[str] | None = None) -> list[dict]: if col in row.index: val = row[col] if pd.notna(val): - props[col] = float(val) if isinstance(val, (int, float)) else str(val) + props[col] = ( + float(val) if isinstance(val, (int, float)) else str(val) + ) features.append( { "type": "Feature", @@ -48,9 +50,7 @@ def _extract_bus_features(n, carriers: list[str] | None = None) -> list[dict]: return features -def _extract_branch_features( - n, component: str, buses_df: pd.DataFrame -) -> list[dict]: +def _extract_branch_features(n, component: str, buses_df: pd.DataFrame) -> list[dict]: """Extract branch (Line/Link) data as GeoJSON LineString features.""" df = getattr(n, component, None) if df is None or len(df) == 0: diff --git a/src/pypsa_app/backend/api/routes/views.py b/src/pypsa_app/backend/api/routes/views.py index 50fc5d1..34298ba 100644 --- a/src/pypsa_app/backend/api/routes/views.py +++ b/src/pypsa_app/backend/api/routes/views.py @@ -80,12 +80,7 @@ def list_views( ) total = query.count() - views = ( - query.order_by(SavedView.updated_at.desc()) - .offset(skip) - .limit(limit) - .all() - ) + views = query.order_by(SavedView.updated_at.desc()).offset(skip).limit(limit).all() return SavedViewListResponse(data=views, total=total) diff --git a/src/pypsa_app/backend/main.py b/src/pypsa_app/backend/main.py index fd3c020..3a9ebe0 100644 --- a/src/pypsa_app/backend/main.py +++ b/src/pypsa_app/backend/main.py @@ -19,7 +19,6 @@ auth, cache, components, - map as map_routes, networks, plots, runs, @@ -28,6 +27,9 @@ version, views, ) +from pypsa_app.backend.api.routes import ( + map as map_routes, +) from pypsa_app.backend.auth.authenticate import set_auth_disabled_user from pypsa_app.backend.cache import cache_service from pypsa_app.backend.database import SessionLocal, engine @@ -297,9 +299,7 @@ async def global_exception_handler(request: Request, exc: Exception) -> JSONResp app.include_router( components.router, prefix=f"{API_V1_PREFIX}/networks", tags=["components"] ) -app.include_router( - map_routes.router, prefix=f"{API_V1_PREFIX}/networks", tags=["map"] -) +app.include_router(map_routes.router, prefix=f"{API_V1_PREFIX}/networks", tags=["map"]) app.include_router(plots.router, prefix=f"{API_V1_PREFIX}/plots", tags=["plots"]) app.include_router( analysis.router, From 9b06be12e9f046515cea23832617f8c52463eb1b Mon Sep 17 00:00:00 2001 From: Mayk Thewessen Date: Mon, 6 Apr 2026 23:06:33 +0200 Subject: [PATCH 11/17] feat: add network sharing with specific users Networks can now be shared with specific users beyond public/private visibility. Adds network_shares join table, extends can_access() to check shares, adds share/unshare API endpoints, and includes a Share dialog in the network detail header with user search. Co-Authored-By: Claude Opus 4.6 (1M context) --- frontend/app/src/lib/api/client.ts | 13 ++ frontend/app/src/lib/types.ts | 6 + .../src/routes/database/network/+page.svelte | 2 + .../network/components/ShareDialog.svelte | 201 ++++++++++++++++++ .../versions/0004_add_network_shares.py | 45 ++++ src/pypsa_app/backend/api/routes/networks.py | 93 +++++++- src/pypsa_app/backend/models.py | 27 +++ src/pypsa_app/backend/permissions.py | 13 +- src/pypsa_app/backend/schemas/network.py | 14 ++ 9 files changed, 409 insertions(+), 5 deletions(-) create mode 100644 frontend/app/src/routes/database/network/components/ShareDialog.svelte create mode 100644 src/pypsa_app/backend/alembic/versions/0004_add_network_shares.py diff --git a/frontend/app/src/lib/api/client.ts b/frontend/app/src/lib/api/client.ts index d0b4994..1423818 100644 --- a/frontend/app/src/lib/api/client.ts +++ b/frontend/app/src/lib/api/client.ts @@ -1,6 +1,7 @@ import type { User, Network, + NetworkShareResponse, Run, RunSummary, Backend, @@ -134,6 +135,18 @@ export const networks = { async delete(id: string): Promise { return request(`/networks/${id}`, { method: 'DELETE' }); }, + async getShares(id: string): Promise { + return request(`/networks/${id}/shares`); + }, + async shareWith(id: string, userId: string): Promise { + return request(`/networks/${id}/shares`, { + method: 'POST', + body: JSON.stringify({ user_id: userId }) + }); + }, + async unshare(id: string, userId: string): Promise { + return request(`/networks/${id}/shares/${userId}`, { method: 'DELETE' }); + }, async updateVisibility(id: string, visibility: Visibility): Promise { return request(`/networks/${id}`, { method: 'PATCH', diff --git a/frontend/app/src/lib/types.ts b/frontend/app/src/lib/types.ts index 0616c87..be39a16 100644 --- a/frontend/app/src/lib/types.ts +++ b/frontend/app/src/lib/types.ts @@ -41,6 +41,7 @@ export interface Network { file_size?: number; visibility: Visibility; owner: User; + shared_with?: User[]; source_run_id?: string; dimensions?: Record; dimensions_count?: number; @@ -52,6 +53,11 @@ export interface Network { updated_at?: string; } +export interface NetworkShareResponse { + network_id: string; + shared_with: User[]; +} + export type Visibility = "public" | "private"; export interface BackendPublic { diff --git a/frontend/app/src/routes/database/network/+page.svelte b/frontend/app/src/routes/database/network/+page.svelte index c6fe914..9151cc8 100644 --- a/frontend/app/src/routes/database/network/+page.svelte +++ b/frontend/app/src/routes/database/network/+page.svelte @@ -9,6 +9,7 @@ import ComponentBrowser from './components/ComponentBrowser.svelte'; import NetworkMap from './components/NetworkMap.svelte'; import SaveViewDialog from './components/SaveViewDialog.svelte'; + import ShareDialog from './components/ShareDialog.svelte'; import ViewSelector from './components/ViewSelector.svelte'; import { formatFileSize, formatDate, formatRelativeTime, formatNumber, getDirectoryPath, getTagType, getTagColor } from '$lib/utils.js'; import { Network, AlertCircle, FolderOpen, Clock, CalendarRange, Waypoints, ChevronLeft, ChevronRight, SlidersHorizontal, PanelRight } from 'lucide-svelte'; @@ -1276,6 +1277,7 @@ async function loadPlot(statistic: string, plotType: string, parameters: Record< {/if}
+
diff --git a/frontend/app/src/routes/database/network/components/ShareDialog.svelte b/frontend/app/src/routes/database/network/components/ShareDialog.svelte new file mode 100644 index 0000000..c6ded8e --- /dev/null +++ b/frontend/app/src/routes/database/network/components/ShareDialog.svelte @@ -0,0 +1,201 @@ + + +{#if isOwner} + +{/if} + +{#if open} + +{/if} diff --git a/src/pypsa_app/backend/alembic/versions/0004_add_network_shares.py b/src/pypsa_app/backend/alembic/versions/0004_add_network_shares.py new file mode 100644 index 0000000..5564551 --- /dev/null +++ b/src/pypsa_app/backend/alembic/versions/0004_add_network_shares.py @@ -0,0 +1,45 @@ +"""Add network_shares table for sharing networks with specific users. + +Revision ID: 0004 +Revises: 0003 +Create Date: 2026-04-06 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "0004" +down_revision: str | None = "0003" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "network_shares", + sa.Column( + "network_id", + sa.Uuid(), + sa.ForeignKey("networks.id", ondelete="CASCADE"), + primary_key=True, + ), + sa.Column( + "user_id", + sa.Uuid(), + sa.ForeignKey("users.id", ondelete="CASCADE"), + primary_key=True, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(), + server_default=sa.func.now(), + ), + ) + + +def downgrade() -> None: + op.drop_table("network_shares") diff --git a/src/pypsa_app/backend/api/routes/networks.py b/src/pypsa_app/backend/api/routes/networks.py index 84af7bf..ce9ff15 100644 --- a/src/pypsa_app/backend/api/routes/networks.py +++ b/src/pypsa_app/backend/api/routes/networks.py @@ -16,12 +16,14 @@ from pypsa_app.backend.api.utils.network_utils import ( delete_network as delete_network_and_file, ) -from pypsa_app.backend.models import Network, Permission, User, Visibility +from pypsa_app.backend.models import Network, Permission, User, Visibility, network_shares from pypsa_app.backend.permissions import has_permission from pypsa_app.backend.schemas.common import MessageResponse from pypsa_app.backend.schemas.network import ( NetworkListResponse, NetworkResponse, + NetworkShareRequest, + NetworkShareResponse, NetworkUpdate, ) from pypsa_app.backend.services.network import import_network_file @@ -105,10 +107,16 @@ def list_networks( visibility_filter = None if not has_permission(user, Permission.NETWORKS_MANAGE_ALL): - # Non-admin users see: own networks + public + # Non-admin users see: own networks + public + shared with them + shared_ids = ( + db.query(network_shares.c.network_id) + .filter(network_shares.c.user_id == user.id) + .subquery() + ) visibility_filter = or_( Network.user_id == user.id, Network.visibility == Visibility.PUBLIC, + Network.id.in_(shared_ids), ) query = query.filter(visibility_filter) @@ -199,3 +207,84 @@ def delete_network( """Delete network from database and file system""" message = delete_network_and_file(auth.model, db) return {"message": message} + + +# --- Sharing endpoints --- + + +@router.get("/{network_id}/shares", response_model=NetworkShareResponse) +def get_network_shares( + auth: Authorized[Network] = Depends(require_network("modify")), +) -> dict: + """Get list of users this network is shared with. Owner only.""" + network = auth.model + return { + "network_id": network.id, + "shared_with": network.shared_with, + } + + +@router.post("/{network_id}/shares", response_model=NetworkShareResponse) +def share_network( + body: NetworkShareRequest, + auth: Authorized[Network] = Depends(require_network("modify")), + db: Session = Depends(get_db), +) -> dict: + """Share a network with another user. Owner only.""" + network = auth.model + target_user = db.query(User).filter(User.id == body.user_id).first() + if not target_user: + raise HTTPException(404, "User not found") + + if target_user.id == auth.user.id: + raise HTTPException(400, "Cannot share a network with yourself") + + if target_user in network.shared_with: + raise HTTPException(400, "Network is already shared with this user") + + network.shared_with.append(target_user) + db.commit() + db.refresh(network) + + logger.info( + "Network shared", + extra={ + "network_id": str(network.id), + "shared_with": target_user.username, + "shared_by": auth.user.username, + }, + ) + return { + "network_id": network.id, + "shared_with": network.shared_with, + } + + +@router.delete("/{network_id}/shares/{user_id}", response_model=NetworkShareResponse) +def unshare_network( + user_id: _uuid.UUID, + auth: Authorized[Network] = Depends(require_network("modify")), + db: Session = Depends(get_db), +) -> dict: + """Remove a user's access to a shared network. Owner only.""" + network = auth.model + target_user = db.query(User).filter(User.id == user_id).first() + if not target_user or target_user not in network.shared_with: + raise HTTPException(404, "User not found in share list") + + network.shared_with.remove(target_user) + db.commit() + db.refresh(network) + + logger.info( + "Network unshared", + extra={ + "network_id": str(network.id), + "unshared_from": target_user.username, + "unshared_by": auth.user.username, + }, + ) + return { + "network_id": network.id, + "shared_with": network.shared_with, + } diff --git a/src/pypsa_app/backend/models.py b/src/pypsa_app/backend/models.py index 352e05e..4810ce7 100644 --- a/src/pypsa_app/backend/models.py +++ b/src/pypsa_app/backend/models.py @@ -53,6 +53,25 @@ def str_enum(enum_cls: type[enum.Enum], name: str) -> Enum: ) +network_shares = Table( + "network_shares", + Base.metadata, + Column( + "network_id", + Uuid, + ForeignKey("networks.id", ondelete="CASCADE"), + primary_key=True, + ), + Column( + "user_id", + Uuid, + ForeignKey("users.id", ondelete="CASCADE"), + primary_key=True, + ), + Column("created_at", TIMESTAMP, server_default=func.now()), +) + + class SnakedispatchBackend(Base): """A registered Snakedispatch execution backend.""" @@ -242,6 +261,14 @@ class Network(Base): facets: Mapped[Any | None] = mapped_column(JSON) topology_svg: Mapped[str | None] = mapped_column(Text) + shared_with: Mapped[list["User"]] = relationship( + secondary=network_shares, backref="shared_networks" + ) + + @property + def shared_user_ids(self) -> list[uuid.UUID]: + return [u.id for u in self.shared_with] + @property def tags(self) -> list | None: tags = self.meta.get("tags") if self.meta else None diff --git a/src/pypsa_app/backend/permissions.py b/src/pypsa_app/backend/permissions.py index 149e1d7..8d7c351 100644 --- a/src/pypsa_app/backend/permissions.py +++ b/src/pypsa_app/backend/permissions.py @@ -72,13 +72,20 @@ class ResourcePerms: def can_access(user: User, resource: Network | Run) -> bool: - """Can user view this resource? True if public, owner, or admin.""" + """Can user view this resource? True if public, owner, shared, or admin.""" perms = RESOURCE_PERMS[type(resource)] - return ( + if ( resource.visibility == Visibility.PUBLIC or resource.user_id == user.id or has_permission(user, perms.manage_all) - ) + ): + return True + + # Check if the network is shared with this user + if isinstance(resource, Network) and hasattr(resource, "shared_user_ids"): + return user.id in resource.shared_user_ids + + return False def can_modify(user: User, resource: Network | Run) -> bool: diff --git a/src/pypsa_app/backend/schemas/network.py b/src/pypsa_app/backend/schemas/network.py index bc306bd..969faad 100644 --- a/src/pypsa_app/backend/schemas/network.py +++ b/src/pypsa_app/backend/schemas/network.py @@ -32,6 +32,7 @@ class NetworkResponse(BaseModel): # Ownership, visibility and provenance visibility: Visibility = Visibility.PRIVATE owner: UserPublicResponse + shared_with: list[UserPublicResponse] = [] source_run_id: UUID | None = None # Model properties @@ -62,3 +63,16 @@ class NetworkAdminUpdate(NetworkUpdate): """Admin-only fields""" user_id: UUID | None = None + + +class NetworkShareRequest(BaseModel): + """Request to share a network with a user""" + + user_id: UUID + + +class NetworkShareResponse(BaseModel): + """Response for network share status""" + + network_id: UUID + shared_with: list[UserPublicResponse] From c801fd8d771d01a2b81d0675a6acba649a8dde68 Mon Sep 17 00:00:00 2001 From: Mayk Thewessen Date: Mon, 6 Apr 2026 23:12:27 +0200 Subject: [PATCH 12/17] feat: add New Run creation dialog for self-service workflow submission Form with workflow URL, branch/tag, config file, backend selection, and advanced options (snakemake args, import networks, visibility). Validates inputs, submits to API, and navigates to run detail on success. Co-Authored-By: Claude Opus 4.6 (1M context) --- frontend/app/src/routes/runs/+page.svelte | 7 + .../runs/components/CreateRunDialog.svelte | 306 ++++++++++++++++++ 2 files changed, 313 insertions(+) create mode 100644 frontend/app/src/routes/runs/components/CreateRunDialog.svelte diff --git a/frontend/app/src/routes/runs/+page.svelte b/frontend/app/src/routes/runs/+page.svelte index 459eedb..38d8319 100644 --- a/frontend/app/src/routes/runs/+page.svelte +++ b/frontend/app/src/routes/runs/+page.svelte @@ -14,6 +14,7 @@ import StatusBadge from './cells/StatusBadge.svelte'; import { Play } from 'lucide-svelte'; import { toast } from 'svelte-sonner'; + import CreateRunDialog from './components/CreateRunDialog.svelte'; import PaginatedTable from '$lib/components/PaginatedTable.svelte'; import { createColumns } from './components/columns.js'; import { authStore } from '$lib/stores/auth.svelte.js'; @@ -248,6 +249,12 @@
+ +
+
+ +
+ {#if viewState !== 'loading' && viewState !== 'empty'} + import { runs } from '$lib/api/client.js'; + import type { BackendPublic, Run } from '$lib/types.js'; + import { Plus, X, Loader2, ChevronDown, Server, GitBranch, FileCode, Settings2 } from 'lucide-svelte'; + import Button from '$lib/components/ui/button/button.svelte'; + import Badge from '$lib/components/ui/badge/badge.svelte'; + import { toast } from 'svelte-sonner'; + import { goto } from '$app/navigation'; + + let open = $state(false); + let submitting = $state(false); + let backends = $state([]); + let loadingBackends = $state(false); + let showAdvanced = $state(false); + + // Form fields + let workflow = $state(''); + let gitRef = $state(''); + let configfile = $state(''); + let backendId = $state(''); + let snakemakeArgs = $state(''); + let importNetworks = $state(''); + let visibility = $state<'private' | 'public'>('private'); + + // Validation + let errors = $state>({}); + + function validate(): boolean { + errors = {}; + if (!workflow.trim()) { + errors.workflow = 'Workflow URL is required'; + } else { + try { + new URL(workflow.trim()); + } catch { + errors.workflow = 'Must be a valid URL (e.g., https://github.com/org/repo)'; + } + } + if (backends.length > 1 && !backendId) { + errors.backend = 'Select a backend'; + } + return Object.keys(errors).length === 0; + } + + async function openDialog() { + open = true; + workflow = ''; + gitRef = ''; + configfile = ''; + backendId = ''; + snakemakeArgs = ''; + importNetworks = ''; + visibility = 'private'; + showAdvanced = false; + errors = {}; + await loadBackends(); + } + + function closeDialog() { + open = false; + } + + async function loadBackends() { + loadingBackends = true; + try { + backends = await runs.backends(); + // Auto-select if only one backend + if (backends.length === 1) { + backendId = backends[0].id; + } + } catch { + backends = []; + } finally { + loadingBackends = false; + } + } + + async function handleSubmit() { + if (!validate()) return; + submitting = true; + + try { + const body: Record = { + workflow: workflow.trim(), + visibility, + }; + + if (gitRef.trim()) body.git_ref = gitRef.trim(); + if (configfile.trim()) body.configfile = configfile.trim(); + if (backendId) body.backend_id = backendId; + + if (snakemakeArgs.trim()) { + body.snakemake_args = snakemakeArgs.trim().split(/\s+/).filter(Boolean); + } + if (importNetworks.trim()) { + body.import_networks = importNetworks.trim().split(/[,\s]+/).filter(Boolean); + } + + const run: Run = await runs.create(body as any); + toast.success(`Run submitted`); + closeDialog(); + goto(`/runs/${run.id}`); + } catch (err: unknown) { + toast.error((err as Error).message || 'Failed to submit run'); + } finally { + submitting = false; + } + } + + function handleKeydown(e: KeyboardEvent) { + if (e.key === 'Escape') closeDialog(); + } + + + + +{#if open} + +{/if} From 56c73beda9b231cb1dfec16fcd2456fae0876ad6 Mon Sep 17 00:00:00 2001 From: Mayk Thewessen Date: Mon, 6 Apr 2026 23:20:12 +0200 Subject: [PATCH 13/17] fix: correct export method and add non-admin user search for sharing - Fix n.export() -> n.export_to_netcdf() (PyPSA has no .export()) - Add GET /networks/users/search endpoint for non-admin user search - Fix ShareDialog to use new endpoint instead of admin-only listUsers - Remove deadlock risk by dropping explicit cache lock acquisition Co-Authored-By: Claude Opus 4.6 (1M context) --- frontend/app/src/lib/api/client.ts | 5 ++++ .../network/components/ShareDialog.svelte | 10 +++---- .../backend/api/routes/components.py | 5 ++-- src/pypsa_app/backend/api/routes/networks.py | 28 +++++++++++++++++++ 4 files changed, 39 insertions(+), 9 deletions(-) diff --git a/frontend/app/src/lib/api/client.ts b/frontend/app/src/lib/api/client.ts index 1423818..be47449 100644 --- a/frontend/app/src/lib/api/client.ts +++ b/frontend/app/src/lib/api/client.ts @@ -135,6 +135,11 @@ export const networks = { async delete(id: string): Promise { return request(`/networks/${id}`, { method: 'DELETE' }); }, + async searchUsers(q: string): Promise<{ id: string; username: string; avatar_url?: string }[]> { + return request<{ id: string; username: string; avatar_url?: string }[]>( + `/networks/users/search?q=${encodeURIComponent(q)}` + ); + }, async getShares(id: string): Promise { return request(`/networks/${id}/shares`); }, diff --git a/frontend/app/src/routes/database/network/components/ShareDialog.svelte b/frontend/app/src/routes/database/network/components/ShareDialog.svelte index c6ded8e..f0c9c61 100644 --- a/frontend/app/src/routes/database/network/components/ShareDialog.svelte +++ b/frontend/app/src/routes/database/network/components/ShareDialog.svelte @@ -1,5 +1,5 @@ -
+
{#if loading} @@ -179,9 +242,66 @@ {/if} {#if mapData && !loading} -
- - {mapData.total_buses} buses, {mapData.total_branches} branches + +
+ + {mapData.total_buses} buses · {mapData.total_branches} branches +
+ + +
+ + + + {#if legendExpanded} +
+ +
+ + +
+ + + {#each legendEntries as [name, color]} +
+ + {name} +
+ {/each} + + {#if legendEntries.length === 0} +
No data
+ {/if} + + +
+
+ + AC Line +
+
+ + DC Link +
+
+ + Cross-border +
+
+
+ {/if}
{/if}
@@ -189,5 +309,13 @@ diff --git a/src/pypsa_app/backend/api/routes/map.py b/src/pypsa_app/backend/api/routes/map.py index c210c03..414faa3 100644 --- a/src/pypsa_app/backend/api/routes/map.py +++ b/src/pypsa_app/backend/api/routes/map.py @@ -12,6 +12,7 @@ ) from pypsa_app.backend.models import Network from pypsa_app.backend.services.network import NetworkService +from pypsa_app.backend.utils.carrier_colors import CARRIER_COLORS, COUNTRY_COLORS router = APIRouter() logger = logging.getLogger(__name__) @@ -21,12 +22,40 @@ def _load_network(network: Network) -> NetworkService: return NetworkService(network.file_path, use_cache=True) +def _build_bus_stats(n) -> dict[str, dict]: + """Pre-compute per-bus generator/load counts and total capacity.""" + stats: dict[str, dict] = {} + + if len(n.generators) > 0: + for bus, group in n.generators.groupby("bus"): + s = stats.setdefault(str(bus), {}) + s["generators"] = len(group) + s["gen_capacity"] = float(group["p_nom"].sum()) if "p_nom" in group else 0 + # Dominant generator carrier at this bus + if "carrier" in group.columns: + s["gen_carriers"] = sorted(group["carrier"].unique().tolist()) + + if len(n.loads) > 0: + for bus, group in n.loads.groupby("bus"): + s = stats.setdefault(str(bus), {}) + s["loads"] = len(group) + + if len(n.storage_units) > 0: + for bus, group in n.storage_units.groupby("bus"): + s = stats.setdefault(str(bus), {}) + s["storage"] = len(group) + + return stats + + def _extract_bus_features(n, carriers: list[str] | None = None) -> list[dict]: - """Extract bus positions as GeoJSON Point features.""" + """Extract bus positions as GeoJSON Point features with attached stats.""" buses = n.buses if carriers: buses = buses[buses.carrier.isin(carriers)] + bus_stats = _build_bus_stats(n) + features = [] for idx, row in buses.iterrows(): x, y = row.get("x"), row.get("y") @@ -40,6 +69,11 @@ def _extract_bus_features(n, carriers: list[str] | None = None) -> list[dict]: props[col] = ( float(val) if isinstance(val, (int, float)) else str(val) ) + + # Attach per-bus stats + if str(idx) in bus_stats: + props.update(bus_stats[str(idx)]) + features.append( { "type": "Feature", @@ -82,12 +116,23 @@ def _extract_branch_features(n, component: str, buses_df: pd.DataFrame) -> list[ "type": component.rstrip("s"), # "lines" -> "line", "links" -> "link" } - # Add capacity info + # Add capacity and carrier info if "s_nom" in row.index and pd.notna(row["s_nom"]): props["capacity"] = float(row["s_nom"]) elif "p_nom" in row.index and pd.notna(row["p_nom"]): props["capacity"] = float(row["p_nom"]) + if "carrier" in row.index and pd.notna(row["carrier"]): + props["carrier"] = str(row["carrier"]) + + # Cross-border flag + if bus0_name in buses_df.index and bus1_name in buses_df.index: + c0 = buses_df.loc[bus0_name].get("country", "") + c1 = buses_df.loc[bus1_name].get("country", "") + if c0 and c1 and c0 != c1: + props["cross_border"] = True + props["countries"] = f"{c0}-{c1}" + features.append( { "type": "Feature", @@ -126,6 +171,19 @@ def get_map_data( "northeast": [max(ys), max(xs)], } + # Collect unique carriers and countries for legend + bus_carriers = { + f["properties"]["carrier"] + for f in bus_features + if "carrier" in f["properties"] + } + carrier_colors = {c: CARRIER_COLORS.get(c, "#94a3b8") for c in bus_carriers} + country_colors = { + c: COUNTRY_COLORS.get(c, "#94a3b8") + for f in bus_features + if (c := f["properties"].get("country")) + } + return { "buses": { "type": "FeatureCollection", @@ -138,4 +196,6 @@ def get_map_data( "bounds": bounds, "total_buses": len(bus_features), "total_branches": len(line_features) + len(link_features), + "carrier_colors": carrier_colors, + "country_colors": country_colors, } From f08f8412248c13e64c8cddc78911d19fed1be13d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Apr 2026 21:30:59 +0000 Subject: [PATCH 15/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pypsa_app/backend/api/routes/map.py | 4 +--- src/pypsa_app/backend/api/routes/networks.py | 8 +++++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/pypsa_app/backend/api/routes/map.py b/src/pypsa_app/backend/api/routes/map.py index 414faa3..418bf40 100644 --- a/src/pypsa_app/backend/api/routes/map.py +++ b/src/pypsa_app/backend/api/routes/map.py @@ -173,9 +173,7 @@ def get_map_data( # Collect unique carriers and countries for legend bus_carriers = { - f["properties"]["carrier"] - for f in bus_features - if "carrier" in f["properties"] + f["properties"]["carrier"] for f in bus_features if "carrier" in f["properties"] } carrier_colors = {c: CARRIER_COLORS.get(c, "#94a3b8") for c in bus_carriers} country_colors = { diff --git a/src/pypsa_app/backend/api/routes/networks.py b/src/pypsa_app/backend/api/routes/networks.py index 793c487..3fb81df 100644 --- a/src/pypsa_app/backend/api/routes/networks.py +++ b/src/pypsa_app/backend/api/routes/networks.py @@ -16,7 +16,13 @@ from pypsa_app.backend.api.utils.network_utils import ( delete_network as delete_network_and_file, ) -from pypsa_app.backend.models import Network, Permission, User, Visibility, network_shares +from pypsa_app.backend.models import ( + Network, + Permission, + User, + Visibility, + network_shares, +) from pypsa_app.backend.permissions import has_permission from pypsa_app.backend.schemas.common import MessageResponse from pypsa_app.backend.schemas.network import ( From 46bb6abba9a487fea11559153f91e777ac7af116 Mon Sep 17 00:00:00 2001 From: Mayk Thewessen Date: Tue, 7 Apr 2026 09:55:29 +0200 Subject: [PATCH 16/17] style: fix ruff lint errors (type annotations, line length, imports) Co-Authored-By: Claude Opus 4.6 (1M context) --- src/pypsa_app/backend/api/routes/components.py | 13 ++++++------- src/pypsa_app/backend/api/routes/map.py | 14 ++++++++++---- src/pypsa_app/backend/api/routes/views.py | 9 ++++++--- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/pypsa_app/backend/api/routes/components.py b/src/pypsa_app/backend/api/routes/components.py index ac4a82b..7a9fceb 100644 --- a/src/pypsa_app/backend/api/routes/components.py +++ b/src/pypsa_app/backend/api/routes/components.py @@ -4,6 +4,7 @@ from pathlib import Path import pandas as pd +import pypsa from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.orm import Session @@ -35,15 +36,15 @@ def _load_network(network: Network, *, use_cache: bool = True) -> NetworkService return NetworkService(network.file_path, use_cache=use_cache) -def _find_component(n, component_name: str): +def _find_component(n: pypsa.Network, component_name: str): # noqa: ANN202 """Find a component by name or list_name. Raises 404 if not found.""" for c in n.components: - if c.name == component_name or c.list_name == component_name: + if component_name in (c.name, c.list_name): return c raise HTTPException(404, f"Component '{component_name}' not found in network") -def _get_dynamic_attrs(n, list_name: str) -> list[str]: +def _get_dynamic_attrs(n: pypsa.Network, list_name: str) -> list[str]: """Get non-empty time-varying attribute names for a component.""" dynamic_attr = f"{list_name}_t" if not hasattr(n, dynamic_attr): @@ -51,15 +52,13 @@ def _get_dynamic_attrs(n, list_name: str) -> list[str]: dynamic_store = getattr(n, dynamic_attr) attrs = [] - # Use pandas-based iteration over known DataFrame attributes - # rather than dir() which can expose internal Python attributes for attr_name in dynamic_store: try: attr_val = getattr(dynamic_store, attr_name, None) if isinstance(attr_val, pd.DataFrame) and len(attr_val) > 0: attrs.append(attr_name) - except Exception: - continue + except Exception: # noqa: S112 + logger.debug("Skipping dynamic attr %s.%s", list_name, attr_name) return sorted(attrs) diff --git a/src/pypsa_app/backend/api/routes/map.py b/src/pypsa_app/backend/api/routes/map.py index 418bf40..a72687a 100644 --- a/src/pypsa_app/backend/api/routes/map.py +++ b/src/pypsa_app/backend/api/routes/map.py @@ -4,6 +4,7 @@ from typing import Any import pandas as pd +import pypsa from fastapi import APIRouter, Depends, Query from pypsa_app.backend.api.deps import ( @@ -22,7 +23,7 @@ def _load_network(network: Network) -> NetworkService: return NetworkService(network.file_path, use_cache=True) -def _build_bus_stats(n) -> dict[str, dict]: +def _build_bus_stats(n: pypsa.Network) -> dict[str, dict]: """Pre-compute per-bus generator/load counts and total capacity.""" stats: dict[str, dict] = {} @@ -48,7 +49,9 @@ def _build_bus_stats(n) -> dict[str, dict]: return stats -def _extract_bus_features(n, carriers: list[str] | None = None) -> list[dict]: +def _extract_bus_features( + n: pypsa.Network, carriers: list[str] | None = None +) -> list[dict]: """Extract bus positions as GeoJSON Point features with attached stats.""" buses = n.buses if carriers: @@ -84,7 +87,9 @@ def _extract_bus_features(n, carriers: list[str] | None = None) -> list[dict]: return features -def _extract_branch_features(n, component: str, buses_df: pd.DataFrame) -> list[dict]: +def _extract_branch_features( + n: pypsa.Network, component: str, buses_df: pd.DataFrame +) -> list[dict]: """Extract branch (Line/Link) data as GeoJSON LineString features.""" df = getattr(n, component, None) if df is None or len(df) == 0: @@ -106,7 +111,8 @@ def _extract_branch_features(n, component: str, buses_df: pd.DataFrame) -> list[ break coords.append([float(x), float(y)]) - if len(coords) != 2: + expected_endpoints = 2 + if len(coords) != expected_endpoints: continue props: dict[str, Any] = { diff --git a/src/pypsa_app/backend/api/routes/views.py b/src/pypsa_app/backend/api/routes/views.py index 34298ba..7ef47df 100644 --- a/src/pypsa_app/backend/api/routes/views.py +++ b/src/pypsa_app/backend/api/routes/views.py @@ -102,9 +102,12 @@ def get_view( raise HTTPException(404, "View not found") # Check access: own views or public - if view.user_id != user.id and view.visibility != Visibility.PUBLIC: - if not has_permission(user, Permission.NETWORKS_MANAGE_ALL): - raise HTTPException(404, "View not found") + if ( + view.user_id != user.id + and view.visibility != Visibility.PUBLIC + and not has_permission(user, Permission.NETWORKS_MANAGE_ALL) + ): + raise HTTPException(404, "View not found") return view From 263e1e7b3748015025eafc1ccdf1e143d304a670 Mon Sep 17 00:00:00 2001 From: Mayk Thewessen Date: Tue, 7 Apr 2026 13:48:04 +0200 Subject: [PATCH 17/17] fix: handle NaN category in PyPSA component metadata Bus components have category=NaN (float) which passed through the `or None` check and caused Pydantic serialization to fail with 500. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/pypsa_app/backend/api/routes/components.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/pypsa_app/backend/api/routes/components.py b/src/pypsa_app/backend/api/routes/components.py index 7a9fceb..c106baa 100644 --- a/src/pypsa_app/backend/api/routes/components.py +++ b/src/pypsa_app/backend/api/routes/components.py @@ -1,6 +1,7 @@ """API routes for browsing and editing network component data.""" import logging +import math from pathlib import Path import pandas as pd @@ -62,6 +63,16 @@ def _get_dynamic_attrs(n: pypsa.Network, list_name: str) -> list[str]: return sorted(attrs) +def _safe_category(c) -> str | None: # noqa: ANN001, ANN202 + """Extract component category, handling NaN values.""" + cat = getattr(c, "category", None) + if cat is None: + return None + if isinstance(cat, float) and math.isnan(cat): + return None + return str(cat) if cat else None + + @router.get("/{network_id}/components", response_model=ComponentListResponse) def list_components( auth: Authorized[Network] = Depends(require_network("read")), @@ -85,7 +96,7 @@ def list_components( name=c.name, list_name=c.list_name, count=len(c), - category=getattr(c, "category", None) or None, + category=_safe_category(c), attrs=list(static_df.columns), has_dynamic=len(dynamic_attrs) > 0, dynamic_attrs=dynamic_attrs,
@@ -274,7 +428,8 @@ @@ -293,15 +448,32 @@
- {componentData.index[rowIdx]} + {rowIndex} - {formatted} - + handleCellEdit(rowIndex, column, (e.target as HTMLInputElement).value, dtype)} + class="w-full h-full px-2 py-1 text-xs bg-transparent border border-transparent rounded + hover:border-border focus:border-primary focus:outline-none + {edited ? 'border-yellow-400 dark:border-yellow-600' : ''}" + /> + + {formatted} +