Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Byte calculation on analysis wo ec missing Cholesky Channels multiplier #11

Open
Downchuck opened this issue Aug 26, 2024 · 6 comments
Open

Comments

@Downchuck
Copy link

In the gaussianimage_cholesky.py we have a six bit quantizer with three channels:

            self.cholesky_quantizer = UniformQuantizer(signed=False, bits=6, learned=True, num_channels=3)

In the analysis_wo_ec method, the number of channels are missing in the bit measurement:

        total_bits += quant_cholesky_elements.size * 6 #cholesky bits 

and

        cholesky_bits += len(quant_cholesky_elements) * 6

I believe that those should be multiplied by 3.

@Xinjie-Q
Copy link
Owner

I have carefully checked the analysis. For the calculation of total_bits, our calculation is correct. You can see that we change the tensor quant_cholesky_elements to ndarray format by using this code quant_cholesky_elements = quant_cholesky_elements.cpu().numpy(). Then we use the ndarray.size to obtain the number of elements to calculate the total_bits. Thus, our compression results are correct.

However, for analysis of cholesky_bits, it is wrong. Since len(quant_cholesky_elements) = the number of gaussians, it should be multiplied by 3. We have fixed this problem. Thanks for pointing our error in analyzing the cholesky_bits.

@Downchuck
Copy link
Author

Thanks for the clear code base and quick response!

I've reproduced the file size as reported (total bytes) on a round trip, serializing and deserializing per the analysis_wo_ec method as a reference. The attached added 5 bytes of header in overhead in my test, which is exactly what it should do as that's how big the extra header is.

Serialize

        # simple serde:
        with open(self.log_dir / 'simple.bin', 'wb+') as f:
            xyz_bytes = encoding_dict["xyz"].detach().cpu().numpy().tobytes()
            feature_dc_index = encoding_dict["feature_dc_index"].int().detach().cpu().numpy().reshape(-1)
            quant_cholesky_elements = encoding_dict["quant_cholesky_elements"].int().detach().cpu().numpy().reshape(-1)

            scale_floats = self.gaussian_model.cholesky_quantizer.scale.detach().cpu().numpy().tobytes()
            beta_floats = self.gaussian_model.cholesky_quantizer.beta.detach().cpu().numpy().tobytes()

            layer_floats = bytearray()
            for _, layer in enumerate(self.gaussian_model.features_dc_quantizer.quantizer.layers):
                current_layer_floats = layer._codebook.embed.detach().cpu().numpy().tobytes()
                layer_floats.extend(current_layer_floats)

            max_bit = np.max(feature_dc_index).item().bit_length()

            header = [round(len(feature_dc_index)/2), round(len(quant_cholesky_elements)/3)]
            max_header_bit = max(header).bit_length()
            header_byte = max_header_bit.to_bytes()
            max_cholesky_bit = 6

            f.write(header_byte)
            f.write(max_bit.to_bytes())
            packed_bits = bytearray(packbits.pack_bytesize(len(header), max_header_bit))
            packbits.pack_word(packed_bits, header, max_header_bit)
            f.write(packed_bits)

            packed_bits = bytearray(packbits.pack_bytesize(len(feature_dc_index), max_bit))
            packbits.pack_word(packed_bits, feature_dc_index, max_bit)
            f.write(packed_bits)

            packed_bits = bytearray(packbits.pack_bytesize(len(quant_cholesky_elements), max_cholesky_bit))
            packbits.pack_word(packed_bits, quant_cholesky_elements, max_cholesky_bit)
            f.write(packed_bits)
            f.write(scale_floats + beta_floats)
            f.write(layer_floats)
            f.write(xyz_bytes)

