Skip to content

Commit 0bdcdfd

Browse files
Fic a Warning messaage
1 parent 01d56d5 commit 0bdcdfd

File tree

1 file changed

+39
-46
lines changed

1 file changed

+39
-46
lines changed

src/plot.jl

+39-46
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ struct _MeanPlot; c; val; end
1111
struct _DensityPlot; c; val; end
1212
struct _HistogramPlot; c; val; end
1313
struct _AutocorPlot; lags; val; end
14-
struct _ViolinPlot; parameters; val; total_chains; end
14+
struct _ViolinPlot; par; val; end
1515

1616
# define alias functions for old syntax
1717
const translationdict = Dict(
@@ -33,7 +33,9 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor
3333
colordim = :chain,
3434
barbounds = (-Inf, Inf),
3535
maxlag = nothing,
36-
append_chains = false
36+
append_chains = false,
37+
sections = chains.name_map[:parameters],
38+
combined = true
3739
)
3840
st = get(plotattributes, :seriestype, :traceplot)
3941
c = append_chains || st == :pooleddensity ? pool_chain(chains) : chains
@@ -72,6 +74,39 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor
7274
else
7375
range(c), val
7476
end
77+
78+
total_chains = i
79+
if st == :violinplot
80+
n_iter, n_par, n_chains = size(chains)
81+
if combined
82+
colordim := :chain
83+
par = string.(reshape(repeat(sections, inner = n_iter), n_iter, n_par))[:,i]
84+
val = Array(chains)[:,i]
85+
_ViolinPlot(par, val)
86+
elseif combined == false
87+
if colordim == :chain
88+
par_names = ["$(sections[i]).Chain $j" for i in 1:n_par, j in 1:n_chains]
89+
pars = string.(reshape(repeat(vec(par_names), inner = n_iter), (n_iter, n_par, n_chains)))
90+
val = chains.value[:,i,:]
91+
par = pars[:,i,:]
92+
elseif colordim == :parameter
93+
par_vec = repeat(sections, inner = n_iter)
94+
pars = string.(reshape(repeat(par_vec, n_chains, 1), (n_iter, n_par, n_chains)))
95+
val = chains.value[:,:,i]
96+
par = pars[:,:,i]
97+
label --> string.(names(c))
98+
else
99+
throw(ArgumentError("`colordim` must be one of `:chain` or `:parameter`"))
100+
end
101+
_ViolinPlot(par, val)
102+
else
103+
throw(ArgumentError("In `ViolinPlots` `Chains` can be combined or separated "))
104+
end
105+
elseif st supportedplots
106+
translationdict[st](c, val)
107+
else
108+
range(c), val
109+
end
75110
end
76111

77112
@recipe function f(p::_DensityPlot)
@@ -188,59 +223,17 @@ end
188223
RecipesBase.recipetype(:cornerplot, vcat(ar...))
189224
end
190225

191-
@recipe function f(
192-
chains::Chains;
193-
sections::Vector{Symbol} = chains.name_map[:parameters],
194-
combined = true
195-
)
196-
197-
st = get(plotattributes, :seriestype, :traceplot)
198-
total_chains = 0
199-
if st == :violinplot
200-
if combined
201-
n_iter, n_parameters = size(Array(chains))
202-
parameters = string.(repeat(sections, inner = n_iter))
203-
val = vec(Array(chains))
204-
total_chains = Integer(size(chains.value.data)[3])
205-
_ViolinPlot(parameters, val, total_chains)
206-
elseif combined == false
207-
n_parameters = length(sections)
208-
chain_arr = Array(chains, append_chains = false)
209-
val_vec = [chain_arr[j][:,i]
210-
for i in 1:n_parameters
211-
for j in 1:length(chain_arr)]
212-
n_iter = length(val_vec[1])
213-
total_chains = length(val_vec)
214-
val = zeros(Float64, n_iter, total_chains)
215-
for i in 1:total_chains
216-
val[:,i] = val_vec[:][i]
217-
end
218-
val = vec(val)
219-
parameters_names = ["param $(sections[i]).Chain $j"
220-
for i in 1:n_parameters
221-
for j in 1:length(chain_arr)]
222-
parameters = string.(repeat(parameters_names, inner = n_iter))
223-
_ViolinPlot(parameters, val, total_chains)
224-
else
225-
error("Symbol names are interpreted as parameter names, only compatible with ",
226-
"`colordim = :chain`")
227-
end
228-
end
229-
end
230-
231226
@recipe function f(p::_ViolinPlot)
232227
@series begin
233228
seriestype := :violin
234-
xaxis --> "Parameter"
235-
size --> (200*p.total_chains, 500)
236-
p.parameters, p.val
229+
p.par, p.val
237230
end
238231

239232
@series begin
240233
seriestype := :boxplot
241234
bar_width --> 0.1
242235
linewidth --> 2
243236
fillalpha --> 0.8
244-
p.parameters, p.val
237+
p.par, p.val
245238
end
246239
end

0 commit comments

Comments
 (0)