From 341e22bf40be44dce2b6f495bca560fddf1d8ae0 Mon Sep 17 00:00:00 2001 From: 10fra <743893+10fra@users.noreply.github.com> Date: Fri, 20 Feb 2026 13:53:40 +0100 Subject: [PATCH] feat: add DE/EU locale support with German data source connectors - Add --locale flag (us|de) to CLI and AgentConfig - Create agent/connectors/ package with shared HTTP helper - Add Lobbyregister Bundestag API connector - Add abgeordnetenwatch.de API v2 connector - Add OffeneRegister bulk JSONL search connector - Add EU Transparency Register bulk CSV/XML connector - Create agent/normalizers/ with German entity normalization (umlauts, legal forms, titles, courts) and composite entity resolver - Wire 4 DE-locale tools into tool_defs, engine dispatch, and WorkspaceTools - Add DE prompt localization (entity resolution + political context sections) --- agent/__main__.py | 8 + agent/builder.py | 2 + agent/config.py | 4 + agent/connectors/__init__.py | 94 ++++++++++ agent/connectors/abgeordnetenwatch.py | 231 ++++++++++++++++++++++++ agent/connectors/eu_transparency.py | 223 ++++++++++++++++++++++++ agent/connectors/lobbyregister.py | 188 ++++++++++++++++++++ agent/connectors/offeneregister.py | 160 +++++++++++++++++ agent/engine.py | 50 +++++- agent/normalizers/__init__.py | 1 + agent/normalizers/entity_resolver.py | 120 +++++++++++++ agent/normalizers/german.py | 241 ++++++++++++++++++++++++++ agent/prompts.py | 71 ++++++++ agent/tool_defs.py | 120 ++++++++++++- agent/tools.py | 85 +++++++++ 15 files changed, 1594 insertions(+), 4 deletions(-) create mode 100644 agent/connectors/__init__.py create mode 100644 agent/connectors/abgeordnetenwatch.py create mode 100644 agent/connectors/eu_transparency.py create mode 100644 agent/connectors/lobbyregister.py create mode 100644 agent/connectors/offeneregister.py create mode 100644 agent/normalizers/__init__.py create mode 100644 agent/normalizers/entity_resolver.py create mode 100644 agent/normalizers/german.py diff --git a/agent/__main__.py b/agent/__main__.py index afc736ee..c04f62c6 100644 --- a/agent/__main__.py +++ b/agent/__main__.py @@ -133,6 +133,12 @@ def build_parser() -> argparse.ArgumentParser: action="store_true", help="Censor entity names and workspace path segments in output (UI-only).", ) + parser.add_argument( + "--locale", + choices=["us", "de"], + default=None, + help="Locale for data sources and entity resolution (default: us).", + ) return parser @@ -316,6 +322,8 @@ def _apply_runtime_overrides(cfg: AgentConfig, args: argparse.Namespace, creds: cfg.acceptance_criteria = True if args.demo: cfg.demo = True + if args.locale: + cfg.locale = args.locale def run_plain_repl(ctx: ChatContext) -> None: diff --git a/agent/builder.py b/agent/builder.py index 279c264b..b0f7fb58 100644 --- a/agent/builder.py +++ b/agent/builder.py @@ -146,6 +146,8 @@ def build_engine(cfg: AgentConfig) -> RLMEngine: max_search_hits=cfg.max_search_hits, exa_api_key=cfg.exa_api_key, exa_base_url=cfg.exa_base_url, + locale=cfg.locale, + lobbyregister_api_key=cfg.lobbyregister_api_key, ) try: diff --git a/agent/config.py b/agent/config.py index 36839499..d3a5a4b3 100644 --- a/agent/config.py +++ b/agent/config.py @@ -48,6 +48,8 @@ class AgentConfig: acceptance_criteria: bool = True max_plan_chars: int = 40_000 demo: bool = False + locale: str = "us" + lobbyregister_api_key: str = "5bHB2zrUuHR6YdPoZygQhWfg2CBrjUOi" @classmethod def from_env(cls, workspace: str | Path) -> "AgentConfig": @@ -100,4 +102,6 @@ def from_env(cls, workspace: str | Path) -> "AgentConfig": acceptance_criteria=os.getenv("OPENPLANTER_ACCEPTANCE_CRITERIA", "true").strip().lower() in ("1", "true", "yes"), max_plan_chars=int(os.getenv("OPENPLANTER_MAX_PLAN_CHARS", "40000")), demo=os.getenv("OPENPLANTER_DEMO", "").strip().lower() in ("1", "true", "yes"), + locale=os.getenv("OPENPLANTER_LOCALE", "us").strip().lower(), + lobbyregister_api_key=os.getenv("OPENPLANTER_LOBBYREGISTER_API_KEY", "5bHB2zrUuHR6YdPoZygQhWfg2CBrjUOi"), ) diff --git a/agent/connectors/__init__.py b/agent/connectors/__init__.py new file mode 100644 index 00000000..08aba0a9 --- /dev/null +++ b/agent/connectors/__init__.py @@ -0,0 +1,94 @@ +"""German/EU data source connectors for OpenPlanter. + +Shared HTTP helper following the urllib.request pattern from tools.py. +""" +from __future__ import annotations + +import json +import urllib.error +import urllib.request +import urllib.parse +from typing import Any + + +class ConnectorError(RuntimeError): + """Raised when a connector request fails.""" + + +def _api_request( + url: str, + payload: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + method: str = "GET", + timeout: int = 30, +) -> dict[str, Any]: + """Stdlib HTTP helper (urllib.request). Returns parsed JSON.""" + hdrs = { + "User-Agent": "OpenPlanter/1.0", + "Accept": "application/json", + } + if headers: + hdrs.update(headers) + + data: bytes | None = None + if payload is not None: + data = json.dumps(payload).encode("utf-8") + hdrs["Content-Type"] = "application/json" + + req = urllib.request.Request(url=url, data=data, headers=hdrs, method=method) + + try: + with urllib.request.urlopen(req, timeout=timeout) as resp: + raw = resp.read().decode("utf-8", errors="replace") + except urllib.error.HTTPError as exc: + body = exc.read().decode("utf-8", errors="replace") + raise ConnectorError(f"HTTP {exc.code}: {body[:500]}") from exc + except urllib.error.URLError as exc: + raise ConnectorError(f"Connection error: {exc}") from exc + except OSError as exc: + raise ConnectorError(f"Network error: {exc}") from exc + + try: + parsed = json.loads(raw) + except json.JSONDecodeError as exc: + raise ConnectorError(f"Non-JSON response: {raw[:500]}") from exc + if not isinstance(parsed, dict): + raise ConnectorError(f"Expected JSON object, got {type(parsed).__name__}") + return parsed + + +def _api_request_list( + url: str, + headers: dict[str, str] | None = None, + timeout: int = 30, +) -> list[dict[str, Any]]: + """Like _api_request but expects a JSON array at top level.""" + hdrs = { + "User-Agent": "OpenPlanter/1.0", + "Accept": "application/json", + } + if headers: + hdrs.update(headers) + + req = urllib.request.Request(url=url, headers=hdrs, method="GET") + + try: + with urllib.request.urlopen(req, timeout=timeout) as resp: + raw = resp.read().decode("utf-8", errors="replace") + except urllib.error.HTTPError as exc: + body = exc.read().decode("utf-8", errors="replace") + raise ConnectorError(f"HTTP {exc.code}: {body[:500]}") from exc + except urllib.error.URLError as exc: + raise ConnectorError(f"Connection error: {exc}") from exc + except OSError as exc: + raise ConnectorError(f"Network error: {exc}") from exc + + try: + parsed = json.loads(raw) + except json.JSONDecodeError as exc: + raise ConnectorError(f"Non-JSON response: {raw[:500]}") from exc + if isinstance(parsed, dict): + return [parsed] + if not isinstance(parsed, list): + raise ConnectorError(f"Expected JSON array, got {type(parsed).__name__}") + return parsed diff --git a/agent/connectors/abgeordnetenwatch.py b/agent/connectors/abgeordnetenwatch.py new file mode 100644 index 00000000..7d98a61a --- /dev/null +++ b/agent/connectors/abgeordnetenwatch.py @@ -0,0 +1,231 @@ +"""abgeordnetenwatch.de API v2 connector. + +Accesses German MP data, votes, questions, and side income via the +public CC0-licensed REST API. +Base URL: https://www.abgeordnetenwatch.de/api/v2/ +No authentication required. +""" +from __future__ import annotations + +import json +import urllib.parse +from typing import Any + +from . import ConnectorError, _api_request + +_BASE_URL = "https://www.abgeordnetenwatch.de/api/v2" + + +def _build_url(endpoint: str, params: dict[str, Any] | None = None) -> str: + """Build API URL with optional query parameters.""" + url = f"{_BASE_URL}/{endpoint.lstrip('/')}" + if params: + # Filter out None values + clean = {k: str(v) for k, v in params.items() if v is not None} + if clean: + url += f"?{urllib.parse.urlencode(clean)}" + return url + + +def _normalize_politician(raw: dict[str, Any]) -> dict[str, Any]: + """Normalize a politician record.""" + return { + "id": raw.get("id"), + "label": raw.get("label", ""), + "first_name": raw.get("first_name", ""), + "last_name": raw.get("last_name", ""), + "birth_name": raw.get("birth_name", ""), + "year_of_birth": raw.get("year_of_birth"), + "party": _extract_party(raw), + "occupation": raw.get("occupation", ""), + "education": raw.get("education", ""), + "url": raw.get("abgeordnetenwatch_url", ""), + "mandates": _extract_mandates(raw), + } + + +def _extract_party(raw: dict[str, Any]) -> str: + """Extract party name from nested party object.""" + party = raw.get("party") + if isinstance(party, dict): + return party.get("label", "") or party.get("full_name", "") + return str(party) if party else "" + + +def _extract_mandates(raw: dict[str, Any]) -> list[dict[str, Any]]: + """Extract mandates from related data.""" + mandates_raw = raw.get("mandates") or raw.get("related_data", {}).get("mandates", {}) + if isinstance(mandates_raw, dict): + mandates_raw = mandates_raw.get("data", []) + if not isinstance(mandates_raw, list): + return [] + result: list[dict[str, Any]] = [] + for m in mandates_raw: + if not isinstance(m, dict): + continue + result.append({ + "id": m.get("id"), + "label": m.get("label", ""), + "parliament_period": _nested_label(m.get("parliament_period")), + "fraction": _nested_label(m.get("fraction")), + "start_date": m.get("start_date", ""), + "end_date": m.get("end_date", ""), + }) + return result + + +def _nested_label(obj: Any) -> str: + if isinstance(obj, dict): + return obj.get("label", "") or obj.get("full_name", "") + return str(obj) if obj else "" + + +def _normalize_sidejob(raw: dict[str, Any]) -> dict[str, Any]: + """Normalize a sidejob record.""" + return { + "id": raw.get("id"), + "label": raw.get("label", ""), + "organization": raw.get("sidejob_organization", {}).get("label", "") if isinstance(raw.get("sidejob_organization"), dict) else "", + "category": raw.get("category", ""), + "income_level": raw.get("income_level", ""), + "interval": raw.get("interval", ""), + "created": raw.get("created", ""), + "politician_id": _extract_nested_id(raw.get("mandate", {})), + } + + +def _extract_nested_id(obj: Any) -> int | None: + if isinstance(obj, dict): + pol = obj.get("politician") + if isinstance(pol, dict): + return pol.get("id") + return obj.get("id") + return None + + +def _normalize_vote(raw: dict[str, Any]) -> dict[str, Any]: + """Normalize a vote record.""" + return { + "id": raw.get("id"), + "vote": raw.get("vote", ""), + "reason_no_show": raw.get("reason_no_show", ""), + "mandate": _nested_label(raw.get("mandate")), + "fraction": _nested_label(raw.get("fraction")), + "poll_id": raw.get("poll", {}).get("id") if isinstance(raw.get("poll"), dict) else None, + } + + +def search_politicians( + query: str | None = None, + parliament_period: int | None = None, + party_id: int | None = None, + max_results: int = 20, +) -> str: + """Search politicians on abgeordnetenwatch.""" + params: dict[str, Any] = { + "range_end": min(max_results, 100), + } + if parliament_period is not None: + params["parliament_period"] = parliament_period + if party_id is not None: + params["party"] = party_id + if query: + params["label[cn]"] = query + + url = _build_url("politicians", params) + + try: + data = _api_request(url, timeout=30) + except ConnectorError as exc: + return json.dumps({"error": str(exc), "query": query}) + + results: list[dict[str, Any]] = [] + entries = data.get("data", []) + if isinstance(entries, list): + for entry in entries[:max_results]: + if isinstance(entry, dict): + results.append(_normalize_politician(entry)) + + meta = data.get("meta", {}) + return json.dumps({ + "source": "abgeordnetenwatch", + "query": query, + "total_results": meta.get("result", {}).get("total", len(results)) if isinstance(meta, dict) else len(results), + "results": results, + }, ensure_ascii=False, indent=2) + + +def get_politician(politician_id: int) -> str: + """Fetch a single politician with mandates.""" + url = _build_url(f"politicians/{politician_id}", {"related_data": "mandates"}) + + try: + data = _api_request(url, timeout=30) + except ConnectorError as exc: + return json.dumps({"error": str(exc), "politician_id": politician_id}) + + entry = data.get("data", data) + if isinstance(entry, dict): + entry = _normalize_politician(entry) + + return json.dumps({ + "source": "abgeordnetenwatch", + "politician": entry, + }, ensure_ascii=False, indent=2) + + +def get_poll_votes(poll_id: int, max_results: int = 100) -> str: + """Fetch votes for a specific poll.""" + url = _build_url(f"polls/{poll_id}/votes", {"range_end": min(max_results, 500)}) + + try: + data = _api_request(url, timeout=30) + except ConnectorError as exc: + return json.dumps({"error": str(exc), "poll_id": poll_id}) + + results: list[dict[str, Any]] = [] + entries = data.get("data", []) + if isinstance(entries, list): + for entry in entries[:max_results]: + if isinstance(entry, dict): + results.append(_normalize_vote(entry)) + + return json.dumps({ + "source": "abgeordnetenwatch", + "poll_id": poll_id, + "total_votes": len(results), + "votes": results, + }, ensure_ascii=False, indent=2) + + +def search_sidejobs( + politician_id: int | None = None, + max_results: int = 50, +) -> str: + """Search sidejobs (Nebeneinkünfte).""" + params: dict[str, Any] = { + "range_end": min(max_results, 200), + } + if politician_id is not None: + params["politician"] = politician_id + + url = _build_url("sidejobs", params) + + try: + data = _api_request(url, timeout=30) + except ConnectorError as exc: + return json.dumps({"error": str(exc), "politician_id": politician_id}) + + results: list[dict[str, Any]] = [] + entries = data.get("data", []) + if isinstance(entries, list): + for entry in entries[:max_results]: + if isinstance(entry, dict): + results.append(_normalize_sidejob(entry)) + + return json.dumps({ + "source": "abgeordnetenwatch", + "politician_id": politician_id, + "total_results": len(results), + "sidejobs": results, + }, ensure_ascii=False, indent=2) diff --git a/agent/connectors/eu_transparency.py b/agent/connectors/eu_transparency.py new file mode 100644 index 00000000..08c0420f --- /dev/null +++ b/agent/connectors/eu_transparency.py @@ -0,0 +1,223 @@ +"""EU Transparency Register connector. + +Searches pre-downloaded bulk CSV data from the EU Transparency Register +(data.europa.eu). Handles both CSV and XML formats. +""" +from __future__ import annotations + +import csv +import json +import io +import re +from pathlib import Path +from typing import Any + +from ..normalizers.german import normalize_company_name, umlauts_to_ascii + + +def _fuzzy_match(query_normalized: str, name: str) -> bool: + """Check if query tokens all appear in the normalized name.""" + target = normalize_company_name(name) + tokens = query_normalized.split() + return all(tok in target for tok in tokens) + + +def _normalize_csv_entry(row: dict[str, str]) -> dict[str, Any]: + """Normalize a CSV row from the EU Transparency Register.""" + # The CSV format varies; handle common column names + name = ( + row.get("Name", "") + or row.get("Organisation name", "") + or row.get("name", "") + or row.get("organisationName", "") + ) + return { + "name": name, + "identification_code": row.get("Identification code", "") or row.get("identificationCode", ""), + "category": ( + row.get("Category", "") + or row.get("Section", "") + or row.get("category", "") + ), + "country": ( + row.get("Country of head office", "") + or row.get("Head office country", "") + or row.get("country", "") + ), + "eu_lobbying_expenditure": ( + row.get("Estimated costs", "") + or row.get("Costs of direct lobbying", "") + or row.get("estimatedCosts", "") + ), + "num_lobbyists": ( + row.get("Number of persons", "") + or row.get("numberOfPersons", "") + ), + "legislative_interests": _split_field( + row.get("Fields of interest", "") + or row.get("fieldsOfInterest", "") + ), + "registration_date": ( + row.get("Registration date", "") + or row.get("registrationDate", "") + ), + "website": row.get("Website", "") or row.get("website", ""), + } + + +def _split_field(value: str) -> list[str]: + """Split a semicolon or comma-delimited field into a list.""" + if not value: + return [] + # Try semicolon first, then comma + if ";" in value: + return [v.strip() for v in value.split(";") if v.strip()] + return [v.strip() for v in value.split(",") if v.strip()] + + +def _parse_xml_entry(text: str) -> dict[str, Any]: + """Minimal XML tag extraction without external dependencies.""" + def _extract_tag(tag: str) -> str: + match = re.search(rf"<{tag}[^>]*>(.*?)", text, re.DOTALL) + return match.group(1).strip() if match else "" + + return { + "name": _extract_tag("name") or _extract_tag("organisationName"), + "identification_code": _extract_tag("identificationCode"), + "category": _extract_tag("category") or _extract_tag("section"), + "country": _extract_tag("country") or _extract_tag("headOfficeCountry"), + "eu_lobbying_expenditure": _extract_tag("estimatedCosts") or _extract_tag("lobbyCosts"), + "num_lobbyists": _extract_tag("numberOfPersons"), + "legislative_interests": _split_field(_extract_tag("fieldsOfInterest")), + "registration_date": _extract_tag("registrationDate"), + "website": _extract_tag("website"), + } + + +def search_eu_transparency( + query: str, + data_path: str, + max_results: int = 20, +) -> str: + """Search EU Transparency Register bulk data for matching entries. + + Args: + query: Organization name or search terms. + data_path: Path to the bulk CSV or XML file. + max_results: Maximum number of results to return. + + Returns: + JSON string with search results. + """ + if not query.strip(): + return json.dumps({"error": "Empty query"}) + + path = Path(data_path) + if not path.exists(): + return json.dumps({ + "error": f"Data file not found: {data_path}", + "hint": ( + "Download EU Transparency Register data from " + "https://data.europa.eu/data/datasets/transparency-register" + ), + }) + + suffix = path.suffix.lower() + if suffix == ".xml": + return _search_xml(query, path, max_results) + return _search_csv(query, path, max_results) + + +def _search_csv(query: str, path: Path, max_results: int) -> str: + """Search a CSV format EU Transparency Register file.""" + query_normalized = normalize_company_name(query) + query_ascii = umlauts_to_ascii(query.strip().lower()) + + results: list[dict[str, Any]] = [] + scanned = 0 + + try: + with open(path, "r", encoding="utf-8", errors="replace") as fh: + # Sniff delimiter + sample = fh.read(4096) + fh.seek(0) + try: + dialect = csv.Sniffer().sniff(sample, delimiters=",;\t") + except csv.Error: + dialect = csv.excel + + reader = csv.DictReader(fh, dialect=dialect) + for row in reader: + scanned += 1 + name = ( + row.get("Name", "") + or row.get("Organisation name", "") + or row.get("name", "") + or row.get("organisationName", "") + ) + if not name: + continue + + name_lower = name.lower() + if query_ascii not in name_lower and not _fuzzy_match(query_normalized, name): + continue + + results.append(_normalize_csv_entry(row)) + if len(results) >= max_results: + break + + except OSError as exc: + return json.dumps({"error": f"Failed to read data file: {exc}"}) + + return json.dumps({ + "source": "eu_transparency_register", + "query": query, + "data_path": str(path), + "rows_scanned": scanned, + "total_results": len(results), + "results": results, + }, ensure_ascii=False, indent=2) + + +def _search_xml(query: str, path: Path, max_results: int) -> str: + """Search an XML format EU Transparency Register file.""" + query_normalized = normalize_company_name(query) + query_ascii = umlauts_to_ascii(query.strip().lower()) + + results: list[dict[str, Any]] = [] + scanned = 0 + + try: + content = path.read_text(encoding="utf-8", errors="replace") + except OSError as exc: + return json.dumps({"error": f"Failed to read data file: {exc}"}) + + # Split on record tags (common patterns) + record_pattern = re.compile( + r"<(?:interestRepresentative|entry|organisation)[^>]*>.*?", + re.DOTALL, + ) + + for match in record_pattern.finditer(content): + scanned += 1 + block = match.group() + block_lower = block.lower() + + if query_ascii not in block_lower: + continue + + entry = _parse_xml_entry(block) + name = entry.get("name", "") + if name and _fuzzy_match(query_normalized, name): + results.append(entry) + if len(results) >= max_results: + break + + return json.dumps({ + "source": "eu_transparency_register", + "query": query, + "data_path": str(path), + "records_scanned": scanned, + "total_results": len(results), + "results": results, + }, ensure_ascii=False, indent=2) diff --git a/agent/connectors/lobbyregister.py b/agent/connectors/lobbyregister.py new file mode 100644 index 00000000..92822ff3 --- /dev/null +++ b/agent/connectors/lobbyregister.py @@ -0,0 +1,188 @@ +"""Lobbyregister Bundestag API connector. + +Accesses the German federal lobby register via its public REST API. +Documentation: https://www.lobbyregister.bundestag.de/api/ +""" +from __future__ import annotations + +import json +import urllib.parse +from typing import Any + +from . import ConnectorError, _api_request + +_BASE_URL = "https://www.lobbyregister.bundestag.de/api/v1" + +# Valid sort parameters for search. +VALID_SORTS = ( + "ALPHABETICAL_ASC", + "ALPHABETICAL_DESC", + "FINANCIALEXPENSES_ASC", + "FINANCIALEXPENSES_DESC", + "DONATIONS_ASC", + "DONATIONS_DESC", + "REGISTRATION_DATE_ASC", + "REGISTRATION_DATE_DESC", +) + +# Fields-of-interest filter codes (Interessenbereiche). +FOI_CODES: dict[str, str] = { + "agriculture": "AGRICULTURE", + "defence": "DEFENCE", + "digital": "DIGITAL", + "economy": "ECONOMY", + "education": "EDUCATION", + "energy": "ENERGY", + "environment": "ENVIRONMENT", + "europe": "EUROPE", + "finance": "FINANCE", + "foreign": "FOREIGN", + "health": "HEALTH", + "home": "HOME", + "housing": "HOUSING", + "justice": "JUSTICE", + "labour": "LABOUR", + "media": "MEDIA", + "science": "SCIENCE", + "social": "SOCIAL", + "traffic": "TRAFFIC", +} + + +def _build_search_url( + query: str, + sort: str = "ALPHABETICAL_ASC", + page: int = 0, + size: int = 20, + foi_filter: str | None = None, + api_key: str = "", +) -> str: + """Build the search URL with query parameters.""" + params: dict[str, str] = { + "q": query, + "sort": sort if sort in VALID_SORTS else "ALPHABETICAL_ASC", + "page": str(max(0, page)), + "size": str(max(1, min(size, 50))), + } + if foi_filter and foi_filter.upper() in FOI_CODES.values(): + params["fieldOfInterest"] = foi_filter.upper() + elif foi_filter and foi_filter.lower() in FOI_CODES: + params["fieldOfInterest"] = FOI_CODES[foi_filter.lower()] + if api_key: + params["apikey"] = api_key + return f"{_BASE_URL}/sucheDetailJson?{urllib.parse.urlencode(params)}" + + +def _normalize_entry(raw: dict[str, Any]) -> dict[str, Any]: + """Normalize a Lobbyregister entry to a standard shape.""" + general = raw.get("general", {}) or {} + financial = raw.get("financialInformation", {}) or {} + activity = raw.get("activity", {}) or {} + + # Extract clients list + clients: list[str] = [] + for client in (activity.get("clients") or []): + if isinstance(client, dict): + clients.append(client.get("name", "")) + elif isinstance(client, str): + clients.append(client) + + # Fields of interest + fois: list[str] = [] + for foi in (activity.get("fieldsOfInterest") or []): + if isinstance(foi, dict): + fois.append(foi.get("name", "") or foi.get("code", "")) + elif isinstance(foi, str): + fois.append(foi) + + return { + "name": general.get("name", "") or raw.get("name", ""), + "register_number": raw.get("registerNumber", ""), + "entry_id": raw.get("id", ""), + "org_type": general.get("organizationType", ""), + "legal_form": general.get("legalForm", ""), + "address": _format_address(general.get("address")), + "employees": general.get("numberOfEmployees", ""), + "financial_expenditure": financial.get("financialExpenditure", ""), + "financial_year": financial.get("financialYear", ""), + "donations_flag": bool(financial.get("donations")), + "fields_of_interest": fois, + "clients": clients, + "registration_date": raw.get("registrationDate", ""), + "last_update": raw.get("lastUpdate", ""), + } + + +def _format_address(addr: dict[str, Any] | None) -> str: + """Format an address dict to a single string.""" + if not addr or not isinstance(addr, dict): + return "" + parts = [ + addr.get("street", ""), + addr.get("zipCode", ""), + addr.get("city", ""), + addr.get("country", ""), + ] + return ", ".join(p for p in parts if p) + + +def search_lobbyregister( + query: str, + api_key: str = "", + sort: str = "ALPHABETICAL_ASC", + max_results: int = 20, + foi_filter: str | None = None, +) -> str: + """Search the Lobbyregister and return normalized JSON results.""" + if not query.strip(): + return json.dumps({"error": "Empty query"}) + + url = _build_search_url( + query=query.strip(), + sort=sort, + size=min(max_results, 50), + foi_filter=foi_filter, + api_key=api_key, + ) + + try: + data = _api_request(url, timeout=30) + except ConnectorError as exc: + return json.dumps({"error": str(exc), "query": query}) + + results: list[dict[str, Any]] = [] + entries = data.get("content", []) or data.get("results", []) or [] + if isinstance(entries, list): + for entry in entries[:max_results]: + if isinstance(entry, dict): + results.append(_normalize_entry(entry)) + + return json.dumps({ + "source": "lobbyregister_bundestag", + "query": query, + "total_results": data.get("totalElements", len(results)), + "results": results, + }, ensure_ascii=False, indent=2) + + +def get_lobbyregister_entry( + register_number: str, + entry_id: str, + api_key: str = "", +) -> str: + """Fetch a single Lobbyregister entry by register number and entry ID.""" + params: dict[str, str] = {} + if api_key: + params["apikey"] = api_key + qs = f"?{urllib.parse.urlencode(params)}" if params else "" + url = f"{_BASE_URL}/register/{urllib.parse.quote(register_number)}/{urllib.parse.quote(entry_id)}{qs}" + + try: + data = _api_request(url, timeout=30) + except ConnectorError as exc: + return json.dumps({"error": str(exc), "register_number": register_number}) + + return json.dumps({ + "source": "lobbyregister_bundestag", + "entry": _normalize_entry(data), + }, ensure_ascii=False, indent=2) diff --git a/agent/connectors/offeneregister.py b/agent/connectors/offeneregister.py new file mode 100644 index 00000000..f40f6190 --- /dev/null +++ b/agent/connectors/offeneregister.py @@ -0,0 +1,160 @@ +"""OffeneRegister bulk JSONL connector. + +Searches pre-downloaded bulk JSONL data from OffeneRegister (~5.1M German +companies, CC0 license). The agent uses run_shell to download the bulk +file; this connector searches the local copy. + +Data format: one JSON object per line with fields like: + {"company_number":"...", "name":"...", "registered_address":"...", + "officers":[...], "all_attributes":{...}, ...} +""" +from __future__ import annotations + +import json +import os +import re +from pathlib import Path +from typing import Any + +from ..normalizers.german import ( + extract_legal_form, + normalize_company_name, + umlauts_to_ascii, +) + + +def _fuzzy_match(query_normalized: str, name: str) -> bool: + """Check if query tokens all appear in the normalized company name.""" + target = normalize_company_name(name) + tokens = query_normalized.split() + return all(tok in target for tok in tokens) + + +def _normalize_officer(raw: dict[str, Any]) -> dict[str, Any]: + """Normalize an officer record from OffeneRegister.""" + return { + "name": raw.get("name", ""), + "role": raw.get("position", "") or raw.get("role", ""), + "start_date": raw.get("start_date", ""), + "end_date": raw.get("end_date", ""), + } + + +def _normalize_entry(raw: dict[str, Any]) -> dict[str, Any]: + """Normalize an OffeneRegister company entry.""" + name = raw.get("name", "") + officers_raw = raw.get("officers", []) + officers = [_normalize_officer(o) for o in officers_raw if isinstance(o, dict)] + + # Extract register info from all_attributes or top level + attrs = raw.get("all_attributes", {}) or {} + registered_address = raw.get("registered_address", "") or attrs.get("registered_address", "") + + # Parse court and register number from company_number or all_attributes + company_number = raw.get("company_number", "") + court = attrs.get("court", "") or "" + register_type = "" + register_number = company_number + + # Try to split "HRB 12345" pattern + hrb_match = re.match(r"^(HR[AB])\s*(\d+.*)$", company_number, re.IGNORECASE) + if hrb_match: + register_type = hrb_match.group(1).upper() + register_number = hrb_match.group(2).strip() + + return { + "name": name, + "legal_form": extract_legal_form(name) or "", + "registered_office": registered_address, + "officers": [o["name"] for o in officers if o.get("name")], + "officers_detail": officers, + "hrb_number": company_number, + "register_type": register_type, + "register_number": register_number, + "court": court, + "status": raw.get("current_status", "") or attrs.get("current_status", ""), + "raw_id": raw.get("company_number", ""), + } + + +def search_offeneregister( + query: str, + data_path: str, + max_results: int = 20, +) -> str: + """Search OffeneRegister bulk JSONL file for matching companies. + + Args: + query: Company name or search terms. + data_path: Path to the bulk JSONL file. + max_results: Maximum number of results to return. + + Returns: + JSON string with search results. + """ + if not query.strip(): + return json.dumps({"error": "Empty query"}) + + path = Path(data_path) + if not path.exists(): + return json.dumps({ + "error": f"Data file not found: {data_path}", + "hint": ( + "Download the OffeneRegister bulk JSONL file first. " + "See https://offeneregister.de/daten/ for download links." + ), + }) + + query_normalized = normalize_company_name(query) + # Also prepare ASCII-folded variant for broader matching + query_ascii = umlauts_to_ascii(query.strip().lower()) + + results: list[dict[str, Any]] = [] + scanned = 0 + errors = 0 + + try: + with open(path, "r", encoding="utf-8", errors="replace") as fh: + for line in fh: + scanned += 1 + line = line.strip() + if not line: + continue + + # Quick pre-filter: check if any query token appears in raw line + line_lower = line.lower() + if query_ascii not in line_lower and not any( + tok in line_lower for tok in query_normalized.split() + ): + continue + + try: + record = json.loads(line) + except json.JSONDecodeError: + errors += 1 + continue + + if not isinstance(record, dict): + continue + + name = record.get("name", "") + if not name: + continue + + if _fuzzy_match(query_normalized, name): + results.append(_normalize_entry(record)) + if len(results) >= max_results: + break + + except OSError as exc: + return json.dumps({"error": f"Failed to read data file: {exc}"}) + + return json.dumps({ + "source": "offeneregister", + "query": query, + "data_path": str(path), + "lines_scanned": scanned, + "parse_errors": errors, + "total_results": len(results), + "results": results, + }, ensure_ascii=False, indent=2) diff --git a/agent/engine.py b/agent/engine.py index 8bd2b65a..531078fe 100644 --- a/agent/engine.py +++ b/agent/engine.py @@ -140,9 +140,11 @@ def __post_init__(self) -> None: self.config.recursive, acceptance_criteria=self.config.acceptance_criteria, demo=self.config.demo, + locale=self.config.locale, ) ac = self.config.acceptance_criteria - tool_defs = get_tool_definitions(include_subtask=self.config.recursive, include_acceptance_criteria=ac) + locale = self.config.locale + tool_defs = get_tool_definitions(include_subtask=self.config.recursive, include_acceptance_criteria=ac, locale=locale) if hasattr(self.model, "tool_defs"): self.model.tool_defs = tool_defs @@ -855,10 +857,10 @@ def _apply_tool_call( # Give executor full tools (no subtask, no execute). _saved_defs = None if exec_model and hasattr(exec_model, "tool_defs"): - exec_model.tool_defs = get_tool_definitions(include_subtask=False, include_acceptance_criteria=self.config.acceptance_criteria) + exec_model.tool_defs = get_tool_definitions(include_subtask=False, include_acceptance_criteria=self.config.acceptance_criteria, locale=self.config.locale) elif exec_model is None and hasattr(cur, "tool_defs"): _saved_defs = cur.tool_defs - cur.tool_defs = get_tool_definitions(include_subtask=False, include_acceptance_criteria=self.config.acceptance_criteria) + cur.tool_defs = get_tool_definitions(include_subtask=False, include_acceptance_criteria=self.config.acceptance_criteria, locale=self.config.locale) self._emit(f"[d{depth}] >> executing leaf: {objective}", on_event) child_logger = replay_logger.child(depth, step) if replay_logger else None @@ -895,6 +897,48 @@ def _apply_tool_call( limit = int(args.get("limit", 100) or 100) return False, self._read_artifact(aid, offset, limit) + # -- DE-locale tools -- + if name == "search_lobbyregister": + query = str(args.get("query", "")).strip() + if not query: + return False, "search_lobbyregister requires query" + sort = str(args.get("sort", "ALPHABETICAL_ASC")) + raw_max = args.get("max_results", 20) + max_results = raw_max if isinstance(raw_max, int) else 20 + return False, self.tools.search_lobbyregister(query=query, sort=sort, max_results=max_results) + + if name == "search_abgeordnetenwatch": + query = str(args.get("query", "")).strip() + endpoint = str(args.get("endpoint", "politicians")).strip() + raw_pp = args.get("parliament_period") + parliament_period = int(raw_pp) if raw_pp is not None else None + raw_party = args.get("party_id") + party_id = int(raw_party) if raw_party is not None else None + raw_pol = args.get("politician_id") + politician_id = int(raw_pol) if raw_pol is not None else None + raw_max = args.get("max_results", 20) + max_results = raw_max if isinstance(raw_max, int) else 20 + return False, self.tools.search_abgeordnetenwatch( + query=query, endpoint=endpoint, parliament_period=parliament_period, + party_id=party_id, politician_id=politician_id, max_results=max_results, + ) + + if name == "search_offeneregister": + query = str(args.get("query", "")).strip() + if not query: + return False, "search_offeneregister requires query" + raw_max = args.get("max_results", 20) + max_results = raw_max if isinstance(raw_max, int) else 20 + return False, self.tools.search_offeneregister(query=query, max_results=max_results) + + if name == "search_eu_transparency": + query = str(args.get("query", "")).strip() + if not query: + return False, "search_eu_transparency requires query" + raw_max = args.get("max_results", 20) + max_results = raw_max if isinstance(raw_max, int) else 20 + return False, self.tools.search_eu_transparency(query=query, max_results=max_results) + return False, f"Unknown action type: {name}" # ------------------------------------------------------------------ diff --git a/agent/normalizers/__init__.py b/agent/normalizers/__init__.py new file mode 100644 index 00000000..3efa8291 --- /dev/null +++ b/agent/normalizers/__init__.py @@ -0,0 +1 @@ +"""German entity normalizers and resolution utilities.""" diff --git a/agent/normalizers/entity_resolver.py b/agent/normalizers/entity_resolver.py new file mode 100644 index 00000000..8e22c403 --- /dev/null +++ b/agent/normalizers/entity_resolver.py @@ -0,0 +1,120 @@ +"""Composite entity matching for German corporate entities. + +Provides multi-signal matching using register numbers, normalized names, +officer overlap, and address similarity. +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from .german import ( + extract_legal_form, + normalize_company_name, + normalize_court, + normalize_person_name, +) + + +@dataclass(slots=True) +class MatchResult: + """Result of a company match attempt.""" + confidence: float # 0.0–1.0 + match_type: str # "exact_register", "name_form_city", "officer_address", "none" + details: dict[str, Any] + + +def _normalize_register(register: str | None) -> str: + """Normalize a register number for comparison (strip whitespace, uppercase).""" + if not register: + return "" + return register.strip().upper().replace(" ", "") + + +def _officer_overlap(officers_a: list[str], officers_b: list[str]) -> float: + """Return fraction of overlapping officers (by normalized name).""" + if not officers_a or not officers_b: + return 0.0 + set_a = {normalize_person_name(o) for o in officers_a} + set_b = {normalize_person_name(o) for o in officers_b} + intersection = set_a & set_b + union = set_a | set_b + if not union: + return 0.0 + return len(intersection) / len(union) + + +def match_company(a: dict[str, Any], b: dict[str, Any]) -> MatchResult: + """Composite matching of two company records. + + Expected record fields (all optional): + name: str — company name + legal_form: str — legal form (GmbH, AG, etc.) + court: str — register court + register_type: str — HRA or HRB + register_number: str — the register number + city: str — registered office city + officers: list[str] — officer/director names + address: str — full address + + Match tiers: + 1. Exact: court + register_type + register_number → confidence=1.0 + 2. High: normalized name + legal form + city → confidence=0.85 + 3. Medium: overlapping officers + similar city → confidence=0.6 + 4. None: no match signals → confidence=0.0 + """ + details: dict[str, Any] = {} + + # --- Tier 1: Exact register match --- + reg_a = _normalize_register(a.get("register_number")) + reg_b = _normalize_register(b.get("register_number")) + if reg_a and reg_b and reg_a == reg_b: + court_a = normalize_court(a.get("court", "")) + court_b = normalize_court(b.get("court", "")) + type_a = (a.get("register_type") or "").upper().strip() + type_b = (b.get("register_type") or "").upper().strip() + if court_a == court_b and type_a == type_b: + details["court"] = court_a + details["register_type"] = type_a + details["register_number"] = reg_a + return MatchResult(confidence=1.0, match_type="exact_register", details=details) + + # --- Tier 2: Normalized name + legal form + city --- + name_a = normalize_company_name(a.get("name", "")) + name_b = normalize_company_name(b.get("name", "")) + + if name_a and name_b and name_a == name_b: + form_a = extract_legal_form(a.get("name", "")) or a.get("legal_form", "") + form_b = extract_legal_form(b.get("name", "")) or b.get("legal_form", "") + city_a = (a.get("city") or "").strip().lower() + city_b = (b.get("city") or "").strip().lower() + + form_match = (form_a or "").lower() == (form_b or "").lower() if (form_a and form_b) else True + city_match = city_a == city_b if (city_a and city_b) else True + + if form_match and city_match and (form_a or city_a): + details["normalized_name"] = name_a + details["legal_form_match"] = form_match + details["city_match"] = city_match + return MatchResult(confidence=0.85, match_type="name_form_city", details=details) + + # --- Tier 3: Officer overlap + city --- + officers_a = a.get("officers", []) + officers_b = b.get("officers", []) + overlap = _officer_overlap(officers_a, officers_b) + + city_a = (a.get("city") or "").strip().lower() + city_b = (b.get("city") or "").strip().lower() + city_match = city_a == city_b if (city_a and city_b) else False + + if overlap >= 0.3 and city_match: + details["officer_overlap"] = round(overlap, 3) + details["city"] = city_a + return MatchResult(confidence=0.6, match_type="officer_address", details=details) + + if overlap >= 0.5: + details["officer_overlap"] = round(overlap, 3) + return MatchResult(confidence=0.5, match_type="officer_overlap_only", details=details) + + # --- No match --- + return MatchResult(confidence=0.0, match_type="none", details={}) diff --git a/agent/normalizers/german.py b/agent/normalizers/german.py new file mode 100644 index 00000000..8c206b7c --- /dev/null +++ b/agent/normalizers/german.py @@ -0,0 +1,241 @@ +"""German-specific name, legal form, and court normalization. + +Handles umlauts, legal form extraction, title stripping, and court aliases +for entity resolution across German corporate/political datasets. +""" +from __future__ import annotations + +import re +import unicodedata + +# --------------------------------------------------------------------------- +# Umlaut normalization — bidirectional ä↔ae, ö↔oe, ü↔ue, ß↔ss +# --------------------------------------------------------------------------- + +UMLAUT_TO_ASCII: dict[str, str] = { + "ä": "ae", "ö": "oe", "ü": "ue", "ß": "ss", + "Ä": "Ae", "Ö": "Oe", "Ü": "Ue", +} + +ASCII_TO_UMLAUT: dict[str, str] = { + "ae": "ä", "oe": "ö", "ue": "ü", "ss": "ß", + "Ae": "Ä", "Oe": "Ö", "Ue": "Ü", +} + +_UMLAUT_RE = re.compile(r"[äöüßÄÖÜ]") + + +def umlauts_to_ascii(text: str) -> str: + """Replace umlauts with ASCII digraphs: ä→ae, ö→oe, ü→ue, ß→ss.""" + return _UMLAUT_RE.sub(lambda m: UMLAUT_TO_ASCII.get(m.group(), m.group()), text) + + +def normalize_unicode(text: str) -> str: + """NFC-normalize and strip accents beyond standard umlauts.""" + return unicodedata.normalize("NFC", text) + + +# --------------------------------------------------------------------------- +# Legal forms — canonical mapping +# --------------------------------------------------------------------------- + +# Map of variations → canonical form. Order matters: longer patterns first. +LEGAL_FORMS: dict[str, str] = { + "gmbh & co. kgaa": "GmbH & Co. KGaA", + "gmbh & co. kg": "GmbH & Co. KG", + "gmbh & co. ohg": "GmbH & Co. OHG", + "gmbh & co.kg": "GmbH & Co. KG", + "ug (haftungsbeschränkt)": "UG (haftungsbeschränkt)", + "ug (haftungsbeschraenkt)": "UG (haftungsbeschränkt)", + "ug haftungsbeschränkt": "UG (haftungsbeschränkt)", + "kgaa": "KGaA", + "gmbh": "GmbH", + "ag": "AG", + "kg": "KG", + "ohg": "OHG", + "gbr": "GbR", + "e.v.": "e.V.", + "ev": "e.V.", + "eg": "eG", + "e.g.": "eG", + "se": "SE", + "se & co. kgaa": "SE & Co. KGaA", + "ug": "UG (haftungsbeschränkt)", + "partg": "PartG", + "partg mbb": "PartG mbB", + "vvag": "VVaG", + "ewiv": "EWIV", + "stiftung": "Stiftung", +} + +# Sorted by length descending so longer patterns match first. +_LEGAL_FORM_PATTERNS: list[tuple[re.Pattern[str], str]] = [ + (re.compile(r"\b" + re.escape(k) + r"\b", re.IGNORECASE), v) + for k, v in sorted(LEGAL_FORMS.items(), key=lambda kv: -len(kv[0])) +] + + +def extract_legal_form(name: str) -> str | None: + """Pull the legal form from a company name, or None if not recognized.""" + for pattern, canonical in _LEGAL_FORM_PATTERNS: + if pattern.search(name): + return canonical + return None + + +def strip_legal_form(name: str) -> str: + """Remove legal form suffix from company name.""" + for pattern, _ in _LEGAL_FORM_PATTERNS: + name = pattern.sub("", name) + return name.strip().rstrip(",").rstrip("&").strip() + + +# --------------------------------------------------------------------------- +# Title prefixes — common in German names +# --------------------------------------------------------------------------- + +TITLE_PREFIXES: list[str] = [ + "Prof. Dr. Dr. h.c.", + "Prof. Dr. Dr.", + "Prof. Dr.", + "Prof.", + "Dr. Dr.", + "Dr. h.c.", + "Dr. med.", + "Dr. jur.", + "Dr. rer. nat.", + "Dr. phil.", + "Dr. ing.", + "Dr.-Ing.", + "Dr.", + "Dipl.-Ing.", + "Dipl.-Kfm.", + "Dipl.-Vw.", +] + +# Nobility particles kept lowercase in German convention. +NOBILITY_PARTICLES: set[str] = { + "von", "zu", "von und zu", "vom", "zum", "zur", + "freiherr", "freifrau", "freiin", + "graf", "gräfin", + "fürst", "fürstin", + "prinz", "prinzessin", + "baron", "baronin", "baroness", + "ritter", +} + +_TITLE_RE = re.compile( + r"^(" + "|".join(re.escape(t) for t in TITLE_PREFIXES) + r")\s+", + re.IGNORECASE, +) + + +def strip_titles(name: str) -> str: + """Remove academic/professional title prefixes from a person name.""" + result = name.strip() + while True: + m = _TITLE_RE.match(result) + if not m: + break + result = result[m.end():].strip() + return result + + +def normalize_person_name(name: str) -> str: + """Normalize a person name: strip titles, normalize umlauts, case-fold.""" + result = strip_titles(name.strip()) + result = umlauts_to_ascii(result) + result = normalize_unicode(result) + # Collapse whitespace + result = re.sub(r"\s+", " ", result).strip() + return result.lower() + + +# --------------------------------------------------------------------------- +# Court aliases +# --------------------------------------------------------------------------- + +# Map of known court abbreviations/aliases → canonical name. +COURT_ALIASES: dict[str, str] = { + "ag münchen": "Amtsgericht München", + "amtsgericht münchen": "Amtsgericht München", + "münchen": "Amtsgericht München", + "muenchen": "Amtsgericht München", + "ag berlin charlottenburg": "Amtsgericht Berlin-Charlottenburg", + "amtsgericht berlin-charlottenburg": "Amtsgericht Berlin-Charlottenburg", + "berlin charlottenburg": "Amtsgericht Berlin-Charlottenburg", + "berlin-charlottenburg": "Amtsgericht Berlin-Charlottenburg", + "ag hamburg": "Amtsgericht Hamburg", + "amtsgericht hamburg": "Amtsgericht Hamburg", + "hamburg": "Amtsgericht Hamburg", + "ag frankfurt am main": "Amtsgericht Frankfurt am Main", + "amtsgericht frankfurt am main": "Amtsgericht Frankfurt am Main", + "frankfurt am main": "Amtsgericht Frankfurt am Main", + "frankfurt": "Amtsgericht Frankfurt am Main", + "ag köln": "Amtsgericht Köln", + "amtsgericht köln": "Amtsgericht Köln", + "ag koeln": "Amtsgericht Köln", + "köln": "Amtsgericht Köln", + "koeln": "Amtsgericht Köln", + "ag düsseldorf": "Amtsgericht Düsseldorf", + "amtsgericht düsseldorf": "Amtsgericht Düsseldorf", + "duesseldorf": "Amtsgericht Düsseldorf", + "düsseldorf": "Amtsgericht Düsseldorf", + "ag stuttgart": "Amtsgericht Stuttgart", + "amtsgericht stuttgart": "Amtsgericht Stuttgart", + "stuttgart": "Amtsgericht Stuttgart", + "ag nürnberg": "Amtsgericht Nürnberg", + "amtsgericht nürnberg": "Amtsgericht Nürnberg", + "nürnberg": "Amtsgericht Nürnberg", + "nuernberg": "Amtsgericht Nürnberg", + "ag hannover": "Amtsgericht Hannover", + "amtsgericht hannover": "Amtsgericht Hannover", + "hannover": "Amtsgericht Hannover", + "ag bremen": "Amtsgericht Bremen", + "amtsgericht bremen": "Amtsgericht Bremen", + "bremen": "Amtsgericht Bremen", + "ag leipzig": "Amtsgericht Leipzig", + "amtsgericht leipzig": "Amtsgericht Leipzig", + "leipzig": "Amtsgericht Leipzig", + "ag dresden": "Amtsgericht Dresden", + "amtsgericht dresden": "Amtsgericht Dresden", + "dresden": "Amtsgericht Dresden", +} + + +def normalize_court(court: str) -> str: + """Normalize a court name to its canonical form.""" + key = umlauts_to_ascii(court.strip()).lower() + # Try exact match first, then umlaut-folded + canonical = COURT_ALIASES.get(court.strip().lower()) + if canonical: + return canonical + canonical = COURT_ALIASES.get(key) + if canonical: + return canonical + # Fallback: prepend "Amtsgericht" if not already present + if not court.lower().startswith("ag ") and not court.lower().startswith("amtsgericht"): + return court.strip() + return court.strip() + + +# --------------------------------------------------------------------------- +# Company name normalization +# --------------------------------------------------------------------------- + +_PUNCT_RE = re.compile(r"[^\w\s]", re.UNICODE) +_MULTI_SPACE_RE = re.compile(r"\s+") + + +def normalize_company_name(name: str) -> str: + """Normalize a company name for matching. + + Strips legal form, normalizes umlauts to ASCII, removes punctuation, + collapses whitespace, and case-folds. + """ + result = strip_legal_form(name.strip()) + result = umlauts_to_ascii(result) + result = normalize_unicode(result) + result = _PUNCT_RE.sub(" ", result) + result = _MULTI_SPACE_RE.sub(" ", result).strip() + return result.lower() diff --git a/agent/prompts.py b/agent/prompts.py index 4ee76424..20853519 100644 --- a/agent/prompts.py +++ b/agent/prompts.py @@ -334,13 +334,84 @@ """ +DE_ENTITY_RESOLUTION_SECTION = """ + +== GERMAN ENTITY RESOLUTION == +When resolving entities across German datasets, apply these rules: + +Legal forms to recognize and normalize: + GmbH, UG (haftungsbeschränkt), AG, KG, KGaA, OHG, GbR, e.V., eG, + GmbH & Co. KG, GmbH & Co. KGaA, SE, SE & Co. KGaA, PartG, PartG mbB, + Stiftung, VVaG, EWIV + +HRB/HRA numbers: German commercial register numbers are court-specific + and NOT globally unique. Always match on (court + register_type + number) + together, never on register number alone. + +Umlaut handling: Always normalize ä↔ae, ö↔oe, ü↔ue, ß↔ss for matching. + "Müller" must match "Mueller". "Straße" must match "Strasse". + +Person names: Strip academic/professional titles before matching: + Prof., Dr., Dr. h.c., Dipl.-Ing., Dipl.-Kfm. + Strip nobility particles: von, zu, Freiherr, Graf, Fürst, Prinz. + Normalize "Dr. jur. Hans-Peter von Müller" → "hans-peter mueller" for matching. + +Court names: Normalize variants — "AG München" = "Amtsgericht München" + = "München" = "Muenchen" all refer to the same register court. +""" + + +DE_CONTEXT_SECTION = """ + +== GERMAN POLITICAL/LEGAL CONTEXT == +Key structural knowledge for investigating German corporate/political networks: + +Government structure: + - Bundestag (federal parliament), Bundesrat (state chamber), 16 Landtage + - Coalition system with Koalitionsvertrag; Fraktionszwang (party discipline) + - 16 Bundesländer each with own transparency and procurement rules + +Key terminology: + - Interessenvertretung: lobbying / interest representation + - Rechenschaftsbericht: party financial accountability report + - Nebeneinkünfte: side income of parliamentarians + - Karenzzeit: cooling-off period for former officials + - Drehtür (revolving door): movement between government and industry + - Spende vs Sponsoring: donation vs sponsorship (different disclosure rules) + - Vergabeverfahren: public procurement process + - Handelsregister: commercial register (HRA/HRB entries at Amtsgerichte) + - Transparenzregister: beneficial ownership register (restricted access) + - Bundesanzeiger: federal gazette (Jahresabschlüsse / annual reports) + +Available data sources (accessible via tools): + - Lobbyregister Bundestag: registered lobbying organizations, expenditure, clients + - abgeordnetenwatch.de: MP profiles, votes, questions, side income (Nebeneinkünfte) + - OffeneRegister: ~5.1M German company registrations (bulk JSONL, CC0) + - EU Transparency Register: EU-level lobbying organizations + +Restricted sources (NOT directly accessible — use web_search/fetch_url): + - Transparenzregister: beneficial ownership (requires registration) + - Bundesanzeiger: annual financial reports (paywalled search) + - handelsregister.de: official register (per-query fees) + - DIP (Dokumentations- und Informationssystem): Bundestag documents +""" + + def build_system_prompt( recursive: bool, acceptance_criteria: bool = False, demo: bool = False, + locale: str = "us", ) -> str: """Assemble the system prompt, including recursion sections only when enabled.""" prompt = SYSTEM_PROMPT_BASE + if locale == "de": + prompt = prompt.replace( + "LLC, Inc, Corp, Ltd", + "GmbH, UG, AG, KG, KGaA, e.V., GmbH & Co. KG, SE", + ) + prompt += DE_ENTITY_RESOLUTION_SECTION + prompt += DE_CONTEXT_SECTION if recursive: prompt += RECURSIVE_SECTION if acceptance_criteria: diff --git a/agent/tool_defs.py b/agent/tool_defs.py index 949a0925..7ed69b00 100644 --- a/agent/tool_defs.py +++ b/agent/tool_defs.py @@ -403,6 +403,119 @@ ] +DE_TOOL_DEFINITIONS: list[dict[str, Any]] = [ + { + "name": "search_lobbyregister", + "description": ( + "Search the German Bundestag Lobbyregister for lobbying organizations. " + "Returns registrations with financial expenditure, clients, and fields of interest." + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Organization name or search terms.", + }, + "sort": { + "type": "string", + "description": "Sort order: ALPHABETICAL_ASC, FINANCIALEXPENSES_DESC, DONATIONS_DESC, etc.", + }, + "max_results": { + "type": "integer", + "description": "Maximum results to return (1-50, default 20).", + }, + }, + "required": ["query"], + "additionalProperties": False, + }, + }, + { + "name": "search_abgeordnetenwatch", + "description": ( + "Search abgeordnetenwatch.de for German MP data, votes, and side income. " + "Endpoint selects which resource to query." + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search terms (person name, topic, etc.).", + }, + "endpoint": { + "type": "string", + "enum": ["politicians", "polls", "sidejobs"], + "description": "Which resource to search.", + }, + "parliament_period": { + "type": "integer", + "description": "Filter by parliament period ID.", + }, + "party_id": { + "type": "integer", + "description": "Filter by party ID.", + }, + "politician_id": { + "type": "integer", + "description": "Politician ID for sidejobs endpoint.", + }, + "max_results": { + "type": "integer", + "description": "Maximum results (1-100, default 20).", + }, + }, + "required": ["query", "endpoint"], + "additionalProperties": False, + }, + }, + { + "name": "search_offeneregister", + "description": ( + "Search OffeneRegister bulk data for German company registrations. " + "Requires the bulk JSONL file to be downloaded to the workspace first." + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Company name or search terms.", + }, + "max_results": { + "type": "integer", + "description": "Maximum results (1-100, default 20).", + }, + }, + "required": ["query"], + "additionalProperties": False, + }, + }, + { + "name": "search_eu_transparency", + "description": ( + "Search the EU Transparency Register bulk data for lobbying organizations " + "active in EU institutions. Requires the CSV/XML file to be downloaded first." + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Organization name or search terms.", + }, + "max_results": { + "type": "integer", + "description": "Maximum results (1-100, default 20).", + }, + }, + "required": ["query"], + "additionalProperties": False, + }, + }, +] + + _ARTIFACT_TOOLS = {"list_artifacts", "read_artifact"} _DELEGATION_TOOLS = {"subtask", "execute", "list_artifacts", "read_artifact"} @@ -426,13 +539,15 @@ def get_tool_definitions( include_subtask: bool = True, include_artifacts: bool = False, include_acceptance_criteria: bool = False, + locale: str = "us", ) -> list[dict[str, Any]]: - """Return tool definitions based on mode. + """Return tool definitions based on mode and locale. - ``include_subtask=True`` (normal recursive) → everything except execute, artifact tools. - ``include_subtask=False`` (flat / executor) → no subtask, no execute, no artifact tools. - ``include_artifacts=True`` → add list_artifacts + read_artifact. - ``include_acceptance_criteria=False`` → strip acceptance_criteria from schemas. + - ``locale="de"`` → append German/EU data source tools. """ if include_subtask: defs = [d for d in TOOL_DEFINITIONS if d["name"] not in ("execute",) and d["name"] not in _ARTIFACT_TOOLS] @@ -442,6 +557,9 @@ def get_tool_definitions( if include_artifacts: defs += [d for d in TOOL_DEFINITIONS if d["name"] in _ARTIFACT_TOOLS] + if locale == "de": + defs += DE_TOOL_DEFINITIONS + if not include_acceptance_criteria: defs = _strip_acceptance_criteria(defs) return defs diff --git a/agent/tools.py b/agent/tools.py index bb015c76..6e88a3e9 100644 --- a/agent/tools.py +++ b/agent/tools.py @@ -55,6 +55,8 @@ class WorkspaceTools: max_search_hits: int = 200 exa_api_key: str | None = None exa_base_url: str = "https://api.exa.ai" + locale: str = "us" + lobbyregister_api_key: str = "" def __post_init__(self) -> None: self.root = self.root.expanduser().resolve() @@ -842,3 +844,86 @@ def fetch_url(self, urls: list[str]) -> str: "total": len(pages), } return self._clip(json.dumps(output, indent=2, ensure_ascii=True), self.max_file_chars) + + # ------------------------------------------------------------------ + # DE-locale connector wrappers + # ------------------------------------------------------------------ + + def search_lobbyregister( + self, + query: str, + sort: str = "ALPHABETICAL_ASC", + max_results: int = 20, + ) -> str: + from .connectors.lobbyregister import search_lobbyregister + return self._clip( + search_lobbyregister(query, api_key=self.lobbyregister_api_key, sort=sort, max_results=max_results), + self.max_file_chars, + ) + + def search_abgeordnetenwatch( + self, + query: str, + endpoint: str = "politicians", + parliament_period: int | None = None, + party_id: int | None = None, + politician_id: int | None = None, + max_results: int = 20, + ) -> str: + from .connectors.abgeordnetenwatch import ( + get_poll_votes, + search_politicians, + search_sidejobs, + ) + if endpoint == "politicians": + return self._clip( + search_politicians(query=query, parliament_period=parliament_period, party_id=party_id, max_results=max_results), + self.max_file_chars, + ) + if endpoint == "polls": + poll_id = parliament_period or 0 + return self._clip(get_poll_votes(poll_id=poll_id, max_results=max_results), self.max_file_chars) + if endpoint == "sidejobs": + return self._clip( + search_sidejobs(politician_id=politician_id, max_results=max_results), + self.max_file_chars, + ) + return f"Unknown endpoint: {endpoint}. Use: politicians, polls, sidejobs" + + def search_offeneregister(self, query: str, max_results: int = 20) -> str: + from .connectors.offeneregister import search_offeneregister + data_path = self._find_bulk_data("offeneregister*.jsonl", "offeneregister") + if not data_path: + return json.dumps({ + "error": "OffeneRegister bulk JSONL file not found in workspace.", + "hint": "Download from https://offeneregister.de/daten/ using run_shell, then retry.", + }) + return self._clip(search_offeneregister(query, str(data_path), max_results), self.max_file_chars) + + def search_eu_transparency(self, query: str, max_results: int = 20) -> str: + from .connectors.eu_transparency import search_eu_transparency + data_path = self._find_bulk_data("*transparency*", "eu_transparency") + if not data_path: + return json.dumps({ + "error": "EU Transparency Register data file not found in workspace.", + "hint": "Download from https://data.europa.eu/data/datasets/transparency-register using run_shell, then retry.", + }) + return self._clip(search_eu_transparency(query, str(data_path), max_results), self.max_file_chars) + + def _find_bulk_data(self, glob_pattern: str, subdir: str) -> Path | None: + """Locate a bulk data file in workspace or a data subdirectory.""" + import fnmatch as _fnmatch + for p in self.root.iterdir(): + if p.is_file() and _fnmatch.fnmatch(p.name.lower(), glob_pattern): + return p + data_dir = self.root / "data" + if data_dir.is_dir(): + for p in data_dir.iterdir(): + if p.is_file() and _fnmatch.fnmatch(p.name.lower(), glob_pattern): + return p + sub = self.root / subdir + if sub.is_dir(): + for p in sub.iterdir(): + if p.is_file(): + return p + return None