Skip to content

Commit

Permalink
[TypeParameterAccessors] similartype
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Nov 1, 2024
1 parent 8bb156a commit 11a5c98
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 66 deletions.
10 changes: 0 additions & 10 deletions NDTensors/src/abstractarray/set_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,3 @@ TODO: Use `Accessors.jl` notation:
function TypeParameterAccessors.set_ndims(numbertype::Type{<:Number}, ndims)
return numbertype
end

"""
`set_indstype` should be overloaded for
types with structured dimensions,
like `OffsetArrays` or named indices
(such as ITensors).
"""
function set_indstype(arraytype::Type{<:AbstractArray}, dims::Tuple)
return set_ndims(arraytype, length(dims))
end
56 changes: 1 addition & 55 deletions NDTensors/src/abstractarray/similar.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Base: DimOrInd, Dims, OneTo
using .TypeParameterAccessors: IsWrappedArray, unwrap_array_type, set_eltype

## Custom `NDTensors.similar` implementation.
Expand Down Expand Up @@ -96,58 +97,3 @@ end
# Use the `size` to determine the dimensions
# NDTensors.similar
similar(array::AbstractArray) = NDTensors.similar(typeof(array), size(array))

## similartype

function similartype(arraytype::Type{<:AbstractArray}, eltype::Type, dims::Tuple)
return similartype(similartype(arraytype, eltype), dims)
end

@traitfn function similartype(
arraytype::Type{ArrayT}, eltype::Type
) where {ArrayT; !IsWrappedArray{ArrayT}}
return set_eltype(arraytype, eltype)
end

@traitfn function similartype(
arraytype::Type{ArrayT}, dims::Tuple
) where {ArrayT; !IsWrappedArray{ArrayT}}
return set_indstype(arraytype, dims)
end

function similartype(arraytype::Type{<:AbstractArray}, dims::DimOrInd...)
return similartype(arraytype, dims)
end

function similartype(arraytype::Type{<:AbstractArray})
return similartype(arraytype, eltype(arraytype))
end

## Wrapped arrays
@traitfn function similartype(
arraytype::Type{ArrayT}, eltype::Type
) where {ArrayT; IsWrappedArray{ArrayT}}
return similartype(unwrap_array_type(arraytype), eltype)
end

@traitfn function similartype(
arraytype::Type{ArrayT}, dims::Tuple
) where {ArrayT; IsWrappedArray{ArrayT}}
return similartype(unwrap_array_type(arraytype), dims)
end

# This is for uniform `Diag` storage which uses
# a Number as the data type.
# TODO: Delete this when we change to using a
# `FillArray` instead. This is a stand-in
# to make things work with the current design.
function similartype(numbertype::Type{<:Number})
return numbertype
end

# Instances
function similartype(array::AbstractArray, eltype::Type, dims...)
return similartype(typeof(array), eltype, dims...)
end
similartype(array::AbstractArray, eltype::Type) = similartype(typeof(array), eltype)
similartype(array::AbstractArray, dims...) = similartype(typeof(array), dims...)
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ include("set_parameters.jl")
include("specify_parameters.jl")
include("default_parameters.jl")
include("base/abstractarray.jl")
include("base/similartype.jl")
include("base/array.jl")
include("base/linearalgebra.jl")
include("base/stridedviews.jl")
Expand Down
62 changes: 62 additions & 0 deletions NDTensors/src/lib/TypeParameterAccessors/src/base/similartype.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
`set_indstype` should be overloaded for
types with structured dimensions,
like `OffsetArrays` or named indices
(such as ITensors).
"""
function set_indstype(arraytype::Type{<:AbstractArray}, dims::Tuple)
return set_ndims(arraytype, length(dims))
end

function similartype(arraytype::Type{<:AbstractArray}, eltype::Type, dims::Tuple)
return similartype(similartype(arraytype, eltype), dims)
end

@traitfn function similartype(
arraytype::Type{ArrayT}, eltype::Type
) where {ArrayT; !IsWrappedArray{ArrayT}}
return set_eltype(arraytype, eltype)
end

@traitfn function similartype(
arraytype::Type{ArrayT}, dims::Tuple
) where {ArrayT; !IsWrappedArray{ArrayT}}
return set_indstype(arraytype, dims)
end

function similartype(arraytype::Type{<:AbstractArray}, dims::Base.DimOrInd...)
return similartype(arraytype, dims)
end

function similartype(arraytype::Type{<:AbstractArray})
return similartype(arraytype, eltype(arraytype))
end

## Wrapped arrays
@traitfn function similartype(
arraytype::Type{ArrayT}, eltype::Type
) where {ArrayT; IsWrappedArray{ArrayT}}
return similartype(unwrap_array_type(arraytype), eltype)
end

@traitfn function similartype(
arraytype::Type{ArrayT}, dims::Tuple
) where {ArrayT; IsWrappedArray{ArrayT}}
return similartype(unwrap_array_type(arraytype), dims)
end

# This is for uniform `Diag` storage which uses
# a Number as the data type.
# TODO: Delete this when we change to using a
# `FillArray` instead. This is a stand-in
# to make things work with the current design.
function similartype(numbertype::Type{<:Number})
return numbertype
end

# Instances
function similartype(array::AbstractArray, eltype::Type, dims...)
return similartype(typeof(array), eltype, dims...)
end
similartype(array::AbstractArray, eltype::Type) = similartype(typeof(array), eltype)
similartype(array::AbstractArray, dims...) = similartype(typeof(array), dims...)
1 change: 1 addition & 0 deletions NDTensors/src/lib/TypeParameterAccessors/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ using Test: @testset
include("test_defaults.jl")
include("test_custom_types.jl")
include("test_wrappers.jl")
include("test_similartype.jl")
end
end
15 changes: 15 additions & 0 deletions NDTensors/src/lib/TypeParameterAccessors/test/test_similartype.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
@eval module $(gensym())
using Test: @test, @test_broken, @testset
using LinearAlgebra: Adjoint
using NDTensors.TypeParameterAccessors: similartype
@testset "TypeParameterAccessors similartype" begin
@test similartype(Array, Float64, (2, 2)) == Matrix{Float64}
# TODO: Is this a good definition? Probably it should be left unspecified.
@test similartype(Array) == Array{Any}
@test similartype(Array, Float64) == Array{Float64}
@test similartype(Array, (2, 2)) == Matrix
@test similartype(Adjoint{Float32,Matrix{Float32}}, Float64, (2, 2, 2)) ==
Array{Float64,3}
@test similartype(Adjoint{Float32,Matrix{Float32}}, Float64) == Matrix{Float64}
end
end
2 changes: 1 addition & 1 deletion NDTensors/src/tensor/set_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ end

# TODO: Modify the `storagetype` according to `inds`, such as the dimensions?
# TODO: Make a version that accepts `indstype::Type`?
function set_indstype(tensortype::Type{<:Tensor}, inds::Tuple)
function TypeParameterAccessors.set_indstype(tensortype::Type{<:Tensor}, inds::Tuple)
return Tensor{eltype(tensortype),length(inds),storagetype(tensortype),typeof(inds)}
end

Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/tensor/similar.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using .TypeParameterAccessors: set_indstype

# NDTensors.similar
similar(tensor::Tensor) = setstorage(tensor, similar(storage(tensor)))

Expand Down

0 comments on commit 11a5c98

Please sign in to comment.