11import numpy as np
22import plotly .graph_objects as go
3-
3+ from mpl_toolkits .axes_grid1 import make_axes_locatable
4+ import matplotlib .pyplot as plt
45import maxplotlib .subfigure .tikz_figure as tf
56
67class Node :
@@ -44,7 +45,7 @@ def __init__(self, **kwargs):
4445 self ._description = kwargs .get ("description" , None )
4546 self ._label = kwargs .get ("label" , None )
4647 self ._grid = kwargs .get ("grid" , False )
47- self ._legend = kwargs .get ("legend" , True )
48+ self ._legend = kwargs .get ("legend" , False )
4849
4950 self ._xlabel = kwargs .get ("xlabel" , None )
5051 self ._ylabel = kwargs .get ("ylabel" , None )
@@ -69,6 +70,13 @@ def __init__(self, **kwargs):
6970 def add_caption (self , caption ):
7071 self ._caption = caption
7172
73+ def _add (self , obj , layer ):
74+ self .line_data .append (obj )
75+ if layer in self .layered_line_data :
76+ self .layered_line_data [layer ].append (obj )
77+ else :
78+ self .layered_line_data [layer ] = [obj ]
79+
7280 def add_line (self , x_data , y_data , layer = 0 , plot_type = 'plot' , ** kwargs ):
7381 """
7482 Add a line to the plot.
@@ -86,11 +94,7 @@ def add_line(self, x_data, y_data, layer=0, plot_type='plot', **kwargs):
8694 "plot_type" : plot_type ,
8795 "kwargs" : kwargs ,
8896 }
89- self .line_data .append (ld )
90- if layer in self .layered_line_data :
91- self .layered_line_data [layer ].append (ld )
92- else :
93- self .layered_line_data [layer ] = [ld ]
97+ self ._add (ld , layer )
9498
9599 def add_imshow (self , data , layer = 0 , plot_type = 'imshow' , ** kwargs ):
96100 ld = {
@@ -99,11 +103,25 @@ def add_imshow(self, data, layer=0, plot_type='imshow', **kwargs):
99103 "plot_type" : plot_type ,
100104 "kwargs" : kwargs ,
101105 }
102- self .line_data .append (ld )
103- if layer in self .layered_line_data :
104- self .layered_line_data [layer ].append (ld )
105- else :
106- self .layered_line_data [layer ] = [ld ]
106+ self ._add (ld , layer )
107+
108+ def add_patch (self , patch , layer = 0 , plot_type = 'patch' , ** kwargs ):
109+ ld = {
110+ "patch" : patch ,
111+ "layer" : layer ,
112+ "plot_type" : plot_type ,
113+ "kwargs" : kwargs ,
114+ }
115+ self ._add (ld , layer )
116+
117+ def add_colorbar (self , label = "" , layer = 0 , plot_type = 'colorbar' , ** kwargs ):
118+ cb = {
119+ "label" : label ,
120+ "layer" : layer ,
121+ "plot_type" : plot_type ,
122+ "kwargs" : kwargs ,
123+ }
124+ self ._add (cb , layer )
107125
108126 @property
109127 def layers (self ):
@@ -136,12 +154,18 @@ def plot_matplotlib(self, ax, layers=None):
136154 ** line ["kwargs" ],
137155 )
138156 elif line ["plot_type" ] == "imshow" :
139- ax .imshow (
157+ im = ax .imshow (
140158 line ["data" ],
141159 ** line ["kwargs" ],
142160 )
143- # if self._caption:
144- # ax.set_title(self._caption)
161+ elif line ["plot_type" ] == "patch" :
162+ ax .add_patch (line ["patch" ],
163+ ** line ["kwargs" ],
164+ )
165+ elif line ["plot_type" ] == "colorbar" :
166+ divider = make_axes_locatable (ax )
167+ cax = divider .append_axes ("right" , size = "5%" , pad = 0.05 )
168+ plt .colorbar (im , cax = cax , label = "Potential (V)" )
145169 if self ._title :
146170 ax .set_title (self ._title )
147171 if self ._label :
0 commit comments