Skip to content

Commit 86884cf

Browse files
authored
Merge pull request #15 from JuliaImageRecon/nh/normalOp
Add parent-type as parameter of normal operator
2 parents e54bba2 + 8177479 commit 86884cf

File tree

8 files changed

+37
-20
lines changed

8 files changed

+37
-20
lines changed

.github/workflows/Breakage.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
with:
2929
version: 1
3030
arch: x64
31-
- uses: actions/cache@v1
31+
- uses: actions/cache@v4
3232
env:
3333
cache-name: cache-artifacts
3434
with:

ext/LinearOperatorNFFTExt/NFFTOp.jl

+9-3
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ function NFFTToeplitzNormalOp(shape, W, fftplan, ifftplan, λ, xL1::matT, xL2::m
128128
, shape, W, fftplan, ifftplan, λ, xL1, xL2)
129129
end
130130

131-
function NFFTToeplitzNormalOp(nfft::NFFTOp{T}, W=opEye(eltype(nfft), size(nfft, 1), S= LinearOperators.storage_type(nfft)); kwargs...) where {T}
131+
function NFFTToeplitzNormalOp(nfft::NFFTOp{T}, W=nothing; kwargs...) where {T}
132132
shape = nfft.plan.N
133133

134134
tmpVec = similar(nfft.Mv5, (2 .* shape)...)
@@ -147,7 +147,13 @@ function NFFTToeplitzNormalOp(nfft::NFFTOp{T}, W=opEye(eltype(nfft), size(nfft,
147147
precompute=NFFT.POLYNOMIAL, fftflags=FFTW.ESTIMATE, blocking=true)
148148
tmpOnes = similar(tmpVec, size(nfft.plan.k, 2))
149149
tmpOnes .= one(T)
150-
eigMat = adjoint(p) * ( W * tmpOnes)
150+
151+
if !isnothing(W)
152+
eigMat = adjoint(p) * ( W * tmpOnes)
153+
else
154+
eigMat = adjoint(p) * (tmpOnes)
155+
end
156+
151157
λ = fftplan * fftshift(eigMat)
152158

153159
xL1 = tmpVec
@@ -156,7 +162,7 @@ function NFFTToeplitzNormalOp(nfft::NFFTOp{T}, W=opEye(eltype(nfft), size(nfft,
156162
return NFFTToeplitzNormalOp(shape, W, fftplan, ifftplan, λ, xL1, xL2)
157163
end
158164

159-
function LinearOperatorCollection.normalOperator(S::NFFTOpImpl{T}, W = opEye(eltype(S), size(S, 1), S= LinearOperators.storage_type(S)); copyOpsFn = copy, kwargs...) where T
165+
function LinearOperatorCollection.normalOperator(S::NFFTOpImpl{T}, W = nothing; copyOpsFn = copy, kwargs...) where T
160166
if S.toeplitz
161167
return NFFTToeplitzNormalOp(S,W; kwargs...)
162168
else

src/DiagOp.jl

+9-4
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,18 @@ function diagNormOpProd!(y, normalOps, idx, x)
134134
return y
135135
end
136136

137-
function LinearOperatorCollection.normalOperator(diag::DiagOp, W=opEye(eltype(diag), size(diag,1), S = LinearOperators.storage_type(diag)); copyOpsFn = copy, kwargs...)
138-
T = promote_type(eltype(diag), eltype(W))
139-
S = promote_type(LinearOperators.storage_type(diag), LinearOperators.storage_type(W))
137+
function LinearOperatorCollection.normalOperator(diag::DiagOp, W=nothing; copyOpsFn = copy, kwargs...)
138+
if !isnothing(W)
139+
T = promote_type(eltype(diag), eltype(W))
140+
S = promote_type(LinearOperators.storage_type(diag), LinearOperators.storage_type(W))
141+
else
142+
T = eltype(diag)
143+
S = LinearOperators.storage_type(diag)
144+
end
140145
isconcretetype(S) || throw(LinearOperatorException("Storage types cannot be promoted to a concrete type"))
141146
tmp = S(undef, diag.nrow)
142147
tmp .= one(eltype(diag))
143-
weights = W*tmp
148+
weights = isnothing(W) ? tmp : W * tmp
144149

145150

146151
if diag.equalOps

src/LinearOperatorCollection.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ abstract type DCTOp{T} <: AbstractLinearOperatorFromCollection{T} end
3939
abstract type DSTOp{T} <: AbstractLinearOperatorFromCollection{T} end
4040
abstract type NFFTOp{T} <: AbstractLinearOperatorFromCollection{T} end
4141
abstract type SamplingOp{T} <: AbstractLinearOperatorFromCollection{T} end
42-
abstract type NormalOp{T} <: AbstractLinearOperatorFromCollection{T} end
42+
abstract type NormalOp{T,S} <: AbstractLinearOperatorFromCollection{T} end
4343
abstract type GradientOp{T} <: AbstractLinearOperatorFromCollection{T} end
4444
abstract type RadonOp{T} <: AbstractLinearOperatorFromCollection{T} end
4545

src/NormalOp.jl

+14-8
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,15 @@ Computes `adjoint(parent) * weights * parent`.
1616
* `weights` - Optional weights for normal operator. Must already be of form `weights = adjoint.(w) .* w`
1717
1818
"""
19-
function LinearOperatorCollection.NormalOp(::Type{T}; parent, weights = opEye(eltype(parent), size(parent, 1), S = storage_type(parent))) where T <: Number
19+
function LinearOperatorCollection.NormalOp(::Type{T}; parent, weights = nothing) where T <: Number
2020
return NormalOp(T, parent, weights)
2121
end
2222

23-
function NormalOp(::Type{T}, parent, ::Nothing) where T
24-
weights = opEye(eltype(parent), size(parent, 1), S = storage_type(parent))
25-
return NormalOp(T, parent, weights)
26-
end
2723
NormalOp(::Union{Type{T}, Type{Complex{T}}}, parent, weights::AbstractVector{T}) where T = NormalOp(T, parent, WeightingOp(weights))
2824

29-
NormalOp(::Union{Type{T}, Type{Complex{T}}}, parent, weights::AbstractLinearOperator{T}; kwargs...) where T = NormalOpImpl(parent, weights)
25+
NormalOp(::Union{Type{T}, Type{Complex{T}}}, parent, weights; kwargs...) where T = NormalOpImpl(parent, weights)
3026

31-
mutable struct NormalOpImpl{T,S,D,V} <: NormalOp{T}
27+
mutable struct NormalOpImpl{T,S,D,V} <: NormalOp{T, S}
3228
nrow :: Int
3329
ncol :: Int
3430
symmetric :: Bool
@@ -56,13 +52,23 @@ function NormalOpImpl(parent, weights)
5652
tmp = S(undef, size(parent, 1))
5753
return NormalOpImpl(parent, weights, tmp)
5854
end
55+
function NormalOpImpl(parent, weights::Nothing)
56+
S = storage_type(parent)
57+
tmp = S(undef, size(parent, 1))
58+
return NormalOpImpl(parent, weights, tmp)
59+
end
5960

6061
function NormalOpImpl(parent, weights, tmp)
6162
function produ!(y, parent, weights, tmp, x)
6263
mul!(tmp, parent, x)
6364
mul!(tmp, weights, tmp) # This can be dangerous. We might need to create two tmp vectors
6465
return mul!(y, adjoint(parent), tmp)
6566
end
67+
function produ!(y, parent, weights::Nothing, tmp, x)
68+
mul!(tmp, parent, x)
69+
return mul!(y, adjoint(parent), tmp)
70+
end
71+
6672

6773
return NormalOpImpl{eltype(parent), typeof(parent), typeof(weights), typeof(tmp)}(size(parent,2), size(parent,2), false, false
6874
, (res,x) -> produ!(res, parent, weights, tmp, x)
@@ -81,6 +87,6 @@ end
8187
8288
Constructs a normal operator of the parent in an opinionated way, i.e. it tries to apply optimisations to the resulting operator.
8389
"""
84-
function normalOperator(parent, weights=opEye(eltype(parent), size(parent, 1), S= storage_type(parent)); kwargs...)
90+
function normalOperator(parent, weights=nothing; kwargs...)
8591
return NormalOp(eltype(storage_type((parent))); parent = parent, weights = weights)
8692
end

src/ProdOp.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ end
126126
Fuses weights of `ẀeightingOp` by computing `adjoint.(weights) .* weights`
127127
"""
128128
normalOperator(S::ProdOp{T, <:WeightingOp, matT}; kwargs...) where {T, matT} = normalOperator(S.B, WeightingOp(adjoint.(S.A.weights) .* S.A.weights); kwargs...)
129-
function normalOperator(S::ProdOp, W=opEye(eltype(S),size(S,1), S = storage_type(S)); kwargs...)
129+
function normalOperator(S::ProdOp, W=nothing; kwargs...)
130130
arrayType = storage_type(S)
131131
tmp = arrayType(undef, size(S.A, 2))
132132
return ProdNormalOp(S.B, normalOperator(S.A, W; kwargs...), tmp)

test/runtests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using RadonKA
99
using JLArrays
1010

1111
areTypesDefined = @isdefined arrayTypes
12-
arrayTypes = areTypesDefined ? arrayTypes : [Array, JLArray]
12+
arrayTypes = areTypesDefined ? arrayTypes : [Array] #, JLArray]
1313

1414
@testset "LinearOperatorCollection" begin
1515
include("testNormalOp.jl")

test/testOperators.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ function testDiagOp(N=32,K=2;arrayType = Array)
378378

379379
@testset "Weighted Diag Normal" begin
380380
w = rand(eltype(op1), size(op1, 1))
381-
wop = WeightingOp(w)
381+
wop = WeightingOp(arrayType(w))
382382
prod1 = ProdOp(wop, op1)
383383
prod2 = ProdOp(wop, op2)
384384
prod3 = ProdOp(wop, op3)

0 commit comments

Comments
 (0)