diff --git a/chebai/preprocessing/bin/smiles_token/tokens.txt b/chebai/preprocessing/bin/smiles_token/tokens.txt index 92d0b77a..9ce39f9d 100644 --- a/chebai/preprocessing/bin/smiles_token/tokens.txt +++ b/chebai/preprocessing/bin/smiles_token/tokens.txt @@ -819,3 +819,168 @@ p [16N] [17N] [14N] +[Pb+2] +[AlH4-] +[BH4-] +[Pt-2] +[Cl+2] +[I+3] +[Br+2] +[Cl+3] +[Os-2] +[Cr-2] +[Hg-2] +[PH] +[Br+3] +[I+2] +[AsH2] +[SH] +[W-2] +[Cd-2] +[Ir-2] +[Ru-2] +[Rh-2] +[Ag-2] +[Be-2] +[TeH2+] +[13c] +[13cH] +[PH4] +[AsH4] +[As-2] +[SbH3+] +[SbH4] +[BiH3] +[BH3-] +[GeH3] +[GeH2] +[SiH2-] +[SiH2+] +[SnH2] +[SnH3] +[SnH] +[PbH] +[PbH3] +[Al-2] +[B+2] +[N+2] +[SbH] +[SbH2] +[InH2] +[GaH2] +[TlH2] +[Au+2] +[sH+] +[Hg+2] +[Si-2] +[Sn-2] +[Pb-2] +[AsH3] +[Cr+2] +[Ag+2] +[V-2] +[Ce-2] +[13C@] +[*+2] +[He+2] +[4He+2] +[3He+2] +[Eu+2] +[Ge+2] +[Os+2] +[Y+2] +[Gd+2] +[La+2] +[Se+2] +[NH-2] +[TeH2-] +[AlH3-] +[SbH3-] +[AsH3-] +[BiH3-] +[PH3-] +[CH2-2] +[AsH4+] +[AlH3+] +[BiH3+] +[FH+] +[CH3+] +[Te-2] +[OH] +[CH3] +[18OH2] +[OH3+] +[OH4+2] +[SH3] +[SH3+] +[SH3-] +[SH4] +[SeH2] +[SeH-] +[SeH3+] +[SeH3-] +[SeH3] +[SeH+] +[TeH2] +[TeH-] +[TeH3-] +[TeH3+] +[TeH+] +[TeH3] +[TeH4] +[PoH2] +[NH2] +[NH+2] +[PH5] +[PH4+] +[PH-2] +[PH4-] +[PH+2] +[AsH2+] +[AsH2-] +[AsH+2] +[AsH-2] +[AsH5] +[SbH3] +[SbH4+] +[SbH5] +[BiH4+] +[BiH5] +[BiH4-] +[BH2] +[BH2+] +[BH2-] +[BH-2] +[BH+2] +[GeH4] +[GeH3+] +[GeH3-] +[SiH3-] +[SiH3+] +[SiH+] +[SiH4] +[HeH+2] +[HeH+] +[AlH] +[AlH+] +[SnH4] +[SnH3-] +[SnH3+] +[PbH4] +[PbH3-] +[PbH3+] +[BeH4-2] +[BeH] +[BeH+] +[BeH-] +[BeH2] +[AtH] +[InH3] +[GaH3] +[TlH3] +[IH3] +[FeH6-4] +[FH2+] +[ClH2+] +[BrH2+] +[IH2+] diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index aa9960f9..4b1b0353 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -8,6 +8,7 @@ import deepsmiles import selfies as sf from pysmiles.read_smiles import _tokenize +from rdkit import Chem from transformers import RobertaTokenizerFast from chebai.preprocessing.collate import DefaultCollator, RaggedCollator @@ -176,6 +177,11 @@ class ChemDataReader(TokenIndexerReader): COLLATOR = RaggedCollator + def __init__(self, canonicalize_smiles=True, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.canonicalize_smiles = canonicalize_smiles + print(f"Using SMILES canonicalization: {self.canonicalize_smiles}") + @classmethod def name(cls) -> str: """Returns the name of the data reader.""" @@ -183,7 +189,7 @@ def name(cls) -> str: def _read_data(self, raw_data: str) -> List[int]: """ - Reads and tokenizes raw SMILES data into a list of token indices. + Reads and tokenizes raw SMILES data into a list of token indices. Canonicalizes the SMILES string using RDKit. Args: raw_data (str): The raw SMILES string to be tokenized. @@ -191,6 +197,15 @@ def _read_data(self, raw_data: str) -> List[int]: Returns: List[int]: A list of integers representing the indices of the SMILES tokens. """ + if self.canonicalize_smiles: + try: + mol = Chem.MolFromSmiles(raw_data.strip()) + if mol is not None: + raw_data = Chem.MolToSmiles(mol, canonical=True) + except Exception as e: + print(f"RDKit failed to process {raw_data}") + print(f"\t{e}") + return [self._get_token_index(v[1]) for v in _tokenize(raw_data)]