From e66aa8cf6645b0931c442a9839536f3d40983524 Mon Sep 17 00:00:00 2001
From: Johnny <johnnynuca14@gmail.com>
Date: Wed, 22 Jan 2025 00:54:30 +0100
Subject: [PATCH] feat: initial support for blackwell
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Thomas Müller <tom@94.me>
---
 .github/workflows/main.yml                     | 10 ++++++++++
 bindings/torch/setup.py                        |  4 +++-
 bindings/torch/tinycudann/modules.py           |  2 +-
 .../scripts/actions/install_cuda_windows.ps1   | 18 +++++++++---------
 4 files changed, 23 insertions(+), 11 deletions(-)

diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 91148b99..dfec01b5 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -15,6 +15,12 @@ jobs:
     strategy:
       matrix:
         include:
+          - os: ubuntu-24.04
+            cuda: "12.8"
+            arch: 120
+          - os: ubuntu-24.04
+            cuda: "12.8"
+            arch: 100
           - os: ubuntu-24.04
             cuda: "12.6"
             arch: 89
@@ -72,6 +78,10 @@ jobs:
     strategy:
       matrix:
         include:
+          - os: windows-2025
+            visual_studio: "Visual Studio 17 2022"
+            cuda: "12.8.0"
+            arch: 120
           - os: windows-2025
             visual_studio: "Visual Studio 17 2022"
             cuda: "12.6.3"
diff --git a/bindings/torch/setup.py b/bindings/torch/setup.py
index 533805e6..f6431c48 100644
--- a/bindings/torch/setup.py
+++ b/bindings/torch/setup.py
@@ -26,8 +26,10 @@ def max_supported_compute_capability(cuda_version):
 		return 80
 	elif cuda_version < parse_version("11.8"):
 		return 86
-	else:
+	elif cuda_version < parse_version("12.8"):
 		return 90
+	else:
+		return 120
 
 # Find version of tinycudann by scraping CMakeLists.txt
 with open(os.path.join(ROOT_DIR, "CMakeLists.txt"), "r") as cmakelists:
diff --git a/bindings/torch/tinycudann/modules.py b/bindings/torch/tinycudann/modules.py
index 6fe4c913..bf84629b 100644
--- a/bindings/torch/tinycudann/modules.py
+++ b/bindings/torch/tinycudann/modules.py
@@ -13,7 +13,7 @@
 
 import torch
 
-ALL_COMPUTE_CAPABILITIES = [20, 21, 30, 35, 37, 50, 52, 53, 60, 61, 62, 70, 72, 75, 80, 86, 87, 89, 90]
+ALL_COMPUTE_CAPABILITIES = [20, 21, 30, 35, 37, 50, 52, 53, 60, 61, 62, 70, 72, 75, 80, 86, 87, 89, 90, 100, 101, 120]
 
 if not torch.cuda.is_available():
 	raise EnvironmentError("Unknown compute capability. Ensure PyTorch with CUDA support is installed.")
