Skip to content

Commit

Permalink
feat: update v1.2 (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
byshiue authored Aug 16, 2022
1 parent d76f2c2 commit 22dba92
Show file tree
Hide file tree
Showing 94 changed files with 310,103 additions and 977 deletions.
76 changes: 53 additions & 23 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ project(tritonfastertransformerbackend LANGUAGES C CXX)
#
option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON)
option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON)
option(BUILD_MULTI_GPU "Enable multi GPU support" ON)

set(TRITON_PYTORCH_INCLUDE_PATHS "" CACHE PATH "Paths to Torch includes")
set(TRITON_PYTORCH_LIB_PATHS "" CACHE PATH "Paths to Torch libraries")

Expand All @@ -44,8 +46,6 @@ if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()

set(BUILD_MULTI_GPU "ON")
message("-- Enable BUILD_MULTI_GPU")
set(USE_TRITONSERVER_DATATYPE "ON")
message("-- Enable USE_TRITONSERVER_DATATYPE")

Expand All @@ -56,10 +56,15 @@ find_package(Python3 REQUIRED COMPONENTS Development)

find_package(FasterTransformer)
find_package(CUDA 10.1 REQUIRED)
find_package(MPI REQUIRED)
find_package(NCCL REQUIRED)

message(STATUS "Found MPI (include: ${MPI_INCLUDE_DIRS}, library: ${MPI_LIBRARIES})")
if (BUILD_MULTI_GPU)
message(STATUS "Enable BUILD_MULTI_GPU.")
find_package(MPI REQUIRED)
find_package(NCCL REQUIRED)
message(STATUS "Found MPI (include: ${MPI_INCLUDE_DIRS}, library: ${MPI_LIBRARIES})")
message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
else()
message(STATUS "Disable BUILD_MULTI_GPU.")
endif()

if (${CUDA_VERSION} GREATER_EQUAL 11.0)
message(STATUS "Add DCUDA11_MODE")
Expand Down Expand Up @@ -97,12 +102,19 @@ FetchContent_Declare(
GIT_TAG ${TRITON_BACKEND_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_Declare(
repo-ft
GIT_REPOSITORY https://github.com/NVIDIA/FasterTransformer.git
GIT_TAG main
GIT_SHALLOW ON
)
if (EXISTS ${FT_DIR})
FetchContent_Declare(
repo-ft
SOURCE_DIR ${FT_DIR}
)
else()
FetchContent_Declare(
repo-ft
GIT_REPOSITORY https://github.com/NVIDIA/FasterTransformer.git
GIT_TAG v5.1
GIT_SHALLOW ON
)
endif()
FetchContent_MakeAvailable(repo-common repo-core repo-backend repo-ft)

#
Expand All @@ -128,12 +140,10 @@ add_library(

#find_package(CUDAToolkit REQUIRED)
find_package(CUDA 10.1 REQUIRED)
find_package(MPI REQUIRED)
##find_package(NCCL REQUIRED)
#if (${CUDA_VERSION} GREATER_EQUAL 11.0)
message(STATUS "Add DCUDA11_MODE")
add_definitions("-DCUDA11_MODE")
#endif()
if (${CUDA_VERSION} GREATER_EQUAL 11.0)
message(STATUS "Add DCUDA11_MODE")
add_definitions("-DCUDA11_MODE")
endif()

set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})

Expand All @@ -148,7 +158,6 @@ target_include_directories(
${CMAKE_CURRENT_SOURCE_DIR}/src
${TRITON_PYTORCH_INCLUDE_PATHS}
${Python3_INCLUDE_DIRS}
${MPI_INCLUDE_PATH}
${repo-ft_SOURCE_DIR}
${repo-core_SOURCE_DIR}/include
)
Expand All @@ -157,8 +166,6 @@ target_link_directories(
triton-fastertransformer-backend
PRIVATE
${CUDA_PATH}/lib64
${MPI_Libraries}
/usr/local/mpi/lib
)

target_compile_features(triton-fastertransformer-backend PRIVATE cxx_std_14)
Expand Down Expand Up @@ -210,14 +217,37 @@ target_link_libraries(
triton-backend-utils # from repo-backend
transformer-shared # from repo-ft
${TRITON_PYTORCH_LDFLAGS}
${NCCL_LIBRARIES}
${MPI_LIBRARIES}
-lcublas
-lcublasLt
-lcudart
-lcurand
)

if (BUILD_MULTI_GPU)
target_compile_definitions(
triton-fastertransformer-backend
PUBLIC
BUILD_MULTI_GPU
)
target_include_directories(
triton-fastertransformer-backend
PRIVATE
${MPI_INCLUDE_PATH}
)
target_link_directories(
triton-fastertransformer-backend
PRIVATE
${MPI_Libraries}
/usr/local/mpi/lib
)
target_link_libraries(
triton-fastertransformer-backend
PRIVATE
${NCCL_LIBRARIES}
${MPI_LIBRARIES}
)
endif()

if(${TRITON_ENABLE_GPU})
target_link_libraries(
triton-fastertransformer-backend
Expand Down
87 changes: 70 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,39 @@

# FasterTransformer Backend

The Triton backend for the [FasterTransformer](https://github.com/NVIDIA/FasterTransformer). This repository provides a script and recipe to run the highly optimized transformer-based encoder and decoder component, and it is tested and maintained by NVIDIA. In the FasterTransformer v4.0, it supports multi-gpu inference on GPT-3 model. This backend integrates FasterTransformer into Triton to use giant GPT-3 model serving by Triton. In the below example, we will show how to use the FasterTransformer backend in Triton to run inference on a GPT-3 model with 345M parameters trained by [Megatron-LM](https://github.com/NVIDIA/Megatron-LM). In latest beta release, FasterTransformer backend supports the multi-node multi-GPU inference on T5 with the model of huggingface.
The Triton backend for the [FasterTransformer](https://github.com/NVIDIA/FasterTransformer). This repository provides a script and recipe to run the highly optimized transformer-based encoder and decoder component, and it is tested and maintained by NVIDIA. In the FasterTransformer v4.0, it supports multi-gpu inference on GPT-3 model. This backend integrates FasterTransformer into Triton to use giant GPT-3 model serving by Triton. In the below example, we will show how to use the FasterTransformer backend in Triton to run inference on a GPT-3 model with 345M parameters trained by [Megatron-LM](https://github.com/NVIDIA/Megatron-LM). In latest release, FasterTransformer backend supports the multi-node multi-GPU inference on T5 with the model of huggingface.

Note that this is a research and prototyping tool, not a formal product or maintained framework. User can learn more about Triton backends in the [backend repo](https://github.com/triton-inference-server/backend). Ask questions or report problems on the [issues page](https://github.com/triton-inference-server/fastertransformer_backend/issues) in this FasterTransformer_backend repo.

## Table Of Contents

- [FasterTransformer Backend](#fastertransformer-backend)
- [Table Of Contents](#table-of-contents)
- [Support matrix](#support-matrix)
- [Introduction](#introduction)
- [Setup](#setup)
- [Prepare docker images](#prepare-docker-images)
- [Rebuilding FasterTransformer backend (optional)](#rebuilding-fastertransformer-backend-optional)
- [NCCL_LAUNCH_MODE](#nccl_launch_mode)
- [GPUs Topology](#gpus-topology)
- [MPI Launching with Tensor Parallel size and Pipeline Parallel Size Setting](#mpi-launching-with-tensor-parallel-size-and-pipeline-parallel-size-setting)
- [Model-Parallism and Triton-Multiple-Model-Instances](#model-parallism-and-triton-multiple-model-instances)
- [Run inter-node (T x P > GPUs per Node) models](#run-inter-node-t-x-p--gpus-per-node-models)
- [Run intra-node (T x P <= GPUs per Node) models](#run-intra-node-t-x-p--gpus-per-node-models)
- [Specify Multiple Model Instances](#specify-multiple-model-instances)
- [Multi-Node Inference](#multi-node-inference)
- [Request examples](#request-examples)
- [Changelog](#changelog)

## Support matrix

| Models | FP16 | BF16 | Tensor parallel | Pipeline parallel |
| -------- | ---- | ---- | --------------- | ----------------- |
| GPT | Yes | Yes | Yes | Yes |
| GPT-J | Yes | Yes | Yes | Yes |
| T5 | Yes | Yes | Yes | Yes |
| GPT-NeoX | Yes | Yes | Yes | Yes |
| BERT | Yes | Yes | Yes | Yes |

## Introduction

FasterTransformer backend hopes to integrate the FasterTransformer into Triton, leveraging the efficiency of FasterTransformer and serving capabilities of Triton. To run the GPT-3 model, we need to solve the following two issues: 1. How to run the auto-regressive model? 2. How to run the model with multi-gpu and multi-node?
Expand Down Expand Up @@ -84,10 +99,10 @@ For the issue of running the model with multi-gpu and multi-node, FasterTransfor
## Setup

```bash
git clone https://github.com/triton-inference-server/fastertransformer_backend.git
cd fastertransformer_backend
export WORKSPACE=$(pwd)
export SRC_MODELS_DIR=${WORKSPACE}/models
export TRITON_MODELS_STORE=${WORKSPACE}/triton-model-store
export CONTAINER_VERSION=22.03
export CONTAINER_VERSION=22.07
export TRITON_DOCKER_IMAGE=triton_with_ft:${CONTAINER_VERSION}
```

Expand All @@ -97,12 +112,6 @@ The current official Triton Inference Server docker image doesn't contain
FasterTransformer backend, thus the users must prepare own docker image using below command:

```bash
cd ${WORKSPACE}
git clone https://github.com/triton-inference-server/fastertransformer_backend
git clone https://github.com/triton-inference-server/server.git # We need some tools when we test this backend
git clone https://github.com/NVIDIA/FasterTransformer.git # Used for convert the checkpoint and triton output
ln -s server/qa/common .
cd fastertransformer_backend
docker build --rm \
--build-arg TRITON_VERSION=${CONTAINER_VERSION} \
-t ${TRITON_DOCKER_IMAGE} \
Expand Down Expand Up @@ -180,22 +189,50 @@ If your current machine/nodes are fully connected through PCIE or even across NU
If you met timed-out or hangs, please first check the topology and try to use DGX V100 or DGX A100 with nvlink connected.


## MPI Launching with Tensor Parallel size and Pipeline Parallel Size Setting

## Model-Parallism and Triton-Multiple-Model-Instances
We apply MPI to start single-node/multi-node servers.

- N: Number of MPI Processes/Number of Nodes
- T: Tensor Parallel Size. Default 8
- T: Tensor Parallel Size. Default 1
- P: Pipeline Parallel Size. Default 1

`total number of gpus = num_gpus_per_node x N = T x P`
Multiple model instances on same GPUs will share the weights, so there will not be any redundant weights memory allocated.

### Run inter-node (T x P > GPUs per Node) models
- `total number of GPUs = num_gpus_per_node x N = T x P`.
- only single mode instance supported

### Run intra-node (T x P <= GPUs per Node) models
- `total number of visible GPUs must be evenly divisble by T x P`. Note that you can control this by setting `CUDA_VISIBLE_DEVICES`.
- `total number of visible GPUs must be <= T x P x Instance Count`. It can avoid unnecessary cuda memory allocation on unused GPUs.
- multiple model instances can be run on tsame GPU groups or different GPU groups.

The backend will first try to assign different GPU groups to different model instances. If there are not empty GPUs, multiple model instances will be assigned to the same GPU groups.

For example, if there are 8 GPUs, 8 model instances (T = 2, P = 1), then model instances will be distributed to GPU groups [0, 1], [2, 3], [4, 5], [6, 7], [0, 1], [2, 3], [4, 5], [6, 7].
- weights are shared among model instances in same GPU groups. In the example above, instance 0 and instance 4 will share the same weights, and others are similar.

**Note** that we currently do not support the case that different nodes have different number of GPUs.
### Specify Multiple Model Instances

Set `count` here to start multiple model instances. Note `KIND_CPU` is the only choice here as the backend needs to take full control of how to distribute multiple model instances to all the visible GPUs.

```json
instance_group [
{
count: 8
kind: KIND_CPU
}
]
```

### Multi-Node Inference

We currently do not support the case that different nodes have different number of GPUs.

We start one MPI process per node. If you need to run on three nodes, then you should launch 3 Nodes with one process per node.
Remember to change `tensor_para_size` and `pipeline_para_size` if you run on multiple nodes.

We do suggest tensor_para_size = number of gpus in one node (e.g. 8 for DGX A100), and pipeline_para_size = number of nodes (2 for two nodes). Other model configuration in config.pbtxt should be modified as normal.
We do suggest tensor_para_size = number of GPUs in one node (e.g. 8 for DGX A100), and pipeline_para_size = number of nodes (2 for two nodes). Other model configuration in config.pbtxt should be modified as normal.

## Request examples

Expand All @@ -205,6 +242,22 @@ Specifically `tools/issue_request.py` is a simple script that sends a request co

## Changelog

Aug 2022
- Support for interactive generation

July 2022
- Support shared context optimization in GPT model
- Support UL2

June 2022
- Support decoupled (streaming) mode.
- Add demo of grpc protocol.
- Support BERT

May 2022
- Support GPT-NeoX.
- Support optional input. (triton version must be after 22.05)

April 2022
- Support bfloat16 inference in GPT model.
- Support Nemo Megatron T5 and Megatron-LM T5 model.
Expand Down
14 changes: 14 additions & 0 deletions all_models/bert/fastertransformer/1/1-gpu/config.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[bert]
model_name = bert
position_embedding_type = absolute
hidden_size = 768
num_layer = 12
head_num = 12
size_per_head = 64
activation_type = gelu
inter_size = 3072
max_position_embeddings = 512
layer_norm_eps = 1e-12
weight_data_type = fp32
tensor_para_size = 1

14 changes: 14 additions & 0 deletions all_models/bert/fastertransformer/1/2-gpu/config.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[bert]
model_name = bert
position_embedding_type = absolute
hidden_size = 768
num_layer = 12
head_num = 12
size_per_head = 64
activation_type = gelu
inter_size = 3072
max_position_embeddings = 512
layer_norm_eps = 1e-12
weight_data_type = fp32
tensor_para_size = 2

Loading

0 comments on commit 22dba92

Please sign in to comment.