Skip to content

Commit

Permalink
Zmz/matmul12 (DeepLink-org#869)
Browse files Browse the repository at this point in the history
* support matmul
  • Loading branch information
hellozmz authored Jan 17, 2024
1 parent 03056c2 commit bc77f24
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 2 deletions.
3 changes: 3 additions & 0 deletions impl/ascend/convert_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@
- diopiHardswishBackward:
dtype: (float64)->float32

- diopiMatmul:
dtype: (float64)->float32

- diopiAtan:
dtype: (uint8, int8, int32, int16, int64, bool)->float32

Expand Down
2 changes: 1 addition & 1 deletion impl/ascend/device_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@
args=[
{
"ins": ['input'],
"shape": [Skip((128, 49, 128)),Skip((5,)),Skip((128, 4, 49, 32)),Skip((2, 1, 3136, 3136)),Skip((2, 784, 64)),Skip((2, 16, 8, 64)),Skip((2, 31, 6, 40, 512)),],
"shape": [Skip((2, 31, 6, 40, 512)),],
},
]
),
Expand Down
1 change: 1 addition & 0 deletions impl/ascend_npu/ascend_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ ascend:
- diopiScatterScalar
- diopiScatterInpScalar
ascend_npu:
- diopiMatmul
- diopiCastDtype
- diopiCopyInp
- diopiCat
Expand Down
18 changes: 18 additions & 0 deletions impl/ascend_npu/diopi_impl/matmul.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/**
* @file
* @author DeepLink
* @copyright (c) 2023, DeepLink.
*/

#include "helper.hpp"
#include "op_plugin/AclOpsInterface.h"

namespace OP_IMPL_NS {

diopiError_t diopiMatmul(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t other) {
BEGIN_CALL_ACL_OP(input, out, other);
acl_op::matmul_out(inputAt, otherAt, outAt);
END_CALL_ACL_OP();
}

} // namespace OP_IMPL_NS
28 changes: 27 additions & 1 deletion impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2895,7 +2895,24 @@ at::Tensor viewStorage(const at::Tensor input, const c10::IntArrayRef sizes, con
if (st != -1) st *= sizes[i - 1];
}
}
return fromPreAllocated(input.data_ptr() + storageOffset * input.itemsize(), sizes, stridesVec, input.options());

// when shape[0]=-1, fill data
std::vector<int64_t> sizeVec(sizes.size(), 1);
std::copy(sizes.begin(), sizes.end(), sizeVec.begin());
if (!sizes.empty() && sizes[0] == -1) {
bool flag = true;
for (auto i : sizes) {
if (!flag && i < 0) {
TORCH_CHECK(false, "more than one -1, sizes=", sizes);
}
if (i < 0) {
flag = false;
}
}
int count = std::accumulate(sizeVec.begin() + 1, sizeVec.end(), 1, std::multiplies<int>());
sizeVec[0] = input.numel() / count;
}
return fromPreAllocated(input.data_ptr() + storageOffset * input.itemsize(), sizeVec, stridesVec, input.options());
}

c10::List<c10::optional<at::Tensor>> castIntIndicesToLongIndices(const c10::List<c10::optional<at::Tensor>>& indices) {
Expand Down Expand Up @@ -3057,7 +3074,11 @@ at::Tensor wrapper__transpose(const at::Tensor& self, int64_t dim0, int64_t dim1
}

at::Scalar wrapper___local_scalar_dense(const at::Tensor& self) { return at_npu::native::NPUNativeFunctions::_local_scalar_dense(self); }
at::Tensor& wrapper_out_mm_out(const at::Tensor& self, const at::Tensor& mat2, at::Tensor& out) { return acl_op::mm_out(self, mat2, out); }

at::Tensor& wrapper_source_Tensor_set_(at::Tensor& self, const at::Tensor& source) { return at_npu::native::NPUNativeFunctions::set_(self, source); }
at::Tensor& wrapper_out_bmm_out(const at::Tensor& self, const at::Tensor& mat2, at::Tensor& out) { return acl_op::bmm_out(self, mat2, out); }
at::Tensor wrapper__dot(const at::Tensor& self, const at::Tensor& tensor) { return acl_op::dot(self, tensor); }
} // namespace

namespace at {
Expand Down Expand Up @@ -3092,6 +3113,11 @@ TORCH_LIBRARY_IMPL(aten, XLA, m) {
m.impl("repeat", TORCH_FN(wrapper__repeat));
m.impl("transpose.int", TORCH_FN(wrapper__transpose));
m.impl("_local_scalar_dense", TORCH_FN(wrapper___local_scalar_dense));
m.impl("cat", TORCH_FN(wrapper__cat));
m.impl("mm.out", TORCH_FN(wrapper_out_mm_out));
m.impl("set_.source_Tensor", TORCH_FN(wrapper_source_Tensor_set_));
m.impl("dot", TORCH_FN(wrapper__dot));
m.impl("bmm.out", TORCH_FN(wrapper_out_bmm_out));
};

TORCH_LIBRARY_IMPL(_, XLA, m) { m.fallback(torch::CppFunction::makeFromBoxedFunction<&ascend_diopi_fallback>()); }
Expand Down

0 comments on commit bc77f24

Please sign in to comment.