Skip to content

Commit

Permalink
Merge pull request #9150 from OpenMined/bschell/fix-incorrect-dataset…
Browse files Browse the repository at this point in the history
…-size

Update get_mb_size util function to handle containers
  • Loading branch information
IonesioJunior authored Aug 9, 2024
2 parents 4dcd2ae + 01e16c0 commit 62e6630
Showing 1 changed file with 63 additions and 2 deletions.
65 changes: 63 additions & 2 deletions packages/syft/src/syft/util/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# stdlib
import asyncio
from asyncio.selector_events import BaseSelectorEventLoop
from collections import deque
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
Expand All @@ -11,6 +12,7 @@
from datetime import datetime
import functools
import hashlib
from itertools import chain
from itertools import repeat
import json
import logging
Expand All @@ -29,6 +31,7 @@
from secrets import randbelow
import socket
import sys
from sys import getsizeof
import threading
import time
import types
Expand Down Expand Up @@ -93,8 +96,66 @@ def get_name_for(klass: type) -> str:
return klass_name


def get_mb_size(data: Any) -> float:
return sys.getsizeof(data) / (1024 * 1024)
def get_mb_size(data: Any, handlers: dict | None = None) -> float:
"""Returns the approximate memory footprint an object and all of its contents.
Automatically finds the contents of the following builtin containers and
their subclasses: tuple, list, deque, dict, set and frozenset.
Otherwise, tries to read from the __slots__ or __dict__ of the object.
To search other containers, add handlers to iterate over their contents:
handlers = {SomeContainerClass: iter,
OtherContainerClass: OtherContainerClass.get_elements}
Lightly modified from
https://code.activestate.com/recipes/577504-compute-memory-footprint-of-an-object-and-its-cont/
which is referenced in official sys.getsizeof documentation
https://docs.python.org/3/library/sys.html#sys.getsizeof.
"""

def dict_handler(d: dict[Any, Any]) -> Iterator[Any]:
return chain.from_iterable(d.items())

all_handlers = {
tuple: iter,
list: iter,
deque: iter,
dict: dict_handler,
set: iter,
frozenset: iter,
}
if handlers:
all_handlers.update(handlers) # user handlers take precedence
seen = set() # track which object id's have already been seen
default_size = getsizeof(0) # estimate sizeof object without __sizeof__

def sizeof(o: Any) -> int:
if id(o) in seen: # do not double count the same object
return 0
seen.add(id(o))
s = getsizeof(o, default_size)

for typ, handler in all_handlers.items():
if isinstance(o, typ):
s += sum(map(sizeof, handler(o))) # type: ignore
break
else:
# no __slots__ *usually* means a __dict__, but some special builtin classes
# (such as `type(None)`) have neither else, `o` has no attributes at all,
# so sys.getsizeof() actually returned the correct value
if not hasattr(o.__class__, "__slots__"):
if hasattr(o, "__dict__"):
s += sizeof(o.__dict__)
else:
s += sum(
sizeof(getattr(o, x))
for x in o.__class__.__slots__
if hasattr(o, x)
)
return s

return sizeof(data) / (1024.0 * 1024.0)


def get_mb_serialized_size(data: Any) -> float:
Expand Down

0 comments on commit 62e6630

Please sign in to comment.