Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add generator #255

Merged
merged 29 commits into from
Aug 30, 2023
Merged

add generator #255

merged 29 commits into from
Aug 30, 2023

Conversation

caikun-pjlab
Copy link
Contributor

No description provided.

@@ -37,15 +39,13 @@ class MLUGeneratorImpl : public dipu::DIPUGeneratorImpl {
*/
void init_state() const override {
// resize and set the state tensor.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

从目前 camb 和 cuda 的实现看, 是否可以改成能获取 get_state_size + setState 两个基础函数就可以了? 或者只是 camb 实现一个 intState? update 的可以直接删掉了

Copy link
Contributor Author

@caikun-pjlab caikun-pjlab Aug 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update_state必须存在,当更新了seed之后,就需要update_state;init_state可以删除,现在厂商只需要实现update_state和set_state即可

m.def("_is_in_bad_fork", []()->bool { return is_in_bad_fork(); });

m.def("_create_dipu_generator", [](int idx)->at::Generator {
at::DeviceIndex index = static_cast<at::DeviceIndex>(idx);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default generator 是否要提前初始化出来,参考 csrc/cuda/Module.cpp THCPModule_initExtension() getDefaultCUDAGenerator() 的逻辑。 否则相关接口都调不通

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加了default_generators的提前初始化

default_generator.seed()
random_seed = default_generator.initial_seed()
_C._seed(i)
random_seed = _C._initial_seed()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个调用会报错,这是个get 型函数

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

再加下测例吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修复,并增加了测试,没有传递device index导致的报错

def create_optional_generator_process_code(arg_name):
process_template = CodeTemplate(
"""
::diopiGeneratorHandle_t ${arg_name}DiopiGenerator = (${arg_name}.has_value() && ${arg_name}.value().defined()) ? toDiopiGeneratorHandle(${arg_name}) : toDiopiGeneratorHandle(getDefaultDIPUGenerator());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

::diopiGeneratorHandle_t ${arg_name}DiopiGenerator = toDiopiGeneratorHandle((${arg_name}.has_value() && ${arg_name}.value().defined()) ? ${arg_name} : getDefaultDIPUGenerator());

@@ -592,15 +594,17 @@
- schema: "dropout_impl(Tensor input, float p, bool train, *, Tensor(a!) mask) -> Tensor"
custom_code_at_the_beginning: |
at::Tensor out = at::empty_like(input);
diopiGeneratorHandle_t generatorDiopiGenerator = toDiopiGeneratorHandle(getDefaultDIPUGenerator());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

toDiopiGeneratorHandle() 是否可以提供一个空参数的版本?这行就可以不用显示写出来了

register_op: False
interface: diopiDropout(ctx, out, mask, input, p, train)
interface: diopiDropout(ctx, out, mask, input, p, train, generatorDiopiGenerator)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interface: diopiDropout(ctx, out, mask, input, p, train, getDefalutDiopiGenerator())

@caikun-pjlab caikun-pjlab merged commit 78362d4 into main Aug 30, 2023
9 of 11 checks passed
LeungChiNan pushed a commit to DeepLink-org/deeplink.framework.dev that referenced this pull request Dec 8, 2023
@mrdanielw mrdanielw deleted the caikun/dipu_generator branch December 11, 2023 12:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants