diff --git a/data/expression2/tests/compiler/compiler/restrictions/fn_void_param.txt b/data/expression2/tests/compiler/compiler/restrictions/fn_void_param.txt new file mode 100644 index 0000000000..5128b062e6 --- /dev/null +++ b/data/expression2/tests/compiler/compiler/restrictions/fn_void_param.txt @@ -0,0 +1,3 @@ +## SHOULD_FAIL:COMPILE + +function test(X:void) {} \ No newline at end of file diff --git a/data/expression2/tests/compiler/compiler/restrictions/lambda/implicit_param.txt b/data/expression2/tests/compiler/compiler/restrictions/lambda/implicit_param.txt new file mode 100644 index 0000000000..7cc9e62e44 --- /dev/null +++ b/data/expression2/tests/compiler/compiler/restrictions/lambda/implicit_param.txt @@ -0,0 +1,4 @@ +## SHOULD_FAIL:COMPILE + +# Implicit number fallback is not going to be allowed. +const J = function(X) {} \ No newline at end of file diff --git a/data/expression2/tests/compiler/compiler/restrictions/lambda/return_codepaths.txt b/data/expression2/tests/compiler/compiler/restrictions/lambda/return_codepaths.txt new file mode 100644 index 0000000000..0c50cbf810 --- /dev/null +++ b/data/expression2/tests/compiler/compiler/restrictions/lambda/return_codepaths.txt @@ -0,0 +1,6 @@ +## SHOULD_FAIL:COMPILE + +const X = function() { + if (1) { return "str" } + # doesn't return string +} \ No newline at end of file diff --git a/data/expression2/tests/compiler/compiler/restrictions/lambda/return_type_mix.txt b/data/expression2/tests/compiler/compiler/restrictions/lambda/return_type_mix.txt new file mode 100644 index 0000000000..6a63dee2bb --- /dev/null +++ b/data/expression2/tests/compiler/compiler/restrictions/lambda/return_type_mix.txt @@ -0,0 +1,6 @@ +## SHOULD_FAIL:COMPILE + +const X = function() { + if (1) { return "str" } + return 22 +} \ No newline at end of file diff --git a/data/expression2/tests/compiler/compiler/restrictions/lambda/variadic_param.txt b/data/expression2/tests/compiler/compiler/restrictions/lambda/variadic_param.txt new file mode 100644 index 0000000000..4fb88eca1d --- /dev/null +++ b/data/expression2/tests/compiler/compiler/restrictions/lambda/variadic_param.txt @@ -0,0 +1,4 @@ +## SHOULD_FAIL:COMPILE + +const X = function(...A:array) {} +const Y = function(...A:table) {} \ No newline at end of file diff --git a/data/expression2/tests/compiler/compiler/restrictions/lambda/void_param.txt b/data/expression2/tests/compiler/compiler/restrictions/lambda/void_param.txt new file mode 100644 index 0000000000..bf686d24c5 --- /dev/null +++ b/data/expression2/tests/compiler/compiler/restrictions/lambda/void_param.txt @@ -0,0 +1,3 @@ +## SHOULD_FAIL:COMPILE + +const X = function(X:void) {} \ No newline at end of file diff --git a/data/expression2/tests/runtime/base/lambdas.txt b/data/expression2/tests/runtime/base/lambdas.txt new file mode 100644 index 0000000000..8e35132c92 --- /dev/null +++ b/data/expression2/tests/runtime/base/lambdas.txt @@ -0,0 +1,87 @@ +## SHOULD_PASS:EXECUTE + +# Returns + +assert( (function() { return 55 })()[number] == 55 ) +assert( (function() { return "str" })()[string] == "str" ) + +# Upvalues + +const Wrapper = function(V:number) { + return function() { + return V + } +} + +const F1 = Wrapper(55)[function] + +if (1) { + if (2) { + local V = 22 + assert(F1()[number] == 55) + } +} + +assert(F1()[number] == 55) +assert(F1()[number] == 55) + +const F2 = Wrapper(1238)[function] + +assert(F2()[number] == 1238) +#local V = 21 +assert(F2()[number] == 1238) + +const IsEven = function(N:number) { + return N % 2 == 0 +} + +const Not = function(N:number) { + return !N +} + +const IsOdd = function(N:number) { + return Not(IsEven(N)[number])[number] +} + +assert(IsOdd(1)[number] == 1) +assert(IsOdd(2)[number] == 0) + +assert( ((function() { return function() { return 55 } })()[function])()[number] == 55 ) + +const Identity = function(N:number) { + return N +} + +assert(Identity(2)[number] == 2) +assert(Identity(2193921)[number] == 2193921) + +local SayMessage = function() {} + +const SetMessage = function(Message:string) { + SayMessage = function() { + return Message + } +} + +SetMessage("There's a snake in my boot!") + +assert( SayMessage()[string] == "There's a snake in my boot!" ) +assert( SayMessage()[string] == "There's a snake in my boot!" ) + +SetMessage("Reach for the sky!") + +assert( SayMessage()[string] == "Reach for the sky!" ) + +const EarlyReturn = function() { + return +} + +Ran = 0 + +function wrapper() { + EarlyReturn() + Ran = 1 +} + +wrapper() +assert(Ran) \ No newline at end of file diff --git a/lua/entities/gmod_wire_expression2/base/compiler.lua b/lua/entities/gmod_wire_expression2/base/compiler.lua index 624b0de24d..3bbf959e4a 100644 --- a/lua/entities/gmod_wire_expression2/base/compiler.lua +++ b/lua/entities/gmod_wire_expression2/base/compiler.lua @@ -225,7 +225,7 @@ local CompileVisitors = { local i = #stmts + 1 stmts[i], traces[i] = stmt, trace - if node:isExpr() and node.variant ~= NodeVariant.ExprStringCall and node.variant ~= NodeVariant.ExprCall and node.variant ~= NodeVariant.ExprMethodCall then + if node:isExpr() and node.variant ~= NodeVariant.ExprDynCall and node.variant ~= NodeVariant.ExprCall and node.variant ~= NodeVariant.ExprMethodCall then self:Warning("This expression has no effect", node.trace) end end @@ -650,7 +650,7 @@ local CompileVisitors = { local existing = {} for i, param in ipairs(data[4]) do if param.type then - local t = self:CheckType(param.type) + local t = self:Assert(self:CheckType(param.type), "Cannot use void as parameter type", param.name.trace) if param.variadic then self:Assert(t == "r" or t == "t", "Variadic parameter must be of type array or table", param.type.trace) variadic_ind, variadic_ty = i, t @@ -684,9 +684,9 @@ local CompileVisitors = { end else if return_type then - self:Assert(fn_data.returns and fn_data.returns[1] == return_type, "Cannot override with differing return type", trace) + self:Assert(fn_data.ret == return_type, "Cannot override with differing return type", trace) else - self:Assert(fn_data.returns == nil, "Cannot override function returning void with differing return type", trace) + self:Assert(fn_data.ret == nil, "Cannot override function returning void with differing return type", trace) end if not self.strict then @@ -697,7 +697,7 @@ local CompileVisitors = { end end - local fn = { args = param_types, returns = return_type and { return_type }, meta = meta_type, cost = variadic_ty and 10 or 5 + (self.strict and 0 or 3), attrs = {} } + local fn = { args = param_types, ret = return_type, meta = meta_type, cost = variadic_ty and 10 or 5 + (self.strict and 0 or 3), attrs = {} } local sig = table.concat(param_types, "", 1, #param_types - 1) .. ((variadic_ty and ".." or "") .. (param_types[#param_types] or "")) if meta_type then @@ -824,7 +824,7 @@ local CompileVisitors = { end end) - self:Assert((fn.returns and fn.returns[1]) == return_type, "Function " .. name.value .. " expects to return type (" .. (return_type or "void") .. ") but got type (" .. ((fn.returns and fn.returns[1]) or "void") .. ")", trace) + self:Assert(fn.ret == return_type, "Function " .. name.value .. " expects to return type (" .. (return_type or "void") .. ") but got type (" .. (fn.ret or "void") .. ")", trace) local sig = name.value .. "(" .. (meta_type and (meta_type .. ":") or "") .. sig .. ")" local fn = fn.op @@ -920,10 +920,10 @@ local CompileVisitors = { local name, fn = fn[1], fn[2] - if fn.returns then - self:Assert(fn.returns[1] == ret_ty, "Function " .. name .. " expects return type (" .. (fn.returns[1] or "void") .. ") but was given (" .. (ret_ty or "void") .. ")", trace) + if fn.ret then + self:Assert(fn.ret == ret_ty, "Function " .. name .. " expects return type (" .. (fn.ret or "void") .. ") but was given (" .. (ret_ty or "void") .. ")", trace) else - fn.returns = { ret_ty } + fn.ret = ret_ty end if ret_ty then @@ -1319,6 +1319,68 @@ local CompileVisitors = { end end, + ---@param data { [1]: Parameter[], [2]: Node } + [NodeVariant.ExprFunction] = function(self, trace, data) + ---@type EnvFunction + local fn, param_names, param_types, nargs = { attrs = {} }, {}, {}, #data[1] + + local block = self:Scope(function(scope) + scope.data["function"] = { "", fn } + + for i, param in ipairs(data[1]) do + self:Assert(param.type, "Cannot omit parameter type for lambda, annotate with :", param.name.trace) + param_names[i], param_types[i] = param.name.value, self:Assert(self:CheckType(param.type), "Cannot use void as parameter", param.name.trace) + self:Assert(not param.variadic, "Variadic lambdas are not supported, use an array instead", param.name.trace) + scope:DeclVar(param.name.value, { type = param_types[i], initialized = true, trace_if_unused = param.name.trace }) + end + + local block = self:CompileStmt(data[2]) + + if fn.ret then -- Ensure function either returns or errors + self:Assert(scope.data.dead, "Not all codepaths return a value of type '" .. fn.ret .. "'", trace) + end + + return block + end) + + local ret = fn.ret + local expected_sig = table.concat(param_types) + + self.scope.data.ops = self.scope.data.ops + 25 + + return function(state) + local inherited_scopes, after = {}, state.ScopeID + 1 + for i = 0, state.ScopeID do + inherited_scopes[i] = state.Scopes[i] + end + + return E2Lib.Lambda.new( + expected_sig, + ret, + function(args) + local s_scopes, s_scope, s_scopeid = state.Scopes, state.Scope, state.ScopeID + + local scope = { vclk = {} } + state.Scopes = inherited_scopes + state.ScopeID = after + state.Scopes[after] = scope + state.Scope = scope + + for i = 1, nargs do + scope[param_names[i]] = args[i] + end + + block(state) + + state.ScopeID, state.Scope, state.Scopes = s_scopeid, s_scope, s_scopes + + state.__return__ = false + return state.__returnval__ + end + ) + end, "f" + end, + [NodeVariant.ExprArithmetic] = handleInfixOperation, ---@param data { [1]: Node, [2]: Operator, [3]: self } @@ -1509,7 +1571,7 @@ local CompileVisitors = { rargs[k] = args[k](state) end return fn(state, rargs, types) - end, fn_data.returns and (fn_data.returns[1] ~= "" and fn_data.returns[1] or nil) + end, fn_data.ret and (fn_data.ret ~= "" and fn_data.ret or nil) else self.scope.data.ops = self.scope.data.ops + (fn_data.cost or 15) + (fn_data.attrs["legacy"] and 10 or 0) @@ -1526,7 +1588,7 @@ local CompileVisitors = { else state:forceThrow("No such function defined at runtime: " .. full_sig) end - end, fn_data.returns and (fn_data.returns[1] ~= "" and fn_data.returns[1] or nil) + end, fn_data.ret and (fn_data.ret ~= "" and fn_data.ret or nil) end elseif fn_data.attrs["legacy"] then -- Not a user function. Can get function to call at compile time. local fn, largs = fn_data.op, { [1] = {}, [nargs + 2] = types } @@ -1535,7 +1597,7 @@ local CompileVisitors = { end return function(state) ---@param state RuntimeContext return fn(state, largs) - end, fn_data.returns and (fn_data.returns[1] ~= "" and fn_data.returns[1] or nil) + end, fn_data.ret and (fn_data.ret ~= "" and fn_data.ret or nil) else local fn = fn_data.op return function(state) ---@param state RuntimeContext @@ -1545,7 +1607,7 @@ local CompileVisitors = { end return fn(state, rargs, types) - end, fn_data.returns and (fn_data.returns[1] ~= "" and fn_data.returns[1] or nil) + end, fn_data.ret and (fn_data.ret ~= "" and fn_data.ret or nil) end end, @@ -1579,7 +1641,7 @@ local CompileVisitors = { rargs[k + 1] = args[k](state) end return fn(state, rargs, types) - end, fn_data.returns and (fn_data.returns[1] ~= "" and fn_data.returns[1] or nil) + end, fn_data.ret and (fn_data.ret ~= "" and fn_data.ret or nil) else local full_sig = name.value .. "(" .. meta_type .. ":" .. arg_sig .. ")" return function(state) ---@param state RuntimeContext @@ -1594,7 +1656,7 @@ local CompileVisitors = { else state:forceThrow("No such method defined at runtime: " .. full_sig) end - end, fn_data.returns and (fn_data.returns[1] ~= "" and fn_data.returns[1] or nil) + end, fn_data.ret and (fn_data.ret ~= "" and fn_data.ret or nil) end elseif fn_data.attrs["legacy"] then local fn, largs = fn_data.op, { [nargs + 3] = types, [2] = { [1] = meta } } @@ -1604,7 +1666,7 @@ local CompileVisitors = { return function(state) ---@param state RuntimeContext return fn(state, largs) - end, fn_data.returns and fn_data.returns[1] + end, fn_data.ret else local fn = fn_data.op return function(state) ---@param state RuntimeContext @@ -1614,117 +1676,146 @@ local CompileVisitors = { end return fn(state, rargs, types) - end, fn_data.returns and fn_data.returns[1] + end, fn_data.ret end end, ---@param data { [1]: Node, [2]: Node[], [3]: Token? } - [NodeVariant.ExprStringCall] = function (self, trace, data) - local expr = self:CompileExpr(data[1]) + [NodeVariant.ExprDynCall] = function (self, trace, data) + local expr, expr_ty = self:CompileExpr(data[1]) local args, arg_types = {}, {} for i, arg in ipairs(data[2]) do args[i], arg_types[i] = self:CompileExpr(arg) end - local type_sig = table.concat(arg_types) - local arg_sig = "(" .. type_sig .. ")" - local meta_arg_sig = #arg_types >= 1 and ("(" .. arg_types[1] .. ":" .. table.concat(arg_types, "", 2) .. ")") or "()" - local ret_type = data[3] and self:CheckType(data[3]) - local nargs = #args - return function(state) ---@param state RuntimeContext - local rargs = {} - for k = 1, nargs do - rargs[k] = args[k](state) - end + if expr_ty == "s" then + self:Warning("String calls are deprecated. Use lambdas instead. This will be an error on @strict in the future.", trace) + self.scope.data.ops = self.scope.data.ops + 25 - local fn_name = expr(state) - local sig, meta_sig = fn_name .. arg_sig, fn_name .. meta_arg_sig + local type_sig = table.concat(arg_types) + local arg_sig = "(" .. type_sig .. ")" + local meta_arg_sig = #arg_types >= 1 and ("(" .. arg_types[1] .. ":" .. table.concat(arg_types, "", 2) .. ")") or "()" - local fn = state.funcs[sig] or state.funcs[meta_sig] - if fn then -- first check if user defined any functions that match signature - local r = state.funcs_ret[sig] or state.funcs_ret[meta_sig] - if r ~= ret_type then - state:forceThrow( "Mismatching return types. Got " .. (r or "void") .. ", expected " .. (ret_type or "void")) + local nargs = #args + return function(state) ---@param state RuntimeContext + local rargs = {} + for k = 1, nargs do + rargs[k] = args[k](state) end - return fn(state, rargs, arg_types) - else -- no user defined functions, check builtins - fn = wire_expression2_funcs[sig] or wire_expression2_funcs[meta_sig] - if fn then - local r = fn[2] - if r ~= ret_type and not (ret_type == nil and r == "") then + local fn_name = expr(state) + local sig, meta_sig = fn_name .. arg_sig, fn_name .. meta_arg_sig + + local fn = state.funcs[sig] or state.funcs[meta_sig] + if fn then -- first check if user defined any functions that match signature + local r = state.funcs_ret[sig] or state.funcs_ret[meta_sig] + if r ~= ret_type then state:forceThrow( "Mismatching return types. Got " .. (r or "void") .. ", expected " .. (ret_type or "void")) end - if fn.attributes.legacy then - local largs = { [1] = {}, [nargs + 2] = arg_types } - for i = 1, nargs do - largs[i + 1] = { [1] = function() return rargs[i] end } + return fn(state, rargs, arg_types) + else -- no user defined functions, check builtins + fn = wire_expression2_funcs[sig] or wire_expression2_funcs[meta_sig] + if fn then + local r = fn[2] + if r ~= ret_type and not (ret_type == nil and r == "") then + state:forceThrow( "Mismatching return types. Got " .. (r or "void") .. ", expected " .. (ret_type or "void")) end - return fn[3](state, largs, arg_types) - else - return fn[3](state, rargs, arg_types) - end - else -- none found, check variadic builtins - for i = nargs, 0, -1 do - local varsig = fn_name .. "(" .. type_sig:sub(1, i) .. "...)" - local fn = wire_expression2_funcs[varsig] - if fn then - local r = fn[2] - if r ~= ret_type and not (ret_type == nil and r == "") then - state:forceThrow("Mismatching return types. Got " .. (r or "void") .. ", expected " .. (ret_type or "void")) - end - if fn.attributes.legacy then - local largs = { [1] = {}, [nargs + 2] = arg_types } - for i = 1, nargs do - largs[i + 1] = { [1] = function() return rargs[i] end } - end - return fn[3](state, largs, arg_types) - elseif varsig == "array(...)" then -- Need this since can't enforce compile time argument type restrictions on string calls. Woop. Array creation should not be a function.. - local i = 1 - while i <= #arg_types do - local ty = arg_types[i] - if BLOCKED_ARRAY_TYPES[ty] then - table.remove(rargs, i) - table.remove(arg_types, i) - state:forceThrow("Cannot use type " .. ty .. " for argument #" .. i .. " in stringcall array creation") - else - i = i + 1 - end - end + if fn.attributes.legacy then + local largs = { [1] = {}, [nargs + 2] = arg_types } + for i = 1, nargs do + largs[i + 1] = { [1] = function() return rargs[i] end } end - - return fn[3](state, rargs, arg_types) + return fn[3](state, largs, arg_types) else - local varsig = fn_name .. "(" .. type_sig:sub(1, i) .. "..r)" - local fn = state.funcs[varsig] - + return fn[3](state, rargs, arg_types) + end + else -- none found, check variadic builtins + for i = nargs, 0, -1 do + local varsig = fn_name .. "(" .. type_sig:sub(1, i) .. "...)" + local fn = wire_expression2_funcs[varsig] if fn then - for _, ty in ipairs(arg_types) do -- Just block them entirely. Current method of finding variadics wouldn't allow a proper solution that works with x types. Would need to rewrite all of this which I don't think is worth it when already nobody is going to use this functionality. - if BLOCKED_ARRAY_TYPES[ty] then - state:forceThrow("Cannot pass array into variadic array function") + local r = fn[2] + if r ~= ret_type and not (ret_type == nil and r == "") then + state:forceThrow("Mismatching return types. Got " .. (r or "void") .. ", expected " .. (ret_type or "void")) + end + + if fn.attributes.legacy then + local largs = { [1] = {}, [nargs + 2] = arg_types } + for i = 1, nargs do + largs[i + 1] = { [1] = function() return rargs[i] end } + end + return fn[3](state, largs, arg_types) + elseif varsig == "array(...)" then -- Need this since can't enforce compile time argument type restrictions on string calls. Woop. Array creation should not be a function.. + local i = 1 + while i <= #arg_types do + local ty = arg_types[i] + if BLOCKED_ARRAY_TYPES[ty] then + table.remove(rargs, i) + table.remove(arg_types, i) + state:forceThrow("Cannot use type " .. ty .. " for argument #" .. i .. " in stringcall array creation") + else + i = i + 1 + end end end - return fn(state, rargs, arg_types) + return fn[3](state, rargs, arg_types) else - local varsig = fn_name .. "(" .. type_sig:sub(1, i) .. "..t)" + local varsig = fn_name .. "(" .. type_sig:sub(1, i) .. "..r)" local fn = state.funcs[varsig] + if fn then + for _, ty in ipairs(arg_types) do -- Just block them entirely. Current method of finding variadics wouldn't allow a proper solution that works with x types. Would need to rewrite all of this which I don't think is worth it when already nobody is going to use this functionality. + if BLOCKED_ARRAY_TYPES[ty] then + state:forceThrow("Cannot pass array into variadic array function") + end + end + return fn(state, rargs, arg_types) + else + local varsig = fn_name .. "(" .. type_sig:sub(1, i) .. "..t)" + local fn = state.funcs[varsig] + if fn then + return fn(state, rargs, arg_types) + end end end end + + state:forceThrow("No such function: " .. fn_name .. arg_sig) end + end + end, ret_type + elseif expr_ty == "f" then + self.scope.data.ops = self.scope.data.ops + 15 -- Since functions are 10 ops, this is pretty lenient. I will decrease this slightly when functions are made static and cheaper. + + local nargs = #args + local sig = table.concat(arg_types) + + return function(state) + ---@type E2Lambda + local f = expr(state) - state:forceThrow("No such function: " .. fn_name .. arg_sig) + if f.arg_sig ~= sig then + state:forceThrow("Incorrect arguments passed to lambda, expected (" .. f.arg_sig .. ") got (" .. sig .. ")") + elseif f.ret ~= ret_type then + state:forceThrow("Expected type " .. (ret_type or "void") .. " from lambda, got " .. (f.ret or "void")) + else + local rargs = {} + for k = 1, nargs do + rargs[k] = args[k](state) + end + return f.fn(rargs) end - end - end, ret_type + end, ret_type + else + self:Error("Cannot call type of " .. expr_ty, trace) + end end, ---@param data { [1]: Token, [2]: Parameter[], [3]: Node } @@ -1873,14 +1964,14 @@ function Compiler:GetFunction(name, types, method) local sig, method_prefix = table.concat(types), method and (method .. ":") or "" local fn = wire_expression2_funcs[name .. "(" .. method_prefix .. sig .. ")"] - if fn then return { op = fn[3], returns = { fn[2] }, args = types, cost = fn[4], attrs = fn.attributes }, false, false end + if fn then return { op = fn[3], ret = fn[2], args = types, cost = fn[4], attrs = fn.attributes }, false, false end local fn, variadic = self:GetUserFunction(name, types, method) if fn then return fn, variadic, true end for i = #sig, 0, -1 do fn = wire_expression2_funcs[name .. "(" .. method_prefix .. sig:sub(1, i) .. "...)"] - if fn then return { op = fn[3], returns = { fn[2] }, args = types, cost = fn[4], attrs = fn.attributes }, true, false end + if fn then return { op = fn[3], ret = fn[2], args = types, cost = fn[4], attrs = fn.attributes }, true, false end end end @@ -1904,15 +1995,15 @@ end ---@return RuntimeOperator function Compiler:Process(ast) for var, type in pairs(self.persist[3]) do - self.scope:DeclVar(var, { initialized = false, trace_if_unused = self.persist[5][var], type = type }) + self.global_scope:DeclVar(var, { initialized = false, trace_if_unused = self.persist[5][var], type = type }) end for var, type in pairs(self.inputs[3]) do - self.scope:DeclVar(var, { initialized = true, trace_if_unused = self.inputs[5][var], type = type }) + self.global_scope:DeclVar(var, { initialized = true, trace_if_unused = self.inputs[5][var], type = type }) end for var, type in pairs(self.outputs[3]) do - self.scope:DeclVar(var, { initialized = false, type = type }) + self.global_scope:DeclVar(var, { initialized = false, type = type }) end return self:CompileStmt(ast) diff --git a/lua/entities/gmod_wire_expression2/base/parser.lua b/lua/entities/gmod_wire_expression2/base/parser.lua index 3e041e3a77..b695a20895 100644 --- a/lua/entities/gmod_wire_expression2/base/parser.lua +++ b/lua/entities/gmod_wire_expression2/base/parser.lua @@ -94,12 +94,13 @@ local NodeVariant = { ExprIndex = 29, -- `[, ?]` ExprGrouped = 30, -- () ExprCall = 31, -- `call()` - ExprStringCall = 32, -- `""()` (Temporary until lambdas are made) + ExprDynCall = 32, -- `Var()` ExprUnaryWire = 33, -- `~Var` `$Var` `->Var` ExprArray = 34, -- `array(1, 2, 3)` or `array(1 = 2, 2 = 3)` ExprTable = 35, -- `table(1, 2, 3)` or `table(1 = 2, "test" = 3)` - ExprLiteral = 36, -- `"test"` `5e2` `4.023` `4j` - ExprIdent = 37 -- `Variable` + ExprFunction = 36, -- `function() {}` + ExprLiteral = 37, -- `"test"` `5e2` `4.023` `4j` + ExprIdent = 38 -- `Variable` } Parser.Variant = NodeVariant @@ -533,8 +534,16 @@ end ---@return Token? function Parser:Type() local type = self:Consume(TokenVariant.LowerIdent) - if type and type.value == "normal" then - type.value = "number" + if type then + if type.value == "normal" then + type.value = "number" + end + else -- workaround to allow "function" as type while also being a keyword + local fn = self:Consume(TokenVariant.Keyword, Keyword.Function) + if fn then + fn.value, fn.variant = "function", TokenVariant.LowerIdent + return fn + end end return type end @@ -885,7 +894,7 @@ function Parser:Expr14() end end - return Node.new(NodeVariant.ExprStringCall, { expr, args, typ }, expr.trace:stitch(self:Prev().trace)) + return Node.new(NodeVariant.ExprDynCall, { expr, args, typ }, expr.trace:stitch(self:Prev().trace)) else break end @@ -915,6 +924,11 @@ function Parser:Expr15() return Node.new(NodeVariant.ExprCall, { fn, self:Arguments() }, fn.trace:stitch(self:Prev().trace)) end + local fn = self:Consume(TokenVariant.Keyword, Keyword.Function) + if fn then + return Node.new(NodeVariant.ExprFunction, { self:Parameters(), self:Assert(self:Block(), "Expected block to follow function") }, fn.trace:stitch(self:Prev().trace)) + end + -- Decimal / Hexadecimal / Binary numbers local num = self:Consume(TokenVariant.Decimal) or self:Consume(TokenVariant.Hexadecimal) or self:Consume(TokenVariant.Binary) if num then diff --git a/lua/entities/gmod_wire_expression2/core/e2lib.lua b/lua/entities/gmod_wire_expression2/core/e2lib.lua index de809a1951..25053fe1b0 100644 --- a/lua/entities/gmod_wire_expression2/core/e2lib.lua +++ b/lua/entities/gmod_wire_expression2/core/e2lib.lua @@ -15,7 +15,7 @@ AddCSLuaFile() ---@class EnvOperator ---@field args TypeSignature[] ----@field returns TypeSignature[] +---@field ret TypeSignature? ---@field op RuntimeOperator ---@field cost integer @@ -70,6 +70,53 @@ function E2Lib.newE2Table() return { n = {}, ntypes = {}, s = {}, stypes = {}, size = 0 } end +---@class E2Lambda +---@field fn fun(args: any[]): any +---@field arg_sig string +---@field ret string +local Function = {} +Function.__index = Function + +function Function.new(args, ret, fn) + return setmetatable({ arg_sig = args, ret = ret, fn = fn }, Function) +end + +E2Lib.Lambda = Function + +--- Call the function without doing any type checking. +--- Only use this when you check self:Args() yourself to ensure you have the correct signature function. +function Function:UnsafeCall(args) + return self.fn(args) +end + +function Function:Call(args, types) + if self.arg_sig == types then + return self.fn(args) + else + error("Incorrect arguments passed to lambda") + end +end + +function Function:Args() + return self.arg_sig +end + +function Function:Ret() + return self.ret +end + +--- If given the correct arguments, returns the inner untyped function you can call. +--- Otherwise, throws an error to the given E2 Context. +---@param arg_sig string +---@param ctx RuntimeContext +function Function:Unwrap(arg_sig, ctx) + if self.arg_sig == arg_sig then + return self.fn + else + ctx:forceThrow("Incorrect function signature passed, expected (" .. arg_sig .. ") got (" .. self.arg_sig .. ")") + end +end + -- Returns a cloned table of the variable given if it is a table. -- TODO: Ditch this system for instead having users provide a function that returns the default value. -- Would be much more efficient and avoid type checks. diff --git a/lua/entities/gmod_wire_expression2/core/functions.lua b/lua/entities/gmod_wire_expression2/core/functions.lua index 217e40dfa4..6b9b2de8c5 100644 --- a/lua/entities/gmod_wire_expression2/core/functions.lua +++ b/lua/entities/gmod_wire_expression2/core/functions.lua @@ -1,30 +1,47 @@ ---[[============================================================ - E2 Function System - By Rusketh - General Operators -============================================================]]-- +--[[ + Lambdas for Expression 2 + Format: fun(args: any[], sig: string): ret_ty string?, ret any + Format: { arg_sig: string, ret: string, fn: fun(args: any[]): any } +]] + +registerType("function", "f", nil, + function(self) self.entity:Error("You may not input a function") end, + function(self) self.entity:Error("You may not output a function") end, + nil, + function(v) + return not istable(v) or getmetatable(v) ~= E2Lib.Lambda + end +) __e2setcost(1) -registerOperator("function", "", "", function(self, args) - local sig, body = args[2], args[3] - self.funcs[sig] = body +e2function number operator_is(function f) + return f and 1 or 0 +end - local cached = self.strfunc_cache[1][sig] - if cached then - self.strfunc_cache[2][ cached[3] ] = nil - self.strfunc_cache[1][sig] = nil +local function splitTypeFast(sig) + local i, r, count, len = 1, {}, 0, #sig + while i <= len do + count = count + 1 + if string.sub(sig, i, i) == "x" then + r[count] = string.sub(sig, i, i + 2) + i = i + 3 + else + r[count] = string.sub(sig, i, i) + i = i + 1 + end end -end) + return r +end -__e2setcost(2) +__e2setcost(5) -registerOperator("return", "", "", function(self, args) - if args[2] then - local op = args[2] - local rv = op[1](self, op) - self.func_rv = rv - end +e2function array function:getParameterTypes() + return splitTypeFast(this.arg_sig) +end + +__e2setcost(1) - error("return",0) -end) +e2function string function:getReturnType() + return this.ret or "" +end \ No newline at end of file diff --git a/lua/entities/gmod_wire_expression2/init.lua b/lua/entities/gmod_wire_expression2/init.lua index 266aca895b..b51dedda50 100644 --- a/lua/entities/gmod_wire_expression2/init.lua +++ b/lua/entities/gmod_wire_expression2/init.lua @@ -166,8 +166,10 @@ function ENT:Execute() end self.GlobalScope.vclk = {} - for k, var in pairs(self.globvars_mut) do - self.GlobalScope[k] = fixDefault(wire_expression_types2[var.type][2]) + if not self.directives.strict then + for k, var in pairs(self.globvars_mut) do + self.GlobalScope[k] = fixDefault(wire_expression_types2[var.type][2]) + end end if self.context.prfcount + self.context.prf - e2_softquota > e2_hardquota then @@ -457,8 +459,10 @@ function ENT:ResetContext() self.globvars_mut[k] = nil end - for k, var in pairs(self.globvars_mut) do - self.GlobalScope[k] = fixDefault(wire_expression_types2[var.type][2]) + if not self.directives.strict then -- Need to disable this so local variables at top scope don't get reset + for k, var in pairs(self.globvars_mut) do + self.GlobalScope[k] = fixDefault(wire_expression_types2[var.type][2]) + end end for k, v in pairs(self.Inputs) do diff --git a/lua/wire/client/text_editor/modes/e2.lua b/lua/wire/client/text_editor/modes/e2.lua index 20cd0da029..3bb47c49d1 100644 --- a/lua/wire/client/text_editor/modes/e2.lua +++ b/lua/wire/client/text_editor/modes/e2.lua @@ -21,12 +21,12 @@ local keywords = { ["case"] = { [true] = true, [false] = true }, ["default"] = { [true] = true, [false] = true }, ["catch"] = { [true] = true, [false] = true }, + ["function"] = { [true] = true, [false] = true }, -- keywords that cannot be followed by a "(": ["else"] = { [true] = true }, ["break"] = { [true] = true }, ["continue"] = { [true] = true }, - ["function"] = { [true] = true }, ["return"] = { [true] = true }, ["local"] = { [true] = true }, ["let"] = { [true] = true }, @@ -590,7 +590,7 @@ function EDITOR:SyntaxColorLine(row) tokenname = istype(sstr) and "typename" or "notfound" elseif keywords[sstr][keyword] then tokenname = "keyword" - if sstr == "foreach" then + if sstr == "foreach" or sstr == "function" then highlightmode = 3 elseif sstr == "return" and self:NextPattern( "void" ) then addToken( "keyword", "return" )