diff --git a/models/enet.py b/models/enet.py index ffadcfd..cecaca7 100644 --- a/models/enet.py +++ b/models/enet.py @@ -43,7 +43,7 @@ def __init__(self, # the extension branch self.main_branch = nn.Conv2d( in_channels, - out_channels - 3, + out_channels - in_channels, kernel_size=3, stride=2, padding=1, @@ -478,10 +478,10 @@ class ENet(nn.Module): """ - def __init__(self, num_classes, encoder_relu=False, decoder_relu=True): + def __init__(self, in_channels, num_classes, encoder_relu=False, decoder_relu=True): super().__init__() - self.initial_block = InitialBlock(3, 16, relu=encoder_relu) + self.initial_block = InitialBlock(in_channels, 16, relu=encoder_relu) # Stage 1 - Encoder self.downsample1_0 = DownsamplingBottleneck(