Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 265 additions & 0 deletions evalv_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
from scipy.spatial.distance import cdist
import numpy as np
import os
from scipy.ndimage import center_of_mass
import sys
import matplotlib.pyplot as plt
import pandas as pd
import cv2
from matplotlib_scalebar.scalebar import ScaleBar
import tifffile
import csv

def get_dist_threshold(gt_cells, scales):
scaled_gt_cells = gt_cells.copy()
scaled_gt_cells = scaled_gt_cells * scales
# print(scales, scaled_gt_cells)

dist = cdist(scaled_gt_cells, scaled_gt_cells, metric='euclidean')
np.fill_diagonal(dist, np.inf)

min_dist = np.min(dist, axis=1)
min_dist.sort()
dist_threshold = np.mean(min_dist[:5]) / 2

return dist_threshold

def evaluate(predicted_cell, gt_cell, threshold):
"""
predicted_cell: n x 3
gt_cell: m x 3

If there is no ground-truth cell within a valid distance, the cell prediction is counted as an FP
If there are one or more ground-truth cells within a valid distance, the cell prediction is counted as a TP.
The remaining ground-truth cells that are not matched with any cell prediction are counted as FN.

return precision, recall, f1 score
"""
dist = cdist(predicted_cell, gt_cell, metric='euclidean')
n_pred, n_gt = dist.shape
assert(n_pred != 0 and n_gt != 0)
bool_mask = (dist <= threshold)
tp, fp = 0, 0

for i in range(len(predicted_cell)):
neighbors = bool_mask[i].nonzero()[0]

if len(neighbors) == 0:
fp += 1
else:
gt_idx = min(neighbors, key=lambda j: dist[i, j])
tp += 1
bool_mask[:, gt_idx] = False

precision = tp / (tp + fp + 1e-7)
recall = tp / (n_gt + 1e-7)
f1 = 2 * precision * recall / (precision + recall + 1e-7)

return precision, recall, f1


def mask_to_centroids(mask):
#mask = np.transpose(mask, (1, 2, 0))
print("here")
labels = np.unique(mask)
centroids = []

for label in labels:
if label == 0:
continue
centroids.append(center_of_mass(mask == label))

return np.array(centroids)

def visualize_labels(image, mask, predicted_cell, gt_cell, results_path):
fig, axs = plt.subplots(nrows=1, ncols=3)
axs[0].imshow(np.max(image, axis=2))
axs[0].scatter(gt_cell[:, 1], gt_cell[:, 0], s=2, alpha=0.5, c='white', marker='x')

axs[1].imshow(np.max(mask, axis=0))

# predicted_cell = predicted_cell[np.abs(predicted_cell[:, 2] - z) <= 0.5]
# gt_cell = gt_cell[gt_cell[:, 2] == z]

axs[2].imshow(np.max(image, axis=2))
axs[2].scatter(predicted_cell[:, 1], predicted_cell[:, 0], s=2, alpha=0.5, c='white')
plt.savefig(results_path, dpi=300)
plt.close()


def visualize_labels_figure(image, mask, predicted_cell, gt_cell, results_path, threshold=6):
fig, axs = plt.subplots(nrows=1, ncols=3)
axs[0].imshow(np.max(image, axis=2))
axs[0].scatter(gt_cell[:, 1], gt_cell[:, 0], s=2, alpha=0.5, c='white', marker='x')

axs[1].imshow(np.max(mask, axis=0))

# predicted_cell = predicted_cell[np.abs(predicted_cell[:, 2] - z) <= 0.5]
# gt_cell = gt_cell[gt_cell[:, 2] == z]

axs[2].imshow(np.max(image, axis=2))

dist = cdist(predicted_cell, gt_cell, metric='euclidean')
n_pred, n_gt = dist.shape

bool_mask = (dist <= threshold)
tp, fp, fn = [], [], []
matched_gt_indices = set()

for i in range(len(predicted_cell)):
neighbors = bool_mask[i].nonzero()[0]

if len(neighbors) == 0:
fp.append(predicted_cell[i])
else:
gt_idx = min(neighbors, key=lambda j: dist[i, j])
tp.append(predicted_cell[i])
matched_gt_indices.add(gt_idx)
bool_mask[:, gt_idx] = False

for j in range(len(gt_cell)):
if j not in matched_gt_indices:
fn.append(gt_cell[j])

tp = np.array(tp)
fp = np.array(fp)
fn = np.array(fn)

axs[2].scatter(tp[:, 1], tp[:, 0], s=3, alpha=0.5, c='white')
axs[2].scatter(fp[:, 1], fp[:, 0], s=3, alpha=0.5, c='red')
axs[2].scatter(fn[:, 1], fn[:, 0], s=3, alpha=0.5, c='yellow')
axs[0].axis('off')
axs[1].axis('off')
axs[2].axis('off')

add_manual_scale_bar(axs[2], bar_length_px=740, bar_height_px=10, bar_color='white', bar_text='20 µm', bar_text_color='white')

plt.savefig(results_path, dpi=300)
plt.close()
exit()

