Skip to content

Commit

Permalink
deepspeed-fork content for 1.16.0
Browse files Browse the repository at this point in the history
Signed-off-by: SW publisher <[email protected]>
  • Loading branch information
SW publisher authored and Jenkins committed Jun 4, 2024
1 parent ce78a63 commit 4cbcca3
Show file tree
Hide file tree
Showing 220 changed files with 10,829 additions and 1,427 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/nv-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ jobs:
git rev-parse --short HEAD
pip install .
- name: Install datasets
run: |
pip install datasets
- name: Install deepspeed
run: |
pip install .[dev,1bit,autotuning,inf]
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ repos:
name: check-torchdist
entry: ./scripts/check-torchdist.py
language: python
exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py|deepspeed/elasticity/elastic_agent.py|deepspeed/launcher/launch.py|tests/unit/comm/test_dist.py)
exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py|deepspeed/elasticity/elastic_agent.py|deepspeed/launcher/launch.py|tests/unit/comm/test_dist.py|deepspeed/runtime/zero/utils.py|deepspeed/tools/pg_sim/ut/base.py|deepspeed/tools/pg_sim/pg.py)
# Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm

- repo: local
Expand Down
4 changes: 4 additions & 0 deletions accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@ def create_op_builder(self, class_name):
def get_op_builder(self, class_name):
...

@abc.abstractmethod
def get_compile_backend(self):
...

@abc.abstractmethod
def build_extension(self):
...
Expand Down
3 changes: 3 additions & 0 deletions accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,6 @@ def build_extension(self):

def export_envs(self):
return []

def get_compile_backend(self):
return "inductor"
3 changes: 3 additions & 0 deletions accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,6 @@ def build_extension(self):

def export_envs(self):
return ['NCCL']

def get_compile_backend(self):
return "inductor"
11 changes: 11 additions & 0 deletions accelerator/hpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,17 @@ def get_op_builder(self, class_name):
else:
return self.class_dict['NotImplementedBuilder'] if 'NotImplementedBuilder' in self.class_dict else None

def get_compile_backend(self):
return "hpu_backend"

#shall be removed once moving to torch.compile
def wrap_in_hpu_graph(self, module):
if self.hpu.is_lazy():
module = self.hpu.wrap_in_hpu_graph(module)
else:
print("Warning: hpu graphs in eager mode is not supported, ignoring")
return module

def build_extension(self):
from torch.utils.cpp_extension import BuildExtension
return BuildExtension
Expand Down
3 changes: 3 additions & 0 deletions accelerator/mps_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,6 @@ def build_extension(self):

def export_envs(self):
return []

def get_compile_backend(self):
return "inductor"
3 changes: 3 additions & 0 deletions accelerator/npu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,6 @@ def build_extension(self):

def export_envs(self):
return ['ASCEND', 'HCCL', 'LD_LIBRARY', 'PATH']

def get_compile_backend(self):
return "inductor"
1 change: 1 addition & 0 deletions build.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
+hpu.synapse.v1.16.0
6 changes: 5 additions & 1 deletion csrc/adam/cpu_adam_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,17 @@ int ds_adam_step(int optimizer_id,
opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction);

bool bit16_precision = false;
if ((params.options().dtype() == at::kHalf) || (params.options().dtype() == at::kBFloat16))
bit16_precision = true;

opt->Step_8(params_ptr,
grads_ptr,
exp_avg_ptr,
exp_avg_sq_ptr,
params_c.numel(),
nullptr,
(params.options().dtype() == at::kHalf));
bit16_precision);

