Skip to content

Commit

Permalink
some tweaks
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Nov 19, 2024
1 parent 92af537 commit e7ca8b7
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 29 deletions.
62 changes: 34 additions & 28 deletions thunder/core/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,42 @@
)


allowed_types = {
FunctionType,
dict,
list,
str,
int,
bool,
tuple,
torch.dtype,
float,
dtypes.floating,
dtypes.bool_,
devices.Device,
torch.memory_format,
type(None),
slice,
complex,
type,
type(Ellipsis),
torch.Size,
torch.finfo,
dtypes.signedinteger,
# FakeTensor type is used for automatic registration of torch ops
torch._subclasses.fake_tensor.FakeTensor,
torch.device,
torch.autograd.function.FunctionCtx,
}


def register_type(typ):
allowed_types.add(typ)


def tree_flatten(args, namespace=OPTREE_NAMESPACE):
if (
type(args)
not in {
FunctionType,
dict,
list,
str,
int,
bool,
tuple,
torch.dtype,
float,
dtypes.floating,
dtypes.bool_,
devices.Device,
torch.memory_format,
type(None),
slice,
complex,
type,
type(Ellipsis),
torch.Size,
torch.finfo,
dtypes.signedinteger,
# FakeTensor type is used for automatic registration of torch ops
torch._subclasses.fake_tensor.FakeTensor,
torch.device,
torch.autograd.function.FunctionCtx,
}
type(args) not in allowed_types
and not isinstance(args, (ProxyInterface))
and not is_likely_from_collections_namedtuple(args)
and not dataclasses.is_dataclass(args)
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace:
# may mark some of the operation's outputs as unused
some_unused = False
for out in bsym.flat_proxy_outs:
if variableify(out) in needed_proxies and producer_map[out] == bsym:
if variableify(out) in needed_proxies and producer_map.get(out, None) == bsym:
needed = True
else:
some_unused = True
Expand Down

0 comments on commit e7ca8b7

Please sign in to comment.