Skip to content
This repository has been archived by the owner on Apr 8, 2024. It is now read-only.

Commit

Permalink
Do a single connection for whole model
Browse files Browse the repository at this point in the history
  • Loading branch information
chamini2 committed Sep 28, 2022
1 parent 9dc3741 commit 34043e8
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 45 deletions.
73 changes: 32 additions & 41 deletions adapter/src/dbt/adapters/fal/adapter_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

import pandas as pd
import sqlalchemy
from contextlib import contextmanager
from dbt.adapters.base import BaseAdapter, BaseRelation, RelationType
from dbt.adapters.base.connections import AdapterResponse, Connection
from dbt.adapters.base.connections import AdapterResponse
from dbt.config import RuntimeConfig
from dbt.parser.manifest import ManifestLoader

Expand All @@ -17,11 +16,12 @@
}


def _get_alchemy_engine(adapter: BaseAdapter, connection: Connection) -> Any:
def _get_alchemy_engine(adapter: BaseAdapter) -> Any:
# The following code heavily depends on the implementation
# details of the known adapters, hence it can't work for
# arbitrary ones.
adapter_type = adapter.type()
connection = adapter.connections.get_if_exists()

sqlalchemy_kwargs = {}
format_url = lambda url: url
Expand Down Expand Up @@ -80,30 +80,29 @@ def write_df_to_relation(
return write_df_to_relation(adapter, dataframe, relation)

else:
with new_connection(adapter, "fal:write_df_to_relation") as connection:
# TODO: this should probably live in the materialization macro.
temp_relation = relation.replace_path(
identifier=f"__dbt_fal_temp_{relation.identifier}"
)
drop_relation_if_it_exists(adapter, temp_relation)

alchemy_engine = _get_alchemy_engine(adapter, connection)

# TODO: probably worth handling errors here an returning
# a proper adapter response.
rows_affected = dataframe.to_sql(
con=alchemy_engine,
name=temp_relation.identifier,
schema=temp_relation.schema,
if_exists=if_exists,
index=False,
)
adapter.cache.add(temp_relation)
drop_relation_if_it_exists(adapter, relation)
adapter.rename_relation(temp_relation, relation)
adapter.commit_if_has_connection()
# TODO: this should probably live in the materialization macro.
temp_relation = relation.replace_path(
identifier=f"__dbt_fal_temp_{relation.identifier}"
)
drop_relation_if_it_exists(adapter, temp_relation)

alchemy_engine = _get_alchemy_engine(adapter)

# TODO: probably worth handling errors here an returning
# a proper adapter response.
rows_affected = dataframe.to_sql(
con=alchemy_engine,
name=temp_relation.identifier,
schema=temp_relation.schema,
if_exists=if_exists,
index=False,
)
adapter.cache.add(temp_relation)
drop_relation_if_it_exists(adapter, relation)
adapter.rename_relation(temp_relation, relation)
adapter.commit_if_has_connection()

return AdapterResponse("OK", rows_affected=rows_affected)
return AdapterResponse("OK", rows_affected=rows_affected)


def read_relation_as_df(adapter: BaseAdapter, relation: BaseRelation) -> pd.DataFrame:
Expand All @@ -120,13 +119,12 @@ def read_relation_as_df(adapter: BaseAdapter, relation: BaseRelation) -> pd.Data
return read_relation_as_df(adapter, relation)

else:
with new_connection(adapter, "fal:read_relation_as_df") as connection:
alchemy_engine = _get_alchemy_engine(adapter, connection)
return pd.read_sql_table(
con=alchemy_engine,
table_name=relation.identifier,
schema=relation.schema,
)
alchemy_engine = _get_alchemy_engine(adapter)
return pd.read_sql_table(
con=alchemy_engine,
table_name=relation.identifier,
schema=relation.schema,
)


def prepare_for_adapter(adapter: BaseAdapter, function: Any) -> Any:
Expand Down Expand Up @@ -163,11 +161,4 @@ def reconstruct_adapter(config: RuntimeConfig) -> BaseAdapter:

def reload_adapter_cache(adapter: BaseAdapter, config: RuntimeConfig) -> None:
manifest = ManifestLoader.get_full_manifest(config)
with new_connection(adapter, "fal:reload_adapter_cache"):
adapter.set_relations_cache(manifest, True)


@contextmanager
def new_connection(adapter: BaseAdapter, connection_name: str) -> Connection:
with adapter.connection_named(connection_name):
yield adapter.connections.get_thread_connection()
adapter.set_relations_cache(manifest, True)
10 changes: 6 additions & 4 deletions adapter/src/dbt/adapters/fal/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ def _run_with_adapter(code: str, adapter: BaseAdapter) -> Any:
# main symbol is defined during dbt-fal's compilation
# and acts as an entrypoint for us to run the model.
main = retrieve_symbol(code, "main")
return main(
read_df=prepare_for_adapter(adapter, read_relation_as_df),
write_df=prepare_for_adapter(adapter, write_df_to_relation),
)

with adapter.connection_named("fal:model"):
return main(
read_df=prepare_for_adapter(adapter, read_relation_as_df),
write_df=prepare_for_adapter(adapter, write_df_to_relation),
)


def _isolated_runner(code: str, config: RuntimeConfig) -> Any:
Expand Down

0 comments on commit 34043e8

Please sign in to comment.