Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
GordonSmith committed Mar 22, 2024
1 parent 8995457 commit 3ff99dd
Show file tree
Hide file tree
Showing 5 changed files with 963 additions and 38 deletions.
55 changes: 23 additions & 32 deletions src/store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,38 @@ namespace cmcpp
uint32_t ptr = cx.opts->realloc(0, 0, dst_alignment, dst_byte_length);
assert(ptr == align_to(ptr, dst_alignment));
assert(ptr + dst_byte_length <= cx.opts->memory.size());
auto [enc_src, enc_len] = encode(src, src_code_units, dst_encoding);
auto enc_len = encodeTo(&cx.opts->memory[ptr], src, src_code_units, dst_encoding);
assert(dst_byte_length == enc_len);
std::memcpy(&cx.opts->memory[ptr], enc_src, enc_len);
return std::make_pair(ptr, enc_len);
}

auto MAX_STRING_BYTE_LENGTH = (1U << 31) - 1;

std::pair<uint32_t, uint32_t> store_string_to_utf8(const CallContext &cx, const char8_t *src, uint32_t src_code_units, uint32_t worst_case_size)
std::pair<uint32_t, uint32_t> store_string_to_utf8(const CallContext &cx, const char8_t *src, uint32_t src_code_units)
{
assert(src_code_units <= MAX_STRING_BYTE_LENGTH);
uint32_t ptr = cx.opts->realloc(0, 0, 1, src_code_units);
assert(ptr + src_code_units <= cx.opts->memory.size());
auto [enc_src, enc_len] = encode(src, src_code_units, GuestEncoding::Utf8);
auto enc_len = encodeTo(&cx.opts->memory[ptr], src, src_code_units, GuestEncoding::Utf8);
assert(src_code_units <= enc_len);
std::memcpy(&cx.opts->memory[ptr], enc_src, enc_len);
if (src_code_units <= enc_len)
if (src_code_units < enc_len)
{
assert(worst_case_size <= MAX_STRING_BYTE_LENGTH);
ptr = cx.opts->realloc(ptr, src_code_units, 1, worst_case_size);
assert(ptr + worst_case_size <= cx.opts->memory.size());
std::memcpy(&cx.opts->memory[ptr + src_code_units], enc_src, enc_len);
if (worst_case_size > enc_len)
{
ptr = cx.opts->realloc(ptr, worst_case_size, 1, enc_len);
assert(ptr + enc_len <= cx.opts->memory.size());
}
assert(enc_len <= MAX_STRING_BYTE_LENGTH);
uint32_t ptr = cx.opts->realloc(ptr, src_code_units, 1, enc_len);
assert(ptr + enc_len <= cx.opts->memory.size());
enc_len = encodeTo(&cx.opts->memory[ptr], src, enc_len, GuestEncoding::Utf8);
}
return std::make_pair(ptr, enc_len);
}

std::pair<uint32_t, uint32_t> store_utf16_to_utf8(const CallContext &cx, const char8_t *src, uint32_t src_code_units)
{
uint32_t worst_case_size = src_code_units * 3;
return store_string_to_utf8(cx, src, src_code_units, worst_case_size);
return store_string_to_utf8(cx, src, src_code_units);
}

std::pair<uint32_t, uint32_t> store_latin1_to_utf8(const CallContext &cx, const char8_t *src, uint32_t src_code_units)
{
uint32_t worst_case_size = src_code_units * 2;
return store_string_to_utf8(cx, src, src_code_units, worst_case_size);
return store_string_to_utf8(cx, src, src_code_units);
}

