Skip to content

Commit cd0bc5a

Browse files
committed
enhance plot functionality with parameterized IRF plots and improved data visualization
1 parent b5fddba commit cd0bc5a

File tree

2 files changed

+165
-10
lines changed

2 files changed

+165
-10
lines changed

src/plotting.jl

Lines changed: 111 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,7 +1302,9 @@ function plot_irf!(𝓂::ℳ;
13021302

13031303
param_nms = diffdict[:parameters]|>keys|>collect|>sort
13041304

1305-
annotate_ss = Pair{String,Any}[]
1305+
annotate_ss = [Pair{String,Any}[]]
1306+
1307+
annotate_ss_page = Pair{String,Any}[]
13061308

13071309
annotate_params = Pair{String,Any}[]
13081310

@@ -1339,7 +1341,7 @@ function plot_irf!(𝓂::ℳ;
13391341
SSs = [k[:reference_steady_state][var_idx[i]] for k in irf_active_plot_container]
13401342

13411343
if maximum(SSs) - minimum(SSs) > 1e-10
1342-
push!(annotate_ss, String(variable_name) => SSs)
1344+
push!(annotate_ss_page, String(variable_name) => minimal_sigfig_strings(SSs))
13431345
end
13441346

13451347
push!(pp, plot_irf_subplot( [k[:plot_data][i,:,shock] for k in irf_active_plot_container],
@@ -1367,7 +1369,7 @@ function plot_irf!(𝓂::ℳ;
13671369
end
13681370

13691371
ppp = StatsPlots.plot(pp...; attributes...)
1370-
1372+
13711373
annotate_ss_plot = plot_df(annotate_ss)
13721374

13731375
ppp2 = StatsPlots.plot(annotate_params_plot, annotate_ss_plot; attributes...)
@@ -1389,12 +1391,17 @@ function plot_irf!(𝓂::ℳ;
13891391
end
13901392

13911393
pane += 1
1392-
1393-
annotate_ss = Pair{String,Any}[]
1394+
1395+
push!(annotate_ss, annotate_ss_page)
1396+
1397+
annotate_ss_page = Pair{String,Any}[]
13941398

13951399
pp = []
13961400
end
13971401
end
1402+
1403+
push!(annotate_ss, annotate_ss_page)
1404+
13981405
end
13991406

14001407
if length(pp) > 0
@@ -1414,9 +1421,21 @@ function plot_irf!(𝓂::ℳ;
14141421

14151422
ppp = StatsPlots.plot(pp...; attributes...)
14161423

1417-
annotate_ss_plot = plot_df(annotate_ss)
1424+
if length(annotate_ss[pane-1]) > 0
1425+
annotate_ss_plot = plot_df(annotate_ss[pane-1])
1426+
1427+
ppp2 = StatsPlots.plot(annotate_params_plot, annotate_ss_plot; attributes...)
14181428

1419-
ppp2 = StatsPlots.plot(annotate_params_plot, annotate_ss_plot; attributes...)
1429+
p = StatsPlots.plot(ppp,
1430+
ppp2,
1431+
layout = StatsPlots.grid(2, 1, heights = [0.8, 0.2]),
1432+
plot_title = "Model: "*𝓂.model_name*" " * shock_dir * shock_string *" ("*string(pane)*"/"*string(Int(ceil(n_subplots/plots_per_page)))*")";
1433+
attributes_redux...)
1434+
else
1435+
p = StatsPlots.plot(ppp,
1436+
plot_title = "Model: "*𝓂.model_name*" " * shock_dir * shock_string *" ("*string(pane)*"/"*string(Int(ceil(n_subplots/plots_per_page)))*")";
1437+
attributes_redux...)
1438+
end
14201439

14211440
p = StatsPlots.plot(ppp,
14221441
ppp2,
@@ -1439,6 +1458,91 @@ function plot_irf!(𝓂::ℳ;
14391458
return return_plots
14401459
end
14411460

1461+
function minimal_sigfig_strings(v::AbstractVector{<:Real};
1462+
min_sig::Int = 3, n::Int = 10, dup_tol::Float64 = 1e-13)
1463+
1464+
idx = collect(eachindex(v))
1465+
finite_mask = map(x -> isfinite(x) && x != 0, v)
1466+
work_idx = filter(i -> finite_mask[i], idx)
1467+
sorted_idx = sort(work_idx, by = i -> v[i])
1468+
mwork = length(sorted_idx)
1469+
1470+
# Gaps to nearest neighbour
1471+
gaps = Dict{Int,Float64}()
1472+
for (k, i) in pairs(sorted_idx)
1473+
x = float(v[i])
1474+
if mwork == 1
1475+
gaps[i] = Inf
1476+
elseif k == 1
1477+
gaps[i] = abs(v[sorted_idx[k+1]] - x)
1478+
elseif k == mwork
1479+
gaps[i] = abs(x - v[sorted_idx[k-1]])
1480+
else
1481+
g1 = abs(x - v[sorted_idx[k-1]])
1482+
g2 = abs(v[sorted_idx[k+1]] - x)
1483+
gaps[i] = min(g1, g2)
1484+
end
1485+
end
1486+
1487+
# Duplicate clusters (within dup_tol)
1488+
duplicate = Dict{Int,Bool}()
1489+
k = 1
1490+
while k <= mwork
1491+
i = sorted_idx[k]
1492+
cluster = [i]
1493+
x = v[i]
1494+
j = k + 1
1495+
while j <= mwork && abs(v[sorted_idx[j]] - x) <= dup_tol
1496+
push!(cluster, sorted_idx[j])
1497+
j += 1
1498+
end
1499+
isdup = length(cluster) > 1
1500+
for c in cluster
1501+
duplicate[c] = isdup
1502+
end
1503+
k = j
1504+
end
1505+
1506+
# Required significant digits for distinction
1507+
req_sig = Dict{Int,Int}()
1508+
for i in sorted_idx
1509+
if duplicate[i]
1510+
req_sig[i] = min_sig # will apply rule anyway
1511+
else
1512+
x = float(v[i])
1513+
g = gaps[i]
1514+
if g == 0.0
1515+
req_sig[i] = min_sig
1516+
else
1517+
m = floor(log10(abs(x))) + 1
1518+
s = max(min_sig, ceil(Int, m - log10(g)))
1519+
# Apply rule: if they differ only after more than n sig digits
1520+
if s > n
1521+
req_sig[i] = min_sig
1522+
else
1523+
req_sig[i] = s
1524+
end
1525+
end
1526+
end
1527+
end
1528+
1529+
# Format output
1530+
out = Vector{String}(undef, length(v))
1531+
for i in eachindex(v)
1532+
x = v[i]
1533+
if !(isfinite(x)) || x == 0
1534+
# For zero or non finite just echo (rule does not change them)
1535+
out[i] = string(x)
1536+
elseif haskey(req_sig, i)
1537+
s = req_sig[i]
1538+
out[i] = string(round(x, sigdigits = s))
1539+
else
1540+
# Non finite or zero already handled; fallback
1541+
out[i] = string(x)
1542+
end
1543+
end
1544+
return out
1545+
end
14421546

14431547

14441548
function plot_df(plot_vector::Vector{Pair{String,Any}})

test/fix_combined_plots.jl

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,61 @@ end;
1616
β = 0.95
1717
end;
1818

19-
plot_irf(RBC)
19+
plot_irf(RBC, parameters = [:std_z => 0.01, => 0.95, => 0.2])
2020

21-
MacroModelling.plot_irf!(RBC, parameters = :std_z => 0.012)
21+
MacroModelling.plot_irf!(RBC, parameters = [:std_z => 0.012, => 0.95, => 0.75])
22+
23+
MacroModelling.plot_irf!(RBC, parameters = [:std_z => 0.01, => 0.957, => 0.5])
2224

2325
MacroModelling.irf_active_plot_container
2426

25-
MacroModelling.compare_args_and_kwargs(MacroModelling.irf_active_plot_container)
27+
diffdict = MacroModelling.compare_args_and_kwargs(MacroModelling.irf_active_plot_container)
28+
29+
using StatsPlots, DataFrames
30+
using Plots
31+
32+
diffdict[:parameters]
33+
mapreduce((x, y) -> x y, diffdict[:parameters])
34+
35+
df = diffdict[:parameters]|>DataFrame
36+
param_nms = diffdict[:parameters]|>keys|>collect|>sort
37+
38+
plot_vector = Pair{String,Any}[]
39+
for param in param_nms
40+
push!(plot_vector, String(param) => diffdict[:parameters][param])
41+
end
42+
43+
pushfirst!(plot_vector, "Plot index" => 1:length(diffdict[:parameters][param_nms[1]]))
44+
45+
46+
function plot_df(plot_vector::Vector{Pair{String,Any}})
47+
# Determine dimensions from plot_vector
48+
ncols = length(plot_vector)
49+
nrows = length(plot_vector[1].second)
50+
51+
bg_matrix = ones(nrows + 1, ncols)
52+
bg_matrix[1, :] .= 0.35 # Header row
53+
for i in 3:2:nrows+1
54+
bg_matrix[i, :] .= 0.85
55+
end
56+
57+
# draw the "cells"
58+
df_plot = heatmap(bg_matrix;
59+
c = cgrad([:lightgrey, :white]), # Color gradient for background
60+
yflip = true,
61+
tick=:none,
62+
legend=false,
63+
framestyle = :none, # Keep the outer box
64+
cbar=false)
65+
66+
# overlay the header and numeric values
67+
for j in 1:ncols
68+
annotate!(df_plot, j, 1, text(plot_vector[j].first, :center, 8)) # Header
69+
for i in 1:nrows
70+
annotate!(df_plot, j, i+1, text(string(plot_vector[j].second[i]), :center, 8))
71+
end
72+
end
73+
return df_plot
74+
end
75+
76+
plot_df(plot_vector)

0 commit comments

Comments
 (0)