diff --git a/google/cloud/storage/internal/async/connection_impl_read_test.cc b/google/cloud/storage/internal/async/connection_impl_read_test.cc index 9e7ea1d241326..06c914467bea5 100644 --- a/google/cloud/storage/internal/async/connection_impl_read_test.cc +++ b/google/cloud/storage/internal/async/connection_impl_read_test.cc @@ -520,6 +520,7 @@ TEST_F(AsyncConnectionImplTest, ReadObjectDetectBadMessageChecksum) { return absl::make_optional(response); }); }); + EXPECT_CALL(*stream, Cancel).Times(1); EXPECT_CALL(*stream, Finish).WillOnce([&] { return sequencer.PushBack("Finish").then([](auto) { return Status{}; }); }); @@ -563,14 +564,13 @@ TEST_F(AsyncConnectionImplTest, ReadObjectDetectBadMessageChecksum) { next = sequencer.PopFrontWithName(); EXPECT_EQ(next.second, "Read"); next.first.set_value(true); - auto response = data.get(); - EXPECT_THAT(response, - VariantWith(StatusIs(StatusCode::kInvalidArgument))); - - // The stream Finish() function should be called in the background. + // The `Finish()` call must complete before the result is ready. next = sequencer.PopFrontWithName(); EXPECT_EQ(next.second, "Finish"); next.first.set_value(true); + + EXPECT_THAT(data.get(), + VariantWith(StatusIs(StatusCode::kInvalidArgument))); } TEST_F(AsyncConnectionImplTest, ReadObjectDetectBadFullChecksum) { diff --git a/google/cloud/storage/internal/async/reader_connection_impl.cc b/google/cloud/storage/internal/async/reader_connection_impl.cc index 0acce6d4e28a9..1433aaba90903 100644 --- a/google/cloud/storage/internal/async/reader_connection_impl.cc +++ b/google/cloud/storage/internal/async/reader_connection_impl.cc @@ -35,13 +35,13 @@ future AsyncReaderConnectionImpl::OnRead(absl::optional r) { if (!r) return DoFinish(); auto response = *std::move(r); - auto hash = - hash_function_->Update(offset_, GetContent(response.checksummed_data()), - response.checksummed_data().crc32c()); - if (!hash.ok()) { - (void)DoFinish(); - return make_ready_future(ReadResponse(std::move(hash))); + if (!offset_ && response.has_content_range()) { + offset_ = response.content_range().start(); } + auto hash = hash_function_->Update(offset_.value_or(0), + GetContent(response.checksummed_data()), + response.checksummed_data().crc32c()); + if (!hash.ok()) return HandleHashError(std::move(hash)); auto result = ReadPayloadImpl::Make( StealMutableContent(*response.mutable_checksummed_data())); if (response.has_object_checksums()) { @@ -64,15 +64,19 @@ AsyncReaderConnectionImpl::OnRead(absl::optional r) { if (response.has_metadata()) { result.set_metadata(std::move(*response.mutable_metadata())); } - if (response.has_content_range()) { - result.set_offset(response.content_range().start()); - } else { - result.set_offset(offset_); - } - offset_ = result.offset() + result.size(); + result.set_offset(offset_.value_or(0)); + offset_ = offset_.value_or(0) + result.size(); return make_ready_future(ReadResponse(std::move(result))); } +future +AsyncReaderConnectionImpl::HandleHashError(Status status) { + impl_->Cancel(); + return impl_->Finish().then([s = std::move(status)](auto) mutable { + return ReadResponse(std::move(s)); + }); +} + future AsyncReaderConnectionImpl::DoFinish() { return impl_->Finish().then([](auto f) { return ReadResponse(f.get()); }); diff --git a/google/cloud/storage/internal/async/reader_connection_impl.h b/google/cloud/storage/internal/async/reader_connection_impl.h index ec932ef7ab537..3b2fa229452fe 100644 --- a/google/cloud/storage/internal/async/reader_connection_impl.h +++ b/google/cloud/storage/internal/async/reader_connection_impl.h @@ -20,6 +20,7 @@ #include "google/cloud/internal/async_streaming_read_rpc.h" #include "google/cloud/options.h" #include "google/cloud/version.h" +#include "absl/types/optional.h" #include #include #include @@ -52,14 +53,14 @@ class AsyncReaderConnectionImpl private: future OnRead(absl::optional r); - + future HandleHashError(Status status); future DoFinish(); google::cloud::internal::ImmutableOptions options_; std::shared_ptr hash_; std::unique_ptr impl_; std::shared_ptr hash_function_; - std::int64_t offset_ = 0; + absl::optional offset_; }; GOOGLE_CLOUD_CPP_INLINE_NAMESPACE_END diff --git a/google/cloud/storage/internal/async/reader_connection_impl_test.cc b/google/cloud/storage/internal/async/reader_connection_impl_test.cc index 9cdbb5c434711..aeecf53977806 100644 --- a/google/cloud/storage/internal/async/reader_connection_impl_test.cc +++ b/google/cloud/storage/internal/async/reader_connection_impl_test.cc @@ -19,6 +19,7 @@ #include "google/cloud/storage/options.h" #include "google/cloud/storage/testing/canonical_errors.h" #include "google/cloud/storage/testing/mock_hash_function.h" +#include "google/cloud/testing_util/async_sequencer.h" #include "google/cloud/testing_util/mock_async_streaming_read_rpc.h" #include "google/cloud/testing_util/status_matchers.h" #include @@ -34,6 +35,7 @@ using ::google::cloud::storage::internal::HashValues; using ::google::cloud::storage::testing::MockHashFunction; using ::google::cloud::storage::testing::canonical_errors::PermanentError; using ::google::cloud::storage_experimental::ReadPayload; +using ::google::cloud::testing_util::AsyncSequencer; using ::google::cloud::testing_util::IsOk; using ::google::cloud::testing_util::StatusIs; using ::google::protobuf::TextFormat; @@ -167,13 +169,16 @@ TEST(ReaderConnectionImpl, WithError) { } TEST(ReaderConnectionImpl, HashingError) { + AsyncSequencer sequencer; auto mock = std::make_unique(); - EXPECT_CALL(*mock, Read).WillOnce([] { - return make_ready_future( - absl::make_optional(google::storage::v2::ReadObjectResponse{})); + EXPECT_CALL(*mock, Read).WillOnce([&] { + return sequencer.PushBack("Read").then([](auto) { + return absl::make_optional(google::storage::v2::ReadObjectResponse{}); + }); }); - EXPECT_CALL(*mock, Finish).WillOnce([] { - return make_ready_future(Status{}); + EXPECT_CALL(*mock, Cancel).WillOnce([&] { sequencer.PushBack("Cancel"); }); + EXPECT_CALL(*mock, Finish).WillOnce([&] { + return sequencer.PushBack("Finish").then([](auto) { return Status{}; }); }); auto hash_function = std::make_shared(); @@ -182,7 +187,21 @@ TEST(ReaderConnectionImpl, HashingError) { AsyncReaderConnectionImpl tested(TestOptions(), std::move(mock), std::move(hash_function)); - EXPECT_THAT(tested.Read().get(), + + auto pending = tested.Read(); + auto read = sequencer.PopFrontWithName(); + EXPECT_EQ(read.second, "Read"); + read.first.set_value(false); + auto cancel = sequencer.PopFrontWithName(); + EXPECT_EQ(cancel.second, "Cancel"); + cancel.first.set_value(false); + auto finish = sequencer.PopFrontWithName(); + EXPECT_EQ(finish.second, "Finish"); + + EXPECT_FALSE(pending.is_ready()); + finish.first.set_value(false); + + EXPECT_THAT(pending.get(), VariantWith(StatusIs(StatusCode::kInvalidArgument))); }