-
Notifications
You must be signed in to change notification settings - Fork 55
Open
Description
Hi, I implement your DiscriminativeLoss in pytorch as below, but the clustering result seems to be bad. Could you give me some suggestions on it?
`class DiscriminativeLoss2d(nn.Module):
def __init__(self, delta_var, delta_dist, size_average=True, reduce=True, usegpu=True):
super().__init__()
self.delta_var = float(delta_var)
self.delta_dist = float(delta_dist)
self.size_average = size_average
self.reduce = reduce
self.usegpu = usegpu
def norm(self, inp, L):
if (L == 1):
n = torch.sum(torch.abs(inp), 0)
else:
n = torch.sqrt(torch.sum(torch.pow(inp, 2), 0) + 1e-8)
return n
def forward(self, inputs, targets):
b, c, h, w = inputs.size()
n_instance_maps = targets.size(1)
loss = Variable(torch.zeros(1))
if self.usegpu:
loss = loss.cuda()
for i in range(b):
input = inputs[i] # cxhxw
loss_var = 0.0
loss_dist = 0.0
for j in range(n_instance_maps):
target = targets[i][j].view(1, h, w)
means = []
loss_v = 0.0
loss_d = 0.0
# center pull force
max_id = torch.max(target.data)
for l in range(1, max_id+1):
mask = target.eq(l) # 1xhxw
mask_sum = torch.sum(mask.data)
if mask_sum > 1:
inst = input[mask.expand_as(input)].view(c, -1, 1)
# Calculate mean of instance
mean = torch.mean(inst, 1).view(c, 1, 1) # c x 1 x 1
means.append(mean)
# Calculate variance of instance
var = self.norm((inst - mean.expand_as(inst)), 2) # 1 x -1 x 1
var = torch.clamp(var - self.delta_var, min=0.0)
var = torch.pow(var, 2)
var = var.view(-1)
var = torch.mean(var)
loss_v = loss_v + var
loss_var = loss_var + loss_v
# center push force
if len(means) > 1:
for m in range(0, len(means)):
mean_A = means[m] # c x 1 x 1
for n in range(m+1, len(means)):
mean_B = means[n] # c x 1 x 1
d = self.norm(mean_A - mean_B, 2) # 1 x 1 x 1
d = torch.pow(torch.clamp(-(d - 2 * self.delta_var), min=0.0), 2)
loss_d = loss_d + d[0][0][0]
loss_dist = loss_dist + loss_d / ((len(means) - 1) + 1e-8)
loss = loss + (loss_dist + loss_var)
loss = loss / b
return loss`
Metadata
Metadata
Assignees
Labels
No labels