Skip to content

Commit

Permalink
Replace GenericFunction with DependentFunction.
Browse files Browse the repository at this point in the history
  • Loading branch information
asoffer committed Dec 23, 2023
1 parent cbb2142 commit a7cefab
Show file tree
Hide file tree
Showing 11 changed files with 35 additions and 115 deletions.
1 change: 0 additions & 1 deletion common/language/type_kind.xmacro.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ IC_XMACRO_TYPE_KIND(Function)
IC_XMACRO_TYPE_KIND(Slice)
IC_XMACRO_TYPE_KIND(Pointer)
IC_XMACRO_TYPE_KIND(BufferPointer)
IC_XMACRO_TYPE_KIND(GenericFunction)
IC_XMACRO_TYPE_KIND(Opaque)
IC_XMACRO_TYPE_KIND(DependentFunction)

Expand Down
31 changes: 9 additions & 22 deletions ir/builtin_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,6 @@ nth::NoDestructor<IrFunction> Slice([] {
return f;
}());

nth::NoDestructor<IrFunction> ForeignType([] {
IrFunction f(3, 1);
f.append<jasmin::Swap>();
f.append<jasmin::Drop>();
f.append<jasmin::Swap>();
f.append<jasmin::Drop>();
f.append<jasmin::Return>();
return f;
}());

nth::NoDestructor<std::vector<std::string>> BuiltinNamesImpl;

} // namespace
Expand Down Expand Up @@ -122,18 +112,15 @@ Module BuiltinModule() {
{type::Slice(type::Char)}),
*Slice);

// TODO: There's something wrong with registration happening after this point.
m.Insert(
Identifier("foreign"),
{.qualified_type = type::QualifiedType::Constant(type::GenericFunction(
type::Evaluation::RequireCompileTime, &*ForeignType)),
.value = {&*Foreign}});
global_function_registry.Register(
FunctionId(ModuleId::Builtin(), LocalFunctionId(next_id++)),
&*ForeignType);

global_function_registry.Register(
FunctionId(ModuleId::Builtin(), LocalFunctionId(next_id++)), &*Foreign);
Register(
"foreign",
type::Dependent(type::DependentTerm::Function(
type::DependentTerm::Value(
TypeErasedValue(type::Type_, {type::Type_})),
type::DependentTerm::DeBruijnIndex(0)),
type::DependentParameterMapping(
{type::DependentParameterMapping::Index::Value(1)})),
*Foreign);

