Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify subtensor shape inference #1299

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytensor/compile/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,8 +1480,8 @@ def print_tips(self, file):
ps.XOR,
ps.AND,
ps.Invert,
ps.ScalarMaximum,
ps.ScalarMinimum,
ps.Maximum,
ps.Minimum,
ps.Add,
ps.Mul,
ps.Sub,
Expand Down
18 changes: 18 additions & 0 deletions pytensor/link/jax/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
Composite,
Identity,
IntDiv,
Maximum,
Minimum,
Mod,
Mul,
ScalarOp,
Expand Down Expand Up @@ -172,6 +174,22 @@ def elemwise(x, y):
return elemwise


@jax_funcify.register(Maximum)
def jax_funcify_scalar_Maximum(op, **kwargs):
def elemwise(*inputs):
return functools.reduce(jnp.maximum, inputs[1:], inputs[0])

return elemwise


@jax_funcify.register(Minimum)
def jax_funcify_scalar_Minimum(op, **kwargs):
def elemwise(*inputs):
return functools.reduce(jnp.minimum, inputs[1:], inputs[0])

return elemwise


@jax_funcify.register(Cast)
def jax_funcify_Cast(op, **kwargs):
def cast(x):
Expand Down
18 changes: 9 additions & 9 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
XOR,
Add,
IntDiv,
Maximum,
Minimum,
Mul,
ScalarMaximum,
ScalarMinimum,
Sub,
TrueDiv,
get_scalar_type,
scalar_maximum,
maximum,
)
from pytensor.scalar.basic import add as add_as
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
Expand Down Expand Up @@ -103,16 +103,16 @@ def scalar_in_place_fn_IntDiv(op, idx, res, arr):
return f"{res}[{idx}] //= {arr}"


@scalar_in_place_fn.register(ScalarMaximum)
def scalar_in_place_fn_ScalarMaximum(op, idx, res, arr):
@scalar_in_place_fn.register(Maximum)
def scalar_in_place_fn_Maximum(op, idx, res, arr):
return f"""
if {res}[{idx}] < {arr}:
{res}[{idx}] = {arr}
"""


