diff --git a/tvb_scripts/datatypes/base.py b/tvb_scripts/datatypes/base.py
index 57da232..c85d012 100644
--- a/tvb_scripts/datatypes/base.py
+++ b/tvb_scripts/datatypes/base.py
@@ -1,7 +1,9 @@
from copy import deepcopy
-from tvb_scripts.utils.log_error_utils import warning
+
+from tvb.basic.neotraits.api import HasTraits
+
from tvb_scripts.utils.data_structures_utils import labels_to_inds
-from tvb.basic.neotraits.api import HasTraits, Attr
+from tvb_scripts.utils.log_error_utils import warning
class BaseModel(HasTraits):
@@ -40,4 +42,3 @@ def from_tvb_file(cls, filepath, **kwargs):
@staticmethod
def labels2inds(all_labels, labels):
return labels_to_inds(all_labels, labels)
-
diff --git a/tvb_scripts/datatypes/connectivity.py b/tvb_scripts/datatypes/connectivity.py
index dd84778..c7b9f7c 100644
--- a/tvb_scripts/datatypes/connectivity.py
+++ b/tvb_scripts/datatypes/connectivity.py
@@ -1,8 +1,9 @@
# coding=utf-8
-from tvb_scripts.datatypes.base import BaseModel
from tvb.datatypes.connectivity import Connectivity as TVBConnectivity
+from tvb_scripts.datatypes.base import BaseModel
+
class ConnectivityH5Field(object):
WEIGHTS = "weights"
@@ -30,4 +31,4 @@ def centers(self):
# A usefull method for addressing subsets of the connectome by label:
def get_regions_inds_by_labels(self, labels):
- return self.labels2inds(self.region_labels, labels)
+ return self.labels2inds(self.region_labels, labels)
diff --git a/tvb_scripts/datatypes/head.py b/tvb_scripts/datatypes/head.py
index b587348..ae75b70 100644
--- a/tvb_scripts/datatypes/head.py
+++ b/tvb_scripts/datatypes/head.py
@@ -1,17 +1,17 @@
# coding=utf-8
-from six import string_types
from collections import OrderedDict
-from tvb_scripts.utils.log_error_utils import initialize_logger, raise_value_error, warning
-from tvb_scripts.utils.data_structures_utils import isequal_string, is_integer
-from tvb_scripts.datatypes.connectivity import Connectivity
-from tvb_scripts.datatypes.sensors import SensorTypes, SensorTypesNames, SensorTypesToProjectionDict
-
-from tvb.datatypes.surfaces import CorticalSurface
+from six import string_types
from tvb.datatypes.cortex import Cortex
-from tvb.datatypes.sensors import Sensors
from tvb.datatypes.projections import ProjectionMatrix
+from tvb.datatypes.sensors import Sensors
+from tvb.datatypes.surfaces import CorticalSurface
+
+from tvb_scripts.datatypes.connectivity import Connectivity
+from tvb_scripts.datatypes.sensors import SensorTypes, SensorTypesNames, SensorTypesToProjectionDict
+from tvb_scripts.utils.data_structures_utils import isequal_string, is_integer
+from tvb_scripts.utils.log_error_utils import initialize_logger, raise_value_error, warning
class Head(object):
diff --git a/tvb_scripts/datatypes/sensors.py b/tvb_scripts/datatypes/sensors.py
index bfaebae..96991cc 100644
--- a/tvb_scripts/datatypes/sensors.py
+++ b/tvb_scripts/datatypes/sensors.py
@@ -2,21 +2,20 @@
from enum import Enum
import numpy as np
-
-from tvb_scripts.utils.log_error_utils import warning
-from tvb_scripts.utils.data_structures_utils import ensure_list, \
- labels_to_inds, monopolar_to_bipolar, split_string_text_numbers
-from tvb_scripts.datatypes.base import BaseModel
-
from tvb.basic.neotraits.api import Attr, NArray
+from tvb.datatypes.projections import \
+ ProjectionSurfaceEEG, ProjectionSurfaceMEG, ProjectionSurfaceSEEG
+from tvb.datatypes.sensors import EEG_POLYMORPHIC_IDENTITY, MEG_POLYMORPHIC_IDENTITY, \
+ INTERNAL_POLYMORPHIC_IDENTITY
from tvb.datatypes.sensors import Sensors as TVBSensors
from tvb.datatypes.sensors import SensorsEEG as TVBSensorsEEG
-from tvb.datatypes.sensors import SensorsMEG as TVBSensorsMEG
from tvb.datatypes.sensors import SensorsInternal as TVBSensorsInternal
-from tvb.datatypes.sensors import EEG_POLYMORPHIC_IDENTITY, MEG_POLYMORPHIC_IDENTITY, \
- INTERNAL_POLYMORPHIC_IDENTITY
-from tvb.datatypes.projections import \
- ProjectionSurfaceEEG, ProjectionSurfaceMEG, ProjectionSurfaceSEEG
+from tvb.datatypes.sensors import SensorsMEG as TVBSensorsMEG
+
+from tvb_scripts.datatypes.base import BaseModel
+from tvb_scripts.utils.data_structures_utils import ensure_list, \
+ labels_to_inds, monopolar_to_bipolar, split_string_text_numbers
+from tvb_scripts.utils.log_error_utils import warning
class SensorTypes(Enum):
@@ -28,7 +27,6 @@ class SensorTypes(Enum):
SensorTypesNames = [getattr(SensorTypes, stype).value for stype in SensorTypes.__members__]
-
SensorTypesToProjectionDict = {"EEG": ProjectionSurfaceEEG,
"MEG": ProjectionSurfaceMEG,
"SEEG": ProjectionSurfaceSEEG,
@@ -51,8 +49,8 @@ class Sensors(TVBSensors, BaseModel):
def configure(self, remove_leading_zeros_from_labels=False):
if len(self.labels) > 0:
- if remove_leading_zeros_from_labels:
- self.remove_leading_zeros_from_labels()
+ if remove_leading_zeros_from_labels:
+ self.remove_leading_zeros_from_labels()
self.configure()
def sensor_label_to_index(self, labels):
@@ -86,7 +84,7 @@ def get_bipolar_sensors(self, sensors_inds=None):
class SensorsEEG(Sensors, TVBSensorsEEG):
- pass
+ pass
class SensorsMEG(Sensors, TVBSensorsMEG):
diff --git a/tvb_scripts/datatypes/surface.py b/tvb_scripts/datatypes/surface.py
index 557db3b..bc5c85b 100644
--- a/tvb_scripts/datatypes/surface.py
+++ b/tvb_scripts/datatypes/surface.py
@@ -1,17 +1,17 @@
# coding=utf-8
import numpy as np
-
-from tvb_scripts.datatypes.base import BaseModel
from tvb.basic.neotraits.api import NArray, Attr
-from tvb.datatypes.surfaces import Surface as TVBSurface
-from tvb.datatypes.surfaces import WhiteMatterSurface as TVBWhiteMatterSurface
-from tvb.datatypes.surfaces import CorticalSurface as TVBCorticalSurface
-from tvb.datatypes.surfaces import SkinAir as TVBSkinAir
from tvb.datatypes.surfaces import BrainSkull as TVBBrainSkull
-from tvb.datatypes.surfaces import SkullSkin as TVBSkullSkin
+from tvb.datatypes.surfaces import CorticalSurface as TVBCorticalSurface
from tvb.datatypes.surfaces import EEGCap as TVBEEGCap
from tvb.datatypes.surfaces import FaceSurface as TVBFaceSurface
+from tvb.datatypes.surfaces import SkinAir as TVBSkinAir
+from tvb.datatypes.surfaces import SkullSkin as TVBSkullSkin
+from tvb.datatypes.surfaces import Surface as TVBSurface
+from tvb.datatypes.surfaces import WhiteMatterSurface as TVBWhiteMatterSurface
+
+from tvb_scripts.datatypes.base import BaseModel
class SurfaceH5Field(object):
@@ -23,7 +23,6 @@ class SurfaceH5Field(object):
class Surface(TVBSurface, BaseModel):
-
vox2ras = NArray(
dtype=np.float,
label="vox2ras", default=None, required=False,
@@ -52,9 +51,9 @@ def get_vertex_areas(self):
return vertex_areas
def add_vertices_and_triangles(self, new_vertices, new_triangles,
- new_vertex_normals=np.array([]), new_triangle_normals=np.array([])):
+ new_vertex_normals=np.array([]), new_triangle_normals=np.array([])):
self.triangles = np.array(self.triangles.tolist() +
- (new_triangles + self.number_of_vertices).tolist())
+ (new_triangles + self.number_of_vertices).tolist())
self.vertices = np.array(self.vertices.tolist() + new_vertices.tolist())
self.vertex_normals = np.array(self.vertex_normals.tolist() + new_vertex_normals.tolist())
self.triangle_normals = np.array(self.triangle_normals.tolist() + new_triangle_normals.tolist())
diff --git a/tvb_scripts/datatypes/time_series.py b/tvb_scripts/datatypes/time_series.py
index a17d21d..5f97990 100644
--- a/tvb_scripts/datatypes/time_series.py
+++ b/tvb_scripts/datatypes/time_series.py
@@ -1,22 +1,23 @@
# -*- coding: utf-8 -*-
-from six import string_types
-from enum import Enum
from copy import deepcopy
+from enum import Enum
import numpy
-from tvb_scripts.utils.log_error_utils import initialize_logger, warning
-from tvb_scripts.utils.data_structures_utils import ensure_list, is_integer, monopolar_to_bipolar
+from six import string_types
from tvb.basic.neotraits.api import List, Attr
from tvb.basic.profile import TvbProfile
+from tvb.datatypes.sensors import Sensors, SensorsEEG, SensorsMEG, SensorsInternal
from tvb.datatypes.time_series import TimeSeries as TimeSeriesTVB
-from tvb.datatypes.time_series import TimeSeriesRegion as TimeSeriesRegionTVB
from tvb.datatypes.time_series import TimeSeriesEEG as TimeSeriesEEGTVB
from tvb.datatypes.time_series import TimeSeriesMEG as TimeSeriesMEGTVB
+from tvb.datatypes.time_series import TimeSeriesRegion as TimeSeriesRegionTVB
from tvb.datatypes.time_series import TimeSeriesSEEG as TimeSeriesSEEGTVB
from tvb.datatypes.time_series import TimeSeriesSurface as TimeSeriesSurfaceTVB
from tvb.datatypes.time_series import TimeSeriesVolume as TimeSeriesVolumeTVB
-from tvb.datatypes.sensors import Sensors, SensorsEEG, SensorsMEG, SensorsInternal
+
+from tvb_scripts.utils.data_structures_utils import ensure_list, is_integer, monopolar_to_bipolar
+from tvb_scripts.utils.log_error_utils import initialize_logger, warning
TvbProfile.set_profile(TvbProfile.LIBRARY_PROFILE)
@@ -252,16 +253,16 @@ def slice_data_across_dimension_by_index(self, indices, dimension, **kwargs):
def slice_data_across_dimension_by_label(self, labels, dimension, **kwargs):
dim_index = self.get_dimension_index(dimension)
return self.slice_data_across_dimension_by_index(
- self._get_index_of_label(labels,
- self.get_dimension_name(dim_index)),
- dim_index, **kwargs)
+ self._get_index_of_label(labels,
+ self.get_dimension_name(dim_index)),
+ dim_index, **kwargs)
def slice_data_across_dimension_by_slice(self, slice_arg, dimension, **kwargs):
dim_index = self.get_dimension_index(dimension)
return self.slice_data_across_dimension_by_index(
- self._slice_to_indices(
- self._process_slice(slice_arg, dim_index), dim_index),
- dim_index, **kwargs)
+ self._slice_to_indices(
+ self._process_slice(slice_arg, dim_index), dim_index),
+ dim_index, **kwargs)
def _index_or_label_or_slice(self, inputs):
inputs = ensure_list(inputs)
@@ -589,7 +590,6 @@ def SEEGsensor_labels(self):
TimeSeriesMEG.__name__: TimeSeriesMEG,
TimeSeriesSEEG.__name__: TimeSeriesSEEG}
-
if __name__ == "__main__":
kwargs = {"data": numpy.ones((4, 2, 10, 1)), "start_time": 0.0,
"labels_dimensions": {LABELS_ORDERING[1]: ["x", "y"]}}
diff --git a/tvb_scripts/datatypes/time_series_xarray.py b/tvb_scripts/datatypes/time_series_xarray.py
index 5c5a3e7..db08090 100644
--- a/tvb_scripts/datatypes/time_series_xarray.py
+++ b/tvb_scripts/datatypes/time_series_xarray.py
@@ -35,11 +35,13 @@
"""
from copy import deepcopy
-from six import string_types
-import xarray as xr
+
import numpy as np
-from tvb.datatypes import sensors, surfaces, volumes, region_mapping, connectivity
+import xarray as xr
+from six import string_types
from tvb.basic.neotraits.api import HasTraits, Attr, List, narray_summary_info
+from tvb.datatypes import sensors, surfaces, volumes, region_mapping, connectivity
+
from tvb_scripts.datatypes.time_series import TimeSeries as TimeSeriesTVB
from tvb_scripts.utils.data_structures_utils import is_integer
@@ -358,7 +360,7 @@ def duplicate(self, **kwargs):
def _assert_array_indices(self, slice_tuple):
if is_integer(slice_tuple) or isinstance(slice_tuple, string_types):
- return ([slice_tuple], )
+ return ([slice_tuple],)
else:
if isinstance(slice_tuple, slice):
slice_tuple = (slice_tuple,)
@@ -547,9 +549,9 @@ def plot_timeseries(self, **kwargs):
outputs.append(self[:, var].plot_timeseries(**kwargs))
return outputs
if np.any([s < 2 for s in self.shape[1:]]):
- if self.shape[1] == 1: # only one variable
+ if self.shape[1] == 1: # only one variable
figname = kwargs.pop("figname", "%s" % (self.title + "Time Series")) + ": " \
- + self.labels_dimensions[self.labels_ordering[1]][0]
+ + self.labels_dimensions[self.labels_ordering[1]][0]
kwargs["figname"] = figname
return self.plot_line(**kwargs)
else:
diff --git a/tvb_scripts/io/edf.py b/tvb_scripts/io/edf.py
index 3b3037d..2220e6b 100644
--- a/tvb_scripts/io/edf.py
+++ b/tvb_scripts/io/edf.py
@@ -2,9 +2,9 @@
import numpy as np
+from tvb_scripts.datatypes.time_series import TimeSeries, TimeSeriesDimensions
from tvb_scripts.utils.log_error_utils import initialize_logger
from tvb_scripts.utils.data_structures_utils import ensure_string
-from tvb_scripts.model.timeseries import Timeseries, TimeseriesDimensions
def read_edf_with_mne(path, exclude_channels):
@@ -79,5 +79,5 @@ def read_edf_to_Timeseries(path, sensors, rois_selection=None, label_strip_fun=N
data, times, rois, rois_inds, rois_lbls = \
read_edf(path, sensors, rois_selection, label_strip_fun, time_unit)
- return Timeseries(data, time=times, labels_dimensions={TimeseriesDimensions.SPACE.value: rois_lbls},
+ return TimeSeries(data, time=times, labels_dimensions={TimeSeriesDimensions.SPACE.value: rois_lbls},
sample_period=np.mean(np.diff(times)), sample_period_unit=time_unit, **kwargs)
diff --git a/tvb_scripts/io/h5_reader_base.py b/tvb_scripts/io/h5_reader_base.py
index 40e4dcf..420383f 100644
--- a/tvb_scripts/io/h5_reader_base.py
+++ b/tvb_scripts/io/h5_reader_base.py
@@ -1,10 +1,11 @@
# -*- coding: utf-8 -*-
import os
+
import h5py
-from tvb_scripts.utils.log_error_utils import initialize_logger
from tvb_scripts.io.h5_writer import H5Writer
+from tvb_scripts.utils.log_error_utils import initialize_logger
class H5ReaderBase(object):
@@ -33,6 +34,7 @@ def _log_success(self, name, path=None):
class H5GroupHandlers(object):
+ H5_SUBTYPE_ATTRIBUTE = H5Writer().H5_SUBTYPE_ATTRIBUTE
def read_dictionary_from_group(self, group, type=None):
dictionary = dict()
@@ -41,6 +43,6 @@ def read_dictionary_from_group(self, group, type=None):
for attr in group.attrs.keys():
dictionary.update({attr: group.attrs[attr]})
if type is None:
- type = group.attrs[H5_SUBTYPE_ATTRIBUTE]
+ type = group.attrs[self.H5_SUBTYPE_ATTRIBUTE]
else:
return dictionary
diff --git a/tvb_scripts/plot/base_plotter.py b/tvb_scripts/plot/base_plotter.py
deleted file mode 100644
index 4398cdc..0000000
--- a/tvb_scripts/plot/base_plotter.py
+++ /dev/null
@@ -1,449 +0,0 @@
-# coding=utf-8
-
-import os
-import numpy
-
-from tvb_scripts.config import Config, FiguresConfig
-
-import matplotlib
-matplotlib.use(FiguresConfig().MATPLOTLIB_BACKEND)
-from matplotlib import pyplot
-pyplot.rcParams["font.size"] = FiguresConfig.FONTSIZE
-from mpl_toolkits.axes_grid1 import make_axes_locatable
-
-from tvb_scripts.utils.log_error_utils import initialize_logger, warning
-from tvb_scripts.utils.data_structures_utils import ensure_list, generate_region_labels
-
-
-class BasePlotter(object):
-
- def __init__(self, config=None):
- self.config = config or Config()
- self.logger = initialize_logger(self.__class__.__name__, self.config.out.FOLDER_LOGS)
- self.print_regions_indices = True
-
- def _check_show(self):
- if self.config.figures.SHOW_FLAG:
- # mp.use('TkAgg')
- pyplot.ion()
- pyplot.show()
- else:
- # mp.use('Agg')
- pyplot.ioff()
- pyplot.close()
-
- @staticmethod
- def _figure_filename(fig=pyplot.gcf(), figure_name=None):
- if figure_name is None:
- figure_name = fig.get_label()
- figure_name = figure_name.replace(": ", "_").replace(" ", "_").replace("\t", "_").replace(",", "")
- return figure_name
-
- def _save_figure(self, fig=pyplot.gcf(), figure_name=None):
- if self.config.figures.SAVE_FLAG:
- figure_name = self._figure_filename(fig, figure_name)
- figure_name = figure_name[:numpy.min([100, len(figure_name)])] + '.' + self.config.figures.FIG_FORMAT
- figure_dir = self.config.out.FOLDER_FIGURES
- if not (os.path.isdir(figure_dir)):
- os.mkdir(figure_dir)
- pyplot.savefig(os.path.join(figure_dir, figure_name))
-
- @staticmethod
- def rect_subplot_shape(n, mode="col"):
- nj = int(numpy.ceil(numpy.sqrt(n)))
- ni = int(numpy.ceil(1.0 * n / nj))
- if mode.find("row") >= 0:
- return nj, ni
- else:
- return ni, nj
-
- def plot_vector(self, vector, labels, subplot, title, show_y_labels=True, indices_red=None, sharey=None):
- ax = pyplot.subplot(subplot, sharey=sharey)
- pyplot.title(title)
- n_vector = labels.shape[0]
- y_ticks = numpy.array(range(n_vector), dtype=numpy.int32)
- color = 'k'
- colors = numpy.repeat([color], n_vector)
- coldif = False
- if indices_red is not None:
- colors[indices_red] = 'r'
- coldif = True
- if len(vector.shape) == 1:
- ax.barh(y_ticks, vector, color=colors, align='center')
- else:
- ax.barh(y_ticks, vector[0, :], color=colors, align='center')
- # ax.invert_yaxis()
- ax.grid(True, color='grey')
- ax.set_yticks(y_ticks)
- if show_y_labels:
- region_labels = generate_region_labels(n_vector, labels, ". ", self.print_regions_indices)
- ax.set_yticklabels(region_labels)
- if coldif:
- labels = ax.yaxis.get_ticklabels()
- for ids in indices_red:
- labels[ids].set_color('r')
- ax.yaxis.set_ticklabels(labels)
- else:
- ax.set_yticklabels([])
- ax.autoscale(tight=True)
- if sharey is None:
- ax.invert_yaxis()
- return ax
-
- def plot_vector_violin(self, dataset, vector=[], lines=[], labels=[], subplot=111, title="", violin_flag=True,
- colormap="YlOrRd", show_y_labels=True, indices_red=None, sharey=None):
- ax = pyplot.subplot(subplot, sharey=sharey)
- pyplot.title(title)
- n_violins = dataset.shape[1]
- y_ticks = numpy.array(range(n_violins), dtype=numpy.int32)
- # the vector plot
- coldif = False
- if indices_red is None:
- indices_red = []
- if violin_flag:
- # the violin plot
- colormap = matplotlib.cm.ScalarMappable(cmap=pyplot.set_cmap(colormap))
- colormap = colormap.to_rgba(numpy.mean(dataset, axis=0), alpha=0.75)
- violin_parts = ax.violinplot(dataset, y_ticks, vert=False, widths=0.9,
- showmeans=True, showmedians=True, showextrema=True)
- violin_parts['cmeans'].set_color("k")
- violin_parts['cmins'].set_color("b")
- violin_parts['cmaxes'].set_color("b")
- violin_parts['cbars'].set_color("b")
- violin_parts['cmedians'].set_color("b")
- for ii in range(len(violin_parts['bodies'])):
- violin_parts['bodies'][ii].set_color(numpy.reshape(colormap[ii], (1, 4)))
- violin_parts['bodies'][ii]._alpha = 0.75
- violin_parts['bodies'][ii]._edgecolors = numpy.reshape(colormap[ii], (1, 4))
- violin_parts['bodies'][ii]._facecolors = numpy.reshape(colormap[ii], (1, 4))
- else:
- colorcycle = pyplot.rcParams['axes.prop_cycle'].by_key()['color']
- n_samples = dataset.shape[0]
- for ii in range(n_violins):
- for jj in range(n_samples):
- ax.plot(dataset[jj, ii], y_ticks[ii], "D",
- mfc=colorcycle[jj%n_samples], mec=colorcycle[jj%n_samples], ms=20)
- color = 'k'
- colors = numpy.repeat([color], n_violins)
- if indices_red is not None:
- colors[indices_red] = 'r'
- coldif = True
- if len(vector) == n_violins:
- for ii in range(n_violins):
- ax.plot(vector[ii], y_ticks[ii], '*', mfc=colors[ii], mec=colors[ii], ms=10)
- if len(lines) == 2 and lines[0].shape[0] == n_violins and lines[1].shape[0] == n_violins:
- for ii in range(n_violins):
- yy = (y_ticks[ii] - 0.45*lines[1][ii]/numpy.max(lines[1][ii]))\
- * numpy.ones(numpy.array(lines[0][ii]).shape)
- ax.plot(lines[0][ii], yy, '--', color=colors[ii])
-
- ax.grid(True, color='grey')
- ax.set_yticks(y_ticks)
- if show_y_labels:
- region_labels = generate_region_labels(n_violins, labels, ". ", self.print_regions_indices)
- ax.set_yticklabels(region_labels)
- if coldif:
- labels = ax.yaxis.get_ticklabels()
- for ids in indices_red:
- labels[ids].set_color('r')
- ax.yaxis.set_ticklabels(labels)
- else:
- ax.set_yticklabels([])
- if sharey is None:
- ax.invert_yaxis()
- ax.autoscale()
- return ax
-
- def _plot_matrix(self, matrix, xlabels, ylabels, subplot=111, title="", show_x_labels=True, show_y_labels=True,
- x_ticks=numpy.array([]), y_ticks=numpy.array([]), indices_red_x=None, indices_red_y=None,
- sharex=None, sharey=None, cmap='autumn_r', vmin=None, vmax=None):
- ax = pyplot.subplot(subplot, sharex=sharex, sharey=sharey)
- pyplot.title(title)
- nx, ny = matrix.shape
- indices_red = [indices_red_x, indices_red_y]
- ticks = [x_ticks, y_ticks]
- labels = [xlabels, ylabels]
- nticks = []
- for ii, (n, tick) in enumerate(zip([nx, ny], ticks)):
- if len(tick) == 0:
- ticks[ii] = numpy.array(range(n), dtype=numpy.int32)
- nticks.append(len(ticks[ii]))
- cmap = pyplot.set_cmap(cmap)
- img = pyplot.imshow(matrix[ticks[0]][:, ticks[1]].T, cmap=cmap, vmin=vmin, vmax=vmax, interpolation='none')
- pyplot.grid(True, color='black')
- for ii, (xy, tick, ntick, ind_red, show, lbls, rot) in enumerate(zip(["x", "y"], ticks, nticks, indices_red,
- [show_x_labels, show_y_labels], labels, [90, 0])):
- if show:
- labels[ii] = generate_region_labels(len(tick), numpy.array(lbls)[tick], ". ", self.print_regions_indices, tick)
- # labels[ii] = numpy.array(["%d. %s" % l for l in zip(tick, lbls[tick])])
- getattr(pyplot, xy + "ticks")(numpy.array(range(ntick)), labels[ii], rotation=rot)
- else:
- labels[ii] = numpy.array(["%d." % l for l in tick])
- getattr(pyplot, xy + "ticks")(numpy.array(range(ntick)), labels[ii])
- if ind_red is not None:
- tck = tick.tolist()
- ticklabels = getattr(ax, xy + "axis").get_ticklabels()
- for iidx, indr in enumerate(ind_red):
- try:
- ticklabels[tck.index(indr)].set_color('r')
- except:
- pass
- getattr(ax, xy + "axis").set_ticklabels(ticklabels)
- ax.autoscale(tight=True)
- divider = make_axes_locatable(ax)
- cax1 = divider.append_axes("right", size="5%", pad=0.05)
- pyplot.colorbar(img, cax=cax1) # fraction=0.046, pad=0.04) #fraction=0.15, shrink=1.0
- return ax, cax1
-
- def plot_regions2regions(self, adj, labels, subplot, title, show_x_labels=True, show_y_labels=True,
- x_ticks=numpy.array([]), y_ticks=numpy.array([]), indices_red_x=None, indices_red_y=None,
- sharex=None, sharey=None, cmap='autumn_r', vmin=None, vmax=None):
- return self._plot_matrix(adj, labels, labels, subplot, title, show_x_labels, show_y_labels,
- x_ticks, y_ticks, indices_red_x, indices_red_y, sharex, sharey, cmap, vmin, vmax)
-
- def _set_axis_labels(self, fig, sub, n_regions, region_labels, indices2emphasize, color='k', position='left'):
- y_ticks = range(n_regions)
- # region_labels = numpy.array(["%d. %s" % l for l in zip(y_ticks, region_labels)])
- region_labels = generate_region_labels(len(y_ticks), region_labels, ". ", self.print_regions_indices, y_ticks)
- big_ax = fig.add_subplot(sub, frameon=False)
- if position == 'right':
- big_ax.yaxis.tick_right()
- big_ax.yaxis.set_label_position("right")
- big_ax.set_yticks(y_ticks)
- big_ax.set_yticklabels(region_labels, color='k')
- if not (color == 'k'):
- labels = big_ax.yaxis.get_ticklabels()
- for idx in indices2emphasize:
- labels[idx].set_color(color)
- big_ax.yaxis.set_ticklabels(labels)
- big_ax.invert_yaxis()
- big_ax.axes.get_xaxis().set_visible(False)
- # TODO: find out what is the next line about and why it fails...
- # big_ax.axes.set_facecolor('none')
-
- def plot_in_columns(self, data_dict_list, labels, width_ratios=[], left_ax_focus_indices=[],
- right_ax_focus_indices=[], description="", title="", figure_name=None,
- figsize=None, **kwargs):
- if not isinstance(figsize, (tuple, list)):
- figsize = self.config.figures.VERY_LARGE_SIZE
- fig = pyplot.figure(title, frameon=False, figsize=figsize)
- fig.suptitle(description)
- n_subplots = len(data_dict_list)
- if not width_ratios:
- width_ratios = numpy.ones((n_subplots,)).tolist()
- matplotlib.gridspec.GridSpec(1, n_subplots, width_ratios=width_ratios)
- if 10 > n_subplots > 0:
- subplot_ind0 = 100 + 10 * n_subplots
- else:
- raise ValueError("\nSubplots' number " + str(n_subplots) + " is not between 1 and 9!")
- n_regions = len(labels)
- subplot_ind = subplot_ind0
- ax = None
- ax0 = None
- for iS, data_dict in enumerate(data_dict_list):
- subplot_ind += 1
- data = data_dict["data"]
- focus_indices = data_dict.get("focus_indices")
- if subplot_ind == 0:
- if not left_ax_focus_indices:
- left_ax_focus_indices = focus_indices
- else:
- ax0 = ax
- if data_dict.get("plot_type") == "vector_violin":
- ax = self.plot_vector_violin(data_dict.get("data_samples", []), data, [],
- labels, subplot_ind, data_dict["name"],
- colormap=kwargs.get("colormap", "YlOrRd"), show_y_labels=False,
- indices_red=focus_indices, sharey=ax0)
- elif data_dict.get("plot_type") == "regions2regions":
- # TODO: find a more general solution, in case we don't want to apply focus indices to x_ticks
- ax = self.plot_regions2regions(data, labels, subplot_ind, data_dict["name"], x_ticks=focus_indices,
- show_x_labels=True, show_y_labels=False, indices_red_x=focus_indices,
- sharey=ax0)
- else:
- ax = self.plot_vector(data, labels, subplot_ind, data_dict["name"], show_y_labels=False,
- indices_red=focus_indices, sharey=ax0)
- if right_ax_focus_indices == []:
- right_ax_focus_indices = focus_indices
- self._set_axis_labels(fig, 121, n_regions, labels, left_ax_focus_indices, 'r')
- self._set_axis_labels(fig, 122, n_regions, labels, right_ax_focus_indices, 'r', 'right')
- self._save_figure(pyplot.gcf(), figure_name)
- self._check_show()
- return fig
-
- #TODO: name is too generic
- def plots(self, data_dict, shape=None, transpose=False, skip=0, xlabels={}, xscales={}, yscales={}, title='Plots',
- lgnd={}, figure_name=None, figsize=None):
- if not isinstance(figsize, (tuple, list)):
- figsize = self.config.figures.VERY_LARGE_SIZE
- if shape is None:
- shape = (1, len(data_dict))
- fig, axes = pyplot.subplots(shape[0], shape[1], figsize=figsize)
- fig.set_label(title)
- for i, key in enumerate(data_dict.keys()):
- ind = numpy.unravel_index(i, shape)
- if transpose:
- axes[ind].plot(data_dict[key].T[skip:])
- else:
- axes[ind].plot(data_dict[key][skip:])
- axes[ind].set_xscale(xscales.get(key, "linear"))
- axes[ind].set_yscale(yscales.get(key, "linear"))
- axes[ind].set_xlabel(xlabels.get(key, ""))
- axes[ind].set_ylabel(key)
- this_legend = lgnd.get(key, None)
- if this_legend is not None:
- axes[ind].legend(this_legend)
- fig.tight_layout()
- self._save_figure(fig, figure_name)
- self._check_show()
- return fig, axes
-
- def pair_plots(self, data, keys, diagonal_plots={}, transpose=False, skip=0,
- title='Pair plots', legend_prefix="", subtitles=None, figure_name=None, figsize=None):
-
- def confirm_y_coordinate(data, ymax):
- data = list(data)
- data.append(ymax)
- return tuple(data)
-
- if not isinstance(figsize, (tuple, list)):
- figsize = self.config.figures.VERY_LARGE_SIZE
-
- if subtitles is None:
- subtitles = keys
- data = ensure_list(data)
- n = len(keys)
- fig, axes = pyplot.subplots(n, n, figsize=figsize)
- fig.set_label(title)
- colorcycle = pyplot.rcParams['axes.prop_cycle'].by_key()['color']
- for i, key_i in enumerate(keys):
- for j, key_j in enumerate(keys):
- for datai in data:
- if transpose:
- di = datai[key_i].T[skip:]
- else:
- di = datai[key_i][skip:]
- if i == j:
- if di.shape[0] > 1:
- hist_data = axes[i, j].hist(di, int(numpy.round(numpy.sqrt(len(di)))), log=True)[0]
- if i == 0 and len(di.shape) > 1 and di.shape[1] > 1:
- axes[i, j].legend([legend_prefix + str(ii + 1) for ii in range(di.shape[1])])
- y_max = numpy.array(hist_data).max()
- # The mean line
- axes[i, j].vlines(di.mean(axis=0), 0, y_max, color=colorcycle, linestyle='dashed',
- linewidth=1)
- else:
- # This is for the case of only 1 sample (optimization)
- y_max = 1.0
- for ii in range(di.shape[1]):
- axes[i, j].plot(di[0, ii], y_max, "D", color=colorcycle[ii%di.shape[1]], markersize=20,
- label=legend_prefix + str(ii + 1))
- if i == 0 and len(di.shape) > 1 and di.shape[1] > 1:
- axes[i, j].legend()
- # Plot a line (or marker) in the same axis
- diag_line_plot = ensure_list(diagonal_plots.get(key_i, ((), ()))[0])
- if len(diag_line_plot) in [1, 2]:
- if len(diag_line_plot) == 1 :
- diag_line_plot = confirm_y_coordinate(diag_line_plot, y_max)
- else:
- diag_line_plot[1] = diag_line_plot[1]/numpy.max(diag_line_plot[1])*y_max
- if len(ensure_list(diag_line_plot[0])) == 1:
- axes[i, j].plot(diag_line_plot[0], diag_line_plot[1], "o", mfc="k", mec="k",
- markersize=10)
- else:
- axes[i, j].plot(diag_line_plot[0], diag_line_plot[1], color='k',
- linestyle="dashed", linewidth=1)
- # Plot a marker in the same axis
- diag_marker_plot = ensure_list(diagonal_plots.get(key_i, ((), ()))[1])
- if len(diag_marker_plot) in [1, 2]:
- if len(diag_marker_plot) == 1:
- diag_marker_plot = confirm_y_coordinate(diag_marker_plot, y_max)
- axes[i, j].plot(diag_marker_plot[0], diag_marker_plot[1], "*", color='k', markersize=10)
- axes[i, j].autoscale()
- axes[i, j].set_ylim([0, 1.1*y_max])
-
- else:
- if transpose:
- dj = datai[key_j].T[skip:]
- else:
- dj = datai[key_j][skip:]
- axes[i, j].plot(dj, di, '.')
- if i == 0:
- axes[i, j].set_title(subtitles[j])
- if j == 0:
- axes[i, j].set_ylabel(key_i)
- fig.tight_layout()
- self._save_figure(fig, figure_name)
- self._check_show()
- return fig, axes
-
- def plot_bars(self, data, ax=None, fig=None, title="", group_names=[], legend_prefix="", figsize=None):
-
- def barlabel(ax, rects, positions):
- """
- Attach a text label on each bar displaying its height
- """
- for rect, pos in zip(rects, positions):
- height = rect.get_height()
- if pos < 0:
- y = -height
- pos = 0.75 * pos
- else:
- y = height
- pos = 0.25 * pos
- ax.text(rect.get_x() + rect.get_width() / 2., pos, '%0.2f' % y,
- color="k", ha='center', va='bottom', rotation=90)
- if fig is None:
- if not isinstance(figsize, (tuple, list)):
- figsize = self.config.figures.VERY_LARGE_SIZE
- fig, ax = pyplot.subplots(1, 1, figsize=figsize)
- show_and_save = True
- else:
- show_and_save = False
- if ax is None:
- ax = pyplot.gca()
- if isinstance(data, (list, tuple)): # If, there are many groups, data is a list:
- # Fill in with nan in case that not all groups have the same number of elements
- from itertools import izip_longest
- data = numpy.array(list(izip_longest(*ensure_list(data), fillvalue=numpy.nan))).T
- elif data.ndim == 1: # This is the case where there is only one group...
- data = numpy.expand_dims(data, axis=1).T
- n_groups, n_elements = data.shape
- posmax = numpy.nanmax(data)
- negmax = numpy.nanmax(-(-data))
- n_groups_names = len(group_names)
- if n_groups_names != n_groups:
- if n_groups_names != 0:
- warning("Ignoring group_names because their number (" + str(n_groups_names) +
- ") is not equal to the number of groups (" + str(n_groups) + ")!")
- group_names = n_groups * [""]
- colorcycle = pyplot.rcParams['axes.prop_cycle'].by_key()['color']
- n_colors = len(colorcycle)
- x_inds = numpy.arange(n_groups)
- width = 0.9 / n_elements
- elements = []
- for iE in range(n_elements):
- elements.append(ax.bar(x_inds + iE*width, data[:, iE], width, color=colorcycle[iE % n_colors]))
- positions = numpy.array([negmax if d < 0 else posmax for d in data[:, iE]])
- positions[numpy.logical_or(numpy.isnan(positions), numpy.isinf(numpy.abs(positions)))] = 0.0
- barlabel(ax, elements[-1], positions)
- if n_elements > 1:
- legend = [legend_prefix+str(ii) for ii in range(1, n_elements+1)]
- ax.legend(tuple([element[0] for element in elements]), tuple(legend))
- ax.set_xticks(x_inds + n_elements*width/2)
- ax.set_xticklabels(tuple(group_names))
- ax.set_title(title)
- ax.autoscale() # tight=True
- ax.set_xlim([-1.05*width, n_groups*1.05])
- if show_and_save:
- fig.tight_layout()
- self._save_figure(fig)
- self._check_show()
- return fig, ax
-
- def tvb_plot(self, plot_fun_name, *args, **kwargs):
- import tvb.simulator.plot.tools as TVB_plot_tools
- getattr(TVB_plot_tools, plot_fun_name)(*args, **kwargs)
- fig = pyplot.gcf()
- self._save_figure(fig)
- self._check_show()
- return fig
diff --git a/tvb_scripts/config.py b/tvb_scripts/plot/config.py
similarity index 64%
rename from tvb_scripts/config.py
rename to tvb_scripts/plot/config.py
index 572b49f..3a16396 100644
--- a/tvb_scripts/config.py
+++ b/tvb_scripts/plot/config.py
@@ -1,12 +1,43 @@
# -*- coding: utf-8 -*-
+#
+#
+# TheVirtualBrain-Scientific Package. This package holds all simulators, and
+# analysers necessary to run brain-simulations. You can use it stand alone or
+# in conjunction with TheVirtualBrain-Framework Package. See content of the
+# documentation-folder for more details. See also http://www.thevirtualbrain.org
+#
+# (c) 2012-2020, Baycrest Centre for Geriatric Care ("Baycrest") and others
+#
+# This program is free software: you can redistribute it and/or modify it under the
+# terms of the GNU General Public License as published by the Free Software Foundation,
+# either version 3 of the License, or (at your option) any later version.
+# This program is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
+# PARTICULAR PURPOSE. See the GNU General Public License for more details.
+# You should have received a copy of the GNU General Public License along with this
+# program. If not, see .
+#
+#
+# CITATION:
+# When using The Virtual Brain for scientific publications, please cite it as follows:
+#
+# Paula Sanz Leon, Stuart A. Knock, M. Marmaduke Woodman, Lia Domide,
+# Jochen Mersmann, Anthony R. McIntosh, Viktor Jirsa (2013)
+# The Virtual Brain: a simulator of primate brain network dynamics.
+# Frontiers in Neuroinformatics (7:10. doi: 10.3389/fninf.2013.00010)
+#
+#
+
+"""
+.. moduleauthor:: Dionysios Perdikis
+.. moduleauthor:: Gabriel Florea
+"""
import os
-import numpy
from datetime import datetime
-from tvb.basic.profile import TvbProfile
-
-TvbProfile.set_profile(TvbProfile.LIBRARY_PROFILE)
+import numpy
+from tvb.simulator.plot.config import FiguresConfig
class GenericConfig(object):
@@ -78,18 +109,7 @@ def FOLDER_RES(self):
if not (os.path.isdir(folder)):
os.makedirs(folder)
if self.subfolder is not None:
- os.path.join(folder, self.subfolder)
- return folder
-
- @property
- def FOLDER_FIGURES(self):
- folder = os.path.join(self._out_base, "figs")
- if self._separate_by_run:
- folder = folder + datetime.strftime(datetime.now(), '%Y-%m-%d_%H-%M')
- if not (os.path.isdir(folder)):
- os.makedirs(folder)
- if self.subfolder is not None:
- os.path.join(folder, self.subfolder)
+ folder = os.path.join(folder, self.subfolder)
return folder
@property
@@ -97,33 +117,6 @@ def FOLDER_TEMP(self):
return os.path.join(self._out_base, "temp")
-class FiguresConfig(object):
- VERY_LARGE_SIZE = (40, 20)
- VERY_LARGE_PORTRAIT = (30, 50)
- SUPER_LARGE_SIZE = (80, 40)
- LARGE_SIZE = (20, 15)
- SMALL_SIZE = (15, 10)
- NOTEBOOK_SIZE = (20, 10)
- FIG_FORMAT = 'png'
- SAVE_FLAG = True
- SHOW_FLAG = False
- MOUSE_HOOVER = False
- MATPLOTLIB_BACKEND = "Agg" # "Qt4Agg"
- FONTSIZE = 10
- SMALL_FONTSIZE = 8
- LARGE_FONTSIZE = 12
-
- def largest_size(self):
- import sys
- if 'IPython' not in sys.modules:
- return self.LARGE_SIZE
- from IPython import get_ipython
- if getattr(get_ipython(), 'kernel', None) is not None:
- return self.NOTEBOOK_SIZE
- else:
- return self.LARGE_SIZE
-
-
class CalculusConfig(object):
# Normalization configuration
WEIGHTS_NORM_PERCENT = 99
@@ -139,12 +132,12 @@ class CalculusConfig(object):
class Config(object):
generic = GenericConfig()
- figures = FiguresConfig()
calcul = CalculusConfig()
def __init__(self, head_folder=None, raw_data_folder=None, output_base=None, separate_by_run=False):
self.input = InputConfig(head_folder, raw_data_folder)
self.out = OutputConfig(output_base, separate_by_run)
+ self.figures = FiguresConfig(output_base, separate_by_run)
CONFIGURED = Config()
diff --git a/tvb_scripts/plot/head_plotter.py b/tvb_scripts/plot/head_plotter.py
deleted file mode 100644
index 5b3181c..0000000
--- a/tvb_scripts/plot/head_plotter.py
+++ /dev/null
@@ -1,143 +0,0 @@
-# coding=utf-8
-
-from matplotlib import pyplot
-
-import numpy
-
-from tvb_scripts.utils.computations_utils import compute_in_degree
-from tvb_scripts.datatypes.sensors import Sensors
-from tvb_scripts.plot.base_plotter import BasePlotter
-
-from tvb.datatypes.projections import ProjectionMatrix
-
-
-class HeadPlotter(BasePlotter):
-
- def __init__(self, config=None):
- super(HeadPlotter, self).__init__(config)
-
- def _plot_connectivity(self, connectivity, figure_name='Connectivity'):
- pyplot.figure(figure_name + str(connectivity.number_of_regions), self.config.figures.VERY_LARGE_SIZE)
- axes = []
- axes.append(self.plot_regions2regions(connectivity.normalized_weights,
- connectivity.region_labels, 121, "normalised weights"))
- axes.append(self.plot_regions2regions(connectivity.tract_lengths,
- connectivity.region_labels, 122, "tract lengths"))
- self._save_figure(None, figure_name.replace(" ", "_").replace("\t", "_"))
- self._check_show()
- return pyplot.gcf(), tuple(axes)
-
- def _plot_connectivity_stats(self, connectivity, figsize=None, figure_name='HeadStats '):
- if not isinstance(figsize, (list, tuple)):
- figsize = self.config.figures.VERY_LARGE_SIZE
- pyplot.figure("Head stats " + str(connectivity.number_of_regions), figsize=figsize)
- areas_flag = len(connectivity.areas) == len(connectivity.region_labels)
- axes=[]
- axes.append(self.plot_vector(compute_in_degree(connectivity.normalized_weights), connectivity.region_labels,
- 111 + 10 * areas_flag, "w in-degree"))
- if len(connectivity.areas) == len(connectivity.region_labels):
- axes.append(self.plot_vector(connectivity.areas, connectivity.region_labels, 122, "region areas"))
- self._save_figure(None, figure_name.replace(" ", "").replace("\t", ""))
- self._check_show()
- return pyplot.gcf(), tuple(axes)
-
- def _plot_sensors(self, sensors, projection, region_labels, count=1):
- figure, ax, cax = self._plot_projection(sensors, projection, region_labels,
- title="%d - %s - Projection" % (count, sensors.sensors_type))
- count += 1
- return count, figure, ax, cax
-
- def _plot_projection(self, sensors, projection, region_labels, figure=None, title="Projection",
- show_x_labels=True, show_y_labels=True, x_ticks=numpy.array([]), y_ticks=numpy.array([]),
- figsize=None):
- if not isinstance(figsize, (list, tuple)):
- figsize = self.config.figures.VERY_LARGE_SIZE
- if not (isinstance(figure, pyplot.Figure)):
- figure = pyplot.figure(title, figsize=figsize)
- ax, cax1 = self._plot_matrix(projection, sensors.labels, region_labels, 111, title,
- show_x_labels, show_y_labels, x_ticks, y_ticks)
- self._save_figure(None, title)
- self._check_show()
- return figure, ax, cax1
-
- def plot_head(self, head, plot_stats=False, plot_sensors=True):
- output = []
- output.append(self._plot_connectivity(head.connectivity))
- if plot_stats:
- output.append(self._plot_connectivity_stats(head.connectivity))
- if plot_sensors:
- count = 1
- for s_type, sensors_set in head.sensors.items():
- for sensor, projection in sensors_set.items():
- if isinstance(sensor, Sensors) and isinstance(projection, ProjectionMatrix):
- count, figure, ax, cax = \
- self._plot_sensors(sensor, projection.projection_data,
- head.connectivity.region_labels, count)
- output.append((figure, ax, cax))
- return tuple(output)
-
- def plot_tvb_connectivity(self, connectivity, num="weights", order_by=None, plot_hinton=False, plot_tracts=True,
- **kwargs):
- """
- A 2D plot for visualizing the Connectivity.weights matrix
- """
- figsize = kwargs.pop("figsize", self.config.figures.LARGE_SIZE)
- fontsize = kwargs.pop("fontsize", self.config.figures.SMALL_FONTSIZE)
-
- labels = connectivity.region_labels
- plot_title = connectivity.__class__.__name__
-
- if order_by is None:
- order = numpy.arange(connectivity.number_of_regions)
- else:
- order = numpy.argsort(order_by)
- if order.shape[0] != connectivity.number_of_regions:
- self.logger.error("Ordering vector doesn't have length number_of_regions")
- self.logger.error("Check ordering length and that connectivity is configured")
- return
-
- # Assumes order is shape (number_of_regions, )
- order_rows = order[:, numpy.newaxis]
- order_columns = order_rows.T
-
- if plot_hinton:
- from tvb.simulator.plot.tools import hinton_diagram
- weights_axes = hinton_diagram(connectivity.weights[order_rows, order_columns], num)
- weights_figure = None
- else:
- # weights matrix
- weights_figure = pyplot.figure(num="Connectivity weights", figsize=figsize)
- weights_axes = weights_figure.gca()
- wimg = weights_axes.matshow(connectivity.weights[order_rows, order_columns])
- weights_figure.colorbar(wimg)
-
- weights_axes.set_title(plot_title)
-
- weights_axes.set_yticks(numpy.arange(connectivity.number_of_regions))
- weights_axes.set_yticklabels(list(labels[order]), fontsize=self.config.figures.FONTSIZE)
-
- weights_axes.set_xticks(numpy.arange(connectivity.number_of_regions))
- weights_axes.set_xticklabels(list(labels[order]), fontsize=fontsize, rotation=90)
-
- self._save_figure(weights_figure, plot_title)
- self._check_show()
-
- if plot_tracts:
- # tract lengths matrix
- tracts_figure = pyplot.figure(num="Tracts' lengths", figsize=figsize)
- tracts_axes = tracts_figure.gca()
- timg = tracts_axes.matshow(connectivity.tract_lengths[order_rows, order_columns])
- tracts_axes.set_title("Tracts' lengths")
- tracts_figure.colorbar(timg)
- tracts_axes.set_yticks(numpy.arange(connectivity.number_of_regions))
- tracts_axes.set_yticklabels(list(labels[order]), fontsize=fontsize)
-
- tracts_axes.set_xticks(numpy.arange(connectivity.number_of_regions))
- tracts_axes.set_xticklabels(list(labels[order]), fontsize=fontsize, rotation=90)
-
- self._save_figure(tracts_figure)
- self._check_show()
- return weights_figure, weights_axes, tracts_figure, tracts_axes
-
- else:
- return weights_figure, weights_axes
diff --git a/tvb_scripts/plot/plotter.py b/tvb_scripts/plot/plotter.py
index 90d9c9e..fa73542 100644
--- a/tvb_scripts/plot/plotter.py
+++ b/tvb_scripts/plot/plotter.py
@@ -1,63 +1,18 @@
# -*- coding: utf-8 -*-
+from tvb.simulator.plot.plotter import Plotter as TVBPlotter
-from tvb_scripts.plot.base_plotter import BasePlotter
-from tvb_scripts.plot.head_plotter import HeadPlotter
from tvb_scripts.plot.time_series_plotter import TimeSeriesPlotter
-class Plotter(object):
-
- def __init__(self, config=None):
- self.config = config
-
- @property
- def base(self):
- return BasePlotter(self.config)
-
- def tvb_plot(self, plot_fun_name, *args, **kwargs):
- return BasePlotter(self.config).tvb_plot(plot_fun_name, *args, **kwargs)
-
+class Plotter(TVBPlotter):
def plot_head(self, head):
- return HeadPlotter(self.config).plot_head(head)
-
- def plot_tvb_connectivity(self, *args, **kwargs):
- return HeadPlotter(self.config).plot_tvb_connectivity(*args, **kwargs)
-
- def plot_ts(self, *args, **kwargs):
- return TimeSeriesPlotter(self.config).plot_ts(*args, **kwargs)
-
- def plot_ts_raster(self, *args, **kwargs):
- return TimeSeriesPlotter(self.config).plot_ts_raster(*args, **kwargs)
-
- def plot_ts_trajectories(self, *args, **kwargs):
- return TimeSeriesPlotter(self.config).plot_ts_trajectories(*args, **kwargs)
-
- def plot_tvb_timeseries(self, *args, **kwargs):
- return TimeSeriesPlotter(self.config).plot_tvb_time_series(*args, **kwargs)
+ return self.plot_head_tvb(head.connectivity, head.sensors)
def plot_timeseries(self, *args, **kwargs):
- return TimeSeriesPlotter(self.config).plot_time_series(*args, **kwargs)
-
- def plot_raster(self, *args, **kwargs):
- return TimeSeriesPlotter(self.config).plot_raster(*args, **kwargs)
-
- def plot_trajectories(self, *args, **kwargs):
- return TimeSeriesPlotter(self.config).plot_trajectories(*args, **kwargs)
+ return TimeSeriesPlotter(self.config).plot_tvb_time_series(*args, **kwargs)
def plot_timeseries_interactive(self, *args, **kwargs):
return TimeSeriesPlotter(self.config).plot_time_series_interactive(*args, **kwargs)
- def plot_tvb_timeseries_interactive(self, *args, **kwargs):
- return TimeSeriesPlotter(self.config).plot_tvb_time_series_interactive(*args, **kwargs)
-
- def plot_power_spectra_interactive(self, *args, **kwargs):
- return TimeSeriesPlotter(self.config).plot_power_spectra_interactive(*args, **kwargs)
-
- def plot_tvb_power_spectra_interactive(self, *args, **kwargs):
- return TimeSeriesPlotter(self.config).plot_tvb_power_spectra_interactive(*args, **kwargs)
-
- def plot_ts_spectral_analysis_raster(self, *args, **kwargs):
- return TimeSeriesPlotter(self.config).plot_spectral_analysis_raster(self, *args, **kwargs)
-
def plot_spectral_analysis_raster(self, *args, **kwargs):
- return TimeSeriesPlotter(self.config).plot_spectral_analysis_raster(self, *args, **kwargs)
+ return TimeSeriesPlotter(self.config).plot_spectral_analysis_raster(*args, **kwargs)
diff --git a/tvb_scripts/plot/time_series_interactive_plotter.py b/tvb_scripts/plot/time_series_interactive_plotter.py
deleted file mode 100644
index c5dd2b7..0000000
--- a/tvb_scripts/plot/time_series_interactive_plotter.py
+++ /dev/null
@@ -1,150 +0,0 @@
-# coding=utf-8
-import numpy
-
-from tvb_scripts.utils.log_error_utils import initialize_logger
-from tvb_scripts.utils.data_structures_utils import ensure_list, rotate_n_list_elements
-from tvb.simulator.plot.timeseries_interactive import \
- TimeSeriesInteractive, pylab, time_series_datatypes, BACKGROUNDCOLOUR, EDGECOLOUR
-
-from matplotlib.pyplot import rcParams
-
-
-LOG = initialize_logger(__name__)
-
-
-class TimeseriesInteractivePlotter(TimeSeriesInteractive):
-
- def create_figure(self, **kwargs):
- """ Create the figure and time-series axes. """
- # time_series_type = self.time_series.__class__.__name__
- figsize = kwargs.pop("figsize", (14, 8))
- facecolor = kwargs.pop("facecolor", BACKGROUNDCOLOUR)
- edgecolor = kwargs.pop("edgecolor", EDGECOLOUR)
- try:
- figure_window_title = "Interactive time series: " # + time_series_type
- num = kwargs.pop("figname", kwargs.get("num", figure_window_title))
- # pylab.close(figure_window_title)
- self.its_fig = pylab.figure(num=num,
- figsize=figsize,
- facecolor=facecolor,
- edgecolor=edgecolor)
- except ValueError:
- LOG.info("My life would be easier if you'd update your PyLab...")
- figure_number = 42
- pylab.close(figure_number)
- self.its_fig = pylab.figure(num=figure_number,
- figsize=figsize,
- facecolor=facecolor,
- edgecolor=edgecolor)
-
- self.ts_ax = self.its_fig.add_axes([0.1, 0.1, 0.85, 0.85])
-
- self.whereami_ax = self.its_fig.add_axes([0.1, 0.95, 0.85, 0.025],
- facecolor=facecolor)
- self.whereami_ax.set_axis_off()
- if hasattr(self.whereami_ax, 'autoscale'):
- self.whereami_ax.autoscale(enable=True, axis='both', tight=True)
- self.whereami_ax.plot(self.time_view,
- numpy.zeros((len(self.time_view),)),
- color="0.3", linestyle="--")
- self.hereiam = self.whereami_ax.plot(self.time_view,
- numpy.zeros((len(self.time_view),)),
- 'b-', linewidth=4)
-
- def plot_time_series(self, **kwargs):
- """ Plot a view on the timeseries. """
- # Set title and axis labels
- #time_series_type = self.time_series.__class__.__name__
- #self.ts_ax.set(title = time_series_type)
- #self.ts_ax.set(xlabel = "Time (%s)" % self.units)
-
- # This assumes shape => (time, space)
- step = self.scaling * self.peak_to_peak
- if step == 0:
- offset = 0.0
- else: #NOTE: specifying step in arange is faster, but it fence-posts.
- offset = numpy.arange(0, self.nsrs) * step
- if hasattr(self.ts_ax, 'autoscale'):
- self.ts_ax.autoscale(enable=True, axis='both', tight=True)
-
- self.ts_ax.set_yticks(offset)
- self.ts_ax.set_yticklabels(self.labels, fontsize=10)
- #import pdb; pdb.set_trace()
-
- #Light gray guidelines
- self.ts_ax.plot([self.nsrs*[self.time[self.time_view[0]]],
- self.nsrs*[self.time[self.time_view[-1]]]],
- numpy.vstack(2*(offset,)), "0.85")
-
- # Determine colors and linestyles for each variable of the Timeseries
- linestyle = ensure_list(kwargs.pop("linestyle", "-"))
- colors = kwargs.pop("linestyle", None)
- if colors is not None:
- colors = ensure_list(colors)
- if self.data.shape[1] > 1:
- linestyle = rotate_n_list_elements(linestyle, self.data.shape[1])
- if not isinstance(colors, list):
- colors = (rcParams['axes.prop_cycle']).by_key()['color']
- colors = rotate_n_list_elements(colors, self.data.shape[1])
- else:
- # If no color,
- # or a color sequence is given in the input
- # but there is only one variable to plot,
- # choose the black color
- if colors is None or len(colors) > 1:
- colors = ["k"]
- linestyle = linestyle[:1]
-
- # Determine the alpha value depending on the number of modes/samples of the Timeseries
- alpha = 1.0
- if len(self.data.shape) > 3 and self.data.shape[3] > 1:
- alpha /= self.data.shape[3]
-
- # Plot the timeseries per variable and sample
- self.ts_view = []
- for i_var in range(self.data.shape[1]):
- for ii in range(self.data.shape[3]):
- # Plot the timeseries
- self.ts_view.append(self.ts_ax.plot(self.time[self.time_view],
- offset + self.data[self.time_view, i_var, :, ii],
- alpha=alpha, color=colors[i_var], linestyle=linestyle[i_var],
- **kwargs))
-
- self.hereiam[0].remove()
- self.hereiam = self.whereami_ax.plot(self.time_view,
- numpy.zeros((len(self.time_view),)),
- 'b-', linewidth=4)
-
- pylab.draw()
-
- def show(self, block=True, **kwargs):
- """ Generate the interactive time-series figure. """
- time_series_type = self.time_series.__class__.__name__
- msg = "Generating an interactive time-series plot for %s"
- if isinstance(self.time_series, time_series_datatypes.TimeSeriesSurface):
- LOG.warning("Intended for region and sensors, not surfaces.")
- LOG.info(msg % time_series_type)
-
- # Make the figure:
- self.create_figure()
-
- # Selectors
- # self.add_mode_selector()
-
- # Sliders
- self.add_window_length_slider()
- self.add_scaling_slider()
- # self.add_time_slider()
-
- # time-view buttons
- self.add_step_back_button()
- self.add_step_forward_button()
- self.add_big_step_back_button()
- self.add_big_step_forward_button()
- self.add_start_button()
- self.add_end_button()
-
- # Plot timeseries
- self.plot_time_series()
-
- pylab.show(block=block, **kwargs)
diff --git a/tvb_scripts/plot/time_series_plotter.py b/tvb_scripts/plot/time_series_plotter.py
index def9181..b35aa12 100644
--- a/tvb_scripts/plot/time_series_plotter.py
+++ b/tvb_scripts/plot/time_series_plotter.py
@@ -1,371 +1,12 @@
# -*- coding: utf-8 -*-
-from six import string_types
-import matplotlib
-from matplotlib import pyplot, gridspec
-from matplotlib.colors import Normalize
import numpy
-from tvb_scripts.utils.log_error_utils import warning, raise_value_error
-from tvb_scripts.utils.data_structures_utils import ensure_list, isequal_string, generate_region_labels, ensure_string
-from tvb_scripts.utils.time_series_utils import time_spectral_analysis
-from tvb_scripts.plot.base_plotter import BasePlotter
-from tvb.datatypes.time_series import TimeSeries as TimeSeriesTVB
-from tvb_scripts.datatypes.time_series import TimeSeries
-
-
-def assert_time(time, n_times, time_unit="ms", logger=None):
- if time_unit.find("ms"):
- dt = 0.001
- else:
- dt = 1.0
- try:
- time = numpy.array(time).flatten()
- n_time = len(time)
- if n_time > n_times:
- # self.logger.warning("Input time longer than data time points! Removing redundant tail time points!")
- time = time[:n_times]
- elif n_time < n_times:
- # self.logger.warning("Input time shorter than data time points! "
- # "Extending tail time points with the same average time step!")
- if n_time > 1:
- dt = numpy.mean(numpy.diff(time))
- n_extra_points = n_times - n_time
- start_time_point = time[-1] + dt
- end_time_point = start_time_point + n_extra_points * dt
- time = numpy.concatenate([time, numpy.arange(start_time_point, end_time_point, dt)])
- except:
- if logger:
- logger.warning("Setting a default time step vector manually! Input time: " + str(time))
- time = numpy.arange(0, n_times * dt, dt)
- return time
-
-
-class TimeSeriesPlotter(BasePlotter):
- linestyle = "-"
- linewidth = 1
- marker = None
- markersize = 2
- markerfacecolor = None
- tick_font_size = 12
- print_ts_indices = True
-
- def __init__(self, config=None):
- super(TimeSeriesPlotter, self).__init__(config)
- self.interactive_plotter = None
- self.print_ts_indices = self.print_regions_indices
- self.HighlightingDataCursor = lambda *args, **kwargs: None
- if matplotlib.get_backend() in matplotlib.rcsetup.interactive_bk and self.config.figures.MOUSE_HOOVER:
- try:
- from mpldatacursor import HighlightingDataCursor
- self.HighlightingDataCursor = HighlightingDataCursor
- except ImportError:
- self.config.figures.MOUSE_HOOVER = False
- # self.logger.warning("Importing mpldatacursor failed! No highlighting functionality in plots!")
- else:
- # self.logger.warning("Noninteractive matplotlib backend! No highlighting functionality in plots!")
- self.config.figures.MOUSE_HOOVER = False
-
- @property
- def line_format(self):
- return {"linestyle": self.linestyle, "linewidth": self.linewidth,
- "marker": self.marker, "markersize": self.markersize, "markerfacecolor": self.markerfacecolor}
-
- def _ts_plot(self, time, n_vars, nTS, n_times, time_unit, subplots, offset=0.0, data_lims=[]):
-
- time_unit = ensure_string(time_unit)
- data_fun = lambda data, time, icol: (data[icol], time, icol)
-
- def plot_ts(x, iTS, colors, labels):
- x, time, ivar = x
- time = assert_time(time, len(x[:, iTS]), time_unit, self.logger)
- try:
- return pyplot.plot(time, x[:, iTS], color=colors[iTS], label=labels[iTS], **self.line_format)
- except:
- self.logger.warning("Cannot convert labels' strings for line labels!")
- return pyplot.plot(time, x[:, iTS], color=colors[iTS], label=str(iTS), **self.line_format)
-
- def plot_ts_raster(x, iTS, colors, labels, offset):
- x, time, ivar = x
- time = assert_time(time, len(x[:, iTS]), time_unit, self.logger)
- try:
- return pyplot.plot(time, -x[:, iTS] + (offset * iTS + x[:, iTS].mean()), color=colors[iTS],
- label=labels[iTS], **self.line_format)
- except:
- self.logger.warning("Cannot convert labels' strings for line labels!")
- return pyplot.plot(time, -x[:, iTS] + offset * iTS, color=colors[iTS],
- label=str(iTS), **self.line_format)
-
- def axlabels_ts(labels, n_rows, irow, iTS):
- if irow == n_rows:
- pyplot.gca().set_xlabel("Time (" + time_unit + ")")
- if n_rows > 1:
- try:
- pyplot.gca().set_ylabel(str(iTS) + "." + labels[iTS])
- except:
- self.logger.warning("Cannot convert labels' strings for y axis labels!")
- pyplot.gca().set_ylabel(str(iTS))
-
- def axlimits_ts(data_lims, time, icol):
- pyplot.gca().set_xlim([time[0], time[-1]])
- if n_rows > 1:
- pyplot.gca().set_ylim([data_lims[icol][0], data_lims[icol][1]])
- else:
- pyplot.autoscale(enable=True, axis='y', tight=True)
+from tvb.simulator.plot.time_series_plotter import TimeSeriesPlotter as TVBTimeSeriesPlotter
- def axYticks(labels, nTS, offsets=offset):
- pyplot.gca().set_yticks((offset * numpy.array([list(range(nTS))]).flatten()).tolist())
- try:
- pyplot.gca().set_yticklabels(labels.flatten().tolist())
- except:
- labels = generate_region_labels(nTS, [], "", True)
- self.logger.warning("Cannot convert region labels' strings for y axis ticks!")
-
- if offset > 0.0:
- plot_lines = lambda x, iTS, colors, labels: \
- plot_ts_raster(x, iTS, colors, labels, offset)
- else:
- plot_lines = lambda x, iTS, colors, labels: \
- plot_ts(x, iTS, colors, labels)
- this_axYticks = lambda labels, nTS: axYticks(labels, nTS, offset)
- if subplots:
- n_rows = nTS
- def_alpha = 1.0
- else:
- n_rows = 1
- def_alpha = 0.5
- subtitle_col = lambda subtitle: pyplot.gca().set_title(subtitle)
- subtitle = lambda iTS, labels: None
- projection = None
- axlabels = lambda labels, vars, n_vars, n_rows, irow, iTS: axlabels_ts(labels, n_rows, irow, iTS)
- axlimits = lambda data_lims, time, n_vars, icol: axlimits_ts(data_lims, time, icol)
- loopfun = lambda nTS, n_rows, icol: list(range(nTS))
- return data_fun, time, plot_lines, projection, n_rows, n_vars, def_alpha, loopfun, \
- subtitle, subtitle_col, axlabels, axlimits, this_axYticks
-
- def _trajectories_plot(self, n_dims, nTS, nSamples, subplots):
- data_fun = lambda data, time, icol: data
-
- def plot_traj_2D(x, iTS, colors, labels):
- x, y = x
- try:
- return pyplot.plot(x[:, iTS], y[:, iTS], color=colors[iTS], label=labels[iTS], **self.line_format)
- except:
- self.logger.warning("Cannot convert labels' strings for line labels!")
- return pyplot.plot(x[:, iTS], y[:, iTS], color=colors[iTS], label=str(iTS), **self.line_format)
-
- def plot_traj_3D(x, iTS, colors, labels):
- x, y, z = x
- try:
- return pyplot.plot(x[:, iTS], y[:, iTS], z[:, iTS], color=colors[iTS],
- label=labels[iTS], **self.line_format)
- except:
- self.logger.warning("Cannot convert labels' strings for line labels!")
- return pyplot.plot(x[:, iTS], y[:, iTS], z[:, iTS], color=colors[iTS],
- label=str(iTS), **self.line_format)
-
- def subtitle_traj(labels, iTS):
- try:
- if self.print_ts_indices:
- pyplot.gca().set_title(str(iTS) + "." + labels[iTS])
- else:
- pyplot.gca().set_title(labels[iTS])
- except:
- self.logger.warning("Cannot convert labels' strings for subplot titles!")
- pyplot.gca().set_title(str(iTS))
-
- def axlabels_traj(vars, n_vars):
- pyplot.gca().set_xlabel(vars[0])
- pyplot.gca().set_ylabel(vars[1])
- if n_vars > 2:
- pyplot.gca().set_zlabel(vars[2])
-
- def axlimits_traj(data_lims, n_vars):
- pyplot.gca().set_xlim([data_lims[0][0], data_lims[0][1]])
- pyplot.gca().set_ylim([data_lims[1][0], data_lims[1][1]])
- if n_vars > 2:
- pyplot.gca().set_zlim([data_lims[2][0], data_lims[2][1]])
-
- if n_dims == 2:
- plot_lines = lambda x, iTS, colors, labels: \
- plot_traj_2D(x, iTS, colors, labels)
- projection = None
- elif n_dims == 3:
- plot_lines = lambda x, iTS, colors, labels: \
- plot_traj_3D(x, iTS, colors, labels)
- projection = '3d'
- else:
- raise_value_error("Data dimensions are neigher 2D nor 3D!, but " + str(n_dims) + "D!")
- n_rows = 1
- n_cols = 1
- if subplots is None:
- # if nSamples > 1:
- n_rows = int(numpy.floor(numpy.sqrt(nTS)))
- n_cols = int(numpy.ceil((1.0 * nTS) / n_rows))
- elif isinstance(subplots, (list, tuple)):
- n_rows = subplots[0]
- n_cols = subplots[1]
- if n_rows * n_cols < nTS:
- raise_value_error("Not enough subplots for all time series:"
- "\nn_rows * n_cols = product(subplots) = product(" + str(subplots) + " = "
- + str(n_rows * n_cols) + "!")
- if n_rows * n_cols > 1:
- def_alpha = 0.5
- subtitle = lambda labels, iTS: subtitle_traj(labels, iTS)
- subtitle_col = lambda subtitles, icol: None
- else:
- def_alpha = 1.0
- subtitle = lambda labels, iTS: None
- subtitle_col = lambda subtitles, icol: pyplot.gca().set_title(pyplot.gcf().title)
- axlabels = lambda labels, vars, n_vars, n_rows, irow, iTS: axlabels_traj(vars, n_vars)
- axlimits = lambda data_lims, time, n_vars, icol: axlimits_traj(data_lims, n_vars)
- loopfun = lambda nTS, n_rows, icol: list(range(icol, nTS, n_rows))
- return data_fun, plot_lines, projection, n_rows, n_cols, def_alpha, loopfun, \
- subtitle, subtitle_col, axlabels, axlimits
-
- # TODO: refactor to not have the plot commands here
- def plot_ts(self, data, time=None, var_labels=[], mode="ts", subplots=None, special_idx=[],
- subtitles=[], labels=[], offset=0.5, time_unit="ms",
- title='Time series', figure_name=None, figsize=None):
- if not isinstance(figsize, (list, tuple)):
- figsize = self.config.figures.LARGE_SIZE
- if isinstance(data, dict):
- var_labels = data.keys()
- data = data.values()
- elif isinstance(data, numpy.ndarray):
- if len(data.shape) < 3:
- if len(data.shape) < 2:
- data = numpy.expand_dims(data, 1)
- data = numpy.expand_dims(data, 2)
- data = [data]
- else:
- # Assuming a structure of Time X Space X Variables X Samples
- data = [data[:, :, iv].squeeze() for iv in range(data.shape[2])]
- elif isinstance(data, (list, tuple)):
- data = ensure_list(data)
- else:
- raise_value_error("Input timeseries: %s \n" "is not on of one of the following types: "
- "[numpy.ndarray, dict, list, tuple]" % str(data))
- n_vars = len(data)
- data_lims = []
- for id, d in enumerate(data):
- if isequal_string(mode, "raster"):
- data[id] = (d - d.mean(axis=0))
- drange = numpy.max(data[id].max(axis=0) - data[id].min(axis=0))
- data[id] = data[id] / drange # zscore(d, axis=None)
- data_lims.append([d.min(), d.max()])
- data_shape = data[0].shape
- if len(data_shape) == 1:
- n_times = data_shape[0]
- nTS = 1
- for iV in range(n_vars):
- data[iV] = data[iV][:, numpy.newaxis]
- else:
- n_times, nTS = data_shape[:2]
- if len(data_shape) > 2:
- nSamples = data_shape[2]
- else:
- nSamples = 1
- if special_idx is None:
- special_idx = []
- n_special_idx = len(special_idx)
- if len(subtitles) == 0:
- subtitles = var_labels
- if isinstance(labels, list) and len(labels) == n_vars:
- labels = [generate_region_labels(nTS, label, ". ", self.print_ts_indices) for label in labels]
- else:
- labels = [generate_region_labels(nTS, labels, ". ", self.print_ts_indices) for _ in range(n_vars)]
- if isequal_string(mode, "traj"):
- data_fun, plot_lines, projection, n_rows, n_cols, def_alpha, loopfun, \
- subtitle, subtitle_col, axlabels, axlimits = \
- self._trajectories_plot(n_vars, nTS, nSamples, subplots)
- else:
- if isequal_string(mode, "raster"):
- data_fun, time, plot_lines, projection, n_rows, n_cols, def_alpha, loopfun, \
- subtitle, subtitle_col, axlabels, axlimits, axYticks = \
- self._ts_plot(time, n_vars, nTS, n_times, time_unit, 0, offset, data_lims)
-
- else:
- data_fun, time, plot_lines, projection, n_rows, n_cols, def_alpha, loopfun, \
- subtitle, subtitle_col, axlabels, axlimits, axYticks = \
- self._ts_plot(time, n_vars, nTS, n_times, time_unit, ensure_list(subplots)[0])
- alpha_ratio = 1.0 / nSamples
- alphas = numpy.maximum(numpy.array([def_alpha] * nTS) * alpha_ratio, 0.1)
- alphas[special_idx] = numpy.maximum(alpha_ratio, 0.1)
- if isequal_string(mode, "traj") and (n_cols * n_rows > 1):
- colors = numpy.zeros((nTS, 4))
- colors[special_idx] = \
- numpy.array([numpy.array([1.0, 0, 0, 1.0]) for _ in range(n_special_idx)]).reshape((n_special_idx, 4))
- else:
- cmap = matplotlib.cm.get_cmap('jet')
- colors = numpy.array([cmap(0.5 * iTS / nTS) for iTS in range(nTS)])
- colors[special_idx] = \
- numpy.array([cmap(1.0 - 0.25 * iTS / nTS) for iTS in range(n_special_idx)]).reshape((n_special_idx, 4))
- colors[:, 3] = alphas
- lines = []
- pyplot.figure(title, figsize=figsize)
- axes = []
- for icol in range(n_cols):
- if n_rows == 1:
- # If there are no more rows, create axis, and set its limits, labels and possible subtitle
- axes += ensure_list(pyplot.subplot(n_rows, n_cols, icol + 1, projection=projection))
- axlimits(data_lims, time, n_vars, icol)
- axlabels(labels[icol % n_vars], var_labels, n_vars, n_rows, 1, 0)
- pyplot.gca().set_title(subtitles[icol])
- for iTS in loopfun(nTS, n_rows, icol):
- if n_rows > 1:
- # If there are more rows, create axes, and set their limits, labels and possible subtitles
- axes += ensure_list(pyplot.subplot(n_rows, n_cols, iTS + 1, projection=projection))
- axlimits(data_lims, time, n_vars, icol)
- subtitle(labels[icol % n_vars], iTS)
- axlabels(labels[icol % n_vars], var_labels, n_vars, n_rows, (iTS % n_rows) + 1, iTS)
- lines += ensure_list(plot_lines(data_fun(data, time, icol), iTS, colors, labels[icol % n_vars]))
- if isequal_string(mode, "raster"): # set yticks as labels if this is a raster plot
- axYticks(labels[icol % n_vars], nTS)
- yticklabels = pyplot.gca().yaxis.get_ticklabels()
- self.tick_font_size = numpy.minimum(self.tick_font_size,
- int(numpy.round(self.tick_font_size * 100.0 / nTS)))
- for iTS in range(nTS):
- yticklabels[iTS].set_fontsize(self.tick_font_size)
- if iTS in special_idx:
- yticklabels[iTS].set_color(colors[iTS, :3].tolist() + [1])
- pyplot.gca().yaxis.set_ticklabels(yticklabels)
- pyplot.gca().invert_yaxis()
-
- if self.config.figures.MOUSE_HOOVER:
- for line in lines:
- self.HighlightingDataCursor(line, formatter='{label}'.format, bbox=dict(fc='white'),
- arrowprops=dict(arrowstyle='simple', fc='white', alpha=0.5))
-
- self._save_figure(pyplot.gcf(), figure_name)
- self._check_show()
- return pyplot.gcf(), axes, lines
-
- def plot_ts_raster(self, data, time, var_labels=[], time_unit="ms", special_idx=[],
- title='Raster plot', subtitles=[], labels=[], offset=0.5, figure_name=None, figsize=None):
- if not isinstance(figsize, (list, tuple)):
- figsize = self.config.figures.VERY_LARGE_SIZE
- return self.plot_ts(data, time, var_labels, "raster", None, special_idx, subtitles, labels, offset, time_unit,
- title, figure_name, figsize)
-
- def plot_ts_trajectories(self, data, var_labels=[], subtitles=None, special_idx=[], labels=[],
- title='State space trajectories', figure_name=None, figsize=None):
- if not isinstance(figsize, (list, tuple)):
- figsize = self.config.figures.LARGE_SIZE
- return self.plot_ts(data, [], var_labels, "traj", subtitles, special_idx, labels=labels, title=title,
- figure_name=figure_name, figsize=figsize)
+from tvb_scripts.datatypes.time_series import TimeSeries
- def plot_tvb_time_series(self, time_series, mode="ts", subplots=None, special_idx=[], subtitles=[],
- offset=0.5, title=None, figure_name=None, figsize=None):
- if not isinstance(figsize, (list, tuple)):
- figsize = self.config.figures.LARGE_SIZE
- if title is None:
- title = time_series.title
- variables_labels = time_series.labels_dimensions.get(time_series.labels_ordering[1], [])
- space_labels = time_series.labels_dimensions.get(time_series.labels_ordering[2], [])
- return self.plot_ts(numpy.swapaxes(time_series.data, 1, 2), time_series.time, variables_labels,
- mode, subplots, special_idx, subtitles, space_labels,
- offset, time_series.time_unit, title, figure_name, figsize)
+class TimeSeriesPlotter(TVBTimeSeriesPlotter):
def plot_time_series(self, time_series, mode="ts", subplots=None, special_idx=[], subtitles=[],
offset=0.5, title=None, figure_name=None, figsize=None, **kwargs):
if isinstance(time_series, TimeSeries):
@@ -375,211 +16,25 @@ def plot_time_series(self, time_series, mode="ts", subplots=None, special_idx=[]
time_series.time, time_series.variables_labels,
mode, subplots, special_idx, subtitles, time_series.space_labels,
offset, time_series.time_unit, title, figure_name, figsize)
- elif isinstance(time_series, TimeSeriesTVB):
- self.plot_tvb_time_series(time_series, mode, subplots, special_idx,
- subtitles, offset, title, figure_name, figsize)
- elif isinstance(time_series, (numpy.ndarray, dict, list, tuple)):
- time = kwargs.get("time", None)
- time_unit = kwargs.get("time_unit", "ms")
- labels = kwargs.get("labels", [])
- var_labels = kwargs.get("var_labels", [])
- if title is None:
- title = "Time Series"
- return self.plot_ts(time_series, time=time, mode=mode, time_unit=time_unit,
- labels=labels, var_labels=var_labels, subplots=subplots, special_idx=special_idx,
- subtitles=subtitles, offset=offset, title=title, figure_name=figure_name,
- figsize=figsize)
else:
- raise_value_error("Input time_series: %s \n" "is not on of one of the following types: "
- "[TimeSeries (tvb-scripts), TimeSeries (TVB), numpy.ndarray, dict]" % str(time_series))
-
- def plot_raster(self, time_series, subplots=None, special_idx=[], subtitles=[],
- offset=0.5, title=None, figure_name=None, figsize=None, **kwargs):
- return self.plot_time_series(time_series, "raster", subplots, special_idx,
- subtitles, offset, title, figure_name, figsize, **kwargs)
-
- def plot_trajectories(self, time_series, subplots=None, special_idx=[], subtitles=[],
- offset=0.5, title=None, figure_name=None, figsize=None, **kwargs):
- return self.plot_time_series(time_series, "traj", subplots, special_idx,
- subtitles, offset, title, figure_name, figsize, **kwargs)
-
- @staticmethod
- def plot_tvb_time_series_interactive(time_series, first_n=-1, **kwargs):
- from tvb_scripts.plot.time_series_interactive_plotter import TimeseriesInteractivePlotter
- interactive_plotter = TimeseriesInteractivePlotter(time_series=time_series, first_n=first_n)
- interactive_plotter.configure()
- block = kwargs.pop("block", True)
- interactive_plotter.show(block=block, **kwargs)
+ super(TimeSeriesPlotter, self).plot_time_series(time_series, mode, subplots, special_idx, subtitles, offset,
+ title, figure_name, figsize, **kwargs)
def plot_time_series_interactive(self, time_series, first_n=-1, **kwargs):
- if isinstance(time_series, TimeSeriesTVB):
- self.plot_tvb_time_series_interactive(time_series, first_n, **kwargs)
- elif isinstance(time_series, TimeSeries):
+ if isinstance(time_series, TimeSeries):
self.plot_tvb_time_series_interactive(time_series._tvb, first_n, **kwargs)
- elif isinstance(time_series, numpy.ndarray):
- self.plot_tvb_time_series_interactive(TimeSeries(data=time_series), first_n, **kwargs)
- elif isinstance(time_series, (list, tuple)):
- self.plot_tvb_time_series_interactive(TimeSeries(data=TimeSeriesTVB(data=numpy.stack(time_series, axis=1))),
- first_n, **kwargs)
- elif isinstance(time_series, dict):
- ts = numpy.stack(time_series.values(), axis=1)
- time_series = TimeSeriesTVB(data=ts, labels_dimensions={"State Variable": time_series.keys()})
- self.plot_tvb_time_series_interactive(time_series, first_n, **kwargs)
else:
- raise_value_error("Input time_series: %s \n" "is not on of one of the following types: "
- "[TimeSeries (tvb-scripts), TimeSeriesTVB (TVB), numpy.ndarray, dict, list, tuple]" %
- str(time_series))
-
- @staticmethod
- def plot_tvb_power_spectra_interactive(time_series, spectral_props, **kwargs):
- from tvb.simulator.plot.power_spectra_interactive import PowerSpectraInteractive
- interactive_plotters = PowerSpectraInteractive(time_series=time_series, **spectral_props)
- interactive_plotters.configure()
- block = kwargs.pop("block", True)
- interactive_plotters.show(blocl=block, **kwargs)
-
- def plot_power_spectra_interactive(self, time_series, spectral_props, **kwargs):
- self.plot_tvb_power_spectra_interactive(self, time_series._tvb, spectral_props, **kwargs)
-
- # TODO: refactor to not have the plot commands here
- def _plot_ts_spectral_analysis_raster(self, data, time=None, var_label="", time_unit="ms",
- freq=None, spectral_options={}, special_idx=[], labels=[],
- title='Spectral Analysis', figure_name=None, figsize=None):
- if not isinstance(figsize, (list, tuple)):
- figsize = self.config.figures.VERY_LARGE_SIZE
- if len(data.shape) == 1:
- n_times = data.shape[0]
- nS = 1
- else:
- n_times, nS = data.shape[:2]
- time = assert_time(time, n_times, time_unit, self.logger)
- if not isinstance(time_unit, string_types):
- time_unit = list(time_unit)[0]
- time_unit = ensure_string(time_unit)
- if time_unit in ("ms", "msec"):
- fs = 1000.0
- else:
- fs = 1.0
- fs = fs / numpy.mean(numpy.diff(time))
- n_special_idx = len(special_idx)
- if n_special_idx > 0:
- data = data[:, special_idx]
- nS = data.shape[1]
- if len(labels) > n_special_idx:
- labels = numpy.array([str(ilbl) + ". " + str(labels[ilbl]) for ilbl in special_idx])
- elif len(labels) == n_special_idx:
- labels = numpy.array([str(ilbl) + ". " + str(label) for ilbl, label in zip(special_idx, labels)])
- else:
- labels = numpy.array([str(ilbl) for ilbl in special_idx])
- else:
- if len(labels) != nS:
- labels = numpy.array([str(ilbl) for ilbl in range(nS)])
- if nS > 20:
- warning("It is not possible to plot spectral analysis plots for more than 20 signals!")
- return
-
- log_norm = spectral_options.get("log_norm", False)
- mode = spectral_options.get("mode", "psd")
- psd_label = mode
- if log_norm:
- psd_label = "log" + psd_label
- stf, time, freq, psd = time_spectral_analysis(data, fs,
- freq=freq,
- mode=mode,
- nfft=spectral_options.get("nfft"),
- window=spectral_options.get("window", 'hanning'),
- nperseg=spectral_options.get("nperseg", int(numpy.round(fs / 4))),
- detrend=spectral_options.get("detrend", 'constant'),
- noverlap=spectral_options.get("noverlap"),
- f_low=spectral_options.get("f_low", 10.0),
- log_scale=spectral_options.get("log_scale", False))
- min_val = numpy.min(stf.flatten())
- max_val = numpy.max(stf.flatten())
- if nS > 2:
- figsize = self.config.figures.VERY_LARGE_SIZE
- if len(var_label):
- title += ": " % var_label
- fig = pyplot.figure(title, figsize=figsize)
- fig.suptitle(title)
- gs = gridspec.GridSpec(nS, 23)
- ax = numpy.empty((nS, 2), dtype="O")
- img = numpy.empty((nS,), dtype="O")
- line = numpy.empty((nS,), dtype="O")
- for iS in range(nS, -1, -1):
- if iS < nS - 1:
- ax[iS, 0] = pyplot.subplot(gs[iS, :20], sharex=ax[iS, 0])
- ax[iS, 1] = pyplot.subplot(gs[iS, 20:22], sharex=ax[iS, 1], sharey=ax[iS, 0])
- else:
- # TODO: find and correct bug here
- ax[iS, 0] = pyplot.subplot(gs[iS, :20])
- ax[iS, 1] = pyplot.subplot(gs[iS, 20:22], sharey=ax[iS, 0])
- img[iS] = ax[iS, 0].imshow(numpy.squeeze(stf[:, :, iS]).T, cmap=pyplot.set_cmap('jet'),
- interpolation='none',
- norm=Normalize(vmin=min_val, vmax=max_val), aspect='auto', origin='lower',
- extent=(time.min(), time.max(), freq.min(), freq.max()))
- # img[iS].clim(min_val, max_val)
- ax[iS, 0].set_title(labels[iS])
- ax[iS, 0].set_ylabel("Frequency (Hz)")
- line[iS] = ax[iS, 1].plot(psd[:, iS], freq, 'k', label=labels[iS])
- pyplot.setp(ax[iS, 1].get_yticklabels(), visible=False)
- # ax[iS, 1].yaxis.tick_right()
- # ax[iS, 1].yaxis.set_ticks_position('both')
- if iS == (nS - 1):
- ax[iS, 0].set_xlabel("Time (" + time_unit + ")")
-
- ax[iS, 1].set_xlabel(psd_label)
- else:
- pyplot.setp(ax[iS, 0].get_xticklabels(), visible=False)
- pyplot.setp(ax[iS, 1].get_xticklabels(), visible=False)
- ax[iS, 0].autoscale(tight=True)
- ax[iS, 1].autoscale(tight=True)
- # make a color bar
- cax = pyplot.subplot(gs[:, 22])
- pyplot.colorbar(img[0], cax=pyplot.subplot(gs[:, 22])) # fraction=0.046, pad=0.04) #fraction=0.15, shrink=1.0
- cax.set_title(psd_label)
- self._save_figure(pyplot.gcf(), figure_name)
- self._check_show()
- return fig, ax, img, line, time, freq, stf, psd
-
- def plot_ts_spectral_analysis_raster(self, data, time=None, time_unit="ms", freq=None, spectral_options={},
- special_idx=[], labels=[], title='Spectral Analysis', figure_name=None,
- figsize=None):
- if isinstance(data, dict):
- var_labels = data.keys()
- data = data.values()
- else:
- var_labels = []
- if isinstance(data, (list, tuple)):
- data = data[0]
- elif isinstance(data, numpy.ndarray) and data.ndim > 2:
- # Assuming a structure of Time X Space X Variables X Samples
- if data.ndim > 3:
- data = data[:, :, :, 0]
- data = [data[:, :, iv].squeeze() for iv in range(data.shape[2])]
- if len(var_labels) == 0:
- var_labels = [""] * len(data)
- for d, var_label in zip(data, var_labels):
- self._plot_ts_spectral_analysis_raster(d, time, var_label, time_unit, freq, spectral_options,
- special_idx, labels, title, figure_name, figsize)
+ super(TimeSeriesPlotter, self).plot_time_series_interactive(time_series, first_n, **kwargs)
def plot_spectral_analysis_raster(self, time_series, freq=None, spectral_options={},
special_idx=[], labels=[], title='Spectral Analysis', figure_name=None,
figsize=None, **kwargs):
if isinstance(time_series, TimeSeries):
return self.plot_ts_spectral_analysis_raster(numpy.swapaxes(time_series._tvb.data, 1, 2).squeeze(),
- time_series.time, time_series.time_unit, freq, spectral_options,
- special_idx, labels, title, figure_name, figsize)
- elif isinstance(time_series, TimeSeriesTVB):
- return self.plot_ts_spectral_analysis_raster(numpy.swapaxes(time_series.data, 1, 2).squeeze(),
- time_series.time, time_series.time_unit, freq, spectral_options,
+ time_series.time, time_series.time_unit, freq,
+ spectral_options,
special_idx, labels, title, figure_name, figsize)
- elif isinstance(time_series, (numpy.ndarray, dict, list, tuple)):
- time = kwargs.get("time", None)
- return self.plot_ts_spectral_analysis_raster(time_series, time=time, freq=freq,
- spectral_options=spectral_options, special_idx=special_idx,
- labels=labels, title=title, figure_name=figure_name,
- figsize=figsize)
else:
- raise_value_error("Input time_series: %s \n"
- "is not on of one of the following types: "
- "[TimeSeries (tvb-scripts), TimeSeries (TVB), numpy.ndarray, dict]" % str(time_series))
+ super(TimeSeriesPlotter, self).plot_spectral_analysis_raster(time_series, freq, spectral_options,
+ special_idx, labels, title, figure_name,
+ figsize, **kwargs)
diff --git a/tvb_scripts/tests/base.py b/tvb_scripts/tests/base.py
index d3ea776..34c2506 100644
--- a/tvb_scripts/tests/base.py
+++ b/tvb_scripts/tests/base.py
@@ -2,7 +2,8 @@
import os
import numpy
-from tvb_scripts.config import Config
+from tvb.simulator.plot.config import Config
+
from tvb_scripts.io.h5_reader import H5Reader
from tvb_scripts.datatypes.connectivity import Connectivity
from tvb_scripts.datatypes.sensors import Sensors
diff --git a/tvb_scripts/utils/computations_utils.py b/tvb_scripts/utils/computations_utils.py
index 2255206..b8b03b8 100644
--- a/tvb_scripts/utils/computations_utils.py
+++ b/tvb_scripts/utils/computations_utils.py
@@ -1,14 +1,13 @@
# coding=utf-8
# Some math tools
from itertools import product
-from sklearn.cluster import AgglomerativeClustering
import numpy as np
+from sklearn.cluster import AgglomerativeClustering
+from tvb.simulator.plot.config import CalculusConfig, FiguresConfig
-from tvb_scripts.config import FiguresConfig, CalculusConfig
-from tvb_scripts.utils.log_error_utils import initialize_logger, warning
from tvb_scripts.utils.data_structures_utils import is_integer
-
+from tvb_scripts.utils.log_error_utils import initialize_logger, warning
logger = initialize_logger(__name__)
@@ -84,7 +83,7 @@ def select_greater_values_array_inds(values, threshold=None, percentile=None, nv
def select_greater_values_2Darray_inds(values, threshold=None, percentile=None, nvals=None, verbose=False):
return np.unravel_index(
- select_greater_values_array_inds(values.flatten(), threshold, percentile, nvals, verbose), values.shape)
+ select_greater_values_array_inds(values.flatten(), threshold, percentile, nvals, verbose), values.shape)
def select_by_hierarchical_group_metric_clustering(distance, disconnectivity=np.array([]), metric=None,
@@ -102,10 +101,10 @@ def select_by_hierarchical_group_metric_clustering(distance, disconnectivity=np.
# ... at least members_per_group elements...
n_select = np.minimum(members_per_group, len(cluster_inds))
if metric is not None and len(metric) == distance.shape[0]:
- #...optionally according to some metric
+ # ...optionally according to some metric
inds_select = np.argsort(metric[cluster_inds])[-n_select:]
else:
- #...otherwise, randomly
+ # ...otherwise, randomly
inds_select = range(n_select)
selection.append(cluster_inds[inds_select])
return np.unique(np.hstack(selection)).tolist()
@@ -187,7 +186,7 @@ def onclick(self, event):
def spikes_events_to_time_index(spike_time, time):
if spike_time < time[0] or spike_time > time[-1]:
warning("Spike time is outside the input time vector!")
- return np.argmin(np.abs(time-spike_time))
+ return np.argmin(np.abs(time - spike_time))
def compute_spikes_counts(spikes_times, time):
diff --git a/tvb_scripts/utils/data_structures_utils.py b/tvb_scripts/utils/data_structures_utils.py
index 21e4a48..c551b9e 100644
--- a/tvb_scripts/utils/data_structures_utils.py
+++ b/tvb_scripts/utils/data_structures_utils.py
@@ -1,14 +1,15 @@
# coding=utf-8
-# Data structure manipulations and conversions
-from six import string_types
import re
-import numpy as np
from collections import OrderedDict
from copy import deepcopy
-from tvb_scripts.utils.log_error_utils import warning, raise_value_error, raise_import_error, initialize_logger
-from tvb_scripts.config import CalculusConfig
+import numpy as np
+# Data structure manipulations and conversions
+from six import string_types
+from tvb.simulator.plot.config import CalculusConfig
+
+from tvb_scripts.utils.log_error_utils import warning, raise_value_error, raise_import_error, initialize_logger
logger = initialize_logger(__name__)
@@ -280,7 +281,7 @@ def ensure_list(arg):
arg = [arg]
elif hasattr(arg, "__iter__"):
arg = list(arg)
- else: # if not iterable
+ else: # if not iterable
arg = [arg]
except: # if not iterable
arg = [arg]
@@ -331,7 +332,7 @@ def get_list_or_tuple_item_safely(obj, key):
def delete_list_items_by_indices(lin, inds, start_ind=0):
lout = []
for ind, l in enumerate(lin):
- if ind+start_ind not in inds:
+ if ind + start_ind not in inds:
lout.append(l)
return lout
@@ -359,6 +360,7 @@ def rotate_n_list_elements(lst, n):
old_lst = old_lst[1:] + old_lst[:1]
return lst
+
def linear_index_to_coordinate_tuples(linear_index, shape):
if len(linear_index) > 0:
coordinates_tuple = np.unravel_index(linear_index, shape)
@@ -382,8 +384,8 @@ def find_labels_inds(labels, keys, modefun="find", two_way_search=False, break_a
for key in keys:
for label in labels:
if modefun(label, key):
- inds.append(labels.index(label))
- counts += 1
+ inds.append(labels.index(label))
+ counts += 1
if counts >= break_after:
return inds
return inds
@@ -807,4 +809,3 @@ def property_to_fun(property):
return property
else:
return lambda *args, **kwargs: property
-
diff --git a/tvb_scripts/utils/log_error_utils.py b/tvb_scripts/utils/log_error_utils.py
index 2c8b2b6..718911e 100644
--- a/tvb_scripts/utils/log_error_utils.py
+++ b/tvb_scripts/utils/log_error_utils.py
@@ -1,11 +1,12 @@
# coding=utf-8
# Logs and errors
+import logging
import os
import sys
-import logging
from logging.handlers import TimedRotatingFileHandler
-from tvb_scripts.config import OutputConfig
+
+from tvb.simulator.plot.config import OutputConfig
def initialize_logger(name, target_folder=OutputConfig().FOLDER_LOGS):