Skip to content

Commit

Permalink
Merge pull request #8897 from khoaguin/save-small-variables-without-b…
Browse files Browse the repository at this point in the history
…lob-storage

Send small variables to storage without blob storage
  • Loading branch information
shubham3121 authored Jul 3, 2024
2 parents be0deb9 + 49ceece commit 5b4c8e0
Show file tree
Hide file tree
Showing 18 changed files with 296 additions and 79 deletions.
1 change: 1 addition & 0 deletions packages/grid/backend/grid/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def get_emails_enabled(self) -> Self:
ASSOCIATION_REQUEST_AUTO_APPROVAL: bool = str_to_bool(
os.getenv("ASSOCIATION_REQUEST_AUTO_APPROVAL", "False")
)
MIN_SIZE_BLOB_STORAGE_MB: int = int(os.getenv("MIN_SIZE_BLOB_STORAGE_MB", 16))
REVERSE_TUNNEL_ENABLED: bool = str_to_bool(
os.getenv("REVERSE_TUNNEL_ENABLED", "false")
)
Expand Down
5 changes: 4 additions & 1 deletion packages/grid/backend/grid/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def seaweedfs_config() -> SeaweedFSConfig:
mount_port=settings.SEAWEED_MOUNT_PORT,
)

return SeaweedFSConfig(client_config=seaweed_client_config)
return SeaweedFSConfig(
client_config=seaweed_client_config,
min_blob_size=settings.MIN_SIZE_BLOB_STORAGE_MB,
)


node_type = NodeType(get_node_type())
Expand Down
1 change: 1 addition & 0 deletions packages/grid/default.env
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ CREATE_PRODUCER=False
N_CONSUMERS=1
INMEMORY_WORKERS=True
ASSOCIATION_REQUEST_AUTO_APPROVAL=False
MIN_SIZE_BLOB_STORAGE_MB=16

# New Service Flag
USE_NEW_SERVICE=False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ spec:
name: {{ .Values.seaweedfs.secretKeyName | required "seaweedfs.secretKeyName is required" }}
key: s3RootPassword
{{- end }}
- name: MIN_SIZE_BLOB_STORAGE_MB
value: {{ .Values.seaweedfs.minSizeBlobStorageMB | quote }}
# Tracing
- name: TRACE
value: "false"
Expand Down
2 changes: 2 additions & 0 deletions packages/grid/helm/syft/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ seaweedfs:
s3:
rootUser: admin

minSizeBlobStorageMB: 16

# Mount API
mountApi:
# automount:
Expand Down
12 changes: 10 additions & 2 deletions packages/syft/src/syft/client/domain_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..service.dataset.dataset import CreateDataset
from ..service.response import SyftError
from ..service.response import SyftSuccess
from ..service.response import SyftWarning
from ..service.sync.diff_state import ResolvedSyncState
from ..service.sync.sync_state import SyncState
from ..service.user.roles import Roles
Expand Down Expand Up @@ -131,7 +132,7 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError:
) as pbar:
for asset in dataset.asset_list:
try:
contains_empty = asset.contains_empty()
contains_empty: bool = asset.contains_empty()
twin = TwinObject(
private_obj=ActionObject.from_obj(asset.data),
mock_obj=ActionObject.from_obj(asset.mock),
Expand All @@ -145,8 +146,15 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError:
tqdm.write(f"Failed to create twin for {asset.name}. {e}")
return SyftError(message=f"Failed to create twin. {e}")

if isinstance(res, SyftWarning):
logger.debug(res.message)
skip_save_to_blob_store = True
else:
skip_save_to_blob_store = False
response = self.api.services.action.set(
twin, ignore_detached_objs=contains_empty
twin,
ignore_detached_objs=contains_empty,
skip_save_to_blob_store=skip_save_to_blob_store,
)
if isinstance(response, SyftError):
tqdm.write(f"Failed to upload asset: {asset.name}")
Expand Down
23 changes: 19 additions & 4 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,10 @@ def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None:
client_config = OnDiskBlobStorageClientConfig(
base_directory=self.get_temp_dir("blob")
)
config_ = OnDiskBlobStorageConfig(client_config=client_config)
config_ = OnDiskBlobStorageConfig(
client_config=client_config,
min_blob_size=os.getenv("MIN_SIZE_BLOB_STORAGE_MB", 16),
)
else:
config_ = config
self.blob_store_config = config_
Expand All @@ -505,6 +508,17 @@ def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None:
remote_profile.profile_name
] = remote_profile