Deserialize:

        with open(self.log_dir / 'simple.bin', 'rb') as f:
            mm = memoryview(mmap.mmap(f.fileno(), 0, prot=mmap.PROT_READ))
            max_header_bit = mm[0]
            max_bit = mm[1]

            header_packet_size = packbits.pack_bytesize(2, max_header_bit)
            pos = header_packet_size + 2
            header, _ = packbits.unpack_bits(mm[2:pos], 2, max_header_bit)
            len_feature_half, len_cholesky_half = header

            next_pos = pos + packbits.pack_bytesize(len_feature_half*2, max_bit)
            feature_dc_index, _= packbits.unpack_bits(mm[pos:next_pos], len_feature_half*2, max_bit)
            pos = next_pos

            max_cholesky_bit = 6
            next_pos = pos + packbits.pack_bytesize(len_cholesky_half*3, max_cholesky_bit)
            quant_cholesky_elements, _ = packbits.unpack_bits(mm[pos:next_pos], len_cholesky_half*3, max_cholesky_bit)
            pos = next_pos

            next_pos = pos + 12
            scale_floats = mm[pos:next_pos]
            pos = next_pos

            next_pos = pos + 12
            beta_floats = mm[pos:next_pos]
            pos = next_pos

            next_pos = pos + 96
            layer_one_floats = mm[pos:next_pos]
            pos = next_pos

            next_pos = pos + 96
            layer_two_floats = mm[pos:next_pos]
            pos = next_pos

            estate_dict = self.gaussian_model.state_dict()
            estate_dict['cholesky_quantizer.scale'] = torch.frombuffer(scale_floats, dtype=torch.float32).cuda()
            estate_dict['cholesky_quantizer.beta'] = torch.frombuffer(beta_floats, dtype=torch.float32).cuda()

            self.gaussian_model.load_state_dict(estate_dict)

            restate_dict = self.gaussian_model.features_dc_quantizer.state_dict()
            restate_dict['quantizer.layers.0._codebook.embed'] = torch.frombuffer(layer_one_floats, dtype=torch.float32).cuda().reshape((1, 8, 3,))
            restate_dict['quantizer.layers.1._codebook.embed'] = torch.frombuffer(layer_two_floats, dtype=torch.float32).cuda().reshape((1, 8, 3,))
            self.gaussian_model.features_dc_quantizer.load_state_dict(restate_dict)

            xyz = mm[pos:]
            encoding_dict = {
                "feature_dc_index": torch.frombuffer(feature_dc_index, dtype=torch.uint8).cuda().int().reshape((-1,2,)),
                "quant_cholesky_elements": torch.frombuffer(quant_cholesky_elements, dtype=torch.uint8).cuda().int().reshape((-1,3,)),
                "xyz": torch.frombuffer(xyz, dtype=torch.float16).cuda().float().reshape((-1,2,))
            }

            transform = transforms.ToPILImage()
            out = self.gaussian_model.decompress_wo_ec(encoding_dict)
            out_img = out["render"].float()
            img = transform(out_img.squeeze(0))
            name = "repro.png"
            img.save(str(self.log_dir / name))

@Xinjie-Q Xinjie-Q reopened this Aug 29, 2024
@Downchuck
Copy link
Author

Not meant to reopen -- write confirmed the correct bits -- I just have an extra five for the header I am using to stash the dynamic parameters.

@Xinjie-Q
Copy link
Owner

Thank you for your contributions to accurately calculating the storage for GaussianImage. If you don't mind, I would like to integrate this code for storing files into GaussianImage. Alternatively, you could submit a request to do this.

@Downchuck
Copy link
Author

Downchuck commented Sep 1, 2024

@Xinjie-Q - here's the code I used for packing bits in a subbyte and a generic one for bit sizes over eight bits.

I was toying with packing across sections, by allowing bit_offset to be specified, but that really wasn't worth the effort.

import random
from typing import List

def pack_bytesize(num_elements, bit_width):
    return (num_elements * bit_width + 7) // 8

def pack_subbyte(byte_array: memoryview, int_view: bytearray, bit_width: int, bit_offset = 0, byte_index = 0):
    if bit_offset >= bit_width:
        raise ValueError("Subbyte bit offset should be less bit width")

    if bit_width > 8:
        raise ValueError("Subbyte width should be less than 8")

    for value in int_view:
        if bit_width + bit_offset > 8:
            overflow = bit_width + bit_offset - 8
            byte_array[byte_index] |= value >> overflow
            byte_index += 1
            byte_array[byte_index] |= (value & ((1 << overflow) - 1)) << (8 - overflow)
        else:
            byte_array[byte_index] |= value << (8 - bit_offset - bit_width)
        
        bit_offset = (bit_offset + bit_width) % 8
        if bit_offset == 0:
            byte_index += 1

def unpack_subbyte(byte_view: memoryview, int_view: memoryview, bit_width: int, bit_offset=0, byte_index=0):
    if bit_width > 8:
        raise ValueError("Subbyte width should be less than 8")
    mask = ((1 << bit_width) - 1)
    for i in range(len(int_view)):
        if bit_width + bit_offset > 8:
            # When the value spans across two bytes
            overflow = bit_width + bit_offset - 8
            int_view[i] = ((byte_view[byte_index] << overflow) & mask | (byte_view[byte_index + 1] >> (8 - overflow))) 
        else:
            # When the value fits within one byte
            shift_amount = 8 - bit_offset - bit_width
            int_view[i] = (byte_view[byte_index] >> shift_amount) & mask

        byte_index += (bit_offset + bit_width) // 8
        bit_offset = (bit_offset + bit_width) % 8


# Test cases and pack word generated via gpt.

def pack_word(bit_array: memoryview, int_view: List[int], bit_width: int, bit_offset = 0) -> bytearray:
    byte_index = 0
    for value in int_view:
        bits_remaining = bit_width
        while bits_remaining > 0:
            current_byte_bits = min(8 - bit_offset, bits_remaining)
            mask = (1 << current_byte_bits) - 1
            value_part = (value >> (bits_remaining - current_byte_bits)) & mask
            
            bit_array[byte_index] |= value_part << (8 - bit_offset - current_byte_bits)
            
            bits_remaining -= current_byte_bits
            bit_offset = (bit_offset + current_byte_bits) % 8
            if bit_offset == 0:
                byte_index += 1
    
    return bit_array


def unpack_word(byte_view: bytes, int_array: List[int], bit_width: int, bit_offset = 0):
    byte_index = 0
    for i in range(len(int_array)):
        value = 0
        bits_remaining = bit_width
        while bits_remaining > 0:
            current_byte_bits = min(8 - bit_offset, bits_remaining)
            mask = (1 << current_byte_bits) - 1
            value_part = (byte_view[byte_index] >> (8 - bit_offset - current_byte_bits)) & mask
            value = (value << current_byte_bits) | value_part
            
            bits_remaining -= current_byte_bits
            bit_offset = (bit_offset + current_byte_bits) % 8
            if bit_offset == 0:
                byte_index += 1
        
        int_array[i] = value



def test_pack_unpack_subbyte():
    # meant to be np.array of appropriate np.int16 size.
    test_cases = [
        {'bit_width': 3, 'int_view': bytearray([5, 2, 7, 3])},  # bit_width 3: values fit within 3 bits
        {'bit_width': 5, 'int_view': bytearray([10, 15, 2, 18])}, # bit_width 5: values fit within 5 bits
        {'bit_width': 8, 'int_view': bytearray([255, 128, 64, 32])}, # bit_width 8: max value within a single byte
        {'bit_width': 12, 'int_view': ([4095, 1024, 512, 256])}, # bit_width 12: values exceed a single byte
        {'bit_width': 15, 'int_view': ([32767, 16384, 8192, 4096])} # bit_width 15: values span across bytes
    ]
    
    test_cases = [
        {'bit_width': 11, 'int_view': ([3, 2000, 2000])},
    ]

    for i, case in enumerate(test_cases):
        bit_width = case['bit_width']
        int_view = case['int_view']
        
        if bit_width <= 8:
            # Use the pack_subbyte function for bit widths <= 8
            packed_bits = bytearray(pack_bytesize(len(int_view), bit_width))
            pack_subbyte(packed_bits, int_view, bit_width)
            unpacked_ints = bytearray(len(int_view))
            unpack_subbyte(memoryview(packed_bits), memoryview(unpacked_ints), bit_width)
        else:
            # Use the pack_subbyte_large function for bit widths > 8
            packed_bits = bytearray(pack_bytesize(len(int_view), bit_width))
            pack_word(packed_bits, int_view, bit_width)
            unpacked_ints = [0] * len(int_view)
            unpack_word(packed_bits, unpacked_ints, bit_width)
        
        # Verify that the original and unpacked integers match
        assert unpacked_ints == int_view, f"Test case {i + 1} failed: {unpacked_ints} != {int_view} with bit width {bit_width}"
        print(f"Test case {i + 1} passed. Packed bits: {packed_bits.hex()} Unpacked ints: {list(unpacked_ints)}")

def pack_lists(buffers, bit_offset, packed_bytes):
    # packed_bytes = bytearray(pack_lists_bytesize(buffers))
    for int_array, bit_width in buffers:
        # pack is not intended for byte aligned data.
        if bit_width % 8 == 0:
            raise ValueError("Wrong place for byte aligned data")
            packed_bytes[:] = int_array[:]
        elif bit_width < 8:
            pack_subbyte(packed_bytes, int_array, bit_width)
        else:
            pack_word(packed_bytes, int_array, bit_width)
        # lets not bother with trying to bit pack heterogenuous fixed sections
        bit_offset += (bit_width * len(int_array)) % 8
    return bit_offset, packed_bytes

def pack_lists_bytesize(buffers):
    bytesize = 0
    for elements, bit_width in buffers:
        bytesize += (len(elements) * bit_width)
    return (bytesize + 7) // 8

def unpack_bits(packed_bytes, count, bit_width, bit_offset = 0):
    if bit_width % 8 == 0 and bit_offset % 8 == 0:
        return packed_bytes[round(bit_offset / 8):]
    elif bit_width < 8:
        int_view = bytearray(count) # pack_bytesize(count, bit_width))
        unpack_subbyte(packed_bytes, int_view, bit_width, bit_offset)
        bit_offset += (bit_width * len(int_view) + bit_offset) % 8
        return int_view, bit_offset
    else:
        int_view = [0] * count
        unpack_word(packed_bytes, int_view, bit_width, bit_offset)
        bit_offset += (bit_width * len(int_view) + bit_offset) % 8
        return int_view, bit_offset

# Notes on copying bits.
import timeit
def benchmark_unpack():
    bit_width = 8
    int_view = bytearray([255, 128, 64, 32])
    packed_bits = bytearray(pack_bytesize(len(int_view), bit_width))

    run_a = timeit.timeit(lambda: pack_subbyte(packed_bits, int_view, bit_width), number=10000)
    run_b = timeit.timeit(lambda: pack_word(packed_bits, int_view, bit_width), number=10000)
    print([run_a, run_b ])
    unpacked_ints = bytearray(len(int_view))
    run_c = timeit.timeit(lambda: unpack_subbyte(memoryview(packed_bits), memoryview(unpacked_ints), bit_width), number=10000)
    unpacked_ints = [0] * len(int_view)
    run_d = timeit.timeit(lambda: unpack_word(packed_bits, unpacked_ints, bit_width), number=10000)
    print([run_c, run_d ])
    def cp():
        unpacked_ints[:] = packed_bits[:]
    print([timeit.timeit(cp, number=10000)])
    unpacked_ints = bytearray(len(int_view))
    print([timeit.timeit(cp, number=10000)])
    print([timeit.timeit(lambda : pack_lists([((int_view), bit_width,)], 0, (packed_bits)), number=10000)])


# benchmark_unpack()
# test_pack_unpack_subbyte()

@Downchuck Downchuck reopened this Sep 1, 2024
@Downchuck
Copy link
Author

Here's a nice start for a Pytorch implementation I came across recently for packing to bit sizes 1, 2 and 4:
https://gist.github.com/vadimkantorov/30ea6d278bc492abf6ad328c6965613a

From their request:
pytorch/ao#292

That discussion thread shows the evolution in torch to uint1 - uint7:
https://dev-discuss.pytorch.org/t/supporting-new-dtypes-in-pytorch/1833

-Charles

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants