Skip to content

Commit

Permalink
adopt KeyPath API in nonstrict mode (#1669)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1669

X-link: pytorch/pytorch#118609

This PR rewrites two paths to use the newly-added keypaths API in pytree:
First: we were hand-rolling a tree_map during fakification because we wanted to track sources. This PR uses keypaths instead, which can do the same thing without needing custom code.

Second: our constraint error formatting was referencing placeholder names in error messages. These placeholder names are not otherwise user-visible, so they are super confusing to users (e.g. "which input does arg1_3 correspond to?"). This diff uses the `keystr` API to format the error message.

This necessitated some small refactors—generating the keystr is expensive so doing it in an f-string was very bad.

It can also be further improved—we can inspect the signature so that instead of `*args[0]` we can give people the actual argument name, which would be the ideal UX. But leaving that for later.
ghstack-source-id: 213477246
exported-using-ghexport
bypass-github-pytorch-ci-checks

Reviewed By: avikchaudhuri, zhxchen17

Differential Revision: D53139358

fbshipit-source-id: de43133060eaf9de1fd61aaba6dd712c6183daa0
  • Loading branch information
suo authored and facebook-github-bot committed Jan 30, 2024
1 parent 0e153c9 commit af86a0f
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
# LICENSE file in the root directory of this source tree.

import abc
import json
import operator
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torch.autograd.profiler import record_function
from torch.fx._pytree import register_pytree_flatten_spec, TreeSpec
from torch.utils._pytree import register_pytree_node
from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node

from torchrec.streamable import Pipelineable

Expand Down Expand Up @@ -575,6 +574,14 @@ def _jt_flatten(
return [getattr(t, a) for a in JaggedTensor._fields], None


def _jt_flatten_with_keys(
t: JaggedTensor,
) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], None]:
values, context = _jt_flatten(t)
# pyre can't tell that GetAttrKey implements the KeyEntry protocol
return [(GetAttrKey(k), v) for k, v in zip(JaggedTensor._fields, values)], context # pyre-ignore[7]


def _jt_unflatten(values: List[Optional[torch.Tensor]], context: None) -> JaggedTensor:
return JaggedTensor(*values)

Expand All @@ -583,7 +590,9 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten
return [getattr(t, a) for a in JaggedTensor._fields]


register_pytree_node(JaggedTensor, _jt_flatten, _jt_unflatten)
register_pytree_node(
JaggedTensor, _jt_flatten, _jt_unflatten, flatten_with_keys_fn=_jt_flatten_with_keys
)
register_pytree_flatten_spec(JaggedTensor, _jt_flatten_spec)


Expand Down Expand Up @@ -1988,6 +1997,16 @@ def _kjt_flatten(
return [getattr(t, a) for a in KeyedJaggedTensor._fields], t._keys


def _kjt_flatten_with_keys(
t: KeyedJaggedTensor,
) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], List[str]]:
values, context = _kjt_flatten(t)
# pyre can't tell that GetAttrKey implements the KeyEntry protocol
return [ # pyre-ignore[7]
(GetAttrKey(k), v) for k, v in zip(KeyedJaggedTensor._fields, values)
], context


def _kjt_unflatten(
values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys
) -> KeyedJaggedTensor:
Expand All @@ -2000,7 +2019,12 @@ def _kjt_flatten_spec(
return [getattr(t, a) for a in KeyedJaggedTensor._fields]


register_pytree_node(KeyedJaggedTensor, _kjt_flatten, _kjt_unflatten)
register_pytree_node(
KeyedJaggedTensor,
_kjt_flatten,
_kjt_unflatten,
flatten_with_keys_fn=_kjt_flatten_with_keys,
)
register_pytree_flatten_spec(KeyedJaggedTensor, _kjt_flatten_spec)


Expand Down

0 comments on commit af86a0f

Please sign in to comment.