-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add generator #255
Merged
Merged
add generator #255
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
e2dfb10
add generator
caikun-pjlab bfc0f00
camb support generator
caikun-pjlab 36c68ed
add normal
caikun-pjlab 145f1d8
add other random op
caikun-pjlab a01ce60
merge main
caikun-pjlab bf61cdb
add cuda generator
caikun-pjlab be1d6e1
autogen support generator
caikun-pjlab a613389
fix camb generator
caikun-pjlab 219b67d
fix bug
caikun-pjlab 5714707
optimize code
caikun-pjlab 6f3c377
remove useless header and log
caikun-pjlab 2a3a461
fix format
caikun-pjlab c950879
add torch.Generator mock
caikun-pjlab 74c5f65
fix generator testcase
caikun-pjlab 9fe3add
update diopi
caikun-pjlab acc751a
Merge branch 'main' into caikun/dipu_generator
caikun-pjlab 68cf118
update diopi
caikun-pjlab 7a230a9
update parameter
caikun-pjlab 83549be
update diopi
caikun-pjlab 1092346
update diopi
caikun-pjlab 0fd5fe2
release generator before release memory
caikun-pjlab c0af287
Merge branch 'main' into caikun/dipu_generator
caikun-pjlab 0fe9e76
dropout support generator
caikun-pjlab 246814d
update diopi
caikun-pjlab 4a1f9ed
update DIOPI
caikun-pjlab fb0c6f5
merge main
caikun-pjlab b309e48
update diopi
caikun-pjlab 2be91ad
fix comments
caikun-pjlab 2d54167
fix test and compile
caikun-pjlab File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -385,11 +385,13 @@ | |
|
||
- schema: "randperm.out(int n, *, Tensor(a!) out) -> Tensor(a!)" | ||
autocompare: disable | ||
interface: diopiRandperm(ctx, out, n) | ||
custom_code_at_the_beginning: | | ||
diopiGeneratorHandle_t generatorDiopiGenerator = toDiopiGeneratorHandle(getDefaultDIPUGenerator()); | ||
interface: diopiRandperm(ctx, out, n, generatorDiopiGenerator) | ||
|
||
- schema: "randperm.generator_out(int n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)" | ||
autocompare: disable | ||
interface: diopiRandperm(ctx, out, n) | ||
interface: diopiRandperm(ctx, out, n, generator) | ||
|
||
- schema: "aten::sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)" | ||
custom_code_at_the_beginning: | | ||
|
@@ -592,15 +594,17 @@ | |
- schema: "dropout_impl(Tensor input, float p, bool train, *, Tensor(a!) mask) -> Tensor" | ||
custom_code_at_the_beginning: | | ||
at::Tensor out = at::empty_like(input); | ||
diopiGeneratorHandle_t generatorDiopiGenerator = toDiopiGeneratorHandle(getDefaultDIPUGenerator()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. toDiopiGeneratorHandle() 是否可以提供一个空参数的版本?这行就可以不用显示写出来了 |
||
register_op: False | ||
interface: diopiDropout(ctx, out, mask, input, p, train) | ||
interface: diopiDropout(ctx, out, mask, input, p, train, generatorDiopiGenerator) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. interface: diopiDropout(ctx, out, mask, input, p, train, getDefalutDiopiGenerator()) |
||
|
||
- schema: "dropout(Tensor input, float p, bool train) -> Tensor" | ||
autocompare: disable | ||
custom_code_at_the_beginning: | | ||
auto mask = at::empty(input.sizes(), input.options().dtype(at::kByte)); | ||
at::Tensor out = at::empty_like(input); | ||
interface: diopiDropout(ctx, out, mask, input, p, train) | ||
diopiGeneratorHandle_t generatorDiopiGenerator = toDiopiGeneratorHandle(getDefaultDIPUGenerator()); | ||
interface: diopiDropout(ctx, out, mask, input, p, train, generatorDiopiGenerator) | ||
outs: [mask] | ||
autograd: True | ||
saved_data: [p, mask] | ||
|
@@ -620,14 +624,17 @@ | |
return outputs; | ||
|
||
- schema: "dropout__impl(Tensor(a!) self, Tensor(b!) mask, float p, bool train) -> Tensor(a!)" | ||
custom_code_at_the_beginning: | | ||
diopiGeneratorHandle_t generatorDiopiGenerator = toDiopiGeneratorHandle(getDefaultDIPUGenerator()); | ||
register_op: False | ||
interface: diopiDropoutInp(ctx, self, mask, p, train) | ||
interface: diopiDropoutInp(ctx, self, mask, p, train, generatorDiopiGenerator) | ||
|
||
- schema: "dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)" | ||
custom_code_at_the_beginning: | | ||
auto mask = at::empty(self.sizes(), self.options().dtype(at::kByte)); | ||
diopiGeneratorHandle_t generatorDiopiGenerator = toDiopiGeneratorHandle(getDefaultDIPUGenerator()); | ||
outs: [mask] | ||
interface: diopiDropoutInp(ctx, self, mask, p, train) | ||
interface: diopiDropoutInp(ctx, self, mask, p, train, generatorDiopiGenerator) | ||
autograd: True | ||
forward_process_code: | | ||
auto mask = at::empty(self.sizes(), self.options().dtype(at::kByte)); | ||
|
@@ -918,7 +925,7 @@ | |
interface: diopiRsqrt(ctx, out, self) | ||
|
||
- schema: "uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)" | ||
interface: diopiUniformInp(ctx, self, from, to) | ||
interface: diopiUniformInp(ctx, self, from, to, generator) | ||
|
||
- schema: "tril(Tensor self, int diagonal=0) -> Tensor" | ||
custom_code_at_the_beginning: | | ||
|
@@ -937,10 +944,10 @@ | |
else if (self.dim() == 1) { | ||
out = at::empty({num_samples,}, self.options().dtype(at::kLong)); | ||
} | ||
interface: diopiMultinomial(ctx, out, self, num_samples, replacement) | ||
interface: diopiMultinomial(ctx, out, self, num_samples, replacement, generator) | ||
|
||
- schema: "multinomial.out(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)" | ||
interface: diopiMultinomial(ctx, out, self, num_samples, replacement) | ||
interface: diopiMultinomial(ctx, out, self, num_samples, replacement, generator) | ||
|
||
- schema: "roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor" | ||
custom_code_at_the_beginning: | | ||
|
@@ -1000,15 +1007,15 @@ | |
|
||
- schema: "random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)" | ||
autocompare: disable | ||
interface: diopiRandomInp(ctx, self, 0, nullptr) | ||
interface: diopiRandomInp(ctx, self, 0, nullptr, generator) | ||
|
||
- schema: "random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)" | ||
autocompare: disable | ||
interface: diopiRandomInp(ctx, self, 0, &to) | ||
interface: diopiRandomInp(ctx, self, 0, &to, generator) | ||
|
||
- schema: "random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)" | ||
autocompare: disable | ||
interface: "diopiRandomInp(ctx, self, from, to.has_value() ? &to.value() : nullptr)" | ||
interface: "diopiRandomInp(ctx, self, from, to.has_value() ? &to.value() : nullptr, generator)" | ||
|
||
- schema: "nonzero(Tensor self) -> Tensor" | ||
custom_code_at_the_beginning: | | ||
|
@@ -1338,7 +1345,7 @@ | |
|
||
- schema: "normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)" | ||
autocompare: disable | ||
interface: diopiNormalInp(ctx, self, mean, std) | ||
interface: diopiNormalInp(ctx, self, mean, std, generator) | ||
|
||
- schema: "mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)" | ||
interface: diopiMm(ctx, out, self, mat2) | ||
|
@@ -1749,4 +1756,4 @@ | |
- schema: batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor | ||
custom_code_at_the_beginning: | | ||
auto out = at::empty_like(input); | ||
interface: diopiBatchNormElemt(ctx, out, input, weight, bias, mean, invstd, eps); | ||
interface: diopiBatchNormElemt(ctx, out, input, weight, bias, mean, invstd, eps); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
# Copyright (c) 2023, DeepLink. | ||
import torch | ||
import torch_dipu | ||
|
||
from torch_dipu.testing._internal.common_utils import create_common_tensor, TestCase, run_tests | ||
|
||
|
||
class TestGenerator(TestCase): | ||
def test_python_api(self): | ||
torch.seed() | ||
torch.cuda.seed_all() | ||
torch.cuda.random.seed_all() | ||
torch.cuda.manual_seed_all(1) | ||
rngs = torch.cuda.get_rng_state_all() | ||
torch.cuda.set_rng_state_all(rngs) | ||
torch.manual_seed(1) | ||
assert torch.cuda.initial_seed() == 1 | ||
assert torch.initial_seed() == 1 | ||
for i in range(torch.cuda.device_count()): | ||
torch.cuda.manual_seed(i) | ||
|
||
state = torch.cuda.get_rng_state(0) | ||
new_state = torch.ones_like(state) | ||
torch.cuda.set_rng_state(new_state, 0) | ||
current_state = torch.cuda.get_rng_state(0) | ||
assert torch.allclose(current_state, torch.tensor(1, device=current_state.device, dtype=current_state.dtype)) | ||
|
||
def test_torch_generator(self): | ||
gen = torch.Generator() | ||
assert gen.device.type == 'cpu' | ||
gen.manual_seed(1) | ||
assert gen.initial_seed() == 1 | ||
|
||
gen = torch.Generator("cpu") | ||
assert gen.device.type == 'cpu' | ||
|
||
gen = torch.Generator("cuda") | ||
assert gen.device.type == 'xpu' | ||
|
||
gen = torch.Generator("cuda:0") | ||
assert gen.device == torch.device('xpu:0') | ||
|
||
gen = torch.Generator("dipu") | ||
assert gen.device.type == 'xpu' | ||
gen.manual_seed(1) | ||
assert gen.initial_seed() == 1 | ||
|
||
def test_randn_with_generator(self): | ||
gen = torch.Generator() | ||
gen.manual_seed(1) | ||
data1 = torch.randn(2, 3, generator = gen) | ||
gen.manual_seed(1) | ||
data2 = torch.randn(2, 3, generator = gen) | ||
assert torch.allclose(data1, data2) | ||
data2 = torch.randn(2, 3, generator = gen) | ||
assert not torch.allclose(data1, data2) | ||
|
||
gen = torch.Generator('cuda') | ||
gen.manual_seed(1) | ||
data1 = torch.randn(2, 3, generator = gen, device = 'cuda') | ||
gen.manual_seed(1) | ||
data2 = torch.randn(2, 3, generator = gen, device = 'cuda') | ||
assert torch.allclose(data1, data2) | ||
data2 = torch.randn(2, 3, generator = gen, device = 'cuda') | ||
assert not torch.allclose(data1, data2) | ||
|
||
def test_uniform_(self): | ||
t1 = torch.arange(0, 100, dtype=torch.float32).cuda() | ||
t2 = t1.clone() | ||
torch.manual_seed(1) | ||
t1.uniform_() | ||
torch.manual_seed(1) | ||
t2.uniform_() | ||
assert torch.allclose(t1, t2) | ||
t2.uniform_() | ||
assert not torch.allclose(t1, t2) | ||
print("uniform_ allclose success") | ||
|
||
def test_normal_(self): | ||
t1 = torch.arange(0, 100, dtype=torch.float32).cuda() | ||
t2 = t1.clone() | ||
torch.manual_seed(1) | ||
t1.normal_() | ||
torch.manual_seed(1) | ||
t2.normal_() | ||
assert torch.allclose(t1, t2) | ||
t2.normal_() | ||
assert not torch.allclose(t1, t2) | ||
print("normal_ allclose success") | ||
|
||
def test_random_(self): | ||
t1 = torch.arange(0, 100, dtype=torch.float32).cuda() | ||
t2 = t1.clone() | ||
torch.manual_seed(1) | ||
t1.random_(0, 100) | ||
torch.manual_seed(1) | ||
t2.random_(0, 100) | ||
assert torch.allclose(t1, t2) | ||
t2.random_(0, 100) | ||
assert not torch.allclose(t1, t2) | ||
|
||
torch.manual_seed(1) | ||
t1.random_() | ||
torch.manual_seed(1) | ||
t2.random_() | ||
assert torch.allclose(t1, t2) | ||
t2.random_() | ||
assert not torch.allclose(t1, t2) | ||
print("random_ allclose success") | ||
|
||
def test_multinomial(self): | ||
data = torch.arange(0, 100, dtype=torch.float).cuda() | ||
torch.manual_seed(1) | ||
data1 = torch.multinomial(data, 10) | ||
torch.manual_seed(1) | ||
data2 = torch.multinomial(data, 10) | ||
assert torch.allclose(data1, data2) | ||
data2 = torch.multinomial(data, 10) | ||
assert not torch.allclose(data1, data2) | ||
print("multinomial allclose success") | ||
|
||
def test_randn(self): | ||
torch.manual_seed(1) | ||
t1 = torch.randn(100, device='cuda') | ||
torch.manual_seed(1) | ||
t2 = torch.randn(100, device='cuda') | ||
assert torch.allclose(t1, t2) | ||
t2 = torch.randn(100, device='cuda') | ||
assert not torch.allclose(t1, t2) | ||
print("randn allclose success") | ||
|
||
def test_randperm(self): | ||
if torch_dipu.dipu.vendor_type == "MLU": | ||
return | ||
|
||
torch.manual_seed(1) | ||
t1 = torch.randperm(100, device='cuda') | ||
torch.manual_seed(1) | ||
t2 = torch.randperm(100, device='cuda') | ||
assert torch.allclose(t1, t2) | ||
t2 = torch.randperm(100, device='cuda') | ||
assert not torch.allclose(t1, t2) | ||
print("randperm allclose success") | ||
|
||
def test_dropout(self): | ||
m = torch.nn.Dropout(p=0.2).cuda() | ||
input = torch.randn(20, 16).cuda() | ||
torch.manual_seed(1) | ||
t1 = m(input) | ||
torch.manual_seed(1) | ||
t2 = m(input) | ||
assert torch.allclose(t1, t2) | ||
t2 = m(input) | ||
assert not torch.allclose(t1, t2) | ||
print("dropout allclose success") | ||
|
||
def test_dropout_(self): | ||
m = torch.nn.Dropout(p=0.2, inplace=True).cuda() | ||
input = torch.randn(20, 16).cuda() | ||
p = 0.2 | ||
torch.manual_seed(1) | ||
t1 = input.clone() | ||
m(t1) | ||
torch.manual_seed(1) | ||
t2 = input.clone() | ||
m(t2) | ||
assert torch.allclose(t1, t2) | ||
t2 = input.clone() | ||
m(t2) | ||
assert not torch.allclose(t1, t2) | ||
print("dropout_ allclose success") | ||
|
||
def test_default_generators(self): | ||
assert len(torch.cuda.default_generators) > 0 | ||
torch.cuda.default_generators[0].manual_seed(1) | ||
assert torch.cuda.default_generators[0].initial_seed() == 1 | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
Submodule DIOPI
updated
45 files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
::diopiGeneratorHandle_t ${arg_name}DiopiGenerator = toDiopiGeneratorHandle((${arg_name}.has_value() && ${arg_name}.value().defined()) ? ${arg_name} : getDefaultDIPUGenerator());