Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 30 additions & 32 deletions scientific_library/tvb/simulator/plot/base_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,29 @@
"""

import os
import numpy
from tvb.basic.logger.builder import get_logger
from tvb.simulator.plot.config import CONFIGURED

import matplotlib
matplotlib.use(CONFIGURED.MATPLOTLIB_BACKEND)

import numpy
from matplotlib import pyplot
pyplot.rcParams["font.size"] = CONFIGURED.FONTSIZE

from mpl_toolkits.axes_grid1 import make_axes_locatable
from tvb.simulator.plot.utils import ensure_list, generate_region_labels
from tvb.basic.logger.builder import get_logger
from tvb.simulator.plot.config import FiguresConfig, Config
from tvb.simulator.plot.utils import generate_region_labels, ensure_list

matplotlib.use(FiguresConfig().MATPLOTLIB_BACKEND)

pyplot.rcParams["font.size"] = FiguresConfig.FONTSIZE


class BasePlotter(object):

def __init__(self, config=CONFIGURED):
self.config = config
def __init__(self, config=None):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous code allowed more unity in the configurations. Currently, the fresh Config() in case of missing init param is always "fresh"

self.config = config or Config()
self.logger = get_logger(self.__class__.__name__)
self.print_regions_indices = True

def _check_show(self):
if self.config.SHOW_FLAG:
if self.config.figures.SHOW_FLAG:
# mp.use('TkAgg')
pyplot.ion()
pyplot.show()
Expand All @@ -66,25 +66,23 @@ def _check_show(self):
pyplot.close()

@staticmethod
def _figure_filename(fig=None, figure_name=None):
if fig is None:
fig = pyplot.gcf()
def _figure_filename(fig=pyplot.gcf(), figure_name=None):
if figure_name is None:
figure_name = fig.get_label()
figure_name = figure_name.replace(": ", "_").replace(" ", "_").replace("\t", "_").replace(",", "")
return figure_name

def _save_figure(self, fig, figure_name=None):
if self.config.SAVE_FLAG:
def _save_figure(self, fig=pyplot.gcf(), figure_name=None):
if self.config.figures.SAVE_FLAG:
figure_name = self._figure_filename(fig, figure_name)
figure_name = figure_name[:numpy.min([100, len(figure_name)])] + '.' + self.config.FIG_FORMAT
figure_dir = self.config.FOLDER_FIGURES
figure_name = figure_name[:numpy.min([100, len(figure_name)])] + '.' + self.config.figures.FIG_FORMAT
figure_dir = self.config.out.FOLDER_FIGURES
if not (os.path.isdir(figure_dir)):
os.mkdir(figure_dir)
pyplot.savefig(os.path.join(figure_dir, figure_name))

@staticmethod
def rect_subplot_shape(self, n, mode="col"):
def rect_subplot_shape(n, mode="col"):
nj = int(numpy.ceil(numpy.sqrt(n)))
ni = int(numpy.ceil(1.0 * n / nj))
if mode.find("row") >= 0:
Expand All @@ -96,7 +94,7 @@ def plot_vector(self, vector, labels, subplot, title, show_y_labels=True, indice
ax = pyplot.subplot(subplot, sharey=sharey)
pyplot.title(title)
n_vector = labels.shape[0]
y_ticks = numpy.array(list(range(n_vector)), dtype=numpy.int32)
y_ticks = numpy.array(range(n_vector), dtype=numpy.int32)
color = 'k'
colors = numpy.repeat([color], n_vector)
coldif = False
Expand Down Expand Up @@ -130,7 +128,7 @@ def plot_vector_violin(self, dataset, vector=[], lines=[], labels=[], subplot=11
ax = pyplot.subplot(subplot, sharey=sharey)
pyplot.title(title)
n_violins = dataset.shape[1]
y_ticks = numpy.array(list(range(n_violins)), dtype=numpy.int32)
y_ticks = numpy.array(range(n_violins), dtype=numpy.int32)
# the vector plot
coldif = False
if indices_red is None:
Expand Down Expand Up @@ -201,7 +199,7 @@ def _plot_matrix(self, matrix, xlabels, ylabels, subplot=111, title="", show_x_l
nticks = []
for ii, (n, tick) in enumerate(zip([nx, ny], ticks)):
if len(tick) == 0:
ticks[ii] = numpy.array(list(range(n)), dtype=numpy.int32)
ticks[ii] = numpy.array(range(n), dtype=numpy.int32)
nticks.append(len(ticks[ii]))
cmap = pyplot.set_cmap(cmap)
img = pyplot.imshow(matrix[ticks[0]][:, ticks[1]].T, cmap=cmap, vmin=vmin, vmax=vmax, interpolation='none')
Expand All @@ -213,10 +211,10 @@ def _plot_matrix(self, matrix, xlabels, ylabels, subplot=111, title="", show_x_l
labels[ii] = generate_region_labels(len(tick), numpy.array(lbls)[tick], ". ",
self.print_regions_indices, tick)
# labels[ii] = numpy.array(["%d. %s" % l for l in zip(tick, lbls[tick])])
getattr(pyplot, xy + "ticks")(numpy.array(list(range(ntick))), labels[ii], rotation=rot)
getattr(pyplot, xy + "ticks")(numpy.array(range(ntick)), labels[ii], rotation=rot)
else:
labels[ii] = numpy.array(["%d." % l for l in tick])
getattr(pyplot, xy + "ticks")(numpy.array(list(range(ntick))), labels[ii])
getattr(pyplot, xy + "ticks")(numpy.array(range(ntick)), labels[ii])
if ind_red is not None:
tck = tick.tolist()
ticklabels = getattr(ax, xy + "axis").get_ticklabels()
Expand All @@ -239,7 +237,7 @@ def plot_regions2regions(self, adj, labels, subplot, title, show_x_labels=True,
x_ticks, y_ticks, indices_red_x, indices_red_y, sharex, sharey, cmap, vmin, vmax)

def _set_axis_labels(self, fig, sub, n_regions, region_labels, indices2emphasize, color='k', position='left'):
y_ticks = list(range(n_regions))
y_ticks = range(n_regions)
# region_labels = numpy.array(["%d. %s" % l for l in zip(y_ticks, region_labels)])
region_labels = generate_region_labels(len(y_ticks), region_labels, ". ", self.print_regions_indices, y_ticks)
big_ax = fig.add_subplot(sub, frameon=False)
Expand All @@ -262,7 +260,7 @@ def plot_in_columns(self, data_dict_list, labels, width_ratios=[], left_ax_focus
right_ax_focus_indices=[], description="", title="", figure_name=None,
figsize=None, **kwargs):
if not isinstance(figsize, (tuple, list)):
figsize = self.config.VERY_LARGE_SIZE
figsize = self.config.figures.VERY_LARGE_SIZE
fig = pyplot.figure(title, frameon=False, figsize=figsize)
fig.suptitle(description)
n_subplots = len(data_dict_list)
Expand Down Expand Up @@ -311,7 +309,7 @@ def plot_in_columns(self, data_dict_list, labels, width_ratios=[], left_ax_focus
def plots(self, data_dict, shape=None, transpose=False, skip=0, xlabels={}, xscales={}, yscales={}, title='Plots',
lgnd={}, figure_name=None, figsize=None):
if not isinstance(figsize, (tuple, list)):
figsize = self.config.VERY_LARGE_SIZE
figsize = self.config.figures.VERY_LARGE_SIZE
if shape is None:
shape = (1, len(data_dict))
fig, axes = pyplot.subplots(shape[0], shape[1], figsize=figsize)
Expand Down Expand Up @@ -343,7 +341,7 @@ def confirm_y_coordinate(data, ymax):
return tuple(data)

if not isinstance(figsize, (tuple, list)):
figsize = self.config.VERY_LARGE_SIZE
figsize = self.config.figures.VERY_LARGE_SIZE

if subtitles is None:
subtitles = keys
Expand Down Expand Up @@ -433,7 +431,7 @@ def barlabel(ax, rects, positions):

if fig is None:
if not isinstance(figsize, (tuple, list)):
figsize = self.config.VERY_LARGE_SIZE
figsize = self.config.figures.VERY_LARGE_SIZE
fig, ax = pyplot.subplots(1, 1, figsize=figsize)
show_and_save = True
else:
Expand All @@ -442,8 +440,8 @@ def barlabel(ax, rects, positions):
ax = pyplot.gca()
if isinstance(data, (list, tuple)): # If, there are many groups, data is a list:
# Fill in with nan in case that not all groups have the same number of elements
from itertools import zip_longest
data = numpy.array(list(zip_longest(*ensure_list(data), fillvalue=numpy.nan))).T
from itertools import izip_longest
data = numpy.array(list(izip_longest(*ensure_list(data), fillvalue=numpy.nan))).T
elif data.ndim == 1: # This is the case where there is only one group...
data = numpy.expand_dims(data, axis=1).T
n_groups, n_elements = data.shape
Expand Down Expand Up @@ -479,7 +477,7 @@ def barlabel(ax, rects, positions):
self._check_show()
return fig, ax

def plot(self, plot_fun_name, *args, **kwargs):
def tvb_plot(self, plot_fun_name, *args, **kwargs):
import tvb.simulator.plot.tools as TVB_plot_tools
getattr(TVB_plot_tools, plot_fun_name)(*args, **kwargs)
fig = pyplot.gcf()
Expand Down
134 changes: 104 additions & 30 deletions scientific_library/tvb/simulator/plot/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,47 +34,63 @@
"""

