From df666f9992b76b8e3c18de2fcbe6707e0c83a8b5 Mon Sep 17 00:00:00 2001 From: Flavia Rainone Date: Tue, 11 Jun 2024 16:59:53 -0300 Subject: [PATCH] [UNDERTOW-2403] Fix race condition in read from buffer at ServletInputStreamImpl: prevent a NPE from being thrown to invoker in case input stream is closed by another thread (which can happen typically in timeouts) Signed-off-by: Flavia Rainone --- .../servlet/spec/ServletInputStreamImpl.java | 87 ++++++++++++------- spotbugs-exclude.xml | 9 ++ 2 files changed, 64 insertions(+), 32 deletions(-) diff --git a/servlet/src/main/java/io/undertow/servlet/spec/ServletInputStreamImpl.java b/servlet/src/main/java/io/undertow/servlet/spec/ServletInputStreamImpl.java index f8eaa452a6..32a0239999 100644 --- a/servlet/src/main/java/io/undertow/servlet/spec/ServletInputStreamImpl.java +++ b/servlet/src/main/java/io/undertow/servlet/spec/ServletInputStreamImpl.java @@ -185,8 +185,7 @@ public int read(final byte[] b, final int off, final int len) throws IOException int copied = Math.min(buffer.remaining(), len); buffer.get(b, off, copied); if (!buffer.hasRemaining()) { - pooled.close(); - pooled = null; + closePoolIfNotNull(); if (listener != null) { readIntoBufferNonBlocking(); } @@ -196,50 +195,83 @@ public int read(final byte[] b, final int off, final int len) throws IOException private void readIntoBuffer() throws IOException { if (pooled == null && !anyAreSet(state, FLAG_FINISHED)) { - pooled = bufferPool.allocate(); - - int res = Channels.readBlocking(channel, pooled.getBuffer()); - pooled.getBuffer().flip(); + final ByteBuffer byteBuffer; + try { + PooledByteBuffer pooled = this.pooled = bufferPool.allocate(); + byteBuffer = pooled.getBuffer(); + } catch (NullPointerException e) { + // check for NPE + FLAG_FINISHED, it indicates a race condition where the buffer was closed and set to null by + // another thread + // instead of paying the price of synchronization, we just ignore the NPE and return, mimicking the code path + // we would follow in case the check in the if statement above returned false + // this is unlikely to happen, it will happen only during timeouts and server shutdowns + if (anyAreSet(state, FLAG_FINISHED)) { + return; + } + throw e; + } + int res = Channels.readBlocking(channel, byteBuffer); + byteBuffer.flip(); if (res == -1) { setFlags(FLAG_FINISHED); - pooled.close(); - pooled = null; + closePoolIfNotNull(); } } } private void readIntoBufferNonBlocking() throws IOException { if (pooled == null && !anyAreSet(state, FLAG_FINISHED)) { - pooled = bufferPool.allocate(); + final ByteBuffer byteBuffer; + try { + PooledByteBuffer pooled = this.pooled = bufferPool.allocate(); + byteBuffer = pooled.getBuffer(); + } catch (NullPointerException e) { + // check for NPE + FLAG_FINISHED, it indicates a race condition where the buffer was closed and set to null by + // another thread + // instead of paying the price of synchronization, we just ignore the NPE and return, mimicking the code path + // we would follow in case the check in the if statement above returned false + // this is unlikely to happen, it will happen only during timeouts and server shutdowns + if (anyAreSet(state, FLAG_FINISHED)) { + return; + } + throw e; + } if (listener == null) { - int res = channel.read(pooled.getBuffer()); - if (res == 0) { - pooled.close(); - pooled = null; + int res = channel.read(byteBuffer); + if (res == 0 && pooled != null) { + closePoolIfNotNull(); return; } pooled.getBuffer().flip(); if (res == -1) { setFlags(FLAG_FINISHED); - pooled.close(); - pooled = null; + closePoolIfNotNull(); } } else { - int res = channel.read(pooled.getBuffer()); + int res = channel.read(byteBuffer); pooled.getBuffer().flip(); if (res == -1) { setFlags(FLAG_FINISHED); - pooled.close(); - pooled = null; + closePoolIfNotNull(); } else if (res == 0) { clearFlags(FLAG_READY); - pooled.close(); - pooled = null; + closePoolIfNotNull(); } } } } + private void closePoolIfNotNull() { + try { + if (pooled != null) { + pooled.close(); + pooled = null; + } + } catch (NullPointerException npe) { + // ignore it, this can happen if reading while another thread shutdown this input stream, caused by a timeout or a jdk shutdown + } + } + @Override public int available() throws IOException { if (anyAreSet(state, FLAG_CLOSED)) { @@ -264,17 +296,11 @@ public void close() throws IOException { try { while (allAreClear(state, FLAG_FINISHED)) { readIntoBuffer(); - if (pooled != null) { - pooled.close(); - pooled = null; - } + closePoolIfNotNull(); // race condition can happen if read bytes reads -1 } } finally { setFlags(FLAG_FINISHED); - if (pooled != null) { - pooled.close(); - pooled = null; - } + closePoolIfNotNull(); channel.shutdownReads(); } } @@ -327,10 +353,7 @@ public void run() { } }); } finally { - if (pooled != null) { - pooled.close(); - pooled = null; - } + closePoolIfNotNull(); IoUtils.safeClose(channel); } } diff --git a/spotbugs-exclude.xml b/spotbugs-exclude.xml index cce71e997e..b386b09876 100644 --- a/spotbugs-exclude.xml +++ b/spotbugs-exclude.xml @@ -928,6 +928,15 @@ + + + + + + + + +