Skip to content

Commit

Permalink
Stop checking for input alias in Function.__call__
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 22, 2024
1 parent 4258475 commit 23b4fe9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 129 deletions.
81 changes: 12 additions & 69 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,41 +393,6 @@ def __init__(
assert len(self.input_storage) == len(self.maker.fgraph.inputs)
assert len(self.output_storage) == len(self.maker.fgraph.outputs)

# Group indexes of inputs that are potentially aliased to each other
# Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
# even though there could be two distinct types that use the same kinds of underlying objects.
potential_aliased_input_groups = []
for inp in maker.inputs:
# If the input is a shared variable, the memory region is under PyTensor control
# and can't be aliased.
if not (
isinstance(inp, In)
and inp.borrow
and not inp.shared
and hasattr(inp.variable.type, "may_share_memory")
):
continue

for group in potential_aliased_input_groups:
# If one is super of the other, that means one could be replaced by the other
if any(
inp.variable.type.is_super(other_inp.variable.type)
or other_inp.variable.type.is_super(inp.variable.type)
for other_inp in group
):
group.append(inp)
break
else: # no break
# Input makes a new group
potential_aliased_input_groups.append([inp])

# Potential aliased inputs are those that belong to the same group
self._potential_aliased_input_groups: tuple[tuple[int, ...], ...] = tuple(
tuple(maker.inputs.index(inp) for inp in group)
for group in potential_aliased_input_groups
if len(group) > 1
)

# We will be popping stuff off this `containers` object. It is a copy.
containers = list(self.input_storage)
finder = {}
Expand Down Expand Up @@ -844,11 +809,18 @@ def __call__(self, *args, **kwargs):
if self.output_keys is not None:
output_subset = [self.output_keys.index(key) for key in output_subset]

# Reinitialize each container's 'provided' counter
if self.trust_input:
# Set positional arguments
for arg_container, arg in zip(input_storage, args, strict=False):
arg_container.storage[0] = arg

# Set keyword arguments
if kwargs: # for speed, skip the items for empty kwargs
for k, arg in kwargs.items():
self[k] = arg

else:
# Reinitialize each container's 'provided' counter
for arg_container in input_storage:
arg_container.provided = 0

Expand Down Expand Up @@ -899,39 +871,10 @@ def __call__(self, *args, **kwargs):
raise
arg_container.provided += 1

# Set keyword arguments
if kwargs: # for speed, skip the items for empty kwargs
for k, arg in kwargs.items():
self[k] = arg

if not self.trust_input:
# Collect aliased inputs among the storage space
for potential_group in self._potential_aliased_input_groups:
args_share_memory: list[list[int]] = []
for i in potential_group:
i_type = self.maker.inputs[i].variable.type
i_val = input_storage[i].storage[0]

# Check if value is aliased with any of the values in one of the groups
for j_group in args_share_memory:
if any(
i_type.may_share_memory(input_storage[j].storage[0], i_val)
for j in j_group
):
j_group.append(i)
break
else: # no break
# Create a new group
args_share_memory.append([i])

# Check for groups of more than one argument that share memory
for group in args_share_memory:
if len(group) > 1:
# copy all but the first
for i in group[1:]:
input_storage[i].storage[0] = copy.copy(
input_storage[i].storage[0]
)
# Set keyword arguments
if kwargs: # for speed, skip the items for empty kwargs
for k, arg in kwargs.items():
self[k] = arg

# Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit.
Expand Down
19 changes: 11 additions & 8 deletions tests/compile/function/test_pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,16 +732,13 @@ class TestAliasingRules:
# 2. shared variables are allocated in this memory space, as are the
# temporaries used for Function evaluation.
#
# 3. Physically, this managed memory space may be spread across the host,
# on a GPU device(s), or even on a remote machine.
#
# 4. PyTensor assumes that shared variables are never aliased to one another,
# 3. PyTensor assumes that shared variables are never aliased to one another,
# and tries to make it impossible to accidentally alias them.
#
# 5. PyTensor's managed data is constant while PyTensor Functions are not running
# 4. PyTensor's managed data is constant while PyTensor Functions are not running
# and pytensor library code is not running.
#
# 6. The default behaviour of Function is to return user-space values for
# 5. The default behaviour of Function is to return user-space values for
# outputs, but this can be overridden (borrow=True) for better performance,
# in which case the returned value may be aliased to managed memory, and
# potentially invalidated by the next PyTensor Function call or call to pytensor
Expand Down Expand Up @@ -810,6 +807,9 @@ def test_sparse_input_aliasing_affecting_inplace_operations(self):
assert np.allclose(vals.todense(), bogus_vals.todense())

def test_input_aliasing_affecting_inplace_operations(self):
# Note: The input aliasing check was disabled, so this test now just confirms that wrong values
# will be obtained if the inputs are aliased.

# Note: to trigger this bug with pytensor rev 4586:2bc6fc7f218b,
# you need to make in inputs mutable (so that inplace
# operations are used) and to break the elemwise composition
Expand Down Expand Up @@ -860,9 +860,12 @@ def test_input_aliasing_affecting_inplace_operations(self):
v_copy = v.copy()
vals = f(v, v_copy, m, m_copy)

assert np.allclose(vals, bogus_vals)
assert not np.allclose(vals, bogus_vals)

def test_partial_input_aliasing_affecting_inplace_operations(self):
# Note: The input aliasing check was disabled, so this test now just confirms that wrong values
# will be obtained if the inputs are aliased.

# Note: to trigger this bug with pytensor rev 4586:2bc6fc7f218b,
# you need to make in inputs mutable ( so that inplace
# operations are used) and to break the elemwise composition
Expand Down Expand Up @@ -906,7 +909,7 @@ def test_partial_input_aliasing_affecting_inplace_operations(self):
v_copy2 = v.copy()
vals = f(v[:2], v_copy1[1:3], v_copy2[2:4], m, m_copy1, m_copy2)

assert np.allclose(vals, bogus_vals)
assert not np.allclose(vals, bogus_vals)

def test_potential_output_aliasing_induced_by_updates(self):
A = self.shared(np.zeros((2, 2)))
Expand Down
52 changes: 0 additions & 52 deletions tests/compile/function/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,52 +752,6 @@ def test_default_values(self):
except TypeError:
assert funct(first=1) == x

def test_check_for_aliased_inputs(self):
b = np.random.random((5, 4))
s1 = shared(b)
s2 = shared(b)
x1 = vector()
x2 = vector(shape=(3,))
x3 = vector(shape=(1,))

# Assert cases we should not check for aliased inputs
for d in [
dict(outputs=[s1 + 1]),
dict(outputs=[s1 + 1, s2 + 3]),
dict(outputs=[s1 + 1], updates=[(s2, s2 + 3)]),
dict(inputs=[x1], outputs=[x1 + 1], updates=[(s2, s2 + 3)]),
dict(
inputs=[In(x1, mutable=True)], outputs=[x1 + 1], updates=[(s2, s2 + 3)]
),
dict(
inputs=[In(x2, mutable=True), In(x3, mutable=True)],
outputs=[x2 + 2, x3 + 3],
),
]:
if "inputs" not in d:
d["inputs"] = []
f = function(**d)
assert not f._potential_aliased_input_groups, d

# Assert cases we should check for aliased inputs
for d in [
dict(
inputs=[In(x1, mutable=True), In(x2, mutable=True)],
outputs=[x1 + 1, x2 + 2],
updates=[(s2, s2 + 3)],
),
dict(
inputs=[In(x1, mutable=True), In(x3, mutable=True)],
outputs=[x1 + 1, x3 + 3],
updates=[(s2, s2 + 3)],
),
]:
if "inputs" not in d:
d["inputs"] = []
f = function(**d)

assert f._potential_aliased_input_groups, d

def test_output_dictionary(self):
# Tests that function works when outputs is a dictionary

Expand Down Expand Up @@ -939,12 +893,6 @@ def test_deepcopy(self):
assert x not in g.container
assert x not in g.value
assert len(f.defaults) == len(g.defaults)
# Shared variable is the first input
assert (
f._potential_aliased_input_groups
== g._potential_aliased_input_groups
== ((1, 2),)
)
assert f.name == g.name
assert f.maker.fgraph.name == g.maker.fgraph.name
# print(f"{f.defaults = }")
Expand Down

0 comments on commit 23b4fe9

Please sign in to comment.