From c1cb8e1cf42171ee261df6030d1f2e2efefdead3 Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 13 Dec 2024 10:32:43 +0100 Subject: [PATCH] RAG Pipeline Template --- .../experimental/components/rag/generate.py | 19 ++++- .../components/rag/prompt_builder.py | 14 +++ .../experimental/components/rag/retrievers.py | 14 +++ .../experimental/pipeline/config/runner.py | 39 +++++++++ .../config/template_pipeline/rag_pipeline.py | 85 +++++++++++++++++++ .../simple_rag_pipeline_config.json | 57 +++++++++++++ .../experimental/pipeline/config/types.py | 1 + 7 files changed, 225 insertions(+), 4 deletions(-) create mode 100644 src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/rag_pipeline.py create mode 100644 src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_rag_pipeline_config.json diff --git a/src/neo4j_graphrag/experimental/components/rag/generate.py b/src/neo4j_graphrag/experimental/components/rag/generate.py index 98ac0ada..6fe22f7e 100644 --- a/src/neo4j_graphrag/experimental/components/rag/generate.py +++ b/src/neo4j_graphrag/experimental/components/rag/generate.py @@ -1,5 +1,17 @@ -from typing import Any, Optional - +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from neo4j_graphrag.experimental.pipeline import Component, DataModel from neo4j_graphrag.llm import LLMInterface @@ -9,9 +21,8 @@ class GenerationResult(DataModel): class Generate(Component): - def __init__(self, llm: LLMInterface, return_context: bool = True) -> None: + def __init__(self, llm: LLMInterface) -> None: self.llm = llm - self.return_context = return_context async def run(self, prompt: str) -> GenerationResult: llm_response = await self.llm.ainvoke(prompt) diff --git a/src/neo4j_graphrag/experimental/components/rag/prompt_builder.py b/src/neo4j_graphrag/experimental/components/rag/prompt_builder.py index 36d1ba9e..8db58ada 100644 --- a/src/neo4j_graphrag/experimental/components/rag/prompt_builder.py +++ b/src/neo4j_graphrag/experimental/components/rag/prompt_builder.py @@ -1,3 +1,17 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any from neo4j_graphrag.experimental.pipeline import Component, DataModel diff --git a/src/neo4j_graphrag/experimental/components/rag/retrievers.py b/src/neo4j_graphrag/experimental/components/rag/retrievers.py index 30adfdae..c83c263b 100644 --- a/src/neo4j_graphrag/experimental/components/rag/retrievers.py +++ b/src/neo4j_graphrag/experimental/components/rag/retrievers.py @@ -1,3 +1,17 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any from neo4j_graphrag.experimental.pipeline import Component, DataModel diff --git a/src/neo4j_graphrag/experimental/pipeline/config/runner.py b/src/neo4j_graphrag/experimental/pipeline/config/runner.py index a1a22585..3ba3ed1c 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/runner.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/runner.py @@ -42,12 +42,15 @@ AbstractPipelineConfig, PipelineConfig, ) +from neo4j_graphrag.experimental.pipeline.config.template_pipeline.rag_pipeline import \ + SimpleRAGPipelineConfig from neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder import ( SimpleKGPipelineConfig, ) from neo4j_graphrag.experimental.pipeline.config.types import PipelineType from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult from neo4j_graphrag.experimental.pipeline.types import PipelineDefinition +from neo4j_graphrag.generation.types import RagResultModel logger = logging.getLogger(__name__) @@ -67,6 +70,7 @@ class PipelineConfigWrapper(BaseModel): config: Union[ Annotated[PipelineConfig, Tag(PipelineType.NONE)], Annotated[SimpleKGPipelineConfig, Tag(PipelineType.SIMPLE_KG_PIPELINE)], + Annotated[SimpleRAGPipelineConfig, Tag(PipelineType.SIMPLE_RAG_PIPELINE)], ] = Field(discriminator=Discriminator(_get_discriminator_value)) def parse(self, resolved_data: dict[str, Any] | None = None) -> PipelineDefinition: @@ -130,3 +134,38 @@ async def close(self) -> None: logger.debug("PIPELINE_RUNNER: cleaning up (closing instantiated drivers...)") if self.config: await self.config.close() + + +class RagPipelineRunner(PipelineRunner): + async def search(self, **kwargs) -> RagResultModel: + result = await self.run(kwargs) + context = None + if kwargs.get("return_context"): + context = await self.pipeline.store.get_result_for_component(result.run_id, "retriever") + context = context.get("result") + return RagResultModel( + answer=result.result["generator"]["content"], + retriever_result=context, + ) + + +if __name__ == "__main__": + import os + import asyncio + + os.environ["NEO4J_URI"] = "neo4j+s://demo.neo4jlabs.com" + os.environ["NEO4J_USER"] = "recommendations" + os.environ["NEO4J_PASSWORD"] = "recommendations" + + runner = RagPipelineRunner.from_config_file( + "src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_rag_pipeline_config.json" + ) + print( + asyncio.run(runner.search( + query_text="Show me a movie about love", + retriever_config={ + "top_k": 2, + }, + return_context=True, + )) + ) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/rag_pipeline.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/rag_pipeline.py new file mode 100644 index 00000000..744a0808 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/rag_pipeline.py @@ -0,0 +1,85 @@ +from typing import ClassVar, Literal, Union, Optional, Any, cast + +from pydantic import ConfigDict + +from neo4j_graphrag.experimental.components.rag.generate import Generate +from neo4j_graphrag.experimental.components.rag.prompt_builder import PromptBuilder +from neo4j_graphrag.experimental.components.rag.retrievers import RetrieverWrapper +from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentConfig, T +from neo4j_graphrag.experimental.pipeline.config.template_pipeline.base import \ + TemplatePipelineConfig +from neo4j_graphrag.experimental.pipeline.config.types import PipelineType +from neo4j_graphrag.experimental.pipeline.types import ConnectionDefinition +from neo4j_graphrag.generation import RagTemplate +from neo4j_graphrag.retrievers.base import Retriever + + +class RetrieverConfig(ComponentConfig): + INTERFACE = Retriever # the result of _get_class is a Retriever + # it is translated into a RetrieverWrapper (which is a Component) + # in the 'parse' method below + + def parse(self, resolved_data: dict[str, Any] | None = None) -> RetrieverWrapper: + retriever = super().parse(resolved_data) + return RetrieverWrapper(retriever) + + +class SimpleRAGPipelineConfig(TemplatePipelineConfig): + COMPONENTS: ClassVar[list[str]] = [ + "retriever", + "prompt_builder", + "generator", + ] + retriever: RetrieverConfig + prompt_template: Union[RagTemplate, str] = RagTemplate() + template_: Literal[PipelineType.SIMPLE_RAG_PIPELINE] = ( + PipelineType.SIMPLE_RAG_PIPELINE + ) + + model_config = ConfigDict( + arbitrary_types_allowed=True + ) + + def _get_retriever(self) -> RetrieverWrapper: + retriever = self.retriever.parse(self._global_data) + return retriever + + def _get_prompt_builder(self) -> PromptBuilder: + return PromptBuilder(self.prompt_template) + + def _get_generator(self) -> Generate: + llm = self.get_default_llm() + return Generate(llm) + + def _get_connections(self) -> list[ConnectionDefinition]: + connections = [ConnectionDefinition( + start="retriever", + end="prompt_builder", + input_config={"context": "retriever.result"}, + ), ConnectionDefinition( + start="prompt_builder", + end="generator", + input_config={ + "prompt": "prompt_builder.prompt", + }, + )] + return connections + + def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: + # query_text: str = "", + # examples: str = "", + # retriever_config: Optional[dict[str, Any]] = None, + # return_context: bool | None = None, + run_params = { + "retriever": { + "query_text": user_input["query_text"], + **user_input.get("retriever_config", {}), + }, + "prompt_builder": { + "query_text": user_input["query_text"], + "examples": user_input.get("examples", ""), + }, + "generator": { + } + } + return run_params diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_rag_pipeline_config.json b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_rag_pipeline_config.json new file mode 100644 index 00000000..820597f5 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_rag_pipeline_config.json @@ -0,0 +1,57 @@ +{ + "version_": "1", + "template_": "SimpleRAGPipeline", + "neo4j_config": { + "params_": { + "uri": { + "resolver_": "ENV", + "var_": "NEO4J_URI" + }, + "user": { + "resolver_": "ENV", + "var_": "NEO4J_USER" + }, + "password": { + "resolver_": "ENV", + "var_": "NEO4J_PASSWORD" + } + } + }, + "llm_config": { + "class_": "OpenAILLM", + "params_": { + "api_key": { + "resolver_": "ENV", + "var_": "OPENAI_API_KEY" + }, + "model_name": "gpt-4o", + "model_params": { + "temperature": 0, + "max_tokens": 2000 + } + } + }, + "embedder_config": { + "class_": "OpenAIEmbeddings", + "params_": { + "api_key": { + "resolver_": "ENV", + "var_": "OPENAI_API_KEY" + } + } + }, + "retriever": { + "class_": "neo4j_graphrag.retrievers.VectorRetriever", + "params_": { + "driver": { + "resolver_": "CONFIG_KEY", + "key_": "neo4j_config.default" + }, + "index_name": "moviePlotsEmbedding", + "embedder": { + "resolver_": "CONFIG_KEY", + "key_": "embedder_config.default" + } + } + } +} diff --git a/src/neo4j_graphrag/experimental/pipeline/config/types.py b/src/neo4j_graphrag/experimental/pipeline/config/types.py index 48f91f48..d90a7d6a 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/types.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/types.py @@ -24,3 +24,4 @@ class PipelineType(str, enum.Enum): NONE = "none" SIMPLE_KG_PIPELINE = "SimpleKGPipeline" + SIMPLE_RAG_PIPELINE = "SimpleRAGPipeline"