Skip to content

Commit

Permalink
Adds capture of NamedTuples
Browse files Browse the repository at this point in the history
They can be converted to dictionaries easily. So we do that.
Otherwise we special case 'secret_key' since that's a legacy way
we were telling people to wrap API Keys with.

Fixes some unit tests that weren't updated.
  • Loading branch information
skrawcz committed May 8, 2024
1 parent a9c21b0 commit 6758624
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 4 deletions.
5 changes: 5 additions & 0 deletions ui/sdk/src/hamilton_sdk/tracking/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ def compute_stats_tuple(result: tuple, node_name: str, node_tags: dict) -> Dict[
},
"observability_schema_version": "0.0.2",
}
# namedtuple -- this how we guide people to not have something tracked easily.
# so we skip it if it has a `secret_key`. This is hacky -- better choice would
# be to have an internal object or way to decorate a parameter to not track it.
if hasattr(result, "_asdict") and not hasattr(result, "secret_key"):
return compute_stats_dict(result._asdict(), node_name, node_tags)
return {
"observability_type": "unsupported",
"observability_value": {
Expand Down
4 changes: 2 additions & 2 deletions ui/sdk/tests/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_hash_module_with_subpackage():
seen_modules = set()
result = _hash_module(submodule1, hash_object, seen_modules)

assert result.hexdigest() == "9f0b697b6b071d0ca6df18532031a8e553a8327531e249ff457d772b8bd392c7"
assert result.hexdigest() == "4466d1f61b2c57c2b5bfe8a9fec09acd53befcfdf2f5720075aef83e3d6c6bf8"
assert len(seen_modules) == 2
assert {m.__name__ for m in seen_modules} == {
"tests.test_package_to_hash.subpackage",
Expand All @@ -62,7 +62,7 @@ def test_hash_module_complex():
seen_modules = set()
result = _hash_module(test_package_to_hash, hash_object, seen_modules)

assert result.hexdigest() == "fc568608a2f766eac3cbae4021fb367247c6aa36ac4ae72ea98104c1ba2a5e1c"
assert result.hexdigest() == "c22023a4fdc8564de1cda70d05a19d5e8c0ddaaa9dcccf644a2b789b80f19896"
assert len(seen_modules) == 4
assert {m.__name__ for m in seen_modules} == {
"tests.test_package_to_hash",
Expand Down
39 changes: 37 additions & 2 deletions ui/sdk/tests/tracking/test_stats.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import sys
from collections import namedtuple
from typing import NamedTuple

import pandas as pd
from hamilton_sdk.tracking import stats
Expand Down Expand Up @@ -40,19 +42,52 @@ def test_compute_stats_tuple_dataloader():
"test_node",
{"hamilton.data_loader": True},
)
# in python 3.11 numbers change
if sys.version_info >= (3, 11):
index = 132
memory = 156
else:
index = 128
memory = 152
assert actual == {
"observability_schema_version": "0.0.2",
"observability_type": "dict",
"observability_value": {
"type": "<class 'dict'>",
"value": {
"CONNECTION_INFO": {"URL": "SOME_URL"},
"DF_MEMORY_BREAKDOWN": {"Index": 128, "a": 24},
"DF_MEMORY_TOTAL": 152,
"DF_MEMORY_BREAKDOWN": {"Index": index, "a": 24},
"DF_MEMORY_TOTAL": memory,
"QUERIED_TABLES": {
"table-0": {"catalog": "FOO", "database": "BAR", "name": "TABLE"}
},
"SQL_QUERY": "SELECT * FROM FOO.BAR.TABLE",
},
},
}


def test_compute_states_tuple_namedtuple():
"""Tests namedtuple capture correctly"""

class Foo(NamedTuple):
x: int
y: str

f = Foo(1, "a")
actual = stats.compute_stats(
f,
"test_node",
{"some_tag": "foo-bar"},
)
assert actual == {
"observability_schema_version": "0.0.2",
"observability_type": "dict",
"observability_value": {
"type": "<class 'dict'>",
"value": {
"x": 1,
"y": "a",
},
},
}

0 comments on commit 6758624

Please sign in to comment.