diff --git a/.dockerignore b/.dockerignore index fb78557..201f34e 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,5 +1,14 @@ **/*.log **/*.db **/*.pdf +.env* +.venv +.git +.pytest_cache +.vscode tmp __pycache__ +tests +docs +log +other diff --git a/Dockerfile b/Dockerfile index f3671bf..6149062 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,9 +12,11 @@ ENV TZ="Europe/Zurich" RUN apt update \ && apt upgrade -y \ && apt-get update --fix-missing \ - && apt install -y python3-psycopg2 python3-pip \ - && python3 -m pip install -r requirements.txt \ - && useradd dqm \ + && apt install -y python3-psycopg2 python3-pip + +RUN python3 -m pip install -r requirements.txt + +RUN useradd dqm \ && mkdir -p /dqmsquare_mirror/log \ && mkdir -p /home/dqm/ \ && chown -R dqm /home/dqm/ \ diff --git a/db.py b/db.py index 11c3306..e7aa3db 100644 --- a/db.py +++ b/db.py @@ -103,18 +103,43 @@ class DQM2MirrorDB: def __str__(self): return f"{self.__class__.__name__}: {self.db_uri}" - def __init__(self, log: logging.Logger, db_uri: str = None, server: bool = False): + def __init__( + self, + log: logging.Logger, + username: str = "postgres", + password: str = "postgres", + host: str = "postgres", + port: int = 5432, + db_name: str = "postgres", + server: bool = False, + ): """ The server flag will determine if table creation will take place or not, upon initialization. """ - self.log = log + self.password: str = password + self.username: str = username + self.host: str = host + self.port: int = port + self.db_name: str = db_name + + self.log: logging.Logger = log self.log.info("\n\n DQM2MirrorDB ===== init ") - self.db_uri = db_uri + self.db_uri: str = self.format_db_uri( + host=self.host, + port=self.port, + username=self.username, + password=self.password, + db_name=self.db_name, + ) - if not self.db_uri: + if self.host == ":memory:": self.db_uri = ":memory:" + self.log.info( + f"Connecting to database {self.db_name} on {self.username}@{self.host}:{self.port}" + ) + self.engine = sqlalchemy.create_engine( url=self.db_uri, poolclass=sqlalchemy.pool.QueuePool, @@ -123,7 +148,7 @@ def __init__(self, log: logging.Logger, db_uri: str = None, server: bool = False ) if not database_exists(self.engine.url): raise DatabaseNotFoundError( - f"Database name was not found when connecting to '{self.db_uri}'" + f"Database {self.db_name} was not found on '{self.host}:{self.port}'" ) self.Session = sessionmaker(bind=self.engine) @@ -132,6 +157,19 @@ def __init__(self, log: logging.Logger, db_uri: str = None, server: bool = False self.db_meta = sqlalchemy.MetaData(bind=self.engine) self.db_meta.reflect() + @staticmethod + def format_db_uri( + username: str = "postgres", + password: str = "postgres", + host: str = "postgres", + port: int = 5432, + db_name="postgres", + ) -> str: + """ + Helper function to format the DB URI for SQLAclhemy + """ + return f"postgresql://{username}:{password}@{host}:{port}/{db_name}" + def create_tables(self): """ Initialize the databases diff --git a/dqmsquare_cfg.py b/dqmsquare_cfg.py index abe5423..843b569 100644 --- a/dqmsquare_cfg.py +++ b/dqmsquare_cfg.py @@ -17,41 +17,31 @@ TZ = pytz.timezone(TIMEZONE) -def format_db_uri( - username: str = "postgres", - password: str = "postgres", - host: str = "postgres", - port: int = 5432, - db_name="postgres", -) -> str: - """ - Helper function to format the DB URI for SQLAclhemy - """ - return f"postgresql://{username}:{password}@{host}:{port}/{db_name}" - - def load_cfg() -> dict: """ Prepare configuration, using .env file """ load_dotenv() + # No leading slash: cinder/dqmsquare mount_path = os.path.join("cinder", "dqmsquare") - ### default values === > + ### default values cfg = {} - cfg["VERSION"] = "1.3.1" + cfg["VERSION"] = "1.3.2" cfg["ENV"] = os.environ.get("ENV", "development") # How often to try to get CMSSW jobs info # sec, int - cfg["GRABBER_SLEEP_TIME_INFO"] = os.environ.get("GRABBER_SLEEP_TIME_INFO", 5) + cfg["GRABBER_SLEEP_TIME_INFO"] = int(os.environ.get("GRABBER_SLEEP_TIME_INFO", 5)) # How often to ping the cluster machines for their status. # Keep it above 30 secs. # sec, int - cfg["GRABBER_SLEEP_TIME_STATUS"] = os.environ.get("GRABBER_SLEEP_TIME_STATUS", 30) + cfg["GRABBER_SLEEP_TIME_STATUS"] = int( + os.environ.get("GRABBER_SLEEP_TIME_STATUS", 30) + ) cfg["LOGGER_ROTATION_TIME"] = 24 # h, int cfg["LOGGER_MAX_N_LOG_FILES"] = 10 # int @@ -61,7 +51,7 @@ def load_cfg() -> dict: cfg["FFF_PORT"] = "9215" # Flask server config - cfg["SERVER_DEBUG"] = os.environ.get("SERVER_DEBUG", False) + cfg["SERVER_DEBUG"] = bool(os.environ.get("SERVER_DEBUG", False)) # MACHETE if isinstance(cfg["SERVER_DEBUG"], str): cfg["SERVER_DEBUG"] = True if cfg["SERVER_DEBUG"] == "True" else False @@ -89,15 +79,15 @@ def load_cfg() -> dict: "CMSWEB_FRONTEND_PROXY_URL", # If value is not found in .env ( - "https://cmsweb-testbed.cern.ch/dqm/dqm-square-origin-rubu" + "https://cmsweb.cern.ch/dqm/dqm-square-origin" if cfg["ENV"] == "testbed" else ( - "https://cmsweb-testbed.cern.ch/dqm/dqm-square-origin-rubu" + "https://cmsweb.cern.ch/dqm/dqm-square-origin" if cfg["ENV"] == "production" else ( - "https://cmsweb-testbed.cern.ch/dqm/dqm-square-origin-rubu" + "https://cmsweb.cern.ch/dqm/dqm-square-origin" if cfg["ENV"] == "test4" - else "https://cmsweb-testbed.cern.ch/dqm/dqm-square-origin-rubu" + else "https://cmsweb.cern.ch/dqm/dqm-square-origin" ) ) ), @@ -159,20 +149,20 @@ def load_cfg() -> dict: if isinstance(cfg["GRABBER_DEBUG"], str): cfg["GRABBER_DEBUG"] = True if cfg["GRABBER_DEBUG"] == "True" else False - cfg["DB_PLAYBACK_URI"] = format_db_uri( - username=os.environ.get("POSTGRES_USERNAME", "postgres"), - password=os.environ.get("POSTGRES_PASSWORD", "postgres"), - host=os.environ.get("POSTGRES_HOST", "127.0.0.1"), - port=os.environ.get("POSTGRES_PORT", 5432), - db_name=os.environ.get("POSTGRES_PLAYBACK_DB_NAME", "postgres"), - ) - cfg["DB_PRODUCTION_URI"] = format_db_uri( - username=os.environ.get("POSTGRES_USERNAME", "postgres"), - password=os.environ.get("POSTGRES_PASSWORD", "postgres"), - host=os.environ.get("POSTGRES_HOST", "127.0.0.1"), - port=os.environ.get("POSTGRES_PORT", 5432), - db_name=os.environ.get("POSTGRES_PRODUCTION_DB_NAME", "postgres_production"), + cfg["DB_PLAYBACK_USERNAME"] = os.environ.get("POSTGRES_USERNAME", "postgres") + cfg["DB_PLAYBACK_PASSWORD"] = os.environ.get("POSTGRES_PASSWORD", "postgres") + cfg["DB_PLAYBACK_HOST"] = os.environ.get("POSTGRES_HOST", "127.0.0.1") + cfg["DB_PLAYBACK_PORT"] = os.environ.get("POSTGRES_PORT", 5432) + cfg["DB_PLAYBACK_NAME"] = os.environ.get("POSTGRES_PLAYBACK_DB_NAME", "postgres") + + cfg["DB_PRODUCTION_USERNAME"] = os.environ.get("POSTGRES_USERNAME", "postgres") + cfg["DB_PRODUCTION_PASSWORD"] = os.environ.get("POSTGRES_PASSWORD", "postgres") + cfg["DB_PRODUCTION_HOST"] = os.environ.get("POSTGRES_HOST", "127.0.0.1") + cfg["DB_PRODUCTION_PORT"] = os.environ.get("POSTGRES_PORT", 5432) + cfg["DB_PRODUCTION_NAME"] = os.environ.get( + "POSTGRES_PRODUCTION_DB_NAME", "postgres_production" ) + cfg["TIMEZONE"] = TIMEZONE return cfg diff --git a/dqmsquare_report.py b/dqmsquare_report.py index 63f85a6..c53dcaf 100644 --- a/dqmsquare_report.py +++ b/dqmsquare_report.py @@ -20,11 +20,15 @@ mode = ["p7"] ### DQM^2-MIRROR DB - DB_PLAYBACK_URI = DQM2MirrorDB(log, cfg["DB_PLAYBACK_URI"], server=True) - db_production = DQM2MirrorDB(log, cfg["DB_PRODUCTION_URI"], server=True) - - ### - db = db_production + db = DQM2MirrorDB( + log=log, + host=cfg.get("DB_PRODUCTION_HOST"), + port=cfg.get("DB_PRODUCTION_PORT"), + username=cfg.get("DB_PRODUCTION_USERNAME"), + password=cfg.get("DB_PRODUCTION_PASSWORD"), + db_name=cfg.get("DB_PRODUCTION_NAME"), + server=True, + ) start_date = "01/06/2022" end_date = "01/09/2022" diff --git a/grabber.py b/grabber.py index 55a9497..48af517 100644 --- a/grabber.py +++ b/grabber.py @@ -145,6 +145,7 @@ def get_cluster_status(db: DQM2MirrorDB, cluster: str = "playback"): Function that queries the gateway playback machine periodically to get the status of the production or playback cluster machines. """ + logger.debug(f"Requesting {cluster} cluster status.") url = urljoin( cfg["CMSWEB_FRONTEND_PROXY_URL"] + "/", "cr/exe?" + urlencode({"cluster": cluster, "what": "get_cluster_status"}), @@ -163,6 +164,7 @@ def get_cluster_status(db: DQM2MirrorDB, cluster: str = "playback"): raise Exception( f"Failed to fetch {cluster} status. Got ({response.status_code}) {response.text}" ) + logger.debug(f"Got {cluster} cluster status.") try: response = response.json() @@ -184,9 +186,9 @@ def get_latest_info_from_hosts(hosts: list[str], db: DQM2MirrorDB) -> None: if __name__ == "__main__": - run_modes = ["playback", "production"] - playback_machines = cfg["FFF_PLAYBACK_MACHINES"] - production_machines = cfg["FFF_PRODUCTION_MACHINES"] + run_modes: list[str] = ["playback", "production"] + playback_machines: list[str] = cfg["FFF_PLAYBACK_MACHINES"] + production_machines: list[str] = cfg["FFF_PRODUCTION_MACHINES"] if len(sys.argv) > 1 and sys.argv[1] == "playback": set_log_handler( @@ -224,10 +226,10 @@ def get_latest_info_from_hosts(hosts: list[str], db: DQM2MirrorDB) -> None: logger.info(f"Configured logger for grabber, level={level}") ### global variables and auth cookies - cmsweb_proxy_url = cfg["CMSWEB_FRONTEND_PROXY_URL"] - cert_path = [cfg["SERVER_GRID_CERT_PATH"], cfg["SERVER_GRID_KEY_PATH"]] + cmsweb_proxy_url: str = cfg["CMSWEB_FRONTEND_PROXY_URL"] + cert_path: list[str] = [cfg["SERVER_GRID_CERT_PATH"], cfg["SERVER_GRID_KEY_PATH"]] - env_secret = os.environ.get("DQM_FFF_SECRET") + env_secret: str = os.environ.get("DQM_FFF_SECRET") if env_secret: fff_secret = env_secret logger.debug("Found secret in environmental variables") @@ -235,14 +237,30 @@ def get_latest_info_from_hosts(hosts: list[str], db: DQM2MirrorDB) -> None: logger.warning("No secret found in environmental variables") # Trailing whitespace in secret leads to crashes, strip it - cookies = {str(cfg["FFF_SECRET_NAME"]): env_secret.strip()} + cookies: dict[str, str] = {str(cfg["FFF_SECRET_NAME"]): env_secret.strip()} # DB CONNECTION - db_playback, db_production = None, None + db_playback: DQM2MirrorDB = None + db_production: DQM2MirrorDB = None + if "playback" in run_modes: - db_playback = DQM2MirrorDB(logger, cfg["DB_PLAYBACK_URI"]) + db_playback = DQM2MirrorDB( + log=logger, + host=cfg.get("DB_PLAYBACK_HOST"), + port=cfg.get("DB_PLAYBACK_PORT"), + username=cfg.get("DB_PLAYBACK_USERNAME"), + password=cfg.get("DB_PLAYBACK_PASSWORD"), + db_name=cfg.get("DB_PLAYBACK_NAME"), + ) if "production" in run_modes: - db_production = DQM2MirrorDB(logger, cfg["DB_PRODUCTION_URI"]) + db_production = DQM2MirrorDB( + log=logger, + host=cfg.get("DB_PRODUCTION_HOST"), + port=cfg.get("DB_PRODUCTION_PORT"), + username=cfg.get("DB_PRODUCTION_USERNAME"), + password=cfg.get("DB_PRODUCTION_PASSWORD"), + db_name=cfg.get("DB_PRODUCTION_NAME"), + ) logger.info("Starting loop for modes " + str(run_modes)) diff --git a/server.py b/server.py index f6a2ee8..7fa35cc 100644 --- a/server.py +++ b/server.py @@ -17,9 +17,10 @@ log = logging.getLogger(__name__) log.info("start_server() call ... ") +VALID_DATABASE_OPTIONS = ["playback", "production"] -def create_app(cfg): +def create_app(cfg: dict): app = Flask( __name__, static_url_path=os.path.join("/", cfg["SERVER_URL_PREFIX"], "static") ) @@ -56,8 +57,24 @@ def create_app(cfg): ).strip() } - db_playback = DQM2MirrorDB(log, cfg["DB_PLAYBACK_URI"], server=True) - db_production = DQM2MirrorDB(log, cfg["DB_PRODUCTION_URI"], server=True) + db_playback = DQM2MirrorDB( + log=log, + host=cfg.get("DB_PLAYBACK_HOST"), + port=cfg.get("DB_PLAYBACK_PORT"), + username=cfg.get("DB_PLAYBACK_USERNAME"), + password=cfg.get("DB_PLAYBACK_PASSWORD"), + db_name=cfg.get("DB_PLAYBACK_NAME"), + server=True, + ) + db_production = DQM2MirrorDB( + log=log, + host=cfg.get("DB_PRODUCTION_HOST"), + port=cfg.get("DB_PRODUCTION_PORT"), + username=cfg.get("DB_PRODUCTION_USERNAME"), + password=cfg.get("DB_PRODUCTION_PASSWORD"), + db_name=cfg.get("DB_PRODUCTION_NAME"), + server=True, + ) databases = { "playback": db_playback, "production": db_production, @@ -138,28 +155,41 @@ def dqm2_api(): """ Get data from DQM^2 Mirror's Databases. """ - log.info(flask.request.base_url) what = flask.request.args.get("what") - if what == "get_run": - run = flask.request.args.get("run", default=0) - db_name = flask.request.args.get("db", default="") + try: + run = int(flask.request.args.get("run", type=int)) + except (ValueError, TypeError): + return f"run must be an integer", 400 + db_name = flask.request.args.get("db", type=str) + if db_name not in VALID_DATABASE_OPTIONS: + return f"db must be one of {VALID_DATABASE_OPTIONS}", 400 db_ = databases.get(db_name, db_playback) run_data = db_.get_mirror_data(run) runs_around = db_.get_runs_around(run) return json.dumps([runs_around, run_data]) elif what == "get_graph": - run = flask.request.args.get("run", default=0) - db_name = flask.request.args.get("db", default="") + try: + run = int(flask.request.args.get("run", type=int)) + except (ValueError, TypeError): + return f"run must be an integer", 400 + db_name = flask.request.args.get("db", type=str) + if db_name not in VALID_DATABASE_OPTIONS: + return f"db must be one of {VALID_DATABASE_OPTIONS}", 400 db_ = databases.get(db_name, db_playback) graph_data = db_.get_graphs_data(run) return json.dumps(graph_data) elif what == "get_runs": - run_from = flask.request.args.get("from", default=0) - run_to = flask.request.args.get("to", default=0) - bad_only = flask.request.args.get("bad_only", default=0) - with_ls_only = flask.request.args.get("ls", default=0) - db_name = flask.request.args.get("db", default="") + try: + run_from = int(flask.request.args.get("from", type=int)) + run_to = int(flask.request.args.get("to", type=int)) + bad_only = int(flask.request.args.get("bad_only", type=int)) + with_ls_only = int(flask.request.args.get("ls", type=int)) + except (ValueError, TypeError): + return (f"to, from, bad_only and ls must be integers", 400) + db_name = flask.request.args.get("db", default="", type=str) + if db_name not in VALID_DATABASE_OPTIONS: + return f"db must be one of {VALID_DATABASE_OPTIONS}", 400 db_ = databases.get(db_name, db_playback) answer = db_.get_timeline_data( min(run_from, run_to), @@ -169,20 +199,32 @@ def dqm2_api(): ) return json.dumps(answer) elif what == "get_clients": - run_from = flask.request.args.get("from", default=0) - run_to = flask.request.args.get("to", default=0) - db_name = flask.request.args.get("db", default="") + try: + run_from = int(flask.request.args.get("from", type=int)) + run_to = int(flask.request.args.get("to", type=int)) + except (ValueError, TypeError): + return (f"to, from must be integers", 400) + db_name = flask.request.args.get("db", type=str) + if db_name not in VALID_DATABASE_OPTIONS: + return f"db must be one of {VALID_DATABASE_OPTIONS}", 400 db_ = databases.get(db_name, db_playback) answer = db_.get_clients(run_from, run_to) return json.dumps(answer) elif what == "get_info": - db_name = flask.request.args.get("db", default="") + db_name = flask.request.args.get("db", type=str) + if db_name not in VALID_DATABASE_OPTIONS: + return f"db must be one of {VALID_DATABASE_OPTIONS}", 400 db_ = databases.get(db_name, db_playback) answer = db_.get_info() return json.dumps(answer) elif what == "get_logs": - client_id = flask.request.args.get("id", default=0) - db_name = flask.request.args.get("db", default="") + try: + client_id = int(flask.request.args.get("id", type=int)) + except (ValueError, TypeError): + return (f"client_id must be an integer", 400) + db_name = flask.request.args.get("db", type=str) + if db_name not in VALID_DATABASE_OPTIONS: + return f"db must be one of {VALID_DATABASE_OPTIONS}", 400 db_ = databases.get(db_name, db_playback) answer = db_.get_logs(client_id) a1 = a2 = "" @@ -192,11 +234,14 @@ def dqm2_api(): a2 = "".join(eval(answer[1])) return "
" + a1 + "\n ... \n\n" + a2 + "" elif what == "get_cluster_status": - # WIP - cluster = flask.request.args.get("cluster", default="playback") + cluster = flask.request.args.get("cluster", default="playback", type=str) + if cluster not in VALID_DATABASE_OPTIONS: + return f"cluster must be one of {VALID_DATABASE_OPTIONS}", 400 db_ = databases.get(cluster, db_playback) answer = db_.get_cluster_status() return json.dumps(answer) + else: + return f"{what} is not supported", 404 ### TIMELINE ### @app.route(os.path.join("/", cfg["SERVER_URL_PREFIX"], "timeline/")) diff --git a/tests/fixtures.py b/tests/fixtures.py new file mode 100644 index 0000000..9f92f9e --- /dev/null +++ b/tests/fixtures.py @@ -0,0 +1,190 @@ +import os +import sys +import pickle +import pytest +import sqlalchemy +from sqlalchemy import create_engine, text +from sqlalchemy_utils import create_database, database_exists, drop_database +from sqlalchemy.exc import IntegrityError +from flask import Flask + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +import server +from custom_logger import dummy_log +from db import DQM2MirrorDB +from dqmsquare_cfg import load_cfg + +DB_PROD_NAME = "postgres_production_test" +DB_PLAY_NAME = "postgres_playback_test" + + +def get_or_create_db( + db_uri: str, username: str, password: str, host: str, port: int, db_name: str +): + engine: sqlalchemy.engine.Engine = create_engine(db_uri) + if not database_exists(engine.url): + print(f"Creating database {db_uri}") + create_database(db_uri) + return DQM2MirrorDB( + log=dummy_log(), + username=username, + password=password, + host=host, + port=port, + db_name=db_name, + server=False, + ) + + +def format_entry_to_db_entry(graph_entry: list, datetime_cols: list): + return_value = "" + for i, col in enumerate(graph_entry): + if i in datetime_cols: + return_value += f"'{col.isoformat()}', " + else: + return_value += ( + f"""{"'" + str(col).replace("'", "''") + "'" if col else 'NULL'}, """ + ) + + return return_value[:-2] + + +def fill_db(db: DQM2MirrorDB) -> None: + runs = [] + graphs = [] + runs_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "runs_data.pkl" + ) + graphs_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "graphs_data.pkl" + ) + + print(f"Filling database {db.db_uri} using files {runs_file}, {graphs_file}") + with open(runs_file, "rb") as f: + runs = pickle.load(f) + with open(graphs_file, "rb") as f: + graphs = pickle.load(f) + + with db.engine.connect() as cur: + session = db.Session(bind=cur) + for run in runs: + try: + session.execute( + text( + "INSERT into runs " + + f"""({str(db.TB_DESCRIPTION_RUNS_COLS).replace("[", "").replace("]", "").replace("'", "")}) """ + + f"""VALUES ({format_entry_to_db_entry(run, [13])})""" + ) + ) + session.commit() + except IntegrityError as e: + print(f"Skipping already inserted data", e) + continue + except Exception as e: + print("Error when creating run fixture:", e) + session.rollback() + + for graph in graphs: + try: + session.execute( + text( + "INSERT into graphs " + + f"""({str(db.TB_DESCRIPTION_GRAPHS_COLS).replace("[", "").replace("]", "").replace("'", "")}) """ + + f"""VALUES ({format_entry_to_db_entry(graph, [3, 4])})""" + ) + ) + session.commit() + except IntegrityError as e: + print(f"Skipping already inserted data") + continue + except Exception as e: + print("Error when creating graph fixture:", e) + session.rollback() + + +@pytest.fixture +def dqm2_config(): + # If there's a local .env file, it will be used to override the environment vars + config: dict = load_cfg() + config["DB_PRODUCTION_NAME"] = DB_PROD_NAME + config["DB_PLAYBACK_NAME"] = DB_PLAY_NAME + yield config + + +@pytest.fixture +def testing_database(): + username = os.environ.get("POSTGRES_USERNAME", "postgres") + password = os.environ.get("POSTGRES_PASSWORD", "postgres") + host = os.environ.get("POSTGRES_HOST", "127.0.0.1") + port = int(os.environ.get("POSTGRES_PORT", 5432)) + db_name = "postgres_test" + db_uri = DQM2MirrorDB.format_db_uri( + username=username, password=password, host=host, port=port, db_name=db_name + ) + + db: DQM2MirrorDB = get_or_create_db( + db_uri=db_uri, + username=username, + password=password, + host=host, + port=port, + db_name=db_name, + ) + + fill_db(db) + yield db + drop_database(db_uri) + + +@pytest.fixture +def testing_databases(): + username = os.environ.get("POSTGRES_USERNAME", "postgres") + password = os.environ.get("POSTGRES_PASSWORD", "postgres") + host = os.environ.get("POSTGRES_HOST", "127.0.0.1") + port: int = int(os.environ.get("POSTGRES_PORT", 5432)) + + db_prod_uri: str = DQM2MirrorDB.format_db_uri( + username=username, password=password, host=host, port=port, db_name=DB_PROD_NAME + ) + db_play_uri: str = DQM2MirrorDB.format_db_uri( + username=username, password=password, host=host, port=port, db_name=DB_PLAY_NAME + ) + db_prod: DQM2MirrorDB = get_or_create_db( + db_uri=db_prod_uri, + username=username, + password=password, + host=host, + port=port, + db_name=DB_PROD_NAME, + ) + db_play: DQM2MirrorDB = get_or_create_db( + db_uri=db_play_uri, + username=username, + password=password, + host=host, + port=port, + db_name=DB_PLAY_NAME, + ) + fill_db(db_prod) + fill_db(db_play) + yield db_prod, db_play + drop_database(db_prod_uri) + drop_database(db_play_uri) + + +@pytest.fixture +def app(dqm2_config: dict, testing_databases: tuple[DQM2MirrorDB, DQM2MirrorDB]): + app: Flask = server.create_app(dqm2_config) + app.config.update( + { + "TESTING": True, + } + ) + yield app + + +@pytest.fixture +def client(app: Flask): + """Returns a FlaskClient (An instance of :class:`flask.testing.TestClient`)""" + with app.test_client() as client: + yield client diff --git a/tests/test_2.py b/tests/test_2.py deleted file mode 100644 index 85dd676..0000000 --- a/tests/test_2.py +++ /dev/null @@ -1,95 +0,0 @@ -# tests to check Flask server -# TODO create tests using testing DB -import os -import sys - -sys.path.append(os.path.join(os.path.dirname(__file__), "..")) - -import pytest -from sqlalchemy import create_engine -from sqlalchemy_utils import create_database, database_exists, drop_database -import server -from db import DQM2MirrorDB -from custom_logger import dummy_log -from dqmsquare_cfg import format_db_uri, load_cfg - - -def get_or_create_db(db_uri: str): - engine = create_engine(db_uri) - if not database_exists(engine.url): - create_database(db_uri) - return DQM2MirrorDB( - log=dummy_log(), - db_uri=db_uri, - server=False, - ) - - -@pytest.fixture -def cfg(): - yield load_cfg() - - -@pytest.fixture -def testing_databases(): - db_uri_prod = format_db_uri( - username=os.environ.get("POSTGRES_USERNAME", "postgres"), - password=os.environ.get("POSTGRES_PASSWORD", "postgres"), - host=os.environ.get("POSTGRES_HOST", "127.0.0.1"), - port=os.environ.get("POSTGRES_PORT", 5432), - db_name=os.environ.get("POSTGRES_PRODUCTION_DB_NAME") + "_test", - ) - db_uri_playback = format_db_uri( - username=os.environ.get("POSTGRES_USERNAME", "postgres"), - password=os.environ.get("POSTGRES_PASSWORD", "postgres"), - host=os.environ.get("POSTGRES_HOST", "127.0.0.1"), - port=os.environ.get("POSTGRES_PORT", 5432), - db_name=os.environ.get("POSTGRES_PLAYBACK_DB_NAME") + "_test", - ) - - yield get_or_create_db(db_uri_prod), get_or_create_db(db_uri_playback) - drop_database(db_uri_prod) - drop_database(db_uri_playback) - - -@pytest.fixture -def app(cfg, testing_databases: list[DQM2MirrorDB]): - cfg_test = cfg - # Override databases for the test - cfg_test["DB_PRODUCTION_URI"] = testing_databases[0].db_uri - cfg_test["DB_PLAYBACK_URI"] = testing_databases[1].db_uri - print(cfg_test["DB_PRODUCTION_URI"], cfg_test["DB_PLAYBACK_URI"]) - - app = server.create_app(cfg_test) - app.config.update( - { - "TESTING": True, - } - ) - yield app - - -@pytest.fixture -def client(app): - """A Flask test client. An instance of :class:`flask.testing.TestClient` - by default. - """ - with app.test_client() as client: - yield client - - -def test_server_1(client): - response = client.get("/") - assert response.status_code == 200 - - -def test_server_2(client): - response = client.get("/timeline/") - # assert b"// DQM TIMELINE PAGE //" in response.data - assert response.status_code == 200 - - -def test_server_3(client): - response = client.get("/cr/") - assert response.status_code == 302 - assert response.headers.get("Location") == "/cr/login/" diff --git a/tests/test_1.py b/tests/test_db.py similarity index 71% rename from tests/test_1.py rename to tests/test_db.py index dbecea8..60923f8 100644 --- a/tests/test_1.py +++ b/tests/test_db.py @@ -1,91 +1,18 @@ -# tests to check DB +# Tests to check DB and data filling and fetching import os import sys import json sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -import pickle import pytest from truth_values import TEST_DB_7_TRUTH, TEST_DB_9_TRUTH from datetime import datetime from db import DQM2MirrorDB, DEFAULT_DATETIME -from sqlalchemy import create_engine, text -from sqlalchemy_utils import create_database, database_exists, drop_database -from custom_logger import dummy_log -from dqmsquare_cfg import format_db_uri, TZ +from sqlalchemy import text +from dqmsquare_cfg import TZ - -def format_entry_to_db_entry(graph_entry: list, datetime_cols: list): - return_value = "" - for i, col in enumerate(graph_entry): - if i in datetime_cols: - return_value += f"'{col.isoformat()}', " - else: - return_value += ( - f"""{"'" + str(col).replace("'", "''") + "'" if col else 'NULL'}, """ - ) - - return return_value[:-2] - - -@pytest.fixture -def testing_database() -> DQM2MirrorDB: - db_uri = format_db_uri( - username=os.environ.get("POSTGRES_USERNAME", "postgres"), - password=os.environ.get("POSTGRES_PASSWORD", "postgres"), - host=os.environ.get("POSTGRES_HOST", "127.0.0.1"), - port=os.environ.get("POSTGRES_PORT", 5432), - db_name="postgres_test", - ) - - engine = create_engine(db_uri) - if not database_exists(engine.url): - create_database(db_uri) - db = DQM2MirrorDB( - log=dummy_log(), - db_uri=db_uri, - server=False, - ) - - runs = [] - graphs = [] - with open( - os.path.join(os.path.dirname(os.path.realpath(__file__)), "runs_data.pkl"), "rb" - ) as f: - runs = pickle.load(f) - with open( - os.path.join(os.path.dirname(os.path.realpath(__file__)), "graphs_data.pkl"), - "rb", - ) as f: - graphs = pickle.load(f) - - with db.engine.connect() as cur: - session = db.Session(bind=cur) - for run in runs: - try: - session.execute( - text( - f"""INSERT into runs ({str(db.TB_DESCRIPTION_RUNS_COLS).replace("[", "").replace("]", "").replace("'", "")}) VALUES ({format_entry_to_db_entry(run, [13])})""" - ) - ) - session.commit() - except Exception as e: - print("Error when creating run fixture:", e) - session.rollback() - - for graph in graphs: - try: - session.execute( - text( - f"""INSERT into graphs ({str(db.TB_DESCRIPTION_GRAPHS_COLS).replace("[", "").replace("]", "").replace("'", "")}) VALUES ({format_entry_to_db_entry(graph, [3, 4])})""" - ) - ) - session.commit() - except Exception as e: - print("Error when creating graph fixture:", e) - session.rollback() - yield db - drop_database(db_uri) +# Needed to be passed as a fixture +from fixtures import testing_database pytest.production = [ diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..eb6e2a2 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,73 @@ +import json +from flask.testing import FlaskClient + +# Tests to check Flask server and its responses +# Needed to be passed as a fixture +from fixtures import app, dqm2_config, client, testing_databases, testing_database + + +def test_server_home(client: FlaskClient): + response = client.get("/") + assert response.status_code == 200 + + +def test_server_timeline(client: FlaskClient): + response = client.get("/timeline/") + # assert b"// DQM TIMELINE PAGE //" in response.data + assert response.status_code == 200 + + +def test_server_cr(client: FlaskClient): + response = client.get("/cr/") + assert response.status_code == 302 + assert response.headers.get("Location") == "/cr/login/" + + +def test_server_get_run(client: FlaskClient): + # Wrong GET argument + response = client.get("/api?what=get_run&run=381574&db_name=production") + assert response.status_code == 400 + + # Dumb string instead of number + response = client.get( + "/api?what=get_run&run=;DROP DATABASE postgres_test&db=production" + ) + assert response.status_code == 400 + + # Proper request + response = client.get("/api?what=get_run&run=381574&db=production") + assert response.status_code == 200 + response = json.loads(response.text) + assert isinstance(response, list) + assert response[0][0] == 358792 + assert response[0][1] == None + + response = client.get("/api?what=get_run&run=358791&db=production") + assert response.status_code == 200 + response = json.loads(response.text) + assert isinstance(response, list) + assert response[0][0] == 358788 + assert response[0][1] == 358792 + + +def test_server_get_clients(client: FlaskClient): + # Wrong GET argument + response = client.get( + "/api?what=get_clients&from=377187&to=381574&db_name=production" + ) + assert response.status_code == 400 + + # Dumb string instead of number. + response = client.get( + "/api?what=get_clients&client_id=;DROP DATABASE postgres_test&db=production" + ) + assert response.status_code == 400 + + # Proper request + response = client.get("/api?what=get_clients&from=358788&to=358792&db=production") + assert response.status_code == 200 + response = json.loads(response.text) + print(response) + assert isinstance(response, list) + assert response[0] == "beam" + assert response[-1] == "visualization-live-secondInstance"