Skip to content

Commit

Permalink
Fixed creation of dsrc cache key parts
Browse files Browse the repository at this point in the history
  • Loading branch information
altvod committed Nov 29, 2023
1 parent 485555a commit 5899669
Show file tree
Hide file tree
Showing 10 changed files with 33 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
Optional,
)

from sqlalchemy.sql.elements import ClauseElement

from dl_constants.enums import DataSourceType
from dl_core.connection_models import (
TableDefinition,
Expand Down Expand Up @@ -75,7 +77,7 @@ def get_parameters(self) -> dict:
)

@require_table_name
def get_sql_source(self, alias: Optional[str] = None) -> Any:
def get_sql_source(self, alias: Optional[str] = None) -> ClauseElement:
q = self.quote
alias_str = "" if alias is None else f" AS {q(alias)}"
return sa_plain_text(f"{q(self.db_name)}" f".{q(self.spec.dataset_name)}" f".{q(self.table_name)}{alias_str}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)

from clickhouse_sqlalchemy.quoting import Quoter
from sqlalchemy.sql.elements import ClauseElement

from dl_constants.enums import (
DataSourceType,
Expand Down Expand Up @@ -113,7 +114,7 @@ def quote_str(self, value: str) -> str:
def _handle_component_errors(self) -> None:
pass

def get_sql_source(self, alias: Optional[str] = None) -> Any:
def get_sql_source(self, alias: Optional[str] = None) -> ClauseElement:
origin_src = self._get_origin_src()
status = origin_src.status
raw_schema = self.spec.raw_schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def test_component_error(
assert len(actual_errors) == 1, actual_errors
assert actual_errors[0]["code"] == "ERR.DS_API.SOURCE.FILE.CUSTOM_FILE_ERROR"

@pytest.mark.xfail # FIXME: Refactor get_sql_source method of data source
@pytest.mark.asyncio
def test_component_error_warning(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def test_component_error(
assert len(actual_errors) == 1, actual_errors
assert actual_errors[0]["code"] == "ERR.DS_API.SOURCE.FILE.CUSTOM_FILE_ERROR"

@pytest.mark.xfail # FIXME: Refactor get_sql_source method of data source
@pytest.mark.asyncio
def test_component_error_warning(
self,
Expand Down
11 changes: 6 additions & 5 deletions lib/dl_connector_chyt/dl_connector_chyt/core/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import attr
import sqlalchemy as sa
from sqlalchemy.sql.elements import ClauseElement

from dl_constants.enums import (
DataSourceType,
Expand Down Expand Up @@ -100,7 +101,7 @@ def default_title(self) -> str:
return self.spec.table_name.split("/")[-1]

@require_table_name
def get_sql_source(self, alias: Optional[str] = None) -> Any:
def get_sql_source(self, alias: Optional[str] = None) -> ClauseElement:
if alias:
return sa.alias(self.get_sql_source(), name=alias)
path = self.spec.table_name
Expand All @@ -122,7 +123,7 @@ class BaseCHYTSpecialDataSource(CHYTDataSourceBaseMixin, BaseSQLDataSource, abc.
def default_title(self) -> str:
raise NotImplementedError

def get_sql_source(self, alias: Optional[str] = None) -> Any:
def get_sql_source(self, alias: Optional[str] = None) -> ClauseElement:
raise NotImplementedError

def get_table_definition(self) -> TableDefinition:
Expand Down Expand Up @@ -184,7 +185,7 @@ def get_parameters(self) -> dict:
table_names=self.spec.table_names,
)

def get_sql_source(self, alias: Optional[str] = None) -> Any:
def get_sql_source(self, alias: Optional[str] = None) -> ClauseElement:
if not self.spec.table_names:
raise exc.TableNameNotConfiguredError
table_names = self.normalize_tables_paths(self.spec.table_names)
Expand Down Expand Up @@ -231,7 +232,7 @@ def get_parameters(self) -> dict:
range_to=self.range_to or "",
)

def get_sql_source(self, alias: Optional[str] = None) -> Any:
def get_sql_source(self, alias: Optional[str] = None) -> ClauseElement:
if not self.directory_path:
raise exc.TableNameNotConfiguredError
path = self.directory_path
Expand All @@ -256,7 +257,7 @@ def get_parameters(self) -> dict:
subsql=self.subsql,
)

def get_sql_source(self, alias: Optional[str] = None) -> Any:
def get_sql_source(self, alias: Optional[str] = None) -> ClauseElement:
if not self.connection.is_subselect_allowed:
raise exc.SubselectNotAllowed()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Optional,
)

from sqlalchemy.sql.elements import ClauseElement

from dl_constants.enums import DataSourceType
from dl_core.data_source.sql import (
BaseSQLDataSource,
Expand Down Expand Up @@ -37,7 +39,7 @@ class OracleDataSource(OracleDataSourceMixin, StandardSchemaSQLDataSource):
"""Oracle table"""

@require_table_name
def get_sql_source(self, alias: Optional[str] = None) -> Any:
def get_sql_source(self, alias: Optional[str] = None) -> ClauseElement:
q = self.quote
alias_str = "" if alias is None else f" {q(alias)}"
schema_str = "" if self.schema_name is None else f"{q(self.schema_name)}."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
Type,
)

from sqlalchemy.sql.elements import ClauseElement

from dl_constants.enums import DataSourceType
from dl_core.connection_models import (
TableDefinition,
Expand Down Expand Up @@ -92,7 +94,7 @@ def get_schema_info(
return super(SnowFlakeTableDataSource, self).get_schema_info(conn_executor_factory=conn_executor_factory)

@require_table_name
def get_sql_source(self, alias: Optional[str] = None) -> Any:
def get_sql_source(self, alias: Optional[str] = None) -> ClauseElement:
q = self.quote
alias_str = "" if alias is None else f" AS {q(alias)}"
return sa_plain_text(f"{q(self.db_name)}.{q(self.schema_name)}.{q(self.table_name)}{alias_str}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
Optional,
)

from sqlalchemy.sql.elements import ClauseElement

from dl_constants.enums import DataSourceType
from dl_core.data_source.sql import (
StandardSQLDataSource,
Expand Down Expand Up @@ -33,7 +35,7 @@ class YDBTableDataSource(YDBDataSourceMixin, StandardSQLDataSource):
"""YDB table"""

@require_table_name
def get_sql_source(self, alias: Optional[str] = None) -> Any:
def get_sql_source(self, alias: Optional[str] = None) -> ClauseElement:
# cross-db joins are not supported
assert not self.db_name or self.db_name == self.connection.db_name

Expand Down
8 changes: 5 additions & 3 deletions lib/dl_core/dl_core/data_source/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)

import attr
from sqlalchemy.sql.elements import ClauseElement

from dl_constants.enums import (
ConnectionType,
Expand Down Expand Up @@ -192,12 +193,13 @@ def saved_raw_schema(self) -> Optional[list[SchemaColumn]]:
def saved_index_info_set(self) -> Optional[frozenset[IndexInfo]]:
return None

def get_sql_source(self, alias: Optional[str] = None) -> Any:
@abc.abstractmethod
def get_sql_source(self, alias: Optional[str] = None) -> ClauseElement:
"""
Return something that can be used in a ``select_from`` ``SQLAlchemy`` clause for fetching data.
Optionally assign the source an alias that can be used in the query to refer to it.
"""
return None
raise NotImplementedError()

def source_exists(
self,
Expand Down Expand Up @@ -248,7 +250,7 @@ def get_cache_key_part(self) -> LocalKeyRepresentation:
local_key_rep = local_key_rep.multi_extend(
DataKeyPart(
part_type="data_source_sql",
part_content=self.get_sql_source(),
part_content=self.get_sql_source().compile(compile_kwargs={"literal_binds": True}).string,
),
)
return local_key_rep
Expand Down
11 changes: 6 additions & 5 deletions lib/dl_core/dl_core/data_source/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import attr
import sqlalchemy as sa
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.sql.elements import ClauseElement

from dl_constants.enums import JoinType
from dl_core import exc
Expand Down Expand Up @@ -134,7 +135,7 @@ def get_dialect(self) -> DefaultDialect:
def quote(self, value) -> sa.sql.quoted_name: # type: ignore # TODO: fix # subclass of str
return self.connection.quote(value)

def get_sql_source(self, alias: Optional[str] = None) -> Any:
def get_sql_source(self, alias: Optional[str] = None) -> ClauseElement:
raise NotImplementedError()

def get_table_definition(self) -> TableDefinition:
Expand Down Expand Up @@ -196,7 +197,7 @@ def get_parameters(self) -> dict:
_subquery_alias_joiner = " AS "
_subquery_auto_alias = "source"

def get_sql_source(self, alias: Optional[str] = None) -> Any:
def get_sql_source(self, alias: Optional[str] = None) -> ClauseElement:
if not self.connection.is_subselect_allowed:
raise exc.SubselectNotAllowed()

Expand Down Expand Up @@ -268,7 +269,7 @@ def source_exists(
return super().source_exists(conn_executor_factory=conn_executor_factory, force_refresh=force_refresh)

@require_table_name
def get_sql_source(self, alias: Optional[str] = None) -> Any:
def get_sql_source(self, alias: Optional[str] = None) -> ClauseElement:
q = self.quote
alias_str = "" if alias is None else f" AS {q(alias)}"
return sa_plain_text(f"{q(self.db_name)}.{q(self.table_name)}{alias_str}")
Expand Down Expand Up @@ -353,7 +354,7 @@ class PseudoSQLDataSource(StandardSQLDataSource, IncompatibleDataSourceMixin):
supports_schema_update: ClassVar[bool] = False

@require_table_name
def get_sql_source(self, alias: Optional[str] = None) -> Any:
def get_sql_source(self, alias: Optional[str] = None) -> ClauseElement:
# ignore alias
return sa.table(self.table_name)

Expand Down Expand Up @@ -383,7 +384,7 @@ def get_parameters(self) -> dict:
)

@require_table_name
def get_sql_source(self, alias: Optional[str] = None) -> Any:
def get_sql_source(self, alias: Optional[str] = None) -> ClauseElement:
if not self.schema_name:
return super().get_sql_source(alias=alias)
q = self.quote
Expand Down

0 comments on commit 5899669

Please sign in to comment.