Skip to content

Generic packing algorithms from size N to M #284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
vayuda opened this issue May 26, 2024 · 5 comments
Open

Generic packing algorithms from size N to M #284

vayuda opened this issue May 26, 2024 · 5 comments
Labels
enhancement New feature or request

Comments

@vayuda
Copy link
Collaborator

vayuda commented May 26, 2024

In order to support sub-byte dtypes for quantization, I (and many others) believe that it is better to pack these smaller dtypes into existing pytorch dtypes in order to reduce memory bandwidth contention for a bit of increased computation. Here is a preliminary algorithm in pytorch for doing this. It supports many types of conversions as seen in the tests.

Inspecting the compiled Triton code seems promising because it only launches one kernel and one buffer. Here is a snippit

@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 4
    x1 = (xindex // 4)
    x2 = xindex
    tmp0 = x0
    tmp1 = tl.full([1], 0, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1], 1, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = tl.load(in_ptr0 + (x1), tmp4 & xmask, eviction_policy='evict_last', other=0.0)
    tmp6 = tl.full([1], 6, tl.uint8)
    tmp7 = tmp5 >> tmp6
    tmp8 = tl.full([1], 3, tl.uint8)
    tmp9 = tmp7 & tmp8
    tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
    tmp11 = tl.where(tmp4, tmp9, tmp10)
    tmp12 = tmp0 >= tmp3
    tmp13 = tl.full([1], 2, tl.int64)
    tmp14 = tmp0 < tmp13
    tmp15 = tmp12 & tmp14
    tmp16 = tl.load(in_ptr0 + (x1), tmp15 & xmask, eviction_policy='evict_last', other=0.0)
    tmp17 = tl.full([1], 4, tl.uint8)
    tmp18 = tmp16 >> tmp17
    tmp19 = tmp18 & tmp8
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp15, tmp19, tmp20)
    tmp22 = tmp0 >= tmp13
    tmp23 = tl.full([1], 3, tl.int64)
    tmp24 = tmp0 < tmp23
    tmp25 = tmp22 & tmp24
    tmp26 = tl.load(in_ptr0 + (x1), tmp25 & xmask, eviction_policy='evict_last', other=0.0)
    tmp27 = tl.full([1], 2, tl.uint8)
    tmp28 = tmp26 >> tmp27
    tmp29 = tmp28 & tmp8
    tmp30 = tl.full(tmp29.shape, 0.0, tmp29.dtype)
    tmp31 = tl.where(tmp25, tmp29, tmp30)
    tmp32 = tmp0 >= tmp23
    tmp33 = tl.full([1], 4, tl.int64)
    tmp34 = tmp0 < tmp33
    tmp35 = tl.load(in_ptr0 + (x1), tmp32 & xmask, eviction_policy='evict_last', other=0.0)
    tmp36 = tl.full([1], 0, tl.uint8)
    tmp37 = tmp35 >> tmp36
    tmp38 = tmp37 & tmp8
    tmp39 = tl.full(tmp38.shape, 0.0, tmp38.dtype)
    tmp40 = tl.where(tmp32, tmp38, tmp39)
    tmp41 = tl.where(tmp25, tmp31, tmp40)
    tmp42 = tl.where(tmp15, tmp21, tmp41)
    tmp43 = tl.where(tmp4, tmp11, tmp42)
    tl.store(out_ptr0 + (x2), tmp43, xmask)