def add_manual_scale_bar(ax, bar_length_px, bar_height_px, bar_color, bar_text, bar_text_color, location='lower right', padding=10):
# Get the dimensions of the image
y_lim = ax.get_ylim()
x_lim = ax.get_xlim()

# Calculate the position of the scale bar
if location == 'lower right':
x_start = x_lim[1] - bar_length_px - padding
y_start = y_lim[0] + padding
elif location == 'lower left':
x_start = x_lim[0] + padding
y_start = y_lim[0] + padding
elif location == 'upper right':
x_start = x_lim[1] - bar_length_px - padding
y_start = y_lim[1] - bar_height_px - padding
elif location == 'upper left':
x_start = x_lim[0] + padding
y_start = y_lim[1] - bar_height_px - padding
else:
raise ValueError("Invalid location argument. Choose from 'lower right', 'lower left', 'upper right', 'upper left'.")

# Draw the scale bar
rect = plt.Rectangle((x_start, y_start), bar_length_px, bar_height_px, linewidth=1, edgecolor=bar_color, facecolor=bar_color)
ax.add_patch(rect)

# Add the text label
ax.text(x_start + bar_length_px / 2, y_start - bar_height_px - padding, bar_text, color=bar_text_color,
ha='center', va='top', fontsize=10)

def get_color(RGB, pred_labels):
pred_labels = pred_labels.astype(int)
colors = np.zeros((len(pred_labels), 3))
for i in range(len(pred_labels)):
colors[i] = RGB[pred_labels[i][0], pred_labels[i][1], pred_labels[i][2]]
return colors

if __name__ == "__main__":
gt_folder = []
predicted_folder = []

sessions = ['fold_0' , 'fold_1', 'fold_2', 'fold_3', 'fold_4' ]
##Get the predicted images from fold_n folder



##Use image name to locate gt center filer
##Use original center images
##Use mask_to_centroids function to convert predicted heatmaps to centers
metrics = []
for session in sessions:
gt_folder_path = "/Users/bhavikagopalani/Downloads/Boston/evaluate/nej/nejatbakhsh20/tiff/000541" #center file
results_folder_path = f"/Users/bhavikagopalani/Downloads/Boston/evaluate/nej/predictions/{session}/000541" #heatmaps


pred = []

for file in os.listdir(results_folder_path):
if file.endswith('_im_points.npy'):
base_name = file.replace('_im_points.npy', '')
target_name = f"{base_name}_heatmaps_points.npy"
target_path = os.path.join(gt_folder_path, target_name)


# pred_heatmap = tifffile.imread(f'{results_folder_path}/{file}')
# #gt_heatmap = tifffile.imread(target_path)
# #print("pred_heatmap", pred_heatmap.shape)
# pred_labels = mask_to_centroids(pred_heatmap)
# #gt_labels = mask_to_centroids(gt_heatmap)
# #print("pred_labels after transform", pred_labels.shape)
# #print("gt_labels", gt_labels.shape)

# x_idx, y_idx, z_idx = np.where(gt_labels != 0)
# gt_labels = np.column_stack((x_idx, y_idx, z_idx))
# #print("gt_labels after transform", gt_labels.shape)
# # print(gt_labels[:5])
# # print(pred_labels[:5])
pred_centers = np.load(f'{results_folder_path}/{file}')
gt_centers = np.load(target_path)
total_gt = gt_centers.shape[0]
total_pred = pred_centers.shape[0]

# print(pred_centers.shape)
# print(gt_centers.shape)
thres = get_dist_threshold(gt_centers, 1)
print("threshold = ", thres)
scale = [0.27, 0.27, 1.5] #EY
#scale = [0.21, 0.21, 0.75] #NP
precision, recall, f1 = evaluate(pred_centers*scale, gt_centers*scale, threshold= 3)
print(precision, recall, f1)
metrics.append((base_name, precision, recall, f1, total_gt, total_pred))

if metrics:
# Extract data for plotting
image_names = [m[0] for m in metrics]
precisions = [m[1] for m in metrics]
recalls = [m[2] for m in metrics]
f1s = [m[3] for m in metrics]

# Create an index for the x-axis
x_vals = range(len(metrics))

# --- Plot all three metrics on one figure ---
plt.figure()
plt.plot(x_vals, precisions, marker='o', label='Precision')
plt.plot(x_vals, recalls, marker='o', label='Recall')
plt.plot(x_vals, f1s, marker='o', label='F1 Score')

plt.xlabel('Image Index')
plt.ylabel('Metric Value')
plt.title(f'Metrics per Image')
plt.legend()
plt.show()

csv_filename = f"/Users/bhavikagopalani/Downloads/Boston/evaluate/nej/metrics_nj_3.csv"
csv_path = os.path.join(results_folder_path, csv_filename)

with open(csv_path, mode='w', newline='') as csvfile:
writer = csv.writer(csvfile)
# Write a header row
writer.writerow(["Image Name", "Precision", "Recall", "F1 Score", "Total GT cells", "Total Predicted Cells"])
# Write each row of metrics
for row in metrics:
writer.writerow(row)