Skip to content

Commit

Permalink
Override neo4j user agent when driver is injected
Browse files Browse the repository at this point in the history
  • Loading branch information
jonbesga committed Jan 13, 2025
1 parent 4054c46 commit 2e49311
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from neo4j_graphrag.experimental.components.kg_writer import KGWriter, KGWriterModel
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig, Neo4jGraph
from pydantic import validate_call
from neo4j_graphrag.utils import telemetry


class MyWriter(KGWriter):
def __init__(self, driver: neo4j.Driver) -> None:
self.driver = driver
self.driver = telemetry.override_user_agent(driver)

@validate_call
async def run(
Expand Down
6 changes: 6 additions & 0 deletions src/neo4j_graphrag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from importlib.metadata import version, PackageNotFoundError

try:
__version__ = version("neo4j-graphrag")
except PackageNotFoundError:
__version__ = "0.0.0"
3 changes: 2 additions & 1 deletion src/neo4j_graphrag/experimental/components/kg_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
UPSERT_RELATIONSHIP_QUERY,
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE,
)
from neo4j_graphrag.utils import telemetry

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -113,7 +114,7 @@ def __init__(
neo4j_database: Optional[str] = None,
batch_size: int = 1000,
):
self.driver = driver
self.driver = telemetry.override_user_agent(driver)
self.neo4j_database = neo4j_database
self.batch_size = batch_size
self.is_version_5_23_or_above = self._check_if_version_5_23_or_above()
Expand Down
3 changes: 2 additions & 1 deletion src/neo4j_graphrag/experimental/components/neo4j_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
TextChunks,
)
from neo4j_graphrag.experimental.pipeline import Component
from neo4j_graphrag.utils import telemetry


class Neo4jChunkReader(Component):
Expand Down Expand Up @@ -58,7 +59,7 @@ def __init__(
fetch_embeddings: bool = False,
neo4j_database: Optional[str] = None,
):
self.driver = driver
self.driver = telemetry.override_user_agent(driver)
self.fetch_embeddings = fetch_embeddings
self.neo4j_database = neo4j_database

Expand Down
3 changes: 2 additions & 1 deletion src/neo4j_graphrag/experimental/components/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from neo4j_graphrag.experimental.components.types import ResolutionStats
from neo4j_graphrag.experimental.pipeline import Component
from neo4j_graphrag.utils import telemetry


class EntityResolver(Component, abc.ABC):
Expand All @@ -34,7 +35,7 @@ def __init__(
driver: neo4j.Driver,
filter_query: Optional[str] = None,
) -> None:
self.driver = driver
self.driver = telemetry.override_user_agent(driver)
self.filter_query = filter_query

@abc.abstractmethod
Expand Down
3 changes: 2 additions & 1 deletion src/neo4j_graphrag/retrievers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from neo4j_graphrag.exceptions import Neo4jVersionError
from neo4j_graphrag.types import RawSearchResult, RetrieverResult, RetrieverResultItem
from neo4j_graphrag.utils import telemetry

T = ParamSpec("T")
P = TypeVar("P")
Expand Down Expand Up @@ -82,7 +83,7 @@ class Retriever(ABC, metaclass=RetrieverMetaclass):
VERIFY_NEO4J_VERSION = True

def __init__(self, driver: neo4j.Driver, neo4j_database: Optional[str] = None):
self.driver = driver
self.driver = telemetry.override_user_agent(driver)
self.neo4j_database = neo4j_database
if self.VERIFY_NEO4J_VERSION:
self._verify_version()
Expand Down
8 changes: 8 additions & 0 deletions src/neo4j_graphrag/utils/telemetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import neo4j
from neo4j_graphrag import __version__


# Override user-agent used by neo4j package so we can measure usage of the package by version
def override_user_agent(driver: neo4j.Driver) -> neo4j.Driver:
driver._pool.pool_config.user_agent = f"neo4j-graphrag-python/v{__version__}"
return driver

0 comments on commit 2e49311

Please sign in to comment.