import os
import numpy as np
from datetime import datetime

import numpy


class GenericConfig(object):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should not be here!!!!

_module_path = os.path.dirname(__file__)

# Identify and choose the Simulator, or data folder type to read.
MODE_H5 = "H5"
MODE_TVB = "TVB"


class InputConfig(object):
_base_input = os.getcwd()

@property
def HEAD(self):
if self._head_folder is not None:
return self._head_folder

# or else, try to find tvb_data module
try:
import tvb_data
return os.path.dirname(tvb_data.__file__)
except ImportError:
return self._base_input

@property
def IS_TVB_MODE(self):
"""Identify and choose the Input data type to use"""
return self._data_mode == GenericConfig.MODE_TVB

@property
def RAW_DATA_FOLDER(self):
if self._raw_data is not None:
return self._raw_data

return os.path.join(self._base_input, "data", "raw")

def __init__(self, head_folder=None, raw_folder=None, data_mode=GenericConfig.MODE_TVB):
self._head_folder = head_folder
self._raw_data = raw_folder
self._data_mode = data_mode

class FiguresConfig(object):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have here under tvb.simulator.PLOT just a configuration for Figures. This is the single purpose for this module

VERY_LARGE_SIZE = (40, 20)
VERY_LARGE_PORTRAIT = (30, 50)
SUPER_LARGE_SIZE = (80, 40)
LARGE_SIZE = (20, 15)
SMALL_SIZE = (15, 10)
NOTEBOOK_SIZE = (10, 7)
FIG_FORMAT = 'png'
SAVE_FLAG = True
SHOW_FLAG = False
MOUSE_HOOVER = False
MATPLOTLIB_BACKEND = "Agg" # "Qt4Agg"
FONTSIZE = 10
WEIGHTS_NORM_PERCENT = 99
MAX_SINGLE_VALUE = np.finfo("single").max
MAX_INT_VALUE = np.iinfo(np.int64).max

