Skip to content

Commit

Permalink
Tweak torch.range to be more numerically robust.
Browse files Browse the repository at this point in the history
Otherwise it can produce mildly surprising behaviour,
in the floating-point setting, due to rounding errors.
  • Loading branch information
d11 committed Jun 3, 2015
1 parent 5958eec commit 9ac965a
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 6 deletions.
6 changes: 3 additions & 3 deletions TensorMath.lua
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,9 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor",
wrap("range",
cname("range"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
{name=real},
{name=real},
{name=real, default=1}})
{name=accreal},
{name=accreal},
{name=accreal, default=1}})

wrap("randperm",
cname("randperm"),
Expand Down
4 changes: 2 additions & 2 deletions lib/TH/generic/THTensorMath.c
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,7 @@ void THTensor_(eye)(THTensor *r_, long n, long m)
}


void THTensor_(range)(THTensor *r_, real xmin, real xmax, real step)
void THTensor_(range)(THTensor *r_, accreal xmin, accreal xmax, accreal step)
{
long size;
real i = 0;
Expand All @@ -1151,7 +1151,7 @@ void THTensor_(range)(THTensor *r_, real xmin, real xmax, real step)
THArgCheck(((step > 0) && (xmax >= xmin)) || ((step < 0) && (xmax <= xmin))
, 2, "upper bound and larger bound incoherent with step sign");

size = (long)((xmax-xmin)/step+1);
size = (long)((xmax/step - xmin/step)+1);

THTensor_(resize1d)(r_, size);

Expand Down
2 changes: 1 addition & 1 deletion lib/TH/generic/THTensorMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ TH_API void THTensor_(zeros)(THTensor *r_, THLongStorage *size);
TH_API void THTensor_(ones)(THTensor *r_, THLongStorage *size);
TH_API void THTensor_(diag)(THTensor *r_, THTensor *t, int k);
TH_API void THTensor_(eye)(THTensor *r_, long n, long m);
TH_API void THTensor_(range)(THTensor *r_, real xmin, real xmax, real step);
TH_API void THTensor_(range)(THTensor *r_, accreal xmin, accreal xmax, accreal step);
TH_API void THTensor_(randperm)(THTensor *r_, THGenerator *_generator, long n);

TH_API void THTensor_(reshape)(THTensor *r_, THTensor *t, THLongStorage *size);
Expand Down
12 changes: 12 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,18 @@ function torchtest.rangeequalbounds()
torch.range(mxx,1,1,1)
mytester:asserteq(maxdiff(mx,mxx),0,'torch.range value for equal bounds step')
end
function torchtest.rangefloat()
local mx = torch.FloatTensor():range(0.6, 0.9, 0.1)
mytester:asserteq(mx:size(1), 4, 'wrong size for FloatTensor range')
mx = torch.FloatTensor():range(1, 10, 0.3)
mytester:asserteq(mx:size(1), 31, 'wrong size for FloatTensor range')
end
function torchtest.rangedouble()
local mx = torch.DoubleTensor():range(0.6, 0.9, 0.1)
mytester:asserteq(mx:size(1), 4, 'wrong size for DoubleTensor range')
mx = torch.DoubleTensor():range(1, 10, 0.3)
mytester:asserteq(mx:size(1), 31, 'wrong size for DoubleTensor range')
end
function torchtest.randperm()
local t=os.time()
torch.manualSeed(t)
Expand Down

0 comments on commit 9ac965a

Please sign in to comment.