Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slc/add SyncBN ops and modify ci #221

Merged
merged 40 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
296e675
add torch.module.to
wey-code Jul 21, 2023
695dc4d
Merge branch 'main' of https://github.com/DeepLink-org/dipu into main
wey-code Jul 25, 2023
a83575d
Merge branch 'main' of https://github.com/DeepLink-org/dipu into main
wey-code Jul 26, 2023
b025b6a
add bn stats
wey-code Jul 26, 2023
3ff3e90
add SyncBN ops
wey-code Aug 3, 2023
5110a46
update DIOPI
wey-code Aug 3, 2023
92ab438
update test case comment
wey-code Aug 4, 2023
e4f8c49
update
wey-code Aug 4, 2023
c13c2ea
add bn elemt
wey-code Aug 5, 2023
685f067
update DIOPI
wey-code Aug 5, 2023
9c2e8a9
modify ci cuda models
wey-code Aug 5, 2023
cba4ee2
add td-hm_hrnet
wey-code Aug 5, 2023
b001fcd
fix bug
wey-code Aug 5, 2023
4e11bf8
add shuffle_net
wey-code Aug 5, 2023
837d773
update kinetics400
wey-code Aug 5, 2023
9b4d1a1
test DI-engine
wey-code Aug 7, 2023
97aea06
fix bug
wey-code Aug 7, 2023
6d94c5b
update DIOPI
wey-code Aug 7, 2023
01cfd52
install package
wey-code Aug 7, 2023
878faa3
test all models
wey-code Aug 7, 2023
18ce86b
add exit(1)
wey-code Aug 7, 2023
77004e6
update
wey-code Aug 7, 2023
3db201d
add set -e
wey-code Aug 7, 2023
aebfeb4
Merge branch 'slc/add_bn_stats' of https://github.com/DeepLink-org/di…
wey-code Aug 7, 2023
e11a56c
lint
wey-code Aug 7, 2023
ecf6961
test stable diffusion
wey-code Aug 8, 2023
5436009
install mmagic
wey-code Aug 8, 2023
adee87e
Merge branch 'main' of https://github.com/DeepLink-org/dipu into slc/…
wey-code Aug 8, 2023
24dba8e
simplify test case
wey-code Aug 8, 2023
6a31820
pip install transformer
wey-code Aug 8, 2023
c775cab
pip install accelerate
wey-code Aug 8, 2023
74d8b7b
test all cuda models
wey-code Aug 9, 2023
5d5a0a8
test cuda all models
wey-code Aug 11, 2023
39ab000
Merge branch 'slc/add_bn_stats' of https://github.com/DeepLink-org/di…
wey-code Aug 16, 2023
5df4665
fix lint
wey-code Aug 16, 2023
eca54dd
lint
wey-code Aug 16, 2023
c21ab4a
Update ci_one_iter.sh
wey-code Aug 17, 2023
d88a920
Merge pull request #235 from DeepLink-org/slc/modify_cuda_ci
wey-code Aug 17, 2023
dd3c372
update DIOPI
wey-code Aug 17, 2023
a38696c
Merge branch 'slc/add_bn_stats' of https://github.com/DeepLink-org/di…
wey-code Aug 17, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
steps:
- name: clone repo
run: |
set -e
cd ${GITHUB_WORKSPACE} && rm -rf DIPU DIPU_DIOPI && git clone https://github.com/DeepLink-org/DIPU.git && cd DIPU
if [ $GITHUB_EVENT_NAME == "pull_request" ]; then
echo "${{ github.base_ref }} "
Expand Down
43 changes: 43 additions & 0 deletions scripts/autogen_diopi_wrapper/diopi_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1716,3 +1716,46 @@
custom_code_at_the_beginning: |
::diopiSize_t dimDiopiSize = toDiopiSize(dim);
interface: diopiAmax(ctx, out, self, dimDiopiSize, keepdim)

- schema: batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor)
custom_code_at_the_beginning: |
auto shape = input.size(1);
auto out0 = at::empty({shape}, input.options().dtype(at::kFloat));
auto out1 = at::empty({shape}, input.options().dtype(at::kFloat));
interface: diopiBatchNormStats(ctx, out0, out1, input, eps)

- schema: batch_norm_gather_stats_with_counts(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts) -> (Tensor, Tensor)
custom_code_at_the_beginning: |
auto shape = input.size(1);
auto out0 = at::empty({shape}, input.options().dtype(at::kFloat));
auto out1 = at::empty({shape}, input.options().dtype(at::kFloat));
interface: diopiBatchNormGatherStatsWithCounts(ctx, out0, out1, input, mean, invstd, const_cast<diopiTensorHandle_t>(running_mean), const_cast<diopiTensorHandle_t>(running_var), momentum, eps, counts)

- schema: batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor)
custom_code_at_the_beginning: |
auto shape = input.size(1);
at::Tensor out0;
at::Tensor out1;
at::Tensor out2;
at::Tensor out3;
if(input_g){
out0 = at::empty({shape}, input.options().dtype(at::kFloat));
out1 = at::empty({shape}, input.options().dtype(at::kFloat));
}
if(weight_g){
out2 = at::empty({shape}, input.options().dtype(at::kFloat));
}
if(bias_g){
out3 = at::empty({shape}, input.options().dtype(at::kFloat));
}
interface: diopiBatchNormBackwardReduce(ctx, out0, out1, out2, out3, grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g)

