From c1524d6359d8a4376cb77e5753274faa9dd7c044 Mon Sep 17 00:00:00 2001 From: Freddie Witherden Date: Tue, 10 Jan 2023 12:26:54 -0600 Subject: [PATCH] Lay the groundwork for adaptive kernel generation. --- gimmik/base.py | 35 ++++++++++++++++++++++------------- gimmik/cuda.py | 2 +- gimmik/opencl.py | 13 +++++++++---- 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/gimmik/base.py b/gimmik/base.py index 26539ca..f547afc 100644 --- a/gimmik/base.py +++ b/gimmik/base.py @@ -113,19 +113,28 @@ def kernels(self, dtype, kname='gimmik_mm', **kwargs): } # Incrementally generate and render the kernels - for name, exargs, exmeta in self._kernel_generators(dtype, dsize): - # Merge in the base arguments and metadata - args = baseargs | exargs - meta = basemeta | exmeta - - # Render the kernel template - src = self._render_kernel(dtype, name, args) - - # Post-process the metadata - meta['tplname'] = name - self._process_meta(meta) - - yield (src, meta) + gen = self._kernel_generators(dtype, dsize, **kwargs) + try: + resp = None + while True: + # Generate the next kernel in the sequence + name, exargs, exmeta = gen.send(resp) + + # Merge in the base arguments and metadata + args = baseargs | exargs + meta = basemeta | exmeta + + # Render the kernel template + src = self._render_kernel(dtype, name, args) + + # Post-process the metadata + meta['tplname'] = name + self._process_meta(meta) + + # Yield the source and metadata and await a response + resp = yield (src, meta) + except StopIteration: + pass def _process_meta(self, meta): pass diff --git a/gimmik/cuda.py b/gimmik/cuda.py index 74ebf95..b18c509 100644 --- a/gimmik/cuda.py +++ b/gimmik/cuda.py @@ -8,7 +8,7 @@ class CUDAMatMul(MatMul): basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0, 'dynamic_shared': 0} - def _kernel_generators(self, dtype, dsize): + def _kernel_generators(self, dtype, dsize, *, compute_capability=None): # B loading, C streaming kernel yield ('cstream', {}, {}) diff --git a/gimmik/opencl.py b/gimmik/opencl.py index f158e40..a247cbd 100644 --- a/gimmik/opencl.py +++ b/gimmik/opencl.py @@ -7,7 +7,9 @@ class OpenCLMatMul(MatMul): platform = 'opencl' basemeta = {'local_work_size': None, 'local_mem_size': 0, 'width': 1} - def _kernel_generators(self, dtype, dsize): + def _kernel_generators(self, dtype, dsize, *, local_mem_size=None): + max_local_mem = local_mem_size or 1024**3 + # B loading, C streaming kernel yield ('cstream', {}, {}) @@ -19,14 +21,16 @@ def _kernel_generators(self, dtype, dsize): args = {'msplit': ms, 'blockx': blkx, 'bsz': bsz} meta = {'local_work_size': (blkx, ms), 'local_mem_size': 2*blkx*bsz*dsize} - yield ('bstream-msplit', args, meta) + if meta['local_mem_size'] < max_local_mem: + yield ('bstream-msplit', args, meta) # Two-way k-split B loading, C streaming kernel ks, csz, blkx = 2, 32, 64 args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} meta = {'local_work_size': (blkx, ks), 'local_mem_size': (ks - 1)*csz*blkx*dsize} - yield ('cstream-ksplit', args, meta) + if meta['local_mem_size'] < max_local_mem: + yield ('cstream-ksplit', args, meta) # At single precision also consider vectorized kernels if (dtype == 'float' and @@ -42,7 +46,8 @@ def _kernel_generators(self, dtype, dsize): 'blockx': blkx, 'bsz': bsz} meta = {'local_work_size': (blkx, ms), 'local_mem_size': 2*blkx*bsz*dsize, 'width': 2} - yield ('bstream-msplit', args, meta) + if meta['local_mem_size'] < max_local_mem: + yield ('bstream-msplit', args, meta) def _process_meta(self, meta): if self.n is not None: