Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions configs/encoders/multimodal/kformer_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 31090,
"encoder_width": 768,
"add_cross_attention": true,
"cross_attention_freq": 2,
"num_query_tokens": 32,
"contrastive_layer": 6
}

51 changes: 51 additions & 0 deletions configs/moledit/molkformer-Graph-MegaMolBART.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
{
"model": "molkformer-MegaMolBART",
"data": {
"mol": {
"modality": ["structure"],
"featurizer": {
"structure": {
"name": "MultiScale",
"scales": ["SMILES", "graph"],
"SMILES": {
"name": "moleculeSTM",
"transformer_type": "molbart",
"model_name_or_path": "./ckpts/fusion_ckpts/pretrained_MegaMolBART/bart_vocab.txt",
"max_length": 512
},
"graph": {
"name": "BaseGNN"
}
}

}
},
"text": {
"name": "TransformerTokenizer",
"transformer_type": "biot5",
"max_length": 512,
"model_name_or_path": "./ckpts/text_ckpts/t5-v1.1-base",
"path_selfies": "./assets/tokenizers/biot5/selfies_dict.txt"
}
},
"network": {
"graph": {
"name": "molkformer",
"structure": {
"gin_hidden_dim": 300,
"gin_num_layers": 5,
"drop_ratio": 0.0
},
"decoder": {
"config_file": "./ckpts/text_ckpts/t5-v1.1-base/config.json"
},
"kformer_config_file": "./configs/encoders/multimodal/kformer_config.json",
"encoder_tokenizer": "./ckpts/text_ckpts/scibert_scivocab_uncased",
"decoder_tokenizer": "./ckpts/text_ckpts/t5-v1.1-base",
"path_selfies": "./assets/tokenizers/biot5/selfies_dict.txt",
"max_n_atoms": 256,
"projection_dim": 256,
"init_checkpoint": "./ckpts/fusion_ckpts/molkformer/checkpoint_49.pth"
}
}
}
47 changes: 47 additions & 0 deletions configs/moledit/molstm-Graph-MegaMolBART.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
{
"model": "molstm-MegaMolBART",
"data": {
"mol": {
"modality": ["structure"],
"featurizer": {
"structure": {
"name": "MultiScale",
"scales": ["SMILES", "graph"],
"SMILES": {
"name": "moleculeSTM",
"transformer_type": "molbart",
"model_name_or_path": "./ckpts/fusion_ckpts/pretrained_MegaMolBART/bart_vocab.txt",
"max_length": 512
},
"graph": {
"name": "BaseGNN"
}
}
}
}
},
"network": {
"graph": {
"name": "molstm",
"structure": {
"name": "gnn",
"gin_hidden_dim": 300,
"gin_num_layers": 5,
"drop_ratio": 0.5,
"output_dim": 300,
"ckpt" : "./ckpts/fusion_ckpts/moleculestm/demo_checkpoints_Graph/molecule_model.pth",
"MegaMolBART_generation_model_dir" : "./ckpts/fusion_ckpts/pretrained_MegaMolBART/checkpoints",
"vocab_path": "./ckpts/fusion_ckpts/pretrained_MegaMolBART/bart_vocab.txt"
},
"text": {
"output_dim": 768,
"ckpt": "./ckpts/fusion_ckpts/moleculestm/demo_checkpoints_Graph/text_model.pth",
"bert_path": "./ckpts/text_ckpts/scibert_scivocab_uncased"
},
"structure_proj_ckpt": "./ckpts/fusion_ckpts/moleculestm/demo_checkpoints_Graph/mol2latent_model.pth",
"text_proj_ckpt": "./ckpts/fusion_ckpts/moleculestm/demo_checkpoints_Graph/text2latent_model.pth",
"projection_dim": 256
}

}
}
40 changes: 40 additions & 0 deletions configs/moledit/molstm-SMILES-MegaMolBART.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"model": "molstm-MegaMolBART",
"data": {
"mol": {
"modality": ["structure"],
"featurizer": {
"structure": {
"name": "MultiScale",
"scales": ["SMILES"],
"SMILES": {
"name": "moleculeSTM",
"transformer_type": "molbart",
"model_name_or_path": "./ckpts/fusion_ckpts/pretrained_MegaMolBART/bart_vocab.txt",
"max_length": 512
}
}
}
}
},
"network": {
"smiles": {
"name": "molstm",
"structure": {
"name": "magamolbart",
"output_dim": 256,
"MegaMolBART_generation_model_dir" : "./ckpts/fusion_ckpts/pretrained_MegaMolBART/checkpoints",
"vocab_path": "./ckpts/fusion_ckpts/pretrained_MegaMolBART/bart_vocab.txt"
},
"text": {
"output_dim": 768,
"ckpt": "./ckpts/fusion_ckpts/moleculestm/demo_checkpoints_SMILES/text_model.pth",
"bert_path": "./ckpts/text_ckpts/scibert_scivocab_uncased"
},
"structure_proj_ckpt": "./ckpts/fusion_ckpts/moleculestm/demo_checkpoints_SMILES/mol2latent_model.pth",
"text_proj_ckpt": "./ckpts/fusion_ckpts/moleculestm/demo_checkpoints_SMILES/text2latent_model.pth",
"projection_dim": 256
}

}
}
43 changes: 43 additions & 0 deletions configs/moledit/momu-Graph-MegaMolBART.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"model": "momu-MegaMolBART",
"data": {
"mol": {
"modality": ["structure"],
"featurizer": {
"structure": {
"name": "MultiScale",
"scales": ["SMILES", "graph"],
"SMILES": {
"name": "moleculeSTM",
"transformer_type": "molbart",
"model_name_or_path": "./ckpts/fusion_ckpts/pretrained_MegaMolBART/bart_vocab.txt",
"max_length": 512
},
"graph": {
"name": "ogb"
}
}

}
}
},
"network": {
"graph": {
"name": "momu",
"gin_hidden_dim": 300,
"gin_num_layers": 5,
"drop_ratio": 0.0,
"graph_pooling": "sum",
"graph_self": false,
"max_n_nodes": -1,
"bert_dropout": 0.0,
"bert_hidden_dim": 768,
"output_dim": 300,
"projection_dim": 256,
"init_checkpoint": "./ckpts/fusion_ckpts/momu/MoMu-S.ckpt",
"param_key": "state_dict",
"stop_grad": false
}

}
}
84 changes: 84 additions & 0 deletions open_biomed/datasets/moledit_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import logging
logger = logging.getLogger(__name__)

