diff --git a/src/NNlib.jl b/src/NNlib.jl index 4a60b4c9a..35d811185 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -3,5 +3,6 @@ module NNlib export σ, relu, softmax include("activation.jl") +include("adapt.jl") end # module diff --git a/src/adapt.jl b/src/adapt.jl new file mode 100644 index 000000000..401092d7f --- /dev/null +++ b/src/adapt.jl @@ -0,0 +1,7 @@ +# This doesn't really belong here, but it's convenient. + +adapt_(T, x) = x + +adapt(T, x) = adapt_(T, x) + +adapt(T, x::RowVector) = RowVector(adapt(T, x.vec))