Skip to content

Commit

Permalink
Introduce NamedDataStore
Browse files Browse the repository at this point in the history
Pull Request resolved: #8587

Introduce NamedDataStore for weight sharing. See 'NamedBlobStore' in [RFC]

Rename 'NamedBlobStore' --> 'NamedDataStore' to mirror 'NamedDataMap' in the runtime.


The NamedDataStore exposes two methods:
- add_named_data: add a blob to the store
- get_named_data_store_output: return the contents of the store, to pass to serialization.

Invariants on the NamedDataStore
- Keys are unique regardless of whether they are in PTE or external file.
- Different keys can point to the same data.

NamedDataStore is used in D69764150. It's owned by the EdgeProgramManager.
ghstack-source-id: 268328940
@exported-using-ghexport

Differential Revision: [D69764094](https://our.internmc.facebook.com/intern/diff/D69764094/)
  • Loading branch information
lucylq authored and swolchok committed Feb 26, 2025
1 parent 65432b1 commit 9731a75
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 3 deletions.
1 change: 1 addition & 0 deletions exir/_serialize/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ runtime.python_library(
"_cord.py",
"_dataclass.py",
"_flatbuffer.py",
"_named_data_store.py",
"_program.py",
"_serialize.py",
"data_serializer.py",
Expand Down
183 changes: 183 additions & 0 deletions exir/_serialize/_named_data_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import hashlib
import math
from dataclasses import dataclass

# from dataclasses import dataclass
from typing import Dict, List, Optional


@dataclass
class BufferEntry:
"""A class to hold the buffer entries for serialization.
Attributes:
buffer: The buffer bytes.
alignment: The alignment of the buffer.
"""

buffer: bytes
alignment: int


@dataclass
class NamedDataStoreOutput:
"""
Holds named data for serialization.
Attributes:
buffers: A list of unique buffer entries.
pte_data: Contains data that is stored inside the PTE file. A mapping from
{key: buffer_index}.
external_data: Contains data that is stored external to the PTE. A mapping
from {filename: {key: buffer_index}}.
"""

buffers: List[BufferEntry]
pte_data: Dict[str, int]
external_data: Dict[str, Dict[str, int]]


class NamedDataStore:
"""
NamedDataStore manages the data that delegates want to share. Backends add
bytes to the store under a unique key. These bytes can be retrieved at
runtime using the same key with the NamedDataMap.
Note:
- Keys are unique in the data store, regardless of whether they are stored
in the PTE or externally.
- Multiple keys can point to the same buffer entry.
- The same data can be added multiple times and all keys will point to one
buffer. If a duplicate blob is added with a different alignment, the
lcm of the current and new alignment is taken for that blob.
"""

# List of unique blobs.
buffers: List[BufferEntry]
# Named data stored inside the PTE file. Map of {key: buffer_index}.
pte_data: Dict[str, int]
# Named data stored outside of the PTE file.
# Map of {filename: {key: buffer_index}}.
external_data: Dict[str, Dict[str, int]]

# Cache of the data hash for deduplication.
# Use a hash instead of the data as a key because a sha256 collision is
# unlikely, and the data may be large.
data_hash_to_buffer_idx: Dict[bytes, int]
# Cache of the key to buffer idx to ensure uniqueness.
# If a key is added multiple times, check the buffer idx to ensure that the
# data is identical too.
key_to_buffer_idx: Dict[str, int]

def __init__(self) -> None:
"""
Initializes a new NamedDataStore.
"""
self.buffers = []
self.pte_data = {}
self.external_data = {}

self.data_hash_to_buffer_idx = {}
self.key_to_buffer_idx = {}

def _add_named_data_to_map(
self,
key: str,
data: bytes,
alignment: int,
local_key_to_buffer_idx: Dict[str, int],
) -> None:
"""
Add data to a map and update the alignment. Ensure that the key-data
pair is unique.
- If the key exists, the data must be identical.
- If multiple unique keys exist for the same data, those keys should
point to the same buffer.
Args:
key (str): key associated with the data.
data (bytes): Bytes being requested to be serialized.
alignment (int): alignment for bytes to be serialized with.
local_key_to_buffer_idx (Dict[str, int]): map to add the data to.
Raises:
ValueError: when the key exists in the store, and corresponding data
is different.
"""
# Get data hash.
hashed = hashlib.sha256(data).digest()

# Check if the key exists.
buffer_idx = self.key_to_buffer_idx.get(key, -1)
if buffer_idx != -1:
# If the key exists, the corresponding data must be identical.
if self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx:
raise ValueError(
f"Duplicate key {key} with different data. "
f"Existing data: {self.buffers[buffer_idx].buffer}. "
f"New data: {data}."
)
self.buffers[buffer_idx].alignment = math.lcm(
self.buffers[buffer_idx].alignment, alignment
)
else:
# Key doesn't exist; check if the data exists.
buffer_idx = self.data_hash_to_buffer_idx.get(hashed, -1)
if buffer_idx != -1:
# The data exists; update the alignment.
self.buffers[buffer_idx].alignment = math.lcm(
self.buffers[buffer_idx].alignment, alignment
)
else:
# The data doesn't exist; add it to the data store.
buffer_idx = len(self.buffers)
self.buffers.append(BufferEntry(data, alignment))
self.data_hash_to_buffer_idx[hashed] = buffer_idx

# Add key to the map and the key cache.
local_key_to_buffer_idx[key] = buffer_idx
self.key_to_buffer_idx[key] = buffer_idx

def add_named_data(
self,
key: str,
data: bytes,
alignment: Optional[int] = 1,
external_tag: Optional[str] = None,
) -> None:
"""
Adds a named blob to the NamedDataStore.
Args:
key (str): key associated with the data.
data (bytes): Bytes being requested to be serialized.
alignment (int): alignment for bytes to be serialized with.
external (Optional[str]): the external filename that this data is saved to.
Raises:
ValueError: when the key exists in the store, and corresponding data
is different.
"""

# Set default alignment.
if alignment is None:
alignment = 1
if alignment <= 0:
raise ValueError(f"Alignment must be greater than 0, received {alignment}.")

if external_tag is None:
self._add_named_data_to_map(key, data, alignment, self.pte_data)
else:
self._add_named_data_to_map(
key, data, alignment, self.external_data.setdefault(external_tag, {})
)

def get_named_data_store_output(self) -> NamedDataStoreOutput:
# Clean up empty maps inside self.external_data
self.external_data = {k: v for k, v in self.external_data.items() if len(v) > 0}
return NamedDataStoreOutput(self.buffers, self.pte_data, self.external_data)
16 changes: 13 additions & 3 deletions exir/_serialize/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
oncall("executorch")

python_unittest(
name = "program",
name = "test_program",
srcs = [
"test_program.py",
],
Expand All @@ -15,7 +15,7 @@ python_unittest(
)

python_unittest(
name = "flatbuffer",
name = "test_flatbuffer",
srcs = [
"test_flatbuffer.py",
],
Expand All @@ -25,11 +25,21 @@ python_unittest(
)

python_unittest(
name = "cord",
name = "test_cord",
srcs = [
"test_cord.py",
],
deps = [
"//executorch/exir/_serialize:lib",
],
)

python_unittest(
name = "test_named_data_store",
srcs = [
"test_named_data_store.py",
],
deps = [
"//executorch/exir/_serialize:lib",
],
)
85 changes: 85 additions & 0 deletions exir/_serialize/test/test_named_data_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest

from executorch.exir._serialize._named_data_store import BufferEntry, NamedDataStore


class TestNamedDataStore(unittest.TestCase):
def test_add(self) -> None:
store = NamedDataStore()
store.add_named_data("key1", b"data1", None, None)
store.add_named_data("key2", b"data2", 16, "file1")
store.add_named_data("key3", b"data3", 16, "file1")

output = store.get_named_data_store_output()

self.assertEqual(len(output.buffers), 3)
self.assertEqual(output.buffers[0], BufferEntry(b"data1", 1))
self.assertEqual(output.buffers[1], BufferEntry(b"data2", 16))
self.assertEqual(output.buffers[2], BufferEntry(b"data3", 16))

self.assertEqual(len(output.pte_data), 1)
self.assertEqual(output.pte_data["key1"], 0)

self.assertEqual(len(output.external_data), 1)
self.assertEqual(len(output.external_data["file1"]), 2)
self.assertEqual(output.external_data["file1"]["key2"], 1)
self.assertEqual(output.external_data["file1"]["key3"], 2)

def test_add_duplicate_name_and_data(self) -> None:
store = NamedDataStore()
store.add_named_data("key", b"data", None, None)
store.add_named_data("key", b"data", None, None)

output = store.get_named_data_store_output()

self.assertEqual(len(output.buffers), 1)
self.assertEqual(output.buffers[0], BufferEntry(b"data", 1))

self.assertEqual(len(output.pte_data), 1)
self.assertEqual(output.pte_data["key"], 0)

self.assertEqual(len(output.external_data), 0)

def test_add_same_data_with_different_alignment(self) -> None:
store = NamedDataStore()
store.add_named_data("key", b"data", 3, None)
store.add_named_data("key1", b"data", 4, None)

output = store.get_named_data_store_output()

self.assertEqual(len(output.buffers), 1)
# Check that we take the LCM of the two alignments (3, 4) = 12
self.assertEqual(output.buffers[0], BufferEntry(b"data", 12))

self.assertEqual(len(output.pte_data), 2)
self.assertEqual(output.pte_data["key"], 0)
self.assertEqual(output.pte_data["key1"], 0)

self.assertEqual(len(output.external_data), 0)

def test_add_duplicate_key_fail(self) -> None:
store = NamedDataStore()
store.add_named_data("key", b"data", None, None)

# Cannot add item with the same key and different data.
self.assertRaises(ValueError, store.add_named_data, "key", b"data1", None, None)
self.assertRaises(
ValueError, store.add_named_data, "key", b"data1", 16, "file1"
)

output = store.get_named_data_store_output()

self.assertEqual(len(output.buffers), 1)
self.assertEqual(output.buffers[0], BufferEntry(b"data", 1))

self.assertEqual(len(output.pte_data), 1)
self.assertEqual(output.pte_data["key"], 0)
self.assertEqual(len(output.external_data), 0)

0 comments on commit 9731a75

Please sign in to comment.