diff --git a/configs/encoders/multimodal/kformer_config.json b/configs/encoders/multimodal/kformer_config.json new file mode 100644 index 0000000..47920e6 --- /dev/null +++ b/configs/encoders/multimodal/kformer_config.json @@ -0,0 +1,25 @@ +{ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "type_vocab_size": 2, + "vocab_size": 31090, + "encoder_width": 768, + "add_cross_attention": true, + "cross_attention_freq": 2, + "num_query_tokens": 32, + "contrastive_layer": 6 + } + \ No newline at end of file diff --git a/configs/moledit/molkformer-Graph-MegaMolBART.json b/configs/moledit/molkformer-Graph-MegaMolBART.json new file mode 100644 index 0000000..3ac9731 --- /dev/null +++ b/configs/moledit/molkformer-Graph-MegaMolBART.json @@ -0,0 +1,51 @@ +{ + "model": "molkformer-MegaMolBART", + "data": { + "mol": { + "modality": ["structure"], + "featurizer": { + "structure": { + "name": "MultiScale", + "scales": ["SMILES", "graph"], + "SMILES": { + "name": "moleculeSTM", + "transformer_type": "molbart", + "model_name_or_path": "./ckpts/fusion_ckpts/pretrained_MegaMolBART/bart_vocab.txt", + "max_length": 512 + }, + "graph": { + "name": "BaseGNN" + } + } + + } + }, + "text": { + "name": "TransformerTokenizer", + "transformer_type": "biot5", + "max_length": 512, + "model_name_or_path": "./ckpts/text_ckpts/t5-v1.1-base", + "path_selfies": "./assets/tokenizers/biot5/selfies_dict.txt" + } + }, + "network": { + "graph": { + "name": "molkformer", + "structure": { + "gin_hidden_dim": 300, + "gin_num_layers": 5, + "drop_ratio": 0.0 + }, + "decoder": { + "config_file": "./ckpts/text_ckpts/t5-v1.1-base/config.json" + }, + "kformer_config_file": "./configs/encoders/multimodal/kformer_config.json", + "encoder_tokenizer": "./ckpts/text_ckpts/scibert_scivocab_uncased", + "decoder_tokenizer": "./ckpts/text_ckpts/t5-v1.1-base", + "path_selfies": "./assets/tokenizers/biot5/selfies_dict.txt", + "max_n_atoms": 256, + "projection_dim": 256, + "init_checkpoint": "./ckpts/fusion_ckpts/molkformer/checkpoint_49.pth" + } + } +} \ No newline at end of file diff --git a/configs/moledit/molstm-Graph-MegaMolBART.json b/configs/moledit/molstm-Graph-MegaMolBART.json new file mode 100644 index 0000000..f55c7e7 --- /dev/null +++ b/configs/moledit/molstm-Graph-MegaMolBART.json @@ -0,0 +1,47 @@ +{ + "model": "molstm-MegaMolBART", + "data": { + "mol": { + "modality": ["structure"], + "featurizer": { + "structure": { + "name": "MultiScale", + "scales": ["SMILES", "graph"], + "SMILES": { + "name": "moleculeSTM", + "transformer_type": "molbart", + "model_name_or_path": "./ckpts/fusion_ckpts/pretrained_MegaMolBART/bart_vocab.txt", + "max_length": 512 + }, + "graph": { + "name": "BaseGNN" + } + } + } + } + }, + "network": { + "graph": { + "name": "molstm", + "structure": { + "name": "gnn", + "gin_hidden_dim": 300, + "gin_num_layers": 5, + "drop_ratio": 0.5, + "output_dim": 300, + "ckpt" : "./ckpts/fusion_ckpts/moleculestm/demo_checkpoints_Graph/molecule_model.pth", + "MegaMolBART_generation_model_dir" : "./ckpts/fusion_ckpts/pretrained_MegaMolBART/checkpoints", + "vocab_path": "./ckpts/fusion_ckpts/pretrained_MegaMolBART/bart_vocab.txt" + }, + "text": { + "output_dim": 768, + "ckpt": "./ckpts/fusion_ckpts/moleculestm/demo_checkpoints_Graph/text_model.pth", + "bert_path": "./ckpts/text_ckpts/scibert_scivocab_uncased" + }, + "structure_proj_ckpt": "./ckpts/fusion_ckpts/moleculestm/demo_checkpoints_Graph/mol2latent_model.pth", + "text_proj_ckpt": "./ckpts/fusion_ckpts/moleculestm/demo_checkpoints_Graph/text2latent_model.pth", + "projection_dim": 256 + } + + } +} \ No newline at end of file diff --git a/configs/moledit/molstm-SMILES-MegaMolBART.json b/configs/moledit/molstm-SMILES-MegaMolBART.json new file mode 100644 index 0000000..aa3b841 --- /dev/null +++ b/configs/moledit/molstm-SMILES-MegaMolBART.json @@ -0,0 +1,40 @@ +{ + "model": "molstm-MegaMolBART", + "data": { + "mol": { + "modality": ["structure"], + "featurizer": { + "structure": { + "name": "MultiScale", + "scales": ["SMILES"], + "SMILES": { + "name": "moleculeSTM", + "transformer_type": "molbart", + "model_name_or_path": "./ckpts/fusion_ckpts/pretrained_MegaMolBART/bart_vocab.txt", + "max_length": 512 + } + } + } + } + }, + "network": { + "smiles": { + "name": "molstm", + "structure": { + "name": "magamolbart", + "output_dim": 256, + "MegaMolBART_generation_model_dir" : "./ckpts/fusion_ckpts/pretrained_MegaMolBART/checkpoints", + "vocab_path": "./ckpts/fusion_ckpts/pretrained_MegaMolBART/bart_vocab.txt" + }, + "text": { + "output_dim": 768, + "ckpt": "./ckpts/fusion_ckpts/moleculestm/demo_checkpoints_SMILES/text_model.pth", + "bert_path": "./ckpts/text_ckpts/scibert_scivocab_uncased" + }, + "structure_proj_ckpt": "./ckpts/fusion_ckpts/moleculestm/demo_checkpoints_SMILES/mol2latent_model.pth", + "text_proj_ckpt": "./ckpts/fusion_ckpts/moleculestm/demo_checkpoints_SMILES/text2latent_model.pth", + "projection_dim": 256 + } + + } +} \ No newline at end of file diff --git a/configs/moledit/momu-Graph-MegaMolBART.json b/configs/moledit/momu-Graph-MegaMolBART.json new file mode 100644 index 0000000..6989b95 --- /dev/null +++ b/configs/moledit/momu-Graph-MegaMolBART.json @@ -0,0 +1,43 @@ +{ + "model": "momu-MegaMolBART", + "data": { + "mol": { + "modality": ["structure"], + "featurizer": { + "structure": { + "name": "MultiScale", + "scales": ["SMILES", "graph"], + "SMILES": { + "name": "moleculeSTM", + "transformer_type": "molbart", + "model_name_or_path": "./ckpts/fusion_ckpts/pretrained_MegaMolBART/bart_vocab.txt", + "max_length": 512 + }, + "graph": { + "name": "ogb" + } + } + + } + } + }, + "network": { + "graph": { + "name": "momu", + "gin_hidden_dim": 300, + "gin_num_layers": 5, + "drop_ratio": 0.0, + "graph_pooling": "sum", + "graph_self": false, + "max_n_nodes": -1, + "bert_dropout": 0.0, + "bert_hidden_dim": 768, + "output_dim": 300, + "projection_dim": 256, + "init_checkpoint": "./ckpts/fusion_ckpts/momu/MoMu-S.ckpt", + "param_key": "state_dict", + "stop_grad": false + } + + } +} \ No newline at end of file diff --git a/open_biomed/datasets/moledit_dataset.py b/open_biomed/datasets/moledit_dataset.py new file mode 100644 index 0000000..f7fa9db --- /dev/null +++ b/open_biomed/datasets/moledit_dataset.py @@ -0,0 +1,84 @@ +import logging +logger = logging.getLogger(__name__) + +from abc import ABC, abstractmethod + +import os +import csv + +import torch +from torch.utils.data import Dataset +import pandas as pd + +from feature.mol_featurizer import MolMultiModalFeaturizer +from feature.text_featurizer import TextTransformerTokFeaturizer +from utils.mol_utils import valid_smiles +from models.multimodal.mega_molbart.tokenizer import MolEncTokenizer + +class MoleditDataset(Dataset, ABC): + def __init__(self, path, config): + super(MoleditDataset, self).__init__() + self.path = path + self.config = config + self._load_data() + self._featurize() + + @abstractmethod + def _load_data(self): + raise NotImplementedError + + def _featurize(self): + featurizer = MolMultiModalFeaturizer(self.config) + self.mols = [featurizer(smi) for smi in self.smiles] + + smiles_emb = [d['structure'] for d in self.mols] + smiles_emb = [d['SMILES'] for d in smiles_emb] + smiles_emb = [item for sublist in smiles_emb for item in sublist] + tokens, orig_pad_masks = self._pad_seqs(smiles_emb) + smiles = [{'input_ids': tokens, 'pad_masks': pad_masks} for tokens, pad_masks in zip(tokens, orig_pad_masks)] + for i, dictionary in enumerate(smiles): + self.mols[i]["structure"]["SMILES"] = dictionary + + + @staticmethod + def _pad_seqs(seqs, pad_token = 0): + pad_length = max([len(seq) for seq in seqs]) + padded = [seq + ([pad_token] * (pad_length - len(seq))) for seq in seqs] + masks = [([0] * len(seq)) + ([1] * (pad_length - len(seq))) for seq in seqs] + return padded, masks + + def __getitem__(self, index): + return self.mols[index] + + def __len__(self): + return len(self.mols) + + +class ZINC250K(MoleditDataset): + def __init__(self, path, config, split): + self.split = split + super(ZINC250K, self).__init__(path, config) + + def _load_data(self, subset_size=None): + + SMILES_file = os.path.join(self.path, "raw/250k_rndm_zinc_drugs_clean_3.csv") + df = pd.read_csv(SMILES_file) + smiles = df['smiles'].tolist() # Already canonical SMILES + self.smiles = [x.strip() for x in smiles] + + new_SMILES_file = os.path.join(self.path, "raw/smiles.csv") + if not os.path.exists(new_SMILES_file): + data_smiles_series = pd.Series(self.smiles) + print("saving to {}".format(new_SMILES_file)) + data_smiles_series.to_csv(new_SMILES_file, index=False, header=False) + + if subset_size is not None: + self.smiles = self.smiles[:subset_size] + + + + +SUPPORTED_MOLEDIT_DATASET = { + "ZINC250K": ZINC250K +} + diff --git a/open_biomed/feature/mol_featurizer.py b/open_biomed/feature/mol_featurizer.py index 9e92415..4189ef1 100644 --- a/open_biomed/feature/mol_featurizer.py +++ b/open_biomed/feature/mol_featurizer.py @@ -23,6 +23,7 @@ from sklearn.preprocessing import OneHotEncoder from torch_geometric.data import Data from transformers import BertTokenizer, T5Tokenizer +from models.multimodal.mega_molbart.tokenizer import MolEncTokenizer from feature.base_featurizer import BaseFeaturizer from feature.kg_featurizer import SUPPORTED_KG_FEATURIZER @@ -169,6 +170,21 @@ def __call__(self, data): result = self.tokenizer(data, max_length=self.max_length, padding=True, truncation=True) return result +class MolSTMFeaturizer(BaseFeaturizer): + name2tokenizer = { + "molbart": MolEncTokenizer, + } + + def __init__(self, config): + super(MolSTMFeaturizer, self).__init__() + self.tokenizer = self.name2tokenizer[config["transformer_type"]].from_pretrained(config["model_name_or_path"]) + + def __call__(self, data): + result = self.tokenizer.tokenize(data, pad=False) + result = self.tokenizer.convert_tokens_to_ids(result['original_tokens']) + + return result + class MolBPEFeaturizer(BaseFeaturizer): def __init__(self, config): super(MolBPEFeaturizer, self).__init__() @@ -783,6 +799,7 @@ def __getitem__(self, index): "OneHot": MolOneHotFeaturizer, "KV-PLM*": MolBPEFeaturizer, "transformer": MolTransformerTokFeaturizer, + "moleculeSTM": MolSTMFeaturizer, "fp": MolFPFeaturizer, "TGSA": MolTGSAFeaturizer, "ogb": MolGraphFeaturizer, diff --git a/open_biomed/models/__init__.py b/open_biomed/models/__init__.py index 587bf80..59c3363 100644 --- a/open_biomed/models/__init__.py +++ b/open_biomed/models/__init__.py @@ -4,6 +4,8 @@ from models.knowledge import * from models.text import * from models.multimodal import * +from models.multimodal.molkformer import * + SUPPORTED_MOL_ENCODER = { "cnn": MolCNN, @@ -18,7 +20,9 @@ "biomedgpt-10b": BioMedGPTV, "kv-plm": KVPLM, "momu": MoMu, - "molfm": MolFM + "molfm": MolFM, + "molkformer": MolKFormer, + "molstm": MoleculeSTM } SUPPORTED_MOL_DECODER = { diff --git a/open_biomed/models/multimodal/__init__.py b/open_biomed/models/multimodal/__init__.py index c628cc0..2626315 100644 --- a/open_biomed/models/multimodal/__init__.py +++ b/open_biomed/models/multimodal/__init__.py @@ -5,4 +5,7 @@ from models.multimodal.molfm.molfm import MolFM from models.multimodal.molfm.drugfm import DrugFM from models.multimodal.molt5 import MolT5 -from models.multimodal.text2mol import Text2MolMLP \ No newline at end of file +from models.multimodal.text2mol import Text2MolMLP +from models.multimodal.molkformer.mol_kformer import MolKFormer +from models.multimodal.mega_molbart.mega_mol_bart import MegaMolBART +from models.multimodal.moleculestm import MoleculeSTM \ No newline at end of file diff --git a/open_biomed/models/multimodal/mega_molbart/__init__.py b/open_biomed/models/multimodal/mega_molbart/__init__.py new file mode 100644 index 0000000..d7f244e --- /dev/null +++ b/open_biomed/models/multimodal/mega_molbart/__init__.py @@ -0,0 +1 @@ +from models.multimodal.mega_molbart.megatron_bart import MegatronBART \ No newline at end of file diff --git a/open_biomed/models/multimodal/mega_molbart/decoder.py b/open_biomed/models/multimodal/mega_molbart/decoder.py new file mode 100644 index 0000000..b7aaad1 --- /dev/null +++ b/open_biomed/models/multimodal/mega_molbart/decoder.py @@ -0,0 +1,426 @@ +# coding=utf-8 + +import torch +from rdkit import Chem, RDLogger +from .util import DEFAULT_MAX_SEQ_LEN + +class DecodeSampler: + def __init__( + self, + tokenizer, + max_seq_len=DEFAULT_MAX_SEQ_LEN + ): + self.tokenizer = tokenizer + self.max_seq_len = max_seq_len + + assert max_seq_len > 1, f"Max sequence must be at least 2, got {max_seq_len}" + + self.begin_token_id = self.tokenizer.vocab[self.tokenizer.begin_token] + self.pad_token_id = self.tokenizer.vocab[self.tokenizer.pad_token] + self.end_token_id = self.tokenizer.vocab[self.tokenizer.end_token] + + self.bad_token_ll = -1e5 + + RDLogger.DisableLog("rdApp.*") + + + def decode(self, decode_fn, batch_size, sampling_alg="greedy", device="cpu", **kwargs): + """ Sample a molecule from a model by calling the decode function argument + + Args: + decode_fn: A function mapping a batched sequence of token identifiers and their associated pad masks + to a log probability distribution over possible next tokens + batch_size: The number of elements to pass into the decode function in one batch + sampling_alg: Algorithm to use for sampling from the model + + Returns: + (SMILES of sampled molecules (List[str]), log likelihoods (List[float])) + """ + + if sampling_alg == "greedy": + output = self.greedy_decode(decode_fn, batch_size, device) + + elif sampling_alg == "beam": + output = self.beam_decode(decode_fn, batch_size, device, kwargs) + + else: + raise ValueError(f"Unknown sampling algorithm {sampling_alg}") + + return output + + + def greedy_decode(self, decode_fn, batch_size, device="cpu"): + """ Sample molecules from the model using greedy search + + Args: + decode_fn (fn): Function used to apply tokens to model and produce log probability distribution + batch_size (int): Number of molecules to sample + device: Torch device to create tensors on + + Returns: + (List[str], List[float]): Tuple of (molecules, their log likelihoods) + """ + + # Create tensors which will be reused + token_ids = [self.begin_token_id] + ([self.pad_token_id] * (self.max_seq_len - 1)) + token_ids = [token_ids] * batch_size + token_ids = torch.tensor(token_ids, device=device).transpose(0, 1) + pad_mask = torch.zeros((self.max_seq_len, batch_size), device=device, dtype=torch.bool) + log_lhs = torch.zeros((batch_size)) + + # Iteratively apply the tokens to the model and build up the sequence + for i in range(1, self.max_seq_len): + token_ids_seq = token_ids[:i, :] + pad_mask_seq = pad_mask[:i, :] + + # Sample next id for each element in the batch + output_dist = decode_fn(token_ids_seq, pad_mask_seq) + probs, output_ids = output_dist.max(dim=2) + new_ids = output_ids[-1, :] + new_probs = probs[-1, :] + + # Generate next elements in the pad mask. An element is padded if: + # 1. The previous token is an end token + # 2. The previous token is a pad token + is_end_token = token_ids[i-1, :] == self.end_token_id + is_pad_token = token_ids[i-1, :] == self.pad_token_id + new_pad_mask = torch.logical_or(is_end_token, is_pad_token) + + # Break if sampling is complete + if new_pad_mask.sum().item() == new_pad_mask.numel(): + break + + # Ensure all sequences contain an end token + if i == self.max_seq_len - 1: + new_ids[~new_pad_mask] = self.end_token_id + + # Set the token to pad where required, update the token ids and update lls + new_ids[new_pad_mask] = self.pad_token_id + token_ids[i, :] = new_ids + pad_mask[i, :] = new_pad_mask + log_lhs += new_probs.cpu() + + tokens = token_ids.transpose(0, 1).tolist() + tokens = self.tokenizer.convert_ids_to_tokens(tokens) + mol_strs = self.tokenizer.detokenize(tokens) + log_lhs = log_lhs.tolist() + + return mol_strs, log_lhs + + + def beam_decode(self, decode_fn, batch_size, device="cpu", k=5): + """ Sample molecules from the model using beam search + + Samples molecules by iteratively building up the sequence of SMILES characters using beam search. + Molecules are returned in a 2D list where batch_size is the outer dimension and k is the inner dimension. + + Args: + decode_fn (fn): Function used to apply tokens to model and produce log probability distribution + batch_size (int): Number of molecules to sample + device: Torch device to create tensors on + k (int): Number of beams + + Returns: + (List[List[str]], List[List[float]]): Tuple of (molecules, their log likelihoods) + """ + + # Create tensors which will be reused + token_ids = [self.begin_token_id] + ([self.pad_token_id] * (self.max_seq_len - 1)) + token_ids = [token_ids] * batch_size + token_ids = torch.tensor(token_ids, device=device).transpose(0, 1) + pad_mask = torch.zeros((self.max_seq_len, batch_size), device=device, dtype=torch.bool) + + ts = token_ids[:1, :] + ms = pad_mask[:1, :] + ll = torch.zeros((batch_size)) + + # Apply starting token to model to get a distribution over next tokens + first_lls = self._beam_step(decode_fn, ts, ms, ll) + top_lls, top_idxs = torch.topk(first_lls, k, dim=1) + top_ids = list(top_idxs.T) + + # Setup tensors for each beam which will be reused + token_ids_list = [token_ids.clone() for _ in range(k)] + pad_mask_list = [pad_mask.clone() for _ in range(k)] + lls_list = list(top_lls.cpu().T) + + for beam_idx, ids in enumerate(top_ids): + token_ids_list[beam_idx][1, :] = ids + pad_mask_list[beam_idx][1, :] = 0 + + for i in range(2, self.max_seq_len): + complete = self._update_beams_(i, decode_fn, token_ids_list, pad_mask_list, lls_list) + if complete: + break + + tokens_list = [token_ids.transpose(0, 1).tolist() for token_ids in token_ids_list] + tokens_list = [self.tokenizer.convert_ids_to_tokens(tokens) for tokens in tokens_list] + mol_strs_list = [self.tokenizer.detokenize(tokens) for tokens in tokens_list] + log_lhs_list = [log_lhs.tolist() for log_lhs in lls_list] + + # Transpose and sort list of molecules based on ll + new_mol_strs = self._transpose_list(mol_strs_list) + new_log_lhs = self._transpose_list(log_lhs_list) + sorted_mols, sorted_lls = self._sort_beams(new_mol_strs, new_log_lhs) + + return sorted_mols, sorted_lls + + + def _update_beams_(self, i, decode_fn, token_ids_list, pad_mask_list, lls_list): + """ Update beam tokens and pad mask in-place using a single decode step + + Updates token ids and pad mask in-place by producing the probability distribution over next tokens + and choosing the top k (number of beams) log likelihoods to choose the next tokens. + Sampling is complete if every batch element in every beam has produced an end token. + + Args: + i (int): The current iteration counter + decode_fn (fn): Function used to apply tokens to model and produce log probability distribution + token_ids_list (List[torch.Tensor]): List of token_ids, each of shape [seq_len, batch_size] + pad_mask_list (List[torch.Tensor]): List of pad_masks, each of shape [seq_len, batch_size] + lls_list (List[torch.Tensor]): List of log likelihoods, each of shape [batch_size] + + Returns: + (bool): Specifies whether all of the beams are complete + """ + + assert len(token_ids_list) == len(pad_mask_list) == len(lls_list) + + num_beams = len(token_ids_list) + + ts = [token_ids[:i, :] for token_ids in token_ids_list] + ms = [pad_mask[:i, :] for pad_mask in pad_mask_list] + + # Apply current seqs to model to get a distribution over next tokens + # new_lls is a tensor of shape [batch_size, vocab_size * num_beams] + new_lls = [self._beam_step(decode_fn, t, m, lls) for t, m, lls in zip(ts, ms, lls_list)] + _, vocab_size = new_lls[0].shape + new_lls = torch.cat(new_lls, dim=1) + + # Keep lists (of length num_beams) of tensors of shape [batch_size] + top_lls, top_idxs = torch.topk(new_lls, num_beams, dim=1) + new_ids_list = list((top_idxs % vocab_size).T) + beam_idxs_list = list((top_idxs // vocab_size).T) + top_lls = list(top_lls.T) + + beam_complete = [] + new_ts_list = [] + new_pm_list = [] + new_lls_list = [] + + # Set the sampled tokens, pad masks and log likelihoods for each of the new beams + for new_beam_idx, (new_ids, beam_idxs, lls) in enumerate(zip(new_ids_list, beam_idxs_list, top_lls)): + # Get the previous sequences corresponding to the new beams + token_ids = [token_ids_list[beam_idx][:, b_idx] for b_idx, beam_idx in enumerate(beam_idxs)] + token_ids = torch.stack(token_ids).transpose(0, 1) + + # Generate next elements in the pad mask. An element is padded if: + # 1. The previous token is an end token + # 2. The previous token is a pad token + is_end_token = token_ids[i-1, :] == self.end_token_id + is_pad_token = token_ids[i-1, :] == self.pad_token_id + new_pad_mask = torch.logical_or(is_end_token, is_pad_token) + beam_complete.append(new_pad_mask.sum().item() == new_pad_mask.numel()) + + # Ensure all sequences contain an end token + if i == self.max_seq_len - 1: + new_ids[~new_pad_mask] = self.end_token_id + + # Set the tokens to pad if an end token as already been produced + new_ids[new_pad_mask] = self.pad_token_id + token_ids[i, :] = new_ids + + # Generate full pad mask sequence for new token sequence + pad_mask = [pad_mask_list[beam_idx][:, b_idx] for b_idx, beam_idx in enumerate(beam_idxs)] + pad_mask = torch.stack(pad_mask).transpose(0, 1) + pad_mask[i, :] = new_pad_mask + + # Add tokens, pad mask and lls to list to be updated after all beams have been processed + new_ts_list.append(token_ids) + new_pm_list.append(pad_mask) + new_lls_list.append(lls) + + complete = sum(beam_complete) == len(beam_complete) + + # Update all tokens, pad masks and lls + if not complete: + for beam_idx, (ts, pm, lls) in enumerate(zip(new_ts_list, new_pm_list, new_lls_list)): + token_ids_list[beam_idx] = ts + pad_mask_list[beam_idx] = pm + lls_list[beam_idx] = lls + + return complete + + def _beam_step(self, decode_fn, tokens, mask, lls): + """ Apply tokens to model to produce the log likelihoods for the full sequence + + A single iteration of decode is applied to the model to produce the next tokens in the sequences + and the log likelihoods for the entire sequences (including the next token) + The lls are returned as a distribution over all possible next tokens + + Args: + decode_fn (fn): Function used to apply tokens to model and produce log probability distribution + tokens (torch.Tensor): Tensor of shape [seq_len, batch_size] containing the current token ids + mask (torch.Tensor): BoolTensor of shape [seq_len, batch_size] containing the padding mask + lls (torch.Tensor): Tensor of shape [batch_size] containing log likelihoods for seqs so far + + Returns: + seq_lls (torch.Tensor): Tensor of shape [batch_size, vocab_size] + """ + + output_dist = decode_fn(tokens, mask) + next_token_lls = output_dist[-1, :, :].cpu() + + # Create a vector from which only a pad token can be sampled + # And use this vector in the output for sequences which are complete + _, vocab_size = tuple(next_token_lls.shape) + complete_seq_ll = torch.ones((1, vocab_size)) * self.bad_token_ll + complete_seq_ll[:, self.pad_token_id] = 0.0 + + is_end_token = tokens[-1, :] == self.end_token_id + is_pad_token = tokens[-1, :] == self.pad_token_id + ll_mask = torch.logical_or(is_end_token, is_pad_token).cpu().unsqueeze(1) + masked_lls = (ll_mask * complete_seq_ll) + (~ll_mask * next_token_lls) + + seq_lls = (lls + masked_lls.T).T + return seq_lls + + @staticmethod + def _transpose_list(l): + """ Transpose 2D list so that inner dimension is first + + Args: + l (List[Any]): List to be transposed + + Returns: + (List[Any]): Transposed list + """ + + outer_dim = len(l) + inner_dim = len(l[0]) + + transposed = [[[]] * outer_dim for _ in range(inner_dim)] + for outer_idx, inner in enumerate(l): + for inner_idx, item in enumerate(inner): + transposed[inner_idx][outer_idx] = item + + return transposed + + @staticmethod + def _sort_beams(mol_strs, log_lhs): + """ Return mols sorted by their log likelihood + + Args: + mol_strs (List[List[str]]): SMILES encoding of molecules + log_lhs (List[List[float]]): Log likelihood for each molecule + + Returns: + (List[str], List[float]): Tuple of sorted molecules and sorted log lhs + """ + + assert len(mol_strs) == len(log_lhs) + + sorted_mols = [] + sorted_lls = [] + + for mols, lls in zip(mol_strs, log_lhs): + mol_lls = sorted(zip(mols, lls), reverse=True, key=lambda mol_ll: mol_ll[1]) + mols, lls = tuple(zip(*mol_lls)) + sorted_mols.append(list(mols)) + sorted_lls.append(list(lls)) + + return sorted_mols, sorted_lls + + @staticmethod + def calc_sampling_metrics(sampled_smiles, target_smiles): + """ Calculate sampling metrics for the model + + If sampled_smiles is a List[List[str]] then the following metrics for beam search are calculated (up to the + maximum given by the number of elements in the inner lists): + - "top_1_accuracy" + - "top_5_accuracy" + - "top_10_accuracy" + - "top_20_accuracy" + - "top_50_accuracy" + The SMILES strings must be sorted in decreasing order of their predicted likelihood + + If the sampled_smiles is a List[str] then "accuracy" is calculated + + The the number of invalid SMILES "invalid" is also returned (for beam search this is just from the top_1) + + Args: + sampled_smiles: SMILES strings produced by decode function, + target_smiles: target molecules as canonicalised SMILES strings + + Returns: + dict containing results + """ + + num_sampled = len(sampled_smiles) + num_target = len(target_smiles) + err_msg = f"The number of sampled and target molecules must be the same, got {num_sampled} and {num_target}" + assert num_sampled == num_target, err_msg + + data_type = type(sampled_smiles[0]) + if data_type == str: + results = DecodeSampler._calc_greedy_metrics(sampled_smiles, target_smiles) + elif data_type == list: + results = DecodeSampler._calc_beam_metrics(sampled_smiles, target_smiles) + else: + raise TypeError(f"Elements of sampled_smiles must be either a str or a list, got {data_type}") + + return results + + @staticmethod + def _calc_greedy_metrics(sampled_smiles, target_smiles): + sampled_mols = [Chem.MolFromSmiles(smi) for smi in sampled_smiles] + invalid = [mol is None for mol in sampled_mols] + + canon_smiles = ["Unknown" if mol is None else Chem.MolToSmiles(mol) for mol in sampled_mols] + target_mols = [Chem.MolFromSmiles(smi) for smi in target_smiles] + canon_target_smiles = [Chem.MolToSmiles(mol) for mol in target_mols] + correct_smiles = [canon_target_smiles[idx] == smi for idx, smi in enumerate(canon_smiles)] + + num_correct = sum(correct_smiles) + total = len(correct_smiles) + num_invalid = sum(invalid) + perc_invalid = num_invalid / total + accuracy = num_correct / total + + # Todo: need to move accuracy and perc_invalid to cuda for reducing later + metrics = { + "accuracy": accuracy, + "invalid": perc_invalid + } + + return metrics + + @staticmethod + def _calc_beam_metrics(sampled_smiles, target_smiles): + top_1_samples = [mols[0] for mols in sampled_smiles] + top_1_results = DecodeSampler._calc_greedy_metrics(top_1_samples, target_smiles) + + metrics = { + "top_1_accuracy": top_1_results["accuracy"], + "invalid": top_1_results["invalid"] + } + + ks = [2, 3, 5, 10, 20, 50] + num_samples_list = [k for k in ks if k <= len(sampled_smiles[0])] + + for num_samples in num_samples_list: + top_k_correct = [] + num_mols = len(sampled_smiles) + + for batch_idx, mols in enumerate(sampled_smiles): + samples = mols[:num_samples] + samples_mols = [Chem.MolFromSmiles(smi) for smi in samples] + samples_smiles = ["Unknown" if mol is None else Chem.MolToSmiles(mol) for mol in samples_mols] + correct_smiles = [smi == target_smiles[batch_idx] for smi in samples_smiles] + is_correct = sum(correct_smiles) >= 1 + top_k_correct.append(is_correct) + + accuracy = sum(top_k_correct) / num_mols + metrics[f"top_{str(num_samples)}_accuracy"] = accuracy + + return metrics \ No newline at end of file diff --git a/open_biomed/models/multimodal/mega_molbart/mega_mol_bart.py b/open_biomed/models/multimodal/mega_molbart/mega_mol_bart.py new file mode 100644 index 0000000..006a4ca --- /dev/null +++ b/open_biomed/models/multimodal/mega_molbart/mega_mol_bart.py @@ -0,0 +1,468 @@ +''' +Credit to https://github.com/NVIDIA/cheminformatics/blob/master/megamolbart/megamolbart/inference.py +''' +import logging +from functools import partial +from pathlib import Path +from typing import List +from rdkit import Chem +import random +import numpy as np + +import torch +from torch.nn.parallel import DistributedDataParallel as torchDDP +import pandas as pd +from megatron.checkpointing import load_checkpoint +import megatron.checkpointing as megatron_checkpointing +from megatron.global_vars import set_global_variables +from models.multimodal.mega_molbart.workflow import BaseGenerativeWorkflow, add_jitter +from .decoder import DecodeSampler +from megatron import get_args, mpu +from megatron.initialize import initialize_megatron +from .megatron_bart import MegatronBART +from .tokenizer import MolEncTokenizer +from .util import (REGEX, DEFAULT_CHEM_TOKEN_START, DEFAULT_MAX_SEQ_LEN, + DEFAULT_VOCAB_PATH, + DEFAULT_NUM_LAYERS, DEFAULT_D_MODEL, DEFAULT_NUM_HEADS) + + +logger = logging.getLogger(__name__) + + +@add_jitter.register(torch.Tensor) +def _(embedding, radius, cnt, shape): + if shape is not None: + embedding = torch.reshape(embedding, (1, shape[0], shape[1])).to(embedding.device) + permuted_emb = embedding.permute(1, 0, 2) + + distorteds = [] + for i in range(cnt): + noise = torch.normal(0, radius, permuted_emb.shape).to(embedding.device) + distorted = (noise + permuted_emb).permute(1, 0, 2) + distorteds.append(distorted) + + return distorteds + + +def use_model_module(model): + ''' Credit to https://github.com/MolecularAI/MolBART/blob/megatron-molbart-with-zinc/megatron_molbart/checkpointing.py#L20 ''' + use_model = isinstance(model, torchDDP) + try: + from deepspeed.runtime.engine import DeepSpeedEngine + except: + pass + else: + use_model = use_model | isinstance(model, DeepSpeedEngine) + return use_model + + +class MegaMolBART(BaseGenerativeWorkflow): + + def __init__(self, + input_dir=None, + output_dir=None, + max_seq_len=DEFAULT_MAX_SEQ_LEN, + vocab_path=DEFAULT_VOCAB_PATH, + regex=REGEX, + default_chem_token_start=DEFAULT_CHEM_TOKEN_START, + num_layers=DEFAULT_NUM_LAYERS, + hidden_size=DEFAULT_D_MODEL, + num_attention_heads=DEFAULT_NUM_HEADS, + decoder_max_seq_len=None, + grad_enabled=True) -> None: + super().__init__() + + torch.set_grad_enabled(grad_enabled) # Testing this instead of `with torch.no_grad():` context since it doesn't exit + + self.device = 'cuda' # Megatron arg loading seems to only work with GPU + self.min_jitter_radius = 1.0 + self.max_model_position_embeddings = max_seq_len + + args = { + 'num_layers': num_layers, + 'hidden_size': hidden_size, + 'num_attention_heads': num_attention_heads, + 'max_position_embeddings': self.max_model_position_embeddings, + 'tokenizer_type': 'GPT2BPETokenizer', + 'vocab_file': vocab_path, + } + if input_dir is not None: + args["load"] = input_dir + if output_dir is not None: + args["save"] = output_dir + args["save_interval"] = 1 + + initialize_megatron(args_defaults=args, ignore_unknown_args=True) + args = get_args() + self.tokenizer = self.load_tokenizer(args.vocab_file, regex, default_chem_token_start) + self.model = self.load_model(args, self.tokenizer, decoder_max_seq_len) + + def _compute_radius(self, scaled_radius): # TODO REMOVE + if scaled_radius: + return float(scaled_radius * self.min_jitter_radius) + else: + return self.min_jitter_radius + + def load_tokenizer(self, tokenizer_vocab_path, regex, default_chem_token_start): + """Load tokenizer from vocab file + + Params: + tokenizer_vocab_path: str, path to tokenizer vocab + + Returns: + MolEncTokenizer tokenizer object + """ + print("Loading vocab from {}.".format(tokenizer_vocab_path)) + tokenizer_vocab_path = Path(tokenizer_vocab_path) + tokenizer = MolEncTokenizer.from_vocab_file( + tokenizer_vocab_path, + regex, + default_chem_token_start) + + return tokenizer + + def load_model(self, args, tokenizer, decoder_max_seq_len=None): + """Load saved model checkpoint + + Params: + tokenizer: MolEncTokenizer tokenizer object + decoder_max_seq_len: int, maximum sequence length + args: Megatron initialized arguments + + Returns: + MegaMolBART trained model + """ + + vocab_size = len(tokenizer) + pad_token_idx = tokenizer.vocab[tokenizer.pad_token] + + # TODO how to handle length overrun for batch processing + if not decoder_max_seq_len: + decoder_max_seq_len = args.max_position_embeddings + + sampler = DecodeSampler(tokenizer, decoder_max_seq_len) + model = MegatronBART( + sampler, + pad_token_idx, + vocab_size, + args.hidden_size, + args.num_layers, + args.num_attention_heads, + args.hidden_size * 4, + args.max_position_embeddings, + dropout=0.1, + ) + if args.load is not None: + print("Loading from {}".format(args.load)) + self.iteration = load_checkpoint(model, None, None) + model = model.cuda() + return model + + def save_model(self, iteration, model, optimizer=None, lr_scheduler=None): + ''' Credit to https://github.com/MolecularAI/MolBART/blob/megatron-molbart-with-zinc/megatron_molbart/checkpointing.py#L46 ''' + + """Save a model checkpoint.""" + args = get_args() + + # Only rank zero of the data parallel writes to the disk. + if use_model_module(model): + model = model.module + + if mpu.get_data_parallel_rank() == 0: + + # Arguments, iteration, and model. + state_dict = {} + state_dict['args'] = args + state_dict['checkpoint_version'] = 2.0 + state_dict['iteration'] = iteration + state_dict['model'] = model.state_dict_for_save_checkpoint() + + # Optimizer stuff. + if not args.no_save_optim: + if optimizer is not None: + state_dict['optimizer'] = optimizer.state_dict() + if lr_scheduler is not None: + state_dict['lr_scheduler'] = lr_scheduler.state_dict() + + # RNG states. + if not args.no_save_rng: + state_dict['random_rng_state'] = random.getstate() + state_dict['np_rng_state'] = np.random.get_state() + state_dict['torch_rng_state'] = torch.get_rng_state() + state_dict['cuda_rng_state'] = torch.cuda.get_rng_state() + state_dict['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() + + # Save. + checkpoint_name = megatron_checkpointing.get_checkpoint_name(args.save, iteration) + print('global rank {} is saving checkpoint at iteration {:7d} to {}'. + format(torch.distributed.get_rank(), iteration, + checkpoint_name)) + megatron_checkpointing.ensure_directory_exists(checkpoint_name) + torch.save(state_dict, checkpoint_name) + print(' successfully saved {}'.format(checkpoint_name)) + + # Wait so everyone is done (necessary) + torch.distributed.barrier() + # And update the latest iteration + if torch.distributed.get_rank() == 0: + tracker_filename = megatron_checkpointing.get_checkpoint_tracker_filename(args.save) + with open(tracker_filename, 'w') as f: + f.write(str(iteration)) + # Wait so everyone is done (not necessary) + torch.distributed.barrier() + return + + def smiles2embedding(self, smiles, pad_length=None): + """Calculate embedding and padding mask for smiles with optional extra padding + + Params + smiles: string, input SMILES molecule + pad_length: optional extra + + Returns + embedding array and boolean mask + """ + + assert isinstance(smiles, str) + if pad_length: + assert pad_length >= len(smiles) + 2 + + tokens = self.tokenizer.tokenize([smiles], pad=True) + + # Append to tokens and mask if appropriate + if pad_length: + for i in range(len(tokens['original_tokens'])): + n_pad = pad_length - len(tokens['original_tokens'][i]) + tokens['original_tokens'][i] += [self.tokenizer.pad_token] * n_pad + tokens['masked_pad_masks'][i] += [1] * n_pad + + token_ids = torch.tensor(self.tokenizer.convert_tokens_to_ids(tokens['original_tokens'])).cuda().T + pad_mask = torch.tensor(tokens['masked_pad_masks']).bool().cuda().T + token_ids = token_ids[:self.max_model_position_embeddings] + pad_mask = pad_mask[:self.max_model_position_embeddings] + encode_input = {"encoder_input": token_ids, "encoder_pad_mask": pad_mask} + + embedding = self.model.encode(encode_input) + torch.cuda.empty_cache() + return embedding, pad_mask + + def smileslist2embedding(self, smiles_list): + if isinstance(smiles_list, dict): + self_smiles_list={} + self_smiles_list['input_ids'] = [tensor.unsqueeze(0) for tensor in smiles_list['input_ids']] + token_ids = torch.cat(self_smiles_list['input_ids'], dim=0).cuda() + self_smiles_list['pad_masks'] = [tensor.unsqueeze(0) for tensor in smiles_list['pad_masks']] + pad_mask = torch.cat(self_smiles_list['pad_masks'], dim=0).bool().cuda() + else: + tokens = self.tokenizer.tokenize(smiles_list, pad=True) + token_ids = torch.tensor(self.tokenizer.convert_tokens_to_ids(tokens['original_tokens'])).cuda().T + pad_mask = torch.tensor(tokens['masked_pad_masks']).bool().cuda().T + + token_ids = token_ids[:self.max_model_position_embeddings] + pad_mask = pad_mask[:self.max_model_position_embeddings] + encode_input = {"encoder_input": token_ids, "encoder_pad_mask": pad_mask} + + embedding = self.model.encode(encode_input) + torch.cuda.empty_cache() + return embedding, pad_mask + + def smileslist2embedding_model_given(self, model, smiles_list): + if isinstance(smiles_list, dict): + self_smiles_list={} + self_smiles_list['input_ids'] = [tensor.unsqueeze(0) for tensor in smiles_list['input_ids']] + token_ids = torch.cat(self_smiles_list['input_ids'], dim=0).cuda() + self_smiles_list['pad_masks'] = [tensor.unsqueeze(0) for tensor in smiles_list['pad_masks']] + pad_mask = torch.cat(self_smiles_list['pad_masks'], dim=0).bool().cuda() + else: + tokens = self.tokenizer.tokenize(smiles_list, pad=True) + token_ids = torch.tensor(self.tokenizer.convert_tokens_to_ids(tokens['original_tokens'])).cuda().T + pad_mask = torch.tensor(tokens['masked_pad_masks']).bool().cuda().T + + token_ids = token_ids[:self.max_model_position_embeddings] + pad_mask = pad_mask[:self.max_model_position_embeddings] + encode_input = {"encoder_input": token_ids, "encoder_pad_mask": pad_mask} + + embedding = model.encode(encode_input) + torch.cuda.empty_cache() + return embedding, pad_mask + + def inverse_transform(self, embeddings, mem_pad_mask, k=1, sanitize=True): + mem_pad_mask = mem_pad_mask.clone() + smiles_interp_list = [] + + batch_size = 1 # TODO: parallelize this loop as a batch + with torch.no_grad(): + for memory in embeddings: + + if isinstance(memory, list): + memory = torch.FloatTensor(memory).cuda() + + decode_fn = partial(self.model._decode_fn, + mem_pad_mask=mem_pad_mask.type(torch.LongTensor).cuda(), + memory=memory) + + mol_strs, _ = self.model.sampler.beam_decode(decode_fn, + batch_size=batch_size, + device='cuda', + k=k) + mol_strs = sum(mol_strs, []) # flatten list + + # TODO: add back sanitization and validity checking once model is trained + logger.warn('WARNING: MOLECULE VALIDATION AND SANITIZATION CURRENTLY DISABLED') + for smiles in mol_strs: + if sanitize: + mol = Chem.MolFromSmiles(smiles, sanitize=sanitize) + if mol: + sanitized_smiles = Chem.MolToSmiles(mol) + smiles_interp_list.append(sanitized_smiles) + logger.debug(f'Sanitized SMILES {sanitized_smiles} added...') + break + smiles_interp_list.append(smiles) + + return smiles_interp_list + + def interpolate_molecules(self, smiles1, smiles2, num_interp, tokenizer, k=1): + """Interpolate between two molecules in embedding space. + + Params + smiles1: str, input SMILES molecule + smiles2: str, input SMILES molecule + num_interp: int, number of molecules to interpolate + tokenizer: MolEncTokenizer tokenizer object + k: number of molecules for beam search, default 1. Can increase if there are issues with validity + + Returns + list of interpolated smiles molecules + """ + + pad_length = max(len(smiles1), len(smiles2)) + 2 # add 2 for start / stop + embedding1, pad_mask1 = self.smiles2embedding(smiles1, + pad_length=pad_length) + + embedding2, pad_mask2 = self.smiles2embedding(smiles2, + pad_length=pad_length) + + scale = torch.linspace(0.0, 1.0, num_interp + 2)[ + 1:-1] # skip first and last because they're the selected molecules + scale = scale.unsqueeze(0).unsqueeze(-1).cuda() + + interpolated_emb = torch.lerp(embedding1, embedding2, scale).cuda() # dims: batch, tokens, embedding + combined_mask = (pad_mask1 & pad_mask2).bool().cuda() + + embeddings = [] + dims = [] + for emb in interpolated_emb.permute(1, 0, 2): + dims.append(emb.shape) + embeddings.append(emb) + + generated_mols = self.inverse_transform(embeddings, + combined_mask, + k=k, + sanitize=True) + generated_mols = [smiles1] + generated_mols + [smiles2] + embeddings = [embedding1] + embeddings + [embedding2] + dims = [embedding1.shape] + dims + [embedding2.shape] + return generated_mols, embeddings, combined_mask, dims + + def find_similars_smiles_list(self, + smiles: str, + num_requested: int = 10, + scaled_radius=None, + force_unique=False): + distance = self._compute_radius(scaled_radius) + logger.info(f'Computing with distance {distance}...') + + embedding, pad_mask = self.smiles2embedding(smiles) + + neighboring_embeddings = self.addjitter(embedding, distance, cnt=num_requested) + + generated_mols = self.inverse_transform(neighboring_embeddings, + pad_mask.bool().cuda(), + k=1, sanitize=True) + if force_unique: + generated_mols = list(set(generated_mols)) + + generated_mols = [smiles] + generated_mols + neighboring_embeddings = [embedding] + neighboring_embeddings + return generated_mols, neighboring_embeddings, pad_mask + + def find_similars_smiles(self, + smiles: str, + num_requested: int = 10, + scaled_radius=None, + force_unique=False): + generated_mols, neighboring_embeddings, pad_mask = \ + self.find_similars_smiles_list(smiles, + num_requested=num_requested, + scaled_radius=scaled_radius, + force_unique=force_unique) + + # Rest of the applications and libraries use RAPIDS and cuPY libraries. + # For interoperability, we need to convert the embeddings to cupy. + embeddings = [] + dims = [] + for neighboring_embedding in neighboring_embeddings: + dims.append(neighboring_embedding.shape) + embeddings.append(neighboring_embedding.flatten().tolist()) + + generated_df = pd.DataFrame({'SMILES': generated_mols, + 'embeddings': embeddings, + 'embeddings_dim': dims, + 'Generated': [True for i in range(len(generated_mols))]}) + generated_df.iat[0, 3] = False + + if force_unique: + inv_transform_funct = partial(self.inverse_transform, + mem_pad_mask=pad_mask) + generated_df = self.compute_unique_smiles(generated_df, + inv_transform_funct, + scaled_radius=scaled_radius) + return generated_df + + def interpolate_smiles(self, + smiles: List, + num_points: int = 10, + scaled_radius=None, + force_unique=False): + num_points = int(num_points) + if len(smiles) < 2: + raise Exception('At-least two or more smiles are expected') + + k = 1 + result_df = [] + for idx in range(len(smiles) - 1): + interpolated_mol, interpolated_embeddings, combined_mask, dims = \ + self.interpolate_molecules(smiles[idx], + smiles[idx + 1], + num_points, + self.tokenizer, + k=k) + + # Rest of the applications and libraries use RAPIDS and cuPY libraries. + # For interoperability, we need to convert the embeddings to cupy. + embeddings = [] + for interpolated_embedding in interpolated_embeddings: + embeddings.append(interpolated_embedding.cpu()) + + interp_df = pd.DataFrame({'SMILES': interpolated_mol, + 'embeddings': embeddings, + 'embeddings_dim': dims, + 'Generated': [True for i in range(len(interpolated_mol))]}) + + inv_transform_funct = partial(self.inverse_transform, mem_pad_mask=combined_mask) + + # Mark the source and desinations as not generated + interp_df.iat[0, 3] = False + interp_df.iat[-1, 3] = False + + if force_unique: + interp_df = self.compute_unique_smiles(interp_df, + inv_transform_funct, + scaled_radius=scaled_radius) + + result_df.append(interp_df) + + result_df = pd.concat(result_df) + smile_list = list(result_df['SMILES']) + + return result_df, smile_list \ No newline at end of file diff --git a/open_biomed/models/multimodal/mega_molbart/megatron_bart.py b/open_biomed/models/multimodal/mega_molbart/megatron_bart.py new file mode 100644 index 0000000..e780307 --- /dev/null +++ b/open_biomed/models/multimodal/mega_molbart/megatron_bart.py @@ -0,0 +1,800 @@ +from megatron.module import MegatronModule +from apex.normalization import FusedLayerNorm +from megatron import mpu +from torch.nn import init +import torch.nn as nn +import torch.nn.functional as F +import torch +import math +from functools import partial +from .tokenizer import load_tokenizer +from .util import DEFAULT_CHEM_TOKEN_START, DEFAULT_VOCAB_PATH, REGEX + + +class MultiheadAttention(MegatronModule): + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + cross_attention=False, + init_method=init.xavier_uniform_, + ): + + super(MultiheadAttention, self).__init__() + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.attn_dropout = nn.Dropout(p=dropout) + self.bias = bias + self.cross_attention = cross_attention + self.head_dim = self.embed_dim // self.num_heads + self.scaling = self.head_dim ** -0.5 + self.init_method = init_method + self.skip_bias = not bias + + # Self-Attention is Column Parallelized + self.query_key_value = mpu.ColumnParallelLinear(self.embed_dim, + 3 * self.embed_dim, gather_output=True, + init_method=self.init_method, + skip_bias_add=self.skip_bias) + + # Cross-Attention is Row and Column Parallelized + self.q_proj = mpu.RowParallelLinear(self.embed_dim, + self.embed_dim, input_is_parallel=False, + init_method=self.init_method, bias=bias, + skip_bias_add=self.skip_bias) + self.key_value = mpu.ColumnParallelLinear(self.embed_dim, 2 + * self.embed_dim, gather_output=True, + init_method=self.init_method, + skip_bias_add=self.skip_bias) + + # Final projection is Row Parallelized + self.out_proj = mpu.RowParallelLinear(self.embed_dim, + self.embed_dim, input_is_parallel=False, + init_method=self.init_method, bias=bias) + + def forward( + self, + query, + key=None, + value=None, + key_padding_mask=None, + attn_mask=None, + ): + """Input shape: Time x Batch x Channel + + Args: + query - tokens/states of shape [Time x Batch x Channel] + key - tokens/states of shape [Time x Batch x Channel] + value - tokens/states of shape [Time x Batch x Channel] + key_padding_mask - keys that are pads where padding + elements are indicated by 1s. Shape: [batch, src_len]. + attn_mask - typically used to implement causal attention, where + the mask prevents the attention from looking forward in time. + Shape: [tgt_len, src_len]. + Returns: + outputs - attention probability scores of shape (Time x Batch x Channel) + """ + + (tgt_len, bsz, embed_dim) = query.size() + + # Compute attention projections + if not self.cross_attention: + (q_k_v, bias) = self.query_key_value(query) + (q, k, v) = mpu.split_tensor_along_last_dim(q_k_v, 3) + else: + q, _ = self.q_proj(query) + if key is None: + assert value is None, \ + 'Cross attention mode: since key is None, value must also be None.' + k = v = None + else: + (k_v, bias) = self.key_value(key) + (k, v) = mpu.split_tensor_along_last_dim(k_v, 2) + + # Scale query and reshape + q = q.contiguous() + q *= self.scaling + q = q.view(tgt_len, bsz * self.num_heads, + self.head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, + self.head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, + self.head_dim).transpose(0, 1) + + # Compute attention scores + src_len = k.size(1) + attn_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_weights.size()) == [bsz * self.num_heads, + tgt_len, src_len] + + # Apply causal attention mask + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + # Apply padding mask + if key_padding_mask is not None: + attn_weights = attn_weights.view(bsz, self.num_heads, + tgt_len, src_len) + attn_weights = \ + attn_weights.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float('-inf')) + attn_weights = attn_weights.view(bsz * self.num_heads, + tgt_len, src_len) + + # Compute attention probabilities + attn_weights = F.softmax(attn_weights, dim=-1) + attn_probs = self.attn_dropout(attn_weights) + + # Compute context and output projection + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, + self.head_dim] + if attn.size(1) == 1: # a single decoder step (sequence length == 1) + attn = attn.contiguous().view(tgt_len, bsz, embed_dim) + else: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, + embed_dim) + (attn, bias) = self.out_proj(attn) + attn_output_weights = attn_probs.view(bsz, self.num_heads, + tgt_len, src_len) + attn_output_weights = attn_output_weights.sum(dim=1) \ + / self.num_heads + return (attn, attn_output_weights) + + +class EncoderLayer(MegatronModule): + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + init_method=init.xavier_uniform_, + ): + + super(EncoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + embed_dim, + num_heads, + dropout=dropout, + bias=bias, + cross_attention=False, + init_method=init_method, + ) + self.self_attn_layer_norm = FusedLayerNorm(embed_dim) + self.attn_dropout = nn.Dropout(p=dropout) + self.activation_fn = F.gelu + self.activation_dropout = nn.Dropout(p=dropout) + self.fc1 = mpu.ColumnParallelLinear(embed_dim, 4 + * embed_dim, gather_output=False, + init_method=init_method, skip_bias_add=False) + self.fc2 = mpu.RowParallelLinear(4 * embed_dim, + embed_dim, input_is_parallel=True, + init_method=init_method, skip_bias_add=False) + self.final_layer_norm = FusedLayerNorm(embed_dim) + + def forward( + self, + x, + encoder_padding_mask=None, + attn_mask=None, + ): + """ + Args: + x: input to the layer of shape (seq_len, batch, embed_dim) + encoder_padding_mask: binary ByteTensor of shape + (batch, seq_len) where padding elements are indicated by 1. + attn_mask: binary tensor of shape (tgt_len, src_len), + where tgt_len is the length of output and src_len is the + length of input, though here both are equal to seq_len. + Returns: + encoded output of shape (seq_len, batch, embed_dim) + """ + + if attn_mask is not None: + attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), + -1e8) + residual = x + x = self.self_attn_layer_norm(x) + (x, weights) = self.self_attn(query=x, key=x, value=x, + key_padding_mask=encoder_padding_mask, + attn_mask=attn_mask) + x = self.attn_dropout(x) + x = x + residual + residual = x + x = self.final_layer_norm(x) + x, _ = self.fc1(x) + x = self.activation_fn(x) + x = self.activation_dropout(x) + x, _ = self.fc2(x) + x = self.attn_dropout(x) + x = x + residual + return x + + +class DecoderLayer(MegatronModule): + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + init_method=init.xavier_uniform_, + ): + + super(DecoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + embed_dim, + num_heads, + dropout=dropout, + bias=bias, + cross_attention=False, + init_method=init_method, + ) + self.self_attn_layer_norm = FusedLayerNorm(embed_dim) + self.encoder_attn = MultiheadAttention( + embed_dim, + num_heads, + dropout=dropout, + bias=bias, + cross_attention=True, + init_method=init_method, + ) + self.encoder_attn_layer_norm = FusedLayerNorm(embed_dim) + self.dropout = nn.Dropout(p=dropout) + self.activation_fn = F.gelu + self.activation_dropout = nn.Dropout(p=dropout) + self.fc1 = mpu.ColumnParallelLinear(embed_dim, 4 + * embed_dim, gather_output=False, + init_method=init_method, skip_bias_add=False) + self.fc2 = mpu.RowParallelLinear(4 * embed_dim, + embed_dim, input_is_parallel=True, + init_method=init_method, skip_bias_add=False) + self.final_layer_norm = FusedLayerNorm(embed_dim) + + def forward( + self, + x, + encoder_out=None, + encoder_padding_mask=None, + self_attn_mask=None, + self_attn_padding_mask=None, + ): + """ + Args: + x: input to decoder layer of shape (seq_len, batch, embed_dim) + encoder_out: output from the encoder + encoder_padding_mask: binary ByteTensor of shape + (batch, seq_len) where padding elements are indicated by 1 + self_attn_mask: binary tensor of shape (tgt_len, src_len), + where tgt_lent is the length of output and src_len is the + length of input, though here both are equal to seq_len. + self_attn_padding_mask: binary ByteTensor of shape + (batch, seq_len) where padding elements are indicated by 1. + Returns: + encoded output of shape (seq_len, batch, embed_dim) + """ + + residual = x + x = self.self_attn_layer_norm(x) + + # Self-Attention block + + (x, weights) = self.self_attn(query=x, key=x, value=x, + key_padding_mask=self_attn_padding_mask, + attn_mask=self_attn_mask) + x = self.dropout(x) + x = x + residual + + # Cross-Attention block + if encoder_out is not None: + residual = x + x = self.encoder_attn_layer_norm(x) + (x, attn) = self.encoder_attn(query=x, key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask) + x = self.dropout(x) + x = x + residual + residual = x + x = self.final_layer_norm(x) + + # Fully-connected block + x, _ = self.fc1(x) + x = self.activation_fn(x) + x = self.activation_dropout(x) + x, _ = self.fc2(x) + x = self.dropout(x) + x = x + residual + return x + + +class ParallelTransformerEncoder(MegatronModule): + + def __init__( + self, + num_layers, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + init_method=init.xavier_uniform_, + ): + + super(ParallelTransformerEncoder, self).__init__() + self.layers = nn.ModuleList([]) + self.num_layers = num_layers + self.embed_dim = embed_dim + self.num_heads = num_heads + self.attn_dropout = dropout + self.bias = bias + self.init_method = init_method + self.layers.extend([self.build_encoder_layer() for i in + range(self.num_layers)]) + self.norm = FusedLayerNorm(self.embed_dim) + + def build_encoder_layer(self): + layer = EncoderLayer(self.embed_dim, self.num_heads, + dropout=self.attn_dropout, bias=self.bias, + init_method=self.init_method) + return layer + + def forward( + self, + src, + mask=None, + src_key_padding_mask=None, + ): + """Pass the input through the encoder layers in turn. + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + Returns: + encoded output of shape (src_len, batch, embed_dim) + """ + + output = src + for mod in self.layers: + output = mod(output, attn_mask=mask, + encoder_padding_mask=src_key_padding_mask) + output = self.norm(output) + return output + + +class ParallelTransformerDecoder(MegatronModule): + + def __init__( + self, + num_layers, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + init_method=init.xavier_uniform_, + ): + + super(ParallelTransformerDecoder, self).__init__() + self.layers = nn.ModuleList([]) + self.num_layers = num_layers + self.embed_dim = embed_dim + self.num_heads = num_heads + self.attn_dropout = dropout + self.bias = bias + self.init_method = init_method + self.layers.extend([self.build_decoder_layer() for i in + range(self.num_layers)]) + self.norm = FusedLayerNorm(self.embed_dim) + + def build_decoder_layer(self): + layer = DecoderLayer(self.embed_dim, self.num_heads, + dropout=self.attn_dropout, bias=self.bias, + init_method=self.init_method) + return layer + + def forward( + self, + tgt, + memory, + tgt_mask=None, + tgt_key_padding_mask=None, + memory_key_padding_mask=None, + ): + """Pass the inputs (and mask) through the decoder layer in turn. + Args: + tgt: the sequence to the decoder (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + Returns: + decoded output of shape (tgt_len, batch, embed_dim) + """ + + output = tgt + for mod in self.layers: + output = mod(output, encoder_out=memory, + encoder_padding_mask=memory_key_padding_mask, + self_attn_mask=tgt_mask, + self_attn_padding_mask=tgt_key_padding_mask) + output = self.norm(output) + return output + + +class MegatronBART(MegatronModule): + + def __init__( + self, + decode_sampler, + pad_token_idx, + vocab_size, + d_model, + num_layers, + num_heads, + d_feedforward, + max_seq_len, + dropout=0.0, + ): + + super().__init__() + + self.sampler = decode_sampler + self.pad_token_idx = pad_token_idx + self.val_sampling_alg = 'greedy' + self.num_beams = 5 + self.vocab_size = vocab_size + self.d_model = d_model + self.num_layers = num_layers + self.num_heads = num_heads + self.d_feedforward = d_feedforward + self.max_seq_len = max_seq_len + self.dropout = dropout + self.emb_dropout = nn.Dropout(p=dropout) + init_method = init.xavier_uniform_ + + self.emb = nn.Embedding(vocab_size, d_model) + self.dropout = dropout + self.encoder = ParallelTransformerEncoder( + self.num_layers, + self.d_model, + self.num_heads, + self.dropout, + bias=True, + init_method=init_method, + ) + self.decoder = ParallelTransformerDecoder( + self.num_layers, + self.d_model, + self.num_heads, + self.dropout, + bias=True, + init_method=init_method, + ) + self.token_fc = mpu.RowParallelLinear(d_model, vocab_size, + input_is_parallel=False, init_method=init_method, + skip_bias_add=False) + self.loss_fn = nn.CrossEntropyLoss(reduction='none', + ignore_index=pad_token_idx) + self.log_softmax = nn.LogSoftmax(dim=2) + self._init_params(init_method) + self.register_buffer('pos_emb', self._positional_embs()) + + def forward(self, x): + """ Apply SMILES strings to model + + The dictionary returned will be passed to other functions, so its contents are fairly flexible, + except that it must contain the key "token_output" which is the output of the model + (possibly after any fully connected layers) for each token. + + Arg: + x (dict { + "encoder_input": tensor of token_ids of shape (src_len, batch_size), + "encoder_pad_mask": bool tensor of padded elems of shape (src_len, batch_size), + "decoder_input": tensor of decoder token_ids of shape (tgt_len, batch_size) + "decoder_pad_mask": bool tensor of decoder padding mask of shape (tgt_len, batch_size) + }): + + Returns: + Output from model (dict containing key "token_output" and "model_output") + """ + + encoder_input = x['encoder_input'] + decoder_input = x['decoder_input'] + encoder_pad_mask = x['encoder_pad_mask'].transpose(0, 1) + decoder_pad_mask = x['decoder_pad_mask'].transpose(0, 1) + + encoder_embs = self._construct_input(encoder_input) + decoder_embs = self._construct_input(decoder_input) + + (seq_len, _, _) = tuple(decoder_embs.size()) + tgt_mask = \ + self._generate_square_subsequent_mask(seq_len).to(decoder_embs.device) + + memory = self.encoder(encoder_embs, + src_key_padding_mask=encoder_pad_mask) + model_output = self.decoder(decoder_embs, memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=decoder_pad_mask, + memory_key_padding_mask=encoder_pad_mask.clone()) + + token_output, _ = self.token_fc(model_output) + output = {'model_output': model_output, + 'token_output': token_output} + + return output + + def encode(self, batch): + """ Construct the memory embedding for an encoder input + + Args: + batch (dict { + "encoder_input": tensor of token_ids of shape (src_len, batch_size), + "encoder_pad_mask": bool tensor of padded elems of shape (src_len, batch_size), + }) + + Returns: + encoder memory (Tensor of shape (seq_len, batch_size, d_model)) + """ + + encoder_input = batch['encoder_input'] + encoder_pad_mask = batch['encoder_pad_mask'].transpose(0, 1) + encoder_embs = self._construct_input(encoder_input) + model_output = self.encoder(encoder_embs, + src_key_padding_mask=encoder_pad_mask) + return model_output + + def decode(self, batch): + """ Construct an output from a given decoder input + + Args: + batch (dict { + "decoder_input": tensor of decoder token_ids of shape (tgt_len, batch_size) + "decoder_pad_mask": bool tensor of decoder padding mask of shape (tgt_len, batch_size) + "memory_input": tensor from encoded input of shape (src_len, batch_size, d_model) + "memory_pad_mask": bool tensor of memory padding mask of shape (src_len, batch_size) + }) + """ + + decoder_input = batch['decoder_input'] + decoder_pad_mask = batch['decoder_pad_mask'].transpose(0, 1) + memory_input = batch['memory_input'] + memory_pad_mask = batch['memory_pad_mask'].transpose(0, 1) + + decoder_embs = self._construct_input(decoder_input) + + (seq_len, _, _) = tuple(decoder_embs.size()) + tgt_mask = \ + self._generate_square_subsequent_mask(seq_len).to(decoder_embs.device) + + model_output = self.decoder(decoder_embs, memory_input, + tgt_key_padding_mask=decoder_pad_mask, + memory_key_padding_mask=memory_pad_mask, + tgt_mask=tgt_mask) + token_output, _ = self.token_fc(model_output) + token_probs = self.log_softmax(token_output) + return token_probs + + def validation_step(self, batch, batch_idx=None): + self.eval() + # TODO: This can be further optimized + tokenizer = load_tokenizer(vocab_path=DEFAULT_VOCAB_PATH, chem_token_start=DEFAULT_CHEM_TOKEN_START, regex=REGEX) + + with torch.no_grad(): + model_output = self.forward(batch) + #target_smiles = batch['target_smiles'] + token_ids = batch['target'] + tokens = token_ids.transpose(0, 1).tolist() + tokens = tokenizer.convert_ids_to_tokens(tokens) + target_smiles = tokenizer.detokenize(tokens) + + loss = self._calc_loss(batch, model_output) + token_acc = self._calc_char_acc(batch, model_output) + perplexity = self._calc_perplexity(batch, model_output) + (mol_strs, log_lhs) = self.sample_molecules(batch, + sampling_alg=self.val_sampling_alg) + metrics = self.sampler.calc_sampling_metrics(mol_strs, + target_smiles) + + self.train() + + val_outputs = { + 'val_loss': loss.item(), + 'val_token_acc': token_acc, + 'val_perplexity': perplexity, + 'val_molecular_accuracy': metrics['accuracy'], + 'val_invalid_smiles': metrics['invalid'], + } + return val_outputs + + def _calc_loss(self, batch_input, model_output): + """ Calculate the loss for the model + + Args: + batch_input (dict): Input given to model, + model_output (dict): Output from model + + Returns: + loss (singleton tensor), + """ + + tokens = batch_input['target'] + pad_mask = batch_input['target_pad_mask'] + token_output = model_output['token_output'] + token_mask_loss = self._calc_mask_loss(token_output, tokens, + pad_mask) + return token_mask_loss + + def _calc_mask_loss( + self, + token_output, + target, + target_mask, + ): + """ Calculate the loss for the token prediction task + + Args: + token_output (Tensor of shape (seq_len, batch_size, vocab_size)): token output from transformer + target (Tensor of shape (seq_len, batch_size)): Original (unmasked) SMILES token ids from the tokenizer + target_mask (Tensor of shape (seq_len, batch_size)): Pad mask for target tokens + + Output: + loss (singleton Tensor): Loss computed using cross-entropy, + """ + + (seq_len, batch_size) = tuple(target.size()) + token_pred = token_output.reshape((seq_len * batch_size, + -1)).float() + loss = self.loss_fn(token_pred, + target.reshape(-1)).reshape((seq_len, + batch_size)) + inv_target_mask = ~(target_mask > 0) + num_tokens = inv_target_mask.sum() + loss = loss.sum() / num_tokens + return loss + + def _calc_perplexity(self, batch_input, model_output): + target_ids = batch_input['target'] + target_mask = batch_input['target_pad_mask'] + vocab_dist_output = model_output['token_output'] + inv_target_mask = ~(target_mask > 0) + log_probs = vocab_dist_output.gather(2, + target_ids.unsqueeze(2)).squeeze(2) + log_probs = log_probs * inv_target_mask + log_probs = log_probs.sum(dim=0) + seq_lengths = inv_target_mask.sum(dim=0) + exp = -(1 / seq_lengths) + perp = torch.pow(log_probs.exp(), exp) + return perp.mean().item() + + def _calc_char_acc(self, batch_input, model_output): + token_ids = batch_input['target'] + target_mask = batch_input['target_pad_mask'] + token_output = model_output['token_output'] + target_mask = ~(target_mask > 0) + (_, pred_ids) = torch.max(token_output.float(), dim=2) + correct_ids = torch.eq(token_ids, pred_ids) + correct_ids = correct_ids * target_mask + num_correct = correct_ids.sum() + total = target_mask.sum() + accuracy = num_correct / total + return accuracy + + def sample_molecules(self, batch_input, sampling_alg='greedy'): + """ Sample molecules from the model + + Args: + batch_input (dict): Input given to model + sampling_alg (str): Algorithm to use to sample SMILES strings from model + + Returns: + ([[str]], [[float]]): Tuple of molecule SMILES strings and log lhs (outer dimension is batch) + """ + + enc_input = batch_input['encoder_input'] + enc_mask = batch_input['encoder_pad_mask'] + + # Freezing the weights reduces the amount of memory leakage in the transformer + #model.eval() + + with torch.no_grad(): + + encode_input = {'encoder_input': enc_input, + 'encoder_pad_mask': enc_mask} + memory = self.encode(encode_input) + mem_mask = enc_mask.clone() + (_, batch_size, _) = tuple(memory.size()) + decode_fn = partial(self._decode_fn, memory=memory, + mem_pad_mask=mem_mask) + #self.sampler.device = self.device + if sampling_alg == 'greedy': + (mol_strs, log_lhs) = \ + self.sampler.greedy_decode(decode_fn, batch_size,device=memory.device) + elif sampling_alg == 'beam': + (mol_strs, log_lhs) = \ + self.sampler.beam_decode(decode_fn, batch_size, + self.num_beams,device=memory.device) + + # Must remember to unfreeze! + #model.train() + + return (mol_strs, log_lhs) + + def _decode_fn( + self, + token_ids, + pad_mask, + memory, + mem_pad_mask, + ): + decode_input = { + 'decoder_input': token_ids, + 'decoder_pad_mask': pad_mask, + 'memory_input': memory, + 'memory_pad_mask': mem_pad_mask, + } + model_output = self.decode(decode_input) + return model_output + + def _construct_input(self, token_ids, sentence_masks=None): + (seq_len, _) = tuple(token_ids.size()) + token_embs = self.emb(token_ids) + + # Scaling the embeddings like this is done in other transformer libraries + token_embs = token_embs * math.sqrt(self.d_model) + positional_embs = self.pos_emb[:seq_len, : + ].unsqueeze(0).transpose(0, 1) + embs = token_embs + positional_embs + embs = self.emb_dropout(embs) + return embs + + def _positional_embs(self): + """ Produces a tensor of positional embeddings for the model + + Returns a tensor of shape (self.max_seq_len, self.d_model) filled with positional embeddings, + which are created from sine and cosine waves of varying wavelength + """ + + encs = torch.tensor([dim / self.d_model for dim in range(0, + self.d_model, 2)]) + encs = 10000 ** encs + encs = [(torch.sin(pos / encs), torch.cos(pos / encs)) + for pos in range(self.max_seq_len)] + encs = [torch.stack(enc, dim=1).flatten()[:self.d_model] + for enc in encs] + encs = torch.stack(encs) + return encs + + def _generate_square_subsequent_mask(self, sz): + """ + Method copied from Pytorch nn.Transformer. + Generate a square mask for the sequence. The masked positions are filled with float('-inf'). + Unmasked positions are filled with float(0.0). + + Args: + sz (int): Size of mask to generate + + Returns: + torch.Tensor: Square autoregressive mask for decode + """ + + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf' + )).masked_fill(mask == 1, float(0.0)) + return mask + + def _init_params(self, method): + """ + Apply initialisation of learnable weights + """ + + for p in self.parameters(): + if p.dim() > 1: + method(p) diff --git a/open_biomed/models/multimodal/mega_molbart/tokenizer.py b/open_biomed/models/multimodal/mega_molbart/tokenizer.py new file mode 100644 index 0000000..f62c319 --- /dev/null +++ b/open_biomed/models/multimodal/mega_molbart/tokenizer.py @@ -0,0 +1,483 @@ +# coding=utf-8 + +import re +import torch +import random +from pathlib import Path +from .util import (DEFAULT_BEGIN_TOKEN, DEFAULT_END_TOKEN, DEFAULT_PAD_TOKEN, \ + DEFAULT_UNK_TOKEN, DEFAULT_MASK_TOKEN, DEFAULT_SEP_TOKEN, \ + DEFAULT_MASK_PROB, DEFAULT_SHOW_MASK_TOKEN_PROB, DEFAULT_MASK_SCHEME, \ + DEFAULT_SPAN_LAMBDA, DEFAULT_VOCAB_PATH, DEFAULT_CHEM_TOKEN_START, REGEX) + + +class MolEncTokenizer(): + def __init__( + self, + vocab, + chem_token_idxs, + prog, + begin_token=DEFAULT_BEGIN_TOKEN, + end_token=DEFAULT_END_TOKEN, + pad_token=DEFAULT_PAD_TOKEN, + unk_token=DEFAULT_UNK_TOKEN, + mask_token=DEFAULT_MASK_TOKEN, + sep_token=DEFAULT_SEP_TOKEN, + mask_prob=DEFAULT_MASK_PROB, + show_mask_token_prob=DEFAULT_SHOW_MASK_TOKEN_PROB, + mask_scheme=DEFAULT_MASK_SCHEME, + span_lambda=DEFAULT_SPAN_LAMBDA + ): + """ Initialise the tokenizer + + Args: + vocab (List[str]): Vocabulary for tokenizer + chem_token_idxs (List[int]): List of idxs of chemical tokens + prog (re.Pattern): Regex object for tokenizing + begin_token (str): Token to use at start of each sequence + end_token (str): Token to use at end of each sequence + pad_token (str): Token to use when padding batches of sequences + unk_token (str): Token to use for tokens which are not in the vocabulary + mask_token (str): Token to use when masking pieces of the sequence + sep_token (str): Token to use when sepatating two sentences + mask_prob (float): Probability of token being masked when masking is enabled + show_mask_token_prob (float): Probability of a masked token being replaced with mask token + mask_scheme (str): Masking scheme used by the tokenizer when masking + span_lambda (float): Mean for poisson distribution when sampling a span of tokens + """ + + self.vocab = {t: i for i, t in enumerate(vocab)} + self.decode_vocab = {i: t for t, i in self.vocab.items()} + self.chem_token_idxs = chem_token_idxs + self.prog = prog + + self.begin_token = begin_token + self.end_token = end_token + self.pad_token = pad_token + self.unk_token = unk_token + self.mask_token = mask_token + self.sep_token = sep_token + + self.mask_prob = mask_prob + self.show_mask_token_prob = show_mask_token_prob + self.mask_scheme = mask_scheme + self.span_lambda = span_lambda + + self.unk_id = self.vocab[unk_token] + self.unk_token_cnt = {} + + @staticmethod + def from_vocab_file( + vocab_path, + regex, + chem_tokens_start_idx, + pad_token_idx=0, + unk_token_idx=1, + begin_token_idx=2, + end_token_idx=3, + mask_token_idx=4, + sep_token_idx=5, + mask_prob=DEFAULT_MASK_PROB, + show_mask_token_prob=DEFAULT_SHOW_MASK_TOKEN_PROB, + mask_scheme=DEFAULT_MASK_SCHEME, + span_lambda=DEFAULT_SPAN_LAMBDA + ): + """ Load the tokenizer object from a vocab file and regex + + Reads a newline separated list of tokens from a file to use as the vocabulary + Note: Assumes that the chemical tokens run from chem_tokens_start_idx to the end of the tokens list + Anything after the defined tokens and before chem_tokens_start_idx is assumed to be an extra token + and is added to the regex for tokenizing + + Args: + vocab_path (str): Path to vocab file + regex (str): Regex to use for tokenizing + chem_tokens_start_idx (int): Index of the start of the chemical tokens in the tokens list + + Returns: + MolEncTokenizer object + """ + + text = Path(vocab_path).read_text() + tokens = text.split("\n") + tokens = [t for t in tokens if t is not None and t != ""] + + token_idxs = [pad_token_idx, unk_token_idx, begin_token_idx, end_token_idx, mask_token_idx, sep_token_idx] + extra_tokens_idxs = range(max(token_idxs) + 1, chem_tokens_start_idx) + extra_tokens = [tokens[idx] for idx in extra_tokens_idxs] + prog = MolEncTokenizer._get_compiled_regex(regex, extra_tokens) + + pad_token = tokens[pad_token_idx] + unk_token = tokens[unk_token_idx] + begin_token = tokens[begin_token_idx] + end_token = tokens[end_token_idx] + mask_token = tokens[mask_token_idx] + sep_token = tokens[sep_token_idx] + + chem_tokens_idxs = list(range(chem_tokens_start_idx, len(tokens))) + tokenizer = MolEncTokenizer( + tokens, + chem_tokens_idxs, + prog, + begin_token=begin_token, + end_token=end_token, + pad_token=pad_token, + unk_token=unk_token, + mask_token=mask_token, + sep_token=sep_token, + mask_prob=mask_prob, + show_mask_token_prob=show_mask_token_prob, + mask_scheme=mask_scheme, + span_lambda=span_lambda + ) + return tokenizer + @staticmethod + + def from_pretrained( + vocab_path, + regex=REGEX, + chem_tokens_start_idx=DEFAULT_CHEM_TOKEN_START, + pad_token_idx=0, + unk_token_idx=1, + begin_token_idx=2, + end_token_idx=3, + mask_token_idx=4, + sep_token_idx=5, + mask_prob=DEFAULT_MASK_PROB, + show_mask_token_prob=DEFAULT_SHOW_MASK_TOKEN_PROB, + mask_scheme=DEFAULT_MASK_SCHEME, + span_lambda=DEFAULT_SPAN_LAMBDA + ): + """ Load the tokenizer object from a vocab file and regex + + Reads a newline separated list of tokens from a file to use as the vocabulary + Note: Assumes that the chemical tokens run from chem_tokens_start_idx to the end of the tokens list + Anything after the defined tokens and before chem_tokens_start_idx is assumed to be an extra token + and is added to the regex for tokenizing + + Args: + vocab_path (str): Path to vocab file + regex (str): Regex to use for tokenizing + chem_tokens_start_idx (int): Index of the start of the chemical tokens in the tokens list + + Returns: + MolEncTokenizer object + """ + + text = Path(vocab_path).read_text() + tokens = text.split("\n") + tokens = [t for t in tokens if t is not None and t != ""] + + token_idxs = [pad_token_idx, unk_token_idx, begin_token_idx, end_token_idx, mask_token_idx, sep_token_idx] + extra_tokens_idxs = range(max(token_idxs) + 1, chem_tokens_start_idx) + extra_tokens = [tokens[idx] for idx in extra_tokens_idxs] + prog = MolEncTokenizer._get_compiled_regex(regex, extra_tokens) + + pad_token = tokens[pad_token_idx] + unk_token = tokens[unk_token_idx] + begin_token = tokens[begin_token_idx] + end_token = tokens[end_token_idx] + mask_token = tokens[mask_token_idx] + sep_token = tokens[sep_token_idx] + + chem_tokens_idxs = list(range(chem_tokens_start_idx, len(tokens))) + tokenizer = MolEncTokenizer( + tokens, + chem_tokens_idxs, + prog, + begin_token=begin_token, + end_token=end_token, + pad_token=pad_token, + unk_token=unk_token, + mask_token=mask_token, + sep_token=sep_token, + mask_prob=mask_prob, + show_mask_token_prob=show_mask_token_prob, + mask_scheme=mask_scheme, + span_lambda=span_lambda + ) + return tokenizer + + @staticmethod + def from_smiles( + smiles, + regex, + extra_tokens=None, + begin_token=DEFAULT_BEGIN_TOKEN, + end_token=DEFAULT_END_TOKEN, + pad_token=DEFAULT_PAD_TOKEN, + unk_token=DEFAULT_UNK_TOKEN, + mask_token=DEFAULT_MASK_TOKEN, + sep_token=DEFAULT_SEP_TOKEN, + mask_prob=DEFAULT_MASK_PROB, + show_mask_token_prob=DEFAULT_SHOW_MASK_TOKEN_PROB, + mask_scheme=DEFAULT_MASK_SCHEME, + span_lambda=DEFAULT_SPAN_LAMBDA + ): + """ Build the tokenizer from smiles strings and a regex + + Args: + smiles (List[str]): SMILES strings to use to build vocabulary + regex (str): Regex to use for tokenizing + extra_tokens (Optional[List[str]]): Additional tokens to add to the vocabulary that + may not appear in the SMILES strings + """ + + vocab = { + pad_token: 0, + unk_token: 1, + begin_token: 2, + end_token: 3, + mask_token: 4, + sep_token: 5 + } + + extra_tokens = [] if extra_tokens is None else extra_tokens + [vocab.setdefault(token, len(vocab)) for token in extra_tokens] + + chem_start_idx = len(vocab) + prog = MolEncTokenizer._get_compiled_regex(regex, extra_tokens) + print(f"Chemistry tokens start at index {chem_start_idx}") + + for smi in smiles: + for token in prog.findall(smi): + vocab.setdefault(token, len(vocab)) + + chem_token_idxs = list(range(chem_start_idx, len(vocab))) + + vocab = sorted(vocab.items(), key=lambda k_v: k_v[1]) + vocab = [key for key, val in vocab] + + tokenizer = MolEncTokenizer( + vocab, + chem_token_idxs, + prog, + begin_token=begin_token, + end_token=end_token, + pad_token=pad_token, + unk_token=unk_token, + mask_token=mask_token, + sep_token=sep_token, + mask_prob=mask_prob, + show_mask_token_prob=show_mask_token_prob, + mask_scheme=mask_scheme, + span_lambda=span_lambda + ) + return tokenizer + + def save_vocab(self, vocab_path): + tokens = sorted(self.vocab.items(), key=lambda k_v: k_v[1]) + tokens = [key for key, val in tokens] + + tokens_str = "" + for token in tokens: + tokens_str += f"{token}\n" + + p = Path(vocab_path) + p.write_text(tokens_str) + + def __len__(self): + return len(self.vocab) + + def tokenize(self, sents1, sents2=None, mask=False, pad=False): + if sents2 is not None and len(sents1) != len(sents2): + raise ValueError("Sentence 1 batch and sentence 2 batch must have the same number of elements") + + tokens = self._regex_match(sents1) + m_tokens, token_masks = self._mask_tokens(tokens, empty_mask=not mask) + + sent_masks = None + if sents2 is not None: + sents2_tokens = self._regex_match(sents2) + sents2_m_tokens, sents2_masks = self._mask_tokens(sents2_tokens, empty_mask=not mask) + tokens, sent_masks = self._concat_sentences(tokens, sents2_tokens, self.sep_token) + m_tokens, _ = self._concat_sentences(m_tokens, sents2_m_tokens, self.sep_token) + token_masks, _ = self._concat_sentences(token_masks, sents2_masks, False) + + + tokens = [[self.begin_token] + ts + [self.end_token] for ts in tokens] + m_tokens = [[self.begin_token] + ts + [self.end_token] for ts in m_tokens] + token_masks = [[False] + ts + [False] for ts in token_masks] + sent_masks = [[0] + mask + [1] for mask in sent_masks] if sent_masks is not None else None + + output = {} + + if pad: + tokens, orig_pad_masks = self._pad_seqs(tokens, self.pad_token) + m_tokens, masked_pad_masks = self._pad_seqs(m_tokens, self.pad_token) + token_masks, _ = self._pad_seqs(token_masks, False) + sent_masks, _ = self._pad_seqs(sent_masks, False) if sent_masks is not None else (None, None) + output["original_pad_masks"] = orig_pad_masks + output["masked_pad_masks"] = masked_pad_masks + + output["original_tokens"] = tokens + + if mask: + output["masked_tokens"] = m_tokens + output["token_masks"] = token_masks + + if sent_masks is not None: + output["sentence_masks"] = sent_masks + + return output + + def _regex_match(self, smiles): + tokenized = [] + data_type = type(smiles) + if data_type == str: + smiles = smiles.split() + # tokenized = self.prog.findall(smiles) + for smi in smiles: + tokens = self.prog.findall(smi) + tokenized.append(tokens) + + return tokenized + + @staticmethod + def _get_compiled_regex(regex, extra_tokens): + regex_string = r"(" + for token in extra_tokens: + processed_token = token + for special_character in "()[].|": + processed_token = processed_token.replace(special_character, f"\\{special_character}") + regex_string += processed_token + r"|" + + regex_string += regex + r"|" + regex_string += r".)" + return re.compile(regex_string) + + def _concat_sentences(self, tokens1, tokens2, sep): + tokens = [ts1 + [sep] + ts2 for ts1, ts2 in zip(tokens1, tokens2)] + sent_masks = [([0] * len(ts1)) + [0] + ([1] * len(ts2)) for ts1, ts2 in zip(tokens1, tokens2)] + return tokens, sent_masks + + def detokenize(self, tokens_list): + new_tokens_list = [] + for tokens in tokens_list: + if tokens[0] == self.begin_token: + tokens = tokens[1:] + + # Remove any tokens after the end token (and end token) if it's there + if self.end_token in tokens: + end_token_idx = tokens.index(self.end_token) + tokens = tokens[:end_token_idx] + + new_tokens_list.append(tokens) + + strs = ["".join(tokens) for tokens in new_tokens_list] + return strs + + def convert_tokens_to_ids(self, token_data): + ids_list = [] + for tokens in token_data: + for token in tokens: + token_id = self.vocab.get(token) + if token_id is None: + self._inc_in_dict(self.unk_token_cnt, token) + + ids = [self.vocab.get(token, self.unk_id) for token in tokens] + ids_list.append(ids) + + return ids_list + + def convert_ids_to_tokens(self, token_ids): + tokens_list = [] + for ids in token_ids: + for token_id in ids: + token = self.decode_vocab.get(token_id) + if token is None: + raise ValueError(f"Token id {token_id} is not recognised") + + tokens = [self.decode_vocab.get(token_id) for token_id in ids] + tokens_list.append(tokens) + + return tokens_list + + def print_unknown_tokens(self): + print(f"{'Token':<10}Count") + for token, cnt in self.unk_token_cnt.items(): + print(f"{token:<10}{cnt}") + + print() + + @staticmethod + def _inc_in_dict(coll, item): + cnt = coll.get(item, 0) + cnt += 1 + coll[item] = cnt + + def _mask_tokens(self, tokens, empty_mask=False): + if empty_mask: + mask = [[False] * len(ts) for ts in tokens] + return tokens, mask + + masked_tokens = [] + token_masks = [] + + for ts in tokens: + if self.mask_scheme == "replace": + masked, token_mask = self._mask_replace(ts) + elif self.mask_scheme == "span": + masked, token_mask = self._mask_span(ts) + else: + raise ValueError(f"Unrecognised mask scheme: {self.mask_scheme}") + + masked_tokens.append(masked) + token_masks.append(token_mask) + + return masked_tokens, token_masks + + def _mask_replace(self, ts): + mask_bools = [True, False] + weights = [self.mask_prob, 1 - self.mask_prob] + token_mask = random.choices(mask_bools, weights=weights, k=len(ts)) + masked = [self._mask_token(ts[i]) if m else ts[i] for i, m in enumerate(token_mask)] + return masked, token_mask + + def _mask_span(self, ts): + curr_token = 0 + masked = [] + token_mask = [] + + mask_bools = [True, False] + weights = [self.mask_prob, 1 - self.mask_prob] + sampled_mask = random.choices(mask_bools, weights=weights, k=len(ts)) + + while curr_token < len(ts): + # If mask, sample from a poisson dist to get length of mask + if sampled_mask[curr_token]: + mask_len = torch.poisson(torch.tensor(self.span_lambda)).long().item() + masked.append(self.mask_token) + token_mask.append(True) + curr_token += mask_len + + # Otherwise don't mask + else: + masked.append(ts[curr_token]) + token_mask.append(False) + curr_token += 1 + + return masked, token_mask + + def _mask_token(self, token): + rand = random.random() + if rand < self.show_mask_token_prob: + return self.mask_token + + elif rand < self.show_mask_token_prob + ((1 - self.show_mask_token_prob) / 2): + token_idx = random.choice(self.chem_token_idxs) + return self.decode_vocab[token_idx] + + else: + return token + + @staticmethod + def _pad_seqs(seqs, pad_token): + pad_length = max([len(seq) for seq in seqs]) + padded = [seq + ([pad_token] * (pad_length - len(seq))) for seq in seqs] + masks = [([0] * len(seq)) + ([1] * (pad_length - len(seq))) for seq in seqs] + return padded, masks + + +def load_tokenizer(vocab_path=DEFAULT_VOCAB_PATH, chem_token_start=DEFAULT_CHEM_TOKEN_START, regex=REGEX): + tokenizer = MolEncTokenizer.from_vocab_file(vocab_path, regex, chem_token_start) + return tokenizer \ No newline at end of file diff --git a/open_biomed/models/multimodal/mega_molbart/util.py b/open_biomed/models/multimodal/mega_molbart/util.py new file mode 100644 index 0000000..37807ad --- /dev/null +++ b/open_biomed/models/multimodal/mega_molbart/util.py @@ -0,0 +1,21 @@ +DEFAULT_VOCAB_PATH = "bart_vocab.txt" + +# Tokenization and vocabulary +DEFAULT_MAX_SEQ_LEN = 512 +DEFAULT_CHEM_TOKEN_START = 272 +DEFAULT_BEGIN_TOKEN = "^" +DEFAULT_END_TOKEN = "&" +DEFAULT_PAD_TOKEN = "" +DEFAULT_UNK_TOKEN = "?" +DEFAULT_MASK_TOKEN = "" +DEFAULT_SEP_TOKEN = "" +DEFAULT_MASK_PROB = 0.15 +DEFAULT_SHOW_MASK_TOKEN_PROB = 1.0 +DEFAULT_MASK_SCHEME = "span" +DEFAULT_SPAN_LAMBDA = 3.0 +REGEX = "\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]" + +# Model parameters +DEFAULT_D_MODEL = 256 +DEFAULT_NUM_LAYERS = 4 +DEFAULT_NUM_HEADS = 8 \ No newline at end of file diff --git a/open_biomed/models/multimodal/mega_molbart/workflow.py b/open_biomed/models/multimodal/mega_molbart/workflow.py new file mode 100644 index 0000000..00f1efa --- /dev/null +++ b/open_biomed/models/multimodal/mega_molbart/workflow.py @@ -0,0 +1,210 @@ +import logging +# import torch +from functools import singledispatch +from typing import List + +import numpy as np +from rdkit.Chem import PandasTools, CanonSmiles + +logger = logging.getLogger(__name__) + + +@singledispatch +def add_jitter(embedding, radius, cnt, shape): + return NotImplemented + + +@add_jitter.register(np.ndarray) +def _(embedding, radius, cnt, shape): + + distorteds = [] + for i in range(cnt): + noise = np.random.normal(0, radius, embedding.shape) + distorted = noise + embedding + distorteds.append(distorted) + + return distorteds + +class GenerativeWfDao(object): + + def fetch_id_from_chembl(self, id: List): + """ + Fetch molecular details for a list of molecules. The values in the list + of molecules depends on database/service used. For e.g. it could be + ChemblId or molreg_id for Chemble database. + """ + return NotImplemented + + +class BaseGenerativeWorkflow: + + def __init__(self, dao: GenerativeWfDao = None) -> None: + self.dao = dao + self.min_jitter_radius = None + + def get_iteration(self): + NotImplemented + + def smiles_to_embedding(self, + smiles: str, + padding: int): + NotImplemented + + def embedding_to_smiles(self, + embedding: float, + dim: int, + pad_mask): + NotImplemented + + def interpolate_smiles(self, + smiles: List, + num_points: int = 10, + scaled_radius=None, + force_unique=False): + NotImplemented + + def find_similars_smiles_list(self, + smiles: str, + num_requested: int = 10, + scaled_radius=None, + force_unique=False): + NotImplemented + + def find_similars_smiles(self, + smiles: str, + num_requested: int = 10, + scaled_radius=None, + force_unique=False): + NotImplemented + + def _compute_radius(self, scaled_radius): + if scaled_radius: + return float(scaled_radius * self.min_jitter_radius) + else: + return self.min_jitter_radius + + def addjitter(self, + embedding, + radius=None, + cnt=1, + shape=None): + radius = radius if radius else self.radius_scale + return add_jitter(embedding, radius, cnt, shape) + + def compute_unique_smiles(self, + interp_df, + embedding_funct, + scaled_radius=None): + """ + Identify duplicate SMILES and distorts the embedding. The input df + must have columns 'SMILES' and 'Generated' at 0th and 1st position. + 'Generated' colunm must contain boolean to classify SMILES into input + SMILES(False) and generated SMILES(True). + + This function does not make any assumptions about order of embeddings. + Instead it simply orders the df by SMILES to identify the duplicates. + """ + + distance = self._compute_radius(scaled_radius) + embeddings = interp_df['embeddings'] + embeddings_dim = interp_df['embeddings_dim'] + for index, row in interp_df.iterrows(): + smile_string = row['SMILES'] + try: + canonical_smile = CanonSmiles(smile_string) + except: + # If a SMILES cannot be canonicalized, just use the original + canonical_smile = smile_string + + row['SMILES'] = canonical_smile + + for i in range(5): + smiles = interp_df['SMILES'].sort_values() + duplicates = set() + for idx in range(0, smiles.shape[0] - 1): + if smiles.iat[idx] == smiles.iat[idx + 1]: + duplicates.add(smiles.index[idx]) + duplicates.add(smiles.index[idx + 1]) + + if len(duplicates) > 0: + for dup_idx in duplicates: + if interp_df.iat[dup_idx, 3]: + # add jitter to generated molecules only + distored = self.addjitter(embeddings[dup_idx], + distance, + cnt=1, + shape=embeddings_dim[dup_idx]) + embeddings[dup_idx] = distored[0] + interp_df['SMILES'] = embedding_funct(embeddings.to_list()) + interp_df['embeddings'] = embeddings + else: + break + + # Ensure all generated molecules are valid. + for i in range(5): + PandasTools.AddMoleculeColumnToFrame(interp_df, 'SMILES') + invalid_mol_df = interp_df[interp_df['ROMol'].isnull()] + + if not invalid_mol_df.empty: + invalid_index = invalid_mol_df.index.to_list() + for idx in invalid_index: + embeddings[idx] = self.addjitter(embeddings[idx], + distance, + cnt=1, + shape=embeddings_dim[idx])[0] + interp_df['SMILES'] = embedding_funct(embeddings.to_list()) + interp_df['embeddings'] = embeddings + else: + break + + # Cleanup + if 'ROMol' in interp_df.columns: + interp_df = interp_df.drop('ROMol', axis=1) + + return interp_df + + def interpolate_by_id(self, + ids: List, + id_type: str = 'chembleid', + num_points=10, + force_unique=False, + scaled_radius: int = 1): + smiles = None + + if not self.min_jitter_radius: + raise Exception('Property `radius_scale` must be defined in model class.') + + if id_type.lower() == 'chembleid': + smiles = [row[2] for row in self.dao.fetch_id_from_chembl(ids)] + if len(smiles) != len(ids): + raise Exception('One of the ids is invalid %s', ids) + else: + raise Exception('id type %s not supported' % id_type) + + return self.interpolate_smiles(smiles, + num_points=num_points, + scaled_radius=scaled_radius, + force_unique=force_unique) + + def find_similars_smiles_by_id(self, + chemble_id: str, + id_type: str = 'chembleid', + num_requested=10, + force_unique=False, + scaled_radius: int = 1): + smiles = None + + if not self.min_jitter_radius: + raise Exception('Property `radius_scale` must be defined in model class.') + + if id_type.lower() == 'chembleid': + smiles = [row[2] for row in self.dao.fetch_id_from_chembl(chemble_id)] + if len(smiles) != len(chemble_id): + raise Exception('One of the ids is invalid %s' + chemble_id) + else: + raise Exception('id type %s not supported' % id_type) + + return self.find_similars_smiles(smiles[0], + num_requested=num_requested, + scaled_radius=scaled_radius, + force_unique=force_unique) diff --git a/open_biomed/models/multimodal/moleculestm.py b/open_biomed/models/multimodal/moleculestm.py new file mode 100644 index 0000000..64dc313 --- /dev/null +++ b/open_biomed/models/multimodal/moleculestm.py @@ -0,0 +1,363 @@ +import logging +logger = logging.getLogger(__name__) + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections.abc import Sequence + +from torch_geometric.nn import (MessagePassing, global_add_pool, + global_max_pool, global_mean_pool) +from torch_geometric.utils import add_self_loops, softmax, degree + +from models.base_models import MolEncoder, TextEncoder +from transformers import BertModel +from models.multimodal.mega_molbart.mega_mol_bart import MegaMolBART + +class AtomEncoder(torch.nn.Module): + def __init__(self, emb_dim): + super(AtomEncoder, self).__init__() + + self.atom_embedding_list = torch.nn.ModuleList() + + for i, dim in enumerate([119, 4, 12, 12, 10, 6, 6, 2, 2]): + emb = torch.nn.Embedding(dim, emb_dim) + torch.nn.init.xavier_uniform_(emb.weight.data) + self.atom_embedding_list.append(emb) + + def forward(self, x): + x_embedding = 0 + for i in range(x.shape[1]): + x_embedding += self.atom_embedding_list[i](x[:,i]) + + return x_embedding + + +class BondEncoder(torch.nn.Module): + def __init__(self, emb_dim): + super(BondEncoder, self).__init__() + + self.bond_embedding_list = torch.nn.ModuleList() + + for i, dim in enumerate([5, 6, 2]): + emb = torch.nn.Embedding(dim, emb_dim) + torch.nn.init.xavier_uniform_(emb.weight.data) + self.bond_embedding_list.append(emb) + + def forward(self, edge_attr): + bond_embedding = 0 + for i in range(edge_attr.shape[1]): + bond_embedding += self.bond_embedding_list[i](edge_attr[:,i]) + + return bond_embedding + +class GINConv(MessagePassing): + def __init__(self, emb_dim, aggr="add"): + ''' + emb_dim (int): node embedding dimensionality + ''' + super(GINConv, self).__init__(aggr=aggr) + + self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim)) + self.eps = torch.nn.Parameter(torch.Tensor([0])) + + self.bond_encoder = BondEncoder(emb_dim = emb_dim) + + def forward(self, x, edge_index, edge_attr): + edge_embedding = self.bond_encoder(edge_attr) + out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) + return out + + def message(self, x_j, edge_attr): + return F.relu(x_j + edge_attr) + + def update(self, aggr_out): + return aggr_out + + +class GCNConv(MessagePassing): + def __init__(self, emb_dim, aggr="add"): + super(GCNConv, self).__init__(aggr=aggr) + + self.linear = torch.nn.Linear(emb_dim, emb_dim) + self.root_emb = torch.nn.Embedding(1, emb_dim) + self.bond_encoder = BondEncoder(emb_dim = emb_dim) + + def forward(self, x, edge_index, edge_attr): + x = self.linear(x) + edge_embedding = self.bond_encoder(edge_attr) + + row, col = edge_index + + #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device) + deg = degree(row, x.size(0), dtype = x.dtype) + 1 + deg_inv_sqrt = deg.pow(-0.5) + deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 + + norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] + + return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1) + + def message(self, x_j, edge_attr, norm): + return norm.view(-1, 1) * F.relu(x_j + edge_attr) + + def update(self, aggr_out): + return aggr_out + + +class GNN(nn.Module): + def __init__(self, num_layer, emb_dim, JK="last", drop_ratio=0., gnn_type="gin"): + + if num_layer < 2: + raise ValueError("Number of GNN layers must be greater than 1.") + + super(GNN, self).__init__() + self.drop_ratio = drop_ratio + self.num_layer = num_layer + self.JK = JK + + self.atom_encoder = AtomEncoder(emb_dim) + + ###List of MLPs + self.gnns = nn.ModuleList() + for layer in range(num_layer): + if gnn_type == "gin": + self.gnns.append(GINConv(emb_dim, aggr="add")) + elif gnn_type == "gcn": + self.gnns.append(GCNConv(emb_dim)) + + ###List of batchnorms + self.batch_norms = nn.ModuleList() + for layer in range(num_layer): + self.batch_norms.append(nn.BatchNorm1d(emb_dim)) + + # def forward(self, x, edge_index, edge_attr): + def forward(self, *argv): + if len(argv) == 3: + x, edge_index, edge_attr = argv[0], argv[1], argv[2] + elif len(argv) == 1: + data = argv[0] + x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr + else: + raise ValueError("unmatched number of arguments.") + + x = self.atom_encoder(x) + + h_list = [x] + for layer in range(self.num_layer): + h = self.gnns[layer](h_list[layer], edge_index, edge_attr) + h = self.batch_norms[layer](h) + # h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) + if layer == self.num_layer - 1: + # remove relu for the last layer + h = F.dropout(h, self.drop_ratio, training=self.training) + else: + h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) + h_list.append(h) + + ### Different implementations of Jk-concat + if self.JK == "concat": + node_representation = torch.cat(h_list, dim=1) + elif self.JK == "last": + node_representation = h_list[-1] + elif self.JK == "max": + h_list = [h.unsqueeze_(0) for h in h_list] + node_representation = torch.max(torch.cat(h_list, dim=0), dim=0)[0] + elif self.JK == "sum": + h_list = [h.unsqueeze_(0) for h in h_list] + node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)[0] + else: + raise ValueError("not implemented.") + return node_representation + + +class GNN_graphpred(nn.Module): + """ + Extension of GIN to incorporate edge information by concatenation. + + Args: + num_layer (int): the number of GNN layers + arg.emb_dim (int): dimensionality of embeddings + num_tasks (int): number of tasks in multi-task learning scenario + JK (str): last, concat, max or sum. + graph_pooling (str): sum, mean, max, attention, set2set + + See https://arxiv.org/abs/1810.00826 + JK-net: https://arxiv.org/abs/1806.03536 """ + + def __init__(self, num_layer, emb_dim, num_tasks, JK, graph_pooling, molecule_node_model=None): + super(GNN_graphpred, self).__init__() + + if num_layer < 2: + raise ValueError("# layers must > 1.") + + self.molecule_node_model = molecule_node_model + self.num_layer = num_layer + self.emb_dim = emb_dim + self.num_tasks = num_tasks + self.JK = JK + + # Different kind of graph pooling + if graph_pooling == "sum": + self.pool = global_add_pool + elif graph_pooling == "mean": + self.pool = global_mean_pool + elif graph_pooling == "max": + self.pool = global_max_pool + else: + raise ValueError("Invalid graph pooling type.") + + # For graph-level binary classification + self.mult = 1 + + if self.JK == "concat": + self.graph_pred_linear = nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim, + self.num_tasks) + else: + self.graph_pred_linear = nn.Linear(self.mult * self.emb_dim, self.num_tasks) + return + + def from_pretrained(self, model_file): + print("Loading from {} ...".format(model_file)) + state_dict = torch.load(model_file) + self.molecule_node_model.load_state_dict(state_dict) + return + + def forward(self, *argv): + if len(argv) == 4: + x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3] + elif len(argv) == 1: + data = argv[0] + x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch + else: + raise ValueError("unmatched number of arguments.") + + node_representation = self.molecule_node_model(x, edge_index, edge_attr) + graph_representation = self.pool(node_representation, batch) + output = self.graph_pred_linear(graph_representation) + return graph_representation, output + +class MoleculeSTM(MolEncoder, TextEncoder): + def __init__(self, config): + super().__init__() + self.config = config + + if config["structure"]["name"] == "magamolbart": + self.MegaMolBART_wrapper = MegaMolBART( + vocab_path=config["structure"]["vocab_path"], + input_dir=config["structure"]["MegaMolBART_generation_model_dir"], + output_dir=None + ) + self.structure_encoder = self.MegaMolBART_wrapper.model + elif config["structure"]["name"] == "gnn": + self.MegaMolBART_wrapper = MegaMolBART( + vocab_path=config["structure"]["vocab_path"], + input_dir=config["structure"]["MegaMolBART_generation_model_dir"], + output_dir=None + ) + molecule_node_model = GNN( + num_layer=config["structure"]["gin_num_layers"], + emb_dim=config["structure"]["gin_hidden_dim"], + gnn_type="gin", + drop_ratio=config["structure"]["drop_ratio"], + JK="last", + ) + self.structure_encoder = GNN_graphpred( + num_layer=config["structure"]["gin_num_layers"], + emb_dim=config["structure"]["gin_hidden_dim"], + graph_pooling="mean", + JK="last", + num_tasks=1, + molecule_node_model=molecule_node_model + ) + else: + raise AttributeError + if "ckpt" in config["structure"]: + logger.info("Loading structure encoder from %s" % (config["structure"]["ckpt"])) + state_dict = torch.load(config["structure"]["ckpt"], map_location="cpu") + self.structure_encoder.load_state_dict(state_dict) + + self.text_encoder = BertModel.from_pretrained(config["text"]["bert_path"]) + if "ckpt" in config["text"]: + logger.info("Loading text encoder from %s" % (config["text"]["ckpt"])) + state_dict = torch.load(config["text"]["ckpt"], map_location="cpu") + missing_keys, unexpected_keys = self.text_encoder.load_state_dict(state_dict, strict=False) + logger.info("missing keys: " + str(missing_keys)) + logger.info("unexpected keys: " + str(unexpected_keys)) + + self.structure_proj_head = nn.Linear(config["structure"]["output_dim"], config["projection_dim"]) + self.text_proj_head = nn.Linear(config["text"]["output_dim"], config["projection_dim"]) + if "structure_proj_ckpt" in config: + logger.info("Loading structure projection head from %s" % (config["structure_proj_ckpt"])) + state_dict = torch.load(config["structure_proj_ckpt"], map_location="cpu") + self.structure_proj_head.load_state_dict(state_dict) + if "text_proj_ckpt" in config: + logger.info("Loading text projection head from %s" % (config["text_proj_ckpt"])) + state_dict = torch.load(config["text_proj_ckpt"], map_location="cpu") + self.text_proj_head.load_state_dict(state_dict) + self.norm = False + + def encode_mol(self, structure, proj=False, return_node_feats=False): + mol_embeds, node_embeds = self.structure_encoder(structure) + if proj: + mol_embeds = self.structure_proj_head(mol_embeds) + if not return_node_feats: + return mol_embeds + else: + return mol_embeds, node_embeds + + def encode_text(self, text, proj=False): + text_embeds = self.text_encoder(text["input_ids"], attention_mask=text["attention_mask"])["pooler_output"] + if proj: + return self.text_proj_head(text_embeds) + else: + return text_embeds + + + + +class MLP(nn.Module): + def __init__(self, input_dim, hidden_dims, batch_norm=False, activation="relu", dropout=0): + super(MLP, self).__init__() + + if not isinstance(hidden_dims, Sequence): + hidden_dims = [hidden_dims] + self.dims = [input_dim] + hidden_dims + + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = activation + if dropout: + self.dropout = nn.Dropout(dropout) + else: + self.dropout = None + + self.layers = nn.ModuleList() + for i in range(len(self.dims) - 1): + self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1])) + if batch_norm: + self.batch_norms = nn.ModuleList() + for i in range(len(self.dims) - 2): + self.batch_norms.append(nn.BatchNorm1d(self.dims[i + 1])) + else: + self.batch_norms = None + + def forward(self, input): + layer_input = input + + for i, layer in enumerate(self.layers): + hidden = layer(layer_input) + if i < len(self.layers) - 1: + if self.batch_norms: + x = hidden.flatten(0, -2) + hidden = self.batch_norms[i](x).view_as(hidden) + hidden = self.activation(hidden) + if self.dropout: + hidden = self.dropout(hidden) + if hidden.shape == layer_input.shape: + hidden = hidden + layer_input + layer_input = hidden + + return hidden \ No newline at end of file diff --git a/open_biomed/models/multimodal/molkformer/kformer.py b/open_biomed/models/multimodal/molkformer/kformer.py new file mode 100644 index 0000000..afe1946 --- /dev/null +++ b/open_biomed/models/multimodal/molkformer/kformer.py @@ -0,0 +1,1244 @@ +""" + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + self.token_type_embeddings = nn.Embedding( + config.type_vocab_size, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + token_type_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if token_type_ids is None and input_ids is not None: + token_type_ids = torch.zeros(input_ids.size(), dtype=torch.long, device=self.position_ids.device) + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = embeddings + token_type_embeddings + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + mode='text' + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + + if mode == 'text': + start_layer = 0 + end_layer = self.config.contrastive_layer + elif mode == 'fusion': + start_layer = self.config.contrastive_layer + end_layer = self.config.num_hidden_layers + else: + start_layer = 0 + end_layer = self.config.num_hidden_layers + + for i in range(start_layer, end_layer): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode=None + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.num_query_tokens + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + if query_embeds is not None: + embedding_output = torch.cat([query_embeds, embedding_output], dim=1) + + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + mode=mode + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/open_biomed/models/multimodal/molkformer/mol_kformer.py b/open_biomed/models/multimodal/molkformer/mol_kformer.py new file mode 100644 index 0000000..93e5cd0 --- /dev/null +++ b/open_biomed/models/multimodal/molkformer/mol_kformer.py @@ -0,0 +1,277 @@ +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from transformers import BertTokenizer, T5Tokenizer, T5Config, T5ForConditionalGeneration +from transformers.modeling_outputs import BaseModelOutput + +from models.base_models import MolEncoder, TextEncoder +from models.molecule.gnn_graphmvp import GNNGraphMVP +from models.multimodal.molkformer.kformer import BertConfig, BertLMHeadModel +from utils.mol_utils import convert_pyg_batch, get_biot5_tokenizer +from torch_geometric.nn import (MessagePassing, global_add_pool, global_max_pool, global_mean_pool) +class MolKFormer(MolEncoder, TextEncoder): + def __init__(self, config): + super().__init__() + self.pooling = global_mean_pool + self.structure_config = config["structure"] + self.kformer_config = BertConfig.from_json_file(config["kformer_config_file"]) + self.projection_dim = config["projection_dim"] + self.max_n_atoms = config["max_n_atoms"] + self.num_query_tokens = self.kformer_config.num_query_tokens + self.encoder_tokenizer = BertTokenizer.from_pretrained(config["encoder_tokenizer"]) + #self.decoder_tokenizer = T5Tokenizer.from_pretrained(config["decoder_tokenizer"]) + self.decoder_tokenizer = get_biot5_tokenizer(config["decoder_tokenizer"], config["path_selfies"]) + + self.structure_encoder = GNNGraphMVP( + num_layer=self.structure_config["gin_num_layers"], + emb_dim=self.structure_config["gin_hidden_dim"], + gnn_type="gin", + drop_ratio=self.structure_config["drop_ratio"], + JK="last", + ) + if "ckpt" in self.structure_config: + self.structure_encoder.load_state_dict(torch.load(self.structure_config["ckpt"], map_location="cpu"), strict=False) + self.structure_linear = nn.Linear(self.structure_encoder.output_dim, self.kformer_config.hidden_size) + self.structure_proj_head = nn.Linear(self.kformer_config.hidden_size, self.projection_dim) + #self.structure_proj_head = nn.Linear(self.structure_encoder.output_dim, self.projection_dim) + + self.kformer = BertLMHeadModel(self.kformer_config) + self.query_tokens = nn.Parameter( + torch.zeros(1, self.num_query_tokens, self.kformer_config.hidden_size) + ) + self.query_tokens.data.normal_(mean=0.0, std=self.kformer_config.initializer_range) + self.text_proj_head = nn.Linear(self.kformer_config.hidden_size, self.projection_dim) + self.mtm_head = nn.Linear(self.kformer_config.hidden_size, 2) + + decoder_config = T5Config.from_json_file(config["decoder"]["config_file"]) + self.text_decoder = T5ForConditionalGeneration(decoder_config) + self.text_decoder.resize_token_embeddings(35073) + self.enc2dec = nn.Linear(self.kformer_config.hidden_size, self.text_decoder.config.hidden_size) + + #self.h_proj = nn.Linear(self.kformer_config.hidden_size, self.projection_dim) + #self.t_proj = nn.Linear(self.kformer_config.hidden_size, self.projection_dim) + self.norm = False + + def forward(self, mol, text=None, prompt=None, cal_loss=False): + # calculate molecule feature + batch_size = torch.max(mol.batch).item() + 1 + _, node_embeds, node_attention_mask = self.get_graph_feats(mol, batch_size) + query_embeds = self.query_tokens.expand(batch_size, -1, -1) + + #return query_outputs.squeeze() + if not cal_loss: + input_ids, attention_mask = self.seq_wrap(prompt, text) + attention_mask = torch.cat([torch.ones(query_embeds.shape[:-1], dtype=torch.long).to(query_embeds.device), attention_mask], dim=1) + return self.kformer.bert( + input_ids, + query_embeds=query_embeds, + attention_mask=attention_mask, + encoder_hidden_states=node_embeds, + encoder_attention_mask=node_attention_mask, + return_dict=True, + ).last_hidden_state[:, :self.num_query_tokens, :] + else: + if prompt is not None: + prompt_embeds = self.kformer.get_input_embeddings()(prompt["input_ids"]) + query_attention_mask = torch.cat([torch.ones(query_embeds.shape[:-1], dtype=torch.long).to(self.device), prompt["attention_mask"]], dim=1) + else: + prompt_embeds = None + query_attention_mask = torch.ones(query_embeds.shape[:-1], dtype=torch.long).to(self.device) + query_outputs = self.kformer.bert( + encoder_embeds=prompt_embeds, + query_embeds=query_embeds, + attention_mask=query_attention_mask, + encoder_hidden_states=node_embeds, + encoder_attention_mask=node_attention_mask, + return_dict=True + ).last_hidden_state[:, :self.num_query_tokens, :] + mol_feats = F.normalize(self.structure_proj_head(query_outputs), dim=-1) + + text_outputs = self.kformer.bert( + text["input_ids"], + attention_mask=text["attention_mask"], + return_dict=True + ) + text_embeds = text_outputs.last_hidden_state + text_feats = F.normalize(self.text_proj_head(text_embeds[:, 0, :]), dim=-1) + + sim_m2t = torch.matmul(mol_feats.unsqueeze(1), text_feats.unsqueeze(-1)).squeeze() + sim_m2t, _ = sim_m2t.max(dim=-1) + sim_t2m = torch.matmul(text_feats.unsqueeze(1).unsqueeze(1), mol_feats.transpose(1, 2)).squeeze() + sim_t2m, _ = sim_t2m.max(dim=-1) + + # find hard negatives + with torch.no_grad(): + weights_m2t = F.softmax(sim_m2t, dim=1) + 1e-4 + weights_m2t.fill_diagonal_(0.0) + weights_t2m = F.softmax(sim_t2m, dim=1) + 1e-4 + weights_t2m.fill_diagonal_(0.0) + idx_neg_m2t = [] + for i in range(batch_size): + idx_neg_m2t.append(torch.multinomial(weights_m2t[i], 1).item()) + idx_neg_m2t = torch.tensor(idx_neg_m2t, dtype=int).to(node_embeds) + idx_neg_t2m = [] + for i in range(batch_size): + idx_neg_t2m.append(torch.multinomial(weights_t2m[i], 1).item()) + + node_embeds_mtm = torch.cat([node_embeds, node_embeds, node_embeds[idx_neg_t2m]], dim=0) + node_attention_mask_mtm = torch.cat([node_attention_mask, node_attention_mask, node_attention_mask[idx_neg_t2m]], dim=0) + wrapped_input_ids, wrapped_attention_mask = self.seq_wrap(prompt, text) + text_input_ids_mtm = torch.cat([wrapped_input_ids, wrapped_input_ids[idx_neg_m2t], wrapped_input_ids], dim=0) + text_attention_mask_mtm = torch.cat([wrapped_attention_mask, wrapped_attention_mask[idx_neg_m2t], wrapped_attention_mask], dim=0) + query_embeds_mtm = self.query_tokens.expand(node_embeds_mtm.shape[0], -1, -1) + query_attention_mask_mtm = torch.ones(query_embeds_mtm.shape[:-1], dtype=torch.long).to(query_embeds_mtm.device) + text_attention_mask_mtm = torch.cat([query_attention_mask_mtm, text_attention_mask_mtm], dim=1) + mtm_labels = torch.cat([torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)], dim=0).to(query_embeds_mtm.device) + + output = self.kformer.bert( + input_ids=text_input_ids_mtm, + query_embeds=query_embeds_mtm, + attention_mask=text_attention_mask_mtm, + encoder_hidden_states=node_embeds_mtm, + encoder_attention_mask=node_attention_mask_mtm, + return_dict=True + ) + mtm_output = self.mtm_head(output["last_hidden_state"][:, : self.num_query_tokens, :]).mean(dim=1) + loss_mtm = F.cross_entropy(mtm_output, mtm_labels) + return loss_mtm + + def seq_wrap(self, seq1, seq2): + if seq1 is None: + return seq2["input_ids"], seq2["attention_mask"] + if seq2 is None: + return seq1["input_ids"], seq1["attention_mask"] + batch_size = seq1["input_ids"].shape[0] + wrapped_inputs, wrapped_attention_mask = [], [] + for i in range(batch_size): + cur_len = seq1["attention_mask"][i].sum() + wrapped_inputs.append(torch.cat([ + seq1["input_ids"][i, :cur_len], + seq2["input_ids"][i], + seq1["input_ids"][i, cur_len:] + ], dim=0)) + wrapped_attention_mask.append(torch.cat([ + seq1["attention_mask"][i, :cur_len], + seq2["attention_mask"][i], + seq1["attention_mask"][i, cur_len:] + ], dim=0)) + return torch.stack(wrapped_inputs, dim=0), torch.stack(wrapped_attention_mask, dim=0) + + def get_graph_feats(self, graph, batch_size): + graph_embeds, node_embeds = self.structure_encoder(graph) + # batch = graph.batch + # a = self.pooling(node_embeds, batch) + all_node_feats = self.structure_linear(node_embeds) + # serialize node feature + node_feats = [] + node_attention_mask = [] + for i in range(batch_size): + feat = all_node_feats[torch.where(graph.batch == i)] + if feat.shape[0] < self.max_n_atoms: + node_feats.append(torch.cat(( + feat, + torch.zeros(self.max_n_atoms - feat.shape[0], feat.shape[1]).to(feat.device) + ), dim=0)) + node_attention_mask.append(torch.cat(( + torch.ones(feat.shape[0]).to(feat.device), + torch.zeros(self.max_n_atoms - feat.shape[0]).to(feat.device) + ), dim=0)) + else: + node_feats.append(feat[:self.max_n_atoms, :]) + node_attention_mask.append(torch.ones(self.max_n_atoms).to(feat.device)) + node_feats = torch.stack(node_feats, dim=0) + node_attention_mask = torch.stack(node_attention_mask, dim=0) + return graph_embeds, node_feats, node_attention_mask + + def encode_mol(self, mol, proj=False): + if "text" in mol: + s = mol["structure"] + if "graph" in mol["structure"]: + s = s["graph"] + mol_embeds = self.forward(s, prompt=mol["text"]) + else: + # mol = mol["structure"]["Graph"] + mol = mol["structure"] + batch_size = torch.max(mol.batch).item() + 1 + _, node_embeds, node_attention_mask = self.get_graph_feats(mol, batch_size) + query_embeds = self.query_tokens.expand(batch_size, -1, -1) + attention_mask = torch.ones(query_embeds.shape[:-1], dtype=torch.long).to(query_embeds.device) + mol_embeds = self.kformer.bert( + query_embeds=query_embeds, + attention_mask=attention_mask, + encoder_hidden_states=node_embeds, + encoder_attention_mask=node_attention_mask, + return_dict=True + ).last_hidden_state + if proj: + mol_embeds = F.normalize(self.structure_proj_head(mol_embeds), dim=-1) + return mol_embeds + + def encode_text(self, text, return_cls=True, proj=False): + text_embeds = self.kformer.bert( + text["input_ids"], + attention_mask=text["attention_mask"], + return_dict=True, + ).last_hidden_state + if return_cls: + text_embeds = text_embeds[:, 0, :] + if proj: + text_embeds = F.normalize(self.text_proj_head(text_embeds), dim=-1) + return text_embeds + + def decode(self, mol, num_beams, max_length): + h_graph = self.encode_mol(mol) + h_graph = self.enc2dec(h_graph) + h_smi = self.text_decoder.encoder(**mol["structure"]["SMILES"]).last_hidden_state + h = torch.cat([h_graph, h_smi], dim=1) + attention_mask = torch.ones(h_graph.shape[:-1], dtype=torch.long).to(h.device) + attention_mask = torch.cat([attention_mask, mol["structure"]["SMILES"].attention_mask], dim=1) + h = BaseModelOutput( + last_hidden_state=h, + hidden_states=None, + attentions=None + ) + outputs = self.text_decoder.generate( + encoder_outputs=h, + attention_mask=attention_mask, + num_beams=num_beams, + max_length=max_length + ) + return outputs + #return self.decoder_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + def predict_similarity_score(self, mol, text): + if "text" in mol: + prompt = mol["text"] + mol = mol["structure"] + else: + prompt = None + preds = self.forward(mol, text, prompt=prompt) + return F.softmax(self.mtm_head(preds).mean(dim=1), dim=-1)[:, 1] + + def causal_generation_loss(self, mol, text): + labels = text["input_ids"].masked_fill(~text["attention_mask"].bool(), -100) + #h = self.enc2dec(self.encode_mol(mol)) + #attention_mask = torch.ones(h.shape[:-1], dtype=torch.long).to(h.device) + h_graph = self.encode_mol(mol) + h_graph = self.enc2dec(h_graph) + h_smi = self.text_decoder.encoder(**mol["structure"]["SMILES"]).last_hidden_state + h = torch.cat([h_graph, h_smi], dim=1) + attention_mask = torch.ones(h_graph.shape[:-1], dtype=torch.long).to(h.device) + attention_mask = torch.cat([attention_mask, mol["structure"]["SMILES"].attention_mask], dim=1) + h = BaseModelOutput( + last_hidden_state=h, + hidden_states=None, + attentions=None + ) + return self.text_decoder( + encoder_outputs=h, + attention_mask=attention_mask, + decoder_attention_mask=text["attention_mask"], + return_dict=True, + labels=labels + ).loss \ No newline at end of file diff --git a/open_biomed/models/task_model/moledit_model.py b/open_biomed/models/task_model/moledit_model.py new file mode 100644 index 0000000..23a3823 --- /dev/null +++ b/open_biomed/models/task_model/moledit_model.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn + +from models import SUPPORTED_MOL_ENCODER + + +class MoleditModel(nn.Module): + def __init__(self, config): + super(MoleditModel, self).__init__() + if "smiles" in config: + self.model = SUPPORTED_MOL_ENCODER[config["smiles"]["name"]](config["smiles"]) + self.use_molstm = True if config["smiles"]["name"] == "molstm" else False + elif "graph" in config: + self.model = SUPPORTED_MOL_ENCODER[config["graph"]["name"]](config["graph"]) + if config["graph"]["name"] == "molkformer": + self.ckpt = torch.load(config["graph"]["init_checkpoint"], map_location="cpu") + self.ckpt = self.ckpt["model"] + self.model.load_state_dict(self.ckpt) + + if config["graph"]["name"] == "momu": + self.ckpt = torch.load(config["graph"]["init_checkpoint"]) + if "param_key" in config["graph"]: + self.ckpt = self.ckpt[config["graph"]["param_key"]] + self.model.load_state_dict(self.ckpt) + + self.use_molkformer = True if config["graph"]["name"] == "molkformer" else False + self.use_momu = True if config["graph"]["name"] == "momu" else False + self.use_molstm = True if config["graph"]["name"] == "molstm" else False + + def forward(self, mol): + h = self.encode(mol) + return h + + def encode(self, mol): + #text_encode + if "input_ids" in mol: + h = self.model.encode_text(mol, proj=True) + #graph_encode + else: + if self.use_molkformer==True: + mol={"structure":mol} + graph_feats = self.model.encode_mol(mol, proj=True) + h = graph_feats.mean(dim=1) + if self.use_momu==True: + graph_feats = self.model.encode_mol(mol, proj=True) + h = graph_feats + if self.use_molstm==True: + graph_feats = self.model.encode_mol(mol, proj=True) + h = graph_feats + return h \ No newline at end of file diff --git a/open_biomed/tasks/mol_edit/moledit_step_01_Space_Alignment.py b/open_biomed/tasks/mol_edit/moledit_step_01_Space_Alignment.py new file mode 100644 index 0000000..422648d --- /dev/null +++ b/open_biomed/tasks/mol_edit/moledit_step_01_Space_Alignment.py @@ -0,0 +1,340 @@ +import argparse +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +import numpy as np +from tqdm import tqdm +import time +import json +import re +import copy +import pickle + +import torch +import torch.nn as nn +from torch import optim +import torch.nn.functional as F +from torch.utils.data import DataLoader as torch_DataLoader +from torch_geometric.loader import DataLoader as pyg_DataLoader + + +from utils.molstm_utils import get_molecule_repr_MoleculeSTM +from models.multimodal.moleculestm import MLP +from utils.molstm_utils import load_molecule_models +from utils.molstm_utils import freeze_network +from datasets.moledit_dataset import SUPPORTED_MOLEDIT_DATASET +from models.task_model.moledit_model import MoleditModel +from models.multimodal.mega_molbart.mega_mol_bart import MegaMolBART + +def cycle_index(num, shift): + arr = torch.arange(num) + shift + arr[-shift:] = torch.arange(shift) + return arr + + +def do_CL(X, Y, args): + if args.normalize: + X = F.normalize(X, dim=-1) + Y = F.normalize(Y, dim=-1) + + if args.SSL_loss == 'EBM_NCE': + criterion = nn.BCEWithLogitsLoss() + neg_Y = torch.cat([Y[cycle_index(len(Y), i + 1)] for i in range(args.CL_neg_samples)], dim=0) + neg_X = X.repeat((args.CL_neg_samples, 1)) + + pred_pos = torch.sum(X * Y, dim=1) / args.T + pred_neg = torch.sum(neg_X * neg_Y, dim=1) / args.T + + loss_pos = criterion(pred_pos, torch.ones(len(pred_pos)).to(pred_pos.device)) + loss_neg = criterion(pred_neg, torch.zeros(len(pred_neg)).to(pred_neg.device)) + SSL_loss = (loss_pos + args.CL_neg_samples * loss_neg) / (1 + args.CL_neg_samples) + + SSL_acc = (torch.sum(pred_pos > 0).float() + torch.sum(pred_neg < 0).float()) / \ + (len(pred_pos) + len(pred_neg)) + SSL_acc = SSL_acc.detach().cpu().item() + + elif args.SSL_loss == 'InfoNCE': + criterion = nn.CrossEntropyLoss() + B = X.size()[0] + logits = torch.mm(X, Y.transpose(1, 0)) # B*B + logits = torch.div(logits, args.T) + labels = torch.arange(B).long().to(logits.device) # B*1 + + SSL_loss = criterion(logits, labels) + pred = logits.argmax(dim=1, keepdim=False) + SSL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B + + elif args.SSL_loss == 'RR': + criterion = nn.MSELoss() + SSL_loss = criterion(X, Y) + SSL_acc = 0 + + else: + raise Exception + + return SSL_loss, SSL_acc + +def mean_pooling(token_embeddings, attention_mask): + attention_mask = ~attention_mask + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() # [pad, B, d] + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 0) # [B, d] + sum_mask = torch.clamp(input_mask_expanded.sum(0), min=1e-9) # [B, d] + return sum_embeddings / sum_mask + + +def get_molecule_repr_generation(molecule_data, molecule_model, molecule_type="MegaMolBART", MegaMolBART_wrapper=None): + if molecule_type == "MegaMolBART": + embedding, pad_mask = MegaMolBART_wrapper.smileslist2embedding_model_given(molecule_model, molecule_data) # [pad, B, d], [pad, B] + molecule_repr = mean_pooling(embedding, pad_mask) + else: + molecule_repr, _ = molecule_model(molecule_data) + return molecule_repr + + +def save_model(save_best, epoch=None): + if args.output_path is not None: + if save_best: + global optimal_loss + print("save model with loss: {:.5f}".format(optimal_loss)) + model_file = "model.pth" + + elif epoch is None: + model_file = "model_final.pth" + + else: + model_file = "model_{}.pth".format(epoch) + + saved_file_path = os.path.join(args.output_path, "generation2MoleculeSTM_{}".format(model_file)) + torch.save(generation2MoleculeSTM.state_dict(), saved_file_path) + + saved_file_path = os.path.join(args.output_path, "MoleculeSTM2generation_{}".format(model_file)) + torch.save(MoleculeSTM2generation.state_dict(), saved_file_path) + return + + +def train(epoch): + if args.verbose: + L = tqdm(dataloader) + else: + L = dataloader + + start_time = time.time() + accum_loss, accum_acc = 0, 0 + batch_num=0 + + for batch in L: + if args.MoleculeSTM_molecule_type == "SMILES": + SMILES_list = batch["structure"]["SMILES"] + else: + SMILES_list = batch["structure"]["SMILES"] + graph = batch["structure"]["graph"] + graph = graph.to(device) + + if args.use_static_files==1: + molecule_repr_MoleculeSTM = molecule_repr_MoleculeSTM_list[batch_num].to(device) + molecule_repr_MoleculeSTM2generation = MoleculeSTM2generation(molecule_repr_MoleculeSTM) + molecule_repr_generation = molecule_repr_generation_list[batch_num].to(device) + molecule_repr_generation2MoleculeSTM = generation2MoleculeSTM(molecule_repr_generation) + batch_num+=1 + else: + if args.MoleculeSTM_molecule_type == "SMILES": + molecule_repr_MoleculeSTM = get_molecule_repr_MoleculeSTM( + SMILES_list, molecule_model=molecule_model_MoleculeSTM, mol2latent=mol2latent_MoleculeSTM, + molecule_type=args.MoleculeSTM_molecule_type, MegaMolBART_wrapper=MegaMolBART_wrapper + ) + else: + molecule_repr_MoleculeSTM = get_molecule_repr_MoleculeSTM( + graph, molecule_model=molecule_model_MoleculeSTM, mol2latent=mol2latent_MoleculeSTM, + molecule_type=args.MoleculeSTM_molecule_type, MegaMolBART_wrapper=MegaMolBART_wrapper + ) + if args.generation_model == "MegaMolBART": + molecule_repr_generation = get_molecule_repr_generation( + SMILES_list, molecule_model=molecule_model_generation, + molecule_type="MegaMolBART", MegaMolBART_wrapper=MegaMolBART_wrapper + ) + molecule_repr_MoleculeSTM2generation = MoleculeSTM2generation(molecule_repr_MoleculeSTM) + molecule_repr_generation2MoleculeSTM = generation2MoleculeSTM(molecule_repr_generation) + + loss_01, acc_01 = do_CL(molecule_repr_generation, molecule_repr_MoleculeSTM2generation, args) + loss_02, acc_02 = do_CL(molecule_repr_MoleculeSTM, molecule_repr_generation2MoleculeSTM, args) + loss = (loss_01 + loss_02) / 2 + acc = (acc_01 + acc_02) / 2 + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + accum_loss += loss.item() + accum_acc += acc + + accum_loss /= len(L) + accum_acc /= len(L) + + global optimal_loss + temp_loss = accum_loss + if temp_loss < optimal_loss: + optimal_loss = temp_loss + save_model(save_best=True, epoch=epoch) + print("SSL Loss: {:.5f}\tSSL Acc: {:.5f}\tTime: {:.5f}".format(accum_loss, accum_acc, time.time() - start_time)) + return + +def generate_static_files(): + if args.verbose: + L = tqdm(dataloader) + else: + L = dataloader + molecule_repr_MoleculeSTM_list = [] + molecule_repr_generation_list = [] + for batch in L: + if args.MoleculeSTM_molecule_type == "SMILES": + SMILES_list = batch["structure"]["SMILES"] + else: + SMILES_list = batch["structure"]["SMILES"] + graph = batch["structure"]["graph"] + graph = graph.to(device) + if args.MoleculeSTM_molecule_type == "SMILES": + molecule_repr_MoleculeSTM = get_molecule_repr_MoleculeSTM( + SMILES_list, molecule_model=molecule_model_MoleculeSTM, mol2latent=mol2latent_MoleculeSTM, + molecule_type=args.MoleculeSTM_molecule_type, MegaMolBART_wrapper=MegaMolBART_wrapper + ) + else: + molecule_repr_MoleculeSTM = get_molecule_repr_MoleculeSTM( + graph, molecule_model=molecule_model_MoleculeSTM, mol2latent=mol2latent_MoleculeSTM, + molecule_type=args.MoleculeSTM_molecule_type, MegaMolBART_wrapper=MegaMolBART_wrapper + ) + molecule_repr_MoleculeSTM_list.append(molecule_repr_MoleculeSTM) + if args.generation_model == "MegaMolBART": + molecule_repr_generation = get_molecule_repr_generation( + SMILES_list, molecule_model=molecule_model_generation, + molecule_type="MegaMolBART", MegaMolBART_wrapper=MegaMolBART_wrapper + ) + molecule_repr_generation_list.append(molecule_repr_generation) + saved_file_path = os.path.join(args.static_files_path, "molecule_repr_MoleculeSTM_list.pkl") + with open(saved_file_path, 'wb') as f: + pickle.dump(molecule_repr_MoleculeSTM_list, f) + saved_file_path = os.path.join(args.static_files_path, "molecule_repr_generation_list.pkl") + with open(saved_file_path, 'wb') as f: + pickle.dump(molecule_repr_generation_list, f) + return + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=str, default="cuda:3") + parser.add_argument("--verbose", type=int, default=1) + parser.add_argument("--dataset_path", type=str, default="./datasets/mol_edit/ZINC250K_data") + parser.add_argument("--static_files_path", type=str, default="./datasets/mol_edit/ZINC250K_data/static_files/molkformer-Graph") + parser.add_argument("--dataset", type=str, default="ZINC250K") + parser.add_argument("--MoleculeSTM_molecule_type", type=str, default="Graph", choices=["SMILES", "Graph"]) + parser.add_argument("--output_path", type=str, default="./ckpts/finetune_ckpts/moledit/molkformer/Graph") + parser.add_argument("--config_path", type=str, default="./configs/moledit/molkformer-Graph-MegaMolBART.json") + parser.add_argument("--mode", type=str, default="train") + ########## for MoleculeSTM ########## + parser.add_argument("--MoleculeSTM_model_dir", type=str, default=None) + parser.add_argument("--SSL_emb_dim", type=int, default=256) + ########## for 2D GNN ########## + parser.add_argument("--gnn_emb_dim", type=int, default=300) + parser.add_argument("--num_layer", type=int, default=5) + parser.add_argument('--JK', type=str, default='last') + parser.add_argument("--dropout_ratio", type=float, default=0.5) + parser.add_argument("--gnn_type", type=str, default="gin") + parser.add_argument('--graph_pooling', type=str, default='mean') + + ########## for generation ########## + parser.add_argument('--generation_model', type=str, default="MegaMolBART", choices=["MegaMolBART"]) + + ######### for MegaMolBART ########## + parser.add_argument("--MegaMolBART_generation_model_dir", type=str, default="./ckpts/fusion_ckpts/pretrained_MegaMolBART/checkpoints") + parser.add_argument("--vocab_path", type=str, default="./ckpts/fusion_ckpts/pretrained_MegaMolBART/bart_vocab.txt") + + ########## for optimization ########## + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--epochs", type=int, default=5) + parser.add_argument("--decay", type=float, default=0) + parser.add_argument("--generation_lr", type=float, default=1e-2) + parser.add_argument("--MoleculeSTM_lr", type=float, default=1e-2) + parser.add_argument("--T", type=float, default=0.1) + parser.add_argument("--SSL_loss", type=str, default="EBM_NCE", choices=["EBM_NCE", "InfoNCE", "RR"]) + parser.add_argument("--CL_neg_samples", type=int, default=1) + parser.add_argument('--use_normalize', dest='normalize', action='store_true') + parser.add_argument('--no_normalize', dest='normalize', action='store_false') + parser.set_defaults(normalize=True) + parser.add_argument("--MASTER_PORT", type=str, default='6000') + parser.add_argument("--use_processed_dataset_250K", type=int, default=0) + parser.add_argument("--generate_static_files", type=int, default=0) + parser.add_argument("--use_static_files", type=int, default=0) + + args = parser.parse_args() + print(args) + + config = json.load(open(args.config_path)) + os.environ['MASTER_PORT'] = args.MASTER_PORT + + # load dataset + if args.use_processed_dataset_250K==1: # skip SUPPORTED_MOLEDIT_DATASET + with open("./datasets/mol_edit/ZINC250K_data/dataset_zinc250K.pkl", "rb") as f: + dataset = pickle.load(f) + else: + dataset = SUPPORTED_MOLEDIT_DATASET[args.dataset](args.dataset_path, config["data"]["mol"], split="train") + + dataloader_class = pyg_DataLoader + + device = torch.device(args.device) \ + if torch.cuda.is_available() else torch.device("cpu") + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + # load model + molecule_model_MoleculeSTM = MoleditModel(config["network"]) + mol2latent_MoleculeSTM = None + if config["model"]== "molstm-MegaMolBART": + MegaMolBART_wrapper = molecule_model_MoleculeSTM.model.MegaMolBART_wrapper + molecule_model_generation = copy.deepcopy(MegaMolBART_wrapper.model) + else: + MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=args.MegaMolBART_generation_model_dir, output_dir=None) + molecule_model_generation = copy.deepcopy(MegaMolBART_wrapper.model) + + torch.cuda.set_device(int(re.search(r'\d+', args.device).group())) + + molecule_model_generation = molecule_model_generation.to(device) + molecule_model_MoleculeSTM = molecule_model_MoleculeSTM.to(device) + freeze_network(molecule_model_generation) + freeze_network(molecule_model_MoleculeSTM) + molecule_model_generation.eval() + molecule_model_MoleculeSTM.eval() + + + dataloader = dataloader_class(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) + + + molecule_dim_generation = 256 + molecule_dim_MoleculeSTM = args.SSL_emb_dim + generation2MoleculeSTM = MLP(molecule_dim_generation, [molecule_dim_MoleculeSTM, molecule_dim_MoleculeSTM]).to(device) + MoleculeSTM2generation = MLP(molecule_dim_MoleculeSTM, [molecule_dim_generation, molecule_dim_generation]).to(device) + + model_param_group = [ + {"params": generation2MoleculeSTM.parameters(), "lr": args.generation_lr}, + {"params": MoleculeSTM2generation.parameters(), "lr": args.MoleculeSTM_lr}, + ] + optimizer = optim.Adam(model_param_group, weight_decay=args.decay) + optimal_loss = 1e10 + + if args.generate_static_files==1: + generate_static_files() + + if args.use_static_files==1: + saved_file_path = os.path.join(args.static_files_path, "molecule_repr_MoleculeSTM_list.pkl") + with open(saved_file_path, 'rb') as f: + molecule_repr_MoleculeSTM_list = pickle.load(f) + saved_file_path = os.path.join(args.static_files_path, "molecule_repr_generation_list.pkl") + with open(saved_file_path, 'rb') as f: + molecule_repr_generation_list = pickle.load(f) + + + for e in range(1, args.epochs+1): + print("Epoch {}".format(e)) + train(e) diff --git a/open_biomed/tasks/mol_edit/moledit_step_02_Latent_Optimization.py b/open_biomed/tasks/mol_edit/moledit_step_02_Latent_Optimization.py new file mode 100644 index 0000000..242a229 --- /dev/null +++ b/open_biomed/tasks/mol_edit/moledit_step_02_Latent_Optimization.py @@ -0,0 +1,232 @@ +import argparse +import math +import numpy as np +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +import json +import torch +from torch import optim +import torch.nn.functional as F +from tqdm import tqdm +import re +import copy +from utils.molstm_utils import prepare_text_tokens, get_SMILES_list, get_description_list, load_language_molecule_and_edit_models, clip_loss_for_edit, evaluate_SMILES_list +from models.multimodal.moleculestm import MLP +from transformers import BertTokenizer +from models.task_model.moledit_model import MoleditModel +from models.multimodal.mega_molbart.mega_mol_bart import MegaMolBART + +def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): + lr_ramp = min(1, (1 - t) / rampdown) + lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) + lr_ramp = lr_ramp * min(1, t / rampup) + return initial_lr * lr_ramp + + +def mean_pooling(token_embeddings, attention_mask): + attention_mask = ~attention_mask + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() # [pad, B, d] + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 0) # [B, d] + sum_mask = torch.clamp(input_mask_expanded.sum(0), min=1e-9) # [B, d] + return sum_embeddings / sum_mask + + +def check_edit(SMILES, text): + text_list = text_tokenizer(text, truncation=True, padding=True, return_tensors='pt').to(device) + del text_list["token_type_ids"] + text_output = text_model(text_list) + text_repr = text_output + + + first_and_second_SMILES_list = [] + + latent_code_init, pad_mask_init = MegaMolBART_wrapper.smileslist2embedding([SMILES]) # [pad, B, d], [pad, B] + first_and_second_SMILES_list.append(SMILES) + + regenerated_mols = MegaMolBART_wrapper.inverse_transform([latent_code_init], pad_mask_init.bool().cuda(), k=1, sanitize=True) + first_and_second_SMILES_list.append(regenerated_mols[0]) + + l2_lambda_list = [ + 1e1, 1e0, 1e-1, 1e-2, 1e-3 + ] + result_SMILES_list_one_pair, result_eval_list_one_pair = [], [] + + if args.use_noise_for_init: + print("Use random noise for init") + random_noise = torch.randn(latent_code_init.size()).to(device) + + for l2_lambda in l2_lambda_list: + print("l2 lambda: {}".format(l2_lambda)) + current_SMILES_list = [first_and_second_SMILES_list[0]] + [first_and_second_SMILES_list[1]] + if args.use_noise_for_init: + print("Use random noise for init") + latent = latent_code_init.detach().clone() + random_noise + else: + print("No random noise for init") + latent = latent_code_init.detach().clone() + pad_mask = pad_mask_init.detach().clone() + latent.requires_grad = True + optimizer = optim.Adam([latent], lr=args.lr) + + if args.verbose: + L = tqdm(range(args.epochs)) + else: + L = range(args.epochs) + + for i in L: + t = i / args.epochs + lr = get_lr(t, args.lr) + optimizer.param_groups[0]["lr"] = lr + + molecule_repr_generation = mean_pooling(latent, pad_mask) # [B, d] + if args.normalize: + molecule_repr_generation = F.normalize(molecule_repr_generation, dim=-1) + molecule_repr_MoleculeSTM = generation2MoleculeSTM(molecule_repr_generation) + + clip_loss_ = clip_loss_for_edit(molecule_repr_MoleculeSTM, text_repr) + l2_loss_ = l2_lambda * ((latent_code_init - latent) ** 2).mean() + + loss = clip_loss_ + l2_loss_ + + optimizer.zero_grad() + loss.backward(retain_graph=True) + optimizer.step() + print("clip loss: {:.5f}\tL2 loss: {:.5f}".format(clip_loss_.item(), l2_loss_.item())) + + generated_mols = MegaMolBART_wrapper.inverse_transform([latent], pad_mask.bool().cuda(), k=1, sanitize=True) + current_SMILES_list.append(generated_mols[0]) + result_SMILES_list_one_pair.append([text] + current_SMILES_list + ['{}'.format(l2_lambda)]) + + current_result_list = evaluate_SMILES_list(current_SMILES_list, text) + result_eval_list_one_pair.append(current_result_list) + print() + + result_eval_list_one_pair = np.array(result_eval_list_one_pair) + result_eval_list_one_pair = np.any(result_eval_list_one_pair, axis=0, keepdims=True) + print("result_eval_list_one_pair\n", result_eval_list_one_pair) + return result_SMILES_list_one_pair, result_eval_list_one_pair + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=str, default=None) + parser.add_argument("--verbose", type=int, default=1) + + ########## for editing ########## + parser.add_argument("--input_description", type=str, default=None) + parser.add_argument("--input_description_id", type=int, default=101) + parser.add_argument("--input_SMILES", type=str, default=None) + parser.add_argument("--input_SMILES_file", type=str, default=None) + parser.add_argument("--output_model_dir", type=str, default=None) + parser.add_argument("--use_noise_for_init", dest="use_noise_for_init", action="store_true") + parser.add_argument("--no_noise_for_init", dest="use_noise_for_init", action="store_false") + parser.set_defaults(use_noise_for_init=False) + parser.add_argument('--normalize', dest='normalize', action='store_true') + parser.add_argument('--no_normalize', dest='normalize', action='store_false') + parser.set_defaults(normalize=True) + + parser.add_argument("--dataset_path", type=str, default=None) + parser.add_argument("--SSL_emb_dim", type=int, default=256) + parser.add_argument("--max_seq_len", type=int, default=512) + parser.add_argument("--config_path", type=str, default=None) + ########## for MoleculeSTM ########## + parser.add_argument("--MoleculeSTM_molecule_type", type=str, default=None, choices=["SMILES", "Graph"]) + + ########## for MegaMolBART ########## + parser.add_argument("--MegaMolBART_generation_model_dir", type=str, default=None) + parser.add_argument("--vocab_path", type=str, default=None) + parser.add_argument("--text_mode", type=str, default=None) + + ########## for MoleculeSTM and generation projection ########## + parser.add_argument("--language_edit_model_dir", type=str, default=None) + ########## for editing ########## + parser.add_argument("--lr_rampup", type=float, default=0.05) + parser.add_argument("--lr", type=float, default=0.1) + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--MASTER_PORT", type=str, default='6000') + args = parser.parse_args() + + print(args) + + config = json.load(open(args.config_path)) + os.environ['MASTER_PORT'] = args.MASTER_PORT + device = torch.device(args.device) \ + if torch.cuda.is_available() else torch.device("cpu") + + + # load model + text_model = MoleditModel(config["network"]) + text_tokenizer = BertTokenizer.from_pretrained(args.text_mode, model_max_length=512, cache_dir=args.text_mode) + if config["model"]== "molstm-MegaMolBART": + MegaMolBART_wrapper = text_model.model.MegaMolBART_wrapper + molecule_model = MegaMolBART_wrapper.model + else: + MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=args.MegaMolBART_generation_model_dir, output_dir=None) + molecule_model = MegaMolBART_wrapper.model + + torch.cuda.set_device(int(re.search(r'\d+', args.device).group())) + + generation2MoleculeSTM = MLP(256, [args.SSL_emb_dim, args.SSL_emb_dim]) + input_model_path = os.path.join(args.language_edit_model_dir, "generation2MoleculeSTM_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + generation2MoleculeSTM.load_state_dict(state_dict) + + + MoleculeSTM2generation = MLP(args.SSL_emb_dim, [256, 256]) + input_model_path = os.path.join(args.language_edit_model_dir, "MoleculeSTM2generation_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + MoleculeSTM2generation.load_state_dict(state_dict) + + text_model = text_model.to(device) + molecule_model = molecule_model.to(device) + generation2MoleculeSTM.to(device) + MoleculeSTM2generation.to(device) + text_model.eval() + molecule_model.eval() + generation2MoleculeSTM.eval() + MoleculeSTM2generation.eval() + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + device = torch.device(args.device) \ + if torch.cuda.is_available() else torch.device("cpu") + + print("\n\n\nstart editing\n\n\n") + + source_SMILES_list = get_SMILES_list(args)[0:1] + description_list = get_description_list(args) + + + + for description in description_list: + print("===== for description {} =====".format(description)) + result_SMILES_list, result_acc_list = [], [] + for SMILES in source_SMILES_list: + print("===== for SMILES {} =====".format(SMILES)) + result_SMILES_list_, result_acc_list_ = check_edit(SMILES, description) + result_SMILES_list.extend(result_SMILES_list_) + result_acc_list.append(result_acc_list_) + print("\n\n\n") + + result_acc_list = np.concatenate(result_acc_list, axis=0) + result_acc_list = np.sum(result_acc_list, axis=0) + result_acc_list = 100. * result_acc_list / len(source_SMILES_list) + result_acc_row = '\t'.join(['{}'.format(x) for x in result_acc_list]) + print("===== Accuracy =====\t{}".format(result_acc_row)) + + if args.output_model_dir is not None: + saver_file = os.path.join(args.output_model_dir, "edited_SMILES.tsv") + with open(saver_file, 'a') as f: + for row in result_SMILES_list: + row = "\t".join(row) + print(row, file=f) + + saver_file = os.path.join(args.output_model_dir, "accuracy") + np.savez(saver_file, result_acc_list) diff --git a/open_biomed/utils/mol_utils.py b/open_biomed/utils/mol_utils.py index 55dad0b..0738625 100644 --- a/open_biomed/utils/mol_utils.py +++ b/open_biomed/utils/mol_utils.py @@ -432,4 +432,29 @@ def save_vocabulary( index = token_index writer.write(token + "\n") index += 1 - return (vocab_file,) \ No newline at end of file + return (vocab_file,) + +def get_biot5_tokenizer(path_t5, path_selfies): + from transformers import T5Tokenizer + tokenizer = T5Tokenizer.from_pretrained(path_t5) + tokenizer.model_max_length = int(1e9) + + amino_acids = [ + "A", "C", "D", "E", "F", + "G", "H", "I", "K", "L", + "M", "N", "P", "Q", "R", + "S", "T", "V", "W", "Y" + ] + prefixed_amino_acids = [f"

