Skip to content

Commit

Permalink
add request adapter to python client
Browse files Browse the repository at this point in the history
  • Loading branch information
saint1991 committed Oct 19, 2024
1 parent a4b2c89 commit ccb1013
Show file tree
Hide file tree
Showing 8 changed files with 331 additions and 82 deletions.
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"name":"Debug Python client",
"type":"debugpy",
"request":"launch",
"program":"${workspaceFolder}/client/client.py",
"program":"${workspaceFolder}/client/main.py",
"console":"integratedTerminal",
"env": {
"PYTHONPATH": "${workspaceFolder}/client/gen"
Expand Down
Empty file added client/gduck/__init__.py
Empty file.
54 changes: 34 additions & 20 deletions client/client.py → client/gduck/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@

import threading
from dataclasses import dataclass
from pathlib import Path
from queue import SimpleQueue
from types import TracebackType
from typing import Literal, Self
from typing import Self

import grpc
from error_pb2 import Error
from grpc._channel import _MultiThreadedRendezvous
from query_pb2 import Query
from service_pb2 import Request, Response
from service_pb2_grpc import DbServiceStub

from .request import ConnectionMode, Value, connect, ctas, execute, local_file, parquet, request, rows, value
from .response import parse_location, parse_rows, parse_value


@dataclass(frozen=True)
class Addr:
Expand Down Expand Up @@ -46,7 +53,10 @@ def __init__(self, responses: _MultiThreadedRendezvous, out: SimpleQueue, group:
def run(self) -> None:
try:
for response in self._responses:
self._out.put(response.result)
if response.HasField("success"):
self._out.put(response.success)
elif response.HasField("error"):
self._out.put(response.error)
except _MultiThreadedRendezvous as e:
if e.code() != grpc.StatusCode.CANCELLED:
raise e
Expand All @@ -71,10 +81,29 @@ def _request_generator(self):
while (request := self._requests.get()) != self._END_STREAM:
yield request

def query(self, query: str) -> None:
self._requests.put(self._query_request(query))
def _query(self, query: Query) -> Response.QueryResult | Error:
self._requests.put(request(query))
return self._results.get()

def execute(self, query: str, *params: tuple[Value]) -> None:
self._query(execute(query, *params))

def query_value(self, query: str, *params: tuple[Value]) -> Value:
result = self._query(value(query, *params))
return parse_value(result.value)

def query_rows(self, query: str, *params: tuple[Value]) -> Value:
result = self._query(rows(query, *params))
_, r = parse_rows(result.rows)
return r

def ctas(self, table_name: str, query: str, *params: tuple[Value]) -> None:
self._query(ctas(table_name, query, *params))

def local_parquet(self, file: Path, query: str, *params: tuple[Value]) -> Path:
result = self._query(parquet(local_file(file), query, *params))
return parse_location(result.parquet_file)

def __enter__(self) -> Self:
self._channel = grpc.insecure_channel(target=str(self._addr))

Expand All @@ -90,20 +119,5 @@ def __exit__(self, exc_type: type, exc_value: Exception, traceback: TracebackTyp
self._channel.close()
return False

@property
def mode(self) -> Request.Connect.Mode:
if self._mode == "auto":
return Request.Connect.Mode.MODE_AUTO
elif self._mode == "read_write":
return Request.Connect.Mode.MODE_READ_WRITE
elif self._mode == "read_only":
return Request.Connect.Mode.MODE_READ_ONLY
else:
raise ValueError(f"Unknown mode: {self._mode}")

def _connect_request(self) -> Request:
return Request(connect=Request.Connect(file_name=self._database_file, mode=self.mode))

@staticmethod
def _query_request(query: str) -> Request:
return Request(query=Request.Query(query=query))
return request(kind=connect(file_name=self._database_file, mode=self._mode))
104 changes: 104 additions & 0 deletions client/gduck/request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import math
from datetime import date, datetime, time
from decimal import Decimal
from pathlib import Path
from typing import Literal

from database_pb2 import Connect, Date
from database_pb2 import Decimal as ProtoDecimal
from database_pb2 import Interval, Params, ScalarValue, Time
from dateutil.relativedelta import relativedelta
from google.protobuf.struct_pb2 import NULL_VALUE
from google.protobuf.timestamp_pb2 import Timestamp
from location_pb2 import Location
from query_pb2 import Query
from service_pb2 import Request

from .types import Value

__all__ = ["ConnectionMode", "connect", "local_file", "execute", "value", "rows", "ctas", "parquet", "request"]

ConnectionMode = Literal["auto", "read_write", "read_only"]


def _mode(m: ConnectionMode = Connect.Mode.MODE_AUTO) -> Connect.Mode:
if m == "read_write":
return Connect.Mode.MODE_READ_WRITE
elif m == "read_only":
return Connect.Mode.MODE_READ_ONLY
else:
return Connect.Mode.MODE_AUTO


def connect(file_name: str, mode: ConnectionMode) -> Connect:
return Connect(file_name=file_name, mode=_mode(mode))


def _value(v: Value) -> ScalarValue:
if v is None:
return ScalarValue(null_value=NULL_VALUE)
elif type(v) is bool:
return ScalarValue(bool_value=v)
elif type(v) is int:
return ScalarValue(int_value=v)
elif type(v) is float:
return ScalarValue(double_value=v)
elif type(v) is Decimal:
return ScalarValue(decimal_value=ProtoDecimal(value=str(v)))
elif type(v) is str:
return ScalarValue(str_value=v)
elif type(v) is datetime:
fraction, seconds = math.modf(v.timestamp())
return ScalarValue(datetime_value=Timestamp(seconds=int(seconds), nanos=int(fraction * 1000000000)))
elif type(v) is date:
return ScalarValue(date_value=Date(year=v.year, month=v.month, day=v.day))
elif type(v) is time:
return ScalarValue(time_value=Time(hours=v.hour, minutes=v.minute, seconds=v.second, nanos=v.microsecond * 1000))
elif type(v) is relativedelta:
nv = v.normalized()
return ScalarValue(
interval_value=Interval(
months=12 * nv.years + nv.months,
days=nv.days,
nanos=1000000000 * (nv.hours * 3600 + nv.minutes * 60 + nv.seconds) + 1000 * nv.microseconds,
)
)
else:
raise ValueError(f"Invalid type of value: {v} ({type(v)})")


def _params(*params: tuple[Value]) -> Params:
return Params(params=[_value(p) for p in params])


def local_file(path: Path) -> Location:
return Location(local=Location.LocalFile(path=str(path)))


def execute(query: str, *params: tuple[Value]) -> Query:
return Query(execute=Query.Execute(query=query, params=_params(*params)))


def value(query: str, *params: tuple[Value]) -> Query:
return Query(value=Query.QueryValue(query=query, params=_params(*params)))


def rows(query: str, *params: tuple[Value]) -> Query:
return Query(rows=Query.QueryRows(query=query, params=_params(*params)))


def ctas(table_name: str, query: str, *params: tuple[Value]) -> Query:
return Query(ctas=Query.CreateTableAsQuery(table_name=table_name, query=query, params=_params(*params)))


def parquet(location: Location, query: str, *params: tuple[Value]) -> Query:
return Query(parquet=Query.ParquetQuery(location=location, query=query, params=_params(*params)))


def request(kind: Connect | Query) -> Request:
if type(kind) is Connect:
return Request(connect=kind)
elif type(kind) is Query:
return Request(query=kind)
else:
raise ValueError(f"unsupported type of message: {kind}")
134 changes: 134 additions & 0 deletions client/gduck/response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from __future__ import annotations

from datetime import date, datetime, time
from decimal import Decimal
from pathlib import Path
from typing import Callable

from database_pb2 import DataType, Rows, ScalarValue
from dateutil.relativedelta import relativedelta
from location_pb2 import Location

from .types import ParquetLocation, Schema, Value

__all__ = ["parse_value", "parse_rows", "parse_location"]


def _null_value(v: ScalarValue) -> None:
return None


def _bool_value(v: ScalarValue) -> bool:
return v.bool_value


def _int_value(v: ScalarValue) -> int:
return v.int_value


def _uint_value(v: ScalarValue) -> int:
return v.uint_value


def _double_value(v: ScalarValue) -> float:
return v.double_value


def _decimal_value(v: ScalarValue) -> Decimal:
return Decimal(v.decimal_value.value)


def _str_value(v: ScalarValue) -> str:
return v.str_value


def _datetime_value(v: ScalarValue) -> datetime:
return v.datetime_value.ToDatetime()


def _date_value(v: ScalarValue) -> date:
dt = v.date_value
return date(year=dt.year, month=dt.month, day=dt.day)


def _time_value(v: ScalarValue) -> time:
t = v.time_value
return time(
hour=t.hours, minute=t.minutes, second=t.seconds, microsecond=int(t.nanos / 1000)
) # FIXME python time does not support nanosecond precision


def _interval_value(v: ScalarValue) -> relativedelta:
interval = v.interval_value
return relativedelta(
months=interval.months, days=interval.days, microseconds=int(interval.nanos / 1000)
) # FIXME python relativedelta does not support nanosecond precision


def _getter(data_type: DataType) -> Callable[[ScalarValue], Value]:
if data_type == DataType.DATATYPE_NULL:
return _null_value
elif data_type == DataType.DATATYPE_BOOL:
return _bool_value
elif data_type in (DataType.DATATYPE_INT, DataType.DATATYPE_UINT):
return _int_value
elif data_type == DataType.DATATYPE_DOUBLE:
return _double_value
elif data_type == DataType.DATATYPE_DECIMAL:
return _decimal_value
elif data_type == DataType.DATATYPE_STRING:
return _str_value
elif data_type == DataType.DATATYPE_DATETIME:
return _datetime_value
elif data_type == DataType.DATATYPE_DATE:
return _date_value
elif data_type == DataType.DATATYPE_TIME:
return _time_value
elif data_type == DataType.DATATYPE_INTERVAL:
return _interval_value
else:
raise ValueError(f"unknown data type: {data_type}")


def parse_value(v: ScalarValue) -> Value:
if v.HasField("null_value"):
return None
elif v.HasField("bool_value"):
return _bool_value(v)
elif v.HasField("int_value"):
return _int_value(v)
elif v.HasField("uint_value"):
return _uint_value(v)
elif v.HasField("double_value"):
return _double_value(v)
elif v.HasField("decimal_value"):
return _decimal_value(v)
elif v.HasField("str_value"):
return _str_value(v)
elif v.HasField("datetime_value"):
return _datetime_value(v)
elif v.HasField("date_value"):
return _date_value(v)
elif v.HasField("time_value"):
return _time_value(v)
elif v.HasField("interval_value"):
return _interval_value(v)
else:
raise ValueError(f"unknown type of value {v}")


def parse_rows(rows: Rows) -> tuple[Schema, list[tuple[Value]]]:
schema = Schema.from_proto(rows.schema)

getters = [_getter(col.data_type) for col in schema]

ret = []
for row in rows.rows:
values = tuple(getters[i](v) for i, v in enumerate(row.values))
ret.append(values)

return schema, ret


def parse_location(location: Location) -> ParquetLocation:
return Path(location.local.path)
53 changes: 53 additions & 0 deletions client/gduck/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from dataclasses import dataclass
from datetime import date, datetime, time
from decimal import Decimal
from pathlib import Path
from typing import Iterator, Self, TypeAlias

from database_pb2 import Column as ProtoColumn
from database_pb2 import DataType
from database_pb2 import Schema as ProtoSchema
from dateutil.relativedelta import relativedelta

Value: TypeAlias = bool | int | float | Decimal | str | datetime | date | time | relativedelta | None


@dataclass(frozen=True)
class Column:
name: str
data_type: DataType

@classmethod
def from_proto(cls, col: ProtoColumn) -> Self:
return cls(name=col.name, data_type=col.data_type)


@dataclass(frozen=True)
class Schema:
columns: list[Column]

def __iter__(self) -> Iterator[Column]:
return iter(self.columns)

def __getitem__(self, key: int) -> Column:
return self.columns[key]

def names(self) -> list[str]:
return [col.name for col in self.columns]

@classmethod
def from_proto(cls, schema: ProtoSchema) -> Self:
columns = [Column.from_proto(col) for col in schema.columns]
return cls(columns)


@dataclass(frozen=True)
class Rows:
schema: Schema
rows: list[tuple[Value]]

def __iter__(self) -> Iterator[tuple[Value]]:
return iter(self.rows)


ParquetLocation: TypeAlias = Path
Loading

0 comments on commit ccb1013

Please sign in to comment.