From db85c0f82fc0ebdae197c74012e2b409b4f3ec59 Mon Sep 17 00:00:00 2001 From: David Tschumperle Date: Wed, 11 Sep 2024 13:37:56 +0200 Subject: [PATCH] Implementation of native softmax operator. --- CImg.h | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/CImg.h b/CImg.h index dbcbdc0f..6e17b9fd 100644 --- a/CImg.h +++ b/CImg.h @@ -22379,6 +22379,19 @@ namespace cimg_library { _cimg_mp_const_scalar(is_scalar(arg1)?0:size(arg1)); } + if (!std::strncmp(ss,"softmax(",8)) { // Softmax + _cimg_mp_op("Function 'softmax()'"); + s1 = ss8; while (s10) pos = is_comp_vector(arg1)?arg1:((return_comp = true), vector(p1)); + else _cimg_mp_return(1); + CImg::vector((ulongT)mp_vector_softmax,pos,arg1,p1,arg2).move_to(code); + _cimg_mp_return(pos); + } + if (!std::strncmp(ss,"solve(",6)) { // Solve square linear system _cimg_mp_op("Function 'solve()'"); s1 = ss6; while (s10) pos = is_comp_vector(arg1)?arg1:((return_comp = true), vector(p1)); else { @@ -29241,6 +29254,19 @@ namespace cimg_library { return _mp_arg(1); } + static double mp_vector_softmax(_cimg_math_parser& mp) { + const unsigned int siz = (unsigned int)mp.opcode[3]; + const double temperature = _mp_arg(4); + if (siz>0) { // Vector-valued argument + double *const ptrd = &_mp_arg(1) + 1; + const double *const ptrs = &_mp_arg(2) + 1; + CImg(ptrd,siz,1,1,1,true) = CImg(ptrs,siz,1,1,1,true).get_softmax(temperature); + return cimg::type::nan(); + } + // Scalar-valued argument. + return 1; + } + static double mp_vector_unitnorm(_cimg_math_parser& mp) { const unsigned int siz = (unsigned int)mp.opcode[3]; const double p = _mp_arg(4); @@ -30222,8 +30248,8 @@ namespace cimg_library { CImg res(_width,_height,_depth,_spectrum); const T val_max = max(); Tfloat sum = 0; - cimg_pragma_openmp(parallel reduction(+:sum)) { - cimg_pragma_openmp(for cimg_openmp_if_size(size(),4096)) + cimg_pragma_openmp(parallel reduction(+:sum) cimg_openmp_if_size(size(),4096)) { + cimg_pragma_openmp(for) cimg_rofoff(*this,off) { const Tfloat val = std::exp(((Tfloat)_data[off] - val_max)/temperature); res[off] = val;