Skip to content

Commit

Permalink
non-mutating function bug
Browse files Browse the repository at this point in the history
  • Loading branch information
hakkelt committed Apr 30, 2020
1 parent e072fe1 commit 2b0e304
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 28 deletions.
43 changes: 23 additions & 20 deletions src/getPlan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,24 +87,24 @@ macro createReturnValue(left, right, op)
rBuf = right == :buffer ? :buffer : :(rBuffer.buffer)
name = left == :buffer ? :(buffer.name) : :(lBuffer.name)
esc(quote
if FO.left.ismutating && !FO.right.ismutating
if $op == +
((buffer, x) -> begin
inBuff.buffer .= x;
broadcast!(+, $lBuf, rFunc($rBuf, inBuff.buffer), lFunc($lBuf, inBuff.buffer))
end, lBuffer, "($(inBuff.name) .= $inside; broadcast!(+, $($name), $rText, $lText))")
else
((buffer, x) -> begin
inBuff.buffer .= x;
broadcast!((x,y) -> y-x, $lBuf, rFunc($rBuf, inBuff.buffer), lFunc($lBuf, inBuff.buffer))
end, lBuffer, "($(inBuff.name) .= $inside; broadcast!((x,y) -> y-x, , $($name), $rText, $lText))")
end
else
#if FO.left.ismutating && !FO.right.ismutating
# if $op == +
# ((buffer, x) -> begin
# inBuff.buffer .= x;
# broadcast!(+, $lBuf, rFunc($rBuf, inBuff.buffer), lFunc($lBuf, inBuff.buffer))
# end, lBuffer, "($(inBuff.name) .= $inside; broadcast!(+, $($name), $rText, $lText))")
# else
# ((buffer, x) -> begin
# inBuff.buffer .= x;
# broadcast!((x,y) -> y-x, $lBuf, rFunc($rBuf, inBuff.buffer), lFunc($lBuf, inBuff.buffer))
# end, lBuffer, "($(inBuff.name) .= $inside; broadcast!((x,y) -> y-x, , $($name), $rText, $lText))")
# end
#else
((buffer, x) -> begin
inBuff.buffer .= x;
broadcast!($op, $lBuf, lFunc($lBuf, inBuff.buffer), rFunc($rBuf, inBuff.buffer))
end, lBuffer, "($(inBuff.name) .= $inside; broadcast!($op, $($name), $lText, $rText))")
end
#end
end)
end

Expand All @@ -118,8 +118,8 @@ function getPlanAddSub(FO::FunctionOperatorComposite, buffer::Buffer, adjoint::B
lFunc, lBuffer, lText = getPlan(FO.left, buffer, adjoint FO.left.adjoint, inBuff.name, storage)
lBuffer.available = false
inBuff.available = true
rBuffer = FO.left.ismutating && FO.right.ismutating ?
newBuffer(eltype(FO), FO.outDims, storage) : lBuffer
rBuffer = #FO.left.ismutating && FO.right.ismutating ?
newBuffer(eltype(FO), FO.outDims, storage) #: lBuffer
rFunc, rBuffer, rText = getPlan(FO.right, rBuffer, adjoint FO.right.adjoint, inBuff.name, storage)
lBuffer.available = buffer.available = true
if rBuffer.name == buffer.name == lBuffer.name
Expand Down Expand Up @@ -154,15 +154,18 @@ function getPlan(FO::FunctionOperator, buffer::Buffer, adjoint::Bool, inside::St
# FunctionOperatorComposite object, but leaves intact its descendants.
# Thus, sometimes inDims and outDims are already switched, but sometimes not,
# that's why we need this awkward expression: adjoint != FO.adjoint ? FO.inDims : FO.outDims
size(buffer.buffer) != (adjoint != FO.adjoint ? FO.inDims : FO.outDims) &&
(buffer = newBuffer(eltype(FO), (adjoint != FO.adjoint ? FO.inDims : FO.outDims), storage))
(inDims, outDims) = adjoint != FO.adjoint ? (FO.outDims, FO.inDims) : (FO.inDims, FO.outDims)
size(buffer.buffer) != outDims && (buffer = newBuffer(eltype(FO), outDims, storage))
if checkTwoInputs(FO.forw)
text = FO.scaling ?
"broadcast!(*, $(buffer.name), $(adjoint ? conj(FO.getScale()) : FO.getScale()), $inside)" :
FO.name*(adjoint ? ".backw" : ".forw")*"($(buffer.name), $inside)"
(adjoint ? (b,x) -> begin FO.backw(b,x); b; end : (b,x) -> begin FO.forw(b,x); b; end, buffer, text)
else
text = FO.name*(adjoint ? ".backw" : ".forw")*"($inside)"
(adjoint ? (b, x) -> reshape(FO.backw(x), FO.inDims) : (b, x) -> reshape(FO.forw(x), FO.outDims), buffer, text)
text = buffer.name*" .= "*FO.name*(adjoint ? ".backw" : ".forw")*"($inside)"
(adjoint ?
(b, x) -> b .= reshape(FO.backw(x), size(b)) :
(b, x) -> b .= reshape(FO.forw(x), size(b)),
buffer, text)
end
end
14 changes: 7 additions & 7 deletions test/recycle.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using FunctionOperators, MacroTools, OffsetArrays, LinearAlgebra, Test, BenchmarkTools
using FunctionOperators, MacroTools, LinearAlgebra, Test, BenchmarkTools

d = rand(10, 3, 2)im
Ω = FunctionOperator{Complex{Float64}}(x -> x, x -> x, (10, 3, 2), (10, 3, 2))
Expand Down Expand Up @@ -50,8 +50,8 @@ function AL_2(d::Array{Complex{Float64}, 3}, # measurement data
V₂ = zeros(Complex{Float64}, size(L))
Z_scaler = 1 ./ (reshape(samp, size(samp)..., 1) .+ δ₁)

cost_vec = OffsetVector{Float64}(undef, 0:N)
cost_vec[0] = cost(L, S, d, scale_L*λ_L, scale_S*λ_S)
cost_vec = Vector{Float64}(undef, N+1)
cost_vec[1] = cost(L, S, d, scale_L*λ_L, scale_S*λ_S)

for k in 1:N
Z = Z_scaler .*' * d + δ₁*(Q * C * X - V₁))
Expand All @@ -63,7 +63,7 @@ function AL_2(d::Array{Complex{Float64}, 3}, # measurement data
V₂ += X - L - S
end

cost_vec[k] = cost(L, S, d, scale_L*λ_L, scale_S*λ_S)
cost_vec[k+1] = cost(L, S, d, scale_L*λ_L, scale_S*λ_S)
end

L + S, cost_vec
Expand Down Expand Up @@ -100,8 +100,8 @@ function AL_2_recycle(d::Array{Complex{Float64}, 3}, # measurement data
V₂ = zeros(Complex{Float64}, size(L))
Z_scaler = 1 ./ (reshape(samp, size(samp)..., 1) .+ δ₁)

cost_vec = OffsetVector{Float64}(undef, 0:N)
cost_vec[0] = cost_recycle(L, S, d, scale_L*λ_L, scale_S*λ_S)
cost_vec = Vector{Float64}(undef, N+1)
cost_vec[1] = cost_recycle(L, S, d, scale_L*λ_L, scale_S*λ_S)

@recycle for k in 1:N
Z = Z_scaler .*' * d + δ₁*(Q * C * X - V₁))
Expand All @@ -113,7 +113,7 @@ function AL_2_recycle(d::Array{Complex{Float64}, 3}, # measurement data
V₂ += X - L - S
end

cost_vec[k] = cost_recycle(L, S, d, scale_L*λ_L, scale_S*λ_S)
cost_vec[k+1] = cost_recycle(L, S, d, scale_L*λ_L, scale_S*λ_S)
end

L + S, cost_vec
Expand Down
2 changes: 1 addition & 1 deletion test/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ using FunctionOperators, Test
@test sprint(show, Op₁ * Op₁) == "FunctionOperatorComposite{Float64}(Op₁ * Op₁, (10, 10), (10, 10), no plan)"
combined = Op₁ * Op₁
combined * ones(10, 10)
@test sprint(show, combined) == "FunctionOperatorComposite{Float64}(Op₁ * Op₁, (10, 10), (10, 10), buffer1 .= Op₁.forw(buffer1 .= Op₁.forw(x)))"
@test sprint(show, combined) == "FunctionOperatorComposite{Float64}(Op₁ * Op₁, (10, 10), (10, 10), buffer0 .= Op₁.forw(buffer0 .= Op₁.forw(x)))"
@test sprint(show, "text/plain", Op₁ * Op₁, context = :module=>@__MODULE__) == "FunctionOperatorComposite with eltype Float64\n Name: Op₁ * Op₁\n Input dimensions: (10, 10)\n Output dimensions: (10, 10)\n Plan: no plan"
end

0 comments on commit 2b0e304

Please sign in to comment.