From 3b728e257cc50ca3c84dec0fd3f6dad4b8560cc9 Mon Sep 17 00:00:00 2001 From: Jordan Woods <13803242+jorwoods@users.noreply.github.com> Date: Sat, 8 Feb 2025 21:41:51 -0600 Subject: [PATCH] fix: virtual connection ConnectionItem attributes Closes #1558 Connection XML element for VirtualConnections has different attribute keys compared to connection XML elements when returned by Datasources, Workbooks, and Flows. This PR adds in flexibility to ConnectionItem's reading of XML to account for both sets of attributes that may be present elements. --- tableauserverclient/models/connection_item.py | 6 +-- ...rtual_connection_populate_connections2.xml | 6 +++ test/test_virtual_connection.py | 40 +++++++++++-------- 3 files changed, 32 insertions(+), 20 deletions(-) create mode 100644 test/assets/virtual_connection_populate_connections2.xml diff --git a/tableauserverclient/models/connection_item.py b/tableauserverclient/models/connection_item.py index e68958c3..6a8244fb 100644 --- a/tableauserverclient/models/connection_item.py +++ b/tableauserverclient/models/connection_item.py @@ -103,11 +103,11 @@ def from_response(cls, resp, ns) -> list["ConnectionItem"]: all_connection_xml = parsed_response.findall(".//t:connection", namespaces=ns) for connection_xml in all_connection_xml: connection_item = cls() - connection_item._id = connection_xml.get("id", None) + connection_item._id = connection_xml.get("id", connection_xml.get("connectionId", None)) connection_item._connection_type = connection_xml.get("type", connection_xml.get("dbClass", None)) connection_item.embed_password = string_to_bool(connection_xml.get("embedPassword", "")) - connection_item.server_address = connection_xml.get("serverAddress", None) - connection_item.server_port = connection_xml.get("serverPort", None) + connection_item.server_address = connection_xml.get("serverAddress", connection_xml.get("server", None)) + connection_item.server_port = connection_xml.get("serverPort", connection_xml.get("port", None)) connection_item.username = connection_xml.get("userName", None) connection_item._query_tagging = ( string_to_bool(s) if (s := connection_xml.get("queryTagging", None)) else None diff --git a/test/assets/virtual_connection_populate_connections2.xml b/test/assets/virtual_connection_populate_connections2.xml new file mode 100644 index 00000000..f0ad2646 --- /dev/null +++ b/test/assets/virtual_connection_populate_connections2.xml @@ -0,0 +1,6 @@ + + + + + + diff --git a/test/test_virtual_connection.py b/test/test_virtual_connection.py index 975033d2..5d9a2d1b 100644 --- a/test/test_virtual_connection.py +++ b/test/test_virtual_connection.py @@ -2,6 +2,7 @@ from pathlib import Path import unittest +import pytest import requests_mock import tableauserverclient as TSC @@ -12,6 +13,7 @@ VIRTUAL_CONNECTION_GET_XML = ASSET_DIR / "virtual_connections_get.xml" VIRTUAL_CONNECTION_POPULATE_CONNECTIONS = ASSET_DIR / "virtual_connection_populate_connections.xml" +VIRTUAL_CONNECTION_POPULATE_CONNECTIONS2 = ASSET_DIR / "virtual_connection_populate_connections2.xml" VC_DB_CONN_UPDATE = ASSET_DIR / "virtual_connection_database_connection_update.xml" VIRTUAL_CONNECTION_DOWNLOAD = ASSET_DIR / "virtual_connections_download.xml" VIRTUAL_CONNECTION_UPDATE = ASSET_DIR / "virtual_connections_update.xml" @@ -54,23 +56,27 @@ def test_virtual_connection_get(self): assert items[0].name == "vconn" def test_virtual_connection_populate_connections(self): - vconn = VirtualConnectionItem("vconn") - vconn._id = "8fd7cc02-bb55-4d15-b8b1-9650239efe79" - with requests_mock.mock() as m: - m.get(f"{self.baseurl}/{vconn.id}/connections", text=VIRTUAL_CONNECTION_POPULATE_CONNECTIONS.read_text()) - vc_out = self.server.virtual_connections.populate_connections(vconn) - connection_list = list(vconn.connections) - - assert vc_out is vconn - assert vc_out._connections is not None - - assert len(connection_list) == 1 - connection = connection_list[0] - assert connection.id == "37ca6ced-58d7-4dcf-99dc-f0a85223cbef" - assert connection.connection_type == "postgres" - assert connection.server_address == "localhost" - assert connection.server_port == "5432" - assert connection.username == "pgadmin" + for i, populate_connections_xml in enumerate( + (VIRTUAL_CONNECTION_POPULATE_CONNECTIONS, VIRTUAL_CONNECTION_POPULATE_CONNECTIONS2) + ): + with self.subTest(i): + vconn = VirtualConnectionItem("vconn") + vconn._id = "8fd7cc02-bb55-4d15-b8b1-9650239efe79" + with requests_mock.mock() as m: + m.get(f"{self.baseurl}/{vconn.id}/connections", text=populate_connections_xml.read_text()) + vc_out = self.server.virtual_connections.populate_connections(vconn) + connection_list = list(vconn.connections) + + assert vc_out is vconn + assert vc_out._connections is not None + + assert len(connection_list) == 1 + connection = connection_list[0] + assert connection.id == "37ca6ced-58d7-4dcf-99dc-f0a85223cbef" + assert connection.connection_type == "postgres" + assert connection.server_address == "localhost" + assert connection.server_port == "5432" + assert connection.username == "pgadmin" def test_virtual_connection_update_connection_db_connection(self): vconn = VirtualConnectionItem("vconn")