diff --git a/README.md b/README.md index c28dee3..450800d 100644 --- a/README.md +++ b/README.md @@ -75,3 +75,96 @@ The list can be found in the `configs/data/chebi50_graph_properties.yml` file. ```bash python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/csv_logger.yml --model=../python-chebai-graph/configs/model/gnn_res_gated.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml ``` + +## Augmented Graphs + +Below is the command for the model and data configuration that achieved the best classification performance using augmented graphs. + + +```bash +python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.config.v2=True --data=../python-chebai-graph/configs/data/chebi50_aug_prop_as_per_node.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gatv2_amg_s0 +``` + +### Model Hyperparameters + +#### **GAT Architecture** + +To use a GAT-based model, choose **one** of the following configs: + +- **Atom–Motif–Graph Node Pooling**: `--model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml` +- **Atom-Augmented Node Pooling**: `--model=../python-chebai-graph/configs/model/gat_aug_aagpool.yml` +- **Standard Pooling**: `--model=../python-chebai-graph/configs/model/gat.yml` + +#### GAT-specific hyperparameters + +- **Number of message-passing layers**: `--model.config.num_layers=5` (default: 4) +- **Attention heads**: `--model.config.heads=4` (default: 8) + > Note: The number of heads should be divisible by the output channels (or hidden channels if output channels are not specified). +- **Use GATv2**: `--model.config.v2=True` (default: False) + +#### **ResGated Architecture** + +To use a ResGated GNN model, choose **one** of the following configs: + +- **Atom–Motif–Graph Node Pooling**: `--model=../python-chebai-graph/configs/model/res_aug_amgpool.yml` +- **Atom-Augmented Node Pooling**: `--model=../python-chebai-graph/configs/model/res_aug_aagpool.yml` +- **Standard Pooling**: `--model=../python-chebai-graph/configs/model/resgated.yml` + +#### **Common Hyperparameters** + +These can be used for both GAT and ResGated architectures: + +- **Dropout**: `--model.config.dropout=0.1` (default: 0) +- **Number of final linear layers**: `--model.n_linear_layers=2` (default: 1) + +# Random Node Initialization + +## Static Node Initialization + +In this type of node initialization, the node features (and/or edge features) of the given molecular graph are initialized only once during dataset creation with the given initialization scheme. + +```bash +python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/resgated.yml --model.config.in_channels=203 --model.config.edge_dim=11 --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.pad_node_features=45 --data.pad_edge_features=4 --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --data.init_args.persistent_workers=False --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gni_res_props+zeros_s0 +``` + +In the above command, for each node we use the 158 node features (corresponding the node properties defined in `chebi50_graph_properties.yml`) which are retrieved from RDKit and add 45 additional features (specified by `--data.pad_node_features=45`) drawn from a normal distribution (default). + +You can change the distribution using the following config in above command: `--data.distribution=zeros` + +Available distributions: `"normal", "uniform", "xavier_normal", "xavier_uniform", "zeros"` + + +Similarly, each edge is initialized with 7 RDKit features and 4 additional features drawn from the given distribution. + + +If you want all node (and edge) features to be drawn from a given distribution (i.e., ignore RDKit features), use: `--data=../python-chebai-graph/configs/data/chebi50_static_gni.yml` + + +Refer to the data class code for details. + + +## Dynamic Node Initialization + +In this type of node initialization, the node features (and/or edge features) of the molecular graph are initialized at **each forward pass** of the model using the given initialization scheme. + + + +Currently, dynamic node initialization is implemented only for the **resgated** architecture by specifying: `--model=../python-chebai-graph/configs/model/resgated_dynamic_gni.yml` + +To keep RDKit features and *add* dynamically initialized features use the following config in the command: + +``` +--model.config.complete_randomness=False +--model.config.pad_node_features=45 +``` + +The additional features are drawn from normal distribution (default). You can change it using:`--model.config.distribution=uniform` + +If all features should be initialized from the given distribution, remove the complete_randomness flag (default is True). + + +Please find below the command for a typical dynamic node initialization: + +```bash +python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/resgated_dynamic_gni.yml --model.config.in_channels=203 --model.config.edge_dim=11 --model.config.complete_randomness=False --model.config.pad_node_features=45 --model.config.pad_edge_features=4 --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --data.init_args.persistent_workers=False --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gni_dres_props+rand_s0 +``` diff --git a/chebai_graph/models/dynamic_gni.py b/chebai_graph/models/dynamic_gni.py index 20a6c6a..bb48571 100644 --- a/chebai_graph/models/dynamic_gni.py +++ b/chebai_graph/models/dynamic_gni.py @@ -1,3 +1,23 @@ +""" +ResGatedDynamicGNIGraphPred +------------------------------------------------ + +Module providing a ResGated GNN model that applies Random Node Initialization +(RNI) dynamically at each forward pass. This follows the approach from: + +Abboud, R., et al. (2020). "The surprising power of graph neural networks with +random node initialization." arXiv preprint arXiv:2010.01179. + +The module exposes: +- ResGatedDynamicGNI: a model that can either completely replace node/edge + features with random tensors each forward pass or pad existing features with + additional random features. +- ResGatedDynamicGNIGraphPred: a thin wrapper that instantiates the above for + graph-level prediction pipelines. +""" + +__all__ = ["ResGatedDynamicGNIGraphPred"] + from typing import Any import torch @@ -14,12 +34,37 @@ class ResGatedDynamicGNI(GraphModelBase): """ - Base model class for applying ResGatedGraphConv layers to graph-structured data - with dynamic initialization of features for nodes and edges. - - Args: - config (dict): Configuration dictionary containing model hyperparameters. - **kwargs: Additional keyword arguments for parent class. + ResGated GNN with dynamic Random Node Initialization (RNI). + + This model supports two modes controlled by the `config`: + + - complete_randomness (bool-like): If True, **replace** node and edge + features entirely with randomly initialized tensors each forward pass. + If False, the model **pads** existing features with extra randomly + initialized features on-the-fly. + + - pad_node_features (int, optional): Number of random columns to append + to each node feature vector when `complete_randomness` is False. + + - pad_edge_features (int, optional): Number of random columns to append + to each edge feature vector when `complete_randomness` is False. + + - distribution (str): Distribution for random initialization. Must be one + of RandomFeatureInitializationReader.DISTRIBUTIONS. + + Parameters + ---------- + config : Dict[str, Any] + Configuration dictionary containing model hyperparameters. Expected keys + used by this class: + - distribution (optional, default "normal") + - complete_randomness (optional, default "True") + - pad_node_features (optional, int) + - pad_edge_features (optional, int) + Keys required by GraphModelBase (e.g., in_channels, hidden_channels, + out_channels, num_layers, edge_dim) should also be present. + **kwargs : Any + Additional keyword arguments forwarded to GraphModelBase. """ def __init__(self, config: dict[str, Any], **kwargs: Any): @@ -96,6 +141,8 @@ def forward(self, batch: dict[str, Any]) -> Tensor: new_x = None new_edge_attr = None + + # If replacing features entirely with random values if self.complete_randomness: new_x = torch.empty( graph_data.x.shape[0], graph_data.x.shape[1], device=self.device @@ -110,6 +157,8 @@ def forward(self, batch: dict[str, Any]) -> Tensor: RandomFeatureInitializationReader.random_gni( new_edge_attr, self.distribution ) + + # If padding existing features with additional random columns else: if self.pad_node_features is not None: pad_node = torch.empty( diff --git a/chebai_graph/preprocessing/reader/static_gni.py b/chebai_graph/preprocessing/reader/static_gni.py index 106c528..da9f847 100644 --- a/chebai_graph/preprocessing/reader/static_gni.py +++ b/chebai_graph/preprocessing/reader/static_gni.py @@ -1,19 +1,70 @@ """ -Abboud, Ralph, et al. -"The surprising power of graph neural networks with random node initialization." -arXiv preprint arXiv:2010.01179 (2020). +RandomFeatureInitializationReader +-------------------------------- -Code Reference: https://github.com/ralphabb/GNN-RNI/blob/main/GNNHyb.py +Implements random node / edge / molecule feature initialization for graph neural +networks following: + +Abboud, R., et al. (2020). "The surprising power of graph neural networks with +random node initialization." arXiv preprint arXiv:2010.01179. + +Code reference: https://github.com/ralphabb/GNN-RNI/blob/main/GNNHyb.py + +This module provides a reader that replaces node/edge/molecule features with +randomly initialized tensors drawn from a selected distribution. + +Notes +----- +- This reader subclasses GraphPropertyReader and is intended to be used where a + graph object with attributes `x`, `edge_attr`, and optionally `molecule_attr` + is expected (e.g., `torch_geometric.data.Data`). +- The reader only performs random initialization and does not support reading + specific properties from the input data. """ +from typing import Any, Optional + import torch +from torch import Tensor from torch_geometric.data import Data as GeomData from .reader import GraphPropertyReader class RandomFeatureInitializationReader(GraphPropertyReader): - DISTRIBUTIONS = ["normal", "uniform", "xavier_normal", "xavier_uniform", "zeros"] + """ + Reader that initializes node, bond (edge), and molecule features with + random values according to a chosen distribution. + + Supported distributions: + - "normal" : standard normal (mean=0, std=1) + - "uniform" : uniform in [-1, 1] + - "xavier_normal" : Xavier normal initialization + - "xavier_uniform" : Xavier uniform initialization + - "zeros" : all zeros + + Parameters + ---------- + num_node_properties : int + Number of features to generate per node. + num_bond_properties : int + Number of features to generate per edge/bond. + num_molecule_properties : int + Number of global molecule-level features to generate. + distribution : str, optional + One of the supported distributions (default: "normal"). + *args, **kwargs : Any + Additional positional and keyword arguments passed to the parent + GraphPropertyReader. + """ + + DISTRIBUTIONS = [ + "normal", + "uniform", + "xavier_normal", + "xavier_uniform", + "zeros", + ] def __init__( self, @@ -21,27 +72,56 @@ def __init__( num_bond_properties: int, num_molecule_properties: int, distribution: str = "normal", - *args, - **kwargs, - ): + *args: Any, + **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) - self.num_node_properties = num_node_properties - self.num_bond_properties = num_bond_properties - self.num_molecule_properties = num_molecule_properties - assert distribution in self.DISTRIBUTIONS - self.distribution = distribution + if distribution not in self.DISTRIBUTIONS: + raise ValueError( + f"distribution must be one of {self.DISTRIBUTIONS}, got '{distribution}'" + ) + + self.num_node_properties: int = int(num_node_properties) + self.num_bond_properties: int = int(num_bond_properties) + self.num_molecule_properties: int = int(num_molecule_properties) + self.distribution: str = distribution def name(self) -> str: """ - Get the name identifier of the reader. + Return a human-readable identifier for this reader configuration. + + Returns + ------- + str + A name encoding the chosen distribution and generated feature sizes. + """ + return ( + f"gni-{self.distribution}" + f"-node{self.num_node_properties}" + f"-bond{self.num_bond_properties}" + f"-mol{self.num_molecule_properties}" + ) - Returns: - str: The name of the reader. + def _read_data(self, raw_data: Any) -> Optional[GeomData]: """ - return f"gni-{self.distribution}-node{self.num_node_properties}-bond{self.num_bond_properties}-mol{self.num_molecule_properties}" + Read and return a `torch_geometric.data.Data` object with randomized + node/edge/molecule features. - def _read_data(self, raw_data): - data: GeomData = super()._read_data(raw_data) + This method calls the parent's `_read_data` to obtain a graph object, + then replaces `x`, `edge_attr` and sets `molecule_attr` with new tensors. + + Parameters + ---------- + raw_data : Any + Raw input that the parent reader understands. + + Returns + ------- + Optional[GeomData] + A `Data` object with randomized attributes or `None` if the parent + `_read_data` returned `None`. + """ + data: Optional[GeomData] = super()._read_data(raw_data) if data is None: return None @@ -51,24 +131,55 @@ def _read_data(self, raw_data): ) random_molecule_properties = torch.empty(1, self.num_molecule_properties) + # Initialize them according to the chosen distribution. self.random_gni(random_x, self.distribution) self.random_gni(random_edge_attr, self.distribution) self.random_gni(random_molecule_properties, self.distribution) + # Assign randomized attributes back to the data object. data.x = random_x data.edge_attr = random_edge_attr + # Use `molecule_attr` as the name in this codebase; if your Data object + # expects a different name (e.g., `u` or `global_attr`) adapt accordingly. data.molecule_attr = random_molecule_properties + return data - def read_property(self, *args, **kwargs) -> Exception: - """This reader does not support reading specific properties.""" - raise NotImplementedError("This reader only performs random initialization.") + def read_property(self, *args: Any, **kwargs: Any) -> None: + """ + This reader does not support reading specific properties from the input. + It only performs random initialization of features. + + Raises + ------ + NotImplementedError + Always raised to indicate unsupported operation. + """ + raise NotImplementedError( + "RandomFeatureInitializationReader only performs random initialization." + ) @staticmethod - def random_gni(tensor: torch.Tensor, distribution: str) -> None: + def random_gni(tensor: Tensor, distribution: str) -> None: + """ + Fill `tensor` in-place according to the requested initialization. + + Parameters + ---------- + tensor : torch.Tensor + The tensor to initialize in-place. + distribution : str + One of the supported distribution identifiers. + + Raises + ------ + ValueError + If an unknown distribution string is provided. + """ if distribution == "normal": torch.nn.init.normal_(tensor) elif distribution == "uniform": + # Uniform in [-1, 1] torch.nn.init.uniform_(tensor, a=-1.0, b=1.0) elif distribution == "xavier_normal": torch.nn.init.xavier_normal_(tensor) @@ -77,4 +188,4 @@ def random_gni(tensor: torch.Tensor, distribution: str) -> None: elif distribution == "zeros": torch.nn.init.zeros_(tensor) else: - raise ValueError("Unknown distribution type") + raise ValueError(f"Unknown distribution type: '{distribution}'")