Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions matrix/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .bot import Bot
from .group import Group
from .config import Config
from .context import Context
from .command import Command
Expand All @@ -8,6 +9,7 @@

__all__ = [
"Bot",
"Group",
"Config",
"Command",
"Context",
Expand Down
40 changes: 31 additions & 9 deletions matrix/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,23 @@
)

from .room import Room
from .group import Group
from .config import Config
from .context import Context
from .command import Command
from .help import HelpCommand
from .errors import AlreadyRegisteredError, CommandNotFoundError, CheckError
from .scheduler import Scheduler

from .errors import (
AlreadyRegisteredError,
CommandNotFoundError,
CheckError,
GroupAlreadyRegisteredError
)


Callback = Callable[..., Coroutine[Any, Any, Any]]
GroupCallable = Callable[[Callable[..., Coroutine[Any, Any, Any]]], Group]
ErrorCallback = Callable[[Exception], Coroutine]


Expand Down Expand Up @@ -86,7 +94,7 @@ def __init__(
self.commands: Dict[str, Command] = {}
self.checks: List[Callback] = []
self.scheduler = Scheduler()

self._handlers: Dict[Type[Event], List[Callback]] = defaultdict(list)
self._on_error: Optional[ErrorCallback] = None

Expand Down Expand Up @@ -186,24 +194,21 @@ def wrapper(f: Callback) -> Callback:
def command(
self,
name: Optional[str] = None,
cooldown: Optional[tuple[int, float]] = None
**kwargs
) -> Callable[[Callback], Command]:
"""
Decorator to register a coroutine function as a command handler.

The command name defaults to the function name unless
explicitly provided.

:param name: The name of the command. If omitted, the function
name is used.
:type name: str, optional
:raises TypeError: If the decorated function is not a coroutine.
:raises ValueError: If a command with the same name is registered.
:return: Decorator that registers the command handler.
:rtype: Callback
"""
def wrapper(func: Callback) -> Command:
cmd = Command(func, name=name, cooldown=cooldown, prefix=self.prefix)
cmd = Command(func, name=name, prefix=self.prefix, **kwargs)
return self.register_command(cmd)
return wrapper

Expand Down Expand Up @@ -233,15 +238,32 @@ def wrapper(f: Callback) -> Callback:

return wrapper

def register_command(self, cmd: Command):
def register_command(self, cmd: Command) -> Command:
if cmd in self.commands:
raise AlreadyRegisteredError(cmd)

self.commands[cmd.name] = cmd
self.log.debug("command %s registered", cmd)
self.log.debug("command '%s' registered", cmd)

return cmd

def group(self, **kwargs) -> GroupCallable:
"""Decorator to register a custom error handler for the command."""

def wrapper(func: Callback) -> Group:
group = Group(func, prefix=self.prefix, **kwargs)
return self.register_group(group)
return wrapper

def register_group(self, group: Group) -> Group:
if group in self.commands:
raise GroupAlreadyRegisteredError(group)

self.commands[group.name] = group
self.log.debug("group '%s' registered", group)

return group

def error(self):
"""Decorator to register a custom error handler for the command."""

Expand Down
52 changes: 29 additions & 23 deletions matrix/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@ def __init__(self, func: Callback, **kwargs: Any):

self.description: str = kwargs.get("description", "")
self.prefix: str = kwargs.get("prefix", "")
self.parent: str = kwargs.get("parent", "")
self.usage: str = kwargs.get("usage", self._build_usage())
self.help: str = self._build_help()

self._before_invoke: Optional[Callback] = None
self._after_invoke: Optional[Callback] = None
self._before_invoke_callback: Optional[Callback] = None
self._after_invoke_callback: Optional[Callback] = None
self._on_error: Optional[ErrorCallback] = None

self.cooldown_rate: Optional[int] = None
Expand Down Expand Up @@ -114,7 +115,12 @@ def _build_usage(self) -> str:
:rtype: str
"""
params = " ".join(f"[{p.name}]" for p in self.params)
return f"{self.prefix}{self.name} {params}"
command_name = self.name

if self.parent:
command_name = f"{self.parent} {self.name}"

return f"{self.prefix}{command_name} {params}"

def _parse_arguments(self, ctx: "Context") -> list[Any]:
parsed_args = []
Expand Down Expand Up @@ -185,7 +191,7 @@ def before_invoke(self, func: Callback) -> None:
if not asyncio.iscoroutinefunction(func):
raise TypeError('The hook must be a coroutine.')

self._before_invoke = func
self._before_invoke_callback = func

def after_invoke(self, func: Callback) -> None:
"""
Expand All @@ -200,7 +206,7 @@ def after_invoke(self, func: Callback) -> None:
if not asyncio.iscoroutinefunction(func):
raise TypeError('The hook must be a coroutine.')

self._after_invoke = func
self._after_invoke_callback = func

def error(self, func: ErrorCallback) -> None:
"""
Expand Down Expand Up @@ -234,17 +240,25 @@ async def on_error(self, ctx: "Context", error: Exception) -> None:
ctx.logger.exception("error while executing command '%s'", self)
raise error

async def __before_invoke(self, ctx: "Context") -> None:
for check in self.checks:
if not await check(ctx):
raise CheckError(self, check)
async def invoke(self, ctx):
parsed_args = self._parse_arguments(ctx)
await self.callback(ctx, *parsed_args)

async def _invoke(self, ctx: "Context"):
try:
for check in self.checks:
if not await check(ctx):
raise CheckError(self, check)

if self._before_invoke_callback:
await self._before_invoke_callback(ctx)

if self._before_invoke:
await self._before_invoke(ctx)
await self.invoke(ctx)

async def __after_invoke(self, ctx: "Context") -> None:
if self._after_invoke:
await self._after_invoke(ctx)
if self._after_invoke_callback:
await self._after_invoke_callback(ctx)
except Exception as error:
await self.on_error(ctx, error)

async def __call__(self, ctx: "Context") -> None:
"""
Expand All @@ -253,15 +267,7 @@ async def __call__(self, ctx: "Context") -> None:
:param ctx: The command execution context.
:type ctx: Context
"""
try:
await self.__before_invoke(ctx)

parsed_args = self._parse_arguments(ctx)
await self.callback(ctx, *parsed_args)

await self.__after_invoke(ctx)
except Exception as error:
await self.on_error(ctx, error)
await self._invoke(ctx)

def __eq__(self, other) -> bool:
return self.name == other
Expand Down
18 changes: 15 additions & 3 deletions matrix/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
if TYPE_CHECKING:
from .bot import Bot # pragma: no cover
from .command import Command # pragma: no cover
from .group import Group


class Context:
Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(self, bot: "Bot", room: MatrixRoom, event: Event):
# Command metdata
self.prefix: str = bot.prefix
self.command: Optional[Command] = None
self.subcommand: Optional[Command] = None
self._args: List[str] = shlex.split(self.body)

@property
Expand All @@ -54,8 +56,12 @@ def args(self) -> List[str]:
:return: The list of arguments.
:rtype: List[str]
"""
if self.subcommand:
return self._args[2:]

if self.command:
return self._args[1:]

return self._args

@property
Expand All @@ -80,6 +86,12 @@ async def reply(self, message: str) -> None:
raise MatrixError(f"Failed to send message: {e}")

async def send_help(self) -> None:
if not self.command:
return await self.bot.help.execute(self)
await self.reply(self.command.help)
if self.subcommand:
await self.reply(self.subcommand.help)
return

if self.command:
await self.reply(self.command.help)
return

await self.bot.help.execute(self)
11 changes: 10 additions & 1 deletion matrix/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,16 @@ def __init__(self, param):

class CheckError(CommandError):
def __init__(self, cmd, check):
super().__init__(f"'{check.__name__}' has failed for '{cmd.name}'!")
super().__init__(f"'{check.__name__}' has failed for '{cmd.name}'")


class GroupError(CommandError):
pass


class GroupAlreadyRegisteredError(GroupError):
def __init__(self, group):
super().__init__(f"Group '{group}' is already registered")


class ConfigError(MatrixError):
Expand Down
72 changes: 72 additions & 0 deletions matrix/group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import logging
from typing import TYPE_CHECKING, Optional, Dict, Any, Callable, Coroutine

from .command import Command
from .errors import AlreadyRegisteredError, CommandNotFoundError

if TYPE_CHECKING:
from .context import Context # pragma: no cover

logger = logging.getLogger(__name__)

Callback = Callable[..., Coroutine[Any, Any, Any]]
ErrorCallback = Callable[["Context", Exception], Coroutine[Any, Any, Any]]


class Group(Command):
def __init__(self, callback: Callback, **kwargs: Any):
self.commands: Dict[str, Command] = {}

super().__init__(callback, **kwargs)

def _build_usage(self):
return f"{self.prefix}{self.name} [subcommand]"

def get_command(self, cmd_name: str):
if cmd := self.commands.get(cmd_name):
return cmd
raise CommandNotFoundError(cmd_name)

def command(
self,
name: Optional[str] = None
) -> Callable[[Callback], Command]:
"""
Decorator to register a coroutine function as a command handler.

The command name defaults to the function name unless
explicitly provided.

:param name: The name of the command. If omitted, the function
name is used.
:type name: str, optional
:raises TypeError: If the decorated function is not a coroutine.
:raises ValueError: If a command with the same name is registered.
:return: Decorator that registers the command handler.
:rtype: Callback
"""
def wrapper(func: Callback) -> Command:
cmd = Command(
func,
name=name,
prefix=self.prefix,
parent=self.name
)
return self.register_command(cmd)
return wrapper

def register_command(self, cmd: Command):
if cmd in self.commands:
raise AlreadyRegisteredError(cmd)

self.commands[cmd.name] = cmd
logger.debug("command '%s' registered for group '%s'", cmd, self)

return cmd

async def invoke(self, ctx: "Context"):
if subcommand := ctx.args.pop(0):
ctx.subcommand = self.get_command(subcommand)
await ctx.subcommand(ctx)
else:
await self.callback(ctx)
Loading