diff --git a/.flake8 b/.flake8 index 5d8afa7..bbf8c58 100644 --- a/.flake8 +++ b/.flake8 @@ -8,3 +8,4 @@ count = True max-line-length = 88 statistics = True ignore = E731,W503,E741,E203 +exclude = setup.py diff --git a/setup.py b/setup.py index 0af4bcf..3c4dca7 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ install_requires=[ "numpy>=1.11, <2.0", "pyyaml>=5.1, <6.0", + "rich>=10.0.0, <30.0.0", "scikit-learn>=0.20.0, <0.30.0", "scipy>=1.1.0, <2.0.0", "tensorboard>=1.15.0, <3.0.0", diff --git a/src/emmental/meta.py b/src/emmental/meta.py index f6ec90a..93cab47 100644 --- a/src/emmental/meta.py +++ b/src/emmental/meta.py @@ -10,6 +10,7 @@ import torch import yaml +from rich.logging import RichHandler from emmental.utils.seed import set_random_seed from emmental.utils.utils import merge @@ -110,12 +111,12 @@ def init_logging( level=level, handlers=[ logging.FileHandler(os.path.join(log_path, log_name)), - logging.StreamHandler(), + RichHandler(), ], ) else: logging.basicConfig( - format=format, level=logging.WARN, handlers=[logging.StreamHandler()] + format=format, level=logging.WARN, handlers=[RichHandler()] ) # Notify user of log location diff --git a/src/emmental/model.py b/src/emmental/model.py index 2a79dd8..120c6e4 100644 --- a/src/emmental/model.py +++ b/src/emmental/model.py @@ -1,5 +1,4 @@ """Emmental model.""" -import importlib import itertools import logging import os @@ -10,6 +9,7 @@ import numpy as np import torch from numpy import ndarray +from rich.progress import BarColumn, Progress, TimeElapsedColumn from torch import Tensor, nn from torch.nn import ModuleDict @@ -24,11 +24,6 @@ prob_to_pred, ) -if importlib.util.find_spec("ipywidgets") is not None: - from tqdm.auto import tqdm -else: - from tqdm import tqdm - logger = logging.getLogger(__name__) @@ -490,12 +485,21 @@ def predict( task_to_label_dict = dataloader.task_to_label_dict uid = dataloader.uid - with torch.no_grad(): - for batch_num, bdict in tqdm( - enumerate(dataloader), + # Set progress bar + progress = Progress( + "[progress.description]{task.description}", + "[magenta]{task.completed}/{task.total}", + BarColumn(), + TimeElapsedColumn(), + disable=not Meta.config["meta_config"]["verbose"], + ) + + with torch.no_grad(), progress: + task = progress.add_task( + f"Evaluating {dataloader.data_name} ({dataloader.split})", total=len(dataloader), - desc=f"Evaluating {dataloader.data_name} ({dataloader.split})", - ): + ) + for bdict in dataloader: if isinstance(bdict, dict) == 1: X_bdict = bdict Y_bdict = None @@ -562,6 +566,7 @@ def predict( out_dict[task_name][action_name].extend( out_bdict[task_name][action_name] ) + progress.update(task, advance=1) # Calculate average loss if dataloader.is_learnable: