Skip to content

Commit

Permalink
Merge pull request #8 from PyFR/constdim.
Browse files Browse the repository at this point in the history
  • Loading branch information
FreddieWitherden committed Mar 21, 2022
2 parents cd8f77c + c89c8a6 commit fd14267
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 4 deletions.
20 changes: 18 additions & 2 deletions gimmik/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from gimmik._version import __version__


def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm'):
def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm',
n=None, ldb=None, ldc=None):
# Data type
dtype = np.dtype(dtype).type
if dtype == np.float32:
Expand All @@ -19,11 +20,22 @@ def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm'):
else:
raise ValueError('Invalid floating point data type')

if 0 < (n is None) + (ldb is None) + (ldc is None) < 3:
raise ValueError('Must provide all of (n, ldb, ldc) or none')

# Multiply the matrix through by alpha
mat = alpha*mat

# Template arguments
tplargs = {'dtype': dtype, 'mat': mat, 'beta': beta, 'funcn': funcn}
tplargs = {
'dtype': dtype,
'mat': mat,
'beta': beta,
'funcn': funcn,
'n': n,
'ldb': ldb,
'ldc': ldc
}

# Load and render the template
tpl = pkgutil.get_data(__name__, f'kernels/{platform}.mako')
Expand All @@ -34,5 +46,9 @@ def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm'):
src = re.sub(r'(?=\d*[.eE])(?=\.?\d)\d*\.?\d*(?:[eE][+-]?\d+)?',
r'\g<0>f', src)

# Cleanup
src = re.sub(r'\n\n+', r'\n\n', src.strip()) + '\n'
src = re.sub(r'\w+$', '', src)

# Return the source
return src
12 changes: 10 additions & 2 deletions gimmik/kernels/c-omp.mako
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
# -*- coding: utf-8 -*-

void
${funcn}(int ncol,
% if n is None:
${funcn}(int n,
const ${dtype}* restrict b, int ldb,
${dtype}* restrict c, int ldc)
{
% else:
${funcn}(const ${dtype}* restrict b, ${dtype}* restrict c)
{
const int n = ${n};
const int ldb = ${ldb};
const int ldc = ${ldc};
% endif
${dtype} dotp;

#pragma omp parallel for simd private(dotp)
for (int i = 0; i < ncol; i++)
for (int i = 0; i < n; i++)
{
% for j, jx in enumerate(mat):
dotp = ${' + '.join(f'{kx}*b[i + {k}*ldb]'
Expand Down
8 changes: 8 additions & 0 deletions gimmik/kernels/c.mako
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# -*- coding: utf-8 -*-

void
% if n is None:
${funcn}(int n,
const ${dtype}* restrict b, int ldb,
${dtype}* restrict c, int ldc)
{
% else:
${funcn}(const ${dtype}* restrict b, ${dtype}* restrict c)
{
const int n = ${n};
const int ldb = ${ldb};
const int ldc = ${ldc};
% endif
${dtype} dotp;

#pragma omp simd
Expand Down
8 changes: 8 additions & 0 deletions gimmik/kernels/cuda.mako
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# -*- coding: utf-8 -*-

__global__ void
% if n is None:
${funcn}(int n,
const ${dtype}* __restrict__ b, int ldb,
${dtype}* __restrict__ c, int ldc)
{
% else:
${funcn}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c)
{
const int n = ${n};
const int ldb = ${ldb};
const int ldc = ${ldc};
% endif
int i = blockDim.x*blockIdx.x + threadIdx.x;
${dtype} dotp;

Expand Down
8 changes: 8 additions & 0 deletions gimmik/kernels/hip.mako
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# -*- coding: utf-8 -*-

__global__ __launch_bounds__(128) void
% if n is None:
${funcn}(int n,
const ${dtype}* __restrict__ b, int ldb,
${dtype}* __restrict__ c, int ldc)
{
% else:
${funcn}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c)
{
const int n = ${n};
const int ldb = ${ldb};
const int ldc = ${ldc};
% endif
int i = blockDim.x*blockIdx.x + threadIdx.x;
${dtype} dotp;

Expand Down
8 changes: 8 additions & 0 deletions gimmik/kernels/ispc.mako
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# -*- coding: utf-8 -*-

export void
% if n is None:
${funcn}(uniform int n,
const uniform ${dtype} b[], uniform int ldb,
${dtype} uniform c[], uniform int ldc)
{
% else:
${funcn}(const uniform ${dtype} b[], ${dtype} uniform c[])
{
const uniform int n = ${n};
const uniform int ldb = ${ldb};
const uniform int ldc = ${ldc};
% endif
${dtype} dotp;

foreach (i = 0 ... n)
Expand Down
8 changes: 8 additions & 0 deletions gimmik/kernels/opencl.mako
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,18 @@
% endif

__kernel void
% if n is None:
${funcn}(int n,
__global const ${dtype}* restrict b, int ldb,
__global ${dtype}* restrict c, int ldc)
{
% else:
${funcn}(__global const ${dtype}* restrict b, __global ${dtype}* restrict c)
{
const int n = ${n};
const int ldb = ${ldb};
const int ldc = ${ldc};
% endif
int i = get_global_id(0);
${dtype} dotp;

Expand Down

0 comments on commit fd14267

Please sign in to comment.