Skip to content

Commit

Permalink
Enable kernels to be generated with constant dimensions.
Browse files Browse the repository at this point in the history
This can reduce the size of the generated code and enable
further optimisations on the part of the compiler.
  • Loading branch information
FreddieWitherden committed Dec 17, 2021
1 parent d7f330c commit c89c8a6
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 4 deletions.
16 changes: 14 additions & 2 deletions gimmik/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from gimmik._version import __version__


def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm'):
def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm',
n=None, ldb=None, ldc=None):
# Data type
dtype = np.dtype(dtype).type
if dtype == np.float32:
Expand All @@ -19,11 +20,22 @@ def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm'):
else:
raise ValueError('Invalid floating point data type')

if 0 < (n is None) + (ldb is None) + (ldc is None) < 3:
raise ValueError('Must provide all of (n, ldb, ldc) or none')

# Multiply the matrix through by alpha
mat = alpha*mat

# Template arguments
tplargs = {'dtype': dtype, 'mat': mat, 'beta': beta, 'funcn': funcn}
tplargs = {
'dtype': dtype,
'mat': mat,
'beta': beta,
'funcn': funcn,
'n': n,
'ldb': ldb,
'ldc': ldc
}

# Load and render the template
tpl = pkgutil.get_data(__name__, f'kernels/{platform}.mako')
Expand Down
12 changes: 10 additions & 2 deletions gimmik/kernels/c-omp.mako
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
# -*- coding: utf-8 -*-

void
${funcn}(int ncol,
% if n is None:
${funcn}(int n,
const ${dtype}* restrict b, int ldb,
${dtype}* restrict c, int ldc)
{
% else:
${funcn}(const ${dtype}* restrict b, ${dtype}* restrict c)
{
const int n = ${n};
const int ldb = ${ldb};
const int ldc = ${ldc};
% endif
${dtype} dotp;

#pragma omp parallel for simd private(dotp)
for (int i = 0; i < ncol; i++)
for (int i = 0; i < n; i++)
{
% for j, jx in enumerate(mat):
dotp = ${' + '.join(f'{kx}*b[i + {k}*ldb]'
Expand Down
8 changes: 8 additions & 0 deletions gimmik/kernels/c.mako
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# -*- coding: utf-8 -*-

void
% if n is None:
${funcn}(int n,
const ${dtype}* restrict b, int ldb,
${dtype}* restrict c, int ldc)
{
% else:
${funcn}(const ${dtype}* restrict b, ${dtype}* restrict c)
{
const int n = ${n};
const int ldb = ${ldb};
const int ldc = ${ldc};
% endif
${dtype} dotp;

#pragma omp simd
Expand Down
8 changes: 8 additions & 0 deletions gimmik/kernels/cuda.mako
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# -*- coding: utf-8 -*-

__global__ void
% if n is None:
${funcn}(int n,
const ${dtype}* __restrict__ b, int ldb,
${dtype}* __restrict__ c, int ldc)
{
% else:
${funcn}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c)
{
const int n = ${n};
const int ldb = ${ldb};
const int ldc = ${ldc};
% endif
int i = blockDim.x*blockIdx.x + threadIdx.x;
${dtype} dotp;

Expand Down
8 changes: 8 additions & 0 deletions gimmik/kernels/hip.mako
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# -*- coding: utf-8 -*-

__global__ __launch_bounds__(128) void
% if n is None:
${funcn}(int n,
const ${dtype}* __restrict__ b, int ldb,
${dtype}* __restrict__ c, int ldc)
{
% else:
${funcn}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c)
{
const int n = ${n};
const int ldb = ${ldb};
const int ldc = ${ldc};
% endif
int i = blockDim.x*blockIdx.x + threadIdx.x;
${dtype} dotp;

Expand Down
8 changes: 8 additions & 0 deletions gimmik/kernels/ispc.mako
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# -*- coding: utf-8 -*-

export void
% if n is None:
${funcn}(uniform int n,
const uniform ${dtype} b[], uniform int ldb,
${dtype} uniform c[], uniform int ldc)
{
% else:
${funcn}(const uniform ${dtype} b[], ${dtype} uniform c[])
{
const uniform int n = ${n};
const uniform int ldb = ${ldb};
const uniform int ldc = ${ldc};
% endif
${dtype} dotp;

foreach (i = 0 ... n)
Expand Down
7 changes: 7 additions & 0 deletions gimmik/kernels/opencl.mako
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,17 @@
% endif

__kernel void
% if n is None:
${funcn}(int n,
__global const ${dtype}* restrict b, int ldb,
__global ${dtype}* restrict c, int ldc)
{
% else:
${funcn}(__global const ${dtype}* restrict b, __global ${dtype}* restrict c)
const int n = ${n};
const int ldb = ${ldb};
const int ldc = ${ldc};
% endif
int i = get_global_id(0);
${dtype} dotp;

Expand Down

0 comments on commit c89c8a6

Please sign in to comment.