diff --git a/src/dht.cpp b/src/dht.cpp index f969de940..00675d99e 100644 --- a/src/dht.cpp +++ b/src/dht.cpp @@ -93,7 +93,7 @@ Dht::shutdown(ShutdownCallback cb, bool stop) r.second.done_cb(false, {}); sr.second->callbacks.clear(); for (const auto& a : sr.second->announce) { - if (a.callback) a.callback(false, {}); + for (auto& cb : a.callbacks) cb(false, {}); } sr.second->announce.clear(); sr.second->listeners.clear(); diff --git a/src/search.h b/src/search.h index 99b6e3435..52d991b7f 100644 --- a/src/search.h +++ b/src/search.h @@ -45,7 +45,7 @@ struct Dht::Announce { bool permanent; Sp value; time_point created; - DoneCallback callback; + std::vector callbacks; }; struct Dht::SearchNode { @@ -442,8 +442,9 @@ struct Dht::Search { g.second.done_cb = {}; } for (auto& a : announce) { - a.callback(false, {}); - a.callback = {}; + for (auto& cb : a.callbacks) + cb(false, {}); + a.callbacks.clear(); } } @@ -627,7 +628,9 @@ struct Dht::Search { return a.value->id == value->id; }); if (a_sr == announce.end()) { - announce.emplace_back(Announce {permanent, value, created, callback}); + auto& a = announce.emplace_back(Announce {permanent, value, created, {}} ); + if (callback) + a.callbacks.emplace_back(std::move(callback)); for (auto& n : nodes) { n->probe_query.reset(); n->acked[value->id].req.reset(); @@ -635,24 +638,26 @@ struct Dht::Search { } else { a_sr->permanent = permanent; a_sr->created = created; - if (a_sr->value != value) { + if (a_sr->value != value && *a_sr->value != *value) { + // Value is updated, previous ops are failed a_sr->value = value; for (auto& n : nodes) { n->acked[value->id].req.reset(); n->probe_query.reset(); } - } - if (isAnnounced(value->id)) { - if (a_sr->callback) - a_sr->callback(true, {}); - a_sr->callback = {}; + for (auto& cb: a_sr->callbacks) + cb(false, {}); + a_sr->callbacks.clear(); + if (callback) + a_sr->callbacks.emplace_back(std::move(callback)); + } else if (isAnnounced(value->id)) { + // Same value, already announced if (callback) callback(true, {}); - return; } else { - if (a_sr->callback) - a_sr->callback(false, {}); - a_sr->callback = callback; + // Same value, not announced yet + if (callback) + a_sr->callbacks.emplace_back(std::move(callback)); } } } @@ -722,8 +727,8 @@ struct Dht::Search { std::vector a_cbs; a_cbs.reserve(announce.size()); for (auto ait = announce.begin() ; ait != announce.end(); ) { - if (ait->callback) - a_cbs.emplace_back(std::move(ait->callback)); + a_cbs.insert(a_cbs.end(), std::make_move_iterator(ait->callbacks.begin()), std::make_move_iterator(ait->callbacks.end())); + ait->callbacks.clear(); if (not ait->permanent) ait = announce.erase(ait); else @@ -747,9 +752,11 @@ struct Dht::Search { if (vid != Value::INVALID_ID and (!a.value || a.value->id != vid)) return true; if (isAnnounced(a.value->id)) { - if (a.callback) { - a.callback(true, getNodes()); - a.callback = nullptr; + if (!a.callbacks.empty()) { + const auto& nodes = getNodes(); + for (auto& cb : a.callbacks) + cb(true, nodes); + a.callbacks.clear(); } if (not a.permanent) return false; diff --git a/tests/dhtrunnertester.cpp b/tests/dhtrunnertester.cpp index 01b8a347b..a882794f7 100644 --- a/tests/dhtrunnertester.cpp +++ b/tests/dhtrunnertester.cpp @@ -90,6 +90,59 @@ DhtRunnerTester::testGetPut() { CPPUNIT_ASSERT(vals.front()->data == val_data); } +void +DhtRunnerTester::testPutDuplicate() { + auto key = dht::InfoHash::get("123"); + auto val = std::make_shared("hey"); + val->id = 42; + auto val_data = val->data; + std::promise p1; + std::promise p2; + node2.put(key, val, [&](bool ok){ + p1.set_value(ok); + }); + node2.put(key, val, [&](bool ok){ + p2.set_value(ok); + }); + auto p1ret = p1.get_future().get(); + auto p2ret = p2.get_future().get(); + CPPUNIT_ASSERT(p1ret); + CPPUNIT_ASSERT(p2ret); + auto vals = node1.get(key).get(); + CPPUNIT_ASSERT(not vals.empty()); + CPPUNIT_ASSERT(vals.size() == 1); + CPPUNIT_ASSERT(vals.front()->data == val_data); +} + + +void +DhtRunnerTester::testPutOverride() { + auto key = dht::InfoHash::get("123"); + auto val = std::make_shared("meh"); + val->id = 42; + auto val2 = std::make_shared("hey"); + val2->id = 42; + CPPUNIT_ASSERT_EQUAL(val->id, val2->id); + auto val_data = val2->data; + std::promise p1; + std::promise p2; + node2.put(key, val, [&](bool ok){ + p1.set_value(ok); + }); + node2.put(key, val2, [&](bool ok){ + p2.set_value(ok); + }); + auto p1ret = p1.get_future().get(); + auto p2ret = p2.get_future().get(); + CPPUNIT_ASSERT(!p1ret); + CPPUNIT_ASSERT(p2ret); + auto vals = node1.get(key).get(); + CPPUNIT_ASSERT(not vals.empty()); + CPPUNIT_ASSERT(vals.size() == 1); + CPPUNIT_ASSERT(vals.front()->data == val_data); +} + + void DhtRunnerTester::testListen() { std::mutex mutex; diff --git a/tests/dhtrunnertester.h b/tests/dhtrunnertester.h index 3f5afd3e4..501679d70 100644 --- a/tests/dhtrunnertester.h +++ b/tests/dhtrunnertester.h @@ -31,6 +31,8 @@ class DhtRunnerTester : public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(DhtRunnerTester); CPPUNIT_TEST(testConstructors); CPPUNIT_TEST(testGetPut); + CPPUNIT_TEST(testPutDuplicate); + CPPUNIT_TEST(testPutOverride); CPPUNIT_TEST(testListen); CPPUNIT_TEST(testListenLotOfBytes); CPPUNIT_TEST(testIdOps); @@ -55,6 +57,14 @@ class DhtRunnerTester : public CppUnit::TestFixture { * Test get and put methods */ void testGetPut(); + /** + * Test get and multiple put + */ + void testPutDuplicate(); + /** + * Test get and multiple put with changing value + */ + void testPutOverride(); /** * Test listen method */