Skip to content

Commit

Permalink
feat: binding crate
Browse files Browse the repository at this point in the history
  • Loading branch information
newfla committed Oct 15, 2024
1 parent c59617b commit 51c4859
Show file tree
Hide file tree
Showing 11 changed files with 364 additions and 0 deletions.
11 changes: 11 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file

version: 2
updates:
- package-ecosystem: "cargo" # See documentation for possible values
directory: "/" # Location of package manifests
schedule:
interval: "weekly"
64 changes: 64 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
name: test

on:
workflow_dispatch:
pull_request:
branches: [ "main" ]

env:
CARGO_TERM_COLOR: always

jobs:
cargo-fmt:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt
- name: Check Style
run: cargo fmt --check

cargo-test:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- name: Run tests
run: cargo test

build-no_feature:
runs-on: [ubuntu-latest, windows-latest, macos-latest]
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- name: Build
run: cargo build

build-metal:
runs-on: macos-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- name: Build
run: cargo build --features metal

build-cuda:
strategy:
fail-fast: false
matrix:
platform: [ubuntu-22.04, windows-2019]
runs-on: ${{ matrix.platform }}
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- name: Install cuda toolkit
uses: Jimver/[email protected]
- name: install msvc deps (windows)
if: matrix.platform == 'windows-2019'
uses: ilammy/msvc-dev-cmd@v1
- name: Build
env:
CUDA_COMPUTE_CAP: "75"
run: cargo build --features cublas

3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "sys/stable-diffusion.cpp"]
path = sys/stable-diffusion.cpp
url = https://github.com/leejet/stable-diffusion.cpp
22 changes: 22 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[workspace]
members = ["sys"]

[package]
name = "diffusion-rs"
version = "0.1.0"
edition = "2021"
description = "Rust bindings for stable-diffusion.cpp"
license = "MIT"
documentation = "https://docs.rs/diffusion-rs"
repository = "https://github.com/newfla/diffusion-rs"

[dependencies]
diffusion-rs-sys = { path = "sys", version = "0.1.0" }

[features]
cublas = ["diffusion-rs-sys/cublas"]
hipblas = ["diffusion-rs-sys/hipblas"]
metal = ["diffusion-rs-sys/metal"]
vulkan = ["diffusion-rs-sys/vulkan"]
sycl = ["diffusion-rs-sys/sycl"]
flashattn = ["diffusion-rs-sys/flashattn"]
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
# diffusion-rs
Rust bindings to https://github.com/leejet/stable-diffusion.cpp

# WIP
14 changes: 14 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
pub fn add(left: u64, right: u64) -> u64 {
left + right
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn it_works() {
let result = add(2, 2);
assert_eq!(result, 4);
}
}
49 changes: 49 additions & 0 deletions sys/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
[package]
name = "diffusion-rs-sys"
version = "0.1.0"
edition = "2021"
description = "Rust bindings for stable-diffusion.cpp (FFI bindings)"
license = "MIT"
documentation = "https://docs.rs/diffusion-rs-sys"
repository = "https://github.com/newfla/diffusion-rs"
links = "stable-diffusion"
include = [
"stable-diffusion.cpp/LICENSE",
"stable-diffusion.cpp/CMakeLists.txt",
"stable-diffusion.cpp/stable-diffusion.cpp",
"stable-diffusion.cpp/stable-diffusion.h",
"stable-diffusion.cpp/ggml/src/ggml.c",
"stable-diffusion.cpp/ggml/src/ggml-alloc.c",
"stable-diffusion.cpp/ggml/src/ggml-backend.c",
"stable-diffusion.cpp/ggml/src/ggml-cuda.cu",
"stable-diffusion.cpp/ggml/src/ggml-impl.h",
"stable-diffusion.cpp/ggml/src/ggml-metal.m",
"stable-diffusion.cpp/ggml/src/ggml-metal.metal",
"stable-diffusion.cpp/ggml/src/ggml-quants.h",
"stable-diffusion.cpp/ggml/src/ggml-quants.c",
"stable-diffusion.cpp/ggml/include/ggml.h",
"stable-diffusion.cpp/ggml/include/ggml-alloc.h",
"stable-diffusion.cpp/ggml/include/ggml-backend.h",
"stable-diffusion.cpp/ggml/include/ggml-backend-impl.h",
"stable-diffusion.cpp/ggml/include/ggml-cuda.h",
"stable-diffusion.cpp/ggml/include/ggml-metal.h",
"src/*.rs",
"build.rs",
"wrapper.h",
]

[dependencies]

[features]
cublas = []
hipblas = []
metal = []
vulkan = []
sycl = []
flashattn = []

[build-dependencies]
cmake = "0.1.51"
bindgen = "0.70.1"
cfg-if = "1.0.0"
fs_extra = "1.3.0"
192 changes: 192 additions & 0 deletions sys/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
use std::{
env,
fs::{copy, create_dir_all, read_dir},
path::PathBuf,
};

use cmake::Config;
use fs_extra::dir;

// Heavily ispired by https://github.com/tazz4843/whisper-rs/blob/master/sys/build.rs

fn main() {
// Link C++ standard library
let target = env::var("TARGET").unwrap();
if let Some(cpp_stdlib) = get_cpp_link_stdlib(&target) {
println!("cargo:rustc-link-lib=dylib={}", cpp_stdlib);
}

println!("cargo:rerun-if-changed=wrapper.h");

// Copy stable-diffusion code into the build script directory
let out = PathBuf::from(env::var("OUT_DIR").unwrap());
let diffusion_root = out.join("stable-diffusion.cpp/");

if !diffusion_root.exists() {
create_dir_all(&diffusion_root).unwrap();
dir::copy("./stable-diffusion.cpp", &out, &Default::default()).unwrap_or_else(|e| {
panic!(
"Failed to copy stable-diffusion sources into {}: {}",
diffusion_root.display(),
e
)
});
}

// Bindgen
let bindings = bindgen::Builder::default().header("wrapper.h");

let bindings = bindings
.clang_arg("-I./stable-diffusion.cpp")
.clang_arg("-I./stable-diffusion.cpp/ggml/include")
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.generate();

match bindings {
Ok(b) => {
b.write_to_file(out.join("bindings.rs"))
.expect("Couldn't write bindings!");
}
Err(e) => {
println!("cargo:warning=Unable to generate bindings: {}", e);
println!("cargo:warning=Using bundled bindings.rs, which may be out of date");
// copy src/bindings.rs to OUT_DIR
copy("src/bindings.rs", out.join("bindings.rs")).expect("Unable to copy bindings.rs");
}
}

// stop if we're on docs.rs
if env::var("DOCS_RS").is_ok() {
return;
}

// Configure cmake for building
let mut config = Config::new(&diffusion_root);

//Enable cmake feature flags
#[cfg(feature = "cublas")]
{
println!("cargo:rustc-link-lib=cublas");
println!("cargo:rustc-link-lib=cudart");
println!("cargo:rustc-link-lib=cublasLt");
println!("cargo:rustc-link-lib=cuda");

if target.contains("msvc") {
let cuda_path = PathBuf::from(env::var("CUDA_PATH").unwrap()).join("lib/x64");
println!("cargo:rustc-link-search={}", cuda_path.display());
} else {
println!("cargo:rustc-link-lib=culibos");
println!("cargo:rustc-link-search=/usr/local/cuda/lib64");
println!("cargo:rustc-link-search=/usr/local/cuda/lib64/stubs");
println!("cargo:rustc-link-search=/opt/cuda/lib64");
println!("cargo:rustc-link-search=/opt/cuda/lib64/stubs");
}

config.define("SD_CUBLAS", "ON");
}

#[cfg(feature = "hipblas")]
{
println!("cargo:rustc-link-lib=hipblas");
println!("cargo:rustc-link-lib=rocblas");
println!("cargo:rustc-link-lib=amdhip64");

if target.contains("msvc") {
panic!("Due to a problem with the last revision of the ROCm 5.7 library, it is not possible to compile the library for the windows environment.\nSee https://github.com/ggerganov/whisper.cpp/issues/2202 for more details.")

} else {
println!("cargo:rerun-if-env-changed=HIP_PATH");

let hip_path = match env::var("HIP_PATH") {
Ok(path) =>PathBuf::from(path),
Err(_) => PathBuf::from("/opt/rocm"),
};
let hip_lib_path = hip_path.join("lib");

println!("cargo:rustc-link-search={}",hip_lib_path.display());
}

config.define("SD_HIPBLAS", "ON");
if let Ok(target) = env::var("AMDGPU_TARGETS") {
config.define("AMDGPU_TARGETS", target);
}
}

#[cfg(feature = "metal")]
{
println!("cargo:rustc-link-lib=framework=Foundation");
println!("cargo:rustc-link-lib=framework=Metal");
println!("cargo:rustc-link-lib=framework=MetalKit");
config.define("SD_METAL", "ON");
}

#[cfg(feature = "vulkan")]
{
if target.contains("msvc") {
println!("cargo:rerun-if-env-changed=VULKAN_SDK");
println!("cargo:rustc-link-lib=vulkan-1");
let vulkan_path = match env::var("VULKAN_SDK") {
Ok(path) => PathBuf::from(path),
Err(_) => panic!(
"Please install Vulkan SDK and ensure that VULKAN_SDK env variable is set"
),
};
let vulkan_lib_path = vulkan_path.join("Lib");
println!("cargo:rustc-link-search={}", vulkan_lib_path.display());
} else {
println!("cargo:rustc-link-lib=vulkan");
}
config.define("SD_VULKAN", "ON");
}

#[cfg(feature = "sycl")]
{
config.define("SD_SYCL", "ON");
panic!("Not yet supported!");
}

#[cfg(feature = "flashattn")]
{
config.define("SD_FLASH_ATTN", "ON");
panic!("Broken in 2024/09/02 release!");
}

config
.profile("Release")
.define("SD_BUILD_SHARED_LIBS", "OFF")
.define("SD_BUILD_EXAMPLES", "OFF")
.very_verbose(true)
.pic(true);

let destination = config.build();

add_link_search_path(&out.join("lib")).unwrap();
add_link_search_path(&out.join("build")).unwrap();

println!("cargo:rustc-link-search=native={}", destination.display());
println!("cargo:rustc-link-lib=static=stable-diffusion");
println!("cargo:rustc-link-lib=static=ggml");
}

fn add_link_search_path(dir: &std::path::Path) -> std::io::Result<()> {
if dir.is_dir() {
println!("cargo:rustc-link-search={}", dir.display());
for entry in read_dir(dir)? {
add_link_search_path(&entry?.path())?;
}
}
Ok(())
}

// From https://github.com/alexcrichton/cc-rs/blob/fba7feded71ee4f63cfe885673ead6d7b4f2f454/src/lib.rs#L2462
fn get_cpp_link_stdlib(target: &str) -> Option<&'static str> {
if target.contains("msvc") {
None
} else if target.contains("apple") || target.contains("freebsd") || target.contains("openbsd") {
Some("c++")
} else if target.contains("android") {
Some("c++_shared")
} else {
Some("stdc++")
}
}
5 changes: 5 additions & 0 deletions sys/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]

include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
1 change: 1 addition & 0 deletions sys/stable-diffusion.cpp
Submodule stable-diffusion.cpp added at e410ae
1 change: 1 addition & 0 deletions sys/wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#include <stable-diffusion.h>

0 comments on commit 51c4859

Please sign in to comment.