Skip to content

Commit

Permalink
Fix unsafe_wrap of a view. (#452)
Browse files Browse the repository at this point in the history
And some other small fixes.
  • Loading branch information
christiangnrd committed Oct 17, 2024
1 parent 5cf9372 commit f042969
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 3 deletions.
2 changes: 1 addition & 1 deletion perf/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ const m = 512
const n = 1000

for (S, smname) in [(Metal.PrivateStorage,"private"), (Metal.SharedStorage,"shared")]
group = addgroup!(SUITE, "$smname array")
local group = addgroup!(SUITE, "$smname array")

# generate some arrays
cpu_mat = rand(rng, Float32, m, n)
Expand Down
2 changes: 1 addition & 1 deletion src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ Base.unsafe_convert(::Type{MTL.MTLBuffer}, A::PermutedDimsArray) =
## unsafe_wrap

function Base.unsafe_wrap(::Type{<:Array}, arr::MtlArray{T,N}, dims=size(arr); own=false) where {T,N}
return unsafe_wrap(Array{T,N}, arr.data[], dims; own)
return unsafe_wrap(Array{T,N}, pointer(arr), dims; own)
end

function Base.unsafe_wrap(t::Type{<:Array{T}}, buf::MTLBuffer, dims; own=false) where T
Expand Down
2 changes: 1 addition & 1 deletion src/device/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ Base.show(io::IO, a::MtlDeviceArray) =
Base.show(io::IO, mime::MIME"text/plain", a::MtlDeviceArray) = show(io, a)

@inline function Base.unsafe_view(A::MtlDeviceVector{T}, I::Vararg{Base.ViewIndex,1}) where {T}
ptr = pointer(A) + (I[1].start-1)*sizeof(T)
ptr = pointer(A, I[1].start)
len = I[1].stop - I[1].start + 1
return MtlDeviceArray(len, ptr)
end
Expand Down
9 changes: 9 additions & 0 deletions test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,15 @@ end
arr2 .+= 1;
@test all(arr2 .== 2)
@test all(marr2 .== 2)

@testset "Issue #451" begin
a = mtl(reshape(Float32.(1:60), 5,4,3);storage=Metal.SharedStorage)
view_a = @view a[:,1:4,2]
b = copy(unsafe_wrap(Array, view_a))
c = Array(view_a)

@test b == c
end
end

@testset "ReshapedArray" begin
Expand Down

0 comments on commit f042969

Please sign in to comment.