Skip to content

Commit

Permalink
WIP: Do not pool connections internally
Browse files Browse the repository at this point in the history
  • Loading branch information
index-git committed Sep 25, 2023
1 parent 7f5238a commit ed12cf0
Show file tree
Hide file tree
Showing 16 changed files with 102 additions and 199 deletions.
79 changes: 24 additions & 55 deletions src/db/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,72 +3,44 @@
import psycopg2
import psycopg2.pool
from urllib import parse
from flask import g

import crs as crs_def
from . import PG_URI_STR
from .error import Error

logger = logging.getLogger(__name__)

FLASK_CONN_CUR_KEY = f'{__name__}:CONN_CUR'

CONNECTION_POOL_DICT = {}


def create_connection_cursor(db_uri_str=None, encapsulate_exception=True):
db_uri_str = db_uri_str or PG_URI_STR
try:
connection = psycopg2.connect(db_uri_str)
connection.set_session(autocommit=True)
except BaseException as exc:
if encapsulate_exception:
raise Error(1) from exc
raise exc
cursor = connection.cursor()
return connection, cursor


def get_connection_pool(db_uri_str=None):
def get_connection_pool(db_uri_str=None, encapsulate_exception=True):
db_uri_str = db_uri_str or PG_URI_STR
connection_pool = CONNECTION_POOL_DICT.get(db_uri_str)
if not connection_pool:
db_uri_parsed = parse.urlparse(db_uri_str)
try:
connection_pool = psycopg2.pool.SimpleConnectionPool(5, 50,
user=db_uri_parsed.username,
password=db_uri_parsed.password,
host=db_uri_parsed.hostname,
port=db_uri_parsed.port,
database=db_uri_parsed.path[1:],
)
connection_pool = psycopg2.pool.ThreadedConnectionPool(5, 50,
user=db_uri_parsed.username,
password=db_uri_parsed.password,
host=db_uri_parsed.hostname,
port=db_uri_parsed.port,
database=db_uri_parsed.path[1:],
)
except BaseException as exc:
raise Error(1) from exc
if encapsulate_exception:
raise Error(1) from exc
else:
raise exc
CONNECTION_POOL_DICT[db_uri_str] = connection_pool
return connection_pool



def get_connection_cursor(db_uri_str=None, encapsulate_exception=True):
if db_uri_str is None or db_uri_str == PG_URI_STR:
key = FLASK_CONN_CUR_KEY
if key not in g:
conn_cur = create_connection_cursor(encapsulate_exception=encapsulate_exception)
g.setdefault(key, conn_cur)
result = g.get(key)
else:
result = create_connection_cursor(db_uri_str=db_uri_str, encapsulate_exception=encapsulate_exception)
return result


def run_query(query, data=None, conn_cur=None, encapsulate_exception=True, log_query=False):
if conn_cur is None:
pool = get_connection_pool()
conn = pool.getconn()
conn.autocommit = True
cur = conn.cursor()
else:
conn, cur = conn_cur
def run_query(query, data=None, uri_str=None, encapsulate_exception=True, log_query=False):
pool = get_connection_pool(db_uri_str=uri_str, encapsulate_exception=encapsulate_exception, )
conn = pool.getconn()
conn.autocommit = True
cur = conn.cursor()
try:
if log_query:
logger.info(f"query={cur.mogrify(query, data).decode()}")
Expand All @@ -87,14 +59,11 @@ def run_query(query, data=None, conn_cur=None, encapsulate_exception=True, log_q
return rows


def run_statement(query, data=None, conn_cur=None, encapsulate_exception=True, log_query=False):
if conn_cur is None:
pool = get_connection_pool()
conn = pool.getconn()
conn.autocommit = True
cur = conn.cursor()
else:
conn, cur = conn_cur
def run_statement(query, data=None, uri_str=None, encapsulate_exception=True, log_query=False):
pool = get_connection_pool(db_uri_str=uri_str, encapsulate_exception=encapsulate_exception, )
conn = pool.getconn()
conn.autocommit = True
cur = conn.cursor()
try:
if log_query:
logger.info(f"query={cur.mogrify(query, data).decode()}")
Expand Down Expand Up @@ -132,14 +101,14 @@ def get_internal_srid(crs):
return srid


def get_crs_from_srid(srid, conn_cur=None, *, use_internal_srid):
def get_crs_from_srid(srid, uri_str=None, *, use_internal_srid):
crs = next((
crs_code for crs_code, crs_item_def in crs_def.CRSDefinitions.items()
if crs_item_def.internal_srid == srid
), None) if use_internal_srid else None
if not crs:
sql = 'select auth_name, auth_srid from spatial_ref_sys where srid = %s;'
auth_name, auth_srid = run_query(sql, (srid, ), conn_cur=conn_cur)[0]
auth_name, auth_srid = run_query(sql, (srid, ), uri_str=uri_str)[0]
if auth_name or auth_srid:
crs = f'{auth_name}:{auth_srid}'
return crs
Expand Down
2 changes: 1 addition & 1 deletion src/layman/common/prime_db_schema/schema_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def ensure_schema(db_schema):
db_util.run_statement(model.CREATE_SCHEMA_SQL)
db_util.run_statement(model.setup_codelists_data())
except BaseException as exc:
db_util.run_statement(model.DROP_SCHEMA_SQL, conn_cur=db_util.get_connection_cursor())
db_util.run_statement(model.DROP_SCHEMA_SQL, )
raise exc
else:
logger.info(f"Layman DB schema already exists, schema_name={db_schema}")
Expand Down
Loading

0 comments on commit ed12cf0

Please sign in to comment.