Skip to content

Commit

Permalink
Improve frame creation (#44)
Browse files Browse the repository at this point in the history
this shows no improvement btw, maybe on yasync it does

This revealed a race condition in the stream recv transition. Raises
cause race conditions as they are Future.fail() calls that set callbacks
to continue later. In the meantime an END_STREAM may have been sent.
So the RST won't be sent.

But I think this does not matter. The only time RST won't be sent is if
the stream got closed, which gives the same result: the stream is
closed. However, this makes the h2spec fail.
  • Loading branch information
nitely authored Feb 3, 2025
1 parent 351092d commit d9cde4c
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 76 deletions.
137 changes: 61 additions & 76 deletions src/hyperx/clientserver.nim
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ proc newClient*(
headersDec: initDynHeaders(stgHeaderTableSize.int),
streams: initStreams(),
currStreamId: 0.StreamId,
recvMsgs: newQueue[Frame](10),
streamOpenedMsgs: newQueue[Stream](10),
peerMaxConcurrentStreams: stgInitialMaxConcurrentStreams,
peerWindow: stgInitialWindowSize.int32,
Expand All @@ -179,7 +178,6 @@ proc close*(client: ClientContext) {.raises: [HyperxConnError].} =
try:
catch client.sock.close()
finally:
client.recvMsgs.close()
client.streamOpenedMsgs.close()
client.streams.close()
client.peerWindowUpdateSig.close()
Expand Down Expand Up @@ -208,7 +206,6 @@ when defined(hyperxStats):
when defined(hyperxSanityCheck):
func sanityCheckAfterClose(client: ClientContext) {.raises: [].} =
doAssert not client.isConnected
doAssert client.recvMsgs.isClosed
doAssert client.streamOpenedMsgs.isClosed
doAssert client.peerWindowUpdateSig.isClosed
doAssert client.windowUpdateSig.isClosed
Expand Down Expand Up @@ -387,15 +384,18 @@ proc handshake(client: ClientContext) {.async.} =

func doTransitionRecv(
s: Stream, frm: Frame
) {.raises: [HyperxConnError, HyperxStrmError].} =
) {.raises: [HyperxConnError].} =
doAssert frm.sid == s.id
doAssert frm.sid != frmSidMain
doAssert s.state != strmInvalid
check frm.typ in frmStreamAllowed, newConnError(hyxProtocolError)
let nextState = toNextStateRecv(s.state, frm.toStreamEvent)
if nextState == strmInvalid:
if s.state == strmHalfClosedRemote:
raise newStrmError(hyxStreamClosed)
# This used to be a strmError, but it was raicy.
# Since we may send an END_STREAM before
# this propagates, and we cannot send the RST on a closed stream.
raise newConnError(hyxStreamClosed)
if s.state == strmClosed:
raise newConnError(hyxStreamClosed)
raise newConnError(hyxProtocolError)
Expand Down Expand Up @@ -478,41 +478,6 @@ proc read(client: ClientContext, frm: Frame) {.async.} =
debugInfo "Continuation"
await client.readUntilEnd(frm)

proc recvTaskNaked(client: ClientContext) {.async.} =
## Receive frames and dispatch to opened streams
## Meant to be asyncCheck'ed
doAssert client.isConnected
while client.isConnected:
var frm = newFrame()
await client.read frm
await client.recvMsgs.put frm

proc recvTask(client: ClientContext) {.async.} =
try:
await client.recvTaskNaked()
except QueueClosedError:
doAssert not client.isConnected
except HyperxConnError as err:
debugErr2 err
client.error ?= newError err
await client.sendSilently newGoAwayFrame(
client.maxPeerStreamIdSeen, err.code
)
raise err
except OsError, SslError:
let err = getCurrentException()
debugErr2 err
client.error ?= newConnError(err.msg)
raise newConnError(err.msg, err)
except CatchableError as err:
debugErr2 err
raise err
finally:
debugInfo "recvTask exited"
# xxx send goaway NO_ERROR
# await client.sendGoAway(NO_ERROR)
client.close()

const connFrmAllowed = {
frmtSettings,
frmtPing,
Expand Down Expand Up @@ -603,8 +568,10 @@ proc recvDispatcherNaked(client: ClientContext) {.async.} =
## so it needs to be done here. Same for processing the main
## stream messages.
var headers = ""
var frm = newFrame()
while client.isConnected:
let frm = await client.recvMsgs.pop()
frm.clear()
await client.read frm
debugInfo "recv data on stream " & $frm.sid.int
if frm.typ.isUnknown:
continue
Expand Down Expand Up @@ -686,6 +653,11 @@ proc recvDispatcher(client: ClientContext) {.async.} =
except HyperxStrmError:
debugErr getCurrentException()
doAssert false
except OsError, SslError:
let err = getCurrentException()
debugErr2 err
client.error ?= newConnError(err.msg)
raise newConnError(err.msg, err)
except CatchableError as err:
debugErr2 err
raise err
Expand Down Expand Up @@ -736,13 +708,12 @@ proc failSilently(f: Future[void]) {.async.} =
template with*(client: ClientContext, body: untyped): untyped =
discard getGlobalDispatcher() # setup event loop
doAssert not client.isConnected
var recvFut, dispFut, winupFut: Future[void] = nil
var dispFut, winupFut: Future[void] = nil
try:
client.isConnected = true
if client.typ == ctClient:
await client.connect()
await client.handshake()
recvFut = client.recvTask()
dispFut = client.recvDispatcher()
winupFut = client.windowUpdateTask()
block:
Expand All @@ -755,7 +726,6 @@ template with*(client: ClientContext, body: untyped): untyped =
client.close()
# do not bother the user with hyperx errors
# at this point body completed or errored out
await failSilently(recvFut)
await failSilently(dispFut)
await failSilently(winupFut)
when defined(hyperxSanityCheck):
Expand All @@ -776,6 +746,7 @@ type
headersRecv, bodyRecv, trailersRecv: string
headersRecvSig, bodyRecvSig: SignalAsync
bodyRecvLen: int
frm: Frame

func newClientStream*(client: ClientContext, stream: Stream): ClientStream =
ClientStream(
Expand All @@ -791,6 +762,7 @@ func newClientStream*(client: ClientContext, stream: Stream): ClientStream =
headersRecv: "",
headersRecvSig: newSignal(),
trailersRecv: "",
frm: newEmptyFrame()
)

func newClientStream*(client: ClientContext): ClientStream =
Expand Down Expand Up @@ -878,32 +850,25 @@ proc write(strm: ClientStream, frm: Frame): Future[void] =
stream.doTransitionSend frm
result = client.send frm

proc read(stream: Stream): Future[Frame] {.async.} =
var frm: Frame
while true:
frm = await stream.msgs.get()
#stream.msgs.getDone()
doAssert stream.id == frm.sid
doAssert frm.typ in frmStreamAllowed
# this can raise stream/conn error
stream.doTransitionRecv frm
if frm.typ == frmtRstStream:
stream.error = newStrmError(frm.errCode, hyxRemoteErr)
stream.close()
raise newStrmError(frm.errCode, hyxRemoteErr)
if frm.typ == frmtPushPromise:
raise newStrmError hyxProtocolError
if frm.typ == frmtWindowUpdate:
check frm.windowSizeInc > 0, newStrmError hyxProtocolError
check frm.windowSizeInc <= stgMaxWindowSize, newStrmError hyxProtocolError
check stream.peerWindow <= stgMaxWindowSize.int32 - frm.windowSizeInc.int32,
newStrmError hyxFlowControlError
stream.peerWindow += frm.windowSizeInc.int32
if not stream.peerWindowUpdateSig.isClosed:
stream.peerWindowUpdateSig.trigger()
if frm.typ in {frmtHeaders, frmtData}:
break
return frm
proc process(stream: Stream, frm: Frame) {.raises: [HyperxConnError, HyperxStrmError, QueueClosedError].} =
doAssert stream.id == frm.sid
doAssert frm.typ in frmStreamAllowed
# this can raise stream/conn error
stream.doTransitionRecv frm
if frm.typ == frmtRstStream:
stream.error = newStrmError(frm.errCode, hyxRemoteErr)
stream.close()
raise newStrmError(frm.errCode, hyxRemoteErr)
if frm.typ == frmtPushPromise:
raise newStrmError hyxProtocolError
if frm.typ == frmtWindowUpdate:
check frm.windowSizeInc > 0, newStrmError hyxProtocolError
check frm.windowSizeInc <= stgMaxWindowSize, newStrmError hyxProtocolError
check stream.peerWindow <= stgMaxWindowSize.int32 - frm.windowSizeInc.int32,
newStrmError hyxFlowControlError
stream.peerWindow += frm.windowSizeInc.int32
if not stream.peerWindowUpdateSig.isClosed:
stream.peerWindowUpdateSig.trigger()

# this needs to be {.async.} to fail-silently
proc writeRst(strm: ClientStream, code: FrmErrCode) {.async.} =
Expand All @@ -914,12 +879,18 @@ proc writeRst(strm: ClientStream, code: FrmErrCode) {.async.} =
await strm.write newRstStreamFrame(stream.id, code)

proc recvHeadersTaskNaked(strm: ClientStream) {.async.} =
template stream: untyped = strm.stream
doAssert strm.stateRecv == csStateOpened
strm.stateRecv = csStateHeaders
# https://httpwg.org/specs/rfc9113.html#HttpFraming
var frm: Frame
while true:
frm = await strm.stream.read()
while true:
frm = await stream.msgs.get()
stream.msgs.getDone()
stream.process(frm)
if frm.typ in {frmtHeaders, frmtData}:
break
check frm.typ == frmtHeaders, newStrmError hyxProtocolError
validateHeaders(frm.payload, strm.client.typ)
if strm.client.typ == ctServer:
Expand Down Expand Up @@ -950,11 +921,17 @@ func contentLenCheck(strm: ClientStream) {.raises: [HyperxStrmError].} =
)

proc recvBodyTaskNaked(strm: ClientStream) {.async.} =
template stream: untyped = strm.stream
doAssert strm.stateRecv in {csStateHeaders, csStateData}
strm.stateRecv = csStateData
var frm: Frame
while true:
frm = await strm.stream.read()
while true:
frm = await stream.msgs.get()
stream.msgs.getDone()
stream.process(frm)
if frm.typ in {frmtHeaders, frmtData}:
break
# https://www.rfc-editor.org/rfc/rfc9110.html#section-6.5
if frm.typ == frmtHeaders:
strm.trailersRecv.add frm.payload
Expand All @@ -979,15 +956,21 @@ proc recvBodyTaskNaked(strm: ClientStream) {.async.} =
strm.bodyRecvSig.trigger()
strm.bodyRecvSig.close()

proc process(stream: Stream) {.async.} =
var frm: Frame
while true:
frm = await stream.msgs.get()
stream.msgs.getDone()
stream.process(frm)

proc recvTask(strm: ClientStream) {.async.} =
template client: untyped = strm.client
template stream: untyped = strm.stream
try:
await recvHeadersTaskNaked(strm)
if strm.stateRecv != csStateEnded:
await recvBodyTaskNaked(strm)
while true:
discard await stream.read()
await stream.process()
except QueueClosedError:
discard
except HyperxConnError as err:
Expand Down Expand Up @@ -1074,13 +1057,14 @@ proc sendHeadersImpl*(
): Future[void] =
## Headers must be HPACK encoded;
## headers may be trailers
template frm: untyped = strm.frm
doAssert strm.stream.state in strmStateHeaderSendAllowed
doAssert strm.stateSend == csStateOpened or
(strm.stateSend in {csStateHeaders, csStateData} and finish)
if strm.stream.state == strmIdle:
strm.openStream()
strm.stateSend = csStateHeaders
var frm = newFrame()
frm.clear()
frm.add headers
frm.setTyp frmtHeaders
frm.setSid strm.stream.id
Expand Down Expand Up @@ -1112,6 +1096,7 @@ proc sendBodyNaked(
) {.async.} =
template client: untyped = strm.client
template stream: untyped = strm.stream
template frm: untyped = strm.frm
check stream.state in strmStateDataSendAllowed,
newErrorOrDefault(stream.error, newStrmError hyxStreamClosed)
doAssert strm.stateSend in {csStateHeaders, csStateData}
Expand All @@ -1129,7 +1114,7 @@ proc sendBodyNaked(
await client.peerWindowUpdateSig.waitFor()
let peerWindow = min(client.peerWindow, stream.peerWindow)
dataIdxB = min(dataIdxA+min(peerWindow, stgInitialMaxFrameSize.int), L)
var frm = newFrame()
frm.clear()
frm.setTyp frmtData
frm.setSid stream.id
frm.setPayloadLen (dataIdxB-dataIdxA).FrmPayloadLen
Expand Down
11 changes: 11 additions & 0 deletions src/hyperx/value.nim
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ proc get*[T](vala: ValueAsync[T]): Future[T] {.async.} =
doAssert vala.val != nil
result = vala.val
vala.val = nil

proc getDone*[T](vala: ValueAsync[T]) {.raises: [].} =
wakeupSoon vala.putWaiter

proc failSoon(f: Future[void]) {.raises: [].} =
Expand Down Expand Up @@ -83,9 +85,13 @@ when isMainModule:
doAssert q.val == nil
let puts1 = puts()
doAssert (await q.get())[] == 1
q.getDone()
doAssert (await q.get())[] == 2
q.getDone()
doAssert (await q.get())[] == 3
q.getDone()
doAssert (await q.get())[] == 4
q.getDone()
await puts1
waitFor test()
doAssert not hasPendingOperations()
Expand All @@ -94,9 +100,13 @@ when isMainModule:
var q = newValueAsync[ref int]()
proc gets {.async.} =
doAssert (await q.get())[] == 1
q.getDone()
doAssert (await q.get())[] == 2
q.getDone()
doAssert (await q.get())[] == 3
q.getDone()
doAssert (await q.get())[] == 4
q.getDone()
let gets1 = gets()
await q.put newIntRef(1)
doAssert q.val == nil
Expand All @@ -114,6 +124,7 @@ when isMainModule:
var q = newValueAsync[ref int]()
proc gets {.async.} =
doAssert (await q.get())[] == 1
q.getDone()
q.close()
let gets1 = gets()
await q.put newIntRef(1)
Expand Down

0 comments on commit d9cde4c

Please sign in to comment.