Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ count = True
max-line-length = 88
statistics = True
ignore = E731,W503,E741,E203
exclude = setup.py
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
install_requires=[
"numpy>=1.11, <2.0",
"pyyaml>=5.1, <6.0",
"rich>=10.0.0, <30.0.0",
"scikit-learn>=0.20.0, <0.30.0",
"scipy>=1.1.0, <2.0.0",
"tensorboard>=1.15.0, <3.0.0",
Expand Down
5 changes: 3 additions & 2 deletions src/emmental/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
import yaml
from rich.logging import RichHandler

from emmental.utils.seed import set_random_seed
from emmental.utils.utils import merge
Expand Down Expand Up @@ -110,12 +111,12 @@ def init_logging(
level=level,
handlers=[
logging.FileHandler(os.path.join(log_path, log_name)),
logging.StreamHandler(),
RichHandler(),
],
)
else:
logging.basicConfig(
format=format, level=logging.WARN, handlers=[logging.StreamHandler()]
format=format, level=logging.WARN, handlers=[RichHandler()]
)

# Notify user of log location
Expand Down
27 changes: 16 additions & 11 deletions src/emmental/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Emmental model."""
import importlib
import itertools
import logging
import os
Expand All @@ -10,6 +9,7 @@
import numpy as np
import torch
from numpy import ndarray
from rich.progress import BarColumn, Progress, TimeElapsedColumn
from torch import Tensor, nn
from torch.nn import ModuleDict

Expand All @@ -24,11 +24,6 @@
prob_to_pred,
)

if importlib.util.find_spec("ipywidgets") is not None:
from tqdm.auto import tqdm
else:
from tqdm import tqdm

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -490,12 +485,21 @@ def predict(
task_to_label_dict = dataloader.task_to_label_dict
uid = dataloader.uid

with torch.no_grad():
for batch_num, bdict in tqdm(
enumerate(dataloader),
# Set progress bar
progress = Progress(
"[progress.description]{task.description}",
"[magenta]{task.completed}/{task.total}",
BarColumn(),
TimeElapsedColumn(),
disable=not Meta.config["meta_config"]["verbose"],
)

with torch.no_grad(), progress:
task = progress.add_task(
f"Evaluating {dataloader.data_name} ({dataloader.split})",
total=len(dataloader),
desc=f"Evaluating {dataloader.data_name} ({dataloader.split})",
):
)
for bdict in dataloader:
if isinstance(bdict, dict) == 1:
X_bdict = bdict
Y_bdict = None
Expand Down Expand Up @@ -562,6 +566,7 @@ def predict(
out_dict[task_name][action_name].extend(
out_bdict[task_name][action_name]
)
progress.update(task, advance=1)

# Calculate average loss
if dataloader.is_learnable:
Expand Down