diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index ead663f53..73243086a 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -6,14 +6,14 @@ on: paths: - '.github/workflows/ci-sharktank.yml' - 'sharktank/**' - - '*requirements.txt' + - '*requirements*.txt' push: branches: - main paths: - '.github/workflows/ci-sharktank.yml' - 'sharktank/**' - - '*requirements.txt' + - '*requirements*.txt' concurrency: # A PR number if a pull request and otherwise the commit hash. This cancels @@ -52,7 +52,7 @@ jobs: id: cache-pip with: path: ${{ env.PIP_CACHE_DIR }} - key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }} - name: Install pip deps run: | diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index 37a24eda2..1b1d153aa 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -46,7 +46,7 @@ jobs: id: cache-pip with: path: ${{ env.PIP_CACHE_DIR }} - key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }} - name: Install sharktank deps run: | diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 000000000..e736fe3bd --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,10 @@ +# Used for managing pre-commit flows. +pre-commit + +# Type checking +mypy==1.8.0 +types-requests==2.31.0.20240125 + +# Testing +pytest==8.0.0 +pytest-xdist==3.5.0 diff --git a/requirements.txt b/requirements.txt index 0198314f8..cc2edf876 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,30 +1,4 @@ -# Runtime deps. -gguf==0.6.0 -numpy==1.26.3 -onnx==1.15.0 - -# Model deps. -huggingface-hub==0.22.2 -transformers==4.40.0 -sentencepiece==0.2.0 - -# It is expected that you have installed a PyTorch version/variant specific -# to your needs, so we only include a minimum version spec. -# TODO: Use a versioned release once 2.3.0 drops. -torch>=2.3.0.dev1 - -# Used for managing pre-commit flows. -pre-commit - -# Type checking -mypy==1.8.0 -types-requests==2.31.0.20240125 - -# Testing -parameterized -pytest==8.0.0 -pytest-xdist==3.5.0 - -# Serving deps. -fastapi==0.112.2 -uvicorn==0.30.6 +-r sharktank/requirements.txt +-r sharktank/requirements-tests.txt +-r shortfin/requirements-tests.txt +-r requirements-dev.txt diff --git a/sharktank/requirements-tests.txt b/sharktank/requirements-tests.txt index d7266a5e8..4be48fdde 100644 --- a/sharktank/requirements-tests.txt +++ b/sharktank/requirements-tests.txt @@ -1 +1,3 @@ datasets==3.0.0 +parameterized +pytest==8.0.0 diff --git a/sharktank/requirements.txt b/sharktank/requirements.txt index 6b21f239f..ad231d524 100644 --- a/sharktank/requirements.txt +++ b/sharktank/requirements.txt @@ -1 +1,16 @@ -gguf +# Runtime deps. +gguf==0.6.0 +numpy==1.26.3 + +# Model deps. +huggingface-hub==0.22.2 +transformers==4.40.0 +datasets + +# It is expected that you have installed a PyTorch version/variant specific +# to your needs, so we only include a minimum version spec. +torch>=2.3.0 + +# Serving deps. +fastapi==0.112.2 +uvicorn==0.30.6 diff --git a/sharktank/setup.py b/sharktank/setup.py index ab6e92d33..8ffcf3984 100644 --- a/sharktank/setup.py +++ b/sharktank/setup.py @@ -99,7 +99,6 @@ def initialize_options(self): extras_require={ "testing": [ f"pytest{get_version_spec('pytest')}", - f"pytest-xdist{get_version_spec('pytest-xdist')}", ], }, cmdclass={"build": BuildCommand}, diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index ce6f0864a..0436c0008 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -116,35 +116,38 @@ def setup_cache(model, shard_count): page_count=hp.context_length // llama_config.block_seq_stride ) page_dim = torch.export.Dim("page") + dynamic_shapes = [{0: page_dim}] + unpacked = cache_state + arg_affinities = {} + shard_dim = None + + # Need to unpacke that state when sharded + if llama_config.tensor_parallelism_size > 1: + shard_dim = cache_state[0].shard_dim + + unpacked = [[shard._data for shard in cs.shards] for cs in cache_state] + dynamic_shapes = [ + [ds] * llama_config.tensor_parallelism_size for ds in dynamic_shapes + ] + + for i in range(llama_config.tensor_parallelism_size): + arg_affinities[i] = DeviceAffinity(str(i)) + + return unpacked, shard_dim, dynamic_shapes, arg_affinities + elif model.config.kv_cache_type == "direct": cache_state = model.cache.allocate(bs=1) # Direct cache dimensions: # 2 * transformer_block_count of... # [bs, seq_length, attn_head_count, attn_head_dim] dynamic_shapes = [None] + arg_affinities = {} + shard_dim = None + return torch.stack(cache_state), shard_dim, dynamic_shapes, arg_affinities else: raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}") - unpacked = cache_state - dynamic_shapes = dynamic_shapes - arg_affinities = {} - shard_dim = None - - # Need to unpacke that state when sharded - if llama_config.tensor_parallelism_size > 1: - shard_dim = cache_state[0].shard_dim - - unpacked = [[shard._data for shard in cs.shards] for cs in cache_state] - dynamic_shapes = [ - [ds] * llama_config.tensor_parallelism_size for ds in dynamic_shapes - ] - - for i in range(llama_config.tensor_parallelism_size): - arg_affinities[i] = DeviceAffinity(str(i)) - - return torch.stack(unpacked), shard_dim, dynamic_shapes, arg_affinities - def repack_cache(cache, shard_dim): return [SplitPrimitiveTensor(ts=c, shard_dim=shard_dim) for c in cache] @@ -184,7 +187,13 @@ def generate_batch_prefill(bs: int): arg_device=arg_affinities, ) def _(model, tokens, seq_lens, seq_block_ids, cs): - cache_tensors = torch.unbind(cs) + if ( + model.config.tensor_parallelism_size == 1 + and model.config.kv_cache_type == "direct" + ): + cache_tensors = torch.unbind(cs) + else: + cache_tensors = cs sl = tokens.shape[1] input_mask = model.input_mask(seq_lens, sl) diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py index 63c58e860..7a09995a8 100644 --- a/sharktank/sharktank/layers/causal_llm.py +++ b/sharktank/sharktank/layers/causal_llm.py @@ -95,10 +95,9 @@ def input_mask( def decode_attention_mask(self, boolean_input_mask: torch.Tensor): dtype = self.attention_dtype - numeric_mask = torch.zeros_like(boolean_input_mask, dtype=dtype) - numeric_mask.masked_fill_( - boolean_input_mask, self._maximally_negative_value(dtype) - ) + numeric_mask = torch.where( + boolean_input_mask, self._maximally_negative_value(dtype), 0 + ).to(dtype) return numeric_mask.unsqueeze(1).unsqueeze(1).to(self.device) def attention_mask( @@ -127,9 +126,10 @@ def attention_mask( dtype = self.attention_dtype _, batch_seq_len = input_mask.shape causal_mask = causal_context_mask[:, :, :batch_seq_len, :batch_seq_len] - boolean_mask = causal_mask + input_mask[:, None, None, :] - numeric_mask = torch.zeros_like(boolean_mask, dtype=dtype) - numeric_mask.masked_fill_(boolean_mask, self._maximally_negative_value(dtype)) + boolean_mask = torch.logical_or(causal_mask, input_mask[:, None, None, :]) + numeric_mask = torch.where( + boolean_mask, self._maximally_negative_value(dtype), 0 + ).to(dtype) return numeric_mask.to(self.device) def extract_tokens_from_logits( diff --git a/shortfin/CMakeLists.txt b/shortfin/CMakeLists.txt index 4e2648a87..b3c2ee24f 100644 --- a/shortfin/CMakeLists.txt +++ b/shortfin/CMakeLists.txt @@ -1,8 +1,8 @@ # Copyright 2024 Advanced Micro Devices, Inc. # -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception cmake_minimum_required(VERSION 3.29) diff --git a/shortfin/build_tools/cmake/shortfin_library.cmake b/shortfin/build_tools/cmake/shortfin_library.cmake index 26a31101b..872e24838 100644 --- a/shortfin/build_tools/cmake/shortfin_library.cmake +++ b/shortfin/build_tools/cmake/shortfin_library.cmake @@ -1,8 +1,8 @@ # Copyright 2024 Advanced Micro Devices, Inc. # -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception set(SHORTFIN_DEFAULT_COPTS # General clang and GCC options application to C and C++. diff --git a/shortfin/dev_me.py b/shortfin/dev_me.py index be02d67fa..ca6916767 100755 --- a/shortfin/dev_me.py +++ b/shortfin/dev_me.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 # Copyright 2024 Advanced Micro Devices, Inc. # -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # dev_me.py # diff --git a/shortfin/python/CMakeLists.txt b/shortfin/python/CMakeLists.txt index adf9d7879..d125416af 100644 --- a/shortfin/python/CMakeLists.txt +++ b/shortfin/python/CMakeLists.txt @@ -1,8 +1,8 @@ # Copyright 2024 Advanced Micro Devices, Inc. # -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # shortfin publishes multiple python packages: - _shortfin: Trampoline # __init__.py which looks at environment variables to load an appropriate native diff --git a/shortfin/src/CMakeLists.txt b/shortfin/src/CMakeLists.txt index 5e7c1d8e5..e27318764 100644 --- a/shortfin/src/CMakeLists.txt +++ b/shortfin/src/CMakeLists.txt @@ -1,8 +1,8 @@ # Copyright 2024 Advanced Micro Devices, Inc. # -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception add_subdirectory(shortfin) diff --git a/shortfin/src/shortfin/CMakeLists.txt b/shortfin/src/shortfin/CMakeLists.txt index 1bea0003b..058e0e336 100644 --- a/shortfin/src/shortfin/CMakeLists.txt +++ b/shortfin/src/shortfin/CMakeLists.txt @@ -1,8 +1,8 @@ # Copyright 2024 Advanced Micro Devices, Inc. # -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception add_subdirectory(array) add_subdirectory(local) diff --git a/shortfin/src/shortfin/array/CMakeLists.txt b/shortfin/src/shortfin/array/CMakeLists.txt index d40eed23f..48ab33590 100644 --- a/shortfin/src/shortfin/array/CMakeLists.txt +++ b/shortfin/src/shortfin/array/CMakeLists.txt @@ -1,8 +1,8 @@ # Copyright 2024 Advanced Micro Devices, Inc. # -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception shortfin_cc_component( NAME diff --git a/shortfin/src/shortfin/local/CMakeLists.txt b/shortfin/src/shortfin/local/CMakeLists.txt index 9f51c78bb..250bd79a2 100644 --- a/shortfin/src/shortfin/local/CMakeLists.txt +++ b/shortfin/src/shortfin/local/CMakeLists.txt @@ -1,8 +1,8 @@ # Copyright 2024 Advanced Micro Devices, Inc. # -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception add_subdirectory(systems) diff --git a/shortfin/src/shortfin/local/systems/CMakeLists.txt b/shortfin/src/shortfin/local/systems/CMakeLists.txt index 3ec5f17a6..b2bcbef23 100644 --- a/shortfin/src/shortfin/local/systems/CMakeLists.txt +++ b/shortfin/src/shortfin/local/systems/CMakeLists.txt @@ -1,8 +1,8 @@ # Copyright 2024 Advanced Micro Devices, Inc. # -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception set(_SYSTEM_COMPONENTS) diff --git a/shortfin/src/shortfin/support/CMakeLists.txt b/shortfin/src/shortfin/support/CMakeLists.txt index 9cb0d2b45..cbf171894 100644 --- a/shortfin/src/shortfin/support/CMakeLists.txt +++ b/shortfin/src/shortfin/support/CMakeLists.txt @@ -1,8 +1,8 @@ # Copyright 2024 Advanced Micro Devices, Inc. # -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception shortfin_cc_component( NAME