Skip to content

Commit

Permalink
diopi support profiler (DeepLink-org#390)
Browse files Browse the repository at this point in the history
* diopi support profiler

* code format

* format code

* optimize according to comments

* add comments
  • Loading branch information
caikun-pjlab authored Sep 13, 2023
1 parent 217212a commit 813255e
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 0 deletions.
4 changes: 4 additions & 0 deletions diopi_test/diopi_stub/csrc/litert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,4 +343,8 @@ DIOPI_RT_API diopiError_t diopiGeneratorSetState(diopiGeneratorHandle_t th, diop
return diopiSuccess;
}

DIOPI_RT_API diopiError_t diopiRecordStart(const char* recordName, void** record) { return diopiSuccess; }

DIOPI_RT_API diopiError_t diopiRecordEnd(void** record) { return diopiSuccess; }

} // extern "C"
9 changes: 9 additions & 0 deletions impl/camb/functions/conv_2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,19 +238,28 @@ diopiError_t diopiConvolution2dBackward(diopiContextHandle_t ctx, diopiTensorHan

if (gradWeightTensor.defined()) {
REQUIRES_TENSOR_BY_DTYPE_OR_NOT(gradWeightTensorTmp, gradWeightTensor, inputTensor.dtype());
void *record = nullptr;
DIOPI_CALL(diopiRecordStart("convBackwardFilter", &record));
DIOPI_CALL(convBackwardFilter(ctx, gradOutputTensor, gradWeightTensorTmp, inputTensor, stride, padding, dilation, groups));
DIOPI_CALL(diopiRecordEnd(&record));
DIOPI_CALL(dataTypeCast(ctx, gradWeightTensor, gradWeightTensorTmp));
}

if (gradInputTensor.defined()) {
REQUIRES_TENSOR_BY_DTYPE_OR_NOT(gradInputTensorTmp, gradInputTensor, inputTensor.dtype());
void *record = nullptr;
DIOPI_CALL(diopiRecordStart("convBackwardData", &record));
DIOPI_CALL(convBackwardData(ctx, gradOutputTensor, gradInputTensorTmp, weightTensor, stride, padding, dilation, groups));
DIOPI_CALL(diopiRecordEnd(&record));
DIOPI_CALL(dataTypeCast(ctx, gradInputTensor, gradInputTensorTmp));
}

if (grad3 != nullptr) {
REQUIRES_TENSOR_BY_DTYPE_OR_NOT(gradBiasTensorTmp, gradBiasTensor, inputTensor.dtype());
void *record = nullptr;
DIOPI_CALL(diopiRecordStart("convBackwardBias", &record));
DIOPI_CALL(convBackwardBias(ctx, gradOutputTensor, gradBiasTensorTmp));
DIOPI_CALL(diopiRecordEnd(&record));
DIOPI_CALL(dataTypeCast(ctx, gradBiasTensor, gradBiasTensorTmp))
}

Expand Down
3 changes: 3 additions & 0 deletions impl/torch/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,11 @@ diopiError_t diopiConvolution2d(diopiContextHandle_t ctx, diopiTensorHandle_t ou
auto atStride = impl::aten::buildAtIntArray(stride);
auto atPadding = impl::aten::buildAtIntArray(padding);
auto atDilation = impl::aten::buildAtIntArray(dilation);
void* record = nullptr;
diopiRecordStart("at::convolution_out", &record);
impl::aten::invokeATenFuncInp(
ctx, at::convolution_out, atOut, atInput, atWeight, atBias, atStride, atPadding, atDilation, false, at::IntArrayRef(0), groups);
diopiRecordEnd(&record);
impl::aten::unsetCurCtx();
return diopiSuccess;
}
Expand Down
7 changes: 7 additions & 0 deletions proto/include/diopi/diopirt.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ extern DIOPI_RT_API DIOPI_ATTR_WEEK diopiError_t diopiRequireBuffer(diopiContext
extern DIOPI_RT_API DIOPI_ATTR_WEEK diopiError_t diopiGeneratorGetState(diopiContextHandle_t ctx, diopiConstGeneratorHandle_t th, diopiTensorHandle_t* data);
extern DIOPI_RT_API DIOPI_ATTR_WEEK diopiError_t diopiGeneratorSetState(diopiGeneratorHandle_t th, diopiConstTensorHandle_t state);

/**
* operations to manipulate profiler record objects.
* Call diopiRecordStart at the beginning of code that you want to profile and call diopiRecordEnd at the end.
**/
extern DIOPI_RT_API DIOPI_ATTR_WEEK diopiError_t diopiRecordStart(const char* record_name, void** record);
extern DIOPI_RT_API DIOPI_ATTR_WEEK diopiError_t diopiRecordEnd(void** record);

#if defined(__cplusplus)
}
#endif
Expand Down

0 comments on commit 813255e

Please sign in to comment.