Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add generator #255

Merged
merged 29 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
e2dfb10
add generator
caikun-pjlab Aug 16, 2023
bfc0f00
camb support generator
caikun-pjlab Aug 17, 2023
36c68ed
add normal
caikun-pjlab Aug 17, 2023
145f1d8
add other random op
caikun-pjlab Aug 18, 2023
a01ce60
merge main
caikun-pjlab Aug 18, 2023
bf61cdb
add cuda generator
caikun-pjlab Aug 19, 2023
be1d6e1
autogen support generator
caikun-pjlab Aug 19, 2023
a613389
fix camb generator
caikun-pjlab Aug 19, 2023
219b67d
fix bug
caikun-pjlab Aug 20, 2023
5714707
optimize code
caikun-pjlab Aug 21, 2023
6f3c377
remove useless header and log
caikun-pjlab Aug 21, 2023
2a3a461
fix format
caikun-pjlab Aug 21, 2023
c950879
add torch.Generator mock
caikun-pjlab Aug 21, 2023
74c5f65
fix generator testcase
caikun-pjlab Aug 21, 2023
9fe3add
update diopi
caikun-pjlab Aug 23, 2023
acc751a
Merge branch 'main' into caikun/dipu_generator
caikun-pjlab Aug 23, 2023
68cf118
update diopi
caikun-pjlab Aug 23, 2023
7a230a9
update parameter
caikun-pjlab Aug 24, 2023
83549be
update diopi
caikun-pjlab Aug 24, 2023
1092346
update diopi
caikun-pjlab Aug 24, 2023
0fd5fe2
release generator before release memory
caikun-pjlab Aug 24, 2023
c0af287
Merge branch 'main' into caikun/dipu_generator
caikun-pjlab Aug 24, 2023
0fe9e76
dropout support generator
caikun-pjlab Aug 28, 2023
246814d
update diopi
caikun-pjlab Aug 29, 2023
4a1f9ed
update DIOPI
caikun-pjlab Aug 29, 2023
fb0c6f5
merge main
caikun-pjlab Aug 29, 2023
b309e48
update diopi
caikun-pjlab Aug 29, 2023
2be91ad
fix comments
caikun-pjlab Aug 29, 2023
2d54167
fix test and compile
caikun-pjlab Aug 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,12 @@ def get_function_optional_scalar_args_from_schema(schema):
return re.findall('Scalar *\? +([\w\d_]+)', param_list)


def get_function_optional_generator_args_from_schema(schema):
param_list = schema[schema.find('(') + 1 : schema.find('->')].strip()
param_list = param_list[0:param_list.rfind(')')]
return re.findall('Generator *\? +([\w\d_]+)', param_list)


def get_function_int_array_args_from_schema(schema):
param_list = create_param_list_from_schema(schema)
int_arrays = []
Expand Down Expand Up @@ -515,9 +521,20 @@ def create_device_check_code(fun_config):
if len(tensors) > 0:
code += "}"


return code

def create_optional_generator_process_code(arg_name):
process_template = CodeTemplate(
"""
::diopiGeneratorHandle_t ${arg_name}DiopiGenerator = (${arg_name}.has_value() && ${arg_name}.value().defined()) ? toDiopiGeneratorHandle(${arg_name}) : toDiopiGeneratorHandle(getDefaultDIPUGenerator());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

::diopiGeneratorHandle_t ${arg_name}DiopiGenerator = toDiopiGeneratorHandle((${arg_name}.has_value() && ${arg_name}.value().defined()) ? ${arg_name} : getDefaultDIPUGenerator());

"""
)
process_code = process_template.substitute(
arg_name=[arg_name],
)
return process_code


file_template = CodeTemplate(diopi_wrapper_file_template_content)

fun_template = CodeTemplate(diopi_wrapper_function_template_content)
Expand Down Expand Up @@ -575,8 +592,9 @@ def functions_code_gen(fun_config):
attrs_process_code += create_optional_scalar_process_code(scalar_param)
diopi_fun_call_code = re.sub('([,\(] *&? *)' + scalar_param.strip() + '( *[,\)])', R'\1' + f"{scalar_param}DiopiScalarPtr" + R'\2', diopi_fun_call_code)



for generator_param in get_function_optional_generator_args_from_schema(fun_config['schema']):
attrs_process_code += create_optional_generator_process_code(generator_param)
diopi_fun_call_code = re.sub('([,\(] *&? *)' + generator_param.strip() + '( *[,\)])', R'\1' + f"{generator_param}DiopiGenerator" + R'\2', diopi_fun_call_code)

int_array_list = get_function_int_array_args_from_schema(fun_config['schema'])
attrs_process_code += create_int_array_process_code(int_array_list)
Expand Down
35 changes: 21 additions & 14 deletions scripts/autogen_diopi_wrapper/diopi_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -385,11 +385,13 @@

- schema: "randperm.out(int n, *, Tensor(a!) out) -> Tensor(a!)"
autocompare: disable
interface: diopiRandperm(ctx, out, n)
custom_code_at_the_beginning: |
diopiGeneratorHandle_t generatorDiopiGenerator = toDiopiGeneratorHandle(getDefaultDIPUGenerator());
interface: diopiRandperm(ctx, out, n, generatorDiopiGenerator)

- schema: "randperm.generator_out(int n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)"
autocompare: disable
interface: diopiRandperm(ctx, out, n)
interface: diopiRandperm(ctx, out, n, generator)

- schema: "aten::sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)"
custom_code_at_the_beginning: |
Expand Down Expand Up @@ -592,15 +594,17 @@
- schema: "dropout_impl(Tensor input, float p, bool train, *, Tensor(a!) mask) -> Tensor"
custom_code_at_the_beginning: |
at::Tensor out = at::empty_like(input);
diopiGeneratorHandle_t generatorDiopiGenerator = toDiopiGeneratorHandle(getDefaultDIPUGenerator());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

toDiopiGeneratorHandle() 是否可以提供一个空参数的版本?这行就可以不用显示写出来了

register_op: False
interface: diopiDropout(ctx, out, mask, input, p, train)
interface: diopiDropout(ctx, out, mask, input, p, train, generatorDiopiGenerator)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interface: diopiDropout(ctx, out, mask, input, p, train, getDefalutDiopiGenerator())


- schema: "dropout(Tensor input, float p, bool train) -> Tensor"
autocompare: disable
custom_code_at_the_beginning: |
auto mask = at::empty(input.sizes(), input.options().dtype(at::kByte));
at::Tensor out = at::empty_like(input);
interface: diopiDropout(ctx, out, mask, input, p, train)
diopiGeneratorHandle_t generatorDiopiGenerator = toDiopiGeneratorHandle(getDefaultDIPUGenerator());
interface: diopiDropout(ctx, out, mask, input, p, train, generatorDiopiGenerator)
outs: [mask]
autograd: True
saved_data: [p, mask]
Expand All @@ -620,14 +624,17 @@
return outputs;

- schema: "dropout__impl(Tensor(a!) self, Tensor(b!) mask, float p, bool train) -> Tensor(a!)"
custom_code_at_the_beginning: |
diopiGeneratorHandle_t generatorDiopiGenerator = toDiopiGeneratorHandle(getDefaultDIPUGenerator());
register_op: False
interface: diopiDropoutInp(ctx, self, mask, p, train)
interface: diopiDropoutInp(ctx, self, mask, p, train, generatorDiopiGenerator)

- schema: "dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)"
custom_code_at_the_beginning: |
auto mask = at::empty(self.sizes(), self.options().dtype(at::kByte));
diopiGeneratorHandle_t generatorDiopiGenerator = toDiopiGeneratorHandle(getDefaultDIPUGenerator());
outs: [mask]
interface: diopiDropoutInp(ctx, self, mask, p, train)
interface: diopiDropoutInp(ctx, self, mask, p, train, generatorDiopiGenerator)
autograd: True
forward_process_code: |
auto mask = at::empty(self.sizes(), self.options().dtype(at::kByte));
Expand Down Expand Up @@ -918,7 +925,7 @@
interface: diopiRsqrt(ctx, out, self)

- schema: "uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)"
interface: diopiUniformInp(ctx, self, from, to)
interface: diopiUniformInp(ctx, self, from, to, generator)

- schema: "tril(Tensor self, int diagonal=0) -> Tensor"
custom_code_at_the_beginning: |
Expand All @@ -937,10 +944,10 @@
else if (self.dim() == 1) {
out = at::empty({num_samples,}, self.options().dtype(at::kLong));
}
interface: diopiMultinomial(ctx, out, self, num_samples, replacement)
interface: diopiMultinomial(ctx, out, self, num_samples, replacement, generator)

- schema: "multinomial.out(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)"
interface: diopiMultinomial(ctx, out, self, num_samples, replacement)
interface: diopiMultinomial(ctx, out, self, num_samples, replacement, generator)

- schema: "roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor"
custom_code_at_the_beginning: |
Expand Down Expand Up @@ -1000,15 +1007,15 @@

- schema: "random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)"
autocompare: disable
interface: diopiRandomInp(ctx, self, 0, nullptr)
interface: diopiRandomInp(ctx, self, 0, nullptr, generator)

- schema: "random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)"
autocompare: disable
interface: diopiRandomInp(ctx, self, 0, &to)
interface: diopiRandomInp(ctx, self, 0, &to, generator)

- schema: "random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)"
autocompare: disable
interface: "diopiRandomInp(ctx, self, from, to.has_value() ? &to.value() : nullptr)"
interface: "diopiRandomInp(ctx, self, from, to.has_value() ? &to.value() : nullptr, generator)"

- schema: "nonzero(Tensor self) -> Tensor"
custom_code_at_the_beginning: |
Expand Down Expand Up @@ -1338,7 +1345,7 @@

- schema: "normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)"
autocompare: disable
interface: diopiNormalInp(ctx, self, mean, std)
interface: diopiNormalInp(ctx, self, mean, std, generator)

- schema: "mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)"
interface: diopiMm(ctx, out, self, mat2)
Expand Down Expand Up @@ -1749,4 +1756,4 @@
- schema: batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor
custom_code_at_the_beginning: |
auto out = at::empty_like(input);
interface: diopiBatchNormElemt(ctx, out, input, weight, bias, mean, invstd, eps);
interface: diopiBatchNormElemt(ctx, out, input, weight, bias, mean, invstd, eps);
2 changes: 2 additions & 0 deletions scripts/autogen_diopi_wrapper/diopi_wrapper_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

namespace dipu::native {

using dipu::diopi_helper::toDiopiGeneratorHandle;

inline bool checkDiopiReturnValue() {
static bool enable = std::getenv("DIPU_DISABLE_CHECK_DIOPI_RETURN_VALUE") == nullptr;
return enable;
Expand Down
180 changes: 180 additions & 0 deletions tests/test_ops/archived/test_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright (c) 2023, DeepLink.
import torch
import torch_dipu

from torch_dipu.testing._internal.common_utils import create_common_tensor, TestCase, run_tests


class TestGenerator(TestCase):
def test_python_api(self):
torch.seed()
torch.cuda.seed_all()
torch.cuda.random.seed_all()
torch.cuda.manual_seed_all(1)
rngs = torch.cuda.get_rng_state_all()
torch.cuda.set_rng_state_all(rngs)
torch.manual_seed(1)
assert torch.cuda.initial_seed() == 1
assert torch.initial_seed() == 1
for i in range(torch.cuda.device_count()):
torch.cuda.manual_seed(i)

state = torch.cuda.get_rng_state(0)
new_state = torch.ones_like(state)
torch.cuda.set_rng_state(new_state, 0)
current_state = torch.cuda.get_rng_state(0)
assert torch.allclose(current_state, torch.tensor(1, device=current_state.device, dtype=current_state.dtype))

def test_torch_generator(self):
gen = torch.Generator()
assert gen.device.type == 'cpu'
gen.manual_seed(1)
assert gen.initial_seed() == 1

gen = torch.Generator("cpu")
assert gen.device.type == 'cpu'

gen = torch.Generator("cuda")
assert gen.device.type == 'xpu'

gen = torch.Generator("cuda:0")
assert gen.device == torch.device('xpu:0')

gen = torch.Generator("dipu")
assert gen.device.type == 'xpu'
gen.manual_seed(1)
assert gen.initial_seed() == 1

def test_randn_with_generator(self):
gen = torch.Generator()
gen.manual_seed(1)
data1 = torch.randn(2, 3, generator = gen)
gen.manual_seed(1)
data2 = torch.randn(2, 3, generator = gen)
assert torch.allclose(data1, data2)
data2 = torch.randn(2, 3, generator = gen)
assert not torch.allclose(data1, data2)

gen = torch.Generator('cuda')
gen.manual_seed(1)
data1 = torch.randn(2, 3, generator = gen, device = 'cuda')
gen.manual_seed(1)
data2 = torch.randn(2, 3, generator = gen, device = 'cuda')
assert torch.allclose(data1, data2)
data2 = torch.randn(2, 3, generator = gen, device = 'cuda')
assert not torch.allclose(data1, data2)

def test_uniform_(self):
t1 = torch.arange(0, 100, dtype=torch.float32).cuda()
t2 = t1.clone()
torch.manual_seed(1)
t1.uniform_()
torch.manual_seed(1)
t2.uniform_()
assert torch.allclose(t1, t2)
t2.uniform_()
assert not torch.allclose(t1, t2)
print("uniform_ allclose success")

def test_normal_(self):
t1 = torch.arange(0, 100, dtype=torch.float32).cuda()
t2 = t1.clone()
torch.manual_seed(1)
t1.normal_()
torch.manual_seed(1)
t2.normal_()
assert torch.allclose(t1, t2)
t2.normal_()
assert not torch.allclose(t1, t2)
print("normal_ allclose success")

def test_random_(self):
t1 = torch.arange(0, 100, dtype=torch.float32).cuda()
t2 = t1.clone()
torch.manual_seed(1)
t1.random_(0, 100)
torch.manual_seed(1)
t2.random_(0, 100)
assert torch.allclose(t1, t2)
t2.random_(0, 100)
assert not torch.allclose(t1, t2)

torch.manual_seed(1)
t1.random_()
torch.manual_seed(1)
t2.random_()
assert torch.allclose(t1, t2)
t2.random_()
assert not torch.allclose(t1, t2)
print("random_ allclose success")

def test_multinomial(self):
data = torch.arange(0, 100, dtype=torch.float).cuda()
torch.manual_seed(1)
data1 = torch.multinomial(data, 10)
torch.manual_seed(1)
data2 = torch.multinomial(data, 10)
assert torch.allclose(data1, data2)
data2 = torch.multinomial(data, 10)
assert not torch.allclose(data1, data2)
print("multinomial allclose success")

def test_randn(self):
torch.manual_seed(1)
t1 = torch.randn(100, device='cuda')
torch.manual_seed(1)
t2 = torch.randn(100, device='cuda')
assert torch.allclose(t1, t2)
t2 = torch.randn(100, device='cuda')
assert not torch.allclose(t1, t2)
print("randn allclose success")

def test_randperm(self):
if torch_dipu.dipu.vendor_type == "MLU":
return

torch.manual_seed(1)
t1 = torch.randperm(100, device='cuda')
torch.manual_seed(1)
t2 = torch.randperm(100, device='cuda')
assert torch.allclose(t1, t2)
t2 = torch.randperm(100, device='cuda')
assert not torch.allclose(t1, t2)
print("randperm allclose success")

def test_dropout(self):
m = torch.nn.Dropout(p=0.2).cuda()
input = torch.randn(20, 16).cuda()
torch.manual_seed(1)
t1 = m(input)
torch.manual_seed(1)
t2 = m(input)
assert torch.allclose(t1, t2)
t2 = m(input)
assert not torch.allclose(t1, t2)
print("dropout allclose success")

def test_dropout_(self):
m = torch.nn.Dropout(p=0.2, inplace=True).cuda()
input = torch.randn(20, 16).cuda()
p = 0.2
torch.manual_seed(1)
t1 = input.clone()
m(t1)
torch.manual_seed(1)
t2 = input.clone()
m(t2)
assert torch.allclose(t1, t2)
t2 = input.clone()
m(t2)
assert not torch.allclose(t1, t2)
print("dropout_ allclose success")

def test_default_generators(self):
assert len(torch.cuda.default_generators) > 0
torch.cuda.default_generators[0].manual_seed(1)
assert torch.cuda.default_generators[0].initial_seed() == 1


if __name__ == "__main__":
run_tests()
2 changes: 1 addition & 1 deletion third_party/DIOPI
Submodule DIOPI updated 45 files
+1 −1 CODEOWNERS
+17 −0 diopi_test/csrc/export_runtime.cpp
+58 −3 diopi_test/csrc/litert.cpp
+2 −0 diopi_test/include/conform_test.h
+15 −0 diopi_test/include/litert.hpp
+27 −10 diopi_test/python/conformance/conformance_test.py
+792 −209 diopi_test/python/conformance/diopi_configs.py
+54 −38 diopi_test/python/conformance/diopi_functions.py
+14 −8 diopi_test/python/conformance/diopi_runtime.py
+23 −20 impl/ascend/common/acloprunner.hpp
+74 −32 impl/ascend/common/utils.cpp
+20 −2 impl/ascend/functions/activation.cpp
+3 −4 impl/ascend/functions/conv2d.cpp
+57 −0 impl/ascend/functions/layer_norm.cpp
+4 −1 impl/ascend/functions/linear.cpp
+181 −121 impl/ascend/functions/loss.cpp
+49 −0 impl/ascend/functions/norm.cpp
+0 −19 impl/ascend/functions/pool.cpp
+0 −14 impl/ascend/functions/reduce.cpp
+8 −2 impl/ascend/functions/transpose.cpp
+0 −4 impl/camb/common/contiguous.cpp
+6 −1 impl/camb/common/dtype_cast.cpp
+238 −45 impl/camb/device_configs.py
+9 −46 impl/camb/diopi_helper.cpp
+2 −3 impl/camb/diopi_helper.hpp
+13 −14 impl/camb/functions/adaptive_pooling.cpp
+2 −2 impl/camb/functions/cdist.cpp
+12 −21 impl/camb/functions/dropout.cpp
+169 −0 impl/camb/functions/groupnorm.cpp
+144 −29 impl/camb/functions/loss.cpp
+9 −3 impl/camb/functions/multinomial.cpp
+13 −22 impl/camb/functions/normal.cpp
+10 −5 impl/camb/functions/random.cpp
+1 −1 impl/camb/functions/randperm.cpp
+4 −4 impl/camb/functions/scatter.cpp
+45 −0 impl/camb/functions/sgn.cpp
+12 −20 impl/camb/functions/uniform.cpp
+14 −33 impl/camb/functions_mmcv/roi_align_mlu.cpp
+1 −1 impl/camb/test/CMakeLists.txt
+25 −0 impl/camb/test/conform_test.cpp
+64 −31 impl/torch/functions.cpp
+35 −1 impl/torch/helper.hpp
+13 −0 impl/torch/test/conform_test.cpp
+7 −0 proto/include/diopi/diopirt.h
+32 −14 proto/include/diopi/functions.h
5 changes: 5 additions & 0 deletions torch_dipu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .profiler.profiler import dipu_profiler, dipu_kineto_available
from .dipu.dataloader import apply_dataloader_patch
from .dipu.optim import apply_optim_patch
from .dipu.generator import apply_generator_patch

# mock device functions in generated/python_variable_methods.cpp
def apply_tensor_method_patch():
Expand Down Expand Up @@ -81,6 +82,9 @@ def apply_torch_function_patch():
if hasattr(torch.cuda, attr):
setattr(torch.cuda, attr, getattr(dipu, attr))

if attr in torch.cuda.random.__all__ and hasattr(torch.cuda.random, attr):
setattr(torch.cuda.random, attr, getattr(dipu, attr))


# temp solution, need redesign storage
def apply_temp_patch():
Expand All @@ -104,6 +108,7 @@ def apply_patches():
apply_temp_patch()
apply_dataloader_patch()
apply_optim_patch()
apply_generator_patch()


apply_patches()
2 changes: 2 additions & 0 deletions torch_dipu/csrc_dipu/base/DIPUGlobals.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "DIPUGlobals.h"
#include "csrc_dipu/runtime/core/allocator/DIPUCachingAllocator.h"
#include "csrc_dipu/runtime/core/DIPUEventPool.h"
#include "csrc_dipu/runtime/core/DIPUGeneratorImpl.h"
#include "csrc_dipu/aten/RegisterDIPU.hpp"
#include <iostream>
#include <ctime>
Expand All @@ -24,6 +25,7 @@ void initResource() {
}

void releaseAllResources() {
releaseAllGenerator();
releaseAllDeviceMem();
releaseAllEvent();
devproxy::finalizeVendor();
Expand Down
Loading