Skip to content

Commit 256169d

Browse files
committed
Refactor plot_solution and _plot_solution_from_container to use arrays instead of dictionaries for variable_output and has_impact
1 parent 9cdc8a9 commit 256169d

File tree

1 file changed

+73
-25
lines changed

1 file changed

+73
-25
lines changed

ext/StatsPlotsExt.jl

Lines changed: 73 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3605,8 +3605,8 @@ function plot_solution(𝓂::ℳ,
36053605

36063606
var_state_range = hcat(var_state_range...)
36073607

3608-
variable_output = Dict()
3609-
has_impact = Dict()
3608+
variable_output = []
3609+
has_impact = []
36103610

36113611
for k in vars_to_plot
36123612
idx = indexin([k], 𝓂.var)
@@ -3636,7 +3636,7 @@ function plot_solution(𝓂::ℳ,
36363636
:variable_output => variable_output,
36373637
:has_impact => has_impact,
36383638
:vars_to_plot => vars_to_plot,
3639-
:full_SS_current => full_SS_current[indexin(vars_to_plot, 𝓂.var)],
3639+
:full_SS_current => full_SS_current[indexin(sort(vcat(state, vars_to_plot)), 𝓂.var)],
36403640
:algorithm_label => labels[algorithm][1],
36413641
:ss_label => labels[algorithm][2],
36423642
:rename_dictionary => processed_rename_dictionary)
@@ -3675,9 +3675,9 @@ function _plot_solution_from_container(;
36753675
model_name = first_container[:model_name]
36763676

36773677
# Collect all unique states from containers
3678-
joint_states = OrderedSet{Symbol}()
3678+
joint_states = OrderedSet{String}()
36793679
for container in solution_active_plot_container
3680-
push!(joint_states, container[:state])
3680+
push!(joint_states, string(apply_custom_name.(container[:state], Ref(Dict(container[:rename_dictionary])))))
36813681
end
36823682

36833683
gr_back = StatsPlots.backend() == StatsPlots.Plots.GRBackend()
@@ -3749,6 +3749,8 @@ function _plot_solution_from_container(;
37493749
model_names = unique(model_names)
37503750

37513751
for model in model_names
3752+
# println(grouped_by_model[model])
3753+
# println(typeof(grouped_by_model[model][:has_impact]))
37523754
if length(grouped_by_model[model]) > 1
37533755
diffdict_grouped = compare_args_and_kwargs(grouped_by_model[model])
37543756
diffdict = merge_by_runid(diffdict, diffdict_grouped)
@@ -3880,26 +3882,29 @@ function _plot_solution_from_container(;
38803882
end
38813883

38823884
# Collect all variables to plot across all containers
3883-
all_vars = OrderedSet{Symbol}()
3885+
all_vars = OrderedSet{String}()
38843886
for container in solution_active_plot_container
3885-
foreach(v -> push!(all_vars, v), container[:vars_to_plot])
3887+
foreach(v -> push!(all_vars, v), string.(apply_custom_name.(container[:vars_to_plot], Ref(Dict(container[:rename_dictionary])))))
38863888
end
38873889

38883890
return_plots = []
38893891

38903892
# Loop over each state (similar to how plot_irf loops over shocks)
38913893
for state in joint_states
38923894
# Filter containers for this state
3893-
state_containers = [c for c in solution_active_plot_container if c[:state] == state]
3895+
state_containers = [c for c in solution_active_plot_container if string(apply_custom_name.(c[:state], Ref(Dict(c[:rename_dictionary])))) == state]
38943896

38953897
# Determine which variables have impact in at least one container for this state
38963898
vars_with_impact = []
3897-
for var in all_vars
3899+
for var in setdiff(all_vars, joint_states)
38983900
has_any_impact = false
38993901
for container in state_containers
3900-
if haskey(container[:has_impact], var) && container[:has_impact][var]
3901-
has_any_impact = true
3902-
break
3902+
for (k,v) in Dict(container[:has_impact])
3903+
k_trans = string(apply_custom_name(k, (Dict(container[:rename_dictionary]))))
3904+
if k_trans == var && v
3905+
has_any_impact = true
3906+
break
3907+
end
39033908
end
39043909
end
39053910
if has_any_impact
@@ -3915,33 +3920,76 @@ function _plot_solution_from_container(;
39153920
# Plot each variable for this state
39163921
for k in vars_with_impact
39173922
Pl = StatsPlots.plot()
3918-
3923+
3924+
39193925
# Plot line for each container with this state
39203926
for (i, container) in enumerate(solution_active_plot_container)
3921-
if container[:state] == state && haskey(container[:variable_output], k) && container[:has_impact][k]
3927+
# return the key that corresponds to k in the original variable_output dictionary
3928+
original_k_variable_output = nothing
3929+
for key in keys(Dict(container[:variable_output]))
3930+
if string(apply_custom_name(key, (Dict(container[:rename_dictionary])))) == k
3931+
original_k_variable_output = key
3932+
break
3933+
end
3934+
end
3935+
3936+
# return the key that corresponds to k in the original has_impact dictionary
3937+
original_k_has_impact = nothing
3938+
for key in keys(Dict(container[:has_impact]))
3939+
if string(apply_custom_name(key, (Dict(container[:rename_dictionary])))) == k
3940+
original_k_has_impact = key
3941+
break
3942+
end
3943+
end
3944+
3945+
if string(apply_custom_name.(container[:state], Ref(Dict(container[:rename_dictionary])))) == state && !isnothing(original_k_variable_output) && !isnothing(original_k_has_impact)
3946+
# Create concatenated transformed variable names for indexing
3947+
concat_trans_vars = string.(apply_custom_name.(sort(vcat(container[:vars_to_plot], container[:state])), Ref(Dict(container[:rename_dictionary]))))
3948+
39223949
# Find state index in vars_to_plot
3923-
state_idx = findfirst(==(state), container[:vars_to_plot])
3950+
state_idx = findfirst(==(state), concat_trans_vars)
39243951
if !isnothing(state_idx)
39253952
state_ss = container[:full_SS_current][state_idx]
39263953
else
39273954
state_ss = 0.0 # fallback
39283955
end
3929-
3956+
39303957
StatsPlots.plot!(container[:state_range] .+ state_ss,
3931-
container[:variable_output][k][1,:],
3932-
ylabel = replace_indices_in_symbol(k)*"₍₀₎",
3933-
xlabel = replace_indices_in_symbol(state)*"₍₋₁₎",
3958+
Dict(container[:variable_output])[original_k_variable_output][1,:],
3959+
ylabel = replace_indices_in_symbol(Symbol(k))*"₍₀₎",
3960+
xlabel = replace_indices_in_symbol(Symbol(state))*"₍₋₁₎",
39343961
color = pal[mod1(i, length(pal))],
39353962
label = "")
39363963
end
39373964
end
39383965

39393966
# Plot SS markers for each container with this state
39403967
for (i, container) in enumerate(solution_active_plot_container)
3941-
if container[:state] == state && haskey(container[:variable_output], k) && container[:has_impact][k]
3968+
# return the key that corresponds to k in the original variable_output dictionary
3969+
original_k_variable_output = nothing
3970+
for key in keys(Dict(container[:variable_output]))
3971+
if string(apply_custom_name(key, (Dict(container[:rename_dictionary])))) == k
3972+
original_k_variable_output = key
3973+
break
3974+
end
3975+
end
3976+
3977+
# return the key that corresponds to k in the original has_impact dictionary
3978+
original_k_has_impact = nothing
3979+
for key in keys(Dict(container[:has_impact]))
3980+
if string(apply_custom_name(key, (Dict(container[:rename_dictionary])))) == k
3981+
original_k_has_impact = key
3982+
break
3983+
end
3984+
end
3985+
3986+
if string(apply_custom_name.(container[:state], Ref(Dict(container[:rename_dictionary])))) == state && !isnothing(original_k_variable_output) && !isnothing(original_k_has_impact)
3987+
# Create concatenated transformed variable names for indexing
3988+
concat_trans_vars = string.(apply_custom_name.(sort(vcat(container[:vars_to_plot], container[:state])), Ref(Dict(container[:rename_dictionary]))))
3989+
39423990
# Get state and variable indices
3943-
state_idx = findfirst(==(state), container[:vars_to_plot])
3944-
var_idx = findfirst(==(k), container[:vars_to_plot])
3991+
state_idx = findfirst(==(state), concat_trans_vars)
3992+
var_idx = findfirst(==(k), concat_trans_vars)
39453993

39463994
if !isnothing(state_idx) && !isnothing(var_idx)
39473995
state_ss = container[:full_SS_current][state_idx]
@@ -4258,8 +4306,8 @@ function plot_solution!(𝓂::ℳ,
42584306

42594307
var_state_range = hcat(var_state_range...)
42604308

4261-
variable_output = Dict()
4262-
has_impact = Dict()
4309+
variable_output = []
4310+
has_impact = []
42634311

42644312
for k in vars_to_plot
42654313
idx = indexin([k], 𝓂.var)
@@ -4289,7 +4337,7 @@ function plot_solution!(𝓂::ℳ,
42894337
:variable_output => variable_output,
42904338
:has_impact => has_impact,
42914339
:vars_to_plot => vars_to_plot,
4292-
:full_SS_current => full_SS_current[indexin(vars_to_plot, 𝓂.var)],
4340+
:full_SS_current => full_SS_current[indexin(sort(vcat(state, vars_to_plot)), 𝓂.var)],
42934341
:algorithm_label => labels[algorithm][1],
42944342
:ss_label => labels[algorithm][2],
42954343
:rename_dictionary => processed_rename_dictionary)

0 commit comments

Comments
 (0)