From 067377b99caa23c88d4b8a712ae1ee3040a002fb Mon Sep 17 00:00:00 2001 From: Fran Boon Date: Mon, 26 Aug 2024 09:57:29 +0100 Subject: [PATCH] [Async]HamiltonTracker support passing in custom CA cert --- ui/sdk/src/hamilton_sdk/adapters.py | 11 ++++-- ui/sdk/src/hamilton_sdk/api/clients.py | 53 +++++++++++++++++++++++--- 2 files changed, 54 insertions(+), 10 deletions(-) diff --git a/ui/sdk/src/hamilton_sdk/adapters.py b/ui/sdk/src/hamilton_sdk/adapters.py index b74703a70..e0106f4ef 100644 --- a/ui/sdk/src/hamilton_sdk/adapters.py +++ b/ui/sdk/src/hamilton_sdk/adapters.py @@ -45,11 +45,12 @@ def __init__( dag_name: str, tags: Dict[str, str] = None, client_factory: Callable[ - [str, str, str], clients.HamiltonClient + [str, str, str, str | bool], clients.HamiltonClient ] = clients.BasicSynchronousHamiltonClient, api_key: str = None, hamilton_api_url=os.environ.get("HAMILTON_API_URL", constants.HAMILTON_API_URL), hamilton_ui_url=os.environ.get("HAMILTON_UI_URL", constants.HAMILTON_UI_URL), + verify: str | bool = True, ): """This hooks into Hamilton execution to track DAG runs in Hamilton UI. @@ -61,11 +62,12 @@ def __init__( :param api_key: the API key to use. See us if you want to use this. :param hamilton_api_url: API endpoint. :param hamilton_ui_url: UI Endpoint. + :param verify: SSL verification to pass-through to requests """ self.project_id = project_id self.api_key = api_key self.username = username - self.client = client_factory(api_key, username, hamilton_api_url) + self.client = client_factory(api_key, username, hamilton_api_url, verify=verify) self.initialized = False self.project_version = None self.base_tags = tags if tags is not None else {} @@ -387,16 +389,17 @@ def __init__( dag_name: str, tags: Dict[str, str] = None, client_factory: Callable[ - [str, str, str], clients.BasicAsynchronousHamiltonClient + [str, str, str, str | bool], clients.BasicAsynchronousHamiltonClient ] = clients.BasicAsynchronousHamiltonClient, api_key: str = os.environ.get("HAMILTON_API_KEY", ""), hamilton_api_url=os.environ.get("HAMILTON_API_URL", constants.HAMILTON_API_URL), hamilton_ui_url=os.environ.get("HAMILTON_UI_URL", constants.HAMILTON_UI_URL), + verify: str | bool = True, ): self.project_id = project_id self.api_key = api_key self.username = username - self.client = client_factory(api_key, username, hamilton_api_url) + self.client = client_factory(api_key, username, hamilton_api_url, verify=verify) self.initialized = False self.project_version = None self.base_tags = tags if tags is not None else {} diff --git a/ui/sdk/src/hamilton_sdk/api/clients.py b/ui/sdk/src/hamilton_sdk/api/clients.py index 9a43a67dd..f292ada12 100644 --- a/ui/sdk/src/hamilton_sdk/api/clients.py +++ b/ui/sdk/src/hamilton_sdk/api/clients.py @@ -4,6 +4,7 @@ import functools import logging import queue +import ssl import threading import time from collections import defaultdict @@ -167,6 +168,7 @@ def __init__( username: str, h_api_url: str, base_path: str = "/api/v1", + verify: str | bool = True, ): """Initializes a Hamilton API client @@ -174,10 +176,13 @@ def __init__( :param api_key: API key to save to :param username: Username to authenticate against :param h_api_url: API URL for Hamilton API. + :param base_path: + :param verify: SSL verification to pass-through to requests """ self.api_key = api_key self.username = username self.base_url = h_api_url + base_path + self.verify = verify self.max_batch_size = 100 self.flush_interval = 5 @@ -256,6 +261,7 @@ def flush(self, batch): "task_updates": make_json_safe(task_updates_list), }, headers=self._common_headers(), + verify=self.verify, ) try: response.raise_for_status() @@ -290,7 +296,9 @@ def _common_headers(self) -> Dict[str, Any]: def validate_auth(self): logger.debug(f"Validating auth against {self.base_url}/phone_home") - response = requests.get(f"{self.base_url}/phone_home", headers=self._common_headers()) + response = requests.get( + f"{self.base_url}/phone_home", headers=self._common_headers(), verify=self.verify + ) try: response.raise_for_status() logger.debug(f"Successfully validated auth against {self.base_url}/phone_home") @@ -311,6 +319,7 @@ def register_code_version_if_not_exists( response = requests.get( f"{self.base_url}/project_versions/exists?project_id={project_id}&code_hash={code_hash}", headers=self._common_headers(), + verify=self.verify, ) try: response.raise_for_status() @@ -343,6 +352,7 @@ def register_code_version_if_not_exists( "version_info_schema": 1, # TODO -- wire this through appropriately "code_log": {"files": code_slurped}, }, + verify=self.verify, ) try: code_version_created.raise_for_status() @@ -358,7 +368,9 @@ def register_code_version_if_not_exists( def project_exists(self, project_id: int) -> bool: logger.debug(f"Checking if project {project_id} exists") response = requests.get( - f"{self.base_url}/projects/{project_id}", headers=self._common_headers() + f"{self.base_url}/projects/{project_id}", + headers=self._common_headers(), + verify=self.verify, ) try: response.raise_for_status() @@ -401,6 +413,7 @@ def register_dag_template_if_not_exists( response = requests.get( f"{self.base_url}/dag_templates/exists/?dag_hash={dag_hash}&{params}", headers=self._common_headers(), + verify=self.verify, ) response.raise_for_status() logger.debug(f"DAG template {dag_hash} exists for project {project_id}") @@ -430,6 +443,7 @@ def register_dag_template_if_not_exists( "code_version_info_schema": 1, }, headers=self._common_headers(), + verify=self.verify, ) try: dag_template_created.raise_for_status() @@ -458,6 +472,7 @@ def create_and_start_dag_run( "run_status": "RUNNING", } ), + verify=self.verify, ) try: response.raise_for_status() @@ -498,6 +513,7 @@ def log_dag_run_end(self, dag_run_id: int, status: str): f"{self.base_url}/dag_runs/{dag_run_id}/", json=make_json_safe({"run_status": status, "run_end_time": datetime.datetime.utcnow()}), headers=self._common_headers(), + verify=self.verify, ) try: response.raise_for_status() @@ -508,17 +524,32 @@ def log_dag_run_end(self, dag_run_id: int, status: str): class BasicAsynchronousHamiltonClient(HamiltonClient): - def __init__(self, api_key: str, username: str, h_api_url: str, base_path: str = "/api/v1"): + def __init__( + self, + api_key: str, + username: str, + h_api_url: str, + base_path: str = "/api/v1", + verify: str | bool = True, + ): """Initializes an async Hamilton API client project: Project to save to :param api_key: API key to save to :param username: Username to authenticate against :param h_api_url: API URL for Hamilton API. + :param base_path: + :param verify: SSL verification options in requests format """ self.api_key = api_key self.username = username self.base_url = h_api_url + base_path + if verify is True: + self.ssl = True + elif verify is False: + self.ssl = False + else: + self.ssl = ssl.create_default_context(cafile=verify) self.flush_interval = 5 self.data_queue = asyncio.Queue() self.running = True @@ -542,6 +573,7 @@ async def flush(self, batch): "task_updates": make_json_safe(task_updates_list), }, headers=self._common_headers(), + ssl=self.ssl, ) as response: try: response.raise_for_status() @@ -590,7 +622,9 @@ async def validate_auth(self): logger.debug(f"Validating auth against {self.base_url}/phone_home") async with aiohttp.ClientSession() as session: async with session.get( - f"{self.base_url}/phone_home", headers=self._common_headers() + f"{self.base_url}/phone_home", + headers=self._common_headers(), + ssl=self.ssl, ) as response: try: response.raise_for_status() @@ -613,6 +647,7 @@ async def register_code_version_if_not_exists( async with session.get( f"{self.base_url}/project_versions/exists?project_id={project_id}&code_hash={code_hash}", headers=self._common_headers(), + ssl=self.ssl, ) as response: try: response.raise_for_status() @@ -648,6 +683,7 @@ async def register_code_version_if_not_exists( "version_info_schema": 1, # TODO -- wire this through appropriately "code_log": {"files": code_slurped}, }, + ssl=self.ssl, ) as response: try: response.raise_for_status() @@ -664,7 +700,9 @@ async def project_exists(self, project_id: int) -> bool: logger.debug(f"Checking if project {project_id} exists") async with aiohttp.ClientSession() as session: async with session.get( - f"{self.base_url}/projects/{project_id}", headers=self._common_headers() + f"{self.base_url}/projects/{project_id}", + headers=self._common_headers(), + ssl=self.ssl, ) as response: try: response.raise_for_status() @@ -706,6 +744,7 @@ async def register_dag_template_if_not_exists( async with session.get( f"{self.base_url}/dag_templates/exists/?dag_hash={dag_hash}&{params}", headers=self._common_headers(), + ssl=self.ssl, ) as response: try: response.raise_for_status() @@ -741,6 +780,7 @@ async def register_dag_template_if_not_exists( "code_version_info_schema": 1, }, headers=self._common_headers(), + ssl=self.ssl, ) as response: try: response.raise_for_status() @@ -770,6 +810,7 @@ async def create_and_start_dag_run( } ), headers=self._common_headers(), + ssl=self.ssl, ) as response: try: response.raise_for_status() @@ -802,7 +843,7 @@ async def log_dag_run_end(self, dag_run_id: int, status: str): data = make_json_safe({"run_status": status, "run_end_time": datetime.datetime.utcnow()}) headers = self._common_headers() async with aiohttp.ClientSession() as session: - async with session.put(url, json=data, headers=headers) as response: + async with session.put(url, json=data, headers=headers, ssl=self.ssl) as response: try: response.raise_for_status() logger.debug(f"Logged end of DAG run {dag_run_id}")