Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 157 additions & 114 deletions om_py_plot_surfs.py → plot_brain.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,21 @@
# xvfb-run --server-args="-screen 0 1024x768x24" python om_py_plot_surfs.py
import sip
sip.setapi('QDate', 2)
sip.setapi('QString', 2)
sip.setapi('QTextStream', 2)
sip.setapi('QTime', 2)
sip.setapi('QUrl', 2)
sip.setapi('QVariant', 2)
sip.setapi('QDateTime', 2)
import matplotlib.pyplot as plt
import os
import numpy as np
import nibabel as nb
import nibabel.gifti as gifti
from mayavi import mlab
from tvtk.api import tvtk
import math
import argparse
# To run, execute these steps:
# srun -p om_interactive -N1 -c2 --mem=8G --pty bash
# source activate mathiasg_vd_env
# module add openmind/xvfb-fix/0.1
# export QT_API=pyqt
# python plot_brain.py -i <input> -o <output> -c <path to conte atlas> -r <path to resting atlas>

# https://github.com/cgoldberg/xvfbwrapper

def rotation_matrix(axis, theta):
"""
Return the rotation matrix associated with counterclockwise rotation about
the given axis by theta radians.
"""
"""

import numpy as np
import math

axis = np.asarray(axis)
theta = np.asarray(theta)
axis = axis/math.sqrt(np.dot(axis, axis))
Expand All @@ -33,67 +27,142 @@ def rotation_matrix(axis, theta):
[2*(bc-ad), aa+cc-bb-dd, 2*(cd+ab)],
[2*(bd+ac), 2*(cd-ab), aa+dd-bb-cc]])

def useZstat(zstat, file_path_name_save, file_path_conte, file_path_name_resting_atlas):
"""Plot and save the image.

Arguments
---------
zstat : string
Full file path and name to nii to plot.

file_path_name_save : string
Full file path and name to png output. Output dir will be created if it doesn't exist.

file_path_conte : string
Full file path to Conte atlas

file_path_name_resting_atlas : string

Returns
-------
None. Normal error message:
pixdim[1,2,3] should be non-zero; setting 0 dims to 1
plot_brain.py: Fatal IO error: client killed

Example
-------
python plot_brain.py -i /groupAnalysis/l2/zstat1_threshold.nii.gz -o /plots/l2test.png -c /git/bp2/32k_ConteAtlas_v2 -r rfMRI_REST1_LR_Atlas.dtseries.nii

MIT OM Specific Tip
-------------------
Call this function from a shell script to run headerless BUT requires:
source activate mathiasg_vd_env
export QT_API=pyqt
module add openmind/xvfb-fix/0.1

#file_path_name=$1
#file_path_name_save=$2
#file_path_conte=$3
#file_path_name_resting_atlas=$4
python plot_brain.py \
-i $1 \
-o $2 \
-c $3 \
-r $4

