From e6dd65cfbc2313c7a8584fc26af6fc1ceb9bc31d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 3 Oct 2024 23:24:52 -0400 Subject: [PATCH] fix: urgent patch for reactant breakage --- Project.toml | 4 ++-- src/impl/Impl.jl | 2 +- src/impl/conv.jl | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index d1e4779f..ab9801d9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.3.1" +version = "1.3.2" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -71,7 +71,7 @@ LinearAlgebra = "1.10" LoopVectorization = "0.12.171" LuxCore = "1" MKL = "0.7" -MLDataDevices = "1.1.1" +MLDataDevices = "1.2" Markdown = "1.10" NNlib = "0.9.24" Octavian = "0.3.28" diff --git a/src/impl/Impl.jl b/src/impl/Impl.jl index c1818c77..8956a639 100644 --- a/src/impl/Impl.jl +++ b/src/impl/Impl.jl @@ -21,7 +21,7 @@ using Random: Random, AbstractRNG, rand! using Statistics: Statistics, mean, var using LuxCore: LuxCore -using MLDataDevices: get_device_type, CPUDevice, AMDGPUDevice, CUDADevice, +using MLDataDevices: get_device_type, CPUDevice, AMDGPUDevice, CUDADevice, XLADevice, AbstractGPUDevice, AbstractDevice using NNlib: NNlib, ConvDims diff --git a/src/impl/conv.jl b/src/impl/conv.jl index f5181b65..3a3d22ee 100644 --- a/src/impl/conv.jl +++ b/src/impl/conv.jl @@ -74,8 +74,8 @@ end conv(x, weight, cdims::ConvDims) = conv(get_device_type((x, weight)), x, weight, cdims) -function conv( - ::Type{<:Union{CPUDevice, CUDADevice, AMDGPUDevice}}, x′, weight′, cdims::ConvDims) +function conv(::Type{<:Union{CPUDevice, CUDADevice, AMDGPUDevice, XLADevice}}, + x′, weight′, cdims::ConvDims) x, weight = get_conv_input_weight(x′, weight′) return NNlib.conv(x, weight, cdims) end