Skip to content

Commit

Permalink
Add CI check for incorrect catalog updates
Browse files Browse the repository at this point in the history
There are certain modifications/commands that should not be allowed
in our update/downgrade scripts. For example, when adding or dropping
columns to timescaledb catalog tables, the right way to do this is to
drop and recreate the table with the desired definition instead of doing
ALTER TABLE ... ADD/DROP COLUMN. This is required to ensure consistent
attribute numbers across versions.
This workflow detects this and some other incorrect catalog table
modifications and fails with an error in that case.

Fixes #6049
  • Loading branch information
konskov committed Oct 2, 2023
1 parent 646950f commit f505bd2
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 0 deletions.
33 changes: 33 additions & 0 deletions .github/workflows/catalog-updates-check.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Check for unsafe catalog updates
"on":
pull_request:
types: [opened, synchronize, reopened, edited]
jobs:
check_catalog_correctly_updated:
name: Check updates to latest-dev and reverse-dev are properly handled by PR
runs-on: ubuntu-latest
steps:
- name: Checkout source
uses: actions/checkout@v3
with:
ref: ${{ github.event.pull_request.head.sha }}
fetch-depth: 0
- name: Install pglast
run: |
python -m pip install pglast
- name: Check latest-dev contents
shell: bash {0}
id: check_latestdev
run: |
if ! python scripts/check_updates_ast.py "sql/updates/latest-dev.sql"; then
exit 1
fi
exit 0
- name: Check reverse-dev contents
if: always()
shell: bash {0}
run: |
if ! python scripts/check_updates_ast.py "sql/updates/reverse-dev.sql"; then
exit 1
fi
exit 0
1 change: 1 addition & 0 deletions .unreleased/enhancement_6049
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implements: #6130 Add CI check for incorrect catalog updates
111 changes: 111 additions & 0 deletions scripts/check_updates_ast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from pglast import parse_sql
from pglast.visitors import Visitor
from pglast import enums
import sys
import re


class SQLVisitor(Visitor):
def __init__(self):
self.errors = 0
self.catalog_schemata = [
"_timescaledb_catalog",
"_timescaledb_config",
"_timescaledb_internal",
]
super().__init__()

# ALTER TABLE _timescaledb_catalog.<tablename> ADD/DROP COLUMN
def visit_AlterTableStmt(self, ancestors, node):
if (
"schemaname" in node.relation
and node.relation.schemaname in self.catalog_schemata
):
schema = node.relation.schemaname
table = node.relation.relname
for cmd in node.cmds:
if cmd.subtype in (
enums.AlterTableType.AT_AddColumn,
enums.AlterTableType.AT_DropColumn,
):
self.errors += 1
if cmd.subtype == enums.AlterTableType.AT_AddColumn:
column = cmd.def_.colname
print(
f"ERROR: Attempting to ADD COLUMN {column} to catalog table {schema}.{table}"
)
else:
column = cmd.name
print(
f"ERROR: Attempting to DROP COLUMN {column} from catalog table {schema}.{table}"
)

# ALTER TABLE _timescaledb_catalog.<tablename> RENAME TO
def visit_RenameStmt(self, ancestors, node):
if (
node.renameType == enums.ObjectType.OBJECT_TABLE
and node.relation.schemaname in self.catalog_schemata
):
self.errors += 1
print(
f"ERROR: Attempting to RENAME catalog table {node.relation.schemaname}.{node.relation.relname}"
)

# CREATE TEMP | TEMPORARY TABLE ..
def visit_CreateStmt(self, ancestors, node):
if node.relation.relpersistence == "t":
self.errors += 1
schema = (
node.relation.schemaname + "."
if node.relation.schemaname is not None
else ""
)
print(
f"ERROR: Attempting to CREATE TEMPORARY TABLE {schema}{node.relation.relname}"
)

# CREATE FUNCTION / PROCEDURE _timescaledb_internal...
def visit_CreateFunctionStmt(self, ancestors, node):
if len(node.funcname) == 2 and node.funcname[0].sval == "_timescaledb_internal":
self.errors += 1
functype = "procedure" if node.is_procedure else "function"
print(
f"ERROR: Attempting to create {functype} {node.funcname[1].sval} in the internal schema"
)


# copied from pgspot
def visit_sql(sql):
# @extschema@ is placeholder in extension scripts for
# the schema the extension gets installed in
sql = sql.replace("@extschema@", "extschema")
sql = sql.replace("@extowner@", "extowner")
sql = sql.replace("@database_owner@", "database_owner")
# postgres contrib modules are protected by psql meta commands to
# prevent running extension files in psql.
# The SQL parser will error on those since they are not valid
# SQL, so we comment out all psql meta commands before parsing.
sql = re.sub(r"^\\", "-- \\\\", sql, flags=re.MULTILINE)

visitor = SQLVisitor()
# try:
for stmt in parse_sql(sql):
visitor(stmt)
return visitor.errors


def main():
file = sys.argv[1]
with open(file, "r", encoding="utf-8") as f:
sql = f.read()
errors = visit_sql(sql)
if errors > 0:
numbering = "errors" if errors > 1 else "error"
print(f"{errors} {numbering} detected in file {file}")
sys.exit(1)
sys.exit(0)


if __name__ == "__main__":
main()
sys.exit(0)

0 comments on commit f505bd2

Please sign in to comment.