Skip to content

Commit

Permalink
Merge pull request torch#268 from dominikgrewe/map
Browse files Browse the repository at this point in the history
Make FFI tensor map more robust.
  • Loading branch information
soumith committed Jun 26, 2015
2 parents 10c768d + 10abb79 commit 38bbb02
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions FFI.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,19 @@
local ok, ffi = pcall(require, 'ffi')

local function checkArgument(condition, fn, ud, msg, level)
local level = level or 3
if not condition then
error("bad argument #" .. ud .. " to '" .. fn .. "' (" .. msg .. ")", level)
end
end

local function checkArgumentType(expected, actual, fn, ud, level)
local level = level or 3
if expected ~= actual then
checkArgument(false, fn, ud, expected .. " expected, got " .. actual, level + 1)
end
end

if ok then
local Real2real = {
Byte='unsigned char',
Expand Down Expand Up @@ -115,6 +130,9 @@ typedef struct THRealTensor
rawset(Tensor,
"map",
function(self, src, func)
checkArgument(torch.isTensor(src), "map", 1, "tensor expected")
checkArgumentType(self:type(), src:type(), "map", 1)

if self:isContiguous() and src:isContiguous() and self.data and src.data then
local self_d = self:data()
local src_d = src:data()
Expand All @@ -136,6 +154,11 @@ typedef struct THRealTensor
rawset(Tensor,
"map2",
function(self, src1, src2, func)
checkArgument(torch.isTensor(src1), "map", 1, "tensor expected")
checkArgument(torch.isTensor(src2), "map", 2, "tensor expected")
checkArgumentType(self:type(), src1:type(), "map", 1)
checkArgumentType(self:type(), src2:type(), "map", 2)

if self:isContiguous() and src1:isContiguous() and src2:isContiguous() and self.data and src1.data and src2.data then
local self_d = self:data()
local src1_d = src1:data()
Expand Down

0 comments on commit 38bbb02

Please sign in to comment.