-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreference.py
More file actions
71 lines (60 loc) · 2.36 KB
/
reference.py
File metadata and controls
71 lines (60 loc) · 2.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
def get_rigidity_map_zcxy(
field: torch.Tensor, power: float = 2, diagonal_mult: float = 1.0
) -> torch.Tensor:
# Kernel on Displacement field yields change of displacement
if field.abs().sum() == 0:
return torch.zeros(
(field.shape[0], field.shape[2], field.shape[3]), device=field.device
)
batch = field.shape[0]
diff_ker = torch.tensor(
[
[
[[0, 0, 0], [-1, 1, 0], [0, 0, 0]],
[[0, -1, 0], [0, 1, 0], [0, 0, 0]],
[[-1, 0, 0], [0, 1, 0], [0, 0, 0]],
[[0, 0, -1], [0, 1, 0], [0, 0, 0]],
]
],
dtype=field.dtype,
device=field.device,
)
diff_ker = diff_ker.permute(1, 0, 2, 3).repeat(2, 1, 1, 1)
# Add distance between pixel to get absolute displacement
diff_bias = torch.tensor(
[1.0, 0.0, 1.0, -1.0, 0.0, 1.0, 1.0, 1.0],
dtype=field.dtype,
device=field.device,
)
delta = torch.conv2d(field, diff_ker, diff_bias, groups=2, padding=[2, 2])
# delta1 = delta.reshape(2, 4, *delta.shape[-2:]).permute(1, 2, 3, 0) # original
delta = delta.reshape(batch, 2, 4, *delta.shape[-2:]).permute(0, 2, 3, 4, 1)
# spring_lengths1 = torch.norm(delta1, dim=3)
spring_lengths = torch.norm(delta, dim=-1)
spring_defs = torch.stack(
[
spring_lengths[:, 0, 1:-1, 1:-1] - 1,
spring_lengths[:, 0, 1:-1, 2:] - 1,
spring_lengths[:, 1, 1:-1, 1:-1] - 1,
spring_lengths[:, 1, 2:, 1:-1] - 1,
(spring_lengths[:, 2, 1:-1, 1:-1] - 2 ** (1 / 2))
* (diagonal_mult) ** (1 / power),
(spring_lengths[:, 2, 2:, 2:] - 2 ** (1 / 2))
* (diagonal_mult) ** (1 / power),
(spring_lengths[:, 3, 1:-1, 1:-1] - 2 ** (1 / 2))
* (diagonal_mult) ** (1 / power),
(spring_lengths[:, 3, 2:, 0:-2] - 2 ** (1 / 2))
* (diagonal_mult) ** (1 / power),
]
)
# Slightly faster than sum() + pow(), and no need for abs() if power is odd
result = torch.norm(spring_defs, p=power, dim=0).pow(power)
total = 4 + 4 * diagonal_mult
result /= total
# Remove incorrect smoothness values caused by 2px zero padding
result[..., 0:2, :] = 0
result[..., -2:, :] = 0
result[..., :, 0:2] = 0
result[..., :, -2:] = 0
return result