From ed12cf0637c2f0e27d575b2ec932715d2c835e48 Mon Sep 17 00:00:00 2001 From: index-git Date: Mon, 25 Sep 2023 11:48:16 +0200 Subject: [PATCH] WIP: Do not pool connections internally --- src/db/util.py | 79 ++++-------- .../prime_db_schema/schema_initialization.py | 2 +- src/layman/layer/db/__init__.py | 113 ++++++------------ src/layman/layer/db/table.py | 34 ++---- src/layman/layer/micka/csw.py | 5 +- .../layer/prime_db_schema/file_data_tasks.py | 3 +- src/layman/layer/prime_db_schema/tasks.py | 5 +- src/layman/layer/qgis/wms.py | 7 +- src/layman/layer/util.py | 16 +-- test.sh | 2 +- test_separated.sh | 2 +- test_tools/external_db.py | 19 ++- tests/asserts/final/publication/internal.py | 2 +- .../edge_db_username_and_password.py | 3 +- .../layer_external_db/external_db_test.py | 3 +- .../layer_wfst/new_attribute_test.py | 6 +- 16 files changed, 102 insertions(+), 199 deletions(-) diff --git a/src/db/util.py b/src/db/util.py index 5b8d05889..073a7087f 100644 --- a/src/db/util.py +++ b/src/db/util.py @@ -3,7 +3,6 @@ import psycopg2 import psycopg2.pool from urllib import parse -from flask import g import crs as crs_def from . import PG_URI_STR @@ -11,64 +10,37 @@ logger = logging.getLogger(__name__) -FLASK_CONN_CUR_KEY = f'{__name__}:CONN_CUR' - CONNECTION_POOL_DICT = {} -def create_connection_cursor(db_uri_str=None, encapsulate_exception=True): - db_uri_str = db_uri_str or PG_URI_STR - try: - connection = psycopg2.connect(db_uri_str) - connection.set_session(autocommit=True) - except BaseException as exc: - if encapsulate_exception: - raise Error(1) from exc - raise exc - cursor = connection.cursor() - return connection, cursor - - -def get_connection_pool(db_uri_str=None): +def get_connection_pool(db_uri_str=None, encapsulate_exception=True): db_uri_str = db_uri_str or PG_URI_STR connection_pool = CONNECTION_POOL_DICT.get(db_uri_str) if not connection_pool: db_uri_parsed = parse.urlparse(db_uri_str) try: - connection_pool = psycopg2.pool.SimpleConnectionPool(5, 50, - user=db_uri_parsed.username, - password=db_uri_parsed.password, - host=db_uri_parsed.hostname, - port=db_uri_parsed.port, - database=db_uri_parsed.path[1:], - ) + connection_pool = psycopg2.pool.ThreadedConnectionPool(5, 50, + user=db_uri_parsed.username, + password=db_uri_parsed.password, + host=db_uri_parsed.hostname, + port=db_uri_parsed.port, + database=db_uri_parsed.path[1:], + ) except BaseException as exc: - raise Error(1) from exc + if encapsulate_exception: + raise Error(1) from exc + else: + raise exc CONNECTION_POOL_DICT[db_uri_str] = connection_pool return connection_pool -def get_connection_cursor(db_uri_str=None, encapsulate_exception=True): - if db_uri_str is None or db_uri_str == PG_URI_STR: - key = FLASK_CONN_CUR_KEY - if key not in g: - conn_cur = create_connection_cursor(encapsulate_exception=encapsulate_exception) - g.setdefault(key, conn_cur) - result = g.get(key) - else: - result = create_connection_cursor(db_uri_str=db_uri_str, encapsulate_exception=encapsulate_exception) - return result - - -def run_query(query, data=None, conn_cur=None, encapsulate_exception=True, log_query=False): - if conn_cur is None: - pool = get_connection_pool() - conn = pool.getconn() - conn.autocommit = True - cur = conn.cursor() - else: - conn, cur = conn_cur +def run_query(query, data=None, uri_str=None, encapsulate_exception=True, log_query=False): + pool = get_connection_pool(db_uri_str=uri_str, encapsulate_exception=encapsulate_exception, ) + conn = pool.getconn() + conn.autocommit = True + cur = conn.cursor() try: if log_query: logger.info(f"query={cur.mogrify(query, data).decode()}") @@ -87,14 +59,11 @@ def run_query(query, data=None, conn_cur=None, encapsulate_exception=True, log_q return rows -def run_statement(query, data=None, conn_cur=None, encapsulate_exception=True, log_query=False): - if conn_cur is None: - pool = get_connection_pool() - conn = pool.getconn() - conn.autocommit = True - cur = conn.cursor() - else: - conn, cur = conn_cur +def run_statement(query, data=None, uri_str=None, encapsulate_exception=True, log_query=False): + pool = get_connection_pool(db_uri_str=uri_str, encapsulate_exception=encapsulate_exception, ) + conn = pool.getconn() + conn.autocommit = True + cur = conn.cursor() try: if log_query: logger.info(f"query={cur.mogrify(query, data).decode()}") @@ -132,14 +101,14 @@ def get_internal_srid(crs): return srid -def get_crs_from_srid(srid, conn_cur=None, *, use_internal_srid): +def get_crs_from_srid(srid, uri_str=None, *, use_internal_srid): crs = next(( crs_code for crs_code, crs_item_def in crs_def.CRSDefinitions.items() if crs_item_def.internal_srid == srid ), None) if use_internal_srid else None if not crs: sql = 'select auth_name, auth_srid from spatial_ref_sys where srid = %s;' - auth_name, auth_srid = run_query(sql, (srid, ), conn_cur=conn_cur)[0] + auth_name, auth_srid = run_query(sql, (srid, ), uri_str=uri_str)[0] if auth_name or auth_srid: crs = f'{auth_name}:{auth_srid}' return crs diff --git a/src/layman/common/prime_db_schema/schema_initialization.py b/src/layman/common/prime_db_schema/schema_initialization.py index 36e3b240c..e27f98c58 100644 --- a/src/layman/common/prime_db_schema/schema_initialization.py +++ b/src/layman/common/prime_db_schema/schema_initialization.py @@ -22,7 +22,7 @@ def ensure_schema(db_schema): db_util.run_statement(model.CREATE_SCHEMA_SQL) db_util.run_statement(model.setup_codelists_data()) except BaseException as exc: - db_util.run_statement(model.DROP_SCHEMA_SQL, conn_cur=db_util.get_connection_cursor()) + db_util.run_statement(model.DROP_SCHEMA_SQL, ) raise exc else: logger.info(f"Layman DB schema already exists, schema_name={db_schema}") diff --git a/src/layman/layer/db/__init__.py b/src/layman/layer/db/__init__.py index a6ae979ad..ad7b9213f 100644 --- a/src/layman/layer/db/__init__.py +++ b/src/layman/layer/db/__init__.py @@ -29,12 +29,8 @@ def get_internal_table_name(workspace, layer): return table_name -def get_workspaces(conn_cur=None): +def get_workspaces(): """Returns workspaces from internal DB only""" - if conn_cur is None: - conn_cur = db_util.get_connection_cursor() - _, cur = conn_cur - query = sql.SQL("""select schema_name from information_schema.schemata where schema_name NOT IN ({schemas}) AND schema_owner = {layman_pg_user}""").format( @@ -42,11 +38,10 @@ def get_workspaces(conn_cur=None): layman_pg_user=sql.Literal(settings.LAYMAN_PG_USER), ) try: - cur.execute(query) + rows = db_util.run_query(query) except BaseException as exc: logger.error(f'get_workspaces ERROR') raise LaymanError(7) from exc - rows = cur.fetchall() return [ r[0] for r in rows ] @@ -61,36 +56,26 @@ def check_workspace_name(workspace): raise LaymanError(35, {'reserved_by': __name__, 'schema': workspace}) -def ensure_workspace(workspace, conn_cur=None): +def ensure_workspace(workspace, ): """Ensures workspace in internal DB only""" - if conn_cur is None: - conn_cur = db_util.get_connection_cursor() - conn, cur = conn_cur - statement = sql.SQL("""CREATE SCHEMA IF NOT EXISTS {schema} AUTHORIZATION {user}""").format( schema=sql.Identifier(workspace), user=sql.Identifier(settings.LAYMAN_PG_USER), ) try: - cur.execute(statement) - conn.commit() + db_util.run_statement(statement) except BaseException as exc: logger.error(f'ensure_workspace ERROR') raise LaymanError(7) from exc -def delete_workspace(workspace, conn_cur=None): +def delete_workspace(workspace, ): """Deletes workspace from internal DB only""" - if conn_cur is None: - conn_cur = db_util.get_connection_cursor() - conn, cur = conn_cur - statement = sql.SQL("""DROP SCHEMA IF EXISTS {schema} RESTRICT""").format( schema=sql.Identifier(workspace), ) try: - cur.execute(statement, (workspace, )) - conn.commit() + db_util.run_statement(statement, (workspace, )) except BaseException as exc: logger.error(f'delete_workspace ERROR') raise LaymanError(7) from exc @@ -203,9 +188,7 @@ def import_layer_vector_file_to_internal_table_async(schema, table_name, main_fi return process -def get_text_column_names(schema, table_name, conn_cur=None): - _, cur = conn_cur or db_util.get_connection_cursor() - +def get_text_column_names(schema, table_name, uri_str=None): statement = """ SELECT column_name FROM information_schema.columns @@ -214,25 +197,23 @@ def get_text_column_names(schema, table_name, conn_cur=None): AND data_type IN ('character varying', 'varchar', 'character', 'char', 'text') """ try: - cur.execute(statement, (schema, table_name)) + rows = db_util.run_query(statement, (schema, table_name), uri_str=uri_str) except BaseException as exc: logger.error(f'get_text_column_names ERROR') raise LaymanError(7) from exc - rows = cur.fetchall() return [r[0] for r in rows] -def get_internal_table_all_column_names(workspace, layername, conn_cur=None): +def get_internal_table_all_column_names(workspace, layername, ): table_name = get_internal_table_name(workspace, layername) - return get_all_table_column_names(workspace, table_name, conn_cur=conn_cur) + return get_all_table_column_names(workspace, table_name, ) -def get_all_table_column_names(schema, table_name, conn_cur=None): - return [col.name for col in get_all_column_infos(schema, table_name, conn_cur=conn_cur)] +def get_all_table_column_names(schema, table_name, uri_str=None): + return [col.name for col in get_all_column_infos(schema, table_name, uri_str=uri_str)] -def get_all_column_infos(schema, table_name, *, conn_cur=None, omit_geometry_columns=False): - _, cur = conn_cur or db_util.get_connection_cursor() +def get_all_column_infos(schema, table_name, *, uri_str=None, omit_geometry_columns=False): query = """ SELECT inf.column_name, inf.data_type FROM information_schema.columns inf @@ -247,17 +228,14 @@ def get_all_column_infos(schema, table_name, *, conn_cur=None, omit_geometry_col query += " AND gc.f_geometry_column is null" try: - cur.execute(query, (schema, table_name)) + rows = db_util.run_query(query, (schema, table_name), uri_str=uri_str) except BaseException as exc: logger.error(f'get_all_column_names ERROR') raise LaymanError(7) from exc - rows = cur.fetchall() return [ColumnInfo(name=r[0], data_type=r[1]) for r in rows] -def get_number_of_features(schema, table_name, conn_cur=None): - _, cur = conn_cur or db_util.get_connection_cursor() - +def get_number_of_features(schema, table_name, uri_str=None): statement = sql.SQL(""" select count(*) from {table} @@ -265,20 +243,18 @@ def get_number_of_features(schema, table_name, conn_cur=None): table=sql.Identifier(schema, table_name), ) try: - cur.execute(statement) + rows = db_util.run_query(statement, uri_str=uri_str) except BaseException as exc: logger.error(f'get_number_of_features ERROR') raise LaymanError(7) from exc - rows = cur.fetchall() return rows[0][0] -def get_text_data(schema, table_name, primary_key, conn_cur=None): - _, cur = conn_cur or db_util.get_connection_cursor() - col_names = get_text_column_names(schema, table_name, conn_cur=conn_cur) +def get_text_data(schema, table_name, primary_key, uri_str=None): + col_names = get_text_column_names(schema, table_name, uri_str=uri_str) if len(col_names) == 0: return [], 0 - num_features = get_number_of_features(schema, table_name, conn_cur=conn_cur) + num_features = get_number_of_features(schema, table_name, uri_str=uri_str) if num_features == 0: return [], 0 limit = max(100, num_features // 10) @@ -294,11 +270,10 @@ def get_text_data(schema, table_name, primary_key, conn_cur=None): limit=sql.Literal(limit), ) try: - cur.execute(statement) + rows = db_util.run_query(statement, uri_str=uri_str) except BaseException as exc: logger.error(f'get_text_data ERROR') raise LaymanError(7) from exc - rows = cur.fetchall() col_texts = defaultdict(list) for row in rows: for idx, col_name in enumerate(col_names): @@ -313,8 +288,8 @@ def get_text_data(schema, table_name, primary_key, conn_cur=None): return col_texts, limit -def get_text_languages(schema, table_name, primary_key, *, conn_cur=None): - texts, num_rows = get_text_data(schema, table_name, primary_key, conn_cur) +def get_text_languages(schema, table_name, primary_key, *, uri_str=None): + texts, num_rows = get_text_data(schema, table_name, primary_key, uri_str) all_langs = set() for text in texts: # skip short texts @@ -424,20 +399,17 @@ def get_most_frequent_lower_distance_query(schema, table_name, primary_key, geom return query -def get_most_frequent_lower_distance(schema, table_name, primary_key, geometry_column, conn_cur=None): - _, cur = conn_cur or db_util.get_connection_cursor() - +def get_most_frequent_lower_distance(schema, table_name, primary_key, geometry_column, uri_str=None): query = get_most_frequent_lower_distance_query(schema, table_name, primary_key, geometry_column) # print(f"\nget_most_frequent_lower_distance v1\nusername={username}, layername={layername}") # print(query) try: - cur.execute(query) + rows = db_util.run_query(query, uri_str) except BaseException as exc: logger.error(f'get_most_frequent_lower_distance ERROR') raise LaymanError(7) from exc - rows = cur.fetchall() # for row in rows: # print(f"row={row}") result = None @@ -466,8 +438,8 @@ def get_most_frequent_lower_distance(schema, table_name, primary_key, geometry_c ] -def guess_scale_denominator(schema, table_name, primary_key, geometry_column, *, conn_cur=None): - distance = get_most_frequent_lower_distance(schema, table_name, primary_key, geometry_column, conn_cur=conn_cur) +def guess_scale_denominator(schema, table_name, primary_key, geometry_column, *, uri_str=None): + distance = get_most_frequent_lower_distance(schema, table_name, primary_key, geometry_column, uri_str=uri_str) log_sd_list = [math.log10(sd) for sd in SCALE_DENOMINATORS] if distance is not None: coef = 2000 if distance > 100 else 1000 @@ -480,8 +452,7 @@ def guess_scale_denominator(schema, table_name, primary_key, geometry_column, *, return scale_denominator -def create_string_attributes(attribute_tuples, conn_cur=None): - _, cur = conn_cur or db_util.get_connection_cursor() +def create_string_attributes(attribute_tuples, uri_str=None): query = sql.SQL('{alters} \n COMMIT;').format( alters=sql.SQL('\n').join( [sql.SQL("""ALTER TABLE {table} ADD COLUMN {fattrname} VARCHAR(1024);""").format( @@ -491,7 +462,7 @@ def create_string_attributes(attribute_tuples, conn_cur=None): ) ) try: - cur.execute(query) + db_util.run_statement(query, uri_str=uri_str, encapsulate_exception=False) except InsufficientPrivilege as exc: raise LaymanError(7, data={ 'reason': 'Insufficient privilege', @@ -501,9 +472,7 @@ def create_string_attributes(attribute_tuples, conn_cur=None): raise LaymanError(7) from exc -def get_missing_attributes(attribute_tuples, conn_cur=None): - _, cur = conn_cur or db_util.get_connection_cursor() - +def get_missing_attributes(attribute_tuples, uri_str=None): # Find all foursomes which do not already exist query = sql.SQL("""select attribs.* from ({selects}) attribs left join @@ -520,13 +489,12 @@ def get_missing_attributes(attribute_tuples, conn_cur=None): try: if attribute_tuples: - cur.execute(query) + rows = db_util.run_query(query, uri_str=uri_str) except BaseException as exc: logger.error(f'get_missing_attributes ERROR') raise LaymanError(7) from exc missing_attributes = set() - rows = cur.fetchall() for row in rows: missing_attributes.add((row[0], row[1], @@ -550,7 +518,7 @@ def ensure_attributes(attribute_tuples, db_uri_str): return missing_attributes -def get_bbox(schema, table_name, conn_cur=None, column=settings.OGR_DEFAULT_GEOMETRY_COLUMN): +def get_bbox(schema, table_name, uri_str=None, column=settings.OGR_DEFAULT_GEOMETRY_COLUMN): query = sql.SQL(''' with tmp as (select ST_Extent(l.{column}) as bbox from {table} l @@ -564,24 +532,23 @@ def get_bbox(schema, table_name, conn_cur=None, column=settings.OGR_DEFAULT_GEOM table=sql.Identifier(schema, table_name), column=sql.Identifier(column), ) - result = db_util.run_query(query, conn_cur=conn_cur)[0] + result = db_util.run_query(query, uri_str=uri_str)[0] return result -def get_table_crs(schema, table_name, conn_cur=None, column=settings.OGR_DEFAULT_GEOMETRY_COLUMN, *, use_internal_srid): - srid = get_column_srid(schema, table_name, column, conn_cur=conn_cur) - crs = db_util.get_crs_from_srid(srid, conn_cur, use_internal_srid=use_internal_srid) +def get_table_crs(schema, table_name, uri_str=None, column=settings.OGR_DEFAULT_GEOMETRY_COLUMN, *, use_internal_srid): + srid = get_column_srid(schema, table_name, column, uri_str=uri_str) + crs = db_util.get_crs_from_srid(srid, uri_str, use_internal_srid=use_internal_srid) return crs -def get_column_srid(schema, table, column, *, conn_cur=None): +def get_column_srid(schema, table, column, *, uri_str=None): query = 'select Find_SRID(%s, %s, %s);' - srid = db_util.run_query(query, (schema, table, column), conn_cur=conn_cur)[0][0] + srid = db_util.run_query(query, (schema, table, column), uri_str=uri_str)[0][0] return srid -def get_geometry_types(schema, table_name, *, column_name=settings.OGR_DEFAULT_GEOMETRY_COLUMN, conn_cur=None): - conn, cur = conn_cur or db_util.get_connection_cursor() +def get_geometry_types(schema, table_name, *, column_name=settings.OGR_DEFAULT_GEOMETRY_COLUMN, uri_str=None): query = sql.SQL(""" select distinct ST_GeometryType({column}) as geometry_type_name from {table} @@ -590,11 +557,9 @@ def get_geometry_types(schema, table_name, *, column_name=settings.OGR_DEFAULT_G column=sql.Identifier(column_name), ) try: - cur.execute(query) + rows = db_util.run_query(query, uri_str=uri_str) except BaseException as exc: logger.error(f'get_geometry_types ERROR') raise LaymanError(7) from exc - rows = cur.fetchall() - conn.commit() result = [row[0] for row in rows] return result diff --git a/src/layman/layer/db/table.py b/src/layman/layer/db/table.py index 0d5235993..40baeb05d 100644 --- a/src/layman/layer/db/table.py +++ b/src/layman/layer/db/table.py @@ -23,32 +23,18 @@ def get_layer_info(workspace, layername,): result = {} if table_uri: if layer_info['original_data_source'] == settings.EnumOriginalDataSource.FILE.value: - conn_cur = db_util.get_connection_cursor() + db_uri_str = None else: - try: - conn_cur = db_util.get_connection_cursor(db_uri_str=table_uri.db_uri_str,) - except BaseException: - result['db'] = { - 'schema': table_uri.schema, - 'table': table_uri.table, - 'geo_column': table_uri.geo_column, - 'external_uri': layer_util.redact_uri(table_uri.db_uri_str), - 'status': 'NOT_AVAILABLE', - 'error': 'Cannot connect to DB.', - } - return result - - _, cur = conn_cur + db_uri_str = table_uri.db_uri_str try: - cur.execute(f""" + rows = db_util.run_query(f""" SELECT schemaname, tablename, tableowner FROM pg_tables WHERE schemaname = %s AND tablename = %s - """, (table_uri.schema, table_uri.table, )) + """, (table_uri.schema, table_uri.table, ), uri_str=db_uri_str) except BaseException as exc: raise LaymanError(7) from exc - rows = cur.fetchall() if len(rows) > 0: result['db'] = { 'schema': table_uri.schema, @@ -70,26 +56,22 @@ def get_layer_info(workspace, layername,): return result -def delete_layer(workspace, layername, conn_cur=None): +def delete_layer(workspace, layername, ): """Deletes table from internal DB only""" table_name = get_internal_table_name(workspace, layername) if table_name: - if conn_cur is None: - conn_cur = db_util.get_connection_cursor() - conn, cur = conn_cur query = sql.SQL(""" DROP TABLE IF EXISTS {table} CASCADE """).format( table=sql.Identifier(workspace, table_name), ) try: - cur.execute(query) - conn.commit() + db_util.run_statement(query) except BaseException as exc: raise LaymanError(7)from exc -def set_internal_table_layer_srid(schema, table_name, srid, *, conn_cur=None): +def set_internal_table_layer_srid(schema, table_name, srid, ): query = '''SELECT UpdateGeometrySRID(%s, %s, %s, %s);''' params = (schema, table_name, settings.OGR_DEFAULT_GEOMETRY_COLUMN, srid) - db_util.run_query(query, params, conn_cur=conn_cur) + db_util.run_query(query, params) diff --git a/src/layman/layer/micka/csw.py b/src/layman/layer/micka/csw.py index 5d1c07e3f..870fbec6f 100644 --- a/src/layman/layer/micka/csw.py +++ b/src/layman/layer/micka/csw.py @@ -138,15 +138,14 @@ def get_template_path_and_values(workspace, layername, http_method): if geodata_type == settings.GEODATA_TYPE_VECTOR: table_uri = publ_info['_table_uri'] table_name = table_uri.table - conn_cur = db_util.get_connection_cursor(db_uri_str=table_uri.db_uri_str) try: languages = db.get_text_languages(table_uri.schema, table_name, table_uri.primary_key_column, - conn_cur=conn_cur) + uri_str=table_uri.db_uri_str) except LaymanError: languages = [] try: scale_denominator = db.guess_scale_denominator(table_uri.schema, table_name, table_uri.primary_key_column, - table_uri.geo_column, conn_cur=conn_cur) + table_uri.geo_column, uri_str=table_uri.db_uri_str) except LaymanError: scale_denominator = None spatial_resolution = { diff --git a/src/layman/layer/prime_db_schema/file_data_tasks.py b/src/layman/layer/prime_db_schema/file_data_tasks.py index e8c52fbef..aca09a713 100644 --- a/src/layman/layer/prime_db_schema/file_data_tasks.py +++ b/src/layman/layer/prime_db_schema/file_data_tasks.py @@ -30,8 +30,7 @@ def patch_after_feature_change( assert geodata_type == settings.GEODATA_TYPE_VECTOR table_uri = info['_table_uri'] - conn_cur = db_util.get_connection_cursor(db_uri_str=table_uri.db_uri_str) - bbox = db_get_bbox(table_uri.schema, table_uri.table, conn_cur=conn_cur, column=table_uri.geo_column) + bbox = db_get_bbox(table_uri.schema, table_uri.table, uri_str=table_uri.db_uri_str, column=table_uri.geo_column) if self.is_aborted(): raise AbortedException diff --git a/src/layman/layer/prime_db_schema/tasks.py b/src/layman/layer/prime_db_schema/tasks.py index 710787100..d6714ef4f 100644 --- a/src/layman/layer/prime_db_schema/tasks.py +++ b/src/layman/layer/prime_db_schema/tasks.py @@ -43,9 +43,8 @@ def refresh_file_data( # because for compressed files sent with chunks file_type would be UNKNOWN and table_uri not set publ_info = layman_util.get_publication_info(username, LAYER_TYPE, layername, context={'keys': ['table_uri']}) table_uri = publ_info['_table_uri'] - conn_cur = db_util.get_connection_cursor(db_uri_str=table_uri.db_uri_str) - bbox = db_get_bbox(table_uri.schema, table_uri.table, conn_cur=conn_cur, column=table_uri.geo_column) - crs = get_table_crs(table_uri.schema, table_uri.table, conn_cur=conn_cur, column=table_uri.geo_column, use_internal_srid=True) + bbox = db_get_bbox(table_uri.schema, table_uri.table, uri_str=table_uri.db_uri_str, column=table_uri.geo_column) + crs = get_table_crs(table_uri.schema, table_uri.table, uri_str=table_uri.db_uri_str, column=table_uri.geo_column, use_internal_srid=True) elif geodata_type == settings.GEODATA_TYPE_RASTER: bbox = gdal_get_bbox(username, layername) crs = gdal_get_crs(username, layername) diff --git a/src/layman/layer/qgis/wms.py b/src/layman/layer/qgis/wms.py index cce75d718..d24044905 100644 --- a/src/layman/layer/qgis/wms.py +++ b/src/layman/layer/qgis/wms.py @@ -71,15 +71,14 @@ def save_qgs_file(workspace, layer): db_schema = table_uri.schema layer_bbox = layer_bbox if not bbox_util.is_empty(layer_bbox) else crs_def.CRSDefinitions[crs].default_bbox qml = util.get_original_style_xml(workspace, layer) - conn_cur = db_util.get_connection_cursor(db_uri_str=table_uri.db_uri_str) - db_types = db.get_geometry_types(db_schema, table_name, column_name=table_uri.geo_column, conn_cur=conn_cur) + db_types = db.get_geometry_types(db_schema, table_name, column_name=table_uri.geo_column, uri_str=table_uri.db_uri_str) qml_geometry = util.get_geometry_from_qml_and_db_types(qml, db_types) db_cols = [ - col for col in db.get_all_column_infos(db_schema, table_name, conn_cur=conn_cur, omit_geometry_columns=True) + col for col in db.get_all_column_infos(db_schema, table_name, uri_str=table_uri.db_uri_str, omit_geometry_columns=True) if col.name != table_uri.primary_key_column ] source_type = util.get_source_type(db_types, qml_geometry) - column_srid = db.get_column_srid(db_schema, table_name, table_uri.geo_column, conn_cur=conn_cur) + column_srid = db.get_column_srid(db_schema, table_name, table_uri.geo_column, uri_str=table_uri.db_uri_str) layer_qml = util.fill_layer_template(layer, uuid, layer_bbox, crs, qml, source_type, db_cols, table_uri, column_srid, db_types) qgs_str = util.fill_project_template(layer, uuid, layer_qml, crs, settings.LAYMAN_OUTPUT_SRS_LIST, diff --git a/src/layman/layer/util.py b/src/layman/layer/util.py index d20bda218..4351f699b 100644 --- a/src/layman/layer/util.py +++ b/src/layman/layer/util.py @@ -290,7 +290,7 @@ def parse_and_validate_external_table_uri_str(external_table_uri_str): }) try: - conn_cur = db_util.get_connection_cursor(db_uri_str, encapsulate_exception=False) + db_util.get_connection_pool(db_uri_str=db_uri_str, encapsulate_exception=False) except psycopg2.OperationalError as exc: raise LaymanError(2, { 'parameter': 'external_table_uri', @@ -304,7 +304,7 @@ def parse_and_validate_external_table_uri_str(external_table_uri_str): if not geo_column: query = f'''select f_geometry_column from geometry_columns where f_table_schema = %s and f_table_name = %s order by f_geometry_column asc''' - query_res = db_util.run_query(query, (schema, table), conn_cur=conn_cur) + query_res = db_util.run_query(query, (schema, table), uri_str=db_uri_str, ) if len(query_res) == 0: raise LaymanError(2, { 'parameter': 'external_table_uri', @@ -332,10 +332,10 @@ def parse_and_validate_external_table_uri_str(external_table_uri_str): }) query = f'''select count(*) from information_schema.tables WHERE table_schema=%s and table_name=%s''' - query_res = db_util.run_query(query, (schema, table,), conn_cur=conn_cur) + query_res = db_util.run_query(query, (schema, table,), uri_str=db_uri_str, ) if not query_res[0][0]: query = f'''select table_schema, table_name from information_schema.tables WHERE lower(table_schema)=lower(%s) and lower(table_name)=lower(%s)''' - query_res = db_util.run_query(query, (schema, table,), conn_cur=conn_cur) + query_res = db_util.run_query(query, (schema, table,), uri_str=db_uri_str, ) suggestion = f" Did you mean \"{query_res[0][0]}\".\"{query_res[0][1]}\"?" if query_res else '' raise LaymanError(2, { 'parameter': 'external_table_uri', @@ -349,7 +349,7 @@ def parse_and_validate_external_table_uri_str(external_table_uri_str): }) query = f'''select count(*) from geometry_columns where f_table_schema = %s and f_table_name = %s and f_geometry_column = %s''' - query_res = db_util.run_query(query, (schema, table, geo_column), conn_cur=conn_cur) + query_res = db_util.run_query(query, (schema, table, geo_column), uri_str=db_uri_str, ) if not query_res[0][0]: raise LaymanError(2, { 'parameter': 'external_table_uri', @@ -363,7 +363,7 @@ def parse_and_validate_external_table_uri_str(external_table_uri_str): } }) - crs = get_table_crs(schema, table, conn_cur=conn_cur, column=geo_column, use_internal_srid=False) + crs = get_table_crs(schema, table, uri_str=db_uri_str, column=geo_column, use_internal_srid=False) if crs not in settings.INPUT_SRS_LIST: raise LaymanError(2, { 'parameter': 'external_table_uri', @@ -386,7 +386,7 @@ def parse_and_validate_external_table_uri_str(external_table_uri_str): idx.indisprimary AND cls.relname = %s AND nspace.nspname = %s''' - query_res = db_util.run_query(query, (table, schema), conn_cur=conn_cur, log_query=True) + query_res = db_util.run_query(query, (table, schema), uri_str=db_uri_str, log_query=True) primary_key_columns = [r[0] for r in query_res] if len(query_res) == 0: raise LaymanError(2, { @@ -413,7 +413,7 @@ def parse_and_validate_external_table_uri_str(external_table_uri_str): } }) - column_names = get_all_table_column_names(schema, table, conn_cur=conn_cur) + column_names = get_all_table_column_names(schema, table, uri_str=db_uri_str) unsafe_column_names = [c for c in column_names if not re.match(SAFE_PG_IDENTIFIER_PATTERN, c)] if unsafe_column_names: raise LaymanError(2, { diff --git a/test.sh b/test.sh index 813eb8ab5..8403d454c 100644 --- a/test.sh +++ b/test.sh @@ -11,7 +11,7 @@ if [ "$CI" == "true" ] then python3 src/assert_db.py && python3 src/wait_for_deps.py && python3 src/clear_layman_data.py && python3 -m pytest -m "not irritating" --timeout=60 -W ignore::DeprecationWarning -xvv src else - python3 src/assert_db.py && python3 src/wait_for_deps.py && python3 src/clear_layman_data.py && python3 -m pytest -W ignore::DeprecationWarning -xvv src + python3 src/assert_db.py && python3 src/wait_for_deps.py && python3 src/clear_layman_data.py && python3 -m pytest -m "not irritating" --timeout=60 -W ignore::DeprecationWarning -xvv src fi #python3 src/assert_db.py && python3 src/wait_for_deps.py && python3 src/clear_layman_data.py && python3 -m pytest -W ignore::DeprecationWarning --capture=tee-sys -xvv src/layman/gs_wfs_proxy_test.py #python3 src/assert_db.py && python3 src/wait_for_deps.py && python3 src/clear_layman_data.py && python3 -m pytest -W ignore::DeprecationWarning -xsvv src/layman/layer/client_test.py diff --git a/test_separated.sh b/test_separated.sh index e6f04c94d..4c88253a0 100644 --- a/test_separated.sh +++ b/test_separated.sh @@ -10,5 +10,5 @@ rm -rf tmp/artifacts/* max_fail="$1" if [ -z "$max_fail" ]; then max_fail=1; fi -python3 src/assert_db.py && python3 src/wait_for_deps.py && python3 src/clear_layman_data.py && python3 -m pytest -W ignore::DeprecationWarning --maxfail="$max_fail" -vv --ignore=tests/static_data/ tests +python3 src/assert_db.py && python3 src/wait_for_deps.py && python3 src/clear_layman_data.py && python3 -m pytest -W ignore::DeprecationWarning --maxfail="$max_fail" -sxvv --capture=tee-sys --ignore=tests/static_data/ tests/dynamic_data/publications/layer_by_used_servers/layer_by_used_servers_test.py::TestLayer::test_layer #python3 src/assert_db.py && python3 src/wait_for_deps.py && python3 src/clear_layman_data.py && python3 -m TEST_TYPE=optional pytest -W ignore::DeprecationWarning -sxvv --capture=tee-sys --nocleanup --ignore=tests/static_data/ tests diff --git a/test_tools/external_db.py b/test_tools/external_db.py index 8fcdc286b..a08f3d3c7 100644 --- a/test_tools/external_db.py +++ b/test_tools/external_db.py @@ -29,17 +29,16 @@ def ensure_db(): GRANT CONNECT ON DATABASE {EXTERNAL_DB_NAME} TO {READ_ONLY_USER}; ALTER DEFAULT PRIVILEGES GRANT SELECT ON TABLES TO {READ_ONLY_USER}; """ - db_util.run_statement(statement, ) + db_util.run_statement(statement, uri_str=URI_STR) yield def ensure_schema(schema): - conn_cur = db_util.get_connection_cursor(URI_STR) statement = sql.SQL(f'CREATE SCHEMA IF NOT EXISTS {{schema}} AUTHORIZATION {settings.LAYMAN_PG_USER}').format( schema=sql.Identifier(schema), ) - db_util.run_statement(statement, conn_cur=conn_cur) + db_util.run_statement(statement, uri_str=URI_STR) def ensure_table(schema, name, geo_column, *, primary_key_columns=None, other_columns=None, srid=4326): @@ -70,8 +69,7 @@ def ensure_table(schema, name, geo_column, *, primary_key_columns=None, other_co table=sql.Identifier(schema, name), columns=sql.SQL(',').join(columns), ) - conn_cur = db_util.get_connection_cursor(URI_STR) - db_util.run_statement(statement, conn_cur=conn_cur, log_query=True) + db_util.run_statement(statement, uri_str=URI_STR, log_query=True) def import_table(input_file_path, *, table=None, schema='public', geo_column=settings.OGR_DEFAULT_GEOMETRY_COLUMN, @@ -109,27 +107,25 @@ def import_table(input_file_path, *, table=None, schema='public', geo_column=set assert return_code == 0 and not stdout and not stderr, f"return_code={return_code}, stdout={stdout}, stderr={stderr}" if primary_key_column is None: - conn_cur = db_util.get_connection_cursor(URI_STR) statement = sql.SQL("alter table {table} drop column {primary_key}").format( table=sql.Identifier(schema, table), primary_key=sql.Identifier(primary_key_to_later_drop), ) - db_util.run_statement(statement, conn_cur=conn_cur) + db_util.run_statement(statement, uri_str=URI_STR, ) if additional_geo_column: - conn_cur = db_util.get_connection_cursor(URI_STR) statement = sql.SQL("alter table {table} add column {geo_column_2} GEOMETRY").format( table=sql.Identifier(schema, table), geo_column_2=sql.Identifier(additional_geo_column), ) - db_util.run_statement(statement, conn_cur=conn_cur) + db_util.run_statement(statement, uri_str=URI_STR, ) statement = sql.SQL("update {table} set {geo_column_2} = st_buffer({geo_column}, 15)").format( table=sql.Identifier(schema, table), geo_column_2=sql.Identifier(additional_geo_column), geo_column=sql.Identifier(geo_column), ) - db_util.run_statement(statement, conn_cur=conn_cur) + db_util.run_statement(statement, uri_str=URI_STR, ) return schema, table @@ -139,5 +135,4 @@ def drop_table(schema, name, *, if_exists=False): statement = sql.SQL(f'drop table {if_exists_str} {{table}}').format( table=sql.Identifier(schema, name) ) - conn_cur = db_util.get_connection_cursor(URI_STR) - db_util.run_statement(statement, conn_cur=conn_cur) + db_util.run_statement(statement, uri_str=URI_STR, ) diff --git a/tests/asserts/final/publication/internal.py b/tests/asserts/final/publication/internal.py index ffd154ecd..99b38aa9a 100644 --- a/tests/asserts/final/publication/internal.py +++ b/tests/asserts/final/publication/internal.py @@ -438,7 +438,7 @@ def point_coordinates(workspace, publ_type, name, *, point_id, crs, exp_coordina ) with app.app_context(): to_srid = db_util.get_internal_srid(crs) - coordinates = db_util.run_query(query, (to_srid, point_id), conn_cur=db_util.get_connection_cursor(table_uri.db_uri_str)) + coordinates = db_util.run_query(query, (to_srid, point_id), uri_str=table_uri.db_uri_str) assert len(coordinates) == 1, coordinates coordinates = coordinates[0] diff --git a/tests/dynamic_data/publications/layer_external_db/edge_db_username_and_password.py b/tests/dynamic_data/publications/layer_external_db/edge_db_username_and_password.py index 703f3b9f5..90661e049 100644 --- a/tests/dynamic_data/publications/layer_external_db/edge_db_username_and_password.py +++ b/tests/dynamic_data/publications/layer_external_db/edge_db_username_and_password.py @@ -55,7 +55,6 @@ class TestEdge(base_test.TestSingleRestPublication): def test_layer(self, layer: Publication, rest_method, rest_args, ): """Parametrized using pytest_generate_tests""" - conn_cur = db_util.get_connection_cursor(external_db.URI_STR) statement = sql.SQL(f""" DO $$BEGIN @@ -77,7 +76,7 @@ def test_layer(self, layer: Publication, rest_method, rest_args, ): schema=sql.Identifier(SCHEMA), table=sql.Identifier(TABLE), ) - db_util.run_statement(statement, conn_cur=conn_cur) + db_util.run_statement(statement, uri_str=external_db.URI_STR) # publish layer from external DB table rest_method.fn(layer, args=rest_args) diff --git a/tests/dynamic_data/publications/layer_external_db/external_db_test.py b/tests/dynamic_data/publications/layer_external_db/external_db_test.py index 310057d4d..47b67365b 100644 --- a/tests/dynamic_data/publications/layer_external_db/external_db_test.py +++ b/tests/dynamic_data/publications/layer_external_db/external_db_test.py @@ -219,9 +219,8 @@ def test_layer(self, layer: Publication, rest_method, rest_args, params): 'primary_key_column': primary_key_column, 'additional_geo_column': params['additional_geo_column'], }) - conn_cur = db_util.get_connection_cursor(external_db.URI_STR) query = f'''select type from geometry_columns where f_table_schema = %s and f_table_name = %s and f_geometry_column = %s''' - result = db_util.run_query(query, (schema, table, geo_column), conn_cur=conn_cur) + result = db_util.run_query(query, (schema, table, geo_column), uri_str=external_db.URI_STR) assert result[0][0] == params['exp_geometry_type'] # publish layer from external DB table diff --git a/tests/dynamic_data/publications/layer_wfst/new_attribute_test.py b/tests/dynamic_data/publications/layer_wfst/new_attribute_test.py index 337821c57..488703513 100644 --- a/tests/dynamic_data/publications/layer_wfst/new_attribute_test.py +++ b/tests/dynamic_data/publications/layer_wfst/new_attribute_test.py @@ -245,7 +245,6 @@ def test_new_attribute(self, layer: Publication, rest_args, params, parametrizat style_type = parametrization.style_file.style_type wfs_url = f"http://{settings.LAYMAN_SERVER_NAME}/geoserver/{workspace}/wfs" table_uris = {} - conn_cur = None # get current attributes and assert that new attributes are not yet present old_db_attributes = {} @@ -255,10 +254,9 @@ def test_new_attribute(self, layer: Publication, rest_args, params, parametrizat with app.app_context(): table_uri = get_publication_info(workspace, process_client.LAYER_TYPE, layer_name, context={'keys': ['table_uri']})['_table_uri'] - conn_cur = conn_cur or db_util.get_connection_cursor(table_uri.db_uri_str) table_uris[layer_name] = table_uri old_db_attributes[layer_name] = db.get_all_table_column_names(table_uri.schema, table_uri.table, - conn_cur=conn_cur) + uri_str=table_uri, ) for attr_name in attr_names: assert attr_name not in old_db_attributes[layer_name], \ f"old_db_attributes={old_db_attributes[layer_name]}, attr_name={attr_name}" @@ -287,7 +285,7 @@ def test_new_attribute(self, layer: Publication, rest_args, params, parametrizat for layer_name, attr_names in new_attributes: # assert that exactly all attr_names were created in DB table table_uri = table_uris[layer_name] - db_attributes = db.get_all_table_column_names(table_uri.schema, table_uri.table, conn_cur=conn_cur) + db_attributes = db.get_all_table_column_names(table_uri.schema, table_uri.table, uri_str=table_uri) for attr_name in attr_names: assert attr_name in db_attributes, f"db_attributes={db_attributes}, attr_name={attr_name}" assert set(attr_names).union(set(old_db_attributes[layer_name])) == set(db_attributes)