From 4f48474f8e888ee2d974ac345805eb14c9a38660 Mon Sep 17 00:00:00 2001 From: PeteLipp <73332106+PeteLipp@users.noreply.github.com> Date: Thu, 14 Nov 2024 19:34:59 +0100 Subject: [PATCH] option to normalize vectors in learned lframes --- tensorframes/lframes/learning_lframes.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorframes/lframes/learning_lframes.py b/tensorframes/lframes/learning_lframes.py index 4498f8e..79982c7 100644 --- a/tensorframes/lframes/learning_lframes.py +++ b/tensorframes/lframes/learning_lframes.py @@ -35,6 +35,7 @@ def __init__( fix_gravitational_axis: bool = False, gravitational_axis_index: int = 1, envelope: Union[torch.nn.Module, None] = EnvelopePoly(5), + normalize_relative_vec: bool = True, use_double_cross_product: bool = False, **mlp_kwargs: dict, ) -> None: @@ -53,6 +54,7 @@ def __init__( fix_gravitational_axis (bool, optional): Whether to fix the gravitational axis. Defaults to False. gravitational_axis_index (int, optional): The index of the gravitational axis. Defaults to 1. envelope (Union[torch.nn.Module, None], optional): The envelope module. Defaults to EnvelopePoly(5). + normalize_relative_vec (bool, optional): Whether to normalize the relative vectors. Defaults to True. use_double_cross_product (bool, optional): Whether to use the double cross product method to compute the third vector. Defaults to False. **mlp_kwargs (dict): Additional keyword arguments for the MLP. """ @@ -98,6 +100,7 @@ def __init__( self.concat_receiver = concat_receiver self.exceptional_choice = exceptional_choice self.use_double_cross_product = use_double_cross_product + self.normalize_relative_vec = normalize_relative_vec if self.cutoff is not None: self.envelope = envelope @@ -242,7 +245,8 @@ def message( relative_vec = pos_j - pos_i relative_norm = torch.clamp(torch.linalg.norm(relative_vec, dim=-1, keepdim=True), 1e-6) - relative_vec = relative_vec / relative_norm + if self.normalize_relative_vec: + relative_vec = relative_vec / relative_norm out = torch.einsum("ij,ik->ijk", mlp_out, relative_vec).reshape(-1, self.num_pred_vecs * 3)