diff --git a/FFI.lua b/FFI.lua index 1780ee89..904302aa 100644 --- a/FFI.lua +++ b/FFI.lua @@ -1,4 +1,19 @@ local ok, ffi = pcall(require, 'ffi') + +local function checkArgument(condition, fn, ud, msg, level) + local level = level or 3 + if not condition then + error("bad argument #" .. ud .. " to '" .. fn .. "' (" .. msg .. ")", level) + end +end + +local function checkArgumentType(expected, actual, fn, ud, level) + local level = level or 3 + if expected ~= actual then + checkArgument(false, fn, ud, expected .. " expected, got " .. actual, level + 1) + end +end + if ok then local Real2real = { Byte='unsigned char', @@ -115,6 +130,9 @@ typedef struct THRealTensor rawset(Tensor, "map", function(self, src, func) + checkArgument(torch.isTensor(src), "map", 1, "tensor expected") + checkArgumentType(self:type(), src:type(), "map", 1) + if self:isContiguous() and src:isContiguous() and self.data and src.data then local self_d = self:data() local src_d = src:data() @@ -136,6 +154,11 @@ typedef struct THRealTensor rawset(Tensor, "map2", function(self, src1, src2, func) + checkArgument(torch.isTensor(src1), "map", 1, "tensor expected") + checkArgument(torch.isTensor(src2), "map", 2, "tensor expected") + checkArgumentType(self:type(), src1:type(), "map", 1) + checkArgumentType(self:type(), src2:type(), "map", 2) + if self:isContiguous() and src1:isContiguous() and src2:isContiguous() and self.data and src1.data and src2.data then local self_d = self:data() local src1_d = src1:data()