Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion GCN/GCN_explainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
import torch.nn.functional as F
import plotly.graph_objs as go
import networkx as nx
import numpy as np


def get_node_importance(model, data):
"""
Compute saliency maps by calculating gradients with respect to the input nodes.
Expand All @@ -19,6 +19,7 @@ def get_node_importance(model, data):

return importance_scores


def visualize_graph_with_importance_interactive(data, importance_scores, cell_type_decoder, threshold='auto', output_html="graph.html"):
"""
Visualizes the graph with node importance and exports it as an interactive HTML.
Expand Down Expand Up @@ -99,6 +100,8 @@ def visualize_graph_with_importance_interactive(data, importance_scores, cell_ty
)
fig.write_html(output_html)
print(f"Interactive graph saved to {output_html}")


'''
## Example usage, once you have a trained model:
with open('mel_alpha_0.01_buffer_0.pickle', 'rb') as p:
Expand Down
30 changes: 13 additions & 17 deletions GCN/data_process.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,25 @@
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_max_pool
from torch_geometric.data import Data, Dataset
import pandas as pd
import numpy as np
from sklearn.model_selection import LeaveOneOut
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.preprocessing import OneHotEncoder
from scipy.spatial import Delaunay
import matplotlib.pyplot as plt
from tqdm import tqdm
import itertools
import networkx as nx
from scipy.spatial import ConvexHull
import alphashape
from shapely.geometry import Point
import pickle
import sys
import warnings

warnings.filterwarnings("ignore")


class CellGraphDataset(Dataset):
def __init__(self, csv_path, groups, max_distance=100, transform=None, pre_transform=None,
def __init__(self, csv_path, groups, max_distance=100, transform=None, pre_transform=None,
cells_to_filter=None, alpha=0.01, buffer_value=0):
super().__init__(transform, pre_transform)
self.df = pd.read_csv(csv_path)
self.df = self.df[self.df['Group'].isin(groups)]
self.patient_fov_combos = self.df.groupby(['patient number', 'fov', 'Group']).size().reset_index()[['patient number', 'fov', 'Group']]
self.patient_fov_combos = self.df.groupby(['patient number', 'fov', 'Group']).size().reset_index()[
['patient number', 'fov', 'Group']]
self.pred_encoder = OneHotEncoder(sparse=False).fit(self.df[['pred']])
self.group_encoder = {group: i for i, group in enumerate(self.df['Group'].unique())}
self.max_distance = max_distance
Expand All @@ -40,13 +34,13 @@ def __getitem__(self, idx):
patient = self.patient_fov_combos.iloc[idx]['patient number']
fov = self.patient_fov_combos.iloc[idx]['fov']
group = self.patient_fov_combos.iloc[idx]['Group']

data_p = self.df[(self.df['patient number'] == patient) & (self.df['fov'] == fov)]

### Filter out the specified cell types
if self.cells_to_filter:
data_p = data_p[~data_p['pred'].isin(self.cells_to_filter)]

# Create initial graph
G = nx.Graph()
for i, row in data_p.iterrows():
Expand All @@ -61,7 +55,7 @@ def __getitem__(self, idx):

edges = []
for i in range(len(coords)):
i_neigh = neighbours[indptr_neigh[i]:indptr_neigh[i+1]]
i_neigh = neighbours[indptr_neigh[i]:indptr_neigh[i + 1]]
for j in i_neigh:
if np.linalg.norm(coords[i] - coords[j]) <= self.max_distance:
edges.append([i, j])
Expand Down Expand Up @@ -118,6 +112,8 @@ def filter_nodes_by_label(self, G, label):
filtered_subgraph = G.subgraph(nodes_to_remove).copy()
G.remove_nodes_from(nodes_to_remove)
return filtered_subgraph, G


'''
# Example usage
def preprocess_dataset(dataset):
Expand All @@ -132,4 +128,4 @@ def preprocess_dataset(dataset):
preprocessed_data = preprocess_dataset(dataset)
with open(r"mel_alpha_0.01_buffer_0.pickle", "wb") as output_file:
pickle.dump(preprocessed_data, output_file)
'''
'''
13 changes: 7 additions & 6 deletions GCN/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_max_pool
from torch_geometric.data import Data, Dataset
from sklearn.model_selection import StratifiedKFold, LeaveOneOut, train_test_split
from sklearn.metrics import roc_auc_score
import numpy as np
from tqdm import tqdm
import pickle
from model import GCN_Model


def train(model, data_dict, train_indices, optimizer, device):
model.train()
total_loss = 0
Expand All @@ -22,6 +18,7 @@ def train(model, data_dict, train_indices, optimizer, device):
total_loss += loss.item()
return total_loss / len(train_indices)


def evaluate(model, data_dict, val_indices, device):
model.eval()
y_true, y_pred = [], []
Expand All @@ -36,6 +33,7 @@ def evaluate(model, data_dict, val_indices, device):
val_loss = F.binary_cross_entropy(y_pred, y_true)
return val_loss.item(), y_true.numpy(), y_pred.numpy()


def train_and_validate(data_dict, model, optimizer, train_idx, val_idx, device, epochs=50):
best_val_loss = float('inf')
best_model_state = None
Expand All @@ -48,6 +46,7 @@ def train_and_validate(data_dict, model, optimizer, train_idx, val_idx, device,

return best_model_state


def cross_validation(preprocessed_data, group1, group2, cv_type="3-fold", epochs=500, lr=4e-3, test_size=0.2):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ys = [preprocessed_data[i].y for i in preprocessed_data]
Expand Down Expand Up @@ -76,6 +75,8 @@ def cross_validation(preprocessed_data, group1, group2, cv_type="3-fold", epochs
test_loss, y_true, y_pred = evaluate(model, preprocessed_data, test_indices, device)

return np.array(y_true), np.array(y_pred), model


'''
### Example usage, run after creating a pickle dataset :

Expand All @@ -89,4 +90,4 @@ def cross_validation(preprocessed_data, group1, group2, cv_type="3-fold", epochs
# Calculate and print AUC score on the test set
auc = roc_auc_score(y_true.flatten(), y_pred.flatten())
print(f"AUC Score on Test Set: {auc:.4f}")
'''
'''
10 changes: 2 additions & 8 deletions GCN/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_max_pool
from torch_geometric.data import Data, Dataset
from sklearn.model_selection import StratifiedKFold, LeaveOneOut, train_test_split
from sklearn.metrics import roc_auc_score
import numpy as np
from tqdm import tqdm
import pickle


class GCNBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels):
Expand All @@ -21,6 +15,7 @@ def forward(self, x, edge_index):
x = self.dropout(x)
return x


class GCN_Model(torch.nn.Module):
def __init__(self, input_size, hidden_channels=[75, 150, 50], output_size=1):
super().__init__()
Expand All @@ -42,4 +37,3 @@ def forward(self, data):
x = self.output_layer(x)
x = torch.sigmoid(x)
return x

3 changes: 1 addition & 2 deletions auxiliary/plugin_clean_tumor_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from scipy.spatial import ConvexHull, Delaunay
import numpy as np
import alphashape
from shapely.geometry import Polygon
from shapely.geometry import Point


Expand Down Expand Up @@ -92,7 +91,7 @@ def remove_nodes_inside_alpha_shape(G, alpha_shapes):
Removes nodes from the graph that are inside the given alpha shape.

:param G: The original graph.
:param alpha_shape: A Shapely polygon representing the alpha shape.
:param alpha_shapes: A Shapely polygon representing the alpha shape.
:return: The modified graph with nodes removed.
"""
positions = nx.get_node_attributes(G, 'pos')
Expand Down
Loading