Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions .github/workflows/build-models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ jobs:

- uses: astral-sh/setup-uv@v5

- uses: actions/setup-python@v5
with:
python-version: "3.12"

- id: find
name: Discover models
run: |
Expand All @@ -31,33 +35,37 @@ jobs:
strategy:
fail-fast: false
matrix:
model: ${{ fromJson(needs.discover.outputs.models) }}
include: ${{ fromJson(needs.discover.outputs.models) }}
steps:
- uses: actions/checkout@v4
with:
submodules: true

- uses: astral-sh/setup-uv@v5

- uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Install dependencies
run: uv sync --group all-models
run: uv sync --group ${{ matrix.dep_group }}

- name: Download checkpoints
run: uv run dtai setup ${{ matrix.model }}
run: uv run dtai setup ${{ matrix.id }}

- name: Convert to ONNX
run: uv run dtai convert ${{ matrix.model }}
run: uv run dtai convert ${{ matrix.id }}

- name: Validate output
run: uv run dtai validate ${{ matrix.model }}
run: uv run dtai validate ${{ matrix.id }}

- name: Package model
run: uv run dtai package ${{ matrix.model }}
run: uv run dtai package ${{ matrix.id }}

- uses: actions/upload-artifact@v4
with:
name: ${{ matrix.model }}
path: output/${{ matrix.model }}.dtmodel
name: ${{ matrix.id }}
path: output/${{ matrix.id }}.dtmodel

publish:
needs: build
Expand Down
68 changes: 68 additions & 0 deletions .github/workflows/check-models.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
name: Check Models

on:
pull_request:
branches: [master]
push:
branches: [master]

permissions:
contents: read

jobs:
discover:
runs-on: ubuntu-latest
outputs:
models: ${{ steps.find.outputs.models }}
steps:
- uses: actions/checkout@v4
with:
submodules: true

- uses: astral-sh/setup-uv@v5

- uses: actions/setup-python@v5
with:
python-version: "3.12"

- id: find
name: Discover non-skipped models
run: |
models=$(uv run dtai list --json-output)
echo "models=$models" >> "$GITHUB_OUTPUT"
echo "Will check: $models"

check:
needs: discover
if: ${{ needs.discover.outputs.models != '[]' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include: ${{ fromJson(needs.discover.outputs.models) }}
name: ${{ matrix.id }}
steps:
- uses: actions/checkout@v4
with:
submodules: true

- uses: astral-sh/setup-uv@v5

- uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Install dependencies
run: uv sync --group ${{ matrix.dep_group }}

- name: Download checkpoints
run: uv run dtai setup ${{ matrix.id }}

- name: Convert to ONNX
run: uv run dtai convert ${{ matrix.id }}

- name: Validate output
run: uv run dtai validate ${{ matrix.id }}

- name: Run demo
run: uv run dtai demo ${{ matrix.id }}
59 changes: 41 additions & 18 deletions darktable_ai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,27 @@ def _load_config(root: Path, model_id: str) -> ModelConfig:
return load_model_config(model_dir, root)


def _for_each_model(root: Path, model_id: str | None, callback) -> None:
def _sync_deps(config: ModelConfig) -> None:
"""Ensure the model's dependency group is installed."""
group = config.dep_group
if group == "core":
return
click.echo(f" Syncing dependency group: {group}")
subprocess.run(
["uv", "sync", "--group", group],
cwd=str(config.root_dir),
check=True,
)


def _for_each_model(
root: Path, model_id: str | None, callback, *, sync: bool = False
) -> None:
"""Run callback for one model or all non-skipped models."""
if model_id:
config = _load_config(root, model_id)
if sync:
_sync_deps(config)
callback(config)
else:
for config in discover_models(root):
Expand All @@ -39,6 +56,8 @@ def _for_each_model(root: Path, model_id: str | None, callback) -> None:
click.echo(f"\n{'=' * 40}")
click.echo(f" {config.id}")
click.echo(f"{'=' * 40}")
if sync:
_sync_deps(config)
callback(config)


Expand All @@ -61,25 +80,26 @@ def setup(ctx, model_id):
root = _get_root(ctx)

def _setup(config: ModelConfig):
if config.checkpoints:
download_checkpoints(config.checkpoints, root)

if config.repo and config.repo.setup:
if config.repo:
repo_dir = config.repo_dir
if repo_dir and repo_dir.is_dir():
if repo_dir and not repo_dir.is_dir():
click.echo(f" Initializing submodule: {config.repo.submodule}")
subprocess.run(
["git", "submodule", "update", "--init", config.repo.submodule],
cwd=str(root), check=True,
)

if config.repo.setup and repo_dir and repo_dir.is_dir():
click.echo(f" Running repo setup: {config.repo.setup}")
env = os.environ.copy()
env["DTAI_ROOT"] = str(root)
subprocess.run(
config.repo.setup, shell=True,
cwd=str(repo_dir), env=env, check=True,
)
else:
click.echo(
f" Warning: submodule not found at {repo_dir}. "
f"Run: git submodule update --init",
err=True,
)

if config.checkpoints:
download_checkpoints(config.checkpoints, root)

_for_each_model(root, model_id, _setup)

Expand All @@ -92,7 +112,7 @@ def convert(ctx, model_id):
from darktable_ai.convert import run_conversion

root = _get_root(ctx)
_for_each_model(root, model_id, run_conversion)
_for_each_model(root, model_id, run_conversion, sync=True)


@main.command()
Expand All @@ -103,7 +123,7 @@ def validate(ctx, model_id):
from darktable_ai.validate import run_validation

root = _get_root(ctx)
_for_each_model(root, model_id, run_validation)
_for_each_model(root, model_id, run_validation, sync=True)


@main.command("package")
Expand All @@ -125,7 +145,7 @@ def demo(ctx, model_id):
from darktable_ai.demo import run_demo

root = _get_root(ctx)
_for_each_model(root, model_id, run_demo)
_for_each_model(root, model_id, run_demo, sync=True)


@main.command()
Expand Down Expand Up @@ -157,7 +177,7 @@ def _run_pipeline(config: ModelConfig):
click.echo("\n=== Demo ===")
run_demo(config)

_for_each_model(root, model_id, _run_pipeline)
_for_each_model(root, model_id, _run_pipeline, sync=True)


@main.command("list")
Expand All @@ -169,8 +189,11 @@ def list_models(ctx, as_json):
models = discover_models(root)

if as_json:
ids = [m.id for m in models if not m.skip]
click.echo(json.dumps(ids))
matrix = [
{"id": m.id, "dep_group": m.dep_group}
for m in models if not m.skip
]
click.echo(json.dumps(matrix))
else:
for config in models:
status = " (skipped)" if config.skip else ""
Expand Down
2 changes: 2 additions & 0 deletions darktable_ai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ModelConfig:
description: str
task: str
type: str = "single"
arch: str = "generic"
tiling: bool = False
dep_group: str = "core"
skip: bool = False
Expand Down Expand Up @@ -109,6 +110,7 @@ def load_model_config(model_dir: Path, root_dir: Path) -> ModelConfig:
description=data["description"],
task=data["task"],
type=data.get("type", "single"),
arch=data.get("arch", "generic"),
tiling=data.get("tiling", False),
dep_group=data.get("dep_group", "core"),
skip=skip,
Expand Down
1 change: 1 addition & 0 deletions darktable_ai/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def generate_config_json(config: ModelConfig) -> None:
"name": config.name,
"description": config.description,
"task": config.task,
"arch": config.arch,
"backend": "onnx",
"version": "1.0",
"tiling": config.tiling,
Expand Down
Loading