Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 141 additions & 31 deletions src/cedalion/vis/misc/plot_probe_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
from matplotlib.backends.backend_qtagg import FigureCanvas
from matplotlib.backends.backend_qtagg import NavigationToolbar2QT as NavigationToolbar
from matplotlib.backends.qt_compat import QtWidgets
from matplotlib.backends.qt_compat import QtWidgets, QtGui
from matplotlib.figure import Figure

import cedalion
Expand All @@ -26,10 +26,11 @@


class _MAIN_GUI(QtWidgets.QMainWindow):
def __init__(self, snirfData=None, geo2d=None, geo3d=None):
def __init__(self, snirfData=None, stderr=None, geo2d=None, geo3d=None):
# Initialize
super().__init__()
self.snirfData = snirfData
self.stderr = stderr
self.geo2d = geo2d
self.geo3d = geo3d

Expand All @@ -43,7 +44,7 @@ def __init__(self, snirfData=None, geo2d=None, geo3d=None):
window_layout.setSpacing(10)

# Set Minimum Size
self.setMinimumSize(1000, 850)
self.setMinimumSize(800, 600)

# Set Window Title
self.setWindowTitle("Plot Probe")
Expand Down Expand Up @@ -139,6 +140,28 @@ def __init__(self, snirfData=None, geo2d=None, geo3d=None):
## Add Prune Channels
control_panel_layout.addWidget(prune_channels, stretch=1)

# Create t-stat thresh control - only if standard error is provided
if self.stderr is not None:
tstat_control = QtWidgets.QGroupBox("T-Stat Threshold")
tstat_control_layout = QtWidgets.QVBoxLayout()
tstat_control_layout.setSpacing(10)
tstat_control.setLayout(tstat_control_layout)

## Set up T-stat threshold controller
tstat_threshold_layout = QtWidgets.QHBoxLayout()
self.tstat_threshold = QtWidgets.QDoubleSpinBox()
self.tstat_threshold.setValue(0) # Default: no threshold
self.tstat_threshold.setRange(-100, 100)
self.tstat_threshold.setSingleStep(0.5)
self.tstat_threshold.setDecimals(2)
self.tstat_threshold.valueChanged.connect(self._tstat_threshold_changed)
tstat_threshold_layout.addWidget(QtWidgets.QLabel("Threshold"))
tstat_threshold_layout.addWidget(self.tstat_threshold)
tstat_control_layout.addLayout(tstat_threshold_layout)

## Add T-stat control
control_panel_layout.addWidget(tstat_control, stretch=1)

# Create Probe Control
probe_control = QtWidgets.QGroupBox("Probe")
probe_control_layout = QtWidgets.QVBoxLayout()
Expand Down Expand Up @@ -186,7 +209,7 @@ def __init__(self, snirfData=None, geo2d=None, geo3d=None):
# control_panel_layout.addWidget(ref_point,stretch=1)

# Create button action for opening file
open_btn = QtWidgets.QAction("Open...", self)
open_btn = QtGui.QAction("Open...", self)
open_btn.setStatusTip("Open SNIRF file")
open_btn.triggered.connect(self._open_dialog)

Expand All @@ -199,9 +222,10 @@ def __init__(self, snirfData=None, geo2d=None, geo3d=None):
file_menu.addAction(open_btn)

if self.snirfData is not None:
time_dim = 'reltime' if 'reltime' in self.snirfData.dims else 'time' # Detect time dimension
if np.shape(self.snirfData)[1] != len(self.snirfData.channel):
self.snirfData = self.snirfData.transpose(
"trial_type", "channel", "chromo", "reltime"
"trial_type", "channel", "chromo", time_dim
)

