Skip to content

Commit

Permalink
Merge pull request #9 from JuliaGPU/tb/redesign
Browse files Browse the repository at this point in the history
General improvements for use in CUDAnative.jl
  • Loading branch information
MikeInnes authored Oct 23, 2018
2 parents 8bb4c0a + 10a3752 commit 731f701
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 23 deletions.
29 changes: 22 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,38 @@

[![Build Status](https://travis-ci.org/JuliaGPU/Adapt.jl.svg?branch=master)](https://travis-ci.org/JuliaGPU/Adapt.jl)

The `adapt(T, x)` function acts like `convert(T, x)`, but without the restriction of returning a `T`. This allows you to "convert" wrapper types like `Adjoint` to be GPU compatible (for example) without throwing away the wrapper.
The `adapt(T, x)` function acts like `convert(T, x)`, but without the
restriction of returning a `T`. This allows you to "convert" wrapper types like
`Adjoint` to be GPU compatible (for example) without throwing away the wrapper.

e.g.
For example:

```julia
adapt(CuArray, ::Adjoint{Array})::Adjoint{CuArray}
```

New data types like `Adjoint` should overload `adapt(T, ::Adjoint)` (usually just to forward the call to `adapt`).
New wrapper types like `Adjoint` should overload `adapt_structure(T, ::Adjoint)`
(usually just to forward the call to `adapt`):

```julia
adapt(T, x::Adjoint) = Adjoint(adapt(T, parent(x)))
Adapt.adapt_structure(to, x::Adjoint) = Adjoint(adapt(to, parent(x)))
```

New adaptor types like `CuArray` should overload `adapt_` for compatible types.
A similar function, `adapt_storage`, can be used to define the conversion
behavior for the innermost storage types:

```julia
adapt_(::Type{<:CuArray}, xs::AbstractArray) =
isbits(xs) ? xs : convert(CuArray, xs)
adapt_storage(::Type{<:CuArray}, xs::AbstractArray) = convert(CuArray, xs)
```

Implementations of `adapt_storage` will typically be part of libraries that use
Adapt. For example, CuArrays.jl defines methods of
`adapt_storage(::Type{<:CuArray}, ...)` and uses that to convert different kinds
of arrays, while CUDAnative.jl provides implementations of
`adapt_storage(::CUDAnative.Adaptor, ...)` to convert various values to
GPU-compatible alternatives.

Packages that define new wrapper types and want to be compatible with packages
that use Adapt.jl should provide implementations of `adapt_structure` that
preserve the wrapper type. Adapt.jl already provides such methods for array
wrappers that are part of the Julia standard library.
41 changes: 41 additions & 0 deletions appveyor.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
environment:
matrix:
- julia_version: 0.7
- julia_version: 1
- julia_version: nightly

platform:
- x86 # 32-bit
- x64 # 64-bit

matrix:
allow_failures:
- julia_version: nightly

branches:
only:
- master
- /release-.*/

notifications:
- provider: Email
on_build_success: false
on_build_failure: false
on_build_status_changed: false

install:
- ps: iex ((new-object net.webclient).DownloadString("https://raw.githubusercontent.com/JuliaCI/Appveyor.jl/version-1/bin/install.ps1"))

build_script:
- echo "%JL_BUILD_SCRIPT%"
- C:\julia\bin\julia -e "%JL_BUILD_SCRIPT%"

test_script:
- echo "%JL_TEST_SCRIPT%"
- C:\julia\bin\julia -e "%JL_TEST_SCRIPT%"

# # Uncomment to support code coverage upload. Should only be enabled for packages
# # which would have coverage gaps without running on Windows
# on_success:
# - echo "%JL_CODECOV_SCRIPT%"
# - C:\julia\bin\julia -e "%JL_CODECOV_SCRIPT%"
14 changes: 7 additions & 7 deletions src/Adapt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
module Adapt

using LinearAlgebra
export adapt

adapt_(T, x) = x
# external interface
adapt(to, x) = adapt_structure(to, x)

adapt(T, x) = adapt_(T, x)
# interface for libraries to implement
adapt_structure(to, x) = adapt_storage(to, x)
adapt_storage(to, x) = x

# Base integrations

adapt(T, x::Adjoint) = Adjoint(adapt(T, parent(x)))
adapt(T, x::Transpose) = Transpose(adapt(T, parent(x)))
include("base.jl")

end # module
27 changes: 27 additions & 0 deletions src/base.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# predefined adaptors for working with types from the Julia standard library

## Base

adapt_structure(to, xs::Tuple) = Tuple(adapt(to, x) for x in xs)
@generated adapt_structure(to, x::NamedTuple) =
Expr(:tuple, (:($f=adapt(to, x.$f)) for f in fieldnames(x))...)

adapt(to, x::SubArray) = SubArray(adapt(to, parent(x)), parentindices(x))


## LinearAlgebra

import LinearAlgebra: Adjoint, Transpose
adapt_structure(to, x::Adjoint) = Adjoint(adapt(to, parent(x)))
adapt_structure(to, x::Transpose) = Transpose(adapt(to, parent(x)))


## Broadcast

import Base.Broadcast: Broadcasted, Extruded

adapt_structure(to, bc::Broadcasted{Style}) where Style =
Broadcasted{Style}(bc.f, map(arg->adapt(to, arg), bc.args), bc.axes)

adapt_structure(to, ex::Extruded) =
Extruded(adapt(to, ex.x), ex.keeps, ex.defaults)
52 changes: 43 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,51 @@
import Adapt: adapt, adapt_
using Adapt
using Test

# trivial test

struct Matrix
mat::AbstractArray
# custom array type

struct CustomArray{T,N} <: AbstractArray{T,N}
arr::AbstractArray
end

adapt_(::Type{<:Matrix}, xs::AbstractArray) =
Matrix(xs)
CustomArray(x::AbstractArray{T,N}) where {T,N} = CustomArray{T,N}(x)
Adapt.adapt_storage(::Type{<:CustomArray}, xs::AbstractArray) = CustomArray(xs)

Base.size(x::CustomArray, y...) = size(x.arr, y...)
Base.getindex(x::CustomArray, y...) = getindex(x.arr, y...)


const val = CustomArray{Float64,2}(rand(2,2))

# basic adaption
@test adapt(CustomArray, val.arr) == val
@test adapt(CustomArray, val.arr) isa CustomArray

# idempotency
@test adapt(CustomArray, val) == val
@test adapt(CustomArray, val) isa CustomArray

# custom wrapper
struct Wrapper{T}
arr::T
end
Wrapper(x::T) where T = Wrapper{T}(x)
Adapt.adapt_structure(to, xs::Wrapper) = Wrapper(adapt(to, xs.arr))
@test adapt(CustomArray, Wrapper(val.arr)) == Wrapper(val)
@test adapt(CustomArray, Wrapper(val.arr)) isa Wrapper{<:CustomArray}


## base wrappers

@test adapt(CustomArray, (val.arr,)) == (val,)

@test adapt(CustomArray, (a=val.arr,)) == (a=val,)

@test adapt(CustomArray, view(val.arr,:,:)) == view(val,:,:)
@test adapt(CustomArray, view(val.arr,:,:)) isa SubArray{<:Any,<:Any,<:CustomArray}

testmat = [12;34;56;78]

testresult = Matrix(testmat)
using LinearAlgebra

@test adapt(Matrix, testmat) == testresult
@test adapt(CustomArray, val.arr') == val'
@test adapt(CustomArray, val.arr') isa Adjoint{<:Any,<:CustomArray}

0 comments on commit 731f701

Please sign in to comment.