diff --git a/.gitignore b/.gitignore index 7e5f744..107e7e0 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ tests/functional/tflowcontrol tests/functional/tcancel tests/functional/tcancelremote tests/functional/tmisc +tests/functional/tgracefulclose src/hyperx/client src/hyperx/server src/hyperx/clientserver diff --git a/hyperx.nimble b/hyperx.nimble index 8671e45..fc6d4b9 100644 --- a/hyperx.nimble +++ b/hyperx.nimble @@ -64,6 +64,7 @@ task functest, "Func test": exec "nim c -r -d:release tests/functional/tcancel.nim" exec "nim c -r -d:release tests/functional/tcancelremote.nim" exec "nim c -r -d:release tests/functional/tmisc.nim" + exec "nim c -r -d:release tests/functional/tgracefulclose.nim" task funcserveinsec, "Func Serve Insecure": exec "nim c -r -d:release tests/functional/tserverinsecure.nim" diff --git a/src/hyperx/client.nim b/src/hyperx/client.nim index 0938417..046ceea 100644 --- a/src/hyperx/client.nim +++ b/src/hyperx/client.nim @@ -33,7 +33,8 @@ export ClientContext, HyperxConnError, HyperxStrmError, - HyperxError + HyperxError, + isGracefulClose var sslContext {.threadvar.}: SslContext diff --git a/src/hyperx/clientserver.nim b/src/hyperx/clientserver.nim index 5cbaddb..e699a20 100644 --- a/src/hyperx/clientserver.nim +++ b/src/hyperx/clientserver.nim @@ -129,11 +129,12 @@ type hostname*: string port: Port isConnected*: bool + isGracefulShutdown: bool headersEnc, headersDec: DynHeaders streams: Streams recvMsgs: QueueAsync[Frame] streamOpenedMsgs*: QueueAsync[Stream] - currStreamId, maxPeerStrmIdSeen: StreamId + currStreamId: StreamId peerMaxConcurrentStreams: uint32 peerWindowSize: uint32 peerWindow: int32 # can be negative @@ -159,13 +160,13 @@ proc newClient*( hostname: hostname, port: port, isConnected: false, + isGracefulShutdown: false, headersEnc: initDynHeaders(stgHeaderTableSize.int), headersDec: initDynHeaders(stgHeaderTableSize.int), streams: initStreams(), - currStreamId: 1.StreamId, + currStreamId: 0.StreamId, recvMsgs: newQueue[Frame](10), streamOpenedMsgs: newQueue[Stream](10), - maxPeerStrmIdSeen: 0.StreamId, peerMaxConcurrentStreams: stgInitialMaxConcurrentStreams, peerWindow: stgInitialWindowSize.int32, peerWindowSize: stgInitialWindowSize, @@ -209,12 +210,20 @@ func openMainStream(client: ClientContext): Stream {.raises: [StreamsClosedError doAssert frmSidMain.StreamId notin client.streams result = client.streams.open(frmSidMain.StreamId, client.peerWindowSize.int32) -func openStream(client: ClientContext): Stream {.raises: [StreamsClosedError].} = +func openStream(client: ClientContext): Stream {.raises: [StreamsClosedError, GracefulShutdownError].} = # XXX some error if max sid is reached # XXX error if maxStreams is reached - result = client.streams.open(client.currStreamId, client.peerWindowSize.int32) - # client uses odd numbers, and server even numbers - client.currStreamId += 2.StreamId + doAssert client.typ == ctClient + check not client.isGracefulShutdown, newGracefulShutdownError() + var sid = client.currStreamId.uint32 + sid += (if sid == 0: 1 else: 2) + result = client.streams.open(StreamId sid, client.peerWindowSize.int32) + client.currStreamId = StreamId sid + +func maxPeerStreamIdSeen(client: ClientContext): StreamId {.raises: [].} = + case client.typ + of ctClient: StreamId 0 + of ctServer: client.currStreamId when defined(hyperxStats): func echoStats*(client: ClientContext) = @@ -377,7 +386,6 @@ const serverHandshakeBlob = handshakeBlob(ctServer) proc handshakeNaked(client: ClientContext) {.async.} = doAssert client.isConnected debugInfo "handshake" - # we need to do this before sending any other frame let strm = client.openMainStream() doAssert strm.id == frmSidMain.StreamId check not client.sock.isClosed, newConnClosedError() @@ -417,10 +425,6 @@ func doTransitionRecv(s: Stream, frm: Frame) {.raises: [ConnError, StrmError].} raise newConnError(errStreamClosed) raise newConnError(errProtocolError) s.state = nextState - #if oldState == strmIdle: - # # XXX do this elsewhere not here - # # XXX close streams < s.id in idle state - # discard proc readUntilEnd(client: ClientContext, frm: Frame) {.async.} = ## Read continuation frames until ``END_HEADERS`` flag is set @@ -521,7 +525,7 @@ proc recvTask(client: ClientContext) {.async.} = # XXX close queues client.error = newConnError(err.code) await client.sendSilently newGoAwayFrame( - client.maxPeerStrmIdSeen.int, err.code.int + client.maxPeerStreamIdSeen.int, err.code.int ) #client.close() raise err @@ -613,10 +617,15 @@ proc consumeMainStream(client: ClientContext, frm: Frame) {.async.} = if not strm.pingSig.isClosed: strm.pingSig.trigger() of frmtGoAway: - # XXX close streams lower than Last-Stream-ID - # XXX don't allow new streams creation - # the connection is still ok for streams lower than Last-Stream-ID - discard + client.isGracefulShutdown = true + client.error ?= newConnError frm.errorCode() + # streams are never created by ctServer, + # so there are no streams to close + if client.typ == ctClient: + let sid = frm.lastStreamId() + for strm in values client.streams: + if strm.id.uint32 > sid: + client.streams.close(strm.id) else: doAssert frm.typ notin connFrmAllowed raise newConnError(errProtocolError) @@ -641,19 +650,20 @@ proc recvDispatcherNaked(client: ClientContext) {.async.} = await consumeMainStream(client, frm) continue check frm.typ in frmStreamAllowed, newConnError(errProtocolError) + check frm.sid.int mod 2 != 0, newConnError(errProtocolError) if client.typ == ctServer and - frm.sid.StreamId > client.maxPeerStrmIdSeen and - frm.sid.int mod 2 != 0: + frm.sid.StreamId > client.currStreamId: check client.streams.len <= stgServerMaxConcurrentStreams, newConnError(errProtocolError) - client.maxPeerStrmIdSeen = frm.sid.StreamId - # we do not store idle streams, so no need to close them - let strm = client.streams.open(frm.sid.StreamId, client.peerWindowSize.int32) - await client.streamOpenedMsgs.put strm - if client.typ == ctClient and - frm.sid.StreamId > client.maxPeerStrmIdSeen and - frm.sid.int mod 2 == 0: - client.maxPeerStrmIdSeen = frm.sid.StreamId + if client.isGracefulShutdown: + await client.send newGoAwayFrame( + client.maxPeerStreamIdSeen.int, errNoError.int + ) + else: + client.currStreamId = frm.sid.StreamId + # we do not store idle streams, so no need to close them + let strm = client.streams.open(frm.sid.StreamId, client.peerWindowSize.int32) + await client.streamOpenedMsgs.put strm if frm.typ == frmtHeaders: headers.setLen 0 client.hpackDecode(headers, frm.payload) @@ -667,13 +677,18 @@ proc recvDispatcherNaked(client: ClientContext) {.async.} = check frm.windowSizeInc > 0, newConnError(errProtocolError) if frm.typ == frmtPushPromise: check client.typ == ctClient, newConnError(errProtocolError) - # Process headers even if the stream - # does not exist + # Process headers even if the stream does not exist if frm.sid.StreamId notin client.streams: if frm.typ == frmtData: - client.windowPending -= frm.payloadLen.int - check frm.typ in {frmtRstStream, frmtWindowUpdate}, - newConnError errStreamClosed + client.windowProcessed += frm.payloadLen.int + if client.windowProcessed > stgWindowSize.int div 2: + client.windowUpdateSig.trigger() + if client.typ == ctServer and + frm.sid.StreamId > client.currStreamId: + doAssert client.isGracefulShutdown + else: + check frm.typ in {frmtRstStream, frmtWindowUpdate}, + newConnError errStreamClosed debugInfo "stream not found " & $frm.sid.int continue var stream = client.streams.get frm.sid.StreamId @@ -701,7 +716,7 @@ proc recvDispatcher(client: ClientContext) {.async.} = if client.isConnected: client.error = newConnError(err.code) await client.sendSilently newGoAwayFrame( - client.maxPeerStrmIdSeen.int, err.code.int + client.maxPeerStreamIdSeen.int, err.code.int ) raise err except StrmError: @@ -1023,7 +1038,7 @@ proc recvTask(strm: ClientStream) {.async.} = if client.isConnected: client.error = newConnError(err.code) await client.sendSilently newGoAwayFrame( - client.maxPeerStrmIdSeen.int, err.code.int + client.maxPeerStreamIdSeen.int, err.code.int ) raise err except StrmError as err: @@ -1222,16 +1237,19 @@ template with*(strm: ClientStream, body: untyped): untyped = strm.windowEnd() await failSilently(recvFut) -proc ping(strm: ClientStream) {.async.} = - # this is done for rst pings; only one stream ping +proc ping(client: ClientContext, strm: Stream) {.async.} = + # this is done for rst and go-away pings; only one stream ping # will ever be in progress - if strm.stream.pingSig.len > 0: - await strm.stream.pingSig.waitFor() + if strm.pingSig.len > 0: + await strm.pingSig.waitFor() else: - let sig = strm.stream.pingSig.waitFor() - await strm.client.send newPingFrame(strm.stream.id.uint32) + let sig = strm.pingSig.waitFor() + await client.send newPingFrame(strm.id.uint32) await sig +proc ping(strm: ClientStream) {.async.} = + await strm.client.ping(strm.stream) + proc cancel*(strm: ClientStream, code: ErrorCode) {.async.} = ## This may never return until the stream/conn is closed. ## This can be called multiple times concurrently, @@ -1246,6 +1264,24 @@ proc cancel*(strm: ClientStream, code: ErrorCode) {.async.} = strm.stream.error ?= newStrmError(errStreamClosed) strm.close() +proc gracefulClose*(client: ClientContext) {.async.} = + # returning early is ok + if client.isGracefulShutdown: + return + # fail silently because it's best effort, + # setting isGracefulShutdown is the only important thing + await failSilently client.send newGoAwayFrame( + int32.high, errNoError.int + ) + await failSilently client.ping client.streams.get(StreamId 0) + client.isGracefulShutdown = true + await failSilently client.send newGoAwayFrame( + client.maxPeerStreamIdSeen.int, errNoError.int + ) + +proc isGracefulClose*(client: ClientContext): bool {.raises: [].} = + result = client.isGracefulShutdown + when defined(hyperxTest): proc putRecvTestData*(client: ClientContext, data: seq[byte]) {.async.} = await client.sock.putRecvData data diff --git a/src/hyperx/errors.nim b/src/hyperx/errors.nim index bc1b6e3..4b0045c 100644 --- a/src/hyperx/errors.nim +++ b/src/hyperx/errors.nim @@ -56,21 +56,27 @@ type ConnClosedError* = object of HyperxConnError ConnError* = object of HyperxConnError code*: ErrorCode + GracefulShutdownError* = ConnError StrmError* = object of HyperxStrmError typ*: HyperxErrTyp code*: ErrorCode - QueueError* = object of HyperxError - QueueClosedError* = object of QueueError + QueueClosedError* = object of HyperxError func newHyperxConnError*(msg: string): ref HyperxConnError {.raises: [].} = result = (ref HyperxConnError)(msg: msg) -func newConnClosedError*(): ref ConnClosedError {.raises: [].} = +func newConnClosedError*: ref ConnClosedError {.raises: [].} = result = (ref ConnClosedError)(msg: "Connection Closed") func newConnError*(errCode: ErrorCode): ref ConnError {.raises: [].} = result = (ref ConnError)(code: errCode, msg: "Connection Error: " & $errCode) +func newConnError*(errCode: uint32): ref ConnError {.raises: [].} = + result = (ref ConnError)( + code: errCode.toErrorCode, + msg: "Connection Error: " & $errCode.toErrorCode + ) + func newStrmError*(errCode: ErrorCode, typ = hxLocalErr): ref StrmError {.raises: [].} = let msg = case typ of hxLocalErr: "Stream Error: " & $errCode @@ -90,3 +96,8 @@ func newErrorOrDefault*(err, default: ref StrmError): ref StrmError {.raises: [] return newError(err) else: return default + +func newGracefulShutdownError*(): ref GracefulShutdownError {.raises: [].} = + result = (ref GracefulShutdownError)( + code: errNoError, msg: "Connection Error: " & $errNoError + ) diff --git a/src/hyperx/frame.nim b/src/hyperx/frame.nim index a7e34c5..9245215 100644 --- a/src/hyperx/frame.nim +++ b/src/hyperx/frame.nim @@ -360,12 +360,20 @@ func windowSizeInc*(frm: Frame): uint {.raises: [].} = result.clearBit 31 # clear reserved byte func errorCode*(frm: Frame): uint32 {.raises: [].} = - doAssert frm.typ == frmtRstStream - result = 0 - result += frm.s[frmHeaderSize+0].uint32 shl 24 - result += frm.s[frmHeaderSize+1].uint32 shl 16 - result += frm.s[frmHeaderSize+2].uint32 shl 8 - result += frm.s[frmHeaderSize+3].uint32 + result = 0'u32 + case frm.typ + of frmtRstStream: + result += frm.s[frmHeaderSize+0].uint32 shl 24 + result += frm.s[frmHeaderSize+1].uint32 shl 16 + result += frm.s[frmHeaderSize+2].uint32 shl 8 + result += frm.s[frmHeaderSize+3].uint32 + of frmtGoAway: + result += frm.s[frmHeaderSize+4].uint32 shl 24 + result += frm.s[frmHeaderSize+5].uint32 shl 16 + result += frm.s[frmHeaderSize+6].uint32 shl 8 + result += frm.s[frmHeaderSize+7].uint32 + else: + doAssert false func pingData*(frm: Frame): uint32 {.raises: [].} = # note we ignore the last 4 bytes @@ -376,16 +384,13 @@ func pingData*(frm: Frame): uint32 {.raises: [].} = result += frm.s[frmHeaderSize+2].uint32 shl 8 result += frm.s[frmHeaderSize+3].uint32 -# XXX add padding field and padding as payload -#func setPadding*(frm: Frame, n: FrmPadding) = -# doAssert frm.typ in {frmtData, frmtHeaders, frmtPushPromise} - -#func add*(frm: Frame, payload: openArray[byte]) = -# frm.s.add payload -# frm.setPayloadLen FrmPayloadLen(frm.rawLen-frmHeaderSize) - -#template payload*(frm: Frame): untyped = -# toOpenArray(frm.s, frmHeaderSize, frm.s.len-1) +func lastStreamId*(frm: Frame): uint32 = + doAssert frm.typ == frmtGoAway + result = 0'u32 + result += frm.s[frmHeaderSize+0].uint32 shl 24 + result += frm.s[frmHeaderSize+1].uint32 shl 16 + result += frm.s[frmHeaderSize+2].uint32 shl 8 + result += frm.s[frmHeaderSize+3].uint32 func `$`*(frm: Frame): string {.raises: [].} = result = "" diff --git a/src/hyperx/server.nim b/src/hyperx/server.nim index 566507a..53ee5d7 100644 --- a/src/hyperx/server.nim +++ b/src/hyperx/server.nim @@ -34,7 +34,9 @@ export ClientContext, HyperxConnError, HyperxStrmError, - HyperxError + HyperxError, + gracefulClose, + isGracefulClose var sslContext {.threadvar.}: SslContext diff --git a/src/hyperx/stream.nim b/src/hyperx/stream.nim index 83bd82c..9e6ce16 100644 --- a/src/hyperx/stream.nim +++ b/src/hyperx/stream.nim @@ -4,6 +4,7 @@ import ./frame import ./value import ./signal import ./errors +import ./utils # Section 5.1 type @@ -198,8 +199,12 @@ proc close*(stream: Stream) {.raises: [].} = stream.peerWindowUpdateSig.close() stream.pingSig.close() +type StreamsClosedError* = object of QueueClosedError + +func newStreamsClosedError*(msg: string): ref StreamsClosedError {.raises: [].} = + result = (ref StreamsClosedError)(msg: msg) + type - StreamsClosedError* = object of HyperxError Streams* = object t: Table[StreamId, Stream] isClosed: bool @@ -232,8 +237,7 @@ func open*( peerWindow: int32 ): Stream {.raises: [StreamsClosedError].} = doAssert sid notin s.t, $sid.int - if s.isClosed: - raise newException(StreamsClosedError, "Cannot open stream") + check not s.isClosed, newStreamsClosedError("Cannot open stream") result = newStream(sid, peerWindow) s.t[sid] = result diff --git a/tests/functional/tgracefulclose.nim b/tests/functional/tgracefulclose.nim new file mode 100644 index 0000000..0beabb8 --- /dev/null +++ b/tests/functional/tgracefulclose.nim @@ -0,0 +1,98 @@ +{.define: ssl.} +{.define: hyperxSanityCheck.} + +import std/asyncdispatch +import ../../src/hyperx/client +import ../../src/hyperx/limiter +#import ../../src/hyperx/errors +import ./tutils.nim +from ../../src/hyperx/clientserver import stgWindowSize + +# Since this relies on a sleep, it can be flaky +# but if it fails, there's something wrong; +# it's just hard to make it consistenly fail; +# checked output should be a fairly random number; +# output like "checked 5000" (clients*streams) means streams are not +# created in between server receiving the posion pill +# and client ACK'ing the ping + +const clientsCount = 50 +const strmsInFlight = 100 +const dataFrameLen = 123 +#const dataFrameLen = stgWindowSize.int * 2 + 123 +const theData = newString(dataFrameLen) + +proc send(strm: ClientStream, poison = false) {.async.} = + var headers = @[ + (":method", "POST"), + (":scheme", "https"), + (":path", "/file/"), + (":authority", "foo.bar"), + ("user-agent", "HyperX/0.1"), + ("content-type", "text/plain"), + ("x-no-echo-headers", "true") + ] + if poison: + headers.add ("x-graceful-close-remote", "true") + await strm.sendHeaders(headers, finish = false) + # cannot sleep before sending headers + if poison: + await sleepAsync(10_000) + let data = newStringref theData & $strm.stream.id.int + await strm.sendBody(data, finish = true) + +proc recv(strm: ClientStream) {.async.} = + var data = newStringref() + await strm.recvHeaders(data) + doAssert data[] == ":status: 200\r\n" + data[].setLen 0 + while not strm.recvEnded: + await strm.recvBody(data) + doAssert data[] == theData & $strm.stream.id.int + +proc spawnStream( + client: ClientContext, + checked: ref int, + poison = false +) {.async.} = + if poison: + await sleepAsync(10_000) + if client.isGracefulClose: + inc checked[] + return + let strm = client.newClientStream() + with strm: + let sendFut = strm.send(poison) + let recvFut = strm.recv() + await recvFut + await sendFut + inc checked[] + +proc spawnClient( + checked: ref int +) {.async.} = + var client = newClient(localHost, localPort) + with client: + let lt = newLimiter(strmsInFlight) + await lt.spawn spawnStream(client, checked, poison = true) + while not client.isGracefulClose: + doAssert client.isConnected + await lt.spawn spawnStream(client, checked) + await lt.join() + +proc main() {.async.} = + let checked = new(int) + checked[] = 0 + var clients = newSeq[Future[void]]() + for _ in 0 .. clientsCount-1: + clients.add spawnClient(checked) + for clientFut in clients: + await clientFut + doAssert checked[] >= clientsCount #* strmsInFlight + echo "checked ", $checked[] + +(proc = + waitFor main() + doAssert not hasPendingOperations() + echo "ok" +)() diff --git a/tests/functional/tserver.nim b/tests/functional/tserver.nim index 0119cf5..78fc8d8 100644 --- a/tests/functional/tserver.nim +++ b/tests/functional/tserver.nim @@ -27,11 +27,17 @@ proc processStream(strm: ClientStream) {.async.} = @[("x-trailer", "bye")], finish = true ) await strm.cancel(errCancel) - await strm.sendBody(data, finish = strm.recvEnded) + if "x-graceful-close-remote" in data[]: + await strm.client.gracefulClose() + if "x-no-echo-headers" notin data[]: + await strm.sendBody(data, finish = strm.recvEnded) while not strm.recvEnded: data[].setLen 0 await strm.recvBody(data) await strm.sendBody(data, finish = strm.recvEnded) + if not strm.sendEnded: + data[].setLen 0 + await strm.sendBody(data, finish = true) #GC_fullCollect() proc serve*(server: ServerContext) {.async.} =