diff --git a/.vscode/launch.json b/.vscode/launch.json index e615ed5..3db4437 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -15,7 +15,7 @@ "name":"Debug Python client", "type":"debugpy", "request":"launch", - "program":"${workspaceFolder}/client/client.py", + "program":"${workspaceFolder}/client/main.py", "console":"integratedTerminal", "env": { "PYTHONPATH": "${workspaceFolder}/client/gen" diff --git a/client/gduck/__init__.py b/client/gduck/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/client.py b/client/gduck/client.py similarity index 62% rename from client/client.py rename to client/gduck/client.py index f83479b..31234ea 100644 --- a/client/client.py +++ b/client/gduck/client.py @@ -2,14 +2,21 @@ import threading from dataclasses import dataclass +from pathlib import Path from queue import SimpleQueue from types import TracebackType -from typing import Literal, Self +from typing import Self import grpc +from error_pb2 import Error from grpc._channel import _MultiThreadedRendezvous +from query_pb2 import Query +from service_pb2 import Request, Response from service_pb2_grpc import DbServiceStub +from .request import ConnectionMode, Value, connect, ctas, execute, local_file, parquet, request, rows, value +from .response import parse_location, parse_rows, parse_value + @dataclass(frozen=True) class Addr: @@ -46,7 +53,10 @@ def __init__(self, responses: _MultiThreadedRendezvous, out: SimpleQueue, group: def run(self) -> None: try: for response in self._responses: - self._out.put(response.result) + if response.HasField("success"): + self._out.put(response.success) + elif response.HasField("error"): + self._out.put(response.error) except _MultiThreadedRendezvous as e: if e.code() != grpc.StatusCode.CANCELLED: raise e @@ -71,10 +81,29 @@ def _request_generator(self): while (request := self._requests.get()) != self._END_STREAM: yield request - def query(self, query: str) -> None: - self._requests.put(self._query_request(query)) + def _query(self, query: Query) -> Response.QueryResult | Error: + self._requests.put(request(query)) return self._results.get() + def execute(self, query: str, *params: tuple[Value]) -> None: + self._query(execute(query, *params)) + + def query_value(self, query: str, *params: tuple[Value]) -> Value: + result = self._query(value(query, *params)) + return parse_value(result.value) + + def query_rows(self, query: str, *params: tuple[Value]) -> Value: + result = self._query(rows(query, *params)) + _, r = parse_rows(result.rows) + return r + + def ctas(self, table_name: str, query: str, *params: tuple[Value]) -> None: + self._query(ctas(table_name, query, *params)) + + def local_parquet(self, file: Path, query: str, *params: tuple[Value]) -> Path: + result = self._query(parquet(local_file(file), query, *params)) + return parse_location(result.parquet_file) + def __enter__(self) -> Self: self._channel = grpc.insecure_channel(target=str(self._addr)) @@ -90,20 +119,5 @@ def __exit__(self, exc_type: type, exc_value: Exception, traceback: TracebackTyp self._channel.close() return False - @property - def mode(self) -> Request.Connect.Mode: - if self._mode == "auto": - return Request.Connect.Mode.MODE_AUTO - elif self._mode == "read_write": - return Request.Connect.Mode.MODE_READ_WRITE - elif self._mode == "read_only": - return Request.Connect.Mode.MODE_READ_ONLY - else: - raise ValueError(f"Unknown mode: {self._mode}") - def _connect_request(self) -> Request: - return Request(connect=Request.Connect(file_name=self._database_file, mode=self.mode)) - - @staticmethod - def _query_request(query: str) -> Request: - return Request(query=Request.Query(query=query)) + return request(kind=connect(file_name=self._database_file, mode=self._mode)) diff --git a/client/gduck/request.py b/client/gduck/request.py new file mode 100644 index 0000000..0671174 --- /dev/null +++ b/client/gduck/request.py @@ -0,0 +1,104 @@ +import math +from datetime import date, datetime, time +from decimal import Decimal +from pathlib import Path +from typing import Literal + +from database_pb2 import Connect, Date +from database_pb2 import Decimal as ProtoDecimal +from database_pb2 import Interval, Params, ScalarValue, Time +from dateutil.relativedelta import relativedelta +from google.protobuf.struct_pb2 import NULL_VALUE +from google.protobuf.timestamp_pb2 import Timestamp +from location_pb2 import Location +from query_pb2 import Query +from service_pb2 import Request + +from .types import Value + +__all__ = ["ConnectionMode", "connect", "local_file", "execute", "value", "rows", "ctas", "parquet", "request"] + +ConnectionMode = Literal["auto", "read_write", "read_only"] + + +def _mode(m: ConnectionMode = Connect.Mode.MODE_AUTO) -> Connect.Mode: + if m == "read_write": + return Connect.Mode.MODE_READ_WRITE + elif m == "read_only": + return Connect.Mode.MODE_READ_ONLY + else: + return Connect.Mode.MODE_AUTO + + +def connect(file_name: str, mode: ConnectionMode) -> Connect: + return Connect(file_name=file_name, mode=_mode(mode)) + + +def _value(v: Value) -> ScalarValue: + if v is None: + return ScalarValue(null_value=NULL_VALUE) + elif type(v) is bool: + return ScalarValue(bool_value=v) + elif type(v) is int: + return ScalarValue(int_value=v) + elif type(v) is float: + return ScalarValue(double_value=v) + elif type(v) is Decimal: + return ScalarValue(decimal_value=ProtoDecimal(value=str(v))) + elif type(v) is str: + return ScalarValue(str_value=v) + elif type(v) is datetime: + fraction, seconds = math.modf(v.timestamp()) + return ScalarValue(datetime_value=Timestamp(seconds=int(seconds), nanos=int(fraction * 1000000000))) + elif type(v) is date: + return ScalarValue(date_value=Date(year=v.year, month=v.month, day=v.day)) + elif type(v) is time: + return ScalarValue(time_value=Time(hours=v.hour, minutes=v.minute, seconds=v.second, nanos=v.microsecond * 1000)) + elif type(v) is relativedelta: + nv = v.normalized() + return ScalarValue( + interval_value=Interval( + months=12 * nv.years + nv.months, + days=nv.days, + nanos=1000000000 * (nv.hours * 3600 + nv.minutes * 60 + nv.seconds) + 1000 * nv.microseconds, + ) + ) + else: + raise ValueError(f"Invalid type of value: {v} ({type(v)})") + + +def _params(*params: tuple[Value]) -> Params: + return Params(params=[_value(p) for p in params]) + + +def local_file(path: Path) -> Location: + return Location(local=Location.LocalFile(path=str(path))) + + +def execute(query: str, *params: tuple[Value]) -> Query: + return Query(execute=Query.Execute(query=query, params=_params(*params))) + + +def value(query: str, *params: tuple[Value]) -> Query: + return Query(value=Query.QueryValue(query=query, params=_params(*params))) + + +def rows(query: str, *params: tuple[Value]) -> Query: + return Query(rows=Query.QueryRows(query=query, params=_params(*params))) + + +def ctas(table_name: str, query: str, *params: tuple[Value]) -> Query: + return Query(ctas=Query.CreateTableAsQuery(table_name=table_name, query=query, params=_params(*params))) + + +def parquet(location: Location, query: str, *params: tuple[Value]) -> Query: + return Query(parquet=Query.ParquetQuery(location=location, query=query, params=_params(*params))) + + +def request(kind: Connect | Query) -> Request: + if type(kind) is Connect: + return Request(connect=kind) + elif type(kind) is Query: + return Request(query=kind) + else: + raise ValueError(f"unsupported type of message: {kind}") diff --git a/client/gduck/response.py b/client/gduck/response.py new file mode 100644 index 0000000..28acfb3 --- /dev/null +++ b/client/gduck/response.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from datetime import date, datetime, time +from decimal import Decimal +from pathlib import Path +from typing import Callable + +from database_pb2 import DataType, Rows, ScalarValue +from dateutil.relativedelta import relativedelta +from location_pb2 import Location + +from .types import ParquetLocation, Schema, Value + +__all__ = ["parse_value", "parse_rows", "parse_location"] + + +def _null_value(v: ScalarValue) -> None: + return None + + +def _bool_value(v: ScalarValue) -> bool: + return v.bool_value + + +def _int_value(v: ScalarValue) -> int: + return v.int_value + + +def _uint_value(v: ScalarValue) -> int: + return v.uint_value + + +def _double_value(v: ScalarValue) -> float: + return v.double_value + + +def _decimal_value(v: ScalarValue) -> Decimal: + return Decimal(v.decimal_value.value) + + +def _str_value(v: ScalarValue) -> str: + return v.str_value + + +def _datetime_value(v: ScalarValue) -> datetime: + return v.datetime_value.ToDatetime() + + +def _date_value(v: ScalarValue) -> date: + dt = v.date_value + return date(year=dt.year, month=dt.month, day=dt.day) + + +def _time_value(v: ScalarValue) -> time: + t = v.time_value + return time( + hour=t.hours, minute=t.minutes, second=t.seconds, microsecond=int(t.nanos / 1000) + ) # FIXME python time does not support nanosecond precision + + +def _interval_value(v: ScalarValue) -> relativedelta: + interval = v.interval_value + return relativedelta( + months=interval.months, days=interval.days, microseconds=int(interval.nanos / 1000) + ) # FIXME python relativedelta does not support nanosecond precision + + +def _getter(data_type: DataType) -> Callable[[ScalarValue], Value]: + if data_type == DataType.DATATYPE_NULL: + return _null_value + elif data_type == DataType.DATATYPE_BOOL: + return _bool_value + elif data_type in (DataType.DATATYPE_INT, DataType.DATATYPE_UINT): + return _int_value + elif data_type == DataType.DATATYPE_DOUBLE: + return _double_value + elif data_type == DataType.DATATYPE_DECIMAL: + return _decimal_value + elif data_type == DataType.DATATYPE_STRING: + return _str_value + elif data_type == DataType.DATATYPE_DATETIME: + return _datetime_value + elif data_type == DataType.DATATYPE_DATE: + return _date_value + elif data_type == DataType.DATATYPE_TIME: + return _time_value + elif data_type == DataType.DATATYPE_INTERVAL: + return _interval_value + else: + raise ValueError(f"unknown data type: {data_type}") + + +def parse_value(v: ScalarValue) -> Value: + if v.HasField("null_value"): + return None + elif v.HasField("bool_value"): + return _bool_value(v) + elif v.HasField("int_value"): + return _int_value(v) + elif v.HasField("uint_value"): + return _uint_value(v) + elif v.HasField("double_value"): + return _double_value(v) + elif v.HasField("decimal_value"): + return _decimal_value(v) + elif v.HasField("str_value"): + return _str_value(v) + elif v.HasField("datetime_value"): + return _datetime_value(v) + elif v.HasField("date_value"): + return _date_value(v) + elif v.HasField("time_value"): + return _time_value(v) + elif v.HasField("interval_value"): + return _interval_value(v) + else: + raise ValueError(f"unknown type of value {v}") + + +def parse_rows(rows: Rows) -> tuple[Schema, list[tuple[Value]]]: + schema = Schema.from_proto(rows.schema) + + getters = [_getter(col.data_type) for col in schema] + + ret = [] + for row in rows.rows: + values = tuple(getters[i](v) for i, v in enumerate(row.values)) + ret.append(values) + + return schema, ret + + +def parse_location(location: Location) -> ParquetLocation: + return Path(location.local.path) diff --git a/client/gduck/types.py b/client/gduck/types.py new file mode 100644 index 0000000..9d9d37e --- /dev/null +++ b/client/gduck/types.py @@ -0,0 +1,53 @@ +from dataclasses import dataclass +from datetime import date, datetime, time +from decimal import Decimal +from pathlib import Path +from typing import Iterator, Self, TypeAlias + +from database_pb2 import Column as ProtoColumn +from database_pb2 import DataType +from database_pb2 import Schema as ProtoSchema +from dateutil.relativedelta import relativedelta + +Value: TypeAlias = bool | int | float | Decimal | str | datetime | date | time | relativedelta | None + + +@dataclass(frozen=True) +class Column: + name: str + data_type: DataType + + @classmethod + def from_proto(cls, col: ProtoColumn) -> Self: + return cls(name=col.name, data_type=col.data_type) + + +@dataclass(frozen=True) +class Schema: + columns: list[Column] + + def __iter__(self) -> Iterator[Column]: + return iter(self.columns) + + def __getitem__(self, key: int) -> Column: + return self.columns[key] + + def names(self) -> list[str]: + return [col.name for col in self.columns] + + @classmethod + def from_proto(cls, schema: ProtoSchema) -> Self: + columns = [Column.from_proto(col) for col in schema.columns] + return cls(columns) + + +@dataclass(frozen=True) +class Rows: + schema: Schema + rows: list[tuple[Value]] + + def __iter__(self) -> Iterator[tuple[Value]]: + return iter(self.rows) + + +ParquetLocation: TypeAlias = Path diff --git a/client/main.py b/client/main.py index eb959f6..d22e062 100644 --- a/client/main.py +++ b/client/main.py @@ -1,6 +1,7 @@ -from client import Connection +from gduck.client import Connection if __name__ == "__main__": - with Connection("localhost:50051").transaction(database_file="example.duckdb", mode="read_write") as trans: - result = trans.query("SELECT '1'") - print(result) + with Connection("localhost:50051").transaction(database_file="datasets/example.duckdb", mode="read_write") as trans: + result = trans.query_rows("SELECT * FROM videos WHERE comment_count > ? LIMIT 10;", 2) + for row in result: + print(row) diff --git a/client/request.py b/client/request.py deleted file mode 100644 index fc09c19..0000000 --- a/client/request.py +++ /dev/null @@ -1,57 +0,0 @@ -import math -from datetime import date, datetime, time, timedelta -from decimal import Decimal -from typing import Literal, TypeAlias - -from database_pb2 import Connect, DataType, Date -from database_pb2 import Decimal as ProtoDecimal -from database_pb2 import Interval, Params, Row, Rows, ScalarValue, Schema, Time -from error_pb2 import Error, ErrorCode -from google.protobuf.struct_pb2 import NULL_VALUE -from google.protobuf.timestamp_pb2 import Timestamp -from location_pb2 import Location -from query_pb2 import Query - -__all__ = ["ConnectionMode", "connect"] - -ConnectionMode = Literal["auto", "read_write", "read_only"] -Value: TypeAlias = bool | int | float | Decimal | str | datetime | date | time | timedelta | None - - -def _mode(m: ConnectionMode = Connect.Mode.MODE_AUTO) -> Connect.Mode: - if m == "read_write": - return Connect.Mode.MODE_READ_WRITE - elif m == "read_only": - return Connect.Mode.MODE_READ_ONLY - else: - return Connect.Mode.MODE_AUTO - - -def connect(file_name: str, mode: ConnectionMode) -> Connect: - return Connect(file_name=file_name, mode=_mode(mode)) - - -def _value(v: Value) -> ScalarValue: - if v is None: - return ScalarValue(null_value=NULL_VALUE) - elif type(v) is bool: - return ScalarValue(bool_value=v) - elif type(v) is int: - return ScalarValue(int_value=v) - elif type(v) is float: - return ScalarValue(double_value=v) - elif type(v) is Decimal: - return ScalarValue(decimal_value=ProtoDecimal(value=str(v))) - elif type(v) is str: - return ScalarValue(str_value=v) - elif type(v) is datetime: - fraction, seconds = math.modf(v.timestamp()) - return ScalarValue(datetime_value=Timestamp(seconds=int(seconds), nanos=int(fraction * 1000000000))) - elif type(v) is date: - return ScalarValue(date_value=Date(year=v.year, month=v.month, day=v.day)) - elif type(v) is time: - return ScalarValue(time_value=Time(hours=v.hour, minutes=v.minute, seconds=v.second, nanos=v.microsecond * 1000)) - elif type(v) is timedelta: - pass - else: - raise ValueError(f"Invalid type of value: {v} ({type(v)})")