diff --git a/pytorchvideo/layers/attention.py b/pytorchvideo/layers/attention.py index 5aa4df0..d0f0b80 100644 --- a/pytorchvideo/layers/attention.py +++ b/pytorchvideo/layers/attention.py @@ -4,6 +4,7 @@ import numpy import torch +import torch.fx import torch.nn as nn from torch.nn.common_types import _size_3_t