diff --git a/src/gfn/utils/training.py b/src/gfn/utils/training.py index 4e6359e..8206bc2 100644 --- a/src/gfn/utils/training.py +++ b/src/gfn/utils/training.py @@ -1,5 +1,5 @@ from collections import Counter -from typing import Dict, Optional +from typing import Dict, Optional, Tuple import torch from torchtyping import TensorType as TT