From 5b4078304f56c924f910dd3dd2fa2f86179de5a0 Mon Sep 17 00:00:00 2001 From: Alex Loiko Date: Thu, 19 May 2022 23:28:05 +0200 Subject: [PATCH] Update to work with pytorch 1.11.0 --- train.py | 3 ++- utils/eval.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 5c7905e..07d7e7d 100644 --- a/train.py +++ b/train.py @@ -320,7 +320,8 @@ def validate(valloader, model, criterion, epoch, use_cuda, mode): data_time.update(time.time() - end) if use_cuda: - inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) + inputs = inputs.cuda() + targets = targets.type(torch.LongTensor).cuda(non_blocking=True) # compute output outputs = model(inputs) loss = criterion(outputs, targets) diff --git a/utils/eval.py b/utils/eval.py index 5051350..1643dc5 100644 --- a/utils/eval.py +++ b/utils/eval.py @@ -13,6 +13,6 @@ def accuracy(output, target, topk=(1,)): res = [] for k in topk: - correct_k = correct[:k].view(-1).float().sum(0) + correct_k = correct[:k].reshape(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res \ No newline at end of file