Skip to content

Commit

Permalink
RAG Pipeline Template
Browse files Browse the repository at this point in the history
  • Loading branch information
stellasia committed Dec 13, 2024
1 parent 5491b6d commit c1cb8e1
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/neo4j_graphrag/experimental/components/rag/generate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
from typing import Any, Optional

# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from neo4j_graphrag.experimental.pipeline import Component, DataModel
from neo4j_graphrag.llm import LLMInterface

Expand All @@ -9,9 +21,8 @@ class GenerationResult(DataModel):


class Generate(Component):
def __init__(self, llm: LLMInterface, return_context: bool = True) -> None:
def __init__(self, llm: LLMInterface) -> None:
self.llm = llm
self.return_context = return_context

async def run(self, prompt: str) -> GenerationResult:
llm_response = await self.llm.ainvoke(prompt)
Expand Down
14 changes: 14 additions & 0 deletions src/neo4j_graphrag/experimental/components/rag/prompt_builder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

from neo4j_graphrag.experimental.pipeline import Component, DataModel
Expand Down
14 changes: 14 additions & 0 deletions src/neo4j_graphrag/experimental/components/rag/retrievers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

from neo4j_graphrag.experimental.pipeline import Component, DataModel
Expand Down
39 changes: 39 additions & 0 deletions src/neo4j_graphrag/experimental/pipeline/config/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,15 @@
AbstractPipelineConfig,
PipelineConfig,
)
from neo4j_graphrag.experimental.pipeline.config.template_pipeline.rag_pipeline import \
SimpleRAGPipelineConfig
from neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder import (
SimpleKGPipelineConfig,
)
from neo4j_graphrag.experimental.pipeline.config.types import PipelineType
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.experimental.pipeline.types import PipelineDefinition
from neo4j_graphrag.generation.types import RagResultModel

logger = logging.getLogger(__name__)

Expand All @@ -67,6 +70,7 @@ class PipelineConfigWrapper(BaseModel):
config: Union[
Annotated[PipelineConfig, Tag(PipelineType.NONE)],
Annotated[SimpleKGPipelineConfig, Tag(PipelineType.SIMPLE_KG_PIPELINE)],
Annotated[SimpleRAGPipelineConfig, Tag(PipelineType.SIMPLE_RAG_PIPELINE)],
] = Field(discriminator=Discriminator(_get_discriminator_value))

def parse(self, resolved_data: dict[str, Any] | None = None) -> PipelineDefinition:
Expand Down Expand Up @@ -130,3 +134,38 @@ async def close(self) -> None:
logger.debug("PIPELINE_RUNNER: cleaning up (closing instantiated drivers...)")
if self.config:
await self.config.close()


class RagPipelineRunner(PipelineRunner):
async def search(self, **kwargs) -> RagResultModel:
result = await self.run(kwargs)
context = None
if kwargs.get("return_context"):
context = await self.pipeline.store.get_result_for_component(result.run_id, "retriever")
context = context.get("result")
return RagResultModel(
answer=result.result["generator"]["content"],
retriever_result=context,
)


if __name__ == "__main__":
import os
import asyncio

os.environ["NEO4J_URI"] = "neo4j+s://demo.neo4jlabs.com"
os.environ["NEO4J_USER"] = "recommendations"
os.environ["NEO4J_PASSWORD"] = "recommendations"

runner = RagPipelineRunner.from_config_file(
"src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_rag_pipeline_config.json"
)
print(
asyncio.run(runner.search(
query_text="Show me a movie about love",
retriever_config={
"top_k": 2,
},
return_context=True,
))
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import ClassVar, Literal, Union, Optional, Any, cast

from pydantic import ConfigDict

from neo4j_graphrag.experimental.components.rag.generate import Generate
from neo4j_graphrag.experimental.components.rag.prompt_builder import PromptBuilder
from neo4j_graphrag.experimental.components.rag.retrievers import RetrieverWrapper
from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentConfig, T
from neo4j_graphrag.experimental.pipeline.config.template_pipeline.base import \
TemplatePipelineConfig
from neo4j_graphrag.experimental.pipeline.config.types import PipelineType
from neo4j_graphrag.experimental.pipeline.types import ConnectionDefinition
from neo4j_graphrag.generation import RagTemplate
from neo4j_graphrag.retrievers.base import Retriever


class RetrieverConfig(ComponentConfig):
INTERFACE = Retriever # the result of _get_class is a Retriever
# it is translated into a RetrieverWrapper (which is a Component)
# in the 'parse' method below

def parse(self, resolved_data: dict[str, Any] | None = None) -> RetrieverWrapper:
retriever = super().parse(resolved_data)
return RetrieverWrapper(retriever)


class SimpleRAGPipelineConfig(TemplatePipelineConfig):
COMPONENTS: ClassVar[list[str]] = [
"retriever",
"prompt_builder",
"generator",
]
retriever: RetrieverConfig
prompt_template: Union[RagTemplate, str] = RagTemplate()
template_: Literal[PipelineType.SIMPLE_RAG_PIPELINE] = (
PipelineType.SIMPLE_RAG_PIPELINE
)

model_config = ConfigDict(
arbitrary_types_allowed=True
)

def _get_retriever(self) -> RetrieverWrapper:
retriever = self.retriever.parse(self._global_data)
return retriever

def _get_prompt_builder(self) -> PromptBuilder:
return PromptBuilder(self.prompt_template)

def _get_generator(self) -> Generate:
llm = self.get_default_llm()
return Generate(llm)

def _get_connections(self) -> list[ConnectionDefinition]:
connections = [ConnectionDefinition(
start="retriever",
end="prompt_builder",
input_config={"context": "retriever.result"},
), ConnectionDefinition(
start="prompt_builder",
end="generator",
input_config={
"prompt": "prompt_builder.prompt",
},
)]
return connections

def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]:
# query_text: str = "",
# examples: str = "",
# retriever_config: Optional[dict[str, Any]] = None,
# return_context: bool | None = None,
run_params = {
"retriever": {
"query_text": user_input["query_text"],
**user_input.get("retriever_config", {}),
},
"prompt_builder": {
"query_text": user_input["query_text"],
"examples": user_input.get("examples", ""),
},
"generator": {
}
}
return run_params
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{
"version_": "1",
"template_": "SimpleRAGPipeline",
"neo4j_config": {
"params_": {
"uri": {
"resolver_": "ENV",
"var_": "NEO4J_URI"
},
"user": {
"resolver_": "ENV",
"var_": "NEO4J_USER"
},
"password": {
"resolver_": "ENV",
"var_": "NEO4J_PASSWORD"
}
}
},
"llm_config": {
"class_": "OpenAILLM",
"params_": {
"api_key": {
"resolver_": "ENV",
"var_": "OPENAI_API_KEY"
},
"model_name": "gpt-4o",
"model_params": {
"temperature": 0,
"max_tokens": 2000
}
}
},
"embedder_config": {
"class_": "OpenAIEmbeddings",
"params_": {
"api_key": {
"resolver_": "ENV",
"var_": "OPENAI_API_KEY"
}
}
},
"retriever": {
"class_": "neo4j_graphrag.retrievers.VectorRetriever",
"params_": {
"driver": {
"resolver_": "CONFIG_KEY",
"key_": "neo4j_config.default"
},
"index_name": "moviePlotsEmbedding",
"embedder": {
"resolver_": "CONFIG_KEY",
"key_": "embedder_config.default"
}
}
}
}
1 change: 1 addition & 0 deletions src/neo4j_graphrag/experimental/pipeline/config/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ class PipelineType(str, enum.Enum):

NONE = "none"
SIMPLE_KG_PIPELINE = "SimpleKGPipeline"
SIMPLE_RAG_PIPELINE = "SimpleRAGPipeline"

0 comments on commit c1cb8e1

Please sign in to comment.