Skip to content

Commit

Permalink
refactor: switch wal to sync implementation
Browse files Browse the repository at this point in the history
Signed-off-by: bsbds <[email protected]>
  • Loading branch information
bsbds committed Mar 19, 2024
1 parent e56a82d commit a95e03c
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 334 deletions.
129 changes: 53 additions & 76 deletions crates/curp/src/server/storage/wal/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use curp_external_api::LogIndex;
use serde::{de::DeserializeOwned, Serialize};
use sha2::{Digest, Sha256};
use thiserror::Error;
use tokio_util::codec::{Decoder, Encoder};

use super::{
error::{CorruptType, WALError},
framed::{Decoder, Encoder},
util::{get_checksum, validate_data},
};
use crate::log_entry::LogEntry;
Expand Down Expand Up @@ -104,18 +104,13 @@ where
{
type Error = io::Error;

fn encode(
&mut self,
frames: Vec<DataFrame<C>>,
dst: &mut bytes::BytesMut,
) -> Result<(), Self::Error> {
let frames_bytes: Vec<_> = frames.into_iter().flat_map(|f| f.encode()).collect();
let commit_frame = CommitFrame::new_from_data(&frames_bytes);
/// Encodes a frame
fn encode(&mut self, frames: Vec<DataFrame<C>>) -> Result<Vec<u8>, Self::Error> {
let mut frame_data: Vec<_> = frames.into_iter().flat_map(|f| f.encode()).collect();
let commit_frame = CommitFrame::new_from_data(&frame_data);
frame_data.extend_from_slice(&commit_frame.encode());

dst.extend(frames_bytes);
dst.extend(commit_frame.encode());

Ok(())
Ok(frame_data)
}
}

Expand All @@ -127,30 +122,34 @@ where

type Error = WALError;

fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
loop {
if let Some((frame, len)) = WALFrame::<C>::decode(src)? {
let decoded_bytes = src.split_to(len);
match frame {
WALFrame::Data(data) => {
self.frames.push(data);
self.hasher.update(decoded_bytes);
}
WALFrame::Commit(commit) => {
let frames_bytes: Vec<_> =
self.frames.iter().flat_map(DataFrame::encode).collect();
let checksum = self.hasher.clone().finalize();
self.hasher.reset();
if commit.validate(&checksum) {
return Ok(Some(self.frames.drain(..).collect()));
}
return Err(WALError::Corrupted(CorruptType::Checksum));
#[allow(clippy::arithmetic_side_effects)] // the arithmetic only used as slice indices
fn decode(&mut self, src: &[u8]) -> Result<(Self::Item, usize), Self::Error> {
let mut current = 0;
while current < src.len() {
let next = src.get(current..).ok_or(WALError::UnexpectedEof)?;
let Some((frame, len)) = WALFrame::<C>::decode(next)? else {
return Err(WALError::UnexpectedEof);
};
let decoded_bytes = src
.get(current..current + len)
.ok_or(WALError::UnexpectedEof)?;
current += len;
match frame {
WALFrame::Data(data) => {
self.frames.push(data);
self.hasher.update(decoded_bytes);
}
WALFrame::Commit(commit) => {
let checksum = self.hasher.clone().finalize();
self.hasher.reset();
if commit.validate(&checksum) {
return Ok((self.frames.drain(..).collect(), current));
}
return Err(WALError::Corrupted(CorruptType::Checksum));
}
} else {
return Ok(None);
}
}
Err(WALError::UnexpectedEof)
}
}

Expand Down Expand Up @@ -191,7 +190,7 @@ where
.unwrap_or_else(|_| unreachable!("this conversion will always succeed"));
let frame_type = header[0];
match frame_type {
INVALID => Err(WALError::MaybeEnded),
INVALID => Err(WALError::UnexpectedEof),
ENTRY => Self::decode_entry(header, &src[8..]),
SEAL => Self::decode_seal_index(header),
COMMIT => Self::decode_commit(&src[8..]),
Expand Down Expand Up @@ -323,25 +322,19 @@ mod tests {

#[tokio::test]
async fn frame_encode_decode_is_ok() {
let file = TokioFile::from(tempfile().unwrap());
let mut framed = Framed::new(file, WAL::<TestCommand>::new());
let mut codec = WAL::<TestCommand>::new();
let entry = LogEntry::<TestCommand>::new(1, 1, ProposeId(1, 2), EntryData::Empty);
let data_frame = DataFrame::Entry(entry.clone());
let seal_frame = DataFrame::<TestCommand>::SealIndex(1);
framed.send(vec![data_frame]).await.unwrap();
framed.send(vec![seal_frame]).await.unwrap();
framed.get_mut().flush().await;

let mut file = framed.into_inner();
file.seek(io::SeekFrom::Start(0)).await.unwrap();
let mut framed = Framed::new(file, WAL::<TestCommand>::new());
let mut encoded = codec.encode(vec![data_frame]).unwrap();
encoded.extend_from_slice(&codec.encode(vec![seal_frame]).unwrap());

let data_frame_get = &framed.next().await.unwrap().unwrap()[0];
let seal_frame_get = &framed.next().await.unwrap().unwrap()[0];
let DataFrame::Entry(ref entry_get) = *data_frame_get else {
let (data_frame_get, len) = codec.decode(&encoded).unwrap();
let (seal_frame_get, _) = codec.decode(&encoded[len..]).unwrap();
let DataFrame::Entry(ref entry_get) = data_frame_get[0] else {
panic!("frame should be type: DataFrame::Entry");
};
let DataFrame::SealIndex(ref index) = *seal_frame_get else {
let DataFrame::SealIndex(ref index) = seal_frame_get[0] else {
panic!("frame should be type: DataFrame::Entry");
};

Expand All @@ -351,46 +344,30 @@ mod tests {

#[tokio::test]
async fn frame_zero_write_will_be_detected() {
let file = TokioFile::from(tempfile().unwrap());
let mut framed = Framed::new(file, WAL::<TestCommand>::new());
let mut codec = WAL::<TestCommand>::new();
let entry = LogEntry::<TestCommand>::new(1, 1, ProposeId(1, 2), EntryData::Empty);
let data_frame = DataFrame::Entry(entry.clone());
framed.send(vec![data_frame]).await.unwrap();
framed.get_mut().flush().await;

let mut file = framed.into_inner();
/// zero the first byte, it will reach a success state,
/// all following data will be truncated
file.seek(io::SeekFrom::Start(0)).await.unwrap();
file.write_u8(0).await;

file.seek(io::SeekFrom::Start(0)).await.unwrap();

let mut framed = Framed::new(file, WAL::<TestCommand>::new());
let seal_frame = DataFrame::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame]).unwrap();
encoded[0] = 0;

let err = framed.next().await.unwrap().unwrap_err();
assert!(matches!(err, WALError::MaybeEnded), "error {err} not match");
let err = codec.decode(&encoded).unwrap_err();
assert!(
matches!(err, WALError::UnexpectedEof),
"error {err} not match"
);
}

#[tokio::test]
async fn frame_corrupt_will_be_detected() {
let file = TokioFile::from(tempfile().unwrap());
let mut framed = Framed::new(file, WAL::<TestCommand>::new());
let mut codec = WAL::<TestCommand>::new();
let entry = LogEntry::<TestCommand>::new(1, 1, ProposeId(1, 2), EntryData::Empty);
let data_frame = DataFrame::Entry(entry.clone());
framed.send(vec![data_frame]).await.unwrap();
framed.get_mut().flush().await;

let mut file = framed.into_inner();
/// This will cause a failure state
file.seek(io::SeekFrom::Start(1)).await.unwrap();
file.write_u8(0).await;

file.seek(io::SeekFrom::Start(0)).await.unwrap();

let mut framed = Framed::new(file, WAL::<TestCommand>::new());
let seal_frame = DataFrame::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame]).unwrap();
encoded[1] = 0;

let err = framed.next().await.unwrap().unwrap_err();
let err = codec.decode(&encoded).unwrap_err();
assert!(
matches!(err, WALError::Corrupted(_)),
"error {err} not match"
Expand Down
7 changes: 2 additions & 5 deletions crates/curp/src/server/storage/wal/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@ use thiserror::Error;
/// Errors of the `WALStorage`
#[derive(Debug, Error)]
pub(crate) enum WALError {
/// The WAL segment might reach on end
///
/// NOTE: This exists because we cannot tell the difference between a corrupted WAL
/// and a normally ended WAL, as the segment files are all preallocated with zeros
/// Unexpected end of file of the WAL
#[error("WAL ended")]
MaybeEnded,
UnexpectedEof,
/// The WAL corrupt error
#[error("WAL corrupted: {0}")]
Corrupted(CorruptType),
Expand Down
22 changes: 22 additions & 0 deletions crates/curp/src/server/storage/wal/framed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use std::io;

/// Decoding of frames via buffers.
pub(super) trait Decoder {
/// The type of decoded frames.
type Item;

/// The type of unrecoverable frame decoding errors.
type Error: From<io::Error>;

/// Attempts to decode a frame from the provided buffer of bytes.
fn decode(&mut self, src: &[u8]) -> Result<(Self::Item, usize), Self::Error>;
}

/// Trait of helper objects to write out messages as bytes
pub(super) trait Encoder<Item> {
/// The type of encoding errors.
type Error: From<io::Error>;

/// Encodes a frame
fn encode(&mut self, item: Item) -> Result<Vec<u8>, Self::Error>;
}
3 changes: 3 additions & 0 deletions crates/curp/src/server/storage/wal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ mod segment;
/// File utils
mod util;

/// Framed traits
mod framed;

/// The magic of the WAL file
const WAL_MAGIC: u32 = 0xd86e_0be2;

Expand Down
22 changes: 9 additions & 13 deletions crates/curp/src/server/storage/wal/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub(super) struct FilePipeline {
/// The size of the temp file
file_size: u64,
/// The file receive stream
file_stream: RecvStream<'static, LockedFile>,
file_stream: flume::IntoIter<LockedFile>,
/// Stopped flag
stopped: Arc<AtomicBool>,
}
Expand Down Expand Up @@ -97,7 +97,7 @@ impl FilePipeline {
Ok(Self {
dir,
file_size,
file_stream: file_rx.into_stream(),
file_stream: file_rx.into_iter(),
stopped,
})
}
Expand Down Expand Up @@ -136,18 +136,14 @@ impl Drop for FilePipeline {
}
}

impl Stream for FilePipeline {
impl Iterator for FilePipeline {
type Item = io::Result<LockedFile>;

fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
fn next(&mut self) -> Option<Self::Item> {
if self.stopped.load(Ordering::Relaxed) {
return Poll::Ready(None);
return None;
}

self.file_stream.poll_next_unpin(cx).map(|opt| opt.map(Ok))
self.file_stream.next().map(Ok)
}
}

Expand Down Expand Up @@ -175,11 +171,11 @@ mod tests {
let file = file.into_std();
assert_eq!(file.metadata().unwrap().len(), file_size,);
};
let file0 = pipeline.next().await.unwrap().unwrap();
let file0 = pipeline.next().unwrap().unwrap();
check_size(file0);
let file1 = pipeline.next().await.unwrap().unwrap();
let file1 = pipeline.next().unwrap().unwrap();
check_size(file1);
pipeline.stop();
assert!(pipeline.next().await.is_none());
assert!(pipeline.next().is_none());
}
}
Loading

0 comments on commit a95e03c

Please sign in to comment.