From 7727d70d44769120ca514bcab2aa10616eba080b Mon Sep 17 00:00:00 2001 From: koray kavukcuoglu Date: Sun, 31 May 2015 22:43:34 +0100 Subject: [PATCH] avoid unnecessary transpose in lapackClone --- lib/TH/generic/THTensorLapack.c | 5 +++-- test/test.lua | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/lib/TH/generic/THTensorLapack.c b/lib/TH/generic/THTensorLapack.c index aa1ac581..4e490f16 100644 --- a/lib/TH/generic/THTensorLapack.c +++ b/lib/TH/generic/THTensorLapack.c @@ -20,9 +20,10 @@ static int THTensor_(lapackCloneNrows)(THTensor *r_, THTensor *m, int forced, in else { clone = 1; - /* we need to copy */ THTensor_(resize2d)(r_,m->size[1],nrows); - THTensor_(transpose)(r_,NULL,0,1); + if (r_->stride[0] == nrows && r_->stride[1] == 1) + THTensor_(transpose)(r_,NULL,0,1); + /* we need to copy */ if (m->size[0] == nrows) { THTensor_(copy)(r_,m); } else { diff --git a/test/test.lua b/test/test.lua index dbf66dc2..1b69662b 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1390,6 +1390,24 @@ function torchtest.eig() mytester:assertlt(maxdiff(vv,vvv),1e-12,'torch.eig value') mytester:assertlt(maxdiff(vv,tv),1e-12,'torch.eig value') end +function torchtest.test_symeig() + local xval = torch.rand(100,3) + local cov = torch.mm(xval:t(), xval) + local rese = torch.zeros(3) + local resv = torch.zeros(3,3) + + -- First call to symeig + mytester:assert(resv:isContiguous(), 'resv is not contiguous') -- PASS + torch.symeig(rese, resv, cov:clone(), 'V') + local ahat = resv*torch.diag(rese)*resv:t() + mytester:assertTensorEq(cov, ahat, 1e-8, 'USV\' wrong') -- PASS + + -- Second call to symeig + mytester:assert(not resv:isContiguous(), 'resv is contiguous') -- FAIL + torch.symeig(rese, resv, cov:clone(), 'V') + local ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv:t()) + mytester:assertTensorEq(cov, ahat, 1e-8, 'USV\' wrong') -- FAIL +end function torchtest.svd() if not torch.svd then return end local a=torch.Tensor({{8.79, 6.11, -9.15, 9.57, -3.49, 9.84},