33# Document Author
44# Yuta Nakahara <yuta.nakahara@aoni.waseda.jp>
55import warnings
6- import copy
7- import pickle
86import numpy as np
97import matplotlib .pyplot as plt
108from matplotlib .colors import rgb2hex
@@ -202,19 +200,19 @@ def _gen_sample_recursion(self,node,x):
202200 else :
203201 return self ._gen_sample_recursion (node .children [x [- node .depth - 1 ]],x )
204202
205- def _visualize_model_recursion (self ,tree_graph ,node :_Node ,node_id ,parent_id ,sibling_num ,p_v ):
203+ def _visualize_model_recursion (self ,tree_graph ,node :_Node ,node_id ,parent_id ,sibling_num ,p_s ):
206204 tmp_id = node_id
207- tmp_p_v = p_v
205+ tmp_p_s = p_s
208206
209207 # add node information
210- label_string = f'h_g={ node .h_g :.2f} \\ lp_v= { tmp_p_v :.2f} \\ ltheta_vec\\ l='
208+ label_string = f'h_g={ node .h_g :.2f} \\ lp_s= { tmp_p_s :.2f} \\ ltheta_vec\\ l='
211209 if node .leaf :
212210 label_string += f'{ np .array2string (node .theta_vec ,precision = 2 ,max_line_width = 11 )} \\ l'
213211 else :
214212 label_string += 'None\\ l'
215213
216- tree_graph .node (name = f'{ tmp_id } ' ,label = label_string ,fillcolor = f'{ rgb2hex (_CMAP (tmp_p_v ))} ' )
217- if tmp_p_v > 0.65 :
214+ tree_graph .node (name = f'{ tmp_id } ' ,label = label_string ,fillcolor = f'{ rgb2hex (_CMAP (tmp_p_s ))} ' )
215+ if tmp_p_s > 0.65 :
218216 tree_graph .node (name = f'{ tmp_id } ' ,fontcolor = 'white' )
219217
220218 # add edge information
@@ -223,7 +221,7 @@ def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibl
223221
224222 if node .leaf != True :
225223 for i in range (self .c_k ):
226- node_id = self ._visualize_model_recursion (tree_graph ,node .children [i ],node_id + 1 ,tmp_id ,i ,tmp_p_v * node .h_g )
224+ node_id = self ._visualize_model_recursion (tree_graph ,node .children [i ],node_id + 1 ,tmp_id ,i ,tmp_p_s * node .h_g )
227225
228226 return node_id
229227
@@ -307,7 +305,7 @@ def set_params(self,root=None):
307305 if root is not None :
308306 if type (root ) is not _Node :
309307 raise (ParameterFormatError (
310- "root must be an instance of metatree ._Node"
308+ "root must be an instance of contexttree ._Node"
311309 ))
312310 self ._set_params_recursion (self .root ,root )
313311 return self
@@ -794,7 +792,7 @@ def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
794792 Loss function underlying the Bayes risk function, by default ``\" 0-1\" ``.
795793 This function supports only ``\" 0-1\" ``.
796794 visualize : bool, optional
797- If ``True``, the estimated metatree will be visualized, by default ``True``.
795+ If ``True``, the estimated context tree model will be visualized, by default ``True``.
798796 This visualization requires ``graphviz``.
799797 filename : str, optional
800798 Filename for saving the figure, by default ``None``
@@ -804,8 +802,8 @@ def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
804802
805803 Returns
806804 -------
807- map_root : metatree ._Node
808- The root node of the estimated meta- tree
805+ map_root : contexttree ._Node
806+ The root node of the estimated context tree model
809807 that also contains the estimated parameters in each node.
810808
811809 See Also
@@ -824,32 +822,34 @@ def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
824822 import graphviz
825823 tree_graph = graphviz .Digraph (filename = filename ,format = format )
826824 tree_graph .attr ("node" ,shape = "box" ,fontname = "helvetica" ,style = "rounded,filled" )
827- self ._visualize_model_recursion (tree_graph , map_root , 0 , None , None , 1.0 , True )
825+ self ._visualize_model_recursion (tree_graph , map_root , 0 , None , None , 1.0 , True , False )
828826 # Can we show the image on the console without saving the file?
829827 tree_graph .view ()
830828 return map_root
831829 else :
832830 raise (CriteriaError ("Unsupported loss function! "
833831 + "This function supports only \" 0-1\" ." ))
834832
835- def _visualize_model_recursion (self ,tree_graph ,node :_Node ,node_id ,parent_id ,sibling_num ,p_v ,map_tree ):
833+ def _visualize_model_recursion (self ,tree_graph ,node :_Node ,node_id ,parent_id ,sibling_num ,p_s ,map_tree , h_params ):
836834 tmp_id = node_id
837- tmp_p_v = p_v
835+ tmp_p_s = p_s
838836
839837 # add node information
840- label_string = f'hn_g={ node .h_g :.2f} \\ lp_v= { tmp_p_v :.2f} \\ ltheta_vec \\ l= '
838+ label_string = f'hn_g={ node .h_g :.2f} \\ lp_s= { tmp_p_s :.2f} \\ l '
841839 if map_tree and not node .leaf :
842- label_string += 'None\\ l'
840+ label_string += 'theta_vec \\ l= None\\ l'
843841 else :
844- if np .all (node .h_beta_vec > 1 ):
842+ if h_params :
843+ label_string += f'hn_beta_vec\\ l={ np .array2string (node .h_beta_vec ,precision = 2 ,max_line_width = 11 )} \\ l'
844+ elif np .all (node .h_beta_vec > 1 ):
845845 theta_vec_hat = (node .h_beta_vec - 1 ) / (np .sum (node .h_beta_vec ) - self .c_k )
846- label_string += f'{ np .array2string (theta_vec_hat ,precision = 2 ,max_line_width = 11 )} \\ l'
846+ label_string += f'theta_vec \\ l= { np .array2string (theta_vec_hat ,precision = 2 ,max_line_width = 11 )} \\ l'
847847 else :
848- warnings .warn ("MAP estimate of theta_vec doesn't exist for the current h_beta_vec ." ,ResultWarning )
849- label_string += 'None\\ l'
848+ warnings .warn ("MAP estimate of theta_vec doesn't exist for the current hn_beta_vec ." ,ResultWarning )
849+ label_string += 'theta_vec \\ l= None\\ l'
850850
851- tree_graph .node (name = f'{ tmp_id } ' ,label = label_string ,fillcolor = f'{ rgb2hex (_CMAP (tmp_p_v ))} ' )
852- if tmp_p_v > 0.65 :
851+ tree_graph .node (name = f'{ tmp_id } ' ,label = label_string ,fillcolor = f'{ rgb2hex (_CMAP (tmp_p_s ))} ' )
852+ if tmp_p_s > 0.65 :
853853 tree_graph .node (name = f'{ tmp_id } ' ,fontcolor = 'white' )
854854
855855 # add edge information
@@ -858,29 +858,31 @@ def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibl
858858
859859 for i in range (self .c_k ):
860860 if node .children [i ] is not None :
861- node_id = self ._visualize_model_recursion (tree_graph ,node .children [i ],node_id + 1 ,tmp_id ,i ,tmp_p_v * node .h_g ,map_tree )
861+ node_id = self ._visualize_model_recursion (tree_graph ,node .children [i ],node_id + 1 ,tmp_id ,i ,tmp_p_s * node .h_g ,map_tree , h_params )
862862
863863 return node_id
864864
865- def _visualize_model_recursion_none (self ,tree_graph ,depth ,node_id ,parent_id ,sibling_num ,p_v ):
865+ def _visualize_model_recursion_none (self ,tree_graph ,depth ,node_id ,parent_id ,sibling_num ,p_s , h_params ):
866866 tmp_id = node_id
867- tmp_p_v = p_v
867+ tmp_p_s = p_s
868868
869869 # add node information
870870 if depth == self .c_d_max :
871871 label_string = 'hn_g=0.0\\ l'
872872 else :
873873 label_string = f'hn_g={ self .hn_g :.2f} \\ l'
874- label_string += f'p_v={ tmp_p_v :.2f} \\ ltheta_vec\\ l='
875- if np .all (self .hn_beta_vec > 1 ):
874+ label_string += f'p_s={ tmp_p_s :.2f} \\ l'
875+ if h_params :
876+ label_string += f'hn_beta_vec\\ l={ np .array2string (self .hn_beta_vec ,precision = 2 ,max_line_width = 11 )} \\ l'
877+ elif np .all (self .hn_beta_vec > 1 ):
876878 theta_vec_hat = (self .hn_beta_vec - 1 ) / (np .sum (self .hn_beta_vec ) - self .c_k )
877- label_string += f'{ np .array2string (theta_vec_hat ,precision = 2 ,max_line_width = 11 )} \\ l'
879+ label_string += f'theta_vec \\ l= { np .array2string (theta_vec_hat ,precision = 2 ,max_line_width = 11 )} \\ l'
878880 else :
879- warnings .warn ("MAP estimate of theta_vec doesn't exist for the current h_beta_vec ." ,ResultWarning )
880- label_string += 'None\\ l'
881+ warnings .warn ("MAP estimate of theta_vec doesn't exist for the current hn_beta_vec ." ,ResultWarning )
882+ label_string += 'theta_vec \\ l= None\\ l'
881883
882- tree_graph .node (name = f'{ tmp_id } ' ,label = label_string ,fillcolor = f'{ rgb2hex (_CMAP (tmp_p_v ))} ' )
883- if tmp_p_v > 0.65 :
884+ tree_graph .node (name = f'{ tmp_id } ' ,label = label_string ,fillcolor = f'{ rgb2hex (_CMAP (tmp_p_s ))} ' )
885+ if tmp_p_s > 0.65 :
884886 tree_graph .node (name = f'{ tmp_id } ' ,fontcolor = 'white' )
885887
886888 # add edge information
@@ -889,11 +891,11 @@ def _visualize_model_recursion_none(self,tree_graph,depth,node_id,parent_id,sibl
889891
890892 if depth < self .c_d_max :
891893 for i in range (self .c_k ):
892- node_id = self ._visualize_model_recursion_none (tree_graph ,depth + 1 ,node_id + 1 ,tmp_id ,i ,tmp_p_v * self .hn_g )
894+ node_id = self ._visualize_model_recursion_none (tree_graph ,depth + 1 ,node_id + 1 ,tmp_id ,i ,tmp_p_s * self .hn_g , h_params )
893895
894896 return node_id
895897
896- def visualize_posterior (self ,filename = None ,format = None ):
898+ def visualize_posterior (self ,filename = None ,format = None , h_params = False ):
897899 """Visualize the posterior distribution for the parameter.
898900
899901 This method requires ``graphviz``.
@@ -904,13 +906,16 @@ def visualize_posterior(self,filename=None,format=None):
904906 Filename for saving the figure, by default ``None``
905907 format : str, optional
906908 Rendering output format (``\" pdf\" ``, ``\" png\" ``, ...).
909+ h_params : bool, optional
910+ If ``True``, hyperparameters at each node will be visualized.
911+ if ``False``, estimated parameters at each node will be visulaized.
907912
908913 Examples
909914 --------
910915 >>> from bayesml import contexttree
911916 >>> gen_model = contexttree.GenModel(c_k=2,c_d_max=3,h_g=0.75)
912917 >>> gen_model.gen_params()
913- >>> x = gen_model.gen_sample(50 )
918+ >>> x = gen_model.gen_sample(500 )
914919 >>> learn_model = contexttree.LearnModel(c_k=2,c_d_max=3,h0_g=0.75)
915920 >>> learn_model.update_posterior(x)
916921 >>> learn_model.visualize_posterior()
@@ -926,9 +931,9 @@ def visualize_posterior(self,filename=None,format=None):
926931 tree_graph = graphviz .Digraph (filename = filename ,format = format )
927932 tree_graph .attr ("node" ,shape = "box" ,fontname = "helvetica" ,style = "rounded,filled" )
928933 if self .hn_root is None :
929- self ._visualize_model_recursion_none (tree_graph , 0 , 0 , None , None , 1.0 , False )
934+ self ._visualize_model_recursion_none (tree_graph , 0 , 0 , None , None , 1.0 , h_params )
930935 else :
931- self ._visualize_model_recursion (tree_graph , self .hn_root , 0 , None , None , 1.0 , False )
936+ self ._visualize_model_recursion (tree_graph , self .hn_root , 0 , None , None , 1.0 , False , h_params )
932937 # Can we show the image on the console without saving the file?
933938 tree_graph .view ()
934939 except ImportError as e :
0 commit comments