diff --git a/DiskFile.c b/DiskFile.c index 0201093e..59e12668 100644 --- a/DiskFile.c +++ b/DiskFile.c @@ -80,6 +80,6 @@ void torch_DiskFile_init(lua_State *L) luaT_newmetatable(L, "torch.DiskFile", "torch.File", torch_DiskFile_new, torch_DiskFile_free, NULL); - luaL_register(L, NULL, torch_DiskFile__); + luaT_setfuncs(L, torch_DiskFile__, 0); lua_pop(L, 1); } diff --git a/FFI.lua b/FFI.lua index b35770c4..1780ee89 100644 --- a/FFI.lua +++ b/FFI.lua @@ -1,7 +1,5 @@ -if jit then - - local ffi = require 'ffi' - +local ok, ffi = pcall(require, 'ffi') +if ok then local Real2real = { Byte='unsigned char', Char='char', diff --git a/File.c b/File.c index e2f8e425..3a463752 100644 --- a/File.c +++ b/File.c @@ -199,6 +199,6 @@ static const struct luaL_Reg torch_File__ [] = { void torch_File_init(lua_State *L) { luaT_newmetatable(L, "torch.File", NULL, NULL, NULL, NULL); - luaL_register(L, NULL, torch_File__); + luaT_setfuncs(L, torch_File__, 0); lua_pop(L, 1); } diff --git a/File.lua b/File.lua index 6ce18000..1b86171b 100644 --- a/File.lua +++ b/File.lua @@ -19,7 +19,11 @@ local TYPE_TABLE = 3 local TYPE_TORCH = 4 local TYPE_BOOLEAN = 5 local TYPE_FUNCTION = 6 -local TYPE_RECUR_FUNCTION = 7 +local TYPE_RECUR_FUNCTION = 8 +local LEGACY_TYPE_RECUR_FUNCTION = 7 + +-- Lua 5.2 compatibility +local loadstring = loadstring or load function File:isWritableObject(object) local typename = type(object) @@ -138,7 +142,8 @@ function File:writeObject(object) counter = counter + 1 local name,value = debug.getupvalue(object, counter) if not name then break end - table.insert(upvalues, value) + if name == '_ENV' then value = nil end + table.insert(upvalues, {name=name, value=value}) end local dumped = string.dump(object) local stringStorage = torch.CharStorage():string(dumped) @@ -214,7 +219,7 @@ function File:readObject() debug.setupvalue(func, index, upvalue) end return func - elseif typeidx == TYPE_TABLE or typeidx == TYPE_TORCH or typeidx == TYPE_RECUR_FUNCTION then + elseif typeidx == TYPE_TABLE or typeidx == TYPE_TORCH or typeidx == TYPE_RECUR_FUNCTION or typeidx == LEGACY_TYPE_RECUR_FUNCTION then -- read the index local index = self:readInt() @@ -225,16 +230,22 @@ function File:readObject() end -- otherwise read it - if typeidx == TYPE_RECUR_FUNCTION then - local size = self:readInt() - local dumped = self:readChar(size):string() - local func = loadstring(dumped) - objects[index] = func - local upvalues = self:readObject() - for index,upvalue in ipairs(upvalues) do - debug.setupvalue(func, index, upvalue) - end - return func + if typeidx == TYPE_RECUR_FUNCTION or typeidx == LEGACY_TYPE_RECUR_FUNCTION then + local size = self:readInt() + local dumped = self:readChar(size):string() + local func = loadstring(dumped) + objects[index] = func + local upvalues = self:readObject() + for index,upvalue in ipairs(upvalues) do + if typeidx == LEGACY_TYPE_RECUR_FUNCTION then + debug.setupvalue(func, index, upvalue) + elseif upvalue.name == '_ENV' then + debug.setupvalue(func, index, _ENV) + else + debug.setupvalue(func, index, upvalue.value) + end + end + return func elseif typeidx == TYPE_TORCH then local version, className, versionNumber version = self:readChar(self:readInt()):string() diff --git a/Generator.c b/Generator.c index 7caf0ce1..06ec6d00 100644 --- a/Generator.c +++ b/Generator.c @@ -24,6 +24,6 @@ void torch_Generator_init(lua_State *L) { luaT_newmetatable(L, torch_Generator, NULL, torch_Generator_new, torch_Generator_free, torch_Generator_factory); - luaL_register(L, NULL, torch_Generator_table_); + luaT_setfuncs(L, torch_Generator_table_, 0); lua_pop(L, 1); } diff --git a/MemoryFile.c b/MemoryFile.c index a153e207..114dbc49 100644 --- a/MemoryFile.c +++ b/MemoryFile.c @@ -56,6 +56,6 @@ void torch_MemoryFile_init(lua_State *L) { luaT_newmetatable(L, "torch.MemoryFile", "torch.File", torch_MemoryFile_new, torch_MemoryFile_free, NULL); - luaL_register(L, NULL, torch_MemoryFile__); + luaT_setfuncs(L, torch_MemoryFile__, 0); lua_pop(L, 1); } diff --git a/PipeFile.c b/PipeFile.c index b18f1642..a47c90d1 100644 --- a/PipeFile.c +++ b/PipeFile.c @@ -38,6 +38,6 @@ void torch_PipeFile_init(lua_State *L) { luaT_newmetatable(L, "torch.PipeFile", "torch.DiskFile", torch_PipeFile_new, torch_PipeFile_free, NULL); - luaL_register(L, NULL, torch_PipeFile__); + luaT_setfuncs(L, torch_PipeFile__, 0); lua_pop(L, 1); } diff --git a/Tensor.lua b/Tensor.lua index fded57e7..cf1d9b7e 100644 --- a/Tensor.lua +++ b/Tensor.lua @@ -7,6 +7,9 @@ local Tensor = {} -- types local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Double'} +-- Lua 5.2 compatibility +local log10 = math.log10 or function(x) return math.log(x, 10) end + -- tostring() functions for Tensor and Storage local function Storage__printformat(self) if self:size() == 0 then @@ -25,13 +28,13 @@ local function Storage__printformat(self) local tensor = torch.DoubleTensor(torch.DoubleStorage(self:size()):copy(self), 1, self:size()):abs() local expMin = tensor:min() if expMin ~= 0 then - expMin = math.floor(math.log10(expMin)) + 1 + expMin = math.floor(log10(expMin)) + 1 else expMin = 1 end local expMax = tensor:max() if expMax ~= 0 then - expMax = math.floor(math.log10(expMax)) + 1 + expMax = math.floor(log10(expMax)) + 1 else expMax = 1 end diff --git a/TensorMath.lua b/TensorMath.lua index 121c5061..52280fa2 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -120,7 +120,8 @@ local function wrap(...) end end end - method:wrap(unpack(args)) + local unpack = unpack or table.unpack + method:wrap(unpack(args)) end local reals = {ByteTensor='unsigned char', @@ -1133,12 +1134,12 @@ static void torch_TensorMath_init(lua_State *L) luaT_pushmetatable(L, "torch.Tensor"); /* register methods */ - luaL_register(L, NULL, m_torch_TensorMath__); + luaT_setfuncs(L, m_torch_TensorMath__, 0); /* register functions into the "torch" field of the tensor metaclass */ lua_pushstring(L, "torch"); lua_newtable(L); - luaL_register(L, NULL, torch_TensorMath__); + luaT_setfuncs(L, torch_TensorMath__, 0); lua_rawset(L, -3); lua_pop(L, 1); } @@ -1157,7 +1158,7 @@ void torch_TensorMath_init(lua_State *L) torch_LongTensorMath_init(L); torch_FloatTensorMath_init(L); torch_DoubleTensorMath_init(L); - luaL_register(L, NULL, torch_TensorMath__); + luaT_setfuncs(L, torch_TensorMath__, 0); } ]]) diff --git a/Timer.c b/Timer.c index 152ff75d..96f792a4 100644 --- a/Timer.c +++ b/Timer.c @@ -165,6 +165,6 @@ static const struct luaL_Reg torch_Timer__ [] = { void torch_Timer_init(lua_State *L) { luaT_newmetatable(L, "torch.Timer", NULL, torch_Timer_new, torch_Timer_free, NULL); - luaL_register(L, NULL, torch_Timer__); + luaT_setfuncs(L, torch_Timer__, 0); lua_pop(L, 1); } diff --git a/generic/Storage.c b/generic/Storage.c index 98fad343..8022e530 100644 --- a/generic/Storage.c +++ b/generic/Storage.c @@ -273,7 +273,7 @@ void torch_Storage_(init)(lua_State *L) { luaT_newmetatable(L, torch_Storage, NULL, torch_Storage_(new), torch_Storage_(free), torch_Storage_(factory)); - luaL_register(L, NULL, torch_Storage_(_)); + luaT_setfuncs(L, torch_Storage_(_), 0); lua_pop(L, 1); } diff --git a/generic/Tensor.c b/generic/Tensor.c index e5f03a3a..680c7d17 100644 --- a/generic/Tensor.c +++ b/generic/Tensor.c @@ -1278,7 +1278,7 @@ void torch_Tensor_(init)(lua_State *L) { luaT_newmetatable(L, torch_Tensor, NULL, torch_Tensor_(new), torch_Tensor_(free), torch_Tensor_(factory)); - luaL_register(L, NULL, torch_Tensor_(_)); + luaT_setfuncs(L, torch_Tensor_(_), 0); lua_pop(L, 1); } diff --git a/generic/TensorOperator.c b/generic/TensorOperator.c index 6acb8f4f..f6fe4f1f 100644 --- a/generic/TensorOperator.c +++ b/generic/TensorOperator.c @@ -166,7 +166,7 @@ static const struct luaL_Reg torch_TensorOperator_(_) [] = { void torch_TensorOperator_(init)(lua_State *L) { luaT_pushmetatable(L, torch_Tensor); - luaL_register(L, NULL, torch_TensorOperator_(_)); + luaT_setfuncs(L, torch_TensorOperator_(_), 0); lua_pop(L, 1); } diff --git a/init.c b/init.c index 4d953f60..ad2b257d 100644 --- a/init.c +++ b/init.c @@ -56,7 +56,7 @@ int luaopen_libtorch(lua_State *L) lua_newtable(L); lua_pushvalue(L, -1); - lua_setfield(L, LUA_GLOBALSINDEX, "torch"); + lua_setglobal(L, "torch"); torch_File_init(L); diff --git a/init.lua b/init.lua index 9480b775..76e08d07 100644 --- a/init.lua +++ b/init.lua @@ -6,6 +6,9 @@ if not string.gfind then string.gfind = string.gmatch end +if not table.unpack then + table.unpack = unpack +end require "paths" paths.require "libtorch" diff --git a/lib/luaT/luaT.c b/lib/luaT/luaT.c index 7b85ce3b..ee74c929 100644 --- a/lib/luaT/luaT.c +++ b/lib/luaT/luaT.c @@ -45,6 +45,24 @@ void luaT_free(lua_State *L, void *ptr) free(ptr); } +void luaT_setfuncs(lua_State *L, const luaL_Reg *l, int nup) +{ +#if LUA_VERSION_NUM == 501 + luaL_checkstack(L, nup+1, "too many upvalues"); + for (; l->name != NULL; l++) { /* fill the table with given functions */ + int i; + lua_pushstring(L, l->name); + for (i = 0; i < nup; i++) /* copy upvalues to the top */ + lua_pushvalue(L, -(nup+1)); + lua_pushcclosure(L, l->func, nup); /* closure with those upvalues */ + lua_settable(L, -(nup + 3)); + } + lua_pop(L, nup); /* remove upvalues */ +#else + luaL_setfuncs(L, l, nup); +#endif +} + void luaT_stackdump(lua_State *L) { int i; @@ -159,11 +177,16 @@ const char *luaT_typenameid(lua_State *L, const char *tname) } static const char cdataname[] = "" - "local _, ffi = pcall(require, 'ffi')\n" - "if ffi then\n" + "local ok, ffi = pcall(require, 'ffi')\n" + "if ok then\n" " local id2name = {}\n" " return function(cdata, name)\n" - " local id = tonumber(ffi.typeof(cdata))\n" + " local id\n" + " if jit then\n" + " id = tonumber(ffi.typeof(cdata))\n" + " else\n" + " id = tostring(ffi.typeof(cdata))\n" + " end\n" " if id then\n" " if name then\n" " id2name[id] = name\n" @@ -208,9 +231,47 @@ static const char* luaT_cdataname(lua_State *L, int ud, const char *tname) return tname; } +static void* CDATA_MT_KEY = &CDATA_MT_KEY; +static const char cdatamt[] = "" + "local ok, ffi = pcall(require, 'ffi')\n" + "if ok and not jit then\n" + " return ffi.debug().cdata_mt\n" + "else\n" + " return {}\n" + "end\n"; + +static int luaT_iscdata(lua_State *L, int ud) +{ + int type = lua_type(L, ud); + if(type == 10) + return 1; + if(type != LUA_TUSERDATA) + return 0; + if(!lua_getmetatable(L, ud)) + return 0; + + lua_pushlightuserdata(L, CDATA_MT_KEY); + lua_rawget(L, LUA_REGISTRYINDEX); + if (lua_isnil(L, -1)) + { + // initialize cdata metatable + lua_pop(L, 1); + if(luaL_dostring(L, cdatamt)) + luaL_error(L, "internal error (could not load cdata mt): %s", lua_tostring(L, -1)); + + lua_pushlightuserdata(L, CDATA_MT_KEY); + lua_pushvalue(L, -2); + lua_rawset(L, LUA_REGISTRYINDEX); + } + + int iscdata = lua_rawequal(L, -1, -2); + lua_pop(L, 2); + return iscdata; +} + const char* luaT_typename(lua_State *L, int ud) { - if(lua_type(L, ud) == 10) + if(luaT_iscdata(L, ud)) return luaT_cdataname(L, ud, NULL); else if(lua_getmetatable(L, ud)) { @@ -405,7 +466,7 @@ void luaT_registeratname(lua_State *L, const struct luaL_Reg *methods, const cha lua_rawget(L, idx); } - luaL_register(L, NULL, methods); + luaT_setfuncs(L, methods, 0); lua_pop(L, 1); } @@ -451,9 +512,9 @@ int luaT_lua_newmetatable(lua_State *L) luaL_argcheck(L, lua_isnoneornil(L, 5) || lua_isfunction(L, 5), 5, "factory function or nil expected"); if(is_in_module) - lua_getfield(L, LUA_GLOBALSINDEX, module_name); + lua_getglobal(L, module_name); else - lua_pushvalue(L, LUA_GLOBALSINDEX); + lua_pushglobaltable(L); if(!lua_istable(L, 6)) luaL_error(L, "while creating metatable %s: bad argument #1 (%s is an invalid module name)", tname, module_name); @@ -689,6 +750,8 @@ int luaT_lua_pushudata(lua_State *L) if(lua_type(L, 1) == 10) udata = *((void**)lua_topointer(L, 1)); + else if(luaT_iscdata(L, 1)) + udata = ((void**)lua_topointer(L, 1))[4]; else if(lua_isnumber(L, 1)) udata = (void*)(long)lua_tonumber(L, 1); else @@ -763,7 +826,21 @@ int luaT_lua_isequal(lua_State *L) int luaT_lua_pointer(lua_State *L) { - if(lua_isuserdata(L, 1)) + if(lua_type(L, 1) == 10) /* luajit cdata */ + { + /* we want the pointer holded by cdata */ + /* not the pointer on the cdata object */ + const void* ptr = *((void**)lua_topointer(L, 1)); + lua_pushnumber(L, (long)(ptr)); + return 1; + } + else if (luaT_iscdata(L, 1)) /* luaffi cdata */ + { + void** ptr = (void**)lua_touserdata(L, 1); + lua_pushnumber(L, (long)(ptr[4])); + return 1; + } + else if(lua_isuserdata(L, 1)) { void **ptr; luaL_argcheck(L, luaT_typename(L, 1), 1, "Torch object expected"); @@ -777,14 +854,6 @@ int luaT_lua_pointer(lua_State *L) lua_pushnumber(L, (long)(ptr)); return 1; } - else if(lua_type(L, 1) == 10) /* cdata */ - { - /* we want the pointer holded by cdata */ - /* not the pointer on the cdata object */ - const void* ptr = *((void**)lua_topointer(L, 1)); - lua_pushnumber(L, (long)(ptr)); - return 1; - } else if(lua_isstring(L, 1)) { const char* ptr = lua_tostring(L, 1); @@ -802,7 +871,7 @@ int luaT_lua_setenv(lua_State *L) if(!lua_isfunction(L, 1) && !lua_isuserdata(L, 1)) luaL_typerror(L, 1, "function or userdata"); luaL_checktype(L, 2, LUA_TTABLE); - lua_setfenv(L, 1); + lua_setuservalue(L, 1); return 0; } @@ -810,7 +879,9 @@ int luaT_lua_getenv(lua_State *L) { if(!lua_isfunction(L, 1) && !lua_isuserdata(L, 1)) luaL_typerror(L, 1, "function or userdata"); - lua_getfenv(L, 1); + lua_getuservalue(L, 1); + if (lua_isnil(L, -1)) + lua_newtable(L); return 1; } @@ -826,7 +897,7 @@ int luaT_lua_version(lua_State *L) { luaL_checkany(L, 1); - if(lua_type(L, 1) == 10) + if(luaT_iscdata(L, 1)) { const char *tname = luaT_cdataname(L, 1, NULL); if(tname) diff --git a/lib/luaT/luaT.h b/lib/luaT/luaT.h index 5e8dd2f6..4c25fc1c 100644 --- a/lib/luaT/luaT.h +++ b/lib/luaT/luaT.h @@ -32,6 +32,18 @@ extern "C" { # define LUAT_API LUA_EXTERNC #endif +#if LUA_VERSION_NUM == 501 +# define lua_pushglobaltable(L) lua_pushvalue(L, LUA_GLOBALSINDEX) +# define lua_setuservalue lua_setfenv +# define lua_getuservalue lua_getfenv +#else +# define lua_objlen lua_rawlen +static int luaL_typerror(lua_State *L, int narg, const char *tname) +{ + return luaL_error(L, "%s expected, got %s", tname, luaL_typename(L, narg)); +} +#endif + /* C functions */ @@ -39,6 +51,8 @@ LUAT_API void* luaT_alloc(lua_State *L, long size); LUAT_API void* luaT_realloc(lua_State *L, void *ptr, long size); LUAT_API void luaT_free(lua_State *L, void *ptr); +LUAT_API void luaT_setfuncs(lua_State *L, const luaL_Reg *l, int nup); + LUAT_API const char* luaT_newmetatable(lua_State *L, const char *tname, const char *parenttname, lua_CFunction constructor, lua_CFunction destructor, lua_CFunction factory); diff --git a/random.lua b/random.lua index 6740866a..a6f0c3da 100644 --- a/random.lua +++ b/random.lua @@ -46,7 +46,7 @@ void torch_random_init(lua_State *L) torch_Generator_init(L); torch_Generator_new(L); lua_setfield(L, -2, "_gen"); - luaL_register(L, NULL, random__); + luaT_setfuncs(L, random__, 0); } ]]) diff --git a/test/test.lua b/test/test.lua index 8eb68b92..a9fe7668 100644 --- a/test/test.lua +++ b/test/test.lua @@ -5,6 +5,10 @@ local torchtest = {} local msize = 100 local precision +-- Lua 5.2 compatibility +local loadstring = loadstring or load +local unpack = unpack or table.unpack + local function maxdiff(x,y) local d = x-y if x:type() == 'torch.DoubleTensor' or x:type() == 'torch.FloatTensor' then diff --git a/test/test_writeObject.lua b/test/test_writeObject.lua index 36be49a7..90392f95 100644 --- a/test/test_writeObject.lua +++ b/test/test_writeObject.lua @@ -28,6 +28,28 @@ function tests.test_can_write_a_nil_closure() myTester:assert(copyClosure() == closure(), 'the closures should give same output') end +function tests.test_nil_upvalues_in_closure() + local a = 1 + local b + local c = 2 + local function closure() + if not b then return c end + return a + end + + local copyClosure = serializeAndDeserialize(closure) + myTester:assert(copyClosure() == closure(), 'the closures should give same output') +end + +function tests.test_global_function_in_closure() + local x = "5" + local function closure(str) + return tonumber(str .. x) + end + + local copyClosure = serializeAndDeserialize(closure) + myTester:assert(copyClosure("3") == closure("3"), 'the closures should give same output') +end function tests.test_a_recursive_closure() local foo diff --git a/test/timeSort.lua b/test/timeSort.lua index 3d798979..baa0b7de 100644 --- a/test/timeSort.lua +++ b/test/timeSort.lua @@ -10,7 +10,8 @@ cmd:option('-r', 20, 'Number of repetitions') local options = cmd:parse(arg or {}) function main() - local pow10 = torch.linspace(1,math.log10(options.N), options.p) + local log10 = math.log10 or function(x) return math.log(x, 10) end + local pow10 = torch.linspace(1,log10(options.N), options.p) local num_sizes = options.p local num_reps = options.r @@ -128,11 +129,11 @@ function main() gnuplot.xlabel('N') gnuplot.ylabel('Speed-up Factor (s)') gnuplot.figprint('benchmarkRatio.png') - + torch.save('benchmark.t7', { new_rnd=new_rnd, new_srt=new_srt, - new_cst=new_cst, + new_cst=new_cst, old_rnd=old_rnd, old_srt=old_srt, old_cst=old_cst, diff --git a/utils.c b/utils.c index aa5603a3..acc77f7b 100644 --- a/utils.c +++ b/utils.c @@ -113,7 +113,7 @@ static int torch_lua_getdefaulttensortype(lua_State *L) const char* torch_getdefaulttensortype(lua_State *L) { - lua_getfield(L, LUA_GLOBALSINDEX, "torch"); + lua_getglobal(L, "torch"); if(lua_istable(L, -1)) { lua_getfield(L, -1, "Tensor"); @@ -214,5 +214,5 @@ static const struct luaL_Reg torch_utils__ [] = { void torch_utils_init(lua_State *L) { - luaL_register(L, NULL, torch_utils__); + luaT_setfuncs(L, torch_utils__, 0); }