Skip to content

Commit 8fa35b5

Browse files
danthe3rdxFormers Bot
authored and
xFormers Bot
committed
Creating a LowerTriangularMask no longer creates a CUDA tensor (fairinternal/xformers#1274)
__original_commit__ = fairinternal/xformers@4a6a2a1
1 parent a2f37f8 commit 8fa35b5

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

77
## [0.0.28.post3] - TBD
8+
### Fixed:
9+
- Creating a `LowerTriangularMask` no longer creates a CUDA tensor
810
### Removed:
911
- Following PyTorch, xFormers no longer builds binaries for conda. Pip is now the only recommended way to get xFormers
1012

xformers/ops/fmha/attn_bias.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -1589,18 +1589,8 @@ class AttentionBiasSubTensor(torch.Tensor, AttentionBias):
15891589
_subtensor: torch.Tensor
15901590

15911591
@staticmethod
1592-
def __new__(cls, *, _subtensor=None):
1593-
if _subtensor is None:
1594-
_subtensor = torch.empty((0,), device=_get_default_bias_device())
1595-
tensor = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
1596-
cls,
1597-
[],
1598-
device=_subtensor.device,
1599-
dtype=_subtensor.dtype,
1600-
requires_grad=False,
1601-
)
1602-
tensor._subtensor = _subtensor
1603-
return tensor
1592+
def __new__(cls, *, _subtensor=None, device=None, **kwargs):
1593+
raise NotImplementedError()
16041594

16051595
def __init__(self, *args, **kwargs) -> None:
16061596
super().__init__()
@@ -1667,6 +1657,24 @@ class LowerTriangularMask(AttentionBiasSubTensor):
16671657

16681658
HOLDS_DENSE_TENSOR = False
16691659

1660+
@staticmethod
1661+
def __new__(cls, *, _subtensor=None, device="cpu", **kwargs):
1662+
"""
1663+
Note: create on CPU by default to avoid initializing CUDA context
1664+
by mistake.
1665+
"""
1666+
if _subtensor is None:
1667+
_subtensor = torch.empty((0,), device=device)
1668+
tensor = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
1669+
cls,
1670+
[],
1671+
device=_subtensor.device,
1672+
dtype=_subtensor.dtype,
1673+
requires_grad=False,
1674+
)
1675+
tensor._subtensor = _subtensor
1676+
return tensor
1677+
16701678
def materialize(
16711679
self,
16721680
shape: Tuple[int, ...],

0 commit comments

Comments
 (0)