diff --git a/examples/run_app_copick.py b/examples/run_app_copick.py new file mode 100644 index 0000000..f1c5b2c --- /dev/null +++ b/examples/run_app_copick.py @@ -0,0 +1,670 @@ +"""Example of using CellCanvas to pick particles on a surface. + +To use: +1. update base_file_path to point to cropped_covid.zarr example file +2. Run the script to launch CellCanvas +3. Paint/predict until you're happy with the result. The seeded labels are: + - 1: background (including inside the capsules) + - 2: membrane + - 3: spike proteins +3b. You might want to switch the image layer into the plane + depiction before doing the instance segmentation. + Sometimes I have trouble manipulating the plane after + the instance segmentation - need to look into this. +4. Once you're happy with the prediction, click the "instance segmentation" tab +5. Set the label value to 2. This will extract the membrane and + make instances via connected components. +6. Remove the small objects. Suggested threshold: 100 +7. Alt + left mouse button to select an instance to modify. + Once select, you can dilate, erode, etc. to smooth it. +8. With the segment still selected, you can then mesh it + using the mesh widget. You can play with the smoothing parameters. +9. If the mesh looks good, switch to the "geometry" tab. + Select the mesh and start surfing! +""" +from collections import defaultdict +import os +import numpy as np +import napari +import cellcanvas +from cellcanvas._app.main_app import CellCanvasApp, QtCellCanvas +from cellcanvas.data.data_manager import DataManager +from cellcanvas.data.data_set import DataSet +from napari.qt.threading import thread_worker + +import sys +import logging +import json +import copick +from copick.impl.filesystem import CopickRootFSSpec +import zarr + +from qtpy.QtWidgets import QTreeWidget, QTreeWidgetItem, QVBoxLayout, QWidget, QComboBox, QPushButton, QLabel +from qtpy.QtCore import Qt +import glob # For pattern matching of file names + +from sklearn.ensemble import RandomForestClassifier + +from cellcanvas.semantic.segmentation_manager import ( + SemanticSegmentationManager, +) +from cellcanvas.utils import get_active_button_color + +import dask.array as da + +import napari +from qtpy.QtWidgets import QTreeWidget, QTreeWidgetItem, QVBoxLayout, QWidget +from qtpy.QtCore import Qt + +class NapariCopickExplorer(QWidget): + def __init__(self, viewer: napari.Viewer, root): + super().__init__() + self.viewer = viewer + self.root = root + self.selected_run = None + self.cell_canvas_app = None + + layout = QVBoxLayout() + self.setLayout(layout) + + self._init_logging() + + # Adding new buttons for "Fit on all" and "Predict for all" + self.fit_all_button = QPushButton("Fit on all") + self.fit_all_button.clicked.connect(self.fit_on_all) + layout.addWidget(self.fit_all_button) + + self.predict_all_button = QPushButton("Predict for all") + self.predict_all_button.clicked.connect(self.predict_for_all) + layout.addWidget(self.predict_all_button) + + # Dropdowns for each data layer + self.dropdowns = {} + self.layer_buttons = {} + for layer in ["image", "features", "painting", "prediction"]: + # Make layer button + button = QPushButton(f"Select {layer.capitalize()} Layer") + button.clicked.connect(lambda checked, layer=layer: self.activate_layer(layer)) + layout.addWidget(button) + self.layer_buttons[layer] = button + # Make layer selection dropdown + self.dropdowns[layer] = QComboBox() + layout.addWidget(self.dropdowns[layer]) + + # Button to update CellCanvas with the selected dataset + self.update_button = QPushButton("Initialize/Update CellCanvas") + self.update_button.clicked.connect(self.initialize_or_update_cell_canvas) + layout.addWidget(self.update_button) + + self.tree = QTreeWidget() + self.tree.setHeaderLabel("Copick Runs") + self.tree.itemClicked.connect(self.on_run_clicked) + layout.addWidget(self.tree) + + self.populate_tree() + + # Monkeypatch + cellcanvas.utils.get_labels_colormap = self.get_copick_colormap + + def get_copick_colormap(self): + """Return a colormap for distinct label colors based on the pickable objects.""" + colormap = {obj.label: np.array(obj.color)/255.0 for obj in root.config.pickable_objects} + colormap[None] = np.array([1, 1, 1, 1]) + colormap[9] = np.array([0, 1, 1, 1]) + return colormap + + def get_voxel_spacing(self): + return 10 + + def _init_logging(self): + self.logger = logging.getLogger("cellcanvas") + self.logger.setLevel(logging.DEBUG) + streamHandler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + streamHandler.setFormatter(formatter) + self.logger.addHandler(streamHandler) + + + def populate_tree(self): + for run in self.root.runs: + run_item = QTreeWidgetItem(self.tree, [run.name]) + run_item.setData(0, Qt.UserRole, run) + + for category in ["segmentations", "meshes", "picks", "voxel_spacings"]: + category_item = QTreeWidgetItem(run_item, [category]) + items = getattr(run, category) + for item in items: + if category == "picks": + item_name = item.pickable_object_name + else: + item_name = getattr(item, 'name', 'Unnamed') + + child_item = QTreeWidgetItem(category_item, [item_name]) + child_item.setData(0, Qt.UserRole, item) + + # list tomograms + if category == "voxel_spacings": + for tomogram in item.tomograms: + tomo_item = QTreeWidgetItem(child_item, [f"Tomogram: {tomogram.tomo_type}"]) + tomo_item.setData(0, Qt.UserRole, tomogram) + + def activate_layer(self, layer): + print(f"Activating layer {layer}") + if layer == "image": + layer = self.cell_canvas_app.semantic_segmentor.data_layer + elif layer == "painting": + layer = self.cell_canvas_app.semantic_segmentor.painting_layer + elif layer == "prediction": + layer = self.cell_canvas_app.semantic_segmentor.prediction_layer + else: + return + layer.visible = True + layer.editable = True + self.viewer.layers.selection.active = layer + + def get_complete_data_manager(self, all_pairs=False): + datasets = [] + for run in self.root.runs: + run_dir = run.static_path + voxel_spacing_dir = self.get_default_voxel_spacing_directory(run_dir) + segmentation_dir = self.get_segmentations_directory(run_dir) + + if not voxel_spacing_dir: + print(f"No Voxel Spacing directory found for run {run.name}.") + continue + + os.makedirs(segmentation_dir, exist_ok=True) + + voxel_spacing = self.get_voxel_spacing() + + # Reused paths for all datasets in a run + painting_path = os.path.join(segmentation_dir, f'{voxel_spacing:.3f}_cellcanvas-painting_0_all-multilabel.zarr') + prediction_path = os.path.join(segmentation_dir, f'{voxel_spacing:.3f}_cellcanvas-prediction_0_all-multilabel.zarr') + + zarr_datasets = glob.glob(os.path.join(voxel_spacing_dir, "*.zarr")) + image_feature_pairs = {} + + # Locate all images and corresponding features + for dataset_path in zarr_datasets: + dataset_name = os.path.basename(dataset_path) + if dataset_name.endswith(".zarr") and not dataset_name.endswith("_features.zarr"): + base_image_name = dataset_name.replace(".zarr", "") + # Find corresponding feature files + feature_files = [path for path in zarr_datasets if base_image_name in path and "_features.zarr" in path] + for feature_path in feature_files: + features_base_name = os.path.basename(feature_path).replace("_features.zarr", "") + # Check if the image base name matches the start of the feature base name + if features_base_name.startswith(base_image_name): + image_feature_pairs[features_base_name] = { + 'image': os.path.join(dataset_path, "0"), # Assuming highest resolution + 'features': feature_path + } + + # Handle either all pairs or only those specified by the configuration + config_path = os.path.join(run_dir, "dataset_config.json") + if os.path.exists(config_path): + with open(config_path, 'r') as file: + config = json.load(file) + if 'painting' in config: + painting_path = os.path.join(segmentation_dir, config['painting']) + if 'prediction' in config: + prediction_path = os.path.join(segmentation_dir, config['prediction']) + + if not all_pairs: + with open(config_path, 'r') as file: + config = json.load(file) + image_path = os.path.join(voxel_spacing_dir, config['image']) + features_path = os.path.join(voxel_spacing_dir, config['features']) + if 'painting' in config: + painting_path = os.path.join(segmentation_dir, config['painting']) + if 'prediction' in config: + prediction_path = os.path.join(segmentation_dir, config['prediction']) + + # Load dataset with specific config paths + dataset = DataSet.from_paths( + image_path=image_path, + features_path=features_path, + labels_path=painting_path, + segmentation_path=prediction_path, + make_missing_datasets=True + ) + datasets.append(dataset) + else: + # Load all available pairs + for base_name, paths in image_feature_pairs.items(): + dataset = DataSet.from_paths( + image_path=paths['image'], + features_path=paths['features'], + labels_path=painting_path, + segmentation_path=prediction_path, + make_missing_datasets=True + ) + datasets.append(dataset) + + print(f"Loaded datasets for run {run.name}") + + return DataManager(datasets=datasets) + + # Only train on config pairs + # def get_complete_data_manager(self, all_pairs=False): + # datasets = [] + # for run in self.root.runs: + # run_dir = run.static_path + # config_path = os.path.join(run_dir, "dataset_config.json") + + # voxel_spacing_dir = self.get_default_voxel_spacing_directory(run_dir) + # segmentation_dir = self.get_segmentations_directory(run_dir) + + # if not voxel_spacing_dir: + # print(f"No Voxel Spacing directory found for run {run.name}.") + # continue + + # os.makedirs(segmentation_dir, exist_ok=True) + + # if os.path.exists(config_path): + # with open(config_path, 'r') as file: + # config = json.load(file) + # image_path = os.path.join(voxel_spacing_dir, config['image']) + # features_path = os.path.join(voxel_spacing_dir, config['features']) + # painting_path = os.path.join(segmentation_dir, config['painting']) + # prediction_path = os.path.join(segmentation_dir, config['prediction']) + # else: + # # Existing logic to find paths + # voxel_spacing = self.get_voxel_spacing() + + # zarr_datasets = glob.glob(os.path.join(voxel_spacing_dir, "*.zarr")) + # image_path = None + # features_path = None + # painting_path = os.path.join(segmentation_dir, f'{voxel_spacing:.3f}_cellcanvas-painting_0_all-multilabel.zarr') + # prediction_path = os.path.join(segmentation_dir, f'{voxel_spacing:.3f}_cellcanvas-prediction_0_all-multilabel.zarr') + + # for dataset_path in zarr_datasets: + # dataset_name = os.path.basename(dataset_path).lower() + # if "_features.zarr" in dataset_name: + # features_path = dataset_path + # elif "painting" in dataset_name: + # painting_path = dataset_path + # elif "prediction" in dataset_name: + # prediction_path = dataset_path + # else: + # # TODO hard coded to use highest resolution + # image_path = os.path.join(dataset_path, "0") + + # # Save paths to JSON + # config = { + # 'image': os.path.relpath(image_path, voxel_spacing_dir), + # 'features': os.path.relpath(features_path, voxel_spacing_dir), + # 'painting': os.path.relpath(painting_path, segmentation_dir), + # 'prediction': os.path.relpath(prediction_path, segmentation_dir) + # } + # with open(config_path, 'w') as file: + # json.dump(config, file) + + # print(f"Fitting on paths:") + # print(f"Image: {image_path}") + # print(f"Features: {features_path}") + # print(f"Painting: {painting_path}") + # print(f"Prediction: {prediction_path}") + + # # Load dataset with paths + # if image_path and features_path: + # dataset = DataSet.from_paths( + # image_path=image_path, + # features_path=features_path, + # labels_path=painting_path, + # segmentation_path=prediction_path, + # make_missing_datasets=True + # ) + # datasets.append(dataset) + + # return DataManager(datasets=datasets) + + def get_default_voxel_spacing_directory(self, static_path): + # Find VoxelSpacing directories, assuming a hard coded match for now + voxel_spacing = self.get_voxel_spacing() + voxel_spacing_dirs = glob.glob(os.path.join(static_path, f'VoxelSpacing{voxel_spacing:.3f}')) + if voxel_spacing_dirs: + return voxel_spacing_dirs[0] + return None + + def get_segmentations_directory(self, static_path): + segmentation_dir = os.path.join(static_path, "Segmentations") + return segmentation_dir + + def change_button_color(self, button, color): + button.setStyleSheet(f"background-color: {color};") + + def reset_button_color(self, button): + self.change_button_color(button, "") + + def fit_on_all(self): + if not self.cell_canvas_app: + print("Initialize cell canvas first") + return + + print("Fitting all models to the selected dataset.") + + self.change_button_color( + self.fit_all_button, get_active_button_color() + ) + + self.model_fit_worker = self.threaded_fit_on_all() + self.model_fit_worker.returned.connect(self.on_model_fit_completed) + self.model_fit_worker.start() + + @thread_worker + def threaded_fit_on_all(self): + # Fit model on all pairs + data_manager = self.get_complete_data_manager(all_pairs=True) + + clf = RandomForestClassifier( + n_estimators=50, + n_jobs=-1, + max_depth=10, + max_samples=0.05, + ) + + segmentation_manager = SemanticSegmentationManager( + data=data_manager, model=clf + ) + segmentation_manager.fit() + + return segmentation_manager + + def on_model_fit_completed(self, segmentation_manager): + self.logger.debug("on_model_fit_completed") + + self.cell_canvas_app.semantic_segmentor.segmentation_manager = segmentation_manager + + # Reset color + self.reset_button_color(self.fit_all_button) + + def predict_for_all(self): + if not self.cell_canvas_app: + print("Initialize cell canvas first") + return + + print("Fitting all models to the selected dataset.") + + self.change_button_color( + self.predict_all_button, get_active_button_color() + ) + + self.predict_worker = self.threaded_predict_for_all() + self.predict_worker.returned.connect(self.on_predict_completed) + self.predict_worker.start() + + def on_predict_completed(self, result): + self.logger.debug("on_predict_completed") + + # Reset color + self.reset_button_color(self.predict_all_button) + + @thread_worker + def threaded_predict_for_all(self): + print("Running predictions on all datasets.") + + # Check if segmentation manager is properly initialized + if not hasattr(self.cell_canvas_app.semantic_segmentor, 'segmentation_manager') or self.cell_canvas_app.semantic_segmentor.segmentation_manager is None: + print("Segmentation manager is not initialized.") + return + + # Retrieve the complete data manager that includes all runs + data_manager = self.get_complete_data_manager() + + # Iterate through each dataset within the data manager + for dataset in data_manager.datasets: + dataset_features = da.asarray(dataset.concatenated_features) + chunk_shape = dataset_features.chunksize + shape = dataset_features.shape + dtype = dataset_features.dtype + + # Iterate over chunks + for z in range(0, shape[1], chunk_shape[1]): + for y in range(0, shape[2], chunk_shape[2]): + for x in range(0, shape[3], chunk_shape[3]): + # Compute the slice for the current chunk + # in feature,z,y,x order + chunk_slice = ( + slice(None), + slice(z, min(z + chunk_shape[1], shape[1])), + slice(y, min(y + chunk_shape[2], shape[2])), + slice(x, min(x + chunk_shape[3], shape[3])), + ) + print(f"Predicting on chunk {chunk_slice}") + + # Extract the current chunk + chunk = dataset_features[chunk_slice].compute() + + # Predict on the chunk (adding 1 to each prediction) + predicted_chunk = self.cell_canvas_app.semantic_segmentor.segmentation_manager.predict(chunk) + 1 + + # Write the prediction to the corresponding region in the Zarr array + dataset.segmentation[chunk_slice[1:]] = predicted_chunk + + print(f"Predictions written") + + def on_run_clicked(self, item, column): + data = item.data(0, Qt.UserRole) + if not isinstance(data, copick.impl.filesystem.CopickRunFSSpec): + self.on_item_clicked(item, column) + return + + self.selected_run = data + static_path = self.selected_run.static_path + self.logger.info(f"Selected {static_path}") + + # Clear existing items + for dropdown in self.dropdowns.values(): + dropdown.clear() + + # Define directories + voxel_spacing_dirs = glob.glob(os.path.join(static_path, "VoxelSpacing10*")) + segmentation_dir = self.get_segmentations_directory(static_path) + os.makedirs(segmentation_dir, exist_ok=True) + + # Initialize dictionary to hold default selections from config + default_selections = {} + + # Check for config file and load selections if present + config_path = os.path.join(static_path, "dataset_config.json") + if os.path.exists(config_path): + with open(config_path, 'r') as file: + config = json.load(file) + default_selections = { + 'image': os.path.join(voxel_spacing_dirs[0], config.get('image')), + 'features': os.path.join(voxel_spacing_dirs[0], config.get('features')), + 'painting': os.path.join(segmentation_dir, config.get('painting')), + 'prediction': os.path.join(segmentation_dir, config.get('prediction')) + } + + # Helper function to add items if not already in dropdown + def add_item_if_not_exists(dropdown, item_name, item_data): + if dropdown.findData(item_data) == -1: + dropdown.addItem(item_name, item_data) + + # Load all zarr datasets from voxel spacing directories + if voxel_spacing_dirs: + for voxel_spacing_dir in voxel_spacing_dirs: + zarr_datasets = glob.glob(os.path.join(voxel_spacing_dir, "*.zarr")) + for dataset_path in zarr_datasets: + dataset_name = os.path.basename(dataset_path) + if "_features.zarr" in dataset_name.lower(): + add_item_if_not_exists(self.dropdowns["features"], dataset_name, dataset_path) + else: + add_item_if_not_exists(self.dropdowns["image"], dataset_name + "/0", dataset_path + "/0") + + # Load all zarr datasets from segmentation directory + zarr_datasets = glob.glob(os.path.join(segmentation_dir, "*.zarr")) + for dataset_path in zarr_datasets: + dataset_name = os.path.basename(dataset_path) + if "painting" not in dataset_name.lower(): + add_item_if_not_exists(self.dropdowns["prediction"], dataset_name, dataset_path) + if "prediction" not in dataset_name.lower(): + add_item_if_not_exists(self.dropdowns["painting"], dataset_name, dataset_path) + + # Set default selections in dropdowns if specified in the config + for key, dropdown in self.dropdowns.items(): + if default_selections.get(key): + index = dropdown.findData(default_selections[key]) + if index != -1: + dropdown.setCurrentIndex(index) + + + def on_item_clicked(self, item, column): + data = item.data(0, Qt.UserRole) + if data: + if isinstance(data, copick.impl.filesystem.CopickPicksFSSpec): + self.open_picks(data) + elif isinstance(data, copick.impl.filesystem.CopickTomogramFSSpec): + self.open_tomogram(data) + elif isinstance(data, copick.models.CopickSegmentation): + self.open_labels(data) + + def open_picks(self, picks): + with open(picks.path, 'r') as f: + points_data = json.load(f) + + # Extracting points locations + points_locations = [ + [point['location']['z'], point['location']['y'], point['location']['x']] + for point in points_data['points'] + ] + + # TODO hard coded scaling + points_array = np.array(points_locations) / 10 + + # Adding the points layer to the viewer, using the pickable_object_name as the layer name + pickable_object = [obj for obj in root.config.pickable_objects if obj.name == picks.pickable_object_name][0] + self.viewer.add_points(points_array, name=picks.pickable_object_name, size=25, out_of_slice_display=True, face_color=np.array(pickable_object.color)/255.0) + + def open_tomogram(self, tomogram): + zarr_store = zarr.open(tomogram.zarr(), mode='r') + print(f"open_tomogram {tomogram.zarr()}") + # TODO extract scale/transform info + + # TODO scale is hard coded to 10 here + self.viewer.add_image(zarr_store[0], name=f"Tomogram: {tomogram.tomo_type}") + + def open_labels(self, tomogram): + zarr_store = zarr.open(tomogram.zarr(), mode='r') + print(f"open_labels {tomogram.zarr()}") + # TODO extract scale/transform info + + # TODO scale is hard coded to 10 here + self.viewer.add_image(zarr_store[0], name=f"Tomogram: {tomogram.name}") + + def initialize_or_update_cell_canvas(self): + # Collect paths from dropdowns + paths = {layer: dropdown.currentText() for layer, dropdown in self.dropdowns.items()} + + if not paths["image"] or not paths["features"]: + print("Please ensure image and feature paths are selected before initializing/updating CellCanvas.") + return + + run_dir = self.selected_run.static_path + segmentation_dir = self.get_segmentations_directory(self.selected_run.static_path) + voxel_spacing_dir = self.get_default_voxel_spacing_directory(self.selected_run.static_path) + + voxel_spacing = self.get_voxel_spacing() + + # Ensure segmentations directory exists + os.makedirs(segmentation_dir, exist_ok=True) + + default_painting_path = os.path.join(segmentation_dir, f'{voxel_spacing:.3f}_cellcanvas-painting_0_all-multilabel.zarr') + default_prediction_path = os.path.join(segmentation_dir, f'{voxel_spacing:.3f}_cellcanvas-prediction_0_all-multilabel.zarr') + + painting_path = default_painting_path if not paths["painting"] else os.path.join(segmentation_dir, paths["painting"]) + prediction_path = default_prediction_path if not paths["prediction"] else os.path.join(segmentation_dir, paths["prediction"]) + image_path = os.path.join(voxel_spacing_dir, paths['image']) + features_path = os.path.join(voxel_spacing_dir, paths["features"]) + + # TODO note this is hard coded to use the highest resolution of a multiscale zarr + print(f"Opening paths:") + print(f"Image: {image_path}") + print(f"Features: {features_path}") + print(f"Painting: {painting_path}") + print(f"Prediction: {prediction_path}") + try: + dataset = DataSet.from_paths( + image_path=image_path, + features_path=features_path, + labels_path=painting_path, + segmentation_path=prediction_path, + make_missing_datasets=True, + ) + except FileNotFoundError: + print(f"File {path} not found!", file=sys.stderr) + return + + config_path = os.path.join(run_dir, "dataset_config.json") + + config = { + 'image': os.path.relpath(os.path.join(voxel_spacing_dir, f"{paths['image']}"), voxel_spacing_dir), + 'features': os.path.relpath(os.path.join(voxel_spacing_dir, paths["features"]), voxel_spacing_dir), + 'painting': os.path.relpath(painting_path, segmentation_dir), + 'prediction': os.path.relpath(prediction_path, segmentation_dir) + } + + with open(config_path, 'w') as file: + json.dump(config, file) + + data_manager = DataManager(datasets=[dataset]) + + if not self.cell_canvas_app: + self.cell_canvas_app = CellCanvasApp(data=data_manager, viewer=self.viewer, verbose=True) + cell_canvas_widget = QtCellCanvas(app=self.cell_canvas_app) + self.viewer.window.add_dock_widget(cell_canvas_widget) + else: + # Update existing CellCanvasApp's data manager + self.cell_canvas_app.update_data_manager(data_manager) + + # TODO this has multiple copick specific hardcoded hacks + + # TODO hardcoded scale factor + # self.viewer.layers['Image'].scale = (10, 10, 10) + + # Set colormap + # painting_layer.colormap.color_dict + # self.app.painting_labels + self.cell_canvas_app.semantic_segmentor.set_colormap(self.get_copick_colormap()) + self.cell_canvas_app.semantic_segmentor.painting_labels = [obj.label for obj in root.config.pickable_objects] + [9] + self.cell_canvas_app.semantic_segmentor.widget.class_labels_mapping = {obj.label: obj.name for obj in root.config.pickable_objects} + + self.cell_canvas_app.semantic_segmentor.widget.class_labels_mapping[9] = 'background' + self.cell_canvas_app.semantic_segmentor.widget.setupLegend() + +if __name__ == "__main__": + # Project root + root = CopickRootFSSpec.from_file("/Volumes/kish@CZI.T7/demo_project/copick_config_kyle.json") + # root = CopickRootFSSpec.from_file("/Volumes/kish@CZI.T7/chlamy_copick/copick_config_kyle.json") + + ## Root API + root.config # CopickConfig object + root.runs # List of run objects (lazy loading from filesystem location(s)) + + viewer = napari.Viewer() + + # Hide layer list and controls + # viewer.window.qt_viewer.dockLayerList.setVisible(False) + # viewer.window.qt_viewer.dockLayerControls.setVisible(False) + + copick_explorer_widget = NapariCopickExplorer(viewer, root) + viewer.window.add_dock_widget(copick_explorer_widget, name="Copick Explorer", area="left") + + # napari.run() + +# TODO finish making the prediction computation more lazy +# the strategy should be to start computing labels chunkwise +# on the zarr itself + +# TODO check scaling between picks and zarrs + +# TODO check why painting doesn't work when using proper scaling + +# TODO add proper colormap and legend support +# - override exclusion of non-zero labels +# - consistent colormap in the charts +# - consistent colormap in the painted part of the labels image diff --git a/src/cellcanvas/_app/main_app.py b/src/cellcanvas/_app/main_app.py index d814a87..b7aedc2 100644 --- a/src/cellcanvas/_app/main_app.py +++ b/src/cellcanvas/_app/main_app.py @@ -30,6 +30,11 @@ def __init__( extra_logging=self.verbose, ) + + def update_data_manager(self, data: DataManager): + self.data = data + self.semantic_segmentor.update_data_manager(data) + @property def mode(self) -> AppMode: return self._mode diff --git a/src/cellcanvas/_copick/__init__.py b/src/cellcanvas/_copick/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/cellcanvas/_copick/widget.py b/src/cellcanvas/_copick/widget.py new file mode 100644 index 0000000..c531552 --- /dev/null +++ b/src/cellcanvas/_copick/widget.py @@ -0,0 +1,584 @@ +"""Example of using CellCanvas to pick particles on a surface. + +To use: +1. update base_file_path to point to cropped_covid.zarr example file +2. Run the script to launch CellCanvas +3. Paint/predict until you're happy with the result. The seeded labels are: + - 1: background (including inside the capsules) + - 2: membrane + - 3: spike proteins +3b. You might want to switch the image layer into the plane + depiction before doing the instance segmentation. + Sometimes I have trouble manipulating the plane after + the instance segmentation - need to look into this. +4. Once you're happy with the prediction, click the "instance segmentation" tab +5. Set the label value to 2. This will extract the membrane and + make instances via connected components. +6. Remove the small objects. Suggested threshold: 100 +7. Alt + left mouse button to select an instance to modify. + Once select, you can dilate, erode, etc. to smooth it. +8. With the segment still selected, you can then mesh it + using the mesh widget. You can play with the smoothing parameters. +9. If the mesh looks good, switch to the "geometry" tab. + Select the mesh and start surfing! +""" +from collections import defaultdict +import os +import numpy as np +import napari +import cellcanvas +from cellcanvas._app.main_app import CellCanvasApp, QtCellCanvas +from cellcanvas.data.data_manager import DataManager +from cellcanvas.data.data_set import DataSet +from napari.qt.threading import thread_worker + +import sys +import logging +import json +import copick +from copick.impl.filesystem import CopickRootFSSpec +import zarr + +from qtpy.QtWidgets import QTreeWidget, QTreeWidgetItem, QVBoxLayout, QWidget, QComboBox, QPushButton, QLabel +from qtpy.QtCore import Qt +import glob # For pattern matching of file names + +from sklearn.ensemble import RandomForestClassifier + +from cellcanvas.semantic.segmentation_manager import ( + SemanticSegmentationManager, +) +from cellcanvas.utils import get_active_button_color + +import dask.array as da + +import napari +from qtpy.QtWidgets import QTreeWidget, QTreeWidgetItem, QVBoxLayout, QWidget +from qtpy.QtCore import Qt + +class NapariCopickExplorer(QWidget): + def __init__(self, viewer: napari.Viewer, root): + super().__init__() + self.viewer = viewer + self.root = root + self.selected_run = None + self.cell_canvas_app = None + + layout = QVBoxLayout() + self.setLayout(layout) + + self._init_logging() + + # Adding new buttons for "Fit on all" and "Predict for all" + self.fit_all_button = QPushButton("Fit on all") + self.fit_all_button.clicked.connect(self.fit_on_all) + layout.addWidget(self.fit_all_button) + + self.predict_all_button = QPushButton("Predict for all") + self.predict_all_button.clicked.connect(self.predict_for_all) + layout.addWidget(self.predict_all_button) + + # Dropdowns for each data layer + self.dropdowns = {} + self.layer_buttons = {} + for layer in ["image", "features", "painting", "prediction"]: + # Make layer button + button = QPushButton(f"Select {layer.capitalize()} Layer") + button.clicked.connect(lambda checked, layer=layer: self.activate_layer(layer)) + layout.addWidget(button) + self.layer_buttons[layer] = button + # Make layer selection dropdown + self.dropdowns[layer] = QComboBox() + layout.addWidget(self.dropdowns[layer]) + + # Button to update CellCanvas with the selected dataset + self.update_button = QPushButton("Initialize/Update CellCanvas") + self.update_button.clicked.connect(self.initialize_or_update_cell_canvas) + layout.addWidget(self.update_button) + + self.tree = QTreeWidget() + self.tree.setHeaderLabel("Copick Runs") + self.tree.itemClicked.connect(self.on_run_clicked) + layout.addWidget(self.tree) + + self.populate_tree() + + # Monkeypatch + cellcanvas.utils.get_labels_colormap = self.get_copick_colormap + + def get_copick_colormap(self): + """Return a colormap for distinct label colors based on the pickable objects.""" + colormap = {obj.label: np.array(obj.color)/255.0 for obj in self.root.config.pickable_objects} + colormap[None] = np.array([1, 1, 1, 1]) + return colormap + + def get_voxel_spacing(self): + return 10 + + def _init_logging(self): + self.logger = logging.getLogger("cellcanvas") + self.logger.setLevel(logging.DEBUG) + streamHandler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + streamHandler.setFormatter(formatter) + self.logger.addHandler(streamHandler) + + def populate_tree(self): + self.tree.clear() # Clear existing items if repopulating + for run in self.root.runs: + run_item = QTreeWidgetItem(self.tree, [run.name]) + run_item.setData(0, Qt.UserRole, run) + run_item.setChildIndicatorPolicy(QTreeWidgetItem.ShowIndicator) + + def setup_signals(self): + self.tree.itemExpanded.connect(self.on_item_expanded) + + def on_item_expanded(self, item): + # Check if the item has already been populated + if not hasattr(item, 'is_populated'): + run = item.data(0, Qt.UserRole) + if isinstance(run, copick.models.CopickRun): + self.populate_run(item, run) + item.is_populated = True # Mark as populated + + def populate_run(self, run_item, run): + for category in ["segmentations", "meshes", "picks", "voxel_spacings"]: + category_item = QTreeWidgetItem(run_item, [category]) + items = getattr(run, category, []) + for item in items: + if category == "picks": + item_name = item.pickable_object_name + else: + item_name = getattr(item, 'name', 'Unnamed') + child_item = QTreeWidgetItem(category_item, [item_name]) + child_item.setData(0, Qt.UserRole, item) + + if category == "voxel_spacings": + for tomogram in item.tomograms: + tomo_item = QTreeWidgetItem(child_item, [f"Tomogram: {tomogram.tomo_type}"]) + tomo_item.setData(0, Qt.UserRole, tomogram) + + + def activate_layer(self, layer): + print(f"Activating layer {layer}") + if layer == "image": + layer = self.cell_canvas_app.semantic_segmentor.data_layer + elif layer == "painting": + layer = self.cell_canvas_app.semantic_segmentor.painting_layer + elif layer == "prediction": + layer = self.cell_canvas_app.semantic_segmentor.prediction_layer + else: + return + layer.visible = True + layer.editable = True + self.viewer.layers.selection.active = layer + + def get_complete_data_manager(self, all_pairs=False): + datasets = [] + for run in self.root.runs: + run_dir = run.static_path + overlay_path = run.overlay_path + + voxel_spacing_dir = self.get_default_voxel_spacing_directory(run) + segmentation_dir = self.get_segmentations_directory(run) + + if not voxel_spacing_dir: + print(f"No Voxel Spacing directory found for run {run.name}.") + continue + + os.makedirs(segmentation_dir, exist_ok=True) + + voxel_spacing = self.get_voxel_spacing() + + # Reused paths for all datasets in a run + painting_path = self.get_default_painting_path(segmentation_dir, voxel_spacing) + prediction_path = self.get_default_prediction_path(segmentation_dir, voxel_spacing) + + zarr_datasets = glob.glob(os.path.join(voxel_spacing_dir, "*.zarr")) + image_feature_pairs = {} + + # Locate all images and corresponding features + for dataset_path in zarr_datasets: + dataset_name = os.path.basename(dataset_path) + if dataset_name.endswith(".zarr") and not dataset_name.endswith("_features.zarr"): + base_image_name = dataset_name.replace(".zarr", "") + # Find corresponding feature files + feature_files = [path for path in zarr_datasets if base_image_name in path and "_features.zarr" in path] + for feature_path in feature_files: + features_base_name = os.path.basename(feature_path).replace("_features.zarr", "") + # Check if the image base name matches the start of the feature base name + if features_base_name.startswith(base_image_name): + image_feature_pairs[features_base_name] = { + 'image': os.path.join(dataset_path, "0"), # Assuming highest resolution + 'features': feature_path + } + + # Handle either all pairs or only those specified by the configuration + config_path = self.get_config_path(run.static_path) + if os.path.exists(config_path): + with open(config_path, 'r') as file: + config = json.load(file) + if 'painting' in config: + painting_path = os.path.join(segmentation_dir, config['painting']) + if 'prediction' in config: + prediction_path = os.path.join(segmentation_dir, config['prediction']) + + if os.path.exists(config_path) and not all_pairs: + with open(config_path, 'r') as file: + config = json.load(file) + image_path = os.path.join(voxel_spacing_dir, config['image']) + features_path = os.path.join(voxel_spacing_dir, config['features']) + if 'painting' in config: + painting_path = os.path.join(segmentation_dir, config['painting']) + if 'prediction' in config: + prediction_path = os.path.join(segmentation_dir, config['prediction']) + + # Load dataset with specific config paths + dataset = DataSet.from_paths( + image_path=image_path, + features_path=features_path, + labels_path=painting_path, + segmentation_path=prediction_path, + make_missing_datasets=True + ) + datasets.append(dataset) + else: + # Load all available pairs + for base_name, paths in image_feature_pairs.items(): + dataset = DataSet.from_paths( + image_path=paths['image'], + features_path=paths['features'], + labels_path=painting_path, + segmentation_path=prediction_path, + make_missing_datasets=True + ) + datasets.append(dataset) + + print(f"Loaded datasets for run {run.name}") + + return DataManager(datasets=datasets) + + def get_default_voxel_spacing_directory(self, run): + # Find VoxelSpacing directories, assuming a hard coded match for now + voxel_spacing = self.get_voxel_spacing() + voxel_spacing_dirs = glob.glob(os.path.join(run.static_path, f'VoxelSpacing{voxel_spacing:.3f}')) + if voxel_spacing_dirs: + return voxel_spacing_dirs[0] + return None + + def get_segmentations_directory(self, run): + segmentation_dir = os.path.join(run.overlay_path, "Segmentations") + return segmentation_dir + + def change_button_color(self, button, color): + button.setStyleSheet(f"background-color: {color};") + + def reset_button_color(self, button): + self.change_button_color(button, "") + + def fit_on_all(self): + if not self.cell_canvas_app: + print("Initialize cell canvas first") + return + + print("Fitting all models to the selected dataset.") + + self.change_button_color( + self.fit_all_button, get_active_button_color() + ) + + self.model_fit_worker = self.threaded_fit_on_all() + self.model_fit_worker.returned.connect(self.on_model_fit_completed) + self.model_fit_worker.start() + + @thread_worker + def threaded_fit_on_all(self): + # Fit model on all pairs + data_manager = self.get_complete_data_manager(all_pairs=True) + + clf = RandomForestClassifier( + n_estimators=50, + n_jobs=-1, + max_depth=10, + max_samples=0.05, + ) + + segmentation_manager = SemanticSegmentationManager( + data=data_manager, model=clf + ) + segmentation_manager.fit() + + return segmentation_manager + + def on_model_fit_completed(self, segmentation_manager): + self.logger.debug("on_model_fit_completed") + + self.cell_canvas_app.semantic_segmentor.segmentation_manager = segmentation_manager + + # Reset color + self.reset_button_color(self.fit_all_button) + + def predict_for_all(self): + if not self.cell_canvas_app: + print("Initialize cell canvas first") + return + + print("Fitting all models to the selected dataset.") + + self.change_button_color( + self.predict_all_button, get_active_button_color() + ) + + self.predict_worker = self.threaded_predict_for_all() + self.predict_worker.returned.connect(self.on_predict_completed) + self.predict_worker.start() + + def on_predict_completed(self, result): + self.logger.debug("on_predict_completed") + + # Reset color + self.reset_button_color(self.predict_all_button) + + @thread_worker + def threaded_predict_for_all(self): + print("Running predictions on all datasets.") + + # Check if segmentation manager is properly initialized + if not hasattr(self.cell_canvas_app.semantic_segmentor, 'segmentation_manager') or self.cell_canvas_app.semantic_segmentor.segmentation_manager is None: + print("Segmentation manager is not initialized.") + return + + # Retrieve the complete data manager that includes all runs + data_manager = self.get_complete_data_manager() + + # Iterate through each dataset within the data manager + for dataset in data_manager.datasets: + dataset_features = da.asarray(dataset.concatenated_features) + chunk_shape = dataset_features.chunksize + shape = dataset_features.shape + dtype = dataset_features.dtype + + # Iterate over chunks + for z in range(0, shape[1], chunk_shape[1]): + for y in range(0, shape[2], chunk_shape[2]): + for x in range(0, shape[3], chunk_shape[3]): + # Compute the slice for the current chunk + # in feature,z,y,x order + chunk_slice = ( + slice(None), + slice(z, min(z + chunk_shape[1], shape[1])), + slice(y, min(y + chunk_shape[2], shape[2])), + slice(x, min(x + chunk_shape[3], shape[3])), + ) + print(f"Predicting on chunk {chunk_slice}") + + # Extract the current chunk + chunk = dataset_features[chunk_slice].compute() + + # Predict on the chunk (adding 1 to each prediction) + predicted_chunk = self.cell_canvas_app.semantic_segmentor.segmentation_manager.predict(chunk) + 1 + + # Write the prediction to the corresponding region in the Zarr array + dataset.segmentation[chunk_slice[1:]] = predicted_chunk + + print(f"Predictions written") + + def get_painting_segmentation_name(self): + return "cellcanvas-painting" + + def get_prediction_segmentation_name(self): + return "cellcanvas-prediction" + + def on_run_clicked(self, item, column): + data = item.data(0, Qt.UserRole) + if not isinstance(data, copick.impl.filesystem.CopickRunFSSpec): + self.on_item_clicked(item, column) + return + + self.selected_run = data + static_path = self.selected_run.static_path + overlay_path = self.selected_run.overlay_path + self.logger.info(f"Selected static path: {static_path} overlay path: {overlay_path}") + + # Clear existing items + for dropdown in self.dropdowns.values(): + dropdown.clear() + + voxel_spacing = self.selected_run.get_voxel_spacing(self.get_voxel_spacing()) + if not voxel_spacing: + print("Voxel spacing does not exist.") + return + + # features = self.selected_run.get_voxel_spacing(10).tomograms[0].get_features("cellcanvas01") + + # Define directories + voxel_spacing_dirs = voxel_spacing.static_path + + # Helper function to add items if not already in dropdown + def add_item_if_not_exists(dropdown, item_name, item_data): + if dropdown.findData(item_data) == -1: + dropdown.addItem(item_name, item_data) + + # Load image/tomograms + tomograms = voxel_spacing.tomograms + for tomogram in tomograms: + add_item_if_not_exists(self.dropdowns["image"], + tomogram.tomo_type, + tomogram) + + # Load features + for tomogram in tomograms: + features = tomogram.features + if features: + feature = features[0] + add_item_if_not_exists(self.dropdowns["features"], + tomogram.tomo_type, + feature) + + # Painting + painting_seg = self.selected_run.get_segmentations(user_id=self.root.user_id, is_multilabel=True, name=self.get_painting_segmentation_name(), voxel_size=10) + if not painting_seg: + # Create seg + painting_seg = self.selected_run.new_segmentation(10, self.get_painting_segmentation_name(), self.get_session_id(), True, user_id=self.root.user_id) + else: + painting_seg = painting_seg[0] + add_item_if_not_exists(self.dropdowns["painting"], painting_seg.name, painting_seg) + + # Prediction + prediction_seg = self.selected_run.get_segmentations(user_id=self.root.user_id, is_multilabel=True, name=self.get_prediction_segmentation_name(), voxel_size=10) + if not prediction_seg: + # Create seg + prediction_seg = self.selected_run.new_segmentation(10, self.get_prediction_segmentation_name(), self.get_session_id(), True, user_id=self.root.user_id) + else: + prediction_seg = prediction_seg[0] + add_item_if_not_exists(self.dropdowns["prediction"], prediction_seg.name, prediction_seg) + + def on_item_clicked(self, item, column): + data = item.data(0, Qt.UserRole) + if data: + if isinstance(data, copick.impl.filesystem.CopickPicksFSSpec): + self.open_picks(data) + elif isinstance(data, copick.impl.filesystem.CopickTomogramFSSpec): + self.open_tomogram(data) + elif isinstance(data, copick.models.CopickSegmentation): + self.open_labels(data) + + def open_picks(self, picks): + with open(picks.path, 'r') as f: + points_data = json.load(f) + + # Extracting points locations + points_locations = [ + [point['location']['z'], point['location']['y'], point['location']['x']] + for point in points_data['points'] + ] + + # TODO hard coded scaling + points_array = np.array(points_locations) / self.get_voxel_spacing() + + # Adding the points layer to the viewer, using the pickable_object_name as the layer name + pickable_object = [obj for obj in self.root.config.pickable_objects if obj.name == picks.pickable_object_name][0] + self.viewer.add_points(points_array, name=picks.pickable_object_name, size=25, out_of_slice_display=True, face_color=np.array(pickable_object.color)/255.0) + + def open_tomogram(self, tomogram): + zarr_store = zarr.open(tomogram.zarr(), mode='r') + print(f"open_tomogram {tomogram.zarr()}") + # TODO extract scale/transform info + + # TODO scale is hard coded to 10 here + self.viewer.add_image(zarr_store[0], name=f"Tomogram: {tomogram.tomo_type}") + + def open_labels(self, tomogram): + zarr_store = zarr.open(tomogram.zarr(), mode='r') + print(f"open_labels {tomogram.zarr()}") + # TODO extract scale/transform info + + # TODO scale is hard coded to 10 here + self.viewer.add_image(zarr_store[0], name=f"Tomogram: {tomogram.name}") + + def get_config_path(self, run_dir): + return os.path.join(run_dir, f"{self.get_user_id()}_config.json") + + def get_session_id(self): + return 17 + + def get_user_id(self): + return self.root.user_id + + def initialize_or_update_cell_canvas(self): + # Collect paths from dropdowns + paths = {layer: dropdown.currentData() for layer, dropdown in self.dropdowns.items()} + + if not paths["image"] or not paths["features"]: + print("Please ensure image and feature paths are selected before initializing/updating CellCanvas.") + return + + run_dir = self.selected_run.static_path + overlay_path = self.selected_run.overlay_path + + segmentation_dir = self.get_segmentations_directory(self.selected_run) + voxel_spacing_dir = self.get_default_voxel_spacing_directory(self.selected_run) + + voxel_spacing = self.get_voxel_spacing() + + # Ensure segmentations directory exists + # os.makedirs(segmentation_dir, exist_ok=True) + + # TODO note this is hard coded to use the highest resolution of a multiscale zarr + print(f"Opening paths:") + print(f"Image: {paths['image']}") + print(f"Features: {paths['features']}") + print(f"Painting: {paths['painting']}") + print(f"Prediction: {paths['prediction']}") + dataset = DataSet.from_stores( + image_store=paths['image'].zarr(), + features_store=paths['features'].zarr(), + labels_store=paths['painting'].zarr(), + segmentation_store=paths['prediction'].zarr(), + ) + + data_manager = DataManager(datasets=[dataset]) + + if not self.cell_canvas_app: + self.cell_canvas_app = CellCanvasApp(data=data_manager, viewer=self.viewer, verbose=True) + cell_canvas_widget = QtCellCanvas(app=self.cell_canvas_app) + self.viewer.window.add_dock_widget(cell_canvas_widget) + else: + # Update existing CellCanvasApp's data manager + self.cell_canvas_app.update_data_manager(data_manager) + + # TODO this has multiple copick specific hardcoded hacks + + # TODO hardcoded scale factor + # self.viewer.layers['Image'].scale = (10, 10, 10) + + # Set colormap + # painting_layer.colormap.color_dict + # self.app.painting_labels + self.cell_canvas_app.semantic_segmentor.set_colormap(self.get_copick_colormap()) + self.cell_canvas_app.semantic_segmentor.painting_labels = [obj.label for obj in self.root.config.pickable_objects] + self.cell_canvas_app.semantic_segmentor.widget.class_labels_mapping = {obj.label: obj.name for obj in self.root.config.pickable_objects} + +# self.cell_canvas_app.semantic_segmentor.widget.class_labels_mapping[9] = 'background' + self.cell_canvas_app.semantic_segmentor.widget.setupLegend() + +if __name__ == "__main__": + # Project root + + # root = CopickRootFSSpec.from_file("/Volumes/kish@CZI.T7/demo_project/copick_config_kyle.json") + # root = CopickRootFSSpec.from_file("/Volumes/kish@CZI.T7/chlamy_copick/copick_config_kyle.json") + root = CopickRootFSSpec.from_file("/Volumes/kish@CZI.T7/demo_project/copick_config_pickathon.json") + + viewer = napari.Viewer() + + # Hide layer list and controls + # viewer.window.qt_viewer.dockLayerList.setVisible(False) + # viewer.window.qt_viewer.dockLayerControls.setVisible(False) + + copick_explorer_widget = NapariCopickExplorer(viewer, root) + viewer.window.add_dock_widget(copick_explorer_widget, name="Copick Explorer", area="left") + + # napari.run() + diff --git a/src/cellcanvas/data/data_manager.py b/src/cellcanvas/data/data_manager.py index ddd0cfa..e114816 100644 --- a/src/cellcanvas/data/data_manager.py +++ b/src/cellcanvas/data/data_manager.py @@ -3,6 +3,7 @@ import numpy as np from napari.utils.events.containers import SelectableEventedList from zarr import Array +import dask.array as da from cellcanvas.data.data_set import DataSet @@ -15,6 +16,7 @@ def __init__(self, datasets: Optional[List[DataSet]] = None): datasets = [datasets] self.datasets = SelectableEventedList(datasets) + # Normal version def get_training_data(self) -> Tuple[Array, Array]: """Get the pixel-wise semantic segmentation training data for datasets. @@ -30,23 +32,26 @@ def get_training_data(self) -> Tuple[Array, Array]: features = [] labels = [] for dataset in self.datasets: - # get the features and labels - # todo make lazier - dataset_features = np.asarray(dataset.concatenated_features) - dataset_labels = np.asarray(dataset.labels) - - # reshape the data - dataset_labels = dataset_labels.flatten() - reshaped_features = dataset_features.reshape( - -1, dataset_features.shape[-1] - ) - - # Filter features where labels are greater than 0 - valid_labels = dataset_labels > 0 - filtered_features = reshaped_features[valid_labels, :] - filtered_labels = dataset_labels[valid_labels] - 1 # Adjust labels + dataset_features = da.asarray(dataset.concatenated_features) + dataset_labels = da.asarray(dataset.labels) + # Flatten labels for boolean indexing + flattened_labels = dataset_labels.flatten() + + # Compute valid_indices based on labels > 0 + valid_indices = da.nonzero(flattened_labels > 0)[0].compute() + + # Flatten only the spatial dimensions of the dataset_features while preserving the feature dimension + c, h, w, d = dataset_features.shape + reshaped_features = dataset_features.reshape(c, h * w * d) + + # We need to apply valid_indices for each feature dimension separately + filtered_features_list = [da.take(reshaped_features[i, :], valid_indices, axis=0) for i in range(c)] + filtered_features = da.stack(filtered_features_list, axis=1) + + # Adjust labels + filtered_labels = flattened_labels[valid_indices] - 1 features.append(filtered_features) labels.append(filtered_labels) - - return np.concatenate(features), np.concatenate(labels) + + return da.concatenate(features), da.concatenate(labels) diff --git a/src/cellcanvas/data/data_set.py b/src/cellcanvas/data/data_set.py index f61b423..7021215 100644 --- a/src/cellcanvas/data/data_set.py +++ b/src/cellcanvas/data/data_set.py @@ -6,6 +6,8 @@ import zarr from zarr import Array +from ome_zarr.io import ZarrLocation +from ome_zarr.reader import Multiscales @dataclass class DataSet: @@ -62,7 +64,11 @@ def from_paths( dimension_separator=".", ) else: - labels = zarr.open(labels_path, "a") + if Multiscales.matches(ZarrLocation(labels_path)): + labels = zarr.open(os.path.join(labels_path, "0"), + "a") + else: + labels = zarr.open(labels_path, "a") # get the segmentation if (not os.path.isdir(segmentation_path)) and make_missing_datasets: @@ -83,3 +89,57 @@ def from_paths( labels=labels, segmentation=segmentation, ) + + @classmethod + def from_stores( + cls, + image_store, + features_store, + labels_store, + segmentation_store, + ): + """Create a DataSet from a set of paths. + + todo: add ability to create missing labels/segmentations + """ + + # TODO rewrite this to copy everything to be local + + # get the image + # TODO fix hardcoded scale for pickathon + image = zarr.open(zarr.storage.LRUStoreCache(image_store, None), "r")["0"] + + # get the features + features = {"features": zarr.open(zarr.storage.LRUStoreCache(features_store, None), "r")} + + group_name = "labels" + + # get the labels + labels = zarr.open_group(zarr.storage.LRUStoreCache(labels_store, None), + mode="a") + if group_name in labels: + labels = labels[group_name] + else: + labels = labels.create_dataset(group_name, + shape=image.shape, + dtype="i4") + + # get the segmentation + segmentation = zarr.open_group(zarr.storage.LRUStoreCache(segmentation_store, None), + mode="a") + if group_name in segmentation: + segmentation = segmentation[group_name] + else: + segmentation = segmentation.create_dataset(group_name, + shape=image.shape, + dtype="i4") + + # TODO start a background thread that triggers downloads of the zarrs + + return cls( + image=image, + features=features, + labels=labels, + segmentation=segmentation, + ) + diff --git a/src/cellcanvas/semantic/_embedding_segmentor.py b/src/cellcanvas/semantic/_embedding_segmentor.py index c00e37a..0015c5f 100644 --- a/src/cellcanvas/semantic/_embedding_segmentor.py +++ b/src/cellcanvas/semantic/_embedding_segmentor.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt import napari import numpy as np +import dask.array as da import toolz as tz import zarr from matplotlib.backends.backend_qt5agg import ( @@ -18,7 +19,7 @@ from napari.qt.threading import thread_worker from napari.utils import DirectLabelColormap from psygnal import debounced -from qtpy.QtCore import Qt +from qtpy.QtCore import Qt, Signal, Slot from qtpy.QtGui import QColor, QPainter, QPixmap from qtpy.QtWidgets import ( QCheckBox, @@ -34,6 +35,7 @@ QVBoxLayout, QWidget, ) +from qtpy import QtCore, QtWidgets from sklearn.cross_decomposition import PLSRegression from sklearn.ensemble import RandomForestClassifier from sklearn.utils.class_weight import compute_class_weight @@ -46,6 +48,8 @@ ) from cellcanvas.utils import get_labels_colormap, paint_maker +import xgboost as xgb + ACTIVE_BUTTON_COLOR = "#AF8B38" @@ -59,11 +63,23 @@ def __init__( self.extra_logging = extra_logging self.data = data_manager - clf = RandomForestClassifier( - n_estimators=50, - n_jobs=-1, - max_depth=10, - max_samples=0.05, + self.colormap = get_labels_colormap() + # clf = RandomForestClassifier( + # n_estimators=25, + # n_jobs=-1, + # max_depth=10, + # max_samples=0.05, + # max_features='sqrt', + # class_weight='balanced' + # ) + + clf = xgb.XGBClassifier( + objective='multi:softmax', + num_class=10, # Specify number of classes if using softmax + n_estimators=200, + max_depth=20, + learning_rate=0.1, + scale_pos_weight='balanced' # For handling imbalance ) self.segmentation_manager = SemanticSegmentationManager( data=self.data, model=clf @@ -90,7 +106,7 @@ def __init__( # self.logger.info(f"zarr_path: {zarr_path}") self._add_threading_workers() - self._init_viewer_layers() + self.update_data_manager(self.data) self._add_widget() self.model = None @@ -98,6 +114,27 @@ def __init__( self.start_computing_embedding_plot() self.update_class_distribution_charts() + def set_colormap(self, colormap): + self.colormap = colormap + + self.prediction_layer.colormap = DirectLabelColormap(color_dict=colormap) + self.painting_layer.colormap = DirectLabelColormap(color_dict=colormap) + self.update_class_distribution_charts() + + def update_data_manager(self, data: DataManager): + self.data = data + self.segmentation_manager.update_data_manager(data) + + # get the image and features + # todo this is temporarily assuming a single dataset + # need to generalize + self.image_data = self.data.datasets[0].image + self.features = self.data.datasets[0].features + + # TODO remove old layers + self.viewer.layers.clear() + self._init_viewer_layers() + def reshape_features(self, arr): return arr.reshape(-1, arr.shape[-1]) @@ -112,6 +149,8 @@ def _init_viewer_layers(self): self.data_layer = self.viewer.add_image( self.image_data, name="Image", projection_mode="mean" ) + self.data_layer._keep_auto_contrast = True + self.data_layer.refresh() # self.prediction_data = zarr.open( # f"{self.zarr_path}/prediction", # mode="a", @@ -125,7 +164,7 @@ def _init_viewer_layers(self): name="Prediction", scale=self.data_layer.scale, opacity=0.1, - colormap=DirectLabelColormap(color_dict=get_labels_colormap()), + colormap=DirectLabelColormap(color_dict=self.colormap), ) # self.painting_data = zarr.open( @@ -136,11 +175,12 @@ def _init_viewer_layers(self): # dimension_separator=".", # ) self.painting_data = self.data.datasets[0].labels + # .data.astype("i4") self.painting_layer = self.viewer.add_labels( self.painting_data, name="Painting", scale=self.data_layer.scale, - colormap=DirectLabelColormap(color_dict=get_labels_colormap()), + colormap=DirectLabelColormap(color_dict=self.colormap), ) # Set up painting logging @@ -183,7 +223,7 @@ def _connect_events(self): listener.connect( debounced( ensure_main_thread(on_data_change_handler), - timeout=1000, + timeout=5000, ) ) @@ -209,7 +249,9 @@ def on_data_change(self, event, app): self.corner_pixels = self.viewer.layers["Image"].corner_pixels # TODO check if this is stalling things - self.painting_labels, self.painting_counts = np.unique( + # TODO recheck this after copick + # self.painting_labels, self.painting_counts = np.unique( + _, self.painting_counts = np.unique( self.painting_data[:], return_counts=True ) @@ -220,7 +262,7 @@ def on_data_change(self, event, app): self.update_class_distribution_charts() # Update projection - self.start_computing_embedding_plot() + # self.start_computing_embedding_plot() self.widget.setupLegend() @@ -238,7 +280,9 @@ def threaded_on_data_change( self.logger.info(f"Labels data has changed! {event}") # noqa: G004 # Update stats - self.painting_labels, self.painting_counts = np.unique( + # TODO check after copick + # self.painting_labels, self.painting_counts = np.unique( + _, self.painting_counts = np.unique( self.painting_data[:], return_counts=True ) @@ -250,9 +294,7 @@ def threaded_on_data_change( self.start_prediction() def get_model_type(self): - if not self.model_type: - self.model_type = self.widget.model_dropdown.currentText() - return self.model_type + return "Random Forest" def get_corner_pixels(self): if self.corner_pixels is None: @@ -320,13 +362,14 @@ def update_model(self, model_type): if filtered_labels.size == 0: self.logger.info("No labels present. Skipping model update.") return None - + # Calculate class weights unique_labels = np.unique(filtered_labels) class_weights = compute_class_weight( "balanced", classes=unique_labels, y=filtered_labels ) weight_dict = dict(zip(unique_labels, class_weights)) + self.logger.info(f"Class balance calculated {class_weights}") # Apply weights # sample_weights = np.vectorize(weight_dict.get)(filtered_labels) @@ -334,13 +377,15 @@ def update_model(self, model_type): # Model fitting if model_type == "Random Forest": clf = RandomForestClassifier( - n_estimators=50, + n_estimators=100, n_jobs=-1, - max_depth=10, + max_depth=15, max_samples=0.05, class_weight=weight_dict, ) self.segmentation_manager.model = clf + # self.segmentation_manager.fit() + self.logger.info(f"Starting model fitting") self.segmentation_manager.fit() return self.segmentation_manager.model elif model_type == "XGBoost": @@ -358,26 +403,49 @@ def update_model(self, model_type): raise ValueError(f"Unsupported model type: {model_type}") def predict(self): - # We shift labels + 1 because background is 0 and has special meaning - # prediction = ( - # future.predict_segmenter( - # features.reshape(-1, features.shape[-1]), model - # ).reshape(features.shape[:-1]) - # + 1 - # ) - prediction = ( - self.segmentation_manager.predict( - np.asarray(self.data.datasets[0].concatenated_features) - ) - + 1 - ) - - # Compute stats in thread too - prediction_labels, prediction_counts = np.unique( - prediction, return_counts=True - ) - - return (prediction, prediction_labels, prediction_counts) + dataset_features = da.asarray(self.data.datasets[0].concatenated_features) + chunk_shape = dataset_features.chunksize + shape = dataset_features.shape + dtype = dataset_features.dtype + + # Placeholder for aggregated labels and counts + all_labels = [] + all_counts = [] + + # Iterate over chunks + for z in range(0, shape[1], chunk_shape[1]): + for y in range(0, shape[2], chunk_shape[2]): + for x in range(0, shape[3], chunk_shape[3]): + # Compute the slice for the current chunk + # in feature,z,y,x order + chunk_slice = ( + slice(None), + slice(z, min(z + chunk_shape[1], shape[1])), + slice(y, min(y + chunk_shape[2], shape[2])), + slice(x, min(x + chunk_shape[3], shape[3])), + ) + print(f"Predicting on chunk {chunk_slice}") + + # Extract the current chunk + chunk = dataset_features[chunk_slice].compute() + + # Predict on the chunk (adding 1 to each prediction) + predicted_chunk = self.segmentation_manager.predict(chunk) + 1 + + # Write the prediction to the corresponding region in the Zarr array + self.prediction_data[chunk_slice[1:]] = predicted_chunk + + # Aggregate labels and counts + labels, counts = np.unique(predicted_chunk, return_counts=True) + all_labels.append(labels) + all_counts.append(counts) + + # Combine all_labels and all_counts + unique_labels, inverse = np.unique(np.concatenate(all_labels), return_inverse=True) + total_counts = np.bincount(inverse, weights=np.concatenate(all_counts)) + + # Now, self.prediction_data should contain the predicted labels + return self.prediction_data, unique_labels, total_counts @thread_worker def prediction_thread(self): @@ -402,6 +470,8 @@ def start_prediction(self): # features = self.get_features() + # TODO use a yielded connect worker + self.prediction_worker = self.prediction_thread() self.prediction_worker.returned.connect(self.on_prediction_completed) self.prediction_worker.start() @@ -414,9 +484,6 @@ def on_prediction_completed(self, result): self.prediction_labels = prediction_labels self.prediction_counts = prediction_counts - self.get_prediction_layer().data = self.prediction_data.reshape( - self.get_prediction_layer().data.shape - ) self.get_prediction_layer().refresh() self.update_class_distribution_charts() @@ -442,7 +509,6 @@ def start_model_fit(self): self.model_fit_worker = self.model_fit_thread(self.get_model_type()) self.model_fit_worker.returned.connect(self.on_model_fit_completed) - # TODO update UI to indicate that model training has started self.model_fit_worker.start() def on_model_fit_completed(self, model): @@ -468,16 +534,23 @@ def update_class_distribution_charts(self): else 1 ) - painting_counts = ( - self.painting_counts - if self.painting_counts is not None - else np.array([0]) - ) - painting_labels = ( - self.painting_labels - if self.painting_labels is not None - else np.array([0]) - ) + # Initialize counts for all labels in painting_labels with zero + if self.painting_labels is not None: + unique_labels = np.unique(self.painting_labels) + painting_counts_dict = {label: 0 for label in unique_labels} + else: + unique_labels = np.array([0]) + painting_counts_dict = {0: 0} + + # Update counts from existing painting_counts if available + if self.painting_counts is not None and self.painting_labels is not None: + for label, count in zip(self.painting_labels, self.painting_counts): + painting_counts_dict[label] = count + + # Create arrays from the dictionary + painting_labels = np.array(list(painting_counts_dict.keys())) + painting_counts = np.array(list(painting_counts_dict.values())) + prediction_counts = ( self.prediction_counts if self.prediction_counts is not None @@ -502,9 +575,6 @@ def update_class_distribution_charts(self): self.logger.info( f"image layer: contrast_limits = {self.viewer.layers['Image'].contrast_limits}, opacity = {self.viewer.layers['Image'].opacity}, gamma = {self.viewer.layers['Image'].gamma}" # noqa G004 ) - self.logger.info( - f"Current model type: {self.widget.model_dropdown.currentText()}" # noqa G004 - ) # Calculate percentages instead of raw counts painting_percentages = (painting_counts / total_pixels) * 100 @@ -547,7 +617,7 @@ def update_class_distribution_charts(self): # Example class to color mapping class_color_mapping = { label: f"#{int(rgba[0] * 255):02x}{int(rgba[1] * 255):02x}{int(rgba[2] * 255):02x}" - for label, rgba in get_labels_colormap().items() + for label, rgba in self.colormap.items() } self.widget.figure.clear() @@ -672,6 +742,8 @@ def update_class_distribution_charts(self): def compute_embedding_projection(self): # Filter out entries where the label is 0 filtered_features, filtered_labels = self.data.get_training_data() + filtered_features = filtered_features.compute() + filtered_labels = filtered_labels.compute() # label values are offset by 1 for training, # undo the offset. @@ -733,7 +805,7 @@ def create_embedding_plot(self, result): label: "#{:02x}{:02x}{:02x}".format( int(rgba[0] * 255), int(rgba[1] * 255), int(rgba[2] * 255) ) - for label, rgba in get_labels_colormap().items() + for label, rgba in self.colormap.items() } # Convert filtered_labels to a list of colors for each point @@ -840,14 +912,23 @@ def paint_thread(self, lasso_path, target_label): # Update the painting data self.painting_data[z, y, x] = target_label - if self.extra_logging: - self.logger.info( - f"lasso paint: label = {target_label}, indices = {paint_indices}" # noqa G004 - ) + # if self.extra_logging: + # self.logger.info( + # f"lasso paint: label = {target_label}, indices = {paint_indices}" # noqa G004 + # ) # print(f"Painted {np.sum(contained)} pixels with label {target_label}") +class ClickableLabel(QLabel): + clicked = Signal(int) # Emits the label ID + def __init__(self, label_id, *args, **kwargs): + super().__init__(*args, **kwargs) + self.label_id = label_id + + def mousePressEvent(self, event): + self.clicked.emit(self.label_id) + class EmbeddingPaintingWidget(QWidget): def __init__(self, app, parent=None): super().__init__(parent=parent) @@ -860,22 +941,6 @@ def initUI(self): self.legend_placeholder_index = 0 - # Settings Group - settings_group = QGroupBox("Settings") - settings_layout = QVBoxLayout() - - model_layout = QHBoxLayout() - model_label = QLabel("Select Model") - self.model_dropdown = QComboBox() - self.model_dropdown.addItems(["Random Forest", "XGBoost"]) - model_layout.addWidget(model_label) - model_layout.addWidget(self.model_dropdown) - settings_layout.addLayout(model_layout) - - self.add_features_button = QPushButton("Add Features") - self.add_features_button.clicked.connect(self.add_features) - settings_layout.addWidget(self.add_features_button) - thickness_layout = QHBoxLayout() thickness_label = QLabel("Adjust Slice Thickness") self.thickness_slider = QSlider(Qt.Horizontal) @@ -885,14 +950,11 @@ def initUI(self): self.thickness_slider.setValue(10) thickness_layout.addWidget(thickness_label) thickness_layout.addWidget(self.thickness_slider) - settings_layout.addLayout(thickness_layout) - + main_layout.addLayout(thickness_layout) + # Update layer contrast limits after thick slices has effect self.app.viewer.layers["Image"].reset_contrast_limits() - settings_group.setLayout(settings_layout) - main_layout.addWidget(settings_group) - # Controls Group controls_group = QGroupBox("Controls") controls_layout = QVBoxLayout() @@ -915,10 +977,24 @@ def initUI(self): live_pred_layout.addWidget(self.live_pred_button) controls_layout.addLayout(live_pred_layout) + # Connect checkbox signals to actions + self.live_fit_checkbox.stateChanged.connect(self.on_live_fit_changed) + self.live_pred_checkbox.stateChanged.connect(self.on_live_pred_changed) + + # Connect button clicks to actions + self.live_fit_button.clicked.connect(self.app.start_model_fit) + self.live_pred_button.clicked.connect(self.app.start_prediction) + + # Export model self.export_model_button = QPushButton("Export Model") controls_layout.addWidget(self.export_model_button) self.export_model_button.clicked.connect(self.export_model) + # Import model + self.import_model_button = QPushButton("Import Model") + controls_layout.addWidget(self.import_model_button) + self.import_model_button.clicked.connect(self.import_model) + controls_group.setLayout(controls_layout) main_layout.addWidget(controls_group) @@ -948,19 +1024,16 @@ def initUI(self): self.embedding_canvas = FigureCanvas(self.embedding_figure) self.stats_summary_layout.addWidget(self.embedding_canvas) + # Create a button for computing the embedding plot + self.compute_embedding_button = QPushButton("Compute Embedding Plot") + self.compute_embedding_button.clicked.connect(self.app.start_computing_embedding_plot) + self.stats_summary_layout.addWidget(self.compute_embedding_button) + stats_summary_group.setLayout(self.stats_summary_layout) main_layout.addWidget(stats_summary_group) self.setLayout(main_layout) - # Connect checkbox signals to actions - self.live_fit_checkbox.stateChanged.connect(self.on_live_fit_changed) - self.live_pred_checkbox.stateChanged.connect(self.on_live_pred_changed) - - # Connect button clicks to actions - self.live_fit_button.clicked.connect(self.app.start_model_fit) - self.live_pred_button.clicked.connect(self.app.start_prediction) - def add_features(self): zarr_path = QFileDialog.getExistingDirectory(self, "Select Directory") @@ -992,6 +1065,23 @@ def export_model(self): self, "Model Export", "No model available to export." ) + def import_model(self): + filePath, _ = QFileDialog.getOpenFileName( + self, "Open Model", "", "Joblib Files (*.joblib)" + ) + if filePath: + try: + model = joblib.load(filePath) + self.app.model = model + QMessageBox.information( + self, "Model Import", "Model imported successfully!" + ) + print(f"Loaded model file from: {filePath}") + except Exception as e: + QMessageBox.warning( + self, "Model Import", f"Failed to import model. Error: {str(e)}" + ) + def change_embedding_label_color(self, color): """Change the background color of the embedding label.""" self.embedding_label.setStyleSheet(f"background-color: {color};") @@ -1040,7 +1130,7 @@ def setupLegend(self): color = painting_layer.colormap.color_dict[label_id] # Create a QLabel for color swatch - color_swatch = QLabel() + color_swatch = ClickableLabel(label_id) pixmap = QPixmap(16, 16) if color is None: @@ -1049,6 +1139,7 @@ def setupLegend(self): pixmap.fill(QColor(*[int(c * 255) for c in color])) color_swatch.setPixmap(pixmap) + color_swatch.clicked.connect(self.activateLabel) # Update the mapping with new classes or use the existing name if label_id not in self.class_labels_mapping: @@ -1061,7 +1152,7 @@ def setupLegend(self): label_edit = QLineEdit(label_name) # Highlight the label if it is currently being used - if label_id == painting_layer._selected_label: + if label_id == painting_layer.selected_label: self.highlightLabel(label_edit) # Save changes to class labels back to the mapping @@ -1083,6 +1174,18 @@ def setupLegend(self): self.legend_placeholder_index, self.legend_group ) + def activateLabel(self, current_label_id): + painting_layer = self.app.get_painting_layer() + painting_layer.selected_label = current_label_id + + for label_id, label_edit in self.label_edits.items(): + if label_id == current_label_id: + self.highlightLabel(label_edit) + else: + self.removeHighlightLabel(label_edit) + + self.app.viewer.layers.selection.active = painting_layer + def updateLegendHighlighting(self, selected_label_event): """Update highlighting of legend""" current_label_id = selected_label_event.source._selected_label diff --git a/src/cellcanvas/semantic/segmentation_manager.py b/src/cellcanvas/semantic/segmentation_manager.py index 87b47ac..09e5d67 100644 --- a/src/cellcanvas/semantic/segmentation_manager.py +++ b/src/cellcanvas/semantic/segmentation_manager.py @@ -1,10 +1,14 @@ from typing import Protocol +import sys +import logging import numpy as np +import dask.array as da +from dask import delayed from sklearn.exceptions import NotFittedError from cellcanvas.data.data_manager import DataManager - +from tqdm import tqdm class SegmentationModel(Protocol): """Protocol for semantic segmentations models that are @@ -20,13 +24,34 @@ def __init__(self, data: DataManager, model: SegmentationModel): self.data = data self.model = model + self._init_logging() + + def _init_logging(self): + self.logger = logging.getLogger("cellcanvas") + self.logger.setLevel(logging.DEBUG) + streamHandler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + streamHandler.setFormatter(formatter) + self.logger.addHandler(streamHandler) + + def update_data_manager(self, data: DataManager): + self.data = data + def fit(self): """Fit using the model using the data in the data manager.""" + self.logger.info("Starting to fit") + # Get training data from the data manager features, labels = self.data.get_training_data() - self.model.fit(features, labels) + features_computed, labels_computed = features.compute(), labels.compute() + + self.logger.info("Starting the actual model fit") - def predict(self, feature_image: np.ndarray): + self.model.fit(features_computed, labels_computed) + + def predict(self, feature_image): """Predict using the trained model. Parameters @@ -39,7 +64,8 @@ def predict(self, feature_image: np.ndarray): predicted_labels : Array The prediction of class. """ - features = feature_image.reshape((-1, feature_image.shape[-1])) + c, z, y, x = feature_image.shape + features = feature_image.transpose(1, 2, 3, 0).reshape(-1, c) try: predicted_labels = self.model.predict(features) @@ -49,4 +75,5 @@ def predict(self, feature_image: np.ndarray): "for example with the `fit_segmenter` function." ) from None - return predicted_labels.reshape(feature_image.shape[:-1]) + return predicted_labels.reshape(feature_image.shape[1:]) + diff --git a/src/cellcanvas/utils.py b/src/cellcanvas/utils.py index fbc7211..0041ae2 100644 --- a/src/cellcanvas/utils.py +++ b/src/cellcanvas/utils.py @@ -3,6 +3,9 @@ sphere_indices, ) +from qtpy.QtWidgets import (QApplication, QGroupBox, QVBoxLayout, QHBoxLayout, + QLabel, QComboBox, QPushButton, QWidget, QCheckBox) +from qtpy.QtCore import Slot, Qt def get_labels_colormap(): """Return a colormap for distinct label colors based on: @@ -71,10 +74,14 @@ def paint(self, coord, new_label, refresh=True): int ) - logger.info("paint: label = %s, indices = %s", new_label, mask_indices) + # logger.info("paint: label = %s, indices = %s", new_label, mask_indices) self._paint_indices( mask_indices, new_label, shape, dims_to_paint, slice_coord, refresh ) return paint + +def get_active_button_color(): + return "#AF8B38" +