''', device_str='cuda')

def call(args):
    arg0_1, arg1_1, arg2_1, arg3_1 = args
    args.clear()
    s0 = arg0_1
    s1 = arg1_1
    s2 = arg2_1
    assert_size_stride(arg3_1, (s0, s1, s2), (s1*s2, s2, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((s0, s1, s2, 4), (4*s1*s2, 4*s2, 4, 1), torch.uint8)
        # Source Nodes: [stack], Original ATen: [aten.stack]
        triton_poi_fused_stack_0_xnumel = 4*s0*s1*s2
        stream0 = get_raw_stream(0)
        triton_poi_fused_stack_0.run(arg3_1, buf0, triton_poi_fused_stack_0_xnumel, grid=grid(triton_poi_fused_stack_0_xnumel), stream=stream0)
        del arg3_1
    return (reinterpret_tensor(buf0, (s0, s1, 4*s2), (4*s1*s2, 4*s2, 1), 0), )
@msaroufim
Copy link
Member

msaroufim commented May 26, 2024

This is quite cool and I've been thinking along similar lines

I think what we could to do to ship this is in quantization/ merge the pack and unpack functions and then have tests to ensure the the codegen is efficient. In practice you can test that a single kernel is launched by in your tests doing torch.compile(..., fullgraph=True) - I'm not sure how we can validate that single buffer is used but perhaps @eellison does

And this can be a baseline for smaller dtypes. I'd be specific somewhere in the function names or docs that this is padding-based? Cause conceptually I can imagine another alternative where instead of wasting space you could pack 8 uint3 into 3 unint8 as a more general algorithm but that's finicky enough that we don't have to worry about it right now

@msaroufim msaroufim added the enhancement New feature or request label May 26, 2024
@msaroufim
Copy link
Member

Also @mobicham had been asking us for standardizing bitpacking logic so curious on his thoughts too

@mobicham
Copy link
Collaborator

mobicham commented May 27, 2024

Thanks @vayuda , very interesting, thanks of sharing!

Normally, bit-unpacking is almost never used in isolation, it's either fused in a dequant kernel or a low-bit matmul kernel. There are two main things to consider while designing a bitpacking logic:

@msaroufim do you know by any chance what kind of bitpacking logic is used in tiny_gemm?

@vayuda
Copy link
Collaborator Author

vayuda commented May 27, 2024

@mobicham Thanks for the input. The interleaved accessing is interesting though I'm not really sure what it means to fully take advantage of tensor cores. I think this is something we can iterate on. For now I can create a version that can do row-wise pack/unpack.

As per @msaroufim suggestions, I will place these functions in the api file and write appropriate tests.

@vadimkantorov
Copy link

vadimkantorov commented Jun 27, 2024

Even in relative isolation (without op support) bit packing/unpacking, is still useful for saving memory footprint when storing bool tensors / masks / bitsets:

But of course, more op support is needed for compressed bool tensors / bittensors / bitsets as well...

(Similarly, for some other usecases, it is still useful even when packing/unpacking is not fused into ops where the bottleneck is actually memory efficiency and speed overhead can be tolerated)

yanbing-j pushed a commit to yanbing-j/ao that referenced this issue Dec 9, 2024
Summary: As titled

Test Plan:

```
(executorch) [larryliu@devvm11963.vll0 ~/torchchat (main)]$ python torchchat.py generate --device cpu --checkpoint-path ${MODEL_PATH} --temperature 0 --tiktoken
Warning: command generate does not support option output-pte-path
Warning: command generate does not support option output-dso-path
Using device=cpu Intel Core Processor (Broadwell)
Reloaded Tiktoken model from /data/users/larryliu/llama3/Meta-Llama-3-8B/tokenizer.model
Loading model ...
name Meta-Llama-3-8B
Time to load model: 5.22 seconds
Time to quantize model: 0.00 seconds
<|begin_of_text|>Hello, my name is Kaitlyn and I am a 20 year old college student. I am currently a sophomore at the University of North Carolina at Chapel Hill. I am majoring in Psychology and minoring in Business Administration. I have been babysitting for about 5 years now and I have experience with children of all ages. I have experience with children with special needs as well. I have worked with children with autism, ADHD, and Down Syndrome. I have also worked with children with behavioral issues. I have also worked with children with speech delays. I have experience with children with allergies and asthma. I have also worked with children with food allergies. I have experience with children with diabetes. I have experience with children with epilepsy. I have experience with children with cerebral palsy. I have experience with children with ADD. I have experience with children with ADHD. I have experience with children with autism. I have experience with children with Down Syndrome. I have experience with children with speech delays. I have experience with
Time for inference 1: 451.86 sec total, 0.44 tokens/sec
Bandwidth achieved: 14.22 GB/s
==========
Average tokens/sec: 0.44
Memory used: 0.00 GB
```

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants