Skip to content

Commit

Permalink
Add 'name' and 'mode' arguments to LuaRuntime.{eval,execute,compile} …
Browse files Browse the repository at this point in the history
…to allow finer control of allowed input and debug output. (#252)

These match the same arguments of the Lua `load()` function.

Closes #248
  • Loading branch information
scoder authored Dec 11, 2023
1 parent 042267d commit 27ff097
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 18 deletions.
75 changes: 60 additions & 15 deletions lupa/_lupa.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -393,36 +393,74 @@ cdef class LuaRuntime:
raise
return 0

def eval(self, lua_code, *args):
@cython.final
cdef bytes _source_encode(self, string):
if isinstance(string, unicode):
return (<unicode>string).encode(self._source_encoding)
elif isinstance(string, bytes):
return <bytes> string
elif isinstance(string, bytearray):
return bytes(string)

raise TypeError(f"Expected string, got {type(string)}")

def eval(self, lua_code, *args, name=None, mode=None):
"""Evaluate a Lua expression passed in a string.
The 'name' argument can be used to override the name printed in error messages.
The 'mode' argument specifies the input type. By default, both source code and
pre-compiled byte code is allowed (mode='bt'). It can be restricted to source
code with mode='t' and to byte code with mode='b'. This has no effect on Lua 5.1.
"""
assert self._state is not NULL
if isinstance(lua_code, unicode):
lua_code = (<unicode>lua_code).encode(self._source_encoding)
return run_lua(self, b'return ' + lua_code, args)
name_b = self._source_encode(name) if name is not None else None
mode_b = _asciiOrNone(mode)
return run_lua(self, b'return ' + self._source_encode(lua_code), name_b, mode_b, args)

def execute(self, lua_code, *args):
def execute(self, lua_code, *args, name=None, mode=None):
"""Execute a Lua program passed in a string.
The 'name' argument can be used to override the name printed in error messages.
The 'mode' argument specifies the input type. By default, both source code and
pre-compiled byte code is allowed (mode='bt'). It can be restricted to source
code with mode='t' and to byte code with mode='b'. This has no effect on Lua 5.1.
"""
assert self._state is not NULL
if isinstance(lua_code, unicode):
lua_code = (<unicode>lua_code).encode(self._source_encoding)
return run_lua(self, lua_code, args)
name_b = self._source_encode(name) if name is not None else None
mode_b = _asciiOrNone(mode)
return run_lua(self, self._source_encode(lua_code), name_b, mode_b, args)

def compile(self, lua_code):
def compile(self, lua_code, name=None, mode=None):
"""Compile a Lua program into a callable Lua function.
The 'name' argument can be used to override the name printed in error messages.
The 'mode' argument specifies the input type. By default, both source code and
pre-compiled byte code is allowed (mode='bt'). It can be restricted to source
code with mode='t' and to byte code with mode='b'. This has no effect on Lua 5.1.
"""
assert self._state is not NULL
cdef const char *err
if isinstance(lua_code, unicode):
lua_code = (<unicode>lua_code).encode(self._source_encoding)
cdef const char * c_name = b'<python>'
cdef const char * c_mode = NULL

lua_code_bytes = self._source_encode(lua_code)
if name is not None:
name_b = self._source_encode(name)
c_name = name_b
if mode is not None:
mode_b = _asciiOrNone(mode)
c_mode = mode_b

L = self._state
lock_runtime(self)
old_top = lua.lua_gettop(L)
cdef size_t size
cdef const char *err
try:
check_lua_stack(L, 1)
status = lua.luaL_loadbuffer(L, lua_code, len(lua_code), b'<python>')
status = lua.luaL_loadbufferx(L, lua_code_bytes, len(lua_code_bytes), c_name, c_mode)
if status == 0:
return py_from_lua(self, L, -1)
else:
Expand Down Expand Up @@ -1719,14 +1757,21 @@ cdef build_lua_error_message(LuaRuntime runtime, lua_State* L, int stack_index=-

# calling into Lua

cdef run_lua(LuaRuntime runtime, bytes lua_code, tuple args):
cdef run_lua(LuaRuntime runtime, bytes lua_code, bytes name, bytes mode, tuple args):
"""Run Lua code with arguments"""
cdef lua_State* L = runtime._state
cdef const char* c_name = b'<python>'
cdef const char* c_mode = NULL
if name is not None:
c_name = name
if mode is not None:
c_mode = mode

lock_runtime(runtime)
old_top = lua.lua_gettop(L)
try:
check_lua_stack(L, 1)
if lua.luaL_loadbuffer(L, lua_code, len(lua_code), '<python>'):
if lua.luaL_loadbufferx(L, lua_code, len(lua_code), c_name, c_mode):
error = build_lua_error_message(runtime, L)
if error.startswith("not enough memory"):
raise LuaMemoryError(error)
Expand Down
8 changes: 5 additions & 3 deletions lupa/luaapi.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,10 @@ cdef extern from "lauxlib.h" nogil:
int luaL_ref (lua_State *L, int t)
void luaL_unref (lua_State *L, int t, int ref)

int luaL_loadfile (lua_State *L, char *filename)
int luaL_loadbuffer (lua_State *L, char *buff, size_t sz, char *name)
int luaL_loadstring (lua_State *L, char *s)
int luaL_loadfile (lua_State *L, const char *filename)
int luaL_loadbuffer (lua_State *L, const char *buff, size_t sz, const char *name)
int luaL_loadbufferx (lua_State *L, const char *buff, size_t sz, const char *name, const char *mode)
int luaL_loadstring (lua_State *L, const char *s)

lua_State *luaL_newstate ()

Expand Down Expand Up @@ -450,6 +451,7 @@ cdef extern from * nogil:
#if LUA_VERSION_NUM < 502
#define lua_tointegerx(L, i, isnum) (*(isnum) = lua_isnumber(L, i), lua_tointeger(L, i))
#define luaL_loadbufferx(L, buff, sz, name, mode) (((void)mode), luaL_loadbuffer(L, buff, sz, name))
#endif
#if LUA_VERSION_NUM >= 504
Expand Down
22 changes: 22 additions & 0 deletions lupa/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ def test_eval_args(self):
def test_eval_args_multi(self):
self.assertEqual((1, 2, 3), self.lua.eval('...', 1, 2, 3))

def test_eval_name_mode(self):
self.assertEqual(2, self.lua.eval('1+1', name='plus', mode='t'))

def test_eval_mode_error(self):
if self.lupa.LUA_VERSION < (5, 2):
raise unittest.SkipTest("needs lua 5.2+")
self.assertRaises(self.lupa.LuaSyntaxError, self.lua.eval, '1+1', name='plus', mode='b')

def test_eval_error(self):
self.assertRaises(self.lupa.LuaError, self.lua.eval, '<INVALIDCODE>')

Expand Down Expand Up @@ -156,6 +164,14 @@ def test_eval_error_message_decoding(self):
def test_execute(self):
self.assertEqual(2, self.lua.execute('return 1+1'))

def test_execute_mode(self):
self.assertEqual(2, self.lua.execute('return 1+1', name='return_plus', mode='t'))

def test_execute_mode_error(self):
if self.lupa.LUA_VERSION < (5, 2):
raise unittest.SkipTest("needs lua 5.2+")
self.assertRaises(self.lupa.LuaSyntaxError, self.lua.execute, 'return 1+1', name='plus', mode='b')

def test_execute_function(self):
self.assertEqual(3, self.lua.execute('f = function(i) return i+1 end; return f(2)'))

Expand Down Expand Up @@ -919,6 +935,12 @@ def f(*args, **kwargs):
def test_compile(self):
lua_func = self.lua.compile('return 1 + 2')
self.assertEqual(lua_func(), 3)
lua_func = self.lua.compile('return 3 + 2', mode='t')
self.assertEqual(lua_func(), 5)
lua_func = self.lua.compile('return 1 + 3', name='huhu')
self.assertEqual(lua_func(), 4)
lua_func = self.lua.compile('return 2 + 3', name='huhu', mode='t')
self.assertEqual(lua_func(), 5)
self.assertRaises(self.lupa.LuaSyntaxError, self.lua.compile, 'function awd()')


Expand Down

0 comments on commit 27ff097

Please sign in to comment.