From 2b0e3043f3878b48f457b3e243445aa7dc0f3e1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hakkel=20Tam=C3=A1s?= Date: Thu, 30 Apr 2020 18:48:26 +0200 Subject: [PATCH] non-mutating function bug --- src/getPlan.jl | 43 +++++++++++++++++++++++-------------------- test/recycle.jl | 14 +++++++------- test/show.jl | 2 +- 3 files changed, 31 insertions(+), 28 deletions(-) diff --git a/src/getPlan.jl b/src/getPlan.jl index 0e5100d..b1c5cd7 100644 --- a/src/getPlan.jl +++ b/src/getPlan.jl @@ -87,24 +87,24 @@ macro createReturnValue(left, right, op) rBuf = right == :buffer ? :buffer : :(rBuffer.buffer) name = left == :buffer ? :(buffer.name) : :(lBuffer.name) esc(quote - if FO.left.ismutating && !FO.right.ismutating - if $op == + - ((buffer, x) -> begin - inBuff.buffer .= x; - broadcast!(+, $lBuf, rFunc($rBuf, inBuff.buffer), lFunc($lBuf, inBuff.buffer)) - end, lBuffer, "($(inBuff.name) .= $inside; broadcast!(+, $($name), $rText, $lText))") - else - ((buffer, x) -> begin - inBuff.buffer .= x; - broadcast!((x,y) -> y-x, $lBuf, rFunc($rBuf, inBuff.buffer), lFunc($lBuf, inBuff.buffer)) - end, lBuffer, "($(inBuff.name) .= $inside; broadcast!((x,y) -> y-x, , $($name), $rText, $lText))") - end - else + #if FO.left.ismutating && !FO.right.ismutating + # if $op == + + # ((buffer, x) -> begin + # inBuff.buffer .= x; + # broadcast!(+, $lBuf, rFunc($rBuf, inBuff.buffer), lFunc($lBuf, inBuff.buffer)) + # end, lBuffer, "($(inBuff.name) .= $inside; broadcast!(+, $($name), $rText, $lText))") + # else + # ((buffer, x) -> begin + # inBuff.buffer .= x; + # broadcast!((x,y) -> y-x, $lBuf, rFunc($rBuf, inBuff.buffer), lFunc($lBuf, inBuff.buffer)) + # end, lBuffer, "($(inBuff.name) .= $inside; broadcast!((x,y) -> y-x, , $($name), $rText, $lText))") + # end + #else ((buffer, x) -> begin inBuff.buffer .= x; broadcast!($op, $lBuf, lFunc($lBuf, inBuff.buffer), rFunc($rBuf, inBuff.buffer)) end, lBuffer, "($(inBuff.name) .= $inside; broadcast!($op, $($name), $lText, $rText))") - end + #end end) end @@ -118,8 +118,8 @@ function getPlanAddSub(FO::FunctionOperatorComposite, buffer::Buffer, adjoint::B lFunc, lBuffer, lText = getPlan(FO.left, buffer, adjoint ⊻ FO.left.adjoint, inBuff.name, storage) lBuffer.available = false inBuff.available = true - rBuffer = FO.left.ismutating && FO.right.ismutating ? - newBuffer(eltype(FO), FO.outDims, storage) : lBuffer + rBuffer = #FO.left.ismutating && FO.right.ismutating ? + newBuffer(eltype(FO), FO.outDims, storage) #: lBuffer rFunc, rBuffer, rText = getPlan(FO.right, rBuffer, adjoint ⊻ FO.right.adjoint, inBuff.name, storage) lBuffer.available = buffer.available = true if rBuffer.name == buffer.name == lBuffer.name @@ -154,15 +154,18 @@ function getPlan(FO::FunctionOperator, buffer::Buffer, adjoint::Bool, inside::St # FunctionOperatorComposite object, but leaves intact its descendants. # Thus, sometimes inDims and outDims are already switched, but sometimes not, # that's why we need this awkward expression: adjoint != FO.adjoint ? FO.inDims : FO.outDims - size(buffer.buffer) != (adjoint != FO.adjoint ? FO.inDims : FO.outDims) && - (buffer = newBuffer(eltype(FO), (adjoint != FO.adjoint ? FO.inDims : FO.outDims), storage)) + (inDims, outDims) = adjoint != FO.adjoint ? (FO.outDims, FO.inDims) : (FO.inDims, FO.outDims) + size(buffer.buffer) != outDims && (buffer = newBuffer(eltype(FO), outDims, storage)) if checkTwoInputs(FO.forw) text = FO.scaling ? "broadcast!(*, $(buffer.name), $(adjoint ? conj(FO.getScale()) : FO.getScale()), $inside)" : FO.name*(adjoint ? ".backw" : ".forw")*"($(buffer.name), $inside)" (adjoint ? (b,x) -> begin FO.backw(b,x); b; end : (b,x) -> begin FO.forw(b,x); b; end, buffer, text) else - text = FO.name*(adjoint ? ".backw" : ".forw")*"($inside)" - (adjoint ? (b, x) -> reshape(FO.backw(x), FO.inDims) : (b, x) -> reshape(FO.forw(x), FO.outDims), buffer, text) + text = buffer.name*" .= "*FO.name*(adjoint ? ".backw" : ".forw")*"($inside)" + (adjoint ? + (b, x) -> b .= reshape(FO.backw(x), size(b)) : + (b, x) -> b .= reshape(FO.forw(x), size(b)), + buffer, text) end end \ No newline at end of file diff --git a/test/recycle.jl b/test/recycle.jl index 7f6a405..4c425e7 100644 --- a/test/recycle.jl +++ b/test/recycle.jl @@ -1,4 +1,4 @@ -using FunctionOperators, MacroTools, OffsetArrays, LinearAlgebra, Test, BenchmarkTools +using FunctionOperators, MacroTools, LinearAlgebra, Test, BenchmarkTools d = rand(10, 3, 2)im Ω = FunctionOperator{Complex{Float64}}(x -> x, x -> x, (10, 3, 2), (10, 3, 2)) @@ -50,8 +50,8 @@ function AL_2(d::Array{Complex{Float64}, 3}, # measurement data V₂ = zeros(Complex{Float64}, size(L)) Z_scaler = 1 ./ (reshape(samp, size(samp)..., 1) .+ δ₁) - cost_vec = OffsetVector{Float64}(undef, 0:N) - cost_vec[0] = cost(L, S, d, scale_L*λ_L, scale_S*λ_S) + cost_vec = Vector{Float64}(undef, N+1) + cost_vec[1] = cost(L, S, d, scale_L*λ_L, scale_S*λ_S) for k in 1:N Z = Z_scaler .* (Ω' * d + δ₁*(Q * C * X - V₁)) @@ -63,7 +63,7 @@ function AL_2(d::Array{Complex{Float64}, 3}, # measurement data V₂ += X - L - S end - cost_vec[k] = cost(L, S, d, scale_L*λ_L, scale_S*λ_S) + cost_vec[k+1] = cost(L, S, d, scale_L*λ_L, scale_S*λ_S) end L + S, cost_vec @@ -100,8 +100,8 @@ function AL_2_recycle(d::Array{Complex{Float64}, 3}, # measurement data V₂ = zeros(Complex{Float64}, size(L)) Z_scaler = 1 ./ (reshape(samp, size(samp)..., 1) .+ δ₁) - cost_vec = OffsetVector{Float64}(undef, 0:N) - cost_vec[0] = cost_recycle(L, S, d, scale_L*λ_L, scale_S*λ_S) + cost_vec = Vector{Float64}(undef, N+1) + cost_vec[1] = cost_recycle(L, S, d, scale_L*λ_L, scale_S*λ_S) @recycle for k in 1:N Z = Z_scaler .* (Ω' * d + δ₁*(Q * C * X - V₁)) @@ -113,7 +113,7 @@ function AL_2_recycle(d::Array{Complex{Float64}, 3}, # measurement data V₂ += X - L - S end - cost_vec[k] = cost_recycle(L, S, d, scale_L*λ_L, scale_S*λ_S) + cost_vec[k+1] = cost_recycle(L, S, d, scale_L*λ_L, scale_S*λ_S) end L + S, cost_vec diff --git a/test/show.jl b/test/show.jl index 9e08624..9132861 100644 --- a/test/show.jl +++ b/test/show.jl @@ -7,6 +7,6 @@ using FunctionOperators, Test @test sprint(show, Op₁ * Op₁) == "FunctionOperatorComposite{Float64}(Op₁ * Op₁, (10, 10), (10, 10), no plan)" combined = Op₁ * Op₁ combined * ones(10, 10) - @test sprint(show, combined) == "FunctionOperatorComposite{Float64}(Op₁ * Op₁, (10, 10), (10, 10), buffer1 .= Op₁.forw(buffer1 .= Op₁.forw(x)))" + @test sprint(show, combined) == "FunctionOperatorComposite{Float64}(Op₁ * Op₁, (10, 10), (10, 10), buffer0 .= Op₁.forw(buffer0 .= Op₁.forw(x)))" @test sprint(show, "text/plain", Op₁ * Op₁, context = :module=>@__MODULE__) == "FunctionOperatorComposite with eltype Float64\n Name: Op₁ * Op₁\n Input dimensions: (10, 10)\n Output dimensions: (10, 10)\n Plan: no plan" end