diff --git a/examples/data_in_context.rs b/examples/data_in_context.rs new file mode 100644 index 0000000..38369b2 --- /dev/null +++ b/examples/data_in_context.rs @@ -0,0 +1,44 @@ +use hatchet_sdk::{Client, Context, StepBuilder, WorkflowBuilder}; + +#[derive(serde::Deserialize)] +struct Input {} + +async fn execute(context: Context, _input: Input) -> anyhow::Result<()> { + assert_eq!(context.datum::().number, 10); + Ok(()) +} + +struct Datum { + number: usize, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + dotenv::dotenv().ok(); + tracing_subscriber::fmt() + .with_target(false) + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive("hatchet_sdk=debug".parse()?), + ) + .init(); + + let client = Client::new()?; + let mut worker = client + .worker("example_data_in_context") + .datum(Datum { number: 10 }) + .build(); + worker.register_workflow( + WorkflowBuilder::default() + .name("example_data_in_context") + .step( + StepBuilder::default() + .name("compute") + .function(&execute) + .build()?, + ) + .build()?, + ); + worker.start().await?; + Ok(()) +} diff --git a/examples/fibonacci.rs b/examples/fibonacci.rs index 2f56b16..6458204 100644 --- a/examples/fibonacci.rs +++ b/examples/fibonacci.rs @@ -1,4 +1,4 @@ -use hatchet_sdk::{Client, Context, StepBuilder, WorkflowBuilder}; +use hatchet_sdk::{Client, StepBuilder, WorkflowBuilder}; fn fibonacci(n: u32) -> u32 { (1..=n) diff --git a/src/step_function.rs b/src/step_function.rs index 8b6f1dd..a24e222 100644 --- a/src/step_function.rs +++ b/src/step_function.rs @@ -1,8 +1,16 @@ +use std::{ + any::{Any, TypeId}, + collections::HashMap, + sync::Arc, +}; + use futures_util::lock::Mutex; use tracing::info; use crate::worker::{grpc, ServiceWithAuthorization}; +pub(crate) type DataMap = HashMap>; + pub struct Context { workflow_run_id: String, workflow_step_run_id: String, @@ -15,6 +23,7 @@ pub struct Context { >, u16, )>, + data: Arc, } impl Context { @@ -27,14 +36,24 @@ impl Context { ServiceWithAuthorization, >, >, + data: Arc, ) -> Self { Self { workflow_run_id, workflow_service_client_and_spawn_index: Mutex::new((workflow_service_client, 0)), workflow_step_run_id, + data, } } + pub fn datum(&self) -> &D { + let type_id = TypeId::of::(); + self.data + .get(&type_id) + .and_then(|value| value.downcast_ref()) + .unwrap_or_else(|| panic!("could not find an attached datum of the type: {type_id:?}")) + } + pub async fn trigger_workflow( &self, workflow_name: &str, diff --git a/src/worker/listener.rs b/src/worker/listener.rs index b579edc..97b74c0 100644 --- a/src/worker/listener.rs +++ b/src/worker/listener.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use futures_util::FutureExt; use tokio::task::LocalSet; use tonic::IntoRequest; @@ -14,7 +16,7 @@ use super::{ dispatcher_client::DispatcherClient, AssignedAction, StepActionEvent, StepActionEventType, WorkerListenRequest, }, - ListenStrategy, ServiceWithAuthorization, + DataMap, ListenStrategy, ServiceWithAuthorization, }; const DEFAULT_ACTION_LISTENER_RETRY_INTERVAL: std::time::Duration = @@ -63,6 +65,7 @@ async fn handle_start_step_run( worker_id: &str, workflows: &[Workflow], action: AssignedAction, + data: Arc, ) -> crate::InternalResult<()> { let Some(action_callable) = workflows .iter() @@ -101,6 +104,7 @@ async fn handle_start_step_run( workflow_run_id, workflow_step_run_id, workflow_service_client, + data, ); action_callable(context, input.input).await }) @@ -155,7 +159,7 @@ pub(crate) async fn run( workflows: Vec, listener_v2_timeout: Option, mut interrupt_receiver: tokio::sync::mpsc::Receiver<()>, - _heartbeat_interrupt_sender: tokio::sync::mpsc::Sender<()>, + data: Arc, ) -> crate::InternalResult<()> { use futures_util::StreamExt; @@ -253,7 +257,7 @@ pub(crate) async fn run( match action_type { ActionType::StartStepRun => { - handle_start_step_run(&mut dispatcher, workflow_service_client.clone(), &local_set, namespace, worker_id, &workflows, action).await?; + handle_start_step_run(&mut dispatcher, workflow_service_client.clone(), &local_set, namespace, worker_id, &workflows, action, data.clone()).await?; } ActionType::CancelStepRun => { todo!() diff --git a/src/worker/mod.rs b/src/worker/mod.rs index 87fcafe..136a9ea 100644 --- a/src/worker/mod.rs +++ b/src/worker/mod.rs @@ -1,6 +1,8 @@ mod heartbeat; mod listener; +use std::sync::Arc; + use grpc::{ CreateWorkflowJobOpts, CreateWorkflowStepOpts, CreateWorkflowVersionOpts, PutWorkflowRequest, WorkerRegisterRequest, WorkerRegisterResponse, WorkflowKind, @@ -8,7 +10,7 @@ use grpc::{ use tonic::transport::Certificate; use tracing::info; -use crate::{client::Environment, ClientTlStrategy, Workflow}; +use crate::{client::Environment, step_function::DataMap, ClientTlStrategy, Workflow}; #[derive(Clone)] pub(crate) struct ServiceWithAuthorization { @@ -55,9 +57,18 @@ pub struct Worker<'a> { environment: &'a super::client::Environment, #[builder(default, setter(skip))] workflows: Vec, + #[builder(default, setter(custom))] + data: DataMap, } impl<'a> WorkerBuilder<'a> { + pub fn datum(mut self, datum: D) -> Self { + self.data + .get_or_insert_default() + .insert(std::any::TypeId::of::(), Box::new(datum)); + self + } + pub fn build(self) -> Worker<'a> { self.build_private().expect("must succeed") } @@ -172,7 +183,7 @@ impl<'a> Worker<'a> { let (listening_interrupt_sender1, listening_interrupt_receiver) = tokio::sync::mpsc::channel(1); let _listening_interrupt_sender2 = listening_interrupt_sender1.clone(); - let heartbeat_interrupt_sender2 = heartbeat_interrupt_sender1.clone(); + let _heartbeat_interrupt_sender2 = heartbeat_interrupt_sender1.clone(); tokio::spawn(async move { tokio::signal::ctrl_c().await.unwrap(); @@ -182,6 +193,14 @@ impl<'a> Worker<'a> { let _ = listening_interrupt_sender1.send(()).await; }); + let Self { + workflows, + data, + name, + max_runs, + environment, + } = self; + let Environment { token, host_port, @@ -195,7 +214,7 @@ impl<'a> Worker<'a> { tls_server_name, namespace, listener_v2_timeout, - } = self.environment; + } = environment; let endpoint = construct_endpoint( tls_server_name.as_deref(), @@ -220,7 +239,9 @@ impl<'a> Worker<'a> { let mut all_actions = vec![]; - for workflow in &self.workflows { + let data = Arc::new(data); + + for workflow in &workflows { let namespaced_workflow_name = format!("{namespace}{workflow_name}", workflow_name = workflow.name); @@ -284,8 +305,8 @@ impl<'a> Worker<'a> { let request = { let mut request: tonic::Request = WorkerRegisterRequest { - worker_name: self.name.clone(), - max_runs: self.max_runs, + worker_name: name, + max_runs, services: vec!["default".to_owned()], actions: all_actions, // FIXME: Implement. @@ -304,7 +325,7 @@ impl<'a> Worker<'a> { futures_util::try_join! { heartbeat::run(dispatcher.clone(), &worker_id, heartbeat_interrupt_receiver), - listener::run(dispatcher, workflow_service_client, namespace, &worker_id, self.workflows, *listener_v2_timeout, listening_interrupt_receiver, heartbeat_interrupt_sender2), + listener::run(dispatcher, workflow_service_client, namespace, &worker_id, workflows, *listener_v2_timeout, listening_interrupt_receiver, data) }?; Ok(())