From 2d6a20f32fd1e2016435a8b126f329cbc92b117c Mon Sep 17 00:00:00 2001 From: Luke Wagner Date: Sat, 21 Sep 2024 22:32:37 -0500 Subject: [PATCH] Add 'stream' type --- design/mvp/Async.md | 28 +- design/mvp/CanonicalABI.md | 1 + design/mvp/Explainer.md | 10 +- design/mvp/canonical-abi/definitions.py | 405 +++++++++- design/mvp/canonical-abi/run_tests.py | 933 ++++++++++++++++++++++-- 5 files changed, 1265 insertions(+), 112 deletions(-) diff --git a/design/mvp/Async.md b/design/mvp/Async.md index 2a44f8c5..ce108b93 100644 --- a/design/mvp/Async.md +++ b/design/mvp/Async.md @@ -419,21 +419,25 @@ For now, this remains a [TODO](#todo) and validation will reject `async`-lifted ## TODO -Native async support is being proposed in progressive chunks. The following -features will be added in future chunks to complete "async" in Preview 3: -* `future`/`stream`/`error`: add for use in function types for finer-grained - concurrency -* `subtask.cancel`: allow a supertask to signal to a subtask that its result is - no longer wanted and to please wrap it up promptly -* allow "tail-calling" a subtask so that the current wasm instance can be torn - down eagerly -* `task.index`+`task.wake`: allow tasks in the same instance to wait on and - wake each other (async condvar-style) +Native async support is being proposed incrementally. The following features +will be added in future chunks roughly in the order list to complete the full +"async" story: * `nonblocking` function type attribute: allow a function to declare in its type that it will not transitively do anything blocking +* add `future` type +* add `error` type that can be included when closing a stream/future +* define what `async` means for `start` functions (top-level await + background + tasks), along with cross-task coordination built-ins +* `subtask.cancel`: allow a supertask to signal to a subtask that its result is + no longer wanted and to please wrap it up promptly +* zero-copy forwarding/splicing and built-in way to "tail-call" a subtask so + that the current wasm instance can be torn down eagerly while preserving + structured concurrency +* some way to say "no more elements are coming for a while" * `recursive` function type attribute: allow a function to be reentered - recursively (instead of trapping) -* enable `async` `start` functions + recursively (instead of trapping) and link inner and outer activations +* allow pipelining multiple `stream.read`/`write` calls +* allow chaining multiple async calls together ("promise pipelining") * integrate with `shared`: define how to lift and lower functions `async` *and* `shared` diff --git a/design/mvp/CanonicalABI.md b/design/mvp/CanonicalABI.md index eee85099..33f61b28 100644 --- a/design/mvp/CanonicalABI.md +++ b/design/mvp/CanonicalABI.md @@ -2108,6 +2108,7 @@ where `$callee` has type `$ft`, validation specifies: * a `memory` is present if required by lifting and is a subtype of `(memory 1)` * a `realloc` is present if required by lifting and has type `(func (param i32 i32 i32 i32) (result i32))` * there is no `post-return` in `$opts` +* if `contains_async($ft)`, then `$opts.async` must be set When instantiating component instance `$inst`: * Define `$f` to be the partially-bound closure: `canon_lower($opts, $ft, $callee)` diff --git a/design/mvp/Explainer.md b/design/mvp/Explainer.md index 8cf2d333..1b6623f2 100644 --- a/design/mvp/Explainer.md +++ b/design/mvp/Explainer.md @@ -1214,10 +1214,12 @@ validated to have parameters matching the callee's return type and empty results. 🔀 The `async` option specifies that the component wants to make (for imports) -or support (for exports) multiple concurrent (asynchronous) calls. This option -can be applied to any component-level function type and changes the derived -Canonical ABI significantly. See the [async explainer](Async.md) for more -details. +or support (for exports) multiple concurrent (asynchronous) calls. This +option can be applied to any component-level function type and changes the +derived Canonical ABI significantly. See the [async explainer](Async.md) for +more details. When a function signature contains a `future` or `stream`, +validation requires the `async` option to be set (since a synchronous call to +a function using these types is likely to deadlock). 🔀 The `(callback ...)` option may only be present in `canon lift` when the `async` option has also been set and specifies a core function that is diff --git a/design/mvp/canonical-abi/definitions.py b/design/mvp/canonical-abi/definitions.py index ef125509..1823e71a 100644 --- a/design/mvp/canonical-abi/definitions.py +++ b/design/mvp/canonical-abi/definitions.py @@ -17,6 +17,9 @@ class Trap(BaseException): pass class CoreWebAssemblyException(BaseException): pass +def assert_if(cond1, cond2): + assert(not cond1 or cond2) + def trap(): raise Trap() @@ -171,6 +174,14 @@ class OwnType(ValType): class BorrowType(ValType): rt: ResourceType +@dataclass +class StreamType(ValType): + t: ValType + +@dataclass +class FutureType(ValType): + t: ValType + ### CallContext class CallContext: @@ -198,7 +209,7 @@ class CanonicalOptions: class ComponentInstance: resources: ResourceTables - async_subtasks: Table[Subtask] + waitables: Table[Subtask|StreamHandle|FutureHandle] num_tasks: int may_leave: bool backpressure: bool @@ -207,7 +218,7 @@ class ComponentInstance: def __init__(self): self.resources = ResourceTables() - self.async_subtasks = Table[Subtask]() + self.waitables = Table[Subtask|StreamHandle|FutureHandle]() self.num_tasks = 0 self.may_leave = True self.backpressure = False @@ -250,6 +261,8 @@ class Table(Generic[ElemT]): array: list[Optional[ElemT]] free: list[int] + MAX_LENGTH = 2**30 - 1 + def __init__(self): self.array = [None] self.free = [] @@ -266,7 +279,7 @@ def add(self, e): self.array[i] = e else: i = len(self.array) - trap_if(i >= 2**30) + trap_if(i > Table.MAX_LENGTH) self.array.append(e) return i @@ -300,9 +313,13 @@ class EventCode(IntEnum): CALL_RETURNED = CallState.RETURNED CALL_DONE = CallState.DONE YIELDED = 4 + STREAM_READ = 5 + STREAM_WRITE = 6 + FUTURE_READ = 7 + FUTURE_WRITE = 8 -EventTuple = tuple[EventCode, int] -EventCallback = Callable[[], EventTuple] +EventTuple = tuple[EventCode, int, int] +EventCallback = Callable[[], Optional[EventTuple]] OnBlockCallback = Callable[[Awaitable], Any] current_task = asyncio.Lock() @@ -410,11 +427,12 @@ async def wait(self) -> EventTuple: return e def poll(self) -> Optional[EventTuple]: - if self.events: + while self.events: event = self.events.pop(0) if not self.events: self.has_events.clear() - return event() + if (e := event()): + return e return None def notify(self, event: EventCallback): @@ -438,7 +456,7 @@ def return_(self, flat_results): def exit(self): assert(current_task.locked()) - assert(not self.events) + assert(not self.poll()) assert(self.inst.num_tasks >= 1) trap_if(self.on_return) trap_if(self.need_to_drop != 0) @@ -481,10 +499,10 @@ def maybe_notify_supertask(self): self.enqueued = True def subtask_event(): self.enqueued = False - i = self.inst.async_subtasks.array.index(self) + i = self.inst.waitables.array.index(self) if self.state == CallState.DONE: self.release_lenders() - return (EventCode(self.state), i) + return (EventCode(self.state), i, 0) self.task.notify(subtask_event) def on_start(self): @@ -517,7 +535,193 @@ def drop(self): trap_if(self.state != CallState.DONE) self.task.need_to_drop -= 1 -### Despecialization +class Buffer: + MAX_LENGTH = 2**30 - 1 + + def __init__(self, cx, t, ptr, length): + trap_if(length == 0 or length > Buffer.MAX_LENGTH) + trap_if(ptr != align_to(ptr, alignment(t))) + trap_if(ptr + length * elem_size(t) > len(cx.opts.memory)) + self._cx = cx + self._t = t + self._begin = ptr + self._length = length + self._progress = 0 + + def progress(self): + return self._progress + + def remain(self): + assert(self._progress <= self._length) + return self._length - self._progress + +class ReadableBuffer(Buffer): + def lift(self, n): + assert(n <= self.remain()) + ptr = self._begin + self._progress * elem_size(self._t) + vs = load_list_from_valid_range(self._cx, ptr, n, self._t) + self._progress += n + return vs + +class WritableBuffer(Buffer): + def lower(self, vs): + assert(len(vs) <= self.remain()) + ptr = self._begin + self._progress * elem_size(self._t) + store_list_into_valid_range(self._cx, vs, ptr, self._t) + self._progress += len(vs) + +class AsyncValue: + destroyed: Callable[[]] + destroy: Callable[[], None] + read: Callable[[WritableBuffer, OnBlockCallback], Awaitable] + cancel_read: Callable[[WritableBuffer, OnBlockCallback], Awaitable] + maybe_writer_handle_index: Callable[[ComponentInstance], Optional[int]] + + def __init__(self, impl): + self.destroyed = impl.destroyed + self.destroy = impl.destroy + self.read = impl.read + self.cancel_read = impl.cancel_read + self.maybe_writer_handle_index = impl.maybe_writer_handle_index + +class AsyncValueHandle: + async_value: AsyncValue + t: ValType + cx: Optional[CallContext] + copying_buffer: Optional[Buffer] + + def __init__(self, async_value, t, cx): + self.async_value = async_value + self.t = t + self.cx = cx + self.copying_buffer = None + + def drop(self): + trap_if(self.copying_buffer) + if not self.async_value.destroyed(): + self.async_value.destroy() + if self.cx: + self.cx.task.need_to_drop -= 1 + +class ReadableAsyncValueHandle(AsyncValueHandle): + async def copy(self, dst, on_block): + await self.async_value.read(dst, on_block) + async def cancel_copy(self, dst, on_block): + await self.async_value.cancel_read(dst, on_block) + +class WritableAsyncValueHandle(AsyncValueHandle): + destroyed: bool + rendezvous_buffer: Optional[Buffer] + rendezvous_future: Optional[asyncio.Future] + + def __init__(self, t): + super().__init__(AsyncValue(self), t, cx = None) + self.destroyed = False + self.rendezvous_buffer = None + self.rendezvous_future = None + + def destroyed(self): + return self.destroyed + + async def copy(self, src, on_block): + await self.rendezvous('write', src, on_block) + async def read(self, dst, on_block): + await self.rendezvous('read', dst, on_block) + async def rendezvous(self, direction, buffer, on_block): + assert(not self.async_value.destroyed()) + if self.rendezvous_buffer: + ncopy = min(buffer.remain(), self.rendezvous_buffer.remain()) + assert(ncopy > 0) + match direction: + case 'read': buffer.lower(self.rendezvous_buffer.lift(ncopy)) + case 'write': self.rendezvous_buffer.lower(buffer.lift(ncopy)) + if not self.rendezvous_buffer.remain(): + self.rendezvous_buffer = None + if self.rendezvous_future: + self.rendezvous_future.set_result(None) + self.rendezvous_future = None + else: + assert(not (self.rendezvous_buffer or self.rendezvous_future)) + self.rendezvous_buffer = buffer + self.rendezvous_future = asyncio.Future() + await on_block(self.rendezvous_future) + if self.rendezvous_buffer is buffer: + self.rendezvous_buffer = None + + async def cancel_copy(self, src, on_block): + await self.cancel_rendezvous('write', src, on_block) + async def cancel_read(self, dst, on_block): + await self.cancel_rendezvous('read', dst, on_block) + async def cancel_rendezvous(self, direction, buffer, on_block): + assert(not self.async_value.destroyed()) + if self.rendezvous_buffer is buffer: + self.rendezvous_buffer = None + if self.rendezvous_future: + self.rendezvous_future.set_result(None) + self.rendezvous_future = None + + def maybe_writer_handle_index(self, inst): + assert(not self.async_value.destroyed()) + if inst is self.cx.inst: + return self.cx.task.inst.waitables.array.index(self) + return None + + def destroy(self): + assert(not self.destroyed) + self.destroyed = True + self.rendezvous_buffer = None + if self.rendezvous_future: + self.rendezvous_future.set_result(None) + self.rendezvous_future = None + +class StreamHandle: pass +class ReadableStreamHandle(StreamHandle, ReadableAsyncValueHandle): pass +class WritableStreamHandle(StreamHandle, WritableAsyncValueHandle): pass + +class FutureHandle: + async def copy(self, buffer, on_block): + assert(buffer.remain() == 1) + await super().copy(buffer, on_block) + if buffer.remain() == 0 and not self.async_value.destroyed(): + self.async_value.destroy() + + async def cancel_copy(self, buffer, on_block): + await super().cancel_copy(buffer, on_block) + if buffer.remain() == 0 and not self.async_value.destroyed(): + self.async_value.destroy() + + def drop(self): + trap_if(not self.async_value.destroyed()) + super().drop() + +class ReadableFutureHandle(FutureHandle, ReadableAsyncValueHandle): pass +class WritableFutureHandle(FutureHandle, WritableAsyncValueHandle): pass + +### Type utilities + +def contains_async(t): + match t: + case StreamType() | FutureType(): + return True + case PrimValType() | OwnType() | BorrowType(): + return False + case FuncType(): + return any(contains_async(t) for t in t.param_types()) or \ + any(contains_async(t) for t in t.result_types()) + case ListType(t): + return contains_async(t) + case RecordType(fs): + return any(contains_async(f.t) for f in fs) + case TupleType(ts): + return any(contains_async(t) for t in ts) + case VariantType(cs): + return any(contains_async(c.t) for c in cs) + case OptionType(t): + return contains_async(t) + case ResultType(o,e): + return contains_async(o) or contains_async(e) + case _: + assert(False) def despecialize(t): match t: @@ -545,6 +749,7 @@ def alignment(t): case VariantType(cases) : return alignment_variant(cases) case FlagsType(labels) : return alignment_flags(labels) case OwnType() | BorrowType() : return 4 + case StreamType() | FutureType() : return 4 def alignment_list(elem_type, maybe_length): if maybe_length is not None: @@ -601,6 +806,7 @@ def elem_size(t): case VariantType(cases) : return elem_size_variant(cases) case FlagsType(labels) : return elem_size_flags(labels) case OwnType() | BorrowType() : return 4 + case StreamType() | FutureType() : return 4 def elem_size_list(elem_type, maybe_length): if maybe_length is not None: @@ -660,6 +866,8 @@ def load(cx, ptr, t): case FlagsType(labels) : return load_flags(cx, ptr, labels) case OwnType() : return lift_own(cx, load_int(cx, ptr, 4), t) case BorrowType() : return lift_borrow(cx, load_int(cx, ptr, 4), t) + case StreamType(t) : return lift_stream(cx, load_int(cx, ptr, 4), t) + case FutureType(t) : return lift_future(cx, load_int(cx, ptr, 4), t) def load_int(cx, ptr, nbytes, signed = False): return int.from_bytes(cx.opts.memory[ptr : ptr+nbytes], 'little', signed=signed) @@ -813,6 +1021,30 @@ def lift_borrow(cx, i, t): cx.add_lender(h) return h.rep +def lift_stream(cx, i, t): + return lift_async_value(ReadableStreamHandle, WritableStreamHandle, cx, i, t) + +def lift_future(cx, i, t): + v = lift_async_value(ReadableFutureHandle, WritableFutureHandle, cx, i, t) + trap_if(v.destroyed()) + return v + +def lift_async_value(ReadableHandleT, WritableHandleT, cx, i, t): + h = cx.inst.waitables.get(i) + trap_if(not isinstance(h, ReadableHandleT|WritableHandleT)) + trap_if(h.t != t) + match h: + case ReadableHandleT(): + trap_if(h.copying_buffer) + h.cx.task.need_to_drop -= 1 + cx.inst.waitables.remove(i) + case WritableHandleT(): + trap_if(h.cx is not None) + assert(not h.copying_buffer) + h.cx = cx + h.cx.task.need_to_drop += 1 + return h.async_value + ### Storing def store(cx, v, t, ptr): @@ -838,6 +1070,8 @@ def store(cx, v, t, ptr): case FlagsType(labels) : store_flags(cx, v, ptr, labels) case OwnType() : store_int(cx, lower_own(cx, v, t), ptr, 4) case BorrowType() : store_int(cx, lower_borrow(cx, v, t), ptr, 4) + case StreamType(t) : store_int(cx, lower_stream(cx, v, t), ptr, 4) + case FutureType(t) : store_int(cx, lower_future(cx, v, t), ptr, 4) def store_int(cx, v, ptr, nbytes, signed = False): cx.opts.memory[ptr : ptr+nbytes] = int.to_bytes(v, nbytes, 'little', signed=signed) @@ -1100,6 +1334,27 @@ def lower_borrow(cx, rep, t): cx.need_to_drop += 1 return cx.inst.resources.add(t.rt, h) +def lower_stream(cx, v, t): + return lower_async_value(ReadableStreamHandle, WritableStreamHandle, cx, v, t) + +def lower_future(cx, v, t): + assert(not v.destroyed()) + return lower_async_value(ReadableFutureHandle, WritableFutureHandle, cx, v, t) + +def lower_async_value(ReadableHandleT, WritableHandleT, cx, v, t): + assert(isinstance(v, AsyncValue)) + if (i := v.maybe_writer_handle_index(cx.inst)): + h = cx.inst.waitables.array[i] + assert(isinstance(h, WritableHandleT)) + h.cx.task.need_to_drop -= 1 + h.cx = None + assert(2**31 > Table.MAX_LENGTH >= i) + return i | (2**31) + else: + h = ReadableHandleT(v, t, cx) + cx.task.need_to_drop += 1 + return cx.inst.waitables.add(h) + ### Flattening MAX_FLAT_PARAMS = 16 @@ -1150,6 +1405,7 @@ def flatten_type(t): case VariantType(cases) : return flatten_variant(cases) case FlagsType(labels) : return ['i32'] case OwnType() | BorrowType() : return ['i32'] + case StreamType() | FutureType() : return ['i32'] def flatten_list(elem_type, maybe_length): if maybe_length is not None: @@ -1216,6 +1472,8 @@ def lift_flat(cx, vi, t): case FlagsType(labels) : return lift_flat_flags(vi, labels) case OwnType() : return lift_own(cx, vi.next('i32'), t) case BorrowType() : return lift_borrow(cx, vi.next('i32'), t) + case StreamType(t) : return lift_stream(cx, vi.next('i32'), t) + case FutureType(t) : return lift_future(cx, vi.next('i32'), t) def lift_flat_unsigned(vi, core_width, t_width): i = vi.next('i' + str(core_width)) @@ -1307,6 +1565,8 @@ def lower_flat(cx, v, t): case FlagsType(labels) : return lower_flat_flags(v, labels) case OwnType() : return [lower_own(cx, v, t)] case BorrowType() : return [lower_borrow(cx, v, t)] + case StreamType(t) : return [lower_stream(cx, v, t)] + case FutureType(t) : return [lower_future(cx, v, t)] def lower_flat_signed(i, core_bits): if i < 0: @@ -1420,10 +1680,10 @@ async def canon_lift(opts, inst, ft, callee, caller, on_start, on_return, on_blo ctx = packed_ctx & ~1 if is_yield: await task.yield_() - event, payload = (EventCode.YIELDED, 0) + event, p1, p2 = (EventCode.YIELDED, 0, 0) else: - event, payload = await task.wait() - [packed_ctx] = await call_and_trap_on_throw(opts.callback, task, [ctx, event, payload]) + event, p1, p2 = await task.wait() + [packed_ctx] = await call_and_trap_on_throw(opts.callback, task, [ctx, event, p1, p2]) task.exit() async def call_and_trap_on_throw(callee, task, args): @@ -1438,6 +1698,7 @@ async def canon_lower(opts, ft, callee, task, flat_args): trap_if(not task.inst.may_leave) subtask = Subtask(opts, ft, task, flat_args) if opts.sync: + assert(not contains_async(ft)) await task.call_sync(callee, task, subtask.on_start, subtask.on_return) flat_results = subtask.finish() else: @@ -1448,7 +1709,7 @@ async def do_call(on_block): case Blocked(): subtask.notify_supertask = True task.need_to_drop += 1 - i = task.inst.async_subtasks.add(subtask) + i = task.inst.waitables.add(subtask) flat_results = [pack_async_result(i, subtask.state)] case Returned(): flat_results = [0] @@ -1520,8 +1781,9 @@ async def canon_task_return(task, core_ft, flat_args): async def canon_task_wait(task, ptr): trap_if(not task.inst.may_leave) trap_if(task.opts.callback is not None) - event, payload = await task.wait() - store(task, payload, U32Type(), ptr) + event, p1, p2 = await task.wait() + store(task, p1, U32Type(), ptr) + store(task, p2, U32Type(), ptr + 4) return [event] ### 🔀 `canon task.poll` @@ -1531,7 +1793,7 @@ async def canon_task_poll(task, ptr): ret = task.poll() if ret is None: return [0] - store(task, ret, TupleType([U32Type(), U32Type()]), ptr) + store(task, ret, TupleType([U32Type(), U32Type(), U32Type()]), ptr) return [1] ### 🔀 `canon task.yield` @@ -1542,9 +1804,110 @@ async def canon_task_yield(task): await task.yield_() return [] -### 🔀 `canon subtask.drop` +### 🔀 `canon {stream,future}.new` + +async def canon_stream_new(elem_type, task): + trap_if(not task.inst.may_leave) + h = WritableStreamHandle(elem_type) + return [ task.inst.waitables.add(h) ] + +async def canon_future_new(t, task): + trap_if(not task.inst.may_leave) + h = WritableFutureHandle(t) + return [ task.inst.waitables.add(h) ] + +### 🔀 `canon {stream,future}.{read,write}` + +async def canon_stream_read(task, i, ptr, n): + return await canon_copy(ReadableStreamHandle, WritableBuffer, + task, i, ptr, n, EventCode.STREAM_READ) + +async def canon_stream_write(task, i, ptr, n): + return await canon_copy(WritableStreamHandle, ReadableBuffer, + task, i, ptr, n, EventCode.STREAM_WRITE) + +async def canon_future_read(task, i, ptr): + return await canon_copy(ReadableFutureHandle, WritableBuffer, + task, i, ptr, 1, EventCode.FUTURE_READ) + +async def canon_future_write(task, i, ptr): + return await canon_copy(WritableFutureHandle, ReadableBuffer, + task, i, ptr, 1, EventCode.FUTURE_WRITE) + +async def canon_copy(HandleT, BufferT, task, i, ptr, n, event_code): + trap_if(not task.inst.may_leave) + h = task.inst.waitables.get(i) + trap_if(not isinstance(h, HandleT)) + trap_if(not h.cx) + trap_if(h.copying_buffer) + buffer = BufferT(h.cx, h.t, ptr, n) + if h.async_value.destroyed(): + trap_if(issubclass(HandleT, FutureHandle)) + flat_results = [CLOSED] + else: + async def do_copy(on_block): + await h.copy(buffer, on_block) + def stream_event(): + if h.copying_buffer is buffer: + h.copying_buffer = None + return (event_code, i, copy_result(HandleT, buffer, h)) + else: + return None + h.cx.task.notify(stream_event) + match await call_and_handle_blocking(do_copy): + case Blocked(): + h.copying_buffer = buffer + flat_results = [BLOCKED] + case Returned(): + flat_results = [copy_result(HandleT, buffer, h)] + return flat_results + +def copy_result(HandleT, buffer, h): + assert_if(issubclass(HandleT, FutureHandle), buffer.progress() == 1) + if buffer.progress(): + return buffer.progress() + assert(h.async_value.destroyed()) + return CLOSED + +BLOCKED = 0xffff_ffff +CLOSED = 0x8000_0000 +assert(Buffer.MAX_LENGTH < CLOSED < BLOCKED) + +### 🔀 `canon {stream,future}.cancel-{read,write}` + +async def canon_stream_cancel_read(sync, task, i): + return await canon_cancel_copy(ReadableStreamHandle, sync, task, i) + +async def canon_stream_cancel_write(sync, task, i): + return await canon_cancel_copy(WritableStreamHandle, sync, task, i) + +async def canon_future_cancel_read(sync, task, i): + return await canon_cancel_copy(ReadableFutureHandle, sync, task, i) + +async def canon_future_cancel_write(sync, task, i): + return await canon_cancel_copy(WritableFutureHandle, sync, task, i) + +async def canon_cancel_copy(HandleT, sync, task, i): + trap_if(not task.inst.may_leave) + h = task.inst.waitables.get(i) + trap_if(not isinstance(h, HandleT)) + trap_if(not h.copying_buffer) + if sync: + await task.call_sync(h.cancel_copy, h.copying_buffer) + flat_results = [h.copying_buffer.progress()] + h.copying_buffer = None + else: + match await call_and_handle_blocking(h.cancel_copy, h.copying_buffer): + case Blocked(): + flat_results = [BLOCKED] + case Returned(): + flat_results = [h.copying_buffer.progress()] + h.copying_buffer = None + return flat_results + +### 🔀 `canon waitable.drop` -async def canon_subtask_drop(task, i): +async def canon_waitable_drop(task, i): trap_if(not task.inst.may_leave) - task.inst.async_subtasks.remove(i).drop() + task.inst.waitables.remove(i).drop() return [] diff --git a/design/mvp/canonical-abi/run_tests.py b/design/mvp/canonical-abi/run_tests.py index cc0fb688..46b8e59d 100644 --- a/design/mvp/canonical-abi/run_tests.py +++ b/design/mvp/canonical-abi/run_tests.py @@ -29,18 +29,18 @@ def realloc(self, original_ptr, original_size, alignment, new_size): ret = align_to(self.last_alloc, alignment) self.last_alloc = ret + new_size if self.last_alloc > len(self.memory): - print('oom: have {} need {}'.format(len(self.memory), self.last_alloc)) trap() self.memory[ret : ret + original_size] = self.memory[original_ptr : original_ptr + original_size] return ret -def mk_opts(memory = bytearray(), encoding = 'utf8', realloc = None, post_return = None): +def mk_opts(memory = bytearray(), encoding = 'utf8', realloc = None, post_return = None, sync_task_return = False, sync = True): opts = CanonicalOptions() opts.memory = memory opts.string_encoding = encoding opts.realloc = realloc opts.post_return = post_return - opts.sync = True + opts.sync_task_return = sync_task_return + opts.sync = sync opts.callback = None return opts @@ -59,6 +59,9 @@ def mk_tup_rec(x): return x return { str(i):mk_tup_rec(v) for i,v in enumerate(a) } +def unpack_lower_result(ret): + return (ret & ~(3 << 30), ret >> 30) + def fail(msg): raise BaseException(msg) @@ -361,56 +364,59 @@ def test_flatten(t, params, results): test_flatten(FuncType([U8Type() for _ in range(17)],[]), ['i32' for _ in range(17)], []) test_flatten(FuncType([U8Type() for _ in range(17)],[TupleType([U8Type(),U8Type()])]), ['i32' for _ in range(17)], ['i32','i32']) -def test_roundtrip(t, v): - before = definitions.MAX_FLAT_RESULTS - definitions.MAX_FLAT_RESULTS = 16 - ft = FuncType([t],[t]) - async def callee(task, x): - return x +async def test_roundtrips(): + async def test_roundtrip(t, v): + before = definitions.MAX_FLAT_RESULTS + definitions.MAX_FLAT_RESULTS = 16 - callee_heap = Heap(1000) - callee_opts = mk_opts(callee_heap.memory, 'utf8', callee_heap.realloc) - callee_inst = ComponentInstance() - lifted_callee = partial(canon_lift, callee_opts, callee_inst, ft, callee) + ft = FuncType([t],[t]) + async def callee(task, x): + return x - caller_heap = Heap(1000) - caller_opts = mk_opts(caller_heap.memory, 'utf8', caller_heap.realloc) - caller_inst = ComponentInstance() - caller_task = Task(caller_opts, caller_inst, ft, None, None, None) + callee_heap = Heap(1000) + callee_opts = mk_opts(callee_heap.memory, 'utf8', callee_heap.realloc) + callee_inst = ComponentInstance() + lifted_callee = partial(canon_lift, callee_opts, callee_inst, ft, callee) - flat_args = asyncio.run(caller_task.enter(lambda: [v])) + caller_heap = Heap(1000) + caller_opts = mk_opts(caller_heap.memory, 'utf8', caller_heap.realloc) + caller_inst = ComponentInstance() + caller_task = Task(caller_opts, caller_inst, ft, None, None, None) - return_in_heap = len(flatten_types([t])) > definitions.MAX_FLAT_RESULTS - if return_in_heap: - flat_args += [ caller_heap.realloc(0, 0, alignment(t), elem_size(t)) ] + flat_args = await caller_task.enter(lambda: [v]) - flat_results = asyncio.run(canon_lower(caller_opts, ft, lifted_callee, caller_task, flat_args)) + return_in_heap = len(flatten_types([t])) > definitions.MAX_FLAT_RESULTS + if return_in_heap: + flat_args += [ caller_heap.realloc(0, 0, alignment(t), elem_size(t)) ] - if return_in_heap: - flat_results = [ flat_args[-1] ] + flat_results = await canon_lower(caller_opts, ft, lifted_callee, caller_task, flat_args) - [got] = lift_flat_values(caller_task, definitions.MAX_FLAT_PARAMS, CoreValueIter(flat_results), [t]) - caller_task.exit() + if return_in_heap: + flat_results = [ flat_args[-1] ] - if got != v: - fail("test_roundtrip({},{}) got {}".format(t, v, got)) + [got] = lift_flat_values(caller_task, definitions.MAX_FLAT_PARAMS, CoreValueIter(flat_results), [t]) + caller_task.exit() - definitions.MAX_FLAT_RESULTS = before + if got != v: + fail("test_roundtrip({},{}) got {}".format(t, v, got)) -test_roundtrip(S8Type(), -1) -test_roundtrip(TupleType([U16Type(),U16Type()]), mk_tup(3,4)) -test_roundtrip(ListType(StringType()), [mk_str("hello there")]) -test_roundtrip(ListType(ListType(StringType())), [[mk_str("one"),mk_str("two")],[mk_str("three")]]) -test_roundtrip(ListType(OptionType(TupleType([StringType(),U16Type()]))), [{'some':mk_tup(mk_str("answer"),42)}]) -test_roundtrip(VariantType([CaseType('x', TupleType([U32Type(),U32Type(),U32Type(),U32Type(), - U32Type(),U32Type(),U32Type(),U32Type(), - U32Type(),U32Type(),U32Type(),U32Type(), - U32Type(),U32Type(),U32Type(),U32Type(), - StringType()]))]), - {'x': mk_tup(1,2,3,4, 5,6,7,8, 9,10,11,12, 13,14,15,16, mk_str("wat"))}) - -def test_handles(): + definitions.MAX_FLAT_RESULTS = before + + await test_roundtrip(S8Type(), -1) + await test_roundtrip(TupleType([U16Type(),U16Type()]), mk_tup(3,4)) + await test_roundtrip(ListType(StringType()), [mk_str("hello there")]) + await test_roundtrip(ListType(ListType(StringType())), [[mk_str("one"),mk_str("two")],[mk_str("three")]]) + await test_roundtrip(ListType(OptionType(TupleType([StringType(),U16Type()]))), [{'some':mk_tup(mk_str("answer"),42)}]) + await test_roundtrip(VariantType([CaseType('x', TupleType([U32Type(),U32Type(),U32Type(),U32Type(), + U32Type(),U32Type(),U32Type(),U32Type(), + U32Type(),U32Type(),U32Type(),U32Type(), + U32Type(),U32Type(),U32Type(),U32Type(), + StringType()]))]), + {'x': mk_tup(1,2,3,4, 5,6,7,8, 9,10,11,12, 13,14,15,16, mk_str("wat"))}) + + +async def test_handles(): before = definitions.MAX_FLAT_RESULTS definitions.MAX_FLAT_RESULTS = 16 @@ -502,7 +508,7 @@ def on_return(results): nonlocal got got = results - asyncio.run(canon_lift(opts, inst, ft, core_wasm, None, on_start, on_return, None)) + await canon_lift(opts, inst, ft, core_wasm, None, on_start, on_return, None) assert(len(got) == 3) assert(got[0] == 46) @@ -513,7 +519,6 @@ def on_return(results): assert(len(inst.resources.table(rt).free) == 4) definitions.MAX_FLAT_RESULTS = before -test_handles() async def test_async_to_async(): producer_heap = Heap(10) @@ -549,7 +554,7 @@ async def core_blocking_producer(task, args): [] = await canon_task_return(task, CoreFuncType(['i32'],[]), [44]) await task.wait_on(fut3) return [] - blocking_callee = partial(canon_lift, producer_opts, producer_inst, blocking_ft, core_blocking_producer) + blocking_callee = partial(canon_lift, producer_opts, producer_inst, blocking_ft, core_blocking_producer) consumer_heap = Heap(10) consumer_opts = mk_opts(consumer_heap.memory) @@ -563,31 +568,32 @@ async def consumer(task, args): u8 = consumer_heap.memory[ptr] assert(u8 == 43) [ret] = await canon_lower(consumer_opts, toggle_ft, toggle_callee, task, []) - assert(ret == (1 | (CallState.STARTED << 30))) + subi,state = unpack_lower_result(ret) + assert(state == CallState.STARTED) retp = ptr consumer_heap.memory[retp] = 13 [ret] = await canon_lower(consumer_opts, blocking_ft, blocking_callee, task, [83, retp]) assert(ret == (2 | (CallState.STARTING << 30))) assert(consumer_heap.memory[retp] == 13) fut1.set_result(None) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_DONE) assert(callidx == 1) - [] = await canon_subtask_drop(task, callidx) - event, callidx = await task.wait() + [] = await canon_waitable_drop(task, callidx) + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_STARTED) assert(callidx == 2) assert(consumer_heap.memory[retp] == 13) fut2.set_result(None) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_RETURNED) assert(callidx == 2) assert(consumer_heap.memory[retp] == 44) fut3.set_result(None) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_DONE) assert(callidx == 2) - [] = await canon_subtask_drop(task, callidx) + [] = await canon_waitable_drop(task, callidx) dtor_fut = asyncio.Future() dtor_value = None @@ -605,10 +611,10 @@ async def dtor(task, args): assert(ret == (2 | (CallState.STARTED << 30))) assert(dtor_value is None) dtor_fut.set_result(None) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == CallState.DONE) assert(callidx == 2) - [] = await canon_subtask_drop(task, callidx) + [] = await canon_waitable_drop(task, callidx) [] = await canon_task_return(task, CoreFuncType(['i32'],[]), [42]) return [] @@ -628,7 +634,6 @@ def on_return(results): assert(len(got) == 1) assert(got[0] == 42) -asyncio.run(test_async_to_async()) async def test_async_callback(): producer_inst = ComponentInstance() @@ -662,22 +667,25 @@ async def consumer(task, args): return [42] async def callback(task, args): - assert(len(args) == 3) + assert(len(args) == 4) if args[0] == 42: assert(args[1] == EventCode.CALL_DONE) assert(args[2] == 1) - await canon_subtask_drop(task, 1) + assert(args[3] == 0) + await canon_waitable_drop(task, 1) return [53] elif args[0] == 52: assert(args[1] == EventCode.YIELDED) assert(args[2] == 0) + assert(args[3] == 0) fut2.set_result(None) return [62] else: assert(args[0] == 62) assert(args[1] == EventCode.CALL_DONE) assert(args[2] == 2) - await canon_subtask_drop(task, 2) + assert(args[3] == 0) + await canon_waitable_drop(task, 2) [] = await canon_task_return(task, CoreFuncType(['i32'],[]), [83]) return [0] @@ -696,7 +704,6 @@ def on_return(results): await canon_lift(opts, consumer_inst, consumer_ft, consumer, None, on_start, on_return) assert(got[0] == 83) -asyncio.run(test_async_callback()) async def test_async_to_sync(): producer_opts = CanonicalOptions() @@ -740,19 +747,19 @@ async def consumer(task, args): fut.set_result(None) assert(producer1_done == False) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_DONE) assert(callidx == 1) - await canon_subtask_drop(task, callidx) + await canon_waitable_drop(task, callidx) assert(producer1_done == True) assert(producer2_done == False) await canon_task_yield(task) assert(producer2_done == True) - event, callidx = task.poll() + event, callidx, _ = task.poll() assert(event == EventCode.CALL_DONE) assert(callidx == 2) - await canon_subtask_drop(task, callidx) + await canon_waitable_drop(task, callidx) assert(producer2_done == True) assert(task.poll() is None) @@ -771,7 +778,6 @@ def on_return(results): await canon_lift(consumer_opts, consumer_inst, consumer_ft, consumer, None, on_start, on_return) assert(got[0] == 83) -asyncio.run(test_async_to_sync()) async def test_async_backpressure(): producer_opts = CanonicalOptions() @@ -819,18 +825,18 @@ async def consumer(task, args): fut.set_result(None) assert(producer1_done == False) assert(producer2_done == False) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_DONE) assert(callidx == 1) assert(producer1_done == True) assert(producer2_done == True) - event, callidx = task.poll() + event, callidx, _ = task.poll() assert(event == EventCode.CALL_DONE) assert(callidx == 2) assert(producer2_done == True) - await canon_subtask_drop(task, 1) - await canon_subtask_drop(task, 2) + await canon_waitable_drop(task, 1) + await canon_waitable_drop(task, 2) assert(task.poll() is None) @@ -848,8 +854,6 @@ def on_return(results): await canon_lift(consumer_opts, consumer_inst, consumer_ft, consumer, None, on_start, on_return) assert(got[0] == 84) -if definitions.DETERMINISTIC_PROFILE: - asyncio.run(test_async_backpressure()) async def test_sync_using_wait(): hostcall_opts = mk_opts() @@ -878,16 +882,16 @@ async def core_func(task, args): assert(ret == (2 | (CallState.STARTED << 30))) fut1.set_result(None) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_DONE) assert(callidx == 1) fut2.set_result(None) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_DONE) assert(callidx == 2) - await canon_subtask_drop(task, 1) - await canon_subtask_drop(task, 2) + await canon_waitable_drop(task, 1) + await canon_waitable_drop(task, 2) return [] @@ -896,6 +900,785 @@ def on_start(): return [] def on_return(results): pass await canon_lift(mk_opts(), inst, ft, core_func, None, on_start, on_return) -asyncio.run(test_sync_using_wait()) + +class HostSource(AsyncValue): + remaining: list[int] + destroy_if_empty: bool + chunk: int + waiting: Optional[asyncio.Future] + eager_cancel: asyncio.Event + + def __init__(self, contents, chunk, destroy_if_empty = True): + self.remaining = contents + self.destroy_if_empty = destroy_if_empty + self.chunk = chunk + self.waiting = None + self.eager_cancel = asyncio.Event() + self.eager_cancel.set() + + def destroyed(self): + return not self.remaining and self.destroy_if_empty + + def wake_waiting(self, cancelled = False): + if self.waiting: + self.waiting.set_result(cancelled) + self.waiting = None + + def destroy(self): + self.remaining = [] + self.destroy_if_empty = True + self.wake_waiting() + + def destroy_once_empty(self): + self.destroy_if_empty = True + if self.destroyed(): + self.wake_waiting() + + async def read(self, dst, on_block): + if not self.remaining: + if self.destroyed(): + return + self.waiting = asyncio.Future() + cancelled = await on_block(self.waiting) + if cancelled or self.destroyed(): + return + assert(self.remaining) + n = min(dst.remain(), len(self.remaining), self.chunk) + dst.lower(self.remaining[:n]) + del self.remaining[:n] + + async def cancel_read(self, dst, on_block): + await on_block(self.eager_cancel.wait()) + self.wake_waiting(True) + + def write(self, vs): + assert(vs and not self.destroyed()) + self.remaining += vs + self.wake_waiting() + + def maybe_writer_handle_index(self, inst): + return None + +class HostSink: + stream: AsyncValue + received: list[int] + chunk: int + write_remain: int + write_event: asyncio.Event + ready_to_consume: asyncio.Event + + def __init__(self, stream, chunk, remain = 2**64): + self.stream = stream + self.received = [] + self.chunk = chunk + self.write_remain = remain + self.write_event = asyncio.Event() + if remain: + self.write_event.set() + self.ready_to_consume = asyncio.Event() + async def read_all(): + while not self.stream.destroyed(): + async def on_block(f): + return await f + await self.write_event.wait() + await self.stream.read(self, on_block) + self.ready_to_consume.set() + asyncio.create_task(read_all()) + + def set_remain(self, n): + self.write_remain = n + if self.write_remain > 0: + self.write_event.set() + + def remain(self): + return self.write_remain + + def lower(self, vs): + self.received += vs + self.ready_to_consume.set() + self.write_remain -= len(vs) + if self.write_remain == 0: + self.write_event.clear() + + async def consume(self, n): + while n > len(self.received): + self.ready_to_consume.clear() + await self.ready_to_consume.wait() + if self.stream.destroyed(): + return None + ret = self.received[:n]; + del self.received[:n] + return ret + +async def test_eager_stream_completion(): + ft = FuncType([StreamType(U8Type())], [StreamType(U8Type())]) + inst = ComponentInstance() + mem = bytearray(20) + opts = mk_opts(memory=mem, sync=False) + + async def host_import(task, on_start, on_return, on_block): + args = on_start() + assert(len(args) == 1) + assert(isinstance(args[0], AsyncValue)) + incoming = HostSink(args[0], chunk=4) + outgoing = HostSource([], chunk=4, destroy_if_empty=False) + on_return([outgoing]) + async def add10(): + while (vs := await incoming.consume(4)): + for i in range(len(vs)): + vs[i] += 10 + outgoing.write(vs) + outgoing.destroy() + asyncio.create_task(add10()) + + src_stream = HostSource([1,2,3,4,5,6,7,8], chunk=4) + def on_start(): + return [src_stream] + + dst_stream = None + def on_return(results): + assert(len(results) == 1) + nonlocal dst_stream + dst_stream = HostSink(results[0], chunk=4) + + async def core_func(task, args): + assert(len(args) == 1) + rsi1 = args[0] + assert(rsi1 == 1) + [wsi1] = await canon_stream_new(U8Type(), task) + [] = await canon_task_return(task, CoreFuncType(['i32'],[]), [wsi1]) + [ret] = await canon_stream_read(task, rsi1, 0, 4) + assert(ret == 4) + assert(mem[0:4] == b'\x01\x02\x03\x04') + [wsi2] = await canon_stream_new(U8Type(), task) + retp = 12 + [ret] = await canon_lower(opts, ft, host_import, task, [wsi2, retp]) + assert(ret == 0) + rsi2 = mem[retp] + [ret] = await canon_stream_write(task, wsi2, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_read(task, rsi2, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_write(task, wsi1, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_read(task, rsi1, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_read(task, rsi1, 0, 4) + assert(ret == definitions.CLOSED) + assert(mem[0:4] == b'\x05\x06\x07\x08') + [ret] = await canon_stream_write(task, wsi2, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_read(task, rsi2, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_write(task, wsi1, 0, 4) + assert(ret == 4) + [] = await canon_waitable_drop(task, rsi1) + [] = await canon_waitable_drop(task, rsi2) + [] = await canon_waitable_drop(task, wsi1) + [] = await canon_waitable_drop(task, wsi2) + return [] + + await canon_lift(opts, inst, ft, core_func, None, on_start, on_return) + assert(dst_stream.received == [11,12,13,14,15,16,17,18]) + + +async def test_async_stream_ops(): + ft = FuncType([StreamType(U8Type())], [StreamType(U8Type())]) + inst = ComponentInstance() + mem = bytearray(20) + opts = mk_opts(memory=mem, sync=False) + + host_import_incoming = None + host_import_outgoing = None + async def host_import(task, on_start, on_return, on_block): + nonlocal host_import_incoming, host_import_outgoing + args = on_start() + assert(len(args) == 1) + assert(isinstance(args[0], AsyncValue)) + host_import_incoming = HostSink(args[0], chunk=4, remain = 0) + host_import_outgoing = HostSource([], chunk=4, destroy_if_empty=False) + on_return([host_import_outgoing]) + while not host_import_incoming.stream.destroyed(): + vs = await on_block(host_import_incoming.consume(4)) + for i in range(len(vs)): + vs[i] += 10 + host_import_outgoing.write(vs) + host_import_outgoing.destroy_once_empty() + + src_stream = HostSource([], chunk=4, destroy_if_empty = False) + def on_start(): + return [src_stream] + + dst_stream = None + def on_return(results): + assert(len(results) == 1) + nonlocal dst_stream + dst_stream = HostSink(results[0], chunk=4, remain = 0) + + async def core_func(task, args): + [rsi1] = args + assert(rsi1 == 1) + [wsi1] = await canon_stream_new(U8Type(), task) + [] = await canon_task_return(task, CoreFuncType(['i32'],[]), [wsi1]) + [ret] = await canon_stream_read(task, rsi1, 0, 4) + assert(ret == definitions.BLOCKED) + src_stream.write([1,2,3,4]) + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_READ) + assert(p1 == rsi1) + assert(p2 == 4) + assert(mem[0:4] == b'\x01\x02\x03\x04') + [wsi2] = await canon_stream_new(U8Type(), task) + retp = 16 + [ret] = await canon_lower(opts, ft, host_import, task, [wsi2, retp]) + subi,state = unpack_lower_result(ret) + assert(state == CallState.RETURNED) + rsi2 = mem[16] + assert(rsi2 == 4) + [ret] = await canon_stream_write(task, wsi2, 0, 4) + assert(ret == definitions.BLOCKED) + host_import_incoming.set_remain(100) + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_WRITE) + assert(p1 == wsi2) + assert(p2 == 4) + [ret] = await canon_stream_read(task, rsi2, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_write(task, wsi1, 0, 4) + assert(ret == definitions.BLOCKED) + dst_stream.set_remain(100) + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_WRITE) + assert(p1 == wsi1) + assert(p2 == 4) + src_stream.write([5,6,7,8]) + src_stream.destroy_once_empty() + [ret] = await canon_stream_read(task, rsi1, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_read(task, rsi1, 0, 4) + assert(ret == definitions.CLOSED) + [] = await canon_waitable_drop(task, rsi1) + assert(mem[0:4] == b'\x05\x06\x07\x08') + [ret] = await canon_stream_write(task, wsi2, 0, 4) + assert(ret == 4) + [] = await canon_waitable_drop(task, wsi2) + [ret] = await canon_stream_read(task, rsi2, 0, 4) + assert(ret == definitions.BLOCKED) + event, p1, p2 = await task.wait() + assert(event == EventCode.CALL_DONE) + assert(p1 == subi) + assert(p2 == 0) + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_READ) + assert(p1 == rsi2) + assert(p2 == 4) + [ret] = await canon_stream_read(task, rsi2, 0, 4) + assert(ret == definitions.CLOSED) + [] = await canon_waitable_drop(task, subi) + [] = await canon_waitable_drop(task, rsi2) + [ret] = await canon_stream_write(task, wsi1, 0, 4) + assert(ret == 4) + [] = await canon_waitable_drop(task, wsi1) + return [] + + await canon_lift(opts, inst, ft, core_func, None, on_start, on_return) + assert(dst_stream.received == [11,12,13,14,15,16,17,18]) + + +async def test_stream_forward(): + src_stream = HostSource([1,2,3,4], chunk=4) + def on_start(): + return [src_stream] + + dst_stream = None + def on_return(results): + assert(len(results) == 1) + nonlocal dst_stream + dst_stream = results[0] + + async def core_func(task, args): + assert(len(args) == 1) + rsi1 = args[0] + assert(rsi1 == 1) + return [rsi1] + + opts = mk_opts() + inst = ComponentInstance() + ft = FuncType([StreamType(U8Type())], [StreamType(U8Type())]) + await canon_lift(opts, inst, ft, core_func, None, on_start, on_return) + assert(src_stream is dst_stream) + + +async def test_receive_own_stream(): + inst = ComponentInstance() + mem = bytearray(20) + opts = mk_opts(memory=mem, sync=False) + + host_ft = FuncType([StreamType(U8Type())], [StreamType(U8Type())]) + async def host_import(task, on_start, on_return, on_block): + args = on_start() + assert(len(args) == 1) + assert(isinstance(args[0], AsyncValue)) + on_return(args) + + async def core_func(task, args): + assert(len(args) == 0) + [wsi] = await canon_stream_new(U8Type(), task) + assert(wsi == 1) + retp = 4 + [ret] = await canon_lower(opts, host_ft, host_import, task, [wsi, retp]) + assert(ret == 0) + result = int.from_bytes(mem[retp : retp+4], 'little', signed=False) + assert(result == (wsi | 2**31)) + [] = await canon_waitable_drop(task, wsi) + return [] + + def on_start(): return [] + def on_return(results): assert(len(results) == 0) + ft = FuncType([],[]) + await canon_lift(mk_opts(), inst, ft, core_func, None, on_start, on_return) + + +async def test_host_partial_reads_writes(): + mem = bytearray(20) + opts = mk_opts(memory=mem, sync=False) + + src = HostSource([1,2,3,4], chunk=2, destroy_if_empty = False) + source_ft = FuncType([], [StreamType(U8Type())]) + async def host_source(task, on_start, on_return, on_block): + [] = on_start() + on_return([src]) + + dst = None + sink_ft = FuncType([StreamType(U8Type())], []) + async def host_sink(task, on_start, on_return, on_block): + nonlocal dst + [s] = on_start() + dst = HostSink(s, chunk=1, remain=2) + on_return([]) + + async def core_func(task, args): + assert(len(args) == 0) + retp = 4 + [ret] = await canon_lower(opts, source_ft, host_source, task, [retp]) + assert(ret == 0) + rsi = mem[retp] + assert(rsi == 1) + [ret] = await canon_stream_read(task, rsi, 0, 4) + assert(ret == 2) + assert(mem[0:2] == b'\x01\x02') + [ret] = await canon_stream_read(task, rsi, 0, 4) + assert(ret == 2) + assert(mem[0:2] == b'\x03\x04') + [ret] = await canon_stream_read(task, rsi, 0, 4) + assert(ret == definitions.BLOCKED) + src.write([5,6]) + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_READ) + assert(p1 == rsi) + assert(p2 == 2) + [] = await canon_waitable_drop(task, rsi) + + [wsi] = await canon_stream_new(U8Type(), task) + assert(wsi == 1) + [ret] = await canon_lower(opts, sink_ft, host_sink, task, [wsi]) + assert(ret == 0) + mem[0:6] = b'\x01\x02\x03\x04\x05\x06' + [ret] = await canon_stream_write(task, wsi, 0, 6) + assert(ret == 2) + [ret] = await canon_stream_write(task, wsi, 2, 6) + assert(ret == definitions.BLOCKED) + dst.set_remain(4) + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_WRITE) + assert(p1 == wsi) + assert(p2 == 4) + assert(dst.received == [1,2,3,4,5,6]) + [] = await canon_waitable_drop(task, wsi) + return [] + + opts2 = mk_opts() + inst = ComponentInstance() + def on_start(): return [] + def on_return(results): assert(len(results) == 0) + ft = FuncType([],[]) + await canon_lift(opts2, inst, ft, core_func, None, on_start, on_return) + + +async def test_wasm_to_wasm_stream(): + fut1, fut2, fut3, fut4 = asyncio.Future(), asyncio.Future(), asyncio.Future(), asyncio.Future() + + inst1 = ComponentInstance() + mem1 = bytearray(10) + opts1 = mk_opts(memory=mem1, sync=False) + ft1 = FuncType([], [StreamType(U8Type())]) + async def core_func1(task, args): + assert(not args) + [wsi] = await canon_stream_new(U8Type(), task) + [] = await canon_task_return(task, CoreFuncType(['i32'], []), [wsi]) + + await task.wait_on(fut1) + + mem1[0:4] = b'\x01\x02\x03\x04' + [ret] = await canon_stream_write(task, wsi, 0, 2) + assert(ret == 2) + [ret] = await canon_stream_write(task, wsi, 2, 2) + assert(ret == 2) + + await task.wait_on(fut2) + + mem1[0:8] = b'\x05\x06\x07\x08\x09\x0a\x0b\x0c' + [ret] = await canon_stream_write(task, wsi, 0, 8) + assert(ret == definitions.BLOCKED) + + fut3.set_result(None) + + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_WRITE) + assert(p1 == wsi) + assert(p2 == 4) + + fut4.set_result(None) + + [] = await canon_waitable_drop(task, wsi) + return [] + + func1 = partial(canon_lift, opts1, inst1, ft1, core_func1) + + inst2 = ComponentInstance() + mem2 = bytearray(10) + opts2 = mk_opts(memory=mem2, sync=False) + ft2 = FuncType([], []) + async def core_func2(task, args): + assert(not args) + [] = await canon_task_return(task, CoreFuncType([], []), []) + + retp = 0 + [ret] = await canon_lower(opts2, ft1, func1, task, [retp]) + subi,state = unpack_lower_result(ret) + assert(state== CallState.RETURNED) + rsi = mem2[0] + assert(rsi == 1) + + [ret] = await canon_stream_read(task, rsi, 0, 8) + assert(ret == definitions.BLOCKED) + + fut1.set_result(None) + + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_READ) + assert(p1 == rsi) + assert(p2 == 4) + assert(mem2[0:8] == b'\x01\x02\x03\x04\x00\x00\x00\x00') + + fut2.set_result(None) + await task.wait_on(fut3) + + mem2[0:8] = bytes(8) + [ret] = await canon_stream_read(task, rsi, 0, 2) + assert(ret == 2) + assert(mem2[0:6] == b'\x05\x06\x00\x00\x00\x00') + [ret] = await canon_stream_read(task, rsi, 2, 2) + assert(ret == 2) + assert(mem2[0:6] == b'\x05\x06\x07\x08\x00\x00') + + await task.wait_on(fut4) + + [ret] = await canon_stream_read(task, rsi, 0, 2) + assert(ret == definitions.CLOSED) + [] = await canon_waitable_drop(task, rsi) + + event, callidx, _ = await task.wait() + assert(event == EventCode.CALL_DONE) + assert(callidx == subi) + [] = await canon_waitable_drop(task, subi) + return [] + + await canon_lift(opts2, inst2, ft2, core_func2, None, lambda:[], lambda _:()) + + +async def test_borrow_stream(): + rt_inst = ComponentInstance() + rt = ResourceType(rt_inst, None) + + inst1 = ComponentInstance() + mem1 = bytearray(12) + opts1 = mk_opts(memory=mem1) + ft1 = FuncType([StreamType(BorrowType(rt))], []) + async def core_func1(task, args): + [rsi] = args + + [ret] = await canon_stream_read(task, rsi, 4, 2) + assert(ret == definitions.BLOCKED) + + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_READ) + assert(p1 == rsi) + assert(p2 == 2) + [ret] = await canon_stream_read(task, rsi, 0, 2) + assert(ret == definitions.CLOSED) + + [] = await canon_waitable_drop(task, rsi) + + h1 = mem1[4] + h2 = mem1[8] + assert(await canon_resource_rep(rt, task, h1) == [42]) + assert(await canon_resource_rep(rt, task, h2) == [43]) + [] = await canon_resource_drop(rt, True, task, h1) + [] = await canon_resource_drop(rt, True, task, h2) + + return [] + + func1 = partial(canon_lift, opts1, inst1, ft1, core_func1) + + inst2 = ComponentInstance() + mem2 = bytearray(10) + sync_opts2 = mk_opts(memory=mem2, sync=True) + async_opts2 = mk_opts(memory=mem2, sync=False) + ft2 = FuncType([], []) + async def core_func2(task, args): + assert(not args) + + [wsi] = await canon_stream_new(BorrowType(rt), task) + [ret] = await canon_lower(async_opts2, ft1, func1, task, [wsi]) + subi,state = unpack_lower_result(ret) + assert(state == CallState.STARTED) + + [h1] = await canon_resource_new(rt, task, 42) + [h2] = await canon_resource_new(rt, task, 43) + mem2[0] = h1 + mem2[4] = h2 + + [ret] = await canon_stream_write(task, wsi, 0, 2) + assert(ret == 2) + [] = await canon_waitable_drop(task, wsi) + + event, p1, _ = await task.wait() + assert(event == EventCode.CALL_DONE) + assert(p1 == subi) + + [] = await canon_waitable_drop(task, subi) + return [] + + await canon_lift(sync_opts2, inst2, ft2, core_func2, None, lambda:[], lambda _:()) + + +async def test_cancel_copy(): + inst = ComponentInstance() + mem = bytearray(10) + lower_opts = mk_opts(memory=mem, sync=False) + + host_ft1 = FuncType([StreamType(U8Type())],[]) + host_sink = None + async def host_func1(task, on_start, on_return, on_block): + nonlocal host_sink + [stream] = on_start() + host_sink = HostSink(stream, 2, remain = 0) + on_return([]) + + host_ft2 = FuncType([], [StreamType(U8Type())]) + host_source = None + async def host_func2(task, on_start, on_return, on_block): + nonlocal host_source + [] = on_start() + host_source = HostSource([], chunk=2, destroy_if_empty = False) + on_return([host_source]) + + async def core_func(task, args): + assert(not args) + + [wsi] = await canon_stream_new(U8Type(), task) + [ret] = await canon_lower(lower_opts, host_ft1, host_func1, task, [wsi]) + assert(ret == 0) + mem[0:4] = b'\x0a\x0b\x0c\x0d' + [ret] = await canon_stream_write(task, wsi, 0, 4) + assert(ret == definitions.BLOCKED) + host_sink.set_remain(2) + got = await host_sink.consume(2) + assert(got == [0xa, 0xb]) + [ret] = await canon_stream_cancel_write(True, task, wsi) + assert(ret == 2) + [] = await canon_waitable_drop(task, wsi) + + [wsi] = await canon_stream_new(U8Type(), task) + [ret] = await canon_lower(lower_opts, host_ft1, host_func1, task, [wsi]) + assert(ret == 0) + mem[0:4] = b'\x01\x02\x03\x04' + [ret] = await canon_stream_write(task, wsi, 0, 4) + assert(ret == definitions.BLOCKED) + host_sink.set_remain(2) + got = await host_sink.consume(2) + assert(got == [1, 2]) + [ret] = await canon_stream_cancel_write(False, task, wsi) + assert(ret == 2) + [] = await canon_waitable_drop(task, wsi) + + retp = 0 + [ret] = await canon_lower(lower_opts, host_ft2, host_func2, task, [retp]) + assert(ret == 0) + rsi = mem[retp] + [ret] = await canon_stream_read(task, rsi, 0, 4) + assert(ret == definitions.BLOCKED) + [ret] = await canon_stream_cancel_read(True, task, rsi) + assert(ret == 0) + [] = await canon_waitable_drop(task, rsi) + + retp = 0 + [ret] = await canon_lower(lower_opts, host_ft2, host_func2, task, [retp]) + assert(ret == 0) + rsi = mem[retp] + [ret] = await canon_stream_read(task, rsi, 0, 4) + assert(ret == definitions.BLOCKED) + host_source.eager_cancel.clear() + [ret] = await canon_stream_cancel_read(False, task, rsi) + assert(ret == definitions.BLOCKED) + host_source.write([7,8]) + await asyncio.sleep(0) + host_source.eager_cancel.set() + event,p1,p2 = await task.wait() + assert(event == EventCode.STREAM_READ) + assert(p1 == rsi) + assert(p2 == 2) + assert(mem[0:2] == b'\x07\x08') + [] = await canon_waitable_drop(task, rsi) + + return [] + + lift_opts = mk_opts() + await canon_lift(lift_opts, inst, FuncType([],[]), core_func, None, lambda:[], lambda _:()) + + +class HostFutureSink: + v: Optional[any] = None + + def remain(self): + return 1 if self.v is None else 0 + + def lower(self, v): + assert(not self.v) + assert(len(v) == 1) + self.v = v[0] + +class HostFutureSource(AsyncValue): + v: Optional[asyncio.Future] + def __init__(self): + self.v = asyncio.Future() + def destroyed(self): + return self.v is None + def destroy(self): + assert(self.v is None) + async def read(self, dst, on_block): + assert(self.v is not None) + v = await on_block(self.v) + if v: + dst.lower([v]) + self.v = None + async def cancel_read(self, dst, on_block): + if self.v and not self.v.done(): + self.v.set_result(None) + self.v = asyncio.Future() + def maybe_writer_handle_index(self, inst): + return None + +async def test_futures(): + inst = ComponentInstance() + mem = bytearray(10) + lower_opts = mk_opts(memory=mem, sync=False) + + host_ft1 = FuncType([FutureType(U8Type())],[FutureType(U8Type())]) + async def host_func(task, on_start, on_return, on_block): + [future] = on_start() + outgoing = HostFutureSource() + on_return([outgoing]) + incoming = HostFutureSink() + await future.read(incoming, on_block) + assert(incoming.v == 42) + outgoing.v.set_result(43) + + async def core_func(task, args): + assert(not args) + [wfi] = await canon_future_new(U8Type(), task) + retp = 0 + [ret] = await canon_lower(lower_opts, host_ft1, host_func, task, [wfi, retp]) + subi,state = unpack_lower_result(ret) + assert(state == CallState.RETURNED) + rfi = mem[retp] + + readp = 0 + [ret] = await canon_future_read(task, rfi, readp) + assert(ret == definitions.BLOCKED) + + writep = 8 + mem[writep] = 42 + [ret] = await canon_future_write(task, wfi, writep) + assert(ret == 1) + + event,p1,p2 = await task.wait() + assert(event == EventCode.CALL_DONE) + assert(p1 == subi) + + event,p1,p2 = await task.wait() + assert(event == EventCode.FUTURE_READ) + assert(p1 == rfi) + assert(p2 == 1) + assert(mem[readp] == 43) + + [] = await canon_waitable_drop(task, subi) + [] = await canon_waitable_drop(task, wfi) + [] = await canon_waitable_drop(task, rfi) + + [wfi] = await canon_future_new(U8Type(), task) + retp = 0 + [ret] = await canon_lower(lower_opts, host_ft1, host_func, task, [wfi, retp]) + subi,state = unpack_lower_result(ret) + assert(state == CallState.RETURNED) + rfi = mem[retp] + + readp = 0 + [ret] = await canon_future_read(task, rfi, readp) + assert(ret == definitions.BLOCKED) + + writep = 8 + mem[writep] = 42 + [ret] = await canon_future_write(task, wfi, writep) + assert(ret == 1) + + event,p1,p2 = await task.wait() + assert(event == EventCode.CALL_DONE) + assert(p1 == subi) + + [ret] = await canon_future_cancel_read(True, task, rfi) + assert(ret == 1) + assert(mem[readp] == 43) + + [] = await canon_waitable_drop(task, subi) + [] = await canon_waitable_drop(task, wfi) + [] = await canon_waitable_drop(task, rfi) + + return [] + + lift_opts = mk_opts() + await canon_lift(lift_opts, inst, FuncType([],[]), core_func, None, lambda:[], lambda _:()) + + +async def run_async_tests(): + await test_roundtrips() + await test_handles() + await test_async_to_async() + await test_async_callback() + await test_async_to_sync() + await test_async_backpressure() + await test_sync_using_wait() + await test_eager_stream_completion() + await test_stream_forward() + await test_receive_own_stream() + await test_host_partial_reads_writes() + await test_async_stream_ops() + await test_wasm_to_wasm_stream() + await test_borrow_stream() + await test_cancel_copy() + await test_futures() + +asyncio.run(run_async_tests()) print("All tests passed")