Skip to content

Commit

Permalink
Merge branch 'main' into update-perplexity-ci-install
Browse files Browse the repository at this point in the history
  • Loading branch information
archana-ramalingam authored Dec 17, 2024
2 parents a184987 + aab7161 commit c89c9d1
Show file tree
Hide file tree
Showing 44 changed files with 1,628 additions and 541 deletions.
44 changes: 31 additions & 13 deletions .github/workflows/ci-libshortfin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,9 @@ name: CI - shortfin
on:
workflow_dispatch:
pull_request:
paths:
- '.github/workflows/ci-libshortfin.yml'
- 'shortfin/**'
push:
branches:
- main
paths:
- '.github/workflows/ci-libshortfin.yml'
- 'shortfin/**'

permissions:
contents: read
Expand All @@ -44,7 +38,7 @@ jobs:
strategy:
fail-fast: false
matrix:
name: ["Ubuntu (Clang)(full)", "Ubuntu (Clang)(host-only)", "Ubuntu (GCC)", "Windows (MSVC)"]
name: ["Ubuntu (Clang)(full)", "Ubuntu (Clang)(host-only)", "Windows (MSVC)"]
python-version: ["3.10", "3.11", "3.12"]
include:
- name: Ubuntu (Clang)(full)
Expand All @@ -59,16 +53,21 @@ jobs:
cmake-options:
-DCMAKE_C_COMPILER=clang-18 -DCMAKE_CXX_COMPILER=clang++-18 -DCMAKE_LINKER_TYPE=LLD -DSHORTFIN_HAVE_AMDGPU=OFF -DSHORTFIN_BUILD_STATIC=ON -DSHORTFIN_BUILD_DYNAMIC=ON
additional-packages: clang lld
- name: Ubuntu (GCC)
- name: Ubuntu (GCC 13)
runs-on: ubuntu-24.04
# Only test with GCC 13 and Python 3.12
python-version: "3.12"
cmake-options:
-DCMAKE_C_COMPILER=gcc-13 -DCMAKE_CXX_COMPILER=g++-13
- name: Ubuntu (GCC 14)
runs-on: ubuntu-24.04
# Only test with GCC 14 and Python 3.12
python-version: "3.12"
cmake-options:
-DCMAKE_C_COMPILER=gcc-14 -DCMAKE_CXX_COMPILER=g++-14
- name: Windows (MSVC)
runs-on: windows-2022
exclude:
# Only test Python 3.12 with GCC
- name: Ubuntu (GCC)
python-version: "3.10"
- name: Ubuntu (GCC)
python-version: "3.11"
# TODO: Include additional Python versions for Windows after build got fixed
- name: Windows (MSVC)
python-version: "3.10"
Expand Down Expand Up @@ -152,3 +151,22 @@ jobs:
run: |
ctest --timeout 30 --output-on-failure --test-dir build
pytest -s --durations=10
# Depends on all other jobs to provide an aggregate job status.
ci_libshortfin_summary:
if: always()
runs-on: ubuntu-24.04
needs:
- build-and-test
steps:
- name: Getting failed jobs
run: |
echo '${{ toJson(needs) }}'
FAILED_JOBS="$(echo '${{ toJson(needs) }}' \
| jq --raw-output \
'map_values(select(.result!="success" and .result!="skipped")) | keys | join(",")' \
)"
if [[ "${FAILED_JOBS}" != "" ]]; then
echo "The following jobs failed: ${FAILED_JOBS}"
exit 1
fi
30 changes: 27 additions & 3 deletions .github/workflows/ci-sharktank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,19 @@ jobs:
pip freeze
- name: Run tests
# TODO: unify with-t5-data and with-clip-data flags into a single flag
# and make it possible to run only tests that require data.
# TODO: unify with-*-data flags into a single flag and make it possible to run
# only tests that require data.
# We would still want the separate flags as we may endup with data being
# scattered on different CI machines.
run: |
source ${VENV_DIR}/bin/activate
pytest \
--with-clip-data \
--with-clip-data \
--with-flux-data \
--with-t5-data \
sharktank/tests/models/clip/clip_test.py \
sharktank/tests/models/t5/t5_test.py \
sharktank/tests/models/flux/flux_test.py \
--durations=0
Expand Down Expand Up @@ -182,3 +186,23 @@ jobs:
run: |
pytest -v sharktank/ -m punet_quick \
--durations=0
# Depends on other jobs to provide an aggregate job status.
# TODO(#584): move test_with_data and test_integration to a pkgci integration test workflow?
ci_sharktank_summary:
if: always()
runs-on: ubuntu-24.04
needs:
- test
steps:
- name: Getting failed jobs
run: |
echo '${{ toJson(needs) }}'
FAILED_JOBS="$(echo '${{ toJson(needs) }}' \
| jq --raw-output \
'map_values(select(.result!="success" and .result!="skipped")) | keys | join(",")' \
)"
if [[ "${FAILED_JOBS}" != "" ]]; then
echo "The following jobs failed: ${FAILED_JOBS}"
exit 1
fi
24 changes: 11 additions & 13 deletions build_tools/python_deploy/write_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def write_requirements(requirements):
metapackage_version = load_version_info(VERSION_FILE_LOCAL)
PACKAGE_VERSION = metapackage_version.get("package-version")

# sharktank_version = load_version_info(VERSION_FILE_SHARKTANK)
# SHARKTANK_PACKAGE_VERSION = sharktank_version.get("package-version")
sharktank_version = load_version_info(VERSION_FILE_SHARKTANK)
SHARKTANK_PACKAGE_VERSION = sharktank_version.get("package-version")

shortfin_version = load_version_info(VERSION_FILE_SHORTFIN)
SHORTFIN_PACKAGE_VERSION = shortfin_version.get("package-version")
Expand All @@ -65,13 +65,12 @@ def write_requirements(requirements):
requirements = ""
for package in stable_packages_list:
requirements += package + "\n"
# TODO: Include sharktank as a dependencies of future releases
# requirements = (
# "sharktank=="
# + Version(SHARKTANK_PACKAGE_VERSION).base_version
# + args.version_suffix
# + "\n"
# )
requirements = (
"sharktank=="
+ Version(SHARKTANK_PACKAGE_VERSION).base_version
+ args.version_suffix
+ "\n"
)
requirements += (
"shortfin=="
+ Version(SHORTFIN_PACKAGE_VERSION).base_version
Expand All @@ -89,10 +88,9 @@ def write_requirements(requirements):
requirements = ""
for package in stable_packages_list:
requirements += package + "==" + STABLE_VERSION_TO_PIN + "\n"
# TODO: Include sharktank as a dependencies of future releases
# requirements += (
# "sharktank==" + Version(SHARKTANK_PACKAGE_VERSION).base_version + "\n"
# )
requirements += (
"sharktank==" + Version(SHARKTANK_PACKAGE_VERSION).base_version + "\n"
)
requirements += "shortfin==" + Version(SHORTFIN_PACKAGE_VERSION).base_version

write_requirements(requirements)
13 changes: 10 additions & 3 deletions docs/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,20 @@ Setup your Python environment with the following commands:
# Set up a virtual environment to isolate packages from other envs.
python3.11 -m venv 3.11.venv
source 3.11.venv/bin/activate
```

## Install SHARK and its dependencies

First install a torch version that fulfills your needs:

# Optional: faster installation of torch with just CPU support.
# See other options at https://pytorch.org/get-started/locally/
```bash
# Fast installation of torch with just CPU support.
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
```

## Install SHARK and its dependencies
For other options, see https://pytorch.org/get-started/locally/.

Next install shark-ai:

```bash
pip install shark-ai[apps]
Expand Down
9 changes: 9 additions & 0 deletions sharktank/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ def pytest_addoption(parser):
"code. The user is expected to provide the data"
),
)
parser.addoption(
"--with-flux-data",
action="store_true",
default=False,
help=(
"Enable tests that use Flux data like models that is not a part of the source "
"code. The user is expected to provide the data"
),
)
parser.addoption(
"--with-t5-data",
action="store_true",
Expand Down
2 changes: 1 addition & 1 deletion sharktank/integration/models/punet/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def download(filename):

@pytest.fixture(scope="module")
def sdxl_fp16_dataset(sdxl_fp16_base_files, temp_dir):
from sharktank.models.punet.tools import import_hf_dataset
from sharktank.tools import import_hf_dataset

dataset = temp_dir / "sdxl_fp16_dataset.irpa"
import_hf_dataset.main(
Expand Down
4 changes: 0 additions & 4 deletions sharktank/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ huggingface-hub==0.22.2
transformers==4.40.0
datasets

# It is expected that you have installed a PyTorch version/variant specific
# to your needs, so we only include a minimum version spec.
torch>=2.3.0

# Serving deps.
fastapi>=0.112.2
uvicorn>=0.30.6
10 changes: 10 additions & 0 deletions sharktank/sharktank/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,13 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import importlib.util

msg = """No module named 'torch'. Follow https://pytorch.org/get-started/locally/#start-locally to install 'torch'.
For example, on Linux to install with CPU support run:
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
"""

if spec := importlib.util.find_spec("torch") is None:
raise ModuleNotFoundError(msg)
53 changes: 52 additions & 1 deletion sharktank/sharktank/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Callable, Any
from typing import Callable, Optional, Any
import torch
from os import PathLike
import iree.turbine.aot as aot
from iree.turbine.aot import DeviceAffinity, FxProgramsBuilder
from torch.utils._pytree import tree_structure, tree_unflatten, tree_flatten
from .types.tensors import ShardedTensor
from .layers import BaseLayer
from torch.utils._pytree import PyTree, _is_leaf
import functools

Expand Down Expand Up @@ -172,3 +175,51 @@ def flat_fn(*args, **kwargs):
)

assert False, "TODO: implement the case when not using an FxProgramsBuilder"


def export_static_model_mlir(
model: BaseLayer,
output_path: PathLike,
function_batch_size_pairs: Optional[dict[Optional[str], list[int]]] = None,
batch_sizes: Optional[list[int]] = None,
):
"""Export a model with no dynamic dimensions.
For the set of provided function name batch sizes pair, the resulting MLIR will
have function names with the below format.
```
<function_name>_bs<batch_size>
```
If `batch_sizes` is given then it defaults to a single function with named
"forward".
The model is required to implement method `sample_inputs`.
"""

assert not (function_batch_size_pairs is not None and batch_sizes is not None)

if batch_sizes is not None:
function_batch_size_pairs = {None: batch_sizes}

if function_batch_size_pairs is None and batch_sizes is None:
function_batch_size_pairs = {None: batch_sizes}

fxb = FxProgramsBuilder(model)

for function, batch_sizes in function_batch_size_pairs.items():
for batch_size in batch_sizes:
args, kwargs = model.sample_inputs(batch_size, function)

@fxb.export_program(
name=f"{function or 'forward'}_bs{batch_size}",
args=args,
kwargs=kwargs,
dynamic_shapes=None,
strict=False,
)
def _(model, **kwargs):
return model(**kwargs)

output = aot.export(fxb)
output.save_mlir(output_path)
24 changes: 18 additions & 6 deletions sharktank/sharktank/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Dict

from typing import Dict, Optional
from collections import OrderedDict
import torch
import torch.nn as nn

from ..types import (
InferenceTensor,
Theta,
)
from ..types import InferenceTensor, Theta, AnyTensor
from ..utils import debugging

__all__ = [
Expand Down Expand Up @@ -56,6 +53,21 @@ def assert_not_nan(self, *ts: torch.Tensor):
if torch.isnan(t).any():
raise AssertionError(f"Tensor contains nans! {t}")

def sample_inputs(
self, batch_size: int = 1, function: Optional[str] = None
) -> tuple[tuple[AnyTensor], OrderedDict[str, AnyTensor]]:
"""Return sample inputs that can be used to run the function from the model.
If function is None then layer is treated as the callable.
E.g.
```
args, kwargs = model.sample_inputs()
model(*args, **kwargs)
```
One purpose of this method is to standardize exportation of models to MLIR.
"""
raise NotImplementedError()


class ThetaLayer(BaseLayer):
"Base class for layers that derive parameters from a Theta object."
Expand Down
Loading

0 comments on commit c89c9d1

Please sign in to comment.