Skip to content

Commit

Permalink
WIP: Do not pool connections internally
Browse files Browse the repository at this point in the history
  • Loading branch information
index-git committed Sep 25, 2023
1 parent 85268aa commit f186066
Show file tree
Hide file tree
Showing 17 changed files with 106 additions and 213 deletions.
81 changes: 24 additions & 57 deletions src/db/util.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,44 @@
import logging
import re
from urllib import parse
import psycopg2
import psycopg2.pool
from urllib import parse
from flask import g

import crs as crs_def
from . import PG_URI_STR
from .error import Error

logger = logging.getLogger(__name__)

FLASK_CONN_CUR_KEY = f'{__name__}:CONN_CUR'

CONNECTION_POOL_DICT = {}


def create_connection_cursor(db_uri_str=None, encapsulate_exception=True):
db_uri_str = db_uri_str or PG_URI_STR
try:
connection = psycopg2.connect(db_uri_str)
connection.set_session(autocommit=True)
except BaseException as exc:
if encapsulate_exception:
raise Error(1) from exc
raise exc
cursor = connection.cursor()
return connection, cursor


def get_connection_pool(db_uri_str=None):
def get_connection_pool(db_uri_str=None, encapsulate_exception=True):
db_uri_str = db_uri_str or PG_URI_STR
connection_pool = CONNECTION_POOL_DICT.get(db_uri_str)
if not connection_pool:
db_uri_parsed = parse.urlparse(db_uri_str)
try:
connection_pool = psycopg2.pool.SimpleConnectionPool(5, 50,
user=db_uri_parsed.username,
password=db_uri_parsed.password,
host=db_uri_parsed.hostname,
port=db_uri_parsed.port,
database=db_uri_parsed.path[1:],
)
connection_pool = psycopg2.pool.ThreadedConnectionPool(5, 50,
user=db_uri_parsed.username,
password=db_uri_parsed.password,
host=db_uri_parsed.hostname,
port=db_uri_parsed.port,
database=db_uri_parsed.path[1:],
)
except BaseException as exc:
raise Error(1) from exc
if encapsulate_exception:
raise Error(1) from exc
raise exc
CONNECTION_POOL_DICT[db_uri_str] = connection_pool
return connection_pool



def get_connection_cursor(db_uri_str=None, encapsulate_exception=True):
if db_uri_str is None or db_uri_str == PG_URI_STR:
key = FLASK_CONN_CUR_KEY
if key not in g:
conn_cur = create_connection_cursor(encapsulate_exception=encapsulate_exception)
g.setdefault(key, conn_cur)
result = g.get(key)
else:
result = create_connection_cursor(db_uri_str=db_uri_str, encapsulate_exception=encapsulate_exception)
return result


def run_query(query, data=None, conn_cur=None, encapsulate_exception=True, log_query=False):
if conn_cur is None:
pool = get_connection_pool()
conn = pool.getconn()
conn.autocommit = True
cur = conn.cursor()
else:
conn, cur = conn_cur
def run_query(query, data=None, uri_str=None, encapsulate_exception=True, log_query=False):
pool = get_connection_pool(db_uri_str=uri_str, encapsulate_exception=encapsulate_exception, )
conn = pool.getconn()
conn.autocommit = True
cur = conn.cursor()
try:
if log_query:
logger.info(f"query={cur.mogrify(query, data).decode()}")
Expand All @@ -87,14 +57,11 @@ def run_query(query, data=None, conn_cur=None, encapsulate_exception=True, log_q
return rows


def run_statement(query, data=None, conn_cur=None, encapsulate_exception=True, log_query=False):
if conn_cur is None:
pool = get_connection_pool()
conn = pool.getconn()
conn.autocommit = True
cur = conn.cursor()
else:
conn, cur = conn_cur
def run_statement(query, data=None, uri_str=None, encapsulate_exception=True, log_query=False):
pool = get_connection_pool(db_uri_str=uri_str, encapsulate_exception=encapsulate_exception, )
conn = pool.getconn()
conn.autocommit = True
cur = conn.cursor()
try:
if log_query:
logger.info(f"query={cur.mogrify(query, data).decode()}")
Expand Down Expand Up @@ -132,14 +99,14 @@ def get_internal_srid(crs):
return srid


def get_crs_from_srid(srid, conn_cur=None, *, use_internal_srid):
def get_crs_from_srid(srid, uri_str=None, *, use_internal_srid):
crs = next((
crs_code for crs_code, crs_item_def in crs_def.CRSDefinitions.items()
if crs_item_def.internal_srid == srid
), None) if use_internal_srid else None
if not crs:
sql = 'select auth_name, auth_srid from spatial_ref_sys where srid = %s;'
auth_name, auth_srid = run_query(sql, (srid, ), conn_cur=conn_cur)[0]
auth_name, auth_srid = run_query(sql, (srid, ), uri_str=uri_str)[0]
if auth_name or auth_srid:
crs = f'{auth_name}:{auth_srid}'
return crs
Expand Down
2 changes: 1 addition & 1 deletion src/layman/common/prime_db_schema/schema_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def ensure_schema(db_schema):
db_util.run_statement(model.CREATE_SCHEMA_SQL)
db_util.run_statement(model.setup_codelists_data())
except BaseException as exc:
db_util.run_statement(model.DROP_SCHEMA_SQL, conn_cur=db_util.get_connection_cursor())
db_util.run_statement(model.DROP_SCHEMA_SQL, )
raise exc
else:
logger.info(f"Layman DB schema already exists, schema_name={db_schema}")
Expand Down
4 changes: 1 addition & 3 deletions src/layman/geoserver_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from flask import Blueprint, g, current_app as app, request, Response

import crs as crs_def
from db import util as db_util
from geoserver.util import reset as gs_reset
from layman import authn, authz, settings, util as layman_util, LaymanError
from layman.authn import authenticate, is_user_with_name
Expand Down Expand Up @@ -89,9 +88,8 @@ def ensure_attributes_in_db(attributes_by_db):
(schema, table, attr): (workspace, layer, attr)
for workspace, layer, attr, schema, table in attr_tuples
}
conn_cur = db_util.get_connection_cursor(db_uri_str=db_uri_str)
db_attr_tuples = list(db_layman_attr_mapping.keys())
created_db_attr_tuples = db.ensure_attributes(db_attr_tuples, conn_cur)
created_db_attr_tuples = db.ensure_attributes(db_attr_tuples, db_uri_str=db_uri_str)
all_created_attr_tuples.update({db_layman_attr_mapping[a] for a in created_db_attr_tuples})
return all_created_attr_tuples

Expand Down
Loading

0 comments on commit f186066

Please sign in to comment.