From 9cf1907b8d74a102fb7456f8ef9103a7ac1f7582 Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Thu, 30 Apr 2015 15:00:56 -0400 Subject: [PATCH] Fix serialization of closures in Lua 5.2 This fixes test_can_write_a_nil_closure in Lua 5.2. Added a test case that affected Lua 5.1 and Lua 5.2. Support serialization of closures that reference the environment. --- File.lua | 34 +++++++++++++++++++++------------- test/test_writeObject.lua | 22 ++++++++++++++++++++++ 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/File.lua b/File.lua index a7a6d62e..1b86171b 100644 --- a/File.lua +++ b/File.lua @@ -19,7 +19,8 @@ local TYPE_TABLE = 3 local TYPE_TORCH = 4 local TYPE_BOOLEAN = 5 local TYPE_FUNCTION = 6 -local TYPE_RECUR_FUNCTION = 7 +local TYPE_RECUR_FUNCTION = 8 +local LEGACY_TYPE_RECUR_FUNCTION = 7 -- Lua 5.2 compatibility local loadstring = loadstring or load @@ -141,7 +142,8 @@ function File:writeObject(object) counter = counter + 1 local name,value = debug.getupvalue(object, counter) if not name then break end - table.insert(upvalues, value) + if name == '_ENV' then value = nil end + table.insert(upvalues, {name=name, value=value}) end local dumped = string.dump(object) local stringStorage = torch.CharStorage():string(dumped) @@ -217,7 +219,7 @@ function File:readObject() debug.setupvalue(func, index, upvalue) end return func - elseif typeidx == TYPE_TABLE or typeidx == TYPE_TORCH or typeidx == TYPE_RECUR_FUNCTION then + elseif typeidx == TYPE_TABLE or typeidx == TYPE_TORCH or typeidx == TYPE_RECUR_FUNCTION or typeidx == LEGACY_TYPE_RECUR_FUNCTION then -- read the index local index = self:readInt() @@ -228,16 +230,22 @@ function File:readObject() end -- otherwise read it - if typeidx == TYPE_RECUR_FUNCTION then - local size = self:readInt() - local dumped = self:readChar(size):string() - local func = loadstring(dumped) - objects[index] = func - local upvalues = self:readObject() - for index,upvalue in ipairs(upvalues) do - debug.setupvalue(func, index, upvalue) - end - return func + if typeidx == TYPE_RECUR_FUNCTION or typeidx == LEGACY_TYPE_RECUR_FUNCTION then + local size = self:readInt() + local dumped = self:readChar(size):string() + local func = loadstring(dumped) + objects[index] = func + local upvalues = self:readObject() + for index,upvalue in ipairs(upvalues) do + if typeidx == LEGACY_TYPE_RECUR_FUNCTION then + debug.setupvalue(func, index, upvalue) + elseif upvalue.name == '_ENV' then + debug.setupvalue(func, index, _ENV) + else + debug.setupvalue(func, index, upvalue.value) + end + end + return func elseif typeidx == TYPE_TORCH then local version, className, versionNumber version = self:readChar(self:readInt()):string() diff --git a/test/test_writeObject.lua b/test/test_writeObject.lua index 36be49a7..90392f95 100644 --- a/test/test_writeObject.lua +++ b/test/test_writeObject.lua @@ -28,6 +28,28 @@ function tests.test_can_write_a_nil_closure() myTester:assert(copyClosure() == closure(), 'the closures should give same output') end +function tests.test_nil_upvalues_in_closure() + local a = 1 + local b + local c = 2 + local function closure() + if not b then return c end + return a + end + + local copyClosure = serializeAndDeserialize(closure) + myTester:assert(copyClosure() == closure(), 'the closures should give same output') +end + +function tests.test_global_function_in_closure() + local x = "5" + local function closure(str) + return tonumber(str .. x) + end + + local copyClosure = serializeAndDeserialize(closure) + myTester:assert(copyClosure("3") == closure("3"), 'the closures should give same output') +end function tests.test_a_recursive_closure() local foo