Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more connection executor tests (native types, select data, cast for output, etc.) #8

Merged
merged 3 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import defaultdict
import functools
from http import HTTPStatus
from typing import Iterable

Expand Down Expand Up @@ -70,7 +71,7 @@ def test_lod_fixed_single_dim_in_two_dim_query(self, control_api, data_api, save
assert float(row[3]) == pytest.approx(sum_by_city[row[0]])
assert float(row[4]) == pytest.approx(sum_by_category[row[1]])

def test_null_dimensions(self, control_api, data_api, db, saved_connection_id):
def test_null_dimensions(self, request, control_api, data_api, db, saved_connection_id):
connection_id = saved_connection_id

raw_data = [
Expand All @@ -88,6 +89,7 @@ def test_null_dimensions(self, control_api, data_api, db, saved_connection_id):
C("sales", UserDataType.integer, vg=lambda rn, **kwargs: raw_data[rn]["sales"]),
]
db_table = make_table(db, columns=columns, rows=len(raw_data))
request.addfinalizer(functools.partial(db.drop_table, db_table.table))
MCPN marked this conversation as resolved.
Show resolved Hide resolved

ds = Dataset()
ds.sources["source_1"] = ds.source(connection_id=connection_id, **data_source_settings_from_table(db_table))
Expand Down Expand Up @@ -165,12 +167,14 @@ def get_total_value() -> float:

def test_total_lod_2(
self,
request,
control_api,
data_api,
saved_connection_id,
db,
):
db_table = make_table(db=db)
request.addfinalizer(functools.partial(db.drop_table, db_table.table))
ds = create_basic_dataset(
api_v1=control_api,
connection_id=saved_connection_id,
Expand Down Expand Up @@ -213,12 +217,14 @@ def check_equality_of_totals(*field_names: str) -> None:

def test_lod_in_order_by(
self,
request,
control_api,
data_api,
saved_connection_id,
db,
):
db_table = make_table(db=db)
request.addfinalizer(functools.partial(db.drop_table, db_table.table))

data_api = data_api
ds = create_basic_dataset(
Expand Down Expand Up @@ -258,8 +264,9 @@ def get_data(order_by: list[ResultField]) -> list:


class DefaultBasicLookupFunctionTestSuite(DataApiTestBase, DatasetTestBase, DbServiceFixtureTextClass):
def test_ago_any_db(self, saved_connection_id, control_api, data_api, db):
def test_ago_any_db(self, request, saved_connection_id, control_api, data_api, db):
db_table = make_table(db=db)
request.addfinalizer(functools.partial(db.drop_table, db_table.table))
ds = create_basic_dataset(
api_v1=control_api,
connection_id=saved_connection_id,
Expand Down Expand Up @@ -287,8 +294,9 @@ def test_ago_any_db(self, saved_connection_id, control_api, data_api, db):

check_ago_data(data_rows=data_rows, date_idx=0, value_idx=1, ago_idx=2, day_offset=2)

def test_triple_ago_any_db(self, saved_connection_id, control_api, data_api, db):
def test_triple_ago_any_db(self, request, saved_connection_id, control_api, data_api, db):
db_table = make_table(db)
request.addfinalizer(functools.partial(db.drop_table, db_table.table))
ds = create_basic_dataset(
api_v1=control_api,
connection_id=saved_connection_id,
Expand Down Expand Up @@ -322,10 +330,17 @@ def test_triple_ago_any_db(self, saved_connection_id, control_api, data_api, db)
check_ago_data(data_rows=data_rows, date_idx=0, value_idx=1, ago_idx=3, day_offset=2)
check_ago_data(data_rows=data_rows, date_idx=0, value_idx=1, ago_idx=4, day_offset=3)

def test_ago_any_db_multisource(self, saved_connection_id, control_api, data_api, db):
def test_ago_any_db_multisource(self, request, saved_connection_id, control_api, data_api, db):
connection_id = saved_connection_id
table_1 = make_table(db)
table_2 = make_table(db)

def teardown(db, *tables):
for table in tables:
db.drop_table(table)

request.addfinalizer(functools.partial(teardown, db, table_1.table, table_2.table))

ds = Dataset()
ds.sources["source_1"] = ds.source(
connection_id=connection_id,
Expand Down Expand Up @@ -369,8 +384,9 @@ def test_ago_any_db_multisource(self, saved_connection_id, control_api, data_api

check_ago_data(data_rows=data_rows, date_idx=0, value_idx=1, ago_idx=2, day_offset=2)

def test_nested_ago(self, saved_connection_id, control_api, data_api, db):
def test_nested_ago(self, request, saved_connection_id, control_api, data_api, db):
db_table = make_table(db)
request.addfinalizer(functools.partial(db.drop_table, db_table.table))
ds = create_basic_dataset(
api_v1=control_api,
connection_id=saved_connection_id,
Expand Down Expand Up @@ -404,8 +420,9 @@ def test_nested_ago(self, saved_connection_id, control_api, data_api, db):
check_ago_data(data_rows=data_rows, date_idx=0, value_idx=1, ago_idx=3, day_offset=2)
check_ago_data(data_rows=data_rows, date_idx=0, value_idx=1, ago_idx=4, day_offset=3)

def test_month_ago_for_shorter_month(self, db, saved_connection_id, control_api, data_api):
def test_month_ago_for_shorter_month(self, request, db, saved_connection_id, control_api, data_api):
any_db_table_200 = make_table(db, rows=200)
request.addfinalizer(functools.partial(db.drop_table, any_db_table_200.table))

# FIXME
# if any_db.conn_type == CONNECTION_TYPE_ORACLE:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import functools
import json
from typing import Any

Expand Down Expand Up @@ -60,6 +61,7 @@ def test_basic_result(

def _test_contains(
self,
request,
db: Db,
saved_connection_id: str,
dataset_params: dict,
Expand All @@ -82,6 +84,8 @@ def _test_contains(
),
]
db_table = make_table(db, columns=columns)
request.addfinalizer(functools.partial(db.drop_table, db_table.table))

params = self.get_dataset_params(dataset_params, db_table)
ds = self.make_basic_dataset(control_api, connection_id=saved_connection_id, dataset_params=params)

Expand Down Expand Up @@ -156,6 +160,7 @@ def test_array_not_contains_filter(
@for_features(array_support)
def test_array_contains_field(
self,
request,
db: Db,
saved_connection_id: str,
dataset_params: dict,
Expand All @@ -179,6 +184,7 @@ def test_array_contains_field(
),
]
db_table = make_table(db, columns=columns)
request.addfinalizer(functools.partial(db.drop_table, db_table.table))
params = self.get_dataset_params(dataset_params, db_table)
ds = self.make_basic_dataset(control_api, connection_id=saved_connection_id, dataset_params=params)

Expand Down Expand Up @@ -384,6 +390,7 @@ def test_distinct_with_nonexistent_filter(

def test_date_filter_distinct(
self,
request,
db: Db,
saved_connection_id: str,
dataset_params: dict,
Expand All @@ -398,6 +405,8 @@ def test_date_filter_distinct(
{"date_val": datetime.date(2023, 4, 2)},
]
db_table = make_table(db, columns=columns, data=data)
request.addfinalizer(functools.partial(db.drop_table, db_table.table))

params = self.get_dataset_params(dataset_params, db_table)
ds = self.make_basic_dataset(control_api, connection_id=saved_connection_id, dataset_params=params)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,15 @@ async def s3_bucket(self, s3_client: AsyncS3Client) -> str:
def sample_table_spec(self) -> FixtureTableSpec:
return attr.evolve(TABLE_SPEC_SAMPLE_SUPERSTORE, nullable=True)

def _get_s3_func_schema_for_table(self, table: DbTable) -> str:
field_id_gen = get_field_id_generator(self.conn_type)
tbl_schema = ", ".join(
"{} {}".format(field_id_gen.make_field_id(dict(title=col.name, index=idx)), col.type.compile())
for idx, col in enumerate(table.table.columns)
)
tbl_schema = tbl_schema.replace("()", "") # String() -> String: type arguments are not needed here
return tbl_schema

@pytest.fixture(scope="function")
async def sample_s3_file(
self,
Expand All @@ -153,13 +162,7 @@ async def sample_s3_file(
) -> AsyncGenerator[str, None]:
filename = f"my_file_{uuid.uuid4()}.native"

field_id_gen = get_field_id_generator(self.conn_type)
tbl_schema = ", ".join(
"{} {}".format(field_id_gen.make_field_id(dict(title=col.name, index=idx)), col.type.compile())
for idx, col in enumerate(sample_table.table.columns)
)
tbl_schema = tbl_schema.replace("()", "") # String() -> String: type arguments are not needed here

tbl_schema = self._get_s3_func_schema_for_table(sample_table)
create_s3_native_from_ch_table(filename, s3_bucket, s3_settings, sample_table, tbl_schema)

yield filename
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
import abc

import pytest
import sqlalchemy as sa

from dl_configs.settings_submodels import S3Settings
from dl_constants.enums import UserDataType
from dl_core.connection_executors import (
AsyncConnExecutorBase,
ConnExecutorQuery,
)
from dl_core.connection_models import DBIdent
from dl_core_testing.database import DbTable
from dl_core_testing.testcases.connection_executor import (
DefaultAsyncConnectionExecutorTestSuite,
DefaultSyncAsyncConnectionExecutorCheckBase,
DefaultSyncConnectionExecutorTestSuite,
)
from dl_testing.regulated_test import RegulatedTestParams
from dl_testing.s3_utils import s3_tbl_func_maker

from dl_connector_bundle_chs3_tests.db.base.core.base import (
FILE_CONN_TV,
Expand Down Expand Up @@ -43,6 +52,7 @@ class CHS3SyncConnectionExecutorTestBase(
test_params = RegulatedTestParams(
mark_tests_skipped={
DefaultSyncConnectionExecutorTestSuite.test_type_recognition: "Not implemented",
DefaultSyncConnectionExecutorTestSuite.test_get_table_names: "Not implemented",
},
)

Expand All @@ -52,4 +62,67 @@ class CHS3AsyncConnectionExecutorTestBase(
DefaultAsyncConnectionExecutorTestSuite[FILE_CONN_TV],
metaclass=abc.ABCMeta,
):
pass
async def test_select_data(
self,
sample_table: DbTable,
saved_connection: FILE_CONN_TV,
async_connection_executor: AsyncConnExecutorBase,
s3_settings: S3Settings,
sample_s3_file: str,
) -> None:
schema_line = self._get_s3_func_schema_for_table(sample_table)
s3_tbl_func = s3_tbl_func_maker(s3_settings)(
for_="dba",
conn_dto=saved_connection.get_conn_dto(),
filename=sample_s3_file,
file_fmt="Native",
schema_line=schema_line,
)
file_columns = [sa.column(col_desc.split(" ")[0]) for col_desc in schema_line.split(", ")]
# ^ "col1 String, col2 Int32, col3 Date32" -> [col1, col2, col3]

n_rows = 3
result = await async_connection_executor.execute(
ConnExecutorQuery(
query=sa.select(columns=file_columns)
.select_from(sa.text(s3_tbl_func))
.order_by(file_columns[0])
.limit(n_rows),
chunk_size=6,
)
)
rows = await result.get_all()
assert len(rows) == n_rows

@pytest.mark.asyncio
async def test_cast_row_to_output(
self,
sample_table: DbTable,
saved_connection: FILE_CONN_TV,
async_connection_executor: AsyncConnExecutorBase,
s3_settings: S3Settings,
sample_s3_file: str,
) -> None:
schema_line = self._get_s3_func_schema_for_table(sample_table)
s3_tbl_func = s3_tbl_func_maker(s3_settings)(
for_="dba",
conn_dto=saved_connection.get_conn_dto(),
filename=sample_s3_file,
file_fmt="Native",
schema_line=schema_line,
)

result = await async_connection_executor.execute(
ConnExecutorQuery(
sa.select(columns=[sa.literal(1), sa.literal(2), sa.literal(3)])
.select_from(sa.text(s3_tbl_func))
.limit(1),
user_types=[
UserDataType.boolean,
UserDataType.float,
UserDataType.integer,
],
)
)
rows = await result.get_all()
assert rows == [(True, 2.0, 3)], rows
Loading
Loading