Skip to content

Commit

Permalink
Recordset basic enhancements (#2830)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Jan 19, 2024
1 parent fd581f2 commit 818f6b7
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 17 deletions.
19 changes: 14 additions & 5 deletions src/py/flwr/common/configsrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
class ConfigsRecord:
"""Configs record."""

keep_input: bool
data: Dict[str, ConfigsRecordValues] = field(default_factory=dict)

def __init__(
Expand All @@ -47,12 +46,13 @@ def __init__(
to True, the data is duplicated in memory. If memory is a concern, set
it to False.
"""
self.keep_input = keep_input
self.data = {}
if configs_dict:
self.set_configs(configs_dict)
self.set_configs(configs_dict, keep_input=keep_input)

def set_configs(self, configs_dict: Dict[str, ConfigsRecordValues]) -> None:
def set_configs(
self, configs_dict: Dict[str, ConfigsRecordValues], keep_input: bool = True
) -> None:
"""Add configs to the record.
Parameters
Expand All @@ -61,6 +61,11 @@ def set_configs(self, configs_dict: Dict[str, ConfigsRecordValues]) -> None:
A dictionary that stores basic types (i.e. `str`,`int`, `float`, `bytes` as
defined in `ConfigsRecordValues`) and list of such types (see
`ConfigsScalarList`).
keep_input : bool (default: True)
A boolean indicating whether config passed should be deleted from the input
dictionary immediately after adding them to the record. When set
to True, the data is duplicated in memory. If memory is a concern, set
it to False.
"""
if any(not isinstance(k, str) for k in configs_dict.keys()):
raise TypeError(f"Not all keys are of valid type. Expected {str}")
Expand Down Expand Up @@ -88,11 +93,15 @@ def is_valid(value: ConfigsScalar) -> None:
is_valid(value)

# Add configs to record
if self.keep_input:
if keep_input:
# Copy
self.data = configs_dict.copy()
else:
# Add entries to dataclass without duplicating memory
for key in list(configs_dict.keys()):
self.data[key] = configs_dict[key]
del configs_dict[key]

def __getitem__(self, key: str) -> ConfigsRecordValues:
"""Retrieve an element stored in record."""
return self.data[key]
19 changes: 14 additions & 5 deletions src/py/flwr/common/metricsrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
class MetricsRecord:
"""Metrics record."""

keep_input: bool
data: Dict[str, MetricsRecordValues] = field(default_factory=dict)

def __init__(
Expand All @@ -46,19 +45,25 @@ def __init__(
to True, the data is duplicated in memory. If memory is a concern, set
it to False.
"""
self.keep_input = keep_input
self.data = {}
if metrics_dict:
self.set_metrics(metrics_dict)
self.set_metrics(metrics_dict, keep_input=keep_input)

def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None:
def set_metrics(
self, metrics_dict: Dict[str, MetricsRecordValues], keep_input: bool = True
) -> None:
"""Add metrics to the record.
Parameters
----------
metrics_dict : Dict[str, MetricsRecordValues]
A dictionary that stores basic types (i.e. `int`, `float` as defined
in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
keep_input : bool (default: True)
A boolean indicating whether metrics should be deleted from the input
dictionary immediately after adding them to the record. When set
to True, the data is duplicated in memory. If memory is a concern, set
it to False.
"""
if any(not isinstance(k, str) for k in metrics_dict.keys()):
raise TypeError(f"Not all keys are of valid type. Expected {str}.")
Expand Down Expand Up @@ -86,11 +91,15 @@ def is_valid(value: MetricsScalar) -> None:
is_valid(value)

# Add metrics to record
if self.keep_input:
if keep_input:
# Copy
self.data = metrics_dict.copy()
else:
# Add entries to dataclass without duplicating memory
for key in list(metrics_dict.keys()):
self.data[key] = metrics_dict[key]
del metrics_dict[key]

def __getitem__(self, key: str) -> MetricsRecordValues:
"""Retrieve an element stored in record."""
return self.data[key]
17 changes: 12 additions & 5 deletions src/py/flwr/common/parametersrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class ParametersRecord:
PyTorch's state_dict, but holding serialised tensors instead.
"""

keep_input: bool
data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array])

def __init__(
Expand All @@ -82,29 +81,37 @@ def __init__(
parameters after adding it to the record, set this flag to True. When set
to True, the data is duplicated in memory.
"""
self.keep_input = keep_input
self.data = OrderedDict()
if array_dict:
self.set_parameters(array_dict)
self.set_parameters(array_dict, keep_input=keep_input)

def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None:
def set_parameters(
self, array_dict: OrderedDict[str, Array], keep_input: bool = False
) -> None:
"""Add parameters to record.
Parameters
----------
array_dict : OrderedDict[str, Array]
A dictionary that stores serialized array-like or tensor-like objects.
keep_input : bool (default: False)
A boolean indicating whether parameters should be deleted from the input
dictionary immediately after adding them to the record.
"""
if any(not isinstance(k, str) for k in array_dict.keys()):
raise TypeError(f"Not all keys are of valid type. Expected {str}")
if any(not isinstance(v, Array) for v in array_dict.values()):
raise TypeError(f"Not all values are of valid type. Expected {Array}")

if self.keep_input:
if keep_input:
# Copy
self.data = OrderedDict(array_dict)
else:
# Add entries to dataclass without duplicating memory
for key in list(array_dict.keys()):
self.data[key] = array_dict[key]
del array_dict[key]

def __getitem__(self, key: str) -> Array:
"""Retrieve an element stored in record."""
return self.data[key]
4 changes: 2 additions & 2 deletions src/py/flwr/common/recordset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_set_parameters_while_keeping_intputs() -> None:
array_dict = OrderedDict(
{str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())}
)
p_record.set_parameters(array_dict)
p_record.set_parameters(array_dict, keep_input=True)

# Creating a second parametersrecord passing the same `array_dict` (not erased)
p_record_2 = ParametersRecord(array_dict)
Expand Down Expand Up @@ -253,7 +253,7 @@ def test_set_metrics_to_metricsrecord_with_and_without_keeping_input(
my_metrics_copy = my_metrics.copy()

# Add metric
m_record.set_metrics(my_metrics)
m_record.set_metrics(my_metrics, keep_input=keep_input)

# Check metrics are actually added
# Check that input dict has been emptied when enabled such behaviour
Expand Down

0 comments on commit 818f6b7

Please sign in to comment.