Skip to content

Commit

Permalink
Merge branch 'main' into requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
marbre authored Oct 28, 2024
2 parents 39aff93 + 98392d0 commit e6ab34a
Show file tree
Hide file tree
Showing 12 changed files with 66 additions and 57 deletions.
49 changes: 29 additions & 20 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,35 +116,38 @@ def setup_cache(model, shard_count):
page_count=hp.context_length // llama_config.block_seq_stride
)
page_dim = torch.export.Dim("page")

dynamic_shapes = [{0: page_dim}]
unpacked = cache_state
arg_affinities = {}
shard_dim = None

# Need to unpacke that state when sharded
if llama_config.tensor_parallelism_size > 1:
shard_dim = cache_state[0].shard_dim

unpacked = [[shard._data for shard in cs.shards] for cs in cache_state]
dynamic_shapes = [
[ds] * llama_config.tensor_parallelism_size for ds in dynamic_shapes
]

for i in range(llama_config.tensor_parallelism_size):
arg_affinities[i] = DeviceAffinity(str(i))

return unpacked, shard_dim, dynamic_shapes, arg_affinities

elif model.config.kv_cache_type == "direct":
cache_state = model.cache.allocate(bs=1)
# Direct cache dimensions:
# 2 * transformer_block_count of...
# [bs, seq_length, attn_head_count, attn_head_dim]
dynamic_shapes = [None]
arg_affinities = {}
shard_dim = None
return torch.stack(cache_state), shard_dim, dynamic_shapes, arg_affinities
else:
raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}")

unpacked = cache_state
dynamic_shapes = dynamic_shapes
arg_affinities = {}
shard_dim = None

# Need to unpacke that state when sharded
if llama_config.tensor_parallelism_size > 1:
shard_dim = cache_state[0].shard_dim

unpacked = [[shard._data for shard in cs.shards] for cs in cache_state]
dynamic_shapes = [
[ds] * llama_config.tensor_parallelism_size for ds in dynamic_shapes
]

for i in range(llama_config.tensor_parallelism_size):
arg_affinities[i] = DeviceAffinity(str(i))

return torch.stack(unpacked), shard_dim, dynamic_shapes, arg_affinities

def repack_cache(cache, shard_dim):
return [SplitPrimitiveTensor(ts=c, shard_dim=shard_dim) for c in cache]

Expand Down Expand Up @@ -184,7 +187,13 @@ def generate_batch_prefill(bs: int):
arg_device=arg_affinities,
)
def _(model, tokens, seq_lens, seq_block_ids, cs):
cache_tensors = torch.unbind(cs)
if (
model.config.tensor_parallelism_size == 1
and model.config.kv_cache_type == "direct"
):
cache_tensors = torch.unbind(cs)
else:
cache_tensors = cs

sl = tokens.shape[1]
input_mask = model.input_mask(seq_lens, sl)
Expand Down
14 changes: 7 additions & 7 deletions sharktank/sharktank/layers/causal_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,9 @@ def input_mask(

def decode_attention_mask(self, boolean_input_mask: torch.Tensor):
dtype = self.attention_dtype
numeric_mask = torch.zeros_like(boolean_input_mask, dtype=dtype)
numeric_mask.masked_fill_(
boolean_input_mask, self._maximally_negative_value(dtype)
)
numeric_mask = torch.where(
boolean_input_mask, self._maximally_negative_value(dtype), 0
).to(dtype)
return numeric_mask.unsqueeze(1).unsqueeze(1).to(self.device)

def attention_mask(
Expand Down Expand Up @@ -127,9 +126,10 @@ def attention_mask(
dtype = self.attention_dtype
_, batch_seq_len = input_mask.shape
causal_mask = causal_context_mask[:, :, :batch_seq_len, :batch_seq_len]
boolean_mask = causal_mask + input_mask[:, None, None, :]
numeric_mask = torch.zeros_like(boolean_mask, dtype=dtype)
numeric_mask.masked_fill_(boolean_mask, self._maximally_negative_value(dtype))
boolean_mask = torch.logical_or(causal_mask, input_mask[:, None, None, :])
numeric_mask = torch.where(
boolean_mask, self._maximally_negative_value(dtype), 0
).to(dtype)
return numeric_mask.to(self.device)

def extract_tokens_from_logits(
Expand Down
6 changes: 3 additions & 3 deletions shortfin/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

cmake_minimum_required(VERSION 3.29)

Expand Down
6 changes: 3 additions & 3 deletions shortfin/build_tools/cmake/shortfin_library.cmake
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

set(SHORTFIN_DEFAULT_COPTS
# General clang and GCC options application to C and C++.
Expand Down
6 changes: 3 additions & 3 deletions shortfin/dev_me.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/usr/bin/env python3
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# dev_me.py
#
Expand Down
6 changes: 3 additions & 3 deletions shortfin/python/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# shortfin publishes multiple python packages: - _shortfin: Trampoline
# __init__.py which looks at environment variables to load an appropriate native
Expand Down
6 changes: 3 additions & 3 deletions shortfin/src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

add_subdirectory(shortfin)

Expand Down
6 changes: 3 additions & 3 deletions shortfin/src/shortfin/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

add_subdirectory(array)
add_subdirectory(local)
Expand Down
6 changes: 3 additions & 3 deletions shortfin/src/shortfin/array/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

shortfin_cc_component(
NAME
Expand Down
6 changes: 3 additions & 3 deletions shortfin/src/shortfin/local/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

add_subdirectory(systems)

Expand Down
6 changes: 3 additions & 3 deletions shortfin/src/shortfin/local/systems/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

set(_SYSTEM_COMPONENTS)

Expand Down
6 changes: 3 additions & 3 deletions shortfin/src/shortfin/support/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

shortfin_cc_component(
NAME
Expand Down

0 comments on commit e6ab34a

Please sign in to comment.