diff --git a/.gitignore b/.gitignore index efc628ec5..95031f272 100644 --- a/.gitignore +++ b/.gitignore @@ -67,4 +67,5 @@ target examples/**/build/ # Performance test results -**/performance-tests/results/ \ No newline at end of file +**/performance-tests/results/ + diff --git a/contract-tests/images/applications/psycopg2/psycopg2_server.py b/contract-tests/images/applications/psycopg2/psycopg2_server.py index 4f4acbde9..04cb0b9f8 100644 --- a/contract-tests/images/applications/psycopg2/psycopg2_server.py +++ b/contract-tests/images/applications/psycopg2/psycopg2_server.py @@ -10,9 +10,10 @@ from typing_extensions import override _PORT: int = 8080 -_SUCCESS: str = "success" +_DROP_TABLE: str = "drop_table" _ERROR: str = "error" _FAULT: str = "fault" +_CREATE_DATABASE: str = "create_database" _DB_HOST = os.getenv("DB_HOST") _DB_USER = os.getenv("DB_USER") @@ -26,11 +27,17 @@ class RequestHandler(BaseHTTPRequestHandler): def do_GET(self): status_code: int = 200 conn = psycopg2.connect(dbname=_DB_NAME, user=_DB_USER, password=_DB_PASS, host=_DB_HOST) - if self.in_path(_SUCCESS): + conn.autocommit = True # CREATE DATABASE cannot run in a transaction block + if self.in_path(_DROP_TABLE): cur = conn.cursor() cur.execute("DROP TABLE IF EXISTS test_table") cur.close() status_code = 200 + elif self.in_path(_CREATE_DATABASE): + cur = conn.cursor() + cur.execute("CREATE DATABASE test_database") + cur.close() + status_code = 200 elif self.in_path(_FAULT): cur = conn.cursor() try: diff --git a/contract-tests/tests/test/amazon/psycopg2/psycopg2_test.py b/contract-tests/tests/test/amazon/psycopg2/psycopg2_test.py index 809041899..a6e619b97 100644 --- a/contract-tests/tests/test/amazon/psycopg2/psycopg2_test.py +++ b/contract-tests/tests/test/amazon/psycopg2/psycopg2_test.py @@ -51,9 +51,13 @@ def get_application_extra_environment_variables(self) -> Dict[str, str]: def get_application_image_name(self) -> str: return "aws-application-signals-tests-psycopg2-app" - def test_success(self) -> None: + def test_drop_table_succeeds(self) -> None: self.mock_collector_client.clear_signals() - self.do_test_requests("success", "GET", 200, 0, 0, sql_command="DROP TABLE") + self.do_test_requests("drop_table", "GET", 200, 0, 0, sql_command="DROP TABLE") + + def test_create_database_succeeds(self) -> None: + self.mock_collector_client.clear_signals() + self.do_test_requests("create_database", "GET", 200, 0, 0, sql_command="CREATE DATABASE") def test_fault(self) -> None: self.mock_collector_client.clear_signals()