Skip to content

Commit

Permalink
feat(agent): basic DB persistence
Browse files Browse the repository at this point in the history
Signed-off-by: Zander Franks <[email protected]>
  • Loading branch information
voximity committed Jun 26, 2024
1 parent 399431e commit 1687dc1
Show file tree
Hide file tree
Showing 18 changed files with 163 additions and 58 deletions.
5 changes: 5 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/snops-agent/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ nix = { workspace = true, features = ["signal"] }
reqwest = { workspace = true, features = ["json", "stream"] }
serde_json.workspace = true
simple_moving_average.workspace = true
sled.workspace = true
snops-common = { workspace = true, features = ["aot_cmds"] }
tarpc.workspace = true
tokio = { workspace = true, features = [
Expand Down
80 changes: 80 additions & 0 deletions crates/snops-agent/src/db.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use std::{
io::{Read, Write},
path::PathBuf,
sync::Mutex,
};

use snops_common::{
db::{error::DatabaseError, tree::DbTree, Database as DatabaseTrait},
format::{DataFormat, DataReadError, DataWriteError},
};

#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[repr(u8)]
pub enum AgentDbString {
/// JSON web token of agent.
Jwt,
/// Process ID of node. Used to keep track of zombie node processes.
NodePid,
}

impl DataFormat for AgentDbString {
type Header = u8;
const LATEST_HEADER: Self::Header = 1;

fn read_data<R: Read>(reader: &mut R, header: &Self::Header) -> Result<Self, DataReadError> {
if *header != Self::LATEST_HEADER {
return Err(DataReadError::unsupported(
"AgentDbString",
Self::LATEST_HEADER,
header,
));
}

Ok(match u8::read_data(reader, &())? {
0 => Self::Jwt,
1 => Self::NodePid,
_ => return Err(DataReadError::custom("invalid agent DB string type")),
})
}

fn write_data<W: Write>(&self, writer: &mut W) -> Result<usize, DataWriteError> {
(*self as u8).write_data(writer)
}
}

pub struct Database {
#[allow(unused)]
pub db: sled::Db,

pub jwt_mutex: Mutex<Option<String>>,
pub strings: DbTree<AgentDbString, String>,
}

impl DatabaseTrait for Database {
fn open(path: &PathBuf) -> Result<Self, DatabaseError> {
let db = sled::open(path)?;
let strings = DbTree::new(db.open_tree(b"v1/strings")?);
let jwt_mutex = Mutex::new(strings.restore(&AgentDbString::Jwt)?);

Ok(Self {
db,
jwt_mutex,
strings,
})
}
}

impl Database {
pub fn jwt(&self) -> Option<String> {
self.jwt_mutex.lock().unwrap().clone()
}

pub fn set_jwt(&self, jwt: Option<String>) -> Result<(), DatabaseError> {
let mut lock = self.jwt_mutex.lock().unwrap();
self.strings
.save_option(&AgentDbString::Jwt, jwt.as_ref())?;
*lock = jwt;
Ok(())
}
}
30 changes: 14 additions & 16 deletions crates/snops-agent/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod api;
mod cli;
mod db;
mod metrics;
mod net;
mod reconcile;
Expand All @@ -22,7 +23,9 @@ use futures_util::stream::{FuturesUnordered, StreamExt};
use http::HeaderValue;
use snops_common::{
constant::{ENV_AGENT_KEY, HEADER_AGENT_KEY},
db::Database,
rpc::{agent::AgentService, control::ControlServiceClient, RpcTransport},
util::OpaqueDebug,
};
use tarpc::server::Channel;
use tokio::{
Expand All @@ -36,7 +39,7 @@ use tokio_tungstenite::{
use tracing::{error, info, level_filters::LevelFilter, warn};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

use crate::rpc::{AgentRpcServer, MuxedMessageIncoming, MuxedMessageOutgoing, JWT_FILE};
use crate::rpc::{AgentRpcServer, MuxedMessageIncoming, MuxedMessageOutgoing};
use crate::state::GlobalState;

const PING_HEADER: &[u8] = b"snops-agent";
Expand Down Expand Up @@ -108,10 +111,8 @@ async fn main() {
.await
.expect("failed to create data path");

// get the JWT from the file, if possible
let jwt = tokio::fs::read_to_string(args.path.join(JWT_FILE))
.await
.ok();
// open the database
let db = db::Database::open(&args.path.join("store")).expect("failed to open database");

// create rpc channels
let (client_response_in, client_transport, mut client_request_out) = RpcTransport::new();
Expand All @@ -134,14 +135,14 @@ async fn main() {

// create the client state
let state = Arc::new(GlobalState {
client,
db: OpaqueDebug(db),
started: Instant::now(),
connected: Mutex::new(Instant::now()),
client,
external_addr,
internal_addrs,
cli: args,
endpoint,
jwt: Mutex::new(jwt),
loki: Default::default(),
env_info: Default::default(),
agent_state: Default::default(),
Expand Down Expand Up @@ -193,15 +194,12 @@ async fn main() {
state.env_info.write().await.take();

// attach JWT if we have one
{
let jwt = state.jwt.lock().expect("failed to acquire jwt");
if let Some(jwt) = jwt.as_deref() {
req.headers_mut().insert(
"Authorization",
HeaderValue::from_bytes(format!("Bearer {jwt}").as_bytes())
.expect("attach authorization header"),
);
}
if let Some(jwt) = state.db.jwt() {
req.headers_mut().insert(
"Authorization",
HeaderValue::from_bytes(format!("Bearer {jwt}").as_bytes())
.expect("attach authorization header"),
);
}

// attach agent key if one is set in env vars
Expand Down
14 changes: 3 additions & 11 deletions crates/snops-agent/src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ use tracing::{debug, error, info, trace, warn};

use crate::{api, metrics::MetricComputer, reconcile, state::AppState};

/// The JWT file name.
pub const JWT_FILE: &str = "jwt";

/// A multiplexed message, incoming on the websocket.
pub type MuxedMessageIncoming =
MuxMessage<Response<ControlServiceResponse>, ClientMessage<AgentServiceRequest>>;
Expand Down Expand Up @@ -53,14 +50,9 @@ impl AgentService for AgentRpcServer {
if let Some(token) = handshake.jwt {
// cache the JWT in the state JWT mutex
self.state
.jwt
.lock()
.expect("failed to acquire JWT lock")
.replace(token.to_owned());

tokio::fs::write(self.state.cli.path.join(JWT_FILE), token)
.await
.expect("failed to write jwt file");
.db
.set_jwt(Some(token))
.map_err(|_| ReconcileError::Database)?;
}

// store loki server URL
Expand Down
5 changes: 3 additions & 2 deletions crates/snops-agent/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use snops_common::{
api::EnvInfo,
rpc::control::ControlServiceClient,
state::{AgentId, AgentPeer, AgentState, EnvId, TransferId, TransferStatus},
util::OpaqueDebug,
};
use tarpc::context;
use tokio::{
Expand All @@ -22,7 +23,7 @@ use tokio::{
};
use tracing::info;

use crate::{cli::Cli, metrics::Metrics, transfers::TransferTx};
use crate::{cli::Cli, db::Database, metrics::Metrics, transfers::TransferTx};

pub const NODE_GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);

Expand All @@ -31,6 +32,7 @@ pub type AppState = Arc<GlobalState>;
/// Global state for this agent runner.
pub struct GlobalState {
pub client: ControlServiceClient,
pub db: OpaqueDebug<Database>,
pub started: Instant,
pub connected: Mutex<Instant>,

Expand All @@ -39,7 +41,6 @@ pub struct GlobalState {
pub status_api_port: u16,
pub cli: Cli,
pub endpoint: String,
pub jwt: Mutex<Option<String>>,
pub loki: Mutex<Option<Url>>,
pub agent_state: RwLock<AgentState>,
pub env_info: RwLock<Option<(EnvId, EnvInfo)>>,
Expand Down
4 changes: 4 additions & 0 deletions crates/snops-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ mangen = ["anyhow", "clap_mangen"]

[dependencies]
anyhow = { workspace = true, optional = true }
bincode.workspace = true
bytes.workspace = true
checkpoint = { workspace = true, features = ["serde"] }
chrono = { workspace = true, features = ["serde"] }
clap.workspace = true
Expand All @@ -26,10 +28,12 @@ regex.workspace = true
rand.workspace = true
serde.workspace = true
serde_json.workspace = true
sled.workspace = true
strum_macros.workspace = true
tarpc.workspace = true
thiserror.workspace = true
tokio = { workspace = true, features = ["process"] }
tracing.workspace = true
url.workspace = true
wildmatch.workspace = true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub enum DatabaseError {
#[error("transaction error: {0}")]
TransactionError(#[from] sled::transaction::TransactionError),
#[error("error writing data: {0}")]
DataWriteError(#[from] snops_common::format::DataWriteError),
DataWriteError(#[from] crate::format::DataWriteError),
#[error("error reading data: {0}")]
DataReadError(#[from] snops_common::format::DataReadError),
DataReadError(#[from] crate::format::DataReadError),
}
10 changes: 10 additions & 0 deletions crates/snops-common/src/db/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use std::path::PathBuf;

use self::error::DatabaseError;

pub mod error;
pub mod tree;

pub trait Database: Sized {
fn open(path: &PathBuf) -> Result<Self, DatabaseError>;
}
31 changes: 19 additions & 12 deletions crates/snops/src/db/tree.rs → crates/snops-common/src/db/tree.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
use bytes::Buf;
use snops_common::format::{read_dataformat, write_dataformat, DataFormat};

use super::error::DatabaseError;
use crate::format::{read_dataformat, write_dataformat, DataFormat};

pub struct DbTree<Key: DataFormat, Value: DataFormat> {
pub struct DbTree<K, V> {
tree: sled::Tree,
_phantom: std::marker::PhantomData<(Key, Value)>,
_phantom: std::marker::PhantomData<(K, V)>,
}

impl<Key: DataFormat, Value: DataFormat> DbTree<Key, Value> {
impl<K: DataFormat, V: DataFormat> DbTree<K, V> {
pub fn new(tree: sled::Tree) -> Self {
Self {
tree,
_phantom: std::marker::PhantomData,
}
}

pub fn read_all(&self) -> impl Iterator<Item = (Key, Value)> {
pub fn read_all(&self) -> impl Iterator<Item = (K, V)> {
self.tree.iter().filter_map(|row| {
let (key_bytes, value_bytes) = match row {
Ok((key, value)) => (key, value),
Expand All @@ -26,7 +26,7 @@ impl<Key: DataFormat, Value: DataFormat> DbTree<Key, Value> {
}
};

let key = match Key::read_data(&mut key_bytes.reader(), &Key::LATEST_HEADER) {
let key = match K::read_data(&mut key_bytes.reader(), &K::LATEST_HEADER) {
Ok(key) => key,
Err(e) => {
tracing::error!("Error parsing key from store: {e}");
Expand All @@ -49,7 +49,7 @@ impl<Key: DataFormat, Value: DataFormat> DbTree<Key, Value> {
pub fn read_with_prefix<Prefix: DataFormat>(
&self,
prefix: &Prefix,
) -> Result<impl Iterator<Item = (Key, Value)>, DatabaseError> {
) -> Result<impl Iterator<Item = (K, V)>, DatabaseError> {
Ok(self
.tree
.scan_prefix(prefix.to_byte_vec()?)
Expand All @@ -62,7 +62,7 @@ impl<Key: DataFormat, Value: DataFormat> DbTree<Key, Value> {
}
};

let key = match Key::read_data(&mut key_bytes.reader(), &Key::LATEST_HEADER) {
let key = match K::read_data(&mut key_bytes.reader(), &K::LATEST_HEADER) {
Ok(key) => key,
Err(e) => {
tracing::error!("Error parsing key from store: {e}");
Expand All @@ -82,23 +82,30 @@ impl<Key: DataFormat, Value: DataFormat> DbTree<Key, Value> {
}))
}

pub fn restore(&self, key: &Key) -> Result<Option<Value>, DatabaseError> {
pub fn restore(&self, key: &K) -> Result<Option<V>, DatabaseError> {
Ok(self
.tree
.get(key.to_byte_vec()?)?
.map(|value_bytes| read_dataformat(&mut value_bytes.reader()))
.transpose()?)
}

pub fn save(&self, key: &Key, value: &Value) -> Result<(), DatabaseError> {
pub fn save(&self, key: &K, value: &V) -> Result<(), DatabaseError> {
let key_bytes = key.to_byte_vec()?;
let mut value_bytes = Vec::new();
write_dataformat(&mut value_bytes, value)?;
self.tree.insert(key_bytes, value_bytes)?;
Ok(())
}

pub fn delete(&self, key: &Key) -> Result<bool, DatabaseError> {
pub fn save_option(&self, key: &K, value: Option<&V>) -> Result<(), DatabaseError> {
match value {
Some(value) => self.save(key, value),
None => self.delete(key).map(|_| ()),
}
}

pub fn delete(&self, key: &K) -> Result<bool, DatabaseError> {
Ok(self.tree.remove(key.to_byte_vec()?)?.is_some())
}

Expand All @@ -118,7 +125,7 @@ impl<Key: DataFormat, Value: DataFormat> DbTree<Key, Value> {
}
};

let key = match Key::read_data(&mut key_bytes.reader(), &Key::LATEST_HEADER) {
let key = match K::read_data(&mut key_bytes.reader(), &K::LATEST_HEADER) {
Ok(key) => key,
Err(e) => {
tracing::error!("Error parsing key from store: {e}");
Expand Down
Loading

0 comments on commit 1687dc1

Please sign in to comment.