Skip to content

Commit

Permalink
add squeeze/unsqueeze
Browse files Browse the repository at this point in the history
  • Loading branch information
ShigekiKarita committed Jul 8, 2018
1 parent 2e32e3d commit f7bca8e
Showing 1 changed file with 39 additions and 1 deletion.
40 changes: 39 additions & 1 deletion source/grain/chain.d
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ auto tan(T, size_t dim, alias Storage)(Variable!(T, dim, Storage) x) {
}

/// abs
auto abs(T, size_t dim, alias Storage)(Variable!(T, dim, Storage) x) {
auto abs(T, size_t dim, alias Storage)(Variable!(T, dim, Storage) x) {
import grain.functions.unary : Abs;
auto func = new Abs!(T, dim);
return func.applyForward(x);
Expand Down Expand Up @@ -440,3 +440,41 @@ unittest {
assert(hx.gradSliced == numir.view(hgy.sliced, [6, 4]));
// gradCheckChain!(x => x.view([2, 3, -1]))(hx, hgy, 1e-3, 5e-2, 5e-2);
}


/// squeeze/remove redundant size-1 dimension (axis) d
auto squeeze(size_t d, T, size_t dim, alias Storage)(Variable!(T, dim, Storage) x) {
import grain.utility : castArray;
static assert(dim >= 1);
assert(x.shape[d] == 1);
ptrdiff_t[dim-1] s;
s[0..d] = x.shape[0..d].castArray!ptrdiff_t;
s[d..$] = x.shape[d+1..$].castArray!ptrdiff_t;
return x.view(s);
}

///
unittest {
import mir.ndslice;
auto x = iota(3, 4, 1, 5).as!double.slice.variable;
assert(x.squeeze!2.shape == [3, 4, 5]);
}


/// unsqueeze/add redundant size-1 dimension (axis) d
auto unsqueeze(size_t d, T, size_t dim, alias Storage)(Variable!(T, dim, Storage) x) {
import grain.utility : castArray;
static assert(dim >= d);
ptrdiff_t[dim+1] s;
s[0..d] = x.shape[0..d].castArray!ptrdiff_t;
s[d] = 1;
s[d+1..$] = x.shape[d..$].castArray!ptrdiff_t;
return x.view(s);
}

///
unittest {
import mir.ndslice;
auto x = iota(3, 4, 5).as!double.slice.variable;
assert(x.unsqueeze!2.shape == [3, 4, 1, 5]);
}

0 comments on commit f7bca8e

Please sign in to comment.