diff --git a/clarifai/cli/pipeline_run.py b/clarifai/cli/pipeline_run.py new file mode 100644 index 00000000..efc6ed57 --- /dev/null +++ b/clarifai/cli/pipeline_run.py @@ -0,0 +1,426 @@ +import os +import shutil + +import click + +from clarifai.cli.base import cli +from clarifai.utils.cli import AliasedGroup, from_yaml, validate_context +from clarifai.utils.logging import logger + + +def _load_pipeline_params_from_config(user_id, app_id, pipeline_id, pipeline_version_id): + """Load pipeline parameters from config-lock.yaml if not all provided. + + Args: + user_id: User ID (may be None) + app_id: App ID (may be None) + pipeline_id: Pipeline ID (may be None) + pipeline_version_id: Pipeline Version ID (may be None) + + Returns: + tuple: (user_id, app_id, pipeline_id, pipeline_version_id) + """ + if not all([user_id, app_id, pipeline_id, pipeline_version_id]): + lockfile_path = os.path.join(os.getcwd(), "config-lock.yaml") + if os.path.exists(lockfile_path): + logger.info("Loading parameters from config-lock.yaml") + lockfile_data = from_yaml(lockfile_path) + + if 'pipeline' in lockfile_data: + pipeline_config = lockfile_data['pipeline'] + user_id = user_id or pipeline_config.get('user_id') + app_id = app_id or pipeline_config.get('app_id') + pipeline_id = pipeline_id or pipeline_config.get('id') + pipeline_version_id = pipeline_version_id or pipeline_config.get('version_id') + + return user_id, app_id, pipeline_id, pipeline_version_id + + +def _validate_pipeline_params(user_id, app_id, pipeline_id, pipeline_version_id): + """Validate that all required pipeline parameters are present. + + Args: + user_id: User ID + app_id: App ID + pipeline_id: Pipeline ID + pipeline_version_id: Pipeline Version ID + + Raises: + click.UsageError: If any required parameter is missing + """ + if not all([user_id, app_id, pipeline_id, pipeline_version_id]): + raise click.UsageError( + "Missing required parameters. Either provide --user_id, --app_id, " + "--pipeline_id, and --pipeline_version_id, or ensure config-lock.yaml exists." + ) + + +def _create_pipeline(ctx, user_id, app_id, pipeline_id, pipeline_version_id): + """Create and return a Pipeline object. + + Args: + ctx: Click context + user_id: User ID + app_id: App ID + pipeline_id: Pipeline ID + pipeline_version_id: Pipeline Version ID + + Returns: + Pipeline: Configured Pipeline object + """ + from clarifai.client.pipeline import Pipeline + + return Pipeline( + pipeline_id=pipeline_id, + pipeline_version_id=pipeline_version_id, + user_id=user_id, + app_id=app_id, + pat=ctx.obj.current.pat, + base_url=ctx.obj.current.api_base, + ) + + +@cli.group( + ['pipelinerun', 'pr'], + cls=AliasedGroup, + context_settings={'max_content_width': shutil.get_terminal_size().columns - 10}, +) +def pipelinerun(): + """Manage Pipeline Version Runs: pause, cancel, resume, monitor""" + + +@pipelinerun.command() +@click.argument('pipeline_version_run_id', required=False) +@click.option( + '--pipeline_version_run_id', + 'pipeline_version_run_id_flag', + required=False, + help='Pipeline Version Run ID to pause.', +) +@click.option('--user_id', required=False, help='User ID that owns the pipeline.') +@click.option('--app_id', required=False, help='App ID that contains the pipeline.') +@click.option('--pipeline_id', required=False, help='Pipeline ID.') +@click.option('--pipeline_version_id', required=False, help='Pipeline Version ID.') +@click.pass_context +def pause( + ctx, + pipeline_version_run_id, + pipeline_version_run_id_flag, + user_id, + app_id, + pipeline_id, + pipeline_version_id, +): + """Pause a pipeline version run. + + Pausing is allowed only when the pipeline run is in Queued or Running state. + + Examples: + + # Using positional argument + clarifai pr pause + + # Using flag + clarifai pipelinerun pause --pipeline_version_run_id= + + # With explicit parameters + clarifai pr pause \\ + --user_id=USER_ID \\ + --app_id=APP_ID \\ + --pipeline_id=PIPELINE_ID \\ + --pipeline_version_id=VERSION_ID + """ + from clarifai_grpc.grpc.api.status import status_code_pb2 + + validate_context(ctx) + + # Resolve pipeline_version_run_id from positional or flag + run_id = pipeline_version_run_id or pipeline_version_run_id_flag + if not run_id: + raise click.UsageError( + "pipeline_version_run_id is required. " + "Provide it as a positional argument or use --pipeline_version_run_id flag." + ) + + # Load parameters from config-lock.yaml if not provided + user_id, app_id, pipeline_id, pipeline_version_id = _load_pipeline_params_from_config( + user_id, app_id, pipeline_id, pipeline_version_id + ) + + # Validate required parameters + _validate_pipeline_params(user_id, app_id, pipeline_id, pipeline_version_id) + + # Create Pipeline object + pipeline = _create_pipeline(ctx, user_id, app_id, pipeline_id, pipeline_version_id) + + # Patch the pipeline version run to JOB_PAUSED + try: + result = pipeline.patch_pipeline_version_run( + pipeline_version_run_id=run_id, + orchestration_status_code=status_code_pb2.JOB_PAUSED, + ) + logger.info(f"Successfully paused pipeline version run {run_id}") + click.echo(f"Pipeline version run {run_id} has been paused.") + except Exception as e: + logger.error(f"Failed to pause pipeline version run: {e}") + raise click.ClickException(str(e)) + + +@pipelinerun.command() +@click.argument('pipeline_version_run_id', required=False) +@click.option( + '--pipeline_version_run_id', + 'pipeline_version_run_id_flag', + required=False, + help='Pipeline Version Run ID to cancel.', +) +@click.option('--user_id', required=False, help='User ID that owns the pipeline.') +@click.option('--app_id', required=False, help='App ID that contains the pipeline.') +@click.option('--pipeline_id', required=False, help='Pipeline ID.') +@click.option('--pipeline_version_id', required=False, help='Pipeline Version ID.') +@click.pass_context +def cancel( + ctx, + pipeline_version_run_id, + pipeline_version_run_id_flag, + user_id, + app_id, + pipeline_id, + pipeline_version_id, +): + """Cancel a pipeline version run. + + Cancelling is allowed when the pipeline run is not already in a terminal state. + + Examples: + + # Using positional argument + clarifai pr cancel + + # Using flag + clarifai pipelinerun cancel --pipeline_version_run_id= + + # With explicit parameters + clarifai pr cancel \\ + --user_id=USER_ID \\ + --app_id=APP_ID \\ + --pipeline_id=PIPELINE_ID \\ + --pipeline_version_id=VERSION_ID + """ + from clarifai_grpc.grpc.api.status import status_code_pb2 + + validate_context(ctx) + + # Resolve pipeline_version_run_id from positional or flag + run_id = pipeline_version_run_id or pipeline_version_run_id_flag + if not run_id: + raise click.UsageError( + "pipeline_version_run_id is required. " + "Provide it as a positional argument or use --pipeline_version_run_id flag." + ) + + # Load parameters from config-lock.yaml if not provided + user_id, app_id, pipeline_id, pipeline_version_id = _load_pipeline_params_from_config( + user_id, app_id, pipeline_id, pipeline_version_id + ) + + # Validate required parameters + _validate_pipeline_params(user_id, app_id, pipeline_id, pipeline_version_id) + + # Create Pipeline object + pipeline = _create_pipeline(ctx, user_id, app_id, pipeline_id, pipeline_version_id) + + # Patch the pipeline version run to JOB_CANCELLED + try: + result = pipeline.patch_pipeline_version_run( + pipeline_version_run_id=run_id, + orchestration_status_code=status_code_pb2.JOB_CANCELLED, + ) + logger.info(f"Successfully cancelled pipeline version run {run_id}") + click.echo(f"Pipeline version run {run_id} has been cancelled.") + except Exception as e: + logger.error(f"Failed to cancel pipeline version run: {e}") + raise click.ClickException(str(e)) + + +@pipelinerun.command() +@click.argument('pipeline_version_run_id', required=False) +@click.option( + '--pipeline_version_run_id', + 'pipeline_version_run_id_flag', + required=False, + help='Pipeline Version Run ID to resume.', +) +@click.option('--user_id', required=False, help='User ID that owns the pipeline.') +@click.option('--app_id', required=False, help='App ID that contains the pipeline.') +@click.option('--pipeline_id', required=False, help='Pipeline ID.') +@click.option('--pipeline_version_id', required=False, help='Pipeline Version ID.') +@click.pass_context +def resume( + ctx, + pipeline_version_run_id, + pipeline_version_run_id_flag, + user_id, + app_id, + pipeline_id, + pipeline_version_id, +): + """Resume a paused pipeline version run. + + Resuming is allowed only when the pipeline run is in Paused state. + + Examples: + + # Using positional argument + clarifai pr resume + + # Using flag + clarifai pipelinerun resume --pipeline_version_run_id= + + # With explicit parameters + clarifai pr resume \\ + --user_id=USER_ID \\ + --app_id=APP_ID \\ + --pipeline_id=PIPELINE_ID \\ + --pipeline_version_id=VERSION_ID + """ + from clarifai_grpc.grpc.api.status import status_code_pb2 + + validate_context(ctx) + + # Resolve pipeline_version_run_id from positional or flag + run_id = pipeline_version_run_id or pipeline_version_run_id_flag + if not run_id: + raise click.UsageError( + "pipeline_version_run_id is required. " + "Provide it as a positional argument or use --pipeline_version_run_id flag." + ) + + # Load parameters from config-lock.yaml if not provided + user_id, app_id, pipeline_id, pipeline_version_id = _load_pipeline_params_from_config( + user_id, app_id, pipeline_id, pipeline_version_id + ) + + # Validate required parameters + _validate_pipeline_params(user_id, app_id, pipeline_id, pipeline_version_id) + + # Create Pipeline object + pipeline = _create_pipeline(ctx, user_id, app_id, pipeline_id, pipeline_version_id) + + # Patch the pipeline version run to JOB_RUNNING to resume + try: + result = pipeline.patch_pipeline_version_run( + pipeline_version_run_id=run_id, + orchestration_status_code=status_code_pb2.JOB_RUNNING, + ) + logger.info(f"Successfully resumed pipeline version run {run_id}") + click.echo(f"Pipeline version run {run_id} has been resumed.") + except Exception as e: + logger.error(f"Failed to resume pipeline version run: {e}") + raise click.ClickException(str(e)) + + +@pipelinerun.command() +@click.argument('pipeline_version_run_id', required=False) +@click.option( + '--pipeline_version_run_id', + 'pipeline_version_run_id_flag', + required=False, + help='Pipeline Version Run ID to monitor.', +) +@click.option('--user_id', required=False, help='User ID that owns the pipeline.') +@click.option('--app_id', required=False, help='App ID that contains the pipeline.') +@click.option('--pipeline_id', required=False, help='Pipeline ID.') +@click.option('--pipeline_version_id', required=False, help='Pipeline Version ID.') +@click.option( + '--timeout', + type=int, + default=3600, + help='Maximum time to wait for completion in seconds. Default 3600 (1 hour).', +) +@click.option( + '--monitor_interval', + type=int, + default=10, + help='Interval between status checks in seconds. Default 10.', +) +@click.option( + '--log_file', + type=click.Path(), + required=False, + help='Path to file where logs should be written. If not provided, logs are displayed on console.', +) +@click.pass_context +def monitor( + ctx, + pipeline_version_run_id, + pipeline_version_run_id_flag, + user_id, + app_id, + pipeline_id, + pipeline_version_id, + timeout, + monitor_interval, + log_file, +): + """Monitor an existing pipeline version run. + + Monitor the current status and logs of a running pipeline. + + Examples: + + # Using positional argument + clarifai pr monitor + + # Using flag + clarifai pipelinerun monitor --pipeline_version_run_id= + + # With explicit parameters + clarifai pr monitor \\ + --user_id=USER_ID \\ + --app_id=APP_ID \\ + --pipeline_id=PIPELINE_ID \\ + --pipeline_version_id=VERSION_ID + + # With custom timeout and interval + clarifai pr monitor \\ + --timeout=7200 \\ + --monitor_interval=5 + """ + import json + + validate_context(ctx) + + # Resolve pipeline_version_run_id from positional or flag + run_id = pipeline_version_run_id or pipeline_version_run_id_flag + if not run_id: + raise click.UsageError( + "pipeline_version_run_id is required. " + "Provide it as a positional argument or use --pipeline_version_run_id flag." + ) + + # Load parameters from config-lock.yaml if not provided + user_id, app_id, pipeline_id, pipeline_version_id = _load_pipeline_params_from_config( + user_id, app_id, pipeline_id, pipeline_version_id + ) + + # Validate required parameters + _validate_pipeline_params(user_id, app_id, pipeline_id, pipeline_version_id) + + # Create Pipeline object + pipeline = _create_pipeline(ctx, user_id, app_id, pipeline_id, pipeline_version_id) + + # Set the pipeline_version_run_id for monitoring + pipeline.pipeline_version_run_id = run_id + + # Set log file if provided + if log_file: + pipeline.log_file = log_file + + # Monitor the pipeline run + try: + result = pipeline.monitor_only(timeout=timeout, monitor_interval=monitor_interval) + click.echo(json.dumps(result, indent=2, default=str)) + except Exception as e: + logger.error(f"Failed to monitor pipeline version run: {e}") + raise click.ClickException(str(e)) diff --git a/clarifai/client/pipeline.py b/clarifai/client/pipeline.py index ef028277..e6385023 100644 --- a/clarifai/client/pipeline.py +++ b/clarifai/client/pipeline.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional from clarifai_grpc.grpc.api import resources_pb2, service_pb2 -from clarifai_grpc.grpc.api.status import status_code_pb2 +from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2 from google.protobuf import json_format from clarifai.client.base import BaseClient @@ -342,3 +342,61 @@ def _display_new_logs(self, run_id: str, seen_logs: set, current_page: int = 1) logger.debug(f"Error fetching logs: {e}") # Return current page on error to retry the same page next fetch return current_page + + def patch_pipeline_version_run( + self, + pipeline_version_run_id: str, + orchestration_status_code: int, + ) -> Dict: + """Patch a pipeline version run's orchestration status. + + This method can be used to pause, cancel, or resume a pipeline run. + + Args: + pipeline_version_run_id (str): The pipeline version run ID to patch. + orchestration_status_code (int): The status code to set (e.g., JOB_PAUSED, JOB_CANCELLED, JOB_RUNNING). + + Returns: + Dict: The response as a dictionary. + + Raises: + UserError: If the patch request fails. + """ + # Create the orchestration status + orchestration_status = resources_pb2.OrchestrationStatus( + status=status_pb2.Status(code=orchestration_status_code) + ) + + # Create the pipeline version run with only the ID and status + pipeline_version_run = resources_pb2.PipelineVersionRun( + id=pipeline_version_run_id, orchestration_status=orchestration_status + ) + + # Create the patch request + patch_request = service_pb2.PatchPipelineVersionRunsRequest() + patch_request.user_app_id.CopyFrom(self.user_app_id) + patch_request.pipeline_id = self.pipeline_id + patch_request.pipeline_version_id = self.pipeline_version_id or "" + patch_request.pipeline_version_runs.append(pipeline_version_run) + + # Make the API call + response = self.STUB.PatchPipelineVersionRuns( + patch_request, metadata=self.auth_helper.metadata + ) + + # Check for errors + if response.status.code != status_code_pb2.StatusCode.SUCCESS: + raise UserError( + f"Failed to patch pipeline version run: {response.status.description}. " + f"Details: {response.status.details}. " + f"Code: {status_code_pb2.StatusCode.Name(response.status.code)}." + ) + + logger.info( + f"Successfully patched pipeline version run {pipeline_version_run_id} " + f"to status code {orchestration_status_code} " + f"(user_id: {self.user_app_id.user_id}, app_id: {self.user_app_id.app_id}, " + f"pipeline_id: {self.pipeline_id}, pipeline_version_id: {self.pipeline_version_id})" + ) + + return json_format.MessageToDict(response, preserving_proto_field_name=True) diff --git a/requirements.txt b/requirements.txt index 6af33754..9749937f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -clarifai-grpc>=11.11.1 +clarifai-grpc>=11.11.3 clarifai-protocol>=0.0.33 numpy>=1.22.0 tqdm>=4.65.0 diff --git a/tests/cli/test_pipeline_run.py b/tests/cli/test_pipeline_run.py new file mode 100644 index 00000000..aa3591f8 --- /dev/null +++ b/tests/cli/test_pipeline_run.py @@ -0,0 +1,405 @@ +from unittest.mock import MagicMock, Mock, patch + +import pytest +from clarifai_grpc.grpc.api.status import status_code_pb2 +from click.testing import CliRunner + +from clarifai.cli.pipeline_run import cancel, pause, resume + + +@pytest.fixture +def runner(): + """Create a Click CLI runner for testing.""" + return CliRunner() + + +@pytest.fixture +def mock_context(): + """Create a mock context object for CLI commands.""" + ctx = Mock() + ctx.obj = Mock() + ctx.obj.current = Mock() + ctx.obj.current.pat = "test-pat" + ctx.obj.current.api_base = "https://api.clarifai.com" + ctx.obj.current.user_id = "test-user" + return ctx + + +@pytest.fixture +def config_lock_data(): + """Sample config-lock.yaml data.""" + return { + 'pipeline': { + 'id': 'test-pipeline', + 'user_id': 'test-user', + 'app_id': 'test-app', + 'version_id': 'v1', + } + } + + +class TestPipelineRunPause: + """Test cases for pause command.""" + + @patch('clarifai.cli.pipeline_run.validate_context') + @patch('clarifai.client.pipeline.Pipeline') + def test_pause_with_positional_arg( + self, mock_pipeline_class, mock_validate, runner, mock_context + ): + """Test pause command with positional argument.""" + # Setup mock + mock_pipeline = MagicMock() + mock_pipeline_class.return_value = mock_pipeline + mock_pipeline.patch_pipeline_version_run.return_value = {'status': 'success'} + + # Run command + result = runner.invoke( + pause, + [ + 'test-run-id', + '--user_id=test-user', + '--app_id=test-app', + '--pipeline_id=test-pipeline', + '--pipeline_version_id=v1', + ], + obj=mock_context.obj, + ) + + # Assertions + assert result.exit_code == 0 + assert 'has been paused' in result.output + mock_pipeline.patch_pipeline_version_run.assert_called_once_with( + pipeline_version_run_id='test-run-id', + orchestration_status_code=status_code_pb2.JOB_PAUSED, + ) + + @patch('clarifai.cli.pipeline_run.validate_context') + @patch('clarifai.client.pipeline.Pipeline') + def test_pause_with_flag(self, mock_pipeline_class, mock_validate, runner, mock_context): + """Test pause command with --pipeline_version_run_id flag.""" + # Setup mock + mock_pipeline = MagicMock() + mock_pipeline_class.return_value = mock_pipeline + mock_pipeline.patch_pipeline_version_run.return_value = {'status': 'success'} + + # Run command + result = runner.invoke( + pause, + [ + '--pipeline_version_run_id=test-run-id', + '--user_id=test-user', + '--app_id=test-app', + '--pipeline_id=test-pipeline', + '--pipeline_version_id=v1', + ], + obj=mock_context.obj, + ) + + # Assertions + assert result.exit_code == 0 + assert 'has been paused' in result.output + mock_pipeline.patch_pipeline_version_run.assert_called_once() + + @patch('clarifai.cli.pipeline_run.validate_context') + @patch('clarifai.cli.pipeline_run.from_yaml') + @patch('os.path.exists') + @patch('clarifai.client.pipeline.Pipeline') + def test_pause_with_config_lock( + self, + mock_pipeline_class, + mock_exists, + mock_from_yaml, + mock_validate, + runner, + mock_context, + config_lock_data, + ): + """Test pause command loading parameters from config-lock.yaml.""" + # Setup mocks + mock_exists.return_value = True + mock_from_yaml.return_value = config_lock_data + mock_pipeline = MagicMock() + mock_pipeline_class.return_value = mock_pipeline + mock_pipeline.patch_pipeline_version_run.return_value = {'status': 'success'} + + # Run command + result = runner.invoke(pause, ['test-run-id'], obj=mock_context.obj) + + # Assertions + assert result.exit_code == 0 + mock_pipeline_class.assert_called_once_with( + pipeline_id='test-pipeline', + pipeline_version_id='v1', + user_id='test-user', + app_id='test-app', + pat='test-pat', + base_url='https://api.clarifai.com', + ) + + @patch('clarifai.cli.pipeline_run.validate_context') + def test_pause_without_run_id(self, mock_validate, runner, mock_context): + """Test pause command fails without pipeline_version_run_id.""" + result = runner.invoke( + pause, ['--user_id=test-user', '--app_id=test-app'], obj=mock_context.obj + ) + + assert result.exit_code != 0 + assert 'pipeline_version_run_id is required' in result.output + + @patch('clarifai.cli.pipeline_run.validate_context') + @patch('os.path.exists') + def test_pause_without_required_params(self, mock_exists, mock_validate, runner, mock_context): + """Test pause command fails without required parameters and no config-lock.yaml.""" + mock_exists.return_value = False + + result = runner.invoke(pause, ['test-run-id'], obj=mock_context.obj) + + assert result.exit_code != 0 + assert 'Missing required parameters' in result.output + + +class TestPipelineRunCancel: + """Test cases for cancel command.""" + + @patch('clarifai.cli.pipeline_run.validate_context') + @patch('clarifai.client.pipeline.Pipeline') + def test_cancel_with_positional_arg( + self, mock_pipeline_class, mock_validate, runner, mock_context + ): + """Test cancel command with positional argument.""" + # Setup mock + mock_pipeline = MagicMock() + mock_pipeline_class.return_value = mock_pipeline + mock_pipeline.patch_pipeline_version_run.return_value = {'status': 'success'} + + # Run command + result = runner.invoke( + cancel, + [ + 'test-run-id', + '--user_id=test-user', + '--app_id=test-app', + '--pipeline_id=test-pipeline', + '--pipeline_version_id=v1', + ], + obj=mock_context.obj, + ) + + # Assertions + assert result.exit_code == 0 + assert 'has been cancelled' in result.output + mock_pipeline.patch_pipeline_version_run.assert_called_once_with( + pipeline_version_run_id='test-run-id', + orchestration_status_code=status_code_pb2.JOB_CANCELLED, + ) + + +class TestPipelineRunResume: + """Test cases for resume command.""" + + @patch('clarifai.cli.pipeline_run.validate_context') + @patch('clarifai.client.pipeline.Pipeline') + def test_resume_with_positional_arg( + self, mock_pipeline_class, mock_validate, runner, mock_context + ): + """Test resume command with positional argument.""" + # Setup mock + mock_pipeline = MagicMock() + mock_pipeline_class.return_value = mock_pipeline + mock_pipeline.patch_pipeline_version_run.return_value = {'status': 'success'} + + # Run command + result = runner.invoke( + resume, + [ + 'test-run-id', + '--user_id=test-user', + '--app_id=test-app', + '--pipeline_id=test-pipeline', + '--pipeline_version_id=v1', + ], + obj=mock_context.obj, + ) + + # Assertions + assert result.exit_code == 0 + assert 'has been resumed' in result.output + mock_pipeline.patch_pipeline_version_run.assert_called_once_with( + pipeline_version_run_id='test-run-id', + orchestration_status_code=status_code_pb2.JOB_RUNNING, + ) + + @patch('clarifai.cli.pipeline_run.validate_context') + @patch('clarifai.cli.pipeline_run.from_yaml') + @patch('os.path.exists') + @patch('clarifai.client.pipeline.Pipeline') + def test_resume_with_config_lock( + self, + mock_pipeline_class, + mock_exists, + mock_from_yaml, + mock_validate, + runner, + mock_context, + config_lock_data, + ): + """Test resume command loading parameters from config-lock.yaml.""" + # Setup mocks + mock_exists.return_value = True + mock_from_yaml.return_value = config_lock_data + mock_pipeline = MagicMock() + mock_pipeline_class.return_value = mock_pipeline + mock_pipeline.patch_pipeline_version_run.return_value = {'status': 'success'} + + # Run command + result = runner.invoke(resume, ['test-run-id'], obj=mock_context.obj) + + # Assertions + assert result.exit_code == 0 + assert 'has been resumed' in result.output + + +class TestPipelineRunMonitor: + """Test cases for monitor command.""" + + @patch('clarifai.cli.pipeline_run.validate_context') + @patch('clarifai.client.pipeline.Pipeline') + def test_monitor_with_positional_arg( + self, mock_pipeline_class, mock_validate, runner, mock_context + ): + """Test monitor command with positional argument.""" + # Setup mock + mock_pipeline = MagicMock() + mock_pipeline_class.return_value = mock_pipeline + mock_pipeline.monitor_only.return_value = {'status': 'success', 'run_id': 'test-run-id'} + + from clarifai.cli.pipeline_run import monitor + + # Run command + result = runner.invoke( + monitor, + [ + 'test-run-id', + '--user_id=test-user', + '--app_id=test-app', + '--pipeline_id=test-pipeline', + '--pipeline_version_id=v1', + ], + obj=mock_context.obj, + ) + + # Assertions + assert result.exit_code == 0 + mock_pipeline.monitor_only.assert_called_once_with(timeout=3600, monitor_interval=10) + # Check that pipeline_version_run_id was set + assert mock_pipeline.pipeline_version_run_id == 'test-run-id' + + @patch('clarifai.cli.pipeline_run.validate_context') + @patch('clarifai.client.pipeline.Pipeline') + def test_monitor_with_flag(self, mock_pipeline_class, mock_validate, runner, mock_context): + """Test monitor command with --pipeline_version_run_id flag.""" + # Setup mock + mock_pipeline = MagicMock() + mock_pipeline_class.return_value = mock_pipeline + mock_pipeline.monitor_only.return_value = {'status': 'success'} + + from clarifai.cli.pipeline_run import monitor + + # Run command + result = runner.invoke( + monitor, + [ + '--pipeline_version_run_id=test-run-id', + '--user_id=test-user', + '--app_id=test-app', + '--pipeline_id=test-pipeline', + '--pipeline_version_id=v1', + ], + obj=mock_context.obj, + ) + + # Assertions + assert result.exit_code == 0 + mock_pipeline.monitor_only.assert_called_once() + + @patch('clarifai.cli.pipeline_run.validate_context') + @patch('clarifai.cli.pipeline_run.from_yaml') + @patch('os.path.exists') + @patch('clarifai.client.pipeline.Pipeline') + def test_monitor_with_config_lock( + self, + mock_pipeline_class, + mock_exists, + mock_from_yaml, + mock_validate, + runner, + mock_context, + config_lock_data, + ): + """Test monitor command loading parameters from config-lock.yaml.""" + # Setup mocks + mock_exists.return_value = True + mock_from_yaml.return_value = config_lock_data + mock_pipeline = MagicMock() + mock_pipeline_class.return_value = mock_pipeline + mock_pipeline.monitor_only.return_value = {'status': 'success'} + + from clarifai.cli.pipeline_run import monitor + + # Run command + result = runner.invoke(monitor, ['test-run-id'], obj=mock_context.obj) + + # Assertions + assert result.exit_code == 0 + mock_pipeline_class.assert_called_once_with( + pipeline_id='test-pipeline', + pipeline_version_id='v1', + user_id='test-user', + app_id='test-app', + pat='test-pat', + base_url='https://api.clarifai.com', + ) + + @patch('clarifai.cli.pipeline_run.validate_context') + @patch('clarifai.client.pipeline.Pipeline') + def test_monitor_with_custom_timeout( + self, mock_pipeline_class, mock_validate, runner, mock_context + ): + """Test monitor command with custom timeout and interval.""" + # Setup mock + mock_pipeline = MagicMock() + mock_pipeline_class.return_value = mock_pipeline + mock_pipeline.monitor_only.return_value = {'status': 'success'} + + from clarifai.cli.pipeline_run import monitor + + # Run command + result = runner.invoke( + monitor, + [ + 'test-run-id', + '--user_id=test-user', + '--app_id=test-app', + '--pipeline_id=test-pipeline', + '--pipeline_version_id=v1', + '--timeout=7200', + '--monitor_interval=5', + ], + obj=mock_context.obj, + ) + + # Assertions + assert result.exit_code == 0 + mock_pipeline.monitor_only.assert_called_once_with(timeout=7200, monitor_interval=5) + + @patch('clarifai.cli.pipeline_run.validate_context') + def test_monitor_without_run_id(self, mock_validate, runner, mock_context): + """Test monitor command fails without pipeline_version_run_id.""" + from clarifai.cli.pipeline_run import monitor + + result = runner.invoke( + monitor, ['--user_id=test-user', '--app_id=test-app'], obj=mock_context.obj + ) + + assert result.exit_code != 0 + assert 'pipeline_version_run_id is required' in result.output