0.1
This package implements the fastest first-order parallel associative scan on the GPU for forward and backward.
The scan efficiently solves first-order recurrences of the form x[t] = gate[t] * x[t-1] + token[t], common in state space models and linear RNNs.
The accelerated_scan.warp C++ CUDA kernel uses a chunked processing algorithm that leverages the fastest GPU communication primitives available on each level of hierarchy: warp shuffles within warps of 32 threads and shared memory (SRAM) between warps within a thread block. One sequence per channel dimension is confined to one thread block.
The derivation of Chunked Scan has been used to extend tree-level Blelloch algorithm to block
A similar implementation is available in accelerated_scan.triton using a Triton's tl.associative_scan primitive. It requires Triton 2.2 for its enable_fp_fusion flag.