@@ -1589,18 +1589,8 @@ class AttentionBiasSubTensor(torch.Tensor, AttentionBias):
1589
1589
_subtensor : torch .Tensor
1590
1590
1591
1591
@staticmethod
1592
- def __new__ (cls , * , _subtensor = None ):
1593
- if _subtensor is None :
1594
- _subtensor = torch .empty ((0 ,), device = _get_default_bias_device ())
1595
- tensor = torch .Tensor ._make_wrapper_subclass ( # type: ignore[attr-defined]
1596
- cls ,
1597
- [],
1598
- device = _subtensor .device ,
1599
- dtype = _subtensor .dtype ,
1600
- requires_grad = False ,
1601
- )
1602
- tensor ._subtensor = _subtensor
1603
- return tensor
1592
+ def __new__ (cls , * , _subtensor = None , device = None , ** kwargs ):
1593
+ raise NotImplementedError ()
1604
1594
1605
1595
def __init__ (self , * args , ** kwargs ) -> None :
1606
1596
super ().__init__ ()
@@ -1667,6 +1657,24 @@ class LowerTriangularMask(AttentionBiasSubTensor):
1667
1657
1668
1658
HOLDS_DENSE_TENSOR = False
1669
1659
1660
+ @staticmethod
1661
+ def __new__ (cls , * , _subtensor = None , device = "cpu" , ** kwargs ):
1662
+ """
1663
+ Note: create on CPU by default to avoid initializing CUDA context
1664
+ by mistake.
1665
+ """
1666
+ if _subtensor is None :
1667
+ _subtensor = torch .empty ((0 ,), device = device )
1668
+ tensor = torch .Tensor ._make_wrapper_subclass ( # type: ignore[attr-defined]
1669
+ cls ,
1670
+ [],
1671
+ device = _subtensor .device ,
1672
+ dtype = _subtensor .dtype ,
1673
+ requires_grad = False ,
1674
+ )
1675
+ tensor ._subtensor = _subtensor
1676
+ return tensor
1677
+
1670
1678
def materialize (
1671
1679
self ,
1672
1680
shape : Tuple [int , ...],
0 commit comments