diff --git a/tvb_scripts/datatypes/base.py b/tvb_scripts/datatypes/base.py index 57da232..c85d012 100644 --- a/tvb_scripts/datatypes/base.py +++ b/tvb_scripts/datatypes/base.py @@ -1,7 +1,9 @@ from copy import deepcopy -from tvb_scripts.utils.log_error_utils import warning + +from tvb.basic.neotraits.api import HasTraits + from tvb_scripts.utils.data_structures_utils import labels_to_inds -from tvb.basic.neotraits.api import HasTraits, Attr +from tvb_scripts.utils.log_error_utils import warning class BaseModel(HasTraits): @@ -40,4 +42,3 @@ def from_tvb_file(cls, filepath, **kwargs): @staticmethod def labels2inds(all_labels, labels): return labels_to_inds(all_labels, labels) - diff --git a/tvb_scripts/datatypes/connectivity.py b/tvb_scripts/datatypes/connectivity.py index dd84778..c7b9f7c 100644 --- a/tvb_scripts/datatypes/connectivity.py +++ b/tvb_scripts/datatypes/connectivity.py @@ -1,8 +1,9 @@ # coding=utf-8 -from tvb_scripts.datatypes.base import BaseModel from tvb.datatypes.connectivity import Connectivity as TVBConnectivity +from tvb_scripts.datatypes.base import BaseModel + class ConnectivityH5Field(object): WEIGHTS = "weights" @@ -30,4 +31,4 @@ def centers(self): # A usefull method for addressing subsets of the connectome by label: def get_regions_inds_by_labels(self, labels): - return self.labels2inds(self.region_labels, labels) + return self.labels2inds(self.region_labels, labels) diff --git a/tvb_scripts/datatypes/head.py b/tvb_scripts/datatypes/head.py index b587348..ae75b70 100644 --- a/tvb_scripts/datatypes/head.py +++ b/tvb_scripts/datatypes/head.py @@ -1,17 +1,17 @@ # coding=utf-8 -from six import string_types from collections import OrderedDict -from tvb_scripts.utils.log_error_utils import initialize_logger, raise_value_error, warning -from tvb_scripts.utils.data_structures_utils import isequal_string, is_integer -from tvb_scripts.datatypes.connectivity import Connectivity -from tvb_scripts.datatypes.sensors import SensorTypes, SensorTypesNames, SensorTypesToProjectionDict - -from tvb.datatypes.surfaces import CorticalSurface +from six import string_types from tvb.datatypes.cortex import Cortex -from tvb.datatypes.sensors import Sensors from tvb.datatypes.projections import ProjectionMatrix +from tvb.datatypes.sensors import Sensors +from tvb.datatypes.surfaces import CorticalSurface + +from tvb_scripts.datatypes.connectivity import Connectivity +from tvb_scripts.datatypes.sensors import SensorTypes, SensorTypesNames, SensorTypesToProjectionDict +from tvb_scripts.utils.data_structures_utils import isequal_string, is_integer +from tvb_scripts.utils.log_error_utils import initialize_logger, raise_value_error, warning class Head(object): diff --git a/tvb_scripts/datatypes/sensors.py b/tvb_scripts/datatypes/sensors.py index bfaebae..96991cc 100644 --- a/tvb_scripts/datatypes/sensors.py +++ b/tvb_scripts/datatypes/sensors.py @@ -2,21 +2,20 @@ from enum import Enum import numpy as np - -from tvb_scripts.utils.log_error_utils import warning -from tvb_scripts.utils.data_structures_utils import ensure_list, \ - labels_to_inds, monopolar_to_bipolar, split_string_text_numbers -from tvb_scripts.datatypes.base import BaseModel - from tvb.basic.neotraits.api import Attr, NArray +from tvb.datatypes.projections import \ + ProjectionSurfaceEEG, ProjectionSurfaceMEG, ProjectionSurfaceSEEG +from tvb.datatypes.sensors import EEG_POLYMORPHIC_IDENTITY, MEG_POLYMORPHIC_IDENTITY, \ + INTERNAL_POLYMORPHIC_IDENTITY from tvb.datatypes.sensors import Sensors as TVBSensors from tvb.datatypes.sensors import SensorsEEG as TVBSensorsEEG -from tvb.datatypes.sensors import SensorsMEG as TVBSensorsMEG from tvb.datatypes.sensors import SensorsInternal as TVBSensorsInternal -from tvb.datatypes.sensors import EEG_POLYMORPHIC_IDENTITY, MEG_POLYMORPHIC_IDENTITY, \ - INTERNAL_POLYMORPHIC_IDENTITY -from tvb.datatypes.projections import \ - ProjectionSurfaceEEG, ProjectionSurfaceMEG, ProjectionSurfaceSEEG +from tvb.datatypes.sensors import SensorsMEG as TVBSensorsMEG + +from tvb_scripts.datatypes.base import BaseModel +from tvb_scripts.utils.data_structures_utils import ensure_list, \ + labels_to_inds, monopolar_to_bipolar, split_string_text_numbers +from tvb_scripts.utils.log_error_utils import warning class SensorTypes(Enum): @@ -28,7 +27,6 @@ class SensorTypes(Enum): SensorTypesNames = [getattr(SensorTypes, stype).value for stype in SensorTypes.__members__] - SensorTypesToProjectionDict = {"EEG": ProjectionSurfaceEEG, "MEG": ProjectionSurfaceMEG, "SEEG": ProjectionSurfaceSEEG, @@ -51,8 +49,8 @@ class Sensors(TVBSensors, BaseModel): def configure(self, remove_leading_zeros_from_labels=False): if len(self.labels) > 0: - if remove_leading_zeros_from_labels: - self.remove_leading_zeros_from_labels() + if remove_leading_zeros_from_labels: + self.remove_leading_zeros_from_labels() self.configure() def sensor_label_to_index(self, labels): @@ -86,7 +84,7 @@ def get_bipolar_sensors(self, sensors_inds=None): class SensorsEEG(Sensors, TVBSensorsEEG): - pass + pass class SensorsMEG(Sensors, TVBSensorsMEG): diff --git a/tvb_scripts/datatypes/surface.py b/tvb_scripts/datatypes/surface.py index 557db3b..bc5c85b 100644 --- a/tvb_scripts/datatypes/surface.py +++ b/tvb_scripts/datatypes/surface.py @@ -1,17 +1,17 @@ # coding=utf-8 import numpy as np - -from tvb_scripts.datatypes.base import BaseModel from tvb.basic.neotraits.api import NArray, Attr -from tvb.datatypes.surfaces import Surface as TVBSurface -from tvb.datatypes.surfaces import WhiteMatterSurface as TVBWhiteMatterSurface -from tvb.datatypes.surfaces import CorticalSurface as TVBCorticalSurface -from tvb.datatypes.surfaces import SkinAir as TVBSkinAir from tvb.datatypes.surfaces import BrainSkull as TVBBrainSkull -from tvb.datatypes.surfaces import SkullSkin as TVBSkullSkin +from tvb.datatypes.surfaces import CorticalSurface as TVBCorticalSurface from tvb.datatypes.surfaces import EEGCap as TVBEEGCap from tvb.datatypes.surfaces import FaceSurface as TVBFaceSurface +from tvb.datatypes.surfaces import SkinAir as TVBSkinAir +from tvb.datatypes.surfaces import SkullSkin as TVBSkullSkin +from tvb.datatypes.surfaces import Surface as TVBSurface +from tvb.datatypes.surfaces import WhiteMatterSurface as TVBWhiteMatterSurface + +from tvb_scripts.datatypes.base import BaseModel class SurfaceH5Field(object): @@ -23,7 +23,6 @@ class SurfaceH5Field(object): class Surface(TVBSurface, BaseModel): - vox2ras = NArray( dtype=np.float, label="vox2ras", default=None, required=False, @@ -52,9 +51,9 @@ def get_vertex_areas(self): return vertex_areas def add_vertices_and_triangles(self, new_vertices, new_triangles, - new_vertex_normals=np.array([]), new_triangle_normals=np.array([])): + new_vertex_normals=np.array([]), new_triangle_normals=np.array([])): self.triangles = np.array(self.triangles.tolist() + - (new_triangles + self.number_of_vertices).tolist()) + (new_triangles + self.number_of_vertices).tolist()) self.vertices = np.array(self.vertices.tolist() + new_vertices.tolist()) self.vertex_normals = np.array(self.vertex_normals.tolist() + new_vertex_normals.tolist()) self.triangle_normals = np.array(self.triangle_normals.tolist() + new_triangle_normals.tolist()) diff --git a/tvb_scripts/datatypes/time_series.py b/tvb_scripts/datatypes/time_series.py index a17d21d..5f97990 100644 --- a/tvb_scripts/datatypes/time_series.py +++ b/tvb_scripts/datatypes/time_series.py @@ -1,22 +1,23 @@ # -*- coding: utf-8 -*- -from six import string_types -from enum import Enum from copy import deepcopy +from enum import Enum import numpy -from tvb_scripts.utils.log_error_utils import initialize_logger, warning -from tvb_scripts.utils.data_structures_utils import ensure_list, is_integer, monopolar_to_bipolar +from six import string_types from tvb.basic.neotraits.api import List, Attr from tvb.basic.profile import TvbProfile +from tvb.datatypes.sensors import Sensors, SensorsEEG, SensorsMEG, SensorsInternal from tvb.datatypes.time_series import TimeSeries as TimeSeriesTVB -from tvb.datatypes.time_series import TimeSeriesRegion as TimeSeriesRegionTVB from tvb.datatypes.time_series import TimeSeriesEEG as TimeSeriesEEGTVB from tvb.datatypes.time_series import TimeSeriesMEG as TimeSeriesMEGTVB +from tvb.datatypes.time_series import TimeSeriesRegion as TimeSeriesRegionTVB from tvb.datatypes.time_series import TimeSeriesSEEG as TimeSeriesSEEGTVB from tvb.datatypes.time_series import TimeSeriesSurface as TimeSeriesSurfaceTVB from tvb.datatypes.time_series import TimeSeriesVolume as TimeSeriesVolumeTVB -from tvb.datatypes.sensors import Sensors, SensorsEEG, SensorsMEG, SensorsInternal + +from tvb_scripts.utils.data_structures_utils import ensure_list, is_integer, monopolar_to_bipolar +from tvb_scripts.utils.log_error_utils import initialize_logger, warning TvbProfile.set_profile(TvbProfile.LIBRARY_PROFILE) @@ -252,16 +253,16 @@ def slice_data_across_dimension_by_index(self, indices, dimension, **kwargs): def slice_data_across_dimension_by_label(self, labels, dimension, **kwargs): dim_index = self.get_dimension_index(dimension) return self.slice_data_across_dimension_by_index( - self._get_index_of_label(labels, - self.get_dimension_name(dim_index)), - dim_index, **kwargs) + self._get_index_of_label(labels, + self.get_dimension_name(dim_index)), + dim_index, **kwargs) def slice_data_across_dimension_by_slice(self, slice_arg, dimension, **kwargs): dim_index = self.get_dimension_index(dimension) return self.slice_data_across_dimension_by_index( - self._slice_to_indices( - self._process_slice(slice_arg, dim_index), dim_index), - dim_index, **kwargs) + self._slice_to_indices( + self._process_slice(slice_arg, dim_index), dim_index), + dim_index, **kwargs) def _index_or_label_or_slice(self, inputs): inputs = ensure_list(inputs) @@ -589,7 +590,6 @@ def SEEGsensor_labels(self): TimeSeriesMEG.__name__: TimeSeriesMEG, TimeSeriesSEEG.__name__: TimeSeriesSEEG} - if __name__ == "__main__": kwargs = {"data": numpy.ones((4, 2, 10, 1)), "start_time": 0.0, "labels_dimensions": {LABELS_ORDERING[1]: ["x", "y"]}} diff --git a/tvb_scripts/datatypes/time_series_xarray.py b/tvb_scripts/datatypes/time_series_xarray.py index 5c5a3e7..db08090 100644 --- a/tvb_scripts/datatypes/time_series_xarray.py +++ b/tvb_scripts/datatypes/time_series_xarray.py @@ -35,11 +35,13 @@ """ from copy import deepcopy -from six import string_types -import xarray as xr + import numpy as np -from tvb.datatypes import sensors, surfaces, volumes, region_mapping, connectivity +import xarray as xr +from six import string_types from tvb.basic.neotraits.api import HasTraits, Attr, List, narray_summary_info +from tvb.datatypes import sensors, surfaces, volumes, region_mapping, connectivity + from tvb_scripts.datatypes.time_series import TimeSeries as TimeSeriesTVB from tvb_scripts.utils.data_structures_utils import is_integer @@ -358,7 +360,7 @@ def duplicate(self, **kwargs): def _assert_array_indices(self, slice_tuple): if is_integer(slice_tuple) or isinstance(slice_tuple, string_types): - return ([slice_tuple], ) + return ([slice_tuple],) else: if isinstance(slice_tuple, slice): slice_tuple = (slice_tuple,) @@ -547,9 +549,9 @@ def plot_timeseries(self, **kwargs): outputs.append(self[:, var].plot_timeseries(**kwargs)) return outputs if np.any([s < 2 for s in self.shape[1:]]): - if self.shape[1] == 1: # only one variable + if self.shape[1] == 1: # only one variable figname = kwargs.pop("figname", "%s" % (self.title + "Time Series")) + ": " \ - + self.labels_dimensions[self.labels_ordering[1]][0] + + self.labels_dimensions[self.labels_ordering[1]][0] kwargs["figname"] = figname return self.plot_line(**kwargs) else: diff --git a/tvb_scripts/io/edf.py b/tvb_scripts/io/edf.py index 3b3037d..2220e6b 100644 --- a/tvb_scripts/io/edf.py +++ b/tvb_scripts/io/edf.py @@ -2,9 +2,9 @@ import numpy as np +from tvb_scripts.datatypes.time_series import TimeSeries, TimeSeriesDimensions from tvb_scripts.utils.log_error_utils import initialize_logger from tvb_scripts.utils.data_structures_utils import ensure_string -from tvb_scripts.model.timeseries import Timeseries, TimeseriesDimensions def read_edf_with_mne(path, exclude_channels): @@ -79,5 +79,5 @@ def read_edf_to_Timeseries(path, sensors, rois_selection=None, label_strip_fun=N data, times, rois, rois_inds, rois_lbls = \ read_edf(path, sensors, rois_selection, label_strip_fun, time_unit) - return Timeseries(data, time=times, labels_dimensions={TimeseriesDimensions.SPACE.value: rois_lbls}, + return TimeSeries(data, time=times, labels_dimensions={TimeSeriesDimensions.SPACE.value: rois_lbls}, sample_period=np.mean(np.diff(times)), sample_period_unit=time_unit, **kwargs) diff --git a/tvb_scripts/io/h5_reader_base.py b/tvb_scripts/io/h5_reader_base.py index 40e4dcf..420383f 100644 --- a/tvb_scripts/io/h5_reader_base.py +++ b/tvb_scripts/io/h5_reader_base.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- import os + import h5py -from tvb_scripts.utils.log_error_utils import initialize_logger from tvb_scripts.io.h5_writer import H5Writer +from tvb_scripts.utils.log_error_utils import initialize_logger class H5ReaderBase(object): @@ -33,6 +34,7 @@ def _log_success(self, name, path=None): class H5GroupHandlers(object): + H5_SUBTYPE_ATTRIBUTE = H5Writer().H5_SUBTYPE_ATTRIBUTE def read_dictionary_from_group(self, group, type=None): dictionary = dict() @@ -41,6 +43,6 @@ def read_dictionary_from_group(self, group, type=None): for attr in group.attrs.keys(): dictionary.update({attr: group.attrs[attr]}) if type is None: - type = group.attrs[H5_SUBTYPE_ATTRIBUTE] + type = group.attrs[self.H5_SUBTYPE_ATTRIBUTE] else: return dictionary diff --git a/tvb_scripts/plot/base_plotter.py b/tvb_scripts/plot/base_plotter.py deleted file mode 100644 index 4398cdc..0000000 --- a/tvb_scripts/plot/base_plotter.py +++ /dev/null @@ -1,449 +0,0 @@ -# coding=utf-8 - -import os -import numpy - -from tvb_scripts.config import Config, FiguresConfig - -import matplotlib -matplotlib.use(FiguresConfig().MATPLOTLIB_BACKEND) -from matplotlib import pyplot -pyplot.rcParams["font.size"] = FiguresConfig.FONTSIZE -from mpl_toolkits.axes_grid1 import make_axes_locatable - -from tvb_scripts.utils.log_error_utils import initialize_logger, warning -from tvb_scripts.utils.data_structures_utils import ensure_list, generate_region_labels - - -class BasePlotter(object): - - def __init__(self, config=None): - self.config = config or Config() - self.logger = initialize_logger(self.__class__.__name__, self.config.out.FOLDER_LOGS) - self.print_regions_indices = True - - def _check_show(self): - if self.config.figures.SHOW_FLAG: - # mp.use('TkAgg') - pyplot.ion() - pyplot.show() - else: - # mp.use('Agg') - pyplot.ioff() - pyplot.close() - - @staticmethod - def _figure_filename(fig=pyplot.gcf(), figure_name=None): - if figure_name is None: - figure_name = fig.get_label() - figure_name = figure_name.replace(": ", "_").replace(" ", "_").replace("\t", "_").replace(",", "") - return figure_name - - def _save_figure(self, fig=pyplot.gcf(), figure_name=None): - if self.config.figures.SAVE_FLAG: - figure_name = self._figure_filename(fig, figure_name) - figure_name = figure_name[:numpy.min([100, len(figure_name)])] + '.' + self.config.figures.FIG_FORMAT - figure_dir = self.config.out.FOLDER_FIGURES - if not (os.path.isdir(figure_dir)): - os.mkdir(figure_dir) - pyplot.savefig(os.path.join(figure_dir, figure_name)) - - @staticmethod - def rect_subplot_shape(n, mode="col"): - nj = int(numpy.ceil(numpy.sqrt(n))) - ni = int(numpy.ceil(1.0 * n / nj)) - if mode.find("row") >= 0: - return nj, ni - else: - return ni, nj - - def plot_vector(self, vector, labels, subplot, title, show_y_labels=True, indices_red=None, sharey=None): - ax = pyplot.subplot(subplot, sharey=sharey) - pyplot.title(title) - n_vector = labels.shape[0] - y_ticks = numpy.array(range(n_vector), dtype=numpy.int32) - color = 'k' - colors = numpy.repeat([color], n_vector) - coldif = False - if indices_red is not None: - colors[indices_red] = 'r' - coldif = True - if len(vector.shape) == 1: - ax.barh(y_ticks, vector, color=colors, align='center') - else: - ax.barh(y_ticks, vector[0, :], color=colors, align='center') - # ax.invert_yaxis() - ax.grid(True, color='grey') - ax.set_yticks(y_ticks) - if show_y_labels: - region_labels = generate_region_labels(n_vector, labels, ". ", self.print_regions_indices) - ax.set_yticklabels(region_labels) - if coldif: - labels = ax.yaxis.get_ticklabels() - for ids in indices_red: - labels[ids].set_color('r') - ax.yaxis.set_ticklabels(labels) - else: - ax.set_yticklabels([]) - ax.autoscale(tight=True) - if sharey is None: - ax.invert_yaxis() - return ax - - def plot_vector_violin(self, dataset, vector=[], lines=[], labels=[], subplot=111, title="", violin_flag=True, - colormap="YlOrRd", show_y_labels=True, indices_red=None, sharey=None): - ax = pyplot.subplot(subplot, sharey=sharey) - pyplot.title(title) - n_violins = dataset.shape[1] - y_ticks = numpy.array(range(n_violins), dtype=numpy.int32) - # the vector plot - coldif = False - if indices_red is None: - indices_red = [] - if violin_flag: - # the violin plot - colormap = matplotlib.cm.ScalarMappable(cmap=pyplot.set_cmap(colormap)) - colormap = colormap.to_rgba(numpy.mean(dataset, axis=0), alpha=0.75) - violin_parts = ax.violinplot(dataset, y_ticks, vert=False, widths=0.9, - showmeans=True, showmedians=True, showextrema=True) - violin_parts['cmeans'].set_color("k") - violin_parts['cmins'].set_color("b") - violin_parts['cmaxes'].set_color("b") - violin_parts['cbars'].set_color("b") - violin_parts['cmedians'].set_color("b") - for ii in range(len(violin_parts['bodies'])): - violin_parts['bodies'][ii].set_color(numpy.reshape(colormap[ii], (1, 4))) - violin_parts['bodies'][ii]._alpha = 0.75 - violin_parts['bodies'][ii]._edgecolors = numpy.reshape(colormap[ii], (1, 4)) - violin_parts['bodies'][ii]._facecolors = numpy.reshape(colormap[ii], (1, 4)) - else: - colorcycle = pyplot.rcParams['axes.prop_cycle'].by_key()['color'] - n_samples = dataset.shape[0] - for ii in range(n_violins): - for jj in range(n_samples): - ax.plot(dataset[jj, ii], y_ticks[ii], "D", - mfc=colorcycle[jj%n_samples], mec=colorcycle[jj%n_samples], ms=20) - color = 'k' - colors = numpy.repeat([color], n_violins) - if indices_red is not None: - colors[indices_red] = 'r' - coldif = True - if len(vector) == n_violins: - for ii in range(n_violins): - ax.plot(vector[ii], y_ticks[ii], '*', mfc=colors[ii], mec=colors[ii], ms=10) - if len(lines) == 2 and lines[0].shape[0] == n_violins and lines[1].shape[0] == n_violins: - for ii in range(n_violins): - yy = (y_ticks[ii] - 0.45*lines[1][ii]/numpy.max(lines[1][ii]))\ - * numpy.ones(numpy.array(lines[0][ii]).shape) - ax.plot(lines[0][ii], yy, '--', color=colors[ii]) - - ax.grid(True, color='grey') - ax.set_yticks(y_ticks) - if show_y_labels: - region_labels = generate_region_labels(n_violins, labels, ". ", self.print_regions_indices) - ax.set_yticklabels(region_labels) - if coldif: - labels = ax.yaxis.get_ticklabels() - for ids in indices_red: - labels[ids].set_color('r') - ax.yaxis.set_ticklabels(labels) - else: - ax.set_yticklabels([]) - if sharey is None: - ax.invert_yaxis() - ax.autoscale() - return ax - - def _plot_matrix(self, matrix, xlabels, ylabels, subplot=111, title="", show_x_labels=True, show_y_labels=True, - x_ticks=numpy.array([]), y_ticks=numpy.array([]), indices_red_x=None, indices_red_y=None, - sharex=None, sharey=None, cmap='autumn_r', vmin=None, vmax=None): - ax = pyplot.subplot(subplot, sharex=sharex, sharey=sharey) - pyplot.title(title) - nx, ny = matrix.shape - indices_red = [indices_red_x, indices_red_y] - ticks = [x_ticks, y_ticks] - labels = [xlabels, ylabels] - nticks = [] - for ii, (n, tick) in enumerate(zip([nx, ny], ticks)): - if len(tick) == 0: - ticks[ii] = numpy.array(range(n), dtype=numpy.int32) - nticks.append(len(ticks[ii])) - cmap = pyplot.set_cmap(cmap) - img = pyplot.imshow(matrix[ticks[0]][:, ticks[1]].T, cmap=cmap, vmin=vmin, vmax=vmax, interpolation='none') - pyplot.grid(True, color='black') - for ii, (xy, tick, ntick, ind_red, show, lbls, rot) in enumerate(zip(["x", "y"], ticks, nticks, indices_red, - [show_x_labels, show_y_labels], labels, [90, 0])): - if show: - labels[ii] = generate_region_labels(len(tick), numpy.array(lbls)[tick], ". ", self.print_regions_indices, tick) - # labels[ii] = numpy.array(["%d. %s" % l for l in zip(tick, lbls[tick])]) - getattr(pyplot, xy + "ticks")(numpy.array(range(ntick)), labels[ii], rotation=rot) - else: - labels[ii] = numpy.array(["%d." % l for l in tick]) - getattr(pyplot, xy + "ticks")(numpy.array(range(ntick)), labels[ii]) - if ind_red is not None: - tck = tick.tolist() - ticklabels = getattr(ax, xy + "axis").get_ticklabels() - for iidx, indr in enumerate(ind_red): - try: - ticklabels[tck.index(indr)].set_color('r') - except: - pass - getattr(ax, xy + "axis").set_ticklabels(ticklabels) - ax.autoscale(tight=True) - divider = make_axes_locatable(ax) - cax1 = divider.append_axes("right", size="5%", pad=0.05) - pyplot.colorbar(img, cax=cax1) # fraction=0.046, pad=0.04) #fraction=0.15, shrink=1.0 - return ax, cax1 - - def plot_regions2regions(self, adj, labels, subplot, title, show_x_labels=True, show_y_labels=True, - x_ticks=numpy.array([]), y_ticks=numpy.array([]), indices_red_x=None, indices_red_y=None, - sharex=None, sharey=None, cmap='autumn_r', vmin=None, vmax=None): - return self._plot_matrix(adj, labels, labels, subplot, title, show_x_labels, show_y_labels, - x_ticks, y_ticks, indices_red_x, indices_red_y, sharex, sharey, cmap, vmin, vmax) - - def _set_axis_labels(self, fig, sub, n_regions, region_labels, indices2emphasize, color='k', position='left'): - y_ticks = range(n_regions) - # region_labels = numpy.array(["%d. %s" % l for l in zip(y_ticks, region_labels)]) - region_labels = generate_region_labels(len(y_ticks), region_labels, ". ", self.print_regions_indices, y_ticks) - big_ax = fig.add_subplot(sub, frameon=False) - if position == 'right': - big_ax.yaxis.tick_right() - big_ax.yaxis.set_label_position("right") - big_ax.set_yticks(y_ticks) - big_ax.set_yticklabels(region_labels, color='k') - if not (color == 'k'): - labels = big_ax.yaxis.get_ticklabels() - for idx in indices2emphasize: - labels[idx].set_color(color) - big_ax.yaxis.set_ticklabels(labels) - big_ax.invert_yaxis() - big_ax.axes.get_xaxis().set_visible(False) - # TODO: find out what is the next line about and why it fails... - # big_ax.axes.set_facecolor('none') - - def plot_in_columns(self, data_dict_list, labels, width_ratios=[], left_ax_focus_indices=[], - right_ax_focus_indices=[], description="", title="", figure_name=None, - figsize=None, **kwargs): - if not isinstance(figsize, (tuple, list)): - figsize = self.config.figures.VERY_LARGE_SIZE - fig = pyplot.figure(title, frameon=False, figsize=figsize) - fig.suptitle(description) - n_subplots = len(data_dict_list) - if not width_ratios: - width_ratios = numpy.ones((n_subplots,)).tolist() - matplotlib.gridspec.GridSpec(1, n_subplots, width_ratios=width_ratios) - if 10 > n_subplots > 0: - subplot_ind0 = 100 + 10 * n_subplots - else: - raise ValueError("\nSubplots' number " + str(n_subplots) + " is not between 1 and 9!") - n_regions = len(labels) - subplot_ind = subplot_ind0 - ax = None - ax0 = None - for iS, data_dict in enumerate(data_dict_list): - subplot_ind += 1 - data = data_dict["data"] - focus_indices = data_dict.get("focus_indices") - if subplot_ind == 0: - if not left_ax_focus_indices: - left_ax_focus_indices = focus_indices - else: - ax0 = ax - if data_dict.get("plot_type") == "vector_violin": - ax = self.plot_vector_violin(data_dict.get("data_samples", []), data, [], - labels, subplot_ind, data_dict["name"], - colormap=kwargs.get("colormap", "YlOrRd"), show_y_labels=False, - indices_red=focus_indices, sharey=ax0) - elif data_dict.get("plot_type") == "regions2regions": - # TODO: find a more general solution, in case we don't want to apply focus indices to x_ticks - ax = self.plot_regions2regions(data, labels, subplot_ind, data_dict["name"], x_ticks=focus_indices, - show_x_labels=True, show_y_labels=False, indices_red_x=focus_indices, - sharey=ax0) - else: - ax = self.plot_vector(data, labels, subplot_ind, data_dict["name"], show_y_labels=False, - indices_red=focus_indices, sharey=ax0) - if right_ax_focus_indices == []: - right_ax_focus_indices = focus_indices - self._set_axis_labels(fig, 121, n_regions, labels, left_ax_focus_indices, 'r') - self._set_axis_labels(fig, 122, n_regions, labels, right_ax_focus_indices, 'r', 'right') - self._save_figure(pyplot.gcf(), figure_name) - self._check_show() - return fig - - #TODO: name is too generic - def plots(self, data_dict, shape=None, transpose=False, skip=0, xlabels={}, xscales={}, yscales={}, title='Plots', - lgnd={}, figure_name=None, figsize=None): - if not isinstance(figsize, (tuple, list)): - figsize = self.config.figures.VERY_LARGE_SIZE - if shape is None: - shape = (1, len(data_dict)) - fig, axes = pyplot.subplots(shape[0], shape[1], figsize=figsize) - fig.set_label(title) - for i, key in enumerate(data_dict.keys()): - ind = numpy.unravel_index(i, shape) - if transpose: - axes[ind].plot(data_dict[key].T[skip:]) - else: - axes[ind].plot(data_dict[key][skip:]) - axes[ind].set_xscale(xscales.get(key, "linear")) - axes[ind].set_yscale(yscales.get(key, "linear")) - axes[ind].set_xlabel(xlabels.get(key, "")) - axes[ind].set_ylabel(key) - this_legend = lgnd.get(key, None) - if this_legend is not None: - axes[ind].legend(this_legend) - fig.tight_layout() - self._save_figure(fig, figure_name) - self._check_show() - return fig, axes - - def pair_plots(self, data, keys, diagonal_plots={}, transpose=False, skip=0, - title='Pair plots', legend_prefix="", subtitles=None, figure_name=None, figsize=None): - - def confirm_y_coordinate(data, ymax): - data = list(data) - data.append(ymax) - return tuple(data) - - if not isinstance(figsize, (tuple, list)): - figsize = self.config.figures.VERY_LARGE_SIZE - - if subtitles is None: - subtitles = keys - data = ensure_list(data) - n = len(keys) - fig, axes = pyplot.subplots(n, n, figsize=figsize) - fig.set_label(title) - colorcycle = pyplot.rcParams['axes.prop_cycle'].by_key()['color'] - for i, key_i in enumerate(keys): - for j, key_j in enumerate(keys): - for datai in data: - if transpose: - di = datai[key_i].T[skip:] - else: - di = datai[key_i][skip:] - if i == j: - if di.shape[0] > 1: - hist_data = axes[i, j].hist(di, int(numpy.round(numpy.sqrt(len(di)))), log=True)[0] - if i == 0 and len(di.shape) > 1 and di.shape[1] > 1: - axes[i, j].legend([legend_prefix + str(ii + 1) for ii in range(di.shape[1])]) - y_max = numpy.array(hist_data).max() - # The mean line - axes[i, j].vlines(di.mean(axis=0), 0, y_max, color=colorcycle, linestyle='dashed', - linewidth=1) - else: - # This is for the case of only 1 sample (optimization) - y_max = 1.0 - for ii in range(di.shape[1]): - axes[i, j].plot(di[0, ii], y_max, "D", color=colorcycle[ii%di.shape[1]], markersize=20, - label=legend_prefix + str(ii + 1)) - if i == 0 and len(di.shape) > 1 and di.shape[1] > 1: - axes[i, j].legend() - # Plot a line (or marker) in the same axis - diag_line_plot = ensure_list(diagonal_plots.get(key_i, ((), ()))[0]) - if len(diag_line_plot) in [1, 2]: - if len(diag_line_plot) == 1 : - diag_line_plot = confirm_y_coordinate(diag_line_plot, y_max) - else: - diag_line_plot[1] = diag_line_plot[1]/numpy.max(diag_line_plot[1])*y_max - if len(ensure_list(diag_line_plot[0])) == 1: - axes[i, j].plot(diag_line_plot[0], diag_line_plot[1], "o", mfc="k", mec="k", - markersize=10) - else: - axes[i, j].plot(diag_line_plot[0], diag_line_plot[1], color='k', - linestyle="dashed", linewidth=1) - # Plot a marker in the same axis - diag_marker_plot = ensure_list(diagonal_plots.get(key_i, ((), ()))[1]) - if len(diag_marker_plot) in [1, 2]: - if len(diag_marker_plot) == 1: - diag_marker_plot = confirm_y_coordinate(diag_marker_plot, y_max) - axes[i, j].plot(diag_marker_plot[0], diag_marker_plot[1], "*", color='k', markersize=10) - axes[i, j].autoscale() - axes[i, j].set_ylim([0, 1.1*y_max]) - - else: - if transpose: - dj = datai[key_j].T[skip:] - else: - dj = datai[key_j][skip:] - axes[i, j].plot(dj, di, '.') - if i == 0: - axes[i, j].set_title(subtitles[j]) - if j == 0: - axes[i, j].set_ylabel(key_i) - fig.tight_layout() - self._save_figure(fig, figure_name) - self._check_show() - return fig, axes - - def plot_bars(self, data, ax=None, fig=None, title="", group_names=[], legend_prefix="", figsize=None): - - def barlabel(ax, rects, positions): - """ - Attach a text label on each bar displaying its height - """ - for rect, pos in zip(rects, positions): - height = rect.get_height() - if pos < 0: - y = -height - pos = 0.75 * pos - else: - y = height - pos = 0.25 * pos - ax.text(rect.get_x() + rect.get_width() / 2., pos, '%0.2f' % y, - color="k", ha='center', va='bottom', rotation=90) - if fig is None: - if not isinstance(figsize, (tuple, list)): - figsize = self.config.figures.VERY_LARGE_SIZE - fig, ax = pyplot.subplots(1, 1, figsize=figsize) - show_and_save = True - else: - show_and_save = False - if ax is None: - ax = pyplot.gca() - if isinstance(data, (list, tuple)): # If, there are many groups, data is a list: - # Fill in with nan in case that not all groups have the same number of elements - from itertools import izip_longest - data = numpy.array(list(izip_longest(*ensure_list(data), fillvalue=numpy.nan))).T - elif data.ndim == 1: # This is the case where there is only one group... - data = numpy.expand_dims(data, axis=1).T - n_groups, n_elements = data.shape - posmax = numpy.nanmax(data) - negmax = numpy.nanmax(-(-data)) - n_groups_names = len(group_names) - if n_groups_names != n_groups: - if n_groups_names != 0: - warning("Ignoring group_names because their number (" + str(n_groups_names) + - ") is not equal to the number of groups (" + str(n_groups) + ")!") - group_names = n_groups * [""] - colorcycle = pyplot.rcParams['axes.prop_cycle'].by_key()['color'] - n_colors = len(colorcycle) - x_inds = numpy.arange(n_groups) - width = 0.9 / n_elements - elements = [] - for iE in range(n_elements): - elements.append(ax.bar(x_inds + iE*width, data[:, iE], width, color=colorcycle[iE % n_colors])) - positions = numpy.array([negmax if d < 0 else posmax for d in data[:, iE]]) - positions[numpy.logical_or(numpy.isnan(positions), numpy.isinf(numpy.abs(positions)))] = 0.0 - barlabel(ax, elements[-1], positions) - if n_elements > 1: - legend = [legend_prefix+str(ii) for ii in range(1, n_elements+1)] - ax.legend(tuple([element[0] for element in elements]), tuple(legend)) - ax.set_xticks(x_inds + n_elements*width/2) - ax.set_xticklabels(tuple(group_names)) - ax.set_title(title) - ax.autoscale() # tight=True - ax.set_xlim([-1.05*width, n_groups*1.05]) - if show_and_save: - fig.tight_layout() - self._save_figure(fig) - self._check_show() - return fig, ax - - def tvb_plot(self, plot_fun_name, *args, **kwargs): - import tvb.simulator.plot.tools as TVB_plot_tools - getattr(TVB_plot_tools, plot_fun_name)(*args, **kwargs) - fig = pyplot.gcf() - self._save_figure(fig) - self._check_show() - return fig diff --git a/tvb_scripts/config.py b/tvb_scripts/plot/config.py similarity index 64% rename from tvb_scripts/config.py rename to tvb_scripts/plot/config.py index 572b49f..3a16396 100644 --- a/tvb_scripts/config.py +++ b/tvb_scripts/plot/config.py @@ -1,12 +1,43 @@ # -*- coding: utf-8 -*- +# +# +# TheVirtualBrain-Scientific Package. This package holds all simulators, and +# analysers necessary to run brain-simulations. You can use it stand alone or +# in conjunction with TheVirtualBrain-Framework Package. See content of the +# documentation-folder for more details. See also http://www.thevirtualbrain.org +# +# (c) 2012-2020, Baycrest Centre for Geriatric Care ("Baycrest") and others +# +# This program is free software: you can redistribute it and/or modify it under the +# terms of the GNU General Public License as published by the Free Software Foundation, +# either version 3 of the License, or (at your option) any later version. +# This program is distributed in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A +# PARTICULAR PURPOSE. See the GNU General Public License for more details. +# You should have received a copy of the GNU General Public License along with this +# program. If not, see . +# +# +# CITATION: +# When using The Virtual Brain for scientific publications, please cite it as follows: +# +# Paula Sanz Leon, Stuart A. Knock, M. Marmaduke Woodman, Lia Domide, +# Jochen Mersmann, Anthony R. McIntosh, Viktor Jirsa (2013) +# The Virtual Brain: a simulator of primate brain network dynamics. +# Frontiers in Neuroinformatics (7:10. doi: 10.3389/fninf.2013.00010) +# +# + +""" +.. moduleauthor:: Dionysios Perdikis +.. moduleauthor:: Gabriel Florea +""" import os -import numpy from datetime import datetime -from tvb.basic.profile import TvbProfile - -TvbProfile.set_profile(TvbProfile.LIBRARY_PROFILE) +import numpy +from tvb.simulator.plot.config import FiguresConfig class GenericConfig(object): @@ -78,18 +109,7 @@ def FOLDER_RES(self): if not (os.path.isdir(folder)): os.makedirs(folder) if self.subfolder is not None: - os.path.join(folder, self.subfolder) - return folder - - @property - def FOLDER_FIGURES(self): - folder = os.path.join(self._out_base, "figs") - if self._separate_by_run: - folder = folder + datetime.strftime(datetime.now(), '%Y-%m-%d_%H-%M') - if not (os.path.isdir(folder)): - os.makedirs(folder) - if self.subfolder is not None: - os.path.join(folder, self.subfolder) + folder = os.path.join(folder, self.subfolder) return folder @property @@ -97,33 +117,6 @@ def FOLDER_TEMP(self): return os.path.join(self._out_base, "temp") -class FiguresConfig(object): - VERY_LARGE_SIZE = (40, 20) - VERY_LARGE_PORTRAIT = (30, 50) - SUPER_LARGE_SIZE = (80, 40) - LARGE_SIZE = (20, 15) - SMALL_SIZE = (15, 10) - NOTEBOOK_SIZE = (20, 10) - FIG_FORMAT = 'png' - SAVE_FLAG = True - SHOW_FLAG = False - MOUSE_HOOVER = False - MATPLOTLIB_BACKEND = "Agg" # "Qt4Agg" - FONTSIZE = 10 - SMALL_FONTSIZE = 8 - LARGE_FONTSIZE = 12 - - def largest_size(self): - import sys - if 'IPython' not in sys.modules: - return self.LARGE_SIZE - from IPython import get_ipython - if getattr(get_ipython(), 'kernel', None) is not None: - return self.NOTEBOOK_SIZE - else: - return self.LARGE_SIZE - - class CalculusConfig(object): # Normalization configuration WEIGHTS_NORM_PERCENT = 99 @@ -139,12 +132,12 @@ class CalculusConfig(object): class Config(object): generic = GenericConfig() - figures = FiguresConfig() calcul = CalculusConfig() def __init__(self, head_folder=None, raw_data_folder=None, output_base=None, separate_by_run=False): self.input = InputConfig(head_folder, raw_data_folder) self.out = OutputConfig(output_base, separate_by_run) + self.figures = FiguresConfig(output_base, separate_by_run) CONFIGURED = Config() diff --git a/tvb_scripts/plot/head_plotter.py b/tvb_scripts/plot/head_plotter.py deleted file mode 100644 index 5b3181c..0000000 --- a/tvb_scripts/plot/head_plotter.py +++ /dev/null @@ -1,143 +0,0 @@ -# coding=utf-8 - -from matplotlib import pyplot - -import numpy - -from tvb_scripts.utils.computations_utils import compute_in_degree -from tvb_scripts.datatypes.sensors import Sensors -from tvb_scripts.plot.base_plotter import BasePlotter - -from tvb.datatypes.projections import ProjectionMatrix - - -class HeadPlotter(BasePlotter): - - def __init__(self, config=None): - super(HeadPlotter, self).__init__(config) - - def _plot_connectivity(self, connectivity, figure_name='Connectivity'): - pyplot.figure(figure_name + str(connectivity.number_of_regions), self.config.figures.VERY_LARGE_SIZE) - axes = [] - axes.append(self.plot_regions2regions(connectivity.normalized_weights, - connectivity.region_labels, 121, "normalised weights")) - axes.append(self.plot_regions2regions(connectivity.tract_lengths, - connectivity.region_labels, 122, "tract lengths")) - self._save_figure(None, figure_name.replace(" ", "_").replace("\t", "_")) - self._check_show() - return pyplot.gcf(), tuple(axes) - - def _plot_connectivity_stats(self, connectivity, figsize=None, figure_name='HeadStats '): - if not isinstance(figsize, (list, tuple)): - figsize = self.config.figures.VERY_LARGE_SIZE - pyplot.figure("Head stats " + str(connectivity.number_of_regions), figsize=figsize) - areas_flag = len(connectivity.areas) == len(connectivity.region_labels) - axes=[] - axes.append(self.plot_vector(compute_in_degree(connectivity.normalized_weights), connectivity.region_labels, - 111 + 10 * areas_flag, "w in-degree")) - if len(connectivity.areas) == len(connectivity.region_labels): - axes.append(self.plot_vector(connectivity.areas, connectivity.region_labels, 122, "region areas")) - self._save_figure(None, figure_name.replace(" ", "").replace("\t", "")) - self._check_show() - return pyplot.gcf(), tuple(axes) - - def _plot_sensors(self, sensors, projection, region_labels, count=1): - figure, ax, cax = self._plot_projection(sensors, projection, region_labels, - title="%d - %s - Projection" % (count, sensors.sensors_type)) - count += 1 - return count, figure, ax, cax - - def _plot_projection(self, sensors, projection, region_labels, figure=None, title="Projection", - show_x_labels=True, show_y_labels=True, x_ticks=numpy.array([]), y_ticks=numpy.array([]), - figsize=None): - if not isinstance(figsize, (list, tuple)): - figsize = self.config.figures.VERY_LARGE_SIZE - if not (isinstance(figure, pyplot.Figure)): - figure = pyplot.figure(title, figsize=figsize) - ax, cax1 = self._plot_matrix(projection, sensors.labels, region_labels, 111, title, - show_x_labels, show_y_labels, x_ticks, y_ticks) - self._save_figure(None, title) - self._check_show() - return figure, ax, cax1 - - def plot_head(self, head, plot_stats=False, plot_sensors=True): - output = [] - output.append(self._plot_connectivity(head.connectivity)) - if plot_stats: - output.append(self._plot_connectivity_stats(head.connectivity)) - if plot_sensors: - count = 1 - for s_type, sensors_set in head.sensors.items(): - for sensor, projection in sensors_set.items(): - if isinstance(sensor, Sensors) and isinstance(projection, ProjectionMatrix): - count, figure, ax, cax = \ - self._plot_sensors(sensor, projection.projection_data, - head.connectivity.region_labels, count) - output.append((figure, ax, cax)) - return tuple(output) - - def plot_tvb_connectivity(self, connectivity, num="weights", order_by=None, plot_hinton=False, plot_tracts=True, - **kwargs): - """ - A 2D plot for visualizing the Connectivity.weights matrix - """ - figsize = kwargs.pop("figsize", self.config.figures.LARGE_SIZE) - fontsize = kwargs.pop("fontsize", self.config.figures.SMALL_FONTSIZE) - - labels = connectivity.region_labels - plot_title = connectivity.__class__.__name__ - - if order_by is None: - order = numpy.arange(connectivity.number_of_regions) - else: - order = numpy.argsort(order_by) - if order.shape[0] != connectivity.number_of_regions: - self.logger.error("Ordering vector doesn't have length number_of_regions") - self.logger.error("Check ordering length and that connectivity is configured") - return - - # Assumes order is shape (number_of_regions, ) - order_rows = order[:, numpy.newaxis] - order_columns = order_rows.T - - if plot_hinton: - from tvb.simulator.plot.tools import hinton_diagram - weights_axes = hinton_diagram(connectivity.weights[order_rows, order_columns], num) - weights_figure = None - else: - # weights matrix - weights_figure = pyplot.figure(num="Connectivity weights", figsize=figsize) - weights_axes = weights_figure.gca() - wimg = weights_axes.matshow(connectivity.weights[order_rows, order_columns]) - weights_figure.colorbar(wimg) - - weights_axes.set_title(plot_title) - - weights_axes.set_yticks(numpy.arange(connectivity.number_of_regions)) - weights_axes.set_yticklabels(list(labels[order]), fontsize=self.config.figures.FONTSIZE) - - weights_axes.set_xticks(numpy.arange(connectivity.number_of_regions)) - weights_axes.set_xticklabels(list(labels[order]), fontsize=fontsize, rotation=90) - - self._save_figure(weights_figure, plot_title) - self._check_show() - - if plot_tracts: - # tract lengths matrix - tracts_figure = pyplot.figure(num="Tracts' lengths", figsize=figsize) - tracts_axes = tracts_figure.gca() - timg = tracts_axes.matshow(connectivity.tract_lengths[order_rows, order_columns]) - tracts_axes.set_title("Tracts' lengths") - tracts_figure.colorbar(timg) - tracts_axes.set_yticks(numpy.arange(connectivity.number_of_regions)) - tracts_axes.set_yticklabels(list(labels[order]), fontsize=fontsize) - - tracts_axes.set_xticks(numpy.arange(connectivity.number_of_regions)) - tracts_axes.set_xticklabels(list(labels[order]), fontsize=fontsize, rotation=90) - - self._save_figure(tracts_figure) - self._check_show() - return weights_figure, weights_axes, tracts_figure, tracts_axes - - else: - return weights_figure, weights_axes diff --git a/tvb_scripts/plot/plotter.py b/tvb_scripts/plot/plotter.py index 90d9c9e..fa73542 100644 --- a/tvb_scripts/plot/plotter.py +++ b/tvb_scripts/plot/plotter.py @@ -1,63 +1,18 @@ # -*- coding: utf-8 -*- +from tvb.simulator.plot.plotter import Plotter as TVBPlotter -from tvb_scripts.plot.base_plotter import BasePlotter -from tvb_scripts.plot.head_plotter import HeadPlotter from tvb_scripts.plot.time_series_plotter import TimeSeriesPlotter -class Plotter(object): - - def __init__(self, config=None): - self.config = config - - @property - def base(self): - return BasePlotter(self.config) - - def tvb_plot(self, plot_fun_name, *args, **kwargs): - return BasePlotter(self.config).tvb_plot(plot_fun_name, *args, **kwargs) - +class Plotter(TVBPlotter): def plot_head(self, head): - return HeadPlotter(self.config).plot_head(head) - - def plot_tvb_connectivity(self, *args, **kwargs): - return HeadPlotter(self.config).plot_tvb_connectivity(*args, **kwargs) - - def plot_ts(self, *args, **kwargs): - return TimeSeriesPlotter(self.config).plot_ts(*args, **kwargs) - - def plot_ts_raster(self, *args, **kwargs): - return TimeSeriesPlotter(self.config).plot_ts_raster(*args, **kwargs) - - def plot_ts_trajectories(self, *args, **kwargs): - return TimeSeriesPlotter(self.config).plot_ts_trajectories(*args, **kwargs) - - def plot_tvb_timeseries(self, *args, **kwargs): - return TimeSeriesPlotter(self.config).plot_tvb_time_series(*args, **kwargs) + return self.plot_head_tvb(head.connectivity, head.sensors) def plot_timeseries(self, *args, **kwargs): - return TimeSeriesPlotter(self.config).plot_time_series(*args, **kwargs) - - def plot_raster(self, *args, **kwargs): - return TimeSeriesPlotter(self.config).plot_raster(*args, **kwargs) - - def plot_trajectories(self, *args, **kwargs): - return TimeSeriesPlotter(self.config).plot_trajectories(*args, **kwargs) + return TimeSeriesPlotter(self.config).plot_tvb_time_series(*args, **kwargs) def plot_timeseries_interactive(self, *args, **kwargs): return TimeSeriesPlotter(self.config).plot_time_series_interactive(*args, **kwargs) - def plot_tvb_timeseries_interactive(self, *args, **kwargs): - return TimeSeriesPlotter(self.config).plot_tvb_time_series_interactive(*args, **kwargs) - - def plot_power_spectra_interactive(self, *args, **kwargs): - return TimeSeriesPlotter(self.config).plot_power_spectra_interactive(*args, **kwargs) - - def plot_tvb_power_spectra_interactive(self, *args, **kwargs): - return TimeSeriesPlotter(self.config).plot_tvb_power_spectra_interactive(*args, **kwargs) - - def plot_ts_spectral_analysis_raster(self, *args, **kwargs): - return TimeSeriesPlotter(self.config).plot_spectral_analysis_raster(self, *args, **kwargs) - def plot_spectral_analysis_raster(self, *args, **kwargs): - return TimeSeriesPlotter(self.config).plot_spectral_analysis_raster(self, *args, **kwargs) + return TimeSeriesPlotter(self.config).plot_spectral_analysis_raster(*args, **kwargs) diff --git a/tvb_scripts/plot/time_series_interactive_plotter.py b/tvb_scripts/plot/time_series_interactive_plotter.py deleted file mode 100644 index c5dd2b7..0000000 --- a/tvb_scripts/plot/time_series_interactive_plotter.py +++ /dev/null @@ -1,150 +0,0 @@ -# coding=utf-8 -import numpy - -from tvb_scripts.utils.log_error_utils import initialize_logger -from tvb_scripts.utils.data_structures_utils import ensure_list, rotate_n_list_elements -from tvb.simulator.plot.timeseries_interactive import \ - TimeSeriesInteractive, pylab, time_series_datatypes, BACKGROUNDCOLOUR, EDGECOLOUR - -from matplotlib.pyplot import rcParams - - -LOG = initialize_logger(__name__) - - -class TimeseriesInteractivePlotter(TimeSeriesInteractive): - - def create_figure(self, **kwargs): - """ Create the figure and time-series axes. """ - # time_series_type = self.time_series.__class__.__name__ - figsize = kwargs.pop("figsize", (14, 8)) - facecolor = kwargs.pop("facecolor", BACKGROUNDCOLOUR) - edgecolor = kwargs.pop("edgecolor", EDGECOLOUR) - try: - figure_window_title = "Interactive time series: " # + time_series_type - num = kwargs.pop("figname", kwargs.get("num", figure_window_title)) - # pylab.close(figure_window_title) - self.its_fig = pylab.figure(num=num, - figsize=figsize, - facecolor=facecolor, - edgecolor=edgecolor) - except ValueError: - LOG.info("My life would be easier if you'd update your PyLab...") - figure_number = 42 - pylab.close(figure_number) - self.its_fig = pylab.figure(num=figure_number, - figsize=figsize, - facecolor=facecolor, - edgecolor=edgecolor) - - self.ts_ax = self.its_fig.add_axes([0.1, 0.1, 0.85, 0.85]) - - self.whereami_ax = self.its_fig.add_axes([0.1, 0.95, 0.85, 0.025], - facecolor=facecolor) - self.whereami_ax.set_axis_off() - if hasattr(self.whereami_ax, 'autoscale'): - self.whereami_ax.autoscale(enable=True, axis='both', tight=True) - self.whereami_ax.plot(self.time_view, - numpy.zeros((len(self.time_view),)), - color="0.3", linestyle="--") - self.hereiam = self.whereami_ax.plot(self.time_view, - numpy.zeros((len(self.time_view),)), - 'b-', linewidth=4) - - def plot_time_series(self, **kwargs): - """ Plot a view on the timeseries. """ - # Set title and axis labels - #time_series_type = self.time_series.__class__.__name__ - #self.ts_ax.set(title = time_series_type) - #self.ts_ax.set(xlabel = "Time (%s)" % self.units) - - # This assumes shape => (time, space) - step = self.scaling * self.peak_to_peak - if step == 0: - offset = 0.0 - else: #NOTE: specifying step in arange is faster, but it fence-posts. - offset = numpy.arange(0, self.nsrs) * step - if hasattr(self.ts_ax, 'autoscale'): - self.ts_ax.autoscale(enable=True, axis='both', tight=True) - - self.ts_ax.set_yticks(offset) - self.ts_ax.set_yticklabels(self.labels, fontsize=10) - #import pdb; pdb.set_trace() - - #Light gray guidelines - self.ts_ax.plot([self.nsrs*[self.time[self.time_view[0]]], - self.nsrs*[self.time[self.time_view[-1]]]], - numpy.vstack(2*(offset,)), "0.85") - - # Determine colors and linestyles for each variable of the Timeseries - linestyle = ensure_list(kwargs.pop("linestyle", "-")) - colors = kwargs.pop("linestyle", None) - if colors is not None: - colors = ensure_list(colors) - if self.data.shape[1] > 1: - linestyle = rotate_n_list_elements(linestyle, self.data.shape[1]) - if not isinstance(colors, list): - colors = (rcParams['axes.prop_cycle']).by_key()['color'] - colors = rotate_n_list_elements(colors, self.data.shape[1]) - else: - # If no color, - # or a color sequence is given in the input - # but there is only one variable to plot, - # choose the black color - if colors is None or len(colors) > 1: - colors = ["k"] - linestyle = linestyle[:1] - - # Determine the alpha value depending on the number of modes/samples of the Timeseries - alpha = 1.0 - if len(self.data.shape) > 3 and self.data.shape[3] > 1: - alpha /= self.data.shape[3] - - # Plot the timeseries per variable and sample - self.ts_view = [] - for i_var in range(self.data.shape[1]): - for ii in range(self.data.shape[3]): - # Plot the timeseries - self.ts_view.append(self.ts_ax.plot(self.time[self.time_view], - offset + self.data[self.time_view, i_var, :, ii], - alpha=alpha, color=colors[i_var], linestyle=linestyle[i_var], - **kwargs)) - - self.hereiam[0].remove() - self.hereiam = self.whereami_ax.plot(self.time_view, - numpy.zeros((len(self.time_view),)), - 'b-', linewidth=4) - - pylab.draw() - - def show(self, block=True, **kwargs): - """ Generate the interactive time-series figure. """ - time_series_type = self.time_series.__class__.__name__ - msg = "Generating an interactive time-series plot for %s" - if isinstance(self.time_series, time_series_datatypes.TimeSeriesSurface): - LOG.warning("Intended for region and sensors, not surfaces.") - LOG.info(msg % time_series_type) - - # Make the figure: - self.create_figure() - - # Selectors - # self.add_mode_selector() - - # Sliders - self.add_window_length_slider() - self.add_scaling_slider() - # self.add_time_slider() - - # time-view buttons - self.add_step_back_button() - self.add_step_forward_button() - self.add_big_step_back_button() - self.add_big_step_forward_button() - self.add_start_button() - self.add_end_button() - - # Plot timeseries - self.plot_time_series() - - pylab.show(block=block, **kwargs) diff --git a/tvb_scripts/plot/time_series_plotter.py b/tvb_scripts/plot/time_series_plotter.py index def9181..b35aa12 100644 --- a/tvb_scripts/plot/time_series_plotter.py +++ b/tvb_scripts/plot/time_series_plotter.py @@ -1,371 +1,12 @@ # -*- coding: utf-8 -*- -from six import string_types -import matplotlib -from matplotlib import pyplot, gridspec -from matplotlib.colors import Normalize import numpy -from tvb_scripts.utils.log_error_utils import warning, raise_value_error -from tvb_scripts.utils.data_structures_utils import ensure_list, isequal_string, generate_region_labels, ensure_string -from tvb_scripts.utils.time_series_utils import time_spectral_analysis -from tvb_scripts.plot.base_plotter import BasePlotter -from tvb.datatypes.time_series import TimeSeries as TimeSeriesTVB -from tvb_scripts.datatypes.time_series import TimeSeries - - -def assert_time(time, n_times, time_unit="ms", logger=None): - if time_unit.find("ms"): - dt = 0.001 - else: - dt = 1.0 - try: - time = numpy.array(time).flatten() - n_time = len(time) - if n_time > n_times: - # self.logger.warning("Input time longer than data time points! Removing redundant tail time points!") - time = time[:n_times] - elif n_time < n_times: - # self.logger.warning("Input time shorter than data time points! " - # "Extending tail time points with the same average time step!") - if n_time > 1: - dt = numpy.mean(numpy.diff(time)) - n_extra_points = n_times - n_time - start_time_point = time[-1] + dt - end_time_point = start_time_point + n_extra_points * dt - time = numpy.concatenate([time, numpy.arange(start_time_point, end_time_point, dt)]) - except: - if logger: - logger.warning("Setting a default time step vector manually! Input time: " + str(time)) - time = numpy.arange(0, n_times * dt, dt) - return time - - -class TimeSeriesPlotter(BasePlotter): - linestyle = "-" - linewidth = 1 - marker = None - markersize = 2 - markerfacecolor = None - tick_font_size = 12 - print_ts_indices = True - - def __init__(self, config=None): - super(TimeSeriesPlotter, self).__init__(config) - self.interactive_plotter = None - self.print_ts_indices = self.print_regions_indices - self.HighlightingDataCursor = lambda *args, **kwargs: None - if matplotlib.get_backend() in matplotlib.rcsetup.interactive_bk and self.config.figures.MOUSE_HOOVER: - try: - from mpldatacursor import HighlightingDataCursor - self.HighlightingDataCursor = HighlightingDataCursor - except ImportError: - self.config.figures.MOUSE_HOOVER = False - # self.logger.warning("Importing mpldatacursor failed! No highlighting functionality in plots!") - else: - # self.logger.warning("Noninteractive matplotlib backend! No highlighting functionality in plots!") - self.config.figures.MOUSE_HOOVER = False - - @property - def line_format(self): - return {"linestyle": self.linestyle, "linewidth": self.linewidth, - "marker": self.marker, "markersize": self.markersize, "markerfacecolor": self.markerfacecolor} - - def _ts_plot(self, time, n_vars, nTS, n_times, time_unit, subplots, offset=0.0, data_lims=[]): - - time_unit = ensure_string(time_unit) - data_fun = lambda data, time, icol: (data[icol], time, icol) - - def plot_ts(x, iTS, colors, labels): - x, time, ivar = x - time = assert_time(time, len(x[:, iTS]), time_unit, self.logger) - try: - return pyplot.plot(time, x[:, iTS], color=colors[iTS], label=labels[iTS], **self.line_format) - except: - self.logger.warning("Cannot convert labels' strings for line labels!") - return pyplot.plot(time, x[:, iTS], color=colors[iTS], label=str(iTS), **self.line_format) - - def plot_ts_raster(x, iTS, colors, labels, offset): - x, time, ivar = x - time = assert_time(time, len(x[:, iTS]), time_unit, self.logger) - try: - return pyplot.plot(time, -x[:, iTS] + (offset * iTS + x[:, iTS].mean()), color=colors[iTS], - label=labels[iTS], **self.line_format) - except: - self.logger.warning("Cannot convert labels' strings for line labels!") - return pyplot.plot(time, -x[:, iTS] + offset * iTS, color=colors[iTS], - label=str(iTS), **self.line_format) - - def axlabels_ts(labels, n_rows, irow, iTS): - if irow == n_rows: - pyplot.gca().set_xlabel("Time (" + time_unit + ")") - if n_rows > 1: - try: - pyplot.gca().set_ylabel(str(iTS) + "." + labels[iTS]) - except: - self.logger.warning("Cannot convert labels' strings for y axis labels!") - pyplot.gca().set_ylabel(str(iTS)) - - def axlimits_ts(data_lims, time, icol): - pyplot.gca().set_xlim([time[0], time[-1]]) - if n_rows > 1: - pyplot.gca().set_ylim([data_lims[icol][0], data_lims[icol][1]]) - else: - pyplot.autoscale(enable=True, axis='y', tight=True) +from tvb.simulator.plot.time_series_plotter import TimeSeriesPlotter as TVBTimeSeriesPlotter - def axYticks(labels, nTS, offsets=offset): - pyplot.gca().set_yticks((offset * numpy.array([list(range(nTS))]).flatten()).tolist()) - try: - pyplot.gca().set_yticklabels(labels.flatten().tolist()) - except: - labels = generate_region_labels(nTS, [], "", True) - self.logger.warning("Cannot convert region labels' strings for y axis ticks!") - - if offset > 0.0: - plot_lines = lambda x, iTS, colors, labels: \ - plot_ts_raster(x, iTS, colors, labels, offset) - else: - plot_lines = lambda x, iTS, colors, labels: \ - plot_ts(x, iTS, colors, labels) - this_axYticks = lambda labels, nTS: axYticks(labels, nTS, offset) - if subplots: - n_rows = nTS - def_alpha = 1.0 - else: - n_rows = 1 - def_alpha = 0.5 - subtitle_col = lambda subtitle: pyplot.gca().set_title(subtitle) - subtitle = lambda iTS, labels: None - projection = None - axlabels = lambda labels, vars, n_vars, n_rows, irow, iTS: axlabels_ts(labels, n_rows, irow, iTS) - axlimits = lambda data_lims, time, n_vars, icol: axlimits_ts(data_lims, time, icol) - loopfun = lambda nTS, n_rows, icol: list(range(nTS)) - return data_fun, time, plot_lines, projection, n_rows, n_vars, def_alpha, loopfun, \ - subtitle, subtitle_col, axlabels, axlimits, this_axYticks - - def _trajectories_plot(self, n_dims, nTS, nSamples, subplots): - data_fun = lambda data, time, icol: data - - def plot_traj_2D(x, iTS, colors, labels): - x, y = x - try: - return pyplot.plot(x[:, iTS], y[:, iTS], color=colors[iTS], label=labels[iTS], **self.line_format) - except: - self.logger.warning("Cannot convert labels' strings for line labels!") - return pyplot.plot(x[:, iTS], y[:, iTS], color=colors[iTS], label=str(iTS), **self.line_format) - - def plot_traj_3D(x, iTS, colors, labels): - x, y, z = x - try: - return pyplot.plot(x[:, iTS], y[:, iTS], z[:, iTS], color=colors[iTS], - label=labels[iTS], **self.line_format) - except: - self.logger.warning("Cannot convert labels' strings for line labels!") - return pyplot.plot(x[:, iTS], y[:, iTS], z[:, iTS], color=colors[iTS], - label=str(iTS), **self.line_format) - - def subtitle_traj(labels, iTS): - try: - if self.print_ts_indices: - pyplot.gca().set_title(str(iTS) + "." + labels[iTS]) - else: - pyplot.gca().set_title(labels[iTS]) - except: - self.logger.warning("Cannot convert labels' strings for subplot titles!") - pyplot.gca().set_title(str(iTS)) - - def axlabels_traj(vars, n_vars): - pyplot.gca().set_xlabel(vars[0]) - pyplot.gca().set_ylabel(vars[1]) - if n_vars > 2: - pyplot.gca().set_zlabel(vars[2]) - - def axlimits_traj(data_lims, n_vars): - pyplot.gca().set_xlim([data_lims[0][0], data_lims[0][1]]) - pyplot.gca().set_ylim([data_lims[1][0], data_lims[1][1]]) - if n_vars > 2: - pyplot.gca().set_zlim([data_lims[2][0], data_lims[2][1]]) - - if n_dims == 2: - plot_lines = lambda x, iTS, colors, labels: \ - plot_traj_2D(x, iTS, colors, labels) - projection = None - elif n_dims == 3: - plot_lines = lambda x, iTS, colors, labels: \ - plot_traj_3D(x, iTS, colors, labels) - projection = '3d' - else: - raise_value_error("Data dimensions are neigher 2D nor 3D!, but " + str(n_dims) + "D!") - n_rows = 1 - n_cols = 1 - if subplots is None: - # if nSamples > 1: - n_rows = int(numpy.floor(numpy.sqrt(nTS))) - n_cols = int(numpy.ceil((1.0 * nTS) / n_rows)) - elif isinstance(subplots, (list, tuple)): - n_rows = subplots[0] - n_cols = subplots[1] - if n_rows * n_cols < nTS: - raise_value_error("Not enough subplots for all time series:" - "\nn_rows * n_cols = product(subplots) = product(" + str(subplots) + " = " - + str(n_rows * n_cols) + "!") - if n_rows * n_cols > 1: - def_alpha = 0.5 - subtitle = lambda labels, iTS: subtitle_traj(labels, iTS) - subtitle_col = lambda subtitles, icol: None - else: - def_alpha = 1.0 - subtitle = lambda labels, iTS: None - subtitle_col = lambda subtitles, icol: pyplot.gca().set_title(pyplot.gcf().title) - axlabels = lambda labels, vars, n_vars, n_rows, irow, iTS: axlabels_traj(vars, n_vars) - axlimits = lambda data_lims, time, n_vars, icol: axlimits_traj(data_lims, n_vars) - loopfun = lambda nTS, n_rows, icol: list(range(icol, nTS, n_rows)) - return data_fun, plot_lines, projection, n_rows, n_cols, def_alpha, loopfun, \ - subtitle, subtitle_col, axlabels, axlimits - - # TODO: refactor to not have the plot commands here - def plot_ts(self, data, time=None, var_labels=[], mode="ts", subplots=None, special_idx=[], - subtitles=[], labels=[], offset=0.5, time_unit="ms", - title='Time series', figure_name=None, figsize=None): - if not isinstance(figsize, (list, tuple)): - figsize = self.config.figures.LARGE_SIZE - if isinstance(data, dict): - var_labels = data.keys() - data = data.values() - elif isinstance(data, numpy.ndarray): - if len(data.shape) < 3: - if len(data.shape) < 2: - data = numpy.expand_dims(data, 1) - data = numpy.expand_dims(data, 2) - data = [data] - else: - # Assuming a structure of Time X Space X Variables X Samples - data = [data[:, :, iv].squeeze() for iv in range(data.shape[2])] - elif isinstance(data, (list, tuple)): - data = ensure_list(data) - else: - raise_value_error("Input timeseries: %s \n" "is not on of one of the following types: " - "[numpy.ndarray, dict, list, tuple]" % str(data)) - n_vars = len(data) - data_lims = [] - for id, d in enumerate(data): - if isequal_string(mode, "raster"): - data[id] = (d - d.mean(axis=0)) - drange = numpy.max(data[id].max(axis=0) - data[id].min(axis=0)) - data[id] = data[id] / drange # zscore(d, axis=None) - data_lims.append([d.min(), d.max()]) - data_shape = data[0].shape - if len(data_shape) == 1: - n_times = data_shape[0] - nTS = 1 - for iV in range(n_vars): - data[iV] = data[iV][:, numpy.newaxis] - else: - n_times, nTS = data_shape[:2] - if len(data_shape) > 2: - nSamples = data_shape[2] - else: - nSamples = 1 - if special_idx is None: - special_idx = [] - n_special_idx = len(special_idx) - if len(subtitles) == 0: - subtitles = var_labels - if isinstance(labels, list) and len(labels) == n_vars: - labels = [generate_region_labels(nTS, label, ". ", self.print_ts_indices) for label in labels] - else: - labels = [generate_region_labels(nTS, labels, ". ", self.print_ts_indices) for _ in range(n_vars)] - if isequal_string(mode, "traj"): - data_fun, plot_lines, projection, n_rows, n_cols, def_alpha, loopfun, \ - subtitle, subtitle_col, axlabels, axlimits = \ - self._trajectories_plot(n_vars, nTS, nSamples, subplots) - else: - if isequal_string(mode, "raster"): - data_fun, time, plot_lines, projection, n_rows, n_cols, def_alpha, loopfun, \ - subtitle, subtitle_col, axlabels, axlimits, axYticks = \ - self._ts_plot(time, n_vars, nTS, n_times, time_unit, 0, offset, data_lims) - - else: - data_fun, time, plot_lines, projection, n_rows, n_cols, def_alpha, loopfun, \ - subtitle, subtitle_col, axlabels, axlimits, axYticks = \ - self._ts_plot(time, n_vars, nTS, n_times, time_unit, ensure_list(subplots)[0]) - alpha_ratio = 1.0 / nSamples - alphas = numpy.maximum(numpy.array([def_alpha] * nTS) * alpha_ratio, 0.1) - alphas[special_idx] = numpy.maximum(alpha_ratio, 0.1) - if isequal_string(mode, "traj") and (n_cols * n_rows > 1): - colors = numpy.zeros((nTS, 4)) - colors[special_idx] = \ - numpy.array([numpy.array([1.0, 0, 0, 1.0]) for _ in range(n_special_idx)]).reshape((n_special_idx, 4)) - else: - cmap = matplotlib.cm.get_cmap('jet') - colors = numpy.array([cmap(0.5 * iTS / nTS) for iTS in range(nTS)]) - colors[special_idx] = \ - numpy.array([cmap(1.0 - 0.25 * iTS / nTS) for iTS in range(n_special_idx)]).reshape((n_special_idx, 4)) - colors[:, 3] = alphas - lines = [] - pyplot.figure(title, figsize=figsize) - axes = [] - for icol in range(n_cols): - if n_rows == 1: - # If there are no more rows, create axis, and set its limits, labels and possible subtitle - axes += ensure_list(pyplot.subplot(n_rows, n_cols, icol + 1, projection=projection)) - axlimits(data_lims, time, n_vars, icol) - axlabels(labels[icol % n_vars], var_labels, n_vars, n_rows, 1, 0) - pyplot.gca().set_title(subtitles[icol]) - for iTS in loopfun(nTS, n_rows, icol): - if n_rows > 1: - # If there are more rows, create axes, and set their limits, labels and possible subtitles - axes += ensure_list(pyplot.subplot(n_rows, n_cols, iTS + 1, projection=projection)) - axlimits(data_lims, time, n_vars, icol) - subtitle(labels[icol % n_vars], iTS) - axlabels(labels[icol % n_vars], var_labels, n_vars, n_rows, (iTS % n_rows) + 1, iTS) - lines += ensure_list(plot_lines(data_fun(data, time, icol), iTS, colors, labels[icol % n_vars])) - if isequal_string(mode, "raster"): # set yticks as labels if this is a raster plot - axYticks(labels[icol % n_vars], nTS) - yticklabels = pyplot.gca().yaxis.get_ticklabels() - self.tick_font_size = numpy.minimum(self.tick_font_size, - int(numpy.round(self.tick_font_size * 100.0 / nTS))) - for iTS in range(nTS): - yticklabels[iTS].set_fontsize(self.tick_font_size) - if iTS in special_idx: - yticklabels[iTS].set_color(colors[iTS, :3].tolist() + [1]) - pyplot.gca().yaxis.set_ticklabels(yticklabels) - pyplot.gca().invert_yaxis() - - if self.config.figures.MOUSE_HOOVER: - for line in lines: - self.HighlightingDataCursor(line, formatter='{label}'.format, bbox=dict(fc='white'), - arrowprops=dict(arrowstyle='simple', fc='white', alpha=0.5)) - - self._save_figure(pyplot.gcf(), figure_name) - self._check_show() - return pyplot.gcf(), axes, lines - - def plot_ts_raster(self, data, time, var_labels=[], time_unit="ms", special_idx=[], - title='Raster plot', subtitles=[], labels=[], offset=0.5, figure_name=None, figsize=None): - if not isinstance(figsize, (list, tuple)): - figsize = self.config.figures.VERY_LARGE_SIZE - return self.plot_ts(data, time, var_labels, "raster", None, special_idx, subtitles, labels, offset, time_unit, - title, figure_name, figsize) - - def plot_ts_trajectories(self, data, var_labels=[], subtitles=None, special_idx=[], labels=[], - title='State space trajectories', figure_name=None, figsize=None): - if not isinstance(figsize, (list, tuple)): - figsize = self.config.figures.LARGE_SIZE - return self.plot_ts(data, [], var_labels, "traj", subtitles, special_idx, labels=labels, title=title, - figure_name=figure_name, figsize=figsize) +from tvb_scripts.datatypes.time_series import TimeSeries - def plot_tvb_time_series(self, time_series, mode="ts", subplots=None, special_idx=[], subtitles=[], - offset=0.5, title=None, figure_name=None, figsize=None): - if not isinstance(figsize, (list, tuple)): - figsize = self.config.figures.LARGE_SIZE - if title is None: - title = time_series.title - variables_labels = time_series.labels_dimensions.get(time_series.labels_ordering[1], []) - space_labels = time_series.labels_dimensions.get(time_series.labels_ordering[2], []) - return self.plot_ts(numpy.swapaxes(time_series.data, 1, 2), time_series.time, variables_labels, - mode, subplots, special_idx, subtitles, space_labels, - offset, time_series.time_unit, title, figure_name, figsize) +class TimeSeriesPlotter(TVBTimeSeriesPlotter): def plot_time_series(self, time_series, mode="ts", subplots=None, special_idx=[], subtitles=[], offset=0.5, title=None, figure_name=None, figsize=None, **kwargs): if isinstance(time_series, TimeSeries): @@ -375,211 +16,25 @@ def plot_time_series(self, time_series, mode="ts", subplots=None, special_idx=[] time_series.time, time_series.variables_labels, mode, subplots, special_idx, subtitles, time_series.space_labels, offset, time_series.time_unit, title, figure_name, figsize) - elif isinstance(time_series, TimeSeriesTVB): - self.plot_tvb_time_series(time_series, mode, subplots, special_idx, - subtitles, offset, title, figure_name, figsize) - elif isinstance(time_series, (numpy.ndarray, dict, list, tuple)): - time = kwargs.get("time", None) - time_unit = kwargs.get("time_unit", "ms") - labels = kwargs.get("labels", []) - var_labels = kwargs.get("var_labels", []) - if title is None: - title = "Time Series" - return self.plot_ts(time_series, time=time, mode=mode, time_unit=time_unit, - labels=labels, var_labels=var_labels, subplots=subplots, special_idx=special_idx, - subtitles=subtitles, offset=offset, title=title, figure_name=figure_name, - figsize=figsize) else: - raise_value_error("Input time_series: %s \n" "is not on of one of the following types: " - "[TimeSeries (tvb-scripts), TimeSeries (TVB), numpy.ndarray, dict]" % str(time_series)) - - def plot_raster(self, time_series, subplots=None, special_idx=[], subtitles=[], - offset=0.5, title=None, figure_name=None, figsize=None, **kwargs): - return self.plot_time_series(time_series, "raster", subplots, special_idx, - subtitles, offset, title, figure_name, figsize, **kwargs) - - def plot_trajectories(self, time_series, subplots=None, special_idx=[], subtitles=[], - offset=0.5, title=None, figure_name=None, figsize=None, **kwargs): - return self.plot_time_series(time_series, "traj", subplots, special_idx, - subtitles, offset, title, figure_name, figsize, **kwargs) - - @staticmethod - def plot_tvb_time_series_interactive(time_series, first_n=-1, **kwargs): - from tvb_scripts.plot.time_series_interactive_plotter import TimeseriesInteractivePlotter - interactive_plotter = TimeseriesInteractivePlotter(time_series=time_series, first_n=first_n) - interactive_plotter.configure() - block = kwargs.pop("block", True) - interactive_plotter.show(block=block, **kwargs) + super(TimeSeriesPlotter, self).plot_time_series(time_series, mode, subplots, special_idx, subtitles, offset, + title, figure_name, figsize, **kwargs) def plot_time_series_interactive(self, time_series, first_n=-1, **kwargs): - if isinstance(time_series, TimeSeriesTVB): - self.plot_tvb_time_series_interactive(time_series, first_n, **kwargs) - elif isinstance(time_series, TimeSeries): + if isinstance(time_series, TimeSeries): self.plot_tvb_time_series_interactive(time_series._tvb, first_n, **kwargs) - elif isinstance(time_series, numpy.ndarray): - self.plot_tvb_time_series_interactive(TimeSeries(data=time_series), first_n, **kwargs) - elif isinstance(time_series, (list, tuple)): - self.plot_tvb_time_series_interactive(TimeSeries(data=TimeSeriesTVB(data=numpy.stack(time_series, axis=1))), - first_n, **kwargs) - elif isinstance(time_series, dict): - ts = numpy.stack(time_series.values(), axis=1) - time_series = TimeSeriesTVB(data=ts, labels_dimensions={"State Variable": time_series.keys()}) - self.plot_tvb_time_series_interactive(time_series, first_n, **kwargs) else: - raise_value_error("Input time_series: %s \n" "is not on of one of the following types: " - "[TimeSeries (tvb-scripts), TimeSeriesTVB (TVB), numpy.ndarray, dict, list, tuple]" % - str(time_series)) - - @staticmethod - def plot_tvb_power_spectra_interactive(time_series, spectral_props, **kwargs): - from tvb.simulator.plot.power_spectra_interactive import PowerSpectraInteractive - interactive_plotters = PowerSpectraInteractive(time_series=time_series, **spectral_props) - interactive_plotters.configure() - block = kwargs.pop("block", True) - interactive_plotters.show(blocl=block, **kwargs) - - def plot_power_spectra_interactive(self, time_series, spectral_props, **kwargs): - self.plot_tvb_power_spectra_interactive(self, time_series._tvb, spectral_props, **kwargs) - - # TODO: refactor to not have the plot commands here - def _plot_ts_spectral_analysis_raster(self, data, time=None, var_label="", time_unit="ms", - freq=None, spectral_options={}, special_idx=[], labels=[], - title='Spectral Analysis', figure_name=None, figsize=None): - if not isinstance(figsize, (list, tuple)): - figsize = self.config.figures.VERY_LARGE_SIZE - if len(data.shape) == 1: - n_times = data.shape[0] - nS = 1 - else: - n_times, nS = data.shape[:2] - time = assert_time(time, n_times, time_unit, self.logger) - if not isinstance(time_unit, string_types): - time_unit = list(time_unit)[0] - time_unit = ensure_string(time_unit) - if time_unit in ("ms", "msec"): - fs = 1000.0 - else: - fs = 1.0 - fs = fs / numpy.mean(numpy.diff(time)) - n_special_idx = len(special_idx) - if n_special_idx > 0: - data = data[:, special_idx] - nS = data.shape[1] - if len(labels) > n_special_idx: - labels = numpy.array([str(ilbl) + ". " + str(labels[ilbl]) for ilbl in special_idx]) - elif len(labels) == n_special_idx: - labels = numpy.array([str(ilbl) + ". " + str(label) for ilbl, label in zip(special_idx, labels)]) - else: - labels = numpy.array([str(ilbl) for ilbl in special_idx]) - else: - if len(labels) != nS: - labels = numpy.array([str(ilbl) for ilbl in range(nS)]) - if nS > 20: - warning("It is not possible to plot spectral analysis plots for more than 20 signals!") - return - - log_norm = spectral_options.get("log_norm", False) - mode = spectral_options.get("mode", "psd") - psd_label = mode - if log_norm: - psd_label = "log" + psd_label - stf, time, freq, psd = time_spectral_analysis(data, fs, - freq=freq, - mode=mode, - nfft=spectral_options.get("nfft"), - window=spectral_options.get("window", 'hanning'), - nperseg=spectral_options.get("nperseg", int(numpy.round(fs / 4))), - detrend=spectral_options.get("detrend", 'constant'), - noverlap=spectral_options.get("noverlap"), - f_low=spectral_options.get("f_low", 10.0), - log_scale=spectral_options.get("log_scale", False)) - min_val = numpy.min(stf.flatten()) - max_val = numpy.max(stf.flatten()) - if nS > 2: - figsize = self.config.figures.VERY_LARGE_SIZE - if len(var_label): - title += ": " % var_label - fig = pyplot.figure(title, figsize=figsize) - fig.suptitle(title) - gs = gridspec.GridSpec(nS, 23) - ax = numpy.empty((nS, 2), dtype="O") - img = numpy.empty((nS,), dtype="O") - line = numpy.empty((nS,), dtype="O") - for iS in range(nS, -1, -1): - if iS < nS - 1: - ax[iS, 0] = pyplot.subplot(gs[iS, :20], sharex=ax[iS, 0]) - ax[iS, 1] = pyplot.subplot(gs[iS, 20:22], sharex=ax[iS, 1], sharey=ax[iS, 0]) - else: - # TODO: find and correct bug here - ax[iS, 0] = pyplot.subplot(gs[iS, :20]) - ax[iS, 1] = pyplot.subplot(gs[iS, 20:22], sharey=ax[iS, 0]) - img[iS] = ax[iS, 0].imshow(numpy.squeeze(stf[:, :, iS]).T, cmap=pyplot.set_cmap('jet'), - interpolation='none', - norm=Normalize(vmin=min_val, vmax=max_val), aspect='auto', origin='lower', - extent=(time.min(), time.max(), freq.min(), freq.max())) - # img[iS].clim(min_val, max_val) - ax[iS, 0].set_title(labels[iS]) - ax[iS, 0].set_ylabel("Frequency (Hz)") - line[iS] = ax[iS, 1].plot(psd[:, iS], freq, 'k', label=labels[iS]) - pyplot.setp(ax[iS, 1].get_yticklabels(), visible=False) - # ax[iS, 1].yaxis.tick_right() - # ax[iS, 1].yaxis.set_ticks_position('both') - if iS == (nS - 1): - ax[iS, 0].set_xlabel("Time (" + time_unit + ")") - - ax[iS, 1].set_xlabel(psd_label) - else: - pyplot.setp(ax[iS, 0].get_xticklabels(), visible=False) - pyplot.setp(ax[iS, 1].get_xticklabels(), visible=False) - ax[iS, 0].autoscale(tight=True) - ax[iS, 1].autoscale(tight=True) - # make a color bar - cax = pyplot.subplot(gs[:, 22]) - pyplot.colorbar(img[0], cax=pyplot.subplot(gs[:, 22])) # fraction=0.046, pad=0.04) #fraction=0.15, shrink=1.0 - cax.set_title(psd_label) - self._save_figure(pyplot.gcf(), figure_name) - self._check_show() - return fig, ax, img, line, time, freq, stf, psd - - def plot_ts_spectral_analysis_raster(self, data, time=None, time_unit="ms", freq=None, spectral_options={}, - special_idx=[], labels=[], title='Spectral Analysis', figure_name=None, - figsize=None): - if isinstance(data, dict): - var_labels = data.keys() - data = data.values() - else: - var_labels = [] - if isinstance(data, (list, tuple)): - data = data[0] - elif isinstance(data, numpy.ndarray) and data.ndim > 2: - # Assuming a structure of Time X Space X Variables X Samples - if data.ndim > 3: - data = data[:, :, :, 0] - data = [data[:, :, iv].squeeze() for iv in range(data.shape[2])] - if len(var_labels) == 0: - var_labels = [""] * len(data) - for d, var_label in zip(data, var_labels): - self._plot_ts_spectral_analysis_raster(d, time, var_label, time_unit, freq, spectral_options, - special_idx, labels, title, figure_name, figsize) + super(TimeSeriesPlotter, self).plot_time_series_interactive(time_series, first_n, **kwargs) def plot_spectral_analysis_raster(self, time_series, freq=None, spectral_options={}, special_idx=[], labels=[], title='Spectral Analysis', figure_name=None, figsize=None, **kwargs): if isinstance(time_series, TimeSeries): return self.plot_ts_spectral_analysis_raster(numpy.swapaxes(time_series._tvb.data, 1, 2).squeeze(), - time_series.time, time_series.time_unit, freq, spectral_options, - special_idx, labels, title, figure_name, figsize) - elif isinstance(time_series, TimeSeriesTVB): - return self.plot_ts_spectral_analysis_raster(numpy.swapaxes(time_series.data, 1, 2).squeeze(), - time_series.time, time_series.time_unit, freq, spectral_options, + time_series.time, time_series.time_unit, freq, + spectral_options, special_idx, labels, title, figure_name, figsize) - elif isinstance(time_series, (numpy.ndarray, dict, list, tuple)): - time = kwargs.get("time", None) - return self.plot_ts_spectral_analysis_raster(time_series, time=time, freq=freq, - spectral_options=spectral_options, special_idx=special_idx, - labels=labels, title=title, figure_name=figure_name, - figsize=figsize) else: - raise_value_error("Input time_series: %s \n" - "is not on of one of the following types: " - "[TimeSeries (tvb-scripts), TimeSeries (TVB), numpy.ndarray, dict]" % str(time_series)) + super(TimeSeriesPlotter, self).plot_spectral_analysis_raster(time_series, freq, spectral_options, + special_idx, labels, title, figure_name, + figsize, **kwargs) diff --git a/tvb_scripts/tests/base.py b/tvb_scripts/tests/base.py index d3ea776..34c2506 100644 --- a/tvb_scripts/tests/base.py +++ b/tvb_scripts/tests/base.py @@ -2,7 +2,8 @@ import os import numpy -from tvb_scripts.config import Config +from tvb.simulator.plot.config import Config + from tvb_scripts.io.h5_reader import H5Reader from tvb_scripts.datatypes.connectivity import Connectivity from tvb_scripts.datatypes.sensors import Sensors diff --git a/tvb_scripts/utils/computations_utils.py b/tvb_scripts/utils/computations_utils.py index 2255206..b8b03b8 100644 --- a/tvb_scripts/utils/computations_utils.py +++ b/tvb_scripts/utils/computations_utils.py @@ -1,14 +1,13 @@ # coding=utf-8 # Some math tools from itertools import product -from sklearn.cluster import AgglomerativeClustering import numpy as np +from sklearn.cluster import AgglomerativeClustering +from tvb.simulator.plot.config import CalculusConfig, FiguresConfig -from tvb_scripts.config import FiguresConfig, CalculusConfig -from tvb_scripts.utils.log_error_utils import initialize_logger, warning from tvb_scripts.utils.data_structures_utils import is_integer - +from tvb_scripts.utils.log_error_utils import initialize_logger, warning logger = initialize_logger(__name__) @@ -84,7 +83,7 @@ def select_greater_values_array_inds(values, threshold=None, percentile=None, nv def select_greater_values_2Darray_inds(values, threshold=None, percentile=None, nvals=None, verbose=False): return np.unravel_index( - select_greater_values_array_inds(values.flatten(), threshold, percentile, nvals, verbose), values.shape) + select_greater_values_array_inds(values.flatten(), threshold, percentile, nvals, verbose), values.shape) def select_by_hierarchical_group_metric_clustering(distance, disconnectivity=np.array([]), metric=None, @@ -102,10 +101,10 @@ def select_by_hierarchical_group_metric_clustering(distance, disconnectivity=np. # ... at least members_per_group elements... n_select = np.minimum(members_per_group, len(cluster_inds)) if metric is not None and len(metric) == distance.shape[0]: - #...optionally according to some metric + # ...optionally according to some metric inds_select = np.argsort(metric[cluster_inds])[-n_select:] else: - #...otherwise, randomly + # ...otherwise, randomly inds_select = range(n_select) selection.append(cluster_inds[inds_select]) return np.unique(np.hstack(selection)).tolist() @@ -187,7 +186,7 @@ def onclick(self, event): def spikes_events_to_time_index(spike_time, time): if spike_time < time[0] or spike_time > time[-1]: warning("Spike time is outside the input time vector!") - return np.argmin(np.abs(time-spike_time)) + return np.argmin(np.abs(time - spike_time)) def compute_spikes_counts(spikes_times, time): diff --git a/tvb_scripts/utils/data_structures_utils.py b/tvb_scripts/utils/data_structures_utils.py index 21e4a48..c551b9e 100644 --- a/tvb_scripts/utils/data_structures_utils.py +++ b/tvb_scripts/utils/data_structures_utils.py @@ -1,14 +1,15 @@ # coding=utf-8 -# Data structure manipulations and conversions -from six import string_types import re -import numpy as np from collections import OrderedDict from copy import deepcopy -from tvb_scripts.utils.log_error_utils import warning, raise_value_error, raise_import_error, initialize_logger -from tvb_scripts.config import CalculusConfig +import numpy as np +# Data structure manipulations and conversions +from six import string_types +from tvb.simulator.plot.config import CalculusConfig + +from tvb_scripts.utils.log_error_utils import warning, raise_value_error, raise_import_error, initialize_logger logger = initialize_logger(__name__) @@ -280,7 +281,7 @@ def ensure_list(arg): arg = [arg] elif hasattr(arg, "__iter__"): arg = list(arg) - else: # if not iterable + else: # if not iterable arg = [arg] except: # if not iterable arg = [arg] @@ -331,7 +332,7 @@ def get_list_or_tuple_item_safely(obj, key): def delete_list_items_by_indices(lin, inds, start_ind=0): lout = [] for ind, l in enumerate(lin): - if ind+start_ind not in inds: + if ind + start_ind not in inds: lout.append(l) return lout @@ -359,6 +360,7 @@ def rotate_n_list_elements(lst, n): old_lst = old_lst[1:] + old_lst[:1] return lst + def linear_index_to_coordinate_tuples(linear_index, shape): if len(linear_index) > 0: coordinates_tuple = np.unravel_index(linear_index, shape) @@ -382,8 +384,8 @@ def find_labels_inds(labels, keys, modefun="find", two_way_search=False, break_a for key in keys: for label in labels: if modefun(label, key): - inds.append(labels.index(label)) - counts += 1 + inds.append(labels.index(label)) + counts += 1 if counts >= break_after: return inds return inds @@ -807,4 +809,3 @@ def property_to_fun(property): return property else: return lambda *args, **kwargs: property - diff --git a/tvb_scripts/utils/log_error_utils.py b/tvb_scripts/utils/log_error_utils.py index 2c8b2b6..718911e 100644 --- a/tvb_scripts/utils/log_error_utils.py +++ b/tvb_scripts/utils/log_error_utils.py @@ -1,11 +1,12 @@ # coding=utf-8 # Logs and errors +import logging import os import sys -import logging from logging.handlers import TimedRotatingFileHandler -from tvb_scripts.config import OutputConfig + +from tvb.simulator.plot.config import OutputConfig def initialize_logger(name, target_folder=OutputConfig().FOLDER_LOGS):