Skip to content

0.1

Compare
Choose a tag to compare
@proger proger released this 10 Jan 10:50
· 36 commits to main since this release

This package implements the fastest first-order parallel associative scan on the GPU for forward and backward.

The scan efficiently solves first-order recurrences of the form x[t] = gate[t] * x[t-1] + token[t], common in state space models and linear RNNs.

The accelerated_scan.warp C++ CUDA kernel uses a chunked processing algorithm that leverages the fastest GPU communication primitives available on each level of hierarchy: warp shuffles within warps of 32 threads and shared memory (SRAM) between warps within a thread block. One sequence per channel dimension is confined to one thread block.

The derivation of Chunked Scan has been used to extend tree-level Blelloch algorithm to block

A similar implementation is available in accelerated_scan.triton using a Triton's tl.associative_scan primitive. It requires Triton 2.2 for its enable_fp_fusion flag.

bench