Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
*.iws

# C extensions
*.so

# Pycharm files
../.idea
**/../.idea
**/.idea/
.idea/
.idea/*

# documentation
_build/

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
/conda_recipe/

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints
*.ipynb

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# mac os
.DS_Store
135 changes: 82 additions & 53 deletions NGF/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,38 @@
from copy import deepcopy

from keras import layers
from keras.utils.layer_utils import layer_from_config
import theano.tensor as T
# from keras.utils.layer_utils import layer_from_config # keras 1.2
from keras.layers import deserialize as layer_from_config # keras 2.*

import tensorflow as tf
# import theano.tensor as T
import keras.backend as K

from .utils import filter_func_args, mol_shapes_to_dims

def temporal_padding(x, paddings=(1, 0), padvalue=0):
'''Pad the middle dimension of a 3D tensor
with `padding[0]` values left and `padding[1]` values right.

Modified from keras.backend.temporal_padding
https://github.com/fchollet/keras/blob/3bf913d/keras/backend/theano_backend.py#L590

TODO: Implement for tensorflow (supposebly more easy)
'''
if not isinstance(paddings, (tuple, list, ndarray)):
paddings = (paddings, paddings)

input_shape = x.shape
output_shape = (input_shape[0],
input_shape[1] + sum(paddings),
input_shape[2])
output = T.zeros(output_shape)

# Set pad value and set subtensor of actual tensor
output = T.set_subtensor(output[:, :paddings[0], :], padvalue)
output = T.set_subtensor(output[:, paddings[1]:, :], padvalue)
output = T.set_subtensor(output[:, paddings[0]:x.shape[1] + paddings[0], :], x)
return output
# def temporal_padding(x, paddings=(1, 0), padvalue=0):
# '''Pad the middle dimension of a 3D tensor
# with `padding[0]` values left and `padding[1]` values right.
#
# Modified from keras.backend.temporal_padding
# https://github.com/fchollet/keras/blob/3bf913d/keras/backend/theano_backend.py#L590
#
# TODO: Implement for tensorflow (supposebly more easy)
# '''
# if not isinstance(paddings, (tuple, list, ndarray)):
# paddings = (paddings, paddings)
#
# input_shape = x.shape
# output_shape = (input_shape[0],
# input_shape[1] + sum(paddings),
# input_shape[2])
# output = T.zeros(output_shape)
#
# # Set pad value and set subtensor of actual tensor
# output = T.set_subtensor(output[:, :paddings[0], :], padvalue)
# output = T.set_subtensor(output[:, paddings[1]:, :], padvalue)
# output = T.set_subtensor(output[:, paddings[0]:x.shape[1] + paddings[0], :], x)
# return output

def neighbour_lookup(atoms, edges, maskvalue=0, include_self=False):
''' Looks up the features of an all atoms neighbours, for a batch of molecules.
Expand Down Expand Up @@ -64,8 +67,11 @@ def neighbour_lookup(atoms, edges, maskvalue=0, include_self=False):
masked_edges = edges + 1
# We then add a padding vector at index 0 by padding to the left of the
# lookup matrix with the value that the new mask should get
masked_atoms = temporal_padding(atoms, (1,0), padvalue=maskvalue)

# Theano padding:
# masked_atoms = temporal_padding(atoms, (1,0), padvalue=maskvalue)
# Tensorflow padding:
paddings = tf.constant([[0, 0], [1, 0], [0, 0]])
masked_atoms = tf.pad(atoms, paddings, "CONSTANT")

# Import dimensions
atoms_shape = K.shape(masked_atoms)
Expand All @@ -79,23 +85,34 @@ def neighbour_lookup(atoms, edges, maskvalue=0, include_self=False):

# create broadcastable offset
offset_shape = (batch_n, 1, 1)
offset = K.reshape(T.arange(batch_n, dtype=K.dtype(masked_edges)), offset_shape)
offset *= lookup_size
# Theano arange and cast:
# offset = K.reshape(T.arange(batch_n, dtype=K.dtype(masked_edges)), offset_shape)
# offset *= lookup_size
# Tensorflow range and backend cast:
offset = K.reshape(K.cast(tf.range(batch_n, dtype='int32'), dtype=K.dtype(masked_edges)), offset_shape)
offset *= K.cast(lookup_size, dtype=K.dtype(offset))


# apply offset to account for the fact that after reshape, all individual
# batch_n indices will be combined into a single big index
flattened_atoms = K.reshape(masked_atoms, (-1, num_atom_features))
flattened_edges = K.reshape(masked_edges + offset, (batch_n, -1))

# Gather flattened
flattened_result = K.gather(flattened_atoms, flattened_edges)
# flattened_result = K.gather(flattened_atoms, flattened_edges)
# Tensorflow/backend cast:
flattened_result = K.gather(flattened_atoms, K.cast(flattened_edges, dtype='int32'))

# Unflatten result
output_shape = (batch_n, max_atoms, max_degree, num_atom_features)
output = T.reshape(flattened_result, output_shape)
#output = T.reshape(flattened_result, output_shape)
# Tensorflow/backend reshape:
output = K.reshape(flattened_result, output_shape)

if include_self:
return K.concatenate([K.expand_dims(atoms, dim=2), output], axis=2)
# return K.concatenate([K.expand_dims(atoms, dim=2), output], axis=2)
# Tensorflow concat:
return tf.concat([K.expand_dims(atoms, axis=2), output], axis=2) # Keras 2: raplaced dim with axis
return output

class NeuralGraphHidden(layers.Layer):
Expand Down Expand Up @@ -166,14 +183,15 @@ def __init__(self, inner_layer_arg, **kwargs):
# Case 1: Check if inner_layer_arg is conv_width
if isinstance(inner_layer_arg, (int, long)):
self.conv_width = inner_layer_arg
dense_layer_kwargs, kwargs = filter_func_args(layers.Dense.__init__,
kwargs, overrule_args=['name'])
self.create_inner_layer_fn = lambda: layers.Dense(self.conv_width, **dense_layer_kwargs)
# Keras2: we assume all the kwargs should be passed to the Dense layer
# dense_layer_kwargs, kwargs = filter_func_args(layers.Dense.__init__,
# kwargs, overrule_args=['name'])
self.create_inner_layer_fn = lambda: layers.Dense(self.conv_width, **kwargs) #dense_layer_kwargs)

# Case 2: Check if an initialised keras layer is given
elif isinstance(inner_layer_arg, layers.Layer):
assert inner_layer_arg.built == False, 'When initialising with a keras layer, it cannot be built.'
_, self.conv_width = inner_layer_arg.get_output_shape_for((None, None))
_, self.conv_width = inner_layer_arg.compute_output_shape((None, 1))
# layer_from_config will mutate the config dict, therefore create a get fn
self.create_inner_layer_fn = lambda: layer_from_config(dict(
class_name=inner_layer_arg.__class__.__name__,
Expand All @@ -184,13 +202,13 @@ def __init__(self, inner_layer_arg, **kwargs):
example_instance = inner_layer_arg()
assert isinstance(example_instance, layers.Layer), 'When initialising with a function, the function has to return a keras layer'
assert example_instance.built == False, 'When initialising with a keras layer, it cannot be built.'
_, self.conv_width = example_instance.get_output_shape_for((None, None))
_, self.conv_width = example_instance.compute_output_shape((None, 1))
self.create_inner_layer_fn = inner_layer_arg

else:
raise ValueError('NeuralGraphHidden has to be initialised with 1). int conv_widht, 2). a keras layer instance, or 3). a function returning a keras layer instance.')

super(NeuralGraphHidden, self).__init__(**kwargs)
super(NeuralGraphHidden, self).__init__() # Keras2: all the kwargs will be passed to the Dense layer only

def build(self, inputs_shape):

Expand Down Expand Up @@ -233,7 +251,9 @@ def call(self, inputs, mask=None):
num_bond_features = bonds._keras_shape[-1]

# Create a matrix that stores for each atom, the degree it is
atom_degrees = K.sum(K.not_equal(edges, -1), axis=-1, keepdims=True)
# atom_degrees = K.sum(K.not_equal(edges, -1), axis=-1, keepdims=True)
# backend cast to floatx:
atom_degrees = K.sum(K.cast(K.not_equal(edges, -1), dtype=K.floatx()), axis=-1, keepdims=True)

# For each atom, look up the features of it's neighbour
neighbour_atom_features = neighbour_lookup(atoms, edges, include_self=True)
Expand All @@ -245,7 +265,9 @@ def call(self, inputs, mask=None):
summed_bond_features = K.sum(bonds, axis=-2)

# Concatenate the summed atom and bond features
summed_features = K.concatenate([summed_atom_features, summed_bond_features], axis=-1)
# summed_features = K.concatenate([summed_atom_features, summed_bond_features], axis=-1)
# Tensorflow concat:
summed_features = tf.concat([summed_atom_features, summed_bond_features], axis=-1)

# For each degree we convolve with a different weight matrix
new_features_by_degree = []
Expand All @@ -266,11 +288,11 @@ def call(self, inputs, mask=None):
new_features_by_degree.append(new_masked_features)

# Finally sum the features of all atoms
new_features = layers.merge(new_features_by_degree, mode='sum')
new_features = layers.add(new_features_by_degree)

return new_features

def get_output_shape_for(self, inputs_shape):
def compute_output_shape(self, inputs_shape):

# Import dimensions
(max_atoms, max_degree, num_atom_features, num_bond_features,
Expand Down Expand Up @@ -370,28 +392,29 @@ def __init__(self, inner_layer_arg, **kwargs):
# Case 1: Check if inner_layer_arg is fp_length
if isinstance(inner_layer_arg, (int, long)):
self.fp_length = inner_layer_arg
dense_layer_kwargs, kwargs = filter_func_args(layers.Dense.__init__,
kwargs, overrule_args=['name'])
self.create_inner_layer_fn = lambda: layers.Dense(self.fp_length, **dense_layer_kwargs)
# we assume all the kwargs should be passed to the Dense layer
# dense_layer_kwargs, kwargs = filter_func_args(layers.Dense.__init__,
# kwargs, overrule_args=['name'])
self.create_inner_layer_fn = lambda: layers.Dense(self.fp_length, **kwargs) # dense_layer_kwargs)

# Case 2: Check if an initialised keras layer is given
elif isinstance(inner_layer_arg, layers.Layer):
assert inner_layer_arg.built == False, 'When initialising with a keras layer, it cannot be built.'
_, self.fp_length = inner_layer_arg.get_output_shape_for((None, None))
_, self.fp_length = inner_layer_arg.compute_output_shape((None, 1))
self.create_inner_layer_fn = lambda: inner_layer_arg

# Case 3: Check if a function is provided that returns a initialised keras layer
elif callable(inner_layer_arg):
example_instance = inner_layer_arg()
assert isinstance(example_instance, layers.Layer), 'When initialising with a function, the function has to return a keras layer'
assert example_instance.built == False, 'When initialising with a keras layer, it cannot be built.'
_, self.fp_length = example_instance.get_output_shape_for((None, None))
_, self.fp_length = example_instance.compute_output_shape((None, 1))
self.create_inner_layer_fn = inner_layer_arg

else:
raise ValueError('NeuralGraphHidden has to be initialised with 1). int conv_widht, 2). a keras layer instance, or 3). a function returning a keras layer instance.')

super(NeuralGraphOutput, self).__init__(**kwargs)
super(NeuralGraphOutput, self).__init__() # Keras2: all the kwargs will be passed to the Dense layer only

def build(self, inputs_shape):

Expand Down Expand Up @@ -430,14 +453,18 @@ def call(self, inputs, mask=None):
# to create a general atom mask (unused atoms are 0 padded)
# We have to use the edge vector for this, because in theory, a convolution
# could lead to a zero vector for an atom that is present in the molecule
atom_degrees = K.sum(K.not_equal(edges, -1), axis=-1, keepdims=True)
# atom_degrees = K.sum(K.not_equal(edges, -1), axis=-1, keepdims=True)
# backend cast to floatx:
atom_degrees = K.sum(K.cast(K.not_equal(edges, -1), K.floatx()), axis=-1, keepdims=True)
general_atom_mask = K.cast(K.not_equal(atom_degrees, 0), K.floatx())

# Sum the edge features for each atom
summed_bond_features = K.sum(bonds, axis=-2)

# Concatenate the summed atom and bond features
atoms_bonds_features = K.concatenate([atoms, summed_bond_features], axis=-1)
# atoms_bonds_features = K.concatenate([atoms, summed_bond_features], axis=-1)
# Tensorflow concat:
atoms_bonds_features = tf.concat([atoms, summed_bond_features], axis=-1)

# Compute fingerprint
atoms_bonds_features._keras_shape = (None, max_atoms, num_atom_features+num_bond_features)
Expand All @@ -451,7 +478,7 @@ def call(self, inputs, mask=None):

return final_fp_out

def get_output_shape_for(self, inputs_shape):
def compute_output_shape(self, inputs_shape):

# Import dimensions
(max_atoms, max_degree, num_atom_features, num_bond_features,
Expand Down Expand Up @@ -504,12 +531,14 @@ def call(self, inputs, mask=None):
# Take max along `degree` axis (2) to get max of neighbours and self
max_features = K.max(neighbour_atom_features, axis=2)

atom_degrees = K.sum(K.not_equal(edges, -1), axis=-1, keepdims=True)
# atom_degrees = K.sum(K.not_equal(edges, -1), axis=-1, keepdims=True)
# backend cast to floatx:
atom_degrees = K.sum(K.cast(K.not_equal(edges, -1), K.floatx()), axis=-1, keepdims=True)
general_atom_mask = K.cast(K.not_equal(atom_degrees, 0), K.floatx())

return max_features * general_atom_mask

def get_output_shape_for(self, inputs_shape):
def compute_output_shape(self, inputs_shape):

# Only returns `atoms` tensor
return inputs_shape[0]
Expand Down
Loading