Skip to content

Commit

Permalink
Make FFI tensor map more robust.
Browse files Browse the repository at this point in the history
The 'default' tensor map (in generic/Tensor.c) makes sure that the
input tensors are of the same type. The FFI implementation (for the
contiguous case) doesn't do that. It should do so for consistency and also
to avoid segfaults when passing in a CudaTensor, for example.
  • Loading branch information
dominikgrewe committed Jun 26, 2015
1 parent 10c768d commit 10abb79
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions FFI.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,19 @@
local ok, ffi = pcall(require, 'ffi')

local function checkArgument(condition, fn, ud, msg, level)
local level = level or 3
if not condition then
error("bad argument #" .. ud .. " to '" .. fn .. "' (" .. msg .. ")", level)
end
end

local function checkArgumentType(expected, actual, fn, ud, level)
local level = level or 3
if expected ~= actual then
checkArgument(false, fn, ud, expected .. " expected, got " .. actual, level + 1)
end
end

if ok then
local Real2real = {
Byte='unsigned char',
Expand Down Expand Up @@ -115,6 +130,9 @@ typedef struct THRealTensor
rawset(Tensor,
"map",
function(self, src, func)
checkArgument(torch.isTensor(src), "map", 1, "tensor expected")
checkArgumentType(self:type(), src:type(), "map", 1)

if self:isContiguous() and src:isContiguous() and self.data and src.data then
local self_d = self:data()
local src_d = src:data()
Expand All @@ -136,6 +154,11 @@ typedef struct THRealTensor
rawset(Tensor,
"map2",
function(self, src1, src2, func)
checkArgument(torch.isTensor(src1), "map", 1, "tensor expected")
checkArgument(torch.isTensor(src2), "map", 2, "tensor expected")
checkArgumentType(self:type(), src1:type(), "map", 1)
checkArgumentType(self:type(), src2:type(), "map", 2)

if self:isContiguous() and src1:isContiguous() and src2:isContiguous() and self.data and src1.data and src2.data then
local self_d = self:data()
local src1_d = src1:data()
Expand Down

0 comments on commit 10abb79

Please sign in to comment.