Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions bsuite/bsuite.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,38 +100,41 @@ def load(
return EXPERIMENT_NAME_TO_ENVIRONMENT[experiment_name](**kwargs)


def load_from_id(bsuite_id: str) -> base.Environment:
def load_from_id(bsuite_id: str, loading_message: bool = True) -> base.Environment:
"""Returns a bsuite environment given a bsuite_id."""
kwargs = sweep.SETTINGS[bsuite_id]
experiment_name, _ = unpack_bsuite_id(bsuite_id)
env = load(experiment_name, kwargs)
termcolor.cprint(
f'Loaded bsuite_id: {bsuite_id}.', color='white', attrs=['bold'])
if loading_message:
termcolor.cprint(
f'Loaded bsuite_id: {bsuite_id}.', color='white', attrs=['bold'])
return env


def load_and_record(bsuite_id: str,
save_path: str,
logging_mode: str = 'csv',
overwrite: bool = False) -> dm_env.Environment:
overwrite: bool = False,
loading_message: bool = True) -> dm_env.Environment:
"""Returns a bsuite environment wrapped with either CSV or SQLite logging."""
if logging_mode == 'csv':
return load_and_record_to_csv(bsuite_id, save_path, overwrite)
return load_and_record_to_csv(bsuite_id, save_path, overwrite, loading_message)
elif logging_mode == 'sqlite':
if not save_path.endswith('.db'):
save_path += '.db'
if overwrite:
print('WARNING: overwrite option is ignored for SQLite logging.')
return load_and_record_to_sqlite(bsuite_id, save_path)
return load_and_record_to_sqlite(bsuite_id, save_path, loading_message)
elif logging_mode == 'terminal':
return load_and_record_to_terminal(bsuite_id)
return load_and_record_to_terminal(bsuite_id, loading_message)
else:
raise ValueError((f'Unrecognised logging_mode "{logging_mode}". '
'Must be "csv", "sqlite", or "terminal".'))


def load_and_record_to_sqlite(bsuite_id: str,
db_path: str) -> dm_env.Environment:
db_path: str,
loading_message: bool = True) -> dm_env.Environment:
"""Returns a bsuite environment that saves results to an SQLite database.

The returned environment will automatically save the results required for
Expand All @@ -153,16 +156,18 @@ def load_and_record_to_sqlite(bsuite_id: str,
created if it does not already exist. When generating results using
multiple different processes, specify the *same* db_path for every
bsuite_id.
loading_message: boolean flag (default true) for env loading messages

Returns:
A bsuite environment determined by the bsuite_id.
"""
raw_env = load_from_id(bsuite_id)
experiment_name, setting_index = unpack_bsuite_id(bsuite_id)
termcolor.cprint(
f'Logging results to SQLite database in {db_path}.',
color='yellow',
attrs=['bold'])
if loading_message:
termcolor.cprint(
f'Logging results to SQLite database in {db_path}.',
color='yellow',
attrs=['bold'])
return sqlite_logging.wrap_environment(
env=raw_env,
db_path=db_path,
Expand All @@ -173,7 +178,8 @@ def load_and_record_to_sqlite(bsuite_id: str,

def load_and_record_to_csv(bsuite_id: str,
results_dir: str,
overwrite: bool = False) -> dm_env.Environment:
overwrite: bool = False,
log: bool = True) -> dm_env.Environment:
"""Returns a bsuite environment that saves results to CSV.

To load the results, specify the file path in the provided notebook, or to
Expand All @@ -196,10 +202,11 @@ def load_and_record_to_csv(bsuite_id: str,
A bsuite environment determined by the bsuite_id.
"""
raw_env = load_from_id(bsuite_id)
termcolor.cprint(
f'Logging results to CSV file for each bsuite_id in {results_dir}.',
color='yellow',
attrs=['bold'])
if log:
termcolor.cprint(
f'Logging results to CSV file for each bsuite_id in {results_dir}.',
color='yellow',
attrs=['bold'])
return csv_logging.wrap_environment(
env=raw_env,
bsuite_id=bsuite_id,
Expand All @@ -208,9 +215,10 @@ def load_and_record_to_csv(bsuite_id: str,
)


def load_and_record_to_terminal(bsuite_id: str) -> dm_env.Environment:
def load_and_record_to_terminal(bsuite_id: str, log: bool = True) -> dm_env.Environment:
"""Returns a bsuite environment that logs to terminal."""
raw_env = load_from_id(bsuite_id)
termcolor.cprint(
'Logging results to terminal.', color='yellow', attrs=['bold'])
if log:
termcolor.cprint(
'Logging results to terminal.', color='yellow', attrs=['bold'])
return terminal_logging.wrap_environment(raw_env)