Skip to content

Commit

Permalink
Merge pull request #13 from PyFR/feature/metal.
Browse files Browse the repository at this point in the history
Add Metal support.
  • Loading branch information
FreddieWitherden committed Apr 4, 2023
2 parents f3e7d90 + 9743b0f commit 2cfaeb5
Show file tree
Hide file tree
Showing 8 changed files with 352 additions and 11 deletions.
1 change: 1 addition & 0 deletions gimmik/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from gimmik.cuda import CUDAMatMul
from gimmik.ispc import ISPCMatMul
from gimmik.hip import HIPMatMul
from gimmik.metal import MetalMatMul
from gimmik.opencl import OpenCLMatMul


Expand Down
11 changes: 0 additions & 11 deletions gimmik/copenmp.py

This file was deleted.

16 changes: 16 additions & 0 deletions gimmik/kernels/metal/base.mako
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include <metal_stdlib>

using namespace metal;

% if dtype.endswith('4'):
static inline ${dtype} make_zero()
{ return ${dtype}(0, 0, 0, 0); }
% elif dtype.endswith('2'):
static inline ${dtype} make_zero()
{ return ${dtype}(0, 0); }
% else:
static inline ${dtype} make_zero()
{ return 0; }
% endif

${next.body()}
93 changes: 93 additions & 0 deletions gimmik/kernels/metal/bstream-msplit.mako
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
<%inherit file='base'/>

<%
mx = partition(A, into=msplit, by='rows')
bchunks = chunk(bix, bsz)
%>

kernel void
% if n is None:
${kname}(constant int& n_,
device ${dtype}* b, constant int& ldb_,
device ${dtype}* c, constant int& ldc_,
uint2 tpig [[thread_position_in_grid]],
uint2 tpitg [[thread_position_in_threadgroup]])
{
const int n = ((n_ + ${width} - 1) / ${width}) * ${width};
const int ldb = ldb_ / ${width};
const int ldc = ldc_ / ${width};
% else:
${kname}(device const ${dtype}* b, device ${dtype}* c,
uint2 tpig [[thread_position_in_grid]],
uint2 tpitg [[thread_position_in_threadgroup]])
{
const int n = ${-(-n // width)};
const int ldb = ${ldb // width};
const int ldc = ${ldc // width};
% endif
const int i = tpig.x;

${dtype} bv, csub[${-(-m // msplit)}];
threadgroup ${dtype} bsub[2][${bsz}][${blockx}];

## Fill the initial shared memory block
% for cid in range(msplit):
if (i < n && tpitg.y == ${cid})
{
% for kx in bchunks[0]:
% if loop.index % msplit == cid:
bsub[0][${loop.index}][tpitg.x] = b[i + ${kx}*ldb];
% endif
% endfor
}
% endfor
threadgroup_barrier(mem_flags::mem_threadgroup);

## Iterate over each row-chunk of B
% for bb in range(len(bchunks)):
## Iterate over each row-chunk of C
% for cid, mcx in enumerate(mx):
if (i < n && tpitg.y == ${cid})
{
## Start filling the next shared memory block
% if not loop.parent.last:
% for kx in bchunks[bb + 1]:
% if loop.index % msplit == cid:
bsub[${(bb + 1) % 2}][${loop.index}][tpitg.x] = b[i + ${kx}*ldb];
% endif
% endfor
% endif
## Accumulate our dot products
% for kx in bchunks[bb]:
bv = bsub[${bb % 2}][${loop.index}][tpitg.x];
% for j, jx in enumerate(A[mcx, kx]):
% if jx != 0 and kx == afix[mcx[j]]:
csub[${j}] = ${jx}*bv;
% elif jx != 0:
csub[${j}] += ${jx}*bv;
% endif
## If we're done with this dot product then store to global
% if kx == alix[mcx[j]] and beta == 0:
c[i + ${mcx[j]}*ldc] = csub[${j}];
% elif kx == alix[mcx[j]] and beta == 1:
c[i + ${mcx[j]}*ldc] += csub[${j}];
% elif kx == alix[mcx[j]]:
c[i + ${mcx[j]}*ldc] = csub[${j}] + ${beta}*c[i + ${mcx[j]}*ldc];
% endif
% endfor
% endfor
## Handle rows of A which are all zero
% if loop.parent.last:
% for j, jx in enumerate(afix):
% if jx == -1 and j % msplit == cid and beta == 0:
c[i + ${j}*ldc] = make_zero();
% elif jx == -1 and j % msplit == cid and beta != 1:
c[i + ${j}*ldc] *= ${beta};
% endif
% endfor
% endif
}
% endfor
threadgroup_barrier(mem_flags::mem_threadgroup);
% endfor
}
55 changes: 55 additions & 0 deletions gimmik/kernels/metal/bstream.mako
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
<%inherit file='base'/>

kernel void
% if n is None:
${kname}(constant int& n_,
device ${dtype}* b, constant int& ldb_,
device ${dtype}* c, constant int& ldc_,
uint i [[thread_position_in_grid]])
{
const int n = ((n_ + ${width} - 1) / ${width}) * ${width};
const int ldb = ldb_ / ${width};
const int ldc = ldc_ / ${width};
% else:
${kname}(device const ${dtype}* b, device ${dtype}* c,
uint i [[thread_position_in_grid]])
{
const int n = ${-(-n // width)};
const int ldb = ${ldb // width};
const int ldc = ${ldc // width};
% endif

if (i < n)
{
${dtype} bv, csub[${m}];

## Iterare through the used rows of B
% for kx in bix:
bv = b[i + ${kx}*ldb];
% for j, jx in enumerate(A[:, kx]):
% if jx != 0 and kx == afix[j]:
csub[${j}] = ${jx}*bv;
% elif jx != 0:
csub[${j}] += ${jx}*bv;
% endif
##
% if kx == alix[j] and beta == 0:
c[i + ${j}*ldc] = csub[${j}];
% elif kx == alix[j] and beta == 1:
c[i + ${j}*ldc] += csub[${j}];
% elif kx == alix[j]:
c[i + ${j}*ldc] = csub[${j}] + ${beta}*c[i + ${j}*ldc];
% endif
% endfor
% endfor

## Handle rows of A which are all zero
% for j, jx in enumerate(afix):
% if jx == -1 and beta == 0:
c[i + ${j}*ldc] = make_zero();
% elif jx == -1 and beta != 1:
c[i + ${j}*ldc] *= ${beta};
% endif
% endfor
}
}
86 changes: 86 additions & 0 deletions gimmik/kernels/metal/cstream-ksplit.mako
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
<%inherit file='base'/>

<%
kparts = partition(A, ksplit, by='cols')
cchunks = chunk(range(m), csz)
loaded = set()
%>

kernel void
% if n is None:
${kname}(constant int& n_,
device ${dtype}* b, constant int& ldb_,
device ${dtype}* c, constant int& ldc_,
uint2 tpig [[thread_position_in_grid]],
uint2 tpitg [[thread_position_in_threadgroup]])
{
const int n = ((n_ + ${width} - 1) / ${width}) * ${width};
const int ldb = ldb_ / ${width};
const int ldc = ldc_ / ${width};
% else:
${kname}(device const ${dtype}* b, device ${dtype}* c,
uint2 tpig [[thread_position_in_grid]],
uint2 tpitg [[thread_position_in_threadgroup]])
{
const int n = ${-(-n // width)};
const int ldb = ${ldb // width};
const int ldc = ${ldc // width};
% endif
const int i = tpig.x;

${dtype} cv[${-(-csz // ksplit)}], bv[${-(-k // ksplit)}], dotp;
threadgroup ${dtype} csub[${ksplit - 1}][${csz}][${blockx}];

## Iterate over the row-partitions of C
% for cchunk in cchunks:
## Iterate over the row-partitions of B
% for bid, kbx in enumerate(kparts):
if (i < n && tpitg.y == ${bid})
{
## Evaluate our partial dot products
% for j in cchunk:
## Load in any missing parts of B
% for kx in kbx:
% if A[j, kx] != 0 and kx not in loaded:
bv[${loop.index}] = b[i + ${kx}*ldb]; <% loaded.add(kx) %>
% endif
% endfor
% if (dotex := dot(lambda kx: f'bv[{kx}]', A[j, kbx])) != '0.0':
dotp = ${dotex};
% else:
dotp = make_zero();
% endif
## Save to a register
% if loop.index % ksplit == bid:
cv[${loop.index // ksplit}] = dotp;
## Save to shared memory
% else:
csub[${bid - (bid > loop.index % ksplit)}][${loop.index}][tpitg.x] = dotp;
% endif
% endfor
}
% endfor
threadgroup_barrier(mem_flags::mem_threadgroup);
## Iterate over the column-partitions of B
% for bid, kbx in enumerate(kparts):
if (i < n && tpitg.y == ${bid})
{
## Sum and output the final set of dot products
% for j in cchunk:
% if loop.index % ksplit == bid:
dotp = cv[${loop.index // ksplit}] + ${' + '.join(f'csub[{i}][{loop.index}][tpitg.x]'
for i in range(ksplit - 1))};
% if beta == 0:
c[i + ${j}*ldc] = dotp;
% elif beta == 1:
c[i + ${j}*ldc] += dotp;
% else:
c[i + ${j}*ldc] = dotp + ${beta}*c[i + ${j}*ldc];
% endif
% endif
% endfor
}
% endfor
threadgroup_barrier(mem_flags::mem_threadgroup);
% endfor
}
42 changes: 42 additions & 0 deletions gimmik/kernels/metal/cstream.mako
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
<%inherit file='base'/>

<% ksplit = 2 if m < 36 else 1 %>

kernel void
% if n is None:
${kname}(constant int& n_,
device ${dtype}* b, constant int& ldb_,
device ${dtype}* c, constant int& ldc_,
uint i [[thread_position_in_grid]])
{
const int n = ((n_ + ${width} - 1) / ${width}) * ${width};
const int ldb = ldb_ / ${width};
const int ldc = ldc_ / ${width};
% else:
${kname}(device const ${dtype}* b, device ${dtype}* c,
uint i [[thread_position_in_grid]])
{
const int n = ${-(-n // width)};
const int ldb = ${ldb // width};
const int ldc = ${ldc // width};
% endif
${dtype} dotp;

if (i < n)
{
% for j, jx in enumerate(A):
% if (dotex := dot(lambda kx: f'b[i + {kx}*ldb]', jx, maxsplit=ksplit)) != '0.0':
dotp = ${dotex};
% else:
dotp = make_zero();
% endif
% if beta == 0:
c[i + ${j}*ldc] = dotp;
% elif beta == 1 and dotex != '0.0':
c[i + ${j}*ldc] += dotp;
% else:
c[i + ${j}*ldc] = dotp + ${beta}*c[i + ${j}*ldc];
% endif
% endfor
}
}
59 changes: 59 additions & 0 deletions gimmik/metal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-

from gimmik.base import MatMul


class MetalMatMul(MatMul):
platform = 'metal'
basemeta = {'threadgroup': (128, 1, 1), 'threadgroup_mem_size': 0,
'width': 1}

def _kernel_generators(self, dtype, dsize):
# B loading, C streaming kernel
yield ('cstream', {}, {})

# B streaming, C accumulation kernel
yield ('bstream', {}, {})

# Four-way m-split B streaming, C accumulation kernel
ms, bsz, blkx = 4, 16, 32
args = {'msplit': ms, 'blockx': blkx, 'bsz': bsz}
meta = {'threadgroup': (blkx, ms, 1),
'threadgroup_mem_size': 2*blkx*bsz*dsize}
yield ('bstream-msplit', args, meta)

# Four-way m-split B streaming, C accumulation kernel
ms, bsz, blkx = 4, 20, 32
args = {'msplit': ms, 'blockx': blkx, 'bsz': bsz}
meta = {'threadgroup': (blkx, ms, 1),
'threadgroup_mem_size': 2*blkx*bsz*dsize}
yield ('bstream-msplit', args, meta)

# Two-way k-split B loading, C streaming kernel
ks, csz, blkx = 2, 20, 32
args = {'ksplit': ks, 'csz': csz, 'blockx': blkx}
meta = {'threadgroup': (blkx, ks, 1),
'threadgroup_mem_size': (ks - 1)*csz*blkx*dsize}
yield ('cstream-ksplit', args, meta)

if self.aligne is not None and self.aligne % 2 == 0:
# Vector B loading, C streaming kernel
args = {'dtype': 'float2', 'width': 2}
meta = {'width': 2}
yield ('cstream', args, meta)

# Vector B streaming, C accumulation kernel
yield ('bstream', args, meta)

# Vector four-way m-split B streaming, C accumulation kernel
ms, bsz, blkx = 4, 16, 32
args = {'dtype': 'float2', 'width': 2, 'msplit': ms,
'blockx': blkx, 'bsz': bsz}
meta = {'threadgroup': (blkx, ms, 1),
'threadgroup_mem_size': 2*blkx*bsz*dsize, 'width': 2}
yield ('bstream-msplit', args, meta)

def _process_meta(self, meta):
if self.n is not None:
tg = meta['threadgroup']
meta['grid'] = (-(-self.n // meta['width']), tg[1], 1)

0 comments on commit 2cfaeb5

Please sign in to comment.