Skip to content

Commit

Permalink
Adds configurable capture for SDK
Browse files Browse the repository at this point in the history
This adds a few things to enable one to more
easily customize what is captured via the SDK.

1. Should you capture any data statistics at all?
2. Max lengths of dicts and lists -> so they don't get too big.

You can configure this via three means:
1. python constants
2. config file variables
3. environment variables

There is a prescedence order. So there are default values in the module.
You can then override them via config file variables.
Which then can in turn be overriden by environment variables.
Lastly the user can always modify the constants by directly
changing the module variables.

Note: we also skip logging metadata from datasavers and loaders if
the CAPTURE_DATA_STATISTICS = False. We can fix this by special
casing it, but for now I don't think people would complain with
the current functionality.
  • Loading branch information
skrawcz committed Nov 16, 2024
1 parent da13c36 commit fc6093a
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 9 deletions.
65 changes: 65 additions & 0 deletions ui/sdk/src/hamilton_sdk/tracking/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""This module contains constants for tracking.
We then override these by:
1. Looking for a configuration file and taking the section under `SDK_CONSTANTS`.
2. Via environment variables. They should be prefixed with `HAMILTON_`.
3. Lastly folks can manually adjust these values directly by importing the module and changing the value.
Note: This module cannot import other Hamilton modules.
"""

import configparser
import logging
import os

logger = logging.getLogger(__name__)

# The following are the default values for the tracking client
CAPTURE_DATA_STATISTICS = True
MAX_LIST_LENGTH_CAPTURE = 50
MAX_DICT_LENGTH_CAPTURE = 100

# Check for configuration file
# TODO -- add configuration file support
DEFAULT_CONFIG_URI = os.environ.get("HAMILTON_CONFIG_URI", "~/.hamilton.conf")
DEFAULT_CONFIG_LOCATION = os.path.expanduser(DEFAULT_CONFIG_URI)


def _load_config(config_location: str) -> configparser.ConfigParser:
"""Pulls config if it exists.
:param config_location: location of the config file.
"""
config = configparser.ConfigParser()
try:
with open(config_location) as f:
config.read_file(f)
except Exception:
pass

return config


_constant_values = globals()
file_config = _load_config(DEFAULT_CONFIG_LOCATION)
# loads from config file and overwrites
if "SDK_CONSTANTS" in file_config:
for key, val in file_config["SDK_CONSTANTS"].items():
upper_key = key.upper()
if upper_key not in _constant_values:
continue
# overwrite value
_constant_values[upper_key] = val

# Check for environment variables & overwrites
# TODO automate this by pulling anything in with a prefix and checking
# globals here and updating them.
CAPTURE_DATA_STATISTICS = os.getenv("HAMILTON_CAPTURE_DATA_STATISTICS", CAPTURE_DATA_STATISTICS)
if isinstance(CAPTURE_DATA_STATISTICS, str):
CAPTURE_DATA_STATISTICS = CAPTURE_DATA_STATISTICS.lower() == "true"
MAX_LIST_LENGTH_CAPTURE = int(
os.getenv("HAMILTON_MAX_LIST_LENGTH_CAPTURE", MAX_LIST_LENGTH_CAPTURE)
)
MAX_DICT_LENGTH_CAPTURE = int(
os.getenv("HAMILTON_MAX_DICT_LENGTH_CAPTURE", MAX_DICT_LENGTH_CAPTURE)
)
21 changes: 20 additions & 1 deletion ui/sdk/src/hamilton_sdk/tracking/data_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pandas as pd
from hamilton_sdk.tracking import sql_utils
from hamilton_sdk.tracking import constants

# Multiple observations per are allowed
ObservationType = Dict[str, Any]
Expand Down Expand Up @@ -74,6 +75,17 @@ def compute_stats_primitives(result, node_name: str, node_tags: dict) -> Observa
@compute_stats.register(dict)
def compute_stats_dict(result: dict, node_name: str, node_tags: dict) -> ObservationType:
"""call summary stats on the values in the dict"""
truncated = False
if len(result) >= constants.MAX_DICT_LENGTH_CAPTURE:
new_result = {}
for i, k in enumerate(result.keys()):
new_result[k] = result[k]
if i + 1 < constants.MAX_DICT_LENGTH_CAPTURE:
continue
else:
break
result = new_result # replace pointer with smaller dict
truncated = True
try:
# if it's JSON serializable, take it.
json.dumps(result)
Expand Down Expand Up @@ -109,7 +121,8 @@ def compute_stats_dict(result: dict, node_name: str, node_tags: dict) -> Observa
else:
# it's a DF, Series -- so take full result.
result_values[k] = v_result["observability_value"]

if truncated:
result["__truncated__"] = "... values truncated ..."
return {
"observability_type": "dict",
"observability_value": {
Expand Down Expand Up @@ -170,6 +183,10 @@ def compute_stats_tuple(result: tuple, node_name: str, node_tags: dict) -> Obser
@compute_stats.register(list)
def compute_stats_list(result: list, node_name: str, node_tags: dict) -> ObservationType:
"""call summary stats on the values in the list"""
truncated = False
if len(result) > constants.MAX_LIST_LENGTH_CAPTURE:
result = result[: constants.MAX_LIST_LENGTH_CAPTURE]
truncated = True
try:
# if it's JSON serializable, take it.
json.dumps(result)
Expand Down Expand Up @@ -200,6 +217,8 @@ def compute_stats_list(result: list, node_name: str, node_tags: dict) -> Observa
elif observed_type == "dict":
v = v_result["observability_value"]
result_values.append(v)
if truncated:
result_values.append("... truncated ...")
return {
# yes dict type -- that's so that we can display in the UI. It's a hack.
"observability_type": "dict",
Expand Down
28 changes: 20 additions & 8 deletions ui/sdk/src/hamilton_sdk/tracking/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import datetime, timezone
from typing import Any, Callable, Dict, List, Optional, Tuple

from hamilton_sdk.tracking import data_observation
from hamilton_sdk.tracking import constants, data_observation
from hamilton_sdk.tracking.data_observation import ObservationType
from hamilton_sdk.tracking.trackingtypes import DAGRun, Status, TaskRun

Expand Down Expand Up @@ -59,13 +59,25 @@ def process_result(
statistics = None
schema = None
additional = []
try:
start = py_time.time()
statistics = data_observation.compute_stats(result, node.name, node.tags)
end = py_time.time()
logger.debug(f"Took {end - start} seconds to describe {node.name}")
except Exception as e:
logger.warning(f"Failed to introspect statistics for {node.name}. Error:\n{e}")
if constants.CAPTURE_DATA_STATISTICS:
try:
start = py_time.time()
statistics = data_observation.compute_stats(result, node.name, node.tags)
end = py_time.time()
logger.debug(f"Took {end - start} seconds to describe {node.name}")
except Exception as e:
logger.warning(f"Failed to introspect statistics for {node.name}. Error:\n{e}")
else:
# TODO: handle case where it's metadata from a dataloader/saver right now we don't log that
# info, but we should in this particular case.
statistics = {
"observability_type": "primitive",
"observability_value": {
"type": "str",
"value": "RESULT SUMMARY DISABLED",
},
"observability_schema_version": "0.0.1",
}
try:
start = py_time.time()
schema = data_observation.compute_schema(result, node.name, node.tags)
Expand Down
10 changes: 10 additions & 0 deletions ui/sdk/tests/tracking/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,13 @@ def test_process_result_happy(test_result, test_node, observability_type, stats)
if schema is not None:
json.dumps(schema)
[json.dumps(add) for add in additional]


def test_disable_capturing_data_stats(monkeypatch):
monkeypatch.setattr("hamilton_sdk.tracking.constants.CAPTURE_DATA_STATISTICS", False)
stats, schema, additional = runs.process_result([1, 2, 3, 4], create_node("a", list))
assert stats["observability_type"] == "primitive"
assert stats["observability_value"] == {
"type": "str",
"value": "RESULT SUMMARY DISABLED",
}
29 changes: 29 additions & 0 deletions ui/sdk/tests/tracking/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,35 @@ def test_compute_stats_dict():
}


def test_compute_stats_dict_truncated(monkeypatch):
monkeypatch.setattr("hamilton_sdk.tracking.constants.MAX_DICT_LENGTH_CAPTURE", 1)
actual = data_observation.compute_stats({"a": 1, "b": 2}, "test_node", {})
assert actual == {
"observability_type": "dict",
"observability_value": {
"type": str(type(dict())),
"value": {
"__truncated__": "... values truncated ...",
"a": 1,
},
},
"observability_schema_version": "0.0.2",
}


def test_compute_stats_list_truncated(monkeypatch):
monkeypatch.setattr("hamilton_sdk.tracking.constants.MAX_LIST_LENGTH_CAPTURE", 1)
actual = data_observation.compute_stats([1, 2, 3, 4], "test_node", {})
assert actual == {
"observability_type": "dict",
"observability_value": {
"type": str(type(list())),
"value": [1, "... truncated ..."],
},
"observability_schema_version": "0.0.2",
}


def test_compute_stats_tuple_dataloader():
"""tests case of a dataloader"""
actual = data_observation.compute_stats(
Expand Down

0 comments on commit fc6093a

Please sign in to comment.