Skip to content

Commit

Permalink
Merge pull request #622 from kinnala/do-not-overwrite
Browse files Browse the repository at this point in the history
Do not overwrite by default
  • Loading branch information
kinnala authored Apr 13, 2021
2 parents 13f5630 + a932c06 commit 6b27580
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 21 deletions.
44 changes: 25 additions & 19 deletions skfem/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@
Solution = Union[ndarray, Tuple[ndarray, ndarray]]
LinearSolver = Callable[..., ndarray]
EigenSolver = Callable[..., Tuple[ndarray, ndarray]]
EnforcedSystem = Union[spmatrix,
Tuple[spmatrix, ndarray],
Tuple[spmatrix, spmatrix]]
CondensedSystem = Union[spmatrix,
Tuple[spmatrix, ndarray],
Tuple[spmatrix, spmatrix],
LinearSystem = Union[spmatrix,
Tuple[spmatrix, ndarray],
Tuple[spmatrix, spmatrix]]
CondensedSystem = Union[LinearSystem,
Tuple[spmatrix, ndarray, ndarray],
Tuple[spmatrix, ndarray, ndarray, ndarray],
Tuple[spmatrix, spmatrix, ndarray, ndarray]]
Expand Down Expand Up @@ -248,7 +246,7 @@ def _flatten_helper(S, key):


def _init_bc(A: spmatrix,
b: Optional[ndarray] = None,
b: Optional[Union[ndarray, spmatrix]] = None,
x: Optional[ndarray] = None,
I: Optional[DofsCollection] = None,
D: Optional[DofsCollection] = None) -> Tuple[Optional[ndarray],
Expand Down Expand Up @@ -284,7 +282,8 @@ def enforce(A: spmatrix,
x: Optional[ndarray] = None,
I: Optional[DofsCollection] = None,
D: Optional[DofsCollection] = None,
diag: float = 1.) -> EnforcedSystem:
diag: float = 1.,
overwrite: bool = False) -> LinearSystem:
r"""Enforce degrees-of-freedom of a linear system.
.. note::
Expand All @@ -307,39 +306,46 @@ def enforce(A: spmatrix,
D
Specify either this or ``I``: The set of degree-of-freedom indices to
enforce (rows/diagonal set to zero/one).
overwrite
Optionally, the original system is both modified (for performance) and
returned (for compatibility with :func:`skfem.utils.solve`). By
default, ``False``.
Returns
-------
EnforcedSystem
LinearSystem
A linear system with the enforced rows/diagonals set to zero/one.
"""
b, x, I, D = _init_bc(A, b, x, I, D)

Aout = A if overwrite else A.copy()

# set rows on lhs to zero
start = A.indptr[D]
stop = A.indptr[D + 1]
start = Aout.indptr[D]
stop = Aout.indptr[D + 1]
count = stop - start
idx = np.ones(count.sum(), dtype=np.int64)
idx[np.cumsum(count)[:-1]] -= count[:-1]
idx = np.repeat(start, count) + np.cumsum(idx) - 1
A.data[idx] = 0.
Aout.data[idx] = 0.

# set diagonal value
d = A.diagonal()
d = Aout.diagonal()
d[D] = diag
A.setdiag(d)
Aout.setdiag(d)

if b is not None:
if isinstance(b, spmatrix):
# eigenvalue problem
b = enforce(b, D=D, diag=0.)
# mass matrix (eigen- or initial value problem)
bout = enforce(b, D=D, diag=0., overwrite=overwrite)
else:
# set rhs to the given value
b[D] = x[D]
return A, b
bout = b if overwrite else b.copy()
bout[D] = x[D]
return Aout, bout

return A
return Aout


def condense(A: spmatrix,
Expand Down
7 changes: 5 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ def runTest(self):
M = mass.assemble(basis)
D = m.boundary_nodes()

assert_almost_equal(enforce(A, D=D).todense(), np.eye(A.shape[0]))
assert_almost_equal(enforce(M, D=D, diag=0.).todense(),
assert_almost_equal(enforce(A, D=D).toarray(), np.eye(A.shape[0]))
assert_almost_equal(enforce(M, D=D, diag=0.).toarray(),
np.zeros(M.shape))

enforce(A, D=D, overwrite=True)
assert_almost_equal(A.toarray(), np.eye(A.shape[0]))


if __name__ == '__main__':
unittest.main()

0 comments on commit 6b27580

Please sign in to comment.