Skip to content

Commit

Permalink
refactor(rdb_saver): Add SnapshotDataConsumer to SliceSnapshot (#4287)
Browse files Browse the repository at this point in the history
* refactor(rdb_saver): Add SnapshotDataConsumer to SliceSnapshot

fixes #4218

Signed-off-by: Stepan Bagritsevich <[email protected]>

* refactor: address comments

Signed-off-by: Stepan Bagritsevich <[email protected]>

---------

Signed-off-by: Stepan Bagritsevich <[email protected]>
  • Loading branch information
BagritsevichStepan authored Dec 23, 2024
1 parent d162094 commit 612d50d
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 75 deletions.
2 changes: 1 addition & 1 deletion src/server/detail/save_stages_controller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ GenericError RdbSnapshot::Start(SaveMode save_mode, const std::string& path,
}

error_code RdbSnapshot::SaveBody() {
return saver_->SaveBody(&cntx_);
return saver_->SaveBody(cntx_);
}

error_code RdbSnapshot::WaitSnapshotInShard(EngineShard* shard) {
Expand Down
2 changes: 1 addition & 1 deletion src/server/dflycmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineSha
}

if (flow->start_partial_sync_at.has_value())
saver->StartIncrementalSnapshotInShard(cntx, shard, *flow->start_partial_sync_at);
saver->StartIncrementalSnapshotInShard(*flow->start_partial_sync_at, cntx, shard);
else
saver->StartSnapshotInShard(true, cntx, shard);

Expand Down
80 changes: 40 additions & 40 deletions src/server/rdb_save.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ error_code AlignedBuffer::Flush() {
return upstream_->Write(&ivec, 1);
}

class RdbSaver::Impl {
class RdbSaver::Impl final : public SliceSnapshot::SnapshotDataConsumerInterface {
private:
void CleanShardSnapshots();

Expand All @@ -1078,11 +1078,16 @@ class RdbSaver::Impl {
~Impl();

void StartSnapshotting(bool stream_journal, Context* cntx, EngineShard* shard);
void StartIncrementalSnapshotting(Context* cntx, EngineShard* shard, LSN start_lsn);
void StartIncrementalSnapshotting(LSN start_lsn, Context* cntx, EngineShard* shard);

void StopSnapshotting(EngineShard* shard);
void WaitForSnapshottingFinish(EngineShard* shard);

// Pushes snapshot data. Called from SliceSnapshot
void ConsumeData(std::string data, Context* cntx) override;
// Finalizes the snapshot writing. Called from SliceSnapshot
void Finalize() override;

// used only for legacy rdb save flows.
error_code ConsumeChannel(const Cancellation* cll);

Expand Down Expand Up @@ -1115,8 +1120,6 @@ class RdbSaver::Impl {
}

private:
void PushSnapshotData(Context* cntx, string record);
void FinalizeSnapshotWriting();
error_code WriteRecord(io::Bytes src);

unique_ptr<SliceSnapshot>& GetSnapshot(EngineShard* shard);
Expand Down Expand Up @@ -1252,49 +1255,26 @@ error_code RdbSaver::Impl::WriteRecord(io::Bytes src) {
return ec;
}

void RdbSaver::Impl::PushSnapshotData(Context* cntx, string record) {
if (cntx->IsCancelled()) {
return;
}
if (channel_) { // Rdb write to channel
channel_->Push(record);
} else { // Write directly to socket
auto ec = WriteRecord(io::Buffer(record));
if (ec) {
cntx->ReportError(ec);
}
}
}

void RdbSaver::Impl::FinalizeSnapshotWriting() {
if (channel_) {
channel_->StartClosing();
}
}

void RdbSaver::Impl::StartSnapshotting(bool stream_journal, Context* cntx, EngineShard* shard) {
auto& s = GetSnapshot(shard);
auto& db_slice = namespaces->GetDefaultNamespace().GetDbSlice(shard->shard_id());
auto on_snapshot_finish = std::bind(&RdbSaver::Impl::FinalizeSnapshotWriting, this);
auto push_cb = std::bind(&RdbSaver::Impl::PushSnapshotData, this, cntx, std::placeholders::_1);

s = std::make_unique<SliceSnapshot>(&db_slice, compression_mode_, push_cb, on_snapshot_finish);
s = std::make_unique<SliceSnapshot>(compression_mode_, &db_slice, this, cntx);

const auto allow_flush = (save_mode_ != SaveMode::RDB) ? SliceSnapshot::SnapshotFlush::kAllow
: SliceSnapshot::SnapshotFlush::kDisallow;

s->Start(stream_journal, cntx->GetCancellation(), allow_flush);
s->Start(stream_journal, allow_flush);
}

void RdbSaver::Impl::StartIncrementalSnapshotting(Context* cntx, EngineShard* shard,
LSN start_lsn) {
void RdbSaver::Impl::StartIncrementalSnapshotting(LSN start_lsn, Context* cntx,
EngineShard* shard) {
auto& db_slice = namespaces->GetDefaultNamespace().GetDbSlice(shard->shard_id());
auto& s = GetSnapshot(shard);
auto on_finalize_cb = std::bind(&RdbSaver::Impl::FinalizeSnapshotWriting, this);
auto push_cb = std::bind(&RdbSaver::Impl::PushSnapshotData, this, cntx, std::placeholders::_1);
s = std::make_unique<SliceSnapshot>(&db_slice, compression_mode_, push_cb, on_finalize_cb);

s->StartIncremental(cntx, start_lsn);
s = std::make_unique<SliceSnapshot>(compression_mode_, &db_slice, this, cntx);

s->StartIncremental(start_lsn);
}

// called on save flow
Expand All @@ -1304,6 +1284,26 @@ void RdbSaver::Impl::WaitForSnapshottingFinish(EngineShard* shard) {
snapshot->WaitSnapshotting();
}

void RdbSaver::Impl::ConsumeData(std::string data, Context* cntx) {
if (cntx->IsCancelled()) {
return;
}
if (channel_) { // Rdb write to channel
channel_->Push(std::move(data));
} else { // Write directly to socket
auto ec = WriteRecord(io::Buffer(data));
if (ec) {
cntx->ReportError(ec);
}
}
}

void RdbSaver::Impl::Finalize() {
if (channel_) {
channel_->StartClosing();
}
}

// called from replication flow
void RdbSaver::Impl::StopSnapshotting(EngineShard* shard) {
auto& snapshot = GetSnapshot(shard);
Expand Down Expand Up @@ -1462,8 +1462,8 @@ void RdbSaver::StartSnapshotInShard(bool stream_journal, Context* cntx, EngineSh
impl_->StartSnapshotting(stream_journal, cntx, shard);
}

void RdbSaver::StartIncrementalSnapshotInShard(Context* cntx, EngineShard* shard, LSN start_lsn) {
impl_->StartIncrementalSnapshotting(cntx, shard, start_lsn);
void RdbSaver::StartIncrementalSnapshotInShard(LSN start_lsn, Context* cntx, EngineShard* shard) {
impl_->StartIncrementalSnapshotting(start_lsn, cntx, shard);
}

error_code RdbSaver::WaitSnapshotInShard(EngineShard* shard) {
Expand All @@ -1489,17 +1489,17 @@ error_code RdbSaver::SaveHeader(const GlobalData& glob_state) {
return error_code{};
}

error_code RdbSaver::SaveBody(Context* cntx) {
error_code RdbSaver::SaveBody(const Context& cntx) {
RETURN_ON_ERR(impl_->FlushSerializer());

if (save_mode_ == SaveMode::RDB) {
VLOG(1) << "SaveBody , snapshots count: " << impl_->Size();
error_code io_error = impl_->ConsumeChannel(cntx->GetCancellation());
error_code io_error = impl_->ConsumeChannel(cntx.GetCancellation());
if (io_error) {
return io_error;
}
if (cntx->GetError()) {
return cntx->GetError();
if (cntx.GetError()) {
return cntx.GetError();
}
} else {
DCHECK(save_mode_ == SaveMode::SUMMARY);
Expand Down
4 changes: 2 additions & 2 deletions src/server/rdb_save.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class RdbSaver {
void StartSnapshotInShard(bool stream_journal, Context* cntx, EngineShard* shard);

// Send only the incremental snapshot since start_lsn.
void StartIncrementalSnapshotInShard(Context* cntx, EngineShard* shard, LSN start_lsn);
void StartIncrementalSnapshotInShard(LSN start_lsn, Context* cntx, EngineShard* shard);

// Stops full-sync serialization for replication in the shard's thread.
std::error_code StopFullSyncInShard(EngineShard* shard);
Expand All @@ -107,7 +107,7 @@ class RdbSaver {

// Writes the RDB file into sink. Waits for the serialization to finish.
// Called only for save rdb flow and save df on summary file.
std::error_code SaveBody(Context* cntx);
std::error_code SaveBody(const Context& cntx);

// Fills freq_map with the histogram of rdb types.
void FillFreqMap(RdbTypeFreqMap* freq_map);
Expand Down
39 changes: 17 additions & 22 deletions src/server/snapshot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,9 @@ constexpr size_t kMinBlobSize = 32_KB;

} // namespace

SliceSnapshot::SliceSnapshot(DbSlice* slice, CompressionMode compression_mode,
std::function<void(std::string)> on_push_record,
std::function<void()> on_snapshot_finish)
: db_slice_(slice),
compression_mode_(compression_mode),
on_push_(on_push_record),
on_snapshot_finish_(on_snapshot_finish) {
SliceSnapshot::SliceSnapshot(CompressionMode compression_mode, DbSlice* slice,
SnapshotDataConsumerInterface* consumer, Context* cntx)
: db_slice_(slice), compression_mode_(compression_mode), consumer_(consumer), cntx_(cntx) {
db_array_ = slice->databases();
tl_slice_snapshots.insert(this);
}
Expand All @@ -65,7 +61,7 @@ bool SliceSnapshot::IsSnaphotInProgress() {
return tl_slice_snapshots.size() > 0;
}

void SliceSnapshot::Start(bool stream_journal, const Cancellation* cll, SnapshotFlush allow_flush) {
void SliceSnapshot::Start(bool stream_journal, SnapshotFlush allow_flush) {
DCHECK(!snapshot_fb_.IsJoinable());

auto db_cb = absl::bind_front(&SliceSnapshot::OnDbChange, this);
Expand Down Expand Up @@ -95,19 +91,18 @@ void SliceSnapshot::Start(bool stream_journal, const Cancellation* cll, Snapshot

VLOG(1) << "DbSaver::Start - saving entries with version less than " << snapshot_version_;

snapshot_fb_ = fb2::Fiber("snapshot", [this, stream_journal, cll] {
IterateBucketsFb(cll, stream_journal);
snapshot_fb_ = fb2::Fiber("snapshot", [this, stream_journal] {
this->IterateBucketsFb(stream_journal);
db_slice_->UnregisterOnChange(snapshot_version_);
on_snapshot_finish_();
consumer_->Finalize();
});
}

void SliceSnapshot::StartIncremental(Context* cntx, LSN start_lsn) {
void SliceSnapshot::StartIncremental(LSN start_lsn) {
serializer_ = std::make_unique<RdbSerializer>(compression_mode_);

snapshot_fb_ = fb2::Fiber("incremental_snapshot", [cntx, start_lsn, this] {
this->SwitchIncrementalFb(cntx, start_lsn);
});
snapshot_fb_ = fb2::Fiber("incremental_snapshot",
[start_lsn, this] { this->SwitchIncrementalFb(start_lsn); });
}

// Called only for replication use-case.
Expand Down Expand Up @@ -144,7 +139,7 @@ void SliceSnapshot::FinalizeJournalStream(bool cancel) {
// and survived until it finished.

// Serializes all the entries with version less than snapshot_version_.
void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_sync_cut) {
void SliceSnapshot::IterateBucketsFb(bool send_full_sync_cut) {
{
auto fiber_name = absl::StrCat("SliceSnapshot-", ProactorBase::me()->GetPoolIndex());
ThisFiber::SetName(std::move(fiber_name));
Expand All @@ -156,7 +151,7 @@ void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_syn
}

for (DbIndex db_indx = 0; db_indx < db_array_.size(); ++db_indx) {
if (cll->IsCancelled())
if (cntx_->IsCancelled())
return;

if (!db_array_[db_indx])
Expand All @@ -168,7 +163,7 @@ void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_syn

VLOG(1) << "Start traversing " << pt->size() << " items for index " << db_indx;
do {
if (cll->IsCancelled()) {
if (cntx_->IsCancelled()) {
return;
}

Expand Down Expand Up @@ -204,15 +199,15 @@ void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_syn
<< stats_.loop_serialized << "/" << stats_.side_saved << "/" << stats_.savecb_calls;
}

void SliceSnapshot::SwitchIncrementalFb(Context* cntx, LSN lsn) {
void SliceSnapshot::SwitchIncrementalFb(LSN lsn) {
auto* journal = db_slice_->shard_owner()->journal();
DCHECK(journal);
DCHECK_LE(lsn, journal->GetLsn()) << "The replica tried to sync from the future.";

VLOG(1) << "Starting incremental snapshot from lsn=" << lsn;

// The replica sends the LSN of the next entry is wants to receive.
while (!cntx->IsCancelled() && journal->IsLSNInBuffer(lsn)) {
while (!cntx_->IsCancelled() && journal->IsLSNInBuffer(lsn)) {
serializer_->WriteJournalEntry(journal->GetEntry(lsn));
PushSerialized(false);
lsn++;
Expand All @@ -239,7 +234,7 @@ void SliceSnapshot::SwitchIncrementalFb(Context* cntx, LSN lsn) {
PushSerialized(true);
} else {
// We stopped but we didn't manage to send the whole stream.
cntx->ReportError(
cntx_->ReportError(
std::make_error_code(errc::state_not_recoverable),
absl::StrCat("Partial sync was unsuccessful because entry #", lsn,
" was dropped from the buffer. Current lsn=", journal->GetLsn()));
Expand Down Expand Up @@ -348,7 +343,7 @@ size_t SliceSnapshot::FlushSerialized(SerializerBase::FlushState flush_state) {
seq_cond_.wait(lk, [&] { return id == this->last_pushed_id_ + 1; });

// Blocking point.
on_push_(std::move(sfile.val));
consumer_->ConsumeData(std::move(sfile.val), cntx_);

DCHECK_EQ(last_pushed_id_ + 1, id);
last_pushed_id_ = id;
Expand Down
27 changes: 18 additions & 9 deletions src/server/snapshot.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,18 @@ struct Entry;
// over the sink until explicitly stopped.
class SliceSnapshot {
public:
SliceSnapshot(DbSlice* slice, CompressionMode compression_mode,
std::function<void(std::string)> on_push, std::function<void()> on_snapshot_finish);
// Represents a target for receiving snapshot data.
struct SnapshotDataConsumerInterface {
virtual ~SnapshotDataConsumerInterface() = default;

// Receives a chunk of snapshot data for processing
virtual void ConsumeData(std::string data, Context* cntx) = 0;
// Finalizes the snapshot writing
virtual void Finalize() = 0;
};

SliceSnapshot(CompressionMode compression_mode, DbSlice* slice,
SnapshotDataConsumerInterface* consumer, Context* cntx);
~SliceSnapshot();

static size_t GetThreadLocalMemoryUsage();
Expand All @@ -60,15 +70,14 @@ class SliceSnapshot {
// In journal streaming mode it needs to be stopped by either Stop or Cancel.
enum class SnapshotFlush { kAllow, kDisallow };

void Start(bool stream_journal, const Cancellation* cll,
SnapshotFlush allow_flush = SnapshotFlush::kDisallow);
void Start(bool stream_journal, SnapshotFlush allow_flush = SnapshotFlush::kDisallow);

// Initialize a snapshot that sends only the missing journal updates
// since start_lsn and then registers a callback switches into the
// journal streaming mode until stopped.
// If we're slower than the buffer and can't continue, `Cancel()` is
// called.
void StartIncremental(Context* cntx, LSN start_lsn);
void StartIncremental(LSN start_lsn);

// Finalizes journal streaming writes. Only called for replication.
// Blocking. Must be called from the Snapshot thread.
Expand All @@ -83,10 +92,10 @@ class SliceSnapshot {
private:
// Main snapshotting fiber that iterates over all buckets in the db slice
// and submits them to SerializeBucket.
void IterateBucketsFb(const Cancellation* cll, bool send_full_sync_cut);
void IterateBucketsFb(bool send_full_sync_cut);

// A fiber function that switches to the incremental mode
void SwitchIncrementalFb(Context* cntx, LSN lsn);
void SwitchIncrementalFb(LSN lsn);

// Called on traversing cursor by IterateBucketsFb.
bool BucketSaveCb(PrimeTable::bucket_iterator it);
Expand Down Expand Up @@ -171,8 +180,8 @@ class SliceSnapshot {

ThreadLocalMutex big_value_mu_;

std::function<void(std::string)> on_push_;
std::function<void()> on_snapshot_finish_;
SnapshotDataConsumerInterface* consumer_;
Context* cntx_;
};

} // namespace dfly

0 comments on commit 612d50d

Please sign in to comment.