Skip to content

Commit

Permalink
gemini models added
Browse files Browse the repository at this point in the history
  • Loading branch information
andthattoo committed Oct 18, 2024
1 parent 528482d commit baf5321
Show file tree
Hide file tree
Showing 11 changed files with 691 additions and 199 deletions.
354 changes: 297 additions & 57 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ ollama-rs = { git = "https://github.com/andthattoo/ollama-rs", rev = "00c67cf",
parking_lot = "0.12.2"
langchain-rust = "4.2.0"
openai_dive = "0.5.3"
gem-rs = "0.1.1"
scraper = "0.19.0"
text-splitter = "0.13.1"
search_with_google = "0.5.0"
Expand Down
183 changes: 183 additions & 0 deletions src/api_interface/gem_api.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
use ollama_rs::{
error::OllamaError, generation::functions::tools::Tool,
generation::functions::OpenAIFunctionCall,
};
use reqwest::Client;
use serde_json::{json, Value};
use std::sync::Arc;

pub struct GeminiExecutor {
model: String,
api_key: String,
client: Client,
}

impl GeminiExecutor {
pub fn new(model: String, api_key: String) -> Self {
Self {
model,
api_key,
client: Client::new(),
}
}

pub async fn generate_text(&self, prompt: &str) -> Result<String, OllamaError> {
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
self.model, self.api_key
);

let body = json!({
"contents": [{
"parts": [
{"text": prompt}
]
}],
"safetySettings": [
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_ONLY_HIGH"
}
],
"generationConfig": {
"temperature": 1.0,
"maxOutputTokens": 800,
"topP": 0.8,
"topK": 10
}
});

let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| OllamaError::from(format!("Gemini API request failed: {:?}", e)))?;

let response_body: Value = response.json().await.map_err(|e| {
OllamaError::from(format!("Failed to parse Gemini API response: {:?}", e))
})?;

self.extract_generated_text(response_body)
}

fn extract_generated_text(&self, response: Value) -> Result<String, OllamaError> {
response["candidates"][0]["content"]["parts"][0]["text"]
.as_str()
.map(|s| s.to_string())
.ok_or_else(|| {
OllamaError::from("Failed to extract generated text from response".to_string())
})
}

fn extract_tools(&self, response: Value) -> Result<Value, OllamaError> {
let candidate = response["candidates"]
.get(0)
.ok_or_else(|| OllamaError::from("No candidates found in response".to_string()))?;

let content = &candidate["content"]["parts"][0];

if let Some(function_call) = content.get("functionCall") {
Ok(function_call.clone())
} else if let Some(text) = content.get("text") {
Ok(json!({"text": text}))
} else {
Err(OllamaError::from("Unexpected response format".to_string()))
}
}

pub async fn function_call(
&self,
prompt: &str,
tools: Vec<Arc<dyn Tool>>,
raw_mode: bool,
oai_parser: Arc<OpenAIFunctionCall>,
) -> Result<String, OllamaError> {
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
self.model, self.api_key
);

let function_declarations: Vec<Value> = tools
.iter()
.map(|tool| {
json!({
"name": tool.name(),
"description": tool.description(),
"parameters": tool.parameters()
})
})
.collect();

let body = json!({
"system_instruction": {
"parts": {
"text": "You are a helpful function calling assistant."
}
},
"tools": {"function_declarations" : function_declarations},
"tool_config": {
"function_calling_config": {"mode": "ANY"}
},
"contents": {
"role": "user",
"parts": {
"text": prompt
}
}
});

let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| OllamaError::from(format!("Gemini API request failed: {:?}", e)))?;

let response_body: Value = response.json().await.map_err(|e| {
OllamaError::from(format!("Failed to parse Gemini API response: {:?}", e))
})?;

let tool_call = self.extract_tools(response_body)?;

for tool in &tools {
if tool.name().to_lowercase().replace(' ', "_")
== tool_call["name"].as_str().unwrap_or("")
{
if raw_mode {
let raw_result = serde_json::to_string(&tool_call);
return match raw_result {
Ok(raw_call) => Ok(raw_call),
Err(e) => Err(OllamaError::from(format!(
"Raw Call string conversion failed {:?}",
e
))),
};
}
let res = oai_parser
.function_call_with_history(
tool_call["name"].as_str().unwrap_or("").to_string(),
tool_call["args"].clone(),
tool.clone(),
)
.await;
return match res {
Ok(result) => Ok(result.message.unwrap().content),
Err(e) => Err(OllamaError::from(format!(
"Could not generate text: {:?}",
e
))),
};
}
}

Err(OllamaError::from(format!(
"No matching tool found for function: {}",
tool_call["name"].as_str().unwrap_or("")
)))
}
}
2 changes: 2 additions & 0 deletions src/api_interface/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod gem_api;
pub mod openai_api;
118 changes: 118 additions & 0 deletions src/api_interface/openai_api.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
use ollama_rs::{
error::OllamaError, generation::functions::tools::Tool,
generation::functions::OpenAIFunctionCall,
};
use openai_dive::v1::api::Client;
use openai_dive::v1::resources::chat::*;
use serde_json::{json, Value};
use std::sync::Arc;

pub struct OpenAIExecutor {
model: String,
client: Client,
}

impl OpenAIExecutor {
pub fn new(model: String, api_key: String) -> Self {
Self {
model,
client: Client::new(api_key),
}
}

pub async fn function_call(
&self,
prompt: &str,
tools: Vec<Arc<dyn Tool>>,
raw_mode: bool,
oai_parser: Arc<OpenAIFunctionCall>,
) -> Result<String, OllamaError> {
let openai_tools: Vec<_> = tools
.iter()
.map(|tool| ChatCompletionTool {
r#type: ChatCompletionToolType::Function,
function: ChatCompletionFunction {
name: tool.name().to_lowercase().replace(' ', "_"),
description: Some(tool.description()),
parameters: tool.parameters(),
},
})
.collect();

let messages = vec![ChatMessageBuilder::default()
.content(ChatMessageContent::Text(prompt.to_string()))
.build()
.expect("OpenAI function call message build error")];

let parameters = ChatCompletionParametersBuilder::default()
.model(self.model.clone())
.messages(messages)
.tools(openai_tools)
.build()
.expect("Error while building tools.");

let result = self
.client
.chat()
.create(parameters)
.await
.expect("OpenAI Function call failed");
let message = result.choices[0].message.clone();

if raw_mode {
self.handle_raw_mode(message)
} else {
self.handle_normal_mode(message, tools, oai_parser).await
}
}

fn handle_raw_mode(&self, message: ChatMessage) -> Result<String, OllamaError> {
let mut raw_calls = Vec::new();
if let Some(tool_calls) = message.tool_calls {
for tool_call in tool_calls {
let call_json = json!({
"name": tool_call.function.name,
"arguments": serde_json::from_str::<serde_json::Value>(&tool_call.function.arguments)?
});
raw_calls.push(serde_json::to_string(&call_json)?);
}
}
Ok(raw_calls.join("\n\n"))
}

async fn handle_normal_mode(
&self,
message: ChatMessage,
tools: Vec<Arc<dyn Tool>>,
oai_parser: Arc<OpenAIFunctionCall>,
) -> Result<String, OllamaError> {
let mut results = Vec::<String>::new();
if let Some(tool_calls) = message.tool_calls {
for tool_call in tool_calls {
for tool in &tools {
if tool.name().to_lowercase().replace(' ', "_") == tool_call.function.name {
let tool_params: Value =
serde_json::from_str(&tool_call.function.arguments)?;
let res = oai_parser
.function_call_with_history(
tool_call.function.name.clone(),
tool_params,
tool.clone(),
)
.await;
match res {
Ok(result) => results.push(result.message.unwrap().content),
Err(e) => {
return Err(OllamaError::from(format!(
"Could not generate text: {:?}",
e
)))
}
}
}
}
}
}
Ok(results.join("\n"))
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
//! ```
//! This crate provides a simple execution pipeline and enables users to create and execute workflows with JSON files.
//! Creating specific JSON for you purpose should suffice.
mod api_interface;
mod memory;
mod program;
mod tools;
Expand Down
2 changes: 2 additions & 0 deletions src/program/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub enum ExecutionError {
SamplingError,
InvalidGetAllError,
UnexpectedOutput,
Cancelled,
}

impl fmt::Display for CustomError {
Expand Down Expand Up @@ -110,6 +111,7 @@ impl fmt::Display for ExecutionError {
f,
"Error sampling because value is not get_all compatible (array)"
),
ExecutionError::Cancelled => write!(f, "Execution cancelled"),
}
}
}
Expand Down
Loading

0 comments on commit baf5321

Please sign in to comment.