Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make functions a compile time construct on @strict #2789

Merged
merged 5 commits into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
## SHOULD_FAIL:COMPILE

@strict

function test() {}

function test() {} # ERROR!
117 changes: 117 additions & 0 deletions data/expression2/tests/runtime/base/userfunctions/functions_const.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
## SHOULD_PASS:EXECUTE

@strict

# Ensure functions get called in the first place

Called = 0
function myfunction() {
Called = 1
}

myfunction()

assert(Called)


local X = 500
local Y = 1000
local Z = 5000

# Ensure function scoping doesn't affect outer scope

function test(X, Y, Z) {
assert(X == 1)
assert(Y == 2)
assert(Z == 3)
}

test(1, 2, 3)

assert(X == 500)
assert(Y == 1000)
assert(Z == 5000)

# Ensure functions return properly

function number returning() {
return 5
}

assert(returning() == 5)

function number returning2(X:array) {
return X[1, number] + 5
}

assert(returning2(array(5)) == 10)
assert(returning2(array()) == 5)

function array returningref(X:array) {
return X
}

local A = array()
assert(returningref(A):id() == A:id())

function returnvoid() {
if (1) { return }
error("unreachable")
}

returnvoid()

function void returnvoid2() {
return
}

returnvoid2()

function returnvoid3() {
return void
}

returnvoid3()

# Test recursion

function number recurse(N:number) {
if (N == 1) {
return 5
} else {
return recurse(N - 1) + 1
}
}

assert(recurse(10) == 14, recurse(10):toString())

Sentinel = -1
function recursevoid() {
Sentinel++
if (Sentinel == 0) {
recursevoid()
}
}

recursevoid()

assert(Sentinel == 1)

function number nilInput(X, Y:ranger, Z:vector) {
assert(Z == vec(1, 2, 3))
return 5
}

assert( nilInput(1, noranger(), vec(1, 2, 3)) == 5 )

Ran = 0

if (0) {
function constant() {
Ran = 1
}
}

constant()

assert(Ran)
120 changes: 120 additions & 0 deletions data/expression2/tests/runtime/base/userfunctions/methods_const.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
## SHOULD_PASS:EXECUTE

@strict

# Ensure methods get called in the first place

Called = 0
function number:mymethod() {
Called = 1
}

1:mymethod()

assert(Called)

local This = 10
local X = 500
local Y = 1000
local Z = 5000

# Ensure function scoping doesn't affect outer scope

function number number:method(X, Y, Z) {
assert(This == 500)
assert(X == 1)
assert(Y == 2)
assert(Z == 4)

return 5
}

assert( 500:method(1, 2, 4) == 5 )

assert(This == 10)
assert(X == 500)
assert(Y == 1000)
assert(Z == 5000)

# Ensure functions return properly

function number number:returning() {
return 5
}

assert(1:returning() == 5)

function number number:returning2(X:array) {
return X[1, number] + 5
}

assert(1:returning2(array(5)) == 10)
assert(1:returning2(array()) == 5)

function array number:returningref(X:array) {
return X
}

local A = array()
assert(1:returningref(A):id() == A:id())

function number:returnvoid() {
if (1) { return }
}

1:returnvoid()

function void number:returnvoid2() {
return
}

1:returnvoid2()

function number:returnvoid3() {
return void
}

1:returnvoid3()

# Test recursion

function number number:recurse(N:number) {
if (N == 1) {
return 5
} else {
return This:recurse(N - 1) + 1
}
}

assert(1:recurse(10) == 14, 1:recurse(10):toString())

Sentinel = -1
function number:recursevoid() {
Sentinel++
if (Sentinel == 0) {
This:recursevoid()
}
}

1:recursevoid()

assert(Sentinel == 1)

function number number:nilInput(X, Y:ranger, Z:vector) {
assert(Z == vec(1, 2, 3))
return 5
}

assert( 1:nilInput(1, noranger(), vec(1, 2, 3)) == 5 )

Ran = 0

if (0) {
function number:constant() {
Ran = 1
}
}

1:constant()

assert(Ran)
37 changes: 24 additions & 13 deletions lua/entities/gmod_wire_expression2/base/compiler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ end
---@field persist IODirective
---@field inputs IODirective
---@field outputs IODirective
---@field strict boolean
local Compiler = {}
Compiler.__index = Compiler

Expand All @@ -100,7 +101,7 @@ end
function Compiler.from(directives, dvars, includes)
local global_scope = Scope.new()
return setmetatable({
persist = directives.persist, inputs = directives.inputs, outputs = directives.outputs,
persist = directives.persist, inputs = directives.inputs, outputs = directives.outputs, strict = directives.strict,
global_scope = global_scope, scope = global_scope, warnings = {}, registered_events = {}, user_functions = {}, user_methods = {},
delta_vars = dvars or {}, includes = includes or {}
}, Compiler)
Expand Down Expand Up @@ -671,6 +672,10 @@ local CompileVisitors = {
end
end

if self.strict and not self.scope:IsGlobalScope() then
self:Warning("Functions should be in the top scope, nesting them does nothing", trace)
end

local fn_data, lookup_variadic, userfunction = self:GetFunction(name.value, param_types, meta_type)
if fn_data then
if not userfunction then
Expand All @@ -684,12 +689,15 @@ local CompileVisitors = {
self:Assert(fn_data.returns == nil, "Cannot override function returning void with differing return type", trace)
end

-- Tag function if it is ever re-declared. Used as an optimization
fn_data.const = fn_data.op == nil
if not self.strict then
self:Warning("Do not override functions. This is a hard error with @strict.", trace)
else
self:Error("Cannot override existing function '" .. name.value .. "'", trace)
end
end
end

local fn = { args = param_types, returns = return_type and { return_type }, meta = meta_type, cost = variadic_ty and 25 or 10, attrs = {} }
local fn = { args = param_types, returns = return_type and { return_type }, meta = meta_type, cost = variadic_ty and 10 or 5 + (self.strict and 0 or 3), attrs = {} }
local sig = table.concat(param_types, "", 1, #param_types - 1) .. ((variadic_ty and ".." or "") .. (param_types[#param_types] or ""))

if meta_type then
Expand Down Expand Up @@ -821,9 +829,11 @@ local CompileVisitors = {
local sig = name.value .. "(" .. (meta_type and (meta_type .. ":") or "") .. sig .. ")"
local fn = fn.op

return function(state) ---@param state RuntimeContext
state.funcs[sig] = fn
state.funcs_ret[sig] = return_type
if not self.strict then
return function(state) ---@param state RuntimeContext
state.funcs[sig] = fn
state.funcs_ret[sig] = return_type
end
end
end,

Expand Down Expand Up @@ -1448,7 +1458,6 @@ local CompileVisitors = {
self:Warning("Use of deprecated function (" .. name.value .. ") " .. (type(value) == "string" and value or ""), trace)
end

self.scope.data.ops = self.scope.data.ops + ((fn_data.cost or 15) + (fn_data.attrs["legacy"] and 10 or 0))

if fn_data.attrs["noreturn"] then
self.scope.data.dead = true
Expand All @@ -1457,8 +1466,9 @@ local CompileVisitors = {
local nargs = #args
local user_function = self.user_functions[name.value] and self.user_functions[name.value][arg_sig]
if user_function then
-- Calling a user function - chance of being overridden. Also not legacy.
if user_function.const then
if self.strict then -- If @strict, functions are compile time constructs (like events).
self.scope.data.ops = self.scope.data.ops + fn_data.cost

local fn = user_function.op
return function(state)
local rargs = {}
Expand All @@ -1468,6 +1478,8 @@ local CompileVisitors = {
return fn(state, rargs, types)
end, fn_data.returns and (fn_data.returns[1] ~= "" and fn_data.returns[1] or nil)
else
self.scope.data.ops = self.scope.data.ops + (fn_data.cost or 15) + (fn_data.attrs["legacy"] and 10 or 0)

local full_sig = name.value .. "(" .. arg_sig .. ")"
return function(state) ---@param state RuntimeContext
local rargs = {}
Expand Down Expand Up @@ -1526,16 +1538,15 @@ local CompileVisitors = {
local nargs = #args
local user_method = self.user_methods[meta_type] and self.user_methods[meta_type][name.value] and self.user_methods[meta_type][name.value][arg_sig]
if user_method then
-- Calling a user function - chance of being overridden. Also not legacy.
if user_method.const then
if self.strict then -- If @strict, functions are compile time constructs (like events).
local fn = user_method.op
return function(state)
local rargs = { meta(state) }
for k = 1, nargs do
rargs[k + 1] = args[k](state)
end
return fn(state, rargs, types)
end
end, fn_data.returns and (fn_data.returns[1] ~= "" and fn_data.returns[1] or nil)
else
local full_sig = name.value .. "(" .. meta_type .. ":" .. arg_sig .. ")"
return function(state) ---@param state RuntimeContext
Expand Down