Skip to content

Commit 39e361f

Browse files
committed
ggml-hexagon:add skelton code of offload GGML_OP_SOFT_MAX/GGML_OP_RMS_NORM/GGML_OP_POOL_2D to Hexagon cDSP
1 parent d0aaf11 commit 39e361f

File tree

6 files changed

+151
-372
lines changed

6 files changed

+151
-372
lines changed

ggml/src/ggml-hexagon/ggml-hexagon.cpp

+29-14
Original file line numberDiff line numberDiff line change
@@ -598,9 +598,9 @@ static constexpr const hexagon_op_caps ggmlhexagon_k_op_caps[] = {
598598
{true, GGML_OP_ADD, 2, "ggmlop_dsp_add", ggmlop_dsp_add},
599599
{false, GGML_OP_ADD1, 0, nullptr, nullptr},
600600
{false, GGML_OP_ACC, 0, nullptr, nullptr},
601-
{true, GGML_OP_SUB, 2, "ggmlop_dsp_sub", ggmlop_dsp_sub},
602-
{true, GGML_OP_MUL, 2, "ggmlop_dsp_mul", ggmlop_dsp_mul},
603-
{true, GGML_OP_DIV, 2, "ggmlop_dsp_div", ggmlop_dsp_div},
601+
{false, GGML_OP_SUB, 2, nullptr, nullptr},
602+
{false, GGML_OP_MUL, 2, nullptr, nullptr},
603+
{false, GGML_OP_DIV, 2, nullptr, nullptr},
604604
{false, GGML_OP_SQR, 0, nullptr, nullptr},
605605
{false, GGML_OP_SQRT, 0, nullptr, nullptr},
606606
{false, GGML_OP_LOG, 0, nullptr, nullptr},
@@ -616,7 +616,7 @@ static constexpr const hexagon_op_caps ggmlhexagon_k_op_caps[] = {
616616
{false, GGML_OP_CONCAT, 0, nullptr, nullptr},
617617
{false, GGML_OP_SILU_BACK, 0, nullptr, nullptr},
618618
{false, GGML_OP_NORM, 0, nullptr, nullptr},
619-
{false, GGML_OP_RMS_NORM, 0, nullptr, nullptr},
619+
{true, GGML_OP_RMS_NORM, 1, "ggmlop_dsp_rmsnorm", ggmlop_dsp_rmsnorm},
620620
{false, GGML_OP_RMS_NORM_BACK, 0, nullptr, nullptr},
621621
{false, GGML_OP_GROUP_NORM, 0, nullptr, nullptr},
622622
{false, GGML_OP_L2_NORM, 0, nullptr, nullptr},
@@ -636,7 +636,7 @@ static constexpr const hexagon_op_caps ggmlhexagon_k_op_caps[] = {
636636
{false, GGML_OP_DIAG, 0, nullptr, nullptr},
637637
{false, GGML_OP_DIAG_MASK_INF, 0, nullptr, nullptr},
638638
{false, GGML_OP_DIAG_MASK_ZERO, 0, nullptr, nullptr},
639-
{false, GGML_OP_SOFT_MAX, 0, nullptr, nullptr},
639+
{true, GGML_OP_SOFT_MAX, 1, "ggmlop_dsp_softmax", ggmlop_dsp_softmax},
640640
{false, GGML_OP_SOFT_MAX_BACK, 0, nullptr, nullptr},
641641
{false, GGML_OP_ROPE, 0, nullptr, nullptr},
642642
{false, GGML_OP_ROPE_BACK, 0, nullptr, nullptr},
@@ -646,7 +646,7 @@ static constexpr const hexagon_op_caps ggmlhexagon_k_op_caps[] = {
646646
{false, GGML_OP_IM2COL_BACK, 0, nullptr, nullptr},
647647
{false, GGML_OP_CONV_TRANSPOSE_2D, 0, nullptr, nullptr},
648648
{false, GGML_OP_POOL_1D, 0, nullptr, nullptr},
649-
{false, GGML_OP_POOL_2D, 0, nullptr, nullptr},
649+
{true, GGML_OP_POOL_2D, 1, "ggmlop_dsp_pool2d", ggmlop_dsp_pool2d},
650650
{false, GGML_OP_POOL_2D_BACK, 0, nullptr, nullptr},
651651
{false, GGML_OP_UPSCALE, 0, nullptr, nullptr},
652652
{false, GGML_OP_PAD, 0, nullptr, nullptr},
@@ -694,10 +694,10 @@ static constexpr const hexagon_op_caps ggmlhexagon_k_op_caps[] = {
694694
{false, static_cast<ggml_op>(GGML_UNARY_OP_EXP), 0, nullptr, nullptr}
695695
};
696696

697-
static_assert(ggmlhexagon_k_op_caps[GGML_OP_NONE].supported, "GGML_OP_NONE is not true");
698-
static_assert(ggmlhexagon_k_op_caps[GGML_OP_ADD].supported, "GGML_OP_ADD is not true");
699-
static_assert(ggmlhexagon_k_op_caps[GGML_OP_MUL].supported, "GGML_OP_MUL is not true");
700-
static_assert(ggmlhexagon_k_op_caps[GGML_OP_MUL_MAT].supported, "GGML_OP_MUL_MAT is not true");
697+
static_assert(ggmlhexagon_k_op_caps[GGML_OP_NONE].supported, "GGML_OP_NONE is not true");
698+
static_assert(ggmlhexagon_k_op_caps[GGML_OP_ADD].supported, "GGML_OP_ADD is not true");
699+
static_assert(ggmlhexagon_k_op_caps[GGML_OP_MUL_MAT].supported, "GGML_OP_MUL_MAT is not true");
700+
static_assert(ggmlhexagon_k_op_caps[GGML_OP_SOFT_MAX].supported, "GGML_OP_SOFT_MAX is not true");
701701
static_assert(std::size(ggmlhexagon_k_op_caps) == (static_cast<size_t>(GGML_OP_COUNT) + static_cast<size_t>(GGML_UNARY_OP_COUNT)),
702702
"pls check ggmlhexagon_k_op_caps and ensure is corresponding to latest ggml.h");
703703

@@ -5018,6 +5018,7 @@ static void ggmlhexagon_compute(ggml_backend_hexagon_context * ctx, struct ggml_
50185018
dsptensor_0.nb[3] = src0->nb[3];
50195019

50205020
if (2 == input_tensor_count) {
5021+
GGML_ASSERT(nullptr != src1);
50215022
dsptensor_1.data = src1->data;
50225023
dsptensor_1.type = src1->type;
50235024
dsptensor_1.data_len = ggml_nbytes(src1);
@@ -5047,6 +5048,8 @@ static void ggmlhexagon_compute(ggml_backend_hexagon_context * ctx, struct ggml_
50475048
dsptensor_2.nb[2] = dst->nb[2];
50485049
dsptensor_2.nb[3] = dst->nb[3];
50495050

5051+
memcpy(dsptensor_2.op_params, dst->op_params, GGML_MAX_OP_PARAMS / sizeof(int32_t));
5052+
50505053
hexagon_error = op_func(ctx->ggmlop_handle, &dsptensor_0, &dsptensor_1, &dsptensor_2);
50515054
if (AEE_SUCCESS != hexagon_error) {
50525055
GGMLHEXAGON_LOG_WARN("ggmlop %s computation fail on cdsp", ggml_op_name(op->op));
@@ -5078,19 +5081,31 @@ static bool ggmlhexagon_can_handle_op_through_cdsp(ggml_backend_dev_t dev, const
50785081
if (!ggml_are_same_shape(src0, src1)) {
50795082
return false;
50805083
}
5081-
5082-
//TODO: offload quantize GGML_OP_ADD to cDSP
5083-
return ggmlhexagon_same_types(ctx, op_tensor);
5084+
return (src0->type == GGML_TYPE_F32) && (src1->type == GGML_TYPE_F32) && (op_tensor->type == GGML_TYPE_F32);
50845085
}
50855086
case GGML_OP_MUL_MAT:
50865087
{
50875088
ggmlhexagon_dump_op_info(op_tensor);
50885089
if (1 == g_hexagon_appcfg.enable_q_mulmat)
5089-
return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_Q6_K
5090+
return (src0->type == GGML_TYPE_F32
5091+
|| src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q8_0
5092+
|| src0->type == GGML_TYPE_Q6_K || src0->type == GGML_TYPE_Q8_K
50905093
) && (src1->type == GGML_TYPE_F32) && (op_tensor->type == GGML_TYPE_F32);
50915094
else
50925095
return (src0->type == GGML_TYPE_F32) && (src1->type == GGML_TYPE_F32) && (op_tensor->type == GGML_TYPE_F32);
50935096
}
5097+
case GGML_OP_SOFT_MAX:{
5098+
if (!ggml_is_contiguous(op_tensor))
5099+
return false;
5100+
if (!ggml_are_same_shape(src0, op_tensor))
5101+
return false;
5102+
}
5103+
case GGML_OP_RMS_NORM:
5104+
case GGML_OP_POOL_2D:
5105+
{
5106+
5107+
ggmlhexagon_dump_op_info(op_tensor);
5108+
}
50945109
default:
50955110
break;
50965111
}

0 commit comments

Comments
 (0)