from abc import ABC, abstractmethod

import os
import csv

import torch
from torch.utils.data import Dataset
import pandas as pd

from feature.mol_featurizer import MolMultiModalFeaturizer
from feature.text_featurizer import TextTransformerTokFeaturizer
from utils.mol_utils import valid_smiles
from models.multimodal.mega_molbart.tokenizer import MolEncTokenizer

class MoleditDataset(Dataset, ABC):
def __init__(self, path, config):
super(MoleditDataset, self).__init__()
self.path = path
self.config = config
self._load_data()
self._featurize()

@abstractmethod
def _load_data(self):
raise NotImplementedError

def _featurize(self):
featurizer = MolMultiModalFeaturizer(self.config)
self.mols = [featurizer(smi) for smi in self.smiles]

smiles_emb = [d['structure'] for d in self.mols]
smiles_emb = [d['SMILES'] for d in smiles_emb]
smiles_emb = [item for sublist in smiles_emb for item in sublist]
tokens, orig_pad_masks = self._pad_seqs(smiles_emb)
smiles = [{'input_ids': tokens, 'pad_masks': pad_masks} for tokens, pad_masks in zip(tokens, orig_pad_masks)]
for i, dictionary in enumerate(smiles):
self.mols[i]["structure"]["SMILES"] = dictionary


@staticmethod
def _pad_seqs(seqs, pad_token = 0):
pad_length = max([len(seq) for seq in seqs])
padded = [seq + ([pad_token] * (pad_length - len(seq))) for seq in seqs]
masks = [([0] * len(seq)) + ([1] * (pad_length - len(seq))) for seq in seqs]
return padded, masks

