Skip to content

Loss dimensions #7

@Cram3r95

Description

@Cram3r95

Hi, thank you so much for your fantastic work.

Which is the order, and the dimensions, in this function?

def l1_ewta_loss(prediction, target, k=6, eps=1e-7, mr=2.0):
    num_mixtures = prediction.shape[1]

    target = target.unsqueeze(1).expand(-1, num_mixtures, -1, -1)
    l1_loss = nn.functional.l1_loss(prediction, target, reduction='none').sum(dim=[2, 3])

    # Get loss from top-k mixtures for each timestep
    mixture_loss_sorted, mixture_ranks = torch.sort(l1_loss, descending=False)
    mixture_loss_topk = mixture_loss_sorted.narrow(1, 0, k)

    # Aggregate loss across timesteps and batch
    loss = mixture_loss_topk.sum()
    loss = loss / target.size(0)
    loss = loss / target.size(2)
    loss = loss / k
    return loss

I am not able to obtain good results compared to NLL. I have as inputs:

predictions: batch_size x num_modes x pred_len x data_dim (e.g. 1024 x 6 x 30 x 2)
gt: batch_size x pred_len x data_dim (e.g. 1024 x 30 x 2)

Is this correct?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions