From 911db384075907658c5ec3d6d746064bb4ab3dd7 Mon Sep 17 00:00:00 2001 From: anilaltuner Date: Sun, 12 Jan 2025 23:05:48 +0300 Subject: [PATCH 1/3] vllm executor added --- src/api_interface/mod.rs | 1 + src/program/executor.rs | 26 ++++++++++++-------------- src/program/models.rs | 7 +++++++ 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/api_interface/mod.rs b/src/api_interface/mod.rs index 6b6a670..d272905 100644 --- a/src/api_interface/mod.rs +++ b/src/api_interface/mod.rs @@ -1,3 +1,4 @@ pub mod gem_api; pub mod open_router; pub mod openai_api; +pub mod vllm; diff --git a/src/program/executor.rs b/src/program/executor.rs index a4d9473..5d4b023 100644 --- a/src/program/executor.rs +++ b/src/program/executor.rs @@ -5,6 +5,7 @@ use super::workflow::Workflow; use crate::api_interface::gem_api::GeminiExecutor; use crate::api_interface::open_router::OpenRouterExecutor; use crate::api_interface::openai_api::OpenAIExecutor; +use crate::api_interface::vllm::VLLMExecutor; use crate::memory::types::Entry; use crate::memory::{MemoryReturnType, ProgramMemory}; use crate::program::atomics::MessageInput; @@ -24,20 +25,9 @@ use base64::prelude::*; use log::{debug, error, info, warn}; use rand::seq::SliceRandom; -use ollama_rs::{ - error::OllamaError, - generation::chat::request::ChatMessageRequest, - generation::chat::ChatMessage, - generation::completion::request::GenerationRequest, - generation::functions::tools::StockScraper, - generation::functions::tools::Tool, - generation::functions::{ - DDGSearcher, FunctionCallRequest, LlamaFunctionCall, OpenAIFunctionCall, Scraper, - }, - generation::options::GenerationOptions, - generation::parameters::FormatType, - Ollama, -}; +use ollama_rs::{error::OllamaError, generation::chat::request::ChatMessageRequest, generation::chat::ChatMessage, generation::completion::request::GenerationRequest, generation::functions::tools::StockScraper, generation::functions::tools::Tool, generation::functions::{ + DDGSearcher, FunctionCallRequest, LlamaFunctionCall, OpenAIFunctionCall, Scraper, +}, generation::options::GenerationOptions, generation::parameters::FormatType, Ollama}; fn log_colored(msg: &str) { let colors = ["red", "green", "yellow", "blue", "magenta", "cyan"]; @@ -585,6 +575,10 @@ impl Executor { OpenRouterExecutor::new(self.model.to_string(), api_key.clone()); openai_executor.generate_text(input, schema).await? } + ModelProvider::VLLM => { + let executor = VLLMExecutor::new(self.model.to_string(), "http://localhost:8000".to_string()); + executor.generate_text(input, schema).await? + } }; Ok(response) @@ -669,6 +663,10 @@ impl Executor { .function_call(prompt, tools, raw_mode, oai_parser) .await? } + ModelProvider::VLLM => { + let executor = VLLMExecutor::new(self.model.to_string(), "http://localhost:8000".to_string()); + executor.function_call(prompt, tools, raw_mode, oai_parser).await? + } }; Ok(result) diff --git a/src/program/models.rs b/src/program/models.rs index 4923f89..5d67382 100644 --- a/src/program/models.rs +++ b/src/program/models.rs @@ -197,6 +197,9 @@ pub enum Model { #[serde(rename = "openai/o1")] OROpenAIO1, + + #[serde(rename = "Qwen/Qwen2.5-1.5B-Instruct")] + Qwen25Vllm } impl Model { @@ -264,6 +267,8 @@ pub enum ModelProvider { Gemini, #[serde(rename = "openrouter")] OpenRouter, + #[serde(rename = "VLLM")] + VLLM, } impl From for ModelProvider { @@ -331,6 +336,8 @@ impl From for ModelProvider { Model::ORNemotron70B => ModelProvider::OpenRouter, Model::ORNousHermes405B => ModelProvider::OpenRouter, Model::OROpenAIO1 => ModelProvider::OpenRouter, + //vllm + Model::Qwen25Vllm => ModelProvider::VLLM } } } From 4532978e5487747d414c4ed6c17a5dc3f7a5bcc8 Mon Sep 17 00:00:00 2001 From: anilaltuner Date: Sun, 12 Jan 2025 23:42:34 +0300 Subject: [PATCH 2/3] vllm class added --- src/api_interface/vllm.rs | 209 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 src/api_interface/vllm.rs diff --git a/src/api_interface/vllm.rs b/src/api_interface/vllm.rs new file mode 100644 index 0000000..58d3fab --- /dev/null +++ b/src/api_interface/vllm.rs @@ -0,0 +1,209 @@ +use crate::program::atomics::MessageInput; +use ollama_rs::{error::OllamaError, generation::functions::tools::Tool, generation::functions::OpenAIFunctionCall, IntoUrlSealed}; +use openai_dive::v1::api::Client; +use openai_dive::v1::resources::chat::*; +use serde_json::{json, Value}; +use std::sync::Arc; + +pub struct VLLMExecutor { + model: String, + base_url: String, + client: Client, +} + +impl VLLMExecutor { + pub fn new(model: String, base_url: String) -> Self { + Self { + model, + base_url: base_url.clone(), + client: Client::new_with_base(base_url.as_str(), "".to_string()), + } + } + + pub async fn generate_text( + &self, + input: Vec, + schema: &Option, + ) -> Result { + let messages: Vec = input + .into_iter() + .map(|msg| match msg.role.as_str() { + "user" => ChatMessage::User { + content: ChatMessageContent::Text(msg.content), + name: None, + }, + "assistant" => ChatMessage::Assistant { + content: Some(ChatMessageContent::Text(msg.content)), + tool_calls: None, + name: None, + refusal: None, + }, + "system" => ChatMessage::System { + content: ChatMessageContent::Text(msg.content), + name: None, + }, + _ => ChatMessage::User { + content: ChatMessageContent::Text(msg.content), + name: None, + }, + }) + .collect(); + + let parameters = if let Some(schema_str) = schema { + let mut schema_json: Value = serde_json::from_str(schema_str) + .map_err(|e| OllamaError::from(format!("Invalid schema JSON: {:?}", e)))?; + + if let Value::Object(ref mut map) = schema_json { + map.insert("additionalProperties".to_string(), Value::Bool(false)); + } + + ChatCompletionParametersBuilder::default() + .model(self.model.clone()) + .messages(messages) + .response_format(ChatCompletionResponseFormat::JsonSchema( + JsonSchemaBuilder::default() + .name("structured_output") + .schema(schema_json) + .strict(true) + .build() + .map_err(|e| { + OllamaError::from(format!("Could not build JSON schema: {:?}", e)) + })?, + )) + .build() + } else { + ChatCompletionParametersBuilder::default() + .model(self.model.clone()) + .messages(messages) + .response_format(ChatCompletionResponseFormat::Text) + .build() + } + .map_err(|e| OllamaError::from(format!("Could not build message parameters: {:?}", e)))?; + + let result = self.client.chat().create(parameters).await.map_err(|e| { + OllamaError::from(format!("Failed to parse VLLM API response: {:?}", e)) + })?; + + let message = match &result.choices[0].message { + ChatMessage::Assistant { content, .. } => { + if let Some(ChatMessageContent::Text(text)) = content { + text.clone() + } else { + return Err(OllamaError::from( + "Unexpected message content format".to_string(), + )); + } + } + _ => return Err(OllamaError::from("Unexpected message type".to_string())), + }; + + Ok(message) + } + + pub async fn function_call( + &self, + prompt: &str, + tools: Vec>, + raw_mode: bool, + oai_parser: Arc, + ) -> Result { + let openai_tools: Vec<_> = tools + .iter() + .map(|tool| ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: ChatCompletionFunction { + name: tool.name().to_lowercase().replace(' ', "_"), + description: Some(tool.description()), + parameters: tool.parameters(), + }, + }) + .collect(); + + let messages = vec![ChatMessage::User { + content: ChatMessageContent::Text(prompt.to_string()), + name: None, + }]; + + let parameters = ChatCompletionParametersBuilder::default() + .model(self.model.clone()) + .messages(messages) + .tools(openai_tools) + .build() + .map_err(|e| { + OllamaError::from(format!("Could not build message parameters: {:?}", e)) + })?; + + let result = self.client.chat().create(parameters).await.map_err(|e| { + OllamaError::from(format!("Failed to parse VLLM API response: {:?}", e)) + })?; + let message = result.choices[0].message.clone(); + + if raw_mode { + self.handle_raw_mode(message) + } else { + self.handle_normal_mode(message, tools, oai_parser).await + } + } + + fn handle_raw_mode(&self, message: ChatMessage) -> Result { + let mut raw_calls = Vec::new(); + + if let ChatMessage::Assistant { + tool_calls: Some(tool_calls), + .. + } = message + { + for tool_call in tool_calls { + let call_json = json!({ + "name": tool_call.function.name, + "arguments": serde_json::from_str::(&tool_call.function.arguments)? + }); + raw_calls.push(serde_json::to_string(&call_json)?); + } + } + + Ok(raw_calls.join("\n\n")) + } + + async fn handle_normal_mode( + &self, + message: ChatMessage, + tools: Vec>, + oai_parser: Arc, + ) -> Result { + let mut results = Vec::::new(); + + if let ChatMessage::Assistant { + tool_calls: Some(tool_calls), + .. + } = message + { + for tool_call in tool_calls { + for tool in &tools { + if tool.name().to_lowercase().replace(' ', "_") == tool_call.function.name { + let tool_params: Value = + serde_json::from_str(&tool_call.function.arguments)?; + let res = oai_parser + .function_call_with_history( + tool_call.function.name.clone(), + tool_params, + tool.clone(), + ) + .await; + match res { + Ok(result) => results.push(result.message.unwrap().content), + Err(e) => { + return Err(OllamaError::from(format!( + "Could not generate text: {:?}", + e + ))) + } + } + } + } + } + } + + Ok(results.join("\n")) + } +} From a469721cd440c925c8cb552446a15149c1890416 Mon Sep 17 00:00:00 2001 From: anilaltuner Date: Mon, 13 Jan 2025 10:55:12 +0300 Subject: [PATCH 3/3] format fixes --- src/api_interface/vllm.rs | 9 +++++---- src/program/executor.rs | 27 +++++++++++++++++++++------ src/program/models.rs | 4 ++-- 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/src/api_interface/vllm.rs b/src/api_interface/vllm.rs index 58d3fab..94c9b93 100644 --- a/src/api_interface/vllm.rs +++ b/src/api_interface/vllm.rs @@ -1,5 +1,8 @@ use crate::program::atomics::MessageInput; -use ollama_rs::{error::OllamaError, generation::functions::tools::Tool, generation::functions::OpenAIFunctionCall, IntoUrlSealed}; +use ollama_rs::{ + error::OllamaError, generation::functions::tools::Tool, + generation::functions::OpenAIFunctionCall, +}; use openai_dive::v1::api::Client; use openai_dive::v1::resources::chat::*; use serde_json::{json, Value}; @@ -7,7 +10,6 @@ use std::sync::Arc; pub struct VLLMExecutor { model: String, - base_url: String, client: Client, } @@ -15,7 +17,6 @@ impl VLLMExecutor { pub fn new(model: String, base_url: String) -> Self { Self { model, - base_url: base_url.clone(), client: Client::new_with_base(base_url.as_str(), "".to_string()), } } @@ -78,7 +79,7 @@ impl VLLMExecutor { .response_format(ChatCompletionResponseFormat::Text) .build() } - .map_err(|e| OllamaError::from(format!("Could not build message parameters: {:?}", e)))?; + .map_err(|e| OllamaError::from(format!("Could not build message parameters: {:?}", e)))?; let result = self.client.chat().create(parameters).await.map_err(|e| { OllamaError::from(format!("Failed to parse VLLM API response: {:?}", e)) diff --git a/src/program/executor.rs b/src/program/executor.rs index 5d4b023..5cb709d 100644 --- a/src/program/executor.rs +++ b/src/program/executor.rs @@ -25,9 +25,20 @@ use base64::prelude::*; use log::{debug, error, info, warn}; use rand::seq::SliceRandom; -use ollama_rs::{error::OllamaError, generation::chat::request::ChatMessageRequest, generation::chat::ChatMessage, generation::completion::request::GenerationRequest, generation::functions::tools::StockScraper, generation::functions::tools::Tool, generation::functions::{ - DDGSearcher, FunctionCallRequest, LlamaFunctionCall, OpenAIFunctionCall, Scraper, -}, generation::options::GenerationOptions, generation::parameters::FormatType, Ollama}; +use ollama_rs::{ + error::OllamaError, + generation::chat::request::ChatMessageRequest, + generation::chat::ChatMessage, + generation::completion::request::GenerationRequest, + generation::functions::tools::StockScraper, + generation::functions::tools::Tool, + generation::functions::{ + DDGSearcher, FunctionCallRequest, LlamaFunctionCall, OpenAIFunctionCall, Scraper, + }, + generation::options::GenerationOptions, + generation::parameters::FormatType, + Ollama, +}; fn log_colored(msg: &str) { let colors = ["red", "green", "yellow", "blue", "magenta", "cyan"]; @@ -576,7 +587,8 @@ impl Executor { openai_executor.generate_text(input, schema).await? } ModelProvider::VLLM => { - let executor = VLLMExecutor::new(self.model.to_string(), "http://localhost:8000".to_string()); + let executor = + VLLMExecutor::new(self.model.to_string(), "http://localhost:8000".to_string()); executor.generate_text(input, schema).await? } }; @@ -664,8 +676,11 @@ impl Executor { .await? } ModelProvider::VLLM => { - let executor = VLLMExecutor::new(self.model.to_string(), "http://localhost:8000".to_string()); - executor.function_call(prompt, tools, raw_mode, oai_parser).await? + let executor = + VLLMExecutor::new(self.model.to_string(), "http://localhost:8000".to_string()); + executor + .function_call(prompt, tools, raw_mode, oai_parser) + .await? } }; diff --git a/src/program/models.rs b/src/program/models.rs index 5d67382..a088cb0 100644 --- a/src/program/models.rs +++ b/src/program/models.rs @@ -199,7 +199,7 @@ pub enum Model { OROpenAIO1, #[serde(rename = "Qwen/Qwen2.5-1.5B-Instruct")] - Qwen25Vllm + Qwen25Vllm, } impl Model { @@ -337,7 +337,7 @@ impl From for ModelProvider { Model::ORNousHermes405B => ModelProvider::OpenRouter, Model::OROpenAIO1 => ModelProvider::OpenRouter, //vllm - Model::Qwen25Vllm => ModelProvider::VLLM + Model::Qwen25Vllm => ModelProvider::VLLM, } } }