Skip to content

Commit

Permalink
Add 'future' type
Browse files Browse the repository at this point in the history
  • Loading branch information
lukewagner committed Oct 6, 2024
1 parent 9a698e3 commit 2b648b3
Showing 1 changed file with 75 additions and 6 deletions.
81 changes: 75 additions & 6 deletions design/mvp/canonical-abi/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ class BorrowType(ValType):
class StreamType(ValType):
t: ValType

@dataclass
class FutureType(ValType):
t: ValType

### CallContext

class CallContext:
Expand Down Expand Up @@ -202,7 +206,7 @@ class CanonicalOptions:

class ComponentInstance:
resources: ResourceTables
waitables: Table[Subtask|StreamHandle]
waitables: Table[Subtask|StreamHandle|FutureHandle]
num_tasks: int
may_leave: bool
backpressure: bool
Expand All @@ -211,7 +215,7 @@ class ComponentInstance:

def __init__(self):
self.resources = ResourceTables()
self.waitables = Table[Subtask|StreamHandle]()
self.waitables = Table[Subtask|StreamHandle|FutureHandle]()
self.num_tasks = 0
self.may_leave = True
self.backpressure = False
Expand Down Expand Up @@ -308,6 +312,8 @@ class EventCode(IntEnum):
YIELDED = 4
STREAM_READ = 5
STREAM_WRITE = 6
FUTURE_READ = 7
FUTURE_WRITE = 8

EventTuple = tuple[EventCode, int, int]
EventCallback = Callable[[], Optional[EventTuple]]
Expand Down Expand Up @@ -669,11 +675,23 @@ def maybe_writer_handle_index(self, inst):
return self.cx.task.inst.waitables.array.index(self)
return None

LowerCallback = Callable[[any], None]

class Future:
get: Callable[OnBlockCallback], Awaitable[Optional[any]]
cancel: Callable[[OnBlockCallback], Awaitable]
maybe_writer_handle_index: Callable[[ComponentInstance], Optional[int]]

def __init__(self, impl):
self.read = impl.read
self.cancel = impl.cancel
self.maybe_writer_handle_index = impl.maybe_writer_handle_index

### Type utilities

def contains_async(t):
match t:
case StreamType():
case StreamType() | FutureType():
return True
case PrimValType() | OwnType() | BorrowType():
return False
Expand Down Expand Up @@ -721,7 +739,7 @@ def alignment(t):
case VariantType(cases) : return alignment_variant(cases)
case FlagsType(labels) : return alignment_flags(labels)
case OwnType() | BorrowType() : return 4
case StreamType() : return 4
case StreamType() | FutureType() : return 4

def alignment_list(elem_type, maybe_length):
if maybe_length is not None:
Expand Down Expand Up @@ -778,7 +796,7 @@ def elem_size(t):
case VariantType(cases) : return elem_size_variant(cases)
case FlagsType(labels) : return elem_size_flags(labels)
case OwnType() | BorrowType() : return 4
case StreamType() : return 4
case StreamType() | FutureType() : return 4

def elem_size_list(elem_type, maybe_length):
if maybe_length is not None:
Expand Down Expand Up @@ -838,6 +856,7 @@ def load(cx, ptr, t):
case FlagsType(labels) : return load_flags(cx, ptr, labels)
case OwnType() : return lift_own(cx, load_int(cx, ptr, 4), t)
case BorrowType() : return lift_borrow(cx, load_int(cx, ptr, 4), t)
case FutureType(t) : return lift_future(cx, load_int(cx, ptr, 4), t)
case StreamType(t) : return lift_stream(cx, load_int(cx, ptr, 4), t)

def load_int(cx, ptr, nbytes, signed = False):
Expand Down Expand Up @@ -992,6 +1011,9 @@ def lift_borrow(cx, i, t):
cx.add_lender(h)
return h.rep

def lift_future(cx, i t):
TODO

