From bada67d0edb04c2a8f71dcc0e125c7bec88f2ac6 Mon Sep 17 00:00:00 2001 From: Barak Milshtein Date: Wed, 2 Apr 2025 10:14:56 +0300 Subject: [PATCH 01/10] cleaned imports and some other small things --- cism/cism.py | 3 +- cism/graph/create_formatted_graph.py | 113 ++++++++++------------ cism/graph/plugin_clean_tumor_clusters.py | 3 +- 3 files changed, 55 insertions(+), 64 deletions(-) diff --git a/cism/cism.py b/cism/cism.py index f31713a..17071ff 100644 --- a/cism/cism.py +++ b/cism/cism.py @@ -11,7 +11,6 @@ from joblib.externals.loky import set_loky_pickler import os import enum -# import modin.pandas as pd import pandas as pd import networkx as nx import numpy as np @@ -502,7 +501,7 @@ def __init__(self, tissue_state_func=tissue_state_func) self.common_cells = common_cells_type - def get_patients_class(self, classes: list=None) -> pd.DataFrame: + def get_patients_class(self, classes: list = None) -> pd.DataFrame: if classes is None: classes = self.patient_class_df.patient_class.unique() exist_patients = self.cism.motifs_dataset.Patient_uId.unique() diff --git a/cism/graph/create_formatted_graph.py b/cism/graph/create_formatted_graph.py index e29e715..5a4ac49 100644 --- a/cism/graph/create_formatted_graph.py +++ b/cism/graph/create_formatted_graph.py @@ -1,14 +1,10 @@ -import cv2 import torch import numpy as np -from torch_geometric.data import Data import torch_geometric -import networkx as nx import matplotlib.pyplot as plt from torch_geometric.data import Data import networkx as nx -from scipy.spatial import Voronoi, voronoi_plot_2d, Delaunay -import pandas as pd +from scipy.spatial import Voronoi, voronoi_plot_2d, Delaunay, distance import matplotlib.cm as cm import matplotlib.colors as colors from cism.graph.plugin_clean_tumor_clusters import process_graph @@ -19,61 +15,59 @@ Date: 16/12/2023 ''' -class GraphBuilder(): - def __init__(self, cells_csv, + + +class GraphBuilder: + def __init__(self, cells_csv, common_cells_mapper, colnames_mapper_dict): - #### sepcific csv col names - self.cols_mapper = colnames_mapper_dict - self.cell_type_col = colnames_mapper_dict['cell_types'] - self.patients_col = colnames_mapper_dict['patient_id'] - #### the cells data csv - self.cells = cells_csv - #### the mapper between cell-type -> common cell type - self.cells_mapper = common_cells_mapper - self.cells['common_cell_types'] = self.cells[self.cell_type_col].apply(lambda x: self.cells_mapper[x]) - #### cell type 2 index - common_cell_types = np.unique(self.cells['common_cell_types']) - self.common_cell_type_mapper = {v : k for k,v in enumerate(common_cell_types)} - - - + # specific csv col names + self.cols_mapper = colnames_mapper_dict + self.cell_type_col = colnames_mapper_dict['cell_types'] + self.patients_col = colnames_mapper_dict['patient_id'] + # cells data csv + self.cells = cells_csv + # the mapper between cell-type -> common cell type + self.cells_mapper = common_cells_mapper + self.cells['common_cell_types'] = self.cells[self.cell_type_col].apply(lambda x: self.cells_mapper[x]) + # cell type 2 index + common_cell_types = np.unique(self.cells['common_cell_types']) + self.common_cell_type_mapper = {v: k for k, v in enumerate(common_cell_types)} + def build_graph(self, path_to_output_dir, max_distance, exclude_cell_type: str, removed_cluster_buffer=0, removed_cluster_alpha=0.01): - from scipy.spatial import distance - - ## for every patient and fov get extract the neighbors + # for every patient and fov get extract the neighbors for point in self.cells[self.patients_col].unique(): for fov in self.cells['fov'].unique(): data_p = self.cells[(self.cells[self.patients_col] == point) & (self.cells['fov'] == fov)] if len(data_p) < 1: - ## FOV and patient mismatch, skipping + # FOV and patient mismatch, skipping continue - coords = [(i, j) for i,j in zip(data_p['centroid-0'], data_p['centroid-1'])] + coords = [(i, j) for i, j in zip(data_p['centroid-0'], data_p['centroid-1'])] points = np.array(coords) indptr_neigh, neighbours = Delaunay(points).vertex_neighbor_vertices edge = [] node_ft = [] - for i,row in enumerate(data_p.iterrows()): - i_neigh = neighbours[indptr_neigh[i]:indptr_neigh[i+1]] - node_ft.append(self.common_cell_type_mapper[row[1]['common_cell_types']]) - for cell in i_neigh: - pair = np.array([i, cell]) - edge.append(pair) + for i, row in enumerate(data_p.iterrows()): + i_neigh = neighbours[indptr_neigh[i]:indptr_neigh[i + 1]] + node_ft.append(self.common_cell_type_mapper[row[1]['common_cell_types']]) + for cell in i_neigh: + pair = np.array([i, cell]) + edge.append(pair) edges = np.asarray(edge).T edge_index = torch.tensor(edges, dtype=torch.long) - x = torch.tensor(np.array(node_ft).reshape(-1,1), dtype=torch.float) - left_cell = edge_index.T[:,0].data.tolist() - right_cell = edge_index.T[:,1].data.tolist() - input_list = edge_index.T[:,0].data.tolist() + x = torch.tensor(np.array(node_ft).reshape(-1, 1), dtype=torch.float) + left_cell = edge_index.T[:, 0].data.tolist() + right_cell = edge_index.T[:, 1].data.tolist() + input_list = edge_index.T[:, 0].data.tolist() unique_values = sorted(list(set(input_list))) value_map = {unique_values[i]: i for i in range(len(unique_values))} output_list = [value_map[value] for value in input_list] - mapper = {input_list[i] : output_list[i] for i in range(len(input_list))} + mapper = {input_list[i]: output_list[i] for i in range(len(input_list))} G = nx.Graph() for left, right in edges.T: @@ -92,20 +86,19 @@ def build_graph(self, if len(clusters) > 0: print(f"Found at least one cluster to remove - Patient_{point}_FOV{fov}.txt") - with open(f"{path_to_output_dir}Patient_{point}_FOV{fov}.txt", "w") as file1: - for idx,(left,right) in enumerate(zip(left_cell, right_cell)): + for idx, (left, right) in enumerate(zip(left_cell, right_cell)): if (left not in G.nodes()) or (right not in G.nodes()): continue if (max_distance is None) or distance.euclidean(coords[left], coords[right]) <= max_distance: file1.write(f'{left} {right} {int(x[mapper[left]])} {int(x[mapper[right]])}\n') print(self.common_cell_type_mapper) - def visualize_voronoi(self, patient, fov = None): + def visualize_voronoi(self, patient, fov=None): patients_col = self.cols_mapper['patient_id'] - self.data_p = self.cells[(self.cells[patients_col ] == patient) & (self.cells['fov'] == fov)] - coords = [(i, j) for i,j in zip(self.data_p['centroid-0'], self.data_p['centroid-1'])] - self.idx2cell_type = {k : v for k,v in enumerate(self.data_p['common_cell_types'])} + self.data_p = self.cells[(self.cells[patients_col] == patient) & (self.cells['fov'] == fov)] + coords = [(i, j) for i, j in zip(self.data_p['centroid-0'], self.data_p['centroid-1'])] + self.idx2cell_type = {k: v for k, v in enumerate(self.data_p['common_cell_types'])} self.points = np.array(coords) indptr_neigh, neighbours = Delaunay(self.points).vertex_neighbor_vertices vor = Voronoi(self.points) @@ -113,10 +106,10 @@ def visualize_voronoi(self, patient, fov = None): plt.gca().invert_yaxis() plt.show() plt.clf() - plt.triplot(self.points[:,0], self.points[:,1], Delaunay(self.points).simplices) - plt.plot(self.points[:,0], self.points[:,1], 'o') - for i,idx in enumerate(self.idx2cell_type): - plt.text(self.points[i,0], self.points[i,1], self.idx2cell_type[idx]) + plt.triplot(self.points[:, 0], self.points[:, 1], Delaunay(self.points).simplices) + plt.plot(self.points[:, 0], self.points[:, 1], 'o') + for i, idx in enumerate(self.idx2cell_type): + plt.text(self.points[i, 0], self.points[i, 1], self.idx2cell_type[idx]) #plt.savefig('delaunay.png', dpi = 300) plt.gca().invert_yaxis() plt.show() @@ -126,21 +119,21 @@ def visualize_graph(self): indptr_neigh, neighbours = Delaunay(self.points).vertex_neighbor_vertices edge = [] node_ft = [] - for i,row in enumerate(self.data_p.iterrows()): - i_neigh = neighbours[indptr_neigh[i]:indptr_neigh[i+1]] - node_ft.append(self.common_cell_type_mapper[row[1]['common_cell_types']]) - for cell in i_neigh: - pair = np.array([i, cell]) - edge.append(pair) + for i, row in enumerate(self.data_p.iterrows()): + i_neigh = neighbours[indptr_neigh[i]:indptr_neigh[i + 1]] + node_ft.append(self.common_cell_type_mapper[row[1]['common_cell_types']]) + for cell in i_neigh: + pair = np.array([i, cell]) + edge.append(pair) edges = np.asarray(edge).T edge_index = torch.tensor(edges, dtype=torch.long) - x = torch.tensor(np.array(node_ft).reshape(-1,1), dtype=torch.float) + x = torch.tensor(np.array(node_ft).reshape(-1, 1), dtype=torch.float) data = Data(x=x, edge_index=edge_index.contiguous()) - g1 = torch_geometric.utils.to_networkx(data,to_undirected=True,node_attrs=['x']) + g1 = torch_geometric.utils.to_networkx(data, to_undirected=True, node_attrs=['x']) COLOR_SCHEME = "Paired" - nx.draw_networkx(g1, node_color=node_ft, node_size = 100, - with_labels = False, cmap = COLOR_SCHEME) - ax=plt.gca() + nx.draw_networkx(g1, node_color=node_ft, node_size=100, + with_labels=False, cmap=COLOR_SCHEME) + ax = plt.gca() norm = colors.Normalize(vmin=np.min(node_ft), vmax=np.max(node_ft)) mappable = cm.ScalarMappable(norm=norm, cmap=COLOR_SCHEME) mappable.set_array([]) @@ -148,4 +141,4 @@ def visualize_graph(self): nx.draw_networkx(g1, node_color=node_ft, node_size=100, with_labels=False, cmap=COLOR_SCHEME) plt.colorbar(mappable) plt.axis('off') - plt.show() \ No newline at end of file + plt.show() diff --git a/cism/graph/plugin_clean_tumor_clusters.py b/cism/graph/plugin_clean_tumor_clusters.py index a5b597a..f647082 100644 --- a/cism/graph/plugin_clean_tumor_clusters.py +++ b/cism/graph/plugin_clean_tumor_clusters.py @@ -2,7 +2,6 @@ from scipy.spatial import ConvexHull, Delaunay import numpy as np import alphashape -from shapely.geometry import Polygon from shapely.geometry import Point @@ -92,7 +91,7 @@ def remove_nodes_inside_alpha_shape(G, alpha_shapes): Removes nodes from the graph that are inside the given alpha shape. :param G: The original graph. - :param alpha_shape: A Shapely polygon representing the alpha shape. + :param alpha_shapes: A Shapely polygon representing the alpha shape. :return: The modified graph with nodes removed. """ positions = nx.get_node_attributes(G, 'pos') From a6c9471e504e33fef2d477e1addbd14f4601f1ba Mon Sep 17 00:00:00 2001 From: Barak Milshtein Date: Wed, 2 Apr 2025 10:30:15 +0300 Subject: [PATCH 02/10] cleaned imports and some other small things in GCN files --- GCN/GCN_explainer.py | 5 ++++- GCN/data_process.py | 30 +++++++++++++----------------- GCN/main.py | 13 +++++++------ GCN/model.py | 10 ++-------- 4 files changed, 26 insertions(+), 32 deletions(-) diff --git a/GCN/GCN_explainer.py b/GCN/GCN_explainer.py index a342009..d489f68 100644 --- a/GCN/GCN_explainer.py +++ b/GCN/GCN_explainer.py @@ -1,9 +1,9 @@ -import torch import torch.nn.functional as F import plotly.graph_objs as go import networkx as nx import numpy as np + def get_node_importance(model, data): """ Compute saliency maps by calculating gradients with respect to the input nodes. @@ -19,6 +19,7 @@ def get_node_importance(model, data): return importance_scores + def visualize_graph_with_importance_interactive(data, importance_scores, cell_type_decoder, threshold='auto', output_html="graph.html"): """ Visualizes the graph with node importance and exports it as an interactive HTML. @@ -99,6 +100,8 @@ def visualize_graph_with_importance_interactive(data, importance_scores, cell_ty ) fig.write_html(output_html) print(f"Interactive graph saved to {output_html}") + + ''' ## Example usage, once you have a trained model: with open('mel_alpha_0.01_buffer_0.pickle', 'rb') as p: diff --git a/GCN/data_process.py b/GCN/data_process.py index 3e438d9..a594cc1 100644 --- a/GCN/data_process.py +++ b/GCN/data_process.py @@ -1,31 +1,25 @@ import torch -import torch.nn.functional as F -from torch_geometric.nn import GCNConv, global_max_pool from torch_geometric.data import Data, Dataset import pandas as pd import numpy as np -from sklearn.model_selection import LeaveOneOut -from sklearn.metrics import roc_auc_score, roc_curve from sklearn.preprocessing import OneHotEncoder from scipy.spatial import Delaunay -import matplotlib.pyplot as plt -from tqdm import tqdm -import itertools import networkx as nx -from scipy.spatial import ConvexHull import alphashape from shapely.geometry import Point -import pickle -import sys import warnings + warnings.filterwarnings("ignore") + + class CellGraphDataset(Dataset): - def __init__(self, csv_path, groups, max_distance=100, transform=None, pre_transform=None, + def __init__(self, csv_path, groups, max_distance=100, transform=None, pre_transform=None, cells_to_filter=None, alpha=0.01, buffer_value=0): super().__init__(transform, pre_transform) self.df = pd.read_csv(csv_path) self.df = self.df[self.df['Group'].isin(groups)] - self.patient_fov_combos = self.df.groupby(['patient number', 'fov', 'Group']).size().reset_index()[['patient number', 'fov', 'Group']] + self.patient_fov_combos = self.df.groupby(['patient number', 'fov', 'Group']).size().reset_index()[ + ['patient number', 'fov', 'Group']] self.pred_encoder = OneHotEncoder(sparse=False).fit(self.df[['pred']]) self.group_encoder = {group: i for i, group in enumerate(self.df['Group'].unique())} self.max_distance = max_distance @@ -40,13 +34,13 @@ def __getitem__(self, idx): patient = self.patient_fov_combos.iloc[idx]['patient number'] fov = self.patient_fov_combos.iloc[idx]['fov'] group = self.patient_fov_combos.iloc[idx]['Group'] - + data_p = self.df[(self.df['patient number'] == patient) & (self.df['fov'] == fov)] - + ### Filter out the specified cell types if self.cells_to_filter: data_p = data_p[~data_p['pred'].isin(self.cells_to_filter)] - + # Create initial graph G = nx.Graph() for i, row in data_p.iterrows(): @@ -61,7 +55,7 @@ def __getitem__(self, idx): edges = [] for i in range(len(coords)): - i_neigh = neighbours[indptr_neigh[i]:indptr_neigh[i+1]] + i_neigh = neighbours[indptr_neigh[i]:indptr_neigh[i + 1]] for j in i_neigh: if np.linalg.norm(coords[i] - coords[j]) <= self.max_distance: edges.append([i, j]) @@ -118,6 +112,8 @@ def filter_nodes_by_label(self, G, label): filtered_subgraph = G.subgraph(nodes_to_remove).copy() G.remove_nodes_from(nodes_to_remove) return filtered_subgraph, G + + ''' # Example usage def preprocess_dataset(dataset): @@ -132,4 +128,4 @@ def preprocess_dataset(dataset): preprocessed_data = preprocess_dataset(dataset) with open(r"mel_alpha_0.01_buffer_0.pickle", "wb") as output_file: pickle.dump(preprocessed_data, output_file) -''' \ No newline at end of file +''' diff --git a/GCN/main.py b/GCN/main.py index 60c3c7a..509b0b1 100644 --- a/GCN/main.py +++ b/GCN/main.py @@ -1,14 +1,10 @@ import torch import torch.nn.functional as F -from torch_geometric.nn import GCNConv, global_max_pool -from torch_geometric.data import Data, Dataset from sklearn.model_selection import StratifiedKFold, LeaveOneOut, train_test_split -from sklearn.metrics import roc_auc_score import numpy as np -from tqdm import tqdm -import pickle from model import GCN_Model + def train(model, data_dict, train_indices, optimizer, device): model.train() total_loss = 0 @@ -22,6 +18,7 @@ def train(model, data_dict, train_indices, optimizer, device): total_loss += loss.item() return total_loss / len(train_indices) + def evaluate(model, data_dict, val_indices, device): model.eval() y_true, y_pred = [], [] @@ -36,6 +33,7 @@ def evaluate(model, data_dict, val_indices, device): val_loss = F.binary_cross_entropy(y_pred, y_true) return val_loss.item(), y_true.numpy(), y_pred.numpy() + def train_and_validate(data_dict, model, optimizer, train_idx, val_idx, device, epochs=50): best_val_loss = float('inf') best_model_state = None @@ -48,6 +46,7 @@ def train_and_validate(data_dict, model, optimizer, train_idx, val_idx, device, return best_model_state + def cross_validation(preprocessed_data, group1, group2, cv_type="3-fold", epochs=500, lr=4e-3, test_size=0.2): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ys = [preprocessed_data[i].y for i in preprocessed_data] @@ -76,6 +75,8 @@ def cross_validation(preprocessed_data, group1, group2, cv_type="3-fold", epochs test_loss, y_true, y_pred = evaluate(model, preprocessed_data, test_indices, device) return np.array(y_true), np.array(y_pred), model + + ''' ### Example usage, run after creating a pickle dataset : @@ -89,4 +90,4 @@ def cross_validation(preprocessed_data, group1, group2, cv_type="3-fold", epochs # Calculate and print AUC score on the test set auc = roc_auc_score(y_true.flatten(), y_pred.flatten()) print(f"AUC Score on Test Set: {auc:.4f}") -''' \ No newline at end of file +''' diff --git a/GCN/model.py b/GCN/model.py index c1cdf53..a6b4a4b 100644 --- a/GCN/model.py +++ b/GCN/model.py @@ -1,12 +1,6 @@ import torch -import torch.nn.functional as F from torch_geometric.nn import GCNConv, global_max_pool -from torch_geometric.data import Data, Dataset -from sklearn.model_selection import StratifiedKFold, LeaveOneOut, train_test_split -from sklearn.metrics import roc_auc_score -import numpy as np -from tqdm import tqdm -import pickle + class GCNBlock(torch.nn.Module): def __init__(self, in_channels, out_channels): @@ -21,6 +15,7 @@ def forward(self, x, edge_index): x = self.dropout(x) return x + class GCN_Model(torch.nn.Module): def __init__(self, input_size, hidden_channels=[75, 150, 50], output_size=1): super().__init__() @@ -42,4 +37,3 @@ def forward(self, data): x = self.output_layer(x) x = torch.sigmoid(x) return x - From 64968dc128052964646b25d901d7e6436a3760f8 Mon Sep 17 00:00:00 2001 From: Barak Milshtein Date: Wed, 2 Apr 2025 10:30:29 +0300 Subject: [PATCH 03/10] cleaned imports and some other small things in the helpers.py --- cism/helpers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cism/helpers.py b/cism/helpers.py index f7e9471..2bcc7cc 100644 --- a/cism/helpers.py +++ b/cism/helpers.py @@ -2,9 +2,7 @@ import pickle from pathlib import Path -from tqdm.autonotebook import tqdm import subprocess -#import modin.pandas as pd import pandas as pd import networkx as nx import os From 08f69fbf90e041a6ea61bc472327283696c2d434 Mon Sep 17 00:00:00 2001 From: Barak Milshtein Date: Wed, 2 Apr 2025 10:31:00 +0300 Subject: [PATCH 04/10] cleaned imports and some other small things in the auxiliary files --- auxiliary/plugin_clean_tumor_clusters.py | 3 +- auxiliary/plugin_utils.py | 163 ++++++++++++----------- 2 files changed, 84 insertions(+), 82 deletions(-) diff --git a/auxiliary/plugin_clean_tumor_clusters.py b/auxiliary/plugin_clean_tumor_clusters.py index a5b597a..f647082 100644 --- a/auxiliary/plugin_clean_tumor_clusters.py +++ b/auxiliary/plugin_clean_tumor_clusters.py @@ -2,7 +2,6 @@ from scipy.spatial import ConvexHull, Delaunay import numpy as np import alphashape -from shapely.geometry import Polygon from shapely.geometry import Point @@ -92,7 +91,7 @@ def remove_nodes_inside_alpha_shape(G, alpha_shapes): Removes nodes from the graph that are inside the given alpha shape. :param G: The original graph. - :param alpha_shape: A Shapely polygon representing the alpha shape. + :param alpha_shapes: A Shapely polygon representing the alpha shape. :return: The modified graph with nodes removed. """ positions = nx.get_node_attributes(G, 'pos') diff --git a/auxiliary/plugin_utils.py b/auxiliary/plugin_utils.py index 6b7f99d..01719ab 100644 --- a/auxiliary/plugin_utils.py +++ b/auxiliary/plugin_utils.py @@ -9,7 +9,7 @@ from matplotlib.backends.backend_agg import FigureCanvasAgg motif = { - ('A', 'B'): {}, + ('A', 'B'): {}, ('A', 'C'): {}, ('B', 'C'): {}, 'A': {'type': 9}, @@ -17,32 +17,36 @@ 'C': {'type': 11} } + def rotate_point(points, angle_degrees): angle_rad = np.radians(angle_degrees) rotation_matrix = np.array([ [np.cos(angle_rad), -np.sin(angle_rad)], [np.sin(angle_rad), np.cos(angle_rad)] ]) - + # Check if points is a single point or a list of points if isinstance(points[0], (int, float)): return np.dot(rotation_matrix, points) else: return [np.dot(rotation_matrix, point) for point in points] -def fig_to_np_array(fig): +def fig_to_np_array(fig): canvas = FigureCanvasAgg(fig) canvas.draw() - + image_data = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8) return image_data.reshape(canvas.get_width_height()[::-1] + (3,)) + def convert_dict(input_dict): new_dict = {} for key, value in input_dict.items(): new_dict[value['name']] = key return new_dict + + def parse_motif_input(input_str, cell_type_dict): # Convert cell type names to indices type_to_index = convert_dict(cell_type_dict) @@ -77,13 +81,15 @@ def parse_motif_input(input_str, cell_type_dict): motifs.append(motif) return motifs + + def parse_motif_input2(input_str, cell_type_dict): # Convert cell type names to indices type_to_index = convert_dict(cell_type_dict) lines = input_str.strip().split("\n") motif = {} - + for line in lines: line = line.strip() if "->" in line: @@ -102,13 +108,14 @@ def parse_motif_input2(input_str, cell_type_dict): motif[node] = {"type": type_to_index[type_val]} else: raise ValueError(f"Unknown cell type: {type_val}") - + return motif + def parse_motif_input1(input_str): lines = input_str.strip().split("\n") motif = {} - + for line in lines: line = line.strip() if "->" in line: @@ -122,30 +129,31 @@ def parse_motif_input1(input_str): elif ".type =" in line: node, type_val = map(str.strip, line.split(".type =")) motif[node] = {"type": int(type_val.strip('"'))} - + return motif -def create_tagged_image(cell_data, image, mapper_dict, patinet_number, fov = None): - # Load the CSV data - if fov is None: - cell_data_specific_p = cell_data[cell_data[mapper_dict['patints_col']] == patinet_number] - else: - cell_data_specific_p = cell_data[(cell_data[mapper_dict['patints_col']] == patinet_number) \ - & (cell_data[mapper_dict['fov_col']] == fov)] - segmented_image_array = np.array(image) - # Create a mapping from cell type to integer - cell_types = cell_data_specific_p[mapper_dict['cell_types_col']].unique() - cell_type_to_int = {cell_type: idx + 1 for idx, cell_type in enumerate(cell_types)} +def create_tagged_image(cell_data, image, mapper_dict, patinet_number, fov=None): + # Load the CSV data + if fov is None: + cell_data_specific_p = cell_data[cell_data[mapper_dict['patints_col']] == patinet_number] + else: + cell_data_specific_p = cell_data[(cell_data[mapper_dict['patints_col']] == patinet_number) \ + & (cell_data[mapper_dict['fov_col']] == fov)] - # Create a new image array for cell types - cell_type_image_array = np.zeros_like(segmented_image_array) - # Populate the new image array based on cell type mapping - for _, row in cell_data_specific_p.iterrows(): - cell_label = row[mapper_dict['cell_index_col']] - cell_type_int = cell_type_to_int[row[mapper_dict['cell_types_col']]] - cell_type_image_array[segmented_image_array == cell_label] = cell_type_int - return cell_type_image_array.astype(np.uint8) + segmented_image_array = np.array(image) + # Create a mapping from cell type to integer + cell_types = cell_data_specific_p[mapper_dict['cell_types_col']].unique() + cell_type_to_int = {cell_type: idx + 1 for idx, cell_type in enumerate(cell_types)} + + # Create a new image array for cell types + cell_type_image_array = np.zeros_like(segmented_image_array) + # Populate the new image array based on cell type mapping + for _, row in cell_data_specific_p.iterrows(): + cell_label = row[mapper_dict['cell_index_col']] + cell_type_int = cell_type_to_int[row[mapper_dict['cell_types_col']]] + cell_type_image_array[segmented_image_array == cell_label] = cell_type_int + return cell_type_image_array.astype(np.uint8) def find_motifs(G_full, motifs): @@ -154,7 +162,7 @@ def find_motifs(G_full, motifs): for motif in motifs: motif_graph = nx.Graph() - + # Adding edges and nodes based on motif input for key, value in motif.items(): if isinstance(key, tuple): # Edges @@ -177,11 +185,9 @@ def find_motifs(G_full, motifs): return all_subgraphs - - def find_motifs1(G_full, motif): motif_graph = nx.Graph() - + # Adding edges and nodes based on motif input for key, value in motif.items(): if isinstance(key, tuple): # Edges @@ -204,28 +210,27 @@ def find_motifs1(G_full, motif): def generate_cell_type_structure_from_tagged(mapper): - num_unique_types = len(mapper) - - colormap = cm.get_cmap('tab20c', num_unique_types) - colors = [colormap(i) for i in range(num_unique_types)] - - # Background and Unknown are hardcoded - cell_type_structure = { - 0: {'name': 'Background', 'color': 'black'}, - #1: {'name': 'Unknown', 'color': 'black'} - } - - # Populate the dictionary with unique cell types and their colors - for i, cell_type in enumerate(mapper.keys(), start=1): - cell_type_structure[cell_type] = {'name': mapper[cell_type], 'color': colors[i-1]} - - return cell_type_structure + num_unique_types = len(mapper) + + colormap = cm.get_cmap('tab20c', num_unique_types) + colors = [colormap(i) for i in range(num_unique_types)] + + # Background and Unknown are hardcoded + cell_type_structure = { + 0: {'name': 'Background', 'color': 'black'}, + # 1: {'name': 'Unknown', 'color': 'black'} + } + + # Populate the dictionary with unique cell types and their colors + for i, cell_type in enumerate(mapper.keys(), start=1): + cell_type_structure[cell_type] = {'name': mapper[cell_type], 'color': colors[i - 1]} + + return cell_type_structure # Load the image def build_graph(image, list_of_cells_to_exclude=[]): - - image_full = image#[500:1500, 500:1500] + image_full = image #[500:1500, 500:1500] image_original = copy.deepcopy(image_full) # Preprocess the image based on the provided code #image_full[image_full == 4] = 17 @@ -233,7 +238,7 @@ def build_graph(image, list_of_cells_to_exclude=[]): #image_full = np.where(image_full > 3, image_full - 1, image_full) if len(list_of_cells_to_exclude) > 0: image_full = np.where(np.isin(image_full, list_of_cells_to_exclude), 50, image_full) - exc = 50 #if len(list_of_cells_to_exclude) > None0 else + exc = 50 #if len(list_of_cells_to_exclude) > None0 else # Extract cell contours and centroid coordinates cnts_full = cv2.findContours(image_full, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cnts_full = cnts_full[0] if len(cnts_full) == 2 else cnts_full[1] @@ -246,7 +251,7 @@ def build_graph(image, list_of_cells_to_exclude=[]): coords_full.append([int(x), int(y)]) if image_full[int(y), int(x)] == exc: point2cell_full[idx] = exc - p2c[idx] = image_original[int(y), int(x)]#0 + p2c[idx] = image_original[int(y), int(x)] #0 else: point2cell_full[idx] = image_full[int(y), int(x)] p2c[idx] = image_original[int(y), int(x)] @@ -256,12 +261,12 @@ def build_graph(image, list_of_cells_to_exclude=[]): # Generate edges and node features edges_full = [] node_ft_full = [] - + for i in range(len(coords_full)): if point2cell_full[i] == exc: continue else: - i_neigh = neighbours_full[indptr_neigh_full[i]:indptr_neigh_full[i+1]] + i_neigh = neighbours_full[indptr_neigh_full[i]:indptr_neigh_full[i + 1]] node_ft_full.append(point2cell_full[i]) for cell in i_neigh: if point2cell_full[cell] == exc: @@ -270,13 +275,12 @@ def build_graph(image, list_of_cells_to_exclude=[]): edges_full.append(pair) edges_full = np.asarray(edges_full).T - G_full = nx.Graph() for left, right in edges_full.T: # Add nodes with their cell types G_full.add_node(left, cell_type=point2cell_full[left]) G_full.add_node(right, cell_type=point2cell_full[right]) - + # Add the edge G_full.add_edge(left, right) @@ -286,10 +290,12 @@ def build_graph(image, list_of_cells_to_exclude=[]): return image_original, G_full, coords_full, point2cell_full, p2c + # Define the node matcher function def node_matcher(node1, node2): return node1['cell_type'] == node2['type'] + # Convert the motif dictionary into a graph structure #def find_motifs(G_full, motif = motif): # motif_graph_8 = nx.Graph() @@ -298,16 +304,14 @@ def node_matcher(node1, node2): # if isinstance(attr, dict) and 'type' in attr: # motif_graph_8.add_node(node, type=attr['type']) - # Search for the motif in the graph +# Search for the motif in the graph # subgraphs_8 = nx.algorithms.isomorphism.GraphMatcher(G_full, motif_graph_8, node_match=node_matcher).subgraph_isomorphisms_iter() - # Convert subgraphs to a list and filter +# Convert subgraphs to a list and filter # subgraphs_8_list = list(subgraphs_8) # return subgraphs_8_list - - def vis_graph_and_motifs(coords_full, subgraphs_8_list, image_original, @@ -321,17 +325,16 @@ def vis_graph_and_motifs(coords_full, # Visualize the corrected graph overlay fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8)) - - # Display the image with Voronoi overlay colors = ['black'] + [cell_types[key]['color'] for key in cell_types.keys()] - custom_cmap = ListedColormap(colors) - print(custom_cmap) + custom_cmap = ListedColormap(colors) + print(custom_cmap) ax1.imshow(image_original, cmap=custom_cmap, origin='upper') colors = [] for cell_type, attributes in cell_types.items(): - ax2.plot([], [], 'o', color=attributes['color'], label=f"Cell Type {attributes['name']} ({cell_type})", markersize=10) + ax2.plot([], [], 'o', color=attributes['color'], label=f"Cell Type {attributes['name']} ({cell_type})", + markersize=10) ax2.legend(loc="upper left") ax2.axis('off') #voronoi_plot_2d(vor_full, ax=ax, show_vertices=False, line_colors='black', line_width=0.5, line_alpha=0.6, point_size=2) @@ -341,9 +344,10 @@ def vis_graph_and_motifs(coords_full, motif_nodes = [node for subgraph in subgraphs_8_list for node in subgraph.keys()] node_colors = ['white' if node in motif_nodes else 'none' for node in G_full.nodes()] - nx.draw_networkx(G_full, pos=corrected_pos, node_size=5, with_labels=False, node_color=node_colors, edge_color='white', ax=ax1) + nx.draw_networkx(G_full, pos=corrected_pos, node_size=5, with_labels=False, node_color=node_colors, + edge_color='white', ax=ax1) for sg in subgraphs_8_list: - bb = {v : k for k,v in sg.items()} + bb = {v: k for k, v in sg.items()} motif_nodes = [bb['A'], bb['B'], bb['C']] motif_edges = [(motif_nodes[i], motif_nodes[j]) for i, j in [(0, 1), (0, 2), (1, 2)]] nx.draw_networkx_edges(G_full, pos=corrected_pos, edgelist=motif_edges, edge_color='red', width=2.5, ax=ax1) @@ -360,12 +364,12 @@ def generate_color_map(cell_types): """ unique_cell_types = np.unique(cell_types) num_colors = len(unique_cell_types) - + # Generate a color palette with as many unique colors as there are cell types color_palette = plt.cm.tab20.colors + plt.cm.tab20c.colors # Combine two color palettes to get more unique colors colors = color_palette * (num_colors // len(color_palette) + 1) # Repeat palette if more colors are needed colors = colors[:num_colors] - + return {cell_type: colors[i] for i, cell_type in enumerate(unique_cell_types)} @@ -375,21 +379,21 @@ def generate_cell_type_structure(cell_types): """ unique_cell_types = np.unique(cell_types) num_unique_types = len(unique_cell_types) - + # Generate a wide range of unique colors using a colormap colormap = cm.get_cmap('tab20c', num_unique_types) colors = [colormap(i) for i in range(num_unique_types)] - + # Background and Unknown are hardcoded cell_type_structure = { #0: {'name': 'Background', 'color': 'black'}, #1: {'name': 'Unknown', 'color': 'black'} } - + # Populate the dictionary with unique cell types and their colors for i, cell_type in enumerate(unique_cell_types, start=0): cell_type_structure[i] = {'name': cell_type, 'color': colors[i]} - + return cell_type_structure @@ -401,24 +405,23 @@ def build_cell_graph(data, filter_out_cell_type_cluster=None, buffer=0, alpha=0.01): - from scipy.spatial import distance """Build a graph from cell coordinates and cell types.""" coords = data[['centroid-0', 'centroid-1']].values - cell_types = data['common_cell_types'].values + cell_types = data['common_cell_types'].values cells_idx = data['cell_id'] include_indices = [i for i, ctype in enumerate(cell_types) if ctype not in exclude] coords_included = coords[include_indices] - #cell_types = cell_types[include_indices] - + # cell_types = cell_types[include_indices] + idx2cell = {idx: cell_type for idx, cell_type in enumerate(cell_types)} cell_type_to_index = {v['name']: k for k, v in tnbc_cells_type.items()} - + # Use the above mapping to generate the desired dictionary p2c = {cell_idx: cell_type_to_index[cell_type] for cell_idx, cell_type in idx2cell.items()} - + points = np.array(coords) indptr_neigh, neighbours = Delaunay(points).vertex_neighbor_vertices edges = [] @@ -426,7 +429,7 @@ def build_cell_graph(data, for i, idx in enumerate(coords): if tnbc_cells_type[p2c[i]]['name'] in exclude: continue - i_neigh = neighbours[indptr_neigh[i]:indptr_neigh[i+1]] + i_neigh = neighbours[indptr_neigh[i]:indptr_neigh[i + 1]] for cell in i_neigh: if tnbc_cells_type[p2c[cell]]['name'] in exclude: continue @@ -450,7 +453,7 @@ def build_cell_graph(data, cell_type_to_index[filter_out_cell_type_cluster], buffer=buffer, alpha=alpha) - + return G, clusters, p2c, coords, coords_included From 888804201232297752e12398474e10a962f4337c Mon Sep 17 00:00:00 2001 From: Barak Milshtein Date: Wed, 2 Apr 2025 10:31:14 +0300 Subject: [PATCH 05/10] cleaned imports and some other small things in the pairwise files --- pairwise/common.py | 2 +- pairwise/pairwise_model.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/pairwise/common.py b/pairwise/common.py index 02cd118..02c4ed6 100644 --- a/pairwise/common.py +++ b/pairwise/common.py @@ -4,4 +4,4 @@ class Columns(enum.Enum): PATIENT_CLASS = 'class' PATIENT_CLASS_ID = 'patient_class_id' - PAIRWISE_COUNT = 'pairwise_count' \ No newline at end of file + PAIRWISE_COUNT = 'pairwise_count' diff --git a/pairwise/pairwise_model.py b/pairwise/pairwise_model.py index 2f7b8a1..2d6fcbc 100644 --- a/pairwise/pairwise_model.py +++ b/pairwise/pairwise_model.py @@ -1,8 +1,5 @@ -import enum import itertools - import numpy as np -import networkx as nx import pandas as pd import shap from sklearn.ensemble import RandomForestClassifier @@ -81,7 +78,6 @@ def get_cell_type_count_from_classes(self, return result - def analyze(self, full_graph_df: pd.DataFrame, cells_type: dict, From 7e6a2a7b6fd197c61a6d45565a5f9a92f3c8b8b9 Mon Sep 17 00:00:00 2001 From: Barak Milshtein Date: Wed, 2 Apr 2025 10:31:27 +0300 Subject: [PATCH 06/10] added missing packages from the requirements.txt file --- requirements.txt | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3e2da81..f1992d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,9 +18,9 @@ defusedxml==0.7.1 dill==0.3.6 executing==1.2.0 fastjsonschema==2.16.2 -fonttools==4.38.0 +fonttools==4.43.0 fqdn==1.5.1 -idna==3.4 +idna==3.7 importlib-metadata==6.8.0 importlib-resources==5.12.0 ipykernel==6.21.1 @@ -90,8 +90,8 @@ soupsieve==2.3.2.post1 stack-data==0.6.2 terminado==0.17.1 tinycss2==1.2.1 -tornado==6.3.3 -tqdm==4.64.1 +tornado==6.4.2 +tqdm==4.66.3 traitlets==5.9.0 uri-template==1.2.0 wcwidth==0.2.6 @@ -99,11 +99,18 @@ webcolors==1.12 webencodings==0.5.1 websocket-client==1.5.1 widgetsnbextension==4.0.5 -zipp==3.17.0 -opencv-python==4.8.0.76 -torch==2.1.2 +zipp==3.19.1 +opencv-python==4.8.1.78 +torch==2.5.0 torchvision==0.16.2 torch-geometric==2.4.0 alphashape==1.3.1 shap==0.44.0 pyarrow==17.0.0 + +torch_geometric~=2.4.0 +scikit-learn~=1.6.1 +scipy~=1.15.2 +shapely~=2.0.7 +dotmotif~=0.15.0 +plotly~=6.0.1 \ No newline at end of file From a451780926b6b3b824d1d42cf6839c941ec28808 Mon Sep 17 00:00:00 2001 From: Barak Milshtein Date: Wed, 2 Apr 2025 12:09:51 +0300 Subject: [PATCH 07/10] changed some of the packages to be in the version as I cloned the code --- requirements.txt | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index f1992d5..3019b54 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,9 +18,9 @@ defusedxml==0.7.1 dill==0.3.6 executing==1.2.0 fastjsonschema==2.16.2 -fonttools==4.43.0 +fonttools==4.38.0 fqdn==1.5.1 -idna==3.7 +idna==3.4 importlib-metadata==6.8.0 importlib-resources==5.12.0 ipykernel==6.21.1 @@ -90,8 +90,8 @@ soupsieve==2.3.2.post1 stack-data==0.6.2 terminado==0.17.1 tinycss2==1.2.1 -tornado==6.4.2 -tqdm==4.66.3 +tornado==6.3.3 +tqdm==4.64.1 traitlets==5.9.0 uri-template==1.2.0 wcwidth==0.2.6 @@ -99,15 +99,14 @@ webcolors==1.12 webencodings==0.5.1 websocket-client==1.5.1 widgetsnbextension==4.0.5 -zipp==3.19.1 -opencv-python==4.8.1.78 -torch==2.5.0 +zipp==3.17.0 +opencv-python==4.8.0.76 +torch==2.1.2 torchvision==0.16.2 torch-geometric==2.4.0 alphashape==1.3.1 shap==0.44.0 pyarrow==17.0.0 - torch_geometric~=2.4.0 scikit-learn~=1.6.1 scipy~=1.15.2 From 901252afab33b89516af59438f614c2b8ae91a96 Mon Sep 17 00:00:00 2001 From: Barak Milshtein Date: Wed, 2 Apr 2025 12:25:00 +0300 Subject: [PATCH 08/10] changed PEP 8 things to look better --- cism/cism.py | 187 ++++++++++++++++++++++++++++----------------------- 1 file changed, 104 insertions(+), 83 deletions(-) diff --git a/cism/cism.py b/cism/cism.py index 17071ff..a53cfe7 100644 --- a/cism/cism.py +++ b/cism/cism.py @@ -72,24 +72,24 @@ def add_dataset(self, dataset_folder, dataset_type, dataset_name, **kwargs) -> N for r, d, files in os.walk(self.network_dataset_root_path + dataset_folder): Parallel(n_jobs=n_jobs, prefer=prefer)(delayed(self._analyze_dataset)( - file=file, - output_dir=self.fanmod_output_root_path + dataset_folder, - cache_dir=self.fanmod_cache_root_path + dataset_folder, - force_run_fanmod=force_run_fanmod, - raw_data_folder=self.network_dataset_root_path + dataset_folder, - force_parse=force_parse, - enable_parse=False) for file in tqdm(files)) + file=file, + output_dir=self.fanmod_output_root_path + dataset_folder, + cache_dir=self.fanmod_cache_root_path + dataset_folder, + force_run_fanmod=force_run_fanmod, + raw_data_folder=self.network_dataset_root_path + dataset_folder, + force_parse=force_parse, + enable_parse=False) for file in tqdm(files)) result_list = Parallel(n_jobs=n_jobs, prefer=prefer, return_as="generator")(delayed(self._analyze_dataset)( - file=file, - output_dir=self.fanmod_output_root_path + dataset_folder, - cache_dir=self.fanmod_cache_root_path + dataset_folder, - force_run_fanmod=False, - raw_data_folder=self.network_dataset_root_path + dataset_folder, - force_parse=force_parse, - enable_parse=True, - p_value=p_value, - quantile_threshold=quantile_threshold) for file in tqdm(files)) + file=file, + output_dir=self.fanmod_output_root_path + dataset_folder, + cache_dir=self.fanmod_cache_root_path + dataset_folder, + force_run_fanmod=False, + raw_data_folder=self.network_dataset_root_path + dataset_folder, + force_parse=force_parse, + enable_parse=True, + p_value=p_value, + quantile_threshold=quantile_threshold) for file in tqdm(files)) motifs_dataset = pd.concat(result_list, ignore_index=True) motifs_dataset['FOV'] = motifs_dataset['FOV'].astype('category') @@ -103,7 +103,8 @@ def add_dataset(self, dataset_folder, dataset_type, dataset_name, **kwargs) -> N motifs_dataset['Patient_uId'] = motifs_dataset['Patient_uId'].astype('category') motifs_dataset['FOV'] = motifs_dataset['FOV'].astype('category') - self.motifs_dataset = motifs_dataset if self.motifs_dataset is None else pd.concat([self.motifs_dataset, motifs_dataset]) + self.motifs_dataset = motifs_dataset if self.motifs_dataset is None else pd.concat( + [self.motifs_dataset, motifs_dataset]) self.motifs_dataset.reset_index(drop=True) def motif_dataset(self) -> pd.DataFrame: @@ -133,7 +134,8 @@ def _analyze_dataset(self, **kwargs): class AnalyzeMotifsResult: def __init__(self, analyze_results: list, patients_ids: list, labels: list): self.results = pd.DataFrame(analyze_results, - columns=['TP', 'TN', 'FN', 'FP', 'cFeatures', 'prob', 'class', 'pred_class', 'classes', 'contributions', 'shape_values'], + columns=['TP', 'TN', 'FN', 'FP', 'cFeatures', 'prob', 'class', 'pred_class', + 'classes', 'contributions', 'shape_values'], index=patients_ids) self.labels = labels @@ -196,6 +198,7 @@ def __init__(self, if len(labels) != 2: raise Exception("currently, we support only binary classification") + class HardDiscriminativeFC(FeatureConfiguration): def __init__(self, labels: list, @@ -508,7 +511,7 @@ def get_patients_class(self, classes: list = None) -> pd.DataFrame: return self.patient_class_df[self.patient_class_df.patient_class.isin(classes) & self.patient_class_df.index.isin(exist_patients)] - def discover(self, extract_by: DiscriminativeFeatureKey, classes: list=None): + def discover(self, extract_by: DiscriminativeFeatureKey, classes: list = None): patient_class_dict = self.patient_class_df[DiscriminativeMotifs.PATIENT_CLASS].to_dict() self.cism.motifs_dataset[DiscriminativeMotifs.PATIENT_CLASS] = self.cism.motifs_dataset.Patient_uId.transform( lambda row: patient_class_dict[row]).astype('category') @@ -582,17 +585,17 @@ def __get_motif_dataset_of_patient(motifs_dataset: pd.DataFrame, raw_get_features_result = Parallel(n_jobs=n_jobs, verbose=0, prefer='threads')( delayed(self._get_features)( - x_data=__get_motif_dataset_of_patient(self.cism.motifs_dataset, - unique_patients_ids, - trial_i, - True), - x_test=__get_motif_dataset_of_patient(self.cism.motifs_dataset, - unique_patients_ids, - trial_i, - False), - test_patient_uid=unique_patients_ids[trial_i], - feature_conf=feature_conf, - patient_class_df=local_patient_class) for trial_i in tqdm(range(len(unique_patients_ids)))) + x_data=__get_motif_dataset_of_patient(self.cism.motifs_dataset, + unique_patients_ids, + trial_i, + True), + x_test=__get_motif_dataset_of_patient(self.cism.motifs_dataset, + unique_patients_ids, + trial_i, + False), + test_patient_uid=unique_patients_ids[trial_i], + feature_conf=feature_conf, + patient_class_df=local_patient_class) for trial_i in tqdm(range(len(unique_patients_ids)))) records = [{'test_patient_id': test_patient_id, 'features': features} for test_patient_id, features, _ in raw_get_features_result] @@ -604,10 +607,11 @@ def __get_motif_dataset_of_patient(motifs_dataset: pd.DataFrame, return records - def analyze_motifs(self, feature_conf: FeatureConfiguration, exclude_patients: list, **kwargs) -> AnalyzeMotifsResult: + def analyze_motifs(self, feature_conf: FeatureConfiguration, exclude_patients: list, + **kwargs) -> AnalyzeMotifsResult: n_jobs = kwargs.setdefault("n_jobs", 8) prefer = kwargs.setdefault("prefer", 'processes') - random_state = kwargs.setdefault("random_state", np.random.RandomState()) + random_state = kwargs.setdefault("random_state", np.random.RandomState()) rand_patient_class = kwargs.setdefault("rand_patient_class", False) rand_motifs = kwargs.setdefault("rand_motifs", False) @@ -617,7 +621,8 @@ def analyze_motifs(self, feature_conf: FeatureConfiguration, exclude_patients: l for patient_id in exclude_patients: unique_patients_ids = np.delete(unique_patients_ids, np.where(np.array(unique_patients_ids) == patient_id)) - local_patient_class = self.patient_class_df[self.patient_class_df.patient_class.isin(feature_conf.labels)].copy() + local_patient_class = self.patient_class_df[ + self.patient_class_df.patient_class.isin(feature_conf.labels)].copy() if rand_patient_class: local_patient_class[DiscriminativeMotifs.PATIENT_CLASS] = random_state.permutation( @@ -655,7 +660,8 @@ def __get_motif_dataset_of_patient(motifs_dataset: pd.DataFrame, test_patient_uid=unique_patients_ids[trial_i], feature_conf=feature_conf, random_state=random_state, - patient_class_df=local_patient_class) for trial_i in tqdm(range(len(unique_patients_ids)))) + patient_class_df=local_patient_class) for trial_i in + tqdm(range(len(unique_patients_ids)))) return AnalyzeMotifsResult(analyze_results=raw_analyze_result, patients_ids=unique_patients_ids, @@ -689,12 +695,12 @@ def __get_motif_dataset_of_patient(motifs_dataset: pd.DataFrame, unique_patients_ids, trial_i, True), - x_test=__get_motif_dataset_of_patient(self.cism.motifs_dataset, - unique_patients_ids, - trial_i, - False), - test_patient_uid=unique_patients_ids[trial_i], - feature_conf=feature_conf) for trial_i in tqdm(range(len(unique_patients_ids)))) + x_test=__get_motif_dataset_of_patient(self.cism.motifs_dataset, + unique_patients_ids, + trial_i, + False), + test_patient_uid=unique_patients_ids[trial_i], + feature_conf=feature_conf) for trial_i in tqdm(range(len(unique_patients_ids)))) records = [{'test_patient_id': test_patient_id, 'features': features} for test_patient_id, features, _ in raw_get_features_result] @@ -702,7 +708,8 @@ def __get_motif_dataset_of_patient(motifs_dataset: pd.DataFrame, return pd.DataFrame.from_records(records) @staticmethod - def _load_tissue_state(tissue_state_csv_path: str, tissue_state_to_string: dict[int, str], tissue_state_func=None) -> pd.DataFrame: + def _load_tissue_state(tissue_state_csv_path: str, tissue_state_to_string: dict[int, str], + tissue_state_func=None) -> pd.DataFrame: patient_class_df = pd.read_csv(tissue_state_csv_path, index_col=0, names=['patient_class_id']) if tissue_state_func: @@ -748,21 +755,26 @@ def _extract_discriminative(self, .groupby(discriminative_feature_key, observed=True)[discriminative_group_key] .nunique().reset_index()) one_class_data = (df_copy[df_copy[discriminative_feature_key] - .isin(special_hash_group[special_hash_group[discriminative_group_key] == 1][discriminative_feature_key]) + .isin( + special_hash_group[special_hash_group[discriminative_group_key] == 1][discriminative_feature_key]) & (df_copy.nunique_colors >= min_nunique_colors)].copy()) - one_class_data.loc[:, DiscriminativeMotifs.PATIENT_COUNT_KEY] = one_class_data.groupby(discriminative_feature_key, observed=True)['Patient_uId'].transform('nunique') + one_class_data.loc[:, DiscriminativeMotifs.PATIENT_COUNT_KEY] = \ + one_class_data.groupby(discriminative_feature_key, observed=True)['Patient_uId'].transform('nunique') one_class_num_motifs = (one_class_data[(one_class_data[DiscriminativeMotifs.PATIENT_COUNT_KEY] > min_patients)] - .sort_values('Freq', ascending=False).drop_duplicates(subset=[discriminative_feature_key]).shape[0]) + .sort_values('Freq', ascending=False).drop_duplicates(subset=[discriminative_feature_key]).shape[0]) unique_classes = df_copy.groupby(discriminative_group_key, observed=True).Patient_uId.nunique() for class_index in range(len(unique_classes.index)): class_name = unique_classes.index[class_index] - one_class_data.loc[one_class_data[discriminative_group_key] == class_name, DiscriminativeMotifs.PATIENT_PERCENTAGE_KEY] = ( - one_class_data.loc[one_class_data[discriminative_group_key] == class_name, DiscriminativeMotifs.PATIENT_COUNT_KEY] / + one_class_data.loc[ + one_class_data[discriminative_group_key] == class_name, DiscriminativeMotifs.PATIENT_PERCENTAGE_KEY] = ( + one_class_data.loc[one_class_data[ + discriminative_group_key] == class_name, DiscriminativeMotifs.PATIENT_COUNT_KEY] / unique_classes[class_name]) - return one_class_data[(one_class_data[DiscriminativeMotifs.PATIENT_COUNT_KEY] >= min_patients)], one_class_num_motifs + return one_class_data[ + (one_class_data[DiscriminativeMotifs.PATIENT_COUNT_KEY] >= min_patients)], one_class_num_motifs @staticmethod def _sort_features_single(data: pd.DataFrame, motif_ids, labels: list) -> list: @@ -784,10 +796,10 @@ def _sort_features_single(data: pd.DataFrame, motif_ids, labels: list) -> list: group_b_size = group_b.Patient.nunique() results.append(pd.DataFrame([{'ID': motif_id, - 'wasserstein_distance': wd_score, - 'group_a_size': group_a_size, - 'group_b_size': group_b_size, - 'group_size_max': max(group_a_size, group_b_size)}])) + 'wasserstein_distance': wd_score, + 'group_a_size': group_a_size, + 'group_b_size': group_b_size, + 'group_size_max': max(group_a_size, group_b_size)}])) return results @@ -796,7 +808,8 @@ def _sort_features(data: pd.DataFrame, labels: list, n_jobs: int = 16) -> pd.Dat wd_results = [] wd_results = Parallel(n_jobs=n_jobs, prefer='processes', return_as="generator")( delayed(TissueStateDiscriminativeMotifs._sort_features_single)( - data=data[data.ID.isin(motif_ids)][['ID', 'Freq', DiscriminativeMotifs.PATIENT_CLASS, 'Patient']].copy(), + data=data[data.ID.isin(motif_ids)][ + ['ID', 'Freq', DiscriminativeMotifs.PATIENT_CLASS, 'Patient']].copy(), motif_ids=motif_ids, labels=labels) for motif_ids in np.array_split(data.ID.unique(), n_jobs)) @@ -818,7 +831,8 @@ def _extract_features(one_class_data: pd.DataFrame, feature_conf: FeatureConfigu if isinstance(feature_conf, SoftDiscriminativeFC): sort_by = 'wasserstein_distance' - wd_results = TissueStateDiscriminativeMotifs._sort_features(data=one_class_data_filter, labels=feature_conf.labels) + wd_results = TissueStateDiscriminativeMotifs._sort_features(data=one_class_data_filter, + labels=feature_conf.labels) one_class_data_filter = pd.merge(one_class_data_filter, wd_results, on='ID') one_class_data_filter = one_class_data_filter[one_class_data_filter['wasserstein_distance'] >= 0.05] one_class_data_filter = one_class_data_filter.sort_values(by=sort_by, ascending=False) @@ -880,11 +894,11 @@ def _extract_features(one_class_data: pd.DataFrame, feature_conf: FeatureConfigu return unique_motifs_colors, unique_motifs, c_features def _get_features(self, - x_data: pd.DataFrame, - x_test: pd.DataFrame, - test_patient_uid: str, - feature_conf: FeatureConfiguration, - patient_class_df: pd.DataFrame = None): + x_data: pd.DataFrame, + x_test: pd.DataFrame, + test_patient_uid: str, + feature_conf: FeatureConfiguration, + patient_class_df: pd.DataFrame = None): local_patient_class = self.patient_class_df.copy() if patient_class_df is None else patient_class_df.copy() @@ -898,31 +912,32 @@ def _get_features(self, .astype({DiscriminativeMotifs.PATIENT_CLASS: 'category'})) if isinstance(feature_conf, TopNFC): - unique_motifs = x_data.groupby('ID', observed=True)['Freq'].mean().reset_index().sort_values('Freq', ascending=False).head( + unique_motifs = x_data.groupby('ID', observed=True)['Freq'].mean().reset_index().sort_values('Freq', + ascending=False).head( feature_conf.top_n).ID.tolist() unique_motifs_colors = [] c_features = len(unique_motifs_colors) + len(unique_motifs) elif isinstance(feature_conf, HardDiscriminativeFC): one_class_data, _ = self._extract_discriminative( - x_data=x_data, - discriminative_group_key=DiscriminativeMotifs.PATIENT_CLASS, - discriminative_feature_key=feature_conf.extract_by.value, - common_cell_type=list(self.common_cells.keys()), - min_nunique_colors=1, - min_patients=1, - p_value=0.05) + x_data=x_data, + discriminative_group_key=DiscriminativeMotifs.PATIENT_CLASS, + discriminative_feature_key=feature_conf.extract_by.value, + common_cell_type=list(self.common_cells.keys()), + min_nunique_colors=1, + min_patients=1, + p_value=0.05) unique_motifs_colors, unique_motifs, c_features = self._extract_features(one_class_data=one_class_data, feature_conf=feature_conf) elif isinstance(feature_conf, SoftDiscriminativeFC): unique_classes = x_data.groupby(DiscriminativeMotifs.PATIENT_CLASS, observed=True).Patient_uId.nunique() count_by_id = (x_data.groupby(['ID', DiscriminativeMotifs.PATIENT_CLASS], observed=True)['Patient_uId'] - .agg('nunique').reset_index()) + .agg('nunique').reset_index()) count_by_id = pd.merge(count_by_id, unique_classes.reset_index().rename({'Patient_uId': 'group_number'}, axis=1), left_on='patient_class', right_on='patient_class') - count_by_id['Patient_uId'] = count_by_id['Patient_uId']/count_by_id['group_number'] + count_by_id['Patient_uId'] = count_by_id['Patient_uId'] / count_by_id['group_number'] count_by_id = (count_by_id.groupby('ID', observed=True)['Patient_uId'].agg('max') .reset_index().rename({'Patient_uId': DiscriminativeMotifs.PATIENT_PERCENTAGE_KEY}, axis=1)) @@ -955,21 +970,23 @@ def _validate(self, patient_class_dict = local_patient_class[DiscriminativeMotifs.PATIENT_CLASS].to_dict() patient_class_df = pd.DataFrame.from_dict(patient_class_dict, - orient='index', - columns=[DiscriminativeMotifs.PATIENT_CLASS]) + orient='index', + columns=[DiscriminativeMotifs.PATIENT_CLASS]) x_data = (pd.merge(x_data.drop([DiscriminativeMotifs.PATIENT_CLASS], axis=1, errors='ignore'), - patient_class_df, - left_on='Patient_uId', right_index=True) - .astype({DiscriminativeMotifs.PATIENT_CLASS:'category'})) + patient_class_df, + left_on='Patient_uId', right_index=True) + .astype({DiscriminativeMotifs.PATIENT_CLASS: 'category'})) - if (feature_conf.cell_type_composition_patient_map is not None) and (feature_conf.motifs_patient_map is not None): + if (feature_conf.cell_type_composition_patient_map is not None) and ( + feature_conf.motifs_patient_map is not None): unique_motifs_colors = feature_conf.cell_type_composition_patient_map[test_patient_uid] unique_motifs = feature_conf.motifs_patient_map[test_patient_uid] c_features = len(unique_motifs_colors) + len(unique_motifs) elif isinstance(feature_conf, TopNFC): - unique_motifs = x_data.groupby('ID', observed=True)['Freq'].mean().reset_index().sort_values('Freq', ascending=False).head( + unique_motifs = x_data.groupby('ID', observed=True)['Freq'].mean().reset_index().sort_values('Freq', + ascending=False).head( feature_conf.top_n).ID.tolist() unique_motifs_colors = [] c_features = len(unique_motifs_colors) + len(unique_motifs) @@ -1032,9 +1049,11 @@ def _validate(self, one_class_data_index = one_class_data.set_index(['Patient_uId', 'ID']) one_class_data_index.sort_index(inplace=True) for patient_class in patient_classes: - for patient_uId in one_class_data[one_class_data[DiscriminativeMotifs.PATIENT_CLASS] == patient_class]['Patient_uId'].unique(): + for patient_uId in one_class_data[one_class_data[DiscriminativeMotifs.PATIENT_CLASS] == patient_class][ + 'Patient_uId'].unique(): vector_dict = defaultdict() - self._add_cell_type_composition_freq_feature(one_class_data_color_index, unique_motifs_colors, vector_dict, patient_uId) + self._add_cell_type_composition_freq_feature(one_class_data_color_index, unique_motifs_colors, + vector_dict, patient_uId) self._add_motif_freq_feature(one_class_data_index, unique_motifs, vector_dict, @@ -1057,7 +1076,8 @@ def _validate(self, vector_dict = defaultdict() x_test_color_index = x_test.set_index(['Patient_uId', 'colors_vec_hash']) x_test_color_index.sort_index(inplace=True) - self._add_cell_type_composition_freq_feature(x_test_color_index, unique_motifs_colors, vector_dict, test_patient_uid) + self._add_cell_type_composition_freq_feature(x_test_color_index, unique_motifs_colors, vector_dict, + test_patient_uid) x_test_index = x_test.set_index(['Patient_uId', 'ID']) x_test_index.sort_index(inplace=True) self._add_motif_freq_feature(x_test_index, @@ -1122,8 +1142,8 @@ def _add_cell_type_composition_freq_feature(one_class_data: pd.DataFrame, # here we calculate the probability of getting the cell identity composition total_count = one_class_data.loc[(patient_uId, motif_color_id)].Count.sum() total_sub_graphs = (one_class_data.loc[(patient_uId, motif_color_id)].drop_duplicates('FOV') - .apply(lambda row: row['Count']/row['Freq'], axis=1).sum()) - feature_set[motif_color_id] = total_count/total_sub_graphs + .apply(lambda row: row['Count'] / row['Freq'], axis=1).sum()) + feature_set[motif_color_id] = total_count / total_sub_graphs else: feature_set[motif_color_id] = 0 @@ -1158,8 +1178,10 @@ def _add_motif_freq_feature(one_class_data_index: pd.DataFrame, patient_uId: str, fuzzy_match_map: dict): for motif_id in unique_motifs: - if (len(fuzzy_match_map) > 0) and (patient_uId, one_class_data_index.index.get_level_values('ID').isin(fuzzy_match_map[motif_id])) in one_class_data_index.index: - filter_data = one_class_data_index.loc[[(patient_uId, one_class_data_index.index.get_level_values('ID').isin(fuzzy_match_map[motif_id]))]] + if (len(fuzzy_match_map) > 0) and (patient_uId, one_class_data_index.index.get_level_values('ID').isin( + fuzzy_match_map[motif_id])) in one_class_data_index.index: + filter_data = one_class_data_index.loc[ + [(patient_uId, one_class_data_index.index.get_level_values('ID').isin(fuzzy_match_map[motif_id]))]] elif (len(fuzzy_match_map) == 0) and (patient_uId, motif_id) in one_class_data_index.index: filter_data = one_class_data_index.loc[[(patient_uId, motif_id)]] else: @@ -1170,4 +1192,3 @@ def _add_motif_freq_feature(one_class_data_index: pd.DataFrame, total_sub_graphs = ( filter_data.drop_duplicates('FOV').apply(lambda row: row['Count'] / row['Freq'], axis=1).sum()) feature_set[motif_id] = total_count / total_sub_graphs - From 4874c832165360fcc8af0955d614c56c99718e8e Mon Sep 17 00:00:00 2001 From: Barak Milshtein Date: Wed, 2 Apr 2025 12:25:21 +0300 Subject: [PATCH 09/10] added another missing package and rearranged order to be alphabetical --- requirements.txt | 58 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3019b54..b268310 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +alphashape==1.3.1 anyio==3.6.2 argon2-cffi==21.3.0 argon2-cffi-bindings==21.2.0 @@ -7,7 +8,12 @@ attrs==22.2.0 backcall==0.2.0 beautifulsoup4==4.11.2 bleach==6.0.0 +certifi==2025.1.31 cffi==1.15.1 +charset-normalizer==3.4.1 +click==8.1.8 +click-log==0.4.0 +cloudpickle==3.1.1 colorama==0.4.6 comm==0.1.2 contourpy==1.1.1 @@ -16,10 +22,15 @@ debugpy==1.6.6 decorator==5.1.1 defusedxml==0.7.1 dill==0.3.6 +dotmotif==0.15.0 executing==1.2.0 +fastcluster==1.2.6 fastjsonschema==2.16.2 +filelock==3.17.0 fonttools==4.38.0 fqdn==1.5.1 +fsspec==2025.2.0 +grandiso==2.2.0 idna==3.4 importlib-metadata==6.8.0 importlib-resources==5.12.0 @@ -34,19 +45,23 @@ joblib==1.4.2 jsonpointer==2.3 jsonschema==4.17.3 jupyter==1.0.0 -jupyter-client==8.0.2 jupyter-console==6.5.0 -jupyter-core==5.2.0 jupyter-events==0.6.3 -jupyter-server==2.2.1 -jupyter-server-terminals==0.4.4 +jupyter_client==8.0.2 +jupyter_core==5.2.0 +jupyter_server==2.2.1 +jupyter_server_terminals==0.4.4 jupyterlab-pygments==0.2.2 jupyterlab-widgets==3.0.5 kiwisolver==1.4.4 +lark-parser==0.12.0 +llvmlite==0.44.0 MarkupSafe==2.1.3 matplotlib==3.7.0 matplotlib-inline==0.1.6 mistune==2.0.5 +mpmath==1.3.0 +narwhals==1.33.0 nbclassic==0.5.1 nbclient==0.7.2 nbconvert==7.2.9 @@ -54,8 +69,10 @@ nbformat==5.7.3 nest-asyncio==1.5.6 networkx==3.1 notebook==6.5.2 -notebook-shim==0.2.2 +notebook_shim==0.2.2 +numba==0.61.0 numpy==1.25.2 +opencv-python==4.8.0.76 packaging==23.2 pandas==2.2.2 pandocfilters==1.5.0 @@ -63,10 +80,12 @@ parso==0.8.3 pickleshare==0.7.5 Pillow==10.1.0 platformdirs==3.0.0 +plotly==6.0.1 prometheus-client==0.16.0 prompt-toolkit==3.0.36 psutil==5.9.6 pure-eval==0.2.2 +pyarrow==17.0.0 pycparser==2.21 Pygments==2.14.0 pyparsing==3.0.9 @@ -80,36 +99,39 @@ PyYAML==6.0.1 pyzmq==25.0.0 qtconsole==5.4.0 QtPy==2.3.0 +requests==2.32.3 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 +Rtree==1.3.0 +scikit-learn==1.6.1 +scipy==1.15.2 seaborn==0.13.0 Send2Trash==1.8.0 +shap==0.44.0 +shapely==2.0.7 six==1.16.0 +slicer==0.0.7 sniffio==1.3.0 soupsieve==2.3.2.post1 stack-data==0.6.2 +sympy==1.13.1 terminado==0.17.1 +threadpoolctl==3.5.0 tinycss2==1.2.1 +torch==2.1.2 +torch_geometric==2.4.0 +torchvision==0.16.2 tornado==6.3.3 tqdm==4.64.1 traitlets==5.9.0 +trimesh==4.6.2 +typing_extensions==4.12.2 +tzdata==2025.1 uri-template==1.2.0 +urllib3==2.3.0 wcwidth==0.2.6 webcolors==1.12 webencodings==0.5.1 websocket-client==1.5.1 widgetsnbextension==4.0.5 zipp==3.17.0 -opencv-python==4.8.0.76 -torch==2.1.2 -torchvision==0.16.2 -torch-geometric==2.4.0 -alphashape==1.3.1 -shap==0.44.0 -pyarrow==17.0.0 -torch_geometric~=2.4.0 -scikit-learn~=1.6.1 -scipy~=1.15.2 -shapely~=2.0.7 -dotmotif~=0.15.0 -plotly~=6.0.1 \ No newline at end of file From 9f5d070a55abff6651dc23ba705360e39d125eea Mon Sep 17 00:00:00 2001 From: Barak Milshtein Date: Wed, 2 Apr 2025 14:21:50 +0300 Subject: [PATCH 10/10] removed another unused import --- cism/evaluate_aux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cism/evaluate_aux.py b/cism/evaluate_aux.py index acdb880..6c2d273 100644 --- a/cism/evaluate_aux.py +++ b/cism/evaluate_aux.py @@ -2,7 +2,7 @@ import matplotlib.pyplot as plt from sklearn.metrics import precision_recall_curve, auc -from sklearn.metrics import roc_curve, roc_auc_score +from sklearn.metrics import roc_curve def get_metrics(df: pd.DataFrame):