Skip to content

Commit

Permalink
Merge pull request #2 from blackwer/optimize-encoder
Browse files Browse the repository at this point in the history
Add C++ implementation for fast unique/count in some cases (~10x)
  • Loading branch information
magland authored Jan 24, 2025
2 parents a576928 + 49e9b87 commit cd78804
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 2 deletions.
45 changes: 44 additions & 1 deletion simple_ans/cpp/simple_ans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,50 @@ void ans_decode_t(T* output,

namespace simple_ans
{
constexpr int lookup_array_threshold = std::numeric_limits<uint16_t>::max() + 1;

template <typename T>
std::tuple<std::vector<T>, std::vector<uint64_t>> unique_with_counts(const T* values, size_t n)
{
// WARNING: This is ONLY a helper function. It doesn't support arrays with a large domain, and will instead fail
// return empty vectors. It is up to the caller to handle this case separately. numpy.unique() is quite fast, with
// improvements to use vectorized sorts (in 2.x, at least), so I didn't bother to implement a more efficient version here.
std::vector<T> unique_values;
std::vector<uint64_t> counts;
if (!n)
{
return {unique_values, counts};
}

int64_t min_value = values[0];
int64_t max_value = values[0];
// Check if the range of values is small enough to use a lookup array
for (size_t i = 1; i < n; ++i)
{
min_value = std::min(min_value, static_cast<int64_t>(values[i]));
max_value = std::max(max_value, static_cast<int64_t>(values[i]));
}

if ((max_value - min_value + 1) <= lookup_array_threshold)
{
std::vector<uint64_t> raw_counts(max_value - min_value + 1);
for (size_t i = 0; i < n; ++i)
{
raw_counts[values[i] - min_value]++;
}

for (size_t i = 0; i < counts.size(); ++i)
{
if (raw_counts[i])
{
unique_values.push_back(static_cast<T>(i + min_value));
counts.push_back(raw_counts[i]);
}
}
}

return {std::move(unique_values), std::move(counts)};
}

inline void read_bits_from_end_of_bitstream(const uint64_t* bitstream,
int64_t& source_bit_position,
Expand Down Expand Up @@ -123,7 +167,6 @@ EncodedData ans_encode_t(const T* signal,
}

// Map lookups can be a bottleneck, so we use a lookup array if the number of symbols is "small"
constexpr int lookup_array_threshold = 4096;
const bool use_lookup_array = (max_symbol - min_symbol + 1) <= lookup_array_threshold;
std::vector<size_t> symbol_index_lookup_array;
if (use_lookup_array)
Expand Down
30 changes: 29 additions & 1 deletion simple_ans/encode_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,45 @@
from ._simple_ans import (
ans_encode_int16 as _ans_encode_int16,
ans_decode_int16 as _ans_decode_int16,
ans_unique_int16 as _ans_unique_int16,
ans_encode_int32 as _ans_encode_int32,
ans_decode_int32 as _ans_decode_int32,
ans_unique_int32 as _ans_unique_int32,
ans_encode_uint16 as _ans_encode_uint16,
ans_decode_uint16 as _ans_decode_uint16,
ans_unique_uint16 as _ans_unique_uint16,
ans_encode_uint32 as _ans_encode_uint32,
ans_decode_uint32 as _ans_decode_uint32,
ans_unique_uint32 as _ans_unique_uint32,
ans_encode_uint8 as _ans_encode_uint8,
ans_decode_uint8 as _ans_decode_uint8,
ans_unique_uint8 as _ans_unique_uint8,
)


def _ans_unique(arr: np.ndarray):
dtype = arr.dtype
if dtype == np.int32:
vals, counts = _ans_unique_int32(arr)
elif dtype == np.int16:
vals, counts = _ans_unique_int16(arr)
elif dtype == np.uint32:
vals, counts = _ans_unique_uint32(arr)
elif dtype == np.uint8:
vals, counts = _ans_unique_uint8(arr)
elif dtype == np.uint16:
vals, counts = _ans_unique_uint16(arr)
else:
raise TypeError("Invalid numpy type")

assert len(vals) == len(counts)

if not len(vals):
vals, counts = np.unique(arr, return_counts=True)

return vals, counts


def ans_encode(signal: np.ndarray, *, index_size: Union[int, None] = None, verbose=False) -> EncodedSignal:
"""Encode a signal using Asymmetric Numeral Systems (ANS).
Expand All @@ -36,7 +64,7 @@ def ans_encode(signal: np.ndarray, *, index_size: Union[int, None] = None, verbo
assert signal.ndim == 1, "Input signal must be a 1D array"

signal_length = len(signal)
vals, counts = np.unique(signal, return_counts=True)
vals, counts = _ans_unique(signal)
vals = np.array(vals, dtype=signal.dtype)
probs = counts / np.sum(counts)

Expand Down
13 changes: 13 additions & 0 deletions simple_ans_bind.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <cstdint>
#include <cstring> // for memcpy

#include "simple_ans/cpp/simple_ans.hpp"

namespace py = pybind11;
Expand All @@ -12,6 +14,17 @@ void bind_ans_functions(py::module& m, const char* type_suffix)
{
std::string ans_encode_name = std::string("ans_encode_") + type_suffix;
std::string ans_decode_name = std::string("ans_decode_") + type_suffix;
std::string ans_unique_name = std::string("ans_unique_") + type_suffix;

m.def(
ans_unique_name.c_str(),
[](py::array_t<T> signal)
{
py::buffer_info buf = signal.request();
return simple_ans::unique_with_counts(static_cast<const T*>(buf.ptr), buf.size);
},
"Get unique values and their counts",
py::arg("signal").noconvert());

m.def(
ans_encode_name.c_str(),
Expand Down

0 comments on commit cd78804

Please sign in to comment.