diff --git a/.github/workflows/build-models.yml b/.github/workflows/build-models.yml index 48d44c7..00c875a 100644 --- a/.github/workflows/build-models.yml +++ b/.github/workflows/build-models.yml @@ -19,6 +19,10 @@ jobs: - uses: astral-sh/setup-uv@v5 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - id: find name: Discover models run: | @@ -31,7 +35,7 @@ jobs: strategy: fail-fast: false matrix: - model: ${{ fromJson(needs.discover.outputs.models) }} + include: ${{ fromJson(needs.discover.outputs.models) }} steps: - uses: actions/checkout@v4 with: @@ -39,25 +43,29 @@ jobs: - uses: astral-sh/setup-uv@v5 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - name: Install dependencies - run: uv sync --group all-models + run: uv sync --group ${{ matrix.dep_group }} - name: Download checkpoints - run: uv run dtai setup ${{ matrix.model }} + run: uv run dtai setup ${{ matrix.id }} - name: Convert to ONNX - run: uv run dtai convert ${{ matrix.model }} + run: uv run dtai convert ${{ matrix.id }} - name: Validate output - run: uv run dtai validate ${{ matrix.model }} + run: uv run dtai validate ${{ matrix.id }} - name: Package model - run: uv run dtai package ${{ matrix.model }} + run: uv run dtai package ${{ matrix.id }} - uses: actions/upload-artifact@v4 with: - name: ${{ matrix.model }} - path: output/${{ matrix.model }}.dtmodel + name: ${{ matrix.id }} + path: output/${{ matrix.id }}.dtmodel publish: needs: build diff --git a/.github/workflows/check-models.yml b/.github/workflows/check-models.yml new file mode 100644 index 0000000..acd93a2 --- /dev/null +++ b/.github/workflows/check-models.yml @@ -0,0 +1,68 @@ +name: Check Models + +on: + pull_request: + branches: [master] + push: + branches: [master] + +permissions: + contents: read + +jobs: + discover: + runs-on: ubuntu-latest + outputs: + models: ${{ steps.find.outputs.models }} + steps: + - uses: actions/checkout@v4 + with: + submodules: true + + - uses: astral-sh/setup-uv@v5 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - id: find + name: Discover non-skipped models + run: | + models=$(uv run dtai list --json-output) + echo "models=$models" >> "$GITHUB_OUTPUT" + echo "Will check: $models" + + check: + needs: discover + if: ${{ needs.discover.outputs.models != '[]' }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: ${{ fromJson(needs.discover.outputs.models) }} + name: ${{ matrix.id }} + steps: + - uses: actions/checkout@v4 + with: + submodules: true + + - uses: astral-sh/setup-uv@v5 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install dependencies + run: uv sync --group ${{ matrix.dep_group }} + + - name: Download checkpoints + run: uv run dtai setup ${{ matrix.id }} + + - name: Convert to ONNX + run: uv run dtai convert ${{ matrix.id }} + + - name: Validate output + run: uv run dtai validate ${{ matrix.id }} + + - name: Run demo + run: uv run dtai demo ${{ matrix.id }} diff --git a/darktable_ai/cli.py b/darktable_ai/cli.py index 221532c..61ef60c 100644 --- a/darktable_ai/cli.py +++ b/darktable_ai/cli.py @@ -26,10 +26,27 @@ def _load_config(root: Path, model_id: str) -> ModelConfig: return load_model_config(model_dir, root) -def _for_each_model(root: Path, model_id: str | None, callback) -> None: +def _sync_deps(config: ModelConfig) -> None: + """Ensure the model's dependency group is installed.""" + group = config.dep_group + if group == "core": + return + click.echo(f" Syncing dependency group: {group}") + subprocess.run( + ["uv", "sync", "--group", group], + cwd=str(config.root_dir), + check=True, + ) + + +def _for_each_model( + root: Path, model_id: str | None, callback, *, sync: bool = False +) -> None: """Run callback for one model or all non-skipped models.""" if model_id: config = _load_config(root, model_id) + if sync: + _sync_deps(config) callback(config) else: for config in discover_models(root): @@ -39,6 +56,8 @@ def _for_each_model(root: Path, model_id: str | None, callback) -> None: click.echo(f"\n{'=' * 40}") click.echo(f" {config.id}") click.echo(f"{'=' * 40}") + if sync: + _sync_deps(config) callback(config) @@ -61,12 +80,16 @@ def setup(ctx, model_id): root = _get_root(ctx) def _setup(config: ModelConfig): - if config.checkpoints: - download_checkpoints(config.checkpoints, root) - - if config.repo and config.repo.setup: + if config.repo: repo_dir = config.repo_dir - if repo_dir and repo_dir.is_dir(): + if repo_dir and not repo_dir.is_dir(): + click.echo(f" Initializing submodule: {config.repo.submodule}") + subprocess.run( + ["git", "submodule", "update", "--init", config.repo.submodule], + cwd=str(root), check=True, + ) + + if config.repo.setup and repo_dir and repo_dir.is_dir(): click.echo(f" Running repo setup: {config.repo.setup}") env = os.environ.copy() env["DTAI_ROOT"] = str(root) @@ -74,12 +97,9 @@ def _setup(config: ModelConfig): config.repo.setup, shell=True, cwd=str(repo_dir), env=env, check=True, ) - else: - click.echo( - f" Warning: submodule not found at {repo_dir}. " - f"Run: git submodule update --init", - err=True, - ) + + if config.checkpoints: + download_checkpoints(config.checkpoints, root) _for_each_model(root, model_id, _setup) @@ -92,7 +112,7 @@ def convert(ctx, model_id): from darktable_ai.convert import run_conversion root = _get_root(ctx) - _for_each_model(root, model_id, run_conversion) + _for_each_model(root, model_id, run_conversion, sync=True) @main.command() @@ -103,7 +123,7 @@ def validate(ctx, model_id): from darktable_ai.validate import run_validation root = _get_root(ctx) - _for_each_model(root, model_id, run_validation) + _for_each_model(root, model_id, run_validation, sync=True) @main.command("package") @@ -125,7 +145,7 @@ def demo(ctx, model_id): from darktable_ai.demo import run_demo root = _get_root(ctx) - _for_each_model(root, model_id, run_demo) + _for_each_model(root, model_id, run_demo, sync=True) @main.command() @@ -157,7 +177,7 @@ def _run_pipeline(config: ModelConfig): click.echo("\n=== Demo ===") run_demo(config) - _for_each_model(root, model_id, _run_pipeline) + _for_each_model(root, model_id, _run_pipeline, sync=True) @main.command("list") @@ -169,8 +189,11 @@ def list_models(ctx, as_json): models = discover_models(root) if as_json: - ids = [m.id for m in models if not m.skip] - click.echo(json.dumps(ids)) + matrix = [ + {"id": m.id, "dep_group": m.dep_group} + for m in models if not m.skip + ] + click.echo(json.dumps(matrix)) else: for config in models: status = " (skipped)" if config.skip else "" diff --git a/darktable_ai/config.py b/darktable_ai/config.py index 9ec1850..666cc12 100644 --- a/darktable_ai/config.py +++ b/darktable_ai/config.py @@ -38,6 +38,7 @@ class ModelConfig: description: str task: str type: str = "single" + arch: str = "generic" tiling: bool = False dep_group: str = "core" skip: bool = False @@ -109,6 +110,7 @@ def load_model_config(model_dir: Path, root_dir: Path) -> ModelConfig: description=data["description"], task=data["task"], type=data.get("type", "single"), + arch=data.get("arch", "generic"), tiling=data.get("tiling", False), dep_group=data.get("dep_group", "core"), skip=skip, diff --git a/darktable_ai/convert.py b/darktable_ai/convert.py index 4738f83..a36f339 100644 --- a/darktable_ai/convert.py +++ b/darktable_ai/convert.py @@ -40,6 +40,7 @@ def generate_config_json(config: ModelConfig) -> None: "name": config.name, "description": config.description, "task": config.task, + "arch": config.arch, "backend": "onnx", "version": "1.0", "tiling": config.tiling,