You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As discussed, there is an odd artefact where the fit gets great almost immediately across the model except for the beginning, which then takes up to 20x iterations to converge (linked more to the array size).
Result with 20 iterations (observe the beginning):
Result with 365 iterations (=N)
MWE for a simple curve-fitting scenario
# Generate data to fit
time_max=365
noise_scale=0.05
time_index=collect(1:time_max)/time_max
p=180/time_max
growth=2
offset=0
y=offset .+ sin.(time_index*2π/p) .+ growth.*time_index .+rand(Normal(0,noise_scale),time_max)
plot(y)
# Generate splines to approximate the function
# Note: boundary knots are important to be outside of the needed range to avoid a row of all zeros (which breaks the backprop)
X=Splines2.bs(time_index,df=10,boundary_knots=(-0.01,1.01));
# Build the model
@model function linreg(X,n,dim_x)
T=Float64
y = datavar(T, n)
aux = randomvar(n)
sigma ~ GammaShapeRate(1.0, 1.0)
intercept ~ NormalMeanVariance(0.0, 2.0)
beta ~ MvNormalMeanPrecision(zeros(dim_x), diageye(dim_x))
for i in 1:n
aux[i] ~ intercept + dot(X[i,:], beta)
y[i] ~ NormalMeanPrecision(aux[i], sigma)
end
return beta,aux,y
end
constraints = @constraints begin
q(aux, sigma) = q(aux)q(sigma)
end
# Run inference
@time results = inference(
model = Model(linreg,X,size(X)...),
data = (y = y,),
constraints = constraints,
initmessages = (intercept = vague(NormalMeanVariance),),
initmarginals = (sigma = GammaShapeRate(1.0, 1.0),),
returnvars = (sigma = KeepLast(),beta = KeepLast(), aux = KeepLast()),#,y=KeepLast()),
iterations = 20,
warn = true,
free_energy=true
)
# Plot results
# Note: observe the divergence in the first 50 data points
# It disappears as you increase number of iterations
plot(mean.(results.posteriors[:aux]), ribbon = (results.posteriors[:sigma]|>mean|>inv|>sqrt),label="Fitted")
plot!(y,label="Observed data")
The text was updated successfully, but these errors were encountered:
As discussed, there is an odd artefact where the fit gets great almost immediately across the model except for the beginning, which then takes up to 20x iterations to converge (linked more to the array size).
Result with 20 iterations (observe the beginning):
Result with 365 iterations (=
N
)MWE for a simple curve-fitting scenario
The text was updated successfully, but these errors were encountered: