Skip to content

Use handshake timeout for Tls listener callback #62177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System.Security.Cryptography.X509Certificates;
using Microsoft.AspNetCore.Server.Kestrel.Core;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.AspNetCore.Server.Kestrel.Core.Middleware;
using Microsoft.AspNetCore.Server.Kestrel.Https;
using Microsoft.AspNetCore.Server.Kestrel.Https.Internal;
using Microsoft.Extensions.DependencyInjection;
Expand Down Expand Up @@ -198,15 +197,6 @@ public static ListenOptions UseHttps(this ListenOptions listenOptions, HttpsConn
listenOptions.IsTls = true;
listenOptions.HttpsOptions = httpsOptions;

if (httpsOptions.TlsClientHelloBytesCallback is not null)
{
listenOptions.Use(next =>
{
var middleware = new TlsListenerMiddleware(next, httpsOptions.TlsClientHelloBytesCallback);
return middleware.OnTlsClientHelloAsync;
});
}

listenOptions.Use(next =>
{
var middleware = new HttpsConnectionMiddleware(next, httpsOptions, listenOptions.Protocols, loggerFactory, metrics);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using Microsoft.AspNetCore.Server.Kestrel.Core.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.AspNetCore.Server.Kestrel.Core.Middleware;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;

Expand Down Expand Up @@ -44,6 +45,9 @@ internal sealed class HttpsConnectionMiddleware
private readonly Func<TlsHandshakeCallbackContext, ValueTask<SslServerAuthenticationOptions>>? _tlsCallbackOptions;
private readonly object? _tlsCallbackOptionsState;

// Captures raw TLS client hello and invokes a user callback if any
private readonly TlsListener? _tlsListener;

// Internal for testing
internal readonly HttpProtocols _httpProtocols;

Expand Down Expand Up @@ -112,6 +116,11 @@ public HttpsConnectionMiddleware(ConnectionDelegate next, HttpsConnectionAdapter
(RemoteCertificateValidationCallback?)null : RemoteCertificateValidationCallback;

_sslStreamFactory = s => new SslStream(s, leaveInnerStreamOpen: false, userCertificateValidationCallback: remoteCertificateValidationCallback);

if (options.TlsClientHelloBytesCallback is not null)
{
_tlsListener = new TlsListener(options.TlsClientHelloBytesCallback);
}
}

internal HttpsConnectionMiddleware(
Expand Down Expand Up @@ -162,6 +171,10 @@ public async Task OnConnectionAsync(ConnectionContext context)
using var cancellationTokenSource = _ctsPool.Rent();
cancellationTokenSource.CancelAfter(_handshakeTimeout);

if (_tlsListener is not null)
{
await _tlsListener.OnTlsClientHelloAsync(context, cancellationTokenSource.Token);
}
if (_tlsCallbackOptions is null)
{
await DoOptionsBasedHandshakeAsync(context, sslStream, feature, cancellationTokenSource.Token);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,27 @@

namespace Microsoft.AspNetCore.Server.Kestrel.Core.Middleware;

internal sealed class TlsListenerMiddleware
internal sealed class TlsListener
{
private readonly ConnectionDelegate _next;
private readonly Action<ConnectionContext, ReadOnlySequence<byte>> _tlsClientHelloBytesCallback;

public TlsListenerMiddleware(ConnectionDelegate next, Action<ConnectionContext, ReadOnlySequence<byte>> tlsClientHelloBytesCallback)
public TlsListener(Action<ConnectionContext, ReadOnlySequence<byte>> tlsClientHelloBytesCallback)
{
_next = next;
_tlsClientHelloBytesCallback = tlsClientHelloBytesCallback;
}

/// <summary>
/// Sniffs the TLS Client Hello message, and invokes a callback if found.
/// </summary>
internal async Task OnTlsClientHelloAsync(ConnectionContext connection)
internal async Task OnTlsClientHelloAsync(ConnectionContext connection, CancellationToken cancellationToken)
{
var input = connection.Transport.Input;
ClientHelloParseState parseState = ClientHelloParseState.NotEnoughData;
short recordLength = -1; // remembers the length of TLS record to not re-parse header on every iteration

while (true)
{
var result = await input.ReadAsync();
var result = await input.ReadAsync(cancellationToken);
var buffer = result.Buffer;

try
Expand All @@ -40,7 +39,7 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection)
break;
}

parseState = TryParseClientHello(buffer, out var clientHelloBytes);
parseState = TryParseClientHello(buffer, ref recordLength, out var clientHelloBytes);
if (parseState == ClientHelloParseState.NotEnoughData)
{
// if no data will be added, and we still lack enough bytes
Expand Down Expand Up @@ -74,8 +73,6 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection)
}
}
}

await _next(connection);
}

/// <summary>
Expand All @@ -85,10 +82,25 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection)
/// TLS 1.2: https://datatracker.ietf.org/doc/html/rfc5246#section-6.2
/// TLS 1.3: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1
/// </summary>
private static ClientHelloParseState TryParseClientHello(ReadOnlySequence<byte> buffer, out ReadOnlySequence<byte> clientHelloBytes)
private static ClientHelloParseState TryParseClientHello(ReadOnlySequence<byte> buffer, ref short recordLength, out ReadOnlySequence<byte> clientHelloBytes)
{
clientHelloBytes = default;

// in case bad actor will be sending a TLS client hello one byte at a time
// and we know the expected length of TLS client hello,
// we can check and fail quickly here instead of re-parsing the TLS client hello "header" on each iteration
if (recordLength != -1 && buffer.Length < 5 + recordLength)
{
return ClientHelloParseState.NotEnoughData;
}

// this means we finally got a full tls record, so we can return without parsing again
if (recordLength != -1)
{
clientHelloBytes = buffer.Slice(0, 5 + recordLength);
return ClientHelloParseState.ValidTlsClientHello;
}

if (buffer.Length < 6)
{
return ClientHelloParseState.NotEnoughData;
Expand All @@ -109,7 +121,7 @@ private static ClientHelloParseState TryParseClientHello(ReadOnlySequence<byte>
}

// Record length
if (!reader.TryReadBigEndian(out short recordLength))
if (!reader.TryReadBigEndian(out recordLength))
{
return ClientHelloParseState.NotTlsClientHello;
}
Expand Down
Loading
Loading