diff --git a/gimmik/__init__.py b/gimmik/__init__.py index 5724d1d..7ee8a28 100644 --- a/gimmik/__init__.py +++ b/gimmik/__init__.py @@ -9,7 +9,8 @@ from gimmik._version import __version__ -def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm'): +def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm', + n=None, ldb=None, ldc=None): # Data type dtype = np.dtype(dtype).type if dtype == np.float32: @@ -19,11 +20,22 @@ def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm'): else: raise ValueError('Invalid floating point data type') + if 0 < (n is None) + (ldb is None) + (ldc is None) < 3: + raise ValueError('Must provide all of (n, ldb, ldc) or none') + # Multiply the matrix through by alpha mat = alpha*mat # Template arguments - tplargs = {'dtype': dtype, 'mat': mat, 'beta': beta, 'funcn': funcn} + tplargs = { + 'dtype': dtype, + 'mat': mat, + 'beta': beta, + 'funcn': funcn, + 'n': n, + 'ldb': ldb, + 'ldc': ldc + } # Load and render the template tpl = pkgutil.get_data(__name__, f'kernels/{platform}.mako') @@ -34,5 +46,9 @@ def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm'): src = re.sub(r'(?=\d*[.eE])(?=\.?\d)\d*\.?\d*(?:[eE][+-]?\d+)?', r'\g<0>f', src) + # Cleanup + src = re.sub(r'\n\n+', r'\n\n', src.strip()) + '\n' + src = re.sub(r'\w+$', '', src) + # Return the source return src diff --git a/gimmik/kernels/c-omp.mako b/gimmik/kernels/c-omp.mako index 74d8cd0..1f77d94 100644 --- a/gimmik/kernels/c-omp.mako +++ b/gimmik/kernels/c-omp.mako @@ -1,14 +1,22 @@ # -*- coding: utf-8 -*- void -${funcn}(int ncol, +% if n is None: +${funcn}(int n, const ${dtype}* restrict b, int ldb, ${dtype}* restrict c, int ldc) { +% else: +${funcn}(const ${dtype}* restrict b, ${dtype}* restrict c) +{ + const int n = ${n}; + const int ldb = ${ldb}; + const int ldc = ${ldc}; +% endif ${dtype} dotp; #pragma omp parallel for simd private(dotp) - for (int i = 0; i < ncol; i++) + for (int i = 0; i < n; i++) { % for j, jx in enumerate(mat): dotp = ${' + '.join(f'{kx}*b[i + {k}*ldb]' diff --git a/gimmik/kernels/c.mako b/gimmik/kernels/c.mako index c844985..aa42194 100644 --- a/gimmik/kernels/c.mako +++ b/gimmik/kernels/c.mako @@ -1,10 +1,18 @@ # -*- coding: utf-8 -*- void +% if n is None: ${funcn}(int n, const ${dtype}* restrict b, int ldb, ${dtype}* restrict c, int ldc) { +% else: +${funcn}(const ${dtype}* restrict b, ${dtype}* restrict c) +{ + const int n = ${n}; + const int ldb = ${ldb}; + const int ldc = ${ldc}; +% endif ${dtype} dotp; #pragma omp simd diff --git a/gimmik/kernels/cuda.mako b/gimmik/kernels/cuda.mako index 17bf159..c1b7564 100644 --- a/gimmik/kernels/cuda.mako +++ b/gimmik/kernels/cuda.mako @@ -1,10 +1,18 @@ # -*- coding: utf-8 -*- __global__ void +% if n is None: ${funcn}(int n, const ${dtype}* __restrict__ b, int ldb, ${dtype}* __restrict__ c, int ldc) { +% else: +${funcn}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) +{ + const int n = ${n}; + const int ldb = ${ldb}; + const int ldc = ${ldc}; +% endif int i = blockDim.x*blockIdx.x + threadIdx.x; ${dtype} dotp; diff --git a/gimmik/kernels/hip.mako b/gimmik/kernels/hip.mako index f9b4b0c..92c2897 100644 --- a/gimmik/kernels/hip.mako +++ b/gimmik/kernels/hip.mako @@ -1,10 +1,18 @@ # -*- coding: utf-8 -*- __global__ __launch_bounds__(128) void +% if n is None: ${funcn}(int n, const ${dtype}* __restrict__ b, int ldb, ${dtype}* __restrict__ c, int ldc) { +% else: +${funcn}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) +{ + const int n = ${n}; + const int ldb = ${ldb}; + const int ldc = ${ldc}; +% endif int i = blockDim.x*blockIdx.x + threadIdx.x; ${dtype} dotp; diff --git a/gimmik/kernels/ispc.mako b/gimmik/kernels/ispc.mako index 5cebafd..ac186a8 100644 --- a/gimmik/kernels/ispc.mako +++ b/gimmik/kernels/ispc.mako @@ -1,10 +1,18 @@ # -*- coding: utf-8 -*- export void +% if n is None: ${funcn}(uniform int n, const uniform ${dtype} b[], uniform int ldb, ${dtype} uniform c[], uniform int ldc) { +% else: +${funcn}(const uniform ${dtype} b[], ${dtype} uniform c[]) +{ + const uniform int n = ${n}; + const uniform int ldb = ${ldb}; + const uniform int ldc = ${ldc}; +% endif ${dtype} dotp; foreach (i = 0 ... n) diff --git a/gimmik/kernels/opencl.mako b/gimmik/kernels/opencl.mako index 16edf9d..79f483e 100644 --- a/gimmik/kernels/opencl.mako +++ b/gimmik/kernels/opencl.mako @@ -7,10 +7,18 @@ % endif __kernel void +% if n is None: ${funcn}(int n, __global const ${dtype}* restrict b, int ldb, __global ${dtype}* restrict c, int ldc) { +% else: +${funcn}(__global const ${dtype}* restrict b, __global ${dtype}* restrict c) +{ + const int n = ${n}; + const int ldb = ${ldb}; + const int ldc = ${ldc}; +% endif int i = get_global_id(0); ${dtype} dotp;