From a50872954a1d4ee50d2ae451b3682bae5a8e7e73 Mon Sep 17 00:00:00 2001 From: Termina1 Date: Thu, 5 Sep 2024 22:30:13 +0300 Subject: [PATCH] wait drain state uses context and correctly canceled --- chotki_test.go | 2 +- sync.go | 21 ++++++++++++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/chotki_test.go b/chotki_test.go index 47ade91..15f34f5 100644 --- a/chotki_test.go +++ b/chotki_test.go @@ -147,7 +147,7 @@ func TestChotki_SyncLivePingsOk(t *testing.T) { synca := Syncer{ Host: a, PingPeriod: 100 * time.Millisecond, - PingWait: 100 * time.Millisecond, + PingWait: 200 * time.Millisecond, Mode: SyncRWLive, Name: "a", Src: a.src, log: utils.NewDefaultLogger(slog.LevelDebug), diff --git a/sync.go b/sync.go index cdf93ad..cccfe87 100644 --- a/sync.go +++ b/sync.go @@ -2,6 +2,7 @@ package chotki import ( "bytes" + "context" "errors" "io" "sync" @@ -165,13 +166,14 @@ func (sync *Syncer) Feed() (recs protocol.Records, err error) { sync.SetFeedState(SendDiff) case SendDiff: - + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() select { case <-time.After(sync.PingWait): sync.log.Error("handshake took too long", "name", sync.Name) sync.SetFeedState(SendEOF) return - case <-sync.WaitDrainState(SendDiff): + case <-sync.WaitDrainState(ctx, SendDiff): } recs, err = sync.FeedBlockDiff() if err == io.EOF { @@ -231,7 +233,7 @@ func (sync *Syncer) Feed() (recs protocol.Records, err error) { timer := time.AfterFunc(time.Second, func() { sync.SetDrainState(SendNone) }) - <-sync.WaitDrainState(SendNone) + <-sync.WaitDrainState(context.Background(), SendNone) timer.Stop() err = io.EOF } @@ -361,18 +363,27 @@ func (sync *Syncer) SetDrainState(state SyncState) { sync.lock.Unlock() } -func (sync *Syncer) WaitDrainState(state SyncState) chan SyncState { +func (sync *Syncer) WaitDrainState(ctx context.Context, state SyncState) chan SyncState { res := make(chan SyncState) go func() { + <-ctx.Done() + sync.cond.Broadcast() + }() + go func() { + defer close(res) sync.lock.Lock() + defer sync.lock.Unlock() if sync.cond.L == nil { sync.cond.L = &sync.lock } for sync.drainState < state { + if ctx.Err() != nil { + return + } sync.cond.Wait() } ds := sync.drainState - sync.lock.Unlock() + res <- ds }() return res