Skip to content

Commit

Permalink
- executer now returns Result<String, ExecutionError> instead of stri…
Browse files Browse the repository at this point in the history
…ngs.

- Task execution logs are more detailed.
- Added new testing macro. Llama3.1_8B_Q8 added to models.
- Check operator now removed
  • Loading branch information
andthattoo committed Sep 2, 2024
1 parent 0dc65a5 commit ba038f7
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 214 deletions.
5 changes: 4 additions & 1 deletion examples/execute_workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,8 @@ async fn main() {
let mut memory = ProgramMemory::new();
let input = Entry::try_value_or_str("How would does reiki work?");
let return_value = exe.execute(Some(&input), workflow, &mut memory).await;
println!("{}", return_value);
match return_value {
Ok(value) => println!("{}", value),
Err(err) => eprintln!("Error: {:?}", err),
}
}
9 changes: 5 additions & 4 deletions src/program/atomics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ use crate::ProgramMemory;
pub static R_INPUT: &str = "__input";
pub static R_OUTPUT: &str = "__result";
pub static R_END: &str = "__end";
pub static R_EXPECTED: &str = "__expected";
pub static R_OUTPUTS: &str = "__output";

pub static TOOLS: [&str; 6] = [
"browserless",
Expand Down Expand Up @@ -112,7 +110,6 @@ pub struct Output {
pub enum Operator {
Generation,
FunctionCalling,
Check,
Search,
Sample,
End,
Expand Down Expand Up @@ -252,9 +249,12 @@ pub enum Model {
/// /// [Microsoft's Phi3.5 Mini model](https://ollama.com/library/phi3.5:3.8b-mini-instruct-fp16), 3.8b parameters
#[serde(rename = "phi3.5:3.8b-mini-instruct-fp16")]
Phi3_5MiniFp16,
/// [Ollama's Llama3.1 model](https://ollama.com/library/llama3.1:latest), 8B parameters
/// [Meta's Llama3.1 model](https://ollama.com/library/llama3.1:latest), 8B parameters
#[serde(rename = "llama3.1:latest")]
Llama3_1_8B,
/// [Meta's Llama3.1 model q8](https://ollama.com/library/llama3.1:8b-text-q8_0)
#[serde(rename = "llama3.1:8b-instruct-q8_0")]
Llama3_1_8Bq8,
// OpenAI models
/// [OpenAI's GPT-3.5 Turbo model](https://platform.openai.com/docs/models/gpt-3-5-turbo)
#[serde(rename = "gpt-3.5-turbo")]
Expand Down Expand Up @@ -325,6 +325,7 @@ impl From<Model> for ModelProvider {
Model::Phi3_5Mini => ModelProvider::Ollama,
Model::Phi3_5MiniFp16 => ModelProvider::Ollama,
Model::Llama3_1_8B => ModelProvider::Ollama,
Model::Llama3_1_8Bq8 => ModelProvider::Ollama,
Model::GPT3_5Turbo => ModelProvider::OpenAI,
Model::GPT4Turbo => ModelProvider::OpenAI,
Model::GPT4o => ModelProvider::OpenAI,
Expand Down
43 changes: 43 additions & 0 deletions src/program/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub enum CustomError {
FileSystemError(FileSystemError),
EmbeddingError(EmbeddingError),
ToolError(ToolError),
ExecutionError(ExecutionError),
}

#[allow(dead_code)]
Expand All @@ -31,12 +32,26 @@ pub enum ToolError {
ToolDoesNotExist,
}

#[derive(Debug)]
pub enum ExecutionError {
WorkflowFailed(String),
InvalidInput,
GenerationFailed,
FunctionCallFailed,
VectorSearchFailed,
StringCheckFailed,
SamplingError,
InvalidGetAllError,
UnexpectedOutput,
}

impl fmt::Display for CustomError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
CustomError::FileSystemError(err) => write!(f, "File system error: {}", err),
CustomError::EmbeddingError(err) => write!(f, "Embedding error: {}", err),
CustomError::ToolError(err) => write!(f, "Tool error: {}", err),
CustomError::ExecutionError(err) => write!(f, "Execution error: {}", err),
}
}
}
Expand Down Expand Up @@ -74,10 +89,32 @@ impl fmt::Display for ToolError {
}
}

impl fmt::Display for ExecutionError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ExecutionError::WorkflowFailed(cmd) => write!(f, "Workflow execution failed: {}", cmd),
ExecutionError::InvalidInput => write!(f, "Invalid input provided"),
ExecutionError::UnexpectedOutput => write!(f, "Unexpected output from command"),
ExecutionError::GenerationFailed => write!(f, "Text generation failed"),
ExecutionError::FunctionCallFailed => write!(f, "Function call failed"),
ExecutionError::VectorSearchFailed => write!(f, "Vector search failed"),
ExecutionError::StringCheckFailed => write!(f, "Vector search failed"),
ExecutionError::SamplingError => {
write!(f, "Error sampling because value array is empty")
}
ExecutionError::InvalidGetAllError => write!(
f,
"Error sampling because value is not get_all compatible (array)"
),
}
}
}

