diff --git a/src/Restorer/Cond_NAF.py b/src/Restorer/Cond_NAF.py index 1125db8..b87700a 100644 --- a/src/Restorer/Cond_NAF.py +++ b/src/Restorer/Cond_NAF.py @@ -65,30 +65,48 @@ def forward(self, x, conditioning): return ca +# class CondFuser(nn.Module): +# def __init__(self, chan, cond_chan=1): +# super().__init__() +# self.cca = ConditionedChannelAttention(chan * 2, cond_chan) +# # self.spa = nn.Conv2d( +# # in_channels=chan * 2, +# # out_channels=1, +# # kernel_size=3, +# # padding=1, +# # stride=1, +# # groups=1, +# # bias=True, +# # ) + +# def forward(self, x1, x2, cond): +# x = torch.cat([x1, x2], dim=1) +# x = self.cca(x, cond) * x +# # spa = torch.sigmoid(self.spa(x)) + +# x1, x2 = x.chunk(2, dim=1) +# # return x1 * spa + x2 * (1 - spa) +# return x1 + x2 + + class CondFuser(nn.Module): def __init__(self, chan, cond_chan=1): super().__init__() self.cca = ConditionedChannelAttention(chan * 2, cond_chan) - # self.spa = nn.Conv2d( - # in_channels=chan * 2, - # out_channels=1, - # kernel_size=3, - # padding=1, - # stride=1, - # groups=1, - # bias=True, - # ) + self.sig = nn.Sigmoid() + self.sa = nn.Sequential( + nn.Conv2d(in_channels = 2 * chan, out_channels=chan, kernel_size=3, + padding=1, stride=1, + groups=1, bias=True), + nn.Sigmoid() + ) + def forward(self, x1, x2, cond): x = torch.cat([x1, x2], dim=1) - x = self.cca(x, cond) * x - # spa = torch.sigmoid(self.spa(x)) - - x1, x2 = x.chunk(2, dim=1) - # return x1 * spa + x2 * (1 - spa) + x2 = 1 * self.sig(self.cca(x)) * self.sa(x) * x2 return x1 + x2 - class NKA(nn.Module): def __init__(self, dim, channel_reduction = 8): super().__init__()