@@ -11,7 +11,7 @@ struct _MeanPlot; c; val; end
11
11
struct _DensityPlot; c; val; end
12
12
struct _HistogramPlot; c; val; end
13
13
struct _AutocorPlot; lags; val; end
14
- struct _ViolinPlot; parameters ; val; total_chains ; end
14
+ struct _ViolinPlot; par ; val; end
15
15
16
16
# define alias functions for old syntax
17
17
const translationdict = Dict (
@@ -33,7 +33,9 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor
33
33
colordim = :chain ,
34
34
barbounds = (- Inf , Inf ),
35
35
maxlag = nothing ,
36
- append_chains = false
36
+ append_chains = false ,
37
+ sections = chains. name_map[:parameters ],
38
+ combined = true
37
39
)
38
40
st = get (plotattributes, :seriestype , :traceplot )
39
41
c = append_chains || st == :pooleddensity ? pool_chain (chains) : chains
@@ -72,6 +74,39 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor
72
74
else
73
75
range (c), val
74
76
end
77
+
78
+ total_chains = i
79
+ if st == :violinplot
80
+ n_iter, n_par, n_chains = size (chains)
81
+ if combined
82
+ colordim := :chain
83
+ par = string .(reshape (repeat (sections, inner = n_iter), n_iter, n_par))[:,i]
84
+ val = Array (chains)[:,i]
85
+ _ViolinPlot (par, val)
86
+ elseif combined == false
87
+ if colordim == :chain
88
+ par_names = [" $(sections[i]) .Chain $j " for i in 1 : n_par, j in 1 : n_chains]
89
+ pars = string .(reshape (repeat (vec (par_names), inner = n_iter), (n_iter, n_par, n_chains)))
90
+ val = chains. value[:,i,:]
91
+ par = pars[:,i,:]
92
+ elseif colordim == :parameter
93
+ par_vec = repeat (sections, inner = n_iter)
94
+ pars = string .(reshape (repeat (par_vec, n_chains, 1 ), (n_iter, n_par, n_chains)))
95
+ val = chains. value[:,:,i]
96
+ par = pars[:,:,i]
97
+ label --> string .(names (c))
98
+ else
99
+ throw (ArgumentError (" `colordim` must be one of `:chain` or `:parameter`" ))
100
+ end
101
+ _ViolinPlot (par, val)
102
+ else
103
+ throw (ArgumentError (" In `ViolinPlots` `Chains` can be combined or separated " ))
104
+ end
105
+ elseif st ∈ supportedplots
106
+ translationdict[st](c, val)
107
+ else
108
+ range (c), val
109
+ end
75
110
end
76
111
77
112
@recipe function f (p:: _DensityPlot )
@@ -188,59 +223,17 @@ end
188
223
RecipesBase. recipetype (:cornerplot , vcat (ar... ))
189
224
end
190
225
191
- @recipe function f (
192
- chains:: Chains ;
193
- sections:: Vector{Symbol} = chains. name_map[:parameters ],
194
- combined = true
195
- )
196
-
197
- st = get (plotattributes, :seriestype , :traceplot )
198
- total_chains = 0
199
- if st == :violinplot
200
- if combined
201
- n_iter, n_parameters = size (Array (chains))
202
- parameters = string .(repeat (sections, inner = n_iter))
203
- val = vec (Array (chains))
204
- total_chains = Integer (size (chains. value. data)[3 ])
205
- _ViolinPlot (parameters, val, total_chains)
206
- elseif combined == false
207
- n_parameters = length (sections)
208
- chain_arr = Array (chains, append_chains = false )
209
- val_vec = [chain_arr[j][:,i]
210
- for i in 1 : n_parameters
211
- for j in 1 : length (chain_arr)]
212
- n_iter = length (val_vec[1 ])
213
- total_chains = length (val_vec)
214
- val = zeros (Float64, n_iter, total_chains)
215
- for i in 1 : total_chains
216
- val[:,i] = val_vec[:][i]
217
- end
218
- val = vec (val)
219
- parameters_names = [" param $(sections[i]) .Chain $j "
220
- for i in 1 : n_parameters
221
- for j in 1 : length (chain_arr)]
222
- parameters = string .(repeat (parameters_names, inner = n_iter))
223
- _ViolinPlot (parameters, val, total_chains)
224
- else
225
- error (" Symbol names are interpreted as parameter names, only compatible with " ,
226
- " `colordim = :chain`" )
227
- end
228
- end
229
- end
230
-
231
226
@recipe function f (p:: _ViolinPlot )
232
227
@series begin
233
228
seriestype := :violin
234
- xaxis --> " Parameter"
235
- size --> (200 * p. total_chains, 500 )
236
- p. parameters, p. val
229
+ p. par, p. val
237
230
end
238
231
239
232
@series begin
240
233
seriestype := :boxplot
241
234
bar_width --> 0.1
242
235
linewidth --> 2
243
236
fillalpha --> 0.8
244
- p. parameters , p. val
237
+ p. par , p. val
245
238
end
246
239
end
0 commit comments