diff --git a/src/db/util.py b/src/db/util.py index 6663a92e4..5b8d05889 100644 --- a/src/db/util.py +++ b/src/db/util.py @@ -1,6 +1,8 @@ import logging import re import psycopg2 +import psycopg2.pool +from urllib import parse from flask import g import crs as crs_def @@ -11,6 +13,8 @@ FLASK_CONN_CUR_KEY = f'{__name__}:CONN_CUR' +CONNECTION_POOL_DICT = {} + def create_connection_cursor(db_uri_str=None, encapsulate_exception=True): db_uri_str = db_uri_str or PG_URI_STR @@ -25,6 +29,26 @@ def create_connection_cursor(db_uri_str=None, encapsulate_exception=True): return connection, cursor +def get_connection_pool(db_uri_str=None): + db_uri_str = db_uri_str or PG_URI_STR + connection_pool = CONNECTION_POOL_DICT.get(db_uri_str) + if not connection_pool: + db_uri_parsed = parse.urlparse(db_uri_str) + try: + connection_pool = psycopg2.pool.SimpleConnectionPool(5, 50, + user=db_uri_parsed.username, + password=db_uri_parsed.password, + host=db_uri_parsed.hostname, + port=db_uri_parsed.port, + database=db_uri_parsed.path[1:], + ) + except BaseException as exc: + raise Error(1) from exc + CONNECTION_POOL_DICT[db_uri_str] = connection_pool + return connection_pool + + + def get_connection_cursor(db_uri_str=None, encapsulate_exception=True): if db_uri_str is None or db_uri_str == PG_URI_STR: key = FLASK_CONN_CUR_KEY @@ -39,8 +63,12 @@ def get_connection_cursor(db_uri_str=None, encapsulate_exception=True): def run_query(query, data=None, conn_cur=None, encapsulate_exception=True, log_query=False): if conn_cur is None: - conn_cur = get_connection_cursor() - conn, cur = conn_cur + pool = get_connection_pool() + conn = pool.getconn() + conn.autocommit = True + cur = conn.cursor() + else: + conn, cur = conn_cur try: if log_query: logger.info(f"query={cur.mogrify(query, data).decode()}") @@ -52,14 +80,21 @@ def run_query(query, data=None, conn_cur=None, encapsulate_exception=True, log_q logger.error(f"run_query, query={query}, data={data}, exc={exc}") raise Error(2) from exc raise exc + finally: + if pool and conn: + pool.putconn(conn) return rows def run_statement(query, data=None, conn_cur=None, encapsulate_exception=True, log_query=False): if conn_cur is None: - conn_cur = get_connection_cursor() - conn, cur = conn_cur + pool = get_connection_pool() + conn = pool.getconn() + conn.autocommit = True + cur = conn.cursor() + else: + conn, cur = conn_cur try: if log_query: logger.info(f"query={cur.mogrify(query, data).decode()}") @@ -71,6 +106,10 @@ def run_statement(query, data=None, conn_cur=None, encapsulate_exception=True, l logger.error(f"run_query, query={query}, data={data}, exc={exc}") raise Error(2) from exc raise exc + finally: + if pool and conn: + pool.putconn(conn) + return rows