From 0cc11902c73a081e9e637bf66be22a4848b3da7c Mon Sep 17 00:00:00 2001 From: Kunal Jain Date: Mon, 10 Jun 2024 16:22:53 +0530 Subject: [PATCH 1/2] fix: Session is not initialized. When using GizaModel with `id` and `version`, it tries to set the session after downloading the model but since the `_output_path` is not set it is unable to download the model and initialize the session. Tested with ``` m = GizaModel(id=766, version=1) m.predict(X) ``` --- giza/agents/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/giza/agents/model.py b/giza/agents/model.py index 8ef9209..b046f79 100644 --- a/giza/agents/model.py +++ b/giza/agents/model.py @@ -92,8 +92,6 @@ def __init__( self.framework = self.version.framework self.uri = self._retrieve_uri() self.endpoint_id = self._get_endpoint_id() - self._cache = Cache(os.path.join(os.getcwd(), "tmp", "cachedir")) - self.session = self._set_session() if output_path is not None: self._output_path = output_path else: @@ -101,6 +99,8 @@ def __init__( tempfile.gettempdir(), f"{self.model_id}_{self.version_id}_{self.model.name}", ) + self._cache = Cache(os.path.join(os.getcwd(), "tmp", "cachedir")) + self.session = self._set_session() self._download_model() def _get_endpoint_id(self) -> int: From e42bbf7b712861b0b08d13b80a15ff4f88f51764 Mon Sep 17 00:00:00 2001 From: Kunal Jain Date: Mon, 10 Jun 2024 16:25:18 +0530 Subject: [PATCH 2/2] Update model.py --- giza/agents/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/giza/agents/model.py b/giza/agents/model.py index b046f79..fdf98fb 100644 --- a/giza/agents/model.py +++ b/giza/agents/model.py @@ -92,6 +92,7 @@ def __init__( self.framework = self.version.framework self.uri = self._retrieve_uri() self.endpoint_id = self._get_endpoint_id() + self._cache = Cache(os.path.join(os.getcwd(), "tmp", "cachedir")) if output_path is not None: self._output_path = output_path else: @@ -99,7 +100,6 @@ def __init__( tempfile.gettempdir(), f"{self.model_id}_{self.version_id}_{self.model.name}", ) - self._cache = Cache(os.path.join(os.getcwd(), "tmp", "cachedir")) self.session = self._set_session() self._download_model()