Skip to content

Commit a37018b

Browse files
committed
model: Remove input channels parameter
1 parent a52efea commit a37018b

File tree

6 files changed

+12
-12
lines changed

6 files changed

+12
-12
lines changed

demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
# 2. Model
2020
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21-
model = utils.utils.get_model(model_name, 3, config['num_classes'], config['pretrained_weights']).to(device)
21+
model = utils.utils.get_model(model_name, config['num_classes'], config['pretrained_weights']).to(device)
2222

2323
# 이미지 이름 불러오기
2424
image_names = []

eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def evaluate(model, testloader, num_classes: int, device):
9595

9696
# 2. Model
9797
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
98-
model = utils.utils.get_model(model_name, 3, config['num_classes'], config['pretrained_weights']).to(device)
98+
model = utils.utils.get_model(model_name, config['num_classes'], config['pretrained_weights']).to(device)
9999

100100
# 모델 평가
101101
val_loss, iou, miou, fps = evaluate(model, testloader, config['num_classes'], device)

models/proposed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,11 @@ def forward(self, x):
9191

9292

9393
class Proposed(nn.Module):
94-
def __init__(self, num_channels: int, num_classes: int):
94+
def __init__(self, num_classes: int):
9595
super(Proposed, self).__init__()
9696
resnet50 = torchvision.models.resnet50(pretrained=True)
9797

98-
self.encode1 = self.double_conv(num_channels, 64)
98+
self.encode1 = self.double_conv(3, 64)
9999
self.encode2 = resnet50.layer1 # 256
100100
self.encode3 = resnet50.layer2 # 512
101101
self.encode4 = resnet50.layer3 # 1024
@@ -154,7 +154,7 @@ def forward(self, x):
154154

155155
if __name__ == '__main__':
156156
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
157-
model = Proposed(3, 8).to(device)
157+
model = Proposed(8).to(device)
158158
model.eval()
159159

160160
torchsummary.torchsummary.summary(model, (3, 256, 512))

models/unet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77

88
class UNet(nn.Module):
9-
def __init__(self, num_channels: int, num_classes: int):
9+
def __init__(self, num_classes: int):
1010
super(UNet, self).__init__()
1111

12-
self.encode1 = self.double_conv(num_channels, 64)
12+
self.encode1 = self.double_conv(3, 64)
1313
self.encode2 = self.double_conv(64, 128)
1414
self.encode3 = self.double_conv(128, 256)
1515
self.encode4 = self.double_conv(256, 512)
@@ -58,7 +58,7 @@ def forward(self, x):
5858

5959
if __name__ == '__main__':
6060
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
61-
model = UNet(3, 8).to(device)
61+
model = UNet(8).to(device)
6262
model.eval()
6363

6464
torchsummary.torchsummary.summary(model, (3, 256, 512))

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
# 2. Model
2121
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22-
model = utils.utils.get_model(model_name, 3, config['num_classes']).to(device)
22+
model = utils.utils.get_model(model_name, config['num_classes']).to(device)
2323

2424
# 3. Loss function, optimizer, lr scheduler
2525
criterion = nn.CrossEntropyLoss()

utils/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ def load_config():
3131
return model_name, config
3232

3333

34-
def get_model(model_name: str, num_channels: int, num_classes: int, pretrained: str = None) -> torch.nn.Module:
34+
def get_model(model_name: str, num_classes: int, pretrained: str = None) -> torch.nn.Module:
3535
if model_name == 'UNet':
36-
model = models.unet.UNet(num_channels, num_classes)
36+
model = models.unet.UNet(num_classes)
3737
elif model_name == 'Proposed':
38-
model = models.proposed.Proposed(num_channels, num_classes)
38+
model = models.proposed.Proposed(num_classes)
3939
else:
4040
raise NameError('Wrong model_name.')
4141

0 commit comments

Comments
 (0)