Skip to content

Commit

Permalink
Add support for attaching data to the context.
Browse files Browse the repository at this point in the history
  • Loading branch information
jakubadamw committed Nov 2, 2024
1 parent 8e5b486 commit 7daa19b
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 11 deletions.
44 changes: 44 additions & 0 deletions examples/data_in_context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use hatchet_sdk::{Client, Context, StepBuilder, WorkflowBuilder};

#[derive(serde::Deserialize)]
struct Input {}

async fn execute(context: Context, _input: Input) -> anyhow::Result<()> {
assert_eq!(context.datum::<Datum>().number, 10);
Ok(())
}

struct Datum {
number: usize,
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
dotenv::dotenv().ok();
tracing_subscriber::fmt()
.with_target(false)
.with_env_filter(
tracing_subscriber::EnvFilter::from_default_env()
.add_directive("hatchet_sdk=debug".parse()?),
)
.init();

let client = Client::new()?;
let mut worker = client
.worker("example_data_in_context")
.datum(Datum { number: 10 })
.build();
worker.register_workflow(
WorkflowBuilder::default()
.name("example_data_in_context")
.step(
StepBuilder::default()
.name("compute")
.function(&execute)
.build()?,
)
.build()?,
);
worker.start().await?;
Ok(())
}
2 changes: 1 addition & 1 deletion examples/fibonacci.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use hatchet_sdk::{Client, Context, StepBuilder, WorkflowBuilder};
use hatchet_sdk::{Client, StepBuilder, WorkflowBuilder};

