Skip to content

Commit

Permalink
devices param added
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Nov 26, 2024
1 parent 879a1ab commit 3e770cf
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
15 changes: 5 additions & 10 deletions llama_bringup/launch/base.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def run_llama(context: LaunchContext, embedding, reranking):
"n_batch": LaunchConfiguration("n_batch", default=2048),
"n_ubatch": LaunchConfiguration("n_batch", default=512),
# GPU params
"devices": LaunchConfiguration("devices", default="['']"),
"n_gpu_layers": LaunchConfiguration("n_gpu_layers", default=0),
"split_mode": LaunchConfiguration("split_mode", default="layer"),
"main_gpu": LaunchConfiguration("main_gpu", default=0),
Expand Down Expand Up @@ -92,12 +93,9 @@ def run_llama(context: LaunchContext, embedding, reranking):
"n_keep": LaunchConfiguration("n_keep", default=-1),
# paths params
"model": LaunchConfiguration("model", default=""),
"lora_adapters": ParameterValue(
LaunchConfiguration("lora_adapters", default=[""]), value_type=List[str]
),
"lora_adapters_scales": ParameterValue(
LaunchConfiguration("lora_adapters_scales", default=[0.0]),
value_type=List[float],
"lora_adapters": LaunchConfiguration("lora_adapters", default="['']"),
"lora_adapters_scales": LaunchConfiguration(
"lora_adapters_scales", default="[0.0]"
),
"mmproj": LaunchConfiguration("mmproj", default=""),
"numa": LaunchConfiguration("numa", default="none"),
Expand All @@ -109,10 +107,7 @@ def run_llama(context: LaunchContext, embedding, reranking):
"suffix": ParameterValue(
LaunchConfiguration("suffix", default=""), value_type=str
),
"stopping_words": ParameterValue(
LaunchConfiguration("stopping_words", default=[""]),
value_type=List[str],
),
"stopping_words": LaunchConfiguration("stopping_words", default="['']"),
"image_prefix": ParameterValue(
LaunchConfiguration("image_prefix", default=""), value_type=str
),
Expand Down
27 changes: 24 additions & 3 deletions llama_ros/src/llama_utils/llama_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ void llama_utils::declare_llama_params(
{"image_suffix", ""},
{"image_text", "<image>"},
});
node->declare_parameter<std::vector<std::string>>(
"devices", std::vector<std::string>({}));
node->declare_parameter<std::vector<std::string>>(
"lora_adapters", std::vector<std::string>({}));
node->declare_parameter<std::vector<double>>("lora_adapters_scales",
std::vector<double>({}));
node->declare_parameter<std::vector<std::string>>(
"stopping_words", std::vector<std::string>({}));
node->declare_parameters<float>("", {
Expand All @@ -101,6 +101,8 @@ void llama_utils::declare_llama_params(
});
node->declare_parameter<std::vector<double>>("tensor_split",
std::vector<double>({0.0}));
node->declare_parameter<std::vector<double>>("lora_adapters_scales",
std::vector<double>({}));
node->declare_parameters<bool>("", {
{"debug", true},
{"embedding", false},
Expand All @@ -125,9 +127,12 @@ struct LlamaParams llama_utils::get_llama_params(
int32_t poll;
int32_t poll_batch;

std::vector<std::string> stopping_words;

std::vector<std::string> lora_adapters;
std::vector<double> lora_adapters_scales;
std::vector<std::string> stopping_words;

std::vector<std::string> devices;
std::vector<double> tensor_split;

std::string cpu_mask;
Expand All @@ -151,6 +156,7 @@ struct LlamaParams llama_utils::get_llama_params(
node->get_parameter("n_batch", params.params.n_batch);
node->get_parameter("n_ubatch", params.params.n_ubatch);

node->get_parameter("devices", devices);
node->get_parameter("n_gpu_layers", params.params.n_gpu_layers);
node->get_parameter("split_mode", split_mode);
node->get_parameter("main_gpu", params.params.main_gpu);
Expand Down Expand Up @@ -232,6 +238,21 @@ struct LlamaParams llama_utils::get_llama_params(
params.params.sampling.seed = seed;
}

// devices
for (const std::string &d : devices) {

if (!d.empty()) {
auto *dev = ggml_backend_dev_by_name(d.c_str());

if (!dev || ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) {
throw std::invalid_argument(
string_format("invalid device: %s", d.c_str()));
}

params.params.devices.push_back(dev);
}
}

// check threads number
if (params.params.cpuparams.n_threads < 0) {
params.params.cpuparams.n_threads = cpu_get_num_math();
Expand Down

0 comments on commit 3e770cf

Please sign in to comment.