Skip to content

Commit

Permalink
Improved CreateRangeEncoder/Decoder ops to reduce cdf scanning when t…
Browse files Browse the repository at this point in the history
…here are

large number of batch items.

PiperOrigin-RevId: 448377181
Change-Id: Icc440c0af2825b2c9ef81b7209685b498fdfff2a
  • Loading branch information
ssjhv authored and copybara-github committed May 13, 2022
1 parent 7bd5363 commit 3b70187
Showing 1 changed file with 51 additions and 61 deletions.
112 changes: 51 additions & 61 deletions tensorflow_compression/cc/kernels/range_coder_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow_compression/cc/kernels/range_coder_kernels.h"

#include <cstdint>
#include <memory>
#include <string>
#include <utility>

Expand Down Expand Up @@ -114,24 +115,9 @@ Status IndexCDFMatrix(const TTypes<int32_t>::ConstMatrix& table,

class RangeEncoderInterface final : public EntropyEncoderInterface {
public:
static Status MakeShared(const Tensor lookup,
std::shared_ptr<EntropyEncoderInterface>* ptr) {
Status status;
RangeEncoderInterface* re = new RangeEncoderInterface(lookup);
if (lookup.dims() == 1) {
status = IndexCDFVector(lookup.flat<int32_t>(), &re->lookup_);
} else if (lookup.dims() == 2) {
status = IndexCDFMatrix(lookup.matrix<int32_t>(), &re->lookup_);
} else {
status = errors::InvalidArgument("`lookup` must be rank 1 or 2.");
}
if (status.ok()) {
ptr->reset(re);
} else {
delete re;
}
return status;
}
RangeEncoderInterface(absl::Span<const absl::Span<const int32_t>> lookup,
Tensor hold)
: lookup_(lookup.begin(), lookup.end()), hold_(std::move(hold)) {}

Status Encode(int32_t index, int32_t value) override {
TF_RETURN_IF_ERROR(CheckInRange("index", index, 0, lookup_.size()));
Expand All @@ -153,8 +139,6 @@ class RangeEncoderInterface final : public EntropyEncoderInterface {
}

private:
explicit RangeEncoderInterface(Tensor lookup) : hold_(std::move(lookup)) {}

void OverflowEncode(const absl::Span<const int32_t> row, int32_t value) {
const int32_t max_value = row.size() - 3;
const int32_t sign = value < 0;
Expand Down Expand Up @@ -193,24 +177,12 @@ class RangeEncoderInterface final : public EntropyEncoderInterface {

class RangeDecoderInterface final : public EntropyDecoderInterface {
public:
static Status MakeShared(absl::string_view encoded, const Tensor lookup,
std::shared_ptr<EntropyDecoderInterface>* ptr) {
Status status;
RangeDecoderInterface* rd = new RangeDecoderInterface(encoded, lookup);
if (lookup.dims() == 1) {
status = IndexCDFVector(lookup.flat<int32_t>(), &rd->lookup_);
} else if (lookup.dims() == 2) {
status = IndexCDFMatrix(lookup.matrix<int32_t>(), &rd->lookup_);
} else {
status = errors::InvalidArgument("`lookup` must be rank 1 or 2.");
}
if (status.ok()) {
ptr->reset(rd);
} else {
delete rd;
}
return status;
}
RangeDecoderInterface(absl::string_view encoded,
absl::Span<const absl::Span<const int32_t>> lookup,
Tensor hold)
: lookup_(lookup.begin(), lookup.end()),
decoder_(encoded),
hold_(std::move(hold)) {}

Status Decode(int32_t index, int32_t* output) override {
TF_RETURN_IF_ERROR(CheckInRange("index", index, 0, lookup_.size()));
Expand All @@ -232,9 +204,6 @@ class RangeDecoderInterface final : public EntropyDecoderInterface {
}

private:
RangeDecoderInterface(absl::string_view encoded, Tensor lookup)
: decoder_(encoded), hold_(std::move(lookup)) {}

int32_t OverflowDecode(const absl::Span<const int32_t> row) {
constexpr int32_t binary_uniform_cdf[] = {0, 1, 2};
const int32_t max_value = row.size() - 3;
Expand Down Expand Up @@ -313,11 +282,21 @@ class CreateRangeEncoderOp : public tensorflow::OpKernel {
context->allocate_output(0, handle_shape, &output_tensor));

const Tensor& lookup = context->input(1);
OP_REQUIRES(context, lookup.dims() == 1 || lookup.dims() == 2,
errors::InvalidArgument("`lookup` must be rank 1 or 2."));

std::vector<absl::Span<const int32_t>> table;
if (lookup.dims() == 1) {
OP_REQUIRES_OK(context, IndexCDFVector(lookup.flat<int32_t>(), &table));
} else {
DCHECK_EQ(lookup.dims(), 2);
OP_REQUIRES_OK(context, IndexCDFMatrix(lookup.matrix<int32_t>(), &table));
}

auto output = output_tensor->flat<Variant>();
for (int64_t i = 0; i < output.size(); ++i) {
EntropyEncoderVariant wrap;
OP_REQUIRES_OK(context,
RangeEncoderInterface::MakeShared(lookup, &wrap.encoder));
wrap.encoder = std::make_shared<RangeEncoderInterface>(table, lookup);
output(i) = std::move(wrap);
}
}
Expand Down Expand Up @@ -388,10 +367,10 @@ class EntropyEncodeChannelOp : public tensorflow::OpKernel {
context->SetStatus(status); \
return; \
}
#define REQUIRES_OK(status) \
{ \
auto s = (status); \
REQUIRES(s.ok(), s); \
#define REQUIRES_OK(status) \
{ \
auto s = (status); \
REQUIRES(s.ok(), s); \
}

const int64_t num_elements = value.dimension(1);
Expand Down Expand Up @@ -484,10 +463,10 @@ class EntropyEncodeIndexOp : public tensorflow::OpKernel {
context->SetStatus(status); \
return; \
}
#define REQUIRES_OK(status) \
{ \
auto s = (status); \
REQUIRES(s.ok(), s); \
#define REQUIRES_OK(status) \
{ \
auto s = (status); \
REQUIRES(s.ok(), s); \
}

const int64_t num_elements = value.dimension(1);
Expand Down Expand Up @@ -560,11 +539,22 @@ class CreateRangeDecoderOp : public tensorflow::OpKernel {
&output_tensor));

const Tensor& lookup = context->input(1);
OP_REQUIRES(context, lookup.dims() == 1 || lookup.dims() == 2,
errors::InvalidArgument("`lookup` must be rank 1 or 2."));

std::vector<absl::Span<const int32_t>> table;
if (lookup.dims() == 1) {
OP_REQUIRES_OK(context, IndexCDFVector(lookup.flat<int32_t>(), &table));
} else {
DCHECK_EQ(lookup.dims(), 2);
OP_REQUIRES_OK(context, IndexCDFMatrix(lookup.matrix<int32_t>(), &table));
}

auto output = output_tensor->flat<Variant>();
for (int64_t i = 0; i < output.size(); ++i) {
EntropyDecoderVariant wrap;
OP_REQUIRES_OK(context, RangeDecoderInterface::MakeShared(
encoded(i), lookup, &wrap.decoder));
wrap.decoder =
std::make_shared<RangeDecoderInterface>(encoded(i), table, lookup);
wrap.holder = encoded_tensor;
output(i) = std::move(wrap);
}
Expand Down Expand Up @@ -636,10 +626,10 @@ class EntropyDecodeChannelOp : public tensorflow::OpKernel {
context->SetStatus(status); \
return; \
}
#define REQUIRES_OK(status) \
{ \
auto s = (status); \
REQUIRES(s.ok(), s); \
#define REQUIRES_OK(status) \
{ \
auto s = (status); \
REQUIRES(s.ok(), s); \
}

const int64_t num_elements = output.dimension(1);
Expand Down Expand Up @@ -736,10 +726,10 @@ class EntropyDecodeIndexOp : public tensorflow::OpKernel {
context->SetStatus(status); \
return; \
}
#define REQUIRES_OK(status) \
{ \
auto s = (status); \
REQUIRES(s.ok(), s); \
#define REQUIRES_OK(status) \
{ \
auto s = (status); \
REQUIRES(s.ok(), s); \
}

const int64_t num_elements = output.dimension(1);
Expand Down

0 comments on commit 3b70187

Please sign in to comment.