Skip to content

Commit

Permalink
symbolics: fix printer and arithmetic for sympy 1.13
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jul 9, 2024
1 parent dee09dd commit 67a4586
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 8 deletions.
2 changes: 1 addition & 1 deletion devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _symbolic_functions(self):
@cached_property
def function(self):
if len(self._functions) == 1:
return self._functions.pop()
return set(self._functions).pop()
else:
return None

Expand Down
7 changes: 5 additions & 2 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,12 @@ def _print_Float(self, expr):
rv = to_str(expr._mpf_, dps, strip_zeros=strip, max_fixed=-2, min_fixed=2)

if rv.startswith('-.0'):
rv = '-0.' + rv[3:]
rv = "-0." + rv[3:]
elif rv.startswith('.0'):
rv = '0.' + rv[2:]
rv = "0." + rv[2:]

# Remove trailing zero except first one to avoid 1. instead of 1.0
rv = rv.rstrip('0') + "0"

if self.single_prec():
rv = '%sF' % rv
Expand Down
9 changes: 8 additions & 1 deletion devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,11 +735,18 @@ def adjoint(self, inner=True):
def __add__(self, other):
try:
# Most case support sympy add
return super().__add__(other)
tsum = super().__add__(other)
except TypeError:
# Sympy doesn't support add with scalars
tsum = self.applyfunc(lambda x: x + other)

# As of sympy 1.13, super does not throw an exception but
# only returns NotImplemented for some internal dispatch.
if tsum is NotImplemented:
return self.applyfunc(lambda x: x + other)

return tsum

def _eval_matrix_mul(self, other):
"""
Copy paste from sympy to avoid explicit call to sympy.Add
Expand Down
17 changes: 13 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,20 @@
from setuptools import setup, find_packages


def min_max(pkgs, pkg_name):
pkg = [p for p in pkgs if pkg_name in p][0]
vmin = pkg.split('>=')[1].split(',')[0]
vmax = pkg.split('<')[-1]
return vmin, vmax


def numpy_compat(required):
new_reqs = [r for r in required if "numpy" not in r and "sympy" not in r]
sympy_lb, sympy_ub = min_max(required, "sympy")
numpy_lb, numpy_ub = min_max(required, "numpy")
if sys.version_info < (3, 9):
# Numpy 2.0 requires python > 3.8
new_reqs.extend(["sympy>=1.9,<1.13", "numpy>1.16,<2.0"])
new_reqs.extend([f"sympy>={sympy_lb},<1.12.1", f"numpy>{numpy_lb},<2.0"])
return new_reqs

# Due to api changes in numpy 2.0, it requires sympy 1.12.1 at the minimum
Expand All @@ -20,11 +29,11 @@ def numpy_compat(required):
sympy_version = pkg_resources.get_distribution("sympy").version
min_ver2 = pkg_resources.parse_version("1.12.1")
if pkg_resources.parse_version(sympy_version) < min_ver2:
new_reqs.append("numpy>1.16,<2.0")
new_reqs.append(f"numpy>{numpy_lb},<2.0")
else:
new_reqs.append("numpy>=2.0")
new_reqs.append(f"numpy>=2.0,<{numpy_ub}")
except pkg_resources.DistributionNotFound:
new_reqs.extend(["sympy>=1.12.1", "numpy>=2.0"])
new_reqs.extend([f"sympy>=1.12.1,<{sympy_ub}", f"numpy>=2.0,<{numpy_ub}"])

return new_reqs

Expand Down
13 changes: 13 additions & 0 deletions tests/test_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,16 @@ def test_shifted_lap_of_tensor(shift, ndim):
type(shift) is tuple else d + shift * d.spacing)
ref += getattr(v[j, i], 'd%s2' % d.name)(x0=x0, fd_order=order)
assert df[j] == ref


def test_basic_arithmetic():
grid = Grid(tuple([5]*3))
tau = TensorFunction(name="tau", grid=grid)

# Scalar operations
t1 = tau + 1
print(t1)
assert all(t1i == ti + 1 for (t1i, ti) in zip(t1, tau))

t1 = tau * 2
assert all(t1i == ti * 2 for (t1i, ti) in zip(t1, tau))

0 comments on commit 67a4586

Please sign in to comment.