diff --git a/coutils/eval.py b/coutils/eval.py index b494634..353e9c6 100644 --- a/coutils/eval.py +++ b/coutils/eval.py @@ -41,17 +41,17 @@ def compute_numeric_gradient(f, x, dy=None, h=1e-5): dx = torch.zeros_like(x) # Get flattened views of everything - x_flat = x.contiguous().view(-1) - y_flat = y.contiguous().view(-1) - dx_flat = dx.contiguous().view(-1) - dy_flat = dy.contiguous().view(-1) + x_flat = torch.flatten(x) + y_flat = torch.flatten(y) + dx_flat = torch.flatten(dx) + dy_flat = torch.flatten(dy) for i in range(dx_flat.shape[0]): # Compute numeric derivative dy/dx[i] orig = x_flat[i].item() x_flat[i] = orig + h - yph = f(x).clone().view(-1) + yph = torch.flatten(f(x).clone()) x_flat[i] = orig - h - ymh = f(x).clone().view(-1) + ymh = torch.flatten(f(x).clone()) x_flat[i] = orig dy_dxi = (yph - ymh) / (2.0 * h)