Skip to content

Commit

Permalink
fix(GCS+gRPC): delay resume until Finish() returns (#14569)
Browse files Browse the repository at this point in the history
Another case where we need to wait for `Finish()` before deleting the
underlying RPC.  This is slightly wasteful, as we could start the resume
while the `Finish()` operation is running.  We can improve that later,
once the code no longer crashes.
  • Loading branch information
coryan authored Jul 25, 2024
1 parent e8af3d7 commit 00b80da
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 25 deletions.
10 changes: 5 additions & 5 deletions google/cloud/storage/internal/async/connection_impl_read_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ TEST_F(AsyncConnectionImplTest, ReadObjectDetectBadMessageChecksum) {
return absl::make_optional(response);
});
});
EXPECT_CALL(*stream, Cancel).Times(1);
EXPECT_CALL(*stream, Finish).WillOnce([&] {
return sequencer.PushBack("Finish").then([](auto) { return Status{}; });
});
Expand Down Expand Up @@ -563,14 +564,13 @@ TEST_F(AsyncConnectionImplTest, ReadObjectDetectBadMessageChecksum) {
next = sequencer.PopFrontWithName();
EXPECT_EQ(next.second, "Read");
next.first.set_value(true);
auto response = data.get();
EXPECT_THAT(response,
VariantWith<Status>(StatusIs(StatusCode::kInvalidArgument)));

// The stream Finish() function should be called in the background.
// The `Finish()` call must complete before the result is ready.
next = sequencer.PopFrontWithName();
EXPECT_EQ(next.second, "Finish");
next.first.set_value(true);

EXPECT_THAT(data.get(),
VariantWith<Status>(StatusIs(StatusCode::kInvalidArgument)));
}

