diff --git a/CHANGELOG.md b/CHANGELOG.md index 968e2c2..aff581f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,22 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## v0.1.1 - [2025.11.15] + +Efficiency update - only cells with overlaps are taken into consideration when calculating IoU matrices. Mask comparisons not calculated through a direct comparison, but optimized through lookup tables. + +### `Added` +- `plot_target_size` parameter for plotting - inputs will be "downscaled" to match the target size for plotting efficiency + +### `Fixed` +- memory issues for large input masks with many cells +- Dockerfile copies the repo to pip install instead of pointing to main + +### `Removed` +- graph construction with networkx, replaced with functions +- cost matrix metric currently doesn't affect the execution - only IoU is used. To be updated soon. + ## v0.1.0 - [2025.11.12] First release of Segobe - a tool for object matching, segmentation error counting and metric evaluation. diff --git a/Dockerfile b/Dockerfile index aa8d661..f1dbacd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,3 +18,8 @@ USER root # Ensure micromamba binaries are in PATH ENV PATH="$MAMBA_ROOT_PREFIX/bin:$PATH" +# Copy the rest of the current directory into /app inside the container +WORKDIR /app +COPY . . + +RUN micromamba run -n base pip install . diff --git a/README.md b/README.md index cb77bbb..8875b7d 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Designed for cell segmentation evaluation, it can handle large batches of sample ## Installation -Option 1. Install directory from the repository (recommended for development) +### Option 1. Install directory from the repository (recommended for development) If you plan to develop or modify Segobe, install it in editable mode: ```bash # Clone the repository @@ -24,7 +24,7 @@ pip install -e . ``` > The -e flag (editable mode) means the changes to the source code are immediately reflected without reinstalling. -Option 2. Install directly from GitHub +### Option 2. Install directly from GitHub Once the repository is made public, users can install it directly via URL: ```bash pip install git+https://github.com/schapirolabor/segobe.git @@ -50,7 +50,6 @@ segobe \ --iou_threshold 0.5 \ --graph_iou_threshold 0.1 \ --unmatched_cost 0.4 \ - --cost_matrix_metric 'iou' \ --save_plots ``` @@ -62,8 +61,9 @@ segobe \ | | --iou_threshold | IoU threshold for cell matching (0-1, default: 0.5). Match is true if pair is selected with linear_sum_assignment and IoU above this threshold. | | | --graph_iou_threshold | Graph IoU threshold for error detection (0-1, default: 0.1). Minimal IoU for cells to be considered 'connected'. | | | --unmatched_cost | Cost for unmatched objects in the cost matrix (0-1, default: 0.4) | -| | --cost_matrix_metric | Specify which metric should be used for cost matrix construction (default: 'iou', other options 'dice', 'moc' - see details [here](docs/detailed_overview.md)) | +| | --cost_matrix_metric | Specify which metric should be used for cost matrix construction (default: 'iou', other options 'dice', 'moc' - see details [here](docs/detailed_overview.md)) `note that only IoU is currently supported` | | | --save_plots | Boolean specifying whether plots (barplot grouped by category and row-specific error overview) are saved | +| | --plot_target_size | Size in pixels of the plot error types subfigures. If the inputs are larger, they will be approximately downsampled by a scale factor. If that scale factor is larger than 4, boundaries will not be drawn. (default: 600) | | | --version | Prints tool version. | ### Input format @@ -110,7 +110,6 @@ Captured metrics in the `metrics.csv` for each input CSV `row`: * Splits: counts and dictionary of matched predictions to GTs * Merges: counts and dictionary of matched GTs to predictions * Catastrophes: counts and dictionary of groups of GTs and predictions involved -* IoU graph: constructed graph of cell overlaps * True postives: counts * False positives: counts * False negatives: counts diff --git a/environment.yml b/environment.yml index b7237ac..607ff8e 100644 --- a/environment.yml +++ b/environment.yml @@ -9,7 +9,6 @@ dependencies: - scipy - scikit-image - matplotlib - - networkx - tifffile - git - pip diff --git a/pyproject.toml b/pyproject.toml index bcf7968..78153e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,6 @@ dependencies = [ "pandas", "matplotlib", "scipy", - "networkx", "scikit-image", "tifffile" ] diff --git a/segobe/CLI.py b/segobe/CLI.py index 9d1fa15..434d6c1 100644 --- a/segobe/CLI.py +++ b/segobe/CLI.py @@ -59,6 +59,12 @@ def get_args(): help="Metric used for cost matrix calculation (default: iou)", choices=["iou", "dice", "moc"], ) + parser.add_argument( + "--target_plot_size", + type=int, + default=600, + help="Target size to which large input images will be downsampled for error type plots.", + ) parser.add_argument( "--save_plots", action="store_true", @@ -84,8 +90,12 @@ def main(): # Read input CSV df = pd.read_csv(args.input_csv) + required_columns = {"sampleID", "ref_mask", "eval_mask", "category"} + if not required_columns.issubset(df.columns): + raise ValueError( + f"Input CSV must contain columns: {required_columns}, got {df.columns}" + ) - print(args.cost_matrix_metric) # Run batch evaluation batch_eval = SegmentationEvaluationBatch( df, @@ -136,6 +146,8 @@ def main(): f"{args.basename}_{row['sampleID']}_{row['category']}_error_types.png", ), suptitle=True, + legend=False, + target_size=args.target_plot_size ) diff --git a/segobe/__init__.py b/segobe/__init__.py index 4f24f18..67ca226 100644 --- a/segobe/__init__.py +++ b/segobe/__init__.py @@ -2,7 +2,7 @@ from .plotter import plot_error_types, plot_barplot from .utils import filter_mask_by_ids -__version__ = "0.1.0" +__version__ = "0.1.1" __all__ = [ "SegmentationEvaluator", "SegmentationEvaluationBatch", diff --git a/segobe/evaluator.py b/segobe/evaluator.py index feb5698..da4f9a9 100644 --- a/segobe/evaluator.py +++ b/segobe/evaluator.py @@ -2,189 +2,324 @@ import numpy as np import pandas as pd import tifffile -import networkx as nx +from collections import Counter, defaultdict, deque from scipy.stats import hmean from scipy.optimize import linear_sum_assignment class SegmentationEvaluator: - """Evaluate segmentation between a reference (GT) and evaluation (target) mask.""" + """ + Memory-efficient cell matcher. + Handles merges, splits, catastrophes. + Ensures unmatched cells appear in IoU/Dice/MOC matrices. + """ def __init__( self, - gt_mask, - pred_mask, - iou_threshold=0.5, - graph_iou_threshold=0.1, - unmatched_cost=0.4, - cost_matrix_metric="iou", + gt_mask: np.ndarray, + pred_mask: np.ndarray, + iou_threshold: float = 0.5, + graph_iou_threshold: float = 0.1, + unmatched_cost: float = 0.4, ): + assert gt_mask.shape == pred_mask.shape self.gt = gt_mask self.pred = pred_mask - self.gt_ids = np.unique(self.gt)[1:] - self.pred_ids = np.unique(self.pred)[1:] - self.n_gt = len(self.gt_ids) - self.n_pred = len(self.pred_ids) - self.n = self.n_gt + self.n_pred - self.iou_threshold = iou_threshold self.graph_iou_threshold = graph_iou_threshold self.unmatched_cost = unmatched_cost - self.cost_matrix_metric = cost_matrix_metric - self.get_metric_matrix() - - def get_metric_matrix(self): - iou_matrix = np.zeros((len(self.gt_ids), len(self.pred_ids))) - dice_matrix = np.zeros((len(self.gt_ids), len(self.pred_ids))) - moc_matrix = np.zeros((len(self.gt_ids), len(self.pred_ids))) - iou_graph = nx.Graph() - for i, gt_id in enumerate(self.gt_ids): - gt_labmask = self.gt == gt_id - for j, pred_id in enumerate(self.pred_ids): - pred_labmask = self.pred == pred_id - intersection = np.logical_and(gt_labmask, pred_labmask).sum() - union = np.logical_or(gt_labmask, pred_labmask).sum() - gt_size = gt_labmask.sum() - pred_size = pred_labmask.sum() - iou = intersection / union if union > 0 else 0 - dice = ( - (2 * intersection) / (gt_size + pred_size) - if (gt_size + pred_size) > 0 - else 0 - ) - moc = ( - (intersection / gt_size + intersection / pred_size) / 2 - if gt_size > 0 and pred_size > 0 - else 0 - ) - iou_matrix[i, j] = iou - dice_matrix[i, j] = dice - moc_matrix[i, j] = moc - if iou > self.graph_iou_threshold: - iou_graph.add_edge(f"gt_{gt_id}", f"pred_{pred_id}", weight=1 - iou) - self.iou_matrix = iou_matrix - self.dice_matrix = dice_matrix - self.moc_matrix = moc_matrix - self.iou_graph = iou_graph - - def calculate_cost_matrices(self): - self.cost_matrix_iou = self.construct_cost_matrix(self.iou_matrix) - self.cost_matrix_dice = self.construct_cost_matrix(self.dice_matrix) - self.cost_matrix_moc = self.construct_cost_matrix(self.moc_matrix) - - def construct_cost_matrix(self, metric_matrix): - cost_matrix = np.ones((self.n, self.n)) - cost_matrix[: self.n_gt, : self.n_pred] = ( - 1 - metric_matrix - ) # A: Top-left block = real costs - cost_matrix[self.n_gt :, self.n_pred :] = ( - 1 - metric_matrix - ).T # D: Bottom-right = transpose of A - top_right = self.unmatched_cost * np.eye(self.n_gt) + (1 - np.eye(self.n_gt)) - cost_matrix[: self.n_gt, self.n_pred :] = ( - top_right # B: Top-right = unmatched GTs - ) - bottom_left = self.unmatched_cost * np.eye(self.n_pred) + ( - 1 - np.eye(self.n_pred) - ) - cost_matrix[self.n_gt :, : self.n_pred] = ( - bottom_left # C: Bottom-left = unmatched preds - ) - return cost_matrix - - def specify_cost_matrix(self): - if self.cost_matrix_metric == "iou": - return self.construct_cost_matrix(self.iou_matrix) - elif self.cost_matrix_metric == "dice": - return self.construct_cost_matrix(self.dice_matrix) - elif self.cost_matrix_metric == "moc": - return self.construct_cost_matrix(self.moc_matrix) + + # label sets + self.gt_ids = np.unique(self.gt) + self.pred_ids = np.unique(self.pred) + self.gt_ids = self.gt_ids[self.gt_ids != 0] + self.pred_ids = self.pred_ids[self.pred_ids != 0] + + # Compute cell sizes + gt_counts = np.bincount(self.gt.ravel()) + pred_counts = np.bincount(self.pred.ravel()) + self.gt_sizes = {int(i): int(gt_counts[i]) for i in self.gt_ids} + self.pred_sizes = {int(i): int(pred_counts[i]) for i in self.pred_ids} + + # Compute sparse intersections + self.intersections = self._compute_intersections(self.gt, self.pred) + + # Build pair_df (all intersecting object pairs) + rows = [] + for (g, p), inter in self.intersections.items(): + gsz = self.gt_sizes[g] + psz = self.pred_sizes[p] + union = gsz + psz - inter + iou = inter / union if union > 0 else 0.0 + dice = (2 * inter) / (gsz + psz) + moc = 0.5 * (inter / gsz + inter / psz) + + rows.append((g, p, inter, iou, dice, moc)) + + if rows: + self.pair_df = pd.DataFrame( + rows, columns=["gt", "pred", "intersection", "iou", "dice", "moc"] + ) else: - raise ValueError("Metric must be one of 'iou', 'dice', or 'moc'.") + self.pair_df = pd.DataFrame( + columns=["gt", "pred", "intersection", "iou", "dice", "moc"] + ) + + def _compute_intersections(self, gt: np.ndarray, pred: np.ndarray) -> Counter: + """Get intersections through encoding""" + max_pred = np.int64(pred.max() + 1) # int64 scalar + + # multiply only in int64 temporarily + joint = np.multiply(gt, max_pred, dtype=np.int64) # result is int64 + joint += pred # pred is int32, addition is safe + + vals, counts = np.unique(joint, return_counts=True) + gt_ids = vals // max_pred + pred_ids = vals % max_pred + + c = Counter() + for g, p, cnt in zip(gt_ids, pred_ids, counts): + if g != 0 and p != 0: + c[(int(g), int(p))] = int(cnt) + return c + + # + # Get adjacency graph + # TODO: can we filter out cells below graph threshold already? + def _adjacency_graph(self): + adj_gt_to_pred = defaultdict(set) + adj_pred_to_gt = defaultdict(set) + for row in self.pair_df.itertuples(index=False): + if row.iou > 0: + adj_gt_to_pred[int(row.gt)].add(int(row.pred)) + adj_pred_to_gt[int(row.pred)].add(int(row.gt)) + return adj_gt_to_pred, adj_pred_to_gt + + # + # Get connected components (not using networkx) + # + def _connected_components_prematching(self, adj_gt2p, adj_p2gt): + visited_g = set() + visited_p = set() + + for g0 in adj_gt2p.keys(): + if g0 in visited_g: + continue + + comp_g = set() + comp_p = set() + q = deque([("g", g0)]) + visited_g.add(g0) + + while q: + t, x = q.popleft() + + if t == "g": + comp_g.add(x) + for p in adj_gt2p[x]: + if p not in visited_p: + visited_p.add(p) + q.append(("p", p)) + + else: # p → g + comp_p.add(x) + for g in adj_p2gt[x]: + if g not in visited_g: + visited_g.add(g) + q.append(("g", g)) + + yield comp_g, comp_p + + def _connected_components_postmatching(self, adj_gt_to_preds, adj_pred_to_gts): + """ + Return connected components of a bipartite graph: + - left nodes: GT labels + - right nodes: Pred labels + + adj_gt_to_preds[g] = set of pred IDs connected to g + adj_pred_to_gts[p] = set of gt IDs connected to p + + Returns a list of tuples: (set_of_gt_nodes, set_of_pred_nodes) + """ + visited_gt = set() + visited_pred = set() + components = [] + + # Iterate over GT nodes + for g in adj_gt_to_preds.keys(): + if g in visited_gt: + continue + + stack_gt = [g] + comp_gt = set() + comp_pred = set() + + while stack_gt: + cg = stack_gt.pop() + if cg in visited_gt: + continue + visited_gt.add(cg) + comp_gt.add(cg) + + # Visit preds connected to this gt + for p in adj_gt_to_preds.get(cg, []): + if p not in visited_pred: + # mark pred + visited_pred.add(p) + comp_pred.add(p) + # explore gt neighbors from that pred + for ng in adj_pred_to_gts.get(p, []): + if ng not in visited_gt: + stack_gt.append(ng) + + if comp_gt or comp_pred: + components.append((comp_gt, comp_pred)) + + # Also handle not yet reached preds + for p in adj_pred_to_gts.keys(): + if p in visited_pred: + continue + stack_pred = [p] + comp_gt = set() + comp_pred = set() + + while stack_pred: + cp = stack_pred.pop() + if cp in visited_pred: + continue + visited_pred.add(cp) + comp_pred.add(cp) + + for g in adj_pred_to_gts.get(cp, []): + if g not in visited_gt: + visited_gt.add(g) + comp_gt.add(g) + for np in adj_gt_to_preds.get(g, []): + if np not in visited_pred: + stack_pred.append(np) + + if comp_gt or comp_pred: + components.append((comp_gt, comp_pred)) + + return components + + # + # Evaluation + # def evaluate(self): + adj_gt2p, adj_p2gt = self._adjacency_graph() - # Construct cost matrix - cost_matrix = self.specify_cost_matrix() - - # Solve assignment - order_res = linear_sum_assignment(cost_matrix) - order_mat = np.zeros_like(cost_matrix) - order_mat[order_res] = 1 - - row_ind, col_ind = np.nonzero(order_mat[: self.n_gt, : self.n_pred]) - - matched_pairs, iou_list, dice_list = [], [], [] - matched_gt, matched_pred = set(), set() - - for i, j in zip(row_ind, col_ind): - if self.iou_matrix[i, j] >= self.iou_threshold: - gt_id, pred_id = self.gt_ids[i], self.pred_ids[j] - matched_pairs.append((gt_id, pred_id)) - matched_gt.add(gt_id) - matched_pred.add(pred_id) - - gt_mask = self.gt == gt_id - pred_mask = self.pred == pred_id - intersection = np.logical_and(gt_mask, pred_mask).sum() - dice = ( - (2 * intersection) / (gt_mask.sum() + pred_mask.sum()) - if (gt_mask.sum() + pred_mask.sum()) > 0 - else 0 - ) + if self.pair_df.empty: + return self._empty_result() - iou_list.append(self.iou_matrix[i, j]) - dice_list.append(dice) + # Lookup dicts + iou_lookup = { + (int(r.gt), int(r.pred)): float(r.iou) + for r in self.pair_df.itertuples(index=False) + } + dice_lookup = { + (int(r.gt), int(r.pred)): float(r.dice) + for r in self.pair_df.itertuples(index=False) + } + # moc_lookup = { (int(r.gt), int(r.pred)): float(r.moc) for r in self.pair_df.itertuples(index=False) } # not to be used as a metric, just for matching as it allows for smaller objects to match to larger ones - # Error classification - splits, merges, catastrophes = 0, 0, 0 - split_details, merge_details, catastrophe_details = [], [], [] + matched_pairs = [] + matched_gt = set() + matched_pred = set() + iou_list = [] + dice_list = [] - # Remove matched nodes - self.iou_graph.remove_nodes_from( - [ - n - for n in self.iou_graph.nodes - if (n.startswith("pred_") and int(n.split("_")[1]) in matched_pred) - or (n.startswith("gt_") and int(n.split("_")[1]) in matched_gt) - ] - ) + # Linear sum assignment per connected component - not across all possible matches + for comp_g, comp_p in self._connected_components_prematching( + adj_gt2p, adj_p2gt + ): + gl = sorted(comp_g) + pl = sorted(comp_p) + ng, npred = len(gl), len(pl) + + # build local IoU matrix + M = np.zeros((ng, npred), float) + D = np.zeros((ng, npred), float) + for i, g in enumerate(gl): + for j, p in enumerate(pl): + M[i, j] = iou_lookup.get((g, p), 0.0) + D[i, j] = dice_lookup.get((g, p), 0.0) + + nloc = ng + npred + C = np.ones((nloc, nloc), float) + C[:ng, :npred] = 1.0 - M + C[ng:, npred:] = (1.0 - M).T + C[:ng, npred:] = self.unmatched_cost * np.eye(ng) + (1.0 - np.eye(ng)) + C[ng:, :npred] = self.unmatched_cost * np.eye(npred) + (1.0 - np.eye(npred)) + + ridx, cidx = linear_sum_assignment(C) + + for r, c in zip(ridx, cidx): + if r < ng and c < npred: + g = gl[r] + p = pl[c] + iou_val = M[r, c] + if iou_val >= self.iou_threshold: + matched_pairs.append((g, p)) + matched_gt.add(g) + matched_pred.add(p) + iou_list.append(iou_val) + dice_list.append(D[r, c]) - for component in nx.connected_components(self.iou_graph): - gts = {n for n in component if n.startswith("gt_")} - preds = {n for n in component if n.startswith("pred_")} - ng, np_ = len(gts), len(preds) - gt_nodes = [int(g.split("_")[1]) for g in gts] - pred_nodes = [int(p.split("_")[1]) for p in preds] + # Get connected components for unmatched cells to construct segmentation error cases + adj_gt_to_preds_unmatched = defaultdict(set) + adj_pred_to_gts_unmatched = defaultdict(set) + for row in self.pair_df.itertuples(index=False): + if row.iou <= self.graph_iou_threshold: + continue + g, p = int(row.gt), int(row.pred) + if g in matched_gt or p in matched_pred: + continue + adj_gt_to_preds_unmatched[g].add(p) + adj_pred_to_gts_unmatched[p].add(g) - if ng == 1 and np_ > 1: + splits, merges, catastrophes = 0, 0, 0 + split_details, merge_details, catastrophe_details = [], [], [] + + for comp_gt, comp_pred in self._connected_components_postmatching( + adj_gt_to_preds_unmatched, adj_pred_to_gts_unmatched + ): + ngc, npc = len(comp_gt), len(comp_pred) + if ngc == 1 and npc > 1: splits += 1 - split_details.append({"gt": gt_nodes[0], "preds": sorted(pred_nodes)}) - elif ng > 1 and np_ == 1: + split_details.append( + {"gt": sorted(list(comp_gt))[0], "preds": sorted(list(comp_pred))} + ) + elif ngc > 1 and npc == 1: merges += 1 - merge_details.append({"pred": pred_nodes[0], "gts": sorted(gt_nodes)}) - elif ng > 1 and np_ > 1: + merge_details.append( + {"pred": sorted(list(comp_pred))[0], "gts": sorted(list(comp_gt))} + ) + elif ngc > 1 and npc > 1: catastrophes += 1 catastrophe_details.append( - {"gts": sorted(gt_nodes), "preds": sorted(pred_nodes)} + {"gts": sorted(list(comp_gt)), "preds": sorted(list(comp_pred))} ) - # Compute metrics + all_gt_set = set(map(int, self.gt_ids)) + all_pred_set = set(map(int, self.pred_ids)) + fp_preds = list(all_pred_set - matched_pred) + fn_gts = list(all_gt_set - matched_gt) + tp = len(matched_pairs) - fp = len(self.pred_ids) - len(matched_pred) - fn = len(self.gt_ids) - len(matched_gt) - precision = tp / (tp + fp) if tp + fp > 0 else 0 - recall = tp / (tp + fn) if tp + fn > 0 else 0 - f1 = hmean([precision, recall]) if precision > 0 and recall > 0 else 0 - - fp_preds = set(self.pred_ids) - set(matched_pred) - fn_gts = set(self.gt_ids) - set(matched_gt) - ttp_gts = set(matched_gt) - ttp_preds = set(matched_pred) + fp = len(fp_preds) + fn = len(fn_gts) + + precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 + f1 = hmean([precision, recall]) if precision > 0 and recall > 0 else 0.0 + + # final metrics dict matching your original names self.metrics = { - "iou_mean": np.mean(iou_list) if iou_list else 0, + "iou_mean": float(np.mean(iou_list)) if iou_list else 0.0, "iou_list": iou_list, - "dice_mean": np.mean(dice_list) if dice_list else 0, + "dice_mean": float(np.mean(dice_list)) if dice_list else 0.0, "dice_list": dice_list, "precision": precision, "recall": recall, @@ -195,18 +330,56 @@ def evaluate(self): "split_details": split_details, "merge_details": merge_details, "catastrophe_details": catastrophe_details, - "iou_graph": self.iou_graph, "true_positives": tp, "false_positives": fp, "false_negatives": fn, - "n_gt_labels": self.n_gt, - "n_pred_labels": self.n_pred, - "FP_list": list(fp_preds), - "FN_list": list(fn_gts), - "TTP_gt": list(ttp_gts), - "TTP_preds": list(ttp_preds), - "total_cells": len(self.gt_ids), - "total_pred_cells": len(self.pred_ids), + "FP_list": fp_preds, + "FN_list": fn_gts, + "TTP_gt": list(matched_gt), + "TTP_preds": list(matched_pred), + "total_cells": int(len(self.gt_ids)), + "total_pred_cells": int(len(self.pred_ids)), + } + + return self.metrics + + def _empty_result(self): + """Return a full metric dictionary when there are no overlaps between GT and predictions.""" + n_gt = len(self.gt_ids) + n_pred = len(self.pred_ids) + + # True positives = 0 because no overlaps + tp = 0 + fp = n_pred + fn = n_gt + + precision = 0.0 + recall = 0.0 + f1 = 0.0 + + self.metrics = { + "iou_mean": 0.0, + "iou_list": [], + "dice_mean": 0.0 if self.extra_metric else None, + "dice_list": [] if self.extra_metric else None, + "precision": precision, + "recall": recall, + "f1_score": f1, + "splits": 0, + "merges": 0, + "catastrophes": 0, + "split_details": [], + "merge_details": [], + "catastrophe_details": [], + "true_positives": tp, + "false_positives": fp, + "false_negatives": fn, + "FP_list": list(self.pred_ids), + "FN_list": list(self.gt_ids), + "TTP_gt": [], + "TTP_preds": [], + "total_cells": n_gt, + "total_pred_cells": n_pred, } return self.metrics @@ -217,18 +390,16 @@ class SegmentationEvaluationBatch: def __init__( self, df, - plotting=False, iou_threshold=0.5, graph_iou_threshold=0.1, unmatched_cost=0.4, - cost_matrix_metric="iou", + # cost_matrix_metric="iou", ): self.df = df.copy() self.iou_threshold = iou_threshold self.graph_iou_threshold = graph_iou_threshold self.unmatched_cost = unmatched_cost - self.plotting = plotting - self.cost_matrix_metric = cost_matrix_metric + # self.cost_matrix_metric = cost_matrix_metric def run(self): results = [] @@ -242,7 +413,6 @@ def run(self): self.iou_threshold, self.graph_iou_threshold, self.unmatched_cost, - self.cost_matrix_metric, ) metrics = evaluator.evaluate() results.append({**row, **metrics}) diff --git a/segobe/plotter.py b/segobe/plotter.py index 7b65f9f..1a1c238 100644 --- a/segobe/plotter.py +++ b/segobe/plotter.py @@ -70,6 +70,7 @@ def plot_error_types( save_path=None, suptitle=False, legend=False, + target_size=600, ): # Extract categories @@ -84,6 +85,14 @@ def plot_error_types( cat_gt = [c["gts"] for c in metrics["catastrophe_details"]] cat_preds = [c["preds"] for c in metrics["catastrophe_details"]] + scale = max(gt_mask.shape) // target_size + if scale < 1: + scale = 1 + skip_boundaries = scale > 4 + + plot_gt = gt_mask[::scale, ::scale] + plot_pred = pred_mask[::scale, ::scale] + # Colors blue = "#1f77b4" yellow = "#ffdb58" @@ -91,29 +100,29 @@ def plot_error_types( # Categories: (Title, GT Mask, Pred Mask) categories = [ - ("Ground Truth", filter_mask_by_ids(gt_mask, np.unique(gt_mask)[1:]), None), - ("Prediction", None, filter_mask_by_ids(pred_mask, np.unique(pred_mask)[1:])), + ("Ground Truth", filter_mask_by_ids(plot_gt, np.unique(plot_gt)[1:]), None), + ("Prediction", None, filter_mask_by_ids(plot_pred, np.unique(plot_pred)[1:])), ( "True Positives", - filter_mask_by_ids(gt_mask, tp_gt), - filter_mask_by_ids(pred_mask, tp_preds), + filter_mask_by_ids(plot_gt, tp_gt), + filter_mask_by_ids(plot_pred, tp_preds), ), - ("False Negatives", filter_mask_by_ids(gt_mask, fn_list), None), - ("False Positives", None, filter_mask_by_ids(pred_mask, fp_list)), + ("False Negatives", filter_mask_by_ids(plot_gt, fn_list), None), + ("False Positives", None, filter_mask_by_ids(plot_pred, fp_list)), ( "Merges", - filter_mask_by_ids(gt_mask, merge_gt), - filter_mask_by_ids(pred_mask, merge_preds), + filter_mask_by_ids(plot_gt, merge_gt), + filter_mask_by_ids(plot_pred, merge_preds), ), ( "Splits", - filter_mask_by_ids(gt_mask, split_gt), - filter_mask_by_ids(pred_mask, split_preds), + filter_mask_by_ids(plot_gt, split_gt), + filter_mask_by_ids(plot_pred, split_preds), ), ( "Catastrophes", - filter_mask_by_ids(gt_mask, cat_gt), - filter_mask_by_ids(pred_mask, cat_preds), + filter_mask_by_ids(plot_gt, cat_gt), + filter_mask_by_ids(plot_pred, cat_preds), ), ] @@ -140,7 +149,7 @@ def plot_error_types( ax.set_title(title, fontsize=14) # Black background - ax.imshow(np.zeros_like(gt_mask, dtype=np.uint8), cmap="gray") + ax.imshow(np.zeros_like(plot_gt, dtype=np.uint8), cmap="gray") # Overlap: green where both gt and pred are present if gt is not None and pred is not None: @@ -163,27 +172,28 @@ def plot_error_types( cmap=ListedColormap([yellow]), alpha=0.8, ) - - ax.contour( - segmentation.find_boundaries(gt, mode="outer"), - colors=blue, - linewidths=0.5, - ) - ax.contour( - segmentation.find_boundaries(pred, mode="outer"), - colors=yellow, - linewidths=0.5, - ) + if not skip_boundaries: + ax.contour( + segmentation.find_boundaries(gt, mode="outer"), + colors=blue, + linewidths=0.5, + ) + ax.contour( + segmentation.find_boundaries(pred, mode="outer"), + colors=yellow, + linewidths=0.5, + ) elif gt is not None: ax.imshow( np.ma.masked_where(gt == 0, gt), cmap=ListedColormap([blue]), alpha=0.8 ) - ax.contour( - segmentation.find_boundaries(gt, mode="outer"), - colors=blue, - linewidths=0.5, - ) + if not skip_boundaries: + ax.contour( + segmentation.find_boundaries(gt, mode="outer"), + colors=blue, + linewidths=0.5, + ) elif pred is not None: ax.imshow( @@ -191,11 +201,12 @@ def plot_error_types( cmap=ListedColormap([yellow]), alpha=0.8, ) - ax.contour( - segmentation.find_boundaries(pred, mode="outer"), - colors=yellow, - linewidths=0.5, - ) + if not skip_boundaries: + ax.contour( + segmentation.find_boundaries(pred, mode="outer"), + colors=yellow, + linewidths=0.5, + ) fig.subplots_adjust(left=0.01, right=0.99, top=0.85, bottom=0.05, wspace=0.05) # Legend