diff --git a/wren-ai-service/Makefile b/wren-ai-service/Makefile index fbaf2a418..30d23c045 100644 --- a/wren-ai-service/Makefile +++ b/wren-ai-service/Makefile @@ -12,15 +12,7 @@ dev-down: ## wren-ai-service related ## start: - poetry run python -m src.__main__ & \ - make force_deploy - -force_deploy: - while ! nc -z localhost 5556; do \ - sleep 1; \ - done; \ - echo "wren-ai-service is up and running" && \ - poetry run python src/force_deploy.py + poetry run python -m src.__main__ build: docker compose -f docker/docker-compose.yaml --env-file .env.prod build diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 2b070ecec..c0d42efd2 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -7,9 +7,6 @@ from src.pipelines.ask import ( historical_question, ) -from src.pipelines.ask import ( - query_understanding as ask_query_understanding, -) from src.pipelines.ask import ( retrieval as ask_retrieval, ) @@ -60,9 +57,6 @@ def init_globals(): llm_provider=llm_provider, document_store_provider=document_store_provider, ), - "query_understanding": ask_query_understanding.QueryUnderstanding( - llm_provider=llm_provider, - ), "retrieval": ask_retrieval.Retrieval( llm_provider=llm_provider, document_store_provider=document_store_provider, diff --git a/wren-ai-service/src/pipelines/ask/query_understanding.py b/wren-ai-service/src/pipelines/ask/query_understanding.py deleted file mode 100644 index 3a4dbfce1..000000000 --- a/wren-ai-service/src/pipelines/ask/query_understanding.py +++ /dev/null @@ -1,150 +0,0 @@ -import logging -import sys -from pathlib import Path -from typing import Any, List - -import orjson -from hamilton import base -from hamilton.experimental.h_async import AsyncDriver -from haystack import component -from haystack.components.builders.prompt_builder import PromptBuilder - -from src.core.pipeline import BasicPipeline, async_validate -from src.core.provider import LLMProvider -from src.utils import async_timer, init_providers, timer - -logger = logging.getLogger("wren-ai-service") - - -_prompt = """ -### TASK ### -Based on the user's input below, classify whether the query is not random words. -Provide your classification as 'Yes' or 'No'. Yes if you think the query is not random words, and No if you think the query is random words. - -### FINAL ANSWER FORMAT ### -The final answer must be the JSON format like following: - -{ - "result": "yes" or "no" -} - -### INPUT ### -{{ query }} - -Let's think step by step. -""" - - -@component -class QueryUnderstandingPostProcessor: - @component.output_types( - is_valid_query=bool, - ) - def run(self, replies: List[str]): - try: - result = orjson.loads(replies[0])["result"].lower() - - if result == "yes": - return { - "is_valid_query": True, - } - - return { - "is_valid_query": False, - } - except Exception as e: - logger.exception(f"Error in QueryUnderstandingPostProcessor: {e}") - - return { - "is_valid_query": True, - } - - -## Start of Pipeline -@timer -def prompt(query: str, prompt_builder: PromptBuilder) -> dict: - logger.debug(f"query: {query}") - return prompt_builder.run(query=query) - - -@async_timer -async def generate(prompt: dict, generator: Any) -> dict: - logger.debug(f"prompt: {prompt}") - return await generator.run(prompt=prompt.get("prompt")) - - -@timer -def post_process( - generate: dict, post_processor: QueryUnderstandingPostProcessor -) -> dict: - logger.debug(f"generate: {generate}") - return post_processor.run(generate.get("replies")) - - -## End of Pipeline - - -class QueryUnderstanding(BasicPipeline): - def __init__( - self, - llm_provider: LLMProvider, - ): - self.generator = llm_provider.get_generator() - self.prompt_builder = PromptBuilder(template=_prompt) - self.post_processor = QueryUnderstandingPostProcessor() - - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - - def visualize( - self, - query: str, - ) -> None: - destination = "outputs/pipelines/ask" - if not Path(destination).exists(): - Path(destination).mkdir(parents=True, exist_ok=True) - - self._pipe.visualize_execution( - ["post_process"], - output_file_path=f"{destination}/query_understanding.dot", - inputs={ - "query": query, - "generator": self.generator, - "prompt_builder": self.prompt_builder, - "post_processor": self.post_processor, - }, - show_legend=True, - orient="LR", - ) - - @async_timer - async def run( - self, - query: str, - ): - logger.info("Ask QueryUnderstanding pipeline is running...") - return await self._pipe.execute( - ["post_process"], - inputs={ - "query": query, - "generator": self.generator, - "prompt_builder": self.prompt_builder, - "post_processor": self.post_processor, - }, - ) - - -if __name__ == "__main__": - from src.utils import load_env_vars - - load_env_vars() - - llm_provider, _, _ = init_providers() - pipeline = QueryUnderstanding( - llm_provider=llm_provider, - ) - - input = "this is a test query" - pipeline.visualize(input) - async_validate(lambda: pipeline.run(input)) diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 2095fdbb2..83c537add 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -94,7 +94,7 @@ class AskResult(BaseModel): class AskError(BaseModel): code: Literal[ "MISLEADING_QUERY", "NO_RELEVANT_DATA", "NO_RELEVANT_SQL", "OTHERS" - ] + ] # MISLEADING_QUERY is not in use now, we may add it back in the future when we implement the clarification pipeline message: str status: Literal[ @@ -177,25 +177,6 @@ async def ask( status="understanding", ) - query_understanding_result = await self._pipelines[ - "query_understanding" - ].run( - query=ask_request.query, - ) - - if not query_understanding_result["post_process"]["is_valid_query"]: - logger.exception( - f"ask pipeline - MISLEADING_QUERY: {ask_request.query}" - ) - self._ask_results[query_id] = AskResultResponse( - status="failed", - error=AskResultResponse.AskError( - code="MISLEADING_QUERY", - message="Misleading query, please ask a more specific question.", - ), - ) - return - if not self._is_stopped(query_id): self._ask_results[query_id] = AskResultResponse( status="searching", diff --git a/wren-ai-service/tests/pytest/pipelines/test_ask.py b/wren-ai-service/tests/pytest/pipelines/test_ask.py index 812332ba7..75ab9451b 100644 --- a/wren-ai-service/tests/pytest/pipelines/test_ask.py +++ b/wren-ai-service/tests/pytest/pipelines/test_ask.py @@ -7,7 +7,6 @@ from src.core.provider import DocumentStoreProvider, LLMProvider from src.pipelines.ask.followup_generation import FollowUpGeneration from src.pipelines.ask.generation import Generation -from src.pipelines.ask.query_understanding import QueryUnderstanding from src.pipelines.ask.retrieval import Retrieval from src.pipelines.ask.sql_correction import SQLCorrection from src.pipelines.indexing.indexing import Indexing @@ -98,18 +97,6 @@ def test_indexing_pipeline( ) -def test_query_understanding_pipeline(): - llm_provider, _, _ = init_providers() - pipeline = QueryUnderstanding(llm_provider=llm_provider) - - assert async_validate(lambda: pipeline.run("How many books are there?"))[ - "post_process" - ]["is_valid_query"] - assert not async_validate(lambda: pipeline.run("fds dsio me"))["post_process"][ - "is_valid_query" - ] - - def test_retrieval_pipeline( llm_provider: LLMProvider, document_store_provider: DocumentStoreProvider, diff --git a/wren-ai-service/tests/pytest/services/test_ask.py b/wren-ai-service/tests/pytest/services/test_ask.py index 2838583a6..ea38c196e 100644 --- a/wren-ai-service/tests/pytest/services/test_ask.py +++ b/wren-ai-service/tests/pytest/services/test_ask.py @@ -8,7 +8,6 @@ from src.pipelines.ask import ( generation, historical_question, - query_understanding, retrieval, sql_correction, ) @@ -32,9 +31,6 @@ def ask_service(): llm_provider=llm_provider, document_store_provider=document_store_provider, ), - "query_understanding": query_understanding.QueryUnderstanding( - llm_provider=llm_provider, - ), "retrieval": retrieval.Retrieval( llm_provider=llm_provider, document_store_provider=document_store_provider, @@ -107,47 +103,3 @@ def test_ask_with_successful_query(ask_service: AskService, mdl_str: str): # assert ask_result_response.response[0].sql != "" # assert ask_result_response.response[0].summary != "" # assert ask_result_response.response[0].type == "llm" or "view" - - -def test_ask_with_failed_query(ask_service: AskService, mdl_str: str): - id = str(uuid.uuid4()) - async_validate( - lambda: ask_service.prepare_semantics( - SemanticsPreparationRequest( - mdl=mdl_str, - id=id, - ) - ) - ) - - # asking - query_id = str(uuid.uuid4()) - ask_request = AskRequest( - query="xxxx", - id=id, - ) - ask_request.query_id = query_id - async_validate(lambda: ask_service.ask(ask_request)) - - # getting ask result - ask_result_response = ask_service.get_ask_result( - AskResultRequest( - query_id=query_id, - ) - ) - - # from Pao Sheng: I think it has a potential risk if a dangling status case happens. - # maybe we could consider adding an approach that if over a time limit, - # the process will throw an exception. - while ( - ask_result_response.status != "finished" - and ask_result_response.status != "failed" - ): - ask_result_response = ask_service.get_ask_result( - AskResultRequest( - query_id=query_id, - ) - ) - - assert ask_result_response.status == "failed" - assert ask_result_response.error.code == "MISLEADING_QUERY" diff --git a/wren-ai-service/tests/pytest/test_main.py b/wren-ai-service/tests/pytest/test_main.py index f4e867865..1bd70b62f 100644 --- a/wren-ai-service/tests/pytest/test_main.py +++ b/wren-ai-service/tests/pytest/test_main.py @@ -119,36 +119,6 @@ def test_asks_with_successful_query(): # assert r["summary"] is not None and r["summary"] != "" -def test_asks_with_failed_query(): - with TestClient(app) as client: - semantics_preparation_id = GLOBAL_DATA["semantics_preperation_id"] - - response = client.post( - url="/v1/asks", - json={ - "query": "xxxx", - "id": semantics_preparation_id, - }, - ) - - assert response.status_code == 200 - assert response.json()["query_id"] != "" - - query_id = response.json()["query_id"] - GLOBAL_DATA["query_id"] = query_id - - response = client.get(url=f"/v1/asks/{query_id}/result") - while ( - response.json()["status"] != "finished" - and response.json()["status"] != "failed" - ): - response = client.get(url=f"/v1/asks/{query_id}/result") - - assert response.status_code == 200 - assert response.json()["status"] == "failed" - assert response.json()["error"]["code"] == "MISLEADING_QUERY" - - def test_stop_asks(): with TestClient(app) as client: query_id = GLOBAL_DATA["query_id"]