Skip to content

Commit

Permalink
Add ParametersRecord (#2799)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <[email protected]>
Co-authored-by: Heng Pan <[email protected]>
  • Loading branch information
3 people authored Jan 17, 2024
1 parent 815f662 commit 1fcb147
Show file tree
Hide file tree
Showing 4 changed files with 349 additions and 8 deletions.
110 changes: 110 additions & 0 deletions src/py/flwr/common/parametersrecord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ParametersRecord and Array."""


from dataclasses import dataclass, field
from typing import List, Optional, OrderedDict


@dataclass
class Array:
"""Array type.
A dataclass containing serialized data from an array-like or tensor-like object
along with some metadata about it.
Parameters
----------
dtype : str
A string representing the data type of the serialised object (e.g. `np.float32`)
shape : List[int]
A list representing the shape of the unserialized array-like object. This is
used to deserialize the data (depending on the serialization method) or simply
as a metadata field.
stype : str
A string indicating the type of serialisation mechanism used to generate the
bytes in `data` from an array-like or tensor-like object.
data: bytes
A buffer of bytes containing the data.
"""

dtype: str
shape: List[int]
stype: str
data: bytes


@dataclass
class ParametersRecord:
"""Parameters record.
A dataclass storing named Arrays in order. This means that it holds entries as an
OrderedDict[str, Array]. ParametersRecord objects can be viewed as an equivalent to
PyTorch's state_dict, but holding serialised tensors instead.
"""

keep_input: bool
data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array])

def __init__(
self,
array_dict: Optional[OrderedDict[str, Array]] = None,
keep_input: bool = False,
) -> None:
"""Construct a ParametersRecord object.
Parameters
----------
array_dict : Optional[OrderedDict[str, Array]]
A dictionary that stores serialized array-like or tensor-like objects.
keep_input : bool (default: False)
A boolean indicating whether parameters should be deleted from the input
dictionary immediately after adding them to the record. If False, the
dictionary passed to `set_parameters()` will be empty once exiting from that
function. This is the desired behaviour when working with very large
models/tensors/arrays. However, if you plan to continue working with your
parameters after adding it to the record, set this flag to True. When set
to True, the data is duplicated in memory.
"""
self.keep_input = keep_input
self.data = OrderedDict()
if array_dict:
self.set_parameters(array_dict)

def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None:
"""Add parameters to record.
Parameters
----------
array_dict : OrderedDict[str, Array]
A dictionary that stores serialized array-like or tensor-like objects.
"""
if any(not isinstance(k, str) for k in array_dict.keys()):
raise TypeError(f"Not all keys are of valid type. Expected {str}")
if any(not isinstance(v, Array) for v in array_dict.values()):
raise TypeError(f"Not all values are of valid type. Expected {Array}")

if self.keep_input:
# Copy
self.data = OrderedDict(array_dict)
else:
# Add entries to dataclass without duplicating memory
for key in list(array_dict.keys()):
self.data[key] = array_dict[key]
del array_dict[key]
13 changes: 5 additions & 8 deletions src/py/flwr/common/recordset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,10 @@
# ==============================================================================
"""RecordSet."""

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Dict


@dataclass
class ParametersRecord:
"""Parameters record."""
from .parametersrecord import ParametersRecord


@dataclass
Expand All @@ -37,9 +34,9 @@ class ConfigsRecord:
class RecordSet:
"""Definition of RecordSet."""

parameters: Dict[str, ParametersRecord] = {}
metrics: Dict[str, MetricsRecord] = {}
configs: Dict[str, ConfigsRecord] = {}
parameters: Dict[str, ParametersRecord] = field(default_factory=dict)
metrics: Dict[str, MetricsRecord] = field(default_factory=dict)
configs: Dict[str, ConfigsRecord] = field(default_factory=dict)

def set_parameters(self, name: str, record: ParametersRecord) -> None:
"""Add a ParametersRecord."""
Expand Down
147 changes: 147 additions & 0 deletions src/py/flwr/common/recordset_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""RecordSet tests."""


from typing import Callable, List, OrderedDict, Type, Union

import numpy as np
import pytest

from .parameter import ndarrays_to_parameters, parameters_to_ndarrays
from .parametersrecord import Array, ParametersRecord
from .recordset_utils import (
parameters_to_parametersrecord,
parametersrecord_to_parameters,
)
from .typing import NDArray, NDArrays, Parameters


def get_ndarrays() -> NDArrays:
"""Return list of NumPy arrays."""
arr1 = np.array([[1.0, 2.0], [3.0, 4], [5.0, 6.0]])
arr2 = np.eye(2, 7, 3)

return [arr1, arr2]


def ndarray_to_array(ndarray: NDArray) -> Array:
"""Represent NumPy ndarray as Array."""
return Array(
data=ndarray.tobytes(),
dtype=str(ndarray.dtype),
stype="numpy.ndarray.tobytes",
shape=list(ndarray.shape),
)


def test_ndarray_to_array() -> None:
"""Test creation of Array object from NumPy ndarray."""
shape = (2, 7, 9)
arr = np.eye(*shape)

array = ndarray_to_array(arr)

arr_ = np.frombuffer(buffer=array.data, dtype=array.dtype).reshape(array.shape)

assert np.array_equal(arr, arr_)


def test_parameters_to_array_and_back() -> None:
"""Test conversion between legacy Parameters and Array."""
ndarrays = get_ndarrays()

# Array represents a single array, unlike Paramters, which represent a
# list of arrays
ndarray = ndarrays[0]

parameters = ndarrays_to_parameters([ndarray])

array = Array(
data=parameters.tensors[0], dtype="", stype=parameters.tensor_type, shape=[]
)

parameters = Parameters(tensors=[array.data], tensor_type=array.stype)

ndarray_ = parameters_to_ndarrays(parameters=parameters)[0]

assert np.array_equal(ndarray, ndarray_)


def test_parameters_to_parametersrecord_and_back() -> None:
"""Test conversion between legacy Parameters and ParametersRecords."""
ndarrays = get_ndarrays()

parameters = ndarrays_to_parameters(ndarrays)

params_record = parameters_to_parametersrecord(parameters=parameters)

parameters_ = parametersrecord_to_parameters(params_record)

ndarrays_ = parameters_to_ndarrays(parameters=parameters_)

for arr, arr_ in zip(ndarrays, ndarrays_):
assert np.array_equal(arr, arr_)


def test_set_parameters_while_keeping_intputs() -> None:
"""Tests keep_input functionality in ParametersRecord."""
# Adding parameters to a record that doesn't erase entries in the input `array_dict`
p_record = ParametersRecord(keep_input=True)
array_dict = OrderedDict(
{str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())}
)
p_record.set_parameters(array_dict)

