Skip to content

Commit

Permalink
Add dynamic casts
Browse files Browse the repository at this point in the history
  • Loading branch information
Victorious3 committed Sep 28, 2024
1 parent 7dcbab7 commit 2e8f433
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 16 deletions.
128 changes: 126 additions & 2 deletions src/compiler.pr
Original file line number Diff line number Diff line change
Expand Up @@ -2166,10 +2166,34 @@ def convert_to(loc: &Value, value: Value, tpe: &typechecking::Type, state: &Stat
return value
}
}

let cast = [src = value.tpe, dst = tpe] !toolchain::Cast
if state.module.dyn_casts.contains(cast) {
let fun = state.module.dyn_casts(cast)

var dst = cast.dst
if not is_ref(cast.dst) {
dst = pointer(dst)
}

let dst_value = state.alloca(dst, loc)
state.call(fun.type_name, null, [value, dst_value], loc)

if is_ref(cast.dst) {
let val = state.load(cast.dst, dst_value)
return val
} else {
let ptr = state.load(pointer(cast.dst), dst_value)
let res = state.load(cast.dst, ptr)
res.addr = ptr
return res
}
}

let left = tpe.tpe if is_ref(tpe) else tpe
let right = value.tpe.tpe if is_ref(value.tpe) else value.tpe
let embed_field = get_embed_field(left, value.tpe, state)

// Try to convert to embedded struct / reference
if is_struct(right) and ((is_struct(left) or is_interface(left)) and embed_field.length > 0) {
var elem = NO_VALUE
Expand Down Expand Up @@ -8062,6 +8086,105 @@ export def create_dyn_dispatch(dyn_dispatch: &Vector(&typechecking::Type), state
}
}

export def create_dyn_casts(state: &State) {
let dyn_casts = state.module.dyn_casts
for var cast in @dyn_casts.keys() {
let function = predeclare_function(dyn_casts(cast), state.module)
create_dyn_cast_function(function, cast, state)
consteval::const_module.result.functions(function.name) = function
}
}

def create_dyn_cast_function(function: &Function, cast: toolchain::Cast, state: &State) {
// Setup function
function.block = make_block()
function.forward_declare = false
let previous_block = state.current_block
state.current_block = function.block

var dst_arg = pointer(cast.dst)
if not is_ref(cast.dst) {
dst_arg = pointer(dst_arg)
}

let src_value = [ kind = ValueKind::LOCAL, tpe = cast.src, name = "src.value"] !Value
let dst_value = [ kind = ValueKind::LOCAL, tpe = dst_arg, name = "dst.value" ] !Value

// Extract type
let ref_tpe = state.extract_value(
typechecking::pointer(builtins::Type_),
src_value,
[2]
)
let ref_tpe_deref = state.load(builtins::Type_, ref_tpe)
let tpe_value = state.extract_value(
typechecking::pointer(builtins::Type_),
ref_tpe_deref,
[4]
)
let tpe_deref = state.load(builtins::Type_, tpe_value)
let tpe_id = state.extract_value(builtins::int64_, tpe_deref, [14])

// Switch
let switch_values = vector::make(SwitchValue)
let swtch = make_insn(InsnKind::SWITCH)
swtch.value.switch_ = [
value = tpe_id,
switch_values = switch_values
] !InsnSwitch
push_insn(swtch, state)

let hashes = set::make(uint64)
let keys = map::keys(typechecking::types_map)
for var i in 0..keys.size {
let type_entry = typechecking::types_map(keys(i))

if not is_ref(type_entry.tpe) { continue }
let contains = typechecking::contains(type_entry.tpe)

if not contains.contains(cast.dst) and not (cast.dst.is_ref() and contains.contains(cast.dst.tpe)) { continue }

let hash = md5::high(md5::md5(debug::type_to_str(type_entry.tpe.tpe, full_name = true)))
if hashes.contains(hash) { continue }
hashes.add(hash)

let if_true = make_label(state)
push_label(if_true, state)

if cast.dst.is_ref() {
var res = convert_to(null !&Value, src_value, type_entry.tpe, state)
res = convert_to(null !&Value, res, cast.dst, state)
state.store(dst_value, res)
} else {
var res = convert_to(null !&Value, src_value, type_entry.tpe, state)
res = convert_to(null !&Value, res, cast.dst, state)
state.store(dst_value, @res.addr)
}

state.ret(NO_VALUE)

let svalue = [
label_ = if_true,
value = [ kind = ValueKind::INT, i = hash, tpe = builtins::int64_ ] !Value
] !SwitchValue
switch_values.push(svalue)
}

let end_label = make_label(state)
push_label(end_label, state)
swtch.value.switch_.otherwise = end_label

// TODO Abort with message!
import_cstd_function("abort", state)
state.call("abort", null, [] ![Value])

state.module.imported.add(function.name)
push_insn(make_insn(InsnKind::UNREACHABLE), state)

// Reset state
state.current_block = previous_block
}

