Skip to content

Commit

Permalink
Merge pull request #12 from PyFR/feature/adapt.
Browse files Browse the repository at this point in the history
Lay the groundwork for adaptive kernel generation.
  • Loading branch information
FreddieWitherden authored Jan 19, 2023
2 parents 8cf403f + c1524d6 commit af42472
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 18 deletions.
35 changes: 22 additions & 13 deletions gimmik/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,28 @@ def kernels(self, dtype, kname='gimmik_mm', **kwargs):
}

# Incrementally generate and render the kernels
for name, exargs, exmeta in self._kernel_generators(dtype, dsize):
# Merge in the base arguments and metadata
args = baseargs | exargs
meta = basemeta | exmeta

# Render the kernel template
src = self._render_kernel(dtype, name, args)

# Post-process the metadata
meta['tplname'] = name
self._process_meta(meta)

yield (src, meta)
gen = self._kernel_generators(dtype, dsize, **kwargs)
try:
resp = None
while True:
# Generate the next kernel in the sequence
name, exargs, exmeta = gen.send(resp)

# Merge in the base arguments and metadata
args = baseargs | exargs
meta = basemeta | exmeta

# Render the kernel template
src = self._render_kernel(dtype, name, args)

# Post-process the metadata
meta['tplname'] = name
self._process_meta(meta)

# Yield the source and metadata and await a response
resp = yield (src, meta)
except StopIteration:
pass

def _process_meta(self, meta):
pass
Expand Down
2 changes: 1 addition & 1 deletion gimmik/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class CUDAMatMul(MatMul):
basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0,
'dynamic_shared': 0}

def _kernel_generators(self, dtype, dsize):
def _kernel_generators(self, dtype, dsize, *, compute_capability=None):
# B loading, C streaming kernel
yield ('cstream', {}, {})

Expand Down
13 changes: 9 additions & 4 deletions gimmik/opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ class OpenCLMatMul(MatMul):
platform = 'opencl'
basemeta = {'local_work_size': None, 'local_mem_size': 0, 'width': 1}

def _kernel_generators(self, dtype, dsize):
def _kernel_generators(self, dtype, dsize, *, local_mem_size=None):
max_local_mem = local_mem_size or 1024**3

# B loading, C streaming kernel
yield ('cstream', {}, {})

Expand All @@ -19,14 +21,16 @@ def _kernel_generators(self, dtype, dsize):
args = {'msplit': ms, 'blockx': blkx, 'bsz': bsz}
meta = {'local_work_size': (blkx, ms),
'local_mem_size': 2*blkx*bsz*dsize}
yield ('bstream-msplit', args, meta)
if meta['local_mem_size'] < max_local_mem:
yield ('bstream-msplit', args, meta)

# Two-way k-split B loading, C streaming kernel
ks, csz, blkx = 2, 32, 64
args = {'ksplit': ks, 'csz': csz, 'blockx': blkx}
meta = {'local_work_size': (blkx, ks),
'local_mem_size': (ks - 1)*csz*blkx*dsize}
yield ('cstream-ksplit', args, meta)
if meta['local_mem_size'] < max_local_mem:
yield ('cstream-ksplit', args, meta)

# At single precision also consider vectorized kernels
if (dtype == 'float' and
Expand All @@ -42,7 +46,8 @@ def _kernel_generators(self, dtype, dsize):
'blockx': blkx, 'bsz': bsz}
meta = {'local_work_size': (blkx, ms),
'local_mem_size': 2*blkx*bsz*dsize, 'width': 2}
yield ('bstream-msplit', args, meta)
if meta['local_mem_size'] < max_local_mem:
yield ('bstream-msplit', args, meta)

def _process_meta(self, meta):
if self.n is not None:
Expand Down

0 comments on commit af42472

Please sign in to comment.