Skip to content

Commit 45bc645

Browse files
committed
Hotfix dtype casting issue
1 parent 4cfbce5 commit 45bc645

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

demo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@
3838
os.makedirs(result_dir, exist_ok=True)
3939
os.makedirs(groundtruth_dir, exist_ok=True)
4040
for images, masks in tqdm.tqdm(testloader, desc='Demo'):
41-
images = images.to(device)
42-
4341
# mask에 255를 곱하여 0~1 사이의 값을 0~255 값으로 변경 + 채널 차원 제거
44-
masks.mul_(255).squeeze_(dim=1).type(torch.LongTensor)
42+
masks.mul_(255).squeeze_(dim=1)
43+
44+
images, masks = images.to(device), masks.type(torch.LongTensor)
4545

4646
# 예측
4747
with torch.no_grad():

eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ def evaluate(model, testloader, criterion, num_classes: int, device):
5050
val_loss = 0
5151
inference_time = 0
5252
for images, masks in tqdm.tqdm(testloader, desc='Eval', leave=False):
53-
images, masks = images.to(device), masks.to(device, dtype=torch.int64)
54-
5553
# mask에 255를 곱하여 0~1 사이의 값을 0~255 값으로 변경 + 채널 차원 제거
5654
masks.mul_(255).squeeze_(dim=1)
5755

56+
images, masks = images.to(device), masks.to(device, dtype=torch.int64)
57+
5858
# 예측
5959
with torch.no_grad():
6060
start_time = time.time()

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@
3838
model.train()
3939

4040
for batch_idx, (images, masks) in enumerate(tqdm.tqdm(trainloader, desc='Train', leave=False)):
41-
images, masks = images.to(device), masks.to(device, dtype=torch.int64)
42-
4341
# mask에 255를 곱하여 0~1 사이의 값을 0~255 값으로 변경 + 채널 차원 제거
4442
masks.mul_(255).squeeze_(dim=1)
4543

44+
images, masks = images.to(device), masks.to(device, dtype=torch.int64)
45+
4646
# 순전파 + 역전파 + 최적화
4747
optimizer.zero_grad()
4848
masks_pred = model(images)

0 commit comments

Comments
 (0)