if self.dev_mode:
if isinstance(self.blob_store_config, OnDiskBlobStorageConfig):
logger.debug(
f"Using on-disk blob storage with path: "
f"{self.blob_store_config.client_config.base_directory}",
)
logger.debug(
f"Minimum object size to be saved to the blob storage: "
f"{self.blob_store_config.min_blob_size} (MB)."
)

def run_peer_health_checks(self, context: AuthedServiceContext) -> None:
self.peer_health_manager = PeerHealthCheckTask()
self.peer_health_manager.run(context=context)
Expand Down Expand Up @@ -1072,6 +1086,7 @@ def metadata(self) -> NodeMetadata:
node_side_type=node_side_type,
show_warnings=show_warnings,
eager_execution_enabled=eager_execution_enabled,
min_size_blob_storage_mb=self.blob_store_config.min_blob_size,
)

@property
Expand Down Expand Up @@ -1796,7 +1811,7 @@ def create_default_worker_pool(node: Node) -> SyftError | None:
)
return default_worker_pool

logger.info(f"Creating default worker image with tag='{default_worker_tag}'")
logger.info(f"Creating default worker image with tag='{default_worker_tag}'. ")
# Get/Create a default worker SyftWorkerImage
default_image = create_default_image(
credentials=credentials,
Expand All @@ -1809,7 +1824,7 @@ def create_default_worker_pool(node: Node) -> SyftError | None:
return default_image

if not default_image.is_built:
logger.info(f"Building default worker image with tag={default_worker_tag}")
logger.info(f"Building default worker image with tag={default_worker_tag}. ")
image_build_method = node.get_service_method(SyftWorkerImageService.build)
# Build the Image for given tag
result = image_build_method(
Expand All @@ -1829,7 +1844,7 @@ def create_default_worker_pool(node: Node) -> SyftError | None:
f"name={default_pool_name} "
f"workers={worker_count} "
f"image_uid={default_image.id} "
f"in_memory={node.in_memory_workers}"
f"in_memory={node.in_memory_workers}. "
)
if default_worker_pool is None:
worker_to_add_ = worker_count
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"NodeMetadata": {
"5": {
"version": 5,
"hash": "70197b4725dbdea0560ed8388e4d20b76808bee988f3630c5f916ee8f48761f8",
"hash": "f3927d167073a4db369a07e3bbbf756075bbb29e9addec324b8cd2c3597b75a1",
"action": "add"
}
},
Expand Down
135 changes: 79 additions & 56 deletions packages/syft/src/syft/service/action/action_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
from ...node.credentials import SyftVerifyKey
from ...serde.serializable import serializable
from ...serde.serialize import _serialize as serialize
from ...service.blob_storage.util import can_upload_to_blob_storage
from ...service.response import SyftError
from ...service.response import SyftSuccess
from ...service.response import SyftWarning
from ...store.linked_obj import LinkedObject
from ...types.base import SyftBaseModel
from ...types.datetime import DateTime
Expand Down Expand Up @@ -508,44 +511,33 @@ def convert_to_pointers(
# relative
from ..dataset.dataset import Asset

arg_list = []
kwarg_dict = {}
if args is not None:
for arg in args:
if (
not isinstance(arg, ActionObject | Asset | UID)
and api.signing_key is not None # type: ignore[unreachable]
):
arg = ActionObject.from_obj( # type: ignore[unreachable]
syft_action_data=arg,
syft_client_verify_key=api.signing_key.verify_key,
syft_node_location=api.node_uid,
)
arg.syft_node_uid = node_uid
r = arg._save_to_blob_storage()
if isinstance(r, SyftError):
logger.error(r.message)
arg = api.services.action.set(arg)
arg_list.append(arg)

if kwargs is not None:
for k, arg in kwargs.items():
if (
not isinstance(arg, ActionObject | Asset | UID)
and api.signing_key is not None # type: ignore[unreachable]
):
arg = ActionObject.from_obj( # type: ignore[unreachable]
syft_action_data=arg,
syft_client_verify_key=api.signing_key.verify_key,
syft_node_location=api.node_uid,
)
arg.syft_node_uid = node_uid
r = arg._save_to_blob_storage()
if isinstance(r, SyftError):
logger.error(r.message)
arg = api.services.action.set(arg)
def process_arg(arg: ActionObject | Asset | UID | Any) -> Any:
if (
not isinstance(arg, ActionObject | Asset | UID)
and api.signing_key is not None # type: ignore[unreachable]
):
arg = ActionObject.from_obj( # type: ignore[unreachable]
syft_action_data=arg,
syft_client_verify_key=api.signing_key.verify_key,
syft_node_location=api.node_uid,
)
arg.syft_node_uid = node_uid
r = arg._save_to_blob_storage()
if isinstance(r, SyftError):
print(r.message)
if isinstance(r, SyftWarning):
logger.debug(r.message)
skip_save_to_blob_store = True
else:
skip_save_to_blob_store = False
arg = api.services.action.set(
arg,
skip_save_to_blob_store=skip_save_to_blob_store,
)
return arg

kwarg_dict[k] = arg
arg_list = [process_arg(arg) for arg in args] if args else []
kwarg_dict = {k: process_arg(v) for k, v in kwargs.items()} if kwargs else {}

return arg_list, kwarg_dict

Expand Down Expand Up @@ -801,7 +793,7 @@ def reload_cache(self) -> SyftError | None:

return None

def _save_to_blob_storage_(self, data: Any) -> SyftError | None:
def _save_to_blob_storage_(self, data: Any) -> SyftError | SyftWarning | None:
# relative
from ...types.blob_storage import BlobFile
from ...types.blob_storage import CreateBlobStorageEntry
Expand All @@ -814,6 +806,18 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | None:
)
data._upload_to_blobstorage_from_api(api)
else:
get_metadata = from_api_or_context(
func_or_path="metadata.get_metadata",
syft_node_location=self.syft_node_location,
syft_client_verify_key=self.syft_client_verify_key,
)
if get_metadata is not None and not can_upload_to_blob_storage(
data, get_metadata()
):
return SyftWarning(
message=f"The action object {self.id} was not saved to "
f"the blob store but to memory cache since it is small."
)
serialized = serialize(data, to_bytes=True)
size = sys.getsizeof(serialized)
storage_entry = CreateBlobStorageEntry.from_obj(data, file_size=size)
Expand All @@ -830,13 +834,13 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | None:
)
if allocate_method is not None:
blob_deposit_object = allocate_method(storage_entry)

if isinstance(blob_deposit_object, SyftError):
return blob_deposit_object

result = blob_deposit_object.write(BytesIO(serialized))
if isinstance(result, SyftError):
return result

self.syft_blob_storage_entry_id = (
blob_deposit_object.blob_storage_entry_id
)
Expand All @@ -855,6 +859,33 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | None:

return None

def _save_to_blob_storage(
self, allow_empty: bool = False
) -> SyftError | SyftSuccess | SyftWarning:
data = self.syft_action_data
if isinstance(data, SyftError):
return data

if isinstance(data, ActionDataEmpty):
return SyftError(
message=f"cannot store empty object {self.id} to the blob storage"
)

try:
result = self._save_to_blob_storage_(data)
if isinstance(result, SyftError | SyftWarning):
return result
if not TraceResultRegistry.current_thread_is_tracing():
self._clear_cache()
return SyftSuccess(
message=f"Saved action object {self.id} to the blob store"
)
except Exception as e:
raise e

def _clear_cache(self) -> None:
self.syft_action_data_cache = self.as_empty_data()

def _set_reprs(self, data: any) -> None:
if inspect.isclass(data):
self.syft_action_data_repr_ = truncate_str(repr_cls(data))
Expand All @@ -866,22 +897,6 @@ def _set_reprs(self, data: any) -> None:
)
self.syft_action_data_str_ = truncate_str(str(data))

def _save_to_blob_storage(self, allow_empty: bool = False) -> SyftError | None:
data = self.syft_action_data
if isinstance(data, SyftError):
return data
if isinstance(data, ActionDataEmpty) and not allow_empty:
return SyftError(message=f"cannot store empty object {self.id}")
result = self._save_to_blob_storage_(data)
if isinstance(result, SyftError):
return result
if not TraceResultRegistry.current_thread_is_tracing():
self._clear_cache()
return None

def _clear_cache(self) -> None:
self.syft_action_data_cache = self.as_empty_data()

@property
def is_pointer(self) -> bool:
return self.syft_node_uid is not None
Expand Down Expand Up @@ -1229,8 +1244,16 @@ def _send(
api = self._get_api()
if isinstance(api, SyftError):
return api

if isinstance(blob_storage_res, SyftWarning):
logger.debug(blob_storage_res.message)
skip_save_to_blob_store = True
else:
skip_save_to_blob_store = False
res = api.services.action.set(
self, add_storage_permission=add_storage_permission
self,
add_storage_permission=add_storage_permission,
skip_save_to_blob_store=skip_save_to_blob_store,
)
if isinstance(res, ActionObject):
self.syft_created_at = res.syft_created_at
Expand Down
Loading

0 comments on commit 5b4c8e0

Please sign in to comment.