diff --git a/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs b/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs index 42d7ac8f0476..32bd1dd59889 100644 --- a/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs +++ b/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs @@ -5,7 +5,6 @@ using System.Security.Cryptography.X509Certificates; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; -using Microsoft.AspNetCore.Server.Kestrel.Core.Middleware; using Microsoft.AspNetCore.Server.Kestrel.Https; using Microsoft.AspNetCore.Server.Kestrel.Https.Internal; using Microsoft.Extensions.DependencyInjection; @@ -198,15 +197,6 @@ public static ListenOptions UseHttps(this ListenOptions listenOptions, HttpsConn listenOptions.IsTls = true; listenOptions.HttpsOptions = httpsOptions; - if (httpsOptions.TlsClientHelloBytesCallback is not null) - { - listenOptions.Use(next => - { - var middleware = new TlsListenerMiddleware(next, httpsOptions.TlsClientHelloBytesCallback); - return middleware.OnTlsClientHelloAsync; - }); - } - listenOptions.Use(next => { var middleware = new HttpsConnectionMiddleware(next, httpsOptions, listenOptions.Protocols, loggerFactory, metrics); diff --git a/src/Servers/Kestrel/Core/src/Middleware/HttpsConnectionMiddleware.cs b/src/Servers/Kestrel/Core/src/Middleware/HttpsConnectionMiddleware.cs index 4508204d7fe7..7f6b3bf1b197 100644 --- a/src/Servers/Kestrel/Core/src/Middleware/HttpsConnectionMiddleware.cs +++ b/src/Servers/Kestrel/Core/src/Middleware/HttpsConnectionMiddleware.cs @@ -17,6 +17,7 @@ using Microsoft.AspNetCore.Server.Kestrel.Core.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Core.Middleware; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -44,6 +45,9 @@ internal sealed class HttpsConnectionMiddleware private readonly Func>? _tlsCallbackOptions; private readonly object? _tlsCallbackOptionsState; + // Captures raw TLS client hello and invokes a user callback if any + private readonly TlsListener? _tlsListener; + // Internal for testing internal readonly HttpProtocols _httpProtocols; @@ -112,6 +116,11 @@ public HttpsConnectionMiddleware(ConnectionDelegate next, HttpsConnectionAdapter (RemoteCertificateValidationCallback?)null : RemoteCertificateValidationCallback; _sslStreamFactory = s => new SslStream(s, leaveInnerStreamOpen: false, userCertificateValidationCallback: remoteCertificateValidationCallback); + + if (options.TlsClientHelloBytesCallback is not null) + { + _tlsListener = new TlsListener(options.TlsClientHelloBytesCallback); + } } internal HttpsConnectionMiddleware( @@ -162,6 +171,10 @@ public async Task OnConnectionAsync(ConnectionContext context) using var cancellationTokenSource = _ctsPool.Rent(); cancellationTokenSource.CancelAfter(_handshakeTimeout); + if (_tlsListener is not null) + { + await _tlsListener.OnTlsClientHelloAsync(context, cancellationTokenSource.Token); + } if (_tlsCallbackOptions is null) { await DoOptionsBasedHandshakeAsync(context, sslStream, feature, cancellationTokenSource.Token); diff --git a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs b/src/Servers/Kestrel/Core/src/Middleware/TlsListener.cs similarity index 78% rename from src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs rename to src/Servers/Kestrel/Core/src/Middleware/TlsListener.cs index 01bc75553a09..c7207daf05b3 100644 --- a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs +++ b/src/Servers/Kestrel/Core/src/Middleware/TlsListener.cs @@ -7,28 +7,27 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Middleware; -internal sealed class TlsListenerMiddleware +internal sealed class TlsListener { - private readonly ConnectionDelegate _next; private readonly Action> _tlsClientHelloBytesCallback; - public TlsListenerMiddleware(ConnectionDelegate next, Action> tlsClientHelloBytesCallback) + public TlsListener(Action> tlsClientHelloBytesCallback) { - _next = next; _tlsClientHelloBytesCallback = tlsClientHelloBytesCallback; } /// /// Sniffs the TLS Client Hello message, and invokes a callback if found. /// - internal async Task OnTlsClientHelloAsync(ConnectionContext connection) + internal async Task OnTlsClientHelloAsync(ConnectionContext connection, CancellationToken cancellationToken) { var input = connection.Transport.Input; ClientHelloParseState parseState = ClientHelloParseState.NotEnoughData; + short recordLength = -1; // remembers the length of TLS record to not re-parse header on every iteration while (true) { - var result = await input.ReadAsync(); + var result = await input.ReadAsync(cancellationToken); var buffer = result.Buffer; try @@ -40,7 +39,7 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection) break; } - parseState = TryParseClientHello(buffer, out var clientHelloBytes); + parseState = TryParseClientHello(buffer, ref recordLength, out var clientHelloBytes); if (parseState == ClientHelloParseState.NotEnoughData) { // if no data will be added, and we still lack enough bytes @@ -74,8 +73,6 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection) } } } - - await _next(connection); } /// @@ -85,10 +82,25 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection) /// TLS 1.2: https://datatracker.ietf.org/doc/html/rfc5246#section-6.2 /// TLS 1.3: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1 /// - private static ClientHelloParseState TryParseClientHello(ReadOnlySequence buffer, out ReadOnlySequence clientHelloBytes) + private static ClientHelloParseState TryParseClientHello(ReadOnlySequence buffer, ref short recordLength, out ReadOnlySequence clientHelloBytes) { clientHelloBytes = default; + // in case bad actor will be sending a TLS client hello one byte at a time + // and we know the expected length of TLS client hello, + // we can check and fail quickly here instead of re-parsing the TLS client hello "header" on each iteration + if (recordLength != -1 && buffer.Length < 5 + recordLength) + { + return ClientHelloParseState.NotEnoughData; + } + + // this means we finally got a full tls record, so we can return without parsing again + if (recordLength != -1) + { + clientHelloBytes = buffer.Slice(0, 5 + recordLength); + return ClientHelloParseState.ValidTlsClientHello; + } + if (buffer.Length < 6) { return ClientHelloParseState.NotEnoughData; @@ -109,7 +121,7 @@ private static ClientHelloParseState TryParseClientHello(ReadOnlySequence } // Record length - if (!reader.TryReadBigEndian(out short recordLength)) + if (!reader.TryReadBigEndian(out recordLength)) { return ClientHelloParseState.NotTlsClientHello; } diff --git a/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs b/src/Servers/Kestrel/Core/test/TlsListenerTests.cs similarity index 81% rename from src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs rename to src/Servers/Kestrel/Core/test/TlsListenerTests.cs index ea3103108ff7..4bce89de208d 100644 --- a/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs +++ b/src/Servers/Kestrel/Core/test/TlsListenerTests.cs @@ -24,11 +24,13 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; +using Microsoft.VisualStudio.TestPlatform; using Moq; +using Xunit.Sdk; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests; -public class TlsListenerMiddlewareTests +public class TlsListenerTests { [Theory] [MemberData(nameof(ValidClientHelloData))] @@ -50,6 +52,80 @@ public Task OnTlsClientHelloAsync_ValidData_MultipleSegments(int id, List packets) => RunTlsClientHelloCallbackTest_WithMultipleSegments(id, packets, tlsClientHelloCallbackExpected: false); + [Fact] + public async Task RunTlsClientHelloCallbackTest_WithExtraShortLastingToken() + { + var serviceContext = new TestServiceContext(); + + var pipe = new Pipe(); + var writer = pipe.Writer; + var reader = new ObservablePipeReader(pipe.Reader); + + var transport = new DuplexPipe(reader, writer); + var transportConnection = new DefaultConnectionContext("test", transport, transport); + + var tlsClientHelloCallbackInvoked = false; + var listener = new TlsListener((ctx, data) => { tlsClientHelloCallbackInvoked = true; }); + + var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(3)); + + await writer.WriteAsync(new byte[1] { 0x16 }); + await VerifyThrowsAnyAsync( + async () => await listener.OnTlsClientHelloAsync(transportConnection, cts.Token), + typeof(OperationCanceledException), typeof(TaskCanceledException)); + Assert.False(tlsClientHelloCallbackInvoked); + } + + [Fact] + public async Task RunTlsClientHelloCallbackTest_WithPreCanceledToken() + { + var serviceContext = new TestServiceContext(); + + var pipe = new Pipe(); + var writer = pipe.Writer; + var reader = new ObservablePipeReader(pipe.Reader); + + var transport = new DuplexPipe(reader, writer); + var transportConnection = new DefaultConnectionContext("test", transport, transport); + + var tlsClientHelloCallbackInvoked = false; + var listener = new TlsListener((ctx, data) => { tlsClientHelloCallbackInvoked = true; }); + + var cts = new CancellationTokenSource(); + cts.Cancel(); + + await writer.WriteAsync(new byte[1] { 0x16 }); + await VerifyThrowsAnyAsync( + async () => await listener.OnTlsClientHelloAsync(transportConnection, cts.Token), + typeof(OperationCanceledException), typeof(TaskCanceledException)); + Assert.False(tlsClientHelloCallbackInvoked); + } + + [Fact] + public async Task RunTlsClientHelloCallbackTest_WithPendingCancellation() + { + var serviceContext = new TestServiceContext(); + + var pipe = new Pipe(); + var writer = pipe.Writer; + var reader = new ObservablePipeReader(pipe.Reader); + + var transport = new DuplexPipe(reader, writer); + var transportConnection = new DefaultConnectionContext("test", transport, transport); + + var tlsClientHelloCallbackInvoked = false; + var listener = new TlsListener((ctx, data) => { tlsClientHelloCallbackInvoked = true; }); + + var cts = new CancellationTokenSource(); + await writer.WriteAsync(new byte[1] { 0x16 }); + var listenerTask = listener.OnTlsClientHelloAsync(transportConnection, cts.Token); + await writer.WriteAsync(new byte[2] { 0x03, 0x01 }); + cts.Cancel(); + + await Assert.ThrowsAsync(async () => await listenerTask); + Assert.False(tlsClientHelloCallbackInvoked); + } + [Fact] public async Task RunTlsClientHelloCallbackTest_DeterministicallyReads() { @@ -66,34 +142,21 @@ public async Task RunTlsClientHelloCallbackTest_DeterministicallyReads() var transport = new DuplexPipe(reader, writer); var transportConnection = new DefaultConnectionContext("test", transport, transport); - var nextMiddlewareInvoked = false; var tlsClientHelloCallbackInvoked = false; - - var middleware = new TlsListenerMiddleware( - next: ctx => - { - nextMiddlewareInvoked = true; - var readResult = ctx.Transport.Input.ReadAsync(); - Assert.Equal(5, readResult.Result.Buffer.Length); - - return Task.CompletedTask; - }, - tlsClientHelloBytesCallback: (ctx, data) => - { - tlsClientHelloCallbackInvoked = true; - } - ); + var listener = new TlsListener((ctx, data) => { tlsClientHelloCallbackInvoked = true; }); await writer.WriteAsync(new byte[1] { 0x16 }); - var middlewareTask = middleware.OnTlsClientHelloAsync(transportConnection); + var listenerTask = listener.OnTlsClientHelloAsync(transportConnection, CancellationToken.None); await writer.WriteAsync(new byte[2] { 0x03, 0x01 }); await writer.WriteAsync(new byte[2] { 0x00, 0x20 }); await writer.CompleteAsync(); - await middlewareTask; - Assert.True(nextMiddlewareInvoked); + await listenerTask; Assert.False(tlsClientHelloCallbackInvoked); + var readResult = await reader.ReadAsync(); + Assert.Equal(5, readResult.Buffer.Length); + // ensuring that we have read limited number of times Assert.True(reader.ReadAsyncCounter is >= 2 && reader.ReadAsyncCounter is <= 4, $"Expected ReadAsync() to happen about 2-4 times. Actually happened {reader.ReadAsyncCounter} times."); @@ -110,23 +173,11 @@ private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( var transport = new DuplexPipe(pipe.Reader, writer); var transportConnection = new DefaultConnectionContext("test", transport, transport); - var nextMiddlewareInvokedActual = false; var tlsClientHelloCallbackActual = false; var fullLength = packets.Sum(p => p.Length); - var middleware = new TlsListenerMiddleware( - next: ctx => - { - nextMiddlewareInvokedActual = true; - if (tlsClientHelloCallbackActual) - { - var readResult = ctx.Transport.Input.ReadAsync(); - Assert.Equal(fullLength, readResult.Result.Buffer.Length); - } - - return Task.CompletedTask; - }, + var listener = new TlsListener( tlsClientHelloBytesCallback: (ctx, data) => { tlsClientHelloCallbackActual = true; @@ -139,9 +190,8 @@ private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( // write first packet await writer.WriteAsync(packets[0]); - var middlewareTask = middleware.OnTlsClientHelloAsync(transportConnection); + var listenerTask = listener.OnTlsClientHelloAsync(transportConnection, CancellationToken.None); - /* It is a race condition (middleware's loop and writes here). * We don't know specifically how many packets will be read by middleware's loop * (possibly there are even 2 packets - the first and all others combined). @@ -154,10 +204,15 @@ private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( await writer.WriteAsync(packet); } await writer.CompleteAsync(); - await middlewareTask; + await listenerTask; - Assert.True(nextMiddlewareInvokedActual); Assert.Equal(tlsClientHelloCallbackExpected, tlsClientHelloCallbackActual); + + if (tlsClientHelloCallbackActual) + { + var readResult = await pipe.Reader.ReadAsync(); + Assert.Equal(fullLength, readResult.Buffer.Length); + } } private async Task RunTlsClientHelloCallbackTest( @@ -171,18 +226,9 @@ private async Task RunTlsClientHelloCallbackTest( var transport = new DuplexPipe(pipe.Reader, writer); var transportConnection = new DefaultConnectionContext("test", transport, transport); - var nextMiddlewareInvokedActual = false; var tlsClientHelloCallbackActual = false; - var middleware = new TlsListenerMiddleware( - next: ctx => - { - nextMiddlewareInvokedActual = true; - var readResult = ctx.Transport.Input.ReadAsync(); - Assert.Equal(packetBytes.Length, readResult.Result.Buffer.Length); - - return Task.CompletedTask; - }, + var listener = new TlsListener( tlsClientHelloBytesCallback: (ctx, data) => { tlsClientHelloCallbackActual = true; @@ -197,10 +243,12 @@ private async Task RunTlsClientHelloCallbackTest( await writer.CompleteAsync(); // call middleware and expect a callback - await middleware.OnTlsClientHelloAsync(transportConnection); + await listener.OnTlsClientHelloAsync(transportConnection, CancellationToken.None); - Assert.True(nextMiddlewareInvokedActual); Assert.Equal(tlsClientHelloCallbackExpected, tlsClientHelloCallbackActual); + + var readResult = await pipe.Reader.ReadAsync(); + Assert.Equal(packetBytes.Length, readResult.Buffer.Length); } public static IEnumerable ValidClientHelloData() @@ -575,4 +623,28 @@ public static IEnumerable InvalidClientHelloData_Segmented() _invalidTlsClientHelloHeader, _invalid3BytesMessage, _invalid9BytesMessage, _invalidUnknownProtocolVersion1, _invalidUnknownProtocolVersion2, _invalidIncorrectHandshakeMessageType }; + + static async Task VerifyThrowsAnyAsync(Func code, params Type[] exceptionTypes) + { + if (exceptionTypes == null || exceptionTypes.Length == 0) + { + throw new ArgumentException("At least one exception type must be provided.", nameof(exceptionTypes)); + } + + try + { + await code(); + } + catch (Exception ex) + { + if (exceptionTypes.Any(type => type.IsInstanceOfType(ex))) + { + return; + } + + throw ThrowsException.ForIncorrectExceptionType(exceptionTypes.First(), ex); + } + + throw ThrowsException.ForNoException(exceptionTypes.First()); + } } diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerTests.cs similarity index 97% rename from src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs rename to src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerTests.cs index b57ca2405ba4..f91ea27eae8f 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerTests.cs @@ -21,7 +21,7 @@ namespace InMemory.FunctionalTests; -public class TlsListenerMiddlewareTests : TestApplicationErrorLoggerLoggedTest +public class TlsListenerTests : TestApplicationErrorLoggerLoggedTest { private static readonly X509Certificate2 _x509Certificate2 = TestResources.GetTestCertificate();