Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream cancel fixes #23

Merged
merged 5 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ tests/functional/tserver
tests/functional/tconcurrent
tests/functional/tconcurrentdata
tests/functional/tflowcontrol
tests/functional/tcancel
src/hyperx/client
src/hyperx/server
src/hyperx/clientserver
src/hyperx/queue
src/hyperx/lock
src/hyperx/signal
src/hyperx/value
src/hyperx/untestable
src/hyperx/frame
src/hyperx/stream
Expand Down
2 changes: 2 additions & 0 deletions hyperx.nimble
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ task test, "Test":
exec "nim c -r src/hyperx/utils.nim"
exec "nim c -r src/hyperx/queue.nim"
exec "nim c -r src/hyperx/signal.nim"
exec "nim c -r src/hyperx/value.nim"
exec "nim c -r src/hyperx/stream.nim"
exec "nim c -r src/hyperx/frame.nim"
exec "nim c -r -f -d:hyperxTest -d:ssl src/hyperx/testutils.nim"
Expand Down Expand Up @@ -59,6 +60,7 @@ task functest, "Func test":
exec "nim c -r -d:release tests/functional/tconcurrent.nim"
exec "nim c -r -d:release tests/functional/tconcurrentdata.nim"
exec "nim c -r -d:release tests/functional/tflowcontrol.nim"
exec "nim c -r -d:release tests/functional/tcancel.nim"

task h2spec, "h2spec test":
exec "./h2spec --tls --port 8783 --strict"
Expand Down
21 changes: 14 additions & 7 deletions src/hyperx/clientserver.nim
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import pkg/hpack
import ./frame
import ./stream
import ./queue
import ./value
import ./signal
import ./errors
import ./utils
Expand Down Expand Up @@ -668,6 +669,8 @@ proc recvDispatcherNaked(client: ClientContext) {.async.} =
# Process headers even if the stream
# does not exist
if frm.sid.StreamId notin client.streams:
if frm.typ == frmtData:
client.windowPending -= frm.payloadLen.int
check frm.typ in {frmtRstStream, frmtWindowUpdate},
newConnError errStreamClosed
debugInfo "stream not found " & $frm.sid.int
Expand Down Expand Up @@ -892,14 +895,16 @@ proc write(strm: ClientStream, frm: Frame): Future[void] =
proc read(stream: Stream): Future[Frame] {.async.} =
var frm: Frame
while true:
frm = await stream.msgs.pop()
#frm = await stream.msgs.pop()
frm = await stream.msgs.get()
#stream.msgs.getDone()
doAssert stream.id == frm.sid.StreamId
doAssert frm.typ in frmStreamAllowed
# this can raise stream/conn error
stream.doTransitionRecv frm
if frm.typ == frmtRstStream:
for frm2 in stream.msgs:
stream.doTransitionRecv frm2
#for frm2 in stream.msgs:
# stream.doTransitionRecv frm2
stream.error = newStrmError(frm.errorCode, hxRemoteErr)
stream.close()
raise newStrmError(frm.errorCode, hxRemoteErr)
Expand Down Expand Up @@ -1058,10 +1063,10 @@ proc recvBodyNaked(strm: ClientStream, data: ref string) {.async.} =
let bodyL = strm.bodyRecv.len
data[].add strm.bodyRecv
strm.bodyRecv.setLen 0
if not client.isConnected:
# this avoids raising when sending a window update
# if the conn is closed. Unsure if it's useful
return
#if not client.isConnected:
# # this avoids raising when sending a window update
# # if the conn is closed. Unsure if it's useful
# return
client.windowProcessed += bodyL
stream.windowProcessed += bodyL
doAssert stream.windowPending >= stream.windowProcessed
Expand Down Expand Up @@ -1225,6 +1230,8 @@ proc cancel*(strm: ClientStream, code: ErrorCode) {.async.} =
await failSilently strm.writeRst(code)
await failSilently strm.ping()
finally:
if strm.stream.error == nil:
strm.stream.error = newStrmError(errStreamClosed)
strm.close()

