Skip to content

Commit

Permalink
Add option to create and read gzip compressed blob sequences.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 714047504
  • Loading branch information
achoum authored and copybara-github committed Jan 10, 2025
1 parent d50edf2 commit 05bc458
Show file tree
Hide file tree
Showing 13 changed files with 259 additions and 116 deletions.
75 changes: 43 additions & 32 deletions yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "yggdrasil_decision_forests/utils/filesystem.h"
#include "yggdrasil_decision_forests/utils/logging.h"
#include "yggdrasil_decision_forests/utils/protobuf.h"
#include "yggdrasil_decision_forests/utils/status_macros.h"
#include "yggdrasil_decision_forests/utils/zlib.h"

namespace yggdrasil_decision_forests::dataset::tensorflow_no_dep {
Expand All @@ -54,56 +55,62 @@ inline uint32_t Unmask(const uint32_t masked_crc) {
return ((rot >> 17) | (rot << 15));
}

template <typename T>
constexpr absl::string_view GetView(const T& data) {
return absl::string_view((char*)&data, sizeof(T));
}

} // namespace

absl::StatusOr<absl::crc32c_t> TFRecordReader::ReadCRC() {
uint32_t value;
ASSIGN_OR_RETURN(const bool has_content,
stream_->ReadExactly((char*)&value, sizeof(uint32_t)));
stream().ReadExactly((char*)&value, sizeof(uint32_t)));
if (!has_content) {
return absl::InvalidArgumentError("Empty stream");
}
return absl::crc32c_t(Unmask(absl::little_endian::ToHost32(value)));
}

