Skip to content

Commit

Permalink
[Async]HamiltonTracker support passing in custom CA cert
Browse files Browse the repository at this point in the history
  • Loading branch information
flavour committed Aug 26, 2024
1 parent 989115a commit 067377b
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 10 deletions.
11 changes: 7 additions & 4 deletions ui/sdk/src/hamilton_sdk/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ def __init__(
dag_name: str,
tags: Dict[str, str] = None,
client_factory: Callable[
[str, str, str], clients.HamiltonClient
[str, str, str, str | bool], clients.HamiltonClient
] = clients.BasicSynchronousHamiltonClient,
api_key: str = None,
hamilton_api_url=os.environ.get("HAMILTON_API_URL", constants.HAMILTON_API_URL),
hamilton_ui_url=os.environ.get("HAMILTON_UI_URL", constants.HAMILTON_UI_URL),
verify: str | bool = True,
):
"""This hooks into Hamilton execution to track DAG runs in Hamilton UI.
Expand All @@ -61,11 +62,12 @@ def __init__(
:param api_key: the API key to use. See us if you want to use this.
:param hamilton_api_url: API endpoint.
:param hamilton_ui_url: UI Endpoint.
:param verify: SSL verification to pass-through to requests
"""
self.project_id = project_id
self.api_key = api_key
self.username = username
self.client = client_factory(api_key, username, hamilton_api_url)
self.client = client_factory(api_key, username, hamilton_api_url, verify=verify)
self.initialized = False
self.project_version = None
self.base_tags = tags if tags is not None else {}
Expand Down Expand Up @@ -387,16 +389,17 @@ def __init__(
dag_name: str,
tags: Dict[str, str] = None,
client_factory: Callable[
[str, str, str], clients.BasicAsynchronousHamiltonClient
[str, str, str, str | bool], clients.BasicAsynchronousHamiltonClient
] = clients.BasicAsynchronousHamiltonClient,
api_key: str = os.environ.get("HAMILTON_API_KEY", ""),
hamilton_api_url=os.environ.get("HAMILTON_API_URL", constants.HAMILTON_API_URL),
hamilton_ui_url=os.environ.get("HAMILTON_UI_URL", constants.HAMILTON_UI_URL),
verify: str | bool = True,
):
self.project_id = project_id
self.api_key = api_key
self.username = username
self.client = client_factory(api_key, username, hamilton_api_url)
self.client = client_factory(api_key, username, hamilton_api_url, verify=verify)
self.initialized = False
self.project_version = None
self.base_tags = tags if tags is not None else {}
Expand Down
53 changes: 47 additions & 6 deletions ui/sdk/src/hamilton_sdk/api/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import logging
import queue
import ssl
import threading
import time
from collections import defaultdict
Expand Down Expand Up @@ -167,17 +168,21 @@ def __init__(
username: str,
h_api_url: str,
base_path: str = "/api/v1",
verify: str | bool = True,
):
"""Initializes a Hamilton API client
project: Project to save to
:param api_key: API key to save to
:param username: Username to authenticate against
:param h_api_url: API URL for Hamilton API.
:param base_path:
:param verify: SSL verification to pass-through to requests
"""
self.api_key = api_key
self.username = username
self.base_url = h_api_url + base_path
self.verify = verify

self.max_batch_size = 100
self.flush_interval = 5
Expand Down Expand Up @@ -256,6 +261,7 @@ def flush(self, batch):
"task_updates": make_json_safe(task_updates_list),
},
headers=self._common_headers(),
verify=self.verify,
)
try:
response.raise_for_status()
Expand Down Expand Up @@ -290,7 +296,9 @@ def _common_headers(self) -> Dict[str, Any]:

def validate_auth(self):
logger.debug(f"Validating auth against {self.base_url}/phone_home")
response = requests.get(f"{self.base_url}/phone_home", headers=self._common_headers())
response = requests.get(
f"{self.base_url}/phone_home", headers=self._common_headers(), verify=self.verify
)
try:
response.raise_for_status()
logger.debug(f"Successfully validated auth against {self.base_url}/phone_home")
Expand All @@ -311,6 +319,7 @@ def register_code_version_if_not_exists(
response = requests.get(
f"{self.base_url}/project_versions/exists?project_id={project_id}&code_hash={code_hash}",
headers=self._common_headers(),
verify=self.verify,
)
try:
response.raise_for_status()
Expand Down Expand Up @@ -343,6 +352,7 @@ def register_code_version_if_not_exists(
"version_info_schema": 1, # TODO -- wire this through appropriately
"code_log": {"files": code_slurped},
},
verify=self.verify,
)
try:
code_version_created.raise_for_status()
Expand All @@ -358,7 +368,9 @@ def register_code_version_if_not_exists(
def project_exists(self, project_id: int) -> bool:
logger.debug(f"Checking if project {project_id} exists")
response = requests.get(
f"{self.base_url}/projects/{project_id}", headers=self._common_headers()
f"{self.base_url}/projects/{project_id}",
headers=self._common_headers(),
verify=self.verify,
)
try:
response.raise_for_status()
Expand Down Expand Up @@ -401,6 +413,7 @@ def register_dag_template_if_not_exists(
response = requests.get(
f"{self.base_url}/dag_templates/exists/?dag_hash={dag_hash}&{params}",
headers=self._common_headers(),
verify=self.verify,
)
response.raise_for_status()
logger.debug(f"DAG template {dag_hash} exists for project {project_id}")
Expand Down Expand Up @@ -430,6 +443,7 @@ def register_dag_template_if_not_exists(
"code_version_info_schema": 1,
},
headers=self._common_headers(),
verify=self.verify,
)
try:
dag_template_created.raise_for_status()
Expand Down Expand Up @@ -458,6 +472,7 @@ def create_and_start_dag_run(
"run_status": "RUNNING",
}
),
verify=self.verify,
)
try:
response.raise_for_status()
Expand Down Expand Up @@ -498,6 +513,7 @@ def log_dag_run_end(self, dag_run_id: int, status: str):
f"{self.base_url}/dag_runs/{dag_run_id}/",
json=make_json_safe({"run_status": status, "run_end_time": datetime.datetime.utcnow()}),
headers=self._common_headers(),
verify=self.verify,
)
try:
response.raise_for_status()
Expand All @@ -508,17 +524,32 @@ def log_dag_run_end(self, dag_run_id: int, status: str):


class BasicAsynchronousHamiltonClient(HamiltonClient):
def __init__(self, api_key: str, username: str, h_api_url: str, base_path: str = "/api/v1"):
def __init__(
self,
api_key: str,
username: str,
h_api_url: str,
base_path: str = "/api/v1",
verify: str | bool = True,
):
"""Initializes an async Hamilton API client
project: Project to save to
:param api_key: API key to save to
:param username: Username to authenticate against
:param h_api_url: API URL for Hamilton API.
:param base_path:
:param verify: SSL verification options in requests format
"""
self.api_key = api_key
self.username = username
self.base_url = h_api_url + base_path
if verify is True:
self.ssl = True
elif verify is False:
self.ssl = False
else:
self.ssl = ssl.create_default_context(cafile=verify)
self.flush_interval = 5
self.data_queue = asyncio.Queue()
self.running = True
Expand All @@ -542,6 +573,7 @@ async def flush(self, batch):
"task_updates": make_json_safe(task_updates_list),
},
headers=self._common_headers(),
ssl=self.ssl,
) as response:
try:
response.raise_for_status()
Expand Down Expand Up @@ -590,7 +622,9 @@ async def validate_auth(self):
logger.debug(f"Validating auth against {self.base_url}/phone_home")
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/phone_home", headers=self._common_headers()
f"{self.base_url}/phone_home",
headers=self._common_headers(),
ssl=self.ssl,
) as response:
try:
response.raise_for_status()
Expand All @@ -613,6 +647,7 @@ async def register_code_version_if_not_exists(
async with session.get(
f"{self.base_url}/project_versions/exists?project_id={project_id}&code_hash={code_hash}",
headers=self._common_headers(),
ssl=self.ssl,
) as response:
try:
response.raise_for_status()
Expand Down Expand Up @@ -648,6 +683,7 @@ async def register_code_version_if_not_exists(
"version_info_schema": 1, # TODO -- wire this through appropriately
"code_log": {"files": code_slurped},
},
ssl=self.ssl,
) as response:
try:
response.raise_for_status()
Expand All @@ -664,7 +700,9 @@ async def project_exists(self, project_id: int) -> bool:
logger.debug(f"Checking if project {project_id} exists")
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/projects/{project_id}", headers=self._common_headers()
f"{self.base_url}/projects/{project_id}",
headers=self._common_headers(),
ssl=self.ssl,
) as response:
try:
response.raise_for_status()
Expand Down Expand Up @@ -706,6 +744,7 @@ async def register_dag_template_if_not_exists(
async with session.get(
f"{self.base_url}/dag_templates/exists/?dag_hash={dag_hash}&{params}",
headers=self._common_headers(),
ssl=self.ssl,
) as response:
try:
response.raise_for_status()
Expand Down Expand Up @@ -741,6 +780,7 @@ async def register_dag_template_if_not_exists(
"code_version_info_schema": 1,
},
headers=self._common_headers(),
ssl=self.ssl,
) as response:
try:
response.raise_for_status()
Expand Down Expand Up @@ -770,6 +810,7 @@ async def create_and_start_dag_run(
}
),
headers=self._common_headers(),
ssl=self.ssl,
) as response:
try:
response.raise_for_status()
Expand Down Expand Up @@ -802,7 +843,7 @@ async def log_dag_run_end(self, dag_run_id: int, status: str):
data = make_json_safe({"run_status": status, "run_end_time": datetime.datetime.utcnow()})
headers = self._common_headers()
async with aiohttp.ClientSession() as session:
async with session.put(url, json=data, headers=headers) as response:
async with session.put(url, json=data, headers=headers, ssl=self.ssl) as response:
try:
response.raise_for_status()
logger.debug(f"Logged end of DAG run {dag_run_id}")
Expand Down

0 comments on commit 067377b

Please sign in to comment.