Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion astrbot/cli/commands/cmd_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def set_config(key: str, value: str):

@conf.command(name="get")
@click.argument("key", required=False)
def get_config(key: str = None):
def get_config(key: str | None = None):
"""获取配置项的值,不提供key则显示所有可配置项"""
config = _load_config()

Expand Down
3 changes: 2 additions & 1 deletion astrbot/cli/commands/cmd_init.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import asyncio
from pathlib import Path

import click
from filelock import FileLock, Timeout

from ..utils import check_dashboard, get_astrbot_root


async def initialize_astrbot(astrbot_root) -> None:
async def initialize_astrbot(astrbot_root: Path) -> None:
"""执行 AstrBot 初始化逻辑"""
dot_astrbot = astrbot_root / ".astrbot"

Expand Down
8 changes: 4 additions & 4 deletions astrbot/cli/utils/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,9 @@ def manage_plugin(
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")

# 备份现有插件
if is_update and backup_path.exists():
if is_update and backup_path is not None and backup_path.exists():
shutil.rmtree(backup_path)
if is_update:
if is_update and backup_path is not None:
shutil.copytree(target_path, backup_path)

try:
Expand All @@ -233,13 +233,13 @@ def manage_plugin(
get_git_repo(repo_url, target_path, proxy)

# 更新成功,删除备份
if is_update and backup_path.exists():
if is_update and backup_path is not None and backup_path.exists():
shutil.rmtree(backup_path)
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
except Exception as e:
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
if is_update and backup_path.exists():
if is_update and backup_path is not None and backup_path.exists():
shutil.move(backup_path, target_path)
raise click.ClickException(
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}",
Expand Down
9 changes: 6 additions & 3 deletions astrbot/cli/utils/version_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,12 @@ def split_version(version):
return -1
if isinstance(p1, str) and isinstance(p2, int):
return 1
if (isinstance(p1, int) and isinstance(p2, int)) or (
isinstance(p1, str) and isinstance(p2, str)
):
if isinstance(p1, int) and isinstance(p2, int):
if p1 > p2:
return 1
if p1 < p2:
return -1
elif isinstance(p1, str) and isinstance(p2, str):
if p1 > p2:
return 1
if p1 < p2:
Expand Down
4 changes: 2 additions & 2 deletions astrbot/core/config/astrbot_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
self,
config_path: str = ASTRBOT_CONFIG_PATH,
default_config: dict = DEFAULT_CONFIG,
schema: dict = None,
schema: dict | None = None,
):
super().__init__()

Expand Down Expand Up @@ -142,7 +142,7 @@ def check_config_integrity(self, refer_conf: dict, conf: dict, path=""):

return has_new

def save_config(self, replace_config: dict = None):
def save_config(self, replace_config: dict | None = None):
"""将配置写入文件

如果传入 replace_config,则将配置替换为 replace_config
Expand Down
10 changes: 5 additions & 5 deletions astrbot/core/db/migration/sqlite_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,11 @@ def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: list[str] = None,
message_types: list[str] = None,
search_query: str = None,
exclude_ids: list[str] = None,
exclude_platforms: list[str] = None,
platforms: list[str] | None = None,
message_types: list[str] | None = None,
search_query: str | None = None,
exclude_ids: list[str] | None = None,
exclude_platforms: list[str] | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""获取筛选后的对话列表"""
try:
Expand Down
6 changes: 3 additions & 3 deletions astrbot/core/persona_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ async def delete_persona(self, persona_id: str):
async def update_persona(
self,
persona_id: str,
system_prompt: str = None,
begin_dialogs: list[str] = None,
tools: list[str] = None,
system_prompt: str | None = None,
begin_dialogs: list[str] | None = None,
tools: list[str] | None = None,
):
"""更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
existing_persona = await self.db.get_persona_by_id(persona_id)
Expand Down
4 changes: 2 additions & 2 deletions astrbot/core/platform/astr_message_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def request_llm(
image_urls: list[str] | None = None,
contexts: list | None = None,
system_prompt: str = "",
conversation: Conversation = None,
conversation: Conversation | None = None,
) -> ProviderRequest:
"""创建一个 LLM 请求。

Expand Down Expand Up @@ -394,7 +394,7 @@ async def react(self, emoji: str):
"""
await self.send(MessageChain([Plain(emoji)]))

async def get_group(self, group_id: str = None, **kwargs) -> Group | None:
async def get_group(self, group_id: str | None = None, **kwargs) -> Group | None:
"""获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。

适配情况:
Expand Down
12 changes: 6 additions & 6 deletions astrbot/core/platform/astrbot_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
@dataclass
class MessageMember:
user_id: str # 发送者id
nickname: str = None
nickname: str | None = None

def __str__(self):
# 使用 f-string 来构建返回的字符串表示形式
Expand All @@ -23,15 +23,15 @@ def __str__(self):
class Group:
group_id: str
"""群号"""
group_name: str = None
group_name: str | None = None
"""群名称"""
group_avatar: str = None
group_avatar: str | None = None
"""群头像"""
group_owner: str = None
group_owner: str | None = None
"""群主 id"""
group_admins: list[str] = None
group_admins: list[str] | None = None
"""群管理员 id"""
members: list[MessageMember] = None
members: list[MessageMember] | None = None
"""所有群成员"""

def __str__(self):
Expand Down
2 changes: 1 addition & 1 deletion astrbot/core/platform/message_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class MessageSession:
"""平台适配器实例的唯一标识符。自 AstrBot v4.0.0 起,该字段实际为 platform_id。"""
message_type: MessageType
session_id: str
platform_id: str = None
platform_id: str | None = None

def __str__(self):
return f"{self.platform_id}:{self.message_type.value}:{self.session_id}"
Expand Down
8 changes: 4 additions & 4 deletions astrbot/core/platform/platform_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ class PlatformMetadata:
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
description: str
"""平台的描述"""
id: str = None
id: str | None = None
"""平台的唯一标识符,用于配置中识别特定平台"""

default_config_tmpl: dict = None
default_config_tmpl: dict | None = None
"""平台的默认配置模板"""
adapter_display_name: str = None
adapter_display_name: str | None = None
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
logo_path: str = None
logo_path: str | None = None
"""平台适配器的 logo 文件路径(相对于插件目录)"""
6 changes: 3 additions & 3 deletions astrbot/core/platform/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
def register_platform_adapter(
adapter_name: str,
desc: str,
default_config_tmpl: dict = None,
adapter_display_name: str = None,
logo_path: str = None,
default_config_tmpl: dict | None = None,
adapter_display_name: str | None = None,
logo_path: str | None = None,
):
"""用于注册平台适配器的带参装饰器。

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def send_message(
message_chain: MessageChain,
event: Event | None = None,
is_group: bool = False,
session_id: str = None,
session_id: str | None = None,
):
"""发送消息至 QQ 协议端(aiocqhttp)。

Expand Down
2 changes: 1 addition & 1 deletion astrbot/core/platform/sources/discord/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class DiscordBotClient(discord.Bot):
"""Discord客户端封装"""

def __init__(self, token: str, proxy: str = None):
def __init__(self, token: str, proxy: str | None = None):
self.token = token
self.proxy = proxy

Expand Down
22 changes: 11 additions & 11 deletions astrbot/core/platform/sources/discord/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ class DiscordEmbed(BaseMessageComponent):

def __init__(
self,
title: str = None,
description: str = None,
color: int = None,
url: str = None,
thumbnail: str = None,
image: str = None,
footer: str = None,
fields: list[dict] = None,
title: str | None = None,
description: str | None = None,
color: int | None = None,
url: str | None = None,
thumbnail: str | None = None,
image: str | None = None,
footer: str | None = None,
fields: list[dict] | None = None,
):
self.title = title
self.description = description
Expand Down Expand Up @@ -66,10 +66,10 @@ class DiscordButton(BaseMessageComponent):
def __init__(
self,
label: str,
custom_id: str = None,
custom_id: str | None = None,
style: str = "primary",
emoji: str = None,
url: str = None,
emoji: str | None = None,
url: str | None = None,
disabled: bool = False,
):
self.label = label
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,9 @@ async def _collect_and_register_commands(self):
def _create_dynamic_callback(self, cmd_name: str):
"""为每个指令动态创建一个异步回调函数"""

async def dynamic_callback(ctx: discord.ApplicationContext, params: str = None):
async def dynamic_callback(
ctx: discord.ApplicationContext, params: str | None = None
):
# 将平台特定的前缀'/'剥离,以适配通用的CommandFilter
logger.debug(f"[Discord] 回调函数触发: {cmd_name}")
logger.debug(f"[Discord] 回调函数参数: {ctx}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def send_streaming(self, generator, use_fallback: bool = False):

return await super().send_streaming(generator, use_fallback)

async def _post_send(self, stream: dict = None):
async def _post_send(self, stream: dict | None = None):
if not self.send_buffer:
return None

Expand Down Expand Up @@ -265,17 +265,17 @@ async def post_c2c_message(
self,
openid: str,
msg_type: int = 0,
content: str = None,
embed: message.Embed = None,
ark: message.Ark = None,
message_reference: message.Reference = None,
media: message.Media = None,
msg_id: str = None,
content: str | None = None,
embed: message.Embed | None = None,
ark: message.Ark | None = None,
message_reference: message.Reference | None = None,
media: message.Media | None = None,
msg_id: str | None = None,
msg_seq: str = 1,
event_id: str = None,
markdown: message.MarkdownPayload = None,
keyboard: message.Keyboard = None,
stream: dict = None,
event_id: str | None = None,
markdown: message.MarkdownPayload | None = None,
keyboard: message.Keyboard | None = None,
stream: dict | None = None,
) -> message.Message:
payload = locals()
payload.pop("self", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(
is_private_chat: bool = False,
cached_texts=None,
cached_images=None,
raw_message: dict = None,
raw_message: dict | None = None,
downloader=None,
):
self._xml = None
Expand Down
4 changes: 2 additions & 2 deletions astrbot/core/platform_message_history_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ async def insert(
platform_id: str,
user_id: str,
content: list[dict], # TODO: parse from message chain
sender_id: str = None,
sender_name: str = None,
sender_id: str | None = None,
sender_name: str | None = None,
):
"""Insert a new platform message history record."""
await self.db.insert_platform_message_history(
Expand Down
24 changes: 12 additions & 12 deletions astrbot/core/provider/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ async def get_models(self) -> list[str]:
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: list[str] = None,
func_tool: ToolSet = None,
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
session_id: str | None = None,
image_urls: list[str] | None = None,
func_tool: ToolSet | None = None,
contexts: list | None = None,
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
**kwargs,
) -> LLMResponse:
Expand All @@ -115,12 +115,12 @@ async def text_chat(
async def text_chat_stream(
self,
prompt: str,
session_id: str = None,
image_urls: list[str] = None,
func_tool: ToolSet = None,
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
session_id: str | None = None,
image_urls: list[str] | None = None,
func_tool: ToolSet | None = None,
contexts: list | None = None,
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
Expand Down
4 changes: 2 additions & 2 deletions astrbot/core/provider/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def register_provider_adapter(
provider_type_name: str,
desc: str,
provider_type: ProviderType = ProviderType.CHAT_COMPLETION,
default_config_tmpl: dict = None,
provider_display_name: str = None,
default_config_tmpl: dict | None = None,
provider_display_name: str | None = None,
):
"""用于注册平台适配器的带参装饰器"""

Expand Down
4 changes: 2 additions & 2 deletions astrbot/core/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def save_temp_img(img: Image.Image | str) -> str:
async def download_image_by_url(
url: str,
post: bool = False,
post_data: dict = None,
path=None,
post_data: dict | None = None,
path: str | None = None,
) -> str:
"""下载图片, 返回 path"""
try:
Expand Down
Loading