@scalar_in_place_fn.register(ScalarMinimum)
def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr):
@scalar_in_place_fn.register(Minimum)
def scalar_in_place_fn_Minimum(op, idx, res, arr):
return f"""
if {res}[{idx}] > {arr}:
{res}[{idx}] = {arr}
Expand Down Expand Up @@ -458,7 +458,7 @@ def numba_funcify_Softmax(op, node, **kwargs):
if axis is not None:
axis = normalize_axis_index(axis, x_at.ndim)
reduce_max_py = create_multiaxis_reducer(
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
)
reduce_sum_py = create_multiaxis_reducer(
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
Expand Down Expand Up @@ -522,7 +522,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
if axis is not None:
axis = normalize_axis_index(axis, x_at.ndim)
reduce_max_py = create_multiaxis_reducer(
scalar_maximum,
maximum,
-np.inf,
(axis,),
x_at.ndim,
Expand Down
35 changes: 35 additions & 0 deletions pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,23 @@
create_numba_signature,
generate_fallback_impl,
numba_funcify,
numba_njit,
)
from pytensor.link.numba.dispatch.cython_support import wrap_cython_function
from pytensor.link.utils import (
compile_function_src,
get_name_for_object,
unique_name_generator,
)
from pytensor.scalar import discrete_dtypes
from pytensor.scalar.basic import (
Add,
Cast,
Clip,
Composite,
Identity,
Maximum,
Minimum,
Mul,
Reciprocal,
ScalarOp,
Expand Down Expand Up @@ -186,6 +190,37 @@ def numba_funcify_Mul(op, node, **kwargs):
return numba_basic.numba_njit(signature)(nary_add_fn)


@numba_funcify.register(Maximum)
@numba_funcify.register(Minimum)
def numba_funcify_Extremum(op, node, **kwargs):
input_names = [f"x{i}" for i in range(len(node.inputs))]
input_signature = ", ".join(input_names)
assert len(input_names) > 0

inner_code = f"res = {input_names[0]}\n"

if isinstance(op, Maximum):
op = ">"
func_name = "maximum"
else:
op = "<"
func_name = "minimum"

if all(inp.dtype in discrete_dtypes for inp in node.inputs):
for x in input_names[1:]:
inner_code += f" res = {x} if {x} {op} res else res\n"
else:
for x in input_names[1:]:
inner_code += f" res = {x} if {x} {op} res else (res if res {op}= {x} else np.nan)\n"
inner_code += " return res"

src = f"""
def {func_name}({input_signature}):
{inner_code}
"""
return numba_njit(compile_function_src(src, func_name, globals() | {"np": np}))


@numba_funcify.register(Cast)
def numba_funcify_Cast(op, node, **kwargs):
dtype = np.dtype(op.o_type.dtype)
Expand Down
119 changes: 75 additions & 44 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import math
from collections.abc import Callable
from copy import copy
from functools import reduce
from itertools import chain
from textwrap import dedent
from typing import Any, TypeAlias
Expand Down Expand Up @@ -1868,89 +1869,119 @@ def c_code(self, node, name, inputs, outputs, sub):
##############
# Arithmetic
##############
class ScalarMaximum(BinaryScalarOp):
class AtLeastUnaryScalarOp(ScalarOp):
def make_node(self, *inputs):
if len(inputs) == 0:
raise TypeError(f"{self} requires at least 1 input: got 0")
return super().make_node(*inputs)


class Maximum(AtLeastUnaryScalarOp):
commutative = True
associative = True
nfunc_spec = ("maximum", 2, 1)
nfunc_variadic = "maximum"
nfunc_variadic = "max"
identity = -np.inf

def impl(self, *inputs):
# The built-in max function don't support complex type
return np.maximum(*inputs)
return reduce(np.maximum, inputs)

def c_code(self, node, name, inputs, outputs, sub):
(x, y) = inputs
(z,) = outputs
if any(i.type in complex_types for i in node.inputs):
raise NotImplementedError()
# Test for both y>x and x>=y to detect NaN
return f'{z} = (({y})>({x})? ({y}): (({x})>=({y})? ({x}): nan("")));'

x, *ys = inputs
[z] = outputs

# We need an intermediate variable in case we are working inplace
tmp = f"{z}_tmp"
res = f"{node.outputs[0].type.dtype_specs()[1]} {tmp} = ({x});"
if all(i.dtype in discrete_dtypes for i in node.inputs):
for y in ys:
res += f"\n{tmp} = (({y}) > {tmp})? ({y}): {tmp};"
else:
# Need to check for nans
for y in ys:
res += (
f"\n{tmp} = (({y}) > {tmp})? ({y}): (({tmp} >= ({y}))? {tmp}: NAN);"
)
res += f"\n{z} = {tmp};"
return res

def c_code_cache_version(self):
return (2,)

def L_op(self, inputs, outputs, gout):
(x, y) = inputs
(gz,) = gout
[gz] = gout
if gz.type in complex_types:
# max is currently defined for complex_types,
# but the gradient for complex is not.
raise NotImplementedError()

if outputs[0].type in discrete_types:
return [
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]
# This form handle the case when both value are the same.
# In that case, gx will be gz, gy will be 0.
e = eq(outputs[0], x)
gx = e * gz
gy = (constant(1, dtype=gz.dtype) - e) * gz
return (gx, gy)
[out] = outputs

if out.type in discrete_types:
return [inp.zeros_like(dtype=config.floatX) for inp in inputs]

# We propagate the gradient to the maximum value(s) in the input
return [eq(inp, out) * gz for inp in inputs]

scalar_maximum = ScalarMaximum(upcast_out, name="maximum")

maximum = Maximum(upcast_out, name="maximum")

class ScalarMinimum(BinaryScalarOp):

class Minimum(AtLeastUnaryScalarOp):
commutative = True
associative = True
nfunc_spec = ("minimum", 2, 1)
nfunc_variadic = "minimum"
nfunc_variadic = "min"
identity = np.inf

def impl(self, *inputs):
# The built-in min function don't support complex type
return np.minimum(*inputs)
return reduce(np.minimum, inputs)

def c_code(self, node, name, inputs, outputs, sub):
(x, y) = inputs
(z,) = outputs
if any(i.type in complex_types for i in node.inputs):
raise NotImplementedError()
return f'{z} = (({y})<({x})? ({y}): (({x})<=({y})? ({x}): nan("")));'

x, *ys = inputs
[z] = outputs

# We need an intermediate variable in case we are working inplace
tmp = f"{z}_tmp"
res = f"{node.outputs[0].type.dtype_specs()[1]} {tmp} = ({x});"
if all(i.dtype in discrete_dtypes for i in node.inputs):
for y in ys:
res += f"\n{tmp} = (({y}) < {tmp})? ({y}): {tmp};"
else:
# Need to check for nans
for y in ys:
res += (
f"\n{tmp} = (({y}) < {tmp})? ({y}): (({tmp} <= ({y}))? {tmp}: NAN);"
)
res += f"\n{z} = {tmp};"
return res

def c_code_cache_version(self):
return (2,)

def L_op(self, inputs, outputs, gout):
(x, y) = inputs
(gz,) = gout
[gz] = gout
if gz.type in complex_types:
# min is currently defined for complex_types,
# max is currently defined for complex_types,
# but the gradient for complex is not.
raise NotImplementedError()

if outputs[0].type in discrete_types:
return [
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]
# This form handle the case when both value are the same.
# In that case, gx will be gz, gy will be 0.
e = eq(outputs[0], x)
gx = e * gz
gy = (constant(1, dtype=gz.dtype) - e) * gz
return (gx, gy)
[out] = outputs

if out.type in discrete_types:
return [inp.zeros_like(dtype=config.floatX) for inp in inputs]

# We propagate the gradient to the minimum value(s) in the input
return [eq(inp, out) * gz for inp in inputs]


scalar_minimum = ScalarMinimum(upcast_out, name="minimum")
minimum = Minimum(upcast_out, name="minimum")


class Add(ScalarOp):
Expand Down
10 changes: 4 additions & 6 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
isinf,
log,
log1p,
maximum,
reciprocal,
scalar_maximum,
sqrt,
switch,
true_div,
Expand Down Expand Up @@ -1305,7 +1305,7 @@ def c_code_cache_version(self):
return v


softplus = Softplus(upgrade_to_float, name="scalar_softplus")
softplus = Softplus(upgrade_to_float, name="softplus")


class Log1mexp(UnaryScalarOp):
Expand Down Expand Up @@ -1575,9 +1575,7 @@ def inner_loop(
derivative_new = K * (F1 * dK + F2)

errapx = scalar_abs(derivative - derivative_new)
d_errapx = errapx / scalar_maximum(
err_threshold, scalar_abs(derivative_new)
)
d_errapx = errapx / maximum(err_threshold, scalar_abs(derivative_new))

min_iters_cond = n > (min_iters - 1)
derivative = switch(
Expand Down Expand Up @@ -1823,7 +1821,7 @@ def inner_loop(*args):
if len(grad_incs) == 1:
[max_abs_grad_inc] = grad_incs
else:
max_abs_grad_inc = reduce(scalar_maximum, abs_grad_incs)
max_abs_grad_inc = reduce(maximum, abs_grad_incs)

return (
(*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k),
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ def _obj_is_wrappable_as_tensor(x):
ps.Mul,
ps.IntDiv,
ps.TrueDiv,
ps.ScalarMinimum,
ps.ScalarMaximum,
ps.Minimum,
ps.Maximum,
)


Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,8 +947,8 @@ def infer_shape(self, fgraph, node, input_shapes):
z_shape, _, x_shape, y_shape, _ = input_shapes
return [
(
pytensor.scalar.scalar_maximum(z_shape[0], x_shape[0]),
pytensor.scalar.scalar_maximum(z_shape[1], y_shape[1]),
pytensor.scalar.maximum(z_shape[0], x_shape[0]),
pytensor.scalar.maximum(z_shape[1], y_shape[1]),
)
]

Expand Down
Loading
Loading