Skip to content

Commit

Permalink
Cleanup Function.__call__
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 17, 2024
1 parent 147c892 commit d77f26c
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 161 deletions.
238 changes: 113 additions & 125 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,8 @@ class Function:
def __init__(
self,
vm: "VM",
input_storage,
output_storage,
input_storage: list[Container],
output_storage: list[Container],
indices,
outputs,
defaults,
Expand Down Expand Up @@ -372,7 +372,6 @@ def __init__(
name
A string name.
"""
# TODO: Rename to `vm`
self.vm = vm
self.input_storage = input_storage
self.output_storage = output_storage
Expand All @@ -388,31 +387,49 @@ def __init__(
self.nodes_with_inner_function = []
self.output_keys = output_keys

# See if we have any mutable / borrow inputs
# TODO: this only need to be set if there is more than one input
self._check_for_aliased_inputs = False
for i in maker.inputs:
# If the input is a shared variable, the memory region is
# under PyTensor control and so we don't need to check if it
# is aliased as we never do that.
if (
isinstance(i, In)
and not i.shared
and (getattr(i, "borrow", False) or getattr(i, "mutable", False))
assert len(self.input_storage) == len(self.maker.fgraph.inputs)
assert len(self.output_storage) == len(self.maker.fgraph.outputs)

# Group indexes of inputs that are potentially aliased to each other
# Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
# even though there could be two distinct types that use the same kinds of underlying objects.
potential_aliased_input_groups = []
for inp in maker.inputs:
# If the input is a shared variable, the memory region is under PyTensor control
# and can't be aliased.
if not (
isinstance(inp, In)
and inp.borrow
and not inp.shared
and hasattr(inp.variable.type, "may_share_memory")
):
self._check_for_aliased_inputs = True
break
continue

for group in potential_aliased_input_groups:
# If one is super of the other, that means one could be replaced by the other
if any(
inp.variable.type.is_super(other_inp.variable.type)
or other_inp.variable.type.is_super(inp.variable.type)
for other_inp in group
):
group.append(inp)
break
else: # no break
# Input makes a new group
potential_aliased_input_groups.append([inp])

# Potential aliased inputs are those that belong to the same group
self._potential_aliased_input_groups: tuple[tuple[int, ...], ...] = tuple(
tuple(maker.inputs.index(inp) for inp in group)
for group in potential_aliased_input_groups
if len(group) > 1
)

# We will be popping stuff off this `containers` object. It is a copy.
containers = list(self.input_storage)
finder = {}
inv_finder = {}

def distribute(indices, cs, value):
input.distribute(value, indices, cs)
for c in cs:
c.provided += 1

# Store the list of names of named inputs.
named_inputs = []
# Count the number of un-named inputs.
Expand Down Expand Up @@ -777,6 +794,13 @@ def checkSV(sv_ori, sv_rpl):
f_cpy.maker.fgraph.name = name
return f_cpy

def _restore_defaults(self):
for i, (required, refeed, value) in enumerate(self.defaults):
if refeed:
if isinstance(value, Container):
value = value.storage[0]
self[i] = value

def __call__(self, *args, **kwargs):
"""
Evaluates value of a function on given arguments.
Expand Down Expand Up @@ -805,52 +829,43 @@ def __call__(self, *args, **kwargs):
List of outputs on indices/keys from ``output_subset`` or all of them,
if ``output_subset`` is not passed.
"""

def restore_defaults():
for i, (required, refeed, value) in enumerate(self.defaults):
if refeed:
if isinstance(value, Container):
value = value.storage[0]
self[i] = value

input_storage = self.input_storage
profile = self.profile
t0 = time.perf_counter()

if profile:
t0 = time.perf_counter()

output_subset = kwargs.pop("output_subset", None)
if output_subset is not None and self.output_keys is not None:
output_subset = [self.output_keys.index(key) for key in output_subset]

# Reinitialize each container's 'provided' counter
if self.trust_input:
i = 0
for arg in args:
s = self.input_storage[i]
s.storage[0] = arg
i += 1
for arg_container, arg in zip(input_storage, args, strict=False):
arg_container.storage[0] = arg
else:
for c in self.input_storage:
c.provided = 0
for arg_container in input_storage:
arg_container.provided = 0

if len(args) + len(kwargs) > len(self.input_storage):
if len(args) + len(kwargs) > len(input_storage):
raise TypeError("Too many parameter passed to pytensor function")

# Set positional arguments
i = 0
for arg in args:
# TODO: provide a option for skipping the filter if we really
# want speed.
s = self.input_storage[i]
# see this emails for a discuation about None as input
for arg_container, arg in zip(input_storage, args, strict=False):
# See discussion about None as input
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
if arg is None:
s.storage[0] = arg
arg_container.storage[0] = arg
else:
try:
s.storage[0] = s.type.filter(
arg, strict=s.strict, allow_downcast=s.allow_downcast
arg_container.storage[0] = arg_container.type.filter(
arg,
strict=arg_container.strict,
allow_downcast=arg_container.allow_downcast,
)

except Exception as e:
i = input_storage.index(arg_container)
function_name = "pytensor function"
argument_name = "argument"
if self.name:
Expand All @@ -875,93 +890,74 @@ def restore_defaults():
+ function_name
+ f" at index {int(i)} (0-based). {where}"
) + e.args
restore_defaults()
self._restore_defaults()
raise
s.provided += 1
i += 1
arg_container.provided += 1

# Set keyword arguments
if kwargs: # for speed, skip the items for empty kwargs
for k, arg in kwargs.items():
self[k] = arg

if (
not self.trust_input
and
# The getattr is only needed for old pickle
getattr(self, "_check_for_aliased_inputs", True)
):
if not self.trust_input:
# Collect aliased inputs among the storage space
args_share_memory = []
for i in range(len(self.input_storage)):
i_var = self.maker.inputs[i].variable
i_val = self.input_storage[i].storage[0]
if hasattr(i_var.type, "may_share_memory"):
is_aliased = False
for j in range(len(args_share_memory)):
group_j = zip(
[
self.maker.inputs[k].variable
for k in args_share_memory[j]
],
[
self.input_storage[k].storage[0]
for k in args_share_memory[j]
],
)
for potential_group in self._potential_aliased_input_groups:
args_share_memory: list[list[int]] = []
for i in potential_group:
i_type = self.maker.inputs[i].variable.type
i_val = input_storage[i].storage[0]

# Check if value is aliased with any of the values in one of the groups
for j_group in args_share_memory:
if any(
(
var.type is i_var.type
and var.type.may_share_memory(val, i_val)
)
for (var, val) in group_j
i_type.may_share_memory(input_storage[j].storage[0], i_val)
for j in j_group
):
is_aliased = True
args_share_memory[j].append(i)
j_group.append(i)
break

if not is_aliased:
else: # no break
# Create a new group
args_share_memory.append([i])

# Check for groups of more than one argument that share memory
for group in args_share_memory:
if len(group) > 1:
# copy all but the first
for j in group[1:]:
self.input_storage[j].storage[0] = copy.copy(
self.input_storage[j].storage[0]
)
# Check for groups of more than one argument that share memory
for group in args_share_memory:
if len(group) > 1:
# copy all but the first
for i in group[1:]:
input_storage[i].storage[0] = copy.copy(
input_storage[i].storage[0]
)

# Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit.
if not self.trust_input:
for c in self.input_storage:
if c.required and not c.provided:
restore_defaults()
# Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit.
for arg_container in input_storage:
if arg_container.required and not arg_container.provided:
self._restore_defaults()
raise TypeError(
f"Missing required input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
f"Missing required input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
)
if c.provided > 1:
restore_defaults()
if arg_container.provided > 1:
self._restore_defaults()
raise TypeError(
f"Multiple values for input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
f"Multiple values for input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
)
if c.implicit and c.provided > 0:
restore_defaults()
if arg_container.implicit and arg_container.provided > 0:
self._restore_defaults()
raise TypeError(
f"Tried to provide value for implicit input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
f"Tried to provide value for implicit input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
)

# Do the actual work
t0_fn = time.perf_counter()
if profile:
t0_fn = time.perf_counter()
try:
outputs = (
self.vm()
if output_subset is None
else self.vm(output_subset=output_subset)
)
except Exception:
restore_defaults()
self._restore_defaults()
if hasattr(self.vm, "position_of_error"):
# this is a new vm-provided function or c linker
# they need this because the exception manipulation
Expand All @@ -979,26 +975,24 @@ def restore_defaults():
# old-style linkers raise their own exceptions
raise

dt_fn = time.perf_counter() - t0_fn
self.maker.mode.fn_time += dt_fn
if profile:
dt_fn = time.perf_counter() - t0_fn
self.maker.mode.fn_time += dt_fn
profile.vm_call_time += dt_fn

# Retrieve the values that were computed
if outputs is None:
outputs = [x.data for x in self.output_storage]
assert len(outputs) == len(self.output_storage)

# Remove internal references to required inputs.
# These cannot be re-used anyway.
for c in self.input_storage:
if c.required:
c.storage[0] = None
for arg_container in input_storage:
if arg_container.required:
arg_container.storage[0] = None

# if we are allowing garbage collection, remove the
# output reference from the internal storage cells
if getattr(self.vm, "allow_gc", False):
assert len(self.output_storage) == len(self.maker.fgraph.outputs)
for o_container, o_variable in zip(
self.output_storage, self.maker.fgraph.outputs
):
Expand All @@ -1007,37 +1001,31 @@ def restore_defaults():
# WARNING: This circumvents the 'readonly' attribute in x
o_container.storage[0] = None

# TODO: Get rid of this and `expanded_inputs`, since all the VMs now
# perform the updates themselves
if getattr(self.vm, "need_update_inputs", True):
# Update the inputs that have an update function
for input, storage in reversed(
list(zip(self.maker.expanded_inputs, self.input_storage))
list(zip(self.maker.expanded_inputs, input_storage))
):
if input.update is not None:
storage.data = outputs.pop()
else:
outputs = outputs[: self.n_returned_outputs]

# Put default values back in the storage
restore_defaults()
#
# NOTE: This logic needs to be replicated in
# scan.
# grep for 'PROFILE_CODE'
#

dt_call = time.perf_counter() - t0
pytensor.compile.profiling.total_fct_exec_time += dt_call
self.maker.mode.call_time += dt_call
self._restore_defaults()

if profile:
dt_call = time.perf_counter() - t0
pytensor.compile.profiling.total_fct_exec_time += dt_call
self.maker.mode.call_time += dt_call
profile.fct_callcount += 1
profile.fct_call_time += dt_call
if hasattr(self.vm, "update_profile"):
self.vm.update_profile(profile)
if profile.ignore_first_call:
profile.reset()
profile.ignore_first_call = False

if self.return_none:
return None
elif self.unpack_single and len(outputs) == 1 and output_subset is None:
Expand Down
3 changes: 0 additions & 3 deletions pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,6 @@ def fiter_variable(self, other):
" a symbolic placeholder."
)

def may_share_memory(a, b):
return False

def value_eq(a, b, force_same_dtype=True):
raise AssertionError(
"If you're assigning to a DisconnectedType you're"
Expand Down
Loading

0 comments on commit d77f26c

Please sign in to comment.