From 10abb79287a53b97c63c4a3b439e8174c719a7b1 Mon Sep 17 00:00:00 2001 From: Dominik Grewe Date: Fri, 26 Jun 2015 11:58:53 +0100 Subject: [PATCH] Make FFI tensor map more robust. The 'default' tensor map (in generic/Tensor.c) makes sure that the input tensors are of the same type. The FFI implementation (for the contiguous case) doesn't do that. It should do so for consistency and also to avoid segfaults when passing in a CudaTensor, for example. --- FFI.lua | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/FFI.lua b/FFI.lua index 1780ee89..904302aa 100644 --- a/FFI.lua +++ b/FFI.lua @@ -1,4 +1,19 @@ local ok, ffi = pcall(require, 'ffi') + +local function checkArgument(condition, fn, ud, msg, level) + local level = level or 3 + if not condition then + error("bad argument #" .. ud .. " to '" .. fn .. "' (" .. msg .. ")", level) + end +end + +local function checkArgumentType(expected, actual, fn, ud, level) + local level = level or 3 + if expected ~= actual then + checkArgument(false, fn, ud, expected .. " expected, got " .. actual, level + 1) + end +end + if ok then local Real2real = { Byte='unsigned char', @@ -115,6 +130,9 @@ typedef struct THRealTensor rawset(Tensor, "map", function(self, src, func) + checkArgument(torch.isTensor(src), "map", 1, "tensor expected") + checkArgumentType(self:type(), src:type(), "map", 1) + if self:isContiguous() and src:isContiguous() and self.data and src.data then local self_d = self:data() local src_d = src:data() @@ -136,6 +154,11 @@ typedef struct THRealTensor rawset(Tensor, "map2", function(self, src1, src2, func) + checkArgument(torch.isTensor(src1), "map", 1, "tensor expected") + checkArgument(torch.isTensor(src2), "map", 2, "tensor expected") + checkArgumentType(self:type(), src1:type(), "map", 1) + checkArgumentType(self:type(), src2:type(), "map", 2) + if self:isContiguous() and src1:isContiguous() and src2:isContiguous() and self.data and src1.data and src2.data then local self_d = self:data() local src1_d = src1:data()