Skip to content

Commit ddd3dcf

Browse files
authored
SDGym should be able to automatically discover SDV Enterprise synthesizers (#489)
1 parent 6a3b217 commit ddd3dcf

File tree

15 files changed

+724
-515
lines changed

15 files changed

+724
-515
lines changed

sdgym/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,16 @@
1212

1313
import logging
1414

15-
from sdgym.benchmark import benchmark_single_table
15+
from sdgym.benchmark import benchmark_single_table, benchmark_single_table_aws
1616
from sdgym.cli.collect import collect_results
1717
from sdgym.cli.summary import make_summary_spreadsheet
1818
from sdgym.dataset_explorer import DatasetExplorer
1919
from sdgym.datasets import get_available_datasets, load_dataset
20-
from sdgym.synthesizers import create_sdv_synthesizer_variant, create_single_table_synthesizer
20+
from sdgym.synthesizers import (
21+
create_synthesizer_variant,
22+
create_single_table_synthesizer,
23+
create_multi_table_synthesizer,
24+
)
2125
from sdgym.result_explorer import ResultsExplorer
2226

2327
# Clear the logging wrongfully configured by tensorflow/absl
@@ -28,9 +32,11 @@
2832
'DatasetExplorer',
2933
'ResultsExplorer',
3034
'benchmark_single_table',
35+
'benchmark_single_table_aws',
3136
'collect_results',
32-
'create_sdv_synthesizer_variant',
37+
'create_synthesizer_variant',
3338
'create_single_table_synthesizer',
39+
'create_multi_table_synthesizer',
3440
'get_available_datasets',
3541
'load_dataset',
3642
'make_summary_spreadsheet',

sdgym/benchmark.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
write_csv,
5353
write_file,
5454
)
55-
from sdgym.synthesizers import CTGANSynthesizer, GaussianCopulaSynthesizer, UniformSynthesizer
55+
from sdgym.synthesizers import UniformSynthesizer
5656
from sdgym.synthesizers.base import BaselineSynthesizer
5757
from sdgym.utils import (
5858
calculate_score_time,
@@ -67,7 +67,7 @@
6767
)
6868

6969
LOGGER = logging.getLogger(__name__)
70-
DEFAULT_SYNTHESIZERS = [GaussianCopulaSynthesizer, CTGANSynthesizer, UniformSynthesizer]
70+
DEFAULT_SYNTHESIZERS = ['GaussianCopulaSynthesizer', 'CTGANSynthesizer', 'UniformSynthesizer']
7171
DEFAULT_DATASETS = [
7272
'adult',
7373
'alarm',
@@ -861,6 +861,7 @@ def _directory_exists(bucket_name, s3_file_path):
861861

862862

863863
def _check_write_permissions(s3_client, bucket_name):
864+
s3_client = s3_client or boto3.client('s3')
864865
try:
865866
s3_client.put_object(Bucket=bucket_name, Key='__test__', Body=b'')
866867
write_permission = True
@@ -881,7 +882,7 @@ def _create_sdgym_script(params, output_filepath):
881882
bucket_name, key_prefix = parse_s3_path(output_filepath)
882883
if not _directory_exists(bucket_name, key_prefix):
883884
raise ValueError(f'Directories in {key_prefix} do not exist')
884-
if not _check_write_permissions(bucket_name):
885+
if not _check_write_permissions(None, bucket_name):
885886
raise ValueError('No write permissions allowed for the bucket.')
886887

887888
# Add quotes to parameter strings
@@ -893,23 +894,22 @@ def _create_sdgym_script(params, output_filepath):
893894
params['output_filepath'] = "'" + params['output_filepath'] + "'"
894895

895896
# Generate the output script to run on the e2 instance
896-
synthesizer_string = 'synthesizers=['
897-
for synthesizer in params['synthesizers']:
897+
synthesizers = params.get('synthesizers', [])
898+
names = []
899+
for synthesizer in synthesizers:
898900
if isinstance(synthesizer, str):
899-
synthesizer_string += synthesizer + ', '
901+
names.append(synthesizer)
902+
elif hasattr(synthesizer, '__name__'):
903+
names.append(synthesizer.__name__)
900904
else:
901-
synthesizer_string += synthesizer.__name__ + ', '
902-
if params['synthesizers']:
903-
synthesizer_string = synthesizer_string[:-2]
904-
synthesizer_string += ']'
905+
names.append(synthesizer.__class__.__name__)
906+
907+
all_names = '", "'.join(names)
908+
synthesizer_string = f'synthesizers=["{all_names}"]'
905909
# The indentation of the string is important for the python script
906910
script_content = f"""import boto3
907911
from io import StringIO
908912
import sdgym
909-
from sdgym.synthesizers.sdv import (CopulaGANSynthesizer, CTGANSynthesizer,
910-
GaussianCopulaSynthesizer, HMASynthesizer, PARSynthesizer, SDVRelationalSynthesizer,
911-
SDVTabularSynthesizer, TVAESynthesizer)
912-
from sdgym.synthesizers import RealTabFormerSynthesizer
913913
914914
results = sdgym.benchmark_single_table(
915915
{synthesizer_string}, custom_synthesizers={params['custom_synthesizers']},
@@ -1186,7 +1186,7 @@ def benchmark_single_table(
11861186
custom_synthesizers (list[class] or ``None``):
11871187
A list of custom synthesizer classes to use. These can be completely custom or
11881188
they can be synthesizer variants (the output from ``create_single_table_synthesizer``
1189-
or ``create_sdv_synthesizer_variant``). Defaults to ``None``.
1189+
or ``create_synthesizer_variant``). Defaults to ``None``.
11901190
sdv_datasets (list[str] or ``None``):
11911191
Names of the SDV demo datasets to use for the benchmark. Defaults to
11921192
``[adult, alarm, census, child, expedia_hotel_logs, insurance, intrusion, news,

sdgym/synthesizers/__init__.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,32 @@
11
"""Synthesizers module."""
22

33
from sdgym.synthesizers.generate import (
4-
SYNTHESIZER_MAPPING,
5-
create_multi_table_synthesizer,
6-
create_sdv_synthesizer_variant,
7-
create_sequential_synthesizer,
4+
create_synthesizer_variant,
85
create_single_table_synthesizer,
6+
create_multi_table_synthesizer,
97
)
108
from sdgym.synthesizers.identity import DataIdentity
119
from sdgym.synthesizers.column import ColumnSynthesizer
1210
from sdgym.synthesizers.realtabformer import RealTabFormerSynthesizer
13-
from sdgym.synthesizers.sdv import (
14-
CopulaGANSynthesizer,
15-
CTGANSynthesizer,
16-
GaussianCopulaSynthesizer,
17-
HMASynthesizer,
18-
PARSynthesizer,
19-
SDVRelationalSynthesizer,
20-
SDVTabularSynthesizer,
21-
TVAESynthesizer,
22-
)
2311
from sdgym.synthesizers.uniform import UniformSynthesizer
12+
from sdgym.synthesizers.utils import (
13+
get_available_single_table_synthesizers,
14+
get_available_multi_table_synthesizers,
15+
)
16+
from sdgym.synthesizers.sdv import create_sdv_synthesizer_class, _get_all_sdv_synthesizers
17+
2418

25-
__all__ = (
19+
__all__ = [
2620
'DataIdentity',
2721
'ColumnSynthesizer',
28-
'CTGANSynthesizer',
29-
'TVAESynthesizer',
3022
'UniformSynthesizer',
31-
'CopulaGANSynthesizer',
32-
'GaussianCopulaSynthesizer',
33-
'HMASynthesizer',
34-
'PARSynthesizer',
35-
'SDVTabularSynthesizer',
36-
'SDVRelationalSynthesizer',
23+
'RealTabFormerSynthesizer',
3724
'create_single_table_synthesizer',
3825
'create_multi_table_synthesizer',
39-
'create_sdv_synthesizer_variant',
40-
'create_sequential_synthesizer',
41-
'SYNTHESIZER_MAPPING',
42-
'RealTabFormerSynthesizer',
43-
)
26+
'create_synthesizer_variant',
27+
'get_available_single_table_synthesizers',
28+
'get_available_multi_table_synthesizers',
29+
]
30+
31+
for sdv_name in _get_all_sdv_synthesizers():
32+
create_sdv_synthesizer_class(sdv_name)

sdgym/synthesizers/base.py

Lines changed: 14 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
class BaselineSynthesizer(abc.ABC):
1313
"""Base class for all the ``SDGym`` baselines."""
1414

15+
_MODEL_KWARGS = {}
16+
_NATIVELY_SUPPORTED = True
17+
1518
@classmethod
1619
def get_subclasses(cls, include_parents=False):
1720
"""Recursively find subclasses of this Baseline.
@@ -30,6 +33,17 @@ def get_subclasses(cls, include_parents=False):
3033

3134
return subclasses
3235

36+
@classmethod
37+
def _get_supported_synthesizers(cls):
38+
"""Get the natively supported synthesizer class names."""
39+
subclasses = cls.get_subclasses(include_parents=True)
40+
synthesizers = set()
41+
for name, subclass in subclasses.items():
42+
if subclass._NATIVELY_SUPPORTED:
43+
synthesizers.add(name)
44+
45+
return sorted(synthesizers)
46+
3347
@classmethod
3448
def get_baselines(cls):
3549
"""Get baseline classes."""
@@ -76,79 +90,3 @@ def sample_from_synthesizer(self, synthesizer, n_samples):
7690
should be a dict mapping table name to DataFrame.
7791
"""
7892
return self._sample_from_synthesizer(synthesizer, n_samples)
79-
80-
81-
class MultiSingleTableBaselineSynthesizer(BaselineSynthesizer, abc.ABC):
82-
"""Base class for SingleTableBaselines that are used on multi table scenarios.
83-
84-
These classes model and sample each table independently and then just
85-
randomly choose ids from the parent tables to form the relationships.
86-
87-
NOTE: doesn't currently work.
88-
"""
89-
90-
def get_trained_synthesizer(self, data, metadata):
91-
"""Get the trained synthesizer.
92-
93-
Args:
94-
data (dict):
95-
A dict mapping table name to table data.
96-
metadata (sdv.metadata.multi_table.MultiTableMetadata):
97-
The multi-table metadata.
98-
99-
Returns:
100-
dict:
101-
A mapping of table name to synthesizers.
102-
"""
103-
self.metadata = metadata
104-
synthesizers = {
105-
table_name: self._get_trained_synthesizer(table, metadata.tables[table_name])
106-
for table_name, table in data.items()
107-
}
108-
self.table_columns = {table_name: data[table_name].columns for table_name in data.keys()}
109-
110-
return synthesizers
111-
112-
def _get_foreign_keys(self, metadata, table_name, child_name):
113-
foreign_keys = []
114-
for relation in metadata.relationships:
115-
if (
116-
table_name == relation['parent_table_name']
117-
and child_name == relation['child_table_name']
118-
):
119-
foreign_keys.append(relation['child_foreign_key'])
120-
121-
return foreign_keys
122-
123-
def sample_from_synthesizer(self, synthesizers, n_samples):
124-
"""Sample from the given synthesizers.
125-
126-
Args:
127-
synthesizers (dict):
128-
A dict mapping table name to table synthesizer.
129-
n_samples (int):
130-
The number of samples.
131-
132-
Returns:
133-
dict:
134-
A mapping of table name to sampled table data.
135-
"""
136-
tables = {
137-
table_name: self._sample_from_synthesizer(synthesizer, n_samples)
138-
for table_name, synthesizer in synthesizers.items()
139-
}
140-
141-
for table_name, table in tables.items():
142-
table_metadata = self.metadata.tables[table_name]
143-
parents = list(table_metadata._get_parent_map().keys())
144-
for parent_name in parents:
145-
parent = tables[parent_name]
146-
primary_key = self.metadata.tables[table_name].primary_key
147-
foreign_keys = self._get_foreign_keys(self.metadata, parent_name, table_name)
148-
for foreign_key in foreign_keys:
149-
foreign_key_values = parent[primary_key].sample(len(table), replace=True)
150-
table[foreign_key] = foreign_key_values.to_numpy()
151-
152-
tables[table_name] = table[self.table_columns[table_name]]
153-
154-
return tables

0 commit comments

Comments
 (0)