diff --git a/pyproject.toml b/pyproject.toml index 71e8360..5a4187c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,9 +44,7 @@ test = [ "pytest-asyncio", "pytest-benchmark", "httpx", - "psycopg2", - "pytest-pgsql", - "sqlalchemy>=1.1,<1.4", + "pytest-postgresql", "mapbox-vector-tile", "protobuf>=3.0,<4.0", "numpy", diff --git a/tests/conftest.py b/tests/conftest.py index db5753a..951318e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,9 @@ import os from contextlib import asynccontextmanager +import psycopg import pytest -import pytest_pgsql +from pytest_postgresql.janitor import DatabaseJanitor from tipg.settings import CustomSQLSettings, DatabaseSettings, PostgresSettings @@ -17,76 +18,87 @@ TEMPLATE_DIRECTORY = os.path.join(FIXTURES_DIR, "templates") SQL_FUNCTIONS_DIRECTORY = os.path.join(FIXTURES_DIR, "functions") -test_db = pytest_pgsql.TransactedPostgreSQLTestDB.create_fixture( - "test_db", scope="session", use_restore_state=False -) - @pytest.fixture(scope="session") -def database_url(test_db): - """ - Session scoped fixture to launch a postgresql database in a separate process. We use psycopg2 to ingest test data - because pytest-asyncio event loop is a function scoped fixture and cannot be called within the current scope. Yields - a database url which we pass to our application through a monkeypatched environment variable. - """ - assert test_db.install_extension("postgis") - - # make sure we have a `public` schema - test_db.create_schema("public", exists_ok=True) - - test_db.run_sql_file(os.path.join(FIXTURES_DIR, "landsat_wrs.sql")) - assert test_db.has_table("public.landsat_wrs") - - test_db.run_sql_file(os.path.join(FIXTURES_DIR, "my_data.sql")) - assert test_db.has_table("public.my_data") - - test_db.run_sql_file(os.path.join(FIXTURES_DIR, "nongeo_data.sql")) - assert test_db.has_table("public.nongeo_data") - - test_db.connection.execute( - "CREATE TABLE public.landsat AS SELECT geom, ST_Centroid(geom) as centroid, ogc_fid, id, pr, path, row from public.landsat_wrs;" - ) - test_db.connection.execute("ALTER TABLE public.landsat ADD PRIMARY KEY (ogc_fid);") - assert test_db.has_table("public.landsat") - - count_landsat = test_db.connection.execute( - "SELECT COUNT(*) FROM public.landsat_wrs" - ).scalar() - count_landsat_centroid = test_db.connection.execute( - "SELECT COUNT(*) FROM public.landsat" - ).scalar() - assert count_landsat == count_landsat_centroid - - # Add table with Huge geometries - test_db.run_sql_file(os.path.join(FIXTURES_DIR, "canada.sql")) - assert test_db.has_table("public.canada") - - # add table with geometries not in WGS84 - test_db.run_sql_file(os.path.join(FIXTURES_DIR, "minnesota.sql")) - assert test_db.has_table("public.minnesota") +def database(postgresql_proc): + """Create Database Fixture.""" + with DatabaseJanitor( + user=postgresql_proc.user, + host=postgresql_proc.host, + port=postgresql_proc.port, + dbname="test_db", + version=postgresql_proc.version, + password="password", + ) as jan: + yield jan - # add a `myschema` schema - test_db.create_schema("myschema") - assert test_db.has_schema("myschema") - test_db.connection.execute( - "CREATE TABLE myschema.landsat AS SELECT * FROM public.landsat_wrs;" - ) - assert test_db.has_table("myschema.landsat") - count_landsat_schema = test_db.connection.execute( - "SELECT COUNT(*) FROM myschema.landsat" - ).scalar() - assert count_landsat == count_landsat_schema +def _get_sql(source: str) -> str: + with open(source, "r") as fd: + to_run = fd.readlines() - # add a `userschema` schema - test_db.create_schema("userschema") - assert test_db.has_schema("userschema") + return "\n".join(to_run) - test_db.connection.execute( - "CREATE OR REPLACE FUNCTION userschema.test_no_params() RETURNS TABLE(foo integer, location geometry) AS 'SELECT 1, ST_MakePoint(0,0);' LANGUAGE SQL;" - ) - return str(test_db.connection.engine.url) +@pytest.fixture(scope="session") +def database_url(database): + """add data to the database fixture""" + db_url = f"postgresql://{database.user}:{database.password}@{database.host}:{database.port}/{database.dbname}" + with psycopg.connect(db_url, autocommit=True) as conn: + with conn.cursor() as cur: + cur.execute(f"ALTER DATABASE {database.dbname} SET TIMEZONE='UTC';") + cur.execute("SET TIME ZONE 'UTC';") + + cur.execute("CREATE EXTENSION IF NOT EXISTS postgis;") + # make sure postgis extension exists + assert cur.execute( + "SELECT EXISTS(SELECT 1 FROM pg_extension WHERE extname='postgis' LIMIT 1);" + ).fetchone()[0] + + cur.execute("CREATE SCHEMA IF NOT EXISTS public;") + + # load table + cur.execute(_get_sql(os.path.join(FIXTURES_DIR, "landsat_wrs.sql"))) + cur.execute(_get_sql(os.path.join(FIXTURES_DIR, "my_data.sql"))) + cur.execute(_get_sql(os.path.join(FIXTURES_DIR, "nongeo_data.sql"))) + + cur.execute( + "CREATE TABLE public.landsat AS SELECT geom, ST_Centroid(geom) as centroid, ogc_fid, id, pr, path, row from public.landsat_wrs;" + ) + cur.execute("ALTER TABLE public.landsat ADD PRIMARY KEY (ogc_fid);") + + count_landsat = cur.execute( + "SELECT COUNT(*) FROM public.landsat_wrs" + ).fetchone()[0] + count_landsat_centroid = cur.execute( + "SELECT COUNT(*) FROM public.landsat" + ).fetchone()[0] + assert count_landsat == count_landsat_centroid + + # Add table with Huge geometries + cur.execute(_get_sql(os.path.join(FIXTURES_DIR, "canada.sql"))) + + # add table with geometries not in WGS84 + cur.execute(_get_sql(os.path.join(FIXTURES_DIR, "minnesota.sql"))) + + # add a `myschema` schema + cur.execute("CREATE SCHEMA IF NOT EXISTS myschema;") + cur.execute( + "CREATE TABLE myschema.landsat AS SELECT * FROM public.landsat_wrs;" + ) + count_landsat_schema = cur.execute( + "SELECT COUNT(*) FROM myschema.landsat" + ).fetchone()[0] + assert count_landsat == count_landsat_schema + + # add a `userschema` schema + cur.execute("CREATE SCHEMA IF NOT EXISTS userschema;") + + cur.execute( + "CREATE OR REPLACE FUNCTION userschema.test_no_params() RETURNS TABLE(foo integer, location geometry) AS 'SELECT 1, ST_MakePoint(0,0);' LANGUAGE SQL;" + ) + + return db_url def create_tipg_app( @@ -174,10 +186,6 @@ def app(database_url, monkeypatch): db_settings.tables = None db_settings.functions = None - # Remove middlewares https://github.com/encode/starlette/issues/472 - app.user_middleware = [] - app.middleware_stack = app.build_middleware_stack() - with TestClient(app) as app: yield app