Skip to content

Commit

Permalink
Merge pull request torch#252 from d11/torch_range
Browse files Browse the repository at this point in the history
Tweak torch.range to be more numerically robust.
  • Loading branch information
soumith committed Jun 11, 2015
2 parents 3dbedc5 + 9ac965a commit 5f41136
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 6 deletions.
6 changes: 3 additions & 3 deletions TensorMath.lua
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,9 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor",
wrap("range",
cname("range"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
{name=real},
{name=real},
{name=real, default=1}})
{name=accreal},
{name=accreal},
{name=accreal, default=1}})

wrap("randperm",
cname("randperm"),
Expand Down
4 changes: 2 additions & 2 deletions lib/TH/generic/THTensorMath.c
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,7 @@ void THTensor_(eye)(THTensor *r_, long n, long m)
}


void THTensor_(range)(THTensor *r_, real xmin, real xmax, real step)
void THTensor_(range)(THTensor *r_, accreal xmin, accreal xmax, accreal step)
{
long size;
real i = 0;
Expand All @@ -1151,7 +1151,7 @@ void THTensor_(range)(THTensor *r_, real xmin, real xmax, real step)
THArgCheck(((step > 0) && (xmax >= xmin)) || ((step < 0) && (xmax <= xmin))
, 2, "upper bound and larger bound incoherent with step sign");

size = (long)((xmax-xmin)/step+1);
size = (long)((xmax/step - xmin/step)+1);

THTensor_(resize1d)(r_, size);

Expand Down
2 changes: 1 addition & 1 deletion lib/TH/generic/THTensorMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ TH_API void THTensor_(zeros)(THTensor *r_, THLongStorage *size);
TH_API void THTensor_(ones)(THTensor *r_, THLongStorage *size);
TH_API void THTensor_(diag)(THTensor *r_, THTensor *t, int k);
TH_API void THTensor_(eye)(THTensor *r_, long n, long m);
TH_API void THTensor_(range)(THTensor *r_, real xmin, real xmax, real step);
TH_API void THTensor_(range)(THTensor *r_, accreal xmin, accreal xmax, accreal step);
TH_API void THTensor_(randperm)(THTensor *r_, THGenerator *_generator, long n);

TH_API void THTensor_(reshape)(THTensor *r_, THTensor *t, THLongStorage *size);
Expand Down
12 changes: 12 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,18 @@ function torchtest.rangeequalbounds()
torch.range(mxx,1,1,1)
mytester:asserteq(maxdiff(mx,mxx),0,'torch.range value for equal bounds step')
end
function torchtest.rangefloat()
local mx = torch.FloatTensor():range(0.6, 0.9, 0.1)
mytester:asserteq(mx:size(1), 4, 'wrong size for FloatTensor range')
mx = torch.FloatTensor():range(1, 10, 0.3)
mytester:asserteq(mx:size(1), 31, 'wrong size for FloatTensor range')
end
function torchtest.rangedouble()
local mx = torch.DoubleTensor():range(0.6, 0.9, 0.1)
mytester:asserteq(mx:size(1), 4, 'wrong size for DoubleTensor range')
mx = torch.DoubleTensor():range(1, 10, 0.3)
mytester:asserteq(mx:size(1), 31, 'wrong size for DoubleTensor range')
end
function torchtest.randperm()
local t=os.time()
torch.manualSeed(t)
Expand Down

0 comments on commit 5f41136

Please sign in to comment.