Skip to content

Commit

Permalink
AMD GPU working (ROCm 6.x)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfxuan committed Mar 17, 2024
1 parent 9597d60 commit 35c0718
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ docker build \
Note: If you want to use ROCm 6.x, you need to switch to AMD version of pytorch docker as a base layer to build:
```bash
docker build \
-t opensplat:ubuntu-22.04-libtorch-torch-2.1.2-rocm-6.0.2 \
-t opensplat:ubuntu-22.04-libtorch-2.1.2-rocm-6.0.2 \
-f Dockerfile.rocm6 .
```

Expand All @@ -116,10 +116,23 @@ To run on your own data, choose the path to an existing [COLMAP](https://colmap.

There's several parameters you can tune. To view the full list:
```bash
./opensplat --help
```
To train a model with AMD GPU using docker container, you can use the following command as a reference:
1. Launch the docker container with the following command:
```bash
docker run -it -v ~/data:/data --device=/dev/kfd --device=/dev/dri opensplat:ubuntu-22.04-libtorch-2.1.2-rocm-6.0.2 bash
```
2. Inside the docker container, run the following command to train the model:
```bash
export HIP_VISIBLE_DEVICES=0
export HSA_OVERRIDE_GFX_VERSION=10.3.0 # AMD RX 6700 XT workaround
cd /code/build
./opensplat /data/banana -n 2000
```
## Project Goals
We recently released OpenSplat, so there's lots of work to do.
Expand Down
6 changes: 3 additions & 3 deletions vendor/gsplat/reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
#include <hip/hip_cooperative_groups.h>

#define MAX_INIT 0.0
#define WARP_SIZE 32
#define WARP_SIZE 64

namespace cg = cooperative_groups;

__inline__ __device__ float warp_reduce_sum(float val, const int tile) {
for ( int offset = tile / 2; offset > 0; offset /= 2 )
val += __shfl_down(0xffffffff, val, offset);
val += __shfl_down(val, offset);

return val;
}
Expand Down Expand Up @@ -39,7 +39,7 @@ __inline__ __device__ float block_reduce_sum(float val, const int tile) {

__inline__ __device__ float warp_reduce_max(float val, const int tile) {
for (int offset = tile / 2; offset > 0; offset /= 2)
val = max(val, __shfl_xor(0xffffffff, val, offset));
val = max(val, __shfl_xor(val, offset));
return val;
}

Expand Down

0 comments on commit 35c0718

Please sign in to comment.