#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
opt->SynchronizeStreams();
Expand Down
4 changes: 4 additions & 0 deletions csrc/includes/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ typedef __half ds_half_precision_t;
#include "acl/acl.h"
#include "torch_npu/csrc/core/npu/NPUStream.h"
typedef c10::Half ds_half_precision_t;
#elif defined(__BFLOAT16__)
#include <torch/torch.h>
typedef at::BFloat16 ds_half_precision_t
#else
#include <cmath>
typedef unsigned short ds_half_precision_t;
Expand Down Expand Up @@ -259,6 +262,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
simd_store<span>(_exp_avg + i, momentum_4, false);
simd_store<span>(_exp_avg_sq + i, variance_4, false);
}
// Params are updated only in case of float16, which is currently not supported on HPU
#if defined(__ENABLE_CUDA__)
if (dev_params) {
if (half_precision)
Expand Down
110 changes: 109 additions & 1 deletion csrc/includes/simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,23 @@
#include <cpuid.h>
#include <x86intrin.h>
#endif
#include <cstdint>
#include <cstring>
#include <type_traits>

template <typename T>
inline T readAs(const void* src)
{
T res;
std::memcpy(&res, src, sizeof(T));
return res;
}

template <typename T>
inline void writeAs(void* dst, const T& val)
{
std::memcpy(dst, &val, sizeof(T));
}

#define TILE (128 * 1024 * 1024)
#if defined(__AVX512__) or defined(__AVX256__)
Expand All @@ -29,12 +46,58 @@
#define SIMD_OR(x, y) _mm512_or_ps(x, y)
#define SIMD_XOR(x, y) _mm512_xor_ps(x, y)
#define SIMD_WIDTH 16
#if defined(ENABLE_BFLOAT16)
static __m512 load_16_bf16_as_f32(const void* data)
{
__m256i a = readAs<__m256i>(data); // use memcpy to avoid aliasing
__m512i b = _mm512_cvtepu16_epi32(a); // convert 8 u16 to 8 u32
__m512i c = _mm512_slli_epi32(b, 16); // logical shift left of all u32 by
// 16 bits (representing bf16->f32)
return readAs<__m512>(&c); // use memcpy to avoid aliasing
}

static void store_16_f32_as_bf16_nearest(__m512 v, void* data)
{
__m512i u32 = readAs<__m512i>(&v);

// flow assuming non-nan:

// uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
__m512i b = _mm512_srli_epi32(u32, 16);
__m512i lsb_mask = _mm512_set1_epi32(0x00000001);
__m512i c = _mm512_and_si512(b, lsb_mask);
__m512i bias_constant = _mm512_set1_epi32(0x00007fff);
__m512i rounding_bias = _mm512_add_epi32(c, bias_constant);

// uint16_t res = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
__m512i d = _mm512_add_epi32(u32, rounding_bias);
__m512i e = _mm512_srli_epi32(d, 16);
__m256i non_nan_res = _mm512_cvtusepi32_epi16(e);

// handle nan (exp is all 1s and mantissa != 0)
// if ((x & 0x7fffffffU) > 0x7f800000U)
__m512i mask_out_sign = _mm512_set1_epi32(0x7fffffff);
__m512i non_sign_bits = _mm512_and_si512(u32, mask_out_sign);
__m512i nan_threshold = _mm512_set1_epi32(0x7f800000);
__mmask16 nan_mask = _mm512_cmp_epi32_mask(non_sign_bits, nan_threshold, _MM_CMPINT_GT);

// mix in results with nans as needed
__m256i nans = _mm256_set1_epi16(0x7fc0);
__m256i res = _mm256_mask_mov_epi16(non_nan_res, nan_mask, nans);

writeAs(data, res);
}

#define SIMD_LOAD2(x, h) ((h) ? load_16_bf16_as_f32(x) : _mm512_loadu_ps(x))

#define SIMD_STORE2(x, d, h) ((h) ? store_16_f32_as_bf16_nearest(d, x) : _mm512_storeu_ps(x, d))
#else // ENABLE_BFLOAT16
#define SIMD_LOAD2(x, h) \
((h) ? _mm512_cvtph_ps(_mm256_castps_si256(_mm256_loadu_ps(x))) : _mm512_loadu_ps(x))
#define SIMD_STORE2(x, d, h) \
((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
: _mm512_storeu_ps(x, d))
#endif // ENABLE_BFLOAT16

#define INTV __m256i
#elif defined(__AVX256__)
Expand All @@ -52,12 +115,57 @@
#define SIMD_XOR(x, y) _mm256_xor_ps(x, y)
#define SIMD_WIDTH 8

#if defined(ENABLE_BFLOAT16)
__m256 load_8_bf16_as_f32(const float* data)
{
__m128i a = readAs<__m128i>(data); // use memcpy to avoid aliasing
__m256i b = _mm256_cvtepu16_epi32(a); // convert 8 u16 to 8 u32
__m256i c = _mm256_slli_epi32(b, 16); // logical shift left of all u32 by
// 16 bits (representing bf16->f32)
return readAs<__m256>(&c); // use memcpy to avoid aliasing
}

void store_8_f32_as_bf16_nearest(__m256 v, float* data)
{
__m256i u32 = readAs<__m256i>(&v);

// flow assuming non-nan:

// uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
__m256i b = _mm256_srli_epi32(u32, 16);
__m256i lsb_mask = _mm256_set1_epi32(0x00000001);
__m256i c = _mm256_and_si256(b, lsb_mask);
__m256i bias_constant = _mm256_set1_epi32(0x00007fff);
__m256i rounding_bias = _mm256_add_epi32(c, bias_constant);

// uint16_t res = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
__m256i d = _mm256_add_epi32(u32, rounding_bias);
__m256i e = _mm256_srli_epi32(d, 16);
__m128i non_nan_res = _mm256_cvtusepi32_epi16(e);

// handle nan (exp is all 1s and mantissa != 0)
// if ((x & 0x7fffffffU) > 0x7f800000U)
__m256i mask_out_sign = _mm256_set1_epi32(0x7fffffff);
__m256i non_sign_bits = _mm256_and_si256(u32, mask_out_sign);
__m256i nan_threshold = _mm256_set1_epi32(0x7f800000);
__mmask8 nan_mask = _mm256_cmp_epi32_mask(non_sign_bits, nan_threshold, _MM_CMPINT_GT);

// mix in results with nans as needed
__m128i nans = _mm_set1_epi16(0x7fc0);
__m128i res = _mm_mask_mov_epi16(non_nan_res, nan_mask, nans);

writeAs(data, res);
}
#define SIMD_LOAD2(x, h) ((h) ? load_8_bf16_as_f32(x) : _mm256_loadu_ps(x))

#define SIMD_STORE2(x, d, h) ((h) ? store_8_f32_as_bf16_nearest(d, x) : _mm256_storeu_ps(x, d))
#else // ENABLE_BFLOAT16
#define SIMD_LOAD2(x, h) \
((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)(x))) : _mm256_loadu_ps(x))
#define SIMD_STORE2(x, d, h) \
((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
: _mm256_storeu_ps(x, d))

#endif // ENABLE_BFLOAT16
#define INTV __m128i
#endif

Expand Down
95 changes: 91 additions & 4 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,15 +446,15 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
unsigned layer_id,
unsigned num_layers,
at::Tensor& alibi,
float rope_theta)
float rope_theta,
bool is_prompt,
std::optional<at::Tensor> token_idx)
{
unsigned bsz = query_key_value.size(0);
unsigned seq_len = query_key_value.size(1);
int k = query_key_value.size(2) / (heads + 2 * (num_kv > 0 ? num_kv : heads));
unsigned hidden_dim = heads * k;

bool is_prompt = (seq_len > 1);

if (is_prompt) InferenceContext::Instance().reset_tokens(seq_len);
unsigned soft_len = InferenceContext::Instance().current_tokens();

Expand Down Expand Up @@ -847,6 +847,87 @@ std::vector<at::Tensor> ds_layer_norm_residual_store_pre_ln_res(at::Tensor& inpu
return {norm_output, res_output};
}

template <typename T>
at::Tensor ds_transform4d_0213(at::Tensor& input, int seq_length)
{
auto input_cont = input.contiguous();
unsigned batch_size = input.size(0);
unsigned num_heads = input.size(1);
unsigned seq_length_head_dim = input.size(2);
unsigned head_dim = seq_length_head_dim / seq_length;
unsigned hidden_dim = num_heads * head_dim;

auto options = at::TensorOptions()
.dtype(input.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();

launch_transform4d_0213<T>(workspace,
(T*)input.data_ptr(),
batch_size,
num_heads,
seq_length,
hidden_dim,
InferenceContext::Instance().GetCurrentStream(),
1);
auto output = at::from_blob(workspace, {batch_size, seq_length, num_heads, head_dim}, options);
return output;
}

template <typename T>
std::vector<at::Tensor> ds_bias_add_transform_0213(at::Tensor& input,
at::Tensor& bias,
int num_heads,
int trans_count)
{
TORCH_CHECK(
trans_count == 1 or trans_count == 3, "trans_count ", trans_count, " is not supported");
auto input_cont = input.contiguous();

unsigned batch_size = input.size(0);
unsigned seq_length = input.size(1);
unsigned value_size = input.size(2);
unsigned hidden_dim = input.size(2) / trans_count;
unsigned head_dim = hidden_dim / num_heads;

auto options = at::TensorOptions()
.dtype(input.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
auto final_output = workspace;
int num_kv = -1;
int repo_theta = -1;
size_t offset = (batch_size * seq_length * hidden_dim);
launch_bias_add_transform_0213<T>(final_output,
final_output + offset,
final_output + 2 * offset,
(T*)input.data_ptr(),
(T*)bias.data_ptr(),
batch_size,
seq_length,
0, // seq_offset
input.size(1), // all_tokens .. unused?
hidden_dim,
num_heads,
num_kv,
-1, // rotary_dim
false, // rotate_half
false, // rotate_every_two
InferenceContext::Instance().GetCurrentStream(),
trans_count, // trans_count
input.size(1), // max_out_tokens
repo_theta);
return {at::from_blob(final_output, {batch_size, num_heads, seq_length, head_dim}, options),
at::from_blob(
final_output + offset, {batch_size, num_heads, seq_length, head_dim}, options),
at::from_blob(
final_output + 2 * offset, {batch_size, num_heads, seq_length, head_dim}, options)};
}

template <typename T>
void quantized_gemm(void* output,
T* input,
Expand Down Expand Up @@ -2010,7 +2091,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"DeepSpeed memory allocation for GPT inference with " #_name " (CUDA)"); \
m.def("dequantize_" #_name, \
&ds_dequantize<_dtype>, \
"DeepSpeed dequantize with " #_name " (CUDA)")
"DeepSpeed dequantize with " #_name " (CUDA)"); \
m.def("transform4d_0213_" #_name, \
&ds_transform4d_0213<_dtype>, \
"DeepSpeed transform4d 0213 with " #_name " (CUDA)"); \
m.def("bias_add_transform_0213_" #_name, \
&ds_bias_add_transform_0213<_dtype>, \
"DeepSpeed bias and transform 0213 with " #_name " (CUDA)")

DEF_OPS(fp32, float);
DEF_OPS(fp16, __half);
Expand Down
3 changes: 1 addition & 2 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
else:
HAS_TRITON = False

from .utils import log_dist, OnDevice, logger
from . import ops
from . import module_inject

Expand All @@ -38,11 +39,9 @@
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .module_inject import replace_transformer_layer, revert_transformer_layer

from .utils import log_dist, OnDevice, logger
from .comm.comm import init_distributed

from .runtime import zero
from .runtime import DeepSpeedOptimizer, ZeROOptimizer
from .runtime.compiler import is_compile_supported

from .pipe import PipelineModule
Expand Down
Loading

0 comments on commit 4cbcca3

Please sign in to comment.