Skip to content

Commit

Permalink
add mutable references to workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
erhant committed Nov 17, 2024
1 parent 75ead48 commit c544f98
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 5 deletions.
25 changes: 23 additions & 2 deletions src/program/atomics.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::program::io::{Input, InputValue, Output};
use crate::ProgramMemory;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;

Expand Down Expand Up @@ -77,9 +77,12 @@ pub enum Operator {
End,
}

#[derive(Clone, Debug, Deserialize)]
/// A message entry.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MessageInput {
/// Role, usually `user`, `assistant` or `system`.
pub role: String,
/// Message content.
pub content: String,
}

Expand All @@ -103,6 +106,24 @@ pub struct Task {
pub schema: Option<String>,
}

impl Task {
/// Creates a new chat history entry with the given content for `assistant` role.
pub fn append_assistant_message(&mut self, content: impl Into<String>) {
self.messages.push(MessageInput {
role: "assistant".to_string(),
content: content.into(),
});
}

/// Creates a new chat history entry with the given content for `user` role.
pub fn append_user_message(&mut self, content: impl Into<String>) {
self.messages.push(MessageInput {
role: "user".to_string(),
content: content.into(),
});
}
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum TaskOutputInput {
Expand Down
2 changes: 1 addition & 1 deletion src/program/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ impl fmt::Display for ModelProvider {
mod tests {
use super::*;

const MODEL_NAME: &str = "phi3:3.8b";
const MODEL_NAME: &str = "phi3.5:3.8b";
const PROVIDER_NAME: &str = "openai";
#[test]
fn test_model_string_conversion() {
Expand Down
12 changes: 10 additions & 2 deletions src/program/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ impl Workflow {
pub fn get_tasks(&self) -> &Vec<Task> {
&self.tasks
}
/// Returns a mutable reference to the tasks of the workflow.
pub fn get_tasks_mut(&mut self) -> &Vec<Task> {
&self.tasks
}
/// Returns a reference to the steps of the workflow.
pub fn get_workflow(&self) -> &Vec<Edge> {
&self.steps
Expand All @@ -109,8 +113,8 @@ impl Workflow {
&self.return_value
}
/// Returns a reference to the task at the specified index.
pub fn get_step(&self, index: u32) -> Option<&Edge> {
self.steps.get(index as usize)
pub fn get_step(&self, index: usize) -> Option<&Edge> {
self.steps.get(index)
}
/// Returns a reference to the step for specified task_id.
pub fn get_step_by_id(&self, task_id: &str) -> Option<&Edge> {
Expand All @@ -120,4 +124,8 @@ impl Workflow {
pub fn get_tasks_by_id(&self, task_id: &str) -> Option<&Task> {
self.tasks.iter().find(|task| task.id == task_id)
}
/// Returns a mutable reference to the task at the specified task_id.
pub fn get_tasks_by_id_mut(&mut self, task_id: &str) -> Option<&mut Task> {
self.tasks.iter_mut().find(|task| task.id == task_id)
}
}
15 changes: 15 additions & 0 deletions tests/task_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use ollama_workflows::Workflow;

#[test]
fn test_task_mutable() {
let mut workflow = Workflow::new_from_json("./tests/test_workflows/search.json").unwrap();

let task_id = "E";
let task = workflow.get_tasks_by_id_mut(task_id).unwrap();
assert_eq!(task.id, task_id);

assert_eq!(task.messages.len(), 1);
task.append_assistant_message("This is your response.");
task.append_user_message("Thank you.");
assert_eq!(task.messages.len(), 3);
}

0 comments on commit c544f98

Please sign in to comment.