Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tensorframes/lframes/learning_lframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
fix_gravitational_axis: bool = False,
gravitational_axis_index: int = 1,
envelope: Union[torch.nn.Module, None] = EnvelopePoly(5),
normalize_relative_vec: bool = True,
use_double_cross_product: bool = False,
**mlp_kwargs: dict,
) -> None:
Expand All @@ -53,6 +54,7 @@ def __init__(
fix_gravitational_axis (bool, optional): Whether to fix the gravitational axis. Defaults to False.
gravitational_axis_index (int, optional): The index of the gravitational axis. Defaults to 1.
envelope (Union[torch.nn.Module, None], optional): The envelope module. Defaults to EnvelopePoly(5).
normalize_relative_vec (bool, optional): Whether to normalize the relative vectors. Defaults to True.
use_double_cross_product (bool, optional): Whether to use the double cross product method to compute the third vector. Defaults to False.
**mlp_kwargs (dict): Additional keyword arguments for the MLP.
"""
Expand Down Expand Up @@ -98,6 +100,7 @@ def __init__(
self.concat_receiver = concat_receiver
self.exceptional_choice = exceptional_choice
self.use_double_cross_product = use_double_cross_product
self.normalize_relative_vec = normalize_relative_vec

if self.cutoff is not None:
self.envelope = envelope
Expand Down Expand Up @@ -242,7 +245,8 @@ def message(

relative_vec = pos_j - pos_i
relative_norm = torch.clamp(torch.linalg.norm(relative_vec, dim=-1, keepdim=True), 1e-6)
relative_vec = relative_vec / relative_norm
if self.normalize_relative_vec:
relative_vec = relative_vec / relative_norm

out = torch.einsum("ij,ik->ijk", mlp_out, relative_vec).reshape(-1, self.num_pred_vecs * 3)

Expand Down