Skip to content

Commit

Permalink
Merge branch 'main' into perplexity-vmfb
Browse files Browse the repository at this point in the history
  • Loading branch information
archana-ramalingam authored Oct 28, 2024
2 parents b220688 + f925a5b commit 2a79eda
Show file tree
Hide file tree
Showing 19 changed files with 102 additions and 93 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci-sharktank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ on:
paths:
- '.github/workflows/ci-sharktank.yml'
- 'sharktank/**'
- '*requirements.txt'
- '*requirements*.txt'
push:
branches:
- main
paths:
- '.github/workflows/ci-sharktank.yml'
- 'sharktank/**'
- '*requirements.txt'
- '*requirements*.txt'

concurrency:
# A PR number if a pull request and otherwise the commit hash. This cancels
Expand Down Expand Up @@ -52,7 +52,7 @@ jobs:
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }}
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }}

- name: Install pip deps
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }}
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }}

- name: Install sharktank deps
run: |
Expand Down
10 changes: 10 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Used for managing pre-commit flows.
pre-commit

# Type checking
mypy==1.8.0
types-requests==2.31.0.20240125

# Testing
pytest==8.0.0
pytest-xdist==3.5.0
34 changes: 4 additions & 30 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,30 +1,4 @@
# Runtime deps.
gguf==0.6.0
numpy==1.26.3
onnx==1.15.0

# Model deps.
huggingface-hub==0.22.2
transformers==4.40.0
sentencepiece==0.2.0

# It is expected that you have installed a PyTorch version/variant specific
# to your needs, so we only include a minimum version spec.
# TODO: Use a versioned release once 2.3.0 drops.
torch>=2.3.0.dev1

# Used for managing pre-commit flows.
pre-commit

# Type checking
mypy==1.8.0
types-requests==2.31.0.20240125

# Testing
parameterized
pytest==8.0.0
pytest-xdist==3.5.0

# Serving deps.
fastapi==0.112.2
uvicorn==0.30.6
-r sharktank/requirements.txt
-r sharktank/requirements-tests.txt
-r shortfin/requirements-tests.txt
-r requirements-dev.txt
2 changes: 2 additions & 0 deletions sharktank/requirements-tests.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
datasets==3.0.0
parameterized
pytest==8.0.0
17 changes: 16 additions & 1 deletion sharktank/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,16 @@
gguf
# Runtime deps.
gguf==0.6.0
numpy==1.26.3

# Model deps.
huggingface-hub==0.22.2
transformers==4.40.0
datasets

# It is expected that you have installed a PyTorch version/variant specific
# to your needs, so we only include a minimum version spec.
torch>=2.3.0

# Serving deps.
fastapi==0.112.2
uvicorn==0.30.6
1 change: 0 additions & 1 deletion sharktank/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def initialize_options(self):
extras_require={
"testing": [
f"pytest{get_version_spec('pytest')}",
f"pytest-xdist{get_version_spec('pytest-xdist')}",
],
},
cmdclass={"build": BuildCommand},
Expand Down
49 changes: 29 additions & 20 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,35 +116,38 @@ def setup_cache(model, shard_count):
page_count=hp.context_length // llama_config.block_seq_stride
)
page_dim = torch.export.Dim("page")

dynamic_shapes = [{0: page_dim}]
unpacked = cache_state
arg_affinities = {}
shard_dim = None

# Need to unpacke that state when sharded
if llama_config.tensor_parallelism_size > 1:
shard_dim = cache_state[0].shard_dim

unpacked = [[shard._data for shard in cs.shards] for cs in cache_state]
dynamic_shapes = [
[ds] * llama_config.tensor_parallelism_size for ds in dynamic_shapes
]

for i in range(llama_config.tensor_parallelism_size):
arg_affinities[i] = DeviceAffinity(str(i))

return unpacked, shard_dim, dynamic_shapes, arg_affinities

elif model.config.kv_cache_type == "direct":
cache_state = model.cache.allocate(bs=1)
# Direct cache dimensions:
# 2 * transformer_block_count of...
# [bs, seq_length, attn_head_count, attn_head_dim]
dynamic_shapes = [None]
arg_affinities = {}
shard_dim = None
return torch.stack(cache_state), shard_dim, dynamic_shapes, arg_affinities
else:
raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}")

unpacked = cache_state
dynamic_shapes = dynamic_shapes
arg_affinities = {}
shard_dim = None

