-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add generator #255
add generator #255
Conversation
@@ -37,15 +39,13 @@ class MLUGeneratorImpl : public dipu::DIPUGeneratorImpl { | |||
*/ | |||
void init_state() const override { | |||
// resize and set the state tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
从目前 camb 和 cuda 的实现看, 是否可以改成能获取 get_state_size + setState 两个基础函数就可以了? 或者只是 camb 实现一个 intState? update 的可以直接删掉了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update_state必须存在,当更新了seed之后,就需要update_state;init_state可以删除,现在厂商只需要实现update_state和set_state即可
m.def("_is_in_bad_fork", []()->bool { return is_in_bad_fork(); }); | ||
|
||
m.def("_create_dipu_generator", [](int idx)->at::Generator { | ||
at::DeviceIndex index = static_cast<at::DeviceIndex>(idx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default generator 是否要提前初始化出来,参考 csrc/cuda/Module.cpp THCPModule_initExtension() getDefaultCUDAGenerator() 的逻辑。 否则相关接口都调不通
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
增加了default_generators的提前初始化
torch_dipu/dipu/random_dipu.py
Outdated
default_generator.seed() | ||
random_seed = default_generator.initial_seed() | ||
_C._seed(i) | ||
random_seed = _C._initial_seed() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个调用会报错,这是个get 型函数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
再加下测例吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修复,并增加了测试,没有传递device index导致的报错
def create_optional_generator_process_code(arg_name): | ||
process_template = CodeTemplate( | ||
""" | ||
::diopiGeneratorHandle_t ${arg_name}DiopiGenerator = (${arg_name}.has_value() && ${arg_name}.value().defined()) ? toDiopiGeneratorHandle(${arg_name}) : toDiopiGeneratorHandle(getDefaultDIPUGenerator()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
::diopiGeneratorHandle_t ${arg_name}DiopiGenerator = toDiopiGeneratorHandle((${arg_name}.has_value() && ${arg_name}.value().defined()) ? ${arg_name} : getDefaultDIPUGenerator());
@@ -592,15 +594,17 @@ | |||
- schema: "dropout_impl(Tensor input, float p, bool train, *, Tensor(a!) mask) -> Tensor" | |||
custom_code_at_the_beginning: | | |||
at::Tensor out = at::empty_like(input); | |||
diopiGeneratorHandle_t generatorDiopiGenerator = toDiopiGeneratorHandle(getDefaultDIPUGenerator()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
toDiopiGeneratorHandle() 是否可以提供一个空参数的版本?这行就可以不用显示写出来了
register_op: False | ||
interface: diopiDropout(ctx, out, mask, input, p, train) | ||
interface: diopiDropout(ctx, out, mask, input, p, train, generatorDiopiGenerator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interface: diopiDropout(ctx, out, mask, input, p, train, getDefalutDiopiGenerator())
No description provided.