diff --git a/evalv_.py b/evalv_.py new file mode 100644 index 0000000..c19e114 --- /dev/null +++ b/evalv_.py @@ -0,0 +1,265 @@ +from scipy.spatial.distance import cdist +import numpy as np +import os +from scipy.ndimage import center_of_mass +import sys +import matplotlib.pyplot as plt +import pandas as pd +import cv2 +from matplotlib_scalebar.scalebar import ScaleBar +import tifffile +import csv + +def get_dist_threshold(gt_cells, scales): + scaled_gt_cells = gt_cells.copy() + scaled_gt_cells = scaled_gt_cells * scales + # print(scales, scaled_gt_cells) + + dist = cdist(scaled_gt_cells, scaled_gt_cells, metric='euclidean') + np.fill_diagonal(dist, np.inf) + + min_dist = np.min(dist, axis=1) + min_dist.sort() + dist_threshold = np.mean(min_dist[:5]) / 2 + + return dist_threshold + +def evaluate(predicted_cell, gt_cell, threshold): + """ + predicted_cell: n x 3 + gt_cell: m x 3 + + If there is no ground-truth cell within a valid distance, the cell prediction is counted as an FP + If there are one or more ground-truth cells within a valid distance, the cell prediction is counted as a TP. + The remaining ground-truth cells that are not matched with any cell prediction are counted as FN. + + return precision, recall, f1 score + """ + dist = cdist(predicted_cell, gt_cell, metric='euclidean') + n_pred, n_gt = dist.shape + assert(n_pred != 0 and n_gt != 0) + bool_mask = (dist <= threshold) + tp, fp = 0, 0 + + for i in range(len(predicted_cell)): + neighbors = bool_mask[i].nonzero()[0] + + if len(neighbors) == 0: + fp += 1 + else: + gt_idx = min(neighbors, key=lambda j: dist[i, j]) + tp += 1 + bool_mask[:, gt_idx] = False + + precision = tp / (tp + fp + 1e-7) + recall = tp / (n_gt + 1e-7) + f1 = 2 * precision * recall / (precision + recall + 1e-7) + + return precision, recall, f1 + + +def mask_to_centroids(mask): + #mask = np.transpose(mask, (1, 2, 0)) + print("here") + labels = np.unique(mask) + centroids = [] + + for label in labels: + if label == 0: + continue + centroids.append(center_of_mass(mask == label)) + + return np.array(centroids) + +def visualize_labels(image, mask, predicted_cell, gt_cell, results_path): + fig, axs = plt.subplots(nrows=1, ncols=3) + axs[0].imshow(np.max(image, axis=2)) + axs[0].scatter(gt_cell[:, 1], gt_cell[:, 0], s=2, alpha=0.5, c='white', marker='x') + + axs[1].imshow(np.max(mask, axis=0)) + + # predicted_cell = predicted_cell[np.abs(predicted_cell[:, 2] - z) <= 0.5] + # gt_cell = gt_cell[gt_cell[:, 2] == z] + + axs[2].imshow(np.max(image, axis=2)) + axs[2].scatter(predicted_cell[:, 1], predicted_cell[:, 0], s=2, alpha=0.5, c='white') + plt.savefig(results_path, dpi=300) + plt.close() + + +def visualize_labels_figure(image, mask, predicted_cell, gt_cell, results_path, threshold=6): + fig, axs = plt.subplots(nrows=1, ncols=3) + axs[0].imshow(np.max(image, axis=2)) + axs[0].scatter(gt_cell[:, 1], gt_cell[:, 0], s=2, alpha=0.5, c='white', marker='x') + + axs[1].imshow(np.max(mask, axis=0)) + + # predicted_cell = predicted_cell[np.abs(predicted_cell[:, 2] - z) <= 0.5] + # gt_cell = gt_cell[gt_cell[:, 2] == z] + + axs[2].imshow(np.max(image, axis=2)) + + dist = cdist(predicted_cell, gt_cell, metric='euclidean') + n_pred, n_gt = dist.shape + + bool_mask = (dist <= threshold) + tp, fp, fn = [], [], [] + matched_gt_indices = set() + + for i in range(len(predicted_cell)): + neighbors = bool_mask[i].nonzero()[0] + + if len(neighbors) == 0: + fp.append(predicted_cell[i]) + else: + gt_idx = min(neighbors, key=lambda j: dist[i, j]) + tp.append(predicted_cell[i]) + matched_gt_indices.add(gt_idx) + bool_mask[:, gt_idx] = False + + for j in range(len(gt_cell)): + if j not in matched_gt_indices: + fn.append(gt_cell[j]) + + tp = np.array(tp) + fp = np.array(fp) + fn = np.array(fn) + + axs[2].scatter(tp[:, 1], tp[:, 0], s=3, alpha=0.5, c='white') + axs[2].scatter(fp[:, 1], fp[:, 0], s=3, alpha=0.5, c='red') + axs[2].scatter(fn[:, 1], fn[:, 0], s=3, alpha=0.5, c='yellow') + axs[0].axis('off') + axs[1].axis('off') + axs[2].axis('off') + + add_manual_scale_bar(axs[2], bar_length_px=740, bar_height_px=10, bar_color='white', bar_text='20 µm', bar_text_color='white') + + plt.savefig(results_path, dpi=300) + plt.close() + exit() + +def add_manual_scale_bar(ax, bar_length_px, bar_height_px, bar_color, bar_text, bar_text_color, location='lower right', padding=10): + # Get the dimensions of the image + y_lim = ax.get_ylim() + x_lim = ax.get_xlim() + + # Calculate the position of the scale bar + if location == 'lower right': + x_start = x_lim[1] - bar_length_px - padding + y_start = y_lim[0] + padding + elif location == 'lower left': + x_start = x_lim[0] + padding + y_start = y_lim[0] + padding + elif location == 'upper right': + x_start = x_lim[1] - bar_length_px - padding + y_start = y_lim[1] - bar_height_px - padding + elif location == 'upper left': + x_start = x_lim[0] + padding + y_start = y_lim[1] - bar_height_px - padding + else: + raise ValueError("Invalid location argument. Choose from 'lower right', 'lower left', 'upper right', 'upper left'.") + + # Draw the scale bar + rect = plt.Rectangle((x_start, y_start), bar_length_px, bar_height_px, linewidth=1, edgecolor=bar_color, facecolor=bar_color) + ax.add_patch(rect) + + # Add the text label + ax.text(x_start + bar_length_px / 2, y_start - bar_height_px - padding, bar_text, color=bar_text_color, + ha='center', va='top', fontsize=10) + +def get_color(RGB, pred_labels): + pred_labels = pred_labels.astype(int) + colors = np.zeros((len(pred_labels), 3)) + for i in range(len(pred_labels)): + colors[i] = RGB[pred_labels[i][0], pred_labels[i][1], pred_labels[i][2]] + return colors + +if __name__ == "__main__": + gt_folder = [] + predicted_folder = [] + + sessions = ['fold_0' , 'fold_1', 'fold_2', 'fold_3', 'fold_4' ] + ##Get the predicted images from fold_n folder + + + + ##Use image name to locate gt center filer + ##Use original center images + ##Use mask_to_centroids function to convert predicted heatmaps to centers + metrics = [] + for session in sessions: + gt_folder_path = "/Users/bhavikagopalani/Downloads/Boston/evaluate/nej/nejatbakhsh20/tiff/000541" #center file + results_folder_path = f"/Users/bhavikagopalani/Downloads/Boston/evaluate/nej/predictions/{session}/000541" #heatmaps + + + pred = [] + + for file in os.listdir(results_folder_path): + if file.endswith('_im_points.npy'): + base_name = file.replace('_im_points.npy', '') + target_name = f"{base_name}_heatmaps_points.npy" + target_path = os.path.join(gt_folder_path, target_name) + + + # pred_heatmap = tifffile.imread(f'{results_folder_path}/{file}') + # #gt_heatmap = tifffile.imread(target_path) + # #print("pred_heatmap", pred_heatmap.shape) + # pred_labels = mask_to_centroids(pred_heatmap) + # #gt_labels = mask_to_centroids(gt_heatmap) + # #print("pred_labels after transform", pred_labels.shape) + # #print("gt_labels", gt_labels.shape) + + # x_idx, y_idx, z_idx = np.where(gt_labels != 0) + # gt_labels = np.column_stack((x_idx, y_idx, z_idx)) + # #print("gt_labels after transform", gt_labels.shape) + # # print(gt_labels[:5]) + # # print(pred_labels[:5]) + pred_centers = np.load(f'{results_folder_path}/{file}') + gt_centers = np.load(target_path) + total_gt = gt_centers.shape[0] + total_pred = pred_centers.shape[0] + + # print(pred_centers.shape) + # print(gt_centers.shape) + thres = get_dist_threshold(gt_centers, 1) + print("threshold = ", thres) + scale = [0.27, 0.27, 1.5] #EY + #scale = [0.21, 0.21, 0.75] #NP + precision, recall, f1 = evaluate(pred_centers*scale, gt_centers*scale, threshold= 3) + print(precision, recall, f1) + metrics.append((base_name, precision, recall, f1, total_gt, total_pred)) + + if metrics: + # Extract data for plotting + image_names = [m[0] for m in metrics] + precisions = [m[1] for m in metrics] + recalls = [m[2] for m in metrics] + f1s = [m[3] for m in metrics] + + # Create an index for the x-axis + x_vals = range(len(metrics)) + + # --- Plot all three metrics on one figure --- + plt.figure() + plt.plot(x_vals, precisions, marker='o', label='Precision') + plt.plot(x_vals, recalls, marker='o', label='Recall') + plt.plot(x_vals, f1s, marker='o', label='F1 Score') + + plt.xlabel('Image Index') + plt.ylabel('Metric Value') + plt.title(f'Metrics per Image') + plt.legend() + plt.show() + + csv_filename = f"/Users/bhavikagopalani/Downloads/Boston/evaluate/nej/metrics_nj_3.csv" + csv_path = os.path.join(results_folder_path, csv_filename) + + with open(csv_path, mode='w', newline='') as csvfile: + writer = csv.writer(csvfile) + # Write a header row + writer.writerow(["Image Name", "Precision", "Recall", "F1 Score", "Total GT cells", "Total Predicted Cells"]) + # Write each row of metrics + for row in metrics: + writer.writerow(row) + + \ No newline at end of file