-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add request adapter to python client
- Loading branch information
Showing
8 changed files
with
331 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import math | ||
from datetime import date, datetime, time | ||
from decimal import Decimal | ||
from pathlib import Path | ||
from typing import Literal | ||
|
||
from database_pb2 import Connect, Date | ||
from database_pb2 import Decimal as ProtoDecimal | ||
from database_pb2 import Interval, Params, ScalarValue, Time | ||
from dateutil.relativedelta import relativedelta | ||
from google.protobuf.struct_pb2 import NULL_VALUE | ||
from google.protobuf.timestamp_pb2 import Timestamp | ||
from location_pb2 import Location | ||
from query_pb2 import Query | ||
from service_pb2 import Request | ||
|
||
from .types import Value | ||
|
||
__all__ = ["ConnectionMode", "connect", "local_file", "execute", "value", "rows", "ctas", "parquet", "request"] | ||
|
||
ConnectionMode = Literal["auto", "read_write", "read_only"] | ||
|
||
|
||
def _mode(m: ConnectionMode = Connect.Mode.MODE_AUTO) -> Connect.Mode: | ||
if m == "read_write": | ||
return Connect.Mode.MODE_READ_WRITE | ||
elif m == "read_only": | ||
return Connect.Mode.MODE_READ_ONLY | ||
else: | ||
return Connect.Mode.MODE_AUTO | ||
|
||
|
||
def connect(file_name: str, mode: ConnectionMode) -> Connect: | ||
return Connect(file_name=file_name, mode=_mode(mode)) | ||
|
||
|
||
def _value(v: Value) -> ScalarValue: | ||
if v is None: | ||
return ScalarValue(null_value=NULL_VALUE) | ||
elif type(v) is bool: | ||
return ScalarValue(bool_value=v) | ||
elif type(v) is int: | ||
return ScalarValue(int_value=v) | ||
elif type(v) is float: | ||
return ScalarValue(double_value=v) | ||
elif type(v) is Decimal: | ||
return ScalarValue(decimal_value=ProtoDecimal(value=str(v))) | ||
elif type(v) is str: | ||
return ScalarValue(str_value=v) | ||
elif type(v) is datetime: | ||
fraction, seconds = math.modf(v.timestamp()) | ||
return ScalarValue(datetime_value=Timestamp(seconds=int(seconds), nanos=int(fraction * 1000000000))) | ||
elif type(v) is date: | ||
return ScalarValue(date_value=Date(year=v.year, month=v.month, day=v.day)) | ||
elif type(v) is time: | ||
return ScalarValue(time_value=Time(hours=v.hour, minutes=v.minute, seconds=v.second, nanos=v.microsecond * 1000)) | ||
elif type(v) is relativedelta: | ||
nv = v.normalized() | ||
return ScalarValue( | ||
interval_value=Interval( | ||
months=12 * nv.years + nv.months, | ||
days=nv.days, | ||
nanos=1000000000 * (nv.hours * 3600 + nv.minutes * 60 + nv.seconds) + 1000 * nv.microseconds, | ||
) | ||
) | ||
else: | ||
raise ValueError(f"Invalid type of value: {v} ({type(v)})") | ||
|
||
|
||
def _params(*params: tuple[Value]) -> Params: | ||
return Params(params=[_value(p) for p in params]) | ||
|
||
|
||
def local_file(path: Path) -> Location: | ||
return Location(local=Location.LocalFile(path=str(path))) | ||
|
||
|
||
def execute(query: str, *params: tuple[Value]) -> Query: | ||
return Query(execute=Query.Execute(query=query, params=_params(*params))) | ||
|
||
|
||
def value(query: str, *params: tuple[Value]) -> Query: | ||
return Query(value=Query.QueryValue(query=query, params=_params(*params))) | ||
|
||
|
||
def rows(query: str, *params: tuple[Value]) -> Query: | ||
return Query(rows=Query.QueryRows(query=query, params=_params(*params))) | ||
|
||
|
||
def ctas(table_name: str, query: str, *params: tuple[Value]) -> Query: | ||
return Query(ctas=Query.CreateTableAsQuery(table_name=table_name, query=query, params=_params(*params))) | ||
|
||
|
||
def parquet(location: Location, query: str, *params: tuple[Value]) -> Query: | ||
return Query(parquet=Query.ParquetQuery(location=location, query=query, params=_params(*params))) | ||
|
||
|
||
def request(kind: Connect | Query) -> Request: | ||
if type(kind) is Connect: | ||
return Request(connect=kind) | ||
elif type(kind) is Query: | ||
return Request(query=kind) | ||
else: | ||
raise ValueError(f"unsupported type of message: {kind}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
from __future__ import annotations | ||
|
||
from datetime import date, datetime, time | ||
from decimal import Decimal | ||
from pathlib import Path | ||
from typing import Callable | ||
|
||
from database_pb2 import DataType, Rows, ScalarValue | ||
from dateutil.relativedelta import relativedelta | ||
from location_pb2 import Location | ||
|
||
from .types import ParquetLocation, Schema, Value | ||
|
||
__all__ = ["parse_value", "parse_rows", "parse_location"] | ||
|
||
|
||
def _null_value(v: ScalarValue) -> None: | ||
return None | ||
|
||
|
||
def _bool_value(v: ScalarValue) -> bool: | ||
return v.bool_value | ||
|
||
|
||
def _int_value(v: ScalarValue) -> int: | ||
return v.int_value | ||
|
||
|
||
def _uint_value(v: ScalarValue) -> int: | ||
return v.uint_value | ||
|
||
|
||
def _double_value(v: ScalarValue) -> float: | ||
return v.double_value | ||
|
||
|
||
def _decimal_value(v: ScalarValue) -> Decimal: | ||
return Decimal(v.decimal_value.value) | ||
|
||
|
||
def _str_value(v: ScalarValue) -> str: | ||
return v.str_value | ||
|
||
|
||
def _datetime_value(v: ScalarValue) -> datetime: | ||
return v.datetime_value.ToDatetime() | ||
|
||
|
||
def _date_value(v: ScalarValue) -> date: | ||
dt = v.date_value | ||
return date(year=dt.year, month=dt.month, day=dt.day) | ||
|
||
|
||
def _time_value(v: ScalarValue) -> time: | ||
t = v.time_value | ||
return time( | ||
hour=t.hours, minute=t.minutes, second=t.seconds, microsecond=int(t.nanos / 1000) | ||
) # FIXME python time does not support nanosecond precision | ||
|
||
|
||
def _interval_value(v: ScalarValue) -> relativedelta: | ||
interval = v.interval_value | ||
return relativedelta( | ||
months=interval.months, days=interval.days, microseconds=int(interval.nanos / 1000) | ||
) # FIXME python relativedelta does not support nanosecond precision | ||
|
||
|
||
def _getter(data_type: DataType) -> Callable[[ScalarValue], Value]: | ||
if data_type == DataType.DATATYPE_NULL: | ||
return _null_value | ||
elif data_type == DataType.DATATYPE_BOOL: | ||
return _bool_value | ||
elif data_type in (DataType.DATATYPE_INT, DataType.DATATYPE_UINT): | ||
return _int_value | ||
elif data_type == DataType.DATATYPE_DOUBLE: | ||
return _double_value | ||
elif data_type == DataType.DATATYPE_DECIMAL: | ||
return _decimal_value | ||
elif data_type == DataType.DATATYPE_STRING: | ||
return _str_value | ||
elif data_type == DataType.DATATYPE_DATETIME: | ||
return _datetime_value | ||
elif data_type == DataType.DATATYPE_DATE: | ||
return _date_value | ||
elif data_type == DataType.DATATYPE_TIME: | ||
return _time_value | ||
elif data_type == DataType.DATATYPE_INTERVAL: | ||
return _interval_value | ||
else: | ||
raise ValueError(f"unknown data type: {data_type}") | ||
|
||
|
||
def parse_value(v: ScalarValue) -> Value: | ||
if v.HasField("null_value"): | ||
return None | ||
elif v.HasField("bool_value"): | ||
return _bool_value(v) | ||
elif v.HasField("int_value"): | ||
return _int_value(v) | ||
elif v.HasField("uint_value"): | ||
return _uint_value(v) | ||
elif v.HasField("double_value"): | ||
return _double_value(v) | ||
elif v.HasField("decimal_value"): | ||
return _decimal_value(v) | ||
elif v.HasField("str_value"): | ||
return _str_value(v) | ||
elif v.HasField("datetime_value"): | ||
return _datetime_value(v) | ||
elif v.HasField("date_value"): | ||
return _date_value(v) | ||
elif v.HasField("time_value"): | ||
return _time_value(v) | ||
elif v.HasField("interval_value"): | ||
return _interval_value(v) | ||
else: | ||
raise ValueError(f"unknown type of value {v}") | ||
|
||
|
||
def parse_rows(rows: Rows) -> tuple[Schema, list[tuple[Value]]]: | ||
schema = Schema.from_proto(rows.schema) | ||
|
||
getters = [_getter(col.data_type) for col in schema] | ||
|
||
ret = [] | ||
for row in rows.rows: | ||
values = tuple(getters[i](v) for i, v in enumerate(row.values)) | ||
ret.append(values) | ||
|
||
return schema, ret | ||
|
||
|
||
def parse_location(location: Location) -> ParquetLocation: | ||
return Path(location.local.path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from dataclasses import dataclass | ||
from datetime import date, datetime, time | ||
from decimal import Decimal | ||
from pathlib import Path | ||
from typing import Iterator, Self, TypeAlias | ||
|
||
from database_pb2 import Column as ProtoColumn | ||
from database_pb2 import DataType | ||
from database_pb2 import Schema as ProtoSchema | ||
from dateutil.relativedelta import relativedelta | ||
|
||
Value: TypeAlias = bool | int | float | Decimal | str | datetime | date | time | relativedelta | None | ||
|
||
|
||
@dataclass(frozen=True) | ||
class Column: | ||
name: str | ||
data_type: DataType | ||
|
||
@classmethod | ||
def from_proto(cls, col: ProtoColumn) -> Self: | ||
return cls(name=col.name, data_type=col.data_type) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class Schema: | ||
columns: list[Column] | ||
|
||
def __iter__(self) -> Iterator[Column]: | ||
return iter(self.columns) | ||
|
||
def __getitem__(self, key: int) -> Column: | ||
return self.columns[key] | ||
|
||
def names(self) -> list[str]: | ||
return [col.name for col in self.columns] | ||
|
||
@classmethod | ||
def from_proto(cls, schema: ProtoSchema) -> Self: | ||
columns = [Column.from_proto(col) for col in schema.columns] | ||
return cls(columns) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class Rows: | ||
schema: Schema | ||
rows: list[tuple[Value]] | ||
|
||
def __iter__(self) -> Iterator[tuple[Value]]: | ||
return iter(self.rows) | ||
|
||
|
||
ParquetLocation: TypeAlias = Path |
Oops, something went wrong.