Skip to content

Commit

Permalink
revert pytree changes
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Nov 21, 2024
1 parent e7ca8b7 commit dd075db
Showing 1 changed file with 28 additions and 34 deletions.
62 changes: 28 additions & 34 deletions thunder/core/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,42 +30,36 @@
)


allowed_types = {
FunctionType,
dict,
list,
str,
int,
bool,
tuple,
torch.dtype,
float,
dtypes.floating,
dtypes.bool_,
devices.Device,
torch.memory_format,
type(None),
slice,
complex,
type,
type(Ellipsis),
torch.Size,
torch.finfo,
dtypes.signedinteger,
# FakeTensor type is used for automatic registration of torch ops
torch._subclasses.fake_tensor.FakeTensor,
torch.device,
torch.autograd.function.FunctionCtx,
}


def register_type(typ):
allowed_types.add(typ)


def tree_flatten(args, namespace=OPTREE_NAMESPACE):
if (
type(args) not in allowed_types
type(args)
not in {
FunctionType,
dict,
list,
str,
int,
bool,
tuple,
torch.dtype,
float,
dtypes.floating,
dtypes.bool_,
devices.Device,
torch.memory_format,
type(None),
slice,
complex,
type,
type(Ellipsis),
torch.Size,
torch.finfo,
dtypes.signedinteger,
# FakeTensor type is used for automatic registration of torch ops
torch._subclasses.fake_tensor.FakeTensor,
torch.device,
torch.autograd.function.FunctionCtx,
}
and not isinstance(args, (ProxyInterface))
and not is_likely_from_collections_namedtuple(args)
and not dataclasses.is_dataclass(args)
Expand Down

0 comments on commit dd075db

Please sign in to comment.