Skip to content

Commit

Permalink
🔨 deprecate DBUtils
Browse files Browse the repository at this point in the history
  • Loading branch information
Marigold committed Apr 17, 2024
1 parent 6f4ef28 commit 56031cf
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 219 deletions.
50 changes: 25 additions & 25 deletions etl/chart_revision/v1/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tqdm import tqdm

from etl.config import DEBUG, GRAPHER_USER_ID
from etl.db import open_db
from etl.db import get_engine
from etl.grapher_helpers import IntRange

log = structlog.get_logger()
Expand Down Expand Up @@ -179,23 +179,23 @@ def _get_chart_update_reason(self, variable_ids: List[int]) -> str:
Accesses DB and finds out the name of the recently added dataset with the new variables."""
try:
with open_db() as db:
with get_engine().connect() as con:
if len(variable_ids) == 1:
results = db.fetch_many(
results = con.execute(
f"""
SELECT variables.name, datasets.name, datasets.version FROM datasets
JOIN variables ON datasets.id = variables.datasetId
WHERE variables.id IN ({variable_ids[0]})
"""
)
).fetchmany()
else:
results = db.fetch_many(
results = con.execute(
f"""
SELECT variables.name, datasets.name, datasets.version FROM datasets
JOIN variables ON datasets.id = variables.datasetId
WHERE variables.id IN {*variable_ids,}
"""
)
).fetchmany()
except Exception:
self.report_error(
"Problem found when accessing the DB trying to get details on the newly added variables"
Expand All @@ -220,10 +220,10 @@ def _get_chart_update_reason(self, variable_ids: List[int]) -> str:
def insert(self, suggested_chart_revisions: List[dict[str, Any]]) -> None:
n_before = 0
try:
with open_db() as db:
n_before = db.fetch_one("SELECT COUNT(id) FROM suggested_chart_revisions")[0]
with get_engine().connect() as con:
n_before = con.execute("SELECT COUNT(id) FROM suggested_chart_revisions").fetchone()[0] # type: ignore

res = db.fetch_many(
res = con.execute(
"""
SELECT *
FROM (
Expand All @@ -235,7 +235,7 @@ def insert(self, suggested_chart_revisions: List[dict[str, Any]]) -> None:
) as grouped
WHERE grouped.c > 1
"""
)
).fetchmany()
if len(res):
raise RuntimeError(
"Two or more suggested chart revisions with status IN "
Expand Down Expand Up @@ -267,13 +267,13 @@ def insert(self, suggested_chart_revisions: List[dict[str, Any]]) -> None:
VALUES
(%s, %s, %s, %s, %s, %s, NOW(), NOW())
"""
db.upsert_many(query, tuples)
con.execute(query, tuples)

# checks if any of the affected chartIds now has multiple
# pending suggested revisions. If so, then rejects the whole
# insert and tell the user which suggested chart revisions need
# to be approved/rejected.
res = db.fetch_many(
res = con.execute(
f"""
SELECT id, scr.chartId, c, createdAt
FROM (
Expand All @@ -291,7 +291,7 @@ def insert(self, suggested_chart_revisions: List[dict[str, Any]]) -> None:
WHERE grouped.c > 1
ORDER BY createdAt ASC
"""
)
).fetchmany()
if len(res):
df = pd.DataFrame(res, columns=["id", "chart_id", "count", "created_at"])
df["drop"] = df.groupby("chart_id")["created_at"].transform(lambda gp: gp == gp.max())
Expand Down Expand Up @@ -321,8 +321,8 @@ def insert(self, suggested_chart_revisions: List[dict[str, Any]]) -> None:
self.report_error(f"INSERT operation into `suggested_chart_revisions` cancelled. Error: {e}")
raise e
finally:
with open_db() as db:
n_after = db.fetch_one("SELECT COUNT(id) FROM suggested_chart_revisions")[0]
with get_engine().connect() as con:
n_after = con.execute("SELECT COUNT(id) FROM suggested_chart_revisions").fetchone()[0] # type: ignore

