11# Copyright (c) NiceBots
22# SPDX-License-Identifier: MIT
33
4+ import ssl
45from collections import defaultdict
56from logging import getLogger
7+ from pathlib import Path
68from typing import Any
9+ from urllib .parse import parse_qs , urlparse
710
811import aerich
912from tortoise import Tortoise
1316logger = getLogger ("bot" ).getChild ("database" )
1417
1518
19+ def create_ssl_context () -> ssl .SSLContext :
20+ """Create SSL context for Database connection."""
21+ cert_paths = [
22+ Path (__file__ ).parent .parent .parent / ".postgres" / "root.crt" ,
23+ Path .cwd () / ".postgres" / "root.crt" ,
24+ ]
25+
26+ cert_path = None
27+ for path in cert_paths :
28+ if path .exists ():
29+ cert_path = path
30+ break
31+
32+ if cert_path is None :
33+ logger .error ("Database certificate not found!" )
34+ raise FileNotFoundError ("Database SSL certificate not found" )
35+
36+ logger .debug (f"Loading SSL certificate from: { cert_path } " )
37+
38+ with open (cert_path ) as f :
39+ cert_content = f .read ()
40+ logger .debug (f"Certificate loaded, length: { len (cert_content )} bytes" )
41+
42+ ssl_context = ssl .SSLContext (ssl .PROTOCOL_TLS_CLIENT )
43+ ssl_context .check_hostname = True
44+ ssl_context .verify_mode = ssl .CERT_REQUIRED
45+
46+ ssl_context .load_verify_locations (cafile = str (cert_path ))
47+
48+ logger .debug (
49+ f"SSL context created: verify_mode={ ssl_context .verify_mode } , check_hostname={ ssl_context .check_hostname } "
50+ )
51+ logger .debug ("SSL context will ONLY trust the custom CA cert (no system CAs)" )
52+
53+ return ssl_context
54+
55+
56+ def parse_postgres_url (url : str ) -> dict [str , Any ]:
57+ """Parse postgres URL into credentials dict."""
58+ parsed = urlparse (url )
59+
60+ query_params = parse_qs (parsed .query )
61+
62+ credentials = {
63+ "host" : parsed .hostname ,
64+ "port" : parsed .port or 5432 ,
65+ "user" : parsed .username ,
66+ "password" : parsed .password ,
67+ "database" : parsed .path .lstrip ("/" ),
68+ }
69+
70+ for key , value in query_params .items ():
71+ if key not in ["ssl" , "sslmode" ]: # Skip SSL params
72+ credentials [key ] = value [0 ] if len (value ) == 1 else value
73+
74+ return credentials
75+
76+
1677def apply_params (uri : str , params : dict [str , Any ] | None ) -> str :
1778 if params is None :
1879 return uri
1980
2081 first : bool = True
2182 for param , value in params .items ():
83+ if param == "ssl" :
84+ continue
2285 if value is not None :
2386 uri += f"{ '?' if first else '&' } { param } ={ value } "
2487 first = False
@@ -42,26 +105,50 @@ def get_url_apps_mapping() -> dict[str, list[str]]:
42105 return mapping
43106
44107
45- def parse_url_apps_mapping (url_apps_mapping : dict [str , list [str ]]) -> tuple [dict [str , str ], dict [str , str ]]:
108+ def parse_url_apps_mapping (url_apps_mapping : dict [str , list [str ]]) -> tuple [dict [str , str ], dict [str , dict [ str , Any ] ]]:
46109 app_connection : dict [str , str ] = {}
47- connection_url : dict [str , str ] = {}
110+ connection_config : dict [str , dict [str , Any ]] = {}
111+
112+ ssl_context = create_ssl_context () if config .db .params and config .db .params .get ("ssl" ) else None
48113
49114 for i , (url , apps ) in enumerate (url_apps_mapping .items ()):
50115 connection_name = f"connection_{ i } "
51- connection_url [connection_name ] = url
116+
117+ credentials = parse_postgres_url (url )
118+ if ssl_context :
119+ credentials ["ssl" ] = ssl_context
120+
121+ logger .debug (f"Connection { connection_name } credentials keys: { list (credentials .keys ())} " )
122+ logger .debug (f"SSL in credentials: { 'ssl' in credentials } " )
123+
124+ connection_config [connection_name ] = {"engine" : "tortoise.backends.asyncpg" , "credentials" : credentials }
125+
52126 for app in apps :
53127 app_connection [app ] = connection_name
54128
55129 app_connection ["models" ] = "default"
56- connection_url ["default" ] = apply_params (config .db .url , config .db .params )
130+ default_url = apply_params (config .db .url , config .db .params )
131+
132+ logger .debug (f"Default URL (sanitized): { default_url .split ('@' )[1 ] if '@' in default_url else 'N/A' } " )
133+
134+ credentials = parse_postgres_url (default_url )
135+ if ssl_context :
136+ credentials ["ssl" ] = ssl_context
137+
138+ logger .debug (f"Default connection credentials keys: { list (credentials .keys ())} " )
139+ logger .debug (f"SSL in credentials: { 'ssl' in credentials } " )
140+ if "ssl" in credentials :
141+ logger .debug (f"SSL value type: { type (credentials ['ssl' ])} " )
142+
143+ connection_config ["default" ] = {"engine" : "tortoise.backends.asyncpg" , "credentials" : credentials }
57144
58- return app_connection , connection_url
145+ return app_connection , connection_config
59146
60147
61148APP_CONNECTION_MAPPING : dict [str , str ]
62- CONNECTION_URL_MAPPING : dict [str , str ]
149+ CONNECTION_CONFIG_MAPPING : dict [str , dict [ str , Any ] ]
63150
64- APP_CONNECTION_MAPPING , CONNECTION_URL_MAPPING = parse_url_apps_mapping (get_url_apps_mapping ()) # pyright: ignore[reportConstantRedefinition]
151+ APP_CONNECTION_MAPPING , CONNECTION_CONFIG_MAPPING = parse_url_apps_mapping (get_url_apps_mapping ())
65152
66153
67154def get_apps () -> dict [str , dict [str , list [str ] | str ]]:
@@ -80,7 +167,7 @@ def get_apps() -> dict[str, dict[str, list[str] | str]]:
80167
81168
82169TORTOISE_ORM = {
83- "connections" : CONNECTION_URL_MAPPING ,
170+ "connections" : CONNECTION_CONFIG_MAPPING ,
84171 "apps" : get_apps (),
85172}
86173
@@ -93,7 +180,7 @@ async def init() -> None:
93180 )
94181 await command .init ()
95182 migrated = await command .upgrade (run_in_transaction = True )
96- logger .success (f"Successfully migrated { migrated } migrations" ) # pyright: ignore [reportAttributeAccessIssue]
183+ logger .success (f"Successfully migrated { migrated } migrations" )
97184 await Tortoise .init (config = TORTOISE_ORM )
98185
99186
0 commit comments