From 408c93e296d77eb21f9adb5c6e2c1cc3c29deea4 Mon Sep 17 00:00:00 2001 From: Thomas Macrina Date: Fri, 21 Aug 2020 17:26:52 -0400 Subject: [PATCH] Swap from torch.mm to torch.matmul for broadcasting --- torchfields/fields.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchfields/fields.py b/torchfields/fields.py index f37cc33..d0144b0 100644 --- a/torchfields/fields.py +++ b/torchfields/fields.py @@ -332,7 +332,7 @@ def affine_field(cls, aff, size, offset=(0., 0.), device=None, dtype=None): [0., 1., -offset[1]], [0., 0., 1.]], device=device) Bi = Bi.expand(N, *Bi.shape) - aff = torch.mm(Bi, torch.mm(A, B))[:, :2] + aff = torch.matmul(Bi, torch.matmul(A, B))[:, :2] M = F.affine_grid(aff, size, align_corners=False) # Id is an identity mapping without the overhead of `identity_mapping` id_aff = tensor_type([[1, 0, 0], [0, 1, 0]], device=device)