diff --git a/thunder/core/pytree.py b/thunder/core/pytree.py index 6a3d3d3a7..dddb04304 100644 --- a/thunder/core/pytree.py +++ b/thunder/core/pytree.py @@ -7,7 +7,7 @@ import thunder.core.dtypes as dtypes import thunder.core.devices as devices from thunder.core.baseutils import ProxyInterface -from types import FunctionType + OPTREE_NAMESPACE = "thunder" @@ -30,36 +30,43 @@ ) +allowed_types = { + FunctionType, + dict, + list, + str, + int, + bool, + tuple, + NamedTuple, + torch.dtype, + float, + dtypes.floating, + dtypes.bool_, + devices.Device, + torch.memory_format, + type(None), + slice, + complex, + type, + type(Ellipsis), + torch.Size, + torch.finfo, + dtypes.signedinteger, + # FakeTensor type is used for automatic registration of torch ops + torch._subclasses.fake_tensor.FakeTensor, + torch.device, + torch.autograd.function.FunctionCtx, +} + + +def register_type(typ): + allowed_types.add(typ) + + def tree_flatten(args, namespace=OPTREE_NAMESPACE): if ( - type(args) - not in { - FunctionType, - dict, - list, - str, - int, - bool, - tuple, - torch.dtype, - float, - dtypes.floating, - dtypes.bool_, - devices.Device, - torch.memory_format, - type(None), - slice, - complex, - type, - type(Ellipsis), - torch.Size, - torch.finfo, - dtypes.signedinteger, - # FakeTensor type is used for automatic registration of torch ops - torch._subclasses.fake_tensor.FakeTensor, - torch.device, - torch.autograd.function.FunctionCtx, - } + type(args) not in allowed_types and not isinstance(args, (ProxyInterface)) and not dataclasses.is_dataclass(args) and not type(args).__module__.startswith("torch.return_types") diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index e21ca4b28..7b7a97469 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -137,7 +137,7 @@ def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace: # may mark some of the operation's outputs as unused some_unused = False for out in bsym.flat_proxy_outs: - if variableify(out) in needed_proxies and producer_map[out] == bsym: + if variableify(out) in needed_proxies and producer_map.get(out, None) == bsym: needed = True else: some_unused = True