diff --git a/NDTensors/Project.toml b/NDTensors/Project.toml index ef3023e0b2..e933b479be 100644 --- a/NDTensors/Project.toml +++ b/NDTensors/Project.toml @@ -93,7 +93,7 @@ StridedViews = "0.2.2, 0.3" TBLIS = "0.2" TimerOutputs = "0.5.5" TupleTools = "1.2.0" -TypeParameterAccessors = "0.1" +TypeParameterAccessors = "0.2" VectorInterface = "0.4.2, 0.5" cuTENSOR = "2" julia = "1.10" diff --git a/NDTensors/ext/NDTensorsAMDGPUExt/adapt.jl b/NDTensors/ext/NDTensorsAMDGPUExt/adapt.jl index 8ef943d674..929790f165 100644 --- a/NDTensors/ext/NDTensorsAMDGPUExt/adapt.jl +++ b/NDTensors/ext/NDTensorsAMDGPUExt/adapt.jl @@ -2,16 +2,14 @@ using NDTensors: NDTensors, EmptyStorage, adapt_storagetype, emptytype using NDTensors.AMDGPUExtensions: AMDGPUExtensions, ROCArrayAdaptor using NDTensors.GPUArraysCoreExtensions: storagemode using NDTensors.TypeParameterAccessors: - default_type_parameter, - set_type_parameter, - set_type_parameters, - type_parameter, - type_parameters + default_type_parameters, set_type_parameters, type_parameters using Adapt: Adapt, adapt using AMDGPU: AMDGPU, ROCArray, ROCVector using Functors: fmap -function AMDGPUExtensions.roc(xs; storagemode=default_type_parameter(ROCArray, storagemode)) +function AMDGPUExtensions.roc( + xs; storagemode=default_type_parameters(ROCArray, storagemode) +) return fmap(x -> adapt(ROCArrayAdaptor{storagemode}(), x), xs) end diff --git a/NDTensors/ext/NDTensorsCUDAExt/adapt.jl b/NDTensors/ext/NDTensorsCUDAExt/adapt.jl index c47a9408be..02e8a41b4b 100644 --- a/NDTensors/ext/NDTensorsCUDAExt/adapt.jl +++ b/NDTensors/ext/NDTensorsCUDAExt/adapt.jl @@ -5,9 +5,9 @@ using NDTensors: NDTensors, EmptyStorage, adapt_storagetype, emptytype using NDTensors.CUDAExtensions: CUDAExtensions, CuArrayAdaptor using NDTensors.GPUArraysCoreExtensions: storagemode using NDTensors.TypeParameterAccessors: - default_type_parameter, set_type_parameters, type_parameters + default_type_parameters, set_type_parameters, type_parameters -function CUDAExtensions.cu(xs; storagemode=default_type_parameter(CuArray, storagemode)) +function CUDAExtensions.cu(xs; storagemode=default_type_parameters(CuArray, storagemode)) return fmap(x -> adapt(CuArrayAdaptor{storagemode}(), x), xs) end diff --git a/NDTensors/src/abstractarray/generic_array_constructors.jl b/NDTensors/src/abstractarray/generic_array_constructors.jl index ff14298e33..8912c81efa 100644 --- a/NDTensors/src/abstractarray/generic_array_constructors.jl +++ b/NDTensors/src/abstractarray/generic_array_constructors.jl @@ -1,5 +1,8 @@ using TypeParameterAccessors: - unwrap_array_type, specify_default_type_parameters, type_parameter + unwrap_array_type, + specify_default_type_parameters, + specify_type_parameters, + type_parameters # Convert to Array, avoiding copying if possible array(a::AbstractArray) = a @@ -8,9 +11,9 @@ vector(a::AbstractVector) = a ## Warning to use these functions it is necessary to define `TypeParameterAccessors.position(::Type{<:YourArrayType}, ::typeof(ndims)))` # Implementation, catches if `ndims(arraytype) != length(dims)`. -## TODO convert ndims to `type_parameter(::, typeof(ndims))` +## TODO convert ndims to `type_parameters(::, typeof(ndims))` function generic_randn(arraytype::Type{<:AbstractArray}, dims...; rng=Random.default_rng()) - arraytype_specified = specify_type_parameter( + arraytype_specified = specify_type_parameters( unwrap_array_type(arraytype), ndims, length(dims) ) arraytype_specified = specify_default_type_parameters(arraytype_specified) @@ -27,7 +30,7 @@ end # Implementation, catches if `ndims(arraytype) != length(dims)`. function generic_zeros(arraytype::Type{<:AbstractArray}, dims...) - arraytype_specified = specify_type_parameter( + arraytype_specified = specify_type_parameters( unwrap_array_type(arraytype), ndims, length(dims) ) arraytype_specified = specify_default_type_parameters(arraytype_specified) diff --git a/NDTensors/src/adapt.jl b/NDTensors/src/adapt.jl index b40abe692c..fb68b40186 100644 --- a/NDTensors/src/adapt.jl +++ b/NDTensors/src/adapt.jl @@ -27,11 +27,11 @@ double_precision(x) = fmap(x -> adapt(double_precision(eltype(x)), x), x) # Used to adapt `EmptyStorage` types # -using TypeParameterAccessors: specify_type_parameter, specify_type_parameters +using TypeParameterAccessors: specify_type_parameters function adapt_storagetype(to::Type{<:AbstractVector}, x::Type{<:TensorStorage}) - return set_datatype(x, specify_type_parameter(to, eltype, eltype(x))) + return set_datatype(x, specify_type_parameters(to, eltype, eltype(x))) end function adapt_storagetype(to::Type{<:AbstractArray}, x::Type{<:TensorStorage}) - return set_datatype(x, specify_type_parameter(to, (ndims, eltype), (1, eltype(x)))) + return set_datatype(x, specify_type_parameters(to, (ndims, eltype), (1, eltype(x)))) end diff --git a/NDTensors/src/dense/generic_array_constructors.jl b/NDTensors/src/dense/generic_array_constructors.jl index 685a6b0a2a..c9ef550e63 100644 --- a/NDTensors/src/dense/generic_array_constructors.jl +++ b/NDTensors/src/dense/generic_array_constructors.jl @@ -1,9 +1,10 @@ using TypeParameterAccessors: - default_type_parameter, + default_type_parameters, parenttype, set_eltype, specify_default_type_parameters, - type_parameter + specify_type_parameters, + type_parameters ##TODO replace randn in ITensors with generic_randn ## and replace zeros with generic_zeros @@ -12,7 +13,9 @@ using TypeParameterAccessors: function generic_randn(StoreT::Type{<:Dense}, dims::Integer; rng=Random.default_rng()) StoreT = specify_default_type_parameters(StoreT) - DataT = specify_type_parameter(type_parameter(StoreT, parenttype), eltype, eltype(StoreT)) + DataT = specify_type_parameters( + type_parameters(StoreT, parenttype), eltype, eltype(StoreT) + ) @assert eltype(StoreT) == eltype(DataT) data = generic_randn(DataT, dims; rng=rng) @@ -22,7 +25,9 @@ end function generic_zeros(StoreT::Type{<:Dense}, dims::Integer) StoreT = specify_default_type_parameters(StoreT) - DataT = specify_type_parameter(type_parameter(StoreT, parenttype), eltype, eltype(StoreT)) + DataT = specify_type_parameters( + type_parameters(StoreT, parenttype), eltype, eltype(StoreT) + ) @assert eltype(StoreT) == eltype(DataT) data = generic_zeros(DataT, dims) diff --git a/NDTensors/src/lib/Expose/src/exposed.jl b/NDTensors/src/lib/Expose/src/exposed.jl index 486914fb30..b57f19e34c 100644 --- a/NDTensors/src/lib/Expose/src/exposed.jl +++ b/NDTensors/src/lib/Expose/src/exposed.jl @@ -1,5 +1,5 @@ using TypeParameterAccessors: - TypeParameterAccessors, unwrap_array_type, parameter, parenttype, type_parameter + TypeParameterAccessors, unwrap_array_type, parenttype, type_parameters struct Exposed{Unwrapped,Object} object::Object end @@ -9,7 +9,7 @@ expose(object) = Exposed{unwrap_array_type(object),typeof(object)}(object) unexpose(E::Exposed) = E.object ## TODO remove TypeParameterAccessors when SetParameters is removed -TypeParameterAccessors.parenttype(type::Type{<:Exposed}) = parameter(type, parenttype) +TypeParameterAccessors.parenttype(type::Type{<:Exposed}) = type_parameters(type, parenttype) function TypeParameterAccessors.position(::Type{<:Exposed}, ::typeof(parenttype)) return TypeParameterAccessors.Position(1) end diff --git a/NDTensors/src/lib/GPUArraysCoreExtensions/src/gpuarrayscore.jl b/NDTensors/src/lib/GPUArraysCoreExtensions/src/gpuarrayscore.jl index 206667e38d..0042a05dab 100644 --- a/NDTensors/src/lib/GPUArraysCoreExtensions/src/gpuarrayscore.jl +++ b/NDTensors/src/lib/GPUArraysCoreExtensions/src/gpuarrayscore.jl @@ -1,15 +1,15 @@ using ..Expose: Exposed, unexpose -using TypeParameterAccessors: TypeParameterAccessors, type_parameter, set_type_parameter +using TypeParameterAccessors: TypeParameterAccessors, type_parameters, set_type_parameters function storagemode(object) return storagemode(typeof(object)) end function storagemode(type::Type) - return type_parameter(type, storagemode) + return type_parameters(type, storagemode) end function set_storagemode(type::Type, param) - return set_type_parameter(type, storagemode, param) + return set_type_parameters(type, storagemode, param) end function cpu end diff --git a/NDTensors/test/Project.toml b/NDTensors/test/Project.toml index 79a4eabfd4..53951dfa18 100644 --- a/NDTensors/test/Project.toml +++ b/NDTensors/test/Project.toml @@ -20,6 +20,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat]