Skip to content

Commit 8e75dd3

Browse files
committed
[UNDERTOW-2425] At ServletOutputStreamImpl synchronized workflow (listener = null), prevent the buffer.flip() from not being cleared after an error during attempts to write.
Also, at ServletPrintWriter, verify if no progress is being made when attempting to encode returns overflow after flushing, and mark error even if there are remaining bytes in the buffer. Signed-off-by: Flavia Rainone <[email protected]>
1 parent bf0cfcb commit 8e75dd3

File tree

3 files changed

+147
-100
lines changed

3 files changed

+147
-100
lines changed

servlet/src/main/java/io/undertow/servlet/spec/ServletOutputStreamImpl.java

+133-95
Original file line numberDiff line numberDiff line change
@@ -178,34 +178,14 @@ private void writeTooLargeForBuffer(byte[] b, int off, int len, ByteBuffer buffe
178178
int rem = buffer.remaining();
179179
buffer.put(b, bytesWritten + off, rem);
180180
buffer.flip();
181-
bytesWritten += rem;
182-
int bufferCount = 1;
183-
for (int i = 0; i < MAX_BUFFERS_TO_ALLOCATE; ++i) {
184-
PooledByteBuffer pooled = bufferPool.allocate();
185-
pooledBuffers[bufferCount - 1] = pooled;
186-
buffers[bufferCount++] = pooled.getBuffer();
187-
ByteBuffer cb = pooled.getBuffer();
188-
int toWrite = len - bytesWritten;
189-
if (toWrite > cb.remaining()) {
190-
rem = cb.remaining();
191-
cb.put(b, bytesWritten + off, rem);
192-
cb.flip();
193-
bytesWritten += rem;
194-
} else {
195-
cb.put(b, bytesWritten + off, toWrite);
196-
bytesWritten = len;
197-
cb.flip();
198-
break;
199-
}
200-
}
201-
Channels.writeBlocking(channel, buffers, 0, bufferCount);
202-
while (bytesWritten < len) {
203-
//ok, it did not fit, loop and loop and loop until it is done
204-
bufferCount = 0;
205-
for (int i = 0; i < MAX_BUFFERS_TO_ALLOCATE + 1; ++i) {
206-
ByteBuffer cb = buffers[i];
207-
cb.clear();
208-
bufferCount++;
181+
try {
182+
bytesWritten += rem;
183+
int bufferCount = 1;
184+
for (int i = 0; i < MAX_BUFFERS_TO_ALLOCATE; ++i) {
185+
PooledByteBuffer pooled = bufferPool.allocate();
186+
pooledBuffers[bufferCount - 1] = pooled;
187+
buffers[bufferCount++] = pooled.getBuffer();
188+
ByteBuffer cb = pooled.getBuffer();
209189
int toWrite = len - bytesWritten;
210190
if (toWrite > cb.remaining()) {
211191
rem = cb.remaining();
@@ -219,9 +199,38 @@ private void writeTooLargeForBuffer(byte[] b, int off, int len, ByteBuffer buffe
219199
break;
220200
}
221201
}
222-
Channels.writeBlocking(channel, buffers, 0, bufferCount);
202+
writeBlocking(buffers, 0, bufferCount, bytesWritten);
203+
// at this point, we know that all buffers[i] have 0 bytes remaining(), so it is safe to loop next just
204+
// until we reach len, even if we stop before reaching the end of buffers array
205+
while (bytesWritten < len) {
206+
int oldBytesWritten = bytesWritten;
207+
//ok, it did not fit, loop and loop and loop until it is done
208+
bufferCount = 0;
209+
for (int i = 0; i < MAX_BUFFERS_TO_ALLOCATE + 1; ++i) {
210+
ByteBuffer cb = buffers[i];
211+
cb.clear();
212+
bufferCount++;
213+
int toWrite = len - bytesWritten;
214+
if (toWrite > cb.remaining()) {
215+
rem = cb.remaining();
216+
cb.put(b, bytesWritten + off, rem);
217+
cb.flip();
218+
bytesWritten += rem;
219+
} else {
220+
cb.put(b, bytesWritten + off, toWrite);
221+
bytesWritten = len;
222+
cb.flip();
223+
// safe to break, all buffers that come next have zero remaining() bytes and hence
224+
// won't affect the next writeBlocking call
225+
break;
226+
}
227+
}
228+
writeBlocking(buffers, 0, bufferCount, bytesWritten - oldBytesWritten);
229+
}
230+
} finally {
231+
if (buffer != null)
232+
buffer.compact();
223233
}
224-
buffer.clear();
225234
} finally {
226235
for (int i = 0; i < pooledBuffers.length; ++i) {
227236
PooledByteBuffer p = pooledBuffers[i];
@@ -245,29 +254,36 @@ private void writeAsync(byte[] b, int off, int len) throws IOException {
245254
buffer.put(b, off, len);
246255
} else {
247256
buffer.flip();
248-
final ByteBuffer userBuffer = ByteBuffer.wrap(b, off, len);
249-
final ByteBuffer[] bufs = new ByteBuffer[]{buffer, userBuffer};
250-
long toWrite = Buffers.remaining(bufs);
251-
long res;
252-
long written = 0;
253-
createChannel();
254-
setFlags(FLAG_WRITE_STARTED);
255-
do {
256-
res = channel.write(bufs);
257-
written += res;
258-
if (res == 0) {
259-
//write it out with a listener
260-
//but we need to copy any extra data
261-
final ByteBuffer copy = ByteBuffer.allocate(userBuffer.remaining());
262-
copy.put(userBuffer);
263-
copy.flip();
264-
265-
this.buffersToWrite = new ByteBuffer[]{buffer, copy};
266-
clearFlags(FLAG_READY);
267-
return;
257+
boolean clearBuffer = true;
258+
try {
259+
final ByteBuffer userBuffer = ByteBuffer.wrap(b, off, len);
260+
final ByteBuffer[] bufs = new ByteBuffer[]{buffer, userBuffer};
261+
long toWrite = Buffers.remaining(bufs);
262+
long res;
263+
long written = 0;
264+
createChannel();
265+
setFlags(FLAG_WRITE_STARTED);
266+
do {
267+
res = channel.write(bufs);
268+
written += res;
269+
if (res == 0) {
270+
//write it out with a listener
271+
//but we need to copy any extra data
272+
final ByteBuffer copy = ByteBuffer.allocate(userBuffer.remaining());
273+
copy.put(userBuffer);
274+
copy.flip();
275+
276+
this.buffersToWrite = new ByteBuffer[]{buffer, copy};
277+
clearFlags(FLAG_READY);
278+
clearBuffer = false;
279+
return;
280+
}
281+
} while (written < toWrite);
282+
} finally {
283+
if (clearBuffer && buffer != null) {
284+
buffer.compact();
268285
}
269-
} while (written < toWrite);
270-
buffer.clear();
286+
}
271287
}
272288
} finally {
273289
updateWrittenAsync(len);
@@ -296,7 +312,7 @@ public void write(ByteBuffer[] buffers) throws IOException {
296312
if (channel == null) {
297313
channel = servletRequestContext.getExchange().getResponseChannel();
298314
}
299-
Channels.writeBlocking(channel, buffers, 0, buffers.length);
315+
writeBlocking(buffers, 0, buffers.length, len);
300316
setFlags(FLAG_WRITE_STARTED);
301317
} else {
302318
ByteBuffer buffer = buffer();
@@ -307,14 +323,18 @@ public void write(ByteBuffer[] buffers) throws IOException {
307323
channel = servletRequestContext.getExchange().getResponseChannel();
308324
}
309325
if (buffer.position() == 0) {
310-
Channels.writeBlocking(channel, buffers, 0, buffers.length);
326+
writeBlocking(buffers, 0, buffers.length, len);
311327
} else {
312328
final ByteBuffer[] newBuffers = new ByteBuffer[buffers.length + 1];
313329
buffer.flip();
314-
newBuffers[0] = buffer;
315-
System.arraycopy(buffers, 0, newBuffers, 1, buffers.length);
316-
Channels.writeBlocking(channel, newBuffers, 0, newBuffers.length);
317-
buffer.clear();
330+
try {
331+
newBuffers[0] = buffer;
332+
System.arraycopy(buffers, 0, newBuffers, 1, buffers.length);
333+
writeBlocking(newBuffers, 0, newBuffers.length, len + buffer.remaining());
334+
} finally {
335+
if (buffer != null)
336+
buffer.clear();
337+
}
318338
}
319339
setFlags(FLAG_WRITE_STARTED);
320340
}
@@ -333,30 +353,34 @@ public void write(ByteBuffer[] buffers) throws IOException {
333353
} else {
334354
final ByteBuffer[] bufs = new ByteBuffer[buffers.length + 1];
335355
buffer.flip();
336-
bufs[0] = buffer;
337-
System.arraycopy(buffers, 0, bufs, 1, buffers.length);
338-
long toWrite = Buffers.remaining(bufs);
339-
long res;
340-
long written = 0;
341-
createChannel();
342-
setFlags(FLAG_WRITE_STARTED);
343-
do {
344-
res = channel.write(bufs);
345-
written += res;
346-
if (res == 0) {
347-
//write it out with a listener
348-
//but we need to copy any extra data
349-
//TODO: should really allocate from the pool here
350-
final ByteBuffer copy = ByteBuffer.allocate((int) Buffers.remaining(buffers));
351-
Buffers.copy(copy, buffers, 0, buffers.length);
352-
copy.flip();
353-
this.buffersToWrite = new ByteBuffer[]{buffer, copy};
354-
clearFlags(FLAG_READY);
355-
channel.resumeWrites();
356-
return;
357-
}
358-
} while (written < toWrite);
359-
buffer.clear();
356+
try {
357+
bufs[0] = buffer;
358+
System.arraycopy(buffers, 0, bufs, 1, buffers.length);
359+
long toWrite = Buffers.remaining(bufs);
360+
long res;
361+
long written = 0;
362+
createChannel();
363+
setFlags(FLAG_WRITE_STARTED);
364+
do {
365+
res = channel.write(bufs);
366+
written += res;
367+
if (res == 0) {
368+
//write it out with a listener
369+
//but we need to copy any extra data
370+
//TODO: should really allocate from the pool here
371+
final ByteBuffer copy = ByteBuffer.allocate((int) Buffers.remaining(buffers));
372+
Buffers.copy(copy, buffers, 0, buffers.length);
373+
copy.flip();
374+
this.buffersToWrite = new ByteBuffer[] { buffer, copy };
375+
clearFlags(FLAG_READY);
376+
channel.resumeWrites();
377+
return;
378+
}
379+
} while (written < toWrite);
380+
} finally {
381+
if (buffer != null)
382+
buffer.compact();
383+
}
360384
}
361385
} finally {
362386
updateWrittenAsync(len);
@@ -515,14 +539,18 @@ public void flushInternal() throws IOException {
515539
//if the write fails we just compact, rather than changing the ready state
516540
setFlags(FLAG_WRITE_STARTED);
517541
buffer.flip();
518-
long res;
519-
do {
520-
res = channel.write(buffer);
521-
} while (buffer.hasRemaining() && res != 0);
522-
if (!buffer.hasRemaining()) {
523-
channel.flush();
542+
try {
543+
long res;
544+
do {
545+
res = channel.write(buffer);
546+
} while (buffer.hasRemaining() && res != 0);
547+
if (!buffer.hasRemaining()) {
548+
channel.flush();
549+
}
550+
} finally {
551+
if (buffer != null)
552+
buffer.compact();
524553
}
525-
buffer.compact();
526554
}
527555
}
528556

@@ -579,14 +607,18 @@ private void writeBufferBlocking(final boolean writeFinal) throws IOException {
579607
channel = servletRequestContext.getExchange().getResponseChannel();
580608
}
581609
buffer.flip();
582-
while (buffer.hasRemaining()) {
583-
int result = writeFinal ? channel.writeFinal(buffer) : channel.write(buffer);
584-
if (result == 0) {
585-
channel.awaitWritable();
610+
try {
611+
while (buffer.hasRemaining()) {
612+
int result = writeFinal ? channel.writeFinal(buffer) : channel.write(buffer);
613+
if (result == 0) {
614+
channel.awaitWritable();
615+
}
586616
}
617+
} finally {
618+
if (buffer != null)
619+
buffer.compact();
620+
setFlags(FLAG_WRITE_STARTED);
587621
}
588-
buffer.clear();
589-
setFlags(FLAG_WRITE_STARTED);
590622
}
591623

592624
/**
@@ -964,4 +996,10 @@ private void clearFlags(int flags) {
964996
} while (!stateUpdater.compareAndSet(this, old, old & ~flags));
965997
}
966998

999+
private void writeBlocking(ByteBuffer[] buffers, int offs, int len, int bytesToWrite) throws IOException {
1000+
int totalWritten = 0;
1001+
do {
1002+
totalWritten += Channels.writeBlocking(channel, buffers, 0, len);
1003+
} while (totalWritten < bytesToWrite);
1004+
}
9671005
}

servlet/src/main/java/io/undertow/servlet/spec/ServletPrintWriter.java

+7-3
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ public void close() {
103103
underflow = null;
104104
}
105105
if (charsetEncoder != null) {
106+
int remaining = 0;
106107
do {
108+
// before we get the underlying buffer, we need to flush outputStream
107109
ByteBuffer out = outputStream.underlyingBuffer();
108110
if (out == null) {
109111
//servlet output stream has already been closed
@@ -113,11 +115,13 @@ public void close() {
113115
CoderResult result = charsetEncoder.encode(buffer, out, true);
114116
if (result.isOverflow()) {
115117
outputStream.flushInternal();
116-
if (out.remaining() == 0) {
118+
if (out.remaining() == remaining) {
119+
// no progress in flush
117120
outputStream.close();
118121
error = true;
119122
return;
120-
}
123+
} else
124+
remaining = out.remaining();
121125
} else {
122126
done = true;
123127
}
@@ -177,7 +181,7 @@ public void write(final CharBuffer input) {
177181
outputStream.updateWritten(writtenLength);
178182
if (result.isOverflow() || !buffer.hasRemaining()) {
179183
outputStream.flushInternal();
180-
if (!buffer.hasRemaining()) {
184+
if (buffer.remaining() == remaining) {
181185
error = true;
182186
return;
183187
}

spotbugs-exclude.xml

+7-2
Original file line numberDiff line numberDiff line change
@@ -925,8 +925,13 @@
925925
</Match>
926926
<Match>
927927
<Bug pattern="RCN_REDUNDANT_NULLCHECK_WOULD_HAVE_BEEN_A_NPE"/>
928-
<Class name="io.undertow.client.http.HttpClientConnection$ClientReadListener"/>
929-
<Method name="handleEvent"/>
928+
<Or>
929+
<And>
930+
<Class name="io.undertow.client.http.HttpClientConnection$ClientReadListener"/>
931+
<Method name="handleEvent"/>
932+
</And>
933+
<Class name="io.undertow.servlet.spec.ServletOutputStreamImpl"/>
934+
</Or>
930935
</Match>
931936
<!-- ignore benchmarks -->
932937
<Match>

0 commit comments

Comments
 (0)