-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
364 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# To get started with Dependabot version updates, you'll need to specify which | ||
# package ecosystems to update and where the package manifests are located. | ||
# Please see the documentation for all configuration options: | ||
# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file | ||
|
||
version: 2 | ||
updates: | ||
- package-ecosystem: "cargo" # See documentation for possible values | ||
directory: "/" # Location of package manifests | ||
schedule: | ||
interval: "weekly" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
name: test | ||
|
||
on: | ||
workflow_dispatch: | ||
pull_request: | ||
branches: [ "main" ] | ||
|
||
env: | ||
CARGO_TERM_COLOR: always | ||
|
||
jobs: | ||
cargo-fmt: | ||
runs-on: ubuntu-22.04 | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: dtolnay/rust-toolchain@stable | ||
with: | ||
components: rustfmt | ||
- name: Check Style | ||
run: cargo fmt --check | ||
|
||
cargo-test: | ||
runs-on: ubuntu-22.04 | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: dtolnay/rust-toolchain@stable | ||
- name: Run tests | ||
run: cargo test | ||
|
||
build-no_feature: | ||
runs-on: [ubuntu-latest, windows-latest, macos-latest] | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: dtolnay/rust-toolchain@stable | ||
- name: Build | ||
run: cargo build | ||
|
||
build-metal: | ||
runs-on: macos-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: dtolnay/rust-toolchain@stable | ||
- name: Build | ||
run: cargo build --features metal | ||
|
||
build-cuda: | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
platform: [ubuntu-22.04, windows-2019] | ||
runs-on: ${{ matrix.platform }} | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: dtolnay/rust-toolchain@stable | ||
- name: Install cuda toolkit | ||
uses: Jimver/[email protected] | ||
- name: install msvc deps (windows) | ||
if: matrix.platform == 'windows-2019' | ||
uses: ilammy/msvc-dev-cmd@v1 | ||
- name: Build | ||
env: | ||
CUDA_COMPUTE_CAP: "75" | ||
run: cargo build --features cublas | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[submodule "sys/stable-diffusion.cpp"] | ||
path = sys/stable-diffusion.cpp | ||
url = https://github.com/leejet/stable-diffusion.cpp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
[workspace] | ||
members = ["sys"] | ||
|
||
[package] | ||
name = "diffusion-rs" | ||
version = "0.1.0" | ||
edition = "2021" | ||
description = "Rust bindings for stable-diffusion.cpp" | ||
license = "MIT" | ||
documentation = "https://docs.rs/diffusion-rs" | ||
repository = "https://github.com/newfla/diffusion-rs" | ||
|
||
[dependencies] | ||
diffusion-rs-sys = { path = "sys", version = "0.1.0" } | ||
|
||
[features] | ||
cublas = ["diffusion-rs-sys/cublas"] | ||
hipblas = ["diffusion-rs-sys/hipblas"] | ||
metal = ["diffusion-rs-sys/metal"] | ||
vulkan = ["diffusion-rs-sys/vulkan"] | ||
sycl = ["diffusion-rs-sys/sycl"] | ||
flashattn = ["diffusion-rs-sys/flashattn"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
# diffusion-rs | ||
Rust bindings to https://github.com/leejet/stable-diffusion.cpp | ||
|
||
# WIP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
pub fn add(left: u64, right: u64) -> u64 { | ||
left + right | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[test] | ||
fn it_works() { | ||
let result = add(2, 2); | ||
assert_eq!(result, 4); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
[package] | ||
name = "diffusion-rs-sys" | ||
version = "0.1.0" | ||
edition = "2021" | ||
description = "Rust bindings for stable-diffusion.cpp (FFI bindings)" | ||
license = "MIT" | ||
documentation = "https://docs.rs/diffusion-rs-sys" | ||
repository = "https://github.com/newfla/diffusion-rs" | ||
links = "stable-diffusion" | ||
include = [ | ||
"stable-diffusion.cpp/LICENSE", | ||
"stable-diffusion.cpp/CMakeLists.txt", | ||
"stable-diffusion.cpp/stable-diffusion.cpp", | ||
"stable-diffusion.cpp/stable-diffusion.h", | ||
"stable-diffusion.cpp/ggml/src/ggml.c", | ||
"stable-diffusion.cpp/ggml/src/ggml-alloc.c", | ||
"stable-diffusion.cpp/ggml/src/ggml-backend.c", | ||
"stable-diffusion.cpp/ggml/src/ggml-cuda.cu", | ||
"stable-diffusion.cpp/ggml/src/ggml-impl.h", | ||
"stable-diffusion.cpp/ggml/src/ggml-metal.m", | ||
"stable-diffusion.cpp/ggml/src/ggml-metal.metal", | ||
"stable-diffusion.cpp/ggml/src/ggml-quants.h", | ||
"stable-diffusion.cpp/ggml/src/ggml-quants.c", | ||
"stable-diffusion.cpp/ggml/include/ggml.h", | ||
"stable-diffusion.cpp/ggml/include/ggml-alloc.h", | ||
"stable-diffusion.cpp/ggml/include/ggml-backend.h", | ||
"stable-diffusion.cpp/ggml/include/ggml-backend-impl.h", | ||
"stable-diffusion.cpp/ggml/include/ggml-cuda.h", | ||
"stable-diffusion.cpp/ggml/include/ggml-metal.h", | ||
"src/*.rs", | ||
"build.rs", | ||
"wrapper.h", | ||
] | ||
|
||
[dependencies] | ||
|
||
[features] | ||
cublas = [] | ||
hipblas = [] | ||
metal = [] | ||
vulkan = [] | ||
sycl = [] | ||
flashattn = [] | ||
|
||
[build-dependencies] | ||
cmake = "0.1.51" | ||
bindgen = "0.70.1" | ||
cfg-if = "1.0.0" | ||
fs_extra = "1.3.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
use std::{ | ||
env, | ||
fs::{copy, create_dir_all, read_dir}, | ||
path::PathBuf, | ||
}; | ||
|
||
use cmake::Config; | ||
use fs_extra::dir; | ||
|
||
// Heavily ispired by https://github.com/tazz4843/whisper-rs/blob/master/sys/build.rs | ||
|
||
fn main() { | ||
// Link C++ standard library | ||
let target = env::var("TARGET").unwrap(); | ||
if let Some(cpp_stdlib) = get_cpp_link_stdlib(&target) { | ||
println!("cargo:rustc-link-lib=dylib={}", cpp_stdlib); | ||
} | ||
|
||
println!("cargo:rerun-if-changed=wrapper.h"); | ||
|
||
// Copy stable-diffusion code into the build script directory | ||
let out = PathBuf::from(env::var("OUT_DIR").unwrap()); | ||
let diffusion_root = out.join("stable-diffusion.cpp/"); | ||
|
||
if !diffusion_root.exists() { | ||
create_dir_all(&diffusion_root).unwrap(); | ||
dir::copy("./stable-diffusion.cpp", &out, &Default::default()).unwrap_or_else(|e| { | ||
panic!( | ||
"Failed to copy stable-diffusion sources into {}: {}", | ||
diffusion_root.display(), | ||
e | ||
) | ||
}); | ||
} | ||
|
||
// Bindgen | ||
let bindings = bindgen::Builder::default().header("wrapper.h"); | ||
|
||
let bindings = bindings | ||
.clang_arg("-I./stable-diffusion.cpp") | ||
.clang_arg("-I./stable-diffusion.cpp/ggml/include") | ||
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) | ||
.generate(); | ||
|
||
match bindings { | ||
Ok(b) => { | ||
b.write_to_file(out.join("bindings.rs")) | ||
.expect("Couldn't write bindings!"); | ||
} | ||
Err(e) => { | ||
println!("cargo:warning=Unable to generate bindings: {}", e); | ||
println!("cargo:warning=Using bundled bindings.rs, which may be out of date"); | ||
// copy src/bindings.rs to OUT_DIR | ||
copy("src/bindings.rs", out.join("bindings.rs")).expect("Unable to copy bindings.rs"); | ||
} | ||
} | ||
|
||
// stop if we're on docs.rs | ||
if env::var("DOCS_RS").is_ok() { | ||
return; | ||
} | ||
|
||
// Configure cmake for building | ||
let mut config = Config::new(&diffusion_root); | ||
|
||
//Enable cmake feature flags | ||
#[cfg(feature = "cublas")] | ||
{ | ||
println!("cargo:rustc-link-lib=cublas"); | ||
println!("cargo:rustc-link-lib=cudart"); | ||
println!("cargo:rustc-link-lib=cublasLt"); | ||
println!("cargo:rustc-link-lib=cuda"); | ||
|
||
if target.contains("msvc") { | ||
let cuda_path = PathBuf::from(env::var("CUDA_PATH").unwrap()).join("lib/x64"); | ||
println!("cargo:rustc-link-search={}", cuda_path.display()); | ||
} else { | ||
println!("cargo:rustc-link-lib=culibos"); | ||
println!("cargo:rustc-link-search=/usr/local/cuda/lib64"); | ||
println!("cargo:rustc-link-search=/usr/local/cuda/lib64/stubs"); | ||
println!("cargo:rustc-link-search=/opt/cuda/lib64"); | ||
println!("cargo:rustc-link-search=/opt/cuda/lib64/stubs"); | ||
} | ||
|
||
config.define("SD_CUBLAS", "ON"); | ||
} | ||
|
||
#[cfg(feature = "hipblas")] | ||
{ | ||
println!("cargo:rustc-link-lib=hipblas"); | ||
println!("cargo:rustc-link-lib=rocblas"); | ||
println!("cargo:rustc-link-lib=amdhip64"); | ||
|
||
if target.contains("msvc") { | ||
panic!("Due to a problem with the last revision of the ROCm 5.7 library, it is not possible to compile the library for the windows environment.\nSee https://github.com/ggerganov/whisper.cpp/issues/2202 for more details.") | ||
|
||
} else { | ||
println!("cargo:rerun-if-env-changed=HIP_PATH"); | ||
|
||
let hip_path = match env::var("HIP_PATH") { | ||
Ok(path) =>PathBuf::from(path), | ||
Err(_) => PathBuf::from("/opt/rocm"), | ||
}; | ||
let hip_lib_path = hip_path.join("lib"); | ||
|
||
println!("cargo:rustc-link-search={}",hip_lib_path.display()); | ||
} | ||
|
||
config.define("SD_HIPBLAS", "ON"); | ||
if let Ok(target) = env::var("AMDGPU_TARGETS") { | ||
config.define("AMDGPU_TARGETS", target); | ||
} | ||
} | ||
|
||
#[cfg(feature = "metal")] | ||
{ | ||
println!("cargo:rustc-link-lib=framework=Foundation"); | ||
println!("cargo:rustc-link-lib=framework=Metal"); | ||
println!("cargo:rustc-link-lib=framework=MetalKit"); | ||
config.define("SD_METAL", "ON"); | ||
} | ||
|
||
#[cfg(feature = "vulkan")] | ||
{ | ||
if target.contains("msvc") { | ||
println!("cargo:rerun-if-env-changed=VULKAN_SDK"); | ||
println!("cargo:rustc-link-lib=vulkan-1"); | ||
let vulkan_path = match env::var("VULKAN_SDK") { | ||
Ok(path) => PathBuf::from(path), | ||
Err(_) => panic!( | ||
"Please install Vulkan SDK and ensure that VULKAN_SDK env variable is set" | ||
), | ||
}; | ||
let vulkan_lib_path = vulkan_path.join("Lib"); | ||
println!("cargo:rustc-link-search={}", vulkan_lib_path.display()); | ||
} else { | ||
println!("cargo:rustc-link-lib=vulkan"); | ||
} | ||
config.define("SD_VULKAN", "ON"); | ||
} | ||
|
||
#[cfg(feature = "sycl")] | ||
{ | ||
config.define("SD_SYCL", "ON"); | ||
panic!("Not yet supported!"); | ||
} | ||
|
||
#[cfg(feature = "flashattn")] | ||
{ | ||
config.define("SD_FLASH_ATTN", "ON"); | ||
panic!("Broken in 2024/09/02 release!"); | ||
} | ||
|
||
config | ||
.profile("Release") | ||
.define("SD_BUILD_SHARED_LIBS", "OFF") | ||
.define("SD_BUILD_EXAMPLES", "OFF") | ||
.very_verbose(true) | ||
.pic(true); | ||
|
||
let destination = config.build(); | ||
|
||
add_link_search_path(&out.join("lib")).unwrap(); | ||
add_link_search_path(&out.join("build")).unwrap(); | ||
|
||
println!("cargo:rustc-link-search=native={}", destination.display()); | ||
println!("cargo:rustc-link-lib=static=stable-diffusion"); | ||
println!("cargo:rustc-link-lib=static=ggml"); | ||
} | ||
|
||
fn add_link_search_path(dir: &std::path::Path) -> std::io::Result<()> { | ||
if dir.is_dir() { | ||
println!("cargo:rustc-link-search={}", dir.display()); | ||
for entry in read_dir(dir)? { | ||
add_link_search_path(&entry?.path())?; | ||
} | ||
} | ||
Ok(()) | ||
} | ||
|
||
// From https://github.com/alexcrichton/cc-rs/blob/fba7feded71ee4f63cfe885673ead6d7b4f2f454/src/lib.rs#L2462 | ||
fn get_cpp_link_stdlib(target: &str) -> Option<&'static str> { | ||
if target.contains("msvc") { | ||
None | ||
} else if target.contains("apple") || target.contains("freebsd") || target.contains("openbsd") { | ||
Some("c++") | ||
} else if target.contains("android") { | ||
Some("c++_shared") | ||
} else { | ||
Some("stdc++") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#![allow(non_upper_case_globals)] | ||
#![allow(non_camel_case_types)] | ||
#![allow(non_snake_case)] | ||
|
||
include!(concat!(env!("OUT_DIR"), "/bindings.rs")); |
Submodule stable-diffusion.cpp
added at
e410ae
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
#include <stable-diffusion.h> |