diff --git a/tests/test_model.py b/tests/test_model.py index 3be8137..a4cdfaa 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -47,6 +47,9 @@ def raise_for_status(self): "giza_actions.model.GizaModel._parse_cairo_response", return_value=np.array([[1, 2], [3, 4]], dtype=np.uint32), ) +@patch( + "giza_actions.model.VersionsClient.download_original", return_value=b"some bytes" +) def test_predict_success(*args): model = GizaModel(id=50, version=2) @@ -86,6 +89,9 @@ def test_predict_success(*args): "giza_actions.model.GizaModel._parse_cairo_response", return_value=np.array([[1, 2], [3, 4]], dtype=np.uint32), ) +@patch( + "giza_actions.model.VersionsClient.download_original", return_value=b"some bytes" +) def test_predict_success_with_file(*args): model = GizaModel(id=50, version=2) @@ -121,6 +127,9 @@ def test_predict_success_with_file(*args): @patch("giza_actions.model.GizaModel._get_output_dtype") @patch("giza_actions.model.GizaModel._retrieve_uri") @patch("giza_actions.model.GizaModel._get_endpoint_id", return_value=1) +@patch( + "giza_actions.model.VersionsClient.download_original", return_value=b"some bytes" +) def test_cache_implementation(*args): model = GizaModel(id=50, version=2)