Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
- uses: astral-sh/setup-uv@v5
with:
cache-dependency-glob: ""
- run: uv run --with flake8 flake8 magicli.py --extend-ignore=E501
- run: uv run --with flake8 flake8 magicli.py --extend-ignore=E203,E501

pylint:
runs-on: ubuntu-latest
Expand All @@ -22,7 +22,7 @@ jobs:
- uses: astral-sh/setup-uv@v5
with:
cache-dependency-glob: ""
- run: uv run --with pylint pylint --disable=unidiomatic-typecheck,raise-missing-from magicli.py
- run: uv run --with pylint pylint magicli.py

ruff:
runs-on: ubuntu-latest
Expand Down
185 changes: 74 additions & 111 deletions magicli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,36 @@
import inspect
import subprocess
import sys
from functools import partial
from importlib import metadata
from pathlib import Path


def magicli():
"""
Parses command-line arguments and calls the appropriate function.
"""
if not sys.argv:
raise SystemExit(1)

"""Parses command-line arguments and calls the appropriate function."""
name = Path(sys.argv[0]).name
argv = sys.argv[1:]

if name == "magicli":
raise SystemExit(call(cli, argv, sys.modules["magicli"]))

module = load_module(name)
name = name.replace("-", "_")

if function := is_command(argv, module):
call(function, argv[1:], module, name)
elif inspect.isfunction(function := module.__dict__.get(name)):
call(function, argv, module)
if function := get_function_from_argv(argv, module, name.replace("-", "_")):
function()
else:
raise SystemExit(help_message(help_from_module, module))


def get_function_from_argv(argv, module, name):
"""Returns the module's function to call based on argv."""
if function := is_command(argv, module):
return partial(call, function, argv[1:], module, name)
if inspect.isfunction(function := module.__dict__.get(name)):
return partial(call, function, argv, module)
return None


def is_command(argv, module):
"""
Checks if the first argument is a valid command in the module and returns
Expand All @@ -57,21 +59,19 @@ def call(function, argv, module=None, name=None):
Displays a help message if an exception occurs.
"""
try:
docstring = get_docstring(function)
docstring = inspect.getdoc(function) or ""
parameters = inspect.signature(function).parameters

check_for_version(argv, parameters, docstring, module)

args, kwargs = args_and_kwargs(argv, parameters, docstring)
args, kwargs = parse_argv(argv, parameters, docstring)
function(*args, **kwargs)
except Exception:
raise SystemExit(help_message(help_from_function, function, name))


def args_and_kwargs(argv, parameters, docstring):
"""
Parses command-line arguments into positional and keyword arguments.
"""
def parse_argv(argv, parameters, docstring):
"""Convert argv into args and kwargs."""
parameter_list = list(parameters.values())
args, kwargs = [], {}

Expand All @@ -87,10 +87,28 @@ def args_and_kwargs(argv, parameters, docstring):
return args, kwargs


def parse_short_options(short_options, docstring, iter_argv, parameters, kwargs):
def parse_kwarg(key, argv, parameters):
"""
Converts short options into long options and casts into correct types.
Parses a single keyword argument from command-line arguments.
Handles '=' syntax for inline values. Casts `NoneType` values to `True`
and boolean values to `not default`.
"""
key, value = key.split("=", 1) if "=" in key else (key, None)
key = key.replace("-", "_")
cast_to = get_type(parameters.get(key))

if value is None:
if cast_to is bool:
return key, not parameters[key].default
if cast_to is type(None):
return key, True
value = next(argv)

return key, value if cast_to is str else cast_to(value)


def parse_short_options(short_options, docstring, iter_argv, parameters, kwargs):
"""Converts short options into long options and casts into correct types."""
for i, short in enumerate(short_options):
long = short_to_long_option(short, docstring)

Expand All @@ -110,45 +128,17 @@ def parse_short_options(short_options, docstring, iter_argv, parameters, kwargs)


def short_to_long_option(short, docstring):
"""
Converts a one character short option to a long option according to the help message.
"""
"""Converts a one character short option to a long option according to the help message."""
template = f"-{short}, --"
if (start := docstring.find(template)) != -1:
start += len(template)
chars = (" ", "\n", "]")

try:
end = min(i for ws in chars if (i := docstring.find(ws, start)) != -1)
return docstring[start:end]

except ValueError:
if len(docstring) - start > 1:
return docstring[start:]

if len(docstring) - start > 1:
chars = [" ", "\n", "]"]
indices = (i for char in chars if (i := docstring.find(char, start)) != -1)
return docstring[start : min(indices, default=None)]
raise SystemExit(f"-{short}: invalid short option")


def parse_kwarg(key, argv, parameters):
"""
Parses a single keyword argument from command-line arguments.
Handles '=' syntax for inline values. Casts `NoneType` values to `True`
and boolean values to `not default`.
"""
key, value = key.split("=", 1) if "=" in key else (key, None)
key = key.replace("-", "_")
cast_to = get_type(parameters.get(key))

if value is None:
if cast_to is bool:
return key, not parameters[key].default
if cast_to is type(None):
return key, True
value = next(argv)

return key, value if cast_to is str else cast_to(value)


