Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…into feature/neo4j-reader
  • Loading branch information
stellasia committed Oct 25, 2024
2 parents eb4f90e + 99bf50e commit 20ab815
Show file tree
Hide file tree
Showing 21 changed files with 929 additions and 1,070 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pr-e2e-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ jobs:
key: ${{ runner.os }}-venv-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction --no-cache --with dev
run: poetry install --no-interaction --no-cache --with dev --all-extras
- name: Clear Poetry cache
run: poetry cache clear --all .
- name: Show disk usage after Poetry installation
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction
run: poetry install --no-interaction --all-extras
- name: Check format and linting
run: |
poetry run ruff check .
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/scheduled-e2e-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ jobs:
key: ${{ runner.os }}-venv-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction --no-cache --with dev
run: poetry install --no-interaction --no-cache --with dev --all-extras
- name: Clear Poetry cache
run: poetry cache clear --all .
- name: Show disk usage after Poetry installation
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
- Removed support for neo4j.AsyncDriver in the KG creation pipeline, affecting Neo4jWriter and related components.
- Updated examples and unit tests to reflect the removal of async driver support.

### Fixed
- Resolved issue with `AzureOpenAIEmbeddings` incorrectly inheriting from `OpenAIEmbeddings`, now inherits from `BaseOpenAIEmbeddings`.

## 1.1.0

Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ Install dependencies

.. code:: bash
poetry install
poetry install --all-extras
***************
Getting started
Expand Down
4 changes: 2 additions & 2 deletions examples/customize/embeddings/azure_openai_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

from neo4j_graphrag.embeddings import AzureOpenAIEmbeddings

embeder = AzureOpenAIEmbeddings(
embedder = AzureOpenAIEmbeddings(
model="text-embedding-ada-002",
azure_endpoint="https://my-endpoint.openai.azure.com/",
api_key="<my key>",
api_version="<update version>",
)
res = embeder.embed_query("my question")
res = embedder.embed_query("my question")
print(res[:10])
1,574 changes: 665 additions & 909 deletions poetry.lock

Large diffs are not rendered by default.

70 changes: 27 additions & 43 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,62 +31,49 @@ from = "src"
python = "^3.9.0"
neo4j = "^5.17.0"
pydantic = "^2.6.3"
urllib3 = "<2"
weaviate-client = {version = "^4.6.1", optional = true}
pinecone-client = {version = "^4.1.0", optional = true}
types-mock = "^5.1.0.20240425"
eval-type-backport = "^0.2.0"
pypdf = "^4.3.1"
fsspec = "^2024.9.0"
fsspec = {version = "^2024.9.0", optional = true}
langchain-text-splitters = {version = "^0.3.0", optional = true }
pypdf = {version = "^4.3.1", optional = true}
pygraphviz = [
{version = "^1.13.0", python = ">=3.10,<4.0.0", optional = true},
{version = "^1.0.0", python = "<3.10", optional = true}
]
google-cloud-aiplatform = {version = "^1.66.0", optional = true}
weaviate-client = {version = "^4.6.1", optional = true }
pinecone-client = {version = "^4.1.0", optional = true }
google-cloud-aiplatform = {version = "^1.66.0", optional = true }
cohere = {version = "^5.9.0", optional = true}
anthropic = { version = "^0.34.2", optional = true}
mistralai = {version = "^1.0.3", optional = true}
qdrant-client = {version = "^1.11.3", optional = true}
llama-index = {version = "^0.10.55", optional = true }
openai = {version = "^1.51.1", optional = true }
anthropic = { version = "^0.36.0", optional = true}
sentence-transformers = {version = "^3.0.0", optional = true }

[tool.poetry.group.dev.dependencies]
pylint = "^3.1.0"
urllib3 = "<2"
ruff = "^0.3.0"
mypy = "^1.10.0"
pytest = "^8.0.2"
pytest-mock = "^3.12.0"
pre-commit = { version = "^3.6.2", python = "^3.9" }
coverage = "^7.4.3"
ruff = "^0.3.0"
langchain-text-splitters = "^0.3.0"
weaviate-client = "^4.6.1"
sentence-transformers = "^3.0.0"
pinecone-client = "^4.1.0"
requests = "^2.32.0"
sphinx = { version = "^7.2.6", python = "^3.9" }
tox = "^4.15.1"
numpy = [
{version = "^1.24.0", python = "<3.12"},
{version = "^1.26.0", python = ">=3.12"}
]
scipy = [
{version = "^1", python = "<3.12"},
{version = "^1.7.0", python = ">=3.12"}
]
llama-index = "^0.10.55"
pytest-asyncio = "^0.23.8"
pygraphviz = [
{version = "^1.13.0", python = ">=3.10,<4.0.0"},
{version = "^1.0.0", python = "<3.10"}
]
google-cloud-aiplatform = {version = "^1.66.0"}
cohere = {version = "^5.9.0"}
anthropic = { version = "^0.34.2"}
mistralai = {version = "^1.0.3"}
qdrant-client = {version = "^1.11.3"}
langchain-openai = "^0.2.2" # needed in the examples
pre-commit = { version = "^3.6.2", python = "^3.9" }
sphinx = { version = "^7.2.6", python = "^3.9" }
langchain-openai = {version = "^0.2.2", optional = true }
langchain-huggingface = {version = "^0.1.0", optional = true }

[tool.poetry.extras]
external_clients = ["weaviate-client", "pinecone-client", "google-cloud-aiplatform", "cohere", "anthropic", "mistralai", "qdrant-client"]
weaviate = ["weaviate-client"]
pinecone = ["pinecone-client"]
google = ["google-cloud-aiplatform"]
cohere = ["cohere"]
anthropic = ["anthropic"]
openai = ["openai"]
mistralai = ["mistralai"]
qdrant = ["qdrant-client"]
kg_creation_tools = ["pygraphviz"]
sentence-transformers = ["sentence-transformers"]
experimental = ["pypdf", "fsspec", "langchain-text-splitters", "pygraphviz", "llama-index"]
examples = ["langchain-openai", "langchain-huggingface"]

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand All @@ -101,9 +88,6 @@ filterwarnings = [
[tool.coverage.paths]
source = ["src"]

[tool.pylint."MESSAGES CONTROL"]
disable="C0114,C0115"

[tool.mypy]
strict = true
ignore_missing_imports = true
Expand Down
68 changes: 45 additions & 23 deletions src/neo4j_graphrag/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,52 @@
from __future__ import annotations

import abc
from typing import Any
from typing import TYPE_CHECKING, Any

from neo4j_graphrag.embeddings.base import Embedder

try:
if TYPE_CHECKING:
import openai
except ImportError:
openai = None # type: ignore


class BaseOpenAIEmbeddings(Embedder, abc.ABC):
"""
Abstract base class for OpenAI embeddings.
"""

client: openai.OpenAI

def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
if openai is None:
try:
import openai
except ImportError:
raise ImportError(
"Could not import openai python client. "
"Please install it with `pip install openai`."
)
self.openai = openai
self.model = model
self.client = self._initialize_client(**kwargs)

@abc.abstractmethod
def _initialize_client(self, **kwargs: Any) -> Any:
"""
Initialize the OpenAI client.
Must be implemented by subclasses.
"""
pass

def embed_query(self, text: str, **kwargs: Any) -> list[float]:
"""
Generate embeddings for a given query using an OpenAI text embedding model.
Args:
text (str): The text to generate an embedding for.
**kwargs (Any): Additional arguments to pass to the OpenAI embedding generation function.
"""
response = self.client.embeddings.create(input=text, model=self.model, **kwargs)
embedding: list[float] = response.data[0].embedding
return embedding


class OpenAIEmbeddings(BaseOpenAIEmbeddings):
Expand All @@ -46,25 +74,19 @@ class OpenAIEmbeddings(BaseOpenAIEmbeddings):
kwargs: All other parameters will be passed to the openai.OpenAI init.
"""

def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
super().__init__(model, **kwargs)
self.openai_client = openai.OpenAI(**kwargs)
def _initialize_client(self, **kwargs: Any) -> Any:
return self.openai.OpenAI(**kwargs)

def embed_query(self, text: str, **kwargs: Any) -> list[float]:
"""
Generate embeddings for a given query using a OpenAI text embedding model.

Args:
text (str): The text to generate an embedding for.
**kwargs (Any): Additional arguments to pass to the OpenAI embedding generation function.
"""
response = self.openai_client.embeddings.create(
input=text, model=self.model, **kwargs
)
return response.data[0].embedding
class AzureOpenAIEmbeddings(BaseOpenAIEmbeddings):
"""
Azure OpenAI embeddings class.
This class uses the Azure OpenAI python client to generate embeddings for text data.
Args:
model (str): The name of the Azure OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
kwargs: All other parameters will be passed to the openai.AzureOpenAI init.
"""

class AzureOpenAIEmbeddings(OpenAIEmbeddings):
def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
super().__init__(model, **kwargs)
self.openai_client = openai.AzureOpenAI(**kwargs)
def _initialize_client(self, **kwargs: Any) -> Any:
return self.openai.AzureOpenAI(**kwargs)
20 changes: 9 additions & 11 deletions src/neo4j_graphrag/embeddings/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,32 @@

from typing import Any

try:
import numpy as np
import sentence_transformers
import torch
except ImportError:
sentence_transformers = None # type: ignore


from neo4j_graphrag.embeddings.base import Embedder


class SentenceTransformerEmbeddings(Embedder):
def __init__(
self, model: str = "all-MiniLM-L6-v2", *args: Any, **kwargs: Any
) -> None:
if sentence_transformers is None:
try:
import numpy as np
import sentence_transformers
import torch
except ImportError:
raise ImportError(
"Could not import sentence_transformers python package. "
"Please install it with `pip install sentence-transformers`."
)
self.torch = torch
self.np = np
self.model = sentence_transformers.SentenceTransformer(model, *args, **kwargs)

def embed_query(self, text: str) -> Any:
result = self.model.encode([text])
if isinstance(result, torch.Tensor) or isinstance(result, np.ndarray):
if isinstance(result, self.torch.Tensor) or isinstance(result, self.np.ndarray):
return result.flatten().tolist()
elif isinstance(result, list) and all(
isinstance(x, torch.Tensor) for x in result
isinstance(x, self.torch.Tensor) for x in result
):
return [item for tensor in result for item in tensor.flatten().tolist()]
else:
Expand Down
6 changes: 4 additions & 2 deletions src/neo4j_graphrag/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from __future__ import annotations

import asyncio
import datetime
import enum
import logging
import uuid
import warnings
from collections import defaultdict
from datetime import datetime
from timeit import default_timer
from typing import Any, AsyncGenerator, Optional

Expand Down Expand Up @@ -61,7 +61,9 @@ class RunStatus(enum.Enum):
class RunResult(BaseModel):
status: RunStatus = RunStatus.DONE
result: Optional[DataModel] = None
timestamp: datetime = Field(default_factory=datetime.utcnow)
timestamp: datetime.datetime = Field(
default_factory=lambda: datetime.datetime.now(datetime.timezone.utc)
)


class TaskPipelineNode(PipelineNode):
Expand Down
16 changes: 6 additions & 10 deletions src/neo4j_graphrag/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,6 @@
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import LLMResponse

try:
import anthropic
from anthropic import APIError
except ImportError:
anthropic = None # type: ignore
APIError = None # type: ignore


class AnthropicLLM(LLMInterface):
"""Interface for large language models on Anthropic
Expand Down Expand Up @@ -58,12 +51,15 @@ def __init__(
model_params: Optional[dict[str, Any]] = None,
**kwargs: Any,
):
if anthropic is None:
try:
import anthropic
except ImportError:
raise ImportError(
"Could not import Anthropic Python client. "
"Please install it with `pip install anthropic`."
)
super().__init__(model_name, model_params)
self.anthropic = anthropic
self.client = anthropic.Anthropic(**kwargs)
self.async_client = anthropic.AsyncAnthropic(**kwargs)

Expand All @@ -88,7 +84,7 @@ def invoke(self, input: str) -> LLMResponse:
**self.model_params,
)
return LLMResponse(content=response.content)
except APIError as e:
except self.anthropic.APIError as e:
raise LLMGenerationError(e)

async def ainvoke(self, input: str) -> LLMResponse:
Expand All @@ -112,5 +108,5 @@ async def ainvoke(self, input: str) -> LLMResponse:
**self.model_params,
)
return LLMResponse(content=response.content)
except APIError as e:
except self.anthropic.APIError as e:
raise LLMGenerationError(e)
Loading

0 comments on commit 20ab815

Please sign in to comment.