Skip to content

Commit 5f2c754

Browse files
committed
feat: enforce use tls
1 parent 5ba1aee commit 5f2c754

3 files changed

Lines changed: 60 additions & 15 deletions

File tree

README.md

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- 🛡️ **类型安全**: 完整的类型注解和运行时检查
1111
-**异步优先**: 全面支持 asyncio 和异步操作
1212
- 🔧 **高度可配置**: 自定义用量计算和元数据提取
13+
- 🔐 **默认 TLS 加密**: 默认使用 TLS 连接确保通信安全
1314

1415
## 前置要求
1516

@@ -35,17 +36,39 @@ uv sync
3536
uv sync --extra dev --extra test
3637
```
3738

39+
## 安全性
40+
41+
### TLS 加密连接
42+
43+
SDK 默认使用 TLS 加密连接
44+
45+
- **传输加密**: 所有 MQTT 通信都通过 TLS 1.2+ 加密
46+
- **自动连接**: 客户端自动使用安全连接
47+
48+
### 连接状态监控
49+
50+
```python
51+
# 检查连接状态
52+
if client.is_connected():
53+
print("已连接到 MQTT 代理")
54+
else:
55+
print("未连接")
56+
57+
# 手动连接(如果需要)
58+
await client.connect()
59+
```
60+
3861
## 快速开始
3962

4063
### 1. 初始化全局单例
4164

4265
```python
4366
from billing_sdk import BillingClient
4467

45-
# 初始化全局单例(整个应用只需要初始化一次,会自动连接 MQTT)
68+
# 初始化全局单例(整个应用只需要初始化一次,会自动通过 TLS 连接 MQTT)
4669
client = BillingClient(
4770
broker_host="localhost",
48-
broker_port=1883,
71+
broker_port=8883, # TLS 默认端口
4972
username="your_username",
5073
password="your_password"
5174
)

src/billing_sdk/client.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import json
33
import logging
4+
import ssl
45
import time
56
from collections.abc import Callable
67
from dataclasses import dataclass
@@ -36,7 +37,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "BillingClient":
3637
def __init__(
3738
self,
3839
broker_host: str,
39-
broker_port: int = 1883,
40+
broker_port: int = 8883, # TLS 默认端口
4041
username: str | None = None,
4142
password: str | None = None,
4243
logger: logging.Logger | None = None,
@@ -49,6 +50,7 @@ def __init__(
4950
self.broker_port = broker_port
5051
self.username = username
5152
self.password = password
53+
5254
self._client: AsyncMQTTClient | None = None
5355
self._is_connected = False
5456
# 用于缓存有效的 API keys,从 MQTT 推送中动态更新
@@ -65,6 +67,16 @@ def __init__(
6567
# 自动连接 MQTT
6668
self._auto_connect()
6769

70+
def _create_tls_context(self) -> ssl.SSLContext:
71+
"""创建默认的 TLS SSL 上下文"""
72+
# 创建 SSL 上下文,默认忽略证书验证
73+
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
74+
context.minimum_version = ssl.TLSVersion.TLSv1_2
75+
context.check_hostname = False # MQTT 通常不使用主机名验证
76+
context.verify_mode = ssl.CERT_NONE # 忽略证书校验
77+
78+
return context
79+
6880
@classmethod
6981
def get_instance(cls) -> "BillingClient":
7082
"""获取单例实例"""
@@ -92,24 +104,34 @@ def is_connected(self) -> bool:
92104
return self._is_connected
93105

94106
async def connect(self) -> None:
95-
"""连接到 MQTT 代理"""
107+
"""连接到 MQTT 代理(默认使用 TLS)"""
96108
async with self._lock:
97109
if self._is_connected:
98110
self._logger.info("BillingClient 已经连接,跳过重复连接")
99111
return
100112

101113
try:
102-
self._client = AsyncMQTTClient(
103-
hostname=self.broker_host,
104-
port=self.broker_port,
105-
username=self.username,
106-
password=self.password,
114+
# 创建 TLS 上下文
115+
tls_context = self._create_tls_context()
116+
117+
# 配置 MQTT 客户端参数,默认使用 TLS
118+
client_kwargs = {
119+
"hostname": self.broker_host,
120+
"port": self.broker_port,
121+
"username": self.username,
122+
"password": self.password,
123+
"tls_context": tls_context,
124+
}
125+
126+
self._logger.info(
127+
f"使用 TLS 连接到 MQTT 代理 {self.broker_host}:{self.broker_port}"
107128
)
129+
130+
self._client = AsyncMQTTClient(**client_kwargs)
108131
await self._client.connect()
109132
self._is_connected = True
110-
self._logger.info(
111-
f"已连接到 MQTT 代理 {self.broker_host}:{self.broker_port}"
112-
)
133+
134+
self._logger.info("已通过 TLS 连接到 MQTT 代理")
113135

114136
# 订阅 Key 状态更新
115137
await self._client.subscribe("key-status-updates")

tests/test_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ async def test_singleton_pattern(self):
1919
"""测试单例模式"""
2020
# 第一次初始化
2121
with patch("asyncio.create_task"):
22-
client1 = BillingClient("localhost", 1883)
22+
client1 = BillingClient("localhost", 8883)
2323

2424
# 第二次获取应该返回同一个实例
2525
with patch("asyncio.create_task"):
@@ -29,7 +29,7 @@ async def test_singleton_pattern(self):
2929
assert client1 is client2
3030
# 配置应该是第一次的配置
3131
assert client1.broker_host == "localhost"
32-
assert client1.broker_port == 1883
32+
assert client1.broker_port == 8883
3333

3434
@pytest.mark.asyncio
3535
async def test_get_instance_before_initialization(self):
@@ -41,7 +41,7 @@ async def test_get_instance_before_initialization(self):
4141
async def test_get_instance_after_initialization(self):
4242
"""测试初始化后获取实例"""
4343
with patch("asyncio.create_task"):
44-
original = BillingClient("localhost", 1883)
44+
original = BillingClient("localhost", 8883)
4545

4646
instance = BillingClient.get_instance()
4747
assert instance is original

0 commit comments

Comments
 (0)