Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(function-list): search file by data source name #859

Merged
merged 2 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ibis-server/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ def update(self, diagnose: bool):
else:
self.init_logger()

def get_remote_function_list_path(self, data_source: str) -> str:
return (
f"{self.remote_function_list_path}/{data_source}.csv"
if self.remote_function_list_path
else None
)

def set_remote_function_list_path(self, path: str):
self.remote_function_list_path = path

Expand Down
2 changes: 1 addition & 1 deletion ibis-server/app/mdl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
self.data_source = data_source
if experiment:
config = get_config()
function_path = config.remote_function_list_path
function_path = config.get_remote_function_list_path(data_source)
self._rewriter = EmbeddedEngineRewriter(manifest_str, function_path)
else:
self._rewriter = ExternalEngineRewriter(manifest_str)
Expand Down
15 changes: 7 additions & 8 deletions ibis-server/app/routers/v3/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from fastapi import APIRouter, Depends, Query, Response
from fastapi.responses import JSONResponse

from app.config import get_config
from app.dependencies import verify_query_dto
from app.mdl.rewriter import Rewriter
from app.model import (
Expand Down Expand Up @@ -57,14 +58,12 @@ def validate(data_source: DataSource, rule_name: str, dto: ValidateDTO) -> Respo
return Response(status_code=204)


@router.get("/functions")
def functions() -> Response:
@router.get("/{data_source}/functions")
def functions(data_source: DataSource) -> Response:
from wren_core import SessionContext

from app.config import get_config
file_path = get_config().get_remote_function_list_path(data_source)
session_context = SessionContext(None, file_path)
func_list = [f.to_dict() for f in session_context.get_available_functions()]

config = get_config()
session_context = SessionContext(None, config.remote_function_list_path)
functions = [f.to_dict() for f in session_context.get_available_functions()]

return JSONResponse(functions)
return JSONResponse(func_list)
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
function_type,name,return_type,description
scalar,add_two,int,"Adds two numbers together."
aggregate,median,int,"Returns the median value of a numeric column."
window,max_if,int,"If the condition is true, returns the maximum value in the window."
scalar,unistr,varchar,"Postgres: Evaluate escaped Unicode characters in the argument".
scalar,unistr,varchar,"Postgres: Evaluate escaped Unicode characters in the argument".
23 changes: 13 additions & 10 deletions ibis-server/tests/routers/v3/connector/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def test_dry_plan(manifest_str):

def test_query_with_remote_function(manifest_str, postgres: PostgresContainer):
config = get_config()
config.set_remote_function_list_path(file_path("resource/functions.csv"))
config.set_remote_function_list_path(file_path("resource/function_list"))

connection_info = _to_connection_info(postgres)
response = client.post(
Expand All @@ -391,22 +391,25 @@ def test_query_with_remote_function(manifest_str, postgres: PostgresContainer):

def test_function_list():
config = get_config()
config.set_remote_function_list_path(file_path("resource/functions.csv"))
config.set_remote_function_list_path(file_path("resource/function_list"))

response = client.get(
url="/v3/connector/functions",
)
response = client.get(url=f"{base_url}/functions")
assert response.status_code == 200
result = response.json()
assert len(result) == 261
add_two = next(filter(lambda x: x["name"] == "add_two", result))
assert add_two["name"] == "add_two"
the_func = next(filter(lambda x: x["name"] == "add_two", result))
assert the_func == {
"name": "add_two",
"description": "Adds two numbers together.",
"function_type": "scalar",
"param_names": None,
"param_types": None,
"return_type": "int",
}

config.set_remote_function_list_path(None)

response = client.get(
url="/v3/connector/functions",
)
response = client.get(url=f"{base_url}/functions")
assert response.status_code == 200
result = response.json()
assert len(result) == 258
Expand Down
Loading