when defined(hyperxTest):
Expand Down
7 changes: 3 additions & 4 deletions src/hyperx/signal.nim
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ proc newSignal*(): SignalAsync {.raises: [].} =
isClosed: false
)

proc waitFor*(sig: SignalAsync) {.async.} =
proc waitFor*(sig: SignalAsync): Future[void] {.raises: [SignalClosedError].} =
if sig.isClosed:
raise newSignalClosedError()
let fut = newFuture[void]()
sig.waiters.addFirst fut
await fut
result = newFuture[void]()
sig.waiters.addFirst result

proc wakeupSoon(f: Future[void]) =
proc wakeup =
Expand Down
48 changes: 37 additions & 11 deletions src/hyperx/stream.nim
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import std/tables

import ./frame
import ./queue
#import ./queue
import ./value
import ./signal
import ./errors

Expand All @@ -15,6 +16,7 @@ type
strmReservedRemote
strmHalfClosedLocal
strmHalfClosedRemote
strmClosedRst
strmInvalid
StreamEvent* = enum
seHeaders
Expand Down Expand Up @@ -130,12 +132,12 @@ func toNextStateRecv*(s: StreamState, e: StreamEvent): StreamState {.raises: [].
of seHeadersEndStream, seRstStream: strmClosed
of sePriority: strmReservedRemote
else: strmInvalid
of strmHalfClosedLocal:
of strmHalfClosedLocal, strmClosedRst:
case e
of seHeadersEndStream,
seDataEndStream,
seRstStream: strmClosed
else: strmHalfClosedLocal
else: s
of strmHalfClosedRemote, strmReservedLocal:
case e
of seRstStream: strmClosed
Expand All @@ -160,11 +162,11 @@ func toNextStateSend*(s: StreamState, e: StreamEvent): StreamState {.raises: [].
case e
of seHeadersEndStream,
seDataEndStream: strmHalfClosedLocal
of seRstStream: strmClosed
of seRstStream: strmClosedRst
else: strmOpen
of strmClosed:
of strmClosed, strmClosedRst:
case e
of sePriority: strmClosed
of sePriority: s
else: strmInvalid
of strmReservedLocal:
case e
Expand All @@ -178,7 +180,12 @@ func toNextStateSend*(s: StreamState, e: StreamEvent): StreamState {.raises: [].
seDataEndStream,
seRstStream: strmClosed
else: strmHalfClosedRemote
of strmHalfClosedLocal, strmReservedRemote:
of strmHalfClosedLocal:
case e
of seRstStream: strmClosedRst
of seWindowUpdate, sePriority: s
else: strmInvalid
of strmReservedRemote:
case e
of seRstStream: strmClosed
of seWindowUpdate, sePriority: s
Expand All @@ -198,7 +205,8 @@ type
Stream* = ref object
id*: StreamId
state*: StreamState
msgs*: QueueAsync[Frame]
#msgs*: QueueAsync[Frame]
msgs*: ValueAsync[Frame]
peerWindow*: int32
peerWindowUpdateSig*: SignalAsync
windowPending*: int
Expand All @@ -211,7 +219,8 @@ proc newStream(id: StreamId, peerWindow: int32): Stream {.raises: [].} =
Stream(
id: id,
state: strmIdle,
msgs: newQueue[Frame](1),
#msgs: newQueue[Frame](1),
msgs: newValueAsync[Frame](),
peerWindow: peerWindow,
peerWindowUpdateSig: newSignal(),
windowPending: 0,
Expand Down Expand Up @@ -259,7 +268,7 @@ func open*(
): Stream {.raises: [StreamsClosedError].} =
doAssert sid notin s.t, $sid.int
if s.isClosed:
raise newException(StreamsClosedError, "Streams is closed")
raise newException(StreamsClosedError, "Cannot open stream")
result = newStream(sid, peerWindow)
s.t[sid] = result

Expand Down Expand Up @@ -317,7 +326,8 @@ when isMainModule:
strmReservedLocal,
strmReservedRemote,
strmHalfClosedLocal,
strmHalfClosedRemote
strmHalfClosedRemote,
strmClosedRst
#strmInvalid
}
block:
Expand Down Expand Up @@ -419,5 +429,21 @@ when isMainModule:
let isValid = toNextStateRecv(state, seData) != strmInvalid
let isValid2 = toNextStateRecv(state, seDataEndStream) != strmInvalid
doAssert isValid == isValid2, $state
block:
for ev in allEvents-{seUnknown,sePriority}:
doAssert toNextStateSend(strmClosedRst, ev) == strmInvalid
doAssert toNextStateSend(strmClosedRst, sePriority) == strmClosedRst
block:
for state in {strmOpen,strmHalfClosedLocal}:
doAssert toNextStateSend(state, seRstStream) == strmClosedRst
for state in allStates-{strmOpen,strmHalfClosedLocal}:
doAssert toNextStateSend(state, seRstStream) in {strmInvalid, strmClosed}
block:
for ev in allEvents-{seUnknown,seRstStream,seHeadersEndStream,seDataEndStream}:
doAssert toNextStateRecv(strmClosedRst, ev) == strmClosedRst
doAssert toNextStateRecv(strmHalfClosedLocal, ev) == strmHalfClosedLocal
for ev in {seRstStream,seHeadersEndStream,seDataEndStream}:
doAssert toNextStateRecv(strmClosedRst, ev) == strmClosed
doAssert toNextStateRecv(strmHalfClosedLocal, ev) == strmClosed

echo "ok"
85 changes: 85 additions & 0 deletions src/hyperx/value.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import std/asyncdispatch

import ./signal
import ./errors

type
ValueAsyncClosedError* = QueueClosedError

func newValueAsyncClosedError(): ref ValueAsyncClosedError {.raises: [].} =
result = (ref ValueAsyncClosedError)(msg: "ValueAsync is closed")

type
ValueAsync*[T] = ref object
sigPut, sigGet: SignalAsync
val: T
isClosed: bool

func newValueAsync*[T](): ValueAsync[T] {.raises: [].} =
ValueAsync[T](
sigPut: newSignal(),
sigGet: newSignal(),
val: nil,
isClosed: false
)

proc put*[T](vala: ValueAsync[T], val: T) {.async.} =
if vala.isClosed:
raise newValueAsyncClosedError()
try:
while vala.val != nil:
await vala.sigPut.waitFor()
vala.val = val
vala.sigGet.trigger()
while vala.val != nil:
await vala.sigPut.waitFor()
except SignalClosedError:
raise newValueAsyncClosedError()

proc get*[T](vala: ValueAsync[T]): Future[T] {.async.} =
if vala.isClosed:
raise newValueAsyncClosedError()
try:
while vala.val == nil:
await vala.sigGet.waitFor()
result = vala.val
vala.val = nil
vala.sigPut.trigger()
except SignalClosedError:
raise newValueAsyncClosedError()

#proc getDone*[T](vala: ValueAsync[T]) =
# vala.val = nil
# try:
# vala.sigPut.trigger()
# except SignalClosedError:
# raise newValueAsyncClosedError()

proc close*[T](vala: ValueAsync[T]) {.raises: [].} =
if vala.isClosed:
return
vala.isClosed = true
vala.sigPut.close()
vala.sigGet.close()

when isMainModule:
func newIntRef(n: int): ref int =
new result
result[] = n
block:
proc test() {.async.} =
var q = newValueAsync[ref int]()
proc puts {.async.} =
await q.put newIntRef(1)
await q.put newIntRef(2)
await q.put newIntRef(3)
await q.put newIntRef(4)
let puts1 = puts()
doAssert (await q.get())[] == 1
doAssert (await q.get())[] == 2
doAssert (await q.get())[] == 3
doAssert (await q.get())[] == 4
await puts1
waitFor test()
doAssert not hasPendingOperations()
echo "ok"
Loading
Loading