diff --git a/src/galileo/__future__/dataset.py b/src/galileo/__future__/dataset.py index 8398f398..4a6589fb 100644 --- a/src/galileo/__future__/dataset.py +++ b/src/galileo/__future__/dataset.py @@ -271,6 +271,7 @@ def generate( count: int = 10, data_types: list[str] | None = None, # type: ignore[valid-type] prompt_settings: dict[str, Any] | None = None, + timeout_seconds: int = 300, ) -> list[DatasetRow]: # type: ignore[valid-type] """ Generate synthetic dataset rows. @@ -282,11 +283,17 @@ def generate( count (int): The number of synthetic examples to generate. data_types (Optional[list[str]]): The types of data to generate. prompt_settings (Optional[dict[str, Any]]): Settings for the prompt generation. + timeout_seconds (int): Maximum seconds to wait for generation to complete. Defaults to 300. Returns ------- list[DatasetRow]: A list of generated dataset rows. + Raises + ------ + DatasetAPIException: If generation fails, stalls with no progress for 30 seconds, + or does not complete within timeout_seconds. + Examples -------- rows = Dataset.generate( @@ -304,6 +311,7 @@ def generate( count=count, data_types=data_types, prompt_settings=prompt_settings, + timeout_seconds=timeout_seconds, ) def get_content(self) -> DatasetContent | None: diff --git a/src/galileo/datasets.py b/src/galileo/datasets.py index f058e7e1..5d31dcd1 100644 --- a/src/galileo/datasets.py +++ b/src/galileo/datasets.py @@ -504,6 +504,7 @@ def extend( examples: Optional[builtins.list[str]] = None, data_types: Optional[builtins.list[str]] = None, count: int = 10, + timeout_seconds: int = 300, ) -> builtins.list[DatasetRow]: """ Extends a dataset with synthetically generated data based on the provided parameters. @@ -529,6 +530,8 @@ def extend( 'Sexist Content in Query'. count : int, default 10 The number of synthetic examples to generate. + timeout_seconds : int, default 300 + Maximum number of seconds to wait for the job to complete before raising an exception. Returns ------- @@ -538,9 +541,8 @@ def extend( Raises ------ DatasetAPIException - If the request to extend the dataset fails. - errors.UnexpectedStatus - If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + If the request to extend the dataset fails, the job stalls with no progress for 30 seconds, + or the job does not complete within timeout_seconds. httpx.TimeoutException If the request takes longer than Client.timeout. """ @@ -582,8 +584,14 @@ def extend( dataset_id = response.dataset_id - # Poll for job completion - while True: + # Poll for job completion. + # TODO: Replace stall detection with a proper status/error field once the API exposes one in JobProgress. + MAX_STALL_SECONDS = 30 + elapsed = 0 + last_steps_completed = None + stall_count = 0 + + while elapsed < timeout_seconds: job_progress = get_dataset_synthetic_extend_status_datasets_extend_dataset_id_get.sync( dataset_id=dataset_id, client=self.config.api_client ) @@ -594,25 +602,34 @@ def extend( if not job_progress or not isinstance(job_progress, JobProgress): raise DatasetAPIException("Invalid job progress response.") - # Check if job is complete + steps_done = job_progress.steps_completed + steps_total = job_progress.steps_total + + if job_progress.progress_message: + logger.info(f"({steps_done}/{steps_total}) {job_progress.progress_message}") + if ( - job_progress.steps_completed is not None - and job_progress.steps_total is not None - and job_progress.steps_completed == job_progress.steps_total + isinstance(steps_done, int) + and isinstance(steps_total, int) + and steps_total > 0 + and steps_done == steps_total ): - logger.info( - f"({job_progress.steps_completed}/{job_progress.steps_total}) {job_progress.progress_message}" - ) break - # Log progress message if available - if job_progress.progress_message: - logger.info( - f"({job_progress.steps_completed}/{job_progress.steps_total}) {job_progress.progress_message}" - ) + if steps_done == last_steps_completed: + stall_count += 1 + if stall_count >= MAX_STALL_SECONDS: + raise DatasetAPIException( + f"Dataset extension job stalled: no progress for {MAX_STALL_SECONDS} seconds." + ) + else: + stall_count = 0 + last_steps_completed = steps_done - # Wait 1 second before polling again time.sleep(1) + elapsed += 1 + else: + raise DatasetAPIException(f"Dataset extension job timed out after {timeout_seconds} seconds.") # Get the final dataset content dataset_content = get_dataset_content_datasets_dataset_id_content_get.sync( @@ -907,6 +924,7 @@ def extend_dataset( examples: Optional[list[str]] = None, data_types: Optional[list[str]] = None, count: int = 10, + timeout_seconds: int = 300, ) -> list[DatasetRow]: """ Extends a dataset with synthetically generated data based on the provided parameters. @@ -932,6 +950,8 @@ def extend_dataset( 'Sexist Content in Query'. count : int, default 10 The number of synthetic examples to generate. + timeout_seconds : int, default 300 + Maximum number of seconds to wait for the job to complete before raising an exception. Returns ------- @@ -941,9 +961,8 @@ def extend_dataset( Raises ------ DatasetAPIException - If the request to extend the dataset fails. - errors.UnexpectedStatus - If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + If the request to extend the dataset fails, the job stalls with no progress for 30 seconds, + or the job does not complete within timeout_seconds. httpx.TimeoutException If the request takes longer than Client.timeout. @@ -955,6 +974,7 @@ def extend_dataset( examples=examples, data_types=data_types, count=count, + timeout_seconds=timeout_seconds, ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 56ee782c..6674a65c 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -588,6 +588,69 @@ def test_extend_dataset_api_failure(extend_dataset_mock: Mock) -> None: extend_dataset(prompt_settings={"model_alias": "GPT-4o mini"}, prompt="Test prompt", count=1) +@patch("galileo.datasets.get_dataset_synthetic_extend_status_datasets_extend_dataset_id_get") +@patch("galileo.datasets.extend_dataset_content_datasets_extend_post") +@patch("galileo.datasets.time.sleep") +def test_extend_dataset_zero_steps_total_on_first_poll( + sleep_mock: Mock, extend_dataset_mock: Mock, get_extend_status_mock: Mock +) -> None: + # Given: the first poll returns steps_total=0 (job not yet started), then completes normally + extended_dataset_id = "a8b3d8e0-5e0b-4b0f-8b3a-3b9f4b3d3b3b" + extend_dataset_mock.sync.return_value = SyntheticDatasetExtensionResponse(dataset_id=extended_dataset_id) + get_extend_status_mock.sync.side_effect = [ + JobProgress(steps_completed=0, steps_total=0, progress_message="Queued"), + JobProgress(steps_completed=1, steps_total=2, progress_message="Processing"), + JobProgress(steps_completed=2, steps_total=2, progress_message="Done"), + ] + + with patch("galileo.datasets.get_dataset_content_datasets_dataset_id_content_get") as content_mock: + row = DatasetRow(index=0, row_id="row-1", values=["val"], values_dict={"col": "val"}, metadata=None) + content_mock.sync.return_value = DatasetContent(column_names=["col"], rows=[row]) + + # When: extending the dataset + result = extend_dataset(prompt="Test", count=2) + + # Then: all three polls are made and the rows are returned (not short-circuited on the first poll) + assert result == [row] + assert get_extend_status_mock.sync.call_count == 3 + + +@patch("galileo.datasets.get_dataset_synthetic_extend_status_datasets_extend_dataset_id_get") +@patch("galileo.datasets.extend_dataset_content_datasets_extend_post") +@patch("galileo.datasets.time.sleep") +def test_extend_dataset_timeout(sleep_mock: Mock, extend_dataset_mock: Mock, get_extend_status_mock: Mock) -> None: + # Given: the job never completes within the timeout + extended_dataset_id = "a8b3d8e0-5e0b-4b0f-8b3a-3b9f4b3d3b3c" + extend_dataset_mock.sync.return_value = SyntheticDatasetExtensionResponse(dataset_id=extended_dataset_id) + get_extend_status_mock.sync.return_value = JobProgress( + steps_completed=1, steps_total=10, progress_message="Processing" + ) + + # When/Then: a timeout exception is raised after timeout_seconds + with pytest.raises(DatasetAPIException, match="timed out after 3 seconds"): + extend_dataset(prompt="Test", count=5, timeout_seconds=3) + + assert get_extend_status_mock.sync.call_count == 3 + + +@patch("galileo.datasets.get_dataset_synthetic_extend_status_datasets_extend_dataset_id_get") +@patch("galileo.datasets.extend_dataset_content_datasets_extend_post") +@patch("galileo.datasets.time.sleep") +def test_extend_dataset_stall_detection( + sleep_mock: Mock, extend_dataset_mock: Mock, get_extend_status_mock: Mock +) -> None: + # Given: the job starts but steps_completed stops advancing + extended_dataset_id = "a8b3d8e0-5e0b-4b0f-8b3a-3b9f4b3d3b3d" + extend_dataset_mock.sync.return_value = SyntheticDatasetExtensionResponse(dataset_id=extended_dataset_id) + get_extend_status_mock.sync.return_value = JobProgress(steps_completed=2, steps_total=10, progress_message="Stuck") + + # When/Then: a stall exception is raised after 30 polls with no progress + with pytest.raises(DatasetAPIException, match="stalled"): + extend_dataset(prompt="Test", count=5, timeout_seconds=300) + + assert get_extend_status_mock.sync.call_count == 31 + + # =================================================================== # Project Association Tests for Dataset CRUD Operations # ===================================================================