From 1a3af4b232a2ef5324e99e5b81755e00841f61a4 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 21 Nov 2024 14:31:07 +0100 Subject: [PATCH] Reduce overhead of Function call --- pytensor/compile/function/types.py | 152 +++++++++++++++-------------- 1 file changed, 77 insertions(+), 75 deletions(-) diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index 53306d52dc..e2e612ac93 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -393,6 +393,8 @@ def __init__( assert len(self.input_storage) == len(self.maker.fgraph.inputs) assert len(self.output_storage) == len(self.maker.fgraph.outputs) + self.has_defaults = any(refeed for _, refeed, _ in self.defaults) + # Group indexes of inputs that are potentially aliased to each other # Note: Historically, we only worried about aliasing inputs if they belonged to the same type, # even though there could be two distinct types that use the same kinds of underlying objects. @@ -540,14 +542,40 @@ def __contains__(self, item): self._value = ValueAttribute() self._container = ContainerAttribute() - # TODO: Get rid of all this `expanded_inputs` nonsense - assert len(self.maker.expanded_inputs) == len(self.input_storage) + update_storage = [ + container + for inp, container in zip( + self.maker.expanded_inputs, input_storage, strict=True + ) + if inp.update is not None + ] + # Updates are the last inner outputs that are not returned by Function.__call__ + self.n_returned_outputs = len(self.output_storage) - len(update_storage) + + # Function.__call__ is responsible for updating the inputs, unless the vm promises to do it itself + self.update_input_storage: tuple[int, Container] = () + if getattr(vm, "need_update_inputs", True): + self.update_input_storage = tuple( + zip( + range(self.n_returned_outputs, len(output_storage)), + update_storage, + strict=True, + ) + ) - # This is used only when `vm.need_update_inputs` is `False`, because - # we're using one of the VM objects and it is putting updates back into - # the input containers all by itself. - self.n_returned_outputs = len(self.output_storage) - sum( - inp.update is not None for inp in self.maker.expanded_inputs + # In every function call we place inputs in the input_storage, and the vm places outputs in the output_storage + # After the call, we want to erase (some of) these references, to allow Python to GC them if unused + # Required input containers are the non-default inputs, must always be provided again, so we GC them + self.clear_input_storage_data = tuple( + container.storage for container in input_storage if container.required + ) + # This is only done when `vm.allow_gc` is True, which can change at runtime. + self.clear_output_storage_data = tuple( + container.storage + for container, variable in zip( + self.output_storage, self.maker.fgraph.outputs, strict=True + ) + if variable.owner is not None # Not a constant output ) for node in self.maker.fgraph.apply_nodes: @@ -747,7 +775,7 @@ def checkSV(sv_ori, sv_rpl): elif isinstance(profile, str): profile = pytensor.compile.profiling.ProfileStats(message=profile) - f_cpy = maker.__class__( + f_cpy = type(maker)( inputs=ins, outputs=outs, fgraph=fg_cpy, @@ -765,6 +793,8 @@ def checkSV(sv_ori, sv_rpl): # check that. accept_inplace=True, no_fgraph_prep=True, + output_keys=maker.output_keys, + name=name, ).create(input_storage, storage_map=new_storage_map) for in_ori, in_cpy, ori, cpy in zip( @@ -797,8 +827,6 @@ def checkSV(sv_ori, sv_rpl): f_cpy.trust_input = self.trust_input f_cpy.unpack_single = self.unpack_single - f_cpy.name = name - f_cpy.maker.fgraph.name = name return f_cpy def _restore_defaults(self): @@ -808,7 +836,7 @@ def _restore_defaults(self): value = value.storage[0] self[i] = value - def __call__(self, *args, **kwargs): + def __call__(self, *args, output_subset=None, **kwargs): """ Evaluates value of a function on given arguments. @@ -836,20 +864,21 @@ def __call__(self, *args, **kwargs): List of outputs on indices/keys from ``output_subset`` or all of them, if ``output_subset`` is not passed. """ + trust_input = self.trust_input input_storage = self.input_storage + vm = self.vm profile = self.profile if profile: t0 = time.perf_counter() - output_subset = kwargs.pop("output_subset", None) if output_subset is not None: warnings.warn("output_subset is deprecated.", FutureWarning) if self.output_keys is not None: output_subset = [self.output_keys.index(key) for key in output_subset] # Reinitialize each container's 'provided' counter - if self.trust_input: + if trust_input: for arg_container, arg in zip(input_storage, args, strict=False): arg_container.storage[0] = arg else: @@ -908,7 +937,7 @@ def __call__(self, *args, **kwargs): for k, arg in kwargs.items(): self[k] = arg - if not self.trust_input: + if not trust_input: # Collect aliased inputs among the storage space for potential_group in self._potential_aliased_input_groups: args_share_memory: list[list[int]] = [] @@ -960,11 +989,7 @@ def __call__(self, *args, **kwargs): if profile: t0_fn = time.perf_counter() try: - outputs = ( - self.vm() - if output_subset is None - else self.vm(output_subset=output_subset) - ) + outputs = vm() if output_subset is None else vm(output_subset=output_subset) except Exception: self._restore_defaults() if hasattr(self.vm, "position_of_error"): @@ -991,39 +1016,23 @@ def __call__(self, *args, **kwargs): # Retrieve the values that were computed if outputs is None: - outputs = [x.data for x in self.output_storage] - - # Remove internal references to required inputs. - # These cannot be re-used anyway. - for arg_container in input_storage: - if arg_container.required: - arg_container.storage[0] = None - - # if we are allowing garbage collection, remove the - # output reference from the internal storage cells - if getattr(self.vm, "allow_gc", False): - # strict=False because we are in a hot loop - for o_container, o_variable in zip( - self.output_storage, self.maker.fgraph.outputs, strict=False - ): - if o_variable.owner is not None: - # this node is the variable of computation - # WARNING: This circumvents the 'readonly' attribute in x - o_container.storage[0] = None - - if getattr(self.vm, "need_update_inputs", True): - # Update the inputs that have an update function - # strict=False because we are in a hot loop - for input, storage in reversed( - list(zip(self.maker.expanded_inputs, input_storage, strict=False)) - ): - if input.update is not None: - storage.data = outputs.pop() - else: - outputs = outputs[: self.n_returned_outputs] + outputs = [x.storage[0] for x in self.output_storage] + + # Set updates and filter them out from the returned outputs + for i, input_storage in self.update_input_storage: + input_storage.storage[0] = outputs[i] + outputs = outputs[: self.n_returned_outputs] + + # Remove input and output values from storage data + for storage_data in self.clear_input_storage_data: + storage_data[0] = None + if getattr(vm, "allow_gc", False): + for storage_data in self.clear_output_storage_data: + storage_data[0] = None # Put default values back in the storage - self._restore_defaults() + if self.has_defaults: + self._restore_defaults() if profile: dt_call = time.perf_counter() - t0 @@ -1031,33 +1040,29 @@ def __call__(self, *args, **kwargs): self.maker.mode.call_time += dt_call profile.fct_callcount += 1 profile.fct_call_time += dt_call - if hasattr(self.vm, "update_profile"): - self.vm.update_profile(profile) + if hasattr(vm, "update_profile"): + vm.update_profile(profile) if profile.ignore_first_call: profile.reset() profile.ignore_first_call = False if self.return_none: return None - elif self.unpack_single and len(outputs) == 1 and output_subset is None: - return outputs[0] - else: - if self.output_keys is not None: - assert len(self.output_keys) == len(outputs) - if output_subset is None: - # strict=False because we are in a hot loop - return dict(zip(self.output_keys, outputs, strict=False)) - else: - return { - self.output_keys[index]: outputs[index] - for index in output_subset - } + if output_subset is not None: + outputs = [outputs[i] for i in output_subset] - if output_subset is None: - return outputs + if self.output_keys is None: + if self.unpack_single: + [out] = outputs + return out else: - return [outputs[i] for i in output_subset] + return outputs + else: + output_keys = self.output_keys + if output_subset is not None: + output_keys = [output_keys[i] for i in output_subset] + return dict(zip(output_keys, outputs, strict=True)) value = property( lambda self: self._value, @@ -1077,9 +1082,10 @@ def free(self): # 1.no allow_gc return False # 2.has allow_gc, if allow_gc is False, return True if not getattr(self.vm, "allow_gc", True): - for key in self.vm.storage_map: - if not isinstance(key, Constant): - self.vm.storage_map[key][0] = None + storage_map = self.vm.storage_map + for key, value in storage_map.items(): + if key.owner is not None: # Not a constant + value[0] = None for node in self.nodes_with_inner_function: if hasattr(node.fn, "free"): @@ -1091,10 +1097,6 @@ def get_shared(self): """ return [i.variable for i in self.maker.inputs if i.implicit] - def sync_shared(self): - # NOTE: sync was needed on old gpu backend - pass - def dprint(self, **kwargs): """Debug print itself