diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fee2645..01b1b83 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,22 +16,23 @@ jobs: - uses: dtolnay/rust-toolchain@stable with: components: rustfmt + - name: Ubuntu build dependencies + run: apt update && apt install -y clang cmake - name: Check Style run: cargo fmt --check - cargo-test: - runs-on: ubuntu-22.04 - steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable - - name: Run tests - run: cargo test - - build-no_feature: - runs-on: [ubuntu-latest, windows-latest, macos-latest] + build: + strategy: + fail-fast: false + matrix: + platform: [ubuntu-latest, windows-latest, macos-latest] + runs-on: ${{ matrix.platform }} steps: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable + - name: Ubuntu build dependencies + if: matrix.platform == 'ubuntu-latest' + run: apt update && apt install -y clang cmake - name: Build run: cargo build @@ -54,9 +55,12 @@ jobs: - uses: dtolnay/rust-toolchain@stable - name: Install cuda toolkit uses: Jimver/cuda-toolkit@v0.2.18 - - name: install msvc deps (windows) + - name: Install Windows dependencies if: matrix.platform == 'windows-2019' uses: ilammy/msvc-dev-cmd@v1 + - name: Ubuntu build dependencies + if: matrix.platform == 'ubuntu-22.04' + run: apt update && apt install -y clang cmake - name: Build env: CUDA_COMPUTE_CAP: "75" diff --git a/.gitignore b/.gitignore index d01bd1a..d102392 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,5 @@ Cargo.lock # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ \ No newline at end of file +#.idea/ +bin/act diff --git a/sys/Cargo.toml b/sys/Cargo.toml index 0fa7728..d7f5c7c 100644 --- a/sys/Cargo.toml +++ b/sys/Cargo.toml @@ -45,5 +45,4 @@ flashattn = [] [build-dependencies] cmake = "0.1.51" bindgen = "0.70.1" -cfg-if = "1.0.0" fs_extra = "1.3.0" \ No newline at end of file diff --git a/sys/build.rs b/sys/build.rs index f3a49d3..2e23a0b 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -1,6 +1,6 @@ use std::{ env, - fs::{copy, create_dir_all, read_dir}, + fs::{create_dir_all, read_dir}, path::PathBuf, }; @@ -36,24 +36,14 @@ fn main() { // Bindgen let bindings = bindgen::Builder::default().header("wrapper.h"); - let bindings = bindings + bindings .clang_arg("-I./stable-diffusion.cpp") .clang_arg("-I./stable-diffusion.cpp/ggml/include") .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) - .generate(); - - match bindings { - Ok(b) => { - b.write_to_file(out.join("bindings.rs")) - .expect("Couldn't write bindings!"); - } - Err(e) => { - println!("cargo:warning=Unable to generate bindings: {}", e); - println!("cargo:warning=Using bundled bindings.rs, which may be out of date"); - // copy src/bindings.rs to OUT_DIR - copy("src/bindings.rs", out.join("bindings.rs")).expect("Unable to copy bindings.rs"); - } - } + .generate() + .unwrap() + .write_to_file(out.join("bindings.rs")) + .expect("Couldn't write bindings!"); // stop if we're on docs.rs if env::var("DOCS_RS").is_ok() { @@ -81,8 +71,11 @@ fn main() { println!("cargo:rustc-link-search=/opt/cuda/lib64"); println!("cargo:rustc-link-search=/opt/cuda/lib64/stubs"); } - + config.define("SD_CUBLAS", "ON"); + if let Ok(target) = env::var("CUDA_COMPUTE_CAP") { + config.define("CUDA_COMPUTE_CAP", target); + } } #[cfg(feature = "hipblas")] @@ -93,21 +86,20 @@ fn main() { if target.contains("msvc") { panic!("Due to a problem with the last revision of the ROCm 5.7 library, it is not possible to compile the library for the windows environment.\nSee https://github.com/ggerganov/whisper.cpp/issues/2202 for more details.") - } else { println!("cargo:rerun-if-env-changed=HIP_PATH"); let hip_path = match env::var("HIP_PATH") { - Ok(path) =>PathBuf::from(path), + Ok(path) => PathBuf::from(path), Err(_) => PathBuf::from("/opt/rocm"), }; let hip_lib_path = hip_path.join("lib"); - println!("cargo:rustc-link-search={}",hip_lib_path.display()); + println!("cargo:rustc-link-search={}", hip_lib_path.display()); } config.define("SD_HIPBLAS", "ON"); - if let Ok(target) = env::var("AMDGPU_TARGETS") { + if let Ok(target) = env::var("AMDGPU_TARGETS") { config.define("AMDGPU_TARGETS", target); } }