Skip to content

Problem with where function #28

@zarif98sjs

Description

@zarif98sjs

Shouldn't the where function be this?

def where(q, a, b):
    "Use this function to replace an if-statement."
    return (q * a) + (torch.logical_not(q)) * b

Otherwise if we use ~q, technically isn't that incorrect according to the desired function outcome?

If we used ~q,
where(arange(4) * 0, 0, 1) returns tensor([-1, -1, -1, -1]).
But the desired output should be tensor([1, 1, 1, 1])

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions