From 3e770cf683ed978fbb411a486a0975e1f97c729a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Tue, 26 Nov 2024 13:40:25 +0100 Subject: [PATCH] devices param added --- llama_bringup/launch/base.launch.py | 15 ++++-------- llama_ros/src/llama_utils/llama_params.cpp | 27 +++++++++++++++++++--- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/llama_bringup/launch/base.launch.py b/llama_bringup/launch/base.launch.py index 7557090f..e10bde1f 100644 --- a/llama_bringup/launch/base.launch.py +++ b/llama_bringup/launch/base.launch.py @@ -42,6 +42,7 @@ def run_llama(context: LaunchContext, embedding, reranking): "n_batch": LaunchConfiguration("n_batch", default=2048), "n_ubatch": LaunchConfiguration("n_batch", default=512), # GPU params + "devices": LaunchConfiguration("devices", default="['']"), "n_gpu_layers": LaunchConfiguration("n_gpu_layers", default=0), "split_mode": LaunchConfiguration("split_mode", default="layer"), "main_gpu": LaunchConfiguration("main_gpu", default=0), @@ -92,12 +93,9 @@ def run_llama(context: LaunchContext, embedding, reranking): "n_keep": LaunchConfiguration("n_keep", default=-1), # paths params "model": LaunchConfiguration("model", default=""), - "lora_adapters": ParameterValue( - LaunchConfiguration("lora_adapters", default=[""]), value_type=List[str] - ), - "lora_adapters_scales": ParameterValue( - LaunchConfiguration("lora_adapters_scales", default=[0.0]), - value_type=List[float], + "lora_adapters": LaunchConfiguration("lora_adapters", default="['']"), + "lora_adapters_scales": LaunchConfiguration( + "lora_adapters_scales", default="[0.0]" ), "mmproj": LaunchConfiguration("mmproj", default=""), "numa": LaunchConfiguration("numa", default="none"), @@ -109,10 +107,7 @@ def run_llama(context: LaunchContext, embedding, reranking): "suffix": ParameterValue( LaunchConfiguration("suffix", default=""), value_type=str ), - "stopping_words": ParameterValue( - LaunchConfiguration("stopping_words", default=[""]), - value_type=List[str], - ), + "stopping_words": LaunchConfiguration("stopping_words", default="['']"), "image_prefix": ParameterValue( LaunchConfiguration("image_prefix", default=""), value_type=str ), diff --git a/llama_ros/src/llama_utils/llama_params.cpp b/llama_ros/src/llama_utils/llama_params.cpp index afd214ec..83391928 100644 --- a/llama_ros/src/llama_utils/llama_params.cpp +++ b/llama_ros/src/llama_utils/llama_params.cpp @@ -84,10 +84,10 @@ void llama_utils::declare_llama_params( {"image_suffix", ""}, {"image_text", ""}, }); + node->declare_parameter>( + "devices", std::vector({})); node->declare_parameter>( "lora_adapters", std::vector({})); - node->declare_parameter>("lora_adapters_scales", - std::vector({})); node->declare_parameter>( "stopping_words", std::vector({})); node->declare_parameters("", { @@ -101,6 +101,8 @@ void llama_utils::declare_llama_params( }); node->declare_parameter>("tensor_split", std::vector({0.0})); + node->declare_parameter>("lora_adapters_scales", + std::vector({})); node->declare_parameters("", { {"debug", true}, {"embedding", false}, @@ -125,9 +127,12 @@ struct LlamaParams llama_utils::get_llama_params( int32_t poll; int32_t poll_batch; + std::vector stopping_words; + std::vector lora_adapters; std::vector lora_adapters_scales; - std::vector stopping_words; + + std::vector devices; std::vector tensor_split; std::string cpu_mask; @@ -151,6 +156,7 @@ struct LlamaParams llama_utils::get_llama_params( node->get_parameter("n_batch", params.params.n_batch); node->get_parameter("n_ubatch", params.params.n_ubatch); + node->get_parameter("devices", devices); node->get_parameter("n_gpu_layers", params.params.n_gpu_layers); node->get_parameter("split_mode", split_mode); node->get_parameter("main_gpu", params.params.main_gpu); @@ -232,6 +238,21 @@ struct LlamaParams llama_utils::get_llama_params( params.params.sampling.seed = seed; } + // devices + for (const std::string &d : devices) { + + if (!d.empty()) { + auto *dev = ggml_backend_dev_by_name(d.c_str()); + + if (!dev || ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) { + throw std::invalid_argument( + string_format("invalid device: %s", d.c_str())); + } + + params.params.devices.push_back(dev); + } + } + // check threads number if (params.params.cpuparams.n_threads < 0) { params.params.cpuparams.n_threads = cpu_get_num_math();