Skip to content

Commit

Permalink
Implementation of native softmax operator.
Browse files Browse the repository at this point in the history
  • Loading branch information
dtschump committed Sep 11, 2024
1 parent d516cde commit db85c0f
Showing 1 changed file with 29 additions and 3 deletions.
32 changes: 29 additions & 3 deletions CImg.h
Original file line number Diff line number Diff line change
Expand Up @@ -22379,6 +22379,19 @@ namespace cimg_library {
_cimg_mp_const_scalar(is_scalar(arg1)?0:size(arg1));
}

if (!std::strncmp(ss,"softmax(",8)) { // Softmax
_cimg_mp_op("Function 'softmax()'");
s1 = ss8; while (s1<se1 && (*s1!=',' || level[s1 - expr._data]!=clevel1)) ++s1;
arg1 = compile(ss8,s1,depth1,0,block_flags);
arg2 = s1<se1?compile(++s1,se1,depth1,0,block_flags):1;
_cimg_mp_check_type(arg2,2,1,0);
p1 = size(arg1);
if (p1>0) pos = is_comp_vector(arg1)?arg1:((return_comp = true), vector(p1));
else _cimg_mp_return(1);
CImg<ulongT>::vector((ulongT)mp_vector_softmax,pos,arg1,p1,arg2).move_to(code);
_cimg_mp_return(pos);
}

if (!std::strncmp(ss,"solve(",6)) { // Solve square linear system
_cimg_mp_op("Function 'solve()'");
s1 = ss6; while (s1<se1 && (*s1!=',' || level[s1 - expr._data]!=clevel1)) ++s1;
Expand Down Expand Up @@ -22798,7 +22811,7 @@ namespace cimg_library {
s1 = s0; while (s1<se1 && (*s1!=',' || level[s1 - expr._data]!=clevel1)) ++s1;
arg1 = compile(s0,s1,depth1,0,block_flags);
arg2 = s1<se1?compile(++s1,se1,depth1,0,block_flags):2;
_cimg_mp_check_type(arg2,0,1,0);
_cimg_mp_check_type(arg2,2,1,0);
p1 = size(arg1);
if (p1>0) pos = is_comp_vector(arg1)?arg1:((return_comp = true), vector(p1));
else {
Expand Down Expand Up @@ -29241,6 +29254,19 @@ namespace cimg_library {
return _mp_arg(1);
}

static double mp_vector_softmax(_cimg_math_parser& mp) {
const unsigned int siz = (unsigned int)mp.opcode[3];
const double temperature = _mp_arg(4);
if (siz>0) { // Vector-valued argument
double *const ptrd = &_mp_arg(1) + 1;
const double *const ptrs = &_mp_arg(2) + 1;
CImg<doubleT>(ptrd,siz,1,1,1,true) = CImg<doubleT>(ptrs,siz,1,1,1,true).get_softmax(temperature);
return cimg::type<double>::nan();
}
// Scalar-valued argument.
return 1;
}

static double mp_vector_unitnorm(_cimg_math_parser& mp) {
const unsigned int siz = (unsigned int)mp.opcode[3];
const double p = _mp_arg(4);
Expand Down Expand Up @@ -30222,8 +30248,8 @@ namespace cimg_library {
CImg<Tfloat> res(_width,_height,_depth,_spectrum);
const T val_max = max();
Tfloat sum = 0;
cimg_pragma_openmp(parallel reduction(+:sum)) {
cimg_pragma_openmp(for cimg_openmp_if_size(size(),4096))
cimg_pragma_openmp(parallel reduction(+:sum) cimg_openmp_if_size(size(),4096)) {
cimg_pragma_openmp(for)
cimg_rofoff(*this,off) {
const Tfloat val = std::exp(((Tfloat)_data[off] - val_max)/temperature);
res[off] = val;
Expand Down

0 comments on commit db85c0f

Please sign in to comment.