Skip to content

Commit

Permalink
fix conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Jul 3, 2024
1 parent 07b5288 commit 9f8ce49
Show file tree
Hide file tree
Showing 14 changed files with 288 additions and 108 deletions.
2 changes: 1 addition & 1 deletion wren-ai-service/.env.dev.example
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,4 @@ LOGGING_LEVEL=INFO
LANGFUSE_ENABLE=
LANGFUSE_SECRET_KEY=
LANGFUSE_PUBLIC_KEY=
LANGFUSE_HOST=https://cloud.langfuse.com
LANGFUSE_HOST=https://cloud.langfuse.com
303 changes: 202 additions & 101 deletions wren-ai-service/poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions wren-ai-service/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ orjson = "==3.10.3"
sf-hamilton = {version = "==1.63.0", extras = ["visualization"]}
aiohttp = "==3.9.5"
ollama-haystack = "==0.0.6"
langfuse = "==2.35.0"

[tool.poetry.group.dev.dependencies]
pytest = "==8.2.0"
Expand Down
3 changes: 2 additions & 1 deletion wren-ai-service/src/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi.responses import ORJSONResponse, RedirectResponse

import src.globals as container
from src.utils import load_env_vars, setup_custom_logger
from src.utils import init_langfuse, load_env_vars, setup_custom_logger
from src.web.v1 import routers

env = load_env_vars()
Expand All @@ -26,6 +26,7 @@
async def lifespan(app: FastAPI):
# startup events
container.init_globals()
init_langfuse()

yield

Expand Down
10 changes: 9 additions & 1 deletion wren-ai-service/src/pipelines/ask/followup_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from hamilton.experimental.h_async import AsyncDriver
from haystack import Document
from haystack.components.builders.prompt_builder import PromptBuilder
from langfuse.decorators import langfuse_context, observe

from src.core.engine import Engine
from src.core.pipeline import BasicPipeline, async_validate
Expand All @@ -16,7 +17,7 @@
TEXT_TO_SQL_RULES,
text_to_sql_system_prompt,
)
from src.utils import async_timer, init_providers, timer
from src.utils import async_timer, init_langfuse, init_providers, timer
from src.web.v1.services.ask import AskRequest

logger = logging.getLogger("wren-ai-service")
Expand Down Expand Up @@ -130,6 +131,7 @@

## Start of Pipeline
@timer
@observe(capture_input=False)
def prompt(
query: str,
documents: List[Document],
Expand All @@ -146,12 +148,14 @@ def prompt(


@async_timer
@observe(as_type="generation", capture_input=False)
async def generate(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {prompt}")
return await generator.run(prompt=prompt.get("prompt"))


@async_timer
@observe(capture_input=False)
async def post_process(generate: dict, post_processor: GenerationPostProcessor) -> dict:
logger.debug(f"generate: {generate}")
return await post_processor.run(generate.get("replies"))
Expand Down Expand Up @@ -205,6 +209,7 @@ def visualize(
)

@async_timer
@observe(name="Ask Follow Up Generation")
async def run(
self,
query: str,
Expand All @@ -230,6 +235,7 @@ async def run(
from src.utils import load_env_vars

load_env_vars()
init_langfuse()

llm_provider, _, _, engine = init_providers()
pipeline = FollowUpGeneration(llm_provider=llm_provider, engine=engine)
Expand All @@ -250,3 +256,5 @@ async def run(
),
)
)

langfuse_context.flush()
10 changes: 9 additions & 1 deletion wren-ai-service/src/pipelines/ask/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from hamilton.experimental.h_async import AsyncDriver
from haystack import Document
from haystack.components.builders.prompt_builder import PromptBuilder
from langfuse.decorators import langfuse_context, observe

from src.core.engine import Engine
from src.core.pipeline import BasicPipeline, async_validate
Expand All @@ -18,7 +19,7 @@
TEXT_TO_SQL_RULES,
text_to_sql_system_prompt,
)
from src.utils import async_timer, init_providers, timer
from src.utils import async_timer, init_langfuse, init_providers, timer

logger = logging.getLogger("wren-ai-service")

Expand Down Expand Up @@ -90,6 +91,7 @@

## Start of Pipeline
@timer
@observe(capture_input=False)
def prompt(
query: str,
documents: List[Document],
Expand All @@ -106,12 +108,14 @@ def prompt(


@async_timer
@observe(as_type="generation", capture_input=False)
async def generate(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {prompt}")
return await generator.run(prompt=prompt.get("prompt"))


@async_timer
@observe(capture_input=False)
async def post_process(generate: dict, post_processor: GenerationPostProcessor) -> dict:
logger.debug(f"generate: {generate}")
return await post_processor.run(generate.get("replies"))
Expand Down Expand Up @@ -163,6 +167,7 @@ def visualize(
)

@async_timer
@observe(name="Ask Generation")
async def run(
self,
query: str,
Expand All @@ -188,6 +193,7 @@ async def run(
from src.utils import load_env_vars

load_env_vars()
init_langfuse()

llm_provider, _, _, engine = init_providers()
pipeline = Generation(
Expand All @@ -197,3 +203,5 @@ async def run(

pipeline.visualize("this is a test query", [], [])
async_validate(lambda: pipeline.run("this is a test query", [], []))

langfuse_context.flush()
10 changes: 10 additions & 0 deletions wren-ai-service/src/pipelines/ask/historical_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
from hamilton import base
from hamilton.experimental.h_async import AsyncDriver
from haystack import Document, component
from langfuse.decorators import langfuse_context, observe

from src.core.pipeline import BasicPipeline, async_validate
from src.core.provider import DocumentStoreProvider, EmbedderProvider
from src.utils import (
async_timer,
init_langfuse,
init_providers,
timer,
)
Expand Down Expand Up @@ -56,25 +58,29 @@ def run(self, documents: List[Document]):

## Start of Pipeline
@async_timer
@observe(capture_input=False, capture_output=False)
async def embedding(query: str, embedder: Any) -> dict:
logger.debug(f"query: {query}")
return await embedder.run(query)


@async_timer
@observe(capture_input=False)
async def retrieval(embedding: dict, retriever: Any) -> dict:
res = await retriever.run(query_embedding=embedding.get("embedding"))
documents = res.get("documents")
return dict(documents=documents)


@timer
@observe(capture_input=False)
def filtered_documents(retrieval: dict, score_filter: ScoreFilter) -> dict:
logger.debug(f"retrieval: {retrieval}")
return score_filter.run(documents=retrieval.get("documents"))


@timer
@observe(capture_input=False)
def formatted_output(
filtered_documents: dict, output_formatter: OutputFormatter
) -> dict:
Expand Down Expand Up @@ -126,6 +132,7 @@ def visualize(
)

@async_timer
@observe(name="Ask Historical Question")
async def run(self, query: str):
logger.info("Ask HistoricalQuestion pipeline is running...")
return await self._pipe.execute(
Expand All @@ -144,6 +151,7 @@ async def run(self, query: str):
from src.utils import load_env_vars

load_env_vars()
init_langfuse()

_, embedder_provider, document_store_provider, _ = init_providers()

Expand All @@ -153,3 +161,5 @@ async def run(self, query: str):

pipeline.visualize("this is a query")
async_validate(lambda: pipeline.run("this is a query"))

langfuse_context.flush()
9 changes: 8 additions & 1 deletion wren-ai-service/src/pipelines/ask/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from hamilton import base
from hamilton.experimental.h_async import AsyncDriver
from langfuse.decorators import langfuse_context, observe

from src.core.pipeline import BasicPipeline, async_validate
from src.core.provider import DocumentStoreProvider, EmbedderProvider
Expand All @@ -15,12 +16,14 @@

## Start of Pipeline
@async_timer
@observe(capture_input=False, capture_output=False)
async def embedding(query: str, embedder: Any) -> dict:
logger.debug(f"query: {query}")
return await embedder.run(query)


@async_timer
@observe(capture_input=False)
async def retrieval(embedding: dict, retriever: Any) -> dict:
return await retriever.run(query_embedding=embedding.get("embedding"))

Expand Down Expand Up @@ -64,6 +67,7 @@ def visualize(
)

@async_timer
@observe(name="Ask Retrieval")
async def run(self, query: str):
logger.info("Ask Retrieval pipeline is running...")
return await self._pipe.execute(
Expand All @@ -77,9 +81,10 @@ async def run(self, query: str):


if __name__ == "__main__":
from src.utils import load_env_vars
from src.utils import init_langfuse, load_env_vars

load_env_vars()
init_langfuse()

_, embedder_provider, document_store_provider, _ = init_providers()
pipeline = Retrieval(
Expand All @@ -89,3 +94,5 @@ async def run(self, query: str):

pipeline.visualize("this is a query")
async_validate(lambda: pipeline.run("this is a query"))

langfuse_context.flush()
10 changes: 9 additions & 1 deletion wren-ai-service/src/pipelines/ask/sql_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from hamilton.experimental.h_async import AsyncDriver
from haystack import Document
from haystack.components.builders.prompt_builder import PromptBuilder
from langfuse.decorators import langfuse_context, observe

from src.core.engine import Engine
from src.core.pipeline import BasicPipeline, async_validate
Expand All @@ -16,7 +17,7 @@
TEXT_TO_SQL_RULES,
text_to_sql_system_prompt,
)
from src.utils import async_timer, init_providers, timer
from src.utils import async_timer, init_langfuse, init_providers, timer

logger = logging.getLogger("wren-ai-service")

Expand Down Expand Up @@ -58,6 +59,7 @@

## Start of Pipeline
@timer
@observe(capture_input=False)
def prompt(
documents: List[Document],
invalid_generation_results: List[Dict],
Expand All @@ -74,12 +76,14 @@ def prompt(


@async_timer
@observe(as_type="generation", capture_input=False)
async def generate(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {prompt}")
return await generator.run(prompt=prompt.get("prompt"))


@async_timer
@observe(capture_input=False)
async def post_process(generate: dict, post_processor: GenerationPostProcessor) -> dict:
logger.debug(f"generate: {generate}")
return await post_processor.run(generate.get("replies"))
Expand Down Expand Up @@ -131,6 +135,7 @@ def visualize(
)

@async_timer
@observe(name="Ask SQL Correction")
async def run(
self,
contexts: List[Document],
Expand All @@ -154,6 +159,7 @@ async def run(
from src.utils import load_env_vars

load_env_vars()
init_langfuse()

llm_provider, _, _, engine = init_providers()
pipeline = SQLCorrection(
Expand All @@ -163,3 +169,5 @@ async def run(

pipeline.visualize([], [])
async_validate(lambda: pipeline.run([], []))

langfuse_context.flush()
9 changes: 9 additions & 0 deletions wren-ai-service/src/pipelines/ask_details/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from hamilton.experimental.h_async import AsyncDriver
from haystack import component
from haystack.components.builders.prompt_builder import PromptBuilder
from langfuse.decorators import langfuse_context, observe

from src.core.engine import (
Engine,
Expand All @@ -23,6 +24,7 @@
)
from src.utils import (
async_timer,
init_langfuse,
init_providers,
timer,
)
Expand Down Expand Up @@ -121,18 +123,21 @@ async def _check_if_sql_executable(

## Start of Pipeline
@timer
@observe(capture_input=False)
def prompt(sql: str, prompt_builder: PromptBuilder) -> dict:
logger.debug(f"sql: {sql}")
return prompt_builder.run(sql=sql)


@async_timer
@observe(as_type="generation", capture_input=False)
async def generate(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {prompt}")
return await generator.run(prompt=prompt.get("prompt"))


@async_timer
@observe(capture_input=False)
async def post_process(generate: dict, post_processor: GenerationPostProcessor) -> dict:
logger.debug(f"generate: {generate}")
return await post_processor.run(generate.get("replies"))
Expand Down Expand Up @@ -176,6 +181,7 @@ def visualize(self, sql: str) -> None:
)

@async_timer
@observe(name="Ask_Details Generation")
async def run(self, sql: str):
logger.info("Ask_Details Generation pipeline is running...")
return await self._pipe.execute(
Expand All @@ -193,6 +199,7 @@ async def run(self, sql: str):
from src.utils import load_env_vars

load_env_vars()
init_langfuse()

llm_provider, _, _, engine = init_providers()
pipeline = Generation(
Expand All @@ -202,3 +209,5 @@ async def run(self, sql: str):

pipeline.visualize("SELECT * FROM table_name")
async_validate(lambda: pipeline.run("SELECT * FROM table_name"))

langfuse_context.flush()
Loading

0 comments on commit 9f8ce49

Please sign in to comment.