Skip to content

Commit

Permalink
Merge pull request #625 from kinnala/add-optional-threading
Browse files Browse the repository at this point in the history
Optionally assemble using multiple threads
  • Loading branch information
kinnala authored Apr 19, 2021
2 parents aa01df2 + 122c298 commit d3da97a
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 10 deletions.
53 changes: 46 additions & 7 deletions skfem/assembly/form/bilinear_form.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Optional, Tuple
from threading import Thread
from itertools import product

import numpy as np
from numpy import ndarray
Expand Down Expand Up @@ -73,7 +75,10 @@ def _assemble(self,

# initialize COO data structures
sz = ubasis.Nbfun * vbasis.Nbfun * nt
data = np.zeros(sz, dtype=self.dtype)
if self.nthreads > 0:
data = np.zeros((ubasis.Nbfun, vbasis.Nbfun, nt), dtype=self.dtype)
else:
data = np.zeros(sz, dtype=self.dtype)
rows = np.zeros(sz, dtype=np.int64)
cols = np.zeros(sz, dtype=np.int64)

Expand All @@ -84,12 +89,36 @@ def _assemble(self,
nt * (vbasis.Nbfun * j + i + 1))
rows[ixs] = vbasis.element_dofs[i]
cols[ixs] = ubasis.element_dofs[j]
data[ixs] = self._kernel(
ubasis.basis[j],
vbasis.basis[i],
wdict,
dx,
)
if self.nthreads <= 0:
data[ixs] = self._kernel(
ubasis.basis[j],
vbasis.basis[i],
wdict,
dx,
)

if self.nthreads > 0:
# create indices for linear loop over local stiffness matrix
indices = np.array(
[[i, j] for j, i in product(range(ubasis.Nbfun),
range(vbasis.Nbfun))]
)

# split local stiffness matrix elements to threads
threads = [
Thread(
target=self._threaded_kernel,
args=(data, ix, ubasis.basis, vbasis.basis, wdict, dx)
) for ix in np.array_split(indices, self.nthreads, axis=0)
]

# start threads and wait for finishing
for t in threads:
t.start()
for t in threads:
t.join()

data = data.flatten('C')

return data, rows, cols, (vbasis.N, ubasis.N)

Expand All @@ -114,3 +143,13 @@ def assemble(self, *args, **kwargs) -> csr_matrix:

def _kernel(self, u, v, w, dx):
return np.sum(self.form(*u, *v, w) * dx, axis=1)

def _threaded_kernel(self, data, ix, ubasis, vbasis, wdict, dx):
for ij in ix:
i, j = ij
data[j, i] = self._kernel(
ubasis[j],
vbasis[i],
wdict,
dx,
)
8 changes: 6 additions & 2 deletions skfem/assembly/form/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ class Form:

def __init__(self,
form: Optional[Callable] = None,
dtype: type = np.float64):
dtype: type = np.float64,
nthreads: int = 0):
self.form = form.form if isinstance(form, Form) else form
self.dtype = dtype
self.nthreads = nthreads

def partial(self, *args, **kwargs):
form = deepcopy(self)
Expand All @@ -34,7 +36,9 @@ def partial(self, *args, **kwargs):

def __call__(self, *args):
if self.form is None: # decorate
return type(self)(form=args[0], dtype=self.dtype)
return type(self)(form=args[0],
dtype=self.dtype,
nthreads=self.nthreads)
return self.assemble(self.kernel(*args))

def assemble(self,
Expand Down
33 changes: 32 additions & 1 deletion tests/test_assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import pytest
import numpy as np
from numpy.testing import assert_equal, assert_almost_equal
from numpy.testing import (assert_equal, assert_almost_equal,
assert_array_almost_equal)

from skfem import BilinearForm, LinearForm, Functional, asm, solve
from skfem.element import (ElementQuad1, ElementQuadS2, ElementHex1,
Expand Down Expand Up @@ -35,6 +36,13 @@ def uv(u, v, w):

B = asm(uv, self.fbasis)

# assemble the same matrix using multiple threads
@BilinearForm(nthreads=2)
def uvt(u, v, w):
return u * v

Bt = asm(uvt, self.fbasis)

@LinearForm
def gv(v, w):
return 1.0 * v
Expand All @@ -45,6 +53,7 @@ def gv(v, w):

self.assertAlmostEqual(ones @ g, self.boundary_area, places=4)
self.assertAlmostEqual(ones @ (B @ ones), self.boundary_area, places=4)
self.assertAlmostEqual(ones @ (Bt @ ones), self.boundary_area, places=4)


class IntegrateOneOverBoundaryS2(IntegrateOneOverBoundaryQ1):
Expand Down Expand Up @@ -434,5 +443,27 @@ def complexfun(v, w):
self.assertAlmostEqual(np.dot(ones, f), 1j * self.interior_area)


class TestThreadedAssembly(TestCase):

def runTest(self):

m = MeshTri().refined()
e = ElementTriP1()
basis = InteriorBasis(m, e)

@BilinearForm
def nonsym(u, v, w):
return u.grad[0] * v

@BilinearForm(nthreads=2)
def threaded_nonsym(u, v, w):
return u.grad[0] * v

assert_almost_equal(
nonsym.assemble(basis).toarray(),
threaded_nonsym.assemble(basis).toarray(),
)


if __name__ == '__main__':
main()

0 comments on commit d3da97a

Please sign in to comment.