# Creating a second parametersrecord passing the same `array_dict` (not erased)
p_record_2 = ParametersRecord(array_dict)
assert p_record.data == p_record_2.data

# Now it should be empty (the second ParametersRecord wasn't flagged to keep it)
assert len(array_dict) == 0


def test_set_parameters_with_correct_types() -> None:
"""Test adding dictionary of Arrays to ParametersRecord."""
p_record = ParametersRecord()
array_dict = OrderedDict(
{str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())}
)
p_record.set_parameters(array_dict)


@pytest.mark.parametrize(
"key_type, value_fn",
[
(str, lambda x: x), # correct key, incorrect value
(str, lambda x: x.tolist()), # correct key, incorrect value
(int, ndarray_to_array), # incorrect key, correct value
(int, lambda x: x), # incorrect key, incorrect value
(int, lambda x: x.tolist()), # incorrect key, incorrect value
],
)
def test_set_parameters_with_incorrect_types(
key_type: Type[Union[int, str]],
value_fn: Callable[[NDArray], Union[NDArray, List[float]]],
) -> None:
"""Test adding dictionary of unsupported types to ParametersRecord."""
p_record = ParametersRecord()

array_dict = {
key_type(i): value_fn(ndarray) for i, ndarray in enumerate(get_ndarrays())
}

with pytest.raises(TypeError):
p_record.set_parameters(array_dict) # type: ignore
87 changes: 87 additions & 0 deletions src/py/flwr/common/recordset_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""RecordSet utilities."""


from typing import OrderedDict

from .parametersrecord import Array, ParametersRecord
from .typing import Parameters


def parametersrecord_to_parameters(
record: ParametersRecord, keep_input: bool = False
) -> Parameters:
"""Convert ParameterRecord to legacy Parameters.
Warning: Because `Arrays` in `ParametersRecord` encode more information of the
array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it
might not be possible to reconstruct such data structures from `Parameters` objects
alone. Additional information or metadta must be provided from elsewhere.
Parameters
----------
record : ParametersRecord
The record to be conveted into Parameters.
keep_input : bool (default: False)
A boolean indicating whether entries in the record should be deleted from the
input dictionary immediately after adding them to the record.
"""
parameters = Parameters(tensors=[], tensor_type="")

for key in list(record.data.keys()):
parameters.tensors.append(record.data[key].data)

if not keep_input:
del record.data[key]

return parameters


def parameters_to_parametersrecord(
parameters: Parameters, keep_input: bool = False
) -> ParametersRecord:
"""Convert legacy Parameters into a single ParametersRecord.
Because there is no concept of names in the legacy Parameters, arbitrary keys will
be used when constructing the ParametersRecord. Similarly, the shape and data type
won't be recorded in the Array objects.
Parameters
----------
parameters : Parameters
Parameters object to be represented as a ParametersRecord.
keep_input : bool (default: False)
A boolean indicating whether parameters should be deleted from the input
Parameters object (i.e. a list of serialized NumPy arrays) immediately after
adding them to the record.
"""
tensor_type = parameters.tensor_type

p_record = ParametersRecord()

num_arrays = len(parameters.tensors)
for idx in range(num_arrays):
if keep_input:
tensor = parameters.tensors[idx]
else:
tensor = parameters.tensors.pop(0)
p_record.set_parameters(
OrderedDict(
{str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])}
)
)

return p_record

0 comments on commit 1fcb147

Please sign in to comment.