-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
528482d
commit baf5321
Showing
11 changed files
with
691 additions
and
199 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
use ollama_rs::{ | ||
error::OllamaError, generation::functions::tools::Tool, | ||
generation::functions::OpenAIFunctionCall, | ||
}; | ||
use reqwest::Client; | ||
use serde_json::{json, Value}; | ||
use std::sync::Arc; | ||
|
||
pub struct GeminiExecutor { | ||
model: String, | ||
api_key: String, | ||
client: Client, | ||
} | ||
|
||
impl GeminiExecutor { | ||
pub fn new(model: String, api_key: String) -> Self { | ||
Self { | ||
model, | ||
api_key, | ||
client: Client::new(), | ||
} | ||
} | ||
|
||
pub async fn generate_text(&self, prompt: &str) -> Result<String, OllamaError> { | ||
let url = format!( | ||
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}", | ||
self.model, self.api_key | ||
); | ||
|
||
let body = json!({ | ||
"contents": [{ | ||
"parts": [ | ||
{"text": prompt} | ||
] | ||
}], | ||
"safetySettings": [ | ||
{ | ||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", | ||
"threshold": "BLOCK_ONLY_HIGH" | ||
} | ||
], | ||
"generationConfig": { | ||
"temperature": 1.0, | ||
"maxOutputTokens": 800, | ||
"topP": 0.8, | ||
"topK": 10 | ||
} | ||
}); | ||
|
||
let response = self | ||
.client | ||
.post(&url) | ||
.header("Content-Type", "application/json") | ||
.json(&body) | ||
.send() | ||
.await | ||
.map_err(|e| OllamaError::from(format!("Gemini API request failed: {:?}", e)))?; | ||
|
||
let response_body: Value = response.json().await.map_err(|e| { | ||
OllamaError::from(format!("Failed to parse Gemini API response: {:?}", e)) | ||
})?; | ||
|
||
self.extract_generated_text(response_body) | ||
} | ||
|
||
fn extract_generated_text(&self, response: Value) -> Result<String, OllamaError> { | ||
response["candidates"][0]["content"]["parts"][0]["text"] | ||
.as_str() | ||
.map(|s| s.to_string()) | ||
.ok_or_else(|| { | ||
OllamaError::from("Failed to extract generated text from response".to_string()) | ||
}) | ||
} | ||
|
||
fn extract_tools(&self, response: Value) -> Result<Value, OllamaError> { | ||
let candidate = response["candidates"] | ||
.get(0) | ||
.ok_or_else(|| OllamaError::from("No candidates found in response".to_string()))?; | ||
|
||
let content = &candidate["content"]["parts"][0]; | ||
|
||
if let Some(function_call) = content.get("functionCall") { | ||
Ok(function_call.clone()) | ||
} else if let Some(text) = content.get("text") { | ||
Ok(json!({"text": text})) | ||
} else { | ||
Err(OllamaError::from("Unexpected response format".to_string())) | ||
} | ||
} | ||
|
||
pub async fn function_call( | ||
&self, | ||
prompt: &str, | ||
tools: Vec<Arc<dyn Tool>>, | ||
raw_mode: bool, | ||
oai_parser: Arc<OpenAIFunctionCall>, | ||
) -> Result<String, OllamaError> { | ||
let url = format!( | ||
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}", | ||
self.model, self.api_key | ||
); | ||
|
||
let function_declarations: Vec<Value> = tools | ||
.iter() | ||
.map(|tool| { | ||
json!({ | ||
"name": tool.name(), | ||
"description": tool.description(), | ||
"parameters": tool.parameters() | ||
}) | ||
}) | ||
.collect(); | ||
|
||
let body = json!({ | ||
"system_instruction": { | ||
"parts": { | ||
"text": "You are a helpful function calling assistant." | ||
} | ||
}, | ||
"tools": {"function_declarations" : function_declarations}, | ||
"tool_config": { | ||
"function_calling_config": {"mode": "ANY"} | ||
}, | ||
"contents": { | ||
"role": "user", | ||
"parts": { | ||
"text": prompt | ||
} | ||
} | ||
}); | ||
|
||
let response = self | ||
.client | ||
.post(&url) | ||
.header("Content-Type", "application/json") | ||
.json(&body) | ||
.send() | ||
.await | ||
.map_err(|e| OllamaError::from(format!("Gemini API request failed: {:?}", e)))?; | ||
|
||
let response_body: Value = response.json().await.map_err(|e| { | ||
OllamaError::from(format!("Failed to parse Gemini API response: {:?}", e)) | ||
})?; | ||
|
||
let tool_call = self.extract_tools(response_body)?; | ||
|
||
for tool in &tools { | ||
if tool.name().to_lowercase().replace(' ', "_") | ||
== tool_call["name"].as_str().unwrap_or("") | ||
{ | ||
if raw_mode { | ||
let raw_result = serde_json::to_string(&tool_call); | ||
return match raw_result { | ||
Ok(raw_call) => Ok(raw_call), | ||
Err(e) => Err(OllamaError::from(format!( | ||
"Raw Call string conversion failed {:?}", | ||
e | ||
))), | ||
}; | ||
} | ||
let res = oai_parser | ||
.function_call_with_history( | ||
tool_call["name"].as_str().unwrap_or("").to_string(), | ||
tool_call["args"].clone(), | ||
tool.clone(), | ||
) | ||
.await; | ||
return match res { | ||
Ok(result) => Ok(result.message.unwrap().content), | ||
Err(e) => Err(OllamaError::from(format!( | ||
"Could not generate text: {:?}", | ||
e | ||
))), | ||
}; | ||
} | ||
} | ||
|
||
Err(OllamaError::from(format!( | ||
"No matching tool found for function: {}", | ||
tool_call["name"].as_str().unwrap_or("") | ||
))) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
pub mod gem_api; | ||
pub mod openai_api; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
use ollama_rs::{ | ||
error::OllamaError, generation::functions::tools::Tool, | ||
generation::functions::OpenAIFunctionCall, | ||
}; | ||
use openai_dive::v1::api::Client; | ||
use openai_dive::v1::resources::chat::*; | ||
use serde_json::{json, Value}; | ||
use std::sync::Arc; | ||
|
||
pub struct OpenAIExecutor { | ||
model: String, | ||
client: Client, | ||
} | ||
|
||
impl OpenAIExecutor { | ||
pub fn new(model: String, api_key: String) -> Self { | ||
Self { | ||
model, | ||
client: Client::new(api_key), | ||
} | ||
} | ||
|
||
pub async fn function_call( | ||
&self, | ||
prompt: &str, | ||
tools: Vec<Arc<dyn Tool>>, | ||
raw_mode: bool, | ||
oai_parser: Arc<OpenAIFunctionCall>, | ||
) -> Result<String, OllamaError> { | ||
let openai_tools: Vec<_> = tools | ||
.iter() | ||
.map(|tool| ChatCompletionTool { | ||
r#type: ChatCompletionToolType::Function, | ||
function: ChatCompletionFunction { | ||
name: tool.name().to_lowercase().replace(' ', "_"), | ||
description: Some(tool.description()), | ||
parameters: tool.parameters(), | ||
}, | ||
}) | ||
.collect(); | ||
|
||
let messages = vec![ChatMessageBuilder::default() | ||
.content(ChatMessageContent::Text(prompt.to_string())) | ||
.build() | ||
.expect("OpenAI function call message build error")]; | ||
|
||
let parameters = ChatCompletionParametersBuilder::default() | ||
.model(self.model.clone()) | ||
.messages(messages) | ||
.tools(openai_tools) | ||
.build() | ||
.expect("Error while building tools."); | ||
|
||
let result = self | ||
.client | ||
.chat() | ||
.create(parameters) | ||
.await | ||
.expect("OpenAI Function call failed"); | ||
let message = result.choices[0].message.clone(); | ||
|
||
if raw_mode { | ||
self.handle_raw_mode(message) | ||
} else { | ||
self.handle_normal_mode(message, tools, oai_parser).await | ||
} | ||
} | ||
|
||
fn handle_raw_mode(&self, message: ChatMessage) -> Result<String, OllamaError> { | ||
let mut raw_calls = Vec::new(); | ||
if let Some(tool_calls) = message.tool_calls { | ||
for tool_call in tool_calls { | ||
let call_json = json!({ | ||
"name": tool_call.function.name, | ||
"arguments": serde_json::from_str::<serde_json::Value>(&tool_call.function.arguments)? | ||
}); | ||
raw_calls.push(serde_json::to_string(&call_json)?); | ||
} | ||
} | ||
Ok(raw_calls.join("\n\n")) | ||
} | ||
|
||
async fn handle_normal_mode( | ||
&self, | ||
message: ChatMessage, | ||
tools: Vec<Arc<dyn Tool>>, | ||
oai_parser: Arc<OpenAIFunctionCall>, | ||
) -> Result<String, OllamaError> { | ||
let mut results = Vec::<String>::new(); | ||
if let Some(tool_calls) = message.tool_calls { | ||
for tool_call in tool_calls { | ||
for tool in &tools { | ||
if tool.name().to_lowercase().replace(' ', "_") == tool_call.function.name { | ||
let tool_params: Value = | ||
serde_json::from_str(&tool_call.function.arguments)?; | ||
let res = oai_parser | ||
.function_call_with_history( | ||
tool_call.function.name.clone(), | ||
tool_params, | ||
tool.clone(), | ||
) | ||
.await; | ||
match res { | ||
Ok(result) => results.push(result.message.unwrap().content), | ||
Err(e) => { | ||
return Err(OllamaError::from(format!( | ||
"Could not generate text: {:?}", | ||
e | ||
))) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
Ok(results.join("\n")) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.