diff --git a/hezar/embeddings/embedding.py b/hezar/embeddings/embedding.py index 9c5495ce..6e680a06 100644 --- a/hezar/embeddings/embedding.py +++ b/hezar/embeddings/embedding.py @@ -17,6 +17,8 @@ ) from ..utils import Logger, get_lib_version, verify_dependencies +from packaging import version +import importlib.metadata logger = Logger(__name__) @@ -28,9 +30,12 @@ # Check if the right combo of gensim/numpy versions are installed def _verify_gensim_installation(): + numpy_version = importlib.metadata.version("numpy") + gensim_version = importlib.metadata.version("gensim") + if ( - not get_lib_version("numpy").startswith(REQUIRED_NUMPY_VERSION) - or not get_lib_version("gensim").startswith(REQUIRED_GENSIM_VERSION) + version.parse(numpy_version) < version.parse(REQUIRED_NUMPY_VERSION) + or version.parse(gensim_version) < version.parse(REQUIRED_GENSIM_VERSION) ): raise ImportError( f"The embeddings module in this version of Hezar, requires a combo of numpy>={REQUIRED_NUMPY_VERSION} and "