Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: bundle CUDA DLL into the release #62

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
52a48da
fix: bundle CUDA DLL into the release
louisgv Jul 2, 2023
43330d4
Merge branch 'main' into 61-bug-cuda-dlls
louisgv Jul 3, 2023
4d94a47
Merge branch 'main' into 61-bug-cuda-dlls
louisgv Jul 14, 2023
953af83
Merge branch 'main' into 61-bug-cuda-dlls
LLukas22 Jul 18, 2023
4b71716
Update `rustformers` + check gpu
LLukas22 Jul 18, 2023
4b8fe59
Set `n_batch` correctly
LLukas22 Jul 18, 2023
187b135
Copy cuda libraries
LLukas22 Jul 20, 2023
9343897
reduce feeding delay if gpu is enabled
LLukas22 Jul 21, 2023
a2a3dbf
Copy `opencl` dlls
LLukas22 Jul 21, 2023
a8b3bbf
create linux ci
LLukas22 Jul 21, 2023
21ae9e1
defaults for release infos
LLukas22 Jul 21, 2023
286574d
Fail if files aren't found
LLukas22 Jul 21, 2023
86cc051
Add windows build
LLukas22 Jul 21, 2023
47f9dfc
Macos build
LLukas22 Jul 21, 2023
7c1f25a
ci bugfixes
LLukas22 Jul 22, 2023
36e050b
More bugfixes and absolute paths
LLukas22 Jul 22, 2023
0b26205
Paths .... again
LLukas22 Jul 22, 2023
cc786f0
Make mac artifacts unique
LLukas22 Jul 22, 2023
89eb1fa
renable build for windows-cublas
LLukas22 Jul 22, 2023
0761d79
update character
louisgv Jul 30, 2023
7481edf
Slight refactor
louisgv Aug 1, 2023
9d23cfd
update character
louisgv Aug 2, 2023
5b51725
update llm
louisgv Aug 2, 2023
006cd5a
Merge branch 'main' into 61-bug-cuda-dlls
louisgv Sep 16, 2023
9b8d16d
fix build script
louisgv Sep 16, 2023
bc5edf6
use self-hosted runner for metal
louisgv Sep 16, 2023
18f04ed
remove build on push (consume too much compute atm)
louisgv Sep 16, 2023
1211cc2
Add todo
louisgv Sep 16, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/tauri.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ jobs:
fail-fast: false
matrix:
platform: [macos-latest, ubuntu-latest, windows-latest, self-hosted]
# include:
# - platform: windows-latest
# args: --config tauri.windows.conf.json

runs-on: ${{ matrix.platform }}

Expand Down Expand Up @@ -113,3 +116,4 @@ jobs:
releaseBody: ${{ github.event.release.body }}
releaseId: ${{ github.event.release.id }}
tagName: ${{ github.event.release.tag_name }}
# args: "${{ matrix.args }}"
18 changes: 8 additions & 10 deletions apps/desktop/src-tauri/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@ edition = "2021"

[build-dependencies]
tauri-build = { version = "1.4.0", features = [] }
glob = "0.3"

[dependencies]
llm = { git = "https://github.com/rustformers/llm", branch = "main", package = "llm", features = [
# Peg the llm version here to prevent unwanted breaking changes
llm = { git = "https://github.com/rustformers/llm", rev = "645093e", package = "llm", features = [
"default",
# "cublas",
] }

# llm = { git = "https://github.com/RedBoxing/llm.git", branch = "hf-tokenizer", package = "llm" }

tauri = { version = "1.4.0", features = [
"reqwest-client",
"dialog-confirm",
Expand Down Expand Up @@ -76,15 +75,10 @@ blake3 = "1.3.3"
cocoa = "0.24.1"
objc = "0.2.7"

[target.aarch64-apple-darwin.dependencies]
llm = { git = "https://github.com/rustformers/llm", branch = "main", package = "llm", features = [
"default",
"metal",
] }

[target."cfg(target_os = \"linux\")".dependencies]
webkit2gtk = "0.18.2"


[target."cfg(target_os = \"windows\")".dependencies]
webview2-com = "0.19.1"
windows = "0.39.0"
Expand All @@ -102,6 +96,10 @@ default = ["custom-protocol"]
# DO NOT remove this
custom-protocol = ["tauri/custom-protocol"]

cublas = ["llm/cublas"]
clblast = ["llm/clblast"]
metal = ["llm/metal"]

[profile.dev.package."*"]
opt-level = 3

Expand Down
132 changes: 109 additions & 23 deletions apps/desktop/src-tauri/build.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,111 @@
// #[cfg(any(target_os = "macos", target_os = "linux"))]
fn main() {
tauri_build::build()
use std::{env, path::{Path, PathBuf}, fs};
extern crate glob;
use glob::glob;

fn main() {
#[cfg(feature = "cublas")]
copy_cuda_dlls();
#[cfg(feature = "clblast")]
copy_opencl_dlls();

tauri_build::build();
}

#[allow(dead_code)]
fn get_build_dir()->PathBuf{
let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let mut build_dir = Path::new(&manifest_dir).join("target");
build_dir.push(env::var("PROFILE").unwrap());
build_dir
}

// #[cfg(target_os = "windows")]
// fn main() {
// let mut windows = tauri_build::WindowsAttributes::new();
// windows = windows.app_manifest(
// r#"
// <assembly xmlns="urn:schemas-microsoft-com:asm.v1" manifestVersion="1.0">
// <trustInfo xmlns="urn:schemas-microsoft-com:asm.v3">
// <security>
// <requestedPrivileges>
// <requestedExecutionLevel level="requireAdministrator" uiAccess="false" />
// </requestedPrivileges>
// </security>
// </trustInfo>
// </assembly>
// "#,
// );

// tauri_build::try_build(tauri_build::Attributes::new().windows_attributes(windows))
// .expect("failed to run build script")
// }
#[allow(dead_code)]
fn copy_cuda_dlls(){
// Get the directory of the output executable.
let out_dir = get_build_dir();

// Get the CUDA path from the environment variable.
let cude_env = env::var("CUDA_PATH").expect("CUDA_PATH not found");
let cuda_path = Path::new(&cude_env);

// Patterns to search for the DLL files.
#[cfg(target_os = "windows")]
let patterns = [
"cublas64_*.dll",
"cublasLt64_*.dll",
"cudart64_*.dll"
];
#[cfg(target_os = "windows")]
let binary_path = cuda_path.join("bin");

#[cfg(target_os = "linux")]
let patterns = [
"libcudart.so",
"libcublasLt.so",
"libcublas.so"
];
#[cfg(target_os = "linux")]
let binary_path = cuda_path.join("lib64");


for pattern in &patterns {
// Construct the full glob pattern.
let full_pattern = format!("{}/{}", binary_path.to_str().unwrap(), pattern);

// Use glob to find the DLL files.
for entry in glob(&full_pattern).expect("Failed to read glob pattern") {
match entry {
Ok(dll_path) => {
// Copy the DLL file to the output directory.
let dll_file_name = dll_path.file_name().unwrap();
let destination = Path::new(&out_dir).join(dll_file_name);
if !destination.exists() {
fs::copy(&dll_path, &destination)
.expect("Failed to copy DLL");
println!("Moved {} to {}", dll_file_name.to_string_lossy(), destination.to_string_lossy());
}

},
Err(e) => panic!("{}",e),
}

}
}
}

#[allow(dead_code)]
fn copy_opencl_dlls(){
// Get the directory of the output executable.
let out_dir = get_build_dir();

let copy_dll = |source:PathBuf| {
let dll_file_name = source.file_name().unwrap();
let destination = Path::new(&out_dir).join(dll_file_name);
if !destination.exists() {
fs::copy(&source, &destination)
.expect(format!("Failed to copy DLL {}", dll_file_name.to_string_lossy()).as_str());
println!("Moved {} to {}", dll_file_name.to_string_lossy(), destination.to_string_lossy());
}
};

let clblast_dll;
let opencl_dll;
#[cfg(target_os = "windows")]
{
let clblast_dir = env::var("CLBLAST_PATH").expect("CLBLAST_PATH not found!");
clblast_dll = Path::new(&clblast_dir).join("bin").join("clblast.dll");

let opencl_dir = env::var("OPENCL_PATH").expect("OPENCL_PATH not found!");
opencl_dll = Path::new(&opencl_dir).join("bin").join("OpenCL.dll");
}

#[cfg(target_os = "linux")]
{
let lib_path = Path::new("/usr/lib/x86_64-linux-gnu");
clblast_dll = lib_path.join("libclblast.so");
opencl_dll = lib_path.join("libOpenCL.so");
}

copy_dll(clblast_dll);
copy_dll(opencl_dll);
}
3 changes: 1 addition & 2 deletions apps/desktop/src-tauri/src/inference/gpu.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#[tauri::command]
pub async fn check_gpu() -> Result<bool, String> {
// TODO: actually check if Metal is available in the future (?)
if cfg!(all(target_os = "macos", target_arch = "aarch64")) {
if llm::ggml_get_accelerator() != llm::GgmlAccelerator::None {
Ok(true)
} else {
Ok(false)
Expand Down
25 changes: 17 additions & 8 deletions apps/desktop/src-tauri/src/inference/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,7 @@ impl InferenceThreadRequest {
fn get_inference_params(
completion_request: &CompletionRequest,
) -> InferenceParameters {
let n_threads = model::pool::get_n_threads();

let n_batch = if get_use_gpu() { 240 } else { n_threads };

InferenceParameters {
n_threads,
n_batch,
sampler: Arc::new(completion_request.to_top_p_top_k()),
}
}
Expand All @@ -95,7 +89,23 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> {
}
};

let mut session = model.start_session(Default::default());
let n_threads = model::pool::get_n_threads();

// set the batch_size according to the accelerator
let backend = llm::ggml_get_accelerator();
let n_batch = match backend{
llm::GgmlAccelerator::Metal => if get_use_gpu() {1} else {n_threads}, // 1 is the only supported batch size for Metal
llm::GgmlAccelerator::None => n_threads,
_ => if get_use_gpu() {512} else {n_threads}
};

let session_config = llm::InferenceSessionConfig {
n_batch: n_batch,
n_threads: n_threads,
..Default::default()
};

let mut session = model.start_session(session_config);

let mut output_request = OutputRequest::default();

Expand All @@ -109,7 +119,6 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> {

match session.feed_prompt::<Infallible, Prompt>(
model.as_ref(),
&inference_params,
req.completion_request.prompt.as_str().into(),
&mut output_request,
|t| {
Expand Down
2 changes: 1 addition & 1 deletion apps/desktop/src/providers/thread.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ const useThreadProvider = ({ thread }: { thread: FileInfo }) => {
{
async onComment(comment) {
setStatusMessage(comment)
await wait(42)
await wait(serverConfig.data.useGpu ? 3 : 42)
},
async onData(data) {
const resp = JSON.parse(data) as StreamResponse
Expand Down