self.report_info(
f"{n_after - n_before} of {len(suggested_chart_revisions)} suggested chart revisions inserted."
Expand All @@ -343,18 +343,18 @@ def _get_charts_from_old_variables(
df_chart_dimensions: dataframe of chart_dimensions rows.
df_chart_revisions: dataframe of chart_revisions rows.
"""
with open_db() as db:
with get_engine().connect() as con:
# retrieves chart_dimensions
variable_ids = list(self.old_var_id2new_var_id.keys())
variable_ids_str = ",".join([str(_id) for _id in variable_ids])
columns = ["id", "chartId", "variableId", "property", "order"]
rows = db.fetch_many(
rows = con.execute(
f"""
SELECT {','.join([f'`{col}`' for col in columns])}
FROM chart_dimensions
WHERE variableId IN ({variable_ids_str})
"""
)
).fetchmany()
df_chart_dimensions = pd.DataFrame(rows, columns=columns)

# retrieves charts
Expand All @@ -369,40 +369,40 @@ def _get_charts_from_old_variables(
"lastEditedAt",
"publishedAt",
]
rows = db.fetch_many(
rows = con.execute(
f"""
SELECT {','.join(columns)}
FROM charts
WHERE id IN ({chart_ids_str})
"""
)
).fetchmany()
df_charts = pd.DataFrame(rows, columns=columns)

# retrieves chart_revisions
columns = ["id", "chartId", "userId", "config", "createdAt", "updatedAt"]
rows = db.fetch_many(
rows = con.execute(
f"""
SELECT {','.join(columns)}
FROM chart_revisions
WHERE chartId IN ({chart_ids_str})
"""
)
).fetchmany()
df_chart_revisions = pd.DataFrame(rows, columns=columns)
return df_charts, df_chart_dimensions, df_chart_revisions

def _get_variable_year_ranges(self) -> Dict[int, List[int]]:
with open_db() as db:
with get_engine().connect() as con:
all_var_ids = list(self.old_var_id2new_var_id.keys()) + list(self.old_var_id2new_var_id.values())
variable_ids_str = ",".join([str(_id) for _id in all_var_ids])
raise NotImplementedError("data_values was deprecated")
rows = db.fetch_many(
rows = con.execute(
f"""
SELECT variableId, MIN(year) AS minYear, MAX(year) AS maxYear
FROM data_values
WHERE variableId IN ({variable_ids_str})
GROUP BY variableId
"""
)
).fetchmany()
var_id2year_range = {}
for variable_id, min_year, max_year in rows:
var_id2year_range[variable_id] = [min_year, max_year]
Expand Down
30 changes: 15 additions & 15 deletions etl/chart_revision/v1/revision.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from etl.chart_revision.v1.chart import Chart
from etl.chart_revision.v1.variables import VariablesUpdate
from etl.config import GRAPHER_USER_ID
from etl.db import get_engine, open_db
from etl.db import get_engine

log = get_logger()
# The maximum length of the suggested revision reason can't exceed the maximum length specified by the datatype "suggestedReason" in grapher.suggested_chart_revisions table.
Expand Down Expand Up @@ -341,10 +341,10 @@ def submit_revisions_to_grapher(revisions: List[ChartVariableUpdateRevision]):
"""Submit chart revisions to Grapher."""
n_before = 0
try:
with open_db() as db:
n_before = db.fetch_one("SELECT COUNT(id) FROM suggested_chart_revisions")[0]
with get_engine().connect() as con:
n_before = con.execute("SELECT COUNT(id) FROM suggested_chart_revisions").fetchone()[0] # type: ignore

res = db.fetch_many(
res = con.execute(
"""
SELECT *
FROM (
Expand All @@ -356,7 +356,7 @@ def submit_revisions_to_grapher(revisions: List[ChartVariableUpdateRevision]):
) as grouped
WHERE grouped.c > 1
"""
)
).fetchmany()
if len(res):
raise RuntimeError(
"Two or more suggested chart revisions with status IN "
Expand Down Expand Up @@ -387,13 +387,13 @@ def submit_revisions_to_grapher(revisions: List[ChartVariableUpdateRevision]):
VALUES
(%s, %s, %s, %s, %s, %s, %s, NOW(), NOW())
"""
db.upsert_many(query, tuples)
con.execute(query, tuples)

# checks if any of the affected chartIds now has multiple
# pending suggested revisions. If so, then rejects the whole
# insert and tell the user which suggested chart revisions need
# to be approved/rejected.
res = db.fetch_many(
res = con.execute(
f"""
SELECT id, scr.chartId, c, createdAt
FROM (
Expand All @@ -411,7 +411,7 @@ def submit_revisions_to_grapher(revisions: List[ChartVariableUpdateRevision]):
WHERE grouped.c > 1
ORDER BY createdAt ASC
"""
)
).fetchmany()
if len(res):
df = pd.DataFrame(res, columns=["id", "chart_id", "count", "created_at"])
df["drop"] = df.groupby("chart_id")["created_at"].transform(lambda gp: gp == gp.max())
Expand Down Expand Up @@ -441,8 +441,8 @@ def submit_revisions_to_grapher(revisions: List[ChartVariableUpdateRevision]):
log.info(f"INSERT operation into `suggested_chart_revisions` cancelled. Error: {e}")
raise e
finally:
with open_db() as db:
n_after = db.fetch_one("SELECT COUNT(id) FROM suggested_chart_revisions")[0]
with get_engine().connect() as con:
n_after = con.execute("SELECT COUNT(id) FROM suggested_chart_revisions").fetchone()[0] # type: ignore

log.info(f"{n_after - n_before} of {len(revisions)} suggested chart revisions inserted.")

Expand All @@ -452,23 +452,23 @@ def _get_chart_update_reason(variable_ids: List[int]) -> str:
Accesses DB and finds out the name of the recently added dataset with the new variables."""
try:
with open_db() as db:
with get_engine().connect() as con:
if len(variable_ids) == 1:
results = db.fetch_many(
results = con.execute(
f"""
SELECT variables.name, datasets.name, datasets.version FROM datasets
JOIN variables ON datasets.id = variables.datasetId
WHERE variables.id IN ({variable_ids[0]})
"""
)
).fetchmany()
else:
results = db.fetch_many(
results = con.execute(
f"""
SELECT variables.name, datasets.name, datasets.version FROM datasets
JOIN variables ON datasets.id = variables.datasetId
WHERE variables.id IN {*variable_ids,}
"""
)
).fetchmany()
except Exception:
log.error(
"Problem found when accessing the DB trying to get details on the newly added variables"
Expand Down
29 changes: 0 additions & 29 deletions etl/db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import traceback
import warnings
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, cast
from urllib.parse import quote

Expand All @@ -14,7 +11,6 @@
from sqlmodel import Session

from etl import config
from etl.db_utils import DBUtils

log = structlog.get_logger()

Expand Down Expand Up @@ -59,31 +55,6 @@ def get_engine(conf: Optional[Dict[str, Any]] = None) -> Engine:
)


@contextmanager
def open_db() -> Generator[DBUtils, None, None]:
connection = None
cursor = None
try:
connection = get_connection()
connection.autocommit(False)
cursor = connection.cursor()
yield DBUtils(cursor)
connection.commit()
except Exception as e:
log.error(f"Error encountered during import: {e}")
log.error("Rolling back changes...")
if connection:
connection.rollback()
if config.DEBUG:
traceback.print_exc()
raise e
finally:
if cursor:
cursor.close()
if connection:
connection.close()


def get_dataset_id(
dataset_name: str, db_conn: Optional[MySQLdb.Connection] = None, version: Optional[str] = None
) -> Any:
Expand Down
Loading

0 comments on commit 56031cf

Please sign in to comment.