diff --git a/flyvis/solver.py b/flyvis/solver.py index 12973fb..cccc40d 100644 --- a/flyvis/solver.py +++ b/flyvis/solver.py @@ -353,14 +353,14 @@ def handle_batch(data, steady_state): loss = loss.detach().cpu() for task in self.task.dataset.tasks: loss_per_task[f"loss_{task}"].append( - losses[task].detach().cpu() + losses[task].detach().cpu().item() ) - loss_over_iters.append(loss) + loss_over_iters.append(loss.item()) activity = activity.detach().cpu() mean_activity = activity.mean() - activity_over_iters.append(mean_activity) - activity_min_over_iters.append(activity.min()) - activity_max_over_iters.append(activity.max()) + activity_over_iters.append(mean_activity.item()) + activity_min_over_iters.append(activity.min().item()) + activity_max_over_iters.append(activity.max().item()) return loss, mean_activity # Call closure.