Skip to content

Commit

Permalink
Update TypeParameterAccessors to v0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Dec 17, 2024
1 parent 0efe21e commit 6991ac8
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 25 deletions.
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ StridedViews = "0.2.2, 0.3"
TBLIS = "0.2"
TimerOutputs = "0.5.5"
TupleTools = "1.2.0"
TypeParameterAccessors = "0.1"
TypeParameterAccessors = "0.2"
VectorInterface = "0.4.2, 0.5"
cuTENSOR = "2"
julia = "1.10"
Expand Down
10 changes: 4 additions & 6 deletions NDTensors/ext/NDTensorsAMDGPUExt/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@ using NDTensors: NDTensors, EmptyStorage, adapt_storagetype, emptytype
using NDTensors.AMDGPUExtensions: AMDGPUExtensions, ROCArrayAdaptor
using NDTensors.GPUArraysCoreExtensions: storagemode
using NDTensors.TypeParameterAccessors:
default_type_parameter,
set_type_parameter,
set_type_parameters,
type_parameter,
type_parameters
default_type_parameters, set_type_parameters, type_parameters
using Adapt: Adapt, adapt
using AMDGPU: AMDGPU, ROCArray, ROCVector
using Functors: fmap

function AMDGPUExtensions.roc(xs; storagemode=default_type_parameter(ROCArray, storagemode))
function AMDGPUExtensions.roc(
xs; storagemode=default_type_parameters(ROCArray, storagemode)
)
return fmap(x -> adapt(ROCArrayAdaptor{storagemode}(), x), xs)
end

Expand Down
4 changes: 2 additions & 2 deletions NDTensors/ext/NDTensorsCUDAExt/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ using NDTensors: NDTensors, EmptyStorage, adapt_storagetype, emptytype
using NDTensors.CUDAExtensions: CUDAExtensions, CuArrayAdaptor
using NDTensors.GPUArraysCoreExtensions: storagemode
using NDTensors.TypeParameterAccessors:
default_type_parameter, set_type_parameters, type_parameters
default_type_parameters, set_type_parameters, type_parameters

function CUDAExtensions.cu(xs; storagemode=default_type_parameter(CuArray, storagemode))
function CUDAExtensions.cu(xs; storagemode=default_type_parameters(CuArray, storagemode))
return fmap(x -> adapt(CuArrayAdaptor{storagemode}(), x), xs)
end

Expand Down
11 changes: 7 additions & 4 deletions NDTensors/src/abstractarray/generic_array_constructors.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using TypeParameterAccessors:
unwrap_array_type, specify_default_type_parameters, type_parameter
unwrap_array_type,
specify_default_type_parameters,
specify_type_parameters,
type_parameters

# Convert to Array, avoiding copying if possible
array(a::AbstractArray) = a
Expand All @@ -8,9 +11,9 @@ vector(a::AbstractVector) = a