self.sPos = self.geo2d.sel(
Expand Down Expand Up @@ -272,6 +296,25 @@ def _init_calc(self):
self.fade_factor = 0.3 ##### Connect?
self.lineWidth = 0.7 ##### Connect?

if 'reltime' in self.snirfData.dims: # handle 'time' or 'reltime'
self.time_dim = 'reltime'
elif 'time' in self.snirfData.dims:
self.time_dim = 'time'
else:
raise ValueError("Data must have either 'time' or 'reltime' dimension")


# T-stat calculation
self.tstat_thresh = 0 # Initialize threshold
if self.stderr is not None:
# Calculate t-statistic: mean / stderr
self.tstat = self.snirfData / self.stderr
# Get max absolute t-stat per channel across time for thresholding
self.tstat_max = np.abs(self.tstat).max(dim=self.time_dim) # Max across time
else:
self.tstat = None
self.tstat_max = None

self.conditions.clear()
self.opt2circ.setChecked(False)
self.measline.setChecked(False)
Expand Down Expand Up @@ -354,14 +397,21 @@ def _init_calc(self):

# Extract time information
try:
self.t = self.snirfData.time.values
except Exception:
pass

try:
self.t = self.snirfData.reltime.values
self.t = self.snirfData[self.time_dim].values # CHANGE to use self.time_dim
except Exception:
pass
# Fallback to trying both
try:
self.t = self.snirfData.time.values
except Exception:
self.t = self.snirfData.reltime.values
# try:
# self.t = self.snirfData.time.values
# except Exception:
# pass
# try:
# self.t = self.snirfData.reltime.values
# except Exception:
# pass

self.minT = min(self.t)
self.maxT = max(self.t)
Expand Down Expand Up @@ -424,37 +474,86 @@ def _init_calc(self):
self._draw_hrf()
self.conditions.setCurrentRow(0)

# def _change_hrf_vis(self): # orig
# for i_con in range(self.trial_types):
# if i_con == self.conditions.currentRow():
# for i_ch in range(self.channels):
# if (
# self.chan_dist[i_ch] >= self.channel_min_dist
# and self.chan_dist[i_ch] <= self.ssFadeThres
# ):
# for i_col in range(self.chromophores):
# self.hrf[
# i_con * self.channels * self.chromophores
# + i_ch * self.chromophores
# + i_col
# ].set_color(self.chrom[i_col] + [self.fade_factor])
# elif (
# self.chan_dist[i_ch] >= self.ssFadeThres
# and self.chan_dist[i_ch] <= self.channel_max_dist
# ):
# for i_col in range(self.chromophores):
# self.hrf[
# i_con * self.channels * self.chromophores
# + i_ch * self.chromophores
# + i_col
# ].set_color(self.chrom[i_col] + [1])
# else:
# for i_col in range(self.chromophores):
# self.hrf[
# i_con * self.channels * self.chromophores
# + i_ch * self.chromophores
# + i_col
# ].set_color(self.chrom[i_col] + [0])
# else:
# for i_ch in range(self.channels):
# for i_col in range(self.chromophores):
# self.hrf[
# i_con * self.channels * self.chromophores
# + i_ch * self.chromophores
# + i_col
# ].set_color(self.chrom[i_col] + [0])

# self._ax.figure.canvas.draw()

def _change_hrf_vis(self):
for i_con in range(self.trial_types):
if i_con == self.conditions.currentRow():
for i_ch in range(self.channels):
# Check if channel meets t-stat threshold
meets_tstat = True
if self.tstat_max is not None:
# Check if ANY chromophore meets threshold for this channel
meets_tstat = any(
self.tstat_max.sel(trial_type=self.snirfData.trial_type[i_con]).values[i_ch, i_col] >= self.tstat_thresh
for i_col in range(self.chromophores)
)

# Determine alpha based on distance and t-stat
if (
self.chan_dist[i_ch] >= self.channel_min_dist
and self.chan_dist[i_ch] <= self.ssFadeThres
):
for i_col in range(self.chromophores):
self.hrf[
i_con * self.channels * self.chromophores
+ i_ch * self.chromophores
+ i_col
].set_color(self.chrom[i_col] + [self.fade_factor])
base_alpha = self.fade_factor
elif (
self.chan_dist[i_ch] >= self.ssFadeThres
and self.chan_dist[i_ch] <= self.channel_max_dist
):
for i_col in range(self.chromophores):
self.hrf[
i_con * self.channels * self.chromophores
+ i_ch * self.chromophores
+ i_col
].set_color(self.chrom[i_col] + [1])
base_alpha = 1
else:
for i_col in range(self.chromophores):
self.hrf[
i_con * self.channels * self.chromophores
+ i_ch * self.chromophores
+ i_col
].set_color(self.chrom[i_col] + [0])
base_alpha = 0

# Apply t-stat threshold: fade if doesn't meet threshold
if not meets_tstat and base_alpha > 0:
base_alpha = base_alpha * 0.5 # Further fade channels below threshold

# Set color for each chromophore
for i_col in range(self.chromophores):
self.hrf[
i_con * self.channels * self.chromophores
+ i_ch * self.chromophores
+ i_col
].set_color(self.chrom[i_col] + [base_alpha])
else:
for i_ch in range(self.channels):
for i_col in range(self.chromophores):
Expand Down Expand Up @@ -491,6 +590,8 @@ def _toggle_circles(self):
if self.opt2circ.isChecked():
self.src_optodes.set_color([1, 0, 0])
self.det_optodes.set_color([0, 0, 1])
self.src_optodes.set_markersize(3) # ADD THIS - Make circles smaller
self.det_optodes.set_markersize(3) # ADD THIS - Make circles smaller

for idx, source in enumerate(self.sPos.label):
self.src_label[idx].set_color([1, 0, 0, 0])
Expand All @@ -499,6 +600,8 @@ def _toggle_circles(self):
else:
self.src_optodes.set_color([1, 0, 0, 0])
self.det_optodes.set_color([0, 0, 1, 0])
self.src_optodes.set_markersize(5) # ADD THIS - Reset to original size
self.det_optodes.set_markersize(5) # ADD THIS - Reset to original size

for idx, source in enumerate(self.sPos.label):
self.src_label[idx].set_color([1, 0, 0, 1])
Expand Down Expand Up @@ -567,6 +670,11 @@ def _ssfade_changed(self, i):
self.ssFadeThres = i
self._change_hrf_vis()

def _tstat_threshold_changed(self, i):
# Pass the new t-stat threshold and update HRF visibility
self.tstat_thresh = i
self._change_hrf_vis()

def _draw_hrf(self):
print("Plotting Optodes!")
t0 = time.time()
Expand Down Expand Up @@ -656,6 +764,7 @@ def run_vis(
blockaverage: cdt.NDTimeSeries,
geo2d: cdt.LabeledPoints,
geo3d: cdt.LabeledPoints,
stderr: cdt.NDTimeSeries = None, # optional standerr input
):
"""Opens the visualization GUI.

Expand All @@ -666,6 +775,7 @@ def run_vis(
"""

app = QtWidgets.QApplication(sys.argv)
main_gui = _MAIN_GUI(snirfData=blockaverage, geo2d=geo2d, geo3d=geo3d)
#main_gui = _MAIN_GUI(snirfData=blockaverage, geo2d=geo2d, geo3d=geo3d)
main_gui = _MAIN_GUI(snirfData=blockaverage, stderr=stderr, geo2d=geo2d, geo3d=geo3d)
main_gui.show()
sys.exit(app.exec())
Loading