Skip to content

Commit

Permalink
Adding a simple test for storage views.
Browse files Browse the repository at this point in the history
  • Loading branch information
zakattacktwitter committed May 26, 2015
1 parent f938d15 commit 3e1601b
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1906,14 +1906,14 @@ function torchtest.maskedCopy()
mytester:assertTensorEq(dest, dest2, 0.000001, "maskedCopy error")

-- make source bigger than number of 1s in mask
src = torch.randn(nDest)
src = torch.randn(nDest)
local ok = pcall(dest.maskedCopy, dest, mask, src)
mytester:assert(ok, "maskedCopy incorrect complaint when"
mytester:assert(ok, "maskedCopy incorrect complaint when"
.. " src is bigger than mask's one count")

src = torch.randn(nCopy - 1) -- make src smaller. this should fail
local ok = pcall(dest.maskedCopy, dest, mask, src)
mytester:assert(not ok, "maskedCopy not erroring when"
mytester:assert(not ok, "maskedCopy not erroring when"
.. " src is smaller than mask's one count")
end

Expand Down Expand Up @@ -2229,6 +2229,18 @@ function torchtest.serialize()
mytester:assertTensorEq(tensObj, torch.deserializeFromStorage(serStorage), 1e-10)
end

function torchtest.storageview()
local s1 = torch.LongStorage({3, 4, 5})
local s2 = torch.LongStorage(s1, 2)

mytester:assert(s2:size() == 2, "should be size 2")
mytester:assert(s2[1] == s1[2], "should have 4 at position 1")
mytester:assert(s2[2] == s1[3], "should have 5 at position 2")

s2[1] = 13
mytester:assert(13 == s1[2], "should have 13 at position 1")
end

function torch.test(tests)
math.randomseed(os.time())
if torch.getdefaulttensortype() == 'torch.FloatTensor' then
Expand Down

0 comments on commit 3e1601b

Please sign in to comment.