Skip to content

Commit

Permalink
added new custom tool method.
Browse files Browse the repository at this point in the history
Now custom tools are enum with mode value: http_request, custom.

Custom mode runs on raw_mode() will only return function calls and won't execute the tool.
  • Loading branch information
andthattoo committed Sep 29, 2024
1 parent a3693ad commit 09f8579
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 56 deletions.
23 changes: 19 additions & 4 deletions src/program/atomics.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::program::io::{Input, InputValue, Output};
use crate::ProgramMemory;
use serde::Deserialize;
use serde_json::Value;
use std::collections::HashMap;

pub static R_INPUT: &str = "__input";
Expand All @@ -25,14 +26,28 @@ pub fn in_tools(tools: &Vec<String>) -> bool {
true
}

#[derive(Debug, Deserialize, Clone)]
#[serde(tag = "mode", rename_all = "snake_case")]
pub enum CustomToolModeTemplate {
Custom {
parameters: Value,
},
HttpRequest {
url: String,
method: String,
#[serde(default)]
headers: Option<HashMap<String, String>>,
#[serde(default)]
body: Option<HashMap<String, String>>,
},
}

#[derive(Debug, Deserialize, Clone)]
pub struct CustomToolTemplate {
pub name: String,
pub description: String,
pub url: String,
pub method: String,
pub headers: Option<HashMap<String, String>>,
pub body: Option<HashMap<String, String>>,
#[serde(flatten)]
pub mode: CustomToolModeTemplate,
}

/// Configuration for the workflow
Expand Down
136 changes: 85 additions & 51 deletions src/tools/external.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,52 @@
use crate::program::atomics::CustomToolTemplate;
use crate::program::atomics::{CustomToolModeTemplate, CustomToolTemplate};
use async_trait::async_trait;
use log::info;
use ollama_rs::generation::functions::tools::Tool;
use reqwest::Client;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::env;
use std::error::Error;

pub enum CustomToolMode {
Custom {
parameters: Value,
},
HttpRequest {
url: String,
method: String,
headers: HashMap<String, String>,
body: HashMap<String, String>,
},
}

pub struct CustomTool {
pub name: String,
pub description: String,
pub url: String,
pub method: String,
pub headers: HashMap<String, String>,
pub body: HashMap<String, String>,
pub mode: CustomToolMode,
}

impl CustomTool {
pub fn new_from_template(template: CustomToolTemplate) -> Self {
CustomTool {
name: template.name,
description: template.description,
url: template.url,
method: template.method,
headers: template.headers.unwrap_or_default(),
body: template.body.unwrap_or_default(),
mode: match template.mode {
CustomToolModeTemplate::Custom { parameters } => {
CustomToolMode::Custom { parameters }
}
CustomToolModeTemplate::HttpRequest {
url,
method,
headers,
body,
} => CustomToolMode::HttpRequest {
url,
method,
headers: headers.unwrap_or_default(),
body: body.unwrap_or_default(),
},
},
}
}
}
Expand All @@ -40,55 +62,67 @@ impl Tool for CustomTool {
}

fn parameters(&self) -> Value {
let properties: HashMap<_, _> = self
.body
.keys()
.map(|k| {
(
k.clone(),
json!({ "type": "string", "description": format!("The value for {}", k) }),
)
})
.collect();
match &self.mode {
CustomToolMode::Custom { parameters } => {
info!("helloooo");
parameters.clone()
}
CustomToolMode::HttpRequest { body, .. } => {
let properties: HashMap<_, _> = body.iter().map(|(k, _)| {
(k.clone(), json!({ "type": "string", "description": format!("The value for {}", k) }))
}).collect();

json!({
"type": "object",
"properties": properties,
"required": self.body.keys().collect::<Vec<_>>()
})
json!({
"type": "object",
"properties": properties,
"required": body.keys().collect::<Vec<_>>()
})
}
}
}

