From c544f98527b7eb0ece2b319953481dd8c4084554 Mon Sep 17 00:00:00 2001 From: erhant Date: Sun, 17 Nov 2024 15:26:12 +0700 Subject: [PATCH] add mutable references to workflow --- src/program/atomics.rs | 25 +++++++++++++++++++++++-- src/program/models.rs | 2 +- src/program/workflow.rs | 12 ++++++++++-- tests/task_test.rs | 15 +++++++++++++++ 4 files changed, 49 insertions(+), 5 deletions(-) create mode 100644 tests/task_test.rs diff --git a/src/program/atomics.rs b/src/program/atomics.rs index d2daba3..002d427 100644 --- a/src/program/atomics.rs +++ b/src/program/atomics.rs @@ -1,6 +1,6 @@ use crate::program::io::{Input, InputValue, Output}; use crate::ProgramMemory; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; @@ -77,9 +77,12 @@ pub enum Operator { End, } -#[derive(Clone, Debug, Deserialize)] +/// A message entry. +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct MessageInput { + /// Role, usually `user`, `assistant` or `system`. pub role: String, + /// Message content. pub content: String, } @@ -103,6 +106,24 @@ pub struct Task { pub schema: Option, } +impl Task { + /// Creates a new chat history entry with the given content for `assistant` role. + pub fn append_assistant_message(&mut self, content: impl Into) { + self.messages.push(MessageInput { + role: "assistant".to_string(), + content: content.into(), + }); + } + + /// Creates a new chat history entry with the given content for `user` role. + pub fn append_user_message(&mut self, content: impl Into) { + self.messages.push(MessageInput { + role: "user".to_string(), + content: content.into(), + }); + } +} + #[derive(Debug, Deserialize)] #[serde(untagged)] pub enum TaskOutputInput { diff --git a/src/program/models.rs b/src/program/models.rs index 4542fb2..7b4c684 100644 --- a/src/program/models.rs +++ b/src/program/models.rs @@ -273,7 +273,7 @@ impl fmt::Display for ModelProvider { mod tests { use super::*; - const MODEL_NAME: &str = "phi3:3.8b"; + const MODEL_NAME: &str = "phi3.5:3.8b"; const PROVIDER_NAME: &str = "openai"; #[test] fn test_model_string_conversion() { diff --git a/src/program/workflow.rs b/src/program/workflow.rs index 0ac9872..17ff79a 100644 --- a/src/program/workflow.rs +++ b/src/program/workflow.rs @@ -100,6 +100,10 @@ impl Workflow { pub fn get_tasks(&self) -> &Vec { &self.tasks } + /// Returns a mutable reference to the tasks of the workflow. + pub fn get_tasks_mut(&mut self) -> &Vec { + &self.tasks + } /// Returns a reference to the steps of the workflow. pub fn get_workflow(&self) -> &Vec { &self.steps @@ -109,8 +113,8 @@ impl Workflow { &self.return_value } /// Returns a reference to the task at the specified index. - pub fn get_step(&self, index: u32) -> Option<&Edge> { - self.steps.get(index as usize) + pub fn get_step(&self, index: usize) -> Option<&Edge> { + self.steps.get(index) } /// Returns a reference to the step for specified task_id. pub fn get_step_by_id(&self, task_id: &str) -> Option<&Edge> { @@ -120,4 +124,8 @@ impl Workflow { pub fn get_tasks_by_id(&self, task_id: &str) -> Option<&Task> { self.tasks.iter().find(|task| task.id == task_id) } + /// Returns a mutable reference to the task at the specified task_id. + pub fn get_tasks_by_id_mut(&mut self, task_id: &str) -> Option<&mut Task> { + self.tasks.iter_mut().find(|task| task.id == task_id) + } } diff --git a/tests/task_test.rs b/tests/task_test.rs new file mode 100644 index 0000000..d6a6690 --- /dev/null +++ b/tests/task_test.rs @@ -0,0 +1,15 @@ +use ollama_workflows::Workflow; + +#[test] +fn test_task_mutable() { + let mut workflow = Workflow::new_from_json("./tests/test_workflows/search.json").unwrap(); + + let task_id = "E"; + let task = workflow.get_tasks_by_id_mut(task_id).unwrap(); + assert_eq!(task.id, task_id); + + assert_eq!(task.messages.len(), 1); + task.append_assistant_message("This is your response."); + task.append_user_message("Thank you."); + assert_eq!(task.messages.len(), 3); +}