Skip to content

Commit

Permalink
remove query understanding (#463)
Browse files Browse the repository at this point in the history
* ignore query understanding pipeline

* fix tests

* remove query understanding

* remove unrelated code

* add comemnts
  • Loading branch information
cyyeh authored Jul 2, 2024
1 parent e97a9de commit 758096e
Show file tree
Hide file tree
Showing 7 changed files with 2 additions and 276 deletions.
10 changes: 1 addition & 9 deletions wren-ai-service/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,7 @@ dev-down:

## wren-ai-service related ##
start:
poetry run python -m src.__main__ & \
make force_deploy

force_deploy:
while ! nc -z localhost 5556; do \
sleep 1; \
done; \
echo "wren-ai-service is up and running" && \
poetry run python src/force_deploy.py
poetry run python -m src.__main__

build:
docker compose -f docker/docker-compose.yaml --env-file .env.prod build
Expand Down
6 changes: 0 additions & 6 deletions wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
from src.pipelines.ask import (
historical_question,
)
from src.pipelines.ask import (
query_understanding as ask_query_understanding,
)
from src.pipelines.ask import (
retrieval as ask_retrieval,
)
Expand Down Expand Up @@ -60,9 +57,6 @@ def init_globals():
llm_provider=llm_provider,
document_store_provider=document_store_provider,
),
"query_understanding": ask_query_understanding.QueryUnderstanding(
llm_provider=llm_provider,
),
"retrieval": ask_retrieval.Retrieval(
llm_provider=llm_provider,
document_store_provider=document_store_provider,
Expand Down
150 changes: 0 additions & 150 deletions wren-ai-service/src/pipelines/ask/query_understanding.py

This file was deleted.

21 changes: 1 addition & 20 deletions wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class AskResult(BaseModel):
class AskError(BaseModel):
code: Literal[
"MISLEADING_QUERY", "NO_RELEVANT_DATA", "NO_RELEVANT_SQL", "OTHERS"
]
] # MISLEADING_QUERY is not in use now, we may add it back in the future when we implement the clarification pipeline
message: str

status: Literal[
Expand Down Expand Up @@ -177,25 +177,6 @@ async def ask(
status="understanding",
)

query_understanding_result = await self._pipelines[
"query_understanding"
].run(
query=ask_request.query,
)

if not query_understanding_result["post_process"]["is_valid_query"]:
logger.exception(
f"ask pipeline - MISLEADING_QUERY: {ask_request.query}"
)
self._ask_results[query_id] = AskResultResponse(
status="failed",
error=AskResultResponse.AskError(
code="MISLEADING_QUERY",
message="Misleading query, please ask a more specific question.",
),
)
return

if not self._is_stopped(query_id):
self._ask_results[query_id] = AskResultResponse(
status="searching",
Expand Down
13 changes: 0 additions & 13 deletions wren-ai-service/tests/pytest/pipelines/test_ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from src.core.provider import DocumentStoreProvider, LLMProvider
from src.pipelines.ask.followup_generation import FollowUpGeneration
from src.pipelines.ask.generation import Generation
from src.pipelines.ask.query_understanding import QueryUnderstanding
from src.pipelines.ask.retrieval import Retrieval
from src.pipelines.ask.sql_correction import SQLCorrection
from src.pipelines.indexing.indexing import Indexing
Expand Down Expand Up @@ -98,18 +97,6 @@ def test_indexing_pipeline(
)


def test_query_understanding_pipeline():
llm_provider, _, _ = init_providers()
pipeline = QueryUnderstanding(llm_provider=llm_provider)

assert async_validate(lambda: pipeline.run("How many books are there?"))[
"post_process"
]["is_valid_query"]
assert not async_validate(lambda: pipeline.run("fds dsio me"))["post_process"][
"is_valid_query"
]


def test_retrieval_pipeline(
llm_provider: LLMProvider,
document_store_provider: DocumentStoreProvider,
Expand Down
48 changes: 0 additions & 48 deletions wren-ai-service/tests/pytest/services/test_ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from src.pipelines.ask import (
generation,
historical_question,
query_understanding,
retrieval,
sql_correction,
)
Expand All @@ -32,9 +31,6 @@ def ask_service():
llm_provider=llm_provider,
document_store_provider=document_store_provider,
),
"query_understanding": query_understanding.QueryUnderstanding(
llm_provider=llm_provider,
),
"retrieval": retrieval.Retrieval(
llm_provider=llm_provider,
document_store_provider=document_store_provider,
Expand Down Expand Up @@ -107,47 +103,3 @@ def test_ask_with_successful_query(ask_service: AskService, mdl_str: str):
# assert ask_result_response.response[0].sql != ""
# assert ask_result_response.response[0].summary != ""
# assert ask_result_response.response[0].type == "llm" or "view"


def test_ask_with_failed_query(ask_service: AskService, mdl_str: str):
id = str(uuid.uuid4())
async_validate(
lambda: ask_service.prepare_semantics(
SemanticsPreparationRequest(
mdl=mdl_str,
id=id,
)
)
)

# asking
query_id = str(uuid.uuid4())
ask_request = AskRequest(
query="xxxx",
id=id,
)
ask_request.query_id = query_id
async_validate(lambda: ask_service.ask(ask_request))

# getting ask result
ask_result_response = ask_service.get_ask_result(
AskResultRequest(
query_id=query_id,
)
)

# from Pao Sheng: I think it has a potential risk if a dangling status case happens.
# maybe we could consider adding an approach that if over a time limit,
# the process will throw an exception.
while (
ask_result_response.status != "finished"
and ask_result_response.status != "failed"
):
ask_result_response = ask_service.get_ask_result(
AskResultRequest(
query_id=query_id,
)
)

assert ask_result_response.status == "failed"
assert ask_result_response.error.code == "MISLEADING_QUERY"
30 changes: 0 additions & 30 deletions wren-ai-service/tests/pytest/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,36 +119,6 @@ def test_asks_with_successful_query():
# assert r["summary"] is not None and r["summary"] != ""


def test_asks_with_failed_query():
with TestClient(app) as client:
semantics_preparation_id = GLOBAL_DATA["semantics_preperation_id"]

response = client.post(
url="/v1/asks",
json={
"query": "xxxx",
"id": semantics_preparation_id,
},
)

assert response.status_code == 200
assert response.json()["query_id"] != ""

query_id = response.json()["query_id"]
GLOBAL_DATA["query_id"] = query_id

response = client.get(url=f"/v1/asks/{query_id}/result")
while (
response.json()["status"] != "finished"
and response.json()["status"] != "failed"
):
response = client.get(url=f"/v1/asks/{query_id}/result")

assert response.status_code == 200
assert response.json()["status"] == "failed"
assert response.json()["error"]["code"] == "MISLEADING_QUERY"


def test_stop_asks():
with TestClient(app) as client:
query_id = GLOBAL_DATA["query_id"]
Expand Down

0 comments on commit 758096e

Please sign in to comment.