fn fibonacci(n: u32) -> u32 {
(1..=n)
Expand Down
19 changes: 19 additions & 0 deletions src/step_function.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
use std::{
any::{Any, TypeId},
collections::HashMap,
sync::Arc,
};

use futures_util::lock::Mutex;
use tracing::info;

use crate::worker::{grpc, ServiceWithAuthorization};

pub(crate) type DataMap = HashMap<TypeId, Box<dyn Any + Send + Sync>>;

pub struct Context {
workflow_run_id: String,
workflow_step_run_id: String,
Expand All @@ -15,6 +23,7 @@ pub struct Context {
>,
u16,
)>,
data: Arc<DataMap>,
}

impl Context {
Expand All @@ -27,14 +36,24 @@ impl Context {
ServiceWithAuthorization,
>,
>,
data: Arc<DataMap>,
) -> Self {
Self {
workflow_run_id,
workflow_service_client_and_spawn_index: Mutex::new((workflow_service_client, 0)),
workflow_step_run_id,
data,
}
}

pub fn datum<D: std::any::Any + Send + Sync>(&self) -> &D {
let type_id = TypeId::of::<D>();
self.data
.get(&type_id)
.and_then(|value| value.downcast_ref())
.unwrap_or_else(|| panic!("could not find an attached datum of the type: {type_id:?}"))
}

pub async fn trigger_workflow<I: serde::Serialize>(
&self,
workflow_name: &str,
Expand Down
10 changes: 7 additions & 3 deletions src/worker/listener.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use futures_util::FutureExt;
use tokio::task::LocalSet;
use tonic::IntoRequest;
Expand All @@ -14,7 +16,7 @@ use super::{
dispatcher_client::DispatcherClient, AssignedAction, StepActionEvent, StepActionEventType,
WorkerListenRequest,
},
ListenStrategy, ServiceWithAuthorization,
DataMap, ListenStrategy, ServiceWithAuthorization,
};

const DEFAULT_ACTION_LISTENER_RETRY_INTERVAL: std::time::Duration =
Expand Down Expand Up @@ -63,6 +65,7 @@ async fn handle_start_step_run(
worker_id: &str,
workflows: &[Workflow],
action: AssignedAction,
data: Arc<DataMap>,
) -> crate::InternalResult<()> {
let Some(action_callable) = workflows
.iter()
Expand Down Expand Up @@ -101,6 +104,7 @@ async fn handle_start_step_run(
workflow_run_id,
workflow_step_run_id,
workflow_service_client,
data,
);
action_callable(context, input.input).await
})
Expand Down Expand Up @@ -155,7 +159,7 @@ pub(crate) async fn run(
workflows: Vec<Workflow>,
listener_v2_timeout: Option<u64>,
mut interrupt_receiver: tokio::sync::mpsc::Receiver<()>,
_heartbeat_interrupt_sender: tokio::sync::mpsc::Sender<()>,
data: Arc<DataMap>,
) -> crate::InternalResult<()> {
use futures_util::StreamExt;

Expand Down Expand Up @@ -253,7 +257,7 @@ pub(crate) async fn run(

match action_type {
ActionType::StartStepRun => {
handle_start_step_run(&mut dispatcher, workflow_service_client.clone(), &local_set, namespace, worker_id, &workflows, action).await?;
handle_start_step_run(&mut dispatcher, workflow_service_client.clone(), &local_set, namespace, worker_id, &workflows, action, data.clone()).await?;
}
ActionType::CancelStepRun => {
todo!()
Expand Down
35 changes: 28 additions & 7 deletions src/worker/mod.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
mod heartbeat;
mod listener;

use std::sync::Arc;

use grpc::{
CreateWorkflowJobOpts, CreateWorkflowStepOpts, CreateWorkflowVersionOpts, PutWorkflowRequest,
WorkerRegisterRequest, WorkerRegisterResponse, WorkflowKind,
};
use tonic::transport::Certificate;
use tracing::info;

use crate::{client::Environment, ClientTlStrategy, Workflow};
use crate::{client::Environment, step_function::DataMap, ClientTlStrategy, Workflow};

#[derive(Clone)]
pub(crate) struct ServiceWithAuthorization {
Expand Down Expand Up @@ -55,9 +57,18 @@ pub struct Worker<'a> {
environment: &'a super::client::Environment,
#[builder(default, setter(skip))]
workflows: Vec<Workflow>,
#[builder(default, setter(custom))]
data: DataMap,
}

impl<'a> WorkerBuilder<'a> {
pub fn datum<D: std::any::Any + Send + Sync>(mut self, datum: D) -> Self {
self.data
.get_or_insert_default()
.insert(std::any::TypeId::of::<D>(), Box::new(datum));
self
}

pub fn build(self) -> Worker<'a> {
self.build_private().expect("must succeed")
}
Expand Down Expand Up @@ -172,7 +183,7 @@ impl<'a> Worker<'a> {
let (listening_interrupt_sender1, listening_interrupt_receiver) =
tokio::sync::mpsc::channel(1);
let _listening_interrupt_sender2 = listening_interrupt_sender1.clone();
let heartbeat_interrupt_sender2 = heartbeat_interrupt_sender1.clone();
let _heartbeat_interrupt_sender2 = heartbeat_interrupt_sender1.clone();

tokio::spawn(async move {
tokio::signal::ctrl_c().await.unwrap();
Expand All @@ -182,6 +193,14 @@ impl<'a> Worker<'a> {
let _ = listening_interrupt_sender1.send(()).await;
});

let Self {
workflows,
data,
name,
max_runs,
environment,
} = self;

let Environment {
token,
host_port,
Expand All @@ -195,7 +214,7 @@ impl<'a> Worker<'a> {
tls_server_name,
namespace,
listener_v2_timeout,
} = self.environment;
} = environment;

let endpoint = construct_endpoint(
tls_server_name.as_deref(),
Expand All @@ -220,7 +239,9 @@ impl<'a> Worker<'a> {

let mut all_actions = vec![];

for workflow in &self.workflows {
let data = Arc::new(data);

for workflow in &workflows {
let namespaced_workflow_name =
format!("{namespace}{workflow_name}", workflow_name = workflow.name);

Expand Down Expand Up @@ -284,8 +305,8 @@ impl<'a> Worker<'a> {

let request = {
let mut request: tonic::Request<WorkerRegisterRequest> = WorkerRegisterRequest {
worker_name: self.name.clone(),
max_runs: self.max_runs,
worker_name: name,
max_runs,
services: vec!["default".to_owned()],
actions: all_actions,
// FIXME: Implement.
Expand All @@ -304,7 +325,7 @@ impl<'a> Worker<'a> {

futures_util::try_join! {
heartbeat::run(dispatcher.clone(), &worker_id, heartbeat_interrupt_receiver),
listener::run(dispatcher, workflow_service_client, namespace, &worker_id, self.workflows, *listener_v2_timeout, listening_interrupt_receiver, heartbeat_interrupt_sender2),
listener::run(dispatcher, workflow_service_client, namespace, &worker_id, workflows, *listener_v2_timeout, listening_interrupt_receiver, data)
}?;

Ok(())
Expand Down

0 comments on commit 7daa19b

Please sign in to comment.