std::pair<uint32_t, uint32_t> store_utf8_to_utf16(const CallContext &cx, const char8_t *src, uint32_t src_code_units)
Expand All @@ -70,11 +61,12 @@ namespace cmcpp
throw std::runtime_error("Pointer misaligned");
if (ptr + worst_case_size > cx.opts->memory.size())
throw std::runtime_error("Out of bounds access");
auto [enc_src, enc_len] = encode(src, src_code_units, GuestEncoding::Utf16le);
std::memcpy(&cx.opts->memory[ptr], enc_src, enc_len);
auto enc_len = encodeTo(&cx.opts->memory[ptr], src, src_code_units, GuestEncoding::Utf16le);
if (enc_len < worst_case_size)
{
uint32_t cleanup_ptr = ptr;
ptr = cx.opts->realloc(ptr, worst_case_size, 2, enc_len);
std::memcpy(&cx.opts->memory[ptr], &cx.opts->memory[ptr], enc_len);
if (ptr != align_to(ptr, 2))
throw std::runtime_error("Pointer misaligned");
if (ptr + enc_len > cx.opts->memory.size())
Expand All @@ -95,7 +87,7 @@ namespace cmcpp
uint32_t dst_byte_length = 0;
for (size_t i = 0; i < src_code_units; ++i)
{
char usv = *(const char8_t *)src;
char8_t usv = *src;
if (static_cast<uint32_t>(usv) < (1 << 8))
{
cx.opts->memory[ptr + dst_byte_length] = static_cast<uint32_t>(usv);
Expand All @@ -116,15 +108,15 @@ namespace cmcpp
cx.opts->memory[ptr + 2 * j] = cx.opts->memory[ptr + j];
cx.opts->memory[ptr + 2 * j + 1] = 0;
}
auto [enc_src, enc_len] = encode(src, src_code_units, GuestEncoding::Utf16le);
std::memcpy(&cx.opts->memory[ptr + 2 * dst_byte_length], enc_src, enc_len);
auto enc_len = encodeTo(&cx.opts->memory[ptr + 2 * dst_byte_length], src, src_code_units, GuestEncoding::Utf16le);
if (worst_case_size > enc_len)
{
ptr = cx.opts->realloc(ptr, worst_case_size, 2, enc_len);
if (ptr != align_to(ptr, 2))
throw std::runtime_error("Pointer misaligned");
if (ptr + enc_len > cx.opts->memory.size())
throw std::runtime_error("Out of bounds access");
// TODO - skipping the truncation for now...
// ptr = cx.opts->realloc(ptr, worst_case_size, 2, enc_len);
// if (ptr != align_to(ptr, 2))
// throw std::runtime_error("Pointer misaligned");
// if (ptr + enc_len > cx.opts->memory.size())
// throw std::runtime_error("Out of bounds access");
}
uint32_t tagged_code_units = static_cast<uint32_t>(enc_len / 2) | UTF16_TAG;
return std::make_pair(ptr, tagged_code_units);
Expand Down Expand Up @@ -154,9 +146,8 @@ namespace cmcpp
if (ptr + src_byte_length > cx.opts->memory.size())
throw std::runtime_error("Not enough memory");

auto [enc_src, enc_len] = encode(src, src_code_units, GuestEncoding::Utf16le);
const uint8_t *enc_src_ptr = static_cast<const uint8_t *>(enc_src);
std::memcpy(&cx.opts->memory[ptr], enc_src_ptr, enc_len);
auto enc_len = encodeTo(&cx.opts->memory[ptr], src, src_code_units, GuestEncoding::Utf16le);
const uint8_t *enc_src_ptr = &cx.opts->memory[ptr];
if (std::any_of(enc_src_ptr, enc_src_ptr + enc_len, [](uint8_t c)
{ return static_cast<unsigned char>(c) >= (1 << 8); }))
{
Expand Down
5 changes: 3 additions & 2 deletions src/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,13 +478,14 @@ namespace cmcpp
}
}

std::pair<void *, size_t> encode(const char8_t *src, uint32_t byte_len, GuestEncoding encoding)
size_t encodeTo(void *dest, const char8_t *src, uint32_t byte_len, GuestEncoding encoding)
{
switch (encoding)
{
case GuestEncoding::Utf8:
case GuestEncoding::Latin1:
return {const_cast<char8_t *>(src), byte_len};
std::memcpy(dest, src, byte_len);
return byte_len;
case GuestEncoding::Utf16le:
assert(false);
break;
Expand Down
2 changes: 1 addition & 1 deletion src/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace cmcpp
Val despecialize(const Val &v);
ValType discriminant_type(const std::vector<Case> &cases);

std::pair<void *, size_t> encode(const char8_t *src, uint32_t byte_len, GuestEncoding encoding);
size_t encodeTo(void *, const char8_t *src, uint32_t byte_len, GuestEncoding encoding);
uint32_t encode_float_as_i32(float32_t f);
uint64_t encode_float_as_i64(float64_t f);

Expand Down
6 changes: 3 additions & 3 deletions test/wasmtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ TEST_CASE("component-model-cpp")
lvs = cmcpp::lower_values(*cx, {charListList, charListList});
wasmtimeVals = vals2WasmtimeVals(lvs);
wasmtimeVals = list_list_string_append.call(store, wasmtimeVals).unwrap();
cmcppWasmVals = wasmtimeVals2WasmVals(wasmtimeVals);
cmcppVals = cmcpp::lift_values(*cx, cmcppWasmVals, {std::make_pair(cmcpp::ValType::List, cmcpp::ValType::String)});
CHECK(std::get<cmcpp::ListPtr>(cmcppVals[0])->vs.size() == 32);
wret = wasmtimeVals2WasmVals(wasmtimeVals);
cmcppVals = cmcpp::lift_values(*cx, wret, {cmcpp::ValType::String});
CHECK(std::get<cmcpp::StringPtr>(cmcppVals[0])->len == 32);

// Actual ABI Test Code --------------------------------------------
}
Expand Down
Loading

0 comments on commit 3ff99dd

Please sign in to comment.