def lift_stream(cx, i, elem_type):
h = cx.inst.waitables.get(i)
trap_if(not isinstance(h, StreamHandle))
Expand Down Expand Up @@ -1033,6 +1055,7 @@ def store(cx, v, t, ptr):
case FlagsType(labels) : store_flags(cx, v, ptr, labels)
case OwnType() : store_int(cx, lower_own(cx, v, t), ptr, 4)
case BorrowType() : store_int(cx, lower_borrow(cx, v, t), ptr, 4)
case FutureType(t) : store_int(cx, lower_future(cx, v, t), ptr, 4)
case StreamType(t) : store_int(cx, lower_stream(cx, v, t), ptr, 4)

def store_int(cx, v, ptr, nbytes, signed = False):
Expand Down Expand Up @@ -1296,6 +1319,9 @@ def lower_borrow(cx, rep, t):
cx.need_to_drop += 1
return cx.inst.resources.add(t.rt, h)

def lower_future(cx, future, t):
TODO

def lower_stream(cx, stream, elem_type):
assert(isinstance(stream, Stream))
if (i := stream.maybe_writer_handle_index(cx.inst)):
Expand Down Expand Up @@ -1360,7 +1386,7 @@ def flatten_type(t):
case VariantType(cases) : return flatten_variant(cases)
case FlagsType(labels) : return ['i32']
case OwnType() | BorrowType() : return ['i32']
case StreamType() : return ['i32']
case StreamType() | FutureType() : return ['i32']

def flatten_list(elem_type, maybe_length):
if maybe_length is not None:
Expand Down Expand Up @@ -1427,6 +1453,7 @@ def lift_flat(cx, vi, t):
case FlagsType(labels) : return lift_flat_flags(vi, labels)
case OwnType() : return lift_own(cx, vi.next('i32'), t)
case BorrowType() : return lift_borrow(cx, vi.next('i32'), t)
case FutureType(t) : return lift_future(cx, vi.next('i32'), t)
case StreamType(t) : return lift_stream(cx, vi.next('i32'), t)

def lift_flat_unsigned(vi, core_width, t_width):
Expand Down Expand Up @@ -1519,6 +1546,7 @@ def lower_flat(cx, v, t):
case FlagsType(labels) : return lower_flat_flags(v, labels)
case OwnType() : return [lower_own(cx, v, t)]
case BorrowType() : return [lower_borrow(cx, v, t)]
case FutureType(t) : return [lower_future(cx, v, t)]
case StreamType(t) : return [lower_stream(cx, v, t)]

def lower_flat_signed(i, core_bits):
Expand Down Expand Up @@ -1837,6 +1865,47 @@ async def stream_cancel_copy(StreamHandleT, sync, task, i):
h.copying_buffer = None
return flat_results

### 🔀 `canon future.new`

async def canon_future_new(t, task):
trap_if(not task.inst.may_leave)
h = WritableFutureHandle(t)
return [ task.inst.waitables.add(h) ]

### 🔀 `canon future.read` and `canon future.write`

async def canon_future_read(task, i, ptr):
return await future_copy(ReadableFutureHandle, WritableBuffer,
task, i, ptr, EventCode.FUTURE_READ)

async def canon_future_write(task, i, ptr):
return await future_copy(WritableFutureHandle, ReadableBuffer,
task, i, ptr, EventCode.FUTURE_WRITE)

async def future_copy(FutureHandleT, BufferT, task, i, ptr, event_code):
trap_if(not task.inst.may_leave)
h = task.inst.waitables.get(i)
trap_if(not isinstance(h, FutureHandleT))
trap_if(not h.cx)
trap_if(h.future.done())
buffer = BufferT(h.cx, h.t, ptr, length = 1)
async def do_copy(on_block):
await h.copy(buffer, on_block)
if h.copying:
def future_event():
return (event_code, i, pack_future_result(buffer))
h.cx.task.notify(stream_event)
match await call_and_handle_blocking(do_copy):
case Blocked():
h.copying = buffer
flat_results = [BLOCKED]
case Returned():
flat_results = [pack_future_result(buffer)]
return flat_results

def pack_future_result(buffer):
TODO

### 🔀 `canon waitable.drop`

async def canon_waitable_drop(task, i):
Expand Down

0 comments on commit 2b648b3

Please sign in to comment.