def __getitem__(self, index):
return self.mols[index]

def __len__(self):
return len(self.mols)


class ZINC250K(MoleditDataset):
def __init__(self, path, config, split):
self.split = split
super(ZINC250K, self).__init__(path, config)

def _load_data(self, subset_size=None):

SMILES_file = os.path.join(self.path, "raw/250k_rndm_zinc_drugs_clean_3.csv")
df = pd.read_csv(SMILES_file)
smiles = df['smiles'].tolist() # Already canonical SMILES
self.smiles = [x.strip() for x in smiles]

new_SMILES_file = os.path.join(self.path, "raw/smiles.csv")
if not os.path.exists(new_SMILES_file):
data_smiles_series = pd.Series(self.smiles)
print("saving to {}".format(new_SMILES_file))
data_smiles_series.to_csv(new_SMILES_file, index=False, header=False)

if subset_size is not None:
self.smiles = self.smiles[:subset_size]




SUPPORTED_MOLEDIT_DATASET = {
"ZINC250K": ZINC250K
}

17 changes: 17 additions & 0 deletions open_biomed/feature/mol_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sklearn.preprocessing import OneHotEncoder
from torch_geometric.data import Data
from transformers import BertTokenizer, T5Tokenizer
from models.multimodal.mega_molbart.tokenizer import MolEncTokenizer

from feature.base_featurizer import BaseFeaturizer
from feature.kg_featurizer import SUPPORTED_KG_FEATURIZER
Expand Down Expand Up @@ -169,6 +170,21 @@ def __call__(self, data):
result = self.tokenizer(data, max_length=self.max_length, padding=True, truncation=True)
return result

class MolSTMFeaturizer(BaseFeaturizer):
name2tokenizer = {
"molbart": MolEncTokenizer,
}

def __init__(self, config):
super(MolSTMFeaturizer, self).__init__()
self.tokenizer = self.name2tokenizer[config["transformer_type"]].from_pretrained(config["model_name_or_path"])

def __call__(self, data):
result = self.tokenizer.tokenize(data, pad=False)
result = self.tokenizer.convert_tokens_to_ids(result['original_tokens'])

return result

class MolBPEFeaturizer(BaseFeaturizer):
def __init__(self, config):
super(MolBPEFeaturizer, self).__init__()
Expand Down Expand Up @@ -783,6 +799,7 @@ def __getitem__(self, index):
"OneHot": MolOneHotFeaturizer,
"KV-PLM*": MolBPEFeaturizer,
"transformer": MolTransformerTokFeaturizer,
"moleculeSTM": MolSTMFeaturizer,
"fp": MolFPFeaturizer,
"TGSA": MolTGSAFeaturizer,
"ogb": MolGraphFeaturizer,
Expand Down
6 changes: 5 additions & 1 deletion open_biomed/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from models.knowledge import *
from models.text import *
from models.multimodal import *
from models.multimodal.molkformer import *


SUPPORTED_MOL_ENCODER = {
"cnn": MolCNN,
Expand All @@ -18,7 +20,9 @@
"biomedgpt-10b": BioMedGPTV,
"kv-plm": KVPLM,
"momu": MoMu,
"molfm": MolFM
"molfm": MolFM,
"molkformer": MolKFormer,
"molstm": MoleculeSTM
}

SUPPORTED_MOL_DECODER = {
Expand Down
5 changes: 4 additions & 1 deletion open_biomed/models/multimodal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@
from models.multimodal.molfm.molfm import MolFM
from models.multimodal.molfm.drugfm import DrugFM
from models.multimodal.molt5 import MolT5
from models.multimodal.text2mol import Text2MolMLP
from models.multimodal.text2mol import Text2MolMLP
from models.multimodal.molkformer.mol_kformer import MolKFormer
from models.multimodal.mega_molbart.mega_mol_bart import MegaMolBART
from models.multimodal.moleculestm import MoleculeSTM
1 change: 1 addition & 0 deletions open_biomed/models/multimodal/mega_molbart/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from models.multimodal.mega_molbart.megatron_bart import MegatronBART
Loading