From 6e2df03415e270e56ae72dded9c8bec224fa09bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 28 Nov 2023 16:28:09 +0100 Subject: [PATCH] Improve type stability of expectation --- src/expectation.jl | 4 ++-- test/expectation.jl | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/expectation.jl b/src/expectation.jl index d0611e5..65a9fce 100644 --- a/src/expectation.jl +++ b/src/expectation.jl @@ -1,6 +1,6 @@ -function _expectation(μ::Measure, p::_APL, f) +function _expectation(μ::Measure{S}, p::_APL{T}, f) where {S,T} i = 1 - s = 0 + s = zero(MA.promote_operation(*, S, T)) for t in MP.terms(p) while i <= length(μ.x) && MP.monomial(t) != μ.x[i] i += 1 diff --git a/test/expectation.jl b/test/expectation.jl index 1c185a6..cdc8e73 100644 --- a/test/expectation.jl +++ b/test/expectation.jl @@ -3,10 +3,10 @@ p = x[3] - 2x[1] * x[2]^2 + 3x[3] * x[1] - 5x[1]^3 v = (1, 2, 3) m = dirac(monomials(p), x => v) - @test MultivariateMoments.expectation(m, p) == + @test (@inferred MultivariateMoments.expectation(m, p)) == p(x => v) == - MultivariateMoments.expectation(p, m) + (@inferred MultivariateMoments.expectation(p, m)) @test_throws ErrorException dot(x[1] * x[2] * x[3], m) - @test dot(0.5 * x[1] * x[2]^2, m) == 2.0 - @test dot(m, x[1] * x[3]) == 3 + @test (@inferred dot(0.5 * x[1] * x[2]^2, m)) == 2.0 + @test (@inferred dot(m, x[1] * x[3])) == 3 end