diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index bed0b451d..048bc364c 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -92,6 +92,7 @@ def __init__( attn_head_count: int, attn_head_dim: int, seq_length: int, + shard_count: int = 1, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ): @@ -100,6 +101,7 @@ def __init__( self.attn_head_count = attn_head_count self.attn_head_dim = attn_head_dim self.seq_length = seq_length + self.shard_count = shard_count self.device = device self.dtype = dtype @@ -113,15 +115,109 @@ def allocate(self, *, bs: int) -> list[torch.Tensor]: Each tensor has shape: [bs, sl, attn_head_count, attn_head_dim] """ - return [ + allocations = [ torch.empty( - [bs, self.seq_length, self.attn_head_count, self.attn_head_dim], + [ + bs, + self.seq_length, + self.attn_head_count, + self.attn_head_dim, + ], dtype=self.dtype, device=self.device, ) for _ in range(2 * self.transformer_block_count) ] + if self.shard_count == 1: + return allocations + + return [ + ops.reshard_split(allocation, dim=2, count=self.shard_count) + for allocation in allocations + ] + + def read( + self, + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + *, + read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + transformer_block_index: int, + seq_len: int, + page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, + ): + """Reads cache partitions from the page table for the given page_ids. + + Args: + state: State struct as returned from allocate(). + read_into_partitions: List of cache partitions to read into in-place. + transformer_block_index: The index of the transformer block accessing + the cache. + page_ids: Tensor of [bs, max_seqlen // block_pos_stride] of page ids + to access. + + Returns a tuple of cache partitions (i.e. k and v caches for the transformer + block), linearized. Note that this reference approach to reading by + materializing linearly may not be terribly efficient unless if the + compiler can fuse the gather. + """ + read_count = len(read_into_partitions) + reads = [] + for i in range(read_count): + reads.append( + state[transformer_block_index * read_count + i][:, :seq_len, :, :] + ) + + return tuple(reads) + + def write_timestep( + self, + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + # List of [bs, 1, attn_head_count, attn_head_dim] + cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + *, + transformer_block_index: int, + # [bs] + seq_positions: Union[torch.Tensor, ReplicatedTensor], + # [bs, max_seqlen // block_pos_stride] + page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, + ): + """Writes a single batched timestep across all cache partitions. + + Note that this internally loops over the batch size, which cannot be + dynamic. + """ + bs, _, _, _ = cache_partitions[0].shape + update_count = len(cache_partitions) + + for b in range(bs): + row_index = torch.tensor(b, dtype=torch.int64) + row_start_pos = seq_positions[row_index] + + for i, update in enumerate(cache_partitions): + cache = state[transformer_block_index * update_count + i] + cache.index_put_((row_index, row_start_pos), update[row_index, 0]) + + def write( + self, + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + *, + transformer_block_index: int, + page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, + ): + """Writes cache partitions from a linear layout to the page table. + + This is the inverse of the linear read. The same caveat applies if the + in-place scatter cannot be fused. + """ + update_count = len(cache_partitions) + + for idx, update_src in enumerate(cache_partitions): + cache_dest = state[transformer_block_index * update_count + idx] + _, batch_seq_len, _, _ = update_src.shape + cache_dest[:, :batch_seq_len, :, :] = update_src + class PagedKVCache(BaseKVCache): """Implementation of a KV cache on top of a 'page table'. @@ -238,24 +334,19 @@ def allocate( """Allocates tensor state for a page table for the given capacity in pages. """ + shards = [ + torch.empty( + [page_count, self.page_slab_flat_dim], + dtype=self.dtype, + device=self.device, + ) + for _ in range(self.shard_count) + ] + if self.shard_count == 1: - return [ - torch.empty( - [page_count, self.page_slab_flat_dim], - dtype=self.dtype, - device=self.device, - ) - ] - else: - shards = [ - torch.empty( - [page_count, self.page_slab_flat_dim], - dtype=self.dtype, - device=self.device, - ) - for _ in range(self.shard_count) - ] - return [SplitPrimitiveTensor(ts=shards, shard_dim=1)] + return shards + + return [SplitPrimitiveTensor(ts=shards, shard_dim=1)] def read( self, @@ -263,6 +354,7 @@ def read( *, read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], transformer_block_index: int, + seq_len: int, page_ids: Union[torch.Tensor, ReplicatedTensor], ): """Reads cache partitions from the page table for the given page_ids. @@ -331,6 +423,8 @@ def read_cache_partition( for index, read_into_partition in enumerate(read_into_partitions): read_cache_partition(index, read_into_partition) + return tuple([p[:, :seq_len, :] for p in read_into_partitions]) + def write_timestep( self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]], diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 59ed7b43a..958dc954e 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -113,27 +113,16 @@ def forward( # Full sequence length. kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride - if self.cache.is_paged: - xk, xv = self.transact_cache_paged( - xk_cache_update=xk, - xv_cache_update=xv, - seq_block_ids=seq_block_ids, - kv_seq_len=kv_seq_len, - start_positions=start_positions, - cache_state=cache_state, - xk_temp=xk_temp, - xv_temp=xv_temp, - ) - elif self.cache.is_direct: - xk, xv = self.transact_cache_direct( - xk_cache_update=xk, - xv_cache_update=xv, - start_positions=start_positions, - kv_seq_len=kv_seq_len, - cache_state=cache_state, - ) - else: - raise NotImplementedError(f"Unsupported KV cache type: {type(self.cache)}") + xk, xv = self.transact_cache( + xk_cache_update=xk, + xv_cache_update=xv, + seq_block_ids=seq_block_ids, + kv_seq_len=kv_seq_len, + start_positions=start_positions, + cache_state=cache_state, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) # Expand kv heads for GQA. gqa_n_rep = self.head_count // self.head_count_kv @@ -202,58 +191,20 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: h = h + attn_output return h - def transact_cache_direct( - self, - *, - cache_state: list[torch.Tensor], - xk_cache_update: torch.Tensor, - xv_cache_update: torch.Tensor, - kv_seq_len: int, - start_positions: Optional[torch.Tensor] = None, - ): - bs, batch_seq_len, _, _ = xk_cache_update.shape - cache_k = cache_state[self.block_index * 2] - cache_v = cache_state[self.block_index * 2 + 1] - - if start_positions is None: - # Prefill. Write the entire cache. - cache_k[:, :batch_seq_len] = xk_cache_update - cache_v[:, :batch_seq_len] = xv_cache_update - return xk_cache_update, xv_cache_update - else: - # Decode. Write a single timestep. - # TODO: This needs to be reworked with index ops. - assert xk_cache_update.shape[1] == 1 - assert xv_cache_update.shape[1] == 1 - for b in range(bs): - # Make a tensor because indices must be all tensors, so we can avoid - # doing start_positions[row_index].item(), which generates a lot of SymInts. - row_index = torch.tensor( - b, dtype=torch.int64, device=xk_cache_update.device - ) - row_start_pos = start_positions[row_index] - cache_k.index_put_( - (row_index, row_start_pos), xk_cache_update[row_index, 0] - ) - cache_v.index_put_( - (row_index, row_start_pos), xv_cache_update[row_index, 0] - ) - return cache_k[:, :kv_seq_len], cache_v[:, :kv_seq_len] - - def transact_cache_paged( + def transact_cache( self, *, xk_cache_update: torch.Tensor, xv_cache_update: torch.Tensor, cache_state: list[torch.Tensor], # [bs, batch_seq_len // block_seq_stride] - seq_block_ids: torch.Tensor, + seq_block_ids: Optional[torch.Tensor], kv_seq_len: int, start_positions: Optional[torch.Tensor] = None, xk_temp: Optional[torch.Tensor] = None, xv_temp: Optional[torch.Tensor] = None, ): - cache = self.cache.paged + cache = self.cache # Manage the cache. if start_positions is None: # Prefill: Write the entire cache. @@ -264,46 +215,45 @@ def transact_cache_paged( page_ids=seq_block_ids, ) return xk_cache_update, xv_cache_update - else: - # Decode at ragged start positions. - # We need to initialize/read the K/V from the cache for the whole - # sequence. Note that at this point, it is possible to fork and - # use a memory efficient attention kernel that can do indirect - # reads, skipping this materialization. This path is taken for - # a decode step. - assert xk_temp is not None and xv_temp is not None - assert xk_cache_update.shape[1] == 1 - assert xv_cache_update.shape[1] == 1 - assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride - - # Write our one updated cache row into the cache. - cache.write_timestep( - cache_state, - cache_partitions=[ - xk_cache_update, - xv_cache_update, - ], - transformer_block_index=self.block_index, - seq_positions=start_positions, - page_ids=seq_block_ids, - ) - # Restore from the cache. - cache.read( - cache_state, - read_into_partitions=[ - xk_temp[:, 0:kv_seq_len, ...], - xv_temp[:, 0:kv_seq_len, ...], - ], - transformer_block_index=self.block_index, - page_ids=seq_block_ids, - ) + # Decode at ragged start positions. + # We need to initialize/read the K/V from the cache for the whole + # sequence. Note that at this point, it is possible to fork and + # use a memory efficient attention kernel that can do indirect + # reads, skipping this materialization. This path is taken for + # a decode step. + assert xk_temp is not None and xv_temp is not None + assert xk_cache_update.shape[1] == 1 + assert xv_cache_update.shape[1] == 1 + assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride + + # Write our one updated cache row into the cache. + cache.write_timestep( + cache_state, + cache_partitions=[ + xk_cache_update, + xv_cache_update, + ], + transformer_block_index=self.block_index, + seq_positions=start_positions, + page_ids=seq_block_ids, + ) + + # Restore from the cache. + xk, xv = cache.read( + cache_state, + read_into_partitions=[ + xk_temp[:, 0:kv_seq_len, ...], + xv_temp[:, 0:kv_seq_len, ...], + ], + transformer_block_index=self.block_index, + page_ids=seq_block_ids, + seq_len=kv_seq_len, + ) - # For computation, we create a subview of the xk/xv tensors to have - # a sequence length covering the blocked size. This must include - # the newly added row (the caller is responsible for ensuring that - # every block has at least one row left). We'll compute on this - # ragged view and use an appropriate mask. - xk = xk_temp[:, 0:kv_seq_len, ...] - xv = xv_temp[:, 0:kv_seq_len, ...] - return xk, xv + # For computation, we create a subview of the xk/xv tensors to have + # a sequence length covering the blocked size. This must include + # the newly added row (the caller is responsible for ensuring that + # every block has at least one row left). We'll compute on this + # ragged view and use an appropriate mask. + return xk, xv diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 226ffd777..7b3d2e04b 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -990,7 +990,7 @@ def _is_slicing_split_dim(self, key): else: # Any other collection is a indexing only dimension 0. return self.shard_dim == 0 - if len(key) < self.shard_dim: + if len(key) <= self.shard_dim: return False if not isinstance(key[self.shard_dim], slice): return True diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py new file mode 100644 index 000000000..65b42c986 --- /dev/null +++ b/sharktank/tests/layers/kv_cache_test.py @@ -0,0 +1,502 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest + +import torch + +from sharktank.ops import replicate, reshard_split, unshard +from sharktank.layers import * +from sharktank.types import * + + +def test_direct(): + bs = 4 + seq_length = 24 + attn_head_count = 4 + attn_head_dim = 16 + transformer_block_count = 4 + cache = DirectKVCache( + block_seq_stride=4, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + seq_length=seq_length, + dtype=torch.float32, + device=None, + ) + + allocation = cache.allocate(bs=bs) + allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] + + write_seq_length = seq_length - 5 + + # Write a prefill in: + write_ones = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 1.0, dtype=torch.float32 + ) + write_twos = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 2.0, dtype=torch.float32 + ) + cache.write( + allocation, cache_partitions=[write_ones, write_twos], transformer_block_index=1 + ) + + # Check the written values have updated: + read_empty = [ + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + ) + torch.testing.assert_close(write_ones, read_back[0]) + torch.testing.assert_close(write_twos, read_back[1]) + + # Check the others are still zero: + for i in range(transformer_block_count): + if i == 1: + continue + read_ones = [ + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_ones = cache.read( + allocation, + read_into_partitions=read_ones, + transformer_block_index=i, + seq_len=write_seq_length, + ) + torch.testing.assert_close(read_ones[0], torch.full(read_ones[0].shape, 0.0)) + torch.testing.assert_close(read_ones[1], torch.full(read_ones[0].shape, 0.0)) + + # Write timestep + write_threes = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32 + ) + write_fours = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32 + ) + write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64) + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + ) + + read_empty = [ + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length + 1, + ) + + check_concat_0 = torch.concat([write_ones, write_threes], dim=1) + check_concat_1 = torch.concat([write_twos, write_fours], dim=1) + + torch.testing.assert_close(check_concat_0, read_back[0]) + torch.testing.assert_close(check_concat_1, read_back[1]) + + +def test_sharded_direct(): + bs = 4 + seq_length = 24 + attn_head_count = 8 + attn_head_dim = 16 + transformer_block_count = 4 + shard_count = 4 + cache = DirectKVCache( + block_seq_stride=4, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + seq_length=seq_length, + shard_count=shard_count, + dtype=torch.float32, + device=None, + ) + + allocation = cache.allocate(bs=bs) + # allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] + + write_seq_length = seq_length - 5 + + # Write a prefill in: + write_ones = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 1.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + write_twos = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 2.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + cache.write( + allocation, cache_partitions=[write_ones, write_twos], transformer_block_index=1 + ) + + # Check the written values have updated: + read_empty = [ + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + ) + torch.testing.assert_close(unshard(write_ones), unshard(read_back[0])) + torch.testing.assert_close(unshard(write_twos), unshard(read_back[1])) + + # Write timestep + write_threes = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32), + dim=2, + count=shard_count, + ) + write_fours = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32), + dim=2, + count=shard_count, + ) + + write_pos = replicate( + torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count + ) + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + ) + + read_empty = [ + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length + 1, + ) + + check_concat_0 = torch.concat([unshard(write_ones), unshard(write_threes)], dim=1) + check_concat_1 = torch.concat([unshard(write_twos), unshard(write_fours)], dim=1) + + torch.testing.assert_close(check_concat_0, unshard(read_back[0])) + torch.testing.assert_close(check_concat_1, unshard(read_back[1])) + + +def test_paged(): + bs = 4 + seq_length = 24 + attn_head_count = 4 + attn_head_dim = 16 + transformer_block_count = 4 + block_seq_stride = 4 + cache = PagedKVCache( + block_seq_stride=block_seq_stride, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + dtype=torch.float32, + device=None, + ) + + write_seq_length = seq_length - 4 + page_count = bs * seq_length // block_seq_stride + page_ids = torch.arange(page_count, dtype=torch.int64) + page_ids = page_ids.view(bs, seq_length // block_seq_stride) + write_page_ids = page_ids[:, : write_seq_length // block_seq_stride] + + allocation = cache.allocate(page_count=page_count) + allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] + + # Write a prefill in: + write_ones = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 1.0, dtype=torch.float32 + ) + write_twos = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 2.0, dtype=torch.float32 + ) + + cache.write( + allocation, + cache_partitions=[write_ones, write_twos], + transformer_block_index=1, + page_ids=write_page_ids, + ) + + # Check the written values have updated: + read_empty = [ + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + page_ids=write_page_ids, + ) + torch.testing.assert_close(write_ones, read_back[0]) + torch.testing.assert_close(write_twos, read_back[1]) + + # Check the others are still zero: + for i in range(transformer_block_count): + if i == 1: + continue + read_ones = [ + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_ones = cache.read( + allocation, + read_into_partitions=read_ones, + transformer_block_index=i, + seq_len=write_seq_length, + page_ids=write_page_ids, + ) + torch.testing.assert_close(read_ones[0], torch.full(read_ones[0].shape, 0.0)) + torch.testing.assert_close(read_ones[1], torch.full(read_ones[0].shape, 0.0)) + + # Write timestep + write_threes = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32 + ) + write_fours = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32 + ) + write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64) + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + page_ids=page_ids, + ) + + read_empty = [ + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length + 1, + page_ids=page_ids, + ) + + check_concat_0 = torch.concat([write_ones, write_threes], dim=1) + check_concat_1 = torch.concat([write_twos, write_fours], dim=1) + + torch.testing.assert_close(check_concat_0, read_back[0]) + torch.testing.assert_close(check_concat_1, read_back[1]) + + +def test_sharded_paged(): + bs = 4 + seq_length = 24 + attn_head_count = 8 + attn_head_dim = 16 + transformer_block_count = 4 + block_seq_stride = 4 + shard_count = 4 + cache = PagedKVCache( + block_seq_stride=block_seq_stride, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + shard_count=shard_count, + dtype=torch.float32, + device=None, + ) + + write_seq_length = seq_length - 4 + page_count = bs * seq_length // block_seq_stride + page_ids = torch.arange(page_count, dtype=torch.int64) + page_ids = page_ids.view(bs, seq_length // block_seq_stride) + page_ids = replicate(page_ids, shard_count) + write_page_ids = page_ids[:, : write_seq_length // block_seq_stride] + + allocation = cache.allocate(page_count=page_count) + + # Write a prefill in: + write_ones = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 1.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + write_twos = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 2.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + cache.write( + allocation, + cache_partitions=[write_ones, write_twos], + transformer_block_index=1, + page_ids=write_page_ids, + ) + + # Check the written values have updated: + empty_k = reshard_split( + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + dim=2, + count=shard_count, + ) + + empty_v = reshard_split( + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + dim=2, + count=shard_count, + ) + + read_empty = [empty_k, empty_v] + + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + page_ids=write_page_ids, + ) + torch.testing.assert_close(unshard(write_ones), unshard(read_back[0])) + torch.testing.assert_close(unshard(write_twos), unshard(read_back[1])) + + # Write timestep + write_threes = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32), + dim=2, + count=shard_count, + ) + + write_fours = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32), + dim=2, + count=shard_count, + ) + + write_pos = replicate( + torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count + ) + + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + page_ids=page_ids, + ) + + empty_k = reshard_split( + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + empty_v = reshard_split( + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + read_back = cache.read( + allocation, + read_into_partitions=[empty_k, empty_v], + transformer_block_index=1, + seq_len=write_seq_length + 1, + page_ids=page_ids, + ) + + check_concat_0 = torch.concat([unshard(write_ones), unshard(write_threes)], dim=1) + check_concat_1 = torch.concat([unshard(write_twos), unshard(write_fours)], dim=1) + + torch.testing.assert_close(check_concat_0, unshard(read_back[0])) + torch.testing.assert_close(check_concat_1, unshard(read_back[1])) diff --git a/sharktank/tests/layers/sharded_paged_kv_cache_test.py b/sharktank/tests/layers/sharded_paged_kv_cache_test.py index d58874f25..d7b6a0b33 100644 --- a/sharktank/tests/layers/sharded_paged_kv_cache_test.py +++ b/sharktank/tests/layers/sharded_paged_kv_cache_test.py @@ -123,6 +123,7 @@ def testRead(self): read_into_partitions=read_into_partitions, transformer_block_index=transformer_block_index, page_ids=page_ids, + seq_len=self.block_seq_len * self.block_seq_stride, ) sharded_read_into_partitions = deepcopy( [ @@ -136,6 +137,7 @@ def testRead(self): read_into_partitions=sharded_read_into_partitions, transformer_block_index=transformer_block_index, page_ids=sharded_page_ids, + seq_len=self.block_seq_len * self.block_seq_stride, ) for unsharded, sharded in zip( read_into_partitions, sharded_read_into_partitions diff --git a/shortfin/python/lib_ext.cc b/shortfin/python/lib_ext.cc index dca171d67..c73bf5a93 100644 --- a/shortfin/python/lib_ext.cc +++ b/shortfin/python/lib_ext.cc @@ -248,8 +248,9 @@ void PyAddProgramInvocationArg(py::capsule &inv_capsule, py::handle arg) { } local::ProgramInvocation::Future PyFunctionCall(local::ProgramFunction &self, - py::args args) { - auto inv = self.CreateInvocation(); + py::args args, + local::Fiber &fiber) { + auto inv = self.CreateInvocation(fiber.shared_from_this()); py::capsule inv_capsule(inv.get()); for (py::handle arg : args) { PyAddProgramInvocationArg(inv_capsule, arg); @@ -592,13 +593,14 @@ void BindLocal(py::module_ &m) { py::class_(m, "Program") .def(py::new_([](std::span modules, - local::Fiber &fiber, bool trace_execution) { + std::vector devices, + bool trace_execution) { local::Program::Options options; + options.devices = devices; options.trace_execution = trace_execution; - return local::Program::Load(fiber.shared_from_this(), modules, - std::move(options)); + return local::Program::Load(modules, std::move(options)); }), - py::arg("modules"), py::arg("fiber"), py::kw_only(), + py::arg("modules"), py::kw_only(), py::arg("devices"), py::arg("trace_execution") = false) .def_prop_ro("exports", &local::Program::exports) .def("lookup_function", &local::Program::LookupRequiredFunction) @@ -607,9 +609,14 @@ void BindLocal(py::module_ &m) { .def_prop_ro("name", &local::ProgramFunction::name) .def_prop_ro("calling_convention", &local::ProgramFunction::calling_convention) - .def("invocation", &local::ProgramFunction::CreateInvocation, - DOCSTRING_PROGRAM_FUNCTION_INVOCATION) - .def("__call__", PyFunctionCall, py::arg("args")) + .def( + "invocation", + [](local::ProgramFunction &self, local::Fiber &fiber) { + return self.CreateInvocation(fiber.shared_from_this()); + }, + DOCSTRING_PROGRAM_FUNCTION_INVOCATION) + .def("__call__", PyFunctionCall, py::arg("args"), py::kw_only(), + py::arg("fiber")) .def("__repr__", &local::ProgramFunction::to_s); py::class_(m, "ProgramModule") .def_prop_ro("exports", &local::ProgramModule::exports) @@ -718,8 +725,17 @@ void BindLocal(py::module_ &m) { }; py::class_(m, "Fiber") .def("__repr__", &local::Fiber::to_s) - .def_prop_ro("raw_devices", &local::Fiber::raw_devices, - py::rv_policy::reference_internal) + .def_prop_ro( + "raw_devices", + [](local::Fiber &self) { + std::vector devices; + devices.reserve(self.raw_devices().size()); + for (auto it : self.raw_devices()) { + devices.push_back(it.second); + } + return devices; + }, + py::rv_policy::reference_internal) .def( "raw_device", [](local::Fiber &self, int index) { return self.raw_device(index); }, diff --git a/shortfin/src/shortfin/local/fiber.h b/shortfin/src/shortfin/local/fiber.h index dd63b30f4..afd65b346 100644 --- a/shortfin/src/shortfin/local/fiber.h +++ b/shortfin/src/shortfin/local/fiber.h @@ -146,6 +146,23 @@ class SHORTFIN_API Fiber : public std::enable_shared_from_this { std::unordered_map device_class_count_; // Ordered devices named as ``. std::vector> devices_; + + // Program isolation control. + // This data structure is manipulated by APIs on the Program class hierarchy. + // It maps a parent context pointer to an isolate accounting struct. This + // struct contains a strong reference to the parent_context and a vector + // of fork contexts. For PER_FIBER invocations, there will only ever be either + // zero or one fork_contexts: when no calls have been issued there will be one + // and if a call is outstanding, there will be zero. This is used to guard + // concurrent access. For PER_CALL invocations, there will be as many + // fork_contexts as are needed to satisfy the peak number of calls in flight + // at any time. + // The program_isolate_mu_ must be held to manipulate the accounting structs. + iree::slim_mutex program_isolate_mu_; + std::unordered_map> + program_isolates_; + friend struct detail::ProgramIsolate; }; } // namespace shortfin::local diff --git a/shortfin/src/shortfin/local/program.cc b/shortfin/src/shortfin/local/program.cc index 2f2d95ab4..038cd106a 100644 --- a/shortfin/src/shortfin/local/program.cc +++ b/shortfin/src/shortfin/local/program.cc @@ -36,12 +36,12 @@ void GetVmModuleExports(iree_vm_module_t *vm_module, // -------------------------------------------------------------------------- // ProgramFunction::ProgramFunction( - std::shared_ptr fiber, iree::vm_context_ptr vm_context, - iree_vm_function_t vm_function, + iree::vm_context_ptr vm_context, iree_vm_function_t vm_function, + ProgramIsolation isolation, std::optional invocation_model) - : fiber_(std::move(fiber)), - vm_context_(std::move(vm_context)), + : vm_context_(std::move(vm_context)), vm_function_(vm_function), + isolation_(isolation), invocation_model_(invocation_model ? *invocation_model : GetInvocationModelFromFunction(vm_function)) {} @@ -73,9 +73,19 @@ std::string_view ProgramFunction::calling_convention() const { iree_vm_function_signature(&vm_function_).calling_convention); } -ProgramInvocation::Ptr ProgramFunction::CreateInvocation() { - return ProgramInvocation::New(fiber_, vm_context_, vm_function_, - invocation_model_); +ProgramInvocation::Ptr ProgramFunction::CreateInvocation( + std::shared_ptr fiber) { + // Low-overhead NONE isolation handling (saves some ref-count twiddling). + if (isolation_ == ProgramIsolation::NONE) { + return ProgramInvocation::New(std::move(fiber), vm_context_, vm_function_, + invocation_model_, /*isolate=*/nullptr); + } + + // Create an isolated invocation. + auto [isolated_context, isolate] = + detail::ProgramIsolate::AcquireIsolate(*fiber, vm_context_, isolation_); + return ProgramInvocation::New(std::move(fiber), std::move(isolated_context), + vm_function_, invocation_model_, isolate); } std::string ProgramFunction::to_s() const { @@ -106,7 +116,7 @@ ProgramModule ProgramModule::Load(System &system, system.vm_instance(), contents.const_buffer(), contents.deallocator(), system.host_allocator(), module.for_output())); contents.release(); // Must be invoked on success path only. - return ProgramModule(std::move(module)); + return ProgramModule(system.shared_from_this(), std::move(module)); } ProgramModule ProgramModule::ParameterProvider( @@ -126,7 +136,7 @@ ProgramModule ProgramModule::ParameterProvider( SHORTFIN_THROW_IF_ERROR(iree_io_parameters_module_create( system.vm_instance(), providers.size(), providers.data(), system.host_allocator(), module.for_output())); - return ProgramModule(std::move(module)); + return ProgramModule(system.shared_from_this(), std::move(module)); } std::string_view ProgramModule::name() const { @@ -158,14 +168,27 @@ std::vector ProgramModule::exports() const { // Program // -------------------------------------------------------------------------- // -Program Program::Load(std::shared_ptr fiber, - std::span modules, Options options) { +Program Program::Load(std::span modules, + Options &&options) { std::vector all_modules; std::vector raw_devices; + System *system = nullptr; // By default, bind all devices in the fiber in order to the program. - for (auto &it : fiber->raw_devices()) { - raw_devices.push_back(it.second->hal_device()); + for (auto &it : options.devices) { + raw_devices.push_back(it->hal_device()); + } + + for (auto &mod : modules) { + if (system && &mod.system() != system) { + throw std::invalid_argument( + "Cannot create Program from modules loaded from multiple system " + "instances"); + } + system = &mod.system(); + } + if (!system) { + throw std::invalid_argument("Cannot create Program with no modules"); } // Add a HAL module. @@ -177,12 +200,11 @@ Program Program::Load(std::shared_ptr fiber, // functionality (or module versions; iree_vm_module_dependency_t has the // minimum version required so you can switch between them, and whether they // are optional/required). - auto &system = fiber->system(); iree::vm_module_ptr hal_module; - SHORTFIN_THROW_IF_ERROR( - iree_hal_module_create(system.vm_instance(), raw_devices.size(), - raw_devices.data(), IREE_HAL_MODULE_FLAG_NONE, - system.host_allocator(), hal_module.for_output())); + SHORTFIN_THROW_IF_ERROR(iree_hal_module_create( + system->vm_instance(), raw_devices.size(), raw_devices.data(), + IREE_HAL_MODULE_FLAG_NONE, system->host_allocator(), + hal_module.for_output())); all_modules.push_back(hal_module); // Add explicit modules. @@ -195,10 +217,10 @@ Program Program::Load(std::shared_ptr fiber, iree_vm_context_flags_t flags = IREE_VM_CONTEXT_FLAG_CONCURRENT; if (options.trace_execution) flags |= IREE_VM_CONTEXT_FLAG_TRACE_EXECUTION; SHORTFIN_THROW_IF_ERROR(iree_vm_context_create_with_modules( - system.vm_instance(), flags, all_modules.size(), all_modules.data(), - system.host_allocator(), context.for_output())); + system->vm_instance(), flags, all_modules.size(), all_modules.data(), + system->host_allocator(), context.for_output())); - return Program(std::move(fiber), std::move(context)); + return Program(std::move(context), options.isolation); } std::optional Program::LookupFunction(std::string_view name) { @@ -217,7 +239,7 @@ std::optional Program::LookupFunction(std::string_view name) { // TODO: Torch import is not setting the coarse-fences abi.model on // its functions. Get it from there instead of just assuming based on // name. - return ProgramFunction(fiber_, vm_context_, f, + return ProgramFunction(vm_context_, f, isolation_, ProgramInvocationModel::COARSE_FENCES); } else if (!iree_status_is_not_found(status)) { SHORTFIN_THROW_IF_ERROR(status); @@ -229,7 +251,7 @@ std::optional Program::LookupFunction(std::string_view name) { vm_context_, to_iree_string_view(name), &f); if (iree_status_is_not_found(status)) return {}; SHORTFIN_THROW_IF_ERROR(status); - return ProgramFunction(fiber_, vm_context_, f); + return ProgramFunction(vm_context_, f, isolation_); } ProgramFunction Program::LookupRequiredFunction(std::string_view name) { @@ -260,6 +282,15 @@ std::vector Program::exports() const { return results; } +void Program::PrepareIsolate(Fiber &fiber) { + if (isolation_ == ProgramIsolation::NONE) return; + auto [context, isolate] = + detail::ProgramIsolate::AcquireIsolate(fiber, vm_context_, isolation_); + if (isolate) { + detail::ProgramIsolate::ReleaseIsolate(fiber, std::move(context), isolate); + } +} + // -------------------------------------------------------------------------- // // ProgramInvocation // -------------------------------------------------------------------------- // @@ -287,18 +318,23 @@ void ProgramInvocation::Deleter::operator()(ProgramInvocation *inst) { } ProgramInvocation::ProgramInvocation() = default; -ProgramInvocation::~ProgramInvocation() { - if (!scheduled()) { - // This instance was dropped on the floor before scheduling. - // Clean up the initialization parameters. - iree::vm_context_ptr drop = - iree::vm_context_ptr::steal_reference(state.params.context); +ProgramInvocation::~ProgramInvocation() { ReleaseContext(); } + +void ProgramInvocation::ReleaseContext() { + if (vm_context_) { + if (isolate_) { + detail::ProgramIsolate::ReleaseIsolate(*fiber_, std::move(vm_context_), + isolate_); + } else { + vm_context_.reset(); + } } } ProgramInvocation::Ptr ProgramInvocation::New( std::shared_ptr fiber, iree::vm_context_ptr vm_context, - iree_vm_function_t &vm_function, ProgramInvocationModel invocation_model) { + iree_vm_function_t &vm_function, ProgramInvocationModel invocation_model, + detail::ProgramIsolate *isolate) { auto sig = iree_vm_function_signature(&vm_function); iree_host_size_t arg_count; iree_host_size_t result_count; @@ -337,8 +373,8 @@ ProgramInvocation::Ptr ProgramInvocation::New( static_cast(inst_storage.release())), Deleter()); inst->fiber_ = std::move(fiber); - inst->state.params.context = - vm_context.release(); // Ref transfer to ProgramInvocation. + inst->vm_context_ = std::move(vm_context); + inst->isolate_ = isolate; inst->state.params.function = vm_function; inst->state.params.invocation_model = invocation_model; inst->result_list_ = result_list; @@ -421,7 +457,6 @@ ProgramInvocation::Future ProgramInvocation::Invoke( Params params = invocation->state.params; auto schedule = [](ProgramInvocation *raw_invocation, Worker *worker, - iree_vm_context_t *owned_context, iree_vm_function_t function, ProgramInvocationModel invocation_model, std::optional failure_future) { @@ -440,6 +475,7 @@ ProgramInvocation::Future ProgramInvocation::Invoke( ProgramInvocation::Ptr invocation( static_cast(user_data)); ProgramInvocation *raw_invocation = invocation.get(); + raw_invocation->ReleaseContext(); if (iree_status_is_ok(status)) { raw_invocation->future_->set_result(std::move(invocation)); } else { @@ -469,7 +505,7 @@ ProgramInvocation::Future ProgramInvocation::Invoke( if (iree_status_is_ok(status)) { status = iree_vm_async_invoke(worker->loop(), &invocation->state.async_invoke_state, - owned_context, function, + invocation->vm_context_.get(), function, /*flags=*/IREE_VM_INVOCATION_FLAG_NONE, /*policy=*/nullptr, /*inputs=*/invocation->arg_list(), @@ -478,10 +514,6 @@ ProgramInvocation::Future ProgramInvocation::Invoke( /*user_data=*/invocation.get()); } - // Regardless of status, the context reference we were holding is no - // longer needed. Drop it on the floor. - iree::vm_context_ptr::steal_reference(owned_context); - // On success, then the complete callback takes ownership of the // invocation, so we release it here and return. We have to treat // the invocation as possibly deallocated at this point, since the @@ -490,9 +522,11 @@ ProgramInvocation::Future ProgramInvocation::Invoke( invocation.release(); } else if (failure_future) { // Requested to set any failure on the future. + invocation->ReleaseContext(); failure_future->set_failure(status); } else { // Synchronous: just throw. + invocation->ReleaseContext(); SHORTFIN_THROW_IF_ERROR(status); } }; @@ -504,14 +538,13 @@ ProgramInvocation::Future ProgramInvocation::Invoke( if (&worker == Worker::GetCurrent()) { // On the same worker: fast-path directly to the loop. - schedule(invocation.release(), &worker, params.context, params.function, + schedule(invocation.release(), &worker, params.function, params.invocation_model, /*failure_future=*/{}); } else { // Cross worker coordination: submit an external task to bootstrap. - auto bound_schedule = - std::bind(schedule, invocation.release(), &worker, params.context, - params.function, params.invocation_model, - /*failure_future=*/fork_future); + auto bound_schedule = std::bind(schedule, invocation.release(), &worker, + params.function, params.invocation_model, + /*failure_future=*/fork_future); worker.CallThreadsafe(bound_schedule); } @@ -623,4 +656,69 @@ void StaticProgramParameters::Load(std::filesystem::path file_path, to_iree_string_view(options.format), file_handle.get(), index_.get())); } +// -------------------------------------------------------------------------- // +// ProgramIsolate +// -------------------------------------------------------------------------- // + +std::pair +detail::ProgramIsolate::AcquireIsolate(Fiber &fiber, + iree::vm_context_ptr root_context, + ProgramIsolation isolation) { + assert(isolation != ProgramIsolation::NONE && + "cannot AcquireIsolate when isolation == NONE"); + // Some isolation required. + detail::ProgramIsolate *isolate = nullptr; + { + iree::slim_mutex_lock_guard lock(fiber.program_isolate_mu_); + auto found_it = fiber.program_isolates_.find(root_context.get()); + if (found_it != fiber.program_isolates_.end()) { + isolate = found_it->second.get(); + } + if (isolate && !isolate->fork_contexts.empty()) { + // Fast path: there is an existing isolate and a context avaialable. + auto isolated_context = std::move(isolate->fork_contexts.back()); + isolate->fork_contexts.pop_back(); + return std::make_pair(std::move(isolated_context), isolate); + } else if (!isolate) { + // Initialize a new isolate accounting struct while in the lock. + // Note that this can cause a fault for PER_FIBER mode if the call + // to fork fails below as it will leave the isolate with no available + // context and every future call will raise an exception indicating that + // the context is busy (vs trying to create a new one). This is deemed + // an acceptable situation for a system fault (which is the only reason + // a fork will fail). + auto [inserted_it, inserted] = + fiber.program_isolates_.insert(std::make_pair( + root_context.get(), + std::make_unique(root_context))); + isolate = inserted_it->second.get(); + } else if (isolation == ProgramIsolation::PER_FIBER) { + throw std::logic_error( + "Cannot make concurrent invocations of a PER_FIBER program from " + "the same Fiber. This typically means that two invocations were " + "attempted on the same program on the same fiber without an " + "await. Consider fixing adding appropriate sequencing or switching " + "to either PER_CALL or NONE isolation if appropriate for the use " + "case. This exception can also occur if the first invocation to this " + "Program failed, leaving no initialized Program for this fiber."); + } + } + + // Slow-path: fork needed (and possibly new isolate registration needed). + iree::vm_context_ptr new_context; + SHORTFIN_THROW_IF_ERROR(iree_vm_context_fork( + root_context.get(), fiber.host_allocator(), new_context.for_output())); + return std::make_pair(std::move(new_context), isolate); +} + +void detail::ProgramIsolate::ReleaseIsolate(Fiber &fiber, + iree::vm_context_ptr context, + detail::ProgramIsolate *isolate) { + assert(isolate && "attempt to release null isolate"); + { + iree::slim_mutex_lock_guard lock(fiber.program_isolate_mu_); + isolate->fork_contexts.push_back(std::move(context)); + } +} + } // namespace shortfin::local diff --git a/shortfin/src/shortfin/local/program.h b/shortfin/src/shortfin/local/program.h index bc5ae05dc..ea4f0cc3f 100644 --- a/shortfin/src/shortfin/local/program.h +++ b/shortfin/src/shortfin/local/program.h @@ -26,6 +26,10 @@ class BaseProgramParameters; class Fiber; class System; +namespace detail { +struct ProgramIsolate; +} // namespace detail + enum class ProgramInvocationModel { // Uses the coarse-fences invocation model. In this model, the last two // arguments are a wait and signal fence, which are used for function-level @@ -37,6 +41,24 @@ enum class ProgramInvocationModel { UNKNOWN, }; +// The level of isolation that a program has with respect to concurrent use. +enum class ProgramIsolation { + // There is no isolation: Callers are completely on their own to only issue + // concurrent invocations if supported. + NONE = 0, + + // Each fiber in the system that makes calls into the program will have its + // own shallow fork of the module. This is done on-demand and the root + // program is retained for the lifetime of any referencing fibers. + // Concurrent calls on the same fiber are considered programming errors and + // will be flagged as such at an appropriate debug level. + PER_FIBER = 1, + + // Each call triggers a shallow fork of the module. This is the most expensive + // but safest way to ensure complete isolation of stateless invocations. + PER_CALL = 2, +}; + // State related to making an invocation of a function on a program. // // Since ownership of this object is transferred to the loop/callback and @@ -67,7 +89,8 @@ class SHORTFIN_API ProgramInvocation { static Ptr New(std::shared_ptr fiber, iree::vm_context_ptr vm_context, iree_vm_function_t &vm_function, - ProgramInvocationModel invocation_model); + ProgramInvocationModel invocation_model, + detail::ProgramIsolate *isolate); ProgramInvocation(const ProgramInvocation &) = delete; ProgramInvocation &operator=(const ProgramInvocation &) = delete; ProgramInvocation &operator=(ProgramInvocation &&) = delete; @@ -133,6 +156,11 @@ class SHORTFIN_API ProgramInvocation { private: ProgramInvocation(); void CheckNotScheduled(); + // Eagerly releases context when it is known that no further use of it can + // be made (allowing it to be returned to a pool prior to the invocation + // actually being recycled). Object destruction also does this, but possibly + // extending the context lifetime. + void ReleaseContext(); // Returns a pointer to the trailing arg list. iree_vm_list_t *arg_list(); @@ -156,8 +184,6 @@ class SHORTFIN_API ProgramInvocation { // This must not contain entities that require destruction or cannot be // trivially copied. struct Params { - // Context is retained upon construction and released when scheduled. - iree_vm_context_t *context; iree_vm_function_t function; ProgramInvocationModel invocation_model; }; @@ -169,6 +195,8 @@ class SHORTFIN_API ProgramInvocation { } state; std::shared_ptr fiber_; + iree::vm_context_ptr vm_context_; + detail::ProgramIsolate *isolate_; iree_vm_list_t *result_list_ = nullptr; std::optional future_; iree::hal_fence_ptr wait_fence_; @@ -187,7 +215,7 @@ class SHORTFIN_API ProgramFunction { std::string_view calling_convention() const; ProgramInvocationModel invocation_model() const { return invocation_model_; } - ProgramInvocation::Ptr CreateInvocation(); + ProgramInvocation::Ptr CreateInvocation(std::shared_ptr fiber); std::string to_s() const; @@ -195,17 +223,16 @@ class SHORTFIN_API ProgramFunction { operator iree_vm_function_t &() { return vm_function_; } private: - ProgramFunction(std::shared_ptr fiber, iree::vm_context_ptr vm_context, - iree_vm_function_t vm_function, + ProgramFunction(iree::vm_context_ptr vm_context, + iree_vm_function_t vm_function, ProgramIsolation isolation, std::optional invocation_model = {}); static ProgramInvocationModel GetInvocationModelFromFunction( iree_vm_function_t &f); - // The context that this function was resolved against. - std::shared_ptr fiber_; iree::vm_context_ptr vm_context_; iree_vm_function_t vm_function_; + ProgramIsolation isolation_; ProgramInvocationModel invocation_model_; friend class Program; }; @@ -231,6 +258,7 @@ class SHORTFIN_API ProgramModule { std::string to_s() const; iree_vm_module_t *vm_module() const { return vm_module_; } std::string_view name() const; + System &system() const { return *system_; } // Loads a dynamic bytecode module (VMFB) from a path on the file system. static ProgramModule Load(System &system, const std::filesystem::path &path, @@ -246,10 +274,12 @@ class SHORTFIN_API ProgramModule { std::vector exports() const; protected: - explicit ProgramModule(iree::vm_module_ptr vm_module) - : vm_module_(std::move(vm_module)) {} + explicit ProgramModule(std::shared_ptr system, + iree::vm_module_ptr vm_module) + : system_(std::move(system)), vm_module_(std::move(vm_module)) {} private: + std::shared_ptr system_; iree::vm_module_ptr vm_module_; }; @@ -269,15 +299,19 @@ class SHORTFIN_API Program { struct Options { Options() {} + // Ordered list of devices to bind this program to. + std::span devices; + + // The isolation level to apply to program invocation. + ProgramIsolation isolation = ProgramIsolation::PER_FIBER; + // Enables program-wide execution tracing (to stderr). bool trace_execution = false; }; - // Loads a program attached to a fiber with a list of user provided modules - // and options. - static Program Load(std::shared_ptr fiber, - std::span modules, - Options options = {}); + // Load a program from a list of modules and options. + static Program Load(std::span modules, + Options &&options); // Looks up a public function by fully qualified name (i.e. module.function). // Returns nothing if not found. @@ -290,12 +324,16 @@ class SHORTFIN_API Program { // Gets the name of all exported functions. std::vector exports() const; + // Eagerly does any per-fiber isolation preparation for the program at a + // convenient point (usually init time) to avoid first-invocation overhead. + void PrepareIsolate(Fiber &fiber); + private: - explicit Program(std::shared_ptr fiber, - iree::vm_context_ptr vm_context) - : fiber_(std::move(fiber)), vm_context_(std::move(vm_context)) {} - std::shared_ptr fiber_; + explicit Program(iree::vm_context_ptr vm_context, ProgramIsolation isolation) + : vm_context_(std::move(vm_context)), isolation_(isolation) {} + iree::vm_context_ptr vm_context_; + ProgramIsolation isolation_; friend class Fiber; }; @@ -354,6 +392,27 @@ class SHORTFIN_API StaticProgramParameters : public BaseProgramParameters { iree::io_parameter_index_ptr index_; }; +namespace detail { +// See Fiber::program_isolates_. +struct ProgramIsolate { + ProgramIsolate(iree::vm_context_ptr parent_context) + : parent_context(std::move(parent_context)) {} + iree::vm_context_ptr parent_context; + std::vector fork_contexts; + + // Acquires an isolate for the given fiber. This will return a context which + // may be the original program context or may be a forked child that is + // available for use. It is only valid to call this when isolation != NONE. + static std::pair + AcquireIsolate(Fiber &fiber, iree::vm_context_ptr root_context, + ProgramIsolation isolation); + + // Releases an isolate obtained from a fiber in AcquireIsolate. + static void ReleaseIsolate(Fiber &fiber, iree::vm_context_ptr context, + ProgramIsolate *isolate); +}; +}; // namespace detail + } // namespace shortfin::local #endif // SHORTFIN_LOCAL_PROGRAM_H diff --git a/shortfin/src/shortfin/support/iree_helpers.cc b/shortfin/src/shortfin/support/iree_helpers.cc index 417a9f443..17430bb71 100644 --- a/shortfin/src/shortfin/support/iree_helpers.cc +++ b/shortfin/src/shortfin/support/iree_helpers.cc @@ -86,13 +86,14 @@ error::error(std::string message, iree_status_t failing_status) message_(std::move(message)), failing_status_(failing_status) { message_.append(": "); + AppendStatusMessage(); } -error::error(iree_status_t failing_status) : failing_status_(failing_status) {} -void error::AppendStatus() const noexcept { - if (status_appended_) return; - status_appended_ = false; +error::error(iree_status_t failing_status) : failing_status_(failing_status) { + AppendStatusMessage(); +} +void error::AppendStatusMessage() { iree_allocator_t allocator = iree_allocator_system(); char *status_buffer = nullptr; iree_host_size_t length = 0; diff --git a/shortfin/src/shortfin/support/iree_helpers.h b/shortfin/src/shortfin/support/iree_helpers.h index 446f32f41..f8d3f1398 100644 --- a/shortfin/src/shortfin/support/iree_helpers.h +++ b/shortfin/src/shortfin/support/iree_helpers.h @@ -277,24 +277,21 @@ class SHORTFIN_API error : public std::exception { public: error(std::string message, iree_status_t failing_status); error(iree_status_t failing_status); - error(const error &) = delete; + error(const error &other) + : code_(other.code_), + message_(other.message_), + failing_status_(iree_status_clone(other.failing_status_)) {} error &operator=(const error &) = delete; ~error() { iree_status_ignore(failing_status_); } - const char *what() const noexcept override { - if (!status_appended_) { - AppendStatus(); - } - return message_.c_str(); - }; + const char *what() const noexcept override { return message_.c_str(); }; iree_status_code_t code() const { return code_; } private: - void AppendStatus() const noexcept; + void AppendStatusMessage(); iree_status_code_t code_; - mutable std::string message_; + std::string message_; mutable iree_status_t failing_status_; - mutable bool status_appended_ = false; }; #define SHORTFIN_IMPL_HANDLE_IF_API_ERROR(var, ...) \ diff --git a/shortfin/tests/invocation/conftest.py b/shortfin/tests/invocation/conftest.py index c366c7f82..148ae064d 100644 --- a/shortfin/tests/invocation/conftest.py +++ b/shortfin/tests/invocation/conftest.py @@ -22,15 +22,16 @@ def mobilenet_onnx_path(tmp_path_factory): import onnx except ModuleNotFoundError: raise pytest.skip("onnx python package not available") - print("Downloading mobilenet.onnx") parent_dir = tmp_path_factory.mktemp("mobilenet_onnx") orig_onnx_path = parent_dir / "mobilenet_orig.onnx" - urllib.request.urlretrieve( - "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx", - orig_onnx_path, - ) upgraded_onnx_path = parent_dir / "mobilenet.onnx" - upgrade_onnx(orig_onnx_path, upgraded_onnx_path) + if not upgraded_onnx_path.exists(): + print("Downloading mobilenet.onnx") + urllib.request.urlretrieve( + "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx", + orig_onnx_path, + ) + upgrade_onnx(orig_onnx_path, upgraded_onnx_path) return upgraded_onnx_path @@ -41,15 +42,18 @@ def mobilenet_compiled_cpu_path(mobilenet_onnx_path): import iree.compiler.tools.import_onnx.__main__ as import_onnx except ModuleNotFoundError: raise pytest.skip("iree.compiler packages not available") - print("Compiling mobilenet") mlir_path = mobilenet_onnx_path.parent / "mobilenet.mlir" vmfb_path = mobilenet_onnx_path.parent / "mobilenet_cpu.vmfb" - args = import_onnx.parse_arguments(["-o", str(mlir_path), str(mobilenet_onnx_path)]) - import_onnx.main(args) - tools.compile_file( - str(mlir_path), - output_file=str(vmfb_path), - target_backends=["llvm-cpu"], - input_type="onnx", - ) + if not vmfb_path.exists(): + print("Compiling mobilenet") + args = import_onnx.parse_arguments( + ["-o", str(mlir_path), str(mobilenet_onnx_path)] + ) + import_onnx.main(args) + tools.compile_file( + str(mlir_path), + output_file=str(vmfb_path), + target_backends=["llvm-cpu"], + input_type="onnx", + ) return vmfb_path diff --git a/shortfin/tests/invocation/mobilenet_program_test.py b/shortfin/tests/invocation/mobilenet_program_test.py index 4275fe9e2..84903fb8f 100644 --- a/shortfin/tests/invocation/mobilenet_program_test.py +++ b/shortfin/tests/invocation/mobilenet_program_test.py @@ -5,6 +5,8 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import array +import asyncio +import time import functools import pytest @@ -21,38 +23,89 @@ def lsys(): @pytest.fixture -def fiber(lsys): +def fiber0(lsys): return lsys.create_fiber() @pytest.fixture -def device(fiber): - return fiber.device(0) +def device(fiber0): + return fiber0.device(0) -def test_invoke_mobilenet(lsys, fiber, mobilenet_compiled_cpu_path): - device = fiber.device(0) +@pytest.fixture +def mobilenet_program_function( + lsys, mobilenet_compiled_cpu_path +) -> tuple[sf.ProgramFunction]: + program_module = lsys.load_module(mobilenet_compiled_cpu_path) + program = sf.Program([program_module], devices=lsys.devices) + main_function = program["module.torch-jit-export"] + return main_function + + +def get_mobilenet_ref_input(device) -> sfnp.device_array: dummy_data = array.array( "f", ([0.2] * (224 * 224)) + ([0.4] * (224 * 224)) + ([-0.2] * (224 * 224)) ) - program_module = lsys.load_module(mobilenet_compiled_cpu_path) - program = sf.Program([program_module], fiber=fiber) - main_function = program["module.torch-jit-export"] + device_input = sfnp.device_array(device, [1, 3, 224, 224], sfnp.float32) + staging_input = device_input.for_transfer() + with staging_input.map(discard=True) as m: + m.fill(dummy_data) + device_input.copy_from(staging_input) + return device_input + + +async def assert_mobilenet_ref_output(device, device_output): + host_output = device_output.for_transfer() + host_output.copy_from(device_output) + await device + flat_output = host_output.items + absmean = functools.reduce( + lambda x, y: x + abs(y) / len(flat_output), flat_output, 0.0 + ) + print("RESULT:", absmean) + assert absmean == pytest.approx(5.01964943873882) + + +def test_invoke_mobilenet(lsys, fiber0, mobilenet_program_function): + device = fiber0.device(0) async def main(): - device_input = sfnp.device_array(device, [1, 3, 224, 224], sfnp.float32) - staging_input = device_input.for_transfer() - with staging_input.map(discard=True) as m: - m.fill(dummy_data) - device_input.copy_from(staging_input) - (device_output,) = await main_function(device_input) - host_output = device_output.for_transfer() - host_output.copy_from(device_output) - await device - flat_output = host_output.items - absmean = functools.reduce( - lambda x, y: x + abs(y) / len(flat_output), flat_output, 0.0 - ) - assert absmean == pytest.approx(5.01964943873882) + device_input = get_mobilenet_ref_input(device) + (device_output,) = await mobilenet_program_function(device_input, fiber=fiber0) + await assert_mobilenet_ref_output(device, device_output) + + lsys.run(main()) + + +def test_invoke_mobilenet_multi_fiber(lsys, mobilenet_program_function): + class InferProcess(sf.Process): + async def run(self): + start_time = time.time() + + def duration(): + return round((time.time() - start_time) * 1000.0) + + print(f"{self}: Start") + device = self.fiber.device(0) + device_input = get_mobilenet_ref_input(device) + (device_output,) = await mobilenet_program_function( + device_input, fiber=self.fiber + ) + print(f"{self}: Program complete (+{duration()}ms)") + await assert_mobilenet_ref_output(device, device_output) + print(f"{self} End (+{duration()}ms)") + + async def main(): + start_time = time.time() + + def duration(): + return round((time.time() - start_time) * 1000.0) + + fibers = [lsys.create_fiber() for _ in range(5)] + print("Fibers:", fibers) + processes = [InferProcess(fiber=f).launch() for f in fibers] + print("Waiting for processes:", processes) + await asyncio.gather(*processes) + print(f"All processes complete: (+{duration()}ms)") lsys.run(main())