Skip to content

Commit 00b2f8b

Browse files
committed
dual axis moves inside subplot func; add stack option; twinx simplified for plot_irf
1 parent c8f406b commit 00b2f8b

File tree

3 files changed

+205
-67
lines changed

3 files changed

+205
-67
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3434
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
3535
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
3636
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
37+
Showoff = "992d4aef-0814-514b-bc4d-f2e9a6c4116f"
3738
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3839
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3940
Subscripts = "2b7f82d5-8785-4f63-971e-f18ddbeb808e"
@@ -92,6 +93,7 @@ Random = "1"
9293
RecursiveFactorization = "0.2"
9394
Reexport = "1"
9495
RuntimeGeneratedFunctions = "0.5"
96+
Showoff = "1"
9597
SparseArrays = "1"
9698
SpecialFunctions = "2"
9799
StatsPlots = "0.15"

ext/StatsPlotsExt.jl

Lines changed: 151 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ const irf_active_plot_container = Dict[]
99
const model_estimates_active_plot_container = Dict[]
1010

1111
import StatsPlots
12+
import Showoff
1213
import DataStructures: OrderedSet
1314
import SparseArrays: SparseMatrixCSC
1415
import NLopt
@@ -925,12 +926,10 @@ function plot_irf(𝓂::ℳ;
925926
for i in 1:length(var_idx)
926927
SS = reference_steady_state[var_idx[i]]
927928

928-
can_dual_axis = gr_back && all((Y[i,:,shock] .+ SS) .> eps(Float32)) && (SS > eps(Float32))
929-
930929
if !(all(isapprox.(Y[i,:,shock],0,atol = eps(Float32))))
931930
variable_name = replace_indices_in_symbol(𝓂.timings.var[var_idx[i]])
932931

933-
push!(pp, plot_irf_subplot(Y[i,:,shock], SS, variable_name, can_dual_axis))
932+
push!(pp, plot_irf_subplot(Y[i,:,shock], SS, variable_name, gr_back))
934933

935934
if !(plot_count % plots_per_page == 0)
936935
plot_count += 1
@@ -1003,15 +1002,20 @@ function plot_irf(𝓂::ℳ;
10031002
end
10041003

10051004

1006-
function plot_irf_subplot(irf_data::AbstractVector{S}, steady_state::S, variable_name::String, can_dual_axis::Bool) where S <: AbstractFloat
1005+
function plot_irf_subplot(irf_data::AbstractVector{S}, steady_state::S, variable_name::String, gr_back::Bool) where S <: AbstractFloat
10071006
p = StatsPlots.plot(irf_data .+ steady_state,
10081007
title = variable_name,
10091008
ylabel = "Level",
10101009
label = "")
10111010

1011+
can_dual_axis = gr_back && all((irf_data .+ steady_state) .> eps(Float32)) && (steady_state > eps(Float32))
1012+
1013+
lo, hi = StatsPlots.ylims(p)
1014+
10121015
if can_dual_axis
10131016
StatsPlots.plot!(StatsPlots.twinx(),
1014-
100*((irf_data .+ steady_state) ./ steady_state .- 1),
1017+
# 100*((irf_data .+ steady_state) ./ steady_state .- 1),
1018+
ylims = (100 * (lo / steady_state - 1), 100 * (hi / steady_state - 1)),
10151019
ylabel = LaTeXStrings.L"\% \Delta",
10161020
label = "")
10171021
end
@@ -1022,20 +1026,27 @@ function plot_irf_subplot(irf_data::AbstractVector{S}, steady_state::S, variable
10221026
return p
10231027
end
10241028

1025-
function plot_irf_subplot(irf_data::Vector{<:AbstractVector{S}}, steady_state::Vector{S}, variable_name::String, can_dual_axis::Bool, same_ss::Bool; pal::StatsPlots.ColorPalette = StatsPlots.palette(:auto)) where S <: AbstractFloat
1029+
function plot_irf_subplot(::Val{:compare}, irf_data::Vector{<:AbstractVector{S}}, steady_state::Vector{S}, variable_name::String, gr_back::Bool, same_ss::Bool; pal::StatsPlots.ColorPalette = StatsPlots.palette(:auto)) where S <: AbstractFloat
10261030
plot_dat = []
1027-
plot_dat_dual = []
1031+
plot_ss = 0
10281032

10291033
pal_val = Int[]
10301034

10311035
stst = 1.0
10321036

1037+
can_dual_axis = gr_back
1038+
1039+
for (y, ss) in zip(irf_data, steady_state)
1040+
can_dual_axis = can_dual_axis && all((y .+ ss) .> eps(Float32)) && (ss > eps(Float32))
1041+
end
1042+
10331043
for (i,(y, ss)) in enumerate(zip(irf_data, steady_state))
10341044
if !isnan(ss)
10351045
stst = ss
1046+
10361047
if can_dual_axis && same_ss
10371048
push!(plot_dat, y .+ ss)
1038-
push!(plot_dat_dual, 100 * ((y .+ ss) ./ ss .- 1))
1049+
plot_ss = ss
10391050
else
10401051
if same_ss
10411052
push!(plot_dat, y .+ ss)
@@ -1053,9 +1064,12 @@ function plot_irf_subplot(irf_data::Vector{<:AbstractVector{S}}, steady_state::V
10531064
color = pal[pal_val]',
10541065
label = "")
10551066

1067+
lo, hi = StatsPlots.ylims(p)
1068+
10561069
if can_dual_axis && same_ss
10571070
StatsPlots.plot!(StatsPlots.twinx(),
1058-
plot_dat_dual,
1071+
ylims = (100 * (lo / plot_ss - 1), 100 * (hi / plot_ss - 1)),
1072+
# plot_dat_dual,
10591073
ylabel = LaTeXStrings.L"\% \Delta",
10601074
color = pal[pal_val]',
10611075
label = "")
@@ -1068,6 +1082,85 @@ function plot_irf_subplot(irf_data::Vector{<:AbstractVector{S}}, steady_state::V
10681082
end
10691083

10701084

1085+
function plot_irf_subplot(::Val{:stack}, irf_data::Vector{<:AbstractVector{S}}, steady_state::Vector{S}, variable_name::String, gr_back::Bool, same_ss::Bool; pal::StatsPlots.ColorPalette = StatsPlots.palette(:auto)) where S <: AbstractFloat
1086+
plot_dat = []
1087+
plot_ss = 0
1088+
plot_dat_dual = []
1089+
1090+
pal_val = Int[]
1091+
1092+
stst = 1.0
1093+
1094+
can_dual_axis = gr_back
1095+
1096+
for (y, ss) in zip(irf_data, steady_state)
1097+
if !isnan(ss)
1098+
can_dual_axis = can_dual_axis && all((y .+ ss) .> eps(Float32)) && (ss > eps(Float32))
1099+
end
1100+
end
1101+
1102+
for (i,(y, ss)) in enumerate(zip(irf_data, steady_state))
1103+
if !isnan(ss)
1104+
stst = ss
1105+
1106+
push!(plot_dat, y)
1107+
1108+
if can_dual_axis && same_ss
1109+
plot_ss = ss
1110+
push!(plot_dat_dual, 100 * ((y .+ ss) ./ ss .- 1))
1111+
else
1112+
if same_ss
1113+
plot_ss = ss
1114+
end
1115+
end
1116+
push!(pal_val, i)
1117+
end
1118+
end
1119+
1120+
# find maximum length
1121+
maxlen = maximum(length.(plot_dat))
1122+
1123+
# pad shorter vectors with 0
1124+
padded = [vcat(collect(v), fill(0, maxlen - length(v))) for v in plot_dat]
1125+
1126+
# now you can hcat
1127+
plot_data = reduce(hcat, padded)
1128+
1129+
p = StatsPlots.groupedbar(typeof(plot_data) <: AbstractVector ? hcat(plot_data) : plot_data,
1130+
title = variable_name,
1131+
bar_position = :stack,
1132+
linecolor = :transparent,
1133+
color = pal[pal_val]',
1134+
label = "")
1135+
1136+
# Get the current y limits
1137+
lo, hi = StatsPlots.ylims(p)
1138+
1139+
# Compute nice ticks on the shifted range
1140+
ticks_shifted, _ = StatsPlots.optimize_ticks(lo + plot_ss, hi + plot_ss, k_min = 4, k_max = 8)
1141+
1142+
labels = Showoff.showoff(ticks_shifted, :auto)
1143+
# Map tick positions back by subtracting the offset, keep shifted labels
1144+
yticks_positions = ticks_shifted .- plot_ss
1145+
1146+
StatsPlots.plot!(p; yticks = (yticks_positions, labels))
1147+
1148+
if can_dual_axis && same_ss
1149+
StatsPlots.plot!(
1150+
StatsPlots.twinx(),
1151+
ylims = (100 * ((lo + plot_ss) / plot_ss - 1), 100 * ((hi + plot_ss) / plot_ss - 1))
1152+
)
1153+
end
1154+
1155+
StatsPlots.hline!(can_dual_axis && same_ss ? [0 0] : [0],
1156+
color = :black,
1157+
ylabel = same_ss ? ["Level" LaTeXStrings.L"\% \Delta"] : "abs. " * LaTeXStrings.L"\Delta" ,
1158+
label = "")
1159+
1160+
1161+
return p
1162+
end
1163+
10711164
function plot_irf!(𝓂::ℳ;
10721165
periods::Int = 40,
10731166
shocks::Union{Symbol_input,String_input,Matrix{Float64},KeyedArray{Float64}} = :all_excluding_obc,
@@ -1084,6 +1177,7 @@ function plot_irf!(𝓂::ℳ;
10841177
generalised_irf::Bool = false,
10851178
initial_state::Union{Vector{Vector{Float64}},Vector{Float64}} = [0.0],
10861179
ignore_obc::Bool = false,
1180+
plot_type::Symbol = :compare,
10871181
plot_attributes::Dict = Dict(),
10881182
verbose::Bool = false,
10891183
tol::Tolerances = Tolerances(),
@@ -1092,6 +1186,8 @@ function plot_irf!(𝓂::ℳ;
10921186
lyapunov_algorithm::Symbol = :doubling)
10931187
# @nospecialize # reduce compile time
10941188

1189+
@assert plot_type [:compare, :stack] "plot_type must be either :compare or :stack"
1190+
10951191
opts = merge_calculation_options(tol = tol, verbose = verbose,
10961192
quadratic_matrix_equation_algorithm = quadratic_matrix_equation_algorithm,
10971193
sylvester_algorithm² = isa(sylvester_algorithm, Symbol) ? sylvester_algorithm : sylvester_algorithm[1],
@@ -1388,7 +1484,20 @@ function plot_irf!(𝓂::ℳ;
13881484
:shock_idx => shock_idx,
13891485
:var_idx => var_idx)
13901486

1391-
push!(irf_active_plot_container, args_and_kwargs)
1487+
no_duplicate = all(
1488+
!(all((
1489+
get(dict, :parameters, nothing) == args_and_kwargs[:parameters],
1490+
get(dict, :shock_names, nothing) == args_and_kwargs[:shock_names],
1491+
get(dict, :initial_state, nothing) == args_and_kwargs[:initial_state],
1492+
all(get(dict, k, nothing) == args_and_kwargs[k] for k in keys(args_and_kwargs_names))
1493+
)))
1494+
for dict in irf_active_plot_container
1495+
)# "New plot must be different from previous plot. Use the version without ! to plot."
1496+
1497+
if no_duplicate push!(irf_active_plot_container, args_and_kwargs)
1498+
else
1499+
@info "Plot with same parameters already exists. Using previous plot data to create plot."
1500+
end
13921501

13931502
# 1. Keep only certain keys from each dictionary
13941503
reduced_vector = [
@@ -1422,9 +1531,8 @@ function plot_irf!(𝓂::ℳ;
14221531
diffdict = merge_by_runid(diffdict, diffdict_grouped)
14231532
end
14241533
end
1425-
1426-
@assert haskey(diffdict, :parameters) || haskey(diffdict, :shock_names) || haskey(diffdict, :initial_state) ||
1427-
any(haskey.(Ref(diffdict), keys(args_and_kwargs_names))) "New plot must be different from previous plot. Use the version without ! to plot."
1534+
1535+
# @assert haskey(diffdict, :parameters) || haskey(diffdict, :shock_names) || haskey(diffdict, :initial_state) || any(haskey.(Ref(diffdict), keys(args_and_kwargs_names))) "New plot must be different from previous plot. Use the version without ! to plot."
14281536

14291537
annotate_ss = Vector{Pair{String, Any}}[]
14301538

@@ -1496,12 +1604,22 @@ function plot_irf!(𝓂::ℳ;
14961604
single_shock_per_irf = true
14971605

14981606
for (i,k) in enumerate(irf_active_plot_container)
1499-
StatsPlots.plot!(legend_plot,
1500-
fill(0,1,1),
1501-
legend_title = length(annotate_diff_input) > 2 ? nothing : annotate_diff_input[2][1],
1502-
framestyle = :none,
1503-
legend = :inside,
1504-
label = length(annotate_diff_input) > 2 ? i : annotate_diff_input[2][2][i] isa String ? annotate_diff_input[2][2][i] : String(Symbol(annotate_diff_input[2][2][i])))
1607+
if plot_type == :stack
1608+
StatsPlots.bar!(legend_plot,
1609+
fill(0,1,1),
1610+
legend_title = length(annotate_diff_input) > 2 ? nothing : annotate_diff_input[2][1],
1611+
framestyle = :none,
1612+
legend = :inside,
1613+
linecolor = :transparent,
1614+
label = length(annotate_diff_input) > 2 ? i : annotate_diff_input[2][2][i] isa String ? annotate_diff_input[2][2][i] : String(Symbol(annotate_diff_input[2][2][i])))
1615+
elseif plot_type == :compare
1616+
StatsPlots.plot!(legend_plot,
1617+
fill(0,1,1),
1618+
legend_title = length(annotate_diff_input) > 2 ? nothing : annotate_diff_input[2][1],
1619+
framestyle = :none,
1620+
legend = :inside,
1621+
label = length(annotate_diff_input) > 2 ? i : annotate_diff_input[2][2][i] isa String ? annotate_diff_input[2][2][i] : String(Symbol(annotate_diff_input[2][2][i])))
1622+
end
15051623

15061624
push!(joint_shocks, k[:shock_names]...)
15071625
push!(joint_variables, k[:variable_names]...)
@@ -1523,11 +1641,9 @@ function plot_irf!(𝓂::ℳ;
15231641
pane = 1
15241642
plot_count = 1
15251643
joint_non_zero_variables = []
1526-
can_dual_axiss = Bool[]
15271644

15281645
for var in joint_variables
15291646
not_zero_in_any_irf = false
1530-
can_dual_axis = gr_back
15311647

15321648
for k in irf_active_plot_container
15331649
var_idx = findfirst(==(var), k[:variable_names])
@@ -1542,27 +1658,18 @@ function plot_irf!(𝓂::ℳ;
15421658
not_zero_in_any_irf = not_zero_in_any_irf || true
15431659
# break # If any irf data is not approximately zero, we set the flag to true.
15441660
end
1545-
1546-
SS = k[:reference_steady_state][var_idx]
1547-
1548-
if all((k[:plot_data][var_idx,:,shock_idx] .+ SS) .> eps(Float32)) && (SS > eps(Float32))
1549-
can_dual_axis = can_dual_axis && true
1550-
else
1551-
can_dual_axis = can_dual_axis && false
1552-
end
15531661
end
15541662
end
15551663

15561664
if not_zero_in_any_irf
15571665
push!(joint_non_zero_variables, var)
1558-
push!(can_dual_axiss, can_dual_axis)
15591666
else
15601667
# If all irf data for this variable and shock is approximately zero, we skip this subplot.
15611668
n_subplots -= 1
15621669
end
15631670
end
15641671

1565-
for (var, can_dual_axis) in zip(joint_non_zero_variables, can_dual_axiss)
1672+
for var in joint_non_zero_variables
15661673
SSs = eltype(irf_active_plot_container[1][:reference_steady_state])[]
15671674
Ys = AbstractVector{eltype(irf_active_plot_container[1][:plot_data])}[]
15681675

@@ -1588,10 +1695,11 @@ function plot_irf!(𝓂::ℳ;
15881695
same_ss = false
15891696
end
15901697

1591-
push!(pp, plot_irf_subplot( Ys,
1698+
push!(pp, plot_irf_subplot(Val(plot_type),
1699+
Ys,
15921700
SSs,
15931701
var,
1594-
can_dual_axis,
1702+
gr_back,
15951703
same_ss))
15961704

15971705
if !(plot_count % plots_per_page == 0)
@@ -2109,9 +2217,17 @@ function plot_conditional_variance_decomposition(𝓂::ℳ;
21092217

21102218
for k in vars_to_plot
21112219
if gr_back
2112-
push!(pp,StatsPlots.groupedbar(fevds(k,:,:)', title = replace_indices_in_symbol(k), bar_position = :stack, legend = :none))
2220+
push!(pp,StatsPlots.groupedbar(fevds(k,:,:)',
2221+
title = replace_indices_in_symbol(k),
2222+
bar_position = :stack,
2223+
linecolor = :transparent,
2224+
legend = :none))
21132225
else
2114-
push!(pp,StatsPlots.groupedbar(fevds(k,:,:)', title = replace_indices_in_symbol(k), bar_position = :stack, label = reshape(string.(replace_indices_in_symbol.(shocks_to_plot)),1,length(shocks_to_plot))))
2226+
push!(pp,StatsPlots.groupedbar(fevds(k,:,:)',
2227+
title = replace_indices_in_symbol(k),
2228+
bar_position = :stack,
2229+
linecolor = :transparent,
2230+
label = reshape(string.(replace_indices_in_symbol.(shocks_to_plot)),1,length(shocks_to_plot))))
21152231
end
21162232

21172233
if !(plot_count % plots_per_page == 0)

0 commit comments

Comments
 (0)