Skip to content

Commit

Permalink
Merge pull request #154 from newfla/hip_linux
Browse files Browse the repository at this point in the history
feat: ROCm linux support
  • Loading branch information
tazz4843 authored Jun 2, 2024
2 parents e6271bf + ce71477 commit b46876a
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ default = []
raw-api = []
coreml = ["whisper-rs-sys/coreml"]
cuda = ["whisper-rs-sys/cuda", "_gpu"]
hipblas = ["whisper-rs-sys/hipblas", "_gpu"]
opencl = ["whisper-rs-sys/opencl"]
openblas = ["whisper-rs-sys/openblas"]
metal = ["whisper-rs-sys/metal", "_gpu"]
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ All disabled by default unless otherwise specified.
**NOTE**: enabling this no longer guarantees semver compliance,
as whisper-rs-sys may be upgraded to a breaking version in a patch release of whisper-rs.
* `cuda`: enable CUDA support. Implicitly enables hidden GPU flag at runtime.
* `hipblas`: enable ROCm/hipBLAS support. Only available on linux. Implicitly enables hidden GPU flag at runtime.
* `opencl`: enable OpenCL support. Upstream whisper.cpp does not treat OpenCL as a GPU, so it is always enabled at
runtime.
* `openblas`: enable OpenBLAS support.
Expand Down
1 change: 1 addition & 0 deletions sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ include = [
[features]
coreml = []
cuda = []
hipblas = []
opencl = []
openblas = []
metal = []
Expand Down
33 changes: 33 additions & 0 deletions sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,29 @@ fn main() {
}
}
}
#[cfg(feature = "hipblas")]
{
println!("cargo:rustc-link-lib=hipblas");
println!("cargo:rustc-link-lib=rocblas");
println!("cargo:rustc-link-lib=amdhip64");

cfg_if::cfg_if! {
if #[cfg(target_os = "windows")] {
panic!("Due to a problem with the last revision of the ROCm 5.7 library, it is not possible to compile the library for the windows environment.\nSee https://github.com/ggerganov/whisper.cpp/issues/2202 for more details.")
} else {
println!("cargo:rerun-if-env-changed=HIP_PATH");

let hip_path = match env::var("HIP_PATH") {
Ok(path) =>PathBuf::from(path),
Err(_) => PathBuf::from("/opt/rocm"),
};
let hip_lib_path = hip_path.join("lib");

println!("cargo:rustc-link-search={}",hip_lib_path.display());
}
}
}

println!("cargo:rerun-if-changed=wrapper.h");

let out = PathBuf::from(env::var("OUT_DIR").unwrap());
Expand Down Expand Up @@ -126,6 +149,16 @@ fn main() {
config.define("WHISPER_CUDA", "ON");
}

if cfg!(feature = "hipblas") {
config.define("WHISPER_HIPBLAS", "ON");
config.define("CMAKE_C_COMPILER", "hipcc");
config.define("CMAKE_CXX_COMPILER", "hipcc");
println!("cargo:rerun-if-env-changed=AMDGPU_TARGETS");
if let Ok(gpu_targets) = env::var("AMDGPU_TARGETS") {
config.define("AMDGPU_TARGETS", gpu_targets);
}
}

if cfg!(feature = "openblas") {
config.define("WHISPER_OPENBLAS", "ON");
}
Expand Down

0 comments on commit b46876a

Please sign in to comment.