"""

import matplotlib.pyplot as plt
import os
from glob import glob
import numpy as np
import nibabel as nb
import nibabel.gifti as gifti

# Crucial: xvfb must be imported and started before importing mayavi
from xvfbwrapper import Xvfb
print('XVb pre')
vdisplay = Xvfb()
vdisplay.start()

print('pre maya')
# Crashes on this line if run with plain python (not xvfb-run ... python) and if xvfbwrapper is after it.
from mayavi import mlab
print('post maya')
from tvtk.api import tvtk
print('post tvtk')
import math

def make_plot(stat, task, contrast, num, outdir, inflated,
split_brain, dual_split, threshold,
display_threshold, atlas_dir):
# load from here until can include
try:
img = nb.load('/om/user/mathiasg/rfMRI_REST1_LR_Atlas.dtseries.nii')
except:
print('File missing - message mathiasg@mit.edu')
raise FileNotFoundError
mim=img.header.matrix.mims[1]
try:
bm1 = mim.brain_models[0]
lidx = bm1.vertex_indices.indices
bm2 = mim.brain_models[1]
ridx = bm1.surface_number_of_vertices + bm2.vertex_indices.indices
except AttributeError: #older citfi version
bm1 = mim.brainModels[0]
lidx = bm1.vertexIndices.indices
bm2 = mim.brainModels[1]
ridx = bm1.surfaceNumberOfVertices + bm2.vertexIndices.indices
print('display')
mlab.options.offscreen = True #offscreen window for rendering

img = nb.load(file_path_name_resting_atlas)
#img = nb.load('/Users/MathiasMacbook/Desktop/rfMRI_REST1_LR_Atlas.dtseries.nii')
mim = img.header.matrix.mims[1]
#for idx, bm in enumerate(mim.brainModels):
# print((idx, bm.indexOffset, bm.brainStructure))
bm1 = mim.brainModels[0]
lidx = bm1.vertexIndices.indices
bm2 = mim.brainModels[1]
ridx = bm1.surfaceNumberOfVertices + bm2.vertexIndices.indices
bidx = np.concatenate((lidx, ridx))

axis = [0, 0, 1]
theta = np.pi
try:
surf = gifti.read(os.path.join(atlas_dir,'Conte69.L.midthickness.32k_fs_LR.surf.gii'))
except:
print('Atlas not found - pass in path with flag -a')
raise FileNotFoundError

inflated = True
split_brain = True

surf = gifti.read(file_path_conte + '/Conte69.L.midthickness.32k_fs_LR.surf.gii')
verts_L_data = surf.darrays[0].data
faces_L_data = surf.darrays[1].data
surf = gifti.read(os.path.join(atlas_dir,'Conte69.R.midthickness.32k_fs_LR.surf.gii'))

surf = gifti.read(file_path_conte + '/Conte69.R.midthickness.32k_fs_LR.surf.gii')
verts_R_data = surf.darrays[0].data
faces_R_data = surf.darrays[1].data

if inflated:
surf = gifti.read(os.path.join(atlas_dir,'Conte69.L.inflated.32k_fs_LR.surf.gii'))
surf = gifti.read(file_path_conte + '/Conte69.L.inflated.32k_fs_LR.surf.gii')
verts_L_display = surf.darrays[0].data
faces_L_display = surf.darrays[1].data
surf = gifti.read(os.path.join(atlas_dir,'Conte69.R.inflated.32k_fs_LR.surf.gii'))
surf = gifti.read(file_path_conte + '/Conte69.R.inflated.32k_fs_LR.surf.gii')
verts_R_display = surf.darrays[0].data
faces_R_display = surf.darrays[1].data
else:
verts_L_display = verts_L_data.copy()
verts_R_display = verts_R_data.copy()
faces_L_display = faces_L_data.copy()
faces_R_display = faces_R_data.copy()

verts_L_display[:, 0] -= max(verts_L_display[:, 0])
verts_R_display[:, 0] -= min(verts_R_display[:, 0])
verts_L_display[:, 1] -= (max(verts_L_display[:, 1]) + 1)
verts_R_display[:, 1] -= (max(verts_R_display[:, 1]) + 1)

faces = np.vstack((faces_L_display, verts_L_display.shape[0] + faces_R_display))

if split_brain:
verts2 = rotation_matrix(axis, theta).dot(verts_R_display.T).T
else:
verts_L_display[:, 1] -= np.mean(verts_L_display[:, 1])
verts_R_display[:, 1] -= np.mean(verts_R_display[:, 1])
verts2 = verts_R_display

verts_rot = np.vstack((verts_L_display, verts2))
verts = np.vstack((verts_L_data, verts_R_data))
#load stat
img = nb.load(stat)
#print verts.shape
#print faces.shape

if not os.path.exists(os.path.split(file_path_name_save)[0]):
os.makedirs(os.path.split(file_path_name_save)[0])

print('use zstat')
img = nb.load(zstat)
print('loaded img')

threshold = 2.3 # 1000, lower limit
display_threshold = 6 #8000, upper limit

data = img.get_data()
aff = img.affine
indices = np.round((np.linalg.pinv(aff).dot(np.hstack((verts,
Expand All @@ -102,11 +171,13 @@ def make_plot(stat, task, contrast, num, outdir, inflated,
scalars2[np.abs(scalars2) < threshold] = 0.
scalars = np.zeros(verts.shape[0])
scalars[bidx] = scalars2[bidx]

negative = positive = False
if np.any(scalars < 0):
negative = True
if np.any(scalars > 0):
positive = True

nlabels = 2
vmin = 0
vmax = 0
Expand All @@ -117,20 +188,24 @@ def make_plot(stat, task, contrast, num, outdir, inflated,
vmin = -maxval
vmax = maxval
nlabels = 3
vmin = -display_threshold
vmax = display_threshold
vmin = -display_threshold ######
vmax = display_threshold ######
elif negative:
vmin = scalars.min()
if vmin < -display_threshold:
vmin = -display_threshold
vmax = 0
vmin = -display_threshold
vmin = -display_threshold ######
elif positive:
vmax = scalars.max()
if vmax > display_threshold:
vmax = display_threshold
vmin = 0
vmax = display_threshold
vmax = display_threshold ######
#print zstat

dual_split = True

fig1 = mlab.figure(1, bgcolor=(0, 0, 0))
mlab.clf()
mesh = tvtk.PolyData(points=verts_rot, polys=faces)
Expand All @@ -148,6 +223,7 @@ def make_plot(stat, task, contrast, num, outdir, inflated,
surf2 = mlab.pipeline.surface(mesh2, colormap='autumn', vmin=vmin, vmax=vmax)
colorbar = mlab.colorbar(surf, nb_labels=nlabels) #, orientation='vertical')
lut = surf.module_manager.scalar_lut_manager.lut.table.to_array()

if negative and positive:
half_index = lut.shape[0] / 2
index = int(half_index * threshold / vmax)
Expand All @@ -163,6 +239,7 @@ def make_plot(stat, task, contrast, num, outdir, inflated,
lut[:index, :] = 192
lut[index:, :] = 255 * plt.cm.autumn(np.linspace(0, 255, lut.shape[0] - index).astype(int))
lut[:, -1] = 255

surf.module_manager.scalar_lut_manager.lut.table = lut
if dual_split:
surf2.module_manager.scalar_lut_manager.lut.table = lut
Expand All @@ -172,6 +249,7 @@ def make_plot(stat, task, contrast, num, outdir, inflated,
surf.module_manager.scalar_lut_manager.show_scalar_bar = True
surf.module_manager.scalar_lut_manager.show_legend = True
mlab.draw()

translate = [0, 0, 0]
if inflated:
zoom = -700
Expand All @@ -186,67 +264,32 @@ def make_plot(stat, task, contrast, num, outdir, inflated,
zoom = -750
else:
zoom = -570
mlab.view(0, 90.0, zoom, translate)
if not os.path.exists(outdir):
os.makedirs(outdir)
#os.chdir(outdir)
outname = '%s-%s-%s.png' % (task,contrast,num)
mlab.savefig(os.path.join(outdir,outname), figure=fig1, magnification=5)

def plot_stats(base, tasks, outdir, atlas_dir, inflated=True,
split_brain=True, dual_split=True, threshold=2.3,
display_threshold=6):
for task in tasks:
print('-----%s-----'%task)
taskdir = os.path.join(base,task)
for contrast in os.listdir(taskdir):
cons = os.path.join(taskdir,contrast)
for x in os.listdir(cons):
subpath = os.path.join(cons,x,'zstat1.nii.gz')
print("Converting:\n" + subpath)
make_plot(subpath, task, contrast, x[-1], outdir,
inflated, split_brain, dual_split,
threshold, display_threshold, atlas_dir)
print("Finished!\n")

#mlab.view(0, 90.0, zoom, translate)
mlab.view(9, 90.0)

print(file_path_name_save)

mlab.savefig(file_path_name_save, figure=fig1, magnification=5)

vdisplay.stop()

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data_dir',
dest='data_dir',
required=True,
help='''location of the data to plot''')
parser.add_argument('-t', '--tasks', dest='tasks', required=True,
type=str, nargs='+', help='''list of tasks to get
contrasts''')
parser.add_argument('-o', '--outdir', dest='outputdir',
default=os.getcwd(),
help='''output directory for resulting images''')
parser.add_argument('-a', '--atlas', dest='atlas_dir',
default=os.path.abspath('32k_ConteAtlas_v2'),
help='''brain atlas directory, default cwd''')
parser.add_argument('-i', '--inflated', dest='inflated',
default=True, action='store_false',
help='''disable inflated brain image''')
parser.add_argument('-s', '--split', dest='split_brain',
default=True, action='store_false',
help='''disable split brain image''')
parser.add_argument('-ss', '--duosplit', dest='dual_split',
default=True, action='store_false',
help='''disable dualsplit brain image''')
parser.add_argument('-th', '--threshold', dest='threshold',
type=float, default=2.3,
help='''set threshold value (default=2.3) - must
be a float''')
parser.add_argument('-dt', '--displaythresh', dest='display_threshold',
type=int, default=6,
help='''set min/max for thresholded values - must
be an int''')
import argparse
parser = argparse.ArgumentParser(prog='plot_brain.py',
description=__doc__)
parser.add_argument('-i', '--input', required=True, help='input full file path and name .nii.gz')
parser.add_argument('-o', '--output', required=True, help='output full file pand and name .png')
parser.add_argument('-c', '--file_path_conte', required=True, help='file path to conte atlas folder')
parser.add_argument('-r', '--file_path_name_resting_atlas', required=True, help='resting atlas nii')
args = parser.parse_args()
plot_stats(args.data_dir, args.tasks,
os.path.abspath(args.outputdir),
atlas_dir=os.path.abspath('32k_ConteAtlas_v2'),
inflated=args.inflated,
split_brain=args.split_brain,
dual_split=args.dual_split,
threshold=args.threshold,
display_threshold=args.display_threshold)

useZstat(args.input, args.output, args.file_path_conte, args.file_path_name_resting_atlas)