From 4c3e44c395c05a1a6b04bdd36bef38b4bb3c3e26 Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Thu, 8 Aug 2024 14:18:50 -0400 Subject: [PATCH] update get_mb_size util function to handle collections --- packages/syft/src/syft/util/util.py | 65 ++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index 43620c4cab5..e0d729ba123 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -1,6 +1,7 @@ # stdlib import asyncio from asyncio.selector_events import BaseSelectorEventLoop +from collections import deque from collections.abc import Callable from collections.abc import Iterator from collections.abc import Sequence @@ -11,6 +12,7 @@ from datetime import datetime import functools import hashlib +from itertools import chain from itertools import repeat import json import logging @@ -29,6 +31,7 @@ from secrets import randbelow import socket import sys +from sys import getsizeof import threading import time import types @@ -93,8 +96,66 @@ def get_name_for(klass: type) -> str: return klass_name -def get_mb_size(data: Any) -> float: - return sys.getsizeof(data) / (1024 * 1024) +def get_mb_size(data: Any, handlers: dict | None = None) -> float: + """Returns the approximate memory footprint an object and all of its contents. + + Automatically finds the contents of the following builtin containers and + their subclasses: tuple, list, deque, dict, set and frozenset. + Otherwise, tries to read from the __slots__ or __dict__ of the object. + To search other containers, add handlers to iterate over their contents: + + handlers = {SomeContainerClass: iter, + OtherContainerClass: OtherContainerClass.get_elements} + + Lightly modified from + https://code.activestate.com/recipes/577504-compute-memory-footprint-of-an-object-and-its-cont/ + which is referenced in official sys.getsizeof documentation + https://docs.python.org/3/library/sys.html#sys.getsizeof. + + """ + + def dict_handler(d: dict[Any, Any]) -> Iterator[Any]: + return chain.from_iterable(d.items()) + + all_handlers = { + tuple: iter, + list: iter, + deque: iter, + dict: dict_handler, + set: iter, + frozenset: iter, + } + if handlers: + all_handlers.update(handlers) # user handlers take precedence + seen = set() # track which object id's have already been seen + default_size = getsizeof(0) # estimate sizeof object without __sizeof__ + + def sizeof(o: Any) -> int: + if id(o) in seen: # do not double count the same object + return 0 + seen.add(id(o)) + s = getsizeof(o, default_size) + + for typ, handler in all_handlers.items(): + if isinstance(o, typ): + s += sum(map(sizeof, handler(o))) # type: ignore + break + else: + # no __slots__ *usually* means a __dict__, but some special builtin classes + # (such as `type(None)`) have neither else, `o` has no attributes at all, + # so sys.getsizeof() actually returned the correct value + if not hasattr(o.__class__, "__slots__"): + if hasattr(o, "__dict__"): + s += sizeof(o.__dict__) + else: + s += sum( + sizeof(getattr(o, x)) + for x in o.__class__.__slots__ + if hasattr(o, x) + ) + return s + + return sizeof(data) / (1024.0 * 1024.0) def get_mb_serialized_size(data: Any) -> float: