diff --git a/gimmik/__init__.py b/gimmik/__init__.py index fc7f015..b32ebdc 100644 --- a/gimmik/__init__.py +++ b/gimmik/__init__.py @@ -6,6 +6,7 @@ from gimmik.cuda import CUDAMatMul from gimmik.ispc import ISPCMatMul from gimmik.hip import HIPMatMul +from gimmik.metal import MetalMatMul from gimmik.opencl import OpenCLMatMul diff --git a/gimmik/copenmp.py b/gimmik/copenmp.py deleted file mode 100644 index 09f3124..0000000 --- a/gimmik/copenmp.py +++ /dev/null @@ -1,11 +0,0 @@ -# -*- coding: utf-8 -*- - -from gimmik.base import MatMul - - -class COpenMPMatMul(MatMul): - platform = 'c-openmp' - basemeta = {} - - def _kernel_generators(self, dtype, dsize): - yield ('cstream', {}, {}) diff --git a/gimmik/kernels/metal/base.mako b/gimmik/kernels/metal/base.mako new file mode 100644 index 0000000..f889892 --- /dev/null +++ b/gimmik/kernels/metal/base.mako @@ -0,0 +1,16 @@ +#include + +using namespace metal; + +% if dtype.endswith('4'): +static inline ${dtype} make_zero() +{ return ${dtype}(0, 0, 0, 0); } +% elif dtype.endswith('2'): +static inline ${dtype} make_zero() +{ return ${dtype}(0, 0); } +% else: +static inline ${dtype} make_zero() +{ return 0; } +% endif + +${next.body()} diff --git a/gimmik/kernels/metal/bstream-msplit.mako b/gimmik/kernels/metal/bstream-msplit.mako new file mode 100644 index 0000000..f95b9f2 --- /dev/null +++ b/gimmik/kernels/metal/bstream-msplit.mako @@ -0,0 +1,93 @@ +<%inherit file='base'/> + +<% +mx = partition(A, into=msplit, by='rows') +bchunks = chunk(bix, bsz) +%> + +kernel void +% if n is None: +${kname}(constant int& n_, + device ${dtype}* b, constant int& ldb_, + device ${dtype}* c, constant int& ldc_, + uint2 tpig [[thread_position_in_grid]], + uint2 tpitg [[thread_position_in_threadgroup]]) +{ + const int n = ((n_ + ${width} - 1) / ${width}) * ${width}; + const int ldb = ldb_ / ${width}; + const int ldc = ldc_ / ${width}; +% else: +${kname}(device const ${dtype}* b, device ${dtype}* c, + uint2 tpig [[thread_position_in_grid]], + uint2 tpitg [[thread_position_in_threadgroup]]) +{ + const int n = ${-(-n // width)}; + const int ldb = ${ldb // width}; + const int ldc = ${ldc // width}; +% endif + const int i = tpig.x; + + ${dtype} bv, csub[${-(-m // msplit)}]; + threadgroup ${dtype} bsub[2][${bsz}][${blockx}]; + +## Fill the initial shared memory block +% for cid in range(msplit): + if (i < n && tpitg.y == ${cid}) + { + % for kx in bchunks[0]: + % if loop.index % msplit == cid: + bsub[0][${loop.index}][tpitg.x] = b[i + ${kx}*ldb]; + % endif + % endfor + } +% endfor + threadgroup_barrier(mem_flags::mem_threadgroup); + +## Iterate over each row-chunk of B +% for bb in range(len(bchunks)): + ## Iterate over each row-chunk of C + % for cid, mcx in enumerate(mx): + if (i < n && tpitg.y == ${cid}) + { + ## Start filling the next shared memory block + % if not loop.parent.last: + % for kx in bchunks[bb + 1]: + % if loop.index % msplit == cid: + bsub[${(bb + 1) % 2}][${loop.index}][tpitg.x] = b[i + ${kx}*ldb]; + % endif + % endfor + % endif + ## Accumulate our dot products + % for kx in bchunks[bb]: + bv = bsub[${bb % 2}][${loop.index}][tpitg.x]; + % for j, jx in enumerate(A[mcx, kx]): + % if jx != 0 and kx == afix[mcx[j]]: + csub[${j}] = ${jx}*bv; + % elif jx != 0: + csub[${j}] += ${jx}*bv; + % endif + ## If we're done with this dot product then store to global + % if kx == alix[mcx[j]] and beta == 0: + c[i + ${mcx[j]}*ldc] = csub[${j}]; + % elif kx == alix[mcx[j]] and beta == 1: + c[i + ${mcx[j]}*ldc] += csub[${j}]; + % elif kx == alix[mcx[j]]: + c[i + ${mcx[j]}*ldc] = csub[${j}] + ${beta}*c[i + ${mcx[j]}*ldc]; + % endif + % endfor + % endfor + ## Handle rows of A which are all zero + % if loop.parent.last: + % for j, jx in enumerate(afix): + % if jx == -1 and j % msplit == cid and beta == 0: + c[i + ${j}*ldc] = make_zero(); + % elif jx == -1 and j % msplit == cid and beta != 1: + c[i + ${j}*ldc] *= ${beta}; + % endif + % endfor + % endif + } + % endfor + threadgroup_barrier(mem_flags::mem_threadgroup); +% endfor +} diff --git a/gimmik/kernels/metal/bstream.mako b/gimmik/kernels/metal/bstream.mako new file mode 100644 index 0000000..5ff34ea --- /dev/null +++ b/gimmik/kernels/metal/bstream.mako @@ -0,0 +1,55 @@ +<%inherit file='base'/> + +kernel void +% if n is None: +${kname}(constant int& n_, + device ${dtype}* b, constant int& ldb_, + device ${dtype}* c, constant int& ldc_, + uint i [[thread_position_in_grid]]) +{ + const int n = ((n_ + ${width} - 1) / ${width}) * ${width}; + const int ldb = ldb_ / ${width}; + const int ldc = ldc_ / ${width}; +% else: +${kname}(device const ${dtype}* b, device ${dtype}* c, + uint i [[thread_position_in_grid]]) +{ + const int n = ${-(-n // width)}; + const int ldb = ${ldb // width}; + const int ldc = ${ldc // width}; +% endif + + if (i < n) + { + ${dtype} bv, csub[${m}]; + +## Iterare through the used rows of B +% for kx in bix: + bv = b[i + ${kx}*ldb]; + % for j, jx in enumerate(A[:, kx]): + % if jx != 0 and kx == afix[j]: + csub[${j}] = ${jx}*bv; + % elif jx != 0: + csub[${j}] += ${jx}*bv; + % endif + ## + % if kx == alix[j] and beta == 0: + c[i + ${j}*ldc] = csub[${j}]; + % elif kx == alix[j] and beta == 1: + c[i + ${j}*ldc] += csub[${j}]; + % elif kx == alix[j]: + c[i + ${j}*ldc] = csub[${j}] + ${beta}*c[i + ${j}*ldc]; + % endif + % endfor +% endfor + +## Handle rows of A which are all zero +% for j, jx in enumerate(afix): + % if jx == -1 and beta == 0: + c[i + ${j}*ldc] = make_zero(); + % elif jx == -1 and beta != 1: + c[i + ${j}*ldc] *= ${beta}; + % endif +% endfor + } +} diff --git a/gimmik/kernels/metal/cstream-ksplit.mako b/gimmik/kernels/metal/cstream-ksplit.mako new file mode 100644 index 0000000..d6fb927 --- /dev/null +++ b/gimmik/kernels/metal/cstream-ksplit.mako @@ -0,0 +1,86 @@ +<%inherit file='base'/> + +<% +kparts = partition(A, ksplit, by='cols') +cchunks = chunk(range(m), csz) +loaded = set() +%> + +kernel void +% if n is None: +${kname}(constant int& n_, + device ${dtype}* b, constant int& ldb_, + device ${dtype}* c, constant int& ldc_, + uint2 tpig [[thread_position_in_grid]], + uint2 tpitg [[thread_position_in_threadgroup]]) +{ + const int n = ((n_ + ${width} - 1) / ${width}) * ${width}; + const int ldb = ldb_ / ${width}; + const int ldc = ldc_ / ${width}; +% else: +${kname}(device const ${dtype}* b, device ${dtype}* c, + uint2 tpig [[thread_position_in_grid]], + uint2 tpitg [[thread_position_in_threadgroup]]) +{ + const int n = ${-(-n // width)}; + const int ldb = ${ldb // width}; + const int ldc = ${ldc // width}; +% endif + const int i = tpig.x; + + ${dtype} cv[${-(-csz // ksplit)}], bv[${-(-k // ksplit)}], dotp; + threadgroup ${dtype} csub[${ksplit - 1}][${csz}][${blockx}]; + +## Iterate over the row-partitions of C +% for cchunk in cchunks: + ## Iterate over the row-partitions of B + % for bid, kbx in enumerate(kparts): + if (i < n && tpitg.y == ${bid}) + { + ## Evaluate our partial dot products + % for j in cchunk: + ## Load in any missing parts of B + % for kx in kbx: + % if A[j, kx] != 0 and kx not in loaded: + bv[${loop.index}] = b[i + ${kx}*ldb]; <% loaded.add(kx) %> + % endif + % endfor + % if (dotex := dot(lambda kx: f'bv[{kx}]', A[j, kbx])) != '0.0': + dotp = ${dotex}; + % else: + dotp = make_zero(); + % endif + ## Save to a register + % if loop.index % ksplit == bid: + cv[${loop.index // ksplit}] = dotp; + ## Save to shared memory + % else: + csub[${bid - (bid > loop.index % ksplit)}][${loop.index}][tpitg.x] = dotp; + % endif + % endfor + } + % endfor + threadgroup_barrier(mem_flags::mem_threadgroup); + ## Iterate over the column-partitions of B + % for bid, kbx in enumerate(kparts): + if (i < n && tpitg.y == ${bid}) + { + ## Sum and output the final set of dot products + % for j in cchunk: + % if loop.index % ksplit == bid: + dotp = cv[${loop.index // ksplit}] + ${' + '.join(f'csub[{i}][{loop.index}][tpitg.x]' + for i in range(ksplit - 1))}; + % if beta == 0: + c[i + ${j}*ldc] = dotp; + % elif beta == 1: + c[i + ${j}*ldc] += dotp; + % else: + c[i + ${j}*ldc] = dotp + ${beta}*c[i + ${j}*ldc]; + % endif + % endif + % endfor + } + % endfor + threadgroup_barrier(mem_flags::mem_threadgroup); +% endfor +} diff --git a/gimmik/kernels/metal/cstream.mako b/gimmik/kernels/metal/cstream.mako new file mode 100644 index 0000000..ce77951 --- /dev/null +++ b/gimmik/kernels/metal/cstream.mako @@ -0,0 +1,42 @@ +<%inherit file='base'/> + +<% ksplit = 2 if m < 36 else 1 %> + +kernel void +% if n is None: +${kname}(constant int& n_, + device ${dtype}* b, constant int& ldb_, + device ${dtype}* c, constant int& ldc_, + uint i [[thread_position_in_grid]]) +{ + const int n = ((n_ + ${width} - 1) / ${width}) * ${width}; + const int ldb = ldb_ / ${width}; + const int ldc = ldc_ / ${width}; +% else: +${kname}(device const ${dtype}* b, device ${dtype}* c, + uint i [[thread_position_in_grid]]) +{ + const int n = ${-(-n // width)}; + const int ldb = ${ldb // width}; + const int ldc = ${ldc // width}; +% endif + ${dtype} dotp; + + if (i < n) + { +% for j, jx in enumerate(A): + % if (dotex := dot(lambda kx: f'b[i + {kx}*ldb]', jx, maxsplit=ksplit)) != '0.0': + dotp = ${dotex}; + % else: + dotp = make_zero(); + % endif + % if beta == 0: + c[i + ${j}*ldc] = dotp; + % elif beta == 1 and dotex != '0.0': + c[i + ${j}*ldc] += dotp; + % else: + c[i + ${j}*ldc] = dotp + ${beta}*c[i + ${j}*ldc]; + % endif +% endfor + } +} diff --git a/gimmik/metal.py b/gimmik/metal.py new file mode 100644 index 0000000..86043fb --- /dev/null +++ b/gimmik/metal.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + +from gimmik.base import MatMul + + +class MetalMatMul(MatMul): + platform = 'metal' + basemeta = {'threadgroup': (128, 1, 1), 'threadgroup_mem_size': 0, + 'width': 1} + + def _kernel_generators(self, dtype, dsize): + # B loading, C streaming kernel + yield ('cstream', {}, {}) + + # B streaming, C accumulation kernel + yield ('bstream', {}, {}) + + # Four-way m-split B streaming, C accumulation kernel + ms, bsz, blkx = 4, 16, 32 + args = {'msplit': ms, 'blockx': blkx, 'bsz': bsz} + meta = {'threadgroup': (blkx, ms, 1), + 'threadgroup_mem_size': 2*blkx*bsz*dsize} + yield ('bstream-msplit', args, meta) + + # Four-way m-split B streaming, C accumulation kernel + ms, bsz, blkx = 4, 20, 32 + args = {'msplit': ms, 'blockx': blkx, 'bsz': bsz} + meta = {'threadgroup': (blkx, ms, 1), + 'threadgroup_mem_size': 2*blkx*bsz*dsize} + yield ('bstream-msplit', args, meta) + + # Two-way k-split B loading, C streaming kernel + ks, csz, blkx = 2, 20, 32 + args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} + meta = {'threadgroup': (blkx, ks, 1), + 'threadgroup_mem_size': (ks - 1)*csz*blkx*dsize} + yield ('cstream-ksplit', args, meta) + + if self.aligne is not None and self.aligne % 2 == 0: + # Vector B loading, C streaming kernel + args = {'dtype': 'float2', 'width': 2} + meta = {'width': 2} + yield ('cstream', args, meta) + + # Vector B streaming, C accumulation kernel + yield ('bstream', args, meta) + + # Vector four-way m-split B streaming, C accumulation kernel + ms, bsz, blkx = 4, 16, 32 + args = {'dtype': 'float2', 'width': 2, 'msplit': ms, + 'blockx': blkx, 'bsz': bsz} + meta = {'threadgroup': (blkx, ms, 1), + 'threadgroup_mem_size': 2*blkx*bsz*dsize, 'width': 2} + yield ('bstream-msplit', args, meta) + + def _process_meta(self, meta): + if self.n is not None: + tg = meta['threadgroup'] + meta['grid'] = (-(-self.n // meta['width']), tg[1], 1)