diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f4798dfc..055f46c71 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ - [#880](https://github.com/LayerManager/layman/issues/880) Use Docker Compose v2 (`docker compose`) in Makefile without `compatibility` flag and remove `Makefile_docker-compose_v1` file. Docker containers are named according to Docker Compose v2 and may have different name after upgrade. - [#765](https://github.com/LayerManager/layman/issues/765) Stop saving OAuth2 claims in filesystem, use prime DB schema only. - [#893](https://github.com/LayerManager/layman/issues/893) It is possible to specify logging level by new environment variable [LAYMAN_LOGLEVEL](doc/env-settings.md#LAYMAN_LOGLEVEL). Default level is `INFO`. +- Use `psycopg2.pool.ThreadedConnectionPool` to share DB connections. - Add new test Python dependency: - jsonpath-ng 1.6.0 - Upgrade Python dependencies diff --git a/src/db/util.py b/src/db/util.py index 6663a92e4..78b91689c 100644 --- a/src/db/util.py +++ b/src/db/util.py @@ -1,7 +1,8 @@ import logging import re +from urllib import parse import psycopg2 -from flask import g +import psycopg2.pool import crs as crs_def from . import PG_URI_STR @@ -9,38 +10,35 @@ logger = logging.getLogger(__name__) -FLASK_CONN_CUR_KEY = f'{__name__}:CONN_CUR' +CONNECTION_POOL_DICT = {} -def create_connection_cursor(db_uri_str=None, encapsulate_exception=True): +def get_connection_pool(db_uri_str=None, encapsulate_exception=True): db_uri_str = db_uri_str or PG_URI_STR - try: - connection = psycopg2.connect(db_uri_str) - connection.set_session(autocommit=True) - except BaseException as exc: - if encapsulate_exception: - raise Error(1) from exc - raise exc - cursor = connection.cursor() - return connection, cursor - - -def get_connection_cursor(db_uri_str=None, encapsulate_exception=True): - if db_uri_str is None or db_uri_str == PG_URI_STR: - key = FLASK_CONN_CUR_KEY - if key not in g: - conn_cur = create_connection_cursor(encapsulate_exception=encapsulate_exception) - g.setdefault(key, conn_cur) - result = g.get(key) - else: - result = create_connection_cursor(db_uri_str=db_uri_str, encapsulate_exception=encapsulate_exception) - return result - - -def run_query(query, data=None, conn_cur=None, encapsulate_exception=True, log_query=False): - if conn_cur is None: - conn_cur = get_connection_cursor() - conn, cur = conn_cur + connection_pool = CONNECTION_POOL_DICT.get(db_uri_str) + if not connection_pool: + db_uri_parsed = parse.urlparse(db_uri_str) + try: + connection_pool = psycopg2.pool.ThreadedConnectionPool(3, 20, + user=db_uri_parsed.username, + password=db_uri_parsed.password, + host=db_uri_parsed.hostname, + port=db_uri_parsed.port, + database=db_uri_parsed.path[1:], + ) + except BaseException as exc: + if encapsulate_exception: + raise Error(1) from exc + raise exc + CONNECTION_POOL_DICT[db_uri_str] = connection_pool + return connection_pool + + +def run_query(query, data=None, uri_str=None, encapsulate_exception=True, log_query=False): + pool = get_connection_pool(db_uri_str=uri_str, encapsulate_exception=encapsulate_exception, ) + conn = pool.getconn() + conn.autocommit = True + cur = conn.cursor() try: if log_query: logger.info(f"query={cur.mogrify(query, data).decode()}") @@ -52,14 +50,17 @@ def run_query(query, data=None, conn_cur=None, encapsulate_exception=True, log_q logger.error(f"run_query, query={query}, data={data}, exc={exc}") raise Error(2) from exc raise exc + finally: + pool.putconn(conn) return rows -def run_statement(query, data=None, conn_cur=None, encapsulate_exception=True, log_query=False): - if conn_cur is None: - conn_cur = get_connection_cursor() - conn, cur = conn_cur +def run_statement(query, data=None, uri_str=None, encapsulate_exception=True, log_query=False): + pool = get_connection_pool(db_uri_str=uri_str, encapsulate_exception=encapsulate_exception, ) + conn = pool.getconn() + conn.autocommit = True + cur = conn.cursor() try: if log_query: logger.info(f"query={cur.mogrify(query, data).decode()}") @@ -71,6 +72,9 @@ def run_statement(query, data=None, conn_cur=None, encapsulate_exception=True, l logger.error(f"run_query, query={query}, data={data}, exc={exc}") raise Error(2) from exc raise exc + finally: + pool.putconn(conn) + return rows @@ -93,14 +97,14 @@ def get_internal_srid(crs): return srid -def get_crs_from_srid(srid, conn_cur=None, *, use_internal_srid): +def get_crs_from_srid(srid, uri_str=None, *, use_internal_srid): crs = next(( crs_code for crs_code, crs_item_def in crs_def.CRSDefinitions.items() if crs_item_def.internal_srid == srid ), None) if use_internal_srid else None if not crs: sql = 'select auth_name, auth_srid from spatial_ref_sys where srid = %s;' - auth_name, auth_srid = run_query(sql, (srid, ), conn_cur=conn_cur)[0] + auth_name, auth_srid = run_query(sql, (srid, ), uri_str=uri_str)[0] if auth_name or auth_srid: crs = f'{auth_name}:{auth_srid}' return crs diff --git a/src/layman/common/prime_db_schema/schema_initialization.py b/src/layman/common/prime_db_schema/schema_initialization.py index 36e3b240c..e27f98c58 100644 --- a/src/layman/common/prime_db_schema/schema_initialization.py +++ b/src/layman/common/prime_db_schema/schema_initialization.py @@ -22,7 +22,7 @@ def ensure_schema(db_schema): db_util.run_statement(model.CREATE_SCHEMA_SQL) db_util.run_statement(model.setup_codelists_data()) except BaseException as exc: - db_util.run_statement(model.DROP_SCHEMA_SQL, conn_cur=db_util.get_connection_cursor()) + db_util.run_statement(model.DROP_SCHEMA_SQL, ) raise exc else: logger.info(f"Layman DB schema already exists, schema_name={db_schema}") diff --git a/src/layman/geoserver_proxy.py b/src/layman/geoserver_proxy.py index 38b98cf83..75fa8b3aa 100644 --- a/src/layman/geoserver_proxy.py +++ b/src/layman/geoserver_proxy.py @@ -10,7 +10,6 @@ from flask import Blueprint, g, current_app as app, request, Response import crs as crs_def -from db import util as db_util from geoserver.util import reset as gs_reset from layman import authn, authz, settings, util as layman_util, LaymanError from layman.authn import authenticate, is_user_with_name @@ -89,9 +88,8 @@ def ensure_attributes_in_db(attributes_by_db): (schema, table, attr): (workspace, layer, attr) for workspace, layer, attr, schema, table in attr_tuples } - conn_cur = db_util.get_connection_cursor(db_uri_str=db_uri_str) db_attr_tuples = list(db_layman_attr_mapping.keys()) - created_db_attr_tuples = db.ensure_attributes(db_attr_tuples, conn_cur) + created_db_attr_tuples = db.ensure_attributes(db_attr_tuples, db_uri_str=db_uri_str) all_created_attr_tuples.update({db_layman_attr_mapping[a] for a in created_db_attr_tuples}) return all_created_attr_tuples diff --git a/src/layman/layer/db/__init__.py b/src/layman/layer/db/__init__.py index 210e4fda7..75ffa835c 100644 --- a/src/layman/layer/db/__init__.py +++ b/src/layman/layer/db/__init__.py @@ -29,12 +29,8 @@ def get_internal_table_name(workspace, layer): return table_name -def get_workspaces(conn_cur=None): +def get_workspaces(): """Returns workspaces from internal DB only""" - if conn_cur is None: - conn_cur = db_util.get_connection_cursor() - _, cur = conn_cur - query = sql.SQL("""select schema_name from information_schema.schemata where schema_name NOT IN ({schemas}) AND schema_owner = {layman_pg_user}""").format( @@ -42,11 +38,10 @@ def get_workspaces(conn_cur=None): layman_pg_user=sql.Literal(settings.LAYMAN_PG_USER), ) try: - cur.execute(query) + rows = db_util.run_query(query) except BaseException as exc: logger.error(f'get_workspaces ERROR') raise LaymanError(7) from exc - rows = cur.fetchall() return [ r[0] for r in rows ] @@ -61,36 +56,26 @@ def check_workspace_name(workspace): raise LaymanError(35, {'reserved_by': __name__, 'schema': workspace}) -def ensure_workspace(workspace, conn_cur=None): +def ensure_workspace(workspace, ): """Ensures workspace in internal DB only""" - if conn_cur is None: - conn_cur = db_util.get_connection_cursor() - conn, cur = conn_cur - statement = sql.SQL("""CREATE SCHEMA IF NOT EXISTS {schema} AUTHORIZATION {user}""").format( schema=sql.Identifier(workspace), user=sql.Identifier(settings.LAYMAN_PG_USER), ) try: - cur.execute(statement) - conn.commit() + db_util.run_statement(statement) except BaseException as exc: logger.error(f'ensure_workspace ERROR') raise LaymanError(7) from exc -def delete_workspace(workspace, conn_cur=None): +def delete_workspace(workspace, ): """Deletes workspace from internal DB only""" - if conn_cur is None: - conn_cur = db_util.get_connection_cursor() - conn, cur = conn_cur - statement = sql.SQL("""DROP SCHEMA IF EXISTS {schema} RESTRICT""").format( schema=sql.Identifier(workspace), ) try: - cur.execute(statement, (workspace, )) - conn.commit() + db_util.run_statement(statement, (workspace, )) except BaseException as exc: logger.error(f'delete_workspace ERROR') raise LaymanError(7) from exc @@ -203,9 +188,7 @@ def import_layer_vector_file_to_internal_table_async(schema, table_name, main_fi return process -def get_text_column_names(schema, table_name, conn_cur=None): - _, cur = conn_cur or db_util.get_connection_cursor() - +def get_text_column_names(schema, table_name, uri_str=None): statement = """ SELECT column_name FROM information_schema.columns @@ -214,25 +197,23 @@ def get_text_column_names(schema, table_name, conn_cur=None): AND data_type IN ('character varying', 'varchar', 'character', 'char', 'text') """ try: - cur.execute(statement, (schema, table_name)) + rows = db_util.run_query(statement, (schema, table_name), uri_str=uri_str) except BaseException as exc: logger.error(f'get_text_column_names ERROR') raise LaymanError(7) from exc - rows = cur.fetchall() return [r[0] for r in rows] -def get_internal_table_all_column_names(workspace, layername, conn_cur=None): +def get_internal_table_all_column_names(workspace, layername, ): table_name = get_internal_table_name(workspace, layername) - return get_all_table_column_names(workspace, table_name, conn_cur=conn_cur) + return get_all_table_column_names(workspace, table_name, ) -def get_all_table_column_names(schema, table_name, conn_cur=None): - return [col.name for col in get_all_column_infos(schema, table_name, conn_cur=conn_cur)] +def get_all_table_column_names(schema, table_name, uri_str=None): + return [col.name for col in get_all_column_infos(schema, table_name, uri_str=uri_str)] -def get_all_column_infos(schema, table_name, *, conn_cur=None, omit_geometry_columns=False): - _, cur = conn_cur or db_util.get_connection_cursor() +def get_all_column_infos(schema, table_name, *, uri_str=None, omit_geometry_columns=False): query = """ SELECT inf.column_name, inf.data_type FROM information_schema.columns inf @@ -247,17 +228,14 @@ def get_all_column_infos(schema, table_name, *, conn_cur=None, omit_geometry_col query += " AND gc.f_geometry_column is null" try: - cur.execute(query, (schema, table_name)) + rows = db_util.run_query(query, (schema, table_name), uri_str=uri_str) except BaseException as exc: logger.error(f'get_all_column_names ERROR') raise LaymanError(7) from exc - rows = cur.fetchall() return [ColumnInfo(name=r[0], data_type=r[1]) for r in rows] -def get_number_of_features(schema, table_name, conn_cur=None): - _, cur = conn_cur or db_util.get_connection_cursor() - +def get_number_of_features(schema, table_name, uri_str=None): statement = sql.SQL(""" select count(*) from {table} @@ -265,20 +243,18 @@ def get_number_of_features(schema, table_name, conn_cur=None): table=sql.Identifier(schema, table_name), ) try: - cur.execute(statement) + rows = db_util.run_query(statement, uri_str=uri_str) except BaseException as exc: logger.error(f'get_number_of_features ERROR') raise LaymanError(7) from exc - rows = cur.fetchall() return rows[0][0] -def get_text_data(schema, table_name, primary_key, conn_cur=None): - _, cur = conn_cur or db_util.get_connection_cursor() - col_names = get_text_column_names(schema, table_name, conn_cur=conn_cur) +def get_text_data(schema, table_name, primary_key, uri_str=None): + col_names = get_text_column_names(schema, table_name, uri_str=uri_str) if len(col_names) == 0: return [], 0 - num_features = get_number_of_features(schema, table_name, conn_cur=conn_cur) + num_features = get_number_of_features(schema, table_name, uri_str=uri_str) if num_features == 0: return [], 0 limit = max(100, num_features // 10) @@ -294,11 +270,10 @@ def get_text_data(schema, table_name, primary_key, conn_cur=None): limit=sql.Literal(limit), ) try: - cur.execute(statement) + rows = db_util.run_query(statement, uri_str=uri_str) except BaseException as exc: logger.error(f'get_text_data ERROR') raise LaymanError(7) from exc - rows = cur.fetchall() col_texts = defaultdict(list) for row in rows: for idx, col_name in enumerate(col_names): @@ -313,8 +288,8 @@ def get_text_data(schema, table_name, primary_key, conn_cur=None): return col_texts, limit -def get_text_languages(schema, table_name, primary_key, *, conn_cur=None): - texts, num_rows = get_text_data(schema, table_name, primary_key, conn_cur) +def get_text_languages(schema, table_name, primary_key, *, uri_str=None): + texts, num_rows = get_text_data(schema, table_name, primary_key, uri_str) all_langs = set() for text in texts: # skip short texts @@ -424,20 +399,17 @@ def get_most_frequent_lower_distance_query(schema, table_name, primary_key, geom return query -def get_most_frequent_lower_distance(schema, table_name, primary_key, geometry_column, conn_cur=None): - _, cur = conn_cur or db_util.get_connection_cursor() - +def get_most_frequent_lower_distance(schema, table_name, primary_key, geometry_column, uri_str=None): query = get_most_frequent_lower_distance_query(schema, table_name, primary_key, geometry_column) # print(f"\nget_most_frequent_lower_distance v1\nusername={username}, layername={layername}") # print(query) try: - cur.execute(query) + rows = db_util.run_query(query, uri_str=uri_str) except BaseException as exc: logger.error(f'get_most_frequent_lower_distance ERROR') raise LaymanError(7) from exc - rows = cur.fetchall() # for row in rows: # print(f"row={row}") result = None @@ -466,8 +438,8 @@ def get_most_frequent_lower_distance(schema, table_name, primary_key, geometry_c ] -def guess_scale_denominator(schema, table_name, primary_key, geometry_column, *, conn_cur=None): - distance = get_most_frequent_lower_distance(schema, table_name, primary_key, geometry_column, conn_cur=conn_cur) +def guess_scale_denominator(schema, table_name, primary_key, geometry_column, *, uri_str=None): + distance = get_most_frequent_lower_distance(schema, table_name, primary_key, geometry_column, uri_str=uri_str) log_sd_list = [math.log10(sd) for sd in SCALE_DENOMINATORS] if distance is not None: coef = 2000 if distance > 100 else 1000 @@ -480,8 +452,7 @@ def guess_scale_denominator(schema, table_name, primary_key, geometry_column, *, return scale_denominator -def create_string_attributes(attribute_tuples, conn_cur=None): - _, cur = conn_cur or db_util.get_connection_cursor() +def create_string_attributes(attribute_tuples, uri_str=None): query = sql.SQL('{alters} \n COMMIT;').format( alters=sql.SQL('\n').join( [sql.SQL("""ALTER TABLE {table} ADD COLUMN {fattrname} VARCHAR(1024);""").format( @@ -491,7 +462,7 @@ def create_string_attributes(attribute_tuples, conn_cur=None): ) ) try: - cur.execute(query) + db_util.run_statement(query, uri_str=uri_str, encapsulate_exception=False) except InsufficientPrivilege as exc: raise LaymanError(7, data={ 'reason': 'Insufficient privilege', @@ -501,9 +472,7 @@ def create_string_attributes(attribute_tuples, conn_cur=None): raise LaymanError(7) from exc -def get_missing_attributes(attribute_tuples, conn_cur=None): - _, cur = conn_cur or db_util.get_connection_cursor() - +def get_missing_attributes(attribute_tuples, uri_str=None): # Find all foursomes which do not already exist query = sql.SQL("""select attribs.* from ({selects}) attribs left join @@ -520,13 +489,12 @@ def get_missing_attributes(attribute_tuples, conn_cur=None): try: if attribute_tuples: - cur.execute(query) + rows = db_util.run_query(query, uri_str=uri_str) except BaseException as exc: logger.error(f'get_missing_attributes ERROR') raise LaymanError(7) from exc missing_attributes = set() - rows = cur.fetchall() for row in rows: missing_attributes.add((row[0], row[1], @@ -534,8 +502,8 @@ def get_missing_attributes(attribute_tuples, conn_cur=None): return missing_attributes -def ensure_attributes(attribute_tuples, conn_cur): - missing_attributes = get_missing_attributes(attribute_tuples, conn_cur) +def ensure_attributes(attribute_tuples, db_uri_str): + missing_attributes = get_missing_attributes(attribute_tuples, db_uri_str) if missing_attributes: dangerous_attribute_names = { a for _, _, a in missing_attributes @@ -546,11 +514,11 @@ def ensure_attributes(attribute_tuples, conn_cur): 'expected': r'Attribute names matching regex ^[a-zA-Z_][a-zA-Z_0-9]*$', 'found': sorted(dangerous_attribute_names), }) - create_string_attributes(missing_attributes, conn_cur) + create_string_attributes(missing_attributes, db_uri_str) return missing_attributes -def get_bbox(schema, table_name, conn_cur=None, column=settings.OGR_DEFAULT_GEOMETRY_COLUMN): +def get_bbox(schema, table_name, uri_str=None, column=settings.OGR_DEFAULT_GEOMETRY_COLUMN): query = sql.SQL(''' with tmp as (select ST_Extent(l.{column}) as bbox from {table} l @@ -564,24 +532,23 @@ def get_bbox(schema, table_name, conn_cur=None, column=settings.OGR_DEFAULT_GEOM table=sql.Identifier(schema, table_name), column=sql.Identifier(column), ) - result = db_util.run_query(query, conn_cur=conn_cur)[0] + result = db_util.run_query(query, uri_str=uri_str)[0] return result -def get_table_crs(schema, table_name, conn_cur=None, column=settings.OGR_DEFAULT_GEOMETRY_COLUMN, *, use_internal_srid): - srid = get_column_srid(schema, table_name, column, conn_cur=conn_cur) - crs = db_util.get_crs_from_srid(srid, conn_cur, use_internal_srid=use_internal_srid) +def get_table_crs(schema, table_name, uri_str=None, column=settings.OGR_DEFAULT_GEOMETRY_COLUMN, *, use_internal_srid): + srid = get_column_srid(schema, table_name, column, uri_str=uri_str) + crs = db_util.get_crs_from_srid(srid, uri_str, use_internal_srid=use_internal_srid) return crs -def get_column_srid(schema, table, column, *, conn_cur=None): +def get_column_srid(schema, table, column, *, uri_str=None): query = 'select Find_SRID(%s, %s, %s);' - srid = db_util.run_query(query, (schema, table, column), conn_cur=conn_cur)[0][0] + srid = db_util.run_query(query, (schema, table, column), uri_str=uri_str)[0][0] return srid -def get_geometry_types(schema, table_name, *, column_name=settings.OGR_DEFAULT_GEOMETRY_COLUMN, conn_cur=None): - conn, cur = conn_cur or db_util.get_connection_cursor() +def get_geometry_types(schema, table_name, *, column_name=settings.OGR_DEFAULT_GEOMETRY_COLUMN, uri_str=None): query = sql.SQL(""" select distinct ST_GeometryType({column}) as geometry_type_name from {table} @@ -590,11 +557,9 @@ def get_geometry_types(schema, table_name, *, column_name=settings.OGR_DEFAULT_G column=sql.Identifier(column_name), ) try: - cur.execute(query) + rows = db_util.run_query(query, uri_str=uri_str) except BaseException as exc: logger.error(f'get_geometry_types ERROR') raise LaymanError(7) from exc - rows = cur.fetchall() - conn.commit() result = [row[0] for row in rows] return result diff --git a/src/layman/layer/db/table.py b/src/layman/layer/db/table.py index 0d5235993..40baeb05d 100644 --- a/src/layman/layer/db/table.py +++ b/src/layman/layer/db/table.py @@ -23,32 +23,18 @@ def get_layer_info(workspace, layername,): result = {} if table_uri: if layer_info['original_data_source'] == settings.EnumOriginalDataSource.FILE.value: - conn_cur = db_util.get_connection_cursor() + db_uri_str = None else: - try: - conn_cur = db_util.get_connection_cursor(db_uri_str=table_uri.db_uri_str,) - except BaseException: - result['db'] = { - 'schema': table_uri.schema, - 'table': table_uri.table, - 'geo_column': table_uri.geo_column, - 'external_uri': layer_util.redact_uri(table_uri.db_uri_str), - 'status': 'NOT_AVAILABLE', - 'error': 'Cannot connect to DB.', - } - return result - - _, cur = conn_cur + db_uri_str = table_uri.db_uri_str try: - cur.execute(f""" + rows = db_util.run_query(f""" SELECT schemaname, tablename, tableowner FROM pg_tables WHERE schemaname = %s AND tablename = %s - """, (table_uri.schema, table_uri.table, )) + """, (table_uri.schema, table_uri.table, ), uri_str=db_uri_str) except BaseException as exc: raise LaymanError(7) from exc - rows = cur.fetchall() if len(rows) > 0: result['db'] = { 'schema': table_uri.schema, @@ -70,26 +56,22 @@ def get_layer_info(workspace, layername,): return result -def delete_layer(workspace, layername, conn_cur=None): +def delete_layer(workspace, layername, ): """Deletes table from internal DB only""" table_name = get_internal_table_name(workspace, layername) if table_name: - if conn_cur is None: - conn_cur = db_util.get_connection_cursor() - conn, cur = conn_cur query = sql.SQL(""" DROP TABLE IF EXISTS {table} CASCADE """).format( table=sql.Identifier(workspace, table_name), ) try: - cur.execute(query) - conn.commit() + db_util.run_statement(query) except BaseException as exc: raise LaymanError(7)from exc -def set_internal_table_layer_srid(schema, table_name, srid, *, conn_cur=None): +def set_internal_table_layer_srid(schema, table_name, srid, ): query = '''SELECT UpdateGeometrySRID(%s, %s, %s, %s);''' params = (schema, table_name, settings.OGR_DEFAULT_GEOMETRY_COLUMN, srid) - db_util.run_query(query, params, conn_cur=conn_cur) + db_util.run_query(query, params) diff --git a/src/layman/layer/micka/csw.py b/src/layman/layer/micka/csw.py index 5d1c07e3f..32c42ee4b 100644 --- a/src/layman/layer/micka/csw.py +++ b/src/layman/layer/micka/csw.py @@ -8,7 +8,6 @@ from flask import current_app import crs as crs_def -from db import util as db_util from layman.common.filesystem.uuid import get_publication_uuid_file from layman.common.micka import util as common_util, requests as micka_requests from layman.common import language as common_language, empty_method, empty_method_returns_none, bbox as bbox_util @@ -138,15 +137,14 @@ def get_template_path_and_values(workspace, layername, http_method): if geodata_type == settings.GEODATA_TYPE_VECTOR: table_uri = publ_info['_table_uri'] table_name = table_uri.table - conn_cur = db_util.get_connection_cursor(db_uri_str=table_uri.db_uri_str) try: languages = db.get_text_languages(table_uri.schema, table_name, table_uri.primary_key_column, - conn_cur=conn_cur) + uri_str=table_uri.db_uri_str) except LaymanError: languages = [] try: scale_denominator = db.guess_scale_denominator(table_uri.schema, table_name, table_uri.primary_key_column, - table_uri.geo_column, conn_cur=conn_cur) + table_uri.geo_column, uri_str=table_uri.db_uri_str) except LaymanError: scale_denominator = None spatial_resolution = { diff --git a/src/layman/layer/prime_db_schema/file_data_tasks.py b/src/layman/layer/prime_db_schema/file_data_tasks.py index e8c52fbef..bc8ec3540 100644 --- a/src/layman/layer/prime_db_schema/file_data_tasks.py +++ b/src/layman/layer/prime_db_schema/file_data_tasks.py @@ -1,6 +1,5 @@ from celery.utils.log import get_task_logger -from db import util as db_util from layman.celery import AbortedException from layman import celery_app, util as layman_util, settings from .. import LAYER_TYPE @@ -30,8 +29,7 @@ def patch_after_feature_change( assert geodata_type == settings.GEODATA_TYPE_VECTOR table_uri = info['_table_uri'] - conn_cur = db_util.get_connection_cursor(db_uri_str=table_uri.db_uri_str) - bbox = db_get_bbox(table_uri.schema, table_uri.table, conn_cur=conn_cur, column=table_uri.geo_column) + bbox = db_get_bbox(table_uri.schema, table_uri.table, uri_str=table_uri.db_uri_str, column=table_uri.geo_column) if self.is_aborted(): raise AbortedException diff --git a/src/layman/layer/prime_db_schema/tasks.py b/src/layman/layer/prime_db_schema/tasks.py index 710787100..df22680f0 100644 --- a/src/layman/layer/prime_db_schema/tasks.py +++ b/src/layman/layer/prime_db_schema/tasks.py @@ -1,6 +1,5 @@ from celery.utils.log import get_task_logger -from db import util as db_util from layman.celery import AbortedException from layman.common import empty_method_returns_true from layman.common.prime_db_schema import publications @@ -43,9 +42,8 @@ def refresh_file_data( # because for compressed files sent with chunks file_type would be UNKNOWN and table_uri not set publ_info = layman_util.get_publication_info(username, LAYER_TYPE, layername, context={'keys': ['table_uri']}) table_uri = publ_info['_table_uri'] - conn_cur = db_util.get_connection_cursor(db_uri_str=table_uri.db_uri_str) - bbox = db_get_bbox(table_uri.schema, table_uri.table, conn_cur=conn_cur, column=table_uri.geo_column) - crs = get_table_crs(table_uri.schema, table_uri.table, conn_cur=conn_cur, column=table_uri.geo_column, use_internal_srid=True) + bbox = db_get_bbox(table_uri.schema, table_uri.table, uri_str=table_uri.db_uri_str, column=table_uri.geo_column) + crs = get_table_crs(table_uri.schema, table_uri.table, uri_str=table_uri.db_uri_str, column=table_uri.geo_column, use_internal_srid=True) elif geodata_type == settings.GEODATA_TYPE_RASTER: bbox = gdal_get_bbox(username, layername) crs = gdal_get_crs(username, layername) diff --git a/src/layman/layer/qgis/wms.py b/src/layman/layer/qgis/wms.py index cce75d718..705cf3033 100644 --- a/src/layman/layer/qgis/wms.py +++ b/src/layman/layer/qgis/wms.py @@ -2,7 +2,6 @@ from owslib.wms import WebMapService import crs as crs_def -from db import util as db_util from layman import patch_mode, settings, util as layman_util from layman.common import bbox as bbox_util, empty_method, empty_method_returns_none, empty_method_returns_dict from . import util, LAYER_TYPE @@ -71,15 +70,14 @@ def save_qgs_file(workspace, layer): db_schema = table_uri.schema layer_bbox = layer_bbox if not bbox_util.is_empty(layer_bbox) else crs_def.CRSDefinitions[crs].default_bbox qml = util.get_original_style_xml(workspace, layer) - conn_cur = db_util.get_connection_cursor(db_uri_str=table_uri.db_uri_str) - db_types = db.get_geometry_types(db_schema, table_name, column_name=table_uri.geo_column, conn_cur=conn_cur) + db_types = db.get_geometry_types(db_schema, table_name, column_name=table_uri.geo_column, uri_str=table_uri.db_uri_str) qml_geometry = util.get_geometry_from_qml_and_db_types(qml, db_types) db_cols = [ - col for col in db.get_all_column_infos(db_schema, table_name, conn_cur=conn_cur, omit_geometry_columns=True) + col for col in db.get_all_column_infos(db_schema, table_name, uri_str=table_uri.db_uri_str, omit_geometry_columns=True) if col.name != table_uri.primary_key_column ] source_type = util.get_source_type(db_types, qml_geometry) - column_srid = db.get_column_srid(db_schema, table_name, table_uri.geo_column, conn_cur=conn_cur) + column_srid = db.get_column_srid(db_schema, table_name, table_uri.geo_column, uri_str=table_uri.db_uri_str) layer_qml = util.fill_layer_template(layer, uuid, layer_bbox, crs, qml, source_type, db_cols, table_uri, column_srid, db_types) qgs_str = util.fill_project_template(layer, uuid, layer_qml, crs, settings.LAYMAN_OUTPUT_SRS_LIST, diff --git a/src/layman/layer/util.py b/src/layman/layer/util.py index d20bda218..4351f699b 100644 --- a/src/layman/layer/util.py +++ b/src/layman/layer/util.py @@ -290,7 +290,7 @@ def parse_and_validate_external_table_uri_str(external_table_uri_str): }) try: - conn_cur = db_util.get_connection_cursor(db_uri_str, encapsulate_exception=False) + db_util.get_connection_pool(db_uri_str=db_uri_str, encapsulate_exception=False) except psycopg2.OperationalError as exc: raise LaymanError(2, { 'parameter': 'external_table_uri', @@ -304,7 +304,7 @@ def parse_and_validate_external_table_uri_str(external_table_uri_str): if not geo_column: query = f'''select f_geometry_column from geometry_columns where f_table_schema = %s and f_table_name = %s order by f_geometry_column asc''' - query_res = db_util.run_query(query, (schema, table), conn_cur=conn_cur) + query_res = db_util.run_query(query, (schema, table), uri_str=db_uri_str, ) if len(query_res) == 0: raise LaymanError(2, { 'parameter': 'external_table_uri', @@ -332,10 +332,10 @@ def parse_and_validate_external_table_uri_str(external_table_uri_str): }) query = f'''select count(*) from information_schema.tables WHERE table_schema=%s and table_name=%s''' - query_res = db_util.run_query(query, (schema, table,), conn_cur=conn_cur) + query_res = db_util.run_query(query, (schema, table,), uri_str=db_uri_str, ) if not query_res[0][0]: query = f'''select table_schema, table_name from information_schema.tables WHERE lower(table_schema)=lower(%s) and lower(table_name)=lower(%s)''' - query_res = db_util.run_query(query, (schema, table,), conn_cur=conn_cur) + query_res = db_util.run_query(query, (schema, table,), uri_str=db_uri_str, ) suggestion = f" Did you mean \"{query_res[0][0]}\".\"{query_res[0][1]}\"?" if query_res else '' raise LaymanError(2, { 'parameter': 'external_table_uri', @@ -349,7 +349,7 @@ def parse_and_validate_external_table_uri_str(external_table_uri_str): }) query = f'''select count(*) from geometry_columns where f_table_schema = %s and f_table_name = %s and f_geometry_column = %s''' - query_res = db_util.run_query(query, (schema, table, geo_column), conn_cur=conn_cur) + query_res = db_util.run_query(query, (schema, table, geo_column), uri_str=db_uri_str, ) if not query_res[0][0]: raise LaymanError(2, { 'parameter': 'external_table_uri', @@ -363,7 +363,7 @@ def parse_and_validate_external_table_uri_str(external_table_uri_str): } }) - crs = get_table_crs(schema, table, conn_cur=conn_cur, column=geo_column, use_internal_srid=False) + crs = get_table_crs(schema, table, uri_str=db_uri_str, column=geo_column, use_internal_srid=False) if crs not in settings.INPUT_SRS_LIST: raise LaymanError(2, { 'parameter': 'external_table_uri', @@ -386,7 +386,7 @@ def parse_and_validate_external_table_uri_str(external_table_uri_str): idx.indisprimary AND cls.relname = %s AND nspace.nspname = %s''' - query_res = db_util.run_query(query, (table, schema), conn_cur=conn_cur, log_query=True) + query_res = db_util.run_query(query, (table, schema), uri_str=db_uri_str, log_query=True) primary_key_columns = [r[0] for r in query_res] if len(query_res) == 0: raise LaymanError(2, { @@ -413,7 +413,7 @@ def parse_and_validate_external_table_uri_str(external_table_uri_str): } }) - column_names = get_all_table_column_names(schema, table, conn_cur=conn_cur) + column_names = get_all_table_column_names(schema, table, uri_str=db_uri_str) unsafe_column_names = [c for c in column_names if not re.match(SAFE_PG_IDENTIFIER_PATTERN, c)] if unsafe_column_names: raise LaymanError(2, { diff --git a/test_tools/external_db.py b/test_tools/external_db.py index 03890ff82..a08f3d3c7 100644 --- a/test_tools/external_db.py +++ b/test_tools/external_db.py @@ -24,23 +24,21 @@ def ensure_db(): with app.app_context(): db_util.run_statement(statement) - conn_cur = db_util.get_connection_cursor(URI_STR) statement = f""" CREATE USER {READ_ONLY_USER} WITH PASSWORD '{READ_ONLY_PASSWORD}'; GRANT CONNECT ON DATABASE {EXTERNAL_DB_NAME} TO {READ_ONLY_USER}; ALTER DEFAULT PRIVILEGES GRANT SELECT ON TABLES TO {READ_ONLY_USER}; """ - db_util.run_statement(statement, conn_cur=conn_cur) + db_util.run_statement(statement, uri_str=URI_STR) yield def ensure_schema(schema): - conn_cur = db_util.get_connection_cursor(URI_STR) statement = sql.SQL(f'CREATE SCHEMA IF NOT EXISTS {{schema}} AUTHORIZATION {settings.LAYMAN_PG_USER}').format( schema=sql.Identifier(schema), ) - db_util.run_statement(statement, conn_cur=conn_cur) + db_util.run_statement(statement, uri_str=URI_STR) def ensure_table(schema, name, geo_column, *, primary_key_columns=None, other_columns=None, srid=4326): @@ -71,8 +69,7 @@ def ensure_table(schema, name, geo_column, *, primary_key_columns=None, other_co table=sql.Identifier(schema, name), columns=sql.SQL(',').join(columns), ) - conn_cur = db_util.get_connection_cursor(URI_STR) - db_util.run_statement(statement, conn_cur=conn_cur, log_query=True) + db_util.run_statement(statement, uri_str=URI_STR, log_query=True) def import_table(input_file_path, *, table=None, schema='public', geo_column=settings.OGR_DEFAULT_GEOMETRY_COLUMN, @@ -110,27 +107,25 @@ def import_table(input_file_path, *, table=None, schema='public', geo_column=set assert return_code == 0 and not stdout and not stderr, f"return_code={return_code}, stdout={stdout}, stderr={stderr}" if primary_key_column is None: - conn_cur = db_util.get_connection_cursor(URI_STR) statement = sql.SQL("alter table {table} drop column {primary_key}").format( table=sql.Identifier(schema, table), primary_key=sql.Identifier(primary_key_to_later_drop), ) - db_util.run_statement(statement, conn_cur=conn_cur) + db_util.run_statement(statement, uri_str=URI_STR, ) if additional_geo_column: - conn_cur = db_util.get_connection_cursor(URI_STR) statement = sql.SQL("alter table {table} add column {geo_column_2} GEOMETRY").format( table=sql.Identifier(schema, table), geo_column_2=sql.Identifier(additional_geo_column), ) - db_util.run_statement(statement, conn_cur=conn_cur) + db_util.run_statement(statement, uri_str=URI_STR, ) statement = sql.SQL("update {table} set {geo_column_2} = st_buffer({geo_column}, 15)").format( table=sql.Identifier(schema, table), geo_column_2=sql.Identifier(additional_geo_column), geo_column=sql.Identifier(geo_column), ) - db_util.run_statement(statement, conn_cur=conn_cur) + db_util.run_statement(statement, uri_str=URI_STR, ) return schema, table @@ -140,5 +135,4 @@ def drop_table(schema, name, *, if_exists=False): statement = sql.SQL(f'drop table {if_exists_str} {{table}}').format( table=sql.Identifier(schema, name) ) - conn_cur = db_util.get_connection_cursor(URI_STR) - db_util.run_statement(statement, conn_cur=conn_cur) + db_util.run_statement(statement, uri_str=URI_STR, ) diff --git a/tests/asserts/final/publication/internal.py b/tests/asserts/final/publication/internal.py index ffd154ecd..99b38aa9a 100644 --- a/tests/asserts/final/publication/internal.py +++ b/tests/asserts/final/publication/internal.py @@ -438,7 +438,7 @@ def point_coordinates(workspace, publ_type, name, *, point_id, crs, exp_coordina ) with app.app_context(): to_srid = db_util.get_internal_srid(crs) - coordinates = db_util.run_query(query, (to_srid, point_id), conn_cur=db_util.get_connection_cursor(table_uri.db_uri_str)) + coordinates = db_util.run_query(query, (to_srid, point_id), uri_str=table_uri.db_uri_str) assert len(coordinates) == 1, coordinates coordinates = coordinates[0] diff --git a/tests/dynamic_data/publications/layer_external_db/edge_db_username_and_password.py b/tests/dynamic_data/publications/layer_external_db/edge_db_username_and_password.py index 703f3b9f5..90661e049 100644 --- a/tests/dynamic_data/publications/layer_external_db/edge_db_username_and_password.py +++ b/tests/dynamic_data/publications/layer_external_db/edge_db_username_and_password.py @@ -55,7 +55,6 @@ class TestEdge(base_test.TestSingleRestPublication): def test_layer(self, layer: Publication, rest_method, rest_args, ): """Parametrized using pytest_generate_tests""" - conn_cur = db_util.get_connection_cursor(external_db.URI_STR) statement = sql.SQL(f""" DO $$BEGIN @@ -77,7 +76,7 @@ def test_layer(self, layer: Publication, rest_method, rest_args, ): schema=sql.Identifier(SCHEMA), table=sql.Identifier(TABLE), ) - db_util.run_statement(statement, conn_cur=conn_cur) + db_util.run_statement(statement, uri_str=external_db.URI_STR) # publish layer from external DB table rest_method.fn(layer, args=rest_args) diff --git a/tests/dynamic_data/publications/layer_external_db/external_db_test.py b/tests/dynamic_data/publications/layer_external_db/external_db_test.py index 310057d4d..47b67365b 100644 --- a/tests/dynamic_data/publications/layer_external_db/external_db_test.py +++ b/tests/dynamic_data/publications/layer_external_db/external_db_test.py @@ -219,9 +219,8 @@ def test_layer(self, layer: Publication, rest_method, rest_args, params): 'primary_key_column': primary_key_column, 'additional_geo_column': params['additional_geo_column'], }) - conn_cur = db_util.get_connection_cursor(external_db.URI_STR) query = f'''select type from geometry_columns where f_table_schema = %s and f_table_name = %s and f_geometry_column = %s''' - result = db_util.run_query(query, (schema, table, geo_column), conn_cur=conn_cur) + result = db_util.run_query(query, (schema, table, geo_column), uri_str=external_db.URI_STR) assert result[0][0] == params['exp_geometry_type'] # publish layer from external DB table diff --git a/tests/dynamic_data/publications/layer_wfst/new_attribute_test.py b/tests/dynamic_data/publications/layer_wfst/new_attribute_test.py index 337821c57..043e4ed58 100644 --- a/tests/dynamic_data/publications/layer_wfst/new_attribute_test.py +++ b/tests/dynamic_data/publications/layer_wfst/new_attribute_test.py @@ -3,7 +3,6 @@ from owslib.feature.schema import get_schema as get_wfs_schema import pytest -from db import util as db_util from layman import app, settings from layman.layer import db from layman.layer.geoserver import wfs as geoserver_wfs @@ -245,7 +244,6 @@ def test_new_attribute(self, layer: Publication, rest_args, params, parametrizat style_type = parametrization.style_file.style_type wfs_url = f"http://{settings.LAYMAN_SERVER_NAME}/geoserver/{workspace}/wfs" table_uris = {} - conn_cur = None # get current attributes and assert that new attributes are not yet present old_db_attributes = {} @@ -255,10 +253,9 @@ def test_new_attribute(self, layer: Publication, rest_args, params, parametrizat with app.app_context(): table_uri = get_publication_info(workspace, process_client.LAYER_TYPE, layer_name, context={'keys': ['table_uri']})['_table_uri'] - conn_cur = conn_cur or db_util.get_connection_cursor(table_uri.db_uri_str) table_uris[layer_name] = table_uri old_db_attributes[layer_name] = db.get_all_table_column_names(table_uri.schema, table_uri.table, - conn_cur=conn_cur) + uri_str=table_uri, ) for attr_name in attr_names: assert attr_name not in old_db_attributes[layer_name], \ f"old_db_attributes={old_db_attributes[layer_name]}, attr_name={attr_name}" @@ -287,7 +284,7 @@ def test_new_attribute(self, layer: Publication, rest_args, params, parametrizat for layer_name, attr_names in new_attributes: # assert that exactly all attr_names were created in DB table table_uri = table_uris[layer_name] - db_attributes = db.get_all_table_column_names(table_uri.schema, table_uri.table, conn_cur=conn_cur) + db_attributes = db.get_all_table_column_names(table_uri.schema, table_uri.table, uri_str=table_uri) for attr_name in attr_names: assert attr_name in db_attributes, f"db_attributes={db_attributes}, attr_name={attr_name}" assert set(attr_names).union(set(old_db_attributes[layer_name])) == set(db_attributes)