From 23d2b28a74f2642040c33f7b2b69ccf4ff184e1d Mon Sep 17 00:00:00 2001 From: Bob Weinand Date: Fri, 16 Feb 2024 23:51:39 +0100 Subject: [PATCH] Make sure closing the UpgradedSocket doesn't close the whole connection --- src/Driver/Client.php | 6 ++++++ src/Driver/Http2Driver.php | 22 ++++++++++++++++++++-- src/Driver/Http3Driver.php | 4 +--- src/Driver/SocketClient.php | 14 ++++++++------ src/Driver/UpgradedSocket.php | 1 - 5 files changed, 35 insertions(+), 12 deletions(-) diff --git a/src/Driver/Client.php b/src/Driver/Client.php index ba869a5a..0d71c699 100644 --- a/src/Driver/Client.php +++ b/src/Driver/Client.php @@ -13,6 +13,12 @@ interface Client extends Closable */ public function getId(): int; + /** + * @inheritDoc + * @param int $close An optional close reason for streams which support a close reason. + */ + public function close(int $reason = 0): void; + /** * @return SocketAddress Remote client address. */ diff --git a/src/Driver/Http2Driver.php b/src/Driver/Http2Driver.php index 0e00e522..125353aa 100644 --- a/src/Driver/Http2Driver.php +++ b/src/Driver/Http2Driver.php @@ -31,6 +31,8 @@ use Amp\Http\Server\Trailers; use Amp\Pipeline\Queue; use Amp\Socket\InternetAddress; +use Amp\Socket\SocketAddress; +use Amp\Socket\TlsInfo; use League\Uri; use Psr\Log\LoggerInterface as PsrLogger; use Revolt\EventLoop; @@ -1160,7 +1162,23 @@ private function upgrade(Request $request, Response $response, int $id): void throw new \Error('Response was not upgraded'); } - $client = $request->getClient(); + $client = new class($request, $id) extends SocketClient { + public function __construct(private Request $request, int $id) { + parent::__construct($request->getClient(), $id); + } + + public function close(int $reason = 0): void { + // Nothing to do here, closing the output stream is enough + } + + public function isClosed(): bool { + return $this->request->getBody()->isClosed(); + } + + public function onClose(\Closure $onClose): void { + $this->request->getBody()->onClose($onClose); + } + }; // The input RequestBody are parsed raw DATA frames - exactly what we need (see CONNECT) $inputStream = new UnbufferedBodyStream($request->getBody()); @@ -1169,7 +1187,7 @@ private function upgrade(Request $request, Response $response, int $id): void // The output of an upgraded connection is just DATA frames $outputPipe = new Pipe(0); - $upgraded = new UpgradedSocket($client, $inputStream, $outputPipe->getSink(), $id); + $upgraded = new UpgradedSocket($client, $inputStream, $outputPipe->getSink()); try { $upgradeHandler($upgraded, $request, $response); diff --git a/src/Driver/Http3Driver.php b/src/Driver/Http3Driver.php index 01add804..ef503c7f 100644 --- a/src/Driver/Http3Driver.php +++ b/src/Driver/Http3Driver.php @@ -208,8 +208,6 @@ private function upgrade(QuicSocket $stream, Request $request, Response $respons throw new \Error('Response was not upgraded'); } - $client = $request->getClient(); - // The input RequestBody are parsed raw DATA frames - exactly what we need (see CONNECT) $inputStream = new UnbufferedBodyStream($request->getBody()); $request->setBody(""); // hide the body from the upgrade handler, it's available in the UpgradedSocket @@ -220,7 +218,7 @@ private function upgrade(QuicSocket $stream, Request $request, Response $respons $settings = $this->parsedSettings->getFuture()->await(); $datagramStream = empty($settings[Http3Settings::H3_DATAGRAM->value]) ? null : new Http3DatagramStream($this->parser->receiveDatagram(...), $this->writer->writeDatagram(...), $this->writer->maxDatagramSize(...), $stream); - $upgraded = new UpgradedSocket($client, $inputStream, $outputPipe->getSink(), $stream->getId(), $datagramStream); + $upgraded = new UpgradedSocket(new SocketClient($stream, $stream->getId()), $inputStream, $outputPipe->getSink(), $datagramStream); try { $upgradeHandler($upgraded, $request, $response); diff --git a/src/Driver/SocketClient.php b/src/Driver/SocketClient.php index 7dff3140..7681f1cf 100644 --- a/src/Driver/SocketClient.php +++ b/src/Driver/SocketClient.php @@ -3,18 +3,20 @@ namespace Amp\Http\Server\Driver; use Amp\Quic\QuicConnection; +use Amp\Quic\QuicSocket; use Amp\Socket\Socket; use Amp\Socket\SocketAddress; use Amp\Socket\TlsInfo; -final class SocketClient implements Client +class SocketClient implements Client { private readonly int $id; public function __construct( - private readonly Socket|QuicConnection $socket, + private readonly Client|Socket|QuicConnection $socket, + int $id = null ) { - $this->id = createClientId(); + $this->id = $id ?? createClientId(); } public function getId(): int @@ -37,9 +39,9 @@ public function getTlsInfo(): ?TlsInfo return $this->socket->getTlsInfo(); } - public function close(): void + public function close(int $reason = 0): void { - $this->socket->close(); + $this->socket->close($reason); } public function onClose(\Closure $onClose): void @@ -54,6 +56,6 @@ public function isClosed(): bool public function isQuicClient(): bool { - return $this->socket instanceof QuicConnection; + return $this->socket instanceof QuicConnection || $this->socket instanceof QuicSocket; } } diff --git a/src/Driver/UpgradedSocket.php b/src/Driver/UpgradedSocket.php index 167531c1..90abc4a9 100644 --- a/src/Driver/UpgradedSocket.php +++ b/src/Driver/UpgradedSocket.php @@ -27,7 +27,6 @@ public function __construct( private readonly Client $client, private readonly ReadableStream $readableStream, private readonly WritableStream $writableStream, - public readonly int $id = 0, public readonly ?DatagramStream $datagramClient = null, ) { }