@@ -598,9 +598,9 @@ static constexpr const hexagon_op_caps ggmlhexagon_k_op_caps[] = {
598
598
{true , GGML_OP_ADD, 2 , " ggmlop_dsp_add" , ggmlop_dsp_add},
599
599
{false , GGML_OP_ADD1, 0 , nullptr , nullptr },
600
600
{false , GGML_OP_ACC, 0 , nullptr , nullptr },
601
- {true , GGML_OP_SUB, 2 , " ggmlop_dsp_sub " , ggmlop_dsp_sub },
602
- {true , GGML_OP_MUL, 2 , " ggmlop_dsp_mul " , ggmlop_dsp_mul },
603
- {true , GGML_OP_DIV, 2 , " ggmlop_dsp_div " , ggmlop_dsp_div },
601
+ {false , GGML_OP_SUB, 2 , nullptr , nullptr },
602
+ {false , GGML_OP_MUL, 2 , nullptr , nullptr },
603
+ {false , GGML_OP_DIV, 2 , nullptr , nullptr },
604
604
{false , GGML_OP_SQR, 0 , nullptr , nullptr },
605
605
{false , GGML_OP_SQRT, 0 , nullptr , nullptr },
606
606
{false , GGML_OP_LOG, 0 , nullptr , nullptr },
@@ -616,7 +616,7 @@ static constexpr const hexagon_op_caps ggmlhexagon_k_op_caps[] = {
616
616
{false , GGML_OP_CONCAT, 0 , nullptr , nullptr },
617
617
{false , GGML_OP_SILU_BACK, 0 , nullptr , nullptr },
618
618
{false , GGML_OP_NORM, 0 , nullptr , nullptr },
619
- {false , GGML_OP_RMS_NORM, 0 , nullptr , nullptr },
619
+ {true , GGML_OP_RMS_NORM, 1 , " ggmlop_dsp_rmsnorm " , ggmlop_dsp_rmsnorm },
620
620
{false , GGML_OP_RMS_NORM_BACK, 0 , nullptr , nullptr },
621
621
{false , GGML_OP_GROUP_NORM, 0 , nullptr , nullptr },
622
622
{false , GGML_OP_L2_NORM, 0 , nullptr , nullptr },
@@ -636,7 +636,7 @@ static constexpr const hexagon_op_caps ggmlhexagon_k_op_caps[] = {
636
636
{false , GGML_OP_DIAG, 0 , nullptr , nullptr },
637
637
{false , GGML_OP_DIAG_MASK_INF, 0 , nullptr , nullptr },
638
638
{false , GGML_OP_DIAG_MASK_ZERO, 0 , nullptr , nullptr },
639
- {false , GGML_OP_SOFT_MAX, 0 , nullptr , nullptr },
639
+ {true , GGML_OP_SOFT_MAX, 1 , " ggmlop_dsp_softmax " , ggmlop_dsp_softmax },
640
640
{false , GGML_OP_SOFT_MAX_BACK, 0 , nullptr , nullptr },
641
641
{false , GGML_OP_ROPE, 0 , nullptr , nullptr },
642
642
{false , GGML_OP_ROPE_BACK, 0 , nullptr , nullptr },
@@ -646,7 +646,7 @@ static constexpr const hexagon_op_caps ggmlhexagon_k_op_caps[] = {
646
646
{false , GGML_OP_IM2COL_BACK, 0 , nullptr , nullptr },
647
647
{false , GGML_OP_CONV_TRANSPOSE_2D, 0 , nullptr , nullptr },
648
648
{false , GGML_OP_POOL_1D, 0 , nullptr , nullptr },
649
- {false , GGML_OP_POOL_2D, 0 , nullptr , nullptr },
649
+ {true , GGML_OP_POOL_2D, 1 , " ggmlop_dsp_pool2d " , ggmlop_dsp_pool2d },
650
650
{false , GGML_OP_POOL_2D_BACK, 0 , nullptr , nullptr },
651
651
{false , GGML_OP_UPSCALE, 0 , nullptr , nullptr },
652
652
{false , GGML_OP_PAD, 0 , nullptr , nullptr },
@@ -694,10 +694,10 @@ static constexpr const hexagon_op_caps ggmlhexagon_k_op_caps[] = {
694
694
{false , static_cast <ggml_op>(GGML_UNARY_OP_EXP), 0 , nullptr , nullptr }
695
695
};
696
696
697
- static_assert (ggmlhexagon_k_op_caps[GGML_OP_NONE].supported, " GGML_OP_NONE is not true" );
698
- static_assert (ggmlhexagon_k_op_caps[GGML_OP_ADD].supported, " GGML_OP_ADD is not true" );
699
- static_assert (ggmlhexagon_k_op_caps[GGML_OP_MUL ].supported, " GGML_OP_MUL is not true" );
700
- static_assert (ggmlhexagon_k_op_caps[GGML_OP_MUL_MAT ].supported, " GGML_OP_MUL_MAT is not true" );
697
+ static_assert (ggmlhexagon_k_op_caps[GGML_OP_NONE].supported, " GGML_OP_NONE is not true" );
698
+ static_assert (ggmlhexagon_k_op_caps[GGML_OP_ADD].supported, " GGML_OP_ADD is not true" );
699
+ static_assert (ggmlhexagon_k_op_caps[GGML_OP_MUL_MAT ].supported, " GGML_OP_MUL_MAT is not true" );
700
+ static_assert (ggmlhexagon_k_op_caps[GGML_OP_SOFT_MAX ].supported, " GGML_OP_SOFT_MAX is not true" );
701
701
static_assert (std::size(ggmlhexagon_k_op_caps) == (static_cast <size_t >(GGML_OP_COUNT) + static_cast<size_t>(GGML_UNARY_OP_COUNT)),
702
702
"pls check ggmlhexagon_k_op_caps and ensure is corresponding to latest ggml.h");
703
703
@@ -5018,6 +5018,7 @@ static void ggmlhexagon_compute(ggml_backend_hexagon_context * ctx, struct ggml_
5018
5018
dsptensor_0.nb [3 ] = src0->nb [3 ];
5019
5019
5020
5020
if (2 == input_tensor_count) {
5021
+ GGML_ASSERT (nullptr != src1);
5021
5022
dsptensor_1.data = src1->data ;
5022
5023
dsptensor_1.type = src1->type ;
5023
5024
dsptensor_1.data_len = ggml_nbytes (src1);
@@ -5047,6 +5048,8 @@ static void ggmlhexagon_compute(ggml_backend_hexagon_context * ctx, struct ggml_
5047
5048
dsptensor_2.nb [2 ] = dst->nb [2 ];
5048
5049
dsptensor_2.nb [3 ] = dst->nb [3 ];
5049
5050
5051
+ memcpy (dsptensor_2.op_params , dst->op_params , GGML_MAX_OP_PARAMS / sizeof (int32_t ));
5052
+
5050
5053
hexagon_error = op_func (ctx->ggmlop_handle , &dsptensor_0, &dsptensor_1, &dsptensor_2);
5051
5054
if (AEE_SUCCESS != hexagon_error) {
5052
5055
GGMLHEXAGON_LOG_WARN (" ggmlop %s computation fail on cdsp" , ggml_op_name (op->op ));
@@ -5078,19 +5081,31 @@ static bool ggmlhexagon_can_handle_op_through_cdsp(ggml_backend_dev_t dev, const
5078
5081
if (!ggml_are_same_shape (src0, src1)) {
5079
5082
return false ;
5080
5083
}
5081
-
5082
- // TODO: offload quantize GGML_OP_ADD to cDSP
5083
- return ggmlhexagon_same_types (ctx, op_tensor);
5084
+ return (src0->type == GGML_TYPE_F32) && (src1->type == GGML_TYPE_F32) && (op_tensor->type == GGML_TYPE_F32);
5084
5085
}
5085
5086
case GGML_OP_MUL_MAT:
5086
5087
{
5087
5088
ggmlhexagon_dump_op_info (op_tensor);
5088
5089
if (1 == g_hexagon_appcfg.enable_q_mulmat )
5089
- return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_Q6_K
5090
+ return (src0->type == GGML_TYPE_F32
5091
+ || src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q8_0
5092
+ || src0->type == GGML_TYPE_Q6_K || src0->type == GGML_TYPE_Q8_K
5090
5093
) && (src1->type == GGML_TYPE_F32) && (op_tensor->type == GGML_TYPE_F32);
5091
5094
else
5092
5095
return (src0->type == GGML_TYPE_F32) && (src1->type == GGML_TYPE_F32) && (op_tensor->type == GGML_TYPE_F32);
5093
5096
}
5097
+ case GGML_OP_SOFT_MAX:{
5098
+ if (!ggml_is_contiguous (op_tensor))
5099
+ return false ;
5100
+ if (!ggml_are_same_shape (src0, op_tensor))
5101
+ return false ;
5102
+ }
5103
+ case GGML_OP_RMS_NORM:
5104
+ case GGML_OP_POOL_2D:
5105
+ {
5106
+
5107
+ ggmlhexagon_dump_op_info (op_tensor);
5108
+ }
5094
5109
default :
5095
5110
break ;
5096
5111
}
0 commit comments