Skip to content

Commit

Permalink
api: fix printer dtype detection
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jun 18, 2024
1 parent be7c403 commit d1e99f3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
19 changes: 11 additions & 8 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def dtype(self):
def compiler(self):
return self._settings['compiler']

def single_prec(self, expr=None):
dtype = sympy_dtype(expr or self)
return dtype in [np.float32, np.float16]

def parenthesize(self, item, level, strict=False):
if isinstance(item, BooleanFunction):
return "(%s)" % self._print(item)
Expand Down Expand Up @@ -104,9 +108,8 @@ def _print_math_func(self, expr, nest=False, known=None):
except KeyError:
return super()._print_math_func(expr, nest=nest, known=known)

dtype = sympy_dtype(expr)
if dtype is np.float32:
cname += 'f'
if self.single_prec(expr):
cname = '%sf' % cname

args = ', '.join((self._print(arg) for arg in expr.args))

Expand All @@ -116,7 +119,7 @@ def _print_Pow(self, expr):
# Need to override because of issue #1627
# E.g., (Pow(h_x, -1) AND h_x.dtype == np.float32) => 1.0F/h_x
try:
if expr.exp == -1 and self.dtype == np.float32:
if expr.exp == -1 and self.single_prec():
PREC = precedence(expr)
return '1.0F/%s' % self.parenthesize(expr.base, PREC)
except AttributeError:
Expand Down Expand Up @@ -196,8 +199,8 @@ def _print_Float(self, expr):
elif rv.startswith('.0'):
rv = '0.' + rv[2:]

if self.dtype == np.float32:
rv = rv + 'F'
if self.single_prec():
rv = '%sF' % rv

return rv

Expand Down Expand Up @@ -252,8 +255,8 @@ def _print_ComponentAccess(self, expr):

def _print_TrigonometricFunction(self, expr):
func_name = str(expr.func)
if self.dtype == np.float32:
func_name += 'f'
if self.single_prec():
func_name = '%sf' % func_name
return '%s(%s)' % (func_name, self._print(*expr.args))

def _print_DefFunction(self, expr):
Expand Down
8 changes: 8 additions & 0 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,14 @@ def inject(self, field, expr, u_t=None, p_t=None, implicit_dims=None):

return super().inject(field, expr, implicit_dims=implicit_dims)

@property
def forward(self):
"""Symbol for the time-forward state of the TimeFunction."""
i = int(self.time_order / 2) if self.time_order >= 2 else 1
_t = self.dimensions[self._time_position]

return self._subs(_t, _t + i * _t.spacing)


class PrecomputedSparseFunction(AbstractSparseFunction):
"""
Expand Down

0 comments on commit d1e99f3

Please sign in to comment.