From 2e4931121d532cf6871b2de68cd064774febe825 Mon Sep 17 00:00:00 2001 From: Jon Besga Date: Mon, 13 Jan 2025 23:22:26 +0000 Subject: [PATCH] Override neo4j user agent when driver is injected --- .../build_graph/components/writers/custom_writer.py | 3 ++- src/neo4j_graphrag/__init__.py | 6 ++++++ src/neo4j_graphrag/experimental/components/kg_writer.py | 3 ++- .../experimental/components/neo4j_reader.py | 3 ++- src/neo4j_graphrag/experimental/components/resolver.py | 3 ++- src/neo4j_graphrag/retrievers/base.py | 3 ++- src/neo4j_graphrag/utils/telemetry.py | 8 ++++++++ 7 files changed, 24 insertions(+), 5 deletions(-) create mode 100644 src/neo4j_graphrag/utils/telemetry.py diff --git a/examples/customize/build_graph/components/writers/custom_writer.py b/examples/customize/build_graph/components/writers/custom_writer.py index 2c64bd16..24d6c59d 100644 --- a/examples/customize/build_graph/components/writers/custom_writer.py +++ b/examples/customize/build_graph/components/writers/custom_writer.py @@ -6,11 +6,12 @@ from neo4j_graphrag.experimental.components.kg_writer import KGWriter, KGWriterModel from neo4j_graphrag.experimental.components.types import LexicalGraphConfig, Neo4jGraph from pydantic import validate_call +from neo4j_graphrag.utils import telemetry class MyWriter(KGWriter): def __init__(self, driver: neo4j.Driver) -> None: - self.driver = driver + self.driver = telemetry.override_user_agent(driver) @validate_call async def run( diff --git a/src/neo4j_graphrag/__init__.py b/src/neo4j_graphrag/__init__.py index c0199c14..a65a7134 100644 --- a/src/neo4j_graphrag/__init__.py +++ b/src/neo4j_graphrag/__init__.py @@ -12,3 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from importlib.metadata import version, PackageNotFoundError + +try: + __version__ = version("neo4j-graphrag") +except PackageNotFoundError: + __version__ = "0.0.0" diff --git a/src/neo4j_graphrag/experimental/components/kg_writer.py b/src/neo4j_graphrag/experimental/components/kg_writer.py index fad24b7a..965ef784 100644 --- a/src/neo4j_graphrag/experimental/components/kg_writer.py +++ b/src/neo4j_graphrag/experimental/components/kg_writer.py @@ -34,6 +34,7 @@ UPSERT_RELATIONSHIP_QUERY, UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, ) +from neo4j_graphrag.utils import telemetry logger = logging.getLogger(__name__) @@ -113,7 +114,7 @@ def __init__( neo4j_database: Optional[str] = None, batch_size: int = 1000, ): - self.driver = driver + self.driver = telemetry.override_user_agent(driver) self.neo4j_database = neo4j_database self.batch_size = batch_size self.is_version_5_23_or_above = self._check_if_version_5_23_or_above() diff --git a/src/neo4j_graphrag/experimental/components/neo4j_reader.py b/src/neo4j_graphrag/experimental/components/neo4j_reader.py index 352ed1a6..cf7771ed 100644 --- a/src/neo4j_graphrag/experimental/components/neo4j_reader.py +++ b/src/neo4j_graphrag/experimental/components/neo4j_reader.py @@ -25,6 +25,7 @@ TextChunks, ) from neo4j_graphrag.experimental.pipeline import Component +from neo4j_graphrag.utils import telemetry class Neo4jChunkReader(Component): @@ -58,7 +59,7 @@ def __init__( fetch_embeddings: bool = False, neo4j_database: Optional[str] = None, ): - self.driver = driver + self.driver = telemetry.override_user_agent(driver) self.fetch_embeddings = fetch_embeddings self.neo4j_database = neo4j_database diff --git a/src/neo4j_graphrag/experimental/components/resolver.py b/src/neo4j_graphrag/experimental/components/resolver.py index f2da0bff..1f1c6443 100644 --- a/src/neo4j_graphrag/experimental/components/resolver.py +++ b/src/neo4j_graphrag/experimental/components/resolver.py @@ -19,6 +19,7 @@ from neo4j_graphrag.experimental.components.types import ResolutionStats from neo4j_graphrag.experimental.pipeline import Component +from neo4j_graphrag.utils import telemetry class EntityResolver(Component, abc.ABC): @@ -34,7 +35,7 @@ def __init__( driver: neo4j.Driver, filter_query: Optional[str] = None, ) -> None: - self.driver = driver + self.driver = telemetry.override_user_agent(driver) self.filter_query = filter_query @abc.abstractmethod diff --git a/src/neo4j_graphrag/retrievers/base.py b/src/neo4j_graphrag/retrievers/base.py index 55ae06ef..778a6816 100644 --- a/src/neo4j_graphrag/retrievers/base.py +++ b/src/neo4j_graphrag/retrievers/base.py @@ -24,6 +24,7 @@ from neo4j_graphrag.exceptions import Neo4jVersionError from neo4j_graphrag.types import RawSearchResult, RetrieverResult, RetrieverResultItem +from neo4j_graphrag.utils import telemetry T = ParamSpec("T") P = TypeVar("P") @@ -82,7 +83,7 @@ class Retriever(ABC, metaclass=RetrieverMetaclass): VERIFY_NEO4J_VERSION = True def __init__(self, driver: neo4j.Driver, neo4j_database: Optional[str] = None): - self.driver = driver + self.driver = telemetry.override_user_agent(driver) self.neo4j_database = neo4j_database if self.VERIFY_NEO4J_VERSION: self._verify_version() diff --git a/src/neo4j_graphrag/utils/telemetry.py b/src/neo4j_graphrag/utils/telemetry.py new file mode 100644 index 00000000..44eb8b36 --- /dev/null +++ b/src/neo4j_graphrag/utils/telemetry.py @@ -0,0 +1,8 @@ +import neo4j +from neo4j_graphrag import __version__ + + +# Override user-agent used by neo4j package so we can measure usage of the package by version +def override_user_agent(driver: neo4j.Driver) -> neo4j.Driver: + driver._pool.pool_config.user_agent = f"neo4j-graphrag-python/v{__version__}" + return driver