From 0f70d62cda731560844cdf4dbcb218f792ec219c Mon Sep 17 00:00:00 2001 From: "[shankell212]" <[shankell212@gmail.com]> Date: Fri, 9 Jan 2026 10:02:22 -0500 Subject: [PATCH] updating plotprobe GUI to include tstat thresholding --- src/cedalion/vis/misc/plot_probe_gui.py | 172 +++++++++++++++++++----- 1 file changed, 141 insertions(+), 31 deletions(-) diff --git a/src/cedalion/vis/misc/plot_probe_gui.py b/src/cedalion/vis/misc/plot_probe_gui.py index 30f5a363..7b0bf258 100644 --- a/src/cedalion/vis/misc/plot_probe_gui.py +++ b/src/cedalion/vis/misc/plot_probe_gui.py @@ -16,7 +16,7 @@ import numpy as np from matplotlib.backends.backend_qtagg import FigureCanvas from matplotlib.backends.backend_qtagg import NavigationToolbar2QT as NavigationToolbar -from matplotlib.backends.qt_compat import QtWidgets +from matplotlib.backends.qt_compat import QtWidgets, QtGui from matplotlib.figure import Figure import cedalion @@ -26,10 +26,11 @@ class _MAIN_GUI(QtWidgets.QMainWindow): - def __init__(self, snirfData=None, geo2d=None, geo3d=None): + def __init__(self, snirfData=None, stderr=None, geo2d=None, geo3d=None): # Initialize super().__init__() self.snirfData = snirfData + self.stderr = stderr self.geo2d = geo2d self.geo3d = geo3d @@ -43,7 +44,7 @@ def __init__(self, snirfData=None, geo2d=None, geo3d=None): window_layout.setSpacing(10) # Set Minimum Size - self.setMinimumSize(1000, 850) + self.setMinimumSize(800, 600) # Set Window Title self.setWindowTitle("Plot Probe") @@ -139,6 +140,28 @@ def __init__(self, snirfData=None, geo2d=None, geo3d=None): ## Add Prune Channels control_panel_layout.addWidget(prune_channels, stretch=1) + # Create t-stat thresh control - only if standard error is provided + if self.stderr is not None: + tstat_control = QtWidgets.QGroupBox("T-Stat Threshold") + tstat_control_layout = QtWidgets.QVBoxLayout() + tstat_control_layout.setSpacing(10) + tstat_control.setLayout(tstat_control_layout) + + ## Set up T-stat threshold controller + tstat_threshold_layout = QtWidgets.QHBoxLayout() + self.tstat_threshold = QtWidgets.QDoubleSpinBox() + self.tstat_threshold.setValue(0) # Default: no threshold + self.tstat_threshold.setRange(-100, 100) + self.tstat_threshold.setSingleStep(0.5) + self.tstat_threshold.setDecimals(2) + self.tstat_threshold.valueChanged.connect(self._tstat_threshold_changed) + tstat_threshold_layout.addWidget(QtWidgets.QLabel("Threshold")) + tstat_threshold_layout.addWidget(self.tstat_threshold) + tstat_control_layout.addLayout(tstat_threshold_layout) + + ## Add T-stat control + control_panel_layout.addWidget(tstat_control, stretch=1) + # Create Probe Control probe_control = QtWidgets.QGroupBox("Probe") probe_control_layout = QtWidgets.QVBoxLayout() @@ -186,7 +209,7 @@ def __init__(self, snirfData=None, geo2d=None, geo3d=None): # control_panel_layout.addWidget(ref_point,stretch=1) # Create button action for opening file - open_btn = QtWidgets.QAction("Open...", self) + open_btn = QtGui.QAction("Open...", self) open_btn.setStatusTip("Open SNIRF file") open_btn.triggered.connect(self._open_dialog) @@ -199,9 +222,10 @@ def __init__(self, snirfData=None, geo2d=None, geo3d=None): file_menu.addAction(open_btn) if self.snirfData is not None: + time_dim = 'reltime' if 'reltime' in self.snirfData.dims else 'time' # Detect time dimension if np.shape(self.snirfData)[1] != len(self.snirfData.channel): self.snirfData = self.snirfData.transpose( - "trial_type", "channel", "chromo", "reltime" + "trial_type", "channel", "chromo", time_dim ) self.sPos = self.geo2d.sel( @@ -272,6 +296,25 @@ def _init_calc(self): self.fade_factor = 0.3 ##### Connect? self.lineWidth = 0.7 ##### Connect? + if 'reltime' in self.snirfData.dims: # handle 'time' or 'reltime' + self.time_dim = 'reltime' + elif 'time' in self.snirfData.dims: + self.time_dim = 'time' + else: + raise ValueError("Data must have either 'time' or 'reltime' dimension") + + + # T-stat calculation + self.tstat_thresh = 0 # Initialize threshold + if self.stderr is not None: + # Calculate t-statistic: mean / stderr + self.tstat = self.snirfData / self.stderr + # Get max absolute t-stat per channel across time for thresholding + self.tstat_max = np.abs(self.tstat).max(dim=self.time_dim) # Max across time + else: + self.tstat = None + self.tstat_max = None + self.conditions.clear() self.opt2circ.setChecked(False) self.measline.setChecked(False) @@ -354,14 +397,21 @@ def _init_calc(self): # Extract time information try: - self.t = self.snirfData.time.values - except Exception: - pass - - try: - self.t = self.snirfData.reltime.values + self.t = self.snirfData[self.time_dim].values # CHANGE to use self.time_dim except Exception: - pass + # Fallback to trying both + try: + self.t = self.snirfData.time.values + except Exception: + self.t = self.snirfData.reltime.values + # try: + # self.t = self.snirfData.time.values + # except Exception: + # pass + # try: + # self.t = self.snirfData.reltime.values + # except Exception: + # pass self.minT = min(self.t) self.maxT = max(self.t) @@ -424,37 +474,86 @@ def _init_calc(self): self._draw_hrf() self.conditions.setCurrentRow(0) + # def _change_hrf_vis(self): # orig + # for i_con in range(self.trial_types): + # if i_con == self.conditions.currentRow(): + # for i_ch in range(self.channels): + # if ( + # self.chan_dist[i_ch] >= self.channel_min_dist + # and self.chan_dist[i_ch] <= self.ssFadeThres + # ): + # for i_col in range(self.chromophores): + # self.hrf[ + # i_con * self.channels * self.chromophores + # + i_ch * self.chromophores + # + i_col + # ].set_color(self.chrom[i_col] + [self.fade_factor]) + # elif ( + # self.chan_dist[i_ch] >= self.ssFadeThres + # and self.chan_dist[i_ch] <= self.channel_max_dist + # ): + # for i_col in range(self.chromophores): + # self.hrf[ + # i_con * self.channels * self.chromophores + # + i_ch * self.chromophores + # + i_col + # ].set_color(self.chrom[i_col] + [1]) + # else: + # for i_col in range(self.chromophores): + # self.hrf[ + # i_con * self.channels * self.chromophores + # + i_ch * self.chromophores + # + i_col + # ].set_color(self.chrom[i_col] + [0]) + # else: + # for i_ch in range(self.channels): + # for i_col in range(self.chromophores): + # self.hrf[ + # i_con * self.channels * self.chromophores + # + i_ch * self.chromophores + # + i_col + # ].set_color(self.chrom[i_col] + [0]) + + # self._ax.figure.canvas.draw() + def _change_hrf_vis(self): for i_con in range(self.trial_types): if i_con == self.conditions.currentRow(): for i_ch in range(self.channels): + # Check if channel meets t-stat threshold + meets_tstat = True + if self.tstat_max is not None: + # Check if ANY chromophore meets threshold for this channel + meets_tstat = any( + self.tstat_max.sel(trial_type=self.snirfData.trial_type[i_con]).values[i_ch, i_col] >= self.tstat_thresh + for i_col in range(self.chromophores) + ) + + # Determine alpha based on distance and t-stat if ( self.chan_dist[i_ch] >= self.channel_min_dist and self.chan_dist[i_ch] <= self.ssFadeThres ): - for i_col in range(self.chromophores): - self.hrf[ - i_con * self.channels * self.chromophores - + i_ch * self.chromophores - + i_col - ].set_color(self.chrom[i_col] + [self.fade_factor]) + base_alpha = self.fade_factor elif ( self.chan_dist[i_ch] >= self.ssFadeThres and self.chan_dist[i_ch] <= self.channel_max_dist ): - for i_col in range(self.chromophores): - self.hrf[ - i_con * self.channels * self.chromophores - + i_ch * self.chromophores - + i_col - ].set_color(self.chrom[i_col] + [1]) + base_alpha = 1 else: - for i_col in range(self.chromophores): - self.hrf[ - i_con * self.channels * self.chromophores - + i_ch * self.chromophores - + i_col - ].set_color(self.chrom[i_col] + [0]) + base_alpha = 0 + + # Apply t-stat threshold: fade if doesn't meet threshold + if not meets_tstat and base_alpha > 0: + base_alpha = base_alpha * 0.5 # Further fade channels below threshold + + # Set color for each chromophore + for i_col in range(self.chromophores): + self.hrf[ + i_con * self.channels * self.chromophores + + i_ch * self.chromophores + + i_col + ].set_color(self.chrom[i_col] + [base_alpha]) else: for i_ch in range(self.channels): for i_col in range(self.chromophores): @@ -491,6 +590,8 @@ def _toggle_circles(self): if self.opt2circ.isChecked(): self.src_optodes.set_color([1, 0, 0]) self.det_optodes.set_color([0, 0, 1]) + self.src_optodes.set_markersize(3) # ADD THIS - Make circles smaller + self.det_optodes.set_markersize(3) # ADD THIS - Make circles smaller for idx, source in enumerate(self.sPos.label): self.src_label[idx].set_color([1, 0, 0, 0]) @@ -499,6 +600,8 @@ def _toggle_circles(self): else: self.src_optodes.set_color([1, 0, 0, 0]) self.det_optodes.set_color([0, 0, 1, 0]) + self.src_optodes.set_markersize(5) # ADD THIS - Reset to original size + self.det_optodes.set_markersize(5) # ADD THIS - Reset to original size for idx, source in enumerate(self.sPos.label): self.src_label[idx].set_color([1, 0, 0, 1]) @@ -567,6 +670,11 @@ def _ssfade_changed(self, i): self.ssFadeThres = i self._change_hrf_vis() + def _tstat_threshold_changed(self, i): + # Pass the new t-stat threshold and update HRF visibility + self.tstat_thresh = i + self._change_hrf_vis() + def _draw_hrf(self): print("Plotting Optodes!") t0 = time.time() @@ -656,6 +764,7 @@ def run_vis( blockaverage: cdt.NDTimeSeries, geo2d: cdt.LabeledPoints, geo3d: cdt.LabeledPoints, + stderr: cdt.NDTimeSeries = None, # optional standerr input ): """Opens the visualization GUI. @@ -666,6 +775,7 @@ def run_vis( """ app = QtWidgets.QApplication(sys.argv) - main_gui = _MAIN_GUI(snirfData=blockaverage, geo2d=geo2d, geo3d=geo3d) + #main_gui = _MAIN_GUI(snirfData=blockaverage, geo2d=geo2d, geo3d=geo3d) + main_gui = _MAIN_GUI(snirfData=blockaverage, stderr=stderr, geo2d=geo2d, geo3d=geo3d) main_gui.show() sys.exit(app.exec())