Skip to content

Commit

Permalink
Add more connection executor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
KonstantAnxiety committed Oct 10, 2023
1 parent 35a63b2 commit 29e3aaa
Show file tree
Hide file tree
Showing 7 changed files with 431 additions and 69 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import defaultdict
import functools
from http import HTTPStatus
from typing import Iterable

Expand Down Expand Up @@ -70,7 +71,7 @@ def test_lod_fixed_single_dim_in_two_dim_query(self, control_api, data_api, save
assert float(row[3]) == pytest.approx(sum_by_city[row[0]])
assert float(row[4]) == pytest.approx(sum_by_category[row[1]])

def test_null_dimensions(self, control_api, data_api, db, saved_connection_id):
def test_null_dimensions(self, request, control_api, data_api, db, saved_connection_id):
connection_id = saved_connection_id

raw_data = [
Expand All @@ -88,6 +89,7 @@ def test_null_dimensions(self, control_api, data_api, db, saved_connection_id):
C("sales", UserDataType.integer, vg=lambda rn, **kwargs: raw_data[rn]["sales"]),
]
db_table = make_table(db, columns=columns, rows=len(raw_data))
request.addfinalizer(functools.partial(db.drop_table, db_table.table))

ds = Dataset()
ds.sources["source_1"] = ds.source(connection_id=connection_id, **data_source_settings_from_table(db_table))
Expand Down Expand Up @@ -165,12 +167,14 @@ def get_total_value() -> float:

def test_total_lod_2(
self,
request,
control_api,
data_api,
saved_connection_id,
db,
):
db_table = make_table(db=db)
request.addfinalizer(functools.partial(db.drop_table, db_table.table))
ds = create_basic_dataset(
api_v1=control_api,
connection_id=saved_connection_id,
Expand Down Expand Up @@ -213,12 +217,14 @@ def check_equality_of_totals(*field_names: str) -> None:

def test_lod_in_order_by(
self,
request,
control_api,
data_api,
saved_connection_id,
db,
):
db_table = make_table(db=db)
request.addfinalizer(functools.partial(db.drop_table, db_table.table))

data_api = data_api
ds = create_basic_dataset(
Expand Down Expand Up @@ -258,8 +264,9 @@ def get_data(order_by: list[ResultField]) -> list:


class DefaultBasicLookupFunctionTestSuite(DataApiTestBase, DatasetTestBase, DbServiceFixtureTextClass):
def test_ago_any_db(self, saved_connection_id, control_api, data_api, db):
def test_ago_any_db(self, request, saved_connection_id, control_api, data_api, db):
db_table = make_table(db=db)
request.addfinalizer(functools.partial(db.drop_table, db_table.table))
ds = create_basic_dataset(
api_v1=control_api,
connection_id=saved_connection_id,
Expand Down Expand Up @@ -287,8 +294,9 @@ def test_ago_any_db(self, saved_connection_id, control_api, data_api, db):

check_ago_data(data_rows=data_rows, date_idx=0, value_idx=1, ago_idx=2, day_offset=2)

def test_triple_ago_any_db(self, saved_connection_id, control_api, data_api, db):
def test_triple_ago_any_db(self, request, saved_connection_id, control_api, data_api, db):
db_table = make_table(db)
request.addfinalizer(functools.partial(db.drop_table, db_table.table))
ds = create_basic_dataset(
api_v1=control_api,
connection_id=saved_connection_id,
Expand Down Expand Up @@ -322,10 +330,17 @@ def test_triple_ago_any_db(self, saved_connection_id, control_api, data_api, db)
check_ago_data(data_rows=data_rows, date_idx=0, value_idx=1, ago_idx=3, day_offset=2)
check_ago_data(data_rows=data_rows, date_idx=0, value_idx=1, ago_idx=4, day_offset=3)

def test_ago_any_db_multisource(self, saved_connection_id, control_api, data_api, db):
def test_ago_any_db_multisource(self, request, saved_connection_id, control_api, data_api, db):
connection_id = saved_connection_id
table_1 = make_table(db)
table_2 = make_table(db)

def teardown(db, *tables):
for table in tables:
db.drop_table(table)

request.addfinalizer(functools.partial(teardown, db, table_1.table, table_2.table))

ds = Dataset()
ds.sources["source_1"] = ds.source(
connection_id=connection_id,
Expand Down Expand Up @@ -369,8 +384,9 @@ def test_ago_any_db_multisource(self, saved_connection_id, control_api, data_api

check_ago_data(data_rows=data_rows, date_idx=0, value_idx=1, ago_idx=2, day_offset=2)

def test_nested_ago(self, saved_connection_id, control_api, data_api, db):
def test_nested_ago(self, request, saved_connection_id, control_api, data_api, db):
db_table = make_table(db)
request.addfinalizer(functools.partial(db.drop_table, db_table.table))
ds = create_basic_dataset(
api_v1=control_api,
connection_id=saved_connection_id,
Expand Down Expand Up @@ -404,8 +420,9 @@ def test_nested_ago(self, saved_connection_id, control_api, data_api, db):
check_ago_data(data_rows=data_rows, date_idx=0, value_idx=1, ago_idx=3, day_offset=2)
check_ago_data(data_rows=data_rows, date_idx=0, value_idx=1, ago_idx=4, day_offset=3)

def test_month_ago_for_shorter_month(self, db, saved_connection_id, control_api, data_api):
def test_month_ago_for_shorter_month(self, request, db, saved_connection_id, control_api, data_api):
any_db_table_200 = make_table(db, rows=200)
request.addfinalizer(functools.partial(db.drop_table, any_db_table_200.table))

# FIXME
# if any_db.conn_type == CONNECTION_TYPE_ORACLE:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import functools
import json
from typing import Any

Expand Down Expand Up @@ -60,6 +61,7 @@ def test_basic_result(

def _test_contains(
self,
request,
db: Db,
saved_connection_id: str,
dataset_params: dict,
Expand All @@ -82,6 +84,8 @@ def _test_contains(
),
]
db_table = make_table(db, columns=columns)
request.addfinalizer(functools.partial(db.drop_table, db_table.table))

params = self.get_dataset_params(dataset_params, db_table)
ds = self.make_basic_dataset(control_api, connection_id=saved_connection_id, dataset_params=params)

Expand Down Expand Up @@ -156,6 +160,7 @@ def test_array_not_contains_filter(
@for_features(array_support)
def test_array_contains_field(
self,
request,
db: Db,
saved_connection_id: str,
dataset_params: dict,
Expand All @@ -179,6 +184,7 @@ def test_array_contains_field(
),
]
db_table = make_table(db, columns=columns)
request.addfinalizer(functools.partial(db.drop_table, db_table.table))
params = self.get_dataset_params(dataset_params, db_table)
ds = self.make_basic_dataset(control_api, connection_id=saved_connection_id, dataset_params=params)

Expand Down Expand Up @@ -384,6 +390,7 @@ def test_distinct_with_nonexistent_filter(

def test_date_filter_distinct(
self,
request,
db: Db,
saved_connection_id: str,
dataset_params: dict,
Expand All @@ -398,6 +405,8 @@ def test_date_filter_distinct(
{"date_val": datetime.date(2023, 4, 2)},
]
db_table = make_table(db, columns=columns, data=data)
request.addfinalizer(functools.partial(db.drop_table, db_table.table))

params = self.get_dataset_params(dataset_params, db_table)
ds = self.make_basic_dataset(control_api, connection_id=saved_connection_id, dataset_params=params)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,15 @@ async def s3_bucket(self, s3_client: AsyncS3Client) -> str:
def sample_table_spec(self) -> FixtureTableSpec:
return attr.evolve(TABLE_SPEC_SAMPLE_SUPERSTORE, nullable=True)

def _get_s3_func_schema_for_table(self, table: DbTable) -> str:
field_id_gen = get_field_id_generator(self.conn_type)
tbl_schema = ", ".join(
"{} {}".format(field_id_gen.make_field_id(dict(title=col.name, index=idx)), col.type.compile())
for idx, col in enumerate(table.table.columns)
)
tbl_schema = tbl_schema.replace("()", "") # String() -> String: type arguments are not needed here
return tbl_schema

@pytest.fixture(scope="function")
async def sample_s3_file(
self,
Expand All @@ -153,13 +162,7 @@ async def sample_s3_file(
) -> AsyncGenerator[str, None]:
filename = f"my_file_{uuid.uuid4()}.native"

field_id_gen = get_field_id_generator(self.conn_type)
tbl_schema = ", ".join(
"{} {}".format(field_id_gen.make_field_id(dict(title=col.name, index=idx)), col.type.compile())
for idx, col in enumerate(sample_table.table.columns)
)
tbl_schema = tbl_schema.replace("()", "") # String() -> String: type arguments are not needed here

tbl_schema = self._get_s3_func_schema_for_table(sample_table)
create_s3_native_from_ch_table(filename, s3_bucket, s3_settings, sample_table, tbl_schema)

yield filename
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
import abc

import pytest
import sqlalchemy as sa

from dl_configs.settings_submodels import S3Settings
from dl_constants.enums import UserDataType
from dl_core.connection_executors import (
AsyncConnExecutorBase,
ConnExecutorQuery,
)
from dl_core.connection_models import DBIdent
from dl_core_testing.database import DbTable
from dl_core_testing.testcases.connection_executor import (
DefaultAsyncConnectionExecutorTestSuite,
DefaultSyncAsyncConnectionExecutorCheckBase,
DefaultSyncConnectionExecutorTestSuite,
)
from dl_testing.regulated_test import RegulatedTestParams
from dl_testing.s3_utils import s3_tbl_func_maker

from dl_connector_bundle_chs3_tests.db.base.core.base import (
FILE_CONN_TV,
Expand Down Expand Up @@ -43,6 +52,7 @@ class CHS3SyncConnectionExecutorTestBase(
test_params = RegulatedTestParams(
mark_tests_skipped={
DefaultSyncConnectionExecutorTestSuite.test_type_recognition: "Not implemented",
DefaultSyncConnectionExecutorTestSuite.test_get_table_names: "Not implemented",
},
)

Expand All @@ -52,4 +62,67 @@ class CHS3AsyncConnectionExecutorTestBase(
DefaultAsyncConnectionExecutorTestSuite[FILE_CONN_TV],
metaclass=abc.ABCMeta,
):
pass
async def test_select_data(
self,
sample_table: DbTable,
saved_connection: FILE_CONN_TV,
async_connection_executor: AsyncConnExecutorBase,
s3_settings: S3Settings,
sample_s3_file: str,
) -> None:
schema_line = self._get_s3_func_schema_for_table(sample_table)
s3_tbl_func = s3_tbl_func_maker(s3_settings)(
for_="dba",
conn_dto=saved_connection.get_conn_dto(),
filename=sample_s3_file,
file_fmt="Native",
schema_line=schema_line,
)
file_columns = [sa.column(col_desc.split(" ")[0]) for col_desc in schema_line.split(", ")]
# ^ "col1 String, col2 Int32, col3 Date32" -> [col1, col2, col3]

n_rows = 3
result = await async_connection_executor.execute(
ConnExecutorQuery(
query=sa.select(columns=file_columns)
.select_from(sa.text(s3_tbl_func))
.order_by(file_columns[0])
.limit(n_rows),
chunk_size=6,
)
)
rows = await result.get_all()
assert len(rows) == n_rows

@pytest.mark.asyncio
async def test_cast_row_to_output(
self,
sample_table: DbTable,
saved_connection: FILE_CONN_TV,
async_connection_executor: AsyncConnExecutorBase,
s3_settings: S3Settings,
sample_s3_file: str,
) -> None:
schema_line = self._get_s3_func_schema_for_table(sample_table)
s3_tbl_func = s3_tbl_func_maker(s3_settings)(
for_="dba",
conn_dto=saved_connection.get_conn_dto(),
filename=sample_s3_file,
file_fmt="Native",
schema_line=schema_line,
)

result = await async_connection_executor.execute(
ConnExecutorQuery(
sa.select(columns=[sa.literal(1), sa.literal(2), sa.literal(3)])
.select_from(sa.text(s3_tbl_func))
.limit(1),
user_types=[
UserDataType.boolean,
UserDataType.float,
UserDataType.integer,
],
)
)
rows = await result.get_all()
assert rows == [(True, 2.0, 3)], rows
Loading

0 comments on commit 29e3aaa

Please sign in to comment.