TFRecordReader::~TFRecordReader() {
if (stream_) {
if (raw_stream_) {
LOG(WARNING) << "Destruction of a non closed TFRecordReader";
Close().IgnoreError();
}
}

absl::StatusOr<std::unique_ptr<TFRecordReader>> TFRecordReader::Create(
const absl::string_view path, bool compressed) {
ASSIGN_OR_RETURN(std::unique_ptr<utils::InputByteStream> stream,
file::OpenInputFile(path));
auto reader = std::make_unique<TFRecordReader>();

ASSIGN_OR_RETURN(reader->raw_stream_, file::OpenInputFile(path));
if (compressed) {
ASSIGN_OR_RETURN(stream,
utils::GZipInputByteStream::Create(std::move(stream)));
ASSIGN_OR_RETURN(reader->zlib_stream_, utils::GZipInputByteStream::Create(
reader->raw_stream_.get()));
}
return std::make_unique<TFRecordReader>(std::move(stream));
return reader;
}

absl::StatusOr<bool> TFRecordReader::Next(google::protobuf::MessageLite* message) {
uint64_t raw_length;
ASSIGN_OR_RETURN(bool has_content,
stream_->ReadExactly((char*)&raw_length, sizeof(uint64_t)));
stream().ReadExactly((char*)&raw_length, sizeof(uint64_t)));
if (!has_content) {
return false;
}
const uint64_t length = absl::little_endian::ToHost64(raw_length);

ASSIGN_OR_RETURN(const absl::crc32c_t raw_length_expected_crc, ReadCRC());
const absl::crc32c_t raw_length_real_checksum = absl::ComputeCrc32c(
absl::string_view((char*)&raw_length, sizeof(uint64_t)));
const absl::crc32c_t raw_length_real_checksum =
absl::ComputeCrc32c(GetView(raw_length));
if (raw_length_expected_crc != raw_length_real_checksum) {
return absl::InvalidArgumentError(kInvalidDataMessage);
}

buffer_.resize(length);
// TODO: Use buffer_.data() in c++>=17.
if (length > 0) {
ASSIGN_OR_RETURN(has_content, stream_->ReadExactly(&buffer_[0], length));
ASSIGN_OR_RETURN(has_content, stream().ReadExactly(&buffer_[0], length));
}
if (!has_content) {
return absl::InvalidArgumentError(kInvalidDataMessage);
Expand All @@ -123,26 +130,30 @@ absl::StatusOr<bool> TFRecordReader::Next(google::protobuf::MessageLite* message

// Closes the stream.
absl::Status TFRecordReader::Close() {
if (stream_) {
RETURN_IF_ERROR(stream_->Close());
stream_.reset();
if (zlib_stream_) {
RETURN_IF_ERROR(zlib_stream_->Close());
zlib_stream_.reset();
}
if (raw_stream_) {
RETURN_IF_ERROR(raw_stream_->Close());
raw_stream_.reset();
}
return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<TFRecordWriter>> TFRecordWriter::Create(
absl::string_view path, bool compressed) {
ASSIGN_OR_RETURN(std::unique_ptr<utils::OutputByteStream> stream,
file::OpenOutputFile(path));
auto writer = std::make_unique<TFRecordWriter>();
ASSIGN_OR_RETURN(writer->raw_stream_, file::OpenOutputFile(path));
if (compressed) {
ASSIGN_OR_RETURN(stream,
utils::GZipOutputByteStream::Create(std::move(stream)));
ASSIGN_OR_RETURN(writer->zlib_stream_, utils::GZipOutputByteStream::Create(
writer->raw_stream_.get()));
}
return std::make_unique<TFRecordWriter>(std::move(stream));
return writer;
}

TFRecordWriter::~TFRecordWriter() {
if (stream_) {
if (raw_stream_) {
LOG(WARNING) << "Destruction of a non closed TFRecordWriter";
Close().IgnoreError();
}
Expand All @@ -157,30 +168,30 @@ absl::Status TFRecordWriter::Write(const google::protobuf::MessageLite& message)

absl::Status TFRecordWriter::Write(const absl::string_view data) {
uint64_t length = data.size();
RETURN_IF_ERROR(
stream_->Write(absl::string_view((char*)&length, sizeof(uint64_t))));
RETURN_IF_ERROR(stream().Write(GetView(length)));

const uint64_t net_length = absl::little_endian::FromHost64(length);
const uint32_t net_length_checksum =
Mask(static_cast<uint32_t>(absl::ComputeCrc32c(
absl::string_view((char*)&net_length, sizeof(uint64_t)))));
RETURN_IF_ERROR(stream_->Write(
absl::string_view((char*)&net_length_checksum, sizeof(uint32_t))));
Mask(static_cast<uint32_t>(absl::ComputeCrc32c(GetView(net_length))));
RETURN_IF_ERROR(stream().Write(GetView(net_length_checksum)));

RETURN_IF_ERROR(stream_->Write(data));
RETURN_IF_ERROR(stream().Write(data));

const uint32_t net_data_checksum =
Mask(static_cast<uint32_t>(absl::ComputeCrc32c(data)));
RETURN_IF_ERROR(stream_->Write(
absl::string_view((char*)&net_data_checksum, sizeof(uint32_t))));
RETURN_IF_ERROR(stream().Write(GetView(net_data_checksum)));

return absl::OkStatus();
}

absl::Status TFRecordWriter::Close() {
if (stream_) {
RETURN_IF_ERROR(stream_->Close());
stream_.reset();
if (zlib_stream_) {
RETURN_IF_ERROR(zlib_stream_->Close());
zlib_stream_.reset();
}
if (raw_stream_) {
RETURN_IF_ERROR(raw_stream_->Close());
raw_stream_.reset();
}
return absl::OkStatus();
}
Expand Down
26 changes: 19 additions & 7 deletions yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "yggdrasil_decision_forests/utils/protobuf.h"
#include "yggdrasil_decision_forests/utils/sharded_io.h"
#include "yggdrasil_decision_forests/utils/status_macros.h"
#include "yggdrasil_decision_forests/utils/zlib.h"

namespace yggdrasil_decision_forests::dataset::tensorflow_no_dep {

Expand All @@ -51,17 +52,23 @@ class TFRecordReader {
// Closes the stream.
absl::Status Close();

TFRecordReader(std::unique_ptr<utils::InputByteStream>&& stream)
: stream_(std::move(stream)) {}

// Value of the last read record. Includes skipped messages.
const std::string& buffer() const { return buffer_; }

TFRecordReader() {}

private:
utils::InputByteStream& stream() {
return zlib_stream_ ? *zlib_stream_ : *raw_stream_;
}

// Reads a CRC.
absl::StatusOr<absl::crc32c_t> ReadCRC();

std::unique_ptr<utils::InputByteStream> stream_;
// Underlying stream. If nullptr, the reader is closed.
std::unique_ptr<utils::InputByteStream> raw_stream_;
// Optional stream, if zlib compression is enabled.
std::unique_ptr<utils::GZipInputByteStream> zlib_stream_;
std::string buffer_;
};

Expand Down Expand Up @@ -101,11 +108,16 @@ class TFRecordWriter {
// Closes the stream.
absl::Status Close();

TFRecordWriter(std::unique_ptr<utils::OutputByteStream>&& stream)
: stream_(std::move(stream)) {}
TFRecordWriter() {}

private:
std::unique_ptr<utils::OutputByteStream> stream_;
utils::OutputByteStream& stream() {
return zlib_stream_ ? *zlib_stream_ : *raw_stream_;
}
// Underlying stream. If nullptr, the writer is closed.
std::unique_ptr<utils::OutputByteStream> raw_stream_;
// Optional stream, if zlib compression is enabled.
std::unique_ptr<utils::GZipOutputByteStream> zlib_stream_;
std::string buffer_;
};

Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
4 changes: 4 additions & 0 deletions yggdrasil_decision_forests/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ cc_library_ydf(
deps = [
":bytestream",
":status_macros",
":zlib",
"@com_google_absl//absl/base:endian",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down Expand Up @@ -855,6 +856,7 @@ cc_library_ydf(
deps = [
":bytestream",
":status_macros",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down Expand Up @@ -1195,12 +1197,14 @@ cc_test(
cc_test(
name = "blob_sequence_test",
srcs = ["blob_sequence_test.cc"],
data = ["//yggdrasil_decision_forests/test_data"],
deps = [
":blob_sequence",
":filesystem",
":test",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)
Expand Down
Loading

0 comments on commit 05bc458

Please sign in to comment.