def get_type(parameter):
"""
Determines the type based on function signature annotations or defaults.
Expand All @@ -162,21 +152,15 @@ def get_type(parameter):


def check_for_version(argv, parameters, docstring, module):
"""
Displays version information if --version is specified in the docstring.
"""
if (
"version" not in parameters
and any(
(argv == [arg] and string in docstring)
for arg, string in [
("--version", "--version"),
("-v", "-v, --version"),
("-V", "-V, --version"),
]
)
and module
):
"""Displays version information if --version is specified in the docstring."""
if "version" in parameters or not module or len(argv) != 1:
return
args = {
"--version": "--version",
"-v": "-v, --version",
"-V": "-V, --version",
}
if (doc := args.get(argv[0])) and doc in docstring:
print(get_version(module))
raise SystemExit

Expand Down Expand Up @@ -212,17 +196,17 @@ def help_from_module(module):
Generates a help message for a module and lists available commands.
Lists all public functions that are not excluded in `__all__`.
"""
message = []
blocks = []

if version := get_version(module):
message.append([f"{module.__name__} {version}"])
blocks.append([f"{module.__name__} {version}"])

message.append(["usage:", f"{module.__name__} command"])
blocks.append(["usage:", f"{module.__name__} command"])

if commands := get_commands(module):
message.append(["commands:", *commands])
blocks.append(["commands:", *commands])

return format_blocks(message)
return format_blocks(blocks)


def format_blocks(blocks, sep="\n "):
Expand All @@ -239,35 +223,24 @@ def load_module(name):


def get_commands(module):
"""Returns list of public commands that are not present in `__all__`."""
"""Returns list of public commands that are not excluded by `__all__`."""
return [
name
for name, _ in inspect.getmembers(module, inspect.isfunction)
if not name.startswith("_") and name in module.__dict__.get("__all__", [name])
]


def get_docstring(function):
"""
Returns the cleaned up docstring of a function or an empty string.
"""
return inspect.getdoc(function) or ""


def get_version(module):
"""
Returns the version of a module from its metadata or `__version__` attribute.
"""
"""Returns the version of a module from its metadata or `__version__` attribute."""
try:
return metadata.version(module.__name__)
except metadata.PackageNotFoundError:
return module.__dict__.get("__version__")


def get_project_name():
"""
Detect project name from project structure.
"""
"""Detect project name from project structure."""
single_file_layout = [path.stem for path in Path().glob("*.py")]
flat_layout = [
path.parent.name
Expand All @@ -293,37 +266,28 @@ def get_output(command):
).stdout
except FileNotFoundError:
return None
return output.removesuffix("\n") if output else None
return output.removesuffix("\n") or None


def get_homepage(url=None):
"""Return a homepage url from a git remote url."""
url = url or get_output("git remote get-url origin") or ""
url = url.removesuffix(".git")
if url.startswith("git@"):
url = "https://" + url.replace(":", "/")[4:]
return url
url = "https://" + url.removeprefix("git@").replace(":", "/")
return url.removesuffix(".git")


def get_description(name):
"""Return the first paragraph of a module's docstring if available."""
try:
if doc := (importlib.import_module(name).__doc__ or "").split("\n\n"):
return " ".join(
[stripped for line in doc[0].splitlines() if (stripped := line.strip())]
)
module = importlib.import_module(name)
except ModuleNotFoundError:
pass
return None
return None
doc = (module.__doc__ or "").split("\n\n")[0]
return " ".join(stripped for line in doc.splitlines() if (stripped := line.strip()))


def cli(
name="",
author="",
email="",
description="",
homepage="",
):
def cli(name="", author="", email="", description="", homepage=""):
"""
magiCLI✨

Expand Down Expand Up @@ -363,11 +327,11 @@ def cli(
if authors:
project.append(f"authors = [{{{', '.join(authors)}}}]")

if Path(readme := "README.md").exists():
project.append(f'readme = "{readme}"')
if Path("README.md").exists():
project.append('readme = "README.md"')

if Path(license_file := "LICENSE").exists():
project.append(f'license-files = ["{license_file}"]')
if Path("LICENSE").exists():
project.append('license-files = ["LICENSE"]')

if description or (description := get_description(name)):
project.append(f'description = "{description}"')
Expand All @@ -387,11 +351,10 @@ def cli(

pyproject.write_text(format_blocks(blocks, sep="\n") + "\n", encoding="utf-8")

message = ["pyproject.toml created! ✨"]
if Path(".git").exists():
message.append("You can specify the version with `git tag`")
git_note = "You can specify the version with `git tag`"
else:
message.append(
git_note = (
"Error: Not a git repo. Run `git init`. Specify version with `git tag`."
)
print(*message, sep="\n")
print("pyproject.toml created! ✨", git_note, sep="\n")
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,9 @@ dev = ["pytest"]

[project.urls]
Home = "https://github.com/PatrickElmer/magicli"

[tool.pylint."messages control"]
disable = [
"unidiomatic-typecheck",
"raise-missing-from",
]
6 changes: 0 additions & 6 deletions tests/test_magicli.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,6 @@ def test_wrong_command_not_called(mocked):
magicli()


def test_empty_sys_argv():
sys.argv = []
with pytest.raises(SystemExit):
magicli()


@mock.patch("importlib.import_module", side_effect=module_empty)
def test_module_without_functions(mocked):
sys.argv = ["name"]
Expand Down
Loading
Loading