Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

fix for Zygote and ChainRules OneElement #92

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

CarloLucibello
Copy link
Contributor

@CarloLucibello CarloLucibello commented Oct 27, 2024

@@ -48,6 +50,7 @@ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"]
AMDGPU = "0.9.6, 1"
Adapt = "4.1"
CUDA = "5.2"
ChainRules = "1.51"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the first version with OneElement

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using MLDataDevices: GPU_DEVICES, CPUDevice

Adapt.adapt_storage(::CPUDevice, x::OneElement) = x
for Dev in GPU_DEVICES
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This need to be done for all the devices. GPU_DEVICES doesn't contain all. that will fix the XLA CI

@@ -48,6 +50,7 @@ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"]
AMDGPU = "0.9.6, 1"
Adapt = "4.1"
CUDA = "5.2"
ChainRules = "1.51"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

scalar indexing of gpu array in Zygote gradient
2 participants