diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetCore.cs b/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetCore.cs index ada9d6e6..cd136922 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetCore.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetCore.cs @@ -184,11 +184,15 @@ public override void Write(byte[] buffer, int offset, int count) public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) { - return new ValueTask(_owner.FinishWrapNonAppDataAsync(buffer, _owner._lastContextWritePromise ?? _owner.CapturedContext.NewPromise())); + var promise = _owner._asyncWritePromises.TryDequeue(out var queuedPromise) ? queuedPromise : _owner.CapturedContext.NewPromise(); + return new ValueTask(_owner.FinishWrapAsync(buffer, promise)); } public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - => _owner.FinishWrapNonAppDataAsync(buffer, offset, count, _owner._lastContextWritePromise ?? _owner.CapturedContext.NewPromise()); + { + var promise = _owner._asyncWritePromises.TryDequeue(out var queuedPromise) ? queuedPromise : _owner.CapturedContext.NewPromise(); + return _owner.FinishWrapAsync(buffer, offset, count, promise); + } } } } diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetFx.cs b/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetFx.cs index 96f98b51..db83326a 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetFx.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetFx.cs @@ -148,7 +148,7 @@ private int ReadFromInput(byte[] destination, int destinationOffset, int destina public override void Write(byte[] buffer, int offset, int count) => _owner.FinishWrap(buffer, offset, count, _owner._lastContextWritePromise ?? _owner.CapturedContext.VoidPromise()); public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - => _owner.FinishWrapNonAppDataAsync(buffer, offset, count, _owner._lastContextWritePromise ?? _owner.CapturedContext.NewPromise()); + => _owner.FinishWrapAsync(buffer, offset, count, _owner._lastContextWritePromise ?? _owner.CapturedContext.NewPromise()); private static readonly Action s_writeCompleteCallback = (t, s) => HandleChannelWriteComplete(t, s); diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetStandard20.cs b/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetStandard20.cs index 688110d7..d219dacc 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetStandard20.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetStandard20.cs @@ -111,7 +111,11 @@ private int ReadFromInput(byte[] destination, int destinationOffset, int destina public override void Write(byte[] buffer, int offset, int count) => _owner.FinishWrap(buffer, offset, count, _owner._lastContextWritePromise ?? _owner.CapturedContext.NewPromise()); public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - => _owner.FinishWrapNonAppDataAsync(buffer, offset, count, _owner._lastContextWritePromise ?? _owner.CapturedContext.NewPromise()); + { + var promiseQueue = _owner._asyncWritePromises; + var promise = promiseQueue.Count > 0 ? promiseQueue.Dequeue() : _owner.CapturedContext.NewPromise(); + return _owner.FinishWrapAsync(buffer, offset, count, promise); + } } } } diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.Writer.cs b/src/DotNetty.Handlers/Tls/TlsHandler.Writer.cs index 026b350f..3f09fbd6 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.Writer.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.Writer.cs @@ -23,6 +23,7 @@ namespace DotNetty.Handlers.Tls { using System; + using System.Collections.Generic; using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.ExceptionServices; @@ -43,6 +44,11 @@ partial class TlsHandler private IPromise _lastContextWritePromise; private volatile int v_wrapDataSize = TlsUtils.MAX_PLAINTEXT_LENGTH; + +#if NETCOREAPP_2_0_GREATER || NETSTANDARD_2_0_GREATER || NETSTANDARD2_0 + private readonly Queue _asyncWritePromises = new Queue(10); + private Task _lastAsyncWrite; +#endif /// /// Gets or Sets the number of bytes to pass to each call. @@ -173,7 +179,18 @@ private void Wrap(IChannelHandlerContext context) _lastContextWritePromise = promise; if (buf.IsReadable()) { +#if NETCOREAPP_2_0_GREATER || NETSTANDARD_2_0_GREATER + var asyncWrite = WriteAsync(buf, promise); + if (!asyncWrite.IsCompleted) + { + var asyncWriteTask = asyncWrite.AsTask(); + asyncWriteTask.Ignore(); + _lastAsyncWrite = asyncWriteTask; + } + buf = null; //prevent buf from releasing synchronously +#else _ = buf.ReadBytes(_sslStream, readableBytes); // this leads to FinishWrap being called 0+ times +#endif } else if (promise != null) { @@ -182,15 +199,12 @@ private void Wrap(IChannelHandlerContext context) } catch (Exception exc) { - promise.TrySetException(exc); - // SslStream has been closed already. - // Any further write attempts should be denied. - _pendingUnencryptedWrites?.ReleaseAndFailAll(exc); + OnWriteFailure(exc, promise); throw; } finally { - buf.Release(); + buf?.Release(); buf = null; promise = null; _lastContextWritePromise = null; @@ -243,7 +257,7 @@ private void FinishWrap(byte[] buffer, int offset, int count, IPromise promise) } #if NETCOREAPP || NETSTANDARD_2_0_GREATER - private Task FinishWrapNonAppDataAsync(in ReadOnlyMemory buffer, IPromise promise) + private Task FinishWrapAsync(in ReadOnlyMemory buffer, IPromise promise) { var capturedContext = CapturedContext; Task future; @@ -260,7 +274,7 @@ private Task FinishWrapNonAppDataAsync(in ReadOnlyMemory buffer, IPromise } #endif - private Task FinishWrapNonAppDataAsync(byte[] buffer, int offset, int count, IPromise promise) + private Task FinishWrapAsync(byte[] buffer, int offset, int count, IPromise promise) { var capturedContext = CapturedContext; var future = capturedContext.WriteAndFlushAsync(Unpooled.WrappedBuffer(buffer, offset, count), promise); @@ -269,20 +283,52 @@ private Task FinishWrapNonAppDataAsync(byte[] buffer, int offset, int count, IPr } #if NETCOREAPP || NETSTANDARD_2_0_GREATER - private static async ValueTask LinkOutcome(ValueTask valueTask, IPromise promise) + private async ValueTask WriteAsync(IByteBuffer buf, IPromise promise) { + var lastAsyncWrite = _lastAsyncWrite; + if (lastAsyncWrite != null && !lastAsyncWrite.IsCompletedSuccessfully) + { + try + { + await lastAsyncWrite; + } + catch (Exception ex) + { + //handle failure and propagate to the next pending write + buf.Release(); + promise.TrySetException(ex); + throw; + } + } + try { - await valueTask; - promise.TryComplete(); + _asyncWritePromises.Enqueue(promise); + var mem = buf.GetReadableMemory(); + await _sslStream.WriteAsync(mem, CancellationToken.None); // this leads to FinishWrapAsync being called 0+ times + buf.AdvanceReader(mem.Length); } catch (Exception ex) { - promise.TrySetException(ex); + //handle failure and propagate to the next pending write + OnWriteFailure(ex, promise); + throw; + } + finally + { + buf.Release(); } } #endif + private void OnWriteFailure(Exception ex, IPromise promise) + { + promise.TrySetException(ex); + // SslStream has been closed already. + // Any further write attempts should be denied. + _pendingUnencryptedWrites?.ReleaseAndFailAll(ex); + } + [MethodImpl(MethodImplOptions.NoInlining)] private static InvalidOperationException NewPendingWritesNullException() {