Skip to content
This repository has been archived by the owner on Apr 8, 2024. It is now read-only.

Commit

Permalink
feat(dbt-adapter): dbt-athena-community support (#741)
Browse files Browse the repository at this point in the history
* tests(dbt-fal): add athena profile

* feat: add dbt-athena-community support

dbt-athena from PyPI doesn't support dbt 1.3, but dbt-athena-community does
https://github.com/dbt-athena/dbt-athena

dbt-athena-community does not support dbt 1.4, so this branch doesn't need to
be merged to main. Instead, we will release a patch to dbt-fal 1.3 in order to
add support to dbt-athena-community

* PR comments

* refactor: move some hacks location (#742)

---------

Co-authored-by: Matteo Ferrando <[email protected]>
  • Loading branch information
mederka and chamini2 authored Feb 6, 2023
1 parent 99521ad commit f533e92
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 6 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test_integration_cli.yml
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ jobs:
dbt: "1.1.*"
- profile: "fal"
dbt: "1.2.*"
# TODO: use dbt-athena-community instead
# dbt-athena-adapter only supports dbt-core==1.0.* for now
- profile: "athena"
dbt: "1.1.*"
Expand Down
16 changes: 16 additions & 0 deletions adapter/integration_tests/profiles/athena/profiles.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
config:
send_anonymous_usage_stats: False

fal_test:
target: staging
outputs:
staging:
type: fal
db_profile: db
db:
type: athena
s3_staging_dir: "{{ env_var('ATHENA_S3_STAGING_DIR') }}"
region_name: us-east-1
database: "{{ env_var('ATHENA_DATABASE') }}"
schema: "{{ env_var('ATHENA_SCHEMA') }}"
num_retries: 0
9 changes: 9 additions & 0 deletions adapter/src/dbt/adapters/fal/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ def __getattr__(self, name):
else:
getattr(super(), name)

def get_relation(self, database: str, schema: str, identifier: str):
# HACK: When compiling Python models, we get an all-False quoting policy
# This does not happen in 1.4
if self._db_adapter.type() == "athena":
# and dbt-athena-community breaks for that case
self.config.quoting = {"database": True, "schema": True, "identifier": True}

return self._db_adapter.get_relation(database, schema, identifier)


def find_funcs_in_stack(funcs: Set[str]) -> bool:
import inspect
Expand Down
19 changes: 16 additions & 3 deletions adapter/src/dbt/adapters/fal_experimental/adapter_support.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
from time import sleep
from typing import Any

import pandas as pd
Expand All @@ -26,7 +25,8 @@ def _get_alchemy_engine(adapter: BaseAdapter, connection: Connection) -> Any:

sqlalchemy_kwargs = {}
format_url = lambda url: url
if adapter_type == 'trino':

if adapter_type == "trino":
import dbt.adapters.fal_experimental.support.trino as support_trino
return support_trino.create_engine(adapter)

Expand Down Expand Up @@ -82,12 +82,19 @@ def write_df_to_relation(

return support_duckdb.write_df_to_relation(adapter, dataframe, relation)

elif adapter.type() == "athena":
import dbt.adapters.fal_experimental.support.athena as support_athena

return support_athena.write_df_to_relation(adapter, dataframe, relation, if_exists)

else:
with new_connection(adapter, "fal:write_df_to_relation") as connection:

# TODO: this should probably live in the materialization macro.
temp_relation = relation.replace_path(
identifier=f"__dbt_fal_temp_{relation.identifier}"
)

drop_relation_if_it_exists(adapter, temp_relation)

alchemy_engine = _get_alchemy_engine(adapter, connection)
Expand All @@ -103,12 +110,12 @@ def write_df_to_relation(
)
adapter.cache.add(temp_relation)
drop_relation_if_it_exists(adapter, relation)

adapter.rename_relation(temp_relation, relation)
adapter.commit_if_has_connection()

return AdapterResponse("OK", rows_affected=rows_affected)


def read_relation_as_df(adapter: BaseAdapter, relation: BaseRelation) -> pd.DataFrame:
"""Generic version of the read_df_from_relation."""

Expand All @@ -127,9 +134,15 @@ def read_relation_as_df(adapter: BaseAdapter, relation: BaseRelation) -> pd.Data

return support_duckdb.read_relation_as_df(adapter, relation)

elif adapter.type() == "athena":
import dbt.adapters.fal_experimental.support.athena as support_athena

return support_athena.read_relation_as_df(adapter, relation)

else:
with new_connection(adapter, "fal:read_relation_as_df") as connection:
alchemy_engine = _get_alchemy_engine(adapter, connection)

return pd.read_sql_table(
con=alchemy_engine,
table_name=relation.identifier,
Expand Down
1 change: 0 additions & 1 deletion adapter/src/dbt/adapters/fal_experimental/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def manifest(self) -> Manifest:
def macro_manifest(self) -> MacroManifest:
return self._db_adapter.load_macro_manifest()


@telemetry.log_call("experimental_submit_python_job", config=True)
def submit_python_job(
self, parsed_model: dict, compiled_code: str
Expand Down
82 changes: 82 additions & 0 deletions adapter/src/dbt/adapters/fal_experimental/support/athena.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Any
import six
from dbt.adapters.base.relation import BaseRelation
from dbt.contracts.connection import AdapterResponse
import sqlalchemy
import pandas as pd
from dbt.adapters.base import BaseAdapter
from urllib.parse import quote_plus


def create_engine(adapter: BaseAdapter) -> Any:
creds = adapter.config.credentials._db_creds
conn_str = ("awsathena+rest://:@athena.{region_name}.amazonaws.com:443/"
"{schema_name}?s3_staging_dir={s3_staging_dir}"
"&location={location}&compression=snappy")
return sqlalchemy.create_engine(conn_str.format(
region_name=creds.region_name,
schema_name=creds.schema,
s3_staging_dir=quote_plus(creds.s3_staging_dir),
location=quote_plus(creds.s3_staging_dir)))


def drop_relation_if_it_exists(adapter: BaseAdapter, relation: BaseRelation) -> None:
if adapter.get_relation(
database=relation.database,
schema=relation.schema,
identifier=relation.identifier,
):
adapter.drop_relation(relation)


def write_df_to_relation(adapter, dataframe, relation, if_exists) -> AdapterResponse:

assert adapter.type() == "athena"

# This is a quirk of dbt-athena-community, where they set
# relation.schema = relation.identifier
temp_relation = relation.replace_path(
schema=relation.database,
database=adapter.config.credentials._db_creds.database,
# athena complanes when table location has x.__y
identifier=f"dbt_fal_temp_{relation.schema}"
)

relation = temp_relation.replace_path(identifier=relation.schema)

drop_relation_if_it_exists(adapter, temp_relation)

alchemy_engine = create_engine(adapter)

rows_affected = dataframe.to_sql(
con=alchemy_engine,
name=temp_relation.identifier,
schema=temp_relation.schema,
if_exists=if_exists,
index=False,
)

adapter.cache.add(temp_relation)

drop_relation_if_it_exists(adapter, relation)


# athena doesn't let us rename relations, so we do it by hand
stmt = f"create table {relation} as select * from {temp_relation} with data"
adapter.execute(six.text_type(stmt).strip())
adapter.cache.add(relation)
adapter.drop_relation(temp_relation)

adapter.commit_if_has_connection()
return AdapterResponse("OK", rows_affected=rows_affected)

def read_relation_as_df(adapter: BaseAdapter, relation: BaseRelation) -> pd.DataFrame:
alchemy_engine = create_engine(adapter)

# This is dbt-athena-community quirk, table_name=relation.schema

return pd.read_sql_table(
con=alchemy_engine,
table_name=relation.schema,
schema=relation.database,
)
11 changes: 9 additions & 2 deletions adapter/src/dbt/adapters/fal_experimental/utils/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,21 @@ def _parse_remote_config(config: Dict[str, Any], parsed_config: Dict[str, Any])
"target_environments": [env_definition]
}

def _get_package_from_type(adapter_type: str):
SPECIAL_ADAPTERS = {
# Documented in dbt website
"athena": "dbt-athena-community",
}
return SPECIAL_ADAPTERS.get(adapter_type, f"dbt-{adapter_type}")


def _get_dbt_packages(
adapter_type: str,
is_teleport: bool = False,
is_remote: bool = False
) -> Iterator[Tuple[str, Optional[str]]]:
dbt_adapter = f"dbt-{adapter_type}"
for dbt_plugin_name in ['dbt-core', dbt_adapter]:
dbt_adapter = _get_package_from_type(adapter_type)
for dbt_plugin_name in ["dbt-core", dbt_adapter]:
distribution = importlib_metadata.distribution(dbt_plugin_name)

yield dbt_plugin_name, distribution.version
Expand Down

0 comments on commit f533e92

Please sign in to comment.