Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce function returns at compile time #2788

Merged
merged 6 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## SHOULD_FAIL:COMPILE

function string nothing() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
## SHOULD_PASS:COMPILE

function string nothing() {
return "something"
}

function number deadcase() {
if (1) {
return 2158129
} else {
return 2321515
}
}

function number switchcase() {
switch (5) {
case 5,
return 2
default,
return 5
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
## SHOULD_FAIL:COMPILE

function string failure() {
switch (5) {
case 2,
break
default,
break

# 'break' does not return a value or cause a runtime error, just early returns switch.
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
## SHOULD_FAIL:COMPILE

function string failure() {
switch (5) {
case 2,
return "boowomp"
# no default case, compiler can't guarantee that this always runs, fails to compile.
}
}
69 changes: 42 additions & 27 deletions lua/entities/gmod_wire_expression2/base/compiler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
end, "compiler_quota_check")

---@class ScopeData
---@field dead boolean?
---@field dead "ret"|true?
---@field loop boolean?
---@field switch_case boolean?
---@field function { [1]: string, [2]: EnvFunction}?
Expand Down Expand Up @@ -281,8 +281,10 @@
---@param data { [1]: Node?, [2]: Node }[]
[NodeVariant.If] = function (self, trace, data)
local chain = {} ---@type { [1]: RuntimeOperator?, [2]: RuntimeOperator }[]
local dead, els = true, false

for i, ifeif in ipairs(data) do
self:Scope(function()
self:Scope(function(scope)
if ifeif[1] then -- if or elseif
local expr, expr_ty = self:CompileExpr(ifeif[1])

Expand All @@ -301,11 +303,19 @@
self:CompileStmt(ifeif[2])
}
end

dead = dead and scope.data.dead
else -- else block
chain[i] = { nil, self:CompileStmt(ifeif[2]) }
dead, els = dead and scope.data.dead, true
end
end)
end

if els and dead then -- if (0) { return } else { return } mark any code after as dead
self.scope.data.dead = "ret"
end

return function(state) ---@param state RuntimeContext
for _, data in ipairs(chain) do
local cond, block = data[1], data[2]
Expand Down Expand Up @@ -378,7 +388,7 @@

---@param data { [1]: Token<string>, [2]: Node, [3]: Node, [4]: Node?, [5]: Node } var start stop step block
[NodeVariant.For] = function (self, trace, data)
local var, start, stop, step = data[1], self:CompileExpr(data[2]), self:CompileExpr(data[3]), data[4] and self:CompileExpr(data[4]) or data[4]

Check warning on line 391 in lua/entities/gmod_wire_expression2/base/compiler.lua

View workflow job for this annotation

GitHub Actions / lint

"Unused variable"

Unused variable: step

local block = self:Scope(function(scope)
scope.data.loop = true
Expand Down Expand Up @@ -512,14 +522,16 @@
---@param data { [1]: Node, [2]: {[1]: Node, [2]: Node}[], [3]: Node? }
[NodeVariant.Switch] = function (self, trace, data)
local expr, expr_ty = self:CompileExpr(data[1])
local dead = true

local cases = {} ---@type { [1]: RuntimeOperator, [2]: RuntimeOperator }[]
for i, case in ipairs(data[2]) do
local cond, cond_ty = self:CompileExpr(case[1])
local block
self:Scope(function(scope)
local block = self:Scope(function(scope)
scope.data.switch_case = true
block = self:CompileStmt(case[2])
local b = self:CompileStmt(case[2])
dead = dead and scope.data.dead == "ret"
return b
end)

local eq = self:GetOperator("eq", { expr_ty, cond_ty }, case[1].trace)
Expand All @@ -531,7 +543,16 @@
}
end

local default = data[3] and self:Scope(function() return self:CompileStmt(data[3]) end)
local default = data[3] and self:Scope(function(scope)
local b = self:CompileStmt(data[3])
dead = dead and scope.data.dead == "ret"
return b
end)

if dead and default then -- if all cases dead and has default case, mark scope as dead.
self.scope.data.dead = true
end

local ncases = #cases

return function(state) ---@param state RuntimeContext
Expand All @@ -545,9 +566,9 @@

if state.__break__ then
state.__break__ = false
goto exit

Check warning on line 569 in lua/entities/gmod_wire_expression2/base/compiler.lua

View workflow job for this annotation

GitHub Actions / lint

"Goto"

Don't use labels and gotos unless you're jumping out of multiple loops.
elseif state.__return__ then -- Yes this should only be checked if the switch is inside a function, but I don't care enough about the performance of switch case to add another duplicated 30 lines to the file
goto exit

Check warning on line 571 in lua/entities/gmod_wire_expression2/base/compiler.lua

View workflow job for this annotation

GitHub Actions / lint

"Goto"

Don't use labels and gotos unless you're jumping out of multiple loops.
else -- Fallthrough, run every case until break found.
for j = i + 1, ncases do
cases[j][2](state)
Expand Down Expand Up @@ -668,7 +689,7 @@
end
end

local fn = { args = param_types, returns = return_type and { return_type }, meta = meta_type, cost = 20, attrs = {} }
local fn = { args = param_types, returns = return_type and { return_type }, meta = meta_type, cost = variadic_ty and 25 or 10, attrs = {} }
local sig = table.concat(param_types, "", 1, #param_types - 1) .. ((variadic_ty and ".." or "") .. (param_types[#param_types] or ""))

if meta_type then
Expand Down Expand Up @@ -727,12 +748,8 @@

state.Scopes, state.ScopeID, state.Scope = s_scopes, s_scopeid, s_scope

if state.__return__ then
state.__return__ = false
return state.__returnval__
elseif return_type then
state:forceThrow("Expected function return at runtime of type (" .. return_type .. ")")
end
state.__return__ = false
return state.__returnval__
end
else -- table
function fn.op(state, args, arg_types) ---@param state RuntimeContext
Expand All @@ -758,12 +775,8 @@

state.Scopes, state.ScopeID, state.Scope = s_scopes, s_scopeid, s_scope

if state.__return__ then
state.__return__ = false
return state.__returnval__
elseif return_type then
state:forceThrow("Expected function return at runtime of type (" .. return_type .. ")")
end
state.__return__ = false
return state.__returnval__
end
end
else -- Todo: Make this output a different function when it doesn't early return, and/or has no parameters as an optimization.
Expand All @@ -784,23 +797,23 @@

state.Scopes, state.ScopeID, state.Scope = s_scopes, s_scopeid, s_scope

if state.__return__ then
state.__return__ = false
return state.__returnval__
elseif return_type then
state:forceThrow("Expected function function at runtime of type (" .. return_type .. ")")
end
state.__return__ = false
return state.__returnval__
end
end

block = self:IsolatedScope(function (scope)
self:IsolatedScope(function (scope)
for i, type in ipairs(param_types) do
scope:DeclVar(param_names[i], { type = type, trace_if_unused = data[4][i] and data[4][i].name.trace or trace, initialized = true })
end

scope.data["function"] = { name.value, fn }

return self:CompileStmt(data[5])
block = self:CompileStmt(data[5])

if return_type then -- Ensure function either returns or errors
self:Assert(scope.data.dead, "This function marked to return '" .. data[1].value .. "' must return a value", data[1].trace)
end
end)

self:Assert((fn.returns and fn.returns[1]) == return_type, "Function " .. name.value .. " expects to return type (" .. (return_type or "void") .. ") but got type (" .. ((fn.returns and fn.returns[1]) or "void") .. ")", trace)
Expand Down Expand Up @@ -888,6 +901,8 @@
local fn = self.scope:ResolveData("function")
self:Assert(fn, "Cannot use `return` outside of a function", trace)

self.scope.data.dead = "ret"

local retval, ret_ty
if data then
retval, ret_ty = self:CompileExpr(data)
Expand Down
Loading