Skip to content

Commit

Permalink
Fix tests to use proper database, regardless of local .env file
Browse files Browse the repository at this point in the history
  • Loading branch information
nothingface0 committed Jun 6, 2024
1 parent a1dae2d commit a5b10cc
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 24 deletions.
57 changes: 37 additions & 20 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sqlalchemy
from sqlalchemy import create_engine, text
from sqlalchemy_utils import create_database, database_exists, drop_database
from sqlalchemy.exc import IntegrityError
from flask import Flask

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
Expand All @@ -13,9 +14,12 @@
from db import DQM2MirrorDB
from dqmsquare_cfg import load_cfg

DB_PROD_NAME = "postgres_production_test"
DB_PLAY_NAME = "postgres_playback_test"


def get_or_create_db(
db_uri: str, username: str, password: str, host: str, port: str, db_name: str
db_uri: str, username: str, password: str, host: str, port: int, db_name: str
):
engine: sqlalchemy.engine.Engine = create_engine(db_uri)
if not database_exists(engine.url):
Expand Down Expand Up @@ -67,10 +71,15 @@ def fill_db(db: DQM2MirrorDB) -> None:
try:
session.execute(
text(
f"""INSERT into runs ({str(db.TB_DESCRIPTION_RUNS_COLS).replace("[", "").replace("]", "").replace("'", "")}) VALUES ({format_entry_to_db_entry(run, [13])})"""
"INSERT into runs "
+ f"""({str(db.TB_DESCRIPTION_RUNS_COLS).replace("[", "").replace("]", "").replace("'", "")}) """
+ f"""VALUES ({format_entry_to_db_entry(run, [13])})"""
)
)
session.commit()
except IntegrityError as e:
print(f"Skipping already inserted data", e)
continue
except Exception as e:
print("Error when creating run fixture:", e)
session.rollback()
Expand All @@ -79,32 +88,41 @@ def fill_db(db: DQM2MirrorDB) -> None:
try:
session.execute(
text(
f"""INSERT into graphs ({str(db.TB_DESCRIPTION_GRAPHS_COLS).replace("[", "").replace("]", "").replace("'", "")}) VALUES ({format_entry_to_db_entry(graph, [3, 4])})"""
"INSERT into graphs "
+ f"""({str(db.TB_DESCRIPTION_GRAPHS_COLS).replace("[", "").replace("]", "").replace("'", "")}) """
+ f"""VALUES ({format_entry_to_db_entry(graph, [3, 4])})"""
)
)
session.commit()
except IntegrityError as e:
print(f"Skipping already inserted data")
continue
except Exception as e:
print("Error when creating graph fixture:", e)
session.rollback()


@pytest.fixture
def cfg():
yield load_cfg()
def dqm2_config():
# If there's a local .env file, it will be used to override the environment vars
config: dict = load_cfg()
config["DB_PRODUCTION_NAME"] = DB_PROD_NAME
config["DB_PLAYBACK_NAME"] = DB_PLAY_NAME
yield config


@pytest.fixture
def testing_database():
username = os.environ.get("POSTGRES_USERNAME", "postgres")
password = os.environ.get("POSTGRES_PASSWORD", "postgres")
host = os.environ.get("POSTGRES_HOST", "127.0.0.1")
port = os.environ.get("POSTGRES_PORT", 5432)
port = int(os.environ.get("POSTGRES_PORT", 5432))
db_name = "postgres_test"
db_uri = DQM2MirrorDB.format_db_uri(
username=username, password=password, host=host, port=port, db_name=db_name
)

db = get_or_create_db(
db: DQM2MirrorDB = get_or_create_db(
db_uri=db_uri,
username=username,
password=password,
Expand All @@ -123,30 +141,29 @@ def testing_databases():
username = os.environ.get("POSTGRES_USERNAME", "postgres")
password = os.environ.get("POSTGRES_PASSWORD", "postgres")
host = os.environ.get("POSTGRES_HOST", "127.0.0.1")
port = os.environ.get("POSTGRES_PORT", 5432)
db_prod_name = "postgres_production_test"
db_play_name = "postgres_playback_test"
db_prod_uri = DQM2MirrorDB.format_db_uri(
username=username, password=password, host=host, port=port, db_name=db_prod_name
port: int = int(os.environ.get("POSTGRES_PORT", 5432))

db_prod_uri: str = DQM2MirrorDB.format_db_uri(
username=username, password=password, host=host, port=port, db_name=DB_PROD_NAME
)
db_play_uri = DQM2MirrorDB.format_db_uri(
username=username, password=password, host=host, port=port, db_name=db_play_name
db_play_uri: str = DQM2MirrorDB.format_db_uri(
username=username, password=password, host=host, port=port, db_name=DB_PLAY_NAME
)
db_prod = get_or_create_db(
db_prod: DQM2MirrorDB = get_or_create_db(
db_uri=db_prod_uri,
username=username,
password=password,
host=host,
port=port,
db_name=db_prod_name,
db_name=DB_PROD_NAME,
)
db_play = get_or_create_db(
db_play: DQM2MirrorDB = get_or_create_db(
db_uri=db_play_uri,
username=username,
password=password,
host=host,
port=port,
db_name=db_play_name,
db_name=DB_PLAY_NAME,
)
fill_db(db_prod)
fill_db(db_play)
Expand All @@ -156,8 +173,8 @@ def testing_databases():


@pytest.fixture
def app(cfg, testing_databases: tuple[DQM2MirrorDB, DQM2MirrorDB]):
app: Flask = server.create_app(cfg)
def app(dqm2_config: dict, testing_databases: tuple[DQM2MirrorDB, DQM2MirrorDB]):
app: Flask = server.create_app(dqm2_config)
app.config.update(
{
"TESTING": True,
Expand Down
16 changes: 12 additions & 4 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# Tests to check Flask server and its responses
# Needed to be passed as a fixture
from fixtures import app, cfg, client, testing_databases, testing_database
from fixtures import app, dqm2_config, client, testing_databases, testing_database


def test_server_home(client: FlaskClient):
Expand Down Expand Up @@ -39,9 +39,16 @@ def test_server_get_run(client: FlaskClient):
assert response.status_code == 200
response = json.loads(response.text)
assert isinstance(response, list)
assert response[0][0] == 377187
assert response[0][0] == 358792
assert response[0][1] == None

response = client.get("/api?what=get_run&run=358791&db=production")
assert response.status_code == 200
response = json.loads(response.text)
assert isinstance(response, list)
assert response[0][0] == 358788
assert response[0][1] == 358792


def test_server_get_clients(client: FlaskClient):
# Wrong GET argument
Expand All @@ -57,9 +64,10 @@ def test_server_get_clients(client: FlaskClient):
assert response.status_code == 400

# Proper request
response = client.get("/api?what=get_clients&from=377187&to=381574&db=production")
response = client.get("/api?what=get_clients&from=358788&to=358792&db=production")
assert response.status_code == 200
response = json.loads(response.text)
print(response)
assert isinstance(response, list)
assert response[0] == "beam"
assert response[-1] == "visualization-live"
assert response[-1] == "visualization-live-secondInstance"

0 comments on commit a5b10cc

Please sign in to comment.