Skip to content

Commit a52efea

Browse files
committed
Use inplace operations
1 parent 6ebdf4c commit a52efea

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

demo.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@
3333
os.makedirs(result_dir, exist_ok=True)
3434
os.makedirs(groundtruth_dir, exist_ok=True)
3535
for images, masks in tqdm.tqdm(testloader, desc='Demo'):
36-
images = images.to(device)
36+
# mask에 255를 곱하여 0~1 사이의 값을 0~255 값으로 변경 + 채널 차원 제거
37+
masks.mul_(255).squeeze_(dim=1)
38+
39+
# 이미지와 정답 정보를 GPU로 복사
40+
images, masks = images.to(device), masks.type(torch.LongTensor)
3741

3842
# 예측
3943
with torch.no_grad():
@@ -45,5 +49,5 @@
4549
assert masks.shape[0] == masks_pred.shape[0]
4650
for i in range(masks.shape[0]):
4751
plt.imsave(os.path.join(result_dir, image_names[step]), masks_pred[i].cpu())
48-
plt.imsave(os.path.join(groundtruth_dir, image_names[step]), masks[i].squeeze())
52+
plt.imsave(os.path.join(groundtruth_dir, image_names[step]), masks[i])
4953
step += 1

eval.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,10 @@ def evaluate(model, testloader, num_classes: int, device):
5151
inference_time = 0
5252
for images, masks in tqdm.tqdm(testloader, desc='Eval', leave=False):
5353
# mask에 255를 곱하여 0~1 사이의 값을 0~255 값으로 변경 + 채널 차원 제거
54-
masks = torch.mul(masks, 255)
55-
masks = torch.squeeze(masks, dim=1)
54+
masks.mul_(255).squeeze_(dim=1)
5655

5756
# 이미지와 정답 정보를 GPU로 복사
58-
images, masks = images.to(device), masks.to(device, dtype=torch.long)
57+
images, masks = images.to(device), masks.to(device, dtype=torch.int64)
5958

6059
# 예측
6160
with torch.no_grad():

train.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,10 @@
3838

3939
for batch_idx, (images, masks) in enumerate(tqdm.tqdm(trainloader, desc='Train', leave=False)):
4040
# mask에 255를 곱하여 0~1 사이의 값을 0~255 값으로 변경 + 채널 차원 제거
41-
masks = torch.mul(masks, 255)
42-
masks = torch.squeeze(masks, dim=1)
41+
masks.mul_(255).squeeze_(dim=1)
4342

4443
# 이미지와 정답 정보를 GPU로 복사
45-
images, masks = images.to(device), masks.to(device, dtype=torch.long)
44+
images, masks = images.to(device), masks.to(device, dtype=torch.int64)
4645

4746
# 순전파 + 역전파 + 최적화
4847
optimizer.zero_grad()

0 commit comments

Comments
 (0)