export let constructors = map::make(type &typechecking::Type)
def create_constructors {
var done = set::make()
Expand Down Expand Up @@ -9617,7 +9740,7 @@ def generate_vtable_function(function: &Function, tpe: &typechecking::Type, stat
] !InsnSwitch
push_insn(swtch, state)

let hashes = set::make(size_t)
let hashes = set::make(uint64)
let keys = map::keys(typechecking::types_map)
for var i in 0..keys.size {
let type_entry = typechecking::types_map(keys(i))
Expand Down Expand Up @@ -10045,6 +10168,7 @@ export def compile(state: &State, is_main: bool, no_cleanup: bool = false) {
// TODO This doesn't work for functions that return multiple parameters
predeclare_functions(state.module)
create_dyn_dispatch(state.module.dyn_dispatch, state)
create_dyn_casts(state)

let ident = parser::make_identifier("__main__")
ident.loc.module = state.module.module
Expand Down
4 changes: 2 additions & 2 deletions src/toolchain.pr
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ export type Module = struct {
dyn_dispatch_consteval: &Vector(&typechecking::Type)
dyn_dispatch: &Vector(&typechecking::Type)

dyn_casts: &Set(Cast)
dyn_casts: &Map(Cast, &typechecking::Type)

// This is needed to generate functions from create_destructor
compiler_state: &compiler::State
Expand Down Expand Up @@ -320,7 +320,7 @@ export def make_module(
dependants = set::make(type weak &Module),
dyn_dispatch_consteval = vector::make(type &typechecking::Type),
dyn_dispatch = vector::make(type &typechecking::Type),
dyn_casts = set::make(Cast),
dyn_casts = map::make(Cast, type &typechecking::Type),
unresolved = map::make(scope::Ident, type weak &scope::Value),
inlay_hints = vector::make(type &parser::Node),
closures = vector::make(type &scope::Value),
Expand Down
42 changes: 30 additions & 12 deletions src/typechecking.pr
Original file line number Diff line number Diff line change
Expand Up @@ -1282,16 +1282,14 @@ export def equals(a: &Type, b: &Type) -> bool {
assert
}

def contains(a: &Type) -> &Set(&Type) {
export def contains(a: &Type) -> &Set(&Type) {
if is_ref_or_weak(a) { a = a.tpe }
let res = set::make(type &Type)
if not is_struct(a) { return res }
for var field in @a.fields {
if field.is_embed {
if is_struct(field.tpe) {
res.add(field.tpe)
res.add_all(contains(field.tpe))
}
res.add(field.tpe)
res.add_all(contains(field.tpe))
}
}
return res
Expand Down Expand Up @@ -4141,6 +4139,12 @@ def walk_Cast(node: &parser::Node, state: &State) {
var ltpe = left.tpe
if not ltpe { return }

// Create necessary type entries
create_type_entry(ltpe)
if is_struct(ltpe) and is_ref(rtpe) {
create_type_entry(reference(ltpe))
}

if left.kind == parser::NodeKind::INTEGER and
is_integer(ltpe) and is_integer(rtpe) {
left.tpe = rtpe
Expand All @@ -4157,6 +4161,27 @@ def walk_Cast(node: &parser::Node, state: &State) {
errors::errorn(left, "Invalid cast")
return
}
} else if ltpe.is_ref_or_weak() and
ltpe.tpe.is_interface() and
not implements(rtpe, ltpe.tpe, state.module) and
(rtpe.is_struct() or (rtpe.is_ref() and rtpe.tpe.is_struct())) {

var dst = pointer(rtpe)
if not is_ref(rtpe) {
dst = pointer(dst)
}

let params = vector::make(NamedParameter)
params.push([ name = "src", _tpe = ltpe ] !NamedParameter)
params.push([ name = "dst", _tpe = dst] !NamedParameter)
let fun = make_function_type_n(
parser::make_identifier("__cast"),
params,
vector::make(type &Type),
state.module)

// Dynamic cast
state.module.dyn_casts([src = ltpe, dst = rtpe,] !Cast) = fun
} else if is_struct(ltpe) or is_struct(rtpe) {
if ltpe.kind != rtpe.kind and not is_ref(rtpe) {
errors::errorn(left, "Invalid cast")
Expand All @@ -4169,13 +4194,6 @@ def walk_Cast(node: &parser::Node, state: &State) {
}
}

// Dynamic cast
if ltpe.is_ref_or_weak() {
if rtpe.is_struct() and ltpe.tpe != rtpe {
state.module.dyn_casts.add([src = ltpe, dst = rtpe] !Cast)
}
}

if not node.tpe {
node.tpe = rtpe
}
Expand Down

0 comments on commit 2e8f433

Please sign in to comment.