Skip to content

Commit

Permalink
add motherduck support for duckdb plugin (#2680)
Browse files Browse the repository at this point in the history
* add motherduck support for duckdb plugin

* use secret group key and version

* change secret name and run make fmt

Signed-off-by: Daniel Sola <[email protected]>

* change token name

Signed-off-by: Daniel Sola <[email protected]>

* generalize to other duckdb providers

Signed-off-by: Daniel Sola <[email protected]>

* refactor for secret_requests

Signed-off-by: Daniel Sola <[email protected]>

* add query to execution

Signed-off-by: Daniel Sola <[email protected]>

* add query to execution pt 2

Signed-off-by: Daniel Sola <[email protected]>

* add query to execution pt 3

Signed-off-by: Daniel Sola <[email protected]>

* refactor for callable

Signed-off-by: Daniel Sola <[email protected]>

* add tests

Signed-off-by: Daniel Sola <[email protected]>

* add secret arg

Signed-off-by: Daniel Sola <[email protected]>

* assert secret length

Signed-off-by: Daniel Sola <[email protected]>

* move error message and add docstring

Signed-off-by: Daniel Sola <[email protected]>

* fix unit test

Signed-off-by: Daniel Sola <[email protected]>

* allow for no token to be passed

Signed-off-by: Daniel Sola <[email protected]>

---------

Signed-off-by: Daniel Sola <[email protected]>
  • Loading branch information
dansola authored Sep 3, 2024
1 parent 90699f2 commit d97090d
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 11 deletions.
2 changes: 1 addition & 1 deletion plugins/flytekit-duckdb/flytekitplugins/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
DuckDBQuery
"""

from .task import DuckDBQuery
from .task import DuckDBProvider, DuckDBQuery
92 changes: 85 additions & 7 deletions plugins/flytekit-duckdb/flytekitplugins/duckdb/task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
from typing import Dict, List, NamedTuple, Optional, Union
from enum import Enum
from functools import partial
from typing import Callable, Dict, List, NamedTuple, Optional, Union

from flytekit import PythonInstanceTask, lazy_module
from flytekit import PythonInstanceTask, Secret, current_context, lazy_module
from flytekit.extend import Interface
from flytekit.types.structured.structured_dataset import StructuredDataset

Expand All @@ -10,6 +12,25 @@
pa = lazy_module("pyarrow")


class MissingSecretError(ValueError):
pass


def connect_local():
"""Connect to local DuckDB."""
return duckdb.connect(":memory:")


def connect_motherduck(token: str):
"""Connect to MotherDuck."""
return duckdb.connect("md:", config={"motherduck_token": token})


class DuckDBProvider(Enum):
LOCAL = partial(connect_local)
MOTHERDUCK = partial(connect_motherduck)


class QueryOutput(NamedTuple):
counter: int = -1
output: Optional[str] = None
Expand All @@ -21,19 +42,53 @@ class DuckDBQuery(PythonInstanceTask):
def __init__(
self,
name: str,
query: Union[str, List[str]],
query: Optional[Union[str, List[str]]] = None,
inputs: Optional[Dict[str, Union[StructuredDataset, list]]] = None,
provider: Union[DuckDBProvider, Callable] = DuckDBProvider.LOCAL,
**kwargs,
):
"""
This method initializes the DuckDBQuery.
Note that the provider can be one of the default providers listed in DuckDBProvider or a custom callable like the following:
def custom_connect_motherduck(token: str):
return duckdb.connect("md:", config={"motherduck_token": token, "another_config": "hello"})
DuckDBQuery(..., provider=custom_connect_motherduck)
Also note that a query can be provided at runtime if query=None is provided.
duckdb_query = DuckDBQuery(
name="my_duckdb_query",
inputs=kwtypes(query=str)
)
@workflow
def wf(user_query: str) -> pd.DataFrame:
return duckdb_query(query=user_query)
Args:
name: Name of the task
query: DuckDB query to execute
inputs: The query parameters to be used while executing the query
provider: DuckDB provider
"""
self._query = query
self._provider = provider
secret_requests: Optional[list[Secret]] = kwargs.get("secret_requests", None)
self._connect_secret = None
if secret_requests:
assert len(secret_requests) == 1, "Only one secret can be used for a DuckDBQuery task."
self._connect_secret = secret_requests[0]

if (
self._connect_secret is None
and isinstance(self._provider, DuckDBProvider)
and self._provider != DuckDBProvider.LOCAL
):
raise MissingSecretError(f"A secret_requests must be provided for the {self._provider.name} provider.")

outputs = {"result": StructuredDataset}

super(DuckDBQuery, self).__init__(
Expand All @@ -44,6 +99,25 @@ def __init__(
**kwargs,
)

def _connect_to_duckdb(self):
"""
Handles the connection to DuckDB based on the provider.
Returns:
A DuckDB connection object.
"""
connect_token = None
if self._connect_secret:
connect_token = current_context().secrets.get(
group=self._connect_secret.group,
key=self._connect_secret.key,
group_version=self._connect_secret.group_version,
)
if isinstance(self._provider, DuckDBProvider):
return self._provider.value(connect_token) if connect_token else self._provider.value()
else: # callable
return self._provider(connect_token) if connect_token else self._provider()

def _execute_query(
self, con: duckdb.DuckDBPyConnection, params: list, query: str, counter: int, multiple_params: bool
):
Expand Down Expand Up @@ -76,14 +150,15 @@ def _execute_query(

def execute(self, **kwargs) -> StructuredDataset:
# TODO: Enable iterative download after adding the functionality to structured dataset code.

# create an in-memory database that's non-persistent
con = duckdb.connect(":memory:")
con = self._connect_to_duckdb()

params = None
for key in self.python_interface.inputs.keys():
val = kwargs.get(key)
if isinstance(val, StructuredDataset):
if key == "query" and val is not None:
# Execution query takes priority
self._query = val
elif isinstance(val, StructuredDataset):
# register structured dataset
con.register(key, val.open(pa.Table).all())
elif isinstance(val, (pd.DataFrame, pa.Table)):
Expand All @@ -98,6 +173,9 @@ def execute(self, **kwargs) -> StructuredDataset:
else:
raise ValueError(f"Expected inputs of type StructuredDataset, str or list, received {type(val)}")

if self._query is None:
raise ValueError("A query must be specified when defining or executing a DuckDBQuery.")

final_query = self._query
query_output = QueryOutput()
# set flag to indicate the presence of params for multiple queries
Expand Down
31 changes: 28 additions & 3 deletions plugins/flytekit-duckdb/tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import json
from typing import List

import pytest
import pandas as pd
import pyarrow as pa
from flytekitplugins.duckdb import DuckDBQuery
from flytekitplugins.duckdb import DuckDBQuery, DuckDBProvider
from flytekitplugins.duckdb.task import MissingSecretError
from typing_extensions import Annotated

from flytekit import kwtypes, task, workflow
from flytekit import kwtypes, task, workflow, Secret
from flytekit.types.structured.structured_dataset import StructuredDataset


Expand Down Expand Up @@ -146,3 +147,27 @@ def params_wf(params: str) -> pa.Table:
return duckdb_params_query(params=params)

assert isinstance(params_wf(params=json.dumps([[[500], [300], [2]]])), pa.Table)


def test_motherduck_no_token():
with pytest.raises(MissingSecretError, match="A secret_requests must be provided for the MOTHERDUCK provider."):
duckdb_params_query = DuckDBQuery(
name="motherduck_query",
query="SELECT SUM(a) FROM sometable",
provider=DuckDBProvider.MOTHERDUCK,
)


def test_runtime_query():
runtime_duckdb_query = DuckDBQuery(
name="runtime_query", inputs=kwtypes(mydf=pd.DataFrame, query=str)
)

@workflow
def pandas_wf(mydf: pd.DataFrame, query: str) -> pd.DataFrame:
return runtime_duckdb_query(mydf=df, query=query)

df = pd.DataFrame({"a": [1, 2, 3]})
query = "SELECT SUM(a) FROM mydf"
assert isinstance(pandas_wf(mydf=df, query=query), pd.DataFrame)
assert pandas_wf(mydf=df, query=query).iloc[0, 0] == 6

0 comments on commit d97090d

Please sign in to comment.