Skip to content

Commit

Permalink
Fix embedded structs
Browse files Browse the repository at this point in the history
  • Loading branch information
Victorious3 committed Sep 23, 2024
1 parent 9fd2acc commit 7dcbab7
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 38 deletions.
98 changes: 65 additions & 33 deletions src/compiler.pr
Original file line number Diff line number Diff line change
Expand Up @@ -2093,17 +2093,40 @@ def convert_to(kind: InsnKind, loc: &Value, value: Value, tpe: &typechecking::Ty
return ret
}

def get_embed_field(left: &typechecking::Type, right: &typechecking::Type, state: &State) -> &typechecking::StructMember {
if is_interface(left) and is_struct(right) {
if typechecking::implements(right, left, state.module, check_embed = false) { return null }

def get_embed_field(left: &typechecking::Type, right: &typechecking::Type, state: &State) -> &Vector(&typechecking::StructMember) {
let res = vector::make(type &typechecking::StructMember)
// Early return if we already implement the interface!
if is_interface(left) and typechecking::implements(right, left, state.module, check_embed = false) { return res }

if is_ref(right) {
// Unwrap reference
right = right.tpe
}
if is_struct(right) {
if is_struct(left) and typechecking::equals(left, right) or
is_interface(left) and typechecking::implements(right, left, state.module, check_embed = false) {
return res
}

for var field in @right.fields {
if field.is_embed and typechecking::implements(field.tpe, left, state.module, check_embed = false) {
return field
if field.is_embed {
if is_struct(left) and typechecking::equals(left, field.tpe) or
is_interface(left) and typechecking::implements(field.tpe, left, state.module, check_embed = false) {

res.push(field)
return res
} else {
let embed = get_embed_field(left, field.tpe, state)
res.push(field)
res.add_all(embed)
return res
}
}
}
}

return null
return res
}

// value gets loaded by this function
Expand Down Expand Up @@ -2145,40 +2168,49 @@ def convert_to(loc: &Value, value: Value, tpe: &typechecking::Type, state: &Stat
}
let left = tpe.tpe if is_ref(tpe) else tpe
let right = value.tpe.tpe if is_ref(value.tpe) else value.tpe
let embed_field = get_embed_field(left, right, state)

let embed_field = get_embed_field(left, value.tpe, state)
// Try to convert to embedded struct / reference
if is_struct(right) and (is_struct(left) or embed_field) {
var is_embed = false
var field: StructMember
if embed_field {
is_embed = true
field = @embed_field
if is_struct(right) and ((is_struct(left) or is_interface(left)) and embed_field.length > 0) {
var elem = NO_VALUE

// Unwrap reference on the value side
if is_ref(value.tpe) {
elem = state.extract_value(pointer(right), load_value(value, loc, state), [1], loc)
} else {
for var f in @right.fields {
if f.is_embed and (equals(f.tpe, tpe) or equals(f.tpe, tpe.tpe)) {
is_embed = true
field = f
break
elem = @value.addr
}

for var i in 0..embed_field.length {
let field = embed_field(i)
if is_ref(field.tpe) {
elem = state.load(elem.tpe.tpe, elem, loc)
elem = state.extract_value(field.tpe, elem, [field.index !int], loc)
if i < embed_field.length - 1 {
// Convert to a pointer for the next iteration
elem = state.extract_value(pointer(field.tpe), elem, [1], loc)
} else {
// Return reference as is
return elem
}
} else {
elem = state.gep(pointer(field.tpe), field.tpe, elem, [make_int_value(0), make_int_value(field.index !int)], loc)
}
}
if is_embed {
var unwrap = load_value(value, loc, state)
// Unwrap reference on the value side
if is_ref(value.tpe) {
let _ref = state.extract_value(pointer(right), unwrap, [1], loc)
unwrap = state.load(right, _ref, loc)
}
// Extract element
var elem = state.extract_value(field.tpe, unwrap, [field.index !int], loc)
// Wrap in reference if needed
if is_ref(tpe) and not is_ref(field.tpe) {
elem = convert_value_to_ref(tpe, elem, loc, state, 1)
}
return elem

var last_field = embed_field(embed_field.length - 1)

let addr = elem
elem = state.load(elem.tpe.tpe, elem, loc)
elem.addr = addr

// Wrap in reference if needed
if is_ref(tpe) and not is_ref(last_field.tpe) {
elem = convert_value_to_ref(tpe, elem, loc, state, 1)
}
return elem
}

if tpe.kind == value.tpe.kind and value.tpe.is_anon and typechecking::is_struct(value.tpe) {
return convert_anon_to_struct(tpe, value, loc, state)
}
Expand Down
16 changes: 15 additions & 1 deletion src/toolchain.pr
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,16 @@ export type Stage = enum {
BACKEND
}

export type Cast = struct {
src: &typechecking::Type
dst: &typechecking::Type
}
export def hash(cast: Cast) -> int64 {
return combine_hashes(hash(cast.src), hash(cast.dst))
}
export def == (a: Cast, b: Cast) -> bool { return a.src == b.src and a.dst == b.dst }
export def != (a: Cast, b: Cast) -> bool { return not (a == b) }

export type Module = struct {
display_name: Str
filename: Str
Expand All @@ -218,10 +228,13 @@ export type Module = struct {
imports: &Set(Str)
dependants: &Set(weak &Module)
// List of Type
// TODO These should be sets
// This is a list of functions that are generated for dynamic dispatch
dyn_dispatch_consteval: &Vector(&typechecking::Type)
dyn_dispatch: &Vector(&typechecking::Type)

dyn_casts: &Set(Cast)

// This is needed to generate functions from create_destructor
compiler_state: &compiler::State
state: &typechecking::State
Expand Down Expand Up @@ -302,11 +315,12 @@ export def make_module(
scope = scpe,
result = compiler::make_result(),
code = compiler::make_block(),
imported = set::make(),
imported = set::make(Str),
imports = set::make(Str),
dependants = set::make(type weak &Module),
dyn_dispatch_consteval = vector::make(type &typechecking::Type),
dyn_dispatch = vector::make(type &typechecking::Type),
dyn_casts = set::make(Cast),
unresolved = map::make(scope::Ident, type weak &scope::Value),
inlay_hints = vector::make(type &parser::Node),
closures = vector::make(type &scope::Value),
Expand Down
22 changes: 22 additions & 0 deletions src/typechecking.pr
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,21 @@ export def equals(a: &Type, b: &Type) -> bool {
assert
}

def contains(a: &Type) -> &Set(&Type) {
if is_ref_or_weak(a) { a = a.tpe }
let res = set::make(type &Type)
if not is_struct(a) { return res }
for var field in @a.fields {
if field.is_embed {
if is_struct(field.tpe) {
res.add(field.tpe)
res.add_all(contains(field.tpe))
}
}
}
return res
}

def is_setter(mb: StructuralTypeMember) -> bool {
return mb.name.starts_with("__set_") and mb.name.ends_with("__")
}
Expand Down Expand Up @@ -4154,6 +4169,13 @@ def walk_Cast(node: &parser::Node, state: &State) {
}
}

// Dynamic cast
if ltpe.is_ref_or_weak() {
if rtpe.is_struct() and ltpe.tpe != rtpe {
state.module.dyn_casts.add([src = ltpe, dst = rtpe] !Cast)
}
}

if not node.tpe {
node.tpe = rtpe
}
Expand Down
13 changes: 9 additions & 4 deletions std/reflection.pr
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@ export type Type = &interface {
}

export def implements(a: Type, b: InterfaceT) -> bool {
return false
return false // TODO
}

export def assignable(a: Type, b: Type) -> bool {
return false
return false // TODO
}

export def contains(a: StructT, b: Type) -> bool {
return false // TODO
}

export def == (a: Type, b: Type) -> bool {
Expand Down Expand Up @@ -304,8 +308,9 @@ def load_types(input: *uint8, size: size_t, num: size_t, strings: *char) {
for var id in @types.keys() {
let tpe = types(id)

let base = *(tpe !BaseType)
let nmembers = fp.read(int)
tpe.type_members = allocate_ref(type Function, nmembers)
base.type_members = allocate_ref(type Function, nmembers)
for var i in 0..nmembers {
let member = [
name = make_slice(strings, fp.read(int)),
Expand All @@ -324,7 +329,7 @@ def load_types(input: *uint8, size: size_t, num: size_t, strings: *char) {

member.arguments = arguments
member.returns = returns
tpe.type_members()(i) = member
base.type_members(i) = member
}

if tpe.type == StructT {
Expand Down

0 comments on commit 7dcbab7

Please sign in to comment.