async fn run(&self, input: Value) -> Result<String, Box<dyn Error>> {
let client = Client::new();
let mut request = match self.method.as_str() {
"GET" => client.get(&self.url),
"POST" => client.post(&self.url),
"PUT" => client.put(&self.url),
"DELETE" => client.delete(&self.url),
_ => return Err(Box::from("Unsupported HTTP method")),
};
match &self.mode {
CustomToolMode::Custom { .. } => {
Err("Custom mode can't execute tools, use function_calling_raw".into())
}
CustomToolMode::HttpRequest {
url,
method,
headers,
body: _,
} => {
let client = Client::new();
let mut request = match method.as_str() {
"GET" => client.get(url),
"POST" => client.post(url),
"PUT" => client.put(url),
"DELETE" => client.delete(url),
_ => return Err("Unsupported HTTP method".into()),
};

for (key, value) in &self.headers {
request = request.header(key, value);
}
for (key, value) in headers {
request = request.header(key, value);
}

let token = env::var("API_KEY");
if token.is_ok() {
request = request.header("Authorization", format!("Bearer {}", token.unwrap()));
}
if let Ok(token) = env::var("API_KEY") {
request = request.header("Authorization", format!("Bearer {}", token));
}

if self.method != "GET" {
let body: HashMap<String, String> = input
.as_object()
.unwrap()
.iter()
.map(|(k, v)| (k.clone(), v.as_str().unwrap().to_string()))
.collect();
request = request.json(&body);
}
if method != "GET" {
let body: HashMap<String, String> = input
.as_object()
.unwrap()
.iter()
.map(|(k, v)| (k.clone(), v.as_str().unwrap().to_string()))
.collect();
request = request.json(&body);
}

let response = request.send().await?;
let result = response.text().await?;
Ok(result)
let response = request.send().await?;
let result = response.text().await?;
Ok(result)
}
}
}
}
12 changes: 11 additions & 1 deletion tests/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ const SIMPLE_WORKFLOW_PATH: &str = "./tests/test_workflows/simple.json";
const INSERT_WORKFLOW_PATH: &str = "./tests/test_workflows/insert.json";
const USERS_WORKFLOW_PATH: &str = "./tests/test_workflows/users.json";
const CONTEXT_SIZE_WORKFLOW_PATH: &str = "./tests/test_workflows/context_size.json";
const CUSTOM_TOOL_WORKFLOW_PATH: &str = "./tests/test_workflows/custom_tools.json";
const CUSTOM_TOOL_HTTP_WORKFLOW_PATH: &str = "./tests/test_workflows/custom_tools_http.json";
const CUSTOM_TOOL_WORKFLOW_PATH: &str = "./tests/test_workflows/custom_tool.json";
const CODER_PATH: &str = "./tests/test_workflows/coding.json";

async fn setup_test(model: Model) -> Executor {
dotenv().ok();
Expand Down Expand Up @@ -69,6 +71,8 @@ mod simple_workflow_tests {
"How does reiki work?"
);

workflow_test!(simple_coder, Model::Qwen2_5Coder1_5B, CODER_PATH);

workflow_test!(simple_o1, Model::O1Mini, SIMPLE_WORKFLOW_PATH);
}

Expand Down Expand Up @@ -174,6 +178,12 @@ mod context_size_tests {
mod custom_tool_tests {
use super::*;

workflow_test!(
http_custom_tool_workflow,
Model::Llama3_1_8B,
CUSTOM_TOOL_HTTP_WORKFLOW_PATH
);

workflow_test!(
custom_tool_workflow,
Model::Llama3_1_8B,
Expand Down
49 changes: 49 additions & 0 deletions tests/test_workflows/coding.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
{
"config": {
"max_steps": 50,
"max_time": 200,
"tools": [
"ALL"
]
},
"external_memory": {},
"tasks": [
{
"id": "fibonacci",
"name": "Task",
"description": "Task Description",
"prompt": "Write fibonnaci sequence method with memoization in rust",
"inputs": [],
"operator": "generation",
"outputs": [
{
"type": "push",
"key": "code",
"value": "__result"
}
]
},
{
"id": "_end",
"name": "Task",
"description": "Task Description",
"prompt": "",
"inputs": [],
"operator": "end",
"outputs": []
}
],
"steps": [
{
"source": "fibonacci",
"target": "_end"
}
],
"return_value": {
"input": {
"type": "get_all",
"key": "code"
},
"to_json": true
}
}
63 changes: 63 additions & 0 deletions tests/test_workflows/custom_tool.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
{
"name": "Custom Tool",
"description": "This is a simple workflow for custom tools",
"config":{
"max_steps": 5,
"max_time": 100,
"max_tokens": 1024,
"tools": [],
"custom_tools":[{
"name": "google_search_tool",
"description": "Conducts a web search using a specified search type and returns the results.",
"mode":"custom",
"parameters":{
"type": "object",
"properties": {
"website": {
"type": "string",
"description": "The URL of the website to scrape"
}
},
"required": ["website"]
}
}]
},
"tasks":[
{
"id": "A",
"name": "Get prices",
"description": "Get price feed",
"prompt": "What are the current prices for $APPL?",
"inputs":[],
"operator": "function_calling_raw",
"outputs":[
{
"type": "write",
"key": "prices",
"value": "__result"
}
]
},
{
"id": "__end",
"name": "end",
"description": "End of the task",
"prompt": "End of the task",
"inputs": [],
"operator": "end",
"outputs": []
}
],
"steps":[
{
"source":"A",
"target":"_end"
}
],
"return_value":{
"input":{
"type": "read",
"key": "prices"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"custom_tools":[{
"name": "PriceFeedRequest",
"description": "Fetches price feed from Gemini API",
"mode":"http_request",
"url": "https://api.gemini.com/v1/pricefeed",
"method": "GET"
}]
Expand Down

0 comments on commit 09f8579

Please sign in to comment.