diff --git a/src/server/detail/save_stages_controller.cc b/src/server/detail/save_stages_controller.cc index 6aa43db8c478..05e77679ad4a 100644 --- a/src/server/detail/save_stages_controller.cc +++ b/src/server/detail/save_stages_controller.cc @@ -124,7 +124,7 @@ GenericError RdbSnapshot::Start(SaveMode save_mode, const std::string& path, } error_code RdbSnapshot::SaveBody() { - return saver_->SaveBody(&cntx_); + return saver_->SaveBody(cntx_); } error_code RdbSnapshot::WaitSnapshotInShard(EngineShard* shard) { diff --git a/src/server/dflycmd.cc b/src/server/dflycmd.cc index c30f2ab51237..31afc8881800 100644 --- a/src/server/dflycmd.cc +++ b/src/server/dflycmd.cc @@ -586,7 +586,7 @@ OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineSha } if (flow->start_partial_sync_at.has_value()) - saver->StartIncrementalSnapshotInShard(cntx, shard, *flow->start_partial_sync_at); + saver->StartIncrementalSnapshotInShard(*flow->start_partial_sync_at, cntx, shard); else saver->StartSnapshotInShard(true, cntx, shard); diff --git a/src/server/rdb_save.cc b/src/server/rdb_save.cc index 72f977c36edc..40741108e866 100644 --- a/src/server/rdb_save.cc +++ b/src/server/rdb_save.cc @@ -1158,7 +1158,7 @@ error_code AlignedBuffer::Flush() { return upstream_->Write(&ivec, 1); } -class RdbSaver::Impl { +class RdbSaver::Impl final : public SliceSnapshot::SnapshotDataConsumerInterface { private: void CleanShardSnapshots(); @@ -1171,11 +1171,16 @@ class RdbSaver::Impl { ~Impl(); void StartSnapshotting(bool stream_journal, Context* cntx, EngineShard* shard); - void StartIncrementalSnapshotting(Context* cntx, EngineShard* shard, LSN start_lsn); + void StartIncrementalSnapshotting(LSN start_lsn, Context* cntx, EngineShard* shard); void StopSnapshotting(EngineShard* shard); void WaitForSnapshottingFinish(EngineShard* shard); + // Pushes snapshot data. Called from SliceSnapshot + void ConsumeData(std::string data, Context* cntx) override; + // Finalizes the snapshot writing. Called from SliceSnapshot + void Finalize() override; + // used only for legacy rdb save flows. error_code ConsumeChannel(const Cancellation* cll); @@ -1208,8 +1213,6 @@ class RdbSaver::Impl { } private: - void PushSnapshotData(Context* cntx, string record); - void FinalizeSnapshotWriting(); error_code WriteRecord(io::Bytes src); unique_ptr& GetSnapshot(EngineShard* shard); @@ -1345,49 +1348,26 @@ error_code RdbSaver::Impl::WriteRecord(io::Bytes src) { return ec; } -void RdbSaver::Impl::PushSnapshotData(Context* cntx, string record) { - if (cntx->IsCancelled()) { - return; - } - if (channel_) { // Rdb write to channel - channel_->Push(record); - } else { // Write directly to socket - auto ec = WriteRecord(io::Buffer(record)); - if (ec) { - cntx->ReportError(ec); - } - } -} - -void RdbSaver::Impl::FinalizeSnapshotWriting() { - if (channel_) { - channel_->StartClosing(); - } -} - void RdbSaver::Impl::StartSnapshotting(bool stream_journal, Context* cntx, EngineShard* shard) { auto& s = GetSnapshot(shard); auto& db_slice = namespaces->GetDefaultNamespace().GetDbSlice(shard->shard_id()); - auto on_snapshot_finish = std::bind(&RdbSaver::Impl::FinalizeSnapshotWriting, this); - auto push_cb = std::bind(&RdbSaver::Impl::PushSnapshotData, this, cntx, std::placeholders::_1); - s = std::make_unique(&db_slice, compression_mode_, push_cb, on_snapshot_finish); + s = std::make_unique(compression_mode_, &db_slice, this, cntx); const auto allow_flush = (save_mode_ != SaveMode::RDB) ? SliceSnapshot::SnapshotFlush::kAllow : SliceSnapshot::SnapshotFlush::kDisallow; - s->Start(stream_journal, cntx->GetCancellation(), allow_flush); + s->Start(stream_journal, allow_flush); } -void RdbSaver::Impl::StartIncrementalSnapshotting(Context* cntx, EngineShard* shard, - LSN start_lsn) { +void RdbSaver::Impl::StartIncrementalSnapshotting(LSN start_lsn, Context* cntx, + EngineShard* shard) { auto& db_slice = namespaces->GetDefaultNamespace().GetDbSlice(shard->shard_id()); auto& s = GetSnapshot(shard); - auto on_finalize_cb = std::bind(&RdbSaver::Impl::FinalizeSnapshotWriting, this); - auto push_cb = std::bind(&RdbSaver::Impl::PushSnapshotData, this, cntx, std::placeholders::_1); - s = std::make_unique(&db_slice, compression_mode_, push_cb, on_finalize_cb); - s->StartIncremental(cntx, start_lsn); + s = std::make_unique(compression_mode_, &db_slice, this, cntx); + + s->StartIncremental(start_lsn); } // called on save flow @@ -1397,6 +1377,26 @@ void RdbSaver::Impl::WaitForSnapshottingFinish(EngineShard* shard) { snapshot->WaitSnapshotting(); } +void RdbSaver::Impl::ConsumeData(std::string data, Context* cntx) { + if (cntx->IsCancelled()) { + return; + } + if (channel_) { // Rdb write to channel + channel_->Push(std::move(data)); + } else { // Write directly to socket + auto ec = WriteRecord(io::Buffer(data)); + if (ec) { + cntx->ReportError(ec); + } + } +} + +void RdbSaver::Impl::Finalize() { + if (channel_) { + channel_->StartClosing(); + } +} + // called from replication flow void RdbSaver::Impl::StopSnapshotting(EngineShard* shard) { auto& snapshot = GetSnapshot(shard); @@ -1547,8 +1547,8 @@ void RdbSaver::StartSnapshotInShard(bool stream_journal, Context* cntx, EngineSh impl_->StartSnapshotting(stream_journal, cntx, shard); } -void RdbSaver::StartIncrementalSnapshotInShard(Context* cntx, EngineShard* shard, LSN start_lsn) { - impl_->StartIncrementalSnapshotting(cntx, shard, start_lsn); +void RdbSaver::StartIncrementalSnapshotInShard(LSN start_lsn, Context* cntx, EngineShard* shard) { + impl_->StartIncrementalSnapshotting(start_lsn, cntx, shard); } error_code RdbSaver::WaitSnapshotInShard(EngineShard* shard) { @@ -1574,17 +1574,17 @@ error_code RdbSaver::SaveHeader(const GlobalData& glob_state) { return error_code{}; } -error_code RdbSaver::SaveBody(Context* cntx) { +error_code RdbSaver::SaveBody(const Context& cntx) { RETURN_ON_ERR(impl_->FlushSerializer()); if (save_mode_ == SaveMode::RDB) { VLOG(1) << "SaveBody , snapshots count: " << impl_->Size(); - error_code io_error = impl_->ConsumeChannel(cntx->GetCancellation()); + error_code io_error = impl_->ConsumeChannel(cntx.GetCancellation()); if (io_error) { return io_error; } - if (cntx->GetError()) { - return cntx->GetError(); + if (cntx.GetError()) { + return cntx.GetError(); } } else { DCHECK(save_mode_ == SaveMode::SUMMARY); diff --git a/src/server/rdb_save.h b/src/server/rdb_save.h index bed9ebfdbc7f..fc22e7b7939f 100644 --- a/src/server/rdb_save.h +++ b/src/server/rdb_save.h @@ -92,7 +92,7 @@ class RdbSaver { void StartSnapshotInShard(bool stream_journal, Context* cntx, EngineShard* shard); // Send only the incremental snapshot since start_lsn. - void StartIncrementalSnapshotInShard(Context* cntx, EngineShard* shard, LSN start_lsn); + void StartIncrementalSnapshotInShard(LSN start_lsn, Context* cntx, EngineShard* shard); // Stops full-sync serialization for replication in the shard's thread. std::error_code StopFullSyncInShard(EngineShard* shard); @@ -105,7 +105,7 @@ class RdbSaver { // Writes the RDB file into sink. Waits for the serialization to finish. // Called only for save rdb flow and save df on summary file. - std::error_code SaveBody(Context* cntx); + std::error_code SaveBody(const Context& cntx); // Fills freq_map with the histogram of rdb types. void FillFreqMap(RdbTypeFreqMap* freq_map); diff --git a/src/server/snapshot.cc b/src/server/snapshot.cc index e23bfe821d63..8c0f6d6cde34 100644 --- a/src/server/snapshot.cc +++ b/src/server/snapshot.cc @@ -37,13 +37,9 @@ constexpr size_t kMinBlobSize = 32_KB; } // namespace -SliceSnapshot::SliceSnapshot(DbSlice* slice, CompressionMode compression_mode, - std::function on_push_record, - std::function on_snapshot_finish) - : db_slice_(slice), - compression_mode_(compression_mode), - on_push_(on_push_record), - on_snapshot_finish_(on_snapshot_finish) { +SliceSnapshot::SliceSnapshot(CompressionMode compression_mode, DbSlice* slice, + SnapshotDataConsumerInterface* consumer, Context* cntx) + : db_slice_(slice), compression_mode_(compression_mode), consumer_(consumer), cntx_(cntx) { db_array_ = slice->databases(); tl_slice_snapshots.insert(this); } @@ -65,7 +61,7 @@ bool SliceSnapshot::IsSnaphotInProgress() { return tl_slice_snapshots.size() > 0; } -void SliceSnapshot::Start(bool stream_journal, const Cancellation* cll, SnapshotFlush allow_flush) { +void SliceSnapshot::Start(bool stream_journal, SnapshotFlush allow_flush) { DCHECK(!snapshot_fb_.IsJoinable()); auto db_cb = absl::bind_front(&SliceSnapshot::OnDbChange, this); @@ -95,19 +91,18 @@ void SliceSnapshot::Start(bool stream_journal, const Cancellation* cll, Snapshot VLOG(1) << "DbSaver::Start - saving entries with version less than " << snapshot_version_; - snapshot_fb_ = fb2::Fiber("snapshot", [this, stream_journal, cll] { - IterateBucketsFb(cll, stream_journal); + snapshot_fb_ = fb2::Fiber("snapshot", [this, stream_journal] { + this->IterateBucketsFb(stream_journal); db_slice_->UnregisterOnChange(snapshot_version_); - on_snapshot_finish_(); + consumer_->Finalize(); }); } -void SliceSnapshot::StartIncremental(Context* cntx, LSN start_lsn) { +void SliceSnapshot::StartIncremental(LSN start_lsn) { serializer_ = std::make_unique(compression_mode_); - snapshot_fb_ = fb2::Fiber("incremental_snapshot", [cntx, start_lsn, this] { - this->SwitchIncrementalFb(cntx, start_lsn); - }); + snapshot_fb_ = fb2::Fiber("incremental_snapshot", + [start_lsn, this] { this->SwitchIncrementalFb(start_lsn); }); } // Called only for replication use-case. @@ -144,7 +139,7 @@ void SliceSnapshot::FinalizeJournalStream(bool cancel) { // and survived until it finished. // Serializes all the entries with version less than snapshot_version_. -void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_sync_cut) { +void SliceSnapshot::IterateBucketsFb(bool send_full_sync_cut) { { auto fiber_name = absl::StrCat("SliceSnapshot-", ProactorBase::me()->GetPoolIndex()); ThisFiber::SetName(std::move(fiber_name)); @@ -156,7 +151,7 @@ void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_syn } for (DbIndex db_indx = 0; db_indx < db_array_.size(); ++db_indx) { - if (cll->IsCancelled()) + if (cntx_->IsCancelled()) return; if (!db_array_[db_indx]) @@ -168,7 +163,7 @@ void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_syn VLOG(1) << "Start traversing " << pt->size() << " items for index " << db_indx; do { - if (cll->IsCancelled()) { + if (cntx_->IsCancelled()) { return; } @@ -204,7 +199,7 @@ void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_syn << stats_.loop_serialized << "/" << stats_.side_saved << "/" << stats_.savecb_calls; } -void SliceSnapshot::SwitchIncrementalFb(Context* cntx, LSN lsn) { +void SliceSnapshot::SwitchIncrementalFb(LSN lsn) { auto* journal = db_slice_->shard_owner()->journal(); DCHECK(journal); DCHECK_LE(lsn, journal->GetLsn()) << "The replica tried to sync from the future."; @@ -212,7 +207,7 @@ void SliceSnapshot::SwitchIncrementalFb(Context* cntx, LSN lsn) { VLOG(1) << "Starting incremental snapshot from lsn=" << lsn; // The replica sends the LSN of the next entry is wants to receive. - while (!cntx->IsCancelled() && journal->IsLSNInBuffer(lsn)) { + while (!cntx_->IsCancelled() && journal->IsLSNInBuffer(lsn)) { serializer_->WriteJournalEntry(journal->GetEntry(lsn)); PushSerialized(false); lsn++; @@ -239,7 +234,7 @@ void SliceSnapshot::SwitchIncrementalFb(Context* cntx, LSN lsn) { PushSerialized(true); } else { // We stopped but we didn't manage to send the whole stream. - cntx->ReportError( + cntx_->ReportError( std::make_error_code(errc::state_not_recoverable), absl::StrCat("Partial sync was unsuccessful because entry #", lsn, " was dropped from the buffer. Current lsn=", journal->GetLsn())); @@ -348,7 +343,7 @@ size_t SliceSnapshot::FlushSerialized(SerializerBase::FlushState flush_state) { seq_cond_.wait(lk, [&] { return id == this->last_pushed_id_ + 1; }); // Blocking point. - on_push_(std::move(sfile.val)); + consumer_->ConsumeData(std::move(sfile.val), cntx_); DCHECK_EQ(last_pushed_id_ + 1, id); last_pushed_id_ = id; diff --git a/src/server/snapshot.h b/src/server/snapshot.h index 98ef11c78166..da4be49bd147 100644 --- a/src/server/snapshot.h +++ b/src/server/snapshot.h @@ -49,8 +49,18 @@ struct Entry; // over the sink until explicitly stopped. class SliceSnapshot { public: - SliceSnapshot(DbSlice* slice, CompressionMode compression_mode, - std::function on_push, std::function on_snapshot_finish); + // Represents a target for receiving snapshot data. + struct SnapshotDataConsumerInterface { + virtual ~SnapshotDataConsumerInterface() = default; + + // Receives a chunk of snapshot data for processing + virtual void ConsumeData(std::string data, Context* cntx) = 0; + // Finalizes the snapshot writing + virtual void Finalize() = 0; + }; + + SliceSnapshot(CompressionMode compression_mode, DbSlice* slice, + SnapshotDataConsumerInterface* consumer, Context* cntx); ~SliceSnapshot(); static size_t GetThreadLocalMemoryUsage(); @@ -60,15 +70,14 @@ class SliceSnapshot { // In journal streaming mode it needs to be stopped by either Stop or Cancel. enum class SnapshotFlush { kAllow, kDisallow }; - void Start(bool stream_journal, const Cancellation* cll, - SnapshotFlush allow_flush = SnapshotFlush::kDisallow); + void Start(bool stream_journal, SnapshotFlush allow_flush = SnapshotFlush::kDisallow); // Initialize a snapshot that sends only the missing journal updates // since start_lsn and then registers a callback switches into the // journal streaming mode until stopped. // If we're slower than the buffer and can't continue, `Cancel()` is // called. - void StartIncremental(Context* cntx, LSN start_lsn); + void StartIncremental(LSN start_lsn); // Finalizes journal streaming writes. Only called for replication. // Blocking. Must be called from the Snapshot thread. @@ -83,10 +92,10 @@ class SliceSnapshot { private: // Main snapshotting fiber that iterates over all buckets in the db slice // and submits them to SerializeBucket. - void IterateBucketsFb(const Cancellation* cll, bool send_full_sync_cut); + void IterateBucketsFb(bool send_full_sync_cut); // A fiber function that switches to the incremental mode - void SwitchIncrementalFb(Context* cntx, LSN lsn); + void SwitchIncrementalFb(LSN lsn); // Called on traversing cursor by IterateBucketsFb. bool BucketSaveCb(PrimeTable::bucket_iterator it); @@ -171,8 +180,8 @@ class SliceSnapshot { ThreadLocalMutex big_value_mu_; - std::function on_push_; - std::function on_snapshot_finish_; + SnapshotDataConsumerInterface* consumer_; + Context* cntx_; }; } // namespace dfly