diff --git a/spec/types_spec.lua b/spec/types_spec.lua index 9d87380f..0adca7c1 100644 --- a/spec/types_spec.lua +++ b/spec/types_spec.lua @@ -8,41 +8,41 @@ local types = require "pallene.types" describe("Pallene types", function() it("pretty-prints types", function() - assert.same("{ integer }", types.tostring(types.T.Array(types.T.Integer()))) + assert.same("{ integer }", types.tostring(types.T.Array(types.T.Integer))) assert.same("{ x: float, y: float }", types.tostring( - types.T.Table({x = types.T.Float(), y = types.T.Float()}))) + types.T.Table({x = types.T.Float, y = types.T.Float}))) end) it("is_gc works", function() - assert.falsy(types.is_gc(types.T.Integer())) - assert.truthy(types.is_gc(types.T.String())) - assert.truthy(types.is_gc(types.T.Array(types.T.Integer()))) - assert.truthy(types.is_gc(types.T.Table({x = types.T.Float()}))) + assert.falsy(types.is_gc(types.T.Integer)) + assert.truthy(types.is_gc(types.T.String)) + assert.truthy(types.is_gc(types.T.Array(types.T.Integer))) + assert.truthy(types.is_gc(types.T.Table({x = types.T.Float}))) assert.truthy(types.is_gc(types.T.Function({}, {}))) end) describe("equality", function() it("works for primitive types", function() - assert.truthy(types.equals(types.T.Integer(), types.T.Integer())) - assert.falsy(types.equals(types.T.Integer(), types.T.String())) + assert.truthy(types.equals(types.T.Integer, types.T.Integer)) + assert.falsy(types.equals(types.T.Integer, types.T.String)) end) it("is true for two identical tables", function() local t1 = types.T.Table({ - y = types.T.Integer(), x = types.T.Integer()}) + y = types.T.Integer, x = types.T.Integer}) local t2 = types.T.Table({ - x = types.T.Integer(), y = types.T.Integer()}) + x = types.T.Integer, y = types.T.Integer}) assert.truthy(types.equals(t1, t2)) assert.truthy(types.equals(t2, t1)) end) it("is false for tables with different number of fields", function() - local t1 = types.T.Table({x = types.T.Integer()}) - local t2 = types.T.Table({x = types.T.Integer(), - y = types.T.Integer()}) - local t3 = types.T.Table({x = types.T.Integer(), - y = types.T.Integer(), z = types.T.Integer()}) + local t1 = types.T.Table({x = types.T.Integer}) + local t2 = types.T.Table({x = types.T.Integer, + y = types.T.Integer}) + local t3 = types.T.Table({x = types.T.Integer, + y = types.T.Integer, z = types.T.Integer}) assert.falsy(types.equals(t1, t2)) assert.falsy(types.equals(t2, t1)) assert.falsy(types.equals(t2, t3)) @@ -52,39 +52,39 @@ describe("Pallene types", function() end) it("is false for tables with different field names", function() - local t1 = types.T.Table({x = types.T.Integer()}) - local t2 = types.T.Table({y = types.T.Integer()}) + local t1 = types.T.Table({x = types.T.Integer}) + local t2 = types.T.Table({y = types.T.Integer}) assert.falsy(types.equals(t1, t2)) assert.falsy(types.equals(t2, t1)) end) it("is false for tables with different field types", function() - local t1 = types.T.Table({x = types.T.Integer()}) - local t2 = types.T.Table({x = types.T.Float()}) + local t1 = types.T.Table({x = types.T.Integer}) + local t2 = types.T.Table({x = types.T.Float}) assert.falsy(types.equals(t1, t2)) assert.falsy(types.equals(t2, t1)) end) it("is true for identical functions", function() - local f1 = types.T.Function({types.T.String(), types.T.Integer()}, {types.T.Boolean()}) - local f2 = types.T.Function({types.T.String(), types.T.Integer()}, {types.T.Boolean()}) + local f1 = types.T.Function({types.T.String, types.T.Integer}, {types.T.Boolean}) + local f2 = types.T.Function({types.T.String, types.T.Integer}, {types.T.Boolean}) assert.truthy(types.equals(f1, f2)) end) it("is false for functions with different input types", function() - local f1 = types.T.Function({types.T.String(), types.T.Boolean()}, {types.T.Boolean()}) - local f2 = types.T.Function({types.T.Integer(), types.T.Integer()}, {types.T.Boolean()}) + local f1 = types.T.Function({types.T.String, types.T.Boolean}, {types.T.Boolean}) + local f2 = types.T.Function({types.T.Integer, types.T.Integer}, {types.T.Boolean}) assert.falsy(types.equals(f1, f2)) end) it("is false for functions with different output types", function() - local f1 = types.T.Function({types.T.String(), types.T.Integer()}, {types.T.Boolean()}) - local f2 = types.T.Function({types.T.String(), types.T.Integer()}, {types.T.Integer()}) + local f1 = types.T.Function({types.T.String, types.T.Integer}, {types.T.Boolean}) + local f2 = types.T.Function({types.T.String, types.T.Integer}, {types.T.Integer}) assert.falsy(types.equals(f1, f2)) end) it("is false for functions with different input arity", function() - local s = types.T.String() + local s = types.T.String local f1 = types.T.Function({}, {s}) local f2 = types.T.Function({s}, {s}) local f3 = types.T.Function({s, s}, {s}) @@ -97,7 +97,7 @@ describe("Pallene types", function() end) it("is false for functions with different output arity", function() - local s = types.T.String() + local s = types.T.String local f1 = types.T.Function({s}, {}) local f2 = types.T.Function({s}, {s}) local f3 = types.T.Function({s}, {s, s}) @@ -123,42 +123,42 @@ describe("Pallene types", function() describe("consistency", function() it("allows 'any' on either side", function() - assert.truthy(types.consistent(types.T.Any(), types.T.Any())) - assert.truthy(types.consistent(types.T.Any(), types.T.Integer())) - assert.truthy(types.consistent(types.T.Integer(), types.T.Any())) + assert.truthy(types.consistent(types.T.Any, types.T.Any)) + assert.truthy(types.consistent(types.T.Any, types.T.Integer)) + assert.truthy(types.consistent(types.T.Integer, types.T.Any)) end) it("allows types with same tag", function() assert.truthy(types.consistent( - types.T.Integer(), - types.T.Integer() + types.T.Integer, + types.T.Integer )) assert.truthy(types.consistent( - types.T.Array(types.T.Integer()), - types.T.Array(types.T.Integer()) + types.T.Array(types.T.Integer), + types.T.Array(types.T.Integer) )) assert.truthy(types.consistent( - types.T.Array(types.T.Integer()), - types.T.Array(types.T.String()) + types.T.Array(types.T.Integer), + types.T.Array(types.T.String) )) assert.truthy(types.consistent( - types.T.Function({types.T.Integer()}, {types.T.Integer()}), - types.T.Function({types.T.String(), types.T.String()}, {}) + types.T.Function({types.T.Integer}, {types.T.Integer}), + types.T.Function({types.T.String, types.T.String}, {}) )) end) it("forbids different tags", function() assert.falsy(types.consistent( - types.T.Integer(), - types.T.String() + types.T.Integer, + types.T.String )) assert.falsy(types.consistent( - types.T.Array(types.T.Integer()), - types.T.Function({types.T.Integer()},{types.T.Integer()}) + types.T.Array(types.T.Integer), + types.T.Function({types.T.Integer},{types.T.Integer}) )) end) end) diff --git a/src/pallene/builtins.lua b/src/pallene/builtins.lua index a433026b..ba44279b 100644 --- a/src/pallene/builtins.lua +++ b/src/pallene/builtins.lua @@ -10,38 +10,38 @@ local builtins = {} -- TODO: It will be easier to read this is we could write down the types using the normal grammar -local ipairs_itertype = T.Function({T.Any(), T.Any()}, {T.Any(), T.Any()}) +local ipairs_itertype = T.Function({T.Any, T.Any}, {T.Any, T.Any}) builtins.functions = { - type = T.Function({ T.Any() }, { T.String() }), - tostring = T.Function({ T.Any() }, { T.String() }), - ipairs = T.Function({T.Array(T.Any())}, {ipairs_itertype, T.Any(), T.Any()}) + type = T.Function({ T.Any }, { T.String }), + tostring = T.Function({ T.Any }, { T.String }), + ipairs = T.Function({T.Array(T.Any)}, {ipairs_itertype, T.Any, T.Any}) } builtins.modules = { io = { - write = T.Function({ T.String() }, {}), + write = T.Function({ T.String }, {}), }, math = { - abs = T.Function({ T.Float() }, { T.Float() }), - ceil = T.Function({ T.Float() }, { T.Integer() }), - floor = T.Function({ T.Float() }, { T.Integer() }), - fmod = T.Function({ T.Float(), T.Float() }, { T.Float() }), - exp = T.Function({ T.Float() }, { T.Float() }), - ln = T.Function({ T.Float() }, { T.Float() }), - log = T.Function({ T.Float(), T.Float() }, { T.Float() }), - modf = T.Function({ T.Float() }, { T.Integer(), T.Float() }), - pow = T.Function({ T.Float(), T.Float() }, { T.Float() }), - sqrt = T.Function({ T.Float() }, { T.Float() }), + abs = T.Function({ T.Float }, { T.Float }), + ceil = T.Function({ T.Float }, { T.Integer }), + floor = T.Function({ T.Float }, { T.Integer }), + fmod = T.Function({ T.Float, T.Float }, { T.Float }), + exp = T.Function({ T.Float }, { T.Float }), + ln = T.Function({ T.Float }, { T.Float }), + log = T.Function({ T.Float, T.Float }, { T.Float }), + modf = T.Function({ T.Float }, { T.Integer, T.Float }), + pow = T.Function({ T.Float, T.Float }, { T.Float }), + sqrt = T.Function({ T.Float }, { T.Float }), -- constant numbers - huge = T.Float(), - mininteger = T.Integer(), - maxinteger = T.Integer(), - pi = T.Float(), + huge = T.Float, + mininteger = T.Integer, + maxinteger = T.Integer, + pi = T.Float, }, string = { - char = T.Function({ T.Integer() }, { T.String() }), - sub = T.Function({ T.String(), T.Integer(), T.Integer() }, { T.String() }), + char = T.Function({ T.Integer }, { T.String }), + sub = T.Function({ T.String, T.Integer, T.Integer }, { T.String }), }, } diff --git a/src/pallene/coder.lua b/src/pallene/coder.lua index 7104e959..b6f5fea0 100644 --- a/src/pallene/coder.lua +++ b/src/pallene/coder.lua @@ -344,7 +344,7 @@ function Coder:c_value(value) return C.float(value.value) elseif tag == "ir.Value.String" then local str = value.value - return lua_value(types.T.String(), self:string_upvalue_slot(str)) + return lua_value(types.T.String, self:string_upvalue_slot(str)) elseif tag == "ir.Value.LocalVar" then return self:c_var(value.id) elseif tag == "ir.Value.Upvalue" then @@ -1254,7 +1254,7 @@ gen_cmd["SetTable"] = function(self, cmd, _func) tab = tab, key = key, val = val, - init_keyv = set_stack_slot(types.T.String(), "&keyv", key), + init_keyv = set_stack_slot(types.T.String, "&keyv", key), init_valv = set_stack_slot(src_typ, "&valv", val), -- Here we use set_stack_slot slot on a heap object, because -- we call the barrier by hand outside the if statement. diff --git a/src/pallene/ir.lua b/src/pallene/ir.lua index 365c1b91..dbf13186 100644 --- a/src/pallene/ir.lua +++ b/src/pallene/ir.lua @@ -351,7 +351,7 @@ function ir.clean(cmd) end end if #out == 0 then - return ir.Cmd.Nop() + return ir.Cmd.Nop elseif #out == 1 then return out[1] else @@ -366,7 +366,7 @@ function ir.clean(cmd) local e_empty = (cmd.else_._tag == "ir.Cmd.Nop") if t_empty and e_empty then - return ir.Cmd.Nop() + return ir.Cmd.Nop elseif v._tag == "ir.Value.Bool" and v.value == true then return cmd.then_ elseif v._tag == "ir.Value.Bool" and v.value == false then diff --git a/src/pallene/to_ir.lua b/src/pallene/to_ir.lua index f6c69234..6071ab8a 100644 --- a/src/pallene/to_ir.lua +++ b/src/pallene/to_ir.lua @@ -283,7 +283,7 @@ function ToIR:convert_toplevel(prog_ast) local exports_type = types.T.Table({}) self.module.loc_id_of_exports = ir.add_local(self.func, "$exports", exports_type) table.insert(cmds, ir.Cmd.NewTable(self.func.loc, self.module.loc_id_of_exports, ir.Value.Integer(n_exports))) - table.insert(cmds, ir.Cmd.CheckGC()) + table.insert(cmds, ir.Cmd.CheckGC) -- export the functions for _, f_id in ipairs(self.module.exported_functions) do @@ -324,7 +324,7 @@ function ToIR:convert_stat(cmds, stat) local body = {} local cond = self:exp_to_value(body, stat.condition) local condBool = self:value_is_truthy(body, stat.condition, cond) - table.insert(body, ir.Cmd.If(stat.loc, condBool, ir.Cmd.Nop(), ir.Cmd.Break())) + table.insert(body, ir.Cmd.If(stat.loc, condBool, ir.Cmd.Nop, ir.Cmd.Break)) self:convert_stat(body, stat.block) table.insert(cmds, ir.Cmd.Loop(ir.Cmd.Seq(body))) @@ -333,7 +333,7 @@ function ToIR:convert_stat(cmds, stat) self:convert_stat(body, stat.block) local cond = self:exp_to_value(body, stat.condition) local condBool = self:value_is_truthy(body, stat.condition, cond) - table.insert(body, ir.Cmd.If(stat.loc, condBool, ir.Cmd.Break(), ir.Cmd.Nop())) + table.insert(body, ir.Cmd.If(stat.loc, condBool, ir.Cmd.Break, ir.Cmd.Nop)) table.insert(cmds, ir.Cmd.Loop(ir.Cmd.Seq(body))) elseif tag == "ast.Stat.If" then @@ -398,12 +398,12 @@ function ToIR:convert_stat(cmds, stat) -- the table passed as argument to `ipairs` local arr = ipairs_args[1] - assert(types.equals(arr._type, types.T.Array(types.T.Any()))) + assert(types.equals(arr._type, types.T.Array(types.T.Any))) local v_arr = ir.add_local(self.func, "$xs", arr._type) self:exp_to_assignment(cmds, v_arr, arr) -- local i_num: integer = 1 - local v_inum = ir.add_local(self.func, "$"..decls[1].name.."_num", types.T.Integer()) + local v_inum = ir.add_local(self.func, "$"..decls[1].name.."_num", types.T.Integer) local start = ir.Value.Integer(1) table.insert(cmds, ir.Cmd.Move(stat.loc, v_inum, start)) @@ -411,15 +411,15 @@ function ToIR:convert_stat(cmds, stat) local body = {} -- x_dyn = xs[i_num] - local v_x_dyn = ir.add_local(self.func, "$"..decls[2].name.."_dyn", types.T.Any()) + local v_x_dyn = ir.add_local(self.func, "$"..decls[2].name.."_dyn", types.T.Any) local src_arr = ir.Value.LocalVar(v_arr) local src_i = ir.Value.LocalVar(v_inum) - table.insert(body, ir.Cmd.GetArr(stat.loc, types.T.Any(), v_x_dyn, src_arr, src_i)) + table.insert(body, ir.Cmd.GetArr(stat.loc, types.T.Any, v_x_dyn, src_arr, src_i)) -- if x_dyn == nil then break end - local v_cond_checknil = ir.add_local(self.func, false, types.T.Boolean()) + local v_cond_checknil = ir.add_local(self.func, false, types.T.Boolean) table.insert(body, ir.Cmd.IsNil(stat.loc, v_cond_checknil, ir.Value.LocalVar(v_x_dyn))) - table.insert(body, ir.Cmd.If(stat.loc, ir.Value.LocalVar(v_cond_checknil), ir.Cmd.Break(), ir.Cmd.Nop())) + table.insert(body, ir.Cmd.If(stat.loc, ir.Value.LocalVar(v_cond_checknil), ir.Cmd.Break, ir.Cmd.Nop)) -- local i: T1 = i_num as T1 local v_i = ir.add_local(self.func, decls[1].name, decls[1]._type) @@ -427,7 +427,7 @@ function ToIR:convert_stat(cmds, stat) if decls[1]._type._tag == "types.T.Integer" then table.insert(body, ir.Cmd.Move(stat.loc, v_i, ir.Value.LocalVar(v_inum))) else - table.insert(body, ir.Cmd.ToDyn(stat.loc, types.T.Integer(), v_i, ir.Value.LocalVar(v_inum))) + table.insert(body, ir.Cmd.ToDyn(stat.loc, types.T.Integer, v_i, ir.Value.LocalVar(v_inum))) end -- local x = x_dyn as T2 @@ -476,7 +476,7 @@ function ToIR:convert_stat(cmds, stat) local v_lhs_dyn = {} for _, decl in ipairs(decls) do - local v = ir.add_local(self.func, "$" .. decl.name .. "_dyn", types.T.Any()) + local v = ir.add_local(self.func, "$" .. decl.name .. "_dyn", types.T.Any) table.insert(v_lhs_dyn, v) end @@ -491,9 +491,9 @@ function ToIR:convert_stat(cmds, stat) table.insert(body, ir.Cmd.CallDyn(exps[1].loc, itertype, v_lhs_dyn, ir.Value.LocalVar(v_iter), args)) -- if i == nil then break end - local v_cond = ir.add_local(self.func, false, types.T.Boolean()) + local v_cond = ir.add_local(self.func, false, types.T.Boolean) table.insert(body, ir.Cmd.IsNil(stat.loc, v_cond, ir.Value.LocalVar(v_lhs_dyn[1]))) - table.insert(body, ir.Cmd.If(stat.loc, ir.Value.LocalVar(v_cond), ir.Cmd.Break(), ir.Cmd.Nop())) + table.insert(body, ir.Cmd.If(stat.loc, ir.Value.LocalVar(v_cond), ir.Cmd.Break, ir.Cmd.Nop)) -- cast loop LHS to annotated types. for i, decl in ipairs(decls) do @@ -671,7 +671,7 @@ function ToIR:convert_stat(cmds, stat) table.insert(cmds, ir.Cmd.Return(stat.loc, vals)) elseif tag == "ast.Stat.Break" then - table.insert(cmds, ir.Cmd.Break()) + table.insert(cmds, ir.Cmd.Break) elseif tag == "ast.Stat.Functions" then @@ -842,7 +842,7 @@ end function ToIR:exp_to_value(cmds, exp, is_recursive) local tag = exp._tag if tag == "ast.Exp.Nil" then - return ir.Value.Nil() + return ir.Value.Nil elseif tag == "ast.Exp.Bool" then return ir.Value.Bool(exp.value) @@ -929,7 +929,7 @@ function ToIR:exp_to_assignment(cmds, dst, exp) if typ._tag == "types.T.Array" then local n = ir.Value.Integer(#exp.fields) table.insert(cmds, ir.Cmd.NewArr(loc, dst, n)) - table.insert(cmds, ir.Cmd.CheckGC()) + table.insert(cmds, ir.Cmd.CheckGC) for i, field in ipairs(exp.fields) do assert(field._tag == "ast.Field.List") local av = ir.Value.LocalVar(dst) @@ -942,7 +942,7 @@ function ToIR:exp_to_assignment(cmds, dst, exp) elseif typ._tag == "types.T.Table" then local n = ir.Value.Integer(#exp.fields) table.insert(cmds, ir.Cmd.NewTable(loc, dst, n)) - table.insert(cmds, ir.Cmd.CheckGC()) + table.insert(cmds, ir.Cmd.CheckGC) for _, field in ipairs(exp.fields) do assert(field._tag == "ast.Field.Rec") local tv = ir.Value.LocalVar(dst) @@ -960,7 +960,7 @@ function ToIR:exp_to_assignment(cmds, dst, exp) end table.insert(cmds, ir.Cmd.NewRecord(loc, typ, dst)) - table.insert(cmds, ir.Cmd.CheckGC()) + table.insert(cmds, ir.Cmd.CheckGC) for _, field_name in ipairs(typ.field_names) do local f_exp = assert(field_exps[field_name]) local dv = ir.Value.LocalVar(dst) @@ -982,7 +982,7 @@ function ToIR:exp_to_assignment(cmds, dst, exp) assert(typ.is_upvalue_box) table.insert(cmds, ir.Cmd.NewRecord(loc, typ, dst)) - table.insert(cmds, ir.Cmd.CheckGC()) + table.insert(cmds, ir.Cmd.CheckGC) elseif tag == "ast.Exp.Lambda" then local f_id = self:register_lambda(exp, "$lambda") @@ -1079,11 +1079,11 @@ function ToIR:exp_to_assignment(cmds, dst, exp) elseif bname == "string.char" then assert(#xs == 1) table.insert(cmds, ir.Cmd.BuiltinStringChar(loc, dsts, xs)) - table.insert(cmds, ir.Cmd.CheckGC()) + table.insert(cmds, ir.Cmd.CheckGC) elseif bname == "string.sub" then assert(#xs == 3) table.insert(cmds, ir.Cmd.BuiltinStringSub(loc, dsts, xs)) - table.insert(cmds, ir.Cmd.CheckGC()) + table.insert(cmds, ir.Cmd.CheckGC) elseif bname == "type" then assert(#xs == 1) table.insert(cmds, ir.Cmd.BuiltinType(loc, dsts, xs)) @@ -1169,7 +1169,7 @@ function ToIR:exp_to_assignment(cmds, dst, exp) xs[i] = self:exp_to_value(cmds, x_exp) end table.insert(cmds, ir.Cmd.Concat(loc, dst, xs)) - table.insert(cmds, ir.Cmd.CheckGC()) + table.insert(cmds, ir.Cmd.CheckGC) elseif tag == "ast.Exp.Binop" then local op = exp.op @@ -1252,7 +1252,7 @@ function ToIR:value_is_truthy(cmds, exp, val) if typ._tag == "types.T.Boolean" then return val elseif typ._tag == "types.T.Any" then - local b = ir.add_local(self.func, false, types.T.Boolean()) + local b = ir.add_local(self.func, false, types.T.Boolean) table.insert(cmds, ir.Cmd.IsTruthy(exp.loc, b, val)) return ir.Value.LocalVar(b) elseif typedecl.tag_is_type(typ) then diff --git a/src/pallene/typechecker.lua b/src/pallene/typechecker.lua index e4708f0d..6a93b1f3 100644 --- a/src/pallene/typechecker.lua +++ b/src/pallene/typechecker.lua @@ -170,7 +170,7 @@ end function Typechecker:from_ast_type(ast_typ) local tag = ast_typ._tag if tag == "ast.Type.Nil" then - return types.T.Nil() + return types.T.Nil elseif tag == "ast.Type.Name" then local name = ast_typ.name @@ -231,11 +231,11 @@ function Typechecker:check_program(prog_ast) local module_name = prog_ast.module_name -- 1) Add primitive types to the symbol table - self:add_type_symbol("any", types.T.Any()) - self:add_type_symbol("boolean", types.T.Boolean()) - self:add_type_symbol("float", types.T.Float()) - self:add_type_symbol("integer", types.T.Integer()) - self:add_type_symbol("string", types.T.String()) + self:add_type_symbol("any", types.T.Any) + self:add_type_symbol("boolean", types.T.Boolean) + self:add_type_symbol("float", types.T.Float) + self:add_type_symbol("integer", types.T.Integer) + self:add_type_symbol("string", types.T.String) -- 2) Add builtins to symbol table. -- The order does not matter because they are distinct. @@ -249,7 +249,7 @@ function Typechecker:check_program(prog_ast) local id = mod_name .. "." .. fun_name symbols[fun_name] = typechecker.Symbol.Value(typ, typechecker.Def.Builtin(id)) end - local typ = (mod_name == "string") and types.T.String() or false + local typ = (mod_name == "string") and types.T.String or false self:add_module_symbol(mod_name, typ, symbols) end @@ -424,10 +424,10 @@ function Typechecker:check_stat(stat, is_toplevel) local decl_types = {} for _ = 1, #stat.decls do - table.insert(decl_types, types.T.Any()) + table.insert(decl_types, types.T.Any) end - local itertype = types.T.Function({ types.T.Any(), types.T.Any() }, decl_types) + local itertype = types.T.Function({ types.T.Any, types.T.Any }, decl_types) rhs[1] = self:check_exp_synthesize(rhs[1]) local iteratorfn = rhs[1] @@ -715,7 +715,7 @@ function Typechecker:check_var(var) "expected array but found %s in indexed expression", types.tostring(arr_type)) end - var.k = self:check_exp_verify(var.k, types.T.Integer(), "array index") + var.k = self:check_exp_verify(var.k, types.T.Integer, "array index") var._type = arr_type.elem else @@ -795,19 +795,19 @@ function Typechecker:check_exp_synthesize(exp) local tag = exp._tag if tag == "ast.Exp.Nil" then - exp._type = types.T.Nil() + exp._type = types.T.Nil elseif tag == "ast.Exp.Bool" then - exp._type = types.T.Boolean() + exp._type = types.T.Boolean elseif tag == "ast.Exp.Integer" then - exp._type = types.T.Integer() + exp._type = types.T.Integer elseif tag == "ast.Exp.Float" then - exp._type = types.T.Float() + exp._type = types.T.Float elseif tag == "ast.Exp.String" then - exp._type = types.T.String() + exp._type = types.T.String elseif tag == "ast.Exp.InitList" then type_error(exp.loc, "missing type hint for initializer") @@ -829,7 +829,7 @@ function Typechecker:check_exp_synthesize(exp) "trying to take the length of a %s instead of an array or string", types.tostring(t)) end - exp._type = types.T.Integer() + exp._type = types.T.Integer elseif op == "-" then if t._tag ~= "types.T.Integer" and t._tag ~= "types.T.Float" then type_error(exp.loc, @@ -843,10 +843,10 @@ function Typechecker:check_exp_synthesize(exp) "trying to bitwise negate a %s instead of an integer", types.tostring(t)) end - exp._type = types.T.Integer() + exp._type = types.T.Integer elseif op == "not" then check_type_is_condition(exp.exp, "'not' operator") - exp._type = types.T.Boolean() + exp._type = types.T.Boolean else typedecl.tag_error(op) end @@ -870,7 +870,7 @@ function Typechecker:check_exp_synthesize(exp) "cannot compare %s and %s using %s", types.tostring(t1), types.tostring(t2), op) end - exp._type = types.T.Boolean() + exp._type = types.T.Boolean elseif op == "<" or op == ">" or op == "<=" or op == ">=" then if (t1._tag == "types.T.Integer" and t2._tag == "types.T.Integer") or @@ -888,7 +888,7 @@ function Typechecker:check_exp_synthesize(exp) "cannot compare %s and %s using %s", types.tostring(t1), types.tostring(t2), op) end - exp._type = types.T.Boolean() + exp._type = types.T.Boolean elseif op == "+" or op == "-" or op == "*" or op == "%" or op == "//" then if not is_numeric_type(t1) then @@ -905,11 +905,11 @@ function Typechecker:check_exp_synthesize(exp) if t1._tag == "types.T.Integer" and t2._tag == "types.T.Integer" then - exp._type = types.T.Integer() + exp._type = types.T.Integer else exp.lhs = self:coerce_numeric_exp_to_float(exp.lhs) exp.rhs = self:coerce_numeric_exp_to_float(exp.rhs) - exp._type = types.T.Float() + exp._type = types.T.Float end elseif op == "/" or op == "^" then @@ -926,7 +926,7 @@ function Typechecker:check_exp_synthesize(exp) exp.lhs = self:coerce_numeric_exp_to_float(exp.lhs) exp.rhs = self:coerce_numeric_exp_to_float(exp.rhs) - exp._type = types.T.Float() + exp._type = types.T.Float elseif op == ".." then -- The arguments to '..' must be a strings. We do not allow "any" because Pallene does @@ -937,7 +937,7 @@ function Typechecker:check_exp_synthesize(exp) if t2._tag ~= "types.T.String" then type_error(exp.loc, "cannot concatenate with %s value", types.tostring(t2)) end - exp._type = types.T.String() + exp._type = types.T.String elseif op == "and" or op == "or" then check_type_is_condition(exp.lhs, "first operand of '%s'", op) @@ -955,7 +955,7 @@ function Typechecker:check_exp_synthesize(exp) "right-hand side of bitwise expression is a %s instead of an integer", types.tostring(t2)) end - exp._type = types.T.Integer() + exp._type = types.T.Integer else typedecl.tag_error(op) @@ -992,7 +992,7 @@ function Typechecker:check_exp_synthesize(exp) elseif tag == "ast.Exp.ToFloat" then assert(exp.exp._type._tag == "types.T.Integer") - exp._type = types.T.Float() + exp._type = types.T.Float else typedecl.tag_error(tag) diff --git a/src/pallene/typedecl.lua b/src/pallene/typedecl.lua index 12e7bf46..b11c5668 100644 --- a/src/pallene/typedecl.lua +++ b/src/pallene/typedecl.lua @@ -70,19 +70,26 @@ function typedecl.declare(module, mod_name, type_name, constructors) module[type_name] = {} for cons_name, fields in pairs(constructors) do local tag = make_tag(mod_name, type_name, cons_name) - local function cons(...) - local args = table.pack(...) - if args.n ~= #fields then - error(string.format( - "wrong number of arguments for %s. Expected %d but received %d.", - cons_name, #fields, args.n)) - end - local node = { _tag = tag } - for i, field in ipairs(fields) do - node[field] = args[i] + + local cons + if #fields == 0 then + cons = { _tag = tag } + else + cons = function(...) + local args = table.pack(...) + if args.n ~= #fields then + error(string.format( + "wrong number of arguments for %s. Expected %d but received %d.", + cons_name, #fields, args.n)) + end + local node = { _tag = tag } + for i, field in ipairs(fields) do + node[field] = args[i] + end + return node end - return node end + module[type_name][cons_name] = cons end end