From f870d374c36e9d6b5b8409cec21eb6198874cec6 Mon Sep 17 00:00:00 2001
From: Estelle Scifo <stellasia@users.noreply.github.com>
Date: Tue, 17 Sep 2024 15:30:49 +0200
Subject: [PATCH] Make pygraphviz really optional  (#137)

* Make pygraphviz optional in code

* Remove unused import

* Update CHANGELOG.md

* Fix broken import

* Add link to pygraphviz installation page in README
---
 CHANGELOG.md                                         |  6 ++++--
 README.md                                            |  7 +++++++
 src/neo4j_graphrag/embeddings/vertexai.py            |  2 +-
 src/neo4j_graphrag/experimental/pipeline/pipeline.py | 12 +++++++++++-
 tests/unit/experimental/pipeline/test_pipeline.py    | 11 ++++++++++-
 5 files changed, 33 insertions(+), 5 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5cd194f8f..14f3477ae 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,12 +7,14 @@
 
 ### Added
 - Introduction page to the documentation content tree.
-
-### Added
 - Introduced a new Vertex AI embeddings class for generating text embeddings using Vertex AI.
 - Updated documentation to include OpenAI and Vertex AI embeddings classes.
 - Added google-cloud-aiplatform as an optional dependency for Vertex AI embeddings.
 
+### Fixed
+- Make `pygraphviz` an optional dependency - it is now only required when calling `pipeline.draw`.
+
+
 ## 0.6.2
 
 ### Fixed
diff --git a/README.md b/README.md
index 9f41d403d..d9f56dfba 100644
--- a/README.md
+++ b/README.md
@@ -28,6 +28,13 @@ To install the latest stable version, use:
 pip install neo4j-graphrag
 ```
 
+### Optional dependencies
+
+#### pygraphviz
+
+`pygraphviz` is used for visualizing pipelines.
+Follow installation instructions [here](https://pygraphviz.github.io/documentation/stable/install.html).
+
 ## Examples
 
 ### Creating a vector index
diff --git a/src/neo4j_graphrag/embeddings/vertexai.py b/src/neo4j_graphrag/embeddings/vertexai.py
index 5b25e3c64..b7b4d64f3 100644
--- a/src/neo4j_graphrag/embeddings/vertexai.py
+++ b/src/neo4j_graphrag/embeddings/vertexai.py
@@ -17,7 +17,7 @@
 
 from typing import Any
 
-from neo4j_genai.embedder import Embedder
+from neo4j_graphrag.embedder import Embedder
 
 try:
     from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py
index 97b4b53f0..5221ec9b9 100644
--- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py
+++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py
@@ -24,7 +24,11 @@
 from timeit import default_timer
 from typing import Any, AsyncGenerator, Optional
 
-import pygraphviz as pgv
+try:
+    import pygraphviz as pgv
+except ImportError:
+    pyg = None
+
 from pydantic import BaseModel, Field
 
 from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
@@ -386,6 +390,12 @@ def draw(
         G.draw(path)
 
     def get_pygraphviz_graph(self, hide_unused_outputs: bool = True) -> pgv.AGraph:
+        if pgv is None:
+            raise ImportError(
+                "Could not import pygraphviz. "
+                "Follow installation instruction in pygraphviz documentation "
+                "to get it up and running on your system."
+            )
         self.validate_parameter_mapping()
         G = pgv.AGraph(strict=False, directed=True)
         # create a node for each component
diff --git a/tests/unit/experimental/pipeline/test_pipeline.py b/tests/unit/experimental/pipeline/test_pipeline.py
index 97c8b481f..8191da61e 100644
--- a/tests/unit/experimental/pipeline/test_pipeline.py
+++ b/tests/unit/experimental/pipeline/test_pipeline.py
@@ -17,7 +17,7 @@
 import asyncio
 import tempfile
 from unittest import mock
-from unittest.mock import AsyncMock, call
+from unittest.mock import AsyncMock, call, patch
 
 import pytest
 from neo4j_graphrag.experimental.pipeline import Component, Pipeline
@@ -395,3 +395,12 @@ def test_pipeline_draw() -> None:
     pipe.draw(t.name)
     content = t.file.read()
     assert len(content) > 0
+
+
+@patch("neo4j_graphrag.experimental.pipeline.pipeline.pgv", None)
+def test_pipeline_draw_missing_pygraphviz_dep() -> None:
+    pipe = Pipeline()
+    pipe.add_component(ComponentAdd(), "add")
+    t = tempfile.NamedTemporaryFile()
+    with pytest.raises(ImportError):
+        pipe.draw(t.name)