- schema: batch_norm_backward_elemt(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor
custom_code_at_the_beginning: |
auto out = at::empty_like(grad_out);
interface: diopiBatchNormBackwardElemt(ctx, out, grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);

- schema: batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor
custom_code_at_the_beginning: |
auto out = at::empty_like(input);
interface: diopiBatchNormElemt(ctx, out, input, weight, bias, mean, invstd, eps);
2 changes: 2 additions & 0 deletions scripts/ci/ci_one_iter.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#!/bin/bash

function clone_needed_repo() {

set -e
# clone some repositories

#define some version
Expand Down
158 changes: 158 additions & 0 deletions tests/test_ops/archived/test_SyncBN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import unittest
import torch
import os
os.environ['DIPU_MOCK_CUDA'] = "False"
import torch_dipu
dipu = torch_dipu.dipu.diputype
device_cuda = torch.device("cuda")
assert device_cuda.type == "cuda"
device_dipu = torch.device(dipu)


# Now the test case only support CUDA, When using other device, the test will skip.
# TODO: save and read baseline data, which will make the test case work in other device.
class TestSchema(unittest.TestCase):

def test_batch_norm_stats(self):
if (torch.cuda.is_available() is False):
return
x_cuda = torch.randn([5, 5]).to(device_cuda)
z1_mean, z1_invstd = torch.batch_norm_stats(x_cuda, 1e-5)
x_dipu = x_cuda.to(device_dipu)
z2_mean, z2_invstd = torch.batch_norm_stats(x_dipu, 1e-5)

self.assertTrue(torch.allclose(z1_mean.cpu(), z2_mean.cpu()))
self.assertTrue(torch.allclose(z1_invstd.cpu(), z2_invstd.cpu()))

def test_batch_norm_gather_stats_with_counts(self):
if (torch.cuda.is_available() is False):
return
workpiece = 7
input = torch.rand(2, 8, 32, 56, 56)
mean_all = torch.rand(workpiece, 8)
invstd_all = torch.rand(workpiece, 8)
running_mean = torch.rand(8)
running_var = torch.rand(8)
momentum = 1e-4
eps = 1e-5
count_all = torch.rand(workpiece * 8)
res1 = self._test_bng(input, mean_all, invstd_all, running_mean, running_var, momentum, eps, count_all, device_cuda)
res2 = self._test_bng(input, mean_all, invstd_all, running_mean, running_var, momentum, eps, count_all, device_dipu)
self._test_res(res1, res2)

def test_batch_norm_backward_elemt(self):
if (torch.cuda.is_available() is False):
return
input = torch.rand(2, 8, 32, 56, 56)
mean = torch.rand(8)
invstd = torch.rand(8)
weight = torch.rand(8)
grad_out = torch.rand(2, 8, 32, 56, 56)
sum_dy = torch.rand(8)
sum_dy_xmu = torch.rand(8)
count_tensor = torch.tensor([5, 5, 4, 4, 3, 1, 5, 7], dtype=torch.int32)
res1 = self._test_bnbe(input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_tensor, grad_out, device_cuda)
res2 = self._test_bnbe(input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_tensor, grad_out, device_dipu)
self._test_res(res1, res2)

def test_batch_norm_elemt(self):
if (torch.cuda.is_available() is False):
return
input = torch.rand(2, 8, 32, 56, 56)
mean = torch.rand(8)
invstd = torch.rand(8)
weight = torch.rand(8)
bias = torch.rand(8)
eps = 1e-5
res1 = self._test_bne(input, weight, bias, mean, invstd, eps, device_cuda)
res2 = self._test_bne(input, weight, bias, mean, invstd, eps, device_dipu)
self._test_res(res1, res2)

def test_batch_norm_backward_reduce(self):
if (torch.cuda.is_available() is False):
return
input = torch.rand(2, 8, 32, 56, 56)
mean = torch.rand(8)
invstd = torch.rand(8)
weight = torch.rand(8)
grad_out = torch.rand(2, 8, 32, 56, 56)
res1 = self._test_bnbr(input, mean, invstd, weight, grad_out, device_cuda)
res2 = self._test_bnbr(input, mean, invstd, weight, grad_out, device_dipu)
self._test_res(res1, res2)

def _test_bnbr(self, input, mean_all, invstd_all, weight, grad_out, device):
input_d = input.to(device)
mean_all_d = mean_all.to(device)
invstd_all_d = invstd_all.to(device)
weight_d = weight.to(device)
grad_out_d = grad_out.to(device)
sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
grad_out_d,
input_d,
mean_all_d,
invstd_all_d,
weight_d,
True,
True,
True,
)
return [sum_dy, sum_dy_xmu, grad_weight, grad_bias]

def _test_bne(self, input, weight, bias, mean, invstd, eps, device):
input_d = input.to(device)
mean_d = mean.to(device)
invstd_d = invstd.to(device)
weight_d = weight.to(device)
bias_d = bias.to(device)
out = torch.batch_norm_elemt(input_d, weight_d, bias_d, mean_d, invstd_d, eps)
return [out]

def _test_bng(self, input, mean_all, invstd_all, running_mean, running_var, momentum, eps, count_all, device):
input_d = input.to(device)
mean_all_d = mean_all.to(device)
invstd_all_d = invstd_all.to(device)
running_mean_d = running_mean.to(device)
running_var_d = running_var.to(device)
momentum_d = momentum
eps_d = eps
count_all_d = count_all.to(device)
mean_d, invstd_d = torch.batch_norm_gather_stats_with_counts(
input_d,
mean_all_d,
invstd_all_d,
running_mean_d,
running_var_d,
momentum_d,
eps_d,
count_all_d
)
return [mean_d, invstd_d]

def _test_bnbe(self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_tensor, grad_out, device):
input_d = input.to(device)
mean_d = mean.to(device)
invstd_d = invstd.to(device)
weight_d = weight.to(device)
grad_out_d = grad_out.to(device)
sum_dy_d = sum_dy.to(device)
sum_dy_xmu_d = sum_dy_xmu.to(device)
count_tensor_d = count_tensor.to(device)
grad_input = torch.batch_norm_backward_elemt(
grad_out_d,
input_d,
mean_d,
invstd_d,
weight_d,
sum_dy_d,
sum_dy_xmu_d,
count_tensor_d
)
return [grad_input]

def _test_res(self, res1, res2):
for i in range(len(res1)):
self.assertTrue(torch.allclose(res1[i].cpu(), res2[i].cpu()))


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion third_party/DIOPI
Submodule DIOPI updated 61 files
+0 −1 .gitattributes
+20 −3 .github/workflows/main.yml
+3 −0 .gitignore
+2 −2 README.md
+13 −4 adaptor/codegen/gen.py
+50 −2 adaptor/codegen/op_template.py
+0 −1 diopi_test/csrc/litert.cpp
+1 −0 diopi_test/python/conformance/config.py
+2 −0 diopi_test/python/conformance/conformance_test.py
+1,011 −148 diopi_test/python/conformance/diopi_configs.py
+123 −46 diopi_test/python/conformance/diopi_functions.py
+1 −1 diopi_test/python/conformance/diopi_runtime.py
+35 −3 diopi_test/python/conformance/gen_data.py
+0 −0 diopi_test/python/op_time.dat
+12 −2 impl/CMakeLists.txt
+1 −1 impl/ascend/CMakeLists.txt
+179 −102 impl/ascend/common/acloprunner.hpp
+123 −0 impl/ascend/common/utils.cpp
+62 −0 impl/ascend/functions/activation.cpp
+42 −0 impl/ascend/functions/batch_norm.cpp
+135 −0 impl/ascend/functions/binary.cpp
+0 −60 impl/ascend/functions/binary_op.cpp
+20 −0 impl/ascend/functions/cast.cpp
+59 −0 impl/ascend/functions/cat.cpp
+30 −23 impl/ascend/functions/conv2d.cpp
+45 −0 impl/ascend/functions/copy.cpp
+61 −0 impl/ascend/functions/dropout.cpp
+6 −0 impl/ascend/functions/error.cpp
+7 −1 impl/ascend/functions/fill.cpp
+22 −0 impl/ascend/functions/flip.cpp
+27 −0 impl/ascend/functions/less.cpp
+46 −0 impl/ascend/functions/linear.cpp
+96 −0 impl/ascend/functions/loss.cpp
+24 −0 impl/ascend/functions/masked_fill.cpp
+31 −0 impl/ascend/functions/one_hot.cpp
+57 −0 impl/ascend/functions/pool.cpp
+75 −0 impl/ascend/functions/reduce.cpp
+39 −0 impl/ascend/functions/threshold.cpp
+45 −0 impl/ascend/functions/unary.cpp
+9 −3 impl/ascend/test/conform_test.cpp
+4 −5 impl/camb/CMakeLists.txt
+21 −44 impl/camb/cmake/FindNeuware.cmake
+6 −0 impl/camb/cnnl_helper.cpp
+0 −1 impl/camb/cnnl_helper.hpp
+1 −1 impl/camb/common/basic_op.cpp
+0 −1 impl/camb/common/common.hpp
+37 −15 impl/camb/common/debug.hpp
+53 −23 impl/camb/common/dtype_cast.cpp
+1 −1 impl/camb/convert_config.yaml
+151 −28 impl/camb/device_configs.py
+84 −11 impl/camb/diopi_helper.hpp
+66 −0 impl/camb/functions/adadelta.cpp
+22 −10 impl/camb/functions/bitwise.cpp
+1 −1 impl/camb/functions/fill.cpp
+43 −0 impl/camb/functions/polar.cpp
+5 −3 impl/scripts/build_impl.sh
+0 −3 impl/topsrider/prebuilt/libops.so
+1 −1 impl/torch/CMakeLists.txt
+80 −0 impl/torch/functions.cpp
+4 −2 proto/include/diopi/diopirt.h
+87 −0 proto/include/diopi/functions.h