Skip to content

Commit

Permalink
Merge branch 'main' into bugfix/multiple_table_selection
Browse files Browse the repository at this point in the history
  • Loading branch information
himanshu634 committed Jul 9, 2024
2 parents 9c2dd6a + ae657bb commit e9253ca
Show file tree
Hide file tree
Showing 14 changed files with 208 additions and 14 deletions.
2 changes: 1 addition & 1 deletion wren-ai-service/.env.dev.example
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,4 @@ LOGGING_LEVEL=INFO
LANGFUSE_ENABLE=
LANGFUSE_SECRET_KEY=
LANGFUSE_PUBLIC_KEY=
LANGFUSE_HOST=https://cloud.langfuse.com
LANGFUSE_HOST=https://cloud.langfuse.com
113 changes: 108 additions & 5 deletions wren-ai-service/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions wren-ai-service/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ orjson = "==3.10.3"
sf-hamilton = {version = "==1.63.0", extras = ["visualization"]}
aiohttp = "==3.9.5"
ollama-haystack = "==0.0.6"
langfuse = "==2.35.0"

[tool.poetry.group.dev.dependencies]
pytest = "==8.2.0"
Expand Down
3 changes: 2 additions & 1 deletion wren-ai-service/src/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi.responses import ORJSONResponse, RedirectResponse

import src.globals as container
from src.utils import load_env_vars, setup_custom_logger
from src.utils import init_langfuse, load_env_vars, setup_custom_logger
from src.web.v1 import routers

env = load_env_vars()
Expand All @@ -26,6 +26,7 @@
async def lifespan(app: FastAPI):
# startup events
container.init_globals()
init_langfuse()

yield

Expand Down
12 changes: 11 additions & 1 deletion wren-ai-service/src/pipelines/ask/followup_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from hamilton.experimental.h_async import AsyncDriver
from haystack import Document
from haystack.components.builders.prompt_builder import PromptBuilder
from langfuse.decorators import observe

from src.core.engine import Engine
from src.core.pipeline import BasicPipeline, async_validate
Expand Down Expand Up @@ -130,6 +131,7 @@

## Start of Pipeline
@timer
@observe(capture_input=False)
def prompt(
query: str,
documents: List[Document],
Expand All @@ -146,12 +148,14 @@ def prompt(


@async_timer
@observe(as_type="generation", capture_input=False)
async def generate(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {prompt}")
return await generator.run(prompt=prompt.get("prompt"))


@async_timer
@observe(capture_input=False)
async def post_process(generate: dict, post_processor: GenerationPostProcessor) -> dict:
logger.debug(f"generate: {generate}")
return await post_processor.run(generate.get("replies"))
Expand Down Expand Up @@ -205,6 +209,7 @@ def visualize(
)

@async_timer
@observe(name="Ask Follow Up Generation")
async def run(
self,
query: str,
Expand All @@ -227,9 +232,12 @@ async def run(


if __name__ == "__main__":
from src.utils import load_env_vars
from langfuse.decorators import langfuse_context

from src.utils import init_langfuse, load_env_vars

load_env_vars()
init_langfuse()

llm_provider, _, _, engine = init_providers()
pipeline = FollowUpGeneration(llm_provider=llm_provider, engine=engine)
Expand All @@ -250,3 +258,5 @@ async def run(
),
)
)

langfuse_context.flush()
12 changes: 11 additions & 1 deletion wren-ai-service/src/pipelines/ask/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from hamilton.experimental.h_async import AsyncDriver
from haystack import Document
from haystack.components.builders.prompt_builder import PromptBuilder
from langfuse.decorators import observe

from src.core.engine import Engine
from src.core.pipeline import BasicPipeline, async_validate
Expand Down Expand Up @@ -90,6 +91,7 @@

## Start of Pipeline
@timer
@observe(capture_input=False)
def prompt(
query: str,
documents: List[Document],
Expand All @@ -106,12 +108,14 @@ def prompt(


@async_timer
@observe(as_type="generation", capture_input=False)
async def generate(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {prompt}")
return await generator.run(prompt=prompt.get("prompt"))


@async_timer
@observe(capture_input=False)
async def post_process(generate: dict, post_processor: GenerationPostProcessor) -> dict:
logger.debug(f"generate: {generate}")
return await post_processor.run(generate.get("replies"))
Expand Down Expand Up @@ -163,6 +167,7 @@ def visualize(
)

@async_timer
@observe(name="Ask Generation")
async def run(
self,
query: str,
Expand All @@ -185,9 +190,12 @@ async def run(


if __name__ == "__main__":
from src.utils import load_env_vars
from langfuse.decorators import langfuse_context

from src.utils import init_langfuse, load_env_vars

load_env_vars()
init_langfuse()

llm_provider, _, _, engine = init_providers()
pipeline = Generation(
Expand All @@ -197,3 +205,5 @@ async def run(

pipeline.visualize("this is a test query", [], [])
async_validate(lambda: pipeline.run("this is a test query", [], []))

langfuse_context.flush()
13 changes: 12 additions & 1 deletion wren-ai-service/src/pipelines/ask/historical_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from hamilton import base
from hamilton.experimental.h_async import AsyncDriver
from haystack import Document, component
from langfuse.decorators import observe

from src.core.pipeline import BasicPipeline, async_validate
from src.core.provider import DocumentStoreProvider, EmbedderProvider
Expand Down Expand Up @@ -56,25 +57,29 @@ def run(self, documents: List[Document]):

## Start of Pipeline
@async_timer
@observe(capture_input=False, capture_output=False)
async def embedding(query: str, embedder: Any) -> dict:
logger.debug(f"query: {query}")
return await embedder.run(query)


@async_timer
@observe(capture_input=False)
async def retrieval(embedding: dict, retriever: Any) -> dict:
res = await retriever.run(query_embedding=embedding.get("embedding"))
documents = res.get("documents")
return dict(documents=documents)


@timer
@observe(capture_input=False)
def filtered_documents(retrieval: dict, score_filter: ScoreFilter) -> dict:
logger.debug(f"retrieval: {retrieval}")
return score_filter.run(documents=retrieval.get("documents"))


@timer
@observe(capture_input=False)
def formatted_output(
filtered_documents: dict, output_formatter: OutputFormatter
) -> dict:
Expand Down Expand Up @@ -126,6 +131,7 @@ def visualize(
)

@async_timer
@observe(name="Ask Historical Question")
async def run(self, query: str):
logger.info("Ask HistoricalQuestion pipeline is running...")
return await self._pipe.execute(
Expand All @@ -141,9 +147,12 @@ async def run(self, query: str):


if __name__ == "__main__":
from src.utils import load_env_vars
from langfuse.decorators import langfuse_context

from src.utils import init_langfuse, load_env_vars

load_env_vars()
init_langfuse()

_, embedder_provider, document_store_provider, _ = init_providers()

Expand All @@ -153,3 +162,5 @@ async def run(self, query: str):

pipeline.visualize("this is a query")
async_validate(lambda: pipeline.run("this is a query"))

langfuse_context.flush()
Loading

0 comments on commit e9253ca

Please sign in to comment.