Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/galileo/__future__/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def generate(
count: int = 10,
data_types: list[str] | None = None, # type: ignore[valid-type]
prompt_settings: dict[str, Any] | None = None,
timeout_seconds: int = 300,
) -> list[DatasetRow]: # type: ignore[valid-type]
"""
Generate synthetic dataset rows.
Expand All @@ -282,11 +283,17 @@ def generate(
count (int): The number of synthetic examples to generate.
data_types (Optional[list[str]]): The types of data to generate.
prompt_settings (Optional[dict[str, Any]]): Settings for the prompt generation.
timeout_seconds (int): Maximum seconds to wait for generation to complete. Defaults to 300.

Returns
-------
list[DatasetRow]: A list of generated dataset rows.

Raises
------
DatasetAPIException: If generation fails, stalls with no progress for 30 seconds,
or does not complete within timeout_seconds.

Examples
--------
rows = Dataset.generate(
Expand All @@ -304,6 +311,7 @@ def generate(
count=count,
data_types=data_types,
prompt_settings=prompt_settings,
timeout_seconds=timeout_seconds,
)

def get_content(self) -> DatasetContent | None:
Expand Down
62 changes: 41 additions & 21 deletions src/galileo/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ def extend(
examples: Optional[builtins.list[str]] = None,
data_types: Optional[builtins.list[str]] = None,
count: int = 10,
timeout_seconds: int = 300,
) -> builtins.list[DatasetRow]:
"""
Extends a dataset with synthetically generated data based on the provided parameters.
Expand All @@ -529,6 +530,8 @@ def extend(
'Sexist Content in Query'.
count : int, default 10
The number of synthetic examples to generate.
timeout_seconds : int, default 300
Maximum number of seconds to wait for the job to complete before raising an exception.

Returns
-------
Expand All @@ -538,9 +541,8 @@ def extend(
Raises
------
DatasetAPIException
If the request to extend the dataset fails.
errors.UnexpectedStatus
If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
If the request to extend the dataset fails, the job stalls with no progress for 30 seconds,
or the job does not complete within timeout_seconds.
httpx.TimeoutException
If the request takes longer than Client.timeout.
"""
Expand Down Expand Up @@ -582,8 +584,14 @@ def extend(

dataset_id = response.dataset_id

# Poll for job completion
while True:
# Poll for job completion.
# TODO: Replace stall detection with a proper status/error field once the API exposes one in JobProgress.
MAX_STALL_SECONDS = 30
elapsed = 0
last_steps_completed = None
stall_count = 0

while elapsed < timeout_seconds:
job_progress = get_dataset_synthetic_extend_status_datasets_extend_dataset_id_get.sync(
dataset_id=dataset_id, client=self.config.api_client
)
Expand All @@ -594,25 +602,34 @@ def extend(
if not job_progress or not isinstance(job_progress, JobProgress):
raise DatasetAPIException("Invalid job progress response.")

# Check if job is complete
steps_done = job_progress.steps_completed
steps_total = job_progress.steps_total

if job_progress.progress_message:
logger.info(f"({steps_done}/{steps_total}) {job_progress.progress_message}")

if (
job_progress.steps_completed is not None
and job_progress.steps_total is not None
and job_progress.steps_completed == job_progress.steps_total
isinstance(steps_done, int)
and isinstance(steps_total, int)
and steps_total > 0
and steps_done == steps_total
):
logger.info(
f"({job_progress.steps_completed}/{job_progress.steps_total}) {job_progress.progress_message}"
)
break

# Log progress message if available
if job_progress.progress_message:
logger.info(
f"({job_progress.steps_completed}/{job_progress.steps_total}) {job_progress.progress_message}"
)
if steps_done == last_steps_completed:
stall_count += 1
if stall_count >= MAX_STALL_SECONDS:
raise DatasetAPIException(
f"Dataset extension job stalled: no progress for {MAX_STALL_SECONDS} seconds."
)
else:
stall_count = 0
last_steps_completed = steps_done

# Wait 1 second before polling again
time.sleep(1)
elapsed += 1
else:
raise DatasetAPIException(f"Dataset extension job timed out after {timeout_seconds} seconds.")

# Get the final dataset content
dataset_content = get_dataset_content_datasets_dataset_id_content_get.sync(
Expand Down Expand Up @@ -907,6 +924,7 @@ def extend_dataset(
examples: Optional[list[str]] = None,
data_types: Optional[list[str]] = None,
count: int = 10,
timeout_seconds: int = 300,
) -> list[DatasetRow]:
"""
Extends a dataset with synthetically generated data based on the provided parameters.
Expand All @@ -932,6 +950,8 @@ def extend_dataset(
'Sexist Content in Query'.
count : int, default 10
The number of synthetic examples to generate.
timeout_seconds : int, default 300
Maximum number of seconds to wait for the job to complete before raising an exception.

Returns
-------
Expand All @@ -941,9 +961,8 @@ def extend_dataset(
Raises
------
DatasetAPIException
If the request to extend the dataset fails.
errors.UnexpectedStatus
If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
If the request to extend the dataset fails, the job stalls with no progress for 30 seconds,
or the job does not complete within timeout_seconds.
httpx.TimeoutException
If the request takes longer than Client.timeout.

Expand All @@ -955,6 +974,7 @@ def extend_dataset(
examples=examples,
data_types=data_types,
count=count,
timeout_seconds=timeout_seconds,
)


Expand Down
63 changes: 63 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,69 @@ def test_extend_dataset_api_failure(extend_dataset_mock: Mock) -> None:
extend_dataset(prompt_settings={"model_alias": "GPT-4o mini"}, prompt="Test prompt", count=1)


@patch("galileo.datasets.get_dataset_synthetic_extend_status_datasets_extend_dataset_id_get")
@patch("galileo.datasets.extend_dataset_content_datasets_extend_post")
@patch("galileo.datasets.time.sleep")
def test_extend_dataset_zero_steps_total_on_first_poll(
sleep_mock: Mock, extend_dataset_mock: Mock, get_extend_status_mock: Mock
) -> None:
# Given: the first poll returns steps_total=0 (job not yet started), then completes normally
extended_dataset_id = "a8b3d8e0-5e0b-4b0f-8b3a-3b9f4b3d3b3b"
extend_dataset_mock.sync.return_value = SyntheticDatasetExtensionResponse(dataset_id=extended_dataset_id)
get_extend_status_mock.sync.side_effect = [
JobProgress(steps_completed=0, steps_total=0, progress_message="Queued"),
JobProgress(steps_completed=1, steps_total=2, progress_message="Processing"),
JobProgress(steps_completed=2, steps_total=2, progress_message="Done"),
]

with patch("galileo.datasets.get_dataset_content_datasets_dataset_id_content_get") as content_mock:
row = DatasetRow(index=0, row_id="row-1", values=["val"], values_dict={"col": "val"}, metadata=None)
content_mock.sync.return_value = DatasetContent(column_names=["col"], rows=[row])

# When: extending the dataset
result = extend_dataset(prompt="Test", count=2)

# Then: all three polls are made and the rows are returned (not short-circuited on the first poll)
assert result == [row]
assert get_extend_status_mock.sync.call_count == 3


@patch("galileo.datasets.get_dataset_synthetic_extend_status_datasets_extend_dataset_id_get")
@patch("galileo.datasets.extend_dataset_content_datasets_extend_post")
@patch("galileo.datasets.time.sleep")
def test_extend_dataset_timeout(sleep_mock: Mock, extend_dataset_mock: Mock, get_extend_status_mock: Mock) -> None:
# Given: the job never completes within the timeout
extended_dataset_id = "a8b3d8e0-5e0b-4b0f-8b3a-3b9f4b3d3b3c"
extend_dataset_mock.sync.return_value = SyntheticDatasetExtensionResponse(dataset_id=extended_dataset_id)
get_extend_status_mock.sync.return_value = JobProgress(
steps_completed=1, steps_total=10, progress_message="Processing"
)

# When/Then: a timeout exception is raised after timeout_seconds
with pytest.raises(DatasetAPIException, match="timed out after 3 seconds"):
extend_dataset(prompt="Test", count=5, timeout_seconds=3)

assert get_extend_status_mock.sync.call_count == 3


@patch("galileo.datasets.get_dataset_synthetic_extend_status_datasets_extend_dataset_id_get")
@patch("galileo.datasets.extend_dataset_content_datasets_extend_post")
@patch("galileo.datasets.time.sleep")
def test_extend_dataset_stall_detection(
sleep_mock: Mock, extend_dataset_mock: Mock, get_extend_status_mock: Mock
) -> None:
# Given: the job starts but steps_completed stops advancing
extended_dataset_id = "a8b3d8e0-5e0b-4b0f-8b3a-3b9f4b3d3b3d"
extend_dataset_mock.sync.return_value = SyntheticDatasetExtensionResponse(dataset_id=extended_dataset_id)
get_extend_status_mock.sync.return_value = JobProgress(steps_completed=2, steps_total=10, progress_message="Stuck")

# When/Then: a stall exception is raised after 30 polls with no progress
with pytest.raises(DatasetAPIException, match="stalled"):
extend_dataset(prompt="Test", count=5, timeout_seconds=300)

assert get_extend_status_mock.sync.call_count == 31


# ===================================================================
# Project Association Tests for Dataset CRUD Operations
# ===================================================================
Expand Down
Loading