@@ -3605,8 +3605,8 @@ function plot_solution(𝓂::ℳ,
36053605
36063606 var_state_range = hcat (var_state_range... )
36073607
3608- variable_output = Dict ()
3609- has_impact = Dict ()
3608+ variable_output = []
3609+ has_impact = []
36103610
36113611 for k in vars_to_plot
36123612 idx = indexin ([k], 𝓂. var)
@@ -3636,7 +3636,7 @@ function plot_solution(𝓂::ℳ,
36363636 :variable_output => variable_output,
36373637 :has_impact => has_impact,
36383638 :vars_to_plot => vars_to_plot,
3639- :full_SS_current => full_SS_current[indexin (vars_to_plot, 𝓂. var)],
3639+ :full_SS_current => full_SS_current[indexin (sort ( vcat (state, vars_to_plot)) , 𝓂. var)],
36403640 :algorithm_label => labels[algorithm][1 ],
36413641 :ss_label => labels[algorithm][2 ],
36423642 :rename_dictionary => processed_rename_dictionary)
@@ -3675,9 +3675,9 @@ function _plot_solution_from_container(;
36753675 model_name = first_container[:model_name ]
36763676
36773677 # Collect all unique states from containers
3678- joint_states = OrderedSet {Symbol } ()
3678+ joint_states = OrderedSet {String } ()
36793679 for container in solution_active_plot_container
3680- push! (joint_states, container[:state ])
3680+ push! (joint_states, string ( apply_custom_name .( container[:state ], Ref ( Dict (container[ :rename_dictionary ])))) )
36813681 end
36823682
36833683 gr_back = StatsPlots. backend () == StatsPlots. Plots. GRBackend ()
@@ -3749,6 +3749,8 @@ function _plot_solution_from_container(;
37493749 model_names = unique (model_names)
37503750
37513751 for model in model_names
3752+ # println(grouped_by_model[model])
3753+ # println(typeof(grouped_by_model[model][:has_impact]))
37523754 if length (grouped_by_model[model]) > 1
37533755 diffdict_grouped = compare_args_and_kwargs (grouped_by_model[model])
37543756 diffdict = merge_by_runid (diffdict, diffdict_grouped)
@@ -3880,26 +3882,29 @@ function _plot_solution_from_container(;
38803882 end
38813883
38823884 # Collect all variables to plot across all containers
3883- all_vars = OrderedSet {Symbol } ()
3885+ all_vars = OrderedSet {String } ()
38843886 for container in solution_active_plot_container
3885- foreach (v -> push! (all_vars, v), container[:vars_to_plot ])
3887+ foreach (v -> push! (all_vars, v), string .( apply_custom_name .( container[:vars_to_plot ], Ref ( Dict (container[ :rename_dictionary ])))) )
38863888 end
38873889
38883890 return_plots = []
38893891
38903892 # Loop over each state (similar to how plot_irf loops over shocks)
38913893 for state in joint_states
38923894 # Filter containers for this state
3893- state_containers = [c for c in solution_active_plot_container if c[:state ] == state]
3895+ state_containers = [c for c in solution_active_plot_container if string ( apply_custom_name .( c[:state ], Ref ( Dict (c[ :rename_dictionary ])))) == state]
38943896
38953897 # Determine which variables have impact in at least one container for this state
38963898 vars_with_impact = []
3897- for var in all_vars
3899+ for var in setdiff ( all_vars, joint_states)
38983900 has_any_impact = false
38993901 for container in state_containers
3900- if haskey (container[:has_impact ], var) && container[:has_impact ][var]
3901- has_any_impact = true
3902- break
3902+ for (k,v) in Dict (container[:has_impact ])
3903+ k_trans = string (apply_custom_name (k, (Dict (container[:rename_dictionary ]))))
3904+ if k_trans == var && v
3905+ has_any_impact = true
3906+ break
3907+ end
39033908 end
39043909 end
39053910 if has_any_impact
@@ -3915,33 +3920,76 @@ function _plot_solution_from_container(;
39153920 # Plot each variable for this state
39163921 for k in vars_with_impact
39173922 Pl = StatsPlots. plot ()
3918-
3923+
3924+
39193925 # Plot line for each container with this state
39203926 for (i, container) in enumerate (solution_active_plot_container)
3921- if container[:state ] == state && haskey (container[:variable_output ], k) && container[:has_impact ][k]
3927+ # return the key that corresponds to k in the original variable_output dictionary
3928+ original_k_variable_output = nothing
3929+ for key in keys (Dict (container[:variable_output ]))
3930+ if string (apply_custom_name (key, (Dict (container[:rename_dictionary ])))) == k
3931+ original_k_variable_output = key
3932+ break
3933+ end
3934+ end
3935+
3936+ # return the key that corresponds to k in the original has_impact dictionary
3937+ original_k_has_impact = nothing
3938+ for key in keys (Dict (container[:has_impact ]))
3939+ if string (apply_custom_name (key, (Dict (container[:rename_dictionary ])))) == k
3940+ original_k_has_impact = key
3941+ break
3942+ end
3943+ end
3944+
3945+ if string (apply_custom_name .(container[:state ], Ref (Dict (container[:rename_dictionary ])))) == state && ! isnothing (original_k_variable_output) && ! isnothing (original_k_has_impact)
3946+ # Create concatenated transformed variable names for indexing
3947+ concat_trans_vars = string .(apply_custom_name .(sort (vcat (container[:vars_to_plot ], container[:state ])), Ref (Dict (container[:rename_dictionary ]))))
3948+
39223949 # Find state index in vars_to_plot
3923- state_idx = findfirst (== (state), container[ :vars_to_plot ] )
3950+ state_idx = findfirst (== (state), concat_trans_vars )
39243951 if ! isnothing (state_idx)
39253952 state_ss = container[:full_SS_current ][state_idx]
39263953 else
39273954 state_ss = 0.0 # fallback
39283955 end
3929-
3956+
39303957 StatsPlots. plot! (container[:state_range ] .+ state_ss,
3931- container[:variable_output ][k ][1 ,:],
3932- ylabel = replace_indices_in_symbol (k )* " ₍₀₎" ,
3933- xlabel = replace_indices_in_symbol (state)* " ₍₋₁₎" ,
3958+ Dict ( container[:variable_output ])[original_k_variable_output ][1 ,:],
3959+ ylabel = replace_indices_in_symbol (Symbol (k) )* " ₍₀₎" ,
3960+ xlabel = replace_indices_in_symbol (Symbol ( state) )* " ₍₋₁₎" ,
39343961 color = pal[mod1 (i, length (pal))],
39353962 label = " " )
39363963 end
39373964 end
39383965
39393966 # Plot SS markers for each container with this state
39403967 for (i, container) in enumerate (solution_active_plot_container)
3941- if container[:state ] == state && haskey (container[:variable_output ], k) && container[:has_impact ][k]
3968+ # return the key that corresponds to k in the original variable_output dictionary
3969+ original_k_variable_output = nothing
3970+ for key in keys (Dict (container[:variable_output ]))
3971+ if string (apply_custom_name (key, (Dict (container[:rename_dictionary ])))) == k
3972+ original_k_variable_output = key
3973+ break
3974+ end
3975+ end
3976+
3977+ # return the key that corresponds to k in the original has_impact dictionary
3978+ original_k_has_impact = nothing
3979+ for key in keys (Dict (container[:has_impact ]))
3980+ if string (apply_custom_name (key, (Dict (container[:rename_dictionary ])))) == k
3981+ original_k_has_impact = key
3982+ break
3983+ end
3984+ end
3985+
3986+ if string (apply_custom_name .(container[:state ], Ref (Dict (container[:rename_dictionary ])))) == state && ! isnothing (original_k_variable_output) && ! isnothing (original_k_has_impact)
3987+ # Create concatenated transformed variable names for indexing
3988+ concat_trans_vars = string .(apply_custom_name .(sort (vcat (container[:vars_to_plot ], container[:state ])), Ref (Dict (container[:rename_dictionary ]))))
3989+
39423990 # Get state and variable indices
3943- state_idx = findfirst (== (state), container[ :vars_to_plot ] )
3944- var_idx = findfirst (== (k), container[ :vars_to_plot ] )
3991+ state_idx = findfirst (== (state), concat_trans_vars )
3992+ var_idx = findfirst (== (k), concat_trans_vars )
39453993
39463994 if ! isnothing (state_idx) && ! isnothing (var_idx)
39473995 state_ss = container[:full_SS_current ][state_idx]
@@ -4258,8 +4306,8 @@ function plot_solution!(𝓂::ℳ,
42584306
42594307 var_state_range = hcat (var_state_range... )
42604308
4261- variable_output = Dict ()
4262- has_impact = Dict ()
4309+ variable_output = []
4310+ has_impact = []
42634311
42644312 for k in vars_to_plot
42654313 idx = indexin ([k], 𝓂. var)
@@ -4289,7 +4337,7 @@ function plot_solution!(𝓂::ℳ,
42894337 :variable_output => variable_output,
42904338 :has_impact => has_impact,
42914339 :vars_to_plot => vars_to_plot,
4292- :full_SS_current => full_SS_current[indexin (vars_to_plot, 𝓂. var)],
4340+ :full_SS_current => full_SS_current[indexin (sort ( vcat (state, vars_to_plot)) , 𝓂. var)],
42934341 :algorithm_label => labels[algorithm][1 ],
42944342 :ss_label => labels[algorithm][2 ],
42954343 :rename_dictionary => processed_rename_dictionary)
0 commit comments