diff --git a/dependencies/cuda-cmake-github-actions/scripts/actions/install_cuda_windows.ps1 b/dependencies/cuda-cmake-github-actions/scripts/actions/install_cuda_windows.ps1
index a112e8d3..7d5b8e84 100755
--- a/dependencies/cuda-cmake-github-actions/scripts/actions/install_cuda_windows.ps1
+++ b/dependencies/cuda-cmake-github-actions/scripts/actions/install_cuda_windows.ps1
@@ -4,7 +4,7 @@
 
 # Dictionary of known cuda versions and thier download URLS, which do not follow a consistent pattern :(
 $CUDA_KNOWN_URLS = @{
-	"8.0.44" = "http://developer.nvidia.com/compute/cuda/8.0/Prod/network_installers/cuda_8.0.44_win10_network-exe";
+    "8.0.44" = "http://developer.nvidia.com/compute/cuda/8.0/Prod/network_installers/cuda_8.0.44_win10_network-exe";
     "8.0.61" = "http://developer.nvidia.com/compute/cuda/8.0/Prod2/network_installers/cuda_8.0.61_win10_network-exe";
     "9.0.176" = "http://developer.nvidia.com/compute/cuda/9.0/Prod/network_installers/cuda_9.0.176_win10_network-exe";
     "9.1.85" = "http://developer.nvidia.com/compute/cuda/9.1/Prod/network_installers/cuda_9.1.85_win10_network";
@@ -25,10 +25,11 @@ $CUDA_KNOWN_URLS = @{
     "11.3.0" = "https://developer.download.nvidia.com/compute/cuda/11.3.0/network_installers/cuda_11.3.0_win10_network.exe";
     "11.3.1" = "https://developer.download.nvidia.com/compute/cuda/11.3.1/network_installers/cuda_11.3.1_win10_network.exe";
     "11.5.0" = "https://developer.download.nvidia.com/compute/cuda/11.5.0/network_installers/cuda_11.5.0_win10_network.exe";
-	"11.5.1" = "https://developer.download.nvidia.com/compute/cuda/11.5.1/network_installers/cuda_11.5.1_windows_network.exe";
+    "11.5.1" = "https://developer.download.nvidia.com/compute/cuda/11.5.1/network_installers/cuda_11.5.1_windows_network.exe";
     "11.8.0" = "https://developer.download.nvidia.com/compute/cuda/11.8.0/network_installers/cuda_11.8.0_windows_network.exe";
     "12.5.0" = "https://developer.download.nvidia.com/compute/cuda/12.5.0/network_installers/cuda_12.5.0_windows_network.exe";
- 	"12.6.3" = "https://developer.download.nvidia.com/compute/cuda/12.6.3/network_installers/cuda_12.6.3_windows_network.exe";
+    "12.6.3" = "https://developer.download.nvidia.com/compute/cuda/12.6.3/network_installers/cuda_12.6.3_windows_network.exe";
+    "12.8.0" = "https://developer.download.nvidia.com/compute/cuda/12.8.0/network_installers/cuda_12.8.0_windows_network.exe";
 }
 
 # @todo - change this to be based on _MSC_VER intead, or invert it to be CUDA keyed instead?
@@ -74,7 +75,7 @@ $CUDA_PATCH=$Matches.patch
 # Exit if visual studio is too new for the cuda version.
 $VISUAL_STUDIO = $env:visual_studio.trim()
 if ($VISUAL_STUDIO.length -ge 4) {
-$VISUAL_STUDIO_YEAR = $VISUAL_STUDIO.Substring($VISUAL_STUDIO.Length-4)
+    $VISUAL_STUDIO_YEAR = $VISUAL_STUDIO.Substring($VISUAL_STUDIO.Length-4)
     if ($VISUAL_STUDIO_YEAR.length -eq 4 -and $VISUAL_STUDIO_MIN_CUDA.containsKey($VISUAL_STUDIO_YEAR)){
         $MINIMUM_CUDA_VERSION = $VISUAL_STUDIO_MIN_CUDA[$VISUAL_STUDIO_YEAR]
         if ([version]$CUDA_VERSION_FULL -lt [version]$MINIMUM_CUDA_VERSION) {
@@ -99,7 +100,7 @@ $CUDA_PACKAGES = ""
 #     }
 # }
 
-Foreach ($package in $CUDA_PACKAGES_IN) {
+foreach ($package in $CUDA_PACKAGES_IN) {
     # Make sure the correct package name is used for nvcc.
     if($package -eq "nvcc" -and [version]$CUDA_VERSION_FULL -lt [version]"9.1"){
         $package="compiler"
@@ -107,7 +108,6 @@ Foreach ($package in $CUDA_PACKAGES_IN) {
         $package="nvcc"
     }
     $CUDA_PACKAGES += " $($package)_$($CUDA_MAJOR).$($CUDA_MINOR)"
-
 }
 echo "$($CUDA_PACKAGES)"
 ## -----------------
@@ -116,9 +116,9 @@ echo "$($CUDA_PACKAGES)"
 
 # Select the download link if known, otherwise have a guess.
 $CUDA_REPO_PKG_REMOTE=""
-if($CUDA_KNOWN_URLS.containsKey($CUDA_VERSION_FULL)){
+if ($CUDA_KNOWN_URLS.containsKey($CUDA_VERSION_FULL)){
     $CUDA_REPO_PKG_REMOTE=$CUDA_KNOWN_URLS[$CUDA_VERSION_FULL]
-} else{
+} else {
     # Guess what the url is given the most recent pattern (at the time of writing, 10.1)
     Write-Output "note: URL for CUDA ${$CUDA_VERSION_FULL} not known, estimating."
     $CUDA_REPO_PKG_REMOTE="http://developer.download.nvidia.com/compute/cuda/$($CUDA_MAJOR).$($CUDA_MINOR)/Prod/network_installers/cuda_$($CUDA_VERSION_FULL)_win10_network.exe"
@@ -133,7 +133,7 @@ $CUDA_REPO_PKG_LOCAL="cuda_$($CUDA_VERSION_FULL)_win10_network.exe"
 # Get CUDA network installer
 Write-Output "Downloading CUDA Network Installer for $($CUDA_VERSION_FULL) from: $($CUDA_REPO_PKG_REMOTE)"
 Invoke-WebRequest $CUDA_REPO_PKG_REMOTE -OutFile $CUDA_REPO_PKG_LOCAL | Out-Null
-if(Test-Path -Path $CUDA_REPO_PKG_LOCAL){
+if (Test-Path -Path $CUDA_REPO_PKG_LOCAL){
     Write-Output "Downloading Complete"
 } else {
     Write-Output "Error: Failed to download $($CUDA_REPO_PKG_LOCAL) from $($CUDA_REPO_PKG_REMOTE)"