class OutputConfig(object):
subfolder = None

def __init__(self, out_base=None, separate_by_run=False):
print(out_base)
print(separate_by_run)
"""
:param work_folder: Base folder where logs/figures/results should be kept
:param separate_by_run: Set TRUE, when you want logs/results/figures to be in different files / each run
"""
self._out_base = out_base or os.path.join(os.getcwd(), "outputs")
self._separate_by_run = separate_by_run

def largest_size(self):
import sys
if 'IPython' not in sys.modules:
return self.LARGE_SIZE
from IPython import get_ipython
if getattr(get_ipython(), 'kernel', None) is not None:
return self.NOTEBOOK_SIZE
else:
return self.LARGE_SIZE

@property
def FOLDER_LOGS(self):
folder = os.path.join(self._out_base, "logs")
Expand All @@ -91,6 +107,8 @@ def FOLDER_RES(self):
folder = folder + datetime.strftime(datetime.now(), '%Y-%m-%d_%H-%M')
if not (os.path.isdir(folder)):
os.makedirs(folder)
if self.subfolder is not None:
folder = os.path.join(folder, self.subfolder)
return folder

@property
Expand All @@ -100,7 +118,63 @@ def FOLDER_FIGURES(self):
folder = folder + datetime.strftime(datetime.now(), '%Y-%m-%d_%H-%M')
if not (os.path.isdir(folder)):
os.makedirs(folder)
if self.subfolder is not None:
os.path.join(folder, self.subfolder)
return folder

@property
def FOLDER_TEMP(self):
return os.path.join(self._out_base, "temp")


class FiguresConfig(object):
VERY_LARGE_SIZE = (40, 20)
VERY_LARGE_PORTRAIT = (30, 50)
SUPER_LARGE_SIZE = (80, 40)
LARGE_SIZE = (20, 15)
SMALL_SIZE = (15, 10)
NOTEBOOK_SIZE = (20, 10)
FIG_FORMAT = 'png'
SAVE_FLAG = True
SHOW_FLAG = False
MOUSE_HOOVER = False
MATPLOTLIB_BACKEND = "Agg" # "Qt4Agg"
FONTSIZE = 10
SMALL_FONTSIZE = 8
LARGE_FONTSIZE = 12

def largest_size(self):
import sys
if 'IPython' not in sys.modules:
return self.LARGE_SIZE
from IPython import get_ipython
if getattr(get_ipython(), 'kernel', None) is not None:
return self.NOTEBOOK_SIZE
else:
return self.LARGE_SIZE


class CalculusConfig(object):
# Normalization configuration
WEIGHTS_NORM_PERCENT = 99

# If True a plot will be generated to choose the number of eigenvalues to keep
INTERACTIVE_ELBOW_POINT = False

MIN_SINGLE_VALUE = numpy.finfo("single").min
MAX_SINGLE_VALUE = numpy.finfo("single").max
MAX_INT_VALUE = numpy.iinfo(numpy.int64).max
MIN_INT_VALUE = numpy.iinfo(numpy.int64).max


class Config(object):
generic = GenericConfig()
figures = FiguresConfig()
calcul = CalculusConfig()

def __init__(self, head_folder=None, raw_data_folder=None, output_base=None, separate_by_run=False):
self.input = InputConfig(head_folder, raw_data_folder)
self.out = OutputConfig(output_base, separate_by_run)


CONFIGURED = FiguresConfig()
CONFIGURED = Config()
Loading