impl Error for CustomError {}
impl Error for FileSystemError {}
impl Error for EmbeddingError {}
impl Error for ToolError {}
impl Error for ExecutionError {}

impl From<FileSystemError> for CustomError {
fn from(err: FileSystemError) -> CustomError {
Expand All @@ -96,3 +133,9 @@ impl From<ToolError> for CustomError {
CustomError::ToolError(err)
}
}

impl From<ExecutionError> for CustomError {
fn from(err: ExecutionError) -> CustomError {
CustomError::ExecutionError(err)
}
}
70 changes: 22 additions & 48 deletions src/program/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::atomics::*;
use super::workflow::Workflow;
use crate::memory::types::Entry;
use crate::memory::{MemoryReturnType, ProgramMemory};
use crate::program::errors::ToolError;
use crate::program::errors::{ExecutionError, ToolError};
use crate::tools::{Browserless, CustomTool, Jina, SearchTool};

use rand::Rng;
Expand Down Expand Up @@ -82,7 +82,7 @@ impl Executor {
input: Option<&Entry>,
workflow: Workflow,
memory: &mut ProgramMemory,
) -> String {
) -> Result<String, ExecutionError> {
let config = workflow.get_config();
let max_steps = config.max_steps;
let max_time = config.max_time;
Expand Down Expand Up @@ -113,9 +113,9 @@ impl Executor {
}

if let Some(task) = workflow.get_tasks_by_id(&edge.source) {
let is_done = self.execute_task(task, memory.borrow_mut(), config).await;
let result = self.execute_task(task, memory.borrow_mut(), config).await;

current_step = if is_done {
current_step = if result.is_ok() {
//if there are conditions, check them
if let Some(condition) = &edge.condition {
let value = self.handle_input(&condition.input, memory).await;
Expand Down Expand Up @@ -156,9 +156,11 @@ impl Executor {
}
} else if let Some(fallback) = &edge.fallback {
warn!("[{}] failed, stepping into [{}]", &edge.source, &fallback);
error!("Task execution failed: {}", result.unwrap_err());
workflow.get_step_by_id(fallback)
} else {
warn!("{} failed, halting beacause of no fallback", &edge.source);
error!("Task execution failed: {}", result.unwrap_err());
break;
};
} else {
Expand All @@ -177,7 +179,7 @@ impl Executor {
if rv.to_json.is_some() && rv.to_json.unwrap() {
let res = return_value.to_json();
if let Some(result) = res {
return result;
return Ok(result);
}
}

Expand Down Expand Up @@ -214,18 +216,23 @@ impl Executor {
};
}
}
return_string
Ok(return_string)
}

async fn execute_task(&self, task: &Task, memory: &mut ProgramMemory, config: &Config) -> bool {
async fn execute_task(
&self,
task: &Task,
memory: &mut ProgramMemory,
config: &Config,
) -> Result<(), ExecutionError> {
info!("Executing task: {} with id {}", &task.name, &task.id);
info!("Using operator: {:?}", &task.operator);

let mut input_map: HashMap<String, MemoryReturnType> = HashMap::new();
for input in &task.inputs {
let value = self.handle_input(&input.value, memory).await;
if input.required && value.is_none() {
return false;
return Err(ExecutionError::InvalidInput);
}
input_map.insert(input.name.clone(), value.clone());
}
Expand All @@ -236,7 +243,7 @@ impl Executor {
let result = self.generate_text(&prompt, config).await;
if result.is_err() {
error!("Error generating text: {:?}", result.err().unwrap());
return false;
return Err(ExecutionError::GenerationFailed);
}
debug!("Prompt: {}", &prompt);
log_colored(
Expand All @@ -250,8 +257,8 @@ impl Executor {
info!("Prompt: {}", &prompt);
let result = self.function_call(&prompt, config).await;
if result.is_err() {
error!("Error generating text: {:?}", result.err().unwrap());
return false;
error!("Error function calling: {:?}", result.err().unwrap());
return Err(ExecutionError::FunctionCallFailed);
}
debug!("Prompt: {}", &prompt);
log_colored(
Expand All @@ -260,17 +267,12 @@ impl Executor {
let result_entry = Entry::try_value_or_str(&result.unwrap());
self.handle_output(task, result_entry, memory).await;
}
Operator::Check => {
let input = self.prepare_check(&input_map);
let result = self.check(&input.0, &input.1);
return result;
}
Operator::Search => {
let prompt = self.fill_prompt(&task.prompt, &input_map);
let result = memory.search(&Entry::try_value_or_str(&prompt)).await;
if result.is_none() {
error!("Error searching: {:?}", "No results found");
return false;
return Err(ExecutionError::VectorSearchFailed);
}
log_colored(
format!("Operator: {:?}. Output: {:?}", &task.operator, &result).as_str(),
Expand All @@ -291,13 +293,13 @@ impl Executor {
let v = Vec::<Entry>::from(value.clone());
if !v.is_empty() {
error!("Input for Sample operator cannot be GetAll");
return false;
return Err(ExecutionError::InvalidGetAllError);
} else {
let stack_lookup = value.to_string();
let entry = memory.get_all(&stack_lookup);
if entry.is_none() {
error!("Error sampling: {:?}", key);
return false;
return Err(ExecutionError::SamplingError);
}
let sample = self.sample(&entry.unwrap());
prompt.push_str(&format!(": {}", sample));
Expand All @@ -310,7 +312,7 @@ impl Executor {
Operator::End => {}
};

true
Ok(())
}

fn fill_prompt(
Expand All @@ -326,29 +328,6 @@ impl Executor {
filled_prompt
}

fn prepare_check(&self, input_map: &HashMap<String, MemoryReturnType>) -> (String, String) {
let input = input_map.get(R_OUTPUTS);
let expected = input_map.get(R_EXPECTED);

if let Some(i) = input {
if let Some(e) = expected {
return (
i.to_string()
.trim()
.replace('\n', "")
.to_lowercase()
.clone(),
e.to_string()
.trim()
.replace('\n', "")
.to_lowercase()
.clone(),
);
}
}
("+".to_string(), "-".to_string())
}

fn get_tools(
&self,
tool_names: Vec<String>,
Expand Down Expand Up @@ -587,11 +566,6 @@ impl Executor {
Ok(())
}

#[inline]
fn check(&self, input: &str, expected: &str) -> bool {
input == expected
}

//randomly sample list of entries
fn sample(&self, entries: &[Entry]) -> Entry {
let index = rand::thread_rng().gen_range(0..entries.len());
Expand Down
Loading

0 comments on commit ba038f7

Please sign in to comment.