{aa}" for aa in amino_acids] + tokenizer.add_tokens(prefixed_amino_acids) + + selfies_dict_list = [line.strip() for line in open(path_selfies, "r")] + tokenizer.add_tokens(selfies_dict_list) + + special_tokens_dict = {'additional_special_tokens': + ['', '', + '', '', + 'MOLECULE NAME', 'DESCRIPTION', + 'PROTEIN NAME', 'FUNCTION', 'SUBCELLULAR LOCATION', 'PROTEIN FAMILIES']} + tokenizer.add_special_tokens(special_tokens_dict, replace_additional_special_tokens=False) + return tokenizer \ No newline at end of file diff --git a/open_biomed/utils/molstm_utils.py b/open_biomed/utils/molstm_utils.py new file mode 100644 index 0000000..a0ede4c --- /dev/null +++ b/open_biomed/utils/molstm_utils.py @@ -0,0 +1,556 @@ +import os +import copy +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import AutoModel, AutoTokenizer +from models.multimodal.mega_molbart.mega_mol_bart import MegaMolBART +from models.multimodal.moleculestm import MLP +from models.multimodal.moleculestm import GNN, GNN_graphpred +from rdkit import Chem, RDLogger +from rdkit.Chem import AllChem, Descriptors +from rdkit import DataStructs +lg = RDLogger.logger() +lg.setLevel(RDLogger.CRITICAL) + + + + +# This is for BERT +def padarray(A, size, value=0): + t = size - len(A) + return np.pad(A, pad_width=(0, t), mode='constant', constant_values = value) + + +# This is for BERT +def preprocess_each_sentence(sentence, tokenizer, max_seq_len): + text_input = tokenizer( + sentence, truncation=True, max_length=max_seq_len, + padding='max_length', return_tensors='np') + + input_ids = text_input['input_ids'].squeeze() + attention_mask = text_input['attention_mask'].squeeze() + + sentence_tokens_ids = padarray(input_ids, max_seq_len) + sentence_masks = padarray(attention_mask, max_seq_len) + return [sentence_tokens_ids, sentence_masks] + + +# This is for BERT +def prepare_text_tokens(device, description, tokenizer, max_seq_len): + B = len(description) + tokens_outputs = [preprocess_each_sentence(description[idx], tokenizer, max_seq_len) for idx in range(B)] + tokens_ids = [o[0] for o in tokens_outputs] + masks = [o[1] for o in tokens_outputs] + tokens_ids = torch.Tensor(tokens_ids).long().to(device) + masks = torch.Tensor(masks).bool().to(device) + return tokens_ids, masks + + +def get_molecule_repr_MoleculeSTM(molecule_data, mol2latent=None, molecule_type="SMILES", MegaMolBART_wrapper=None, molecule_model=None): + if molecule_type == "SMILES": + embedding, pad_mask = MegaMolBART_wrapper.smileslist2embedding(molecule_data) # [pad, B, d], [pad, B] + molecule_repr = embedding[0, :, :] # [B, d] + else: + molecule_repr = molecule_model(molecule_data) + + if mol2latent is not None: + molecule_repr = mol2latent(molecule_repr) + return molecule_repr + + +def freeze_network(model): + for param in model.parameters(): + param.requires_grad = False + return + + +def get_SMILES_list(args): + if args.input_SMILES is not None: + SMILES_list = [args.input_SMILES] + else: + SMILES_list = [] + f = open(args.input_SMILES_file, 'r') + lines = f.readlines() + for line in lines: + SMILES = line.strip() + if len(SMILES) > 0: + SMILES_list.append(SMILES) + return SMILES_list + + +description_dict = { + 101: "This molecule is soluble in water.", + 102: "This molecule is insoluble in water.", + 103: "This molecule is like a drug.", + 104: "This molecule is not like a drug.", + 105: "This molecule has high permeability.", + 106: "This molecule has low permeability.", + 107: "This molecule has more hydrogen bond acceptors.", + 108: "This molecule has more hydrogen bond donors.", + 109: "This molecule has high bioavailability.", + 110: "This molecule has low toxicity.", + 111: "This molecule is metabolically stable.", + + 201: "This molecule is soluble in water and has more hydrogen bond acceptors.", + 202: "This molecule is insoluble in water and has more hydrogen bond acceptors.", + 203: "This molecule is soluble in water and has more hydrogen bond donors.", + 204: "This molecule is insoluble in water and has more hydrogen bond donors.", + 205: "This molecule is soluble in water and has high permeability.", + 206: "This molecule is soluble in water and has low permeability.", + + 301: "This molecule looks like Penicillin.", + 302: "This molecule looks like Aspirin.", + 303: "This molecule looks like Caffeine.", + 304: "This molecule looks like Cholesterol.", + 305: "This molecule looks like Dopamine.", + 306: "This molecule looks like Cysteine.", + 307: "This molecule looks like Glutathione.", + + 401: "This molecule is tested positive in an assay that are inhibitors and substrates of an enzyme protein. It uses molecular oxygen inserting one oxygen atom into a substrate, and reducing the second into a water molecule.", + 402: "This molecule is tested positive in an assay for Anthrax Lethal, which acts as a protease that cleaves the N-terminal of most dual specificity mitogen-activated protein kinase kinases.", + 403: "This molecule is tested positive in an assay for Activators of ClpP, which cleaves peptides in various proteins in a process that requires ATP hydrolysis and has a limited peptidase activity in the absence of ATP-binding subunits.", + 404: "This molecule is tested positive in an assay for activators involved in the transport of proteins between the endosomes and the trans Golgi network.", + 405: "This molecule is an inhibitor of a protein that prevents the establishment of the cellular antiviral state by inhibiting ubiquitination that triggers antiviral transduction signal and inhibits post-transcriptional processing of cellular pre-mRNA.", + 406: "This molecule is tested positive in the high throughput screening assay to identify inhibitors of the SARS coronavirus 3C-like Protease, which cleaves the C-terminus of replicase polyprotein at 11 sites.", +} + + +def get_description_list(args): + if args.input_description is not None: + description_list = [args.input_description] + elif args.input_description_id is None: + raise ValueError + else: + print("Use {} descrition.".format(args.input_description_id)) + description_list = [description_dict[args.input_description_id]] + print("description_list", description_list) + return description_list + + +# https://pubchem.ncbi.nlm.nih.gov/compound/5904 +# Penicillin_SMILES = "CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C" +Penicillin_SMILES = "CC1(C)SC2C(NC(=O)Cc3ccccc3)C(=O)N2C1C(=O)O" + +# https://pubchem.ncbi.nlm.nih.gov/compound/2244 +# Aspirin_SMILES = "CC(=O)OC1=CC=CC=C1C(=O)O" +Aspirin_SMILES = "CC(=O)Oc1ccccc1C(=O)O" + +# https://pubchem.ncbi.nlm.nih.gov/compound/2519 +# Caffeine_SMILES = "CN1C=NC2=C1C(=O)N(C(=O)N2C)C" +Caffeine_SMILES = "Cn1c(=O)c2c(ncn2C)n(C)c1=O" + +# https://pubchem.ncbi.nlm.nih.gov/compound/5997 +# Cholesterol_SMILES = "CC(C)CCCC(C)C1CCC2C1(CCC3C2CC=C4C3(CCC(C4)O)C)C" +Cholesterol_SMILES = "CC(C)CCCC(C)C1CCC2C3CC=C4CC(O)CCC4(C)C3CCC12C" + +# https://pubchem.ncbi.nlm.nih.gov/compound/681 +# Dopamine_SMILES = "C1=CC(=C(C=C1CCN)O)O" +Dopamine_SMILES = "NCCc1ccc(O)c(O)c1" + +# https://pubchem.ncbi.nlm.nih.gov/compound/5862 +# Cysteine_SMILES = "C(C(C(=O)O)N)S" +Cysteine_SMILES = "NC(CS)C(=O)O" + +# https://pubchem.ncbi.nlm.nih.gov/compound/124886 +# Glutathione_SMILES = "C(CC(=O)NC(CS)C(=O)NCC(=O)O)C(C(=O)O)N" +Glutathione_SMILES = "NC(CCC(=O)NC(CS)C(=O)NCC(=O)O)C(=O)O" + + +def load_molecule_models(args): + """ + This function returns the two encoders, one for molecule generative model and one for CLIP. + """ + if args.MoleculeSTM_molecule_type == "SMILES": + # This is loading from the pretarined_MegaMolBART + MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=args.MegaMolBART_generation_model_dir, output_dir=None) + molecule_model_generation = copy.deepcopy(MegaMolBART_wrapper.model) + print("Loading from pretrained MegaMolBART ({}).".format(args.MegaMolBART_generation_model_dir)) + molecule_dim_generation = 256 + + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "molecule_model.pth") + molecule_model_MoleculeSTM = MegaMolBART_wrapper.model + state_dict = torch.load(input_model_path, map_location='cpu') + print("Loading from {}...".format(input_model_path)) + molecule_model_MoleculeSTM.load_state_dict(state_dict) + molecule_dim_MoleculeSTM = args.SSL_emb_dim + + mol2latent_MoleculeSTM = nn.Linear(256, molecule_dim_MoleculeSTM) + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "mol2latent_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + mol2latent_MoleculeSTM.load_state_dict(state_dict) + + else: + # This is loading from the pretarined_MegaMolBART + MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=args.MegaMolBART_generation_model_dir, output_dir=None) + molecule_model_generation = copy.deepcopy(MegaMolBART_wrapper.model) + print("Loading from pretrained MegaMolBART ({}).".format(args.MegaMolBART_generation_model_dir)) + molecule_dim_generation = 256 + + # This is loading GNN from the pretrained_GNN + molecule_node_model = GNN(num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, JK=args.JK, drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type) + molecule_model_MoleculeSTM = GNN_graphpred(num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, JK=args.JK, graph_pooling=args.graph_pooling, num_tasks=1, molecule_node_model=molecule_node_model) + print("Start from pretrained model (MoleculeSTM) in {}.".format(args.MoleculeSTM_model_dir)) + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "molecule_model.pth") + state_dict = torch.load(input_model_path, map_location='cpu') + molecule_model_MoleculeSTM.load_state_dict(state_dict) + molecule_dim_MoleculeSTM = args.SSL_emb_dim + + mol2latent_MoleculeSTM = nn.Linear(300, molecule_dim_MoleculeSTM) + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "mol2latent_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + mol2latent_MoleculeSTM.load_state_dict(state_dict) + + return MegaMolBART_wrapper, molecule_model_generation, molecule_dim_generation, \ + molecule_model_MoleculeSTM, mol2latent_MoleculeSTM, molecule_dim_MoleculeSTM + + +def load_language_molecule_and_edit_models(args): + pretrained_SciBERT_folder = os.path.join(args.dataset_path, 'pretrained_SciBERT') + text_tokenizer = AutoTokenizer.from_pretrained(args.text_mode, cache_dir=pretrained_SciBERT_folder) + text_model = AutoModel.from_pretrained(args.text_mode, cache_dir=pretrained_SciBERT_folder) + + text_dim = 768 + + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "text_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + text_model.load_state_dict(state_dict) + + """ + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "molecule_model.pth") + print("Loading from {}...".format(input_model_path)) + MegaMolBART_wrapper = MegaMolBART(input_dir=None, output_dir=None) + molecule_model = MegaMolBART_wrapper.model + state_dict = torch.load(input_model_path, map_location='cpu') + molecule_model.load_state_dict(state_dict) + """ + # This is loading from the pretarined_MegaMolBART + MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=args.MegaMolBART_generation_model_dir, output_dir=None) + molecule_model = MegaMolBART_wrapper.model + print("Loading from pretrained MegaMolBART ({}).".format(args.MegaMolBART_generation_model_dir)) + molecule_dim_generation = 256 + if args.MoleculeSTM_molecule_type == "SMILES": # For MegaMolBART + molecule_dim_MoleculeSTM = 256 + else: # For GIN + molecule_dim_MoleculeSTM = 300 + + text2latent = nn.Linear(text_dim, args.SSL_emb_dim) + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "text2latent_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + text2latent.load_state_dict(state_dict) + + mol2latent = nn.Linear(molecule_dim_MoleculeSTM, args.SSL_emb_dim) + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "mol2latent_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + mol2latent.load_state_dict(state_dict) + + # generation2MoleculeSTM = nn.Linear(molecule_dim_generation, args.SSL_emb_dim) + generation2MoleculeSTM = MLP(molecule_dim_generation, [args.SSL_emb_dim, args.SSL_emb_dim]) + input_model_path = os.path.join(args.language_edit_model_dir, "generation2foundation_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + generation2MoleculeSTM.load_state_dict(state_dict) + + # MoleculeSTM2generation = nn.Linear(args.SSL_emb_dim, molecule_dim_generation) + MoleculeSTM2generation = MLP(args.SSL_emb_dim, [molecule_dim_generation, molecule_dim_generation]) + input_model_path = os.path.join(args.language_edit_model_dir, "foundation2generation_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + MoleculeSTM2generation.load_state_dict(state_dict) + + return text_model, text_tokenizer, text_dim, molecule_model, MegaMolBART_wrapper, molecule_dim_generation, text2latent, mol2latent, generation2MoleculeSTM, MoleculeSTM2generation + + +def clip_loss_for_edit(molecule_repr, text_repr): + molecule_repr = F.normalize(molecule_repr, dim=-1) + text_repr = F.normalize(text_repr, dim=-1) + + similarity = -torch.mm(molecule_repr, text_repr.transpose(0, 1))[0] + return similarity + + +def get_molecule_similarity(mol_a, mol_b): + fp_a = AllChem.GetMorganFingerprintAsBitVect(mol_a, 2, nBits=1024) + fp_b = AllChem.GetMorganFingerprintAsBitVect(mol_b, 2, nBits=1024) + sim = DataStructs.TanimotoSimilarity(fp_a, fp_b) + return sim + + +def evaluate_SMILES_list(SMILES_list, description): + print("SMILES_list:", SMILES_list) + mol_list = [] + for SMILES in SMILES_list: + mol = Chem.MolFromSmiles(SMILES) + # Chem.SanitizeMol(mol) + # print(SMILES, mol) + if mol is None: + continue + mol_list.append(mol) + print("valid mol list:", len(mol_list)) + + if len(mol_list) < 3: + return [False] + + if "soluble" in description and "insoluble" not in description: + props = ["MolLogP"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] > value_list[2]: + answer = [True] + else: + answer = [False] + + elif "insoluble" in description: + props = ["MolLogP"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] < value_list[2]: + answer = [True] + else: + answer = [False] + + elif description in ["This molecule is more like a drug.", "This molecule is like a drug."]: + props = ["qed"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] < value_list[2]: + answer = [True] + else: + answer = [False] + + elif description in ["This molecule is less like a drug.", "This molecule is not like a drug."]: + props = ["qed"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] > value_list[2]: + answer = [True] + else: + answer = [False] + + elif description in ["This molecule has higher permeability.", "This molecule has high permeability."]: + props = ["TPSA"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] > value_list[2]: + answer = [True] + else: + answer = [False] + + elif description in ["This molecule has lower permeability.", "This molecule has low permeability."]: + props = ["TPSA"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] < value_list[2]: + answer = [True] + else: + answer = [False] + + elif description in ["This molecule has higher molecular weight.", "This molecule has high molecular weight."]: + props = ["MolWt"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] < value_list[2]: + answer = [True] + else: + answer = [False] + + elif description in ["This molecule has lower molecular weight.", "This molecule has low molecular weight."]: + props = ["MolWt"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] > value_list[2]: + answer = [True] + else: + answer = [False] + + elif description in ["This molecule has more hydrogen bond acceptors."]: + props = ["NumHAcceptors"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] < value_list[2]: + answer = [True] + else: + answer = [False] + + elif description in ["This molecule has more hydrogen bond donors."]: + props = ["NumHDonors"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] < value_list[2]: + answer = [True] + else: + answer = [False] + + elif "penicillin" in description or "Penicillin" in description: + target_mol = Chem.MolFromSmiles(Penicillin_SMILES) + original_SMILES = SMILES_list[0] + original_mol = mol_list[0] + original_similarity = get_molecule_similarity(target_mol, original_mol) + print("similarity between penicillin and original molecules\n{} & {:.5f}".format(original_SMILES, original_similarity)) + + edited_SMILES = SMILES_list[2] + edited_mol = mol_list[2] + edited_similarity = get_molecule_similarity(target_mol, edited_mol) + print("similarity between penicillin and edited molecules\n{} & {:.5f}".format(edited_SMILES, edited_similarity)) + if edited_similarity > original_similarity: + answer = [True] + else: + answer = [False] + + elif "aspirin" in description or "Aspirin" in description: + target_mol = Chem.MolFromSmiles(Aspirin_SMILES) + original_SMILES = SMILES_list[0] + original_mol = mol_list[0] + original_similarity = get_molecule_similarity(target_mol, original_mol) + print("similarity between aspirin and original molecules\n{} & {:.5f}".format(original_SMILES, original_similarity)) + + edited_SMILES = SMILES_list[2] + edited_mol = mol_list[2] + edited_similarity = get_molecule_similarity(target_mol, edited_mol) + print("similarity between aspirin and edited molecules\n{} & {:.5f}".format(edited_SMILES, edited_similarity)) + if edited_similarity > original_similarity: # check original_similarity >< 0.8 + answer = [True] + else: + answer = [False] + + elif "caffeine" in description or "Caffeine" in description: + target_mol = Chem.MolFromSmiles(Caffeine_SMILES) + original_SMILES = SMILES_list[0] + original_mol = mol_list[0] + original_similarity = get_molecule_similarity(target_mol, original_mol) + print("similarity between caffeine and original molecules\n{} & {:.5f}".format(original_SMILES, original_similarity)) + + edited_SMILES = SMILES_list[2] + edited_mol = mol_list[2] + edited_similarity = get_molecule_similarity(target_mol, edited_mol) + print("similarity between caffeine and edited molecules\n{} & {:.5f}".format(edited_SMILES, edited_similarity)) + if edited_similarity > original_similarity: + answer = [True] + else: + answer = [False] + + elif "cholesterol" in description or "Cholesterol" in description: + target_mol = Chem.MolFromSmiles(Cholesterol_SMILES) + original_SMILES = SMILES_list[0] + original_mol = mol_list[0] + original_similarity = get_molecule_similarity(target_mol, original_mol) + print("similarity between cholesterol and original molecules\n{} & {:.5f}".format(original_SMILES, original_similarity)) + + edited_SMILES = SMILES_list[2] + edited_mol = mol_list[2] + edited_similarity = get_molecule_similarity(target_mol, edited_mol) + print("similarity between cholesterol and edited molecules\n{} & {:.5f}".format(edited_SMILES, edited_similarity)) + if edited_similarity > original_similarity: # check original_similarity >< 0.8 + answer = [True] + else: + answer = [False] + + elif "dopamine" in description or "Dopamine" in description: + target_mol = Chem.MolFromSmiles(Dopamine_SMILES) + original_SMILES = SMILES_list[0] + original_mol = mol_list[0] + original_similarity = get_molecule_similarity(target_mol, original_mol) + print("similarity between dopamine and original molecules\n{} & {:.5f}".format(original_SMILES, original_similarity)) + + edited_SMILES = SMILES_list[2] + edited_mol = mol_list[2] + edited_similarity = get_molecule_similarity(target_mol, edited_mol) + print("similarity between dopamine and edited molecules\n{} & {:.5f}".format(edited_SMILES, edited_similarity)) + if edited_similarity > original_similarity: + answer = [True] + else: + answer = [False] + + elif "cysteine" in description or "Cysteine" in description: + target_mol = Chem.MolFromSmiles(Cysteine_SMILES) + original_SMILES = SMILES_list[0] + original_mol = mol_list[0] + original_similarity = get_molecule_similarity(target_mol, original_mol) + print("similarity between cysteine and original molecules\n{} & {:.5f}".format(original_SMILES, original_similarity)) + + edited_SMILES = SMILES_list[2] + edited_mol = mol_list[2] + edited_similarity = get_molecule_similarity(target_mol, edited_mol) + print("similarity between cysteine and edited molecules\n{} & {:.5f}".format(edited_SMILES, edited_similarity)) + if edited_similarity > original_similarity: # check original_similarity >< 0.8 + answer = [True] + else: + answer = [False] + + elif "glutathione" in description or "Glutathione" in description: + target_mol = Chem.MolFromSmiles(Glutathione_SMILES) + original_SMILES = SMILES_list[0] + original_mol = mol_list[0] + original_similarity = get_molecule_similarity(target_mol, original_mol) + print("similarity between glutathione and original molecules\n{} & {:.5f}".format(original_SMILES, original_similarity)) + + edited_SMILES = SMILES_list[2] + edited_mol = mol_list[2] + edited_similarity = get_molecule_similarity(target_mol, edited_mol) + print("similarity between glutathione and edited molecules\n{} & {:.5f}".format(edited_SMILES, edited_similarity)) + if edited_similarity > original_similarity: # check original_similarity >< 0.8 + answer = [True] + else: + answer = [False] + + else: + print("Not implemented.") + answer = [False] + + return answer \ No newline at end of file diff --git a/scripts/multimodal/moledit/edit.sh b/scripts/multimodal/moledit/edit.sh new file mode 100644 index 0000000..0e063ba --- /dev/null +++ b/scripts/multimodal/moledit/edit.sh @@ -0,0 +1,21 @@ +#!/bin/bash molkformer--Graph momu--Graph molstm--SMILES/Graph ID---./models/MoleculeSTM/downstream_molecule_edit_utils.py +MODE="test" +MODEL="molkformer" +DEVICE=$1 +EPOCHS=100 +TYPE="Graph" +ID=101 + +python open_biomed/tasks/mol_edit/moledit_step_02_Latent_Optimization.py \ +--device ${DEVICE} \ +--config_path ./configs/moledit/${MODEL}-${TYPE}-MegaMolBART.json \ +--input_SMILES_file ./datasets/mol_edit/Editing_data/single_multi_property_SMILES.txt \ +--language_edit_model_dir ./ckpts/finetune_ckpts/moledit/${MODEL}/${TYPE} \ +--output_model_dir ./open_biomed/tasks/mol_edit \ +--text_mode ./ckpts/text_ckpts/scibert_scivocab_uncased \ +--epochs ${EPOCHS} \ +--input_description_id ${ID} \ +--MoleculeSTM_molecule_type ${TYPE} \ +--vocab_path ./ckpts/fusion_ckpts/pretrained_MegaMolBART/bart_vocab.txt \ +--MegaMolBART_generation_model_dir ./ckpts/fusion_ckpts/pretrained_MegaMolBART/checkpoints \ +--MASTER_PORT '6000' \ No newline at end of file diff --git a/scripts/multimodal/moledit/train.sh b/scripts/multimodal/moledit/train.sh new file mode 100755 index 0000000..8d54106 --- /dev/null +++ b/scripts/multimodal/moledit/train.sh @@ -0,0 +1,27 @@ +#!/bin/bash molkformer--Graph momu--Graph molstm--SMILES/Graph +MODE="train" +MODEL="molkformer" +DEVICE=$1 +EPOCHS=100 +TYPE="Graph" + +mkdir ./ckpts/finetune_ckpts/moledit/${MODEL} + +python open_biomed/tasks/mol_edit/moledit_step_01_Space_Alignment.py \ +--device ${DEVICE} \ +--MoleculeSTM_molecule_type ${TYPE} \ +--config_path ./configs/moledit/${MODEL}-${TYPE}-MegaMolBART.json \ +--dataset ZINC250K \ +--dataset_path ./datasets/mol_edit/ZINC250K_data \ +--output_path ./ckpts/finetune_ckpts/moledit/${MODEL}/${TYPE} \ +--static_files_path ./datasets/mol_edit/ZINC250K_data/static_files/${MODEL}-${TYPE} \ +--mode ${MODE} \ +--epochs ${EPOCHS} \ +--num_workers 8 \ +--batch_size 256 \ +--vocab_path ./ckpts/fusion_ckpts/pretrained_MegaMolBART/bart_vocab.txt \ +--MegaMolBART_generation_model_dir ./ckpts/fusion_ckpts/pretrained_MegaMolBART/checkpoints \ +--MASTER_PORT '6000' \ +--use_processed_dataset_250K 0 \ +--generate_static_files 0 \ +--use_static_files 0 \ No newline at end of file