-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from PyFR/feature/metal.
Add Metal support.
- Loading branch information
Showing
8 changed files
with
352 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#include <metal_stdlib> | ||
|
||
using namespace metal; | ||
|
||
% if dtype.endswith('4'): | ||
static inline ${dtype} make_zero() | ||
{ return ${dtype}(0, 0, 0, 0); } | ||
% elif dtype.endswith('2'): | ||
static inline ${dtype} make_zero() | ||
{ return ${dtype}(0, 0); } | ||
% else: | ||
static inline ${dtype} make_zero() | ||
{ return 0; } | ||
% endif | ||
|
||
${next.body()} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
<%inherit file='base'/> | ||
|
||
<% | ||
mx = partition(A, into=msplit, by='rows') | ||
bchunks = chunk(bix, bsz) | ||
%> | ||
|
||
kernel void | ||
% if n is None: | ||
${kname}(constant int& n_, | ||
device ${dtype}* b, constant int& ldb_, | ||
device ${dtype}* c, constant int& ldc_, | ||
uint2 tpig [[thread_position_in_grid]], | ||
uint2 tpitg [[thread_position_in_threadgroup]]) | ||
{ | ||
const int n = ((n_ + ${width} - 1) / ${width}) * ${width}; | ||
const int ldb = ldb_ / ${width}; | ||
const int ldc = ldc_ / ${width}; | ||
% else: | ||
${kname}(device const ${dtype}* b, device ${dtype}* c, | ||
uint2 tpig [[thread_position_in_grid]], | ||
uint2 tpitg [[thread_position_in_threadgroup]]) | ||
{ | ||
const int n = ${-(-n // width)}; | ||
const int ldb = ${ldb // width}; | ||
const int ldc = ${ldc // width}; | ||
% endif | ||
const int i = tpig.x; | ||
|
||
${dtype} bv, csub[${-(-m // msplit)}]; | ||
threadgroup ${dtype} bsub[2][${bsz}][${blockx}]; | ||
|
||
## Fill the initial shared memory block | ||
% for cid in range(msplit): | ||
if (i < n && tpitg.y == ${cid}) | ||
{ | ||
% for kx in bchunks[0]: | ||
% if loop.index % msplit == cid: | ||
bsub[0][${loop.index}][tpitg.x] = b[i + ${kx}*ldb]; | ||
% endif | ||
% endfor | ||
} | ||
% endfor | ||
threadgroup_barrier(mem_flags::mem_threadgroup); | ||
|
||
## Iterate over each row-chunk of B | ||
% for bb in range(len(bchunks)): | ||
## Iterate over each row-chunk of C | ||
% for cid, mcx in enumerate(mx): | ||
if (i < n && tpitg.y == ${cid}) | ||
{ | ||
## Start filling the next shared memory block | ||
% if not loop.parent.last: | ||
% for kx in bchunks[bb + 1]: | ||
% if loop.index % msplit == cid: | ||
bsub[${(bb + 1) % 2}][${loop.index}][tpitg.x] = b[i + ${kx}*ldb]; | ||
% endif | ||
% endfor | ||
% endif | ||
## Accumulate our dot products | ||
% for kx in bchunks[bb]: | ||
bv = bsub[${bb % 2}][${loop.index}][tpitg.x]; | ||
% for j, jx in enumerate(A[mcx, kx]): | ||
% if jx != 0 and kx == afix[mcx[j]]: | ||
csub[${j}] = ${jx}*bv; | ||
% elif jx != 0: | ||
csub[${j}] += ${jx}*bv; | ||
% endif | ||
## If we're done with this dot product then store to global | ||
% if kx == alix[mcx[j]] and beta == 0: | ||
c[i + ${mcx[j]}*ldc] = csub[${j}]; | ||
% elif kx == alix[mcx[j]] and beta == 1: | ||
c[i + ${mcx[j]}*ldc] += csub[${j}]; | ||
% elif kx == alix[mcx[j]]: | ||
c[i + ${mcx[j]}*ldc] = csub[${j}] + ${beta}*c[i + ${mcx[j]}*ldc]; | ||
% endif | ||
% endfor | ||
% endfor | ||
## Handle rows of A which are all zero | ||
% if loop.parent.last: | ||
% for j, jx in enumerate(afix): | ||
% if jx == -1 and j % msplit == cid and beta == 0: | ||
c[i + ${j}*ldc] = make_zero(); | ||
% elif jx == -1 and j % msplit == cid and beta != 1: | ||
c[i + ${j}*ldc] *= ${beta}; | ||
% endif | ||
% endfor | ||
% endif | ||
} | ||
% endfor | ||
threadgroup_barrier(mem_flags::mem_threadgroup); | ||
% endfor | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
<%inherit file='base'/> | ||
|
||
kernel void | ||
% if n is None: | ||
${kname}(constant int& n_, | ||
device ${dtype}* b, constant int& ldb_, | ||
device ${dtype}* c, constant int& ldc_, | ||
uint i [[thread_position_in_grid]]) | ||
{ | ||
const int n = ((n_ + ${width} - 1) / ${width}) * ${width}; | ||
const int ldb = ldb_ / ${width}; | ||
const int ldc = ldc_ / ${width}; | ||
% else: | ||
${kname}(device const ${dtype}* b, device ${dtype}* c, | ||
uint i [[thread_position_in_grid]]) | ||
{ | ||
const int n = ${-(-n // width)}; | ||
const int ldb = ${ldb // width}; | ||
const int ldc = ${ldc // width}; | ||
% endif | ||
|
||
if (i < n) | ||
{ | ||
${dtype} bv, csub[${m}]; | ||
|
||
## Iterare through the used rows of B | ||
% for kx in bix: | ||
bv = b[i + ${kx}*ldb]; | ||
% for j, jx in enumerate(A[:, kx]): | ||
% if jx != 0 and kx == afix[j]: | ||
csub[${j}] = ${jx}*bv; | ||
% elif jx != 0: | ||
csub[${j}] += ${jx}*bv; | ||
% endif | ||
## | ||
% if kx == alix[j] and beta == 0: | ||
c[i + ${j}*ldc] = csub[${j}]; | ||
% elif kx == alix[j] and beta == 1: | ||
c[i + ${j}*ldc] += csub[${j}]; | ||
% elif kx == alix[j]: | ||
c[i + ${j}*ldc] = csub[${j}] + ${beta}*c[i + ${j}*ldc]; | ||
% endif | ||
% endfor | ||
% endfor | ||
|
||
## Handle rows of A which are all zero | ||
% for j, jx in enumerate(afix): | ||
% if jx == -1 and beta == 0: | ||
c[i + ${j}*ldc] = make_zero(); | ||
% elif jx == -1 and beta != 1: | ||
c[i + ${j}*ldc] *= ${beta}; | ||
% endif | ||
% endfor | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
<%inherit file='base'/> | ||
|
||
<% | ||
kparts = partition(A, ksplit, by='cols') | ||
cchunks = chunk(range(m), csz) | ||
loaded = set() | ||
%> | ||
|
||
kernel void | ||
% if n is None: | ||
${kname}(constant int& n_, | ||
device ${dtype}* b, constant int& ldb_, | ||
device ${dtype}* c, constant int& ldc_, | ||
uint2 tpig [[thread_position_in_grid]], | ||
uint2 tpitg [[thread_position_in_threadgroup]]) | ||
{ | ||
const int n = ((n_ + ${width} - 1) / ${width}) * ${width}; | ||
const int ldb = ldb_ / ${width}; | ||
const int ldc = ldc_ / ${width}; | ||
% else: | ||
${kname}(device const ${dtype}* b, device ${dtype}* c, | ||
uint2 tpig [[thread_position_in_grid]], | ||
uint2 tpitg [[thread_position_in_threadgroup]]) | ||
{ | ||
const int n = ${-(-n // width)}; | ||
const int ldb = ${ldb // width}; | ||
const int ldc = ${ldc // width}; | ||
% endif | ||
const int i = tpig.x; | ||
|
||
${dtype} cv[${-(-csz // ksplit)}], bv[${-(-k // ksplit)}], dotp; | ||
threadgroup ${dtype} csub[${ksplit - 1}][${csz}][${blockx}]; | ||
|
||
## Iterate over the row-partitions of C | ||
% for cchunk in cchunks: | ||
## Iterate over the row-partitions of B | ||
% for bid, kbx in enumerate(kparts): | ||
if (i < n && tpitg.y == ${bid}) | ||
{ | ||
## Evaluate our partial dot products | ||
% for j in cchunk: | ||
## Load in any missing parts of B | ||
% for kx in kbx: | ||
% if A[j, kx] != 0 and kx not in loaded: | ||
bv[${loop.index}] = b[i + ${kx}*ldb]; <% loaded.add(kx) %> | ||
% endif | ||
% endfor | ||
% if (dotex := dot(lambda kx: f'bv[{kx}]', A[j, kbx])) != '0.0': | ||
dotp = ${dotex}; | ||
% else: | ||
dotp = make_zero(); | ||
% endif | ||
## Save to a register | ||
% if loop.index % ksplit == bid: | ||
cv[${loop.index // ksplit}] = dotp; | ||
## Save to shared memory | ||
% else: | ||
csub[${bid - (bid > loop.index % ksplit)}][${loop.index}][tpitg.x] = dotp; | ||
% endif | ||
% endfor | ||
} | ||
% endfor | ||
threadgroup_barrier(mem_flags::mem_threadgroup); | ||
## Iterate over the column-partitions of B | ||
% for bid, kbx in enumerate(kparts): | ||
if (i < n && tpitg.y == ${bid}) | ||
{ | ||
## Sum and output the final set of dot products | ||
% for j in cchunk: | ||
% if loop.index % ksplit == bid: | ||
dotp = cv[${loop.index // ksplit}] + ${' + '.join(f'csub[{i}][{loop.index}][tpitg.x]' | ||
for i in range(ksplit - 1))}; | ||
% if beta == 0: | ||
c[i + ${j}*ldc] = dotp; | ||
% elif beta == 1: | ||
c[i + ${j}*ldc] += dotp; | ||
% else: | ||
c[i + ${j}*ldc] = dotp + ${beta}*c[i + ${j}*ldc]; | ||
% endif | ||
% endif | ||
% endfor | ||
} | ||
% endfor | ||
threadgroup_barrier(mem_flags::mem_threadgroup); | ||
% endfor | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
<%inherit file='base'/> | ||
|
||
<% ksplit = 2 if m < 36 else 1 %> | ||
|
||
kernel void | ||
% if n is None: | ||
${kname}(constant int& n_, | ||
device ${dtype}* b, constant int& ldb_, | ||
device ${dtype}* c, constant int& ldc_, | ||
uint i [[thread_position_in_grid]]) | ||
{ | ||
const int n = ((n_ + ${width} - 1) / ${width}) * ${width}; | ||
const int ldb = ldb_ / ${width}; | ||
const int ldc = ldc_ / ${width}; | ||
% else: | ||
${kname}(device const ${dtype}* b, device ${dtype}* c, | ||
uint i [[thread_position_in_grid]]) | ||
{ | ||
const int n = ${-(-n // width)}; | ||
const int ldb = ${ldb // width}; | ||
const int ldc = ${ldc // width}; | ||
% endif | ||
${dtype} dotp; | ||
|
||
if (i < n) | ||
{ | ||
% for j, jx in enumerate(A): | ||
% if (dotex := dot(lambda kx: f'b[i + {kx}*ldb]', jx, maxsplit=ksplit)) != '0.0': | ||
dotp = ${dotex}; | ||
% else: | ||
dotp = make_zero(); | ||
% endif | ||
% if beta == 0: | ||
c[i + ${j}*ldc] = dotp; | ||
% elif beta == 1 and dotex != '0.0': | ||
c[i + ${j}*ldc] += dotp; | ||
% else: | ||
c[i + ${j}*ldc] = dotp + ${beta}*c[i + ${j}*ldc]; | ||
% endif | ||
% endfor | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from gimmik.base import MatMul | ||
|
||
|
||
class MetalMatMul(MatMul): | ||
platform = 'metal' | ||
basemeta = {'threadgroup': (128, 1, 1), 'threadgroup_mem_size': 0, | ||
'width': 1} | ||
|
||
def _kernel_generators(self, dtype, dsize): | ||
# B loading, C streaming kernel | ||
yield ('cstream', {}, {}) | ||
|
||
# B streaming, C accumulation kernel | ||
yield ('bstream', {}, {}) | ||
|
||
# Four-way m-split B streaming, C accumulation kernel | ||
ms, bsz, blkx = 4, 16, 32 | ||
args = {'msplit': ms, 'blockx': blkx, 'bsz': bsz} | ||
meta = {'threadgroup': (blkx, ms, 1), | ||
'threadgroup_mem_size': 2*blkx*bsz*dsize} | ||
yield ('bstream-msplit', args, meta) | ||
|
||
# Four-way m-split B streaming, C accumulation kernel | ||
ms, bsz, blkx = 4, 20, 32 | ||
args = {'msplit': ms, 'blockx': blkx, 'bsz': bsz} | ||
meta = {'threadgroup': (blkx, ms, 1), | ||
'threadgroup_mem_size': 2*blkx*bsz*dsize} | ||
yield ('bstream-msplit', args, meta) | ||
|
||
# Two-way k-split B loading, C streaming kernel | ||
ks, csz, blkx = 2, 20, 32 | ||
args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} | ||
meta = {'threadgroup': (blkx, ks, 1), | ||
'threadgroup_mem_size': (ks - 1)*csz*blkx*dsize} | ||
yield ('cstream-ksplit', args, meta) | ||
|
||
if self.aligne is not None and self.aligne % 2 == 0: | ||
# Vector B loading, C streaming kernel | ||
args = {'dtype': 'float2', 'width': 2} | ||
meta = {'width': 2} | ||
yield ('cstream', args, meta) | ||
|
||
# Vector B streaming, C accumulation kernel | ||
yield ('bstream', args, meta) | ||
|
||
# Vector four-way m-split B streaming, C accumulation kernel | ||
ms, bsz, blkx = 4, 16, 32 | ||
args = {'dtype': 'float2', 'width': 2, 'msplit': ms, | ||
'blockx': blkx, 'bsz': bsz} | ||
meta = {'threadgroup': (blkx, ms, 1), | ||
'threadgroup_mem_size': 2*blkx*bsz*dsize, 'width': 2} | ||
yield ('bstream-msplit', args, meta) | ||
|
||
def _process_meta(self, meta): | ||
if self.n is not None: | ||
tg = meta['threadgroup'] | ||
meta['grid'] = (-(-self.n // meta['width']), tg[1], 1) |