return m;
}
Expand Down
2 changes: 1 addition & 1 deletion ir/emit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ void EmitContext::Push(std::span<jasmin::Value const> vs, type::Type t) {
return;
}
switch (t.kind()) {
case type::Type::Kind::GenericFunction:
case type::Type::Kind::DependentFunction:
case type::Type::Kind::Function: {
NTH_REQUIRE((v.harden), vs.size() == 1);
current_function().append<PushFunction>(vs[0]);
Expand Down
4 changes: 2 additions & 2 deletions ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ namespace ic {

struct Store : jasmin::Instruction<Store> {
static void consume(std::span<jasmin::Value, 2> input, uint8_t size) {
void* location = input[0].as<void*>();
jasmin::Value value = input[1];
jasmin::Value value = input[0];
void* location = input[1].as<void*>();
jasmin::Value::Store(value, location, size);
}
static constexpr std::string_view debug() { return "store"; }
Expand Down
37 changes: 20 additions & 17 deletions ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ void HandleParseTreeNodeCallExpression(ParseNodeIndex index, IrContext& context,
NTH_UNIMPLEMENTED("{} {}") <<= {parameters, node.child_count - 1};
}
} else if (invocable_type.type().kind() ==
type::Type::Kind::GenericFunction) {
type::Type::Kind::DependentFunction) {
auto& spec = context.emit.instruction_spec[index];
++spec.parameters;
std::vector<std::pair<ParseNodeIndex, TypeStack::const_iterator>> argument_indices;
Expand All @@ -780,28 +780,31 @@ void HandleParseTreeNodeCallExpression(ParseNodeIndex index, IrContext& context,
}

std::reverse(argument_indices.begin(), argument_indices.end());
nth::stack<jasmin::Value> value_stack;
std::vector<TypeErasedValue> arguments;

for (auto [index, type_iter] : argument_indices) {
auto t = (*type_iter)[0].type();

auto qt = (*type_iter)[0];
nth::stack<jasmin::Value> value_stack;
nth::interval range = context.emit.tree.subtree_range(index);
context.emit.Evaluate(range, value_stack, {t});
spec.parameters += type::JasminSize(t);
if (qt.constant()) {
context.emit.Evaluate(range, value_stack, {qt.type()});
std::span values = value_stack.top_span(value_stack.size());
arguments.emplace_back(qt.type(),
std::vector(values.begin(), values.end()));
} else {
arguments.emplace_back(qt.type(), std::vector<jasmin::Value>{});
}
spec.parameters += type::JasminSize(qt.type());
}

auto g = invocable_type.type().AsGenericFunction();
jasmin::Execute(g.function(), value_stack);
auto t = value_stack.top().as<type::Type>();
context.type_stack().push({type::QualifiedType::Constant(t)});
NTH_REQUIRE((v.debug), t.kind() == type::Type::Kind::Function);
auto dep = invocable_type.type().AsDependentFunction();
std::optional t = dep(arguments);
context.type_stack().push({type::QualifiedType::Constant(*t)});
NTH_REQUIRE((v.debug), t->kind() == type::Type::Kind::Function);

if (g.evaluation() == type::Evaluation::PreferCompileTime or
g.evaluation() == type::Evaluation::RequireCompileTime) {
nth::stack<jasmin::Value> value_stack;
context.emit.Evaluate(context.emit.tree.subtree_range(index), value_stack,
{t});
}
nth::stack<jasmin::Value> value_stack;
context.emit.Evaluate(context.emit.tree.subtree_range(index), value_stack,
{*t});
} else {
NTH_UNIMPLEMENTED("node = {} invocable_type = {}") <<=
{node, invocable_type};
Expand Down
13 changes: 0 additions & 13 deletions type/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,6 @@ cc_library(
],
)

cc_library(
name = "generic_function",
hdrs = ["generic_function.h"],
deps = [
":basic",
":function",
":type_system_cc_proto",
"//common/language:type_kind",
"@asoffer_nth//nth/debug",
],
)

cc_library(
name = "opaque",
hdrs = ["opaque.h"],
Expand Down Expand Up @@ -190,7 +178,6 @@ cc_library(
":dependent",
":family",
":function",
":generic_function",
":opaque",
":parameters",
":pointer",
Expand Down
2 changes: 1 addition & 1 deletion type/deserialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Type Deserialize(TypeProto const& proto, TypeSystemProto const& ts) {
case TypeProto::BUFFER_POINTER:
return DeserializeBufferPointerType(ts.buffer_pointers(proto.index()), ts);
case TypeProto::OPAQUE: return OpaqueType(proto.index());
default: NTH_UNREACHABLE();
default: NTH_UNREACHABLE("{}") <<= {proto.DebugString()};
}
}

Expand Down
30 changes: 0 additions & 30 deletions type/generic_function.h

This file was deleted.

23 changes: 1 addition & 22 deletions type/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,6 @@ BufferPointerType BufPtr(Type t) {
type_system->buffer_pointee_types.insert(t).first));
}

GenericFunctionType GenericFunction(Evaluation e, IrFunction const* fn) {
return GenericFunctionType(type_system->generic_function_types.index(
type_system->generic_function_types.insert(std::pair(fn, e)).first));
}

Type SliceType::element_type() const {
return type_system->slice_element_types.from_index(data());
}
Expand Down Expand Up @@ -112,16 +107,6 @@ std::vector<Type> const& FunctionType::returns() const {
std::get<1>(type_system->functions.from_index(data())));
}

IrFunction const& GenericFunctionType::function() const {
return *type_system->generic_function_types.from_index(BasicType::data())
.first;
}

Evaluation GenericFunctionType::evaluation() const {
return type_system->generic_function_types.from_index(BasicType::data())
.second;
}

OpaqueType Opaque() { return OpaqueType(opaque_count++); }

size_t JasminSize(Type t) {
Expand All @@ -132,7 +117,6 @@ size_t JasminSize(Type t) {
case Type::Kind::Slice: return 2;
case Type::Kind::Pointer: return 1;
case Type::Kind::BufferPointer: return 1;
case Type::Kind::GenericFunction: return 1;
case Type::Kind::Opaque: NTH_UNREACHABLE("{}") <<= {t};
case Type::Kind::DependentFunction: NTH_UNREACHABLE("{}") <<= {t};
}
Expand Down Expand Up @@ -225,21 +209,16 @@ std::optional<Type> DependentFunctionType::operator()(
case DependentParameterMapping::Index::Kind::Type:
if (not term_copy.bind(
TypeErasedValue(Type_, {values[index.index()].type()}))) {
NTH_LOG("Returned");
return std::nullopt;
}
break;
case DependentParameterMapping::Index::Kind::Value:
if (not term_copy.bind(values[index.index()])) {
NTH_LOG("Returned");
return std::nullopt; }
break;
}
}
if (auto* v = term_copy.evaluate()) {
NTH_LOG("Returned");
return v->value()[0].as<Type>(); }
NTH_LOG("Returned");
if (auto* v = term_copy.evaluate()) { return v->value()[0].as<Type>(); }
return std::nullopt;
}

Expand Down
2 changes: 0 additions & 2 deletions type/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "type/dependent.h"
#include "type/family.h"
#include "type/function.h"
#include "type/generic_function.h"
#include "type/opaque.h"
#include "type/parameters.h"
#include "type/pointer.h"
Expand Down Expand Up @@ -41,7 +40,6 @@ struct TypeSystem {
nth::flyweight_set<Type> slice_element_types;
nth::flyweight_set<Type> pointee_types;
nth::flyweight_set<Type> buffer_pointee_types;
nth::flyweight_set<std::pair<IrFunction const*, Evaluation>> generic_function_types;
nth::flyweight_set<Family> type_families;
nth::flyweight_set<DependentTerm> dependent_terms;
nth::flyweight_set<DependentParameterMapping> dependent_mapping;
Expand Down
5 changes: 1 addition & 4 deletions type/type_system.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ message TypeProto {
SLICE = 4;
POINTER = 5;
BUFFER_POINTER = 6;
GENERIC_FUNCTION = 7;
OPAQUE = 8;
DEPENDENT_PRODUCT = 9;
DEPENDENT_SUM = 10;
OPAQUE = 7;
}
Kind kind = 1;
uint32 index = 2;
Expand Down

0 comments on commit a7cefab

Please sign in to comment.