diff --git a/bsuite/bsuite.py b/bsuite/bsuite.py index 073899ad..702f005c 100644 --- a/bsuite/bsuite.py +++ b/bsuite/bsuite.py @@ -100,38 +100,41 @@ def load( return EXPERIMENT_NAME_TO_ENVIRONMENT[experiment_name](**kwargs) -def load_from_id(bsuite_id: str) -> base.Environment: +def load_from_id(bsuite_id: str, loading_message: bool = True) -> base.Environment: """Returns a bsuite environment given a bsuite_id.""" kwargs = sweep.SETTINGS[bsuite_id] experiment_name, _ = unpack_bsuite_id(bsuite_id) env = load(experiment_name, kwargs) - termcolor.cprint( - f'Loaded bsuite_id: {bsuite_id}.', color='white', attrs=['bold']) + if loading_message: + termcolor.cprint( + f'Loaded bsuite_id: {bsuite_id}.', color='white', attrs=['bold']) return env def load_and_record(bsuite_id: str, save_path: str, logging_mode: str = 'csv', - overwrite: bool = False) -> dm_env.Environment: + overwrite: bool = False, + loading_message: bool = True) -> dm_env.Environment: """Returns a bsuite environment wrapped with either CSV or SQLite logging.""" if logging_mode == 'csv': - return load_and_record_to_csv(bsuite_id, save_path, overwrite) + return load_and_record_to_csv(bsuite_id, save_path, overwrite, loading_message) elif logging_mode == 'sqlite': if not save_path.endswith('.db'): save_path += '.db' if overwrite: print('WARNING: overwrite option is ignored for SQLite logging.') - return load_and_record_to_sqlite(bsuite_id, save_path) + return load_and_record_to_sqlite(bsuite_id, save_path, loading_message) elif logging_mode == 'terminal': - return load_and_record_to_terminal(bsuite_id) + return load_and_record_to_terminal(bsuite_id, loading_message) else: raise ValueError((f'Unrecognised logging_mode "{logging_mode}". ' 'Must be "csv", "sqlite", or "terminal".')) def load_and_record_to_sqlite(bsuite_id: str, - db_path: str) -> dm_env.Environment: + db_path: str, + loading_message: bool = True) -> dm_env.Environment: """Returns a bsuite environment that saves results to an SQLite database. The returned environment will automatically save the results required for @@ -153,16 +156,18 @@ def load_and_record_to_sqlite(bsuite_id: str, created if it does not already exist. When generating results using multiple different processes, specify the *same* db_path for every bsuite_id. + loading_message: boolean flag (default true) for env loading messages Returns: A bsuite environment determined by the bsuite_id. """ raw_env = load_from_id(bsuite_id) experiment_name, setting_index = unpack_bsuite_id(bsuite_id) - termcolor.cprint( - f'Logging results to SQLite database in {db_path}.', - color='yellow', - attrs=['bold']) + if loading_message: + termcolor.cprint( + f'Logging results to SQLite database in {db_path}.', + color='yellow', + attrs=['bold']) return sqlite_logging.wrap_environment( env=raw_env, db_path=db_path, @@ -173,7 +178,8 @@ def load_and_record_to_sqlite(bsuite_id: str, def load_and_record_to_csv(bsuite_id: str, results_dir: str, - overwrite: bool = False) -> dm_env.Environment: + overwrite: bool = False, + log: bool = True) -> dm_env.Environment: """Returns a bsuite environment that saves results to CSV. To load the results, specify the file path in the provided notebook, or to @@ -196,10 +202,11 @@ def load_and_record_to_csv(bsuite_id: str, A bsuite environment determined by the bsuite_id. """ raw_env = load_from_id(bsuite_id) - termcolor.cprint( - f'Logging results to CSV file for each bsuite_id in {results_dir}.', - color='yellow', - attrs=['bold']) + if log: + termcolor.cprint( + f'Logging results to CSV file for each bsuite_id in {results_dir}.', + color='yellow', + attrs=['bold']) return csv_logging.wrap_environment( env=raw_env, bsuite_id=bsuite_id, @@ -208,9 +215,10 @@ def load_and_record_to_csv(bsuite_id: str, ) -def load_and_record_to_terminal(bsuite_id: str) -> dm_env.Environment: +def load_and_record_to_terminal(bsuite_id: str, log: bool = True) -> dm_env.Environment: """Returns a bsuite environment that logs to terminal.""" raw_env = load_from_id(bsuite_id) - termcolor.cprint( - 'Logging results to terminal.', color='yellow', attrs=['bold']) + if log: + termcolor.cprint( + 'Logging results to terminal.', color='yellow', attrs=['bold']) return terminal_logging.wrap_environment(raw_env)