Skip to content

Commit a5cb285

Browse files
committed
✨ Add SSL support for database connections and refactor connection configuration handling
1 parent e2e190e commit a5cb285

File tree

1 file changed

+96
-9
lines changed

1 file changed

+96
-9
lines changed

src/database/config/__init__.py

Lines changed: 96 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# Copyright (c) NiceBots
22
# SPDX-License-Identifier: MIT
33

4+
import ssl
45
from collections import defaultdict
56
from logging import getLogger
7+
from pathlib import Path
68
from typing import Any
9+
from urllib.parse import parse_qs, urlparse
710

811
import aerich
912
from tortoise import Tortoise
@@ -13,12 +16,72 @@
1316
logger = getLogger("bot").getChild("database")
1417

1518

19+
def create_ssl_context() -> ssl.SSLContext:
20+
"""Create SSL context for Database connection."""
21+
cert_paths = [
22+
Path(__file__).parent.parent.parent / ".postgres" / "root.crt",
23+
Path.cwd() / ".postgres" / "root.crt",
24+
]
25+
26+
cert_path = None
27+
for path in cert_paths:
28+
if path.exists():
29+
cert_path = path
30+
break
31+
32+
if cert_path is None:
33+
logger.error("Database certificate not found!")
34+
raise FileNotFoundError("Database SSL certificate not found")
35+
36+
logger.debug(f"Loading SSL certificate from: {cert_path}")
37+
38+
with open(cert_path) as f:
39+
cert_content = f.read()
40+
logger.debug(f"Certificate loaded, length: {len(cert_content)} bytes")
41+
42+
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
43+
ssl_context.check_hostname = True
44+
ssl_context.verify_mode = ssl.CERT_REQUIRED
45+
46+
ssl_context.load_verify_locations(cafile=str(cert_path))
47+
48+
logger.debug(
49+
f"SSL context created: verify_mode={ssl_context.verify_mode}, check_hostname={ssl_context.check_hostname}"
50+
)
51+
logger.debug("SSL context will ONLY trust the custom CA cert (no system CAs)")
52+
53+
return ssl_context
54+
55+
56+
def parse_postgres_url(url: str) -> dict[str, Any]:
57+
"""Parse postgres URL into credentials dict."""
58+
parsed = urlparse(url)
59+
60+
query_params = parse_qs(parsed.query)
61+
62+
credentials = {
63+
"host": parsed.hostname,
64+
"port": parsed.port or 5432,
65+
"user": parsed.username,
66+
"password": parsed.password,
67+
"database": parsed.path.lstrip("/"),
68+
}
69+
70+
for key, value in query_params.items():
71+
if key not in ["ssl", "sslmode"]: # Skip SSL params
72+
credentials[key] = value[0] if len(value) == 1 else value
73+
74+
return credentials
75+
76+
1677
def apply_params(uri: str, params: dict[str, Any] | None) -> str:
1778
if params is None:
1879
return uri
1980

2081
first: bool = True
2182
for param, value in params.items():
83+
if param == "ssl":
84+
continue
2285
if value is not None:
2386
uri += f"{'?' if first else '&'}{param}={value}"
2487
first = False
@@ -42,26 +105,50 @@ def get_url_apps_mapping() -> dict[str, list[str]]:
42105
return mapping
43106

44107

45-
def parse_url_apps_mapping(url_apps_mapping: dict[str, list[str]]) -> tuple[dict[str, str], dict[str, str]]:
108+
def parse_url_apps_mapping(url_apps_mapping: dict[str, list[str]]) -> tuple[dict[str, str], dict[str, dict[str, Any]]]:
46109
app_connection: dict[str, str] = {}
47-
connection_url: dict[str, str] = {}
110+
connection_config: dict[str, dict[str, Any]] = {}
111+
112+
ssl_context = create_ssl_context() if config.db.params and config.db.params.get("ssl") else None
48113

49114
for i, (url, apps) in enumerate(url_apps_mapping.items()):
50115
connection_name = f"connection_{i}"
51-
connection_url[connection_name] = url
116+
117+
credentials = parse_postgres_url(url)
118+
if ssl_context:
119+
credentials["ssl"] = ssl_context
120+
121+
logger.debug(f"Connection {connection_name} credentials keys: {list(credentials.keys())}")
122+
logger.debug(f"SSL in credentials: {'ssl' in credentials}")
123+
124+
connection_config[connection_name] = {"engine": "tortoise.backends.asyncpg", "credentials": credentials}
125+
52126
for app in apps:
53127
app_connection[app] = connection_name
54128

55129
app_connection["models"] = "default"
56-
connection_url["default"] = apply_params(config.db.url, config.db.params)
130+
default_url = apply_params(config.db.url, config.db.params)
131+
132+
logger.debug(f"Default URL (sanitized): {default_url.split('@')[1] if '@' in default_url else 'N/A'}")
133+
134+
credentials = parse_postgres_url(default_url)
135+
if ssl_context:
136+
credentials["ssl"] = ssl_context
137+
138+
logger.debug(f"Default connection credentials keys: {list(credentials.keys())}")
139+
logger.debug(f"SSL in credentials: {'ssl' in credentials}")
140+
if "ssl" in credentials:
141+
logger.debug(f"SSL value type: {type(credentials['ssl'])}")
142+
143+
connection_config["default"] = {"engine": "tortoise.backends.asyncpg", "credentials": credentials}
57144

58-
return app_connection, connection_url
145+
return app_connection, connection_config
59146

60147

61148
APP_CONNECTION_MAPPING: dict[str, str]
62-
CONNECTION_URL_MAPPING: dict[str, str]
149+
CONNECTION_CONFIG_MAPPING: dict[str, dict[str, Any]]
63150

64-
APP_CONNECTION_MAPPING, CONNECTION_URL_MAPPING = parse_url_apps_mapping(get_url_apps_mapping()) # pyright: ignore[reportConstantRedefinition]
151+
APP_CONNECTION_MAPPING, CONNECTION_CONFIG_MAPPING = parse_url_apps_mapping(get_url_apps_mapping())
65152

66153

67154
def get_apps() -> dict[str, dict[str, list[str] | str]]:
@@ -80,7 +167,7 @@ def get_apps() -> dict[str, dict[str, list[str] | str]]:
80167

81168

82169
TORTOISE_ORM = {
83-
"connections": CONNECTION_URL_MAPPING,
170+
"connections": CONNECTION_CONFIG_MAPPING,
84171
"apps": get_apps(),
85172
}
86173

@@ -93,7 +180,7 @@ async def init() -> None:
93180
)
94181
await command.init()
95182
migrated = await command.upgrade(run_in_transaction=True)
96-
logger.success(f"Successfully migrated {migrated} migrations") # pyright: ignore [reportAttributeAccessIssue]
183+
logger.success(f"Successfully migrated {migrated} migrations")
97184
await Tortoise.init(config=TORTOISE_ORM)
98185

99186

0 commit comments

Comments
 (0)