Skip to content

Commit

Permalink
avoid unnecessary transpose in lapackClone
Browse files Browse the repository at this point in the history
  • Loading branch information
koraykv committed May 31, 2015
1 parent e271689 commit 7727d70
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
5 changes: 3 additions & 2 deletions lib/TH/generic/THTensorLapack.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ static int THTensor_(lapackCloneNrows)(THTensor *r_, THTensor *m, int forced, in
else
{
clone = 1;
/* we need to copy */
THTensor_(resize2d)(r_,m->size[1],nrows);
THTensor_(transpose)(r_,NULL,0,1);
if (r_->stride[0] == nrows && r_->stride[1] == 1)
THTensor_(transpose)(r_,NULL,0,1);
/* we need to copy */
if (m->size[0] == nrows) {
THTensor_(copy)(r_,m);
} else {
Expand Down
18 changes: 18 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1390,6 +1390,24 @@ function torchtest.eig()
mytester:assertlt(maxdiff(vv,vvv),1e-12,'torch.eig value')
mytester:assertlt(maxdiff(vv,tv),1e-12,'torch.eig value')
end
function torchtest.test_symeig()
local xval = torch.rand(100,3)
local cov = torch.mm(xval:t(), xval)
local rese = torch.zeros(3)
local resv = torch.zeros(3,3)

-- First call to symeig
mytester:assert(resv:isContiguous(), 'resv is not contiguous') -- PASS
torch.symeig(rese, resv, cov:clone(), 'V')
local ahat = resv*torch.diag(rese)*resv:t()
mytester:assertTensorEq(cov, ahat, 1e-8, 'USV\' wrong') -- PASS

-- Second call to symeig
mytester:assert(not resv:isContiguous(), 'resv is contiguous') -- FAIL
torch.symeig(rese, resv, cov:clone(), 'V')
local ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv:t())
mytester:assertTensorEq(cov, ahat, 1e-8, 'USV\' wrong') -- FAIL
end
function torchtest.svd()
if not torch.svd then return end
local a=torch.Tensor({{8.79, 6.11, -9.15, 9.57, -3.49, 9.84},
Expand Down

0 comments on commit 7727d70

Please sign in to comment.