Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions DATASETS.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ Out[6]:
## Getting the list of all the datasets

If you want to obtain the list of all the available datasets you can use the
`sdgym.get_available_datasets` function:
`list_datasets` function:

```python
In [7]: from sdgym import get_available_datasets
In [7]: from sdgym.dataset_explorer import DatasetExplorer

In [8]: get_available_datasets()
In [8]: DatasetExplorer().list_datasets()
Out[8]:
dataset_name size_MB num_tables
0 KRK_v1 0.072128 1
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ Learn more in the [Custom Synthesizers Guide](https://docs.sdv.dev/sdgym/customi
## Customizing your datasets

The SDGym library includes many publicly available datasets that you can include right away.
List these using the ``get_available_datasets`` feature.
List these using the ``list_datasets`` feature.

```python
sdgym.get_available_datasets()
sdgym.dataset_explorer.DatasetExplorer().list_datasets()
```

```
Expand Down
3 changes: 1 addition & 2 deletions sdgym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from sdgym.cli.collect import collect_results
from sdgym.cli.summary import make_summary_spreadsheet
from sdgym.dataset_explorer import DatasetExplorer
from sdgym.datasets import get_available_datasets, load_dataset
from sdgym.datasets import load_dataset
from sdgym.synthesizers import (
create_synthesizer_variant,
create_single_table_synthesizer,
Expand All @@ -37,7 +37,6 @@
'create_synthesizer_variant',
'create_single_table_synthesizer',
'create_multi_table_synthesizer',
'get_available_datasets',
'load_dataset',
'make_summary_spreadsheet',
]
4 changes: 2 additions & 2 deletions sdgym/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _download_datasets(args):
_env_setup(args.logfile, args.verbose)
datasets = args.datasets
if not datasets:
datasets = sdgym.datasets.get_available_datasets(
datasets = sdgym.datasets._get_available_datasets(
args.bucket, args.aws_access_key_id, args.aws_secret_access_key
)['name']

Expand All @@ -118,7 +118,7 @@ def _list_downloaded(args):


def _list_available(args):
datasets = sdgym.datasets.get_available_datasets(
datasets = sdgym.datasets._get_available_datasets(
args.bucket, args.aws_access_key_id, args.aws_secret_access_key
)
_print_table(datasets, args.sort, args.reverse, {'size': humanfriendly.format_size})
Expand Down
33 changes: 33 additions & 0 deletions sdgym/dataset_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,36 @@ def summarize_datasets(self, modality, output_filepath=None):
dataset_summary.to_csv(output_filepath, index=False)

return dataset_summary

def list_datasets(self, modality, output_filepath=None):
"""List available datasets for a modality using metainfo only.

This is a lightweight alternative to ``summarize_datasets`` that does not load
the actual data. It reads dataset information from the ``metainfo.yaml`` files
in the bucket and returns a table equivalent to the legacy
``get_available_datasets`` output.

Args:
modality (str):
It must be ``'single_table'``, ``'multi_table'`` or ``'sequential'``.
output_filepath (str, optional):
Full path to a ``.csv`` file where the resulting table will be written.
If not provided, the table is only returned.

Returns:
pd.DataFrame:
A DataFrame with columns: ``['dataset_name', 'size_MB', 'num_tables']``.
"""
self._validate_output_filepath(output_filepath)
_validate_modality(modality)

dataframe = _get_available_datasets(
modality=modality,
bucket=self._bucket_name,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
)
if output_filepath:
dataframe.to_csv(output_filepath, index=False)

return dataframe
15 changes: 0 additions & 15 deletions sdgym/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,21 +254,6 @@ def load_dataset(
return data, metadata_dict


def get_available_datasets(modality='single_table'):
"""Get available single_table datasets.

Args:
modality (str):
It must be ``'single_table'``, ``'multi_table'`` or ``'sequential'``.

Return:
pd.DataFrame:
Table of available datasets and their sizes.
"""
_validate_modality(modality)
return _get_available_datasets(modality)


def get_dataset_paths(
modality,
datasets=None,
Expand Down
32 changes: 16 additions & 16 deletions tests/integration/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
from sdgym import get_available_datasets
from sdgym import DatasetExplorer


def test_get_available_datasets_single_table():
"""Test that `get_available_datasets` returns single table datasets with expected properties."""
def test_list_datasets_single_table():
"""Test that it lists single table datasets with expected properties."""
# Run
df = get_available_datasets('single_table')
dataframe = DatasetExplorer().list_datasets('single_table')

# Assert
assert df.columns.tolist() == ['dataset_name', 'size_MB', 'num_tables']
assert all(df['num_tables'] == 1)
assert dataframe.columns.tolist() == ['dataset_name', 'size_MB', 'num_tables']
assert all(dataframe['num_tables'] == 1)


def test_get_available_datasets_multi_table():
"""Test that `get_available_datasets` returns multi table datasets with expected properties."""
def test_list_datasets_multi_table():
"""Test that it lists multi table datasets with expected properties."""
# Run
df = get_available_datasets('multi_table')
dataframe = DatasetExplorer().list_datasets('multi_table')

# Assert
assert df.columns.tolist() == ['dataset_name', 'size_MB', 'num_tables']
assert all(df['num_tables'] > 1)
assert dataframe.columns.tolist() == ['dataset_name', 'size_MB', 'num_tables']
assert all(dataframe['num_tables'] > 1)


def test_get_available_datasets_sequential():
"""Test that `get_available_datasets` returns sequential datasets with expected properties."""
def test_list_datasets_sequential():
"""Test that it lists sequential datasets with expected properties."""
# Run
df = get_available_datasets('sequential')
dataframe = DatasetExplorer().list_datasets('sequential')

# Assert
assert df.columns.tolist() == ['dataset_name', 'size_MB', 'num_tables']
assert all(df['num_tables'] == 1)
assert dataframe.columns.tolist() == ['dataset_name', 'size_MB', 'num_tables']
assert all(dataframe['num_tables'] == 1)
54 changes: 54 additions & 0 deletions tests/unit/test_dataset_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,57 @@ def test_summarize_datasets_with_output(
assert output_filepath.exists()
assert isinstance(df, pd.DataFrame)
assert df.columns.to_list() == SUMMARY_OUTPUT_COLUMNS

@patch('sdgym.dataset_explorer._validate_modality')
@patch('sdgym.dataset_explorer._get_available_datasets')
def test_list_datasets_without_output(self, mock_get_available, mock_validate_modality):
"""Test that `list_datasets` returns the expected dataframe."""
# Setup
explorer = DatasetExplorer()
expected_df = pd.DataFrame([
{'dataset_name': 'ds1', 'size_MB': 12.5, 'num_tables': 1},
{'dataset_name': 'ds2', 'size_MB': 3.0, 'num_tables': 2},
])
mock_get_available.return_value = expected_df

# Run
result = explorer.list_datasets('single_table')

# Assert
mock_validate_modality.assert_called_once_with('single_table')
mock_get_available.assert_called_once_with(
modality='single_table',
bucket='sdv-datasets-public',
aws_access_key_id=None,
aws_secret_access_key=None,
)
pd.testing.assert_frame_equal(result, expected_df)

@patch('sdgym.dataset_explorer._validate_modality')
@patch('sdgym.dataset_explorer._get_available_datasets')
def test_list_datasets_with_output(self, mock_get_available, mock_validate_modality, tmp_path):
"""Test that `list_datasets` writes CSV when output path is provided."""
# Setup
explorer = DatasetExplorer()
expected_df = pd.DataFrame([
{'dataset_name': 'alpha', 'size_MB': 1.5, 'num_tables': 1},
{'dataset_name': 'beta', 'size_MB': 2.0, 'num_tables': 3},
])
mock_get_available.return_value = expected_df
output_filepath = tmp_path / 'datasets_list.csv'

# Run
result = explorer.list_datasets('multi_table', output_filepath=str(output_filepath))

# Assert
mock_validate_modality.assert_called_once_with('multi_table')
mock_get_available.assert_called_once_with(
modality='multi_table',
bucket='sdv-datasets-public',
aws_access_key_id=None,
aws_secret_access_key=None,
)
assert output_filepath.exists()
loaded = pd.read_csv(output_filepath)
pd.testing.assert_frame_equal(loaded, expected_df)
pd.testing.assert_frame_equal(result, expected_df)
46 changes: 0 additions & 46 deletions tests/unit/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
from unittest.mock import Mock, call, patch

import numpy as np
import pandas as pd
import pytest

from sdgym import get_available_datasets
from sdgym.datasets import (
DATASETS_PATH,
_download_dataset,
Expand Down Expand Up @@ -361,50 +359,6 @@ def test_get_bucket_name_local_folder():
assert bucket_name == 'bucket-name'


@patch('sdgym.datasets._get_available_datasets')
def test_get_available_datasets(helper_mock):
"""Test that the modality is set to single-table."""
# Run
get_available_datasets()

# Assert
helper_mock.assert_called_once_with('single_table')


def test_get_available_datasets_results():
# Run
tables_info = get_available_datasets()

# Assert
expected_table = pd.DataFrame({
'dataset_name': [
'adult',
'alarm',
'census',
'child',
'covtype',
'expedia_hotel_logs',
'insurance',
'intrusion',
'news',
],
'size_MB': [
'3.907448',
'4.520128',
'98.165608',
'3.200128',
'255.645408',
'0.200128',
'3.340128',
'162.039016',
'18.712096',
],
'num_tables': [1] * 9,
})
expected_table['size_MB'] = expected_table['size_MB'].astype(float).round(2)
assert len(expected_table.merge(tables_info.round(2))) == len(expected_table)


@patch('sdgym.datasets._get_dataset_path_and_download')
@patch('sdgym.datasets._path_contains_data_and_metadata', return_value=True)
@patch('sdgym.datasets.Path')
Expand Down