TEST_F(AsyncConnectionImplTest, ReadObjectDetectBadFullChecksum) {
Expand Down
28 changes: 16 additions & 12 deletions google/cloud/storage/internal/async/reader_connection_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ future<AsyncReaderConnectionImpl::ReadResponse>
AsyncReaderConnectionImpl::OnRead(absl::optional<ProtoPayload> r) {
if (!r) return DoFinish();
auto response = *std::move(r);
auto hash =
hash_function_->Update(offset_, GetContent(response.checksummed_data()),
response.checksummed_data().crc32c());
if (!hash.ok()) {
(void)DoFinish();
return make_ready_future(ReadResponse(std::move(hash)));
if (!offset_ && response.has_content_range()) {
offset_ = response.content_range().start();
}
auto hash = hash_function_->Update(offset_.value_or(0),
GetContent(response.checksummed_data()),
response.checksummed_data().crc32c());
if (!hash.ok()) return HandleHashError(std::move(hash));
auto result = ReadPayloadImpl::Make(
StealMutableContent(*response.mutable_checksummed_data()));
if (response.has_object_checksums()) {
Expand All @@ -64,15 +64,19 @@ AsyncReaderConnectionImpl::OnRead(absl::optional<ProtoPayload> r) {
if (response.has_metadata()) {
result.set_metadata(std::move(*response.mutable_metadata()));
}
if (response.has_content_range()) {
result.set_offset(response.content_range().start());
} else {
result.set_offset(offset_);
}
offset_ = result.offset() + result.size();
result.set_offset(offset_.value_or(0));
offset_ = offset_.value_or(0) + result.size();
return make_ready_future(ReadResponse(std::move(result)));
}

future<AsyncReaderConnectionImpl::ReadResponse>
AsyncReaderConnectionImpl::HandleHashError(Status status) {
impl_->Cancel();
return impl_->Finish().then([s = std::move(status)](auto) mutable {
return ReadResponse(std::move(s));
});
}

future<AsyncReaderConnectionImpl::ReadResponse>
AsyncReaderConnectionImpl::DoFinish() {
return impl_->Finish().then([](auto f) { return ReadResponse(f.get()); });
Expand Down
5 changes: 3 additions & 2 deletions google/cloud/storage/internal/async/reader_connection_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "google/cloud/internal/async_streaming_read_rpc.h"
#include "google/cloud/options.h"
#include "google/cloud/version.h"
#include "absl/types/optional.h"
#include <google/storage/v2/storage.pb.h>
#include <cstdint>
#include <memory>
Expand Down Expand Up @@ -52,14 +53,14 @@ class AsyncReaderConnectionImpl

private:
future<ReadResponse> OnRead(absl::optional<ProtoPayload> r);

future<ReadResponse> HandleHashError(Status status);
future<ReadResponse> DoFinish();

google::cloud::internal::ImmutableOptions options_;
std::shared_ptr<storage::internal::HashFunction> hash_;
std::unique_ptr<StreamingRpc> impl_;
std::shared_ptr<storage::internal::HashFunction> hash_function_;
std::int64_t offset_ = 0;
absl::optional<std::int64_t> offset_;
};

GOOGLE_CLOUD_CPP_INLINE_NAMESPACE_END
Expand Down
31 changes: 25 additions & 6 deletions google/cloud/storage/internal/async/reader_connection_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "google/cloud/storage/options.h"
#include "google/cloud/storage/testing/canonical_errors.h"
#include "google/cloud/storage/testing/mock_hash_function.h"
#include "google/cloud/testing_util/async_sequencer.h"
#include "google/cloud/testing_util/mock_async_streaming_read_rpc.h"
#include "google/cloud/testing_util/status_matchers.h"
#include <google/protobuf/text_format.h>
Expand All @@ -34,6 +35,7 @@ using ::google::cloud::storage::internal::HashValues;
using ::google::cloud::storage::testing::MockHashFunction;
using ::google::cloud::storage::testing::canonical_errors::PermanentError;
using ::google::cloud::storage_experimental::ReadPayload;
using ::google::cloud::testing_util::AsyncSequencer;
using ::google::cloud::testing_util::IsOk;
using ::google::cloud::testing_util::StatusIs;
using ::google::protobuf::TextFormat;
Expand Down Expand Up @@ -167,13 +169,16 @@ TEST(ReaderConnectionImpl, WithError) {
}

TEST(ReaderConnectionImpl, HashingError) {
AsyncSequencer<bool> sequencer;
auto mock = std::make_unique<MockStream>();
EXPECT_CALL(*mock, Read).WillOnce([] {
return make_ready_future(
absl::make_optional(google::storage::v2::ReadObjectResponse{}));
EXPECT_CALL(*mock, Read).WillOnce([&] {
return sequencer.PushBack("Read").then([](auto) {
return absl::make_optional(google::storage::v2::ReadObjectResponse{});
});
});
EXPECT_CALL(*mock, Finish).WillOnce([] {
return make_ready_future(Status{});
EXPECT_CALL(*mock, Cancel).WillOnce([&] { sequencer.PushBack("Cancel"); });
EXPECT_CALL(*mock, Finish).WillOnce([&] {
return sequencer.PushBack("Finish").then([](auto) { return Status{}; });
});

auto hash_function = std::make_shared<MockHashFunction>();
Expand All @@ -182,7 +187,21 @@ TEST(ReaderConnectionImpl, HashingError) {

AsyncReaderConnectionImpl tested(TestOptions(), std::move(mock),
std::move(hash_function));
EXPECT_THAT(tested.Read().get(),

auto pending = tested.Read();
auto read = sequencer.PopFrontWithName();
EXPECT_EQ(read.second, "Read");
read.first.set_value(false);
auto cancel = sequencer.PopFrontWithName();
EXPECT_EQ(cancel.second, "Cancel");
cancel.first.set_value(false);
auto finish = sequencer.PopFrontWithName();
EXPECT_EQ(finish.second, "Finish");

EXPECT_FALSE(pending.is_ready());
finish.first.set_value(false);

EXPECT_THAT(pending.get(),
VariantWith<Status>(StatusIs(StatusCode::kInvalidArgument)));
}

Expand Down

0 comments on commit 00b80da

Please sign in to comment.