Skip to content

Commit 06d9865

Browse files
committed
fix failing tests
1 parent b560110 commit 06d9865

File tree

3 files changed

+126
-77
lines changed

3 files changed

+126
-77
lines changed

fetch/handler.go

+16-19
Original file line numberDiff line numberDiff line change
@@ -101,26 +101,23 @@ func (h *handler) handleLegacyMaliciousIDsReqStream(ctx context.Context, _ p2p.P
101101
}
102102

103103
func (h *handler) handleMaliciousIDsReqStream(ctx context.Context, _ p2p.Peer, _ []byte, s io.ReadWriter) error {
104-
tx, err := h.db.TxImmediate(ctx)
105-
if err != nil {
106-
h.logger.Debug("failed to stream malicious node IDs", log.ZContext(ctx), zap.Error(err))
107-
return nil
108-
}
109-
defer tx.Release()
110-
total, err := malfeasance.Count(tx)
104+
err := h.streamIDs(ctx, s, func(cbk retrieveCallback) error {
105+
return h.db.WithTxImmediate(ctx, func(tx sql.Transaction) error {
106+
total, err := malfeasance.Count(tx)
107+
if err != nil {
108+
return fmt.Errorf("counting malicious nodes: %w", err)
109+
}
110+
return malfeasance.IterateOps(tx, builder.Operations{},
111+
func(nodeID types.NodeID, _ []byte, _ int, _ time.Time) bool {
112+
if err := cbk(total, nodeID.Bytes()); err != nil {
113+
h.logger.Debug("failed to stream malicious node IDs", log.ZContext(ctx), zap.Error(err))
114+
return false
115+
}
116+
return true
117+
})
118+
})
119+
})
111120
if err != nil {
112-
return fmt.Errorf("counting malicious nodes: %w", err)
113-
}
114-
if err := h.streamIDs(ctx, s, func(cbk retrieveCallback) error {
115-
return malfeasance.IterateOps(tx, builder.Operations{},
116-
func(nodeID types.NodeID, _ []byte, _ int, _ time.Time) bool {
117-
if err := cbk(total, nodeID.Bytes()); err != nil {
118-
h.logger.Debug("failed to stream malicious node IDs", log.ZContext(ctx), zap.Error(err))
119-
return false
120-
}
121-
return true
122-
})
123-
}); err != nil {
124121
h.logger.Debug("failed to stream malicious node IDs", log.ZContext(ctx), zap.Error(err))
125122
}
126123
return nil

syncer/malsync/syncer.go

+58-56
Original file line numberDiff line numberDiff line change
@@ -539,36 +539,37 @@ func (s *Syncer) downloadLegacyMalfeasanceProofs(ctx context.Context, initial bo
539539
}
540540

541541
nothingToDownload = len(batch) == 0
542-
if len(batch) != 0 {
543-
s.logger.Debug("retrieving legacy malicious identities",
544-
log.ZContext(ctx),
545-
zap.Int("count", len(batch)),
546-
)
547-
if err := s.fetcher.LegacyMalfeasanceProofs(ctx, batch); err != nil {
548-
if errors.Is(err, context.Canceled) {
549-
return ctx.Err()
550-
}
551-
s.logger.Debug("failed to download malfeasance proofs",
552-
log.ZContext(ctx),
553-
log.NiceZapError(err),
554-
)
555-
}
556-
batchError := &fetch.BatchError{}
557-
if errors.As(err, &batchError) {
558-
for hash, err := range batchError.Errors {
559-
nodeID := types.NodeID(hash)
560-
switch {
561-
case !sst.has(nodeID):
562-
continue
563-
case errors.Is(err, pubsub.ErrValidationReject):
564-
sst.rejected(nodeID)
565-
default:
566-
sst.failed(nodeID)
567-
}
542+
if len(batch) == 0 {
543+
s.logger.Debug("no new legacy malicious identities", log.ZContext(ctx))
544+
continue
545+
}
546+
547+
s.logger.Debug("retrieving legacy malicious identities",
548+
log.ZContext(ctx),
549+
zap.Int("count", len(batch)),
550+
)
551+
batchError := &fetch.BatchError{}
552+
err = s.fetcher.LegacyMalfeasanceProofs(ctx, batch)
553+
switch {
554+
case errors.Is(err, context.Canceled):
555+
return ctx.Err()
556+
case errors.As(err, &batchError):
557+
for hash, err := range batchError.Errors {
558+
nodeID := types.NodeID(hash)
559+
switch {
560+
case !sst.has(nodeID):
561+
continue
562+
case errors.Is(err, pubsub.ErrValidationReject):
563+
sst.rejected(nodeID)
564+
default:
565+
sst.failed(nodeID)
568566
}
569567
}
570-
} else {
571-
s.logger.Debug("no new legacy malicious identities", log.ZContext(ctx))
568+
case err != nil:
569+
s.logger.Debug("failed to download malfeasance proofs",
570+
log.ZContext(ctx),
571+
log.NiceZapError(err),
572+
)
572573
}
573574
}
574575
}
@@ -634,36 +635,37 @@ func (s *Syncer) downloadMalfeasanceProofs(ctx context.Context, initial bool, up
634635
}
635636

636637
nothingToDownload = len(batch) == 0
637-
if len(batch) != 0 {
638-
s.logger.Debug("retrieving malicious identities",
639-
log.ZContext(ctx),
640-
zap.Int("count", len(batch)),
641-
)
642-
if err := s.fetcher.MalfeasanceProofs(ctx, batch); err != nil {
643-
if errors.Is(err, context.Canceled) {
644-
return ctx.Err()
645-
}
646-
s.logger.Debug("failed to download malfeasance proofs",
647-
log.ZContext(ctx),
648-
log.NiceZapError(err),
649-
)
650-
}
651-
batchError := &fetch.BatchError{}
652-
if errors.As(err, &batchError) {
653-
for hash, err := range batchError.Errors {
654-
nodeID := types.NodeID(hash)
655-
switch {
656-
case !sst.has(nodeID):
657-
continue
658-
case errors.Is(err, pubsub.ErrValidationReject):
659-
sst.rejected(nodeID)
660-
default:
661-
sst.failed(nodeID)
662-
}
638+
if len(batch) == 0 {
639+
s.logger.Debug("no new malicious identities", log.ZContext(ctx))
640+
continue
641+
}
642+
643+
s.logger.Debug("retrieving malicious identities",
644+
log.ZContext(ctx),
645+
zap.Int("count", len(batch)),
646+
)
647+
batchError := &fetch.BatchError{}
648+
err = s.fetcher.MalfeasanceProofs(ctx, batch)
649+
switch {
650+
case errors.Is(err, context.Canceled):
651+
return ctx.Err()
652+
case errors.As(err, &batchError):
653+
for hash, err := range batchError.Errors {
654+
nodeID := types.NodeID(hash)
655+
switch {
656+
case !sst.has(nodeID):
657+
continue
658+
case errors.Is(err, pubsub.ErrValidationReject):
659+
sst.rejected(nodeID)
660+
default:
661+
sst.failed(nodeID)
663662
}
664663
}
665-
} else {
666-
s.logger.Debug("no new malicious identities", log.ZContext(ctx))
664+
case err != nil:
665+
s.logger.Debug("failed to download malfeasance proofs",
666+
log.ZContext(ctx),
667+
log.NiceZapError(err),
668+
)
667669
}
668670
}
669671
}

syncer/malsync/syncer_test.go

+52-2
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ func TestSyncer(t *testing.T) {
405405
require.NoError(t, tester.syncer.EnsureLegacyInSync(context.Background(), epochStart, epochEnd))
406406
require.Equal(t, 1, tester.peerErrCount.n)
407407
})
408-
t.Run("skip hashes after max retries", func(t *testing.T) {
408+
t.Run("skip hashes after max retries - legacy", func(t *testing.T) {
409409
cfg := DefaultConfig()
410410
cfg.RequestsLimit = 3
411411
tester := newTester(t, cfg)
@@ -430,7 +430,32 @@ func TestSyncer(t *testing.T) {
430430
// second call does nothing after recent sync
431431
require.NoError(t, tester.syncer.EnsureLegacyInSync(context.Background(), epochStart, epochEnd))
432432
})
433-
t.Run("skip hashes after validation reject", func(t *testing.T) {
433+
t.Run("skip hashes after max retries", func(t *testing.T) {
434+
cfg := DefaultConfig()
435+
cfg.RequestsLimit = 3
436+
tester := newTester(t, cfg)
437+
tester.expectPeers(tester.peers)
438+
tester.expectMaliciousIDs()
439+
tester.expectProofs(map[types.NodeID]error{
440+
nid("102"): errors.New("fail"),
441+
})
442+
epochStart := tester.clock.Now().Truncate(time.Second)
443+
epochEnd := epochStart.Add(10 * time.Minute)
444+
require.NoError(t, tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd))
445+
require.ElementsMatch(t, []types.NodeID{
446+
nid("101"), nid("103"), nid("104"),
447+
}, maps.Keys(tester.received))
448+
require.Equal(t, map[types.NodeID]int{
449+
nid("101"): 1,
450+
nid("102"): tester.cfg.RequestsLimit,
451+
nid("103"): 1,
452+
nid("104"): 1,
453+
}, tester.attempts)
454+
tester.clock.Advance(1 * time.Minute)
455+
// second call does nothing after recent sync
456+
require.NoError(t, tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd))
457+
})
458+
t.Run("skip hashes after validation reject - legacy", func(t *testing.T) {
434459
tester := newTester(t, DefaultConfig())
435460
tester.expectPeers(tester.peers)
436461
tester.expectLegacyMaliciousIDs()
@@ -455,4 +480,29 @@ func TestSyncer(t *testing.T) {
455480
// second call does nothing after recent sync
456481
require.NoError(t, tester.syncer.EnsureLegacyInSync(context.Background(), epochStart, epochEnd))
457482
})
483+
t.Run("skip hashes after validation reject", func(t *testing.T) {
484+
tester := newTester(t, DefaultConfig())
485+
tester.expectPeers(tester.peers)
486+
tester.expectMaliciousIDs()
487+
tester.expectProofs(map[types.NodeID]error{
488+
// note that "102" comes just from a single peer
489+
// (see expectMaliciousIDs)
490+
nid("102"): pubsub.ErrValidationReject,
491+
})
492+
epochStart := tester.clock.Now().Truncate(time.Second)
493+
epochEnd := epochStart.Add(10 * time.Minute)
494+
require.NoError(t, tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd))
495+
require.ElementsMatch(t, []types.NodeID{
496+
nid("101"), nid("103"), nid("104"),
497+
}, maps.Keys(tester.received))
498+
require.Equal(t, map[types.NodeID]int{
499+
nid("101"): 1,
500+
nid("102"): 1,
501+
nid("103"): 1,
502+
nid("104"): 1,
503+
}, tester.attempts)
504+
tester.clock.Advance(1 * time.Minute)
505+
// second call does nothing after recent sync
506+
require.NoError(t, tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd))
507+
})
458508
}

0 commit comments

Comments
 (0)