Skip to content

Commit

Permalink
Improved function call type-checking.
Browse files Browse the repository at this point in the history
  • Loading branch information
asoffer committed Dec 30, 2023
1 parent f51e39c commit c39050c
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 47 deletions.
1 change: 1 addition & 0 deletions common/language/precedence.xmacro.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ IC_XMACRO_PRECEDENCE_GROUP(Function)
IC_XMACRO_PRECEDENCE_GROUP(Comparison)
IC_XMACRO_PRECEDENCE_GROUP(PlusMinus)
IC_XMACRO_PRECEDENCE_GROUP(Modulus)
IC_XMACRO_PRECEDENCE_GROUP(As)
IC_XMACRO_PRECEDENCE_GROUP(MultiplyDivide)
IC_XMACRO_PRECEDENCE_GROUP(TightUnary)

Expand Down
4 changes: 2 additions & 2 deletions examples/arguments.ic
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ let io ::= import "std.io"
// returns a slice of character slices. That is, a value of type `\\char`.

io.Print("There are ")
io.PrintNum(builtin.arguments().count)
io.PrintNum(builtin.arguments().count as i64)
io.Print(" program argument(s):")

var i: u64 = 0
while (i < builtin.arguments().count) {
io.Print("\n ")
io.PrintNum(i)
io.PrintNum(i as i64)
io.Print(": ")
io.Print(builtin.arguments()[i])
i = i + 1
Expand Down
5 changes: 5 additions & 0 deletions ir/emit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,11 @@ void HandleParseTreeNodeExpressionPrecedenceGroup(ParseNodeIndex index,
context.current_function().append<jasmin::LessThan<int64_t>>();
context.current_function().append<jasmin::Not>();
} break;
case Token::Kind::As: {
context.current_function().append<jasmin::Drop>();
// TODO: Cast. For now most integer casts are correct enough at the jasmin
// level, we can ignore this. (Jasmin debug info would flag it though).
} break;
default: NTH_UNIMPLEMENTED("{}") <<= {operator_node.token};
}
}
Expand Down
175 changes: 131 additions & 44 deletions ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,16 @@ void HandleParseTreeNodeExpressionPrecedenceGroup(
context.PopTypeStack(1 + node.child_count / 2);
context.type_stack().push({current});
} break;
case Token::Kind::As: {
if (context.type_stack().top()[0] !=
type::QualifiedType::Constant(type::Type_)) {
NTH_UNIMPLEMENTED();
}
context.PopTypeStack(1 + node.child_count / 2);
std::optional t = context.EvaluateAs<type::Type>(index - 1);
if (not t) { NTH_UNIMPLEMENTED(); }
context.type_stack().push({type::QualifiedType::Constant(*t)});
} break;
case Token::Kind::Less:
case Token::Kind::Greater:
case Token::Kind::LessEqual:
Expand Down Expand Up @@ -709,22 +719,70 @@ void HandleParseTreeNodeIndexExpression(ParseNodeIndex index,
}
}

struct Call {
struct InvocationSuccess {};
struct ParameterArgumentCountMismatch {
size_t parameters;
size_t arguments;
};
struct InvalidBinding{
size_t index;
type::Type parameter;
type::Type argument;
};
using InvocationResult =
std::variant<InvocationSuccess, ParameterArgumentCountMismatch, InvalidBinding>;

struct CallArguments {
struct Argument {
ParseNodeIndex index;
type::QualifiedType qualified_type;

type::Type type() const { return qualified_type.type(); }
};

InvocationResult Invoke(type::FunctionType fn_type) {
auto const& parameters = *fn_type.parameters();
// TODO: Properly implement function call type-checking.
if (parameters.size() != arguments.size()) {
return ParameterArgumentCountMismatch{.parameters = parameters.size(),
.arguments = arguments.size()};
}

for (size_t i = 0; i < arguments.size(); ++i) {
if (not type::ImplicitCast(arguments[i].type(), parameters[i].type)) {
return InvalidBinding{
.index = i,
.parameter = parameters[i].type,
.argument = arguments[i].type(),
};
}
}

return InvocationSuccess{};
}

jasmin::InstructionSpecification MakeInstructionSpecification(
type::FunctionType fn_type) const {
jasmin::InstructionSpecification spec{.parameters = 1, .returns = 0};
auto iter = (*fn_type.parameters()).begin();
for (size_t i = 0; i < std::distance(postfix_start, arguments.end()); ++i) {
spec.parameters += type::JasminSize(iter->type);
++iter;
}

for (type::Type r : fn_type.returns()) {
spec.returns += type::JasminSize(r);
}
return spec;
}

Argument callee;
std::vector<Argument> arguments;
std::vector<Argument>::const_iterator postfix_start;
};

void HandleParseTreeNodeCallExpression(ParseNodeIndex index, IrContext& context,
diag::DiagnosticConsumer& diag) {
Call call;

bool PopulateCall(ParseNodeIndex index, IrContext& context,
CallArguments& call) {
auto& type_stack = context.type_stack();
auto nodes = context.ChildIndices(index);
int postfix_count = -1;
Expand All @@ -733,7 +791,7 @@ void HandleParseTreeNodeCallExpression(ParseNodeIndex index, IrContext& context,
auto const& child_node = context.Node(*node_iter);
if (child_node.kind == ParseNode::Kind::PrefixInvocationArgumentEnd) {
NTH_REQUIRE((v.debug), not call.arguments.empty());
postfix = false;
postfix = false;
call.callee = call.arguments.back();
call.arguments.pop_back();
continue;
Expand All @@ -748,63 +806,92 @@ void HandleParseTreeNodeCallExpression(ParseNodeIndex index, IrContext& context,
if (qts[0].type() == type::Error) {
for (; node_iter != nodes.end(); ++node_iter) { type_stack.pop(); }
type_stack.push({type::QualifiedType::Unqualified(type::Error)});
return;
return false;
}
call.arguments.push_back({.index = *node_iter, .qualified_type = qts[0]});
type_stack.pop();
if (postfix) { ++postfix_count; }
}

if (postfix) {
call.callee = call.arguments.back();
call.arguments.pop_back();
}
std::reverse(call.arguments.begin(), call.arguments.end());
call.postfix_start = call.arguments.end() - postfix_count;
return true;
}

void HandleParseTreeNodeCallExpression(ParseNodeIndex index, IrContext& context,
diag::DiagnosticConsumer& diag) {
CallArguments call;
if (not PopulateCall(index, context, call)) { return; }
auto node = context.Node(index);

if (call.callee.type().kind() == type::Type::Kind::Function) {
auto fn_type = call.callee.type().AsFunction();
auto const& parameters = *fn_type.parameters();
// TODO: Properly implement function call type-checking.
if (parameters.size() == call.arguments.size()) {
auto type_iter = call.arguments.begin();
auto [iter, inserted] = context.emit.instruction_spec.try_emplace(index);
auto& spec = iter->second;
NTH_REQUIRE((v.harden), inserted);
++spec.parameters;
for (auto iter = call.postfix_start; iter != call.arguments.end();
++iter) {
spec.parameters += type::JasminSize(iter->type());
}
auto const& returns = fn_type.returns();
std::vector<type::QualifiedType> return_qts;
for (type::Type r : returns) {
spec.returns += type::JasminSize(r);
return_qts.push_back(type::QualifiedType::Unqualified(r));
}
context.type_stack().push(return_qts);

switch (fn_type.evaluation()) {
case type::Evaluation::PreferCompileTime: NTH_UNIMPLEMENTED();
case type::Evaluation::RequireCompileTime: {
nth::interval range = context.emit.tree.subtree_range(index);
nth::stack<jasmin::Value> value_stack;
context.emit.Evaluate(range, value_stack, returns);
auto module_id = context.EvaluateAs<ModuleId>(index);
if (module_id == ModuleId::Invalid()) {
InvocationResult result = call.Invoke(fn_type);
bool success = std::visit(
[&](auto const& r) {
constexpr auto t = nth::type<decltype(r)>.decayed();
if constexpr (t == nth::type<InvocationSuccess>) {
return true;
} else if constexpr (t == nth::type<ParameterArgumentCountMismatch>) {
diag.Consume({
diag::Header(diag::MessageKind::Error),
diag::Text(InterpolateString<"No module found named \"{}\"">(
*context.EvaluateAs<std::string_view>(index - 1))),
diag::Text(InterpolateString<
"Incorrect number of arguments passed to function: "
"Expected {}, but {} were provided.">(r.parameters,
r.arguments)),
diag::SourceQuote(context.Node(index).token),
});
} else if constexpr (t == nth::type<InvalidBinding>) {
diag.Consume({
diag::Header(diag::MessageKind::Error),
diag::Text(InterpolateString<
"Argument at position {} cannot be passed to "
"function. Expected a {} but argument has type {}">(
r.index, r.parameter, r.argument)),
diag::SourceQuote(context.Node(index).token),
});
return;
}
} break;
case type::Evaluation::PreferRuntime:
case type::Evaluation::RequireRuntime: break;
}
} else {
NTH_UNIMPLEMENTED("{} {}") <<= {parameters, call.arguments.size()};
return false;
},
result);
if (not success) {
context.type_stack().push(
{type::QualifiedType::Unqualified(type::Error)});
return;
}
// TODO: Properly implement function call type-checking.
auto [iter, inserted] = context.emit.instruction_spec.try_emplace(
index, call.MakeInstructionSpecification(fn_type));
NTH_REQUIRE((v.harden), inserted);
auto const& returns = fn_type.returns();
std::vector<type::QualifiedType> return_qts;
for (type::Type r : returns) {
return_qts.push_back(type::QualifiedType::Unqualified(r));
}
context.type_stack().push(return_qts);

switch (fn_type.evaluation()) {
case type::Evaluation::PreferCompileTime: NTH_UNIMPLEMENTED();
case type::Evaluation::RequireCompileTime: {
nth::interval range = context.emit.tree.subtree_range(index);
nth::stack<jasmin::Value> value_stack;
context.emit.Evaluate(range, value_stack, returns);
auto module_id = context.EvaluateAs<ModuleId>(index);
if (module_id == ModuleId::Invalid()) {
diag.Consume({
diag::Header(diag::MessageKind::Error),
diag::Text(InterpolateString<"No module found named \"{}\"">(
*context.EvaluateAs<std::string_view>(index - 1))),
});
return;
}
} break;
case type::Evaluation::PreferRuntime:
case type::Evaluation::RequireRuntime: break;
}
} else if (call.callee.type().kind() == type::Type::Kind::DependentFunction) {
auto& spec = context.emit.instruction_spec[index];
Expand Down
1 change: 1 addition & 0 deletions lexer/token_kind.xmacro.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ IC_XMACRO_TOKEN_KIND_KEYWORD(Enum, "enum")
IC_XMACRO_TOKEN_KIND_KEYWORD(Interface, "interface")
IC_XMACRO_TOKEN_KIND_KEYWORD(Extend, "extend")
IC_XMACRO_TOKEN_KIND_KEYWORD(With, "with")
IC_XMACRO_TOKEN_KIND_KEYWORD(As, "as")

IC_XMACRO_TOKEN_KIND_TERMINAL_EXPRESSION(True, "true")
IC_XMACRO_TOKEN_KIND_TERMINAL_EXPRESSION(False, "false")
Expand Down
1 change: 1 addition & 0 deletions parse/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,7 @@ void Parser::HandleTryPrefix(ParseTree& tree) {
void Parser::HandleTryInfix(ParseTree& tree) {
Precedence p = Precedence::Loosest();
switch (current_token().kind()) {
case Token::Kind::As: p = Precedence::As(); break;
case Token::Kind::Star: p = Precedence::MultiplyDivide(); break;
#define IC_XMACRO_TOKEN_KIND_BINARY_OPERATOR(kind, symbol, precedence_group) \
case Token::Kind::kind: \
Expand Down
2 changes: 1 addition & 1 deletion toolchain/stdlib/io.ic
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ let stdio ::= import "std.compat.c.stdio"

let Print ::= fn(let s: \char) -> () {
let printf ::= builtin.foreign("printf", ([*]char, c.int, [*]char) -> c.int)
printf("%.*s".data, s.count, s.data)
printf("%.*s".data, s.count as i32, s.data)
}

let PrintNum ::= fn(let n: i64) -> () {
Expand Down
7 changes: 7 additions & 0 deletions type/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ bool ImplicitCast(Type from, Type to) {
if (to.kind() == Type::Kind::Primitive and Numeric(to.AsPrimitive())) {
return true;
}
} else if (from == U8) {
return to == I16 or to == I64 or to == I32 or to == U16 or to == U32 or
to == U64;
} else if (from == U16) {
return to == I32 or to == I64 or to == U32 or to == U64;
} else if (from == U32) {
return to == I64 or to == U64;
} else if (from == NullType) {
return to.kind() == Type::Kind::Pointer or
to.kind() == Type::Kind::BufferPointer;
Expand Down

0 comments on commit c39050c

Please sign in to comment.