## Warning to use these functions it is necessary to define `TypeParameterAccessors.position(::Type{<:YourArrayType}, ::typeof(ndims)))`
# Implementation, catches if `ndims(arraytype) != length(dims)`.
## TODO convert ndims to `type_parameter(::, typeof(ndims))`
## TODO convert ndims to `type_parameters(::, typeof(ndims))`
function generic_randn(arraytype::Type{<:AbstractArray}, dims...; rng=Random.default_rng())
arraytype_specified = specify_type_parameter(
arraytype_specified = specify_type_parameters(
unwrap_array_type(arraytype), ndims, length(dims)
)
arraytype_specified = specify_default_type_parameters(arraytype_specified)
Expand All @@ -27,7 +30,7 @@ end

# Implementation, catches if `ndims(arraytype) != length(dims)`.
function generic_zeros(arraytype::Type{<:AbstractArray}, dims...)
arraytype_specified = specify_type_parameter(
arraytype_specified = specify_type_parameters(
unwrap_array_type(arraytype), ndims, length(dims)
)
arraytype_specified = specify_default_type_parameters(arraytype_specified)
Expand Down
6 changes: 3 additions & 3 deletions NDTensors/src/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ double_precision(x) = fmap(x -> adapt(double_precision(eltype(x)), x), x)
# Used to adapt `EmptyStorage` types
#

using TypeParameterAccessors: specify_type_parameter, specify_type_parameters
using TypeParameterAccessors: specify_type_parameters
function adapt_storagetype(to::Type{<:AbstractVector}, x::Type{<:TensorStorage})
return set_datatype(x, specify_type_parameter(to, eltype, eltype(x)))
return set_datatype(x, specify_type_parameters(to, eltype, eltype(x)))
end

function adapt_storagetype(to::Type{<:AbstractArray}, x::Type{<:TensorStorage})
return set_datatype(x, specify_type_parameter(to, (ndims, eltype), (1, eltype(x))))
return set_datatype(x, specify_type_parameters(to, (ndims, eltype), (1, eltype(x))))
end
13 changes: 9 additions & 4 deletions NDTensors/src/dense/generic_array_constructors.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
using TypeParameterAccessors:
default_type_parameter,
default_type_parameters,
parenttype,
set_eltype,
specify_default_type_parameters,
type_parameter
specify_type_parameters,
type_parameters
##TODO replace randn in ITensors with generic_randn
## and replace zeros with generic_zeros

Expand All @@ -12,7 +13,9 @@ using TypeParameterAccessors:

function generic_randn(StoreT::Type{<:Dense}, dims::Integer; rng=Random.default_rng())
StoreT = specify_default_type_parameters(StoreT)
DataT = specify_type_parameter(type_parameter(StoreT, parenttype), eltype, eltype(StoreT))
DataT = specify_type_parameters(
type_parameters(StoreT, parenttype), eltype, eltype(StoreT)
)
@assert eltype(StoreT) == eltype(DataT)

data = generic_randn(DataT, dims; rng=rng)
Expand All @@ -22,7 +25,9 @@ end

function generic_zeros(StoreT::Type{<:Dense}, dims::Integer)
StoreT = specify_default_type_parameters(StoreT)
DataT = specify_type_parameter(type_parameter(StoreT, parenttype), eltype, eltype(StoreT))
DataT = specify_type_parameters(
type_parameters(StoreT, parenttype), eltype, eltype(StoreT)
)
@assert eltype(StoreT) == eltype(DataT)

data = generic_zeros(DataT, dims)
Expand Down
4 changes: 2 additions & 2 deletions NDTensors/src/lib/Expose/src/exposed.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using TypeParameterAccessors:
TypeParameterAccessors, unwrap_array_type, parameter, parenttype, type_parameter
TypeParameterAccessors, unwrap_array_type, parenttype, type_parameters
struct Exposed{Unwrapped,Object}
object::Object
end
Expand All @@ -9,7 +9,7 @@ expose(object) = Exposed{unwrap_array_type(object),typeof(object)}(object)
unexpose(E::Exposed) = E.object

## TODO remove TypeParameterAccessors when SetParameters is removed
TypeParameterAccessors.parenttype(type::Type{<:Exposed}) = parameter(type, parenttype)
TypeParameterAccessors.parenttype(type::Type{<:Exposed}) = type_parameters(type, parenttype)
function TypeParameterAccessors.position(::Type{<:Exposed}, ::typeof(parenttype))
return TypeParameterAccessors.Position(1)
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
using ..Expose: Exposed, unexpose
using TypeParameterAccessors: TypeParameterAccessors, type_parameter, set_type_parameter
using TypeParameterAccessors: TypeParameterAccessors, type_parameters, set_type_parameters

function storagemode(object)
return storagemode(typeof(object))
end
function storagemode(type::Type)
return type_parameter(type, storagemode)
return type_parameters(type, storagemode)
end

function set_storagemode(type::Type, param)
return set_type_parameter(type, storagemode, param)
return set_type_parameters(type, storagemode, param)
end

function cpu end
Expand Down
1 change: 1 addition & 0 deletions NDTensors/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Expand Down

0 comments on commit 6991ac8

Please sign in to comment.