-
Notifications
You must be signed in to change notification settings - Fork 906
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Daniel J. Beutel <[email protected]> Co-authored-by: Heng Pan <[email protected]>
- Loading branch information
1 parent
815f662
commit 1fcb147
Showing
4 changed files
with
349 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright 2024 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""ParametersRecord and Array.""" | ||
|
||
|
||
from dataclasses import dataclass, field | ||
from typing import List, Optional, OrderedDict | ||
|
||
|
||
@dataclass | ||
class Array: | ||
"""Array type. | ||
A dataclass containing serialized data from an array-like or tensor-like object | ||
along with some metadata about it. | ||
Parameters | ||
---------- | ||
dtype : str | ||
A string representing the data type of the serialised object (e.g. `np.float32`) | ||
shape : List[int] | ||
A list representing the shape of the unserialized array-like object. This is | ||
used to deserialize the data (depending on the serialization method) or simply | ||
as a metadata field. | ||
stype : str | ||
A string indicating the type of serialisation mechanism used to generate the | ||
bytes in `data` from an array-like or tensor-like object. | ||
data: bytes | ||
A buffer of bytes containing the data. | ||
""" | ||
|
||
dtype: str | ||
shape: List[int] | ||
stype: str | ||
data: bytes | ||
|
||
|
||
@dataclass | ||
class ParametersRecord: | ||
"""Parameters record. | ||
A dataclass storing named Arrays in order. This means that it holds entries as an | ||
OrderedDict[str, Array]. ParametersRecord objects can be viewed as an equivalent to | ||
PyTorch's state_dict, but holding serialised tensors instead. | ||
""" | ||
|
||
keep_input: bool | ||
data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array]) | ||
|
||
def __init__( | ||
self, | ||
array_dict: Optional[OrderedDict[str, Array]] = None, | ||
keep_input: bool = False, | ||
) -> None: | ||
"""Construct a ParametersRecord object. | ||
Parameters | ||
---------- | ||
array_dict : Optional[OrderedDict[str, Array]] | ||
A dictionary that stores serialized array-like or tensor-like objects. | ||
keep_input : bool (default: False) | ||
A boolean indicating whether parameters should be deleted from the input | ||
dictionary immediately after adding them to the record. If False, the | ||
dictionary passed to `set_parameters()` will be empty once exiting from that | ||
function. This is the desired behaviour when working with very large | ||
models/tensors/arrays. However, if you plan to continue working with your | ||
parameters after adding it to the record, set this flag to True. When set | ||
to True, the data is duplicated in memory. | ||
""" | ||
self.keep_input = keep_input | ||
self.data = OrderedDict() | ||
if array_dict: | ||
self.set_parameters(array_dict) | ||
|
||
def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None: | ||
"""Add parameters to record. | ||
Parameters | ||
---------- | ||
array_dict : OrderedDict[str, Array] | ||
A dictionary that stores serialized array-like or tensor-like objects. | ||
""" | ||
if any(not isinstance(k, str) for k in array_dict.keys()): | ||
raise TypeError(f"Not all keys are of valid type. Expected {str}") | ||
if any(not isinstance(v, Array) for v in array_dict.values()): | ||
raise TypeError(f"Not all values are of valid type. Expected {Array}") | ||
|
||
if self.keep_input: | ||
# Copy | ||
self.data = OrderedDict(array_dict) | ||
else: | ||
# Add entries to dataclass without duplicating memory | ||
for key in list(array_dict.keys()): | ||
self.data[key] = array_dict[key] | ||
del array_dict[key] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# Copyright 2024 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""RecordSet tests.""" | ||
|
||
|
||
from typing import Callable, List, OrderedDict, Type, Union | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from .parameter import ndarrays_to_parameters, parameters_to_ndarrays | ||
from .parametersrecord import Array, ParametersRecord | ||
from .recordset_utils import ( | ||
parameters_to_parametersrecord, | ||
parametersrecord_to_parameters, | ||
) | ||
from .typing import NDArray, NDArrays, Parameters | ||
|
||
|
||
def get_ndarrays() -> NDArrays: | ||
"""Return list of NumPy arrays.""" | ||
arr1 = np.array([[1.0, 2.0], [3.0, 4], [5.0, 6.0]]) | ||
arr2 = np.eye(2, 7, 3) | ||
|
||
return [arr1, arr2] | ||
|
||
|
||
def ndarray_to_array(ndarray: NDArray) -> Array: | ||
"""Represent NumPy ndarray as Array.""" | ||
return Array( | ||
data=ndarray.tobytes(), | ||
dtype=str(ndarray.dtype), | ||
stype="numpy.ndarray.tobytes", | ||
shape=list(ndarray.shape), | ||
) | ||
|
||
|
||
def test_ndarray_to_array() -> None: | ||
"""Test creation of Array object from NumPy ndarray.""" | ||
shape = (2, 7, 9) | ||
arr = np.eye(*shape) | ||
|
||
array = ndarray_to_array(arr) | ||
|
||
arr_ = np.frombuffer(buffer=array.data, dtype=array.dtype).reshape(array.shape) | ||
|
||
assert np.array_equal(arr, arr_) | ||
|
||
|
||
def test_parameters_to_array_and_back() -> None: | ||
"""Test conversion between legacy Parameters and Array.""" | ||
ndarrays = get_ndarrays() | ||
|
||
# Array represents a single array, unlike Paramters, which represent a | ||
# list of arrays | ||
ndarray = ndarrays[0] | ||
|
||
parameters = ndarrays_to_parameters([ndarray]) | ||
|
||
array = Array( | ||
data=parameters.tensors[0], dtype="", stype=parameters.tensor_type, shape=[] | ||
) | ||
|
||
parameters = Parameters(tensors=[array.data], tensor_type=array.stype) | ||
|
||
ndarray_ = parameters_to_ndarrays(parameters=parameters)[0] | ||
|
||
assert np.array_equal(ndarray, ndarray_) | ||
|
||
|
||
def test_parameters_to_parametersrecord_and_back() -> None: | ||
"""Test conversion between legacy Parameters and ParametersRecords.""" | ||
ndarrays = get_ndarrays() | ||
|
||
parameters = ndarrays_to_parameters(ndarrays) | ||
|
||
params_record = parameters_to_parametersrecord(parameters=parameters) | ||
|
||
parameters_ = parametersrecord_to_parameters(params_record) | ||
|
||
ndarrays_ = parameters_to_ndarrays(parameters=parameters_) | ||
|
||
for arr, arr_ in zip(ndarrays, ndarrays_): | ||
assert np.array_equal(arr, arr_) | ||
|
||
|
||
def test_set_parameters_while_keeping_intputs() -> None: | ||
"""Tests keep_input functionality in ParametersRecord.""" | ||
# Adding parameters to a record that doesn't erase entries in the input `array_dict` | ||
p_record = ParametersRecord(keep_input=True) | ||
array_dict = OrderedDict( | ||
{str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} | ||
) | ||
p_record.set_parameters(array_dict) | ||
|
||
# Creating a second parametersrecord passing the same `array_dict` (not erased) | ||
p_record_2 = ParametersRecord(array_dict) | ||
assert p_record.data == p_record_2.data | ||
|
||
# Now it should be empty (the second ParametersRecord wasn't flagged to keep it) | ||
assert len(array_dict) == 0 | ||
|
||
|
||
def test_set_parameters_with_correct_types() -> None: | ||
"""Test adding dictionary of Arrays to ParametersRecord.""" | ||
p_record = ParametersRecord() | ||
array_dict = OrderedDict( | ||
{str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} | ||
) | ||
p_record.set_parameters(array_dict) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"key_type, value_fn", | ||
[ | ||
(str, lambda x: x), # correct key, incorrect value | ||
(str, lambda x: x.tolist()), # correct key, incorrect value | ||
(int, ndarray_to_array), # incorrect key, correct value | ||
(int, lambda x: x), # incorrect key, incorrect value | ||
(int, lambda x: x.tolist()), # incorrect key, incorrect value | ||
], | ||
) | ||
def test_set_parameters_with_incorrect_types( | ||
key_type: Type[Union[int, str]], | ||
value_fn: Callable[[NDArray], Union[NDArray, List[float]]], | ||
) -> None: | ||
"""Test adding dictionary of unsupported types to ParametersRecord.""" | ||
p_record = ParametersRecord() | ||
|
||
array_dict = { | ||
key_type(i): value_fn(ndarray) for i, ndarray in enumerate(get_ndarrays()) | ||
} | ||
|
||
with pytest.raises(TypeError): | ||
p_record.set_parameters(array_dict) # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright 2024 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""RecordSet utilities.""" | ||
|
||
|
||
from typing import OrderedDict | ||
|
||
from .parametersrecord import Array, ParametersRecord | ||
from .typing import Parameters | ||
|
||
|
||
def parametersrecord_to_parameters( | ||
record: ParametersRecord, keep_input: bool = False | ||
) -> Parameters: | ||
"""Convert ParameterRecord to legacy Parameters. | ||
Warning: Because `Arrays` in `ParametersRecord` encode more information of the | ||
array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it | ||
might not be possible to reconstruct such data structures from `Parameters` objects | ||
alone. Additional information or metadta must be provided from elsewhere. | ||
Parameters | ||
---------- | ||
record : ParametersRecord | ||
The record to be conveted into Parameters. | ||
keep_input : bool (default: False) | ||
A boolean indicating whether entries in the record should be deleted from the | ||
input dictionary immediately after adding them to the record. | ||
""" | ||
parameters = Parameters(tensors=[], tensor_type="") | ||
|
||
for key in list(record.data.keys()): | ||
parameters.tensors.append(record.data[key].data) | ||
|
||
if not keep_input: | ||
del record.data[key] | ||
|
||
return parameters | ||
|
||
|
||
def parameters_to_parametersrecord( | ||
parameters: Parameters, keep_input: bool = False | ||
) -> ParametersRecord: | ||
"""Convert legacy Parameters into a single ParametersRecord. | ||
Because there is no concept of names in the legacy Parameters, arbitrary keys will | ||
be used when constructing the ParametersRecord. Similarly, the shape and data type | ||
won't be recorded in the Array objects. | ||
Parameters | ||
---------- | ||
parameters : Parameters | ||
Parameters object to be represented as a ParametersRecord. | ||
keep_input : bool (default: False) | ||
A boolean indicating whether parameters should be deleted from the input | ||
Parameters object (i.e. a list of serialized NumPy arrays) immediately after | ||
adding them to the record. | ||
""" | ||
tensor_type = parameters.tensor_type | ||
|
||
p_record = ParametersRecord() | ||
|
||
num_arrays = len(parameters.tensors) | ||
for idx in range(num_arrays): | ||
if keep_input: | ||
tensor = parameters.tensors[idx] | ||
else: | ||
tensor = parameters.tensors.pop(0) | ||
p_record.set_parameters( | ||
OrderedDict( | ||
{str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])} | ||
) | ||
) | ||
|
||
return p_record |