Skip to content

Commit

Permalink
Timestamps + Workflow update (#143)
Browse files Browse the repository at this point in the history
* added timestamps to task results

* update workflows
  • Loading branch information
erhant authored Nov 16, 2024
1 parent 5d5763c commit b0c91d7
Show file tree
Hide file tree
Showing 9 changed files with 198 additions and 92 deletions.
146 changes: 86 additions & 60 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ default-members = ["compute"]

[workspace.package]
edition = "2021"
version = "0.2.20"
version = "0.2.21"
license = "Apache-2.0"
readme = "README.md"

Expand Down
44 changes: 30 additions & 14 deletions compute/src/handlers/workflow.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::time::Instant;

use async_trait::async_trait;
use dkn_p2p::libp2p::gossipsub::MessageAcceptance;
use dkn_workflows::{Entry, Executor, ModelProvider, ProgramMemory, Workflow};
use eyre::{eyre, Context, Result};
use libsecp256k1::PublicKey;
use serde::Deserialize;

use crate::payloads::{TaskErrorPayload, TaskRequestPayload, TaskResponsePayload};
use crate::payloads::{TaskErrorPayload, TaskRequestPayload, TaskResponsePayload, TaskStats};
use crate::utils::{get_current_time_nanos, DKNMessage};
use crate::DriaComputeNode;

Expand Down Expand Up @@ -38,6 +40,7 @@ impl ComputeHandler for WorkflowHandler {
let task = message
.parse_payload::<TaskRequestPayload<WorkflowPayload>>(true)
.wrap_err("Could not parse workflow task")?;
let mut task_stats = TaskStats::default().record_received_at();

// check if deadline is past or not
let current_time = get_current_time_nanos();
Expand Down Expand Up @@ -90,6 +93,7 @@ impl ComputeHandler for WorkflowHandler {

// execute workflow with cancellation
let exec_result: Result<String>;
let exec_started_at = Instant::now();
tokio::select! {
_ = node.cancellation.cancelled() => {
log::info!("Received cancellation, quitting all tasks.");
Expand All @@ -99,10 +103,10 @@ impl ComputeHandler for WorkflowHandler {
exec_result = exec_result_inner.map_err(|e| eyre!("Execution error: {}", e.to_string()));
}
}
task_stats = task_stats.record_execution_time(exec_started_at);

let (publish_result, acceptance) = match exec_result {
let (message, acceptance) = match exec_result {
Ok(result) => {
log::warn!("Task {} result:", result);
// obtain public key from the payload
let task_public_key_bytes =
hex::decode(&task.public_key).wrap_err("Could not decode public key")?;
Expand All @@ -115,44 +119,56 @@ impl ComputeHandler for WorkflowHandler {
&task_public_key,
&node.config.secret_key,
model_name,
task_stats.record_published_at(),
)?;
let payload_str = serde_json::to_string(&payload)
.wrap_err("Could not serialize response payload")?;

// publish the result
// accept so that if there are others included in filter they can do the task
// prepare signed message
log::debug!(
"Publishing result for task {}\n{}",
task.task_id,
payload_str
);
let message = DKNMessage::new(payload_str, Self::RESPONSE_TOPIC);
(node.publish(message), MessageAcceptance::Accept)
// accept so that if there are others included in filter they can do the task
(message, MessageAcceptance::Accept)
}
Err(err) => {
// use pretty display string for error logging with causes
let err_string = format!("{:#}", err);
log::error!("Task {} failed: {}", task.task_id, err_string);

// prepare error payload
let error_payload =
TaskErrorPayload::new(task.task_id.clone(), err_string, model_name);
let error_payload = TaskErrorPayload {
task_id: task.task_id.clone(),
error: err_string,
model: model_name,
stats: task_stats.record_published_at(),
};
let error_payload_str = serde_json::to_string(&error_payload)
.wrap_err("Could not serialize error payload")?;

// publish the error result for diagnostics
// ignore just in case, workflow may be bugged
// prepare signed message
let message = DKNMessage::new_signed(
error_payload_str,
Self::RESPONSE_TOPIC,
&node.config.secret_key,
);
(node.publish(message), MessageAcceptance::Ignore)
// ignore just in case, workflow may be bugged
(message, MessageAcceptance::Ignore)
}
};

// if for some reason we couldnt publish the result, publish the error itself so that RPC doesnt hang
if let Err(publish_err) = publish_result {
// try publishing the result

if let Err(publish_err) = node.publish(message) {
let err_msg = format!("Could not publish result: {:?}", publish_err);
log::error!("{}", err_msg);

let payload = serde_json::json!({
"taskId": task.task_id,
"error": err_msg
"error": err_msg,
});
let message = DKNMessage::new_signed(
payload.to_string(),
Expand Down
2 changes: 1 addition & 1 deletion compute/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ impl DriaComputeNode {
} else if std::matches!(topic_str, PingpongHandler::RESPONSE_TOPIC | WorkflowHandler::RESPONSE_TOPIC) {
// since we are responding to these topics, we might receive messages from other compute nodes
// we can gracefully ignore them and propagate it to to others
log::debug!("Ignoring message for topic: {}", topic_str);
log::trace!("Ignoring message for topic: {}", topic_str);
self.p2p.validate_message(&message_id, &peer_id, gossipsub::MessageAcceptance::Accept)?;
} else {
// reject this message as its from a foreign topic
Expand Down
14 changes: 4 additions & 10 deletions compute/src/payloads/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};

use super::TaskStats;

/// A task error response.
/// Returning this as the payload helps to debug the errors received at client side.
#[derive(Debug, Clone, Serialize, Deserialize)]
Expand All @@ -11,14 +13,6 @@ pub struct TaskErrorPayload {
pub error: String,
/// Name of the model that caused the error.
pub model: String,
}

impl TaskErrorPayload {
pub fn new(task_id: String, error: String, model: String) -> Self {
Self {
task_id,
error,
model,
}
}
/// Task statistics.
pub stats: TaskStats,
}
3 changes: 3 additions & 0 deletions compute/src/payloads/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ pub use request::TaskRequestPayload;

mod response;
pub use response::TaskResponsePayload;

mod stats;
pub use stats::TaskStats;
18 changes: 15 additions & 3 deletions compute/src/payloads/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use eyre::Result;
use libsecp256k1::{PublicKey, SecretKey};
use serde::{Deserialize, Serialize};

use super::TaskStats;

/// A computation task is the task of computing a result from a given input. The result is encrypted with the public key of the requester.
/// Plain result is signed by the compute node's private key, and a commitment is computed from the signature and plain result.
///
Expand All @@ -19,6 +21,8 @@ pub struct TaskResponsePayload {
pub ciphertext: String,
/// Name of the model used for this task.
pub model: String,
/// Stats about the task execution.
pub stats: TaskStats,
}

impl TaskResponsePayload {
Expand All @@ -32,6 +36,7 @@ impl TaskResponsePayload {
encrypting_public_key: &PublicKey,
signing_secret_key: &SecretKey,
model: String,
stats: TaskStats,
) -> Result<Self> {
// create the message `task_id || payload`
let mut preimage = Vec::new();
Expand All @@ -47,6 +52,7 @@ impl TaskResponsePayload {
signature,
ciphertext,
model,
stats,
})
}
}
Expand Down Expand Up @@ -74,9 +80,15 @@ mod tests {
let task_id = uuid::Uuid::new_v4().to_string();

// creates a signed and encrypted payload
let payload =
TaskResponsePayload::new(RESULT, &task_id, &task_pk, &signer_sk, MODEL.to_string())
.expect("Should create payload");
let payload = TaskResponsePayload::new(
RESULT,
&task_id,
&task_pk,
&signer_sk,
MODEL.to_string(),
Default::default(),
)
.expect("Should create payload");

// decrypt result and compare it to plaintext
let ciphertext_bytes = hex::decode(payload.ciphertext).unwrap();
Expand Down
55 changes: 55 additions & 0 deletions compute/src/payloads/stats.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use serde::{Deserialize, Serialize};
use std::time::Instant;

use crate::utils::get_current_time_nanos;

/// A task stat.
/// Returning this as the payload helps to debug the errors received at client side.
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TaskStats {
/// Timestamp at which the task was received from network & parsed.
pub received_at: u128,
/// Timestamp at which the task was published back to network.
pub published_at: u128,
/// Time taken to execute the task.
pub execution_time: u128,
}

impl TaskStats {
/// Records the current timestamp within `received_at`.
pub fn record_received_at(mut self) -> Self {
// can unwrap safely here as UNIX_EPOCH is always smaller than now
self.received_at = get_current_time_nanos();
self
}

/// Records the current timestamp within `published_at`.
pub fn record_published_at(mut self) -> Self {
self.published_at = get_current_time_nanos();
self
}

pub fn record_execution_time(mut self, started_at: Instant) -> Self {
self.execution_time = Instant::now().duration_since(started_at).as_nanos();
self
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_stats() {
let mut stats = TaskStats::default();

assert_eq!(stats.received_at, 0);
stats = stats.record_received_at();
assert_ne!(stats.received_at, 0);

assert_eq!(stats.published_at, 0);
stats = stats.record_published_at();
assert_ne!(stats.published_at, 0);
}
}
6 changes: 3 additions & 3 deletions compute/src/utils/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@ use dkn_p2p::libp2p::{multiaddr::Protocol, Multiaddr};
use port_check::is_port_reachable;
use std::{
net::{Ipv4Addr, SocketAddrV4},
time::{Duration, SystemTime},
time::SystemTime,
};

/// Returns the current time in nanoseconds since the Unix epoch.
///
/// If a `SystemTimeError` occurs, will return 0 just to keep things running.
#[inline]
#[inline(always)]
pub fn get_current_time_nanos() -> u128 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_else(|e| {
log::error!("Error getting current time: {}", e);
Duration::new(0, 0)
Default::default()
})
.as_nanos()
}
Expand Down

0 comments on commit b0c91d7

Please sign in to comment.