# Need to unpacke that state when sharded
if llama_config.tensor_parallelism_size > 1:
shard_dim = cache_state[0].shard_dim

unpacked = [[shard._data for shard in cs.shards] for cs in cache_state]
dynamic_shapes = [
[ds] * llama_config.tensor_parallelism_size for ds in dynamic_shapes
]

for i in range(llama_config.tensor_parallelism_size):
arg_affinities[i] = DeviceAffinity(str(i))

return torch.stack(unpacked), shard_dim, dynamic_shapes, arg_affinities

def repack_cache(cache, shard_dim):
return [SplitPrimitiveTensor(ts=c, shard_dim=shard_dim) for c in cache]

Expand Down Expand Up @@ -184,7 +187,13 @@ def generate_batch_prefill(bs: int):
arg_device=arg_affinities,
)
def _(model, tokens, seq_lens, seq_block_ids, cs):
cache_tensors = torch.unbind(cs)
if (
model.config.tensor_parallelism_size == 1
and model.config.kv_cache_type == "direct"
):
cache_tensors = torch.unbind(cs)
else:
cache_tensors = cs

sl = tokens.shape[1]
input_mask = model.input_mask(seq_lens, sl)
Expand Down
14 changes: 7 additions & 7 deletions sharktank/sharktank/layers/causal_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,9 @@ def input_mask(

def decode_attention_mask(self, boolean_input_mask: torch.Tensor):
dtype = self.attention_dtype
numeric_mask = torch.zeros_like(boolean_input_mask, dtype=dtype)
numeric_mask.masked_fill_(
boolean_input_mask, self._maximally_negative_value(dtype)
)
numeric_mask = torch.where(
boolean_input_mask, self._maximally_negative_value(dtype), 0
).to(dtype)
return numeric_mask.unsqueeze(1).unsqueeze(1).to(self.device)

def attention_mask(
Expand Down Expand Up @@ -127,9 +126,10 @@ def attention_mask(
dtype = self.attention_dtype
_, batch_seq_len = input_mask.shape
causal_mask = causal_context_mask[:, :, :batch_seq_len, :batch_seq_len]
boolean_mask = causal_mask + input_mask[:, None, None, :]
numeric_mask = torch.zeros_like(boolean_mask, dtype=dtype)
numeric_mask.masked_fill_(boolean_mask, self._maximally_negative_value(dtype))
boolean_mask = torch.logical_or(causal_mask, input_mask[:, None, None, :])
numeric_mask = torch.where(
boolean_mask, self._maximally_negative_value(dtype), 0
).to(dtype)
return numeric_mask.to(self.device)

def extract_tokens_from_logits(
Expand Down
6 changes: 3 additions & 3 deletions shortfin/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

cmake_minimum_required(VERSION 3.29)

Expand Down
6 changes: 3 additions & 3 deletions shortfin/build_tools/cmake/shortfin_library.cmake
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

set(SHORTFIN_DEFAULT_COPTS
# General clang and GCC options application to C and C++.
Expand Down
6 changes: 3 additions & 3 deletions shortfin/dev_me.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/usr/bin/env python3
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# dev_me.py
#
Expand Down
6 changes: 3 additions & 3 deletions shortfin/python/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# shortfin publishes multiple python packages: - _shortfin: Trampoline
# __init__.py which looks at environment variables to load an appropriate native
Expand Down
6 changes: 3 additions & 3 deletions shortfin/src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

add_subdirectory(shortfin)

Expand Down
6 changes: 3 additions & 3 deletions shortfin/src/shortfin/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

add_subdirectory(array)
add_subdirectory(local)
Expand Down
6 changes: 3 additions & 3 deletions shortfin/src/shortfin/array/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

shortfin_cc_component(
NAME
Expand Down
6 changes: 3 additions & 3 deletions shortfin/src/shortfin/local/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

add_subdirectory(systems)

Expand Down
6 changes: 3 additions & 3 deletions shortfin/src/shortfin/local/systems/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

set(_SYSTEM_COMPONENTS)

Expand Down
6 changes: 3 additions & 3 deletions shortfin/src/shortfin/support/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions. See
# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
# Apache-2.0 WITH LLVM-exception
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

shortfin_cc_component(
NAME
Expand Down

0 comments on commit 2a79eda

Please sign in to comment.