Skip to content

loss function implemented in pytorch #10

@boundles

Description

@boundles

Hi, I implement your DiscriminativeLoss in pytorch as below, but the clustering result seems to be bad. Could you give me some suggestions on it?

`class DiscriminativeLoss2d(nn.Module):

def __init__(self, delta_var, delta_dist, size_average=True, reduce=True, usegpu=True):
    super().__init__()

    self.delta_var = float(delta_var)
    self.delta_dist = float(delta_dist)
    
    self.size_average = size_average
    self.reduce = reduce        
    self.usegpu = usegpu

def norm(self, inp, L):
    if (L == 1):
        n = torch.sum(torch.abs(inp), 0)
    else:
        n = torch.sqrt(torch.sum(torch.pow(inp, 2), 0) + 1e-8)
    return n

def forward(self, inputs, targets):
    b, c, h, w = inputs.size()
    n_instance_maps = targets.size(1)
    
    loss = Variable(torch.zeros(1))
    if self.usegpu:
        loss = loss.cuda()

    for i in range(b):
        input = inputs[i] # cxhxw
        loss_var = 0.0
        loss_dist = 0.0

        for j in range(n_instance_maps):
            target = targets[i][j].view(1, h, w)
            means = []
            loss_v = 0.0
            loss_d = 0.0

            # center pull force
            max_id = torch.max(target.data)
            for l in range(1, max_id+1):
                mask = target.eq(l) # 1xhxw
                mask_sum = torch.sum(mask.data)
                if mask_sum > 1:
                    inst = input[mask.expand_as(input)].view(c, -1, 1)
                    
                    # Calculate mean of instance
                    mean = torch.mean(inst, 1).view(c, 1, 1) # c x 1 x 1
                    means.append(mean)
                    
                    # Calculate variance of instance
                    var = self.norm((inst - mean.expand_as(inst)), 2) # 1 x -1 x 1
                    var = torch.clamp(var - self.delta_var, min=0.0)

                    var = torch.pow(var, 2)
                    var = var.view(-1)

                    var = torch.mean(var)
                    loss_v = loss_v + var

            loss_var = loss_var + loss_v

            # center push force
            if len(means) > 1:
                for m in range(0, len(means)):
                    mean_A = means[m] # c x 1 x 1
                    for n in range(m+1, len(means)):
                        mean_B = means[n] # c x 1 x 1
                        d = self.norm(mean_A - mean_B, 2) # 1 x 1 x 1
                        d = torch.pow(torch.clamp(-(d - 2 * self.delta_var), min=0.0), 2)
                        loss_d = loss_d + d[0][0][0]
                
                loss_dist = loss_dist + loss_d / ((len(means) - 1) + 1e-8)

        loss = loss + (loss_dist + loss_var)

    loss = loss / b

    return loss` 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions