Skip to content

Commit

Permalink
Improve import system (neo4j#175)
Browse files Browse the repository at this point in the history
* Improve import system for cohere

* Improve import system for openai

* Improve import system for anthropic

* Improve import system for sentence_transformers

* Fix async tests to support older Python versions

* Add tox for multiversion testing

* Reorganise dependencies

* Add qdrant to extras

* Move llama-index to experimental group

* Update lock file
  • Loading branch information
jonbesga authored Oct 24, 2024
1 parent 57529d4 commit bc8540e
Show file tree
Hide file tree
Showing 17 changed files with 861 additions and 1,065 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pr-e2e-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ jobs:
key: ${{ runner.os }}-venv-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction --no-cache --with dev
run: poetry install --no-interaction --no-cache --with dev --all-extras
- name: Clear Poetry cache
run: poetry cache clear --all .
- name: Show disk usage after Poetry installation
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction
run: poetry install --no-interaction --all-extras
- name: Check format and linting
run: |
poetry run ruff check .
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/scheduled-e2e-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ jobs:
key: ${{ runner.os }}-venv-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction --no-cache --with dev
run: poetry install --no-interaction --no-cache --with dev --all-extras
- name: Clear Poetry cache
run: poetry cache clear --all .
- name: Show disk usage after Poetry installation
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ Install dependencies

.. code:: bash
poetry install
poetry install --all-extras
***************
Getting started
Expand Down
1,574 changes: 665 additions & 909 deletions poetry.lock

Large diffs are not rendered by default.

70 changes: 27 additions & 43 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,62 +31,49 @@ from = "src"
python = "^3.9.0"
neo4j = "^5.17.0"
pydantic = "^2.6.3"
urllib3 = "<2"
weaviate-client = {version = "^4.6.1", optional = true}
pinecone-client = {version = "^4.1.0", optional = true}
types-mock = "^5.1.0.20240425"
eval-type-backport = "^0.2.0"
pypdf = "^4.3.1"
fsspec = "^2024.9.0"
fsspec = {version = "^2024.9.0", optional = true}
langchain-text-splitters = {version = "^0.3.0", optional = true }
pypdf = {version = "^4.3.1", optional = true}
pygraphviz = [
{version = "^1.13.0", python = ">=3.10,<4.0.0", optional = true},
{version = "^1.0.0", python = "<3.10", optional = true}
]
google-cloud-aiplatform = {version = "^1.66.0", optional = true}
weaviate-client = {version = "^4.6.1", optional = true }
pinecone-client = {version = "^4.1.0", optional = true }
google-cloud-aiplatform = {version = "^1.66.0", optional = true }
cohere = {version = "^5.9.0", optional = true}
anthropic = { version = "^0.34.2", optional = true}
mistralai = {version = "^1.0.3", optional = true}
qdrant-client = {version = "^1.11.3", optional = true}
llama-index = {version = "^0.10.55", optional = true }
openai = {version = "^1.51.1", optional = true }
anthropic = { version = "^0.36.0", optional = true}
sentence-transformers = {version = "^3.0.0", optional = true }

[tool.poetry.group.dev.dependencies]
pylint = "^3.1.0"
urllib3 = "<2"
ruff = "^0.3.0"
mypy = "^1.10.0"
pytest = "^8.0.2"
pytest-mock = "^3.12.0"
pre-commit = { version = "^3.6.2", python = "^3.9" }
coverage = "^7.4.3"
ruff = "^0.3.0"
langchain-text-splitters = "^0.3.0"
weaviate-client = "^4.6.1"
sentence-transformers = "^3.0.0"
pinecone-client = "^4.1.0"
requests = "^2.32.0"
sphinx = { version = "^7.2.6", python = "^3.9" }
tox = "^4.15.1"
numpy = [
{version = "^1.24.0", python = "<3.12"},
{version = "^1.26.0", python = ">=3.12"}
]
scipy = [
{version = "^1", python = "<3.12"},
{version = "^1.7.0", python = ">=3.12"}
]
llama-index = "^0.10.55"
pytest-asyncio = "^0.23.8"
pygraphviz = [
{version = "^1.13.0", python = ">=3.10,<4.0.0"},
{version = "^1.0.0", python = "<3.10"}
]
google-cloud-aiplatform = {version = "^1.66.0"}
cohere = {version = "^5.9.0"}
anthropic = { version = "^0.34.2"}
mistralai = {version = "^1.0.3"}
qdrant-client = {version = "^1.11.3"}
langchain-openai = "^0.2.2" # needed in the examples
pre-commit = { version = "^3.6.2", python = "^3.9" }
sphinx = { version = "^7.2.6", python = "^3.9" }
langchain-openai = {version = "^0.2.2", optional = true }
langchain-huggingface = {version = "^0.1.0", optional = true }

[tool.poetry.extras]
external_clients = ["weaviate-client", "pinecone-client", "google-cloud-aiplatform", "cohere", "anthropic", "mistralai", "qdrant-client"]
weaviate = ["weaviate-client"]
pinecone = ["pinecone-client"]
google = ["google-cloud-aiplatform"]
cohere = ["cohere"]
anthropic = ["anthropic"]
openai = ["openai"]
mistralai = ["mistralai"]
qdrant = ["qdrant-client"]
kg_creation_tools = ["pygraphviz"]
sentence-transformers = ["sentence-transformers"]
experimental = ["pypdf", "fsspec", "langchain-text-splitters", "pygraphviz", "llama-index"]
examples = ["langchain-openai", "langchain-huggingface"]

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand All @@ -101,9 +88,6 @@ filterwarnings = [
[tool.coverage.paths]
source = ["src"]

[tool.pylint."MESSAGES CONTROL"]
disable="C0114,C0115"

[tool.mypy]
strict = true
ignore_missing_imports = true
Expand Down
33 changes: 12 additions & 21 deletions src/neo4j_graphrag/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,11 @@

from __future__ import annotations

import abc
from typing import Any

from neo4j_graphrag.embeddings.base import Embedder

try:
import openai
except ImportError:
openai = None # type: ignore


class BaseOpenAIEmbeddings(Embedder, abc.ABC):
def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
if openai is None:
raise ImportError(
"Could not import openai python client. "
"Please install it with `pip install openai`."
)
self.model = model


class OpenAIEmbeddings(BaseOpenAIEmbeddings):
class OpenAIEmbeddings(Embedder):
"""
OpenAI embeddings class.
This class uses the OpenAI python client to generate embeddings for text data.
Expand All @@ -47,8 +30,16 @@ class OpenAIEmbeddings(BaseOpenAIEmbeddings):
"""

def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
super().__init__(model, **kwargs)
self.openai_client = openai.OpenAI(**kwargs)
try:
import openai
except ImportError:
raise ImportError(
"Could not import openai python client. "
"Please install it with `pip install openai`."
)
self.openai = openai
self.model = model
self.openai_client = self.openai.OpenAI(**kwargs)

def embed_query(self, text: str, **kwargs: Any) -> list[float]:
"""
Expand All @@ -67,4 +58,4 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]:
class AzureOpenAIEmbeddings(OpenAIEmbeddings):
def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
super().__init__(model, **kwargs)
self.openai_client = openai.AzureOpenAI(**kwargs)
self.openai_client = self.openai.AzureOpenAI(**kwargs)
20 changes: 9 additions & 11 deletions src/neo4j_graphrag/embeddings/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,32 @@

from typing import Any

try:
import numpy as np
import sentence_transformers
import torch
except ImportError:
sentence_transformers = None # type: ignore


from neo4j_graphrag.embeddings.base import Embedder


class SentenceTransformerEmbeddings(Embedder):
def __init__(
self, model: str = "all-MiniLM-L6-v2", *args: Any, **kwargs: Any
) -> None:
if sentence_transformers is None:
try:
import numpy as np
import sentence_transformers
import torch
except ImportError:
raise ImportError(
"Could not import sentence_transformers python package. "
"Please install it with `pip install sentence-transformers`."
)
self.torch = torch
self.np = np
self.model = sentence_transformers.SentenceTransformer(model, *args, **kwargs)

def embed_query(self, text: str) -> Any:
result = self.model.encode([text])
if isinstance(result, torch.Tensor) or isinstance(result, np.ndarray):
if isinstance(result, self.torch.Tensor) or isinstance(result, self.np.ndarray):
return result.flatten().tolist()
elif isinstance(result, list) and all(
isinstance(x, torch.Tensor) for x in result
isinstance(x, self.torch.Tensor) for x in result
):
return [item for tensor in result for item in tensor.flatten().tolist()]
else:
Expand Down
16 changes: 6 additions & 10 deletions src/neo4j_graphrag/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,6 @@
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import LLMResponse

try:
import anthropic
from anthropic import APIError
except ImportError:
anthropic = None # type: ignore
APIError = None # type: ignore


class AnthropicLLM(LLMInterface):
"""Interface for large language models on Anthropic
Expand Down Expand Up @@ -58,12 +51,15 @@ def __init__(
model_params: Optional[dict[str, Any]] = None,
**kwargs: Any,
):
if anthropic is None:
try:
import anthropic
except ImportError:
raise ImportError(
"Could not import Anthropic Python client. "
"Please install it with `pip install anthropic`."
)
super().__init__(model_name, model_params)
self.anthropic = anthropic
self.client = anthropic.Anthropic(**kwargs)
self.async_client = anthropic.AsyncAnthropic(**kwargs)

Expand All @@ -88,7 +84,7 @@ def invoke(self, input: str) -> LLMResponse:
**self.model_params,
)
return LLMResponse(content=response.content)
except APIError as e:
except self.anthropic.APIError as e:
raise LLMGenerationError(e)

async def ainvoke(self, input: str) -> LLMResponse:
Expand All @@ -112,5 +108,5 @@ async def ainvoke(self, input: str) -> LLMResponse:
**self.model_params,
)
return LLMResponse(content=response.content)
except APIError as e:
except self.anthropic.APIError as e:
raise LLMGenerationError(e)
21 changes: 10 additions & 11 deletions src/neo4j_graphrag/llm/cohere_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import LLMResponse

try:
import cohere
from cohere.core import ApiError
except ImportError:
cohere = None # type: ignore
ApiError = Exception # type: ignore[assignment, misc]


class CohereLLM(LLMInterface):
"""Interface for large language models on the Cohere platform
Expand Down Expand Up @@ -55,12 +48,18 @@ def __init__(
model_params: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
if cohere is None:
super().__init__(model_name, model_params)
try:
import cohere
except ImportError:
raise ImportError(
"Could not import cohere python client. "
"Please install it with `pip install cohere`."
)
super().__init__(model_name, model_params)

self.cohere = cohere
self.cohere_api_error = cohere.core.api_error.ApiError

self.client = cohere.Client(**kwargs)
self.async_client = cohere.AsyncClient(**kwargs)

Expand All @@ -78,7 +77,7 @@ def invoke(self, input: str) -> LLMResponse:
message=input,
model=self.model_name,
)
except ApiError as e:
except self.cohere_api_error as e:
raise LLMGenerationError(e)
return LLMResponse(
content=res.text,
Expand All @@ -98,7 +97,7 @@ async def ainvoke(self, input: str) -> LLMResponse:
message=input,
model=self.model_name,
)
except ApiError as e:
except self.cohere_api_error as e:
raise LLMGenerationError(e)
return LLMResponse(
content=res.text,
Expand Down
Loading

0 comments on commit bc8540e

Please sign in to comment.