Skip to content

Commit

Permalink
Properly use GET args in SQL query
Browse files Browse the repository at this point in the history
  • Loading branch information
nothingface0 committed Jun 12, 2024
1 parent b0d1692 commit bfa0d46
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 55 deletions.
96 changes: 55 additions & 41 deletions db.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,11 @@ def fill_graph(self, header: dict, document: dict) -> int:
self.log.debug("No 'extra' key found in document")
return

id = header.get("_id")
_id = header.get("_id")
run = header.get("run", None)
if not run:
self.log.warning(
"\n\n DQM2MirrorDB.fill_graph(): no 'run' for header id '%s'" % (id)
"\n\n DQM2MirrorDB.fill_graph(): no 'run' for header id '%s'" % (_id)
)
return

Expand Down Expand Up @@ -333,17 +333,18 @@ def fill_graph(self, header: dict, document: dict) -> int:
self.log.warning(
f"Could not parse {timestamp} as a timestamp. Error: '{repr(e)}'"
)
values = [run, rev, id, timestamp, global_start, stream_data, hostname]
values = [run, rev, _id, timestamp, global_start, stream_data, hostname]
values_dic = {}
for val, name in zip(values, self.TB_DESCRIPTION_GRAPHS_COLS):
values_dic[name] = val

with self.engine.connect() as cur:
session = self.Session(bind=cur)
try:
session.execute(
text(f"DELETE FROM {self.TB_NAME_GRAPHS} WHERE id = '{str(id)}';")
)
with self.engine.connect() as cur:
cur.execute(
text(f"DELETE FROM {self.TB_NAME_GRAPHS} WHERE id=:id;"), id=_id
)
session.execute(
sqlalchemy.insert(self.db_meta.tables[self.TB_NAME_GRAPHS]).values(
values_dic
Expand All @@ -361,12 +362,13 @@ def get_graphs_data(self, run: int) -> list:
"""
Load graph data for a specific run
"""
self.log.debug("DQM2MirrorDB.get_graphs_data() - " + str(run))
self.log.debug(f"DQM2MirrorDB.get_graphs_data() - Run {run}")
with self.engine.connect() as cur:
answer = cur.execute(
text(
f"SELECT * FROM {self.TB_NAME_GRAPHS} WHERE CAST(run as INTEGER) = {str(run)};"
)
f"SELECT * FROM {self.TB_NAME_GRAPHS} WHERE CAST(run as INTEGER)=:run;"
),
run=run,
).all()
if not len(answer):
return []
Expand Down Expand Up @@ -449,9 +451,10 @@ def fill_run(self, header: dict, document: dict) -> int:
with self.engine.connect() as cur:
session = self.Session(bind=cur)
try:
session.execute(
text(f"DELETE FROM {self.TB_NAME_RUNS} WHERE id = '{id}';")
)
with self.engine.connect() as cur:
cur.execute(
text(f"DELETE FROM {self.TB_NAME_RUNS} WHERE id=:id;"), id=id
)
session.execute(
sqlalchemy.insert(self.db_meta.tables[self.TB_NAME_RUNS]).values(
values_dic
Expand All @@ -472,9 +475,7 @@ def fill_run(self, header: dict, document: dict) -> int:
old_min_max = [999999999, -1]
with self.engine.connect() as cur:
answer = cur.execute(
text(
f"SELECT data FROM {self.TB_NAME_META} WHERE name = 'min_max_runs';"
)
text(f"SELECT data FROM {self.TB_NAME_META} WHERE name='min_max_runs';")
).all()
if answer:
old_min_max = eval(answer[0][0])
Expand Down Expand Up @@ -507,8 +508,9 @@ def fill_cluster_status(self, cluster_status: dict):
assert "msg" in status
result = cur.execute(
text(
f"SELECT id FROM {self.TB_NAME_HOST_NAME} WHERE name='{hostname}'"
)
f"SELECT id FROM {self.TB_NAME_HOST_NAME} WHERE name=:hostname"
),
hostname=hostname,
).all()

if len(result) == 0:
Expand All @@ -519,8 +521,9 @@ def fill_cluster_status(self, cluster_status: dict):
)
result = cur.execute(
text(
f"SELECT id FROM {self.TB_NAME_HOST_NAME} WHERE name='{hostname}'"
)
f"SELECT id FROM {self.TB_NAME_HOST_NAME} WHERE name=:hostname"
),
hostname=hostname,
).all()

host_id = result[0][0] # id of host in db
Expand Down Expand Up @@ -558,15 +561,18 @@ def get_run(
answer = cur.execute(
text(
f"SET TIMEZONE = '{TIMEZONE}'; SELECT {self.TB_DESCRIPTION_RUNS_COLS_NOLOGS} FROM {self.TB_NAME_RUNS} "
+ f"WHERE run = {run_start} {postfix} ORDER BY client, id;"
)
+ f"WHERE run=:run_start {postfix} ORDER BY client, id;"
),
run_start=run_start,
).all()
else:
answer = cur.execute(
text(
f"SET TIMEZONE = '{TIMEZONE}'; SELECT {self.TB_DESCRIPTION_RUNS_COLS_NOLOGS} FROM {self.TB_NAME_RUNS} "
+ f"WHERE run BETWEEN {run_start} AND {run_end} {postfix};"
)
+ f"WHERE run BETWEEN :run_start AND :run_end {postfix};"
),
run_start=run_start,
run_end=run_end,
).all()
self.log.debug(f"Read DB for runs {run_start}-{run_end}: {answer}")
return answer
Expand Down Expand Up @@ -619,7 +625,7 @@ def get_cluster_status(
with self.engine.connect() as cur:
answer = cur.execute(
text(
f"SELECT hostnames.name, hoststatuses.is_up, hoststatuses.message, max(hoststatuses.created_at) "
"SELECT hostnames.name, hoststatuses.is_up, hoststatuses.message, max(hoststatuses.created_at) "
+ f"FROM {self.TB_NAME_HOST_STATUS} "
+ f"INNER JOIN {self.TB_NAME_HOST_NAME} "
+ "ON hoststatuses.host_id = hostnames.id "
Expand All @@ -639,9 +645,10 @@ def get_clients(self, run_start: int, run_end: int) -> list:
with self.engine.connect() as cur:
answer = cur.execute(
text(
f"SELECT DISTINCT client FROM {self.TB_NAME_RUNS} "
+ f"WHERE run BETWEEN {run_start} AND {run_end} ORDER BY client;"
)
f"SELECT DISTINCT client FROM {self.TB_NAME_RUNS} WHERE run BETWEEN :run_start AND :run_end ORDER BY client;"
),
run_start=run_start,
run_end=run_end,
).all()
answer = [
get_short_client_name(name[0]) for name in answer if filter_clients(name[0])
Expand All @@ -661,12 +668,14 @@ def update_min_max(self, new_min: int, new_max: int):
f"DELETE FROM {self.TB_NAME_META} WHERE name = 'min_max_runs';"
)
)
session.execute(
text(
f"INSERT INTO {self.TB_NAME_META} {self.TB_DESCRIPTION_META_SHORT} "
+ f"VALUES('min_max_runs', '[{new_min},{new_max}]');"
with self.engine.connect() as cur:
cur.execute(
text(
f"INSERT INTO {self.TB_NAME_META} {self.TB_DESCRIPTION_META_SHORT} VALUES('min_max_runs', '[:min,:max]');"
),
min=new_min,
max=new_max,
)
)
session.commit()
except Exception as e:
self.log.error("Error occurred: ", e)
Expand Down Expand Up @@ -714,27 +723,30 @@ def get_latest_revision(self, host: str) -> int:
if "fu" in host:
answer = cur.execute(
text(
f"SELECT MAX(rev) FROM {self.TB_NAME_RUNS} WHERE hostname = '{host}';"
)
f"SELECT MAX(rev) FROM {self.TB_NAME_RUNS} WHERE hostname=:host;"
),
host=host,
).all()
answer = list(answer[0])
return answer[0]
else:
answer = cur.execute(
text(
f"SELECT MAX(rev) FROM {self.TB_NAME_GRAPHS} WHERE hostname = '{host}';"
)
f"SELECT MAX(rev) FROM {self.TB_NAME_GRAPHS} WHERE hostname=:host;"
),
host=host,
).all()
answer = list(answer[0])
return answer[0]

def get_logs(self, client_id: int) -> list:
def get_logs(self, client_id: str) -> list:
self.log.debug("DQM2MirrorDB.get_logs()")
with self.engine.connect() as cur:
answer = cur.execute(
text(
f"SELECT stdlog_start, stdlog_end FROM {self.TB_NAME_RUNS} WHERE id = '{client_id}';"
)
f"SELECT stdlog_start, stdlog_end FROM {self.TB_NAME_RUNS} WHERE id=:client_id;"
),
client_id=client_id,
).all()
if not answer:
answer = ["None", "None"]
Expand All @@ -751,8 +763,10 @@ def get_runs_around(self, run: int) -> list:
with self.engine.connect() as cur:
answer = cur.execute(
text(
f"SELECT min(run) FROM {self.TB_NAME_RUNS} WHERE run > {run} union SELECT max(run) FROM {self.TB_NAME_RUNS} WHERE run < {run};"
)
f"SELECT min(run) FROM {self.TB_NAME_RUNS} WHERE run>:run_number union "
+ f"SELECT max(run) FROM {self.TB_NAME_RUNS} WHERE run<:run_number;"
),
run_number=run,
).all()
answer = [item[0] for item in answer]
return answer
11 changes: 4 additions & 7 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def create_app(cfg: dict):
cr_usernames = {}
# Read in CR credentials from env var
try:
username, password = os.environ.get("DQM_CR_USERNAMES").split(":")
username, password = os.environ.get("DQM_CR_USERNAMES", "").split(":")
cr_usernames = {username: password}
except Exception as e:
log.error(
Expand Down Expand Up @@ -194,8 +194,8 @@ def dqm2_api():
answer = db_.get_timeline_data(
min(run_from, run_to),
max(run_from, run_to),
int(bad_only),
int(with_ls_only),
bad_only,
with_ls_only,
)
return json.dumps(answer)
elif what == "get_clients":
Expand All @@ -218,10 +218,7 @@ def dqm2_api():
answer = db_.get_info()
return json.dumps(answer)
elif what == "get_logs":
try:
client_id = int(flask.request.args.get("id", type=int))
except (ValueError, TypeError):
return (f"client_id must be an integer", 400)
client_id = flask.request.args.get("id", type=str)
db_name = flask.request.args.get("db", type=str)
if db_name not in VALID_DATABASE_OPTIONS:
return f"db must be one of {VALID_DATABASE_OPTIONS}", 400
Expand Down
1 change: 1 addition & 0 deletions tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def test_db_10(testing_database: DQM2MirrorDB):
document = {"extra": {None: None}}
testing_database.fill_graph(header, document)
answer = testing_database.get_graphs_data(123456)
print(answer)
assert all([c1 == c2 for c1, c2 in zip(truth, answer)])


Expand Down
7 changes: 0 additions & 7 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,10 @@ def test_server_get_clients(client: FlaskClient):
)
assert response.status_code == 400

# Dumb string instead of number.
response = client.get(
"/api?what=get_clients&client_id=;DROP DATABASE postgres_test&db=production"
)
assert response.status_code == 400

# Proper request
response = client.get("/api?what=get_clients&from=358788&to=358792&db=production")
assert response.status_code == 200
response = json.loads(response.text)
print(response)
assert isinstance(response, list)
assert response[0] == "beam"
assert response[-1] == "visualization-live-secondInstance"

0 comments on commit bfa0d46

Please sign in to comment.