diff --git a/src/linux/init/init.cpp b/src/linux/init/init.cpp index 94fc3cac6..7e2492f28 100644 --- a/src/linux/init/init.cpp +++ b/src/linux/init/init.cpp @@ -1250,6 +1250,8 @@ try _exit(1); } + channel.SetStrictReplyEnd(); + SessionLeaderEntryUtilityVm(channel, Config); }); } @@ -2427,6 +2429,9 @@ Return Value: PollDescriptors[1].events = POLLIN; } + // Enable strict reply-end sequencing before receiving any messages from Windows. + channel.SetStrictReplyEnd(); + for (;;) { auto Result = poll(PollDescriptors.data(), PollDescriptors.size(), -1); diff --git a/src/shared/inc/SocketChannel.h b/src/shared/inc/SocketChannel.h index 154d3148f..fd145a414 100644 --- a/src/shared/inc/SocketChannel.h +++ b/src/shared/inc/SocketChannel.h @@ -63,6 +63,11 @@ class SocketChannel m_exitEvent = std::move(other.m_exitEvent); #endif m_ignore_sequence = other.m_ignore_sequence; + m_strict_request_end = other.m_strict_request_end; + m_strict_reply_end = other.m_strict_reply_end; + m_sent_messages = other.m_sent_messages; + m_next_send_increment = other.m_next_send_increment; + m_received_messages = other.m_received_messages; return *this; } @@ -82,7 +87,7 @@ class SocketChannel #endif template - void SendMessage(gsl::span span) + void SendMessage(gsl::span span, uint32_t nextSendIncrement = 1) { // Ensure that no other thread is using this channel. const std::unique_lock lock{m_sendMutex, std::try_to_lock}; @@ -103,12 +108,30 @@ class SocketChannel THROW_INVALID_ARG_IF(m_name == nullptr || span.size() < sizeof(TMessage)); - m_sent_messages++; + uint32_t sequenceNumber = 0; + if (m_strict_request_end) + { + std::lock_guard sequenceLock{m_sequenceMutex}; + m_sent_messages += m_next_send_increment; + m_next_send_increment = nextSendIncrement; + sequenceNumber = m_sent_messages; + } + else if (m_strict_reply_end) + { + std::lock_guard sequenceLock{m_sequenceMutex}; + m_sent_messages++; + sequenceNumber = m_sent_messages; + } + else + { + m_sent_messages++; + sequenceNumber = m_sent_messages; + } auto* header = gslhelpers::try_get_struct(span); WI_ASSERT(header->MessageSize == span.size()); - header->SequenceNumber = m_sent_messages; + header->SequenceNumber = sequenceNumber; #ifdef WIN32 @@ -150,7 +173,7 @@ class SocketChannel } template - void SendMessage(TMessage& message) + void SendMessage(TMessage& message, uint32_t nextSendIncrement = 1) { // Catch situations where the other SendMessage() method should be used const auto& header = GetMessageHeader(message); @@ -164,7 +187,7 @@ class SocketChannel #endif } - SendMessage(gslhelpers::struct_as_writeable_bytes(message)); + SendMessage(gslhelpers::struct_as_writeable_bytes(message), nextSendIncrement); } template @@ -179,7 +202,7 @@ class SocketChannel } template - std::pair> ReceiveMessageOrClosed(TTimeout timeout = DefaultSocketTimeout) + std::pair> ReceiveMessageOrClosed(TTimeout timeout = DefaultSocketTimeout, uint32_t expectedOffset = 0) { WI_ASSERT(m_name != nullptr); @@ -199,53 +222,124 @@ class SocketChannel #endif } - m_received_messages++; - - auto receivedSpan = ReceiveImpl(TMessage::Type, timeout); - if (receivedSpan.empty()) + for (;;) { -#ifdef WIN32 - if (errno == HCS_E_CONNECTION_TIMEOUT) + m_received_messages++; + + auto receivedSpan = ReceiveImpl(TMessage::Type, timeout); + if (receivedSpan.empty()) { - THROW_HR_MSG(HCS_E_CONNECTION_TIMEOUT, "Timeout: %d, expected type: %hs, channel: %hs", timeout, ToString(TMessage::Type), m_name); + +#ifdef WIN32 + if (errno == HCS_E_CONNECTION_TIMEOUT) + { + THROW_HR_MSG(HCS_E_CONNECTION_TIMEOUT, "Timeout: %d, expected type: %hs, channel: %hs", timeout, ToString(TMessage::Type), m_name); + } +#endif + + return {nullptr, {}}; } + + // Validate sequence + if (!m_ignore_sequence) + { + // Use header only since this could be a stale message of the wrong type. + auto* receivedHeader = gslhelpers::try_get_struct(receivedSpan); + if (receivedHeader == nullptr) + { +#ifdef WIN32 + THROW_HR_MSG(E_UNEXPECTED, "Message too small for header: %zd, channel: %hs", receivedSpan.size(), m_name); +#else + LOG_ERROR("Message too small for header: {}, channel: {}", receivedSpan.size(), m_name); + THROW_ERRNO(EINVAL); #endif + } - return {nullptr, {}}; - } + if (m_strict_request_end) + { + // Skip stale message for strict request end. + std::lock_guard sequenceLock{m_sequenceMutex}; + uint32_t expectedSequence = m_sent_messages + expectedOffset; + auto diff = static_cast(receivedHeader->SequenceNumber - expectedSequence); + if (diff < 0) + { +#ifdef WIN32 + WSL_LOG( + "DiscardStaleResponse", + TraceLoggingValue(m_name, "Name"), + TraceLoggingValue(receivedHeader->SequenceNumber, "StaleSeq"), + TraceLoggingValue(expectedSequence, "ExpectedSeq")); +#else + LOG_WARNING("Discard stale response on channel: {}. StaleSeq: {}, ExpectedSeq: {}", m_name, receivedHeader->SequenceNumber, expectedSequence); +#endif + continue; + } - auto* message = gslhelpers::try_get_struct(receivedSpan); + if (diff != 0) + { +#ifdef WIN32 + THROW_HR_MSG(E_UNEXPECTED, "Unexpected response sequence: %u, expected: %u, channel: %hs", receivedHeader->SequenceNumber, expectedSequence, m_name); +#else + LOG_ERROR("Unexpected response sequence: {}, expected: {}, channel: {}", receivedHeader->SequenceNumber, expectedSequence, m_name); + THROW_ERRNO(EINVAL); +#endif + } + } + else if (m_strict_reply_end) + { + // The send id must catch up to the latest received id. + std::lock_guard sequenceLock{m_sequenceMutex}; + // - 1: m_sent_messages will be ++ before used. + m_sent_messages = receivedHeader->SequenceNumber - 1; + } + else if (receivedHeader->SequenceNumber != m_received_messages) + { + // Ensure consecutive sequence numbers +#ifdef WIN32 + THROW_HR_MSG(E_UNEXPECTED, "Unexpected message sequence: %u, expected: %u, channel: %hs", receivedHeader->SequenceNumber, m_received_messages, m_name); +#else + LOG_ERROR("Unexpected message sequence: {}, expected: {}, channel: {}", receivedHeader->SequenceNumber, m_received_messages, m_name); + THROW_ERRNO(EINVAL); +#endif + } + } - if (message == nullptr) - { + auto* message = gslhelpers::try_get_struct(receivedSpan); + + if (message == nullptr) + { #ifdef WIN32 - THROW_HR_MSG( - E_UNEXPECTED, "Message size is too small: %zd, expected type: %hs, channel: %hs", receivedSpan.size(), ToString(TMessage::Type), m_name); + THROW_HR_MSG( + E_UNEXPECTED, "Message size is too small: %zd, expected type: %hs, channel: %hs", receivedSpan.size(), ToString(TMessage::Type), m_name); #else - LOG_ERROR("MessageSize is too small: {}, expected type: {}, channel: {}", receivedSpan.size(), ToString(TMessage::Type), m_name); - THROW_ERRNO(EINVAL); + LOG_ERROR("MessageSize is too small: {}, expected type: {}, channel: {}", receivedSpan.size(), ToString(TMessage::Type), m_name); + THROW_ERRNO(EINVAL); #endif - } + } - ValidateMessageHeader(GetMessageHeader(*message), TMessage::Type, m_received_messages); + // Validate type + ValidateMessageHeader(GetMessageHeader(*message), TMessage::Type); #ifdef WIN32 - WSL_LOG( - "ReceivedMessage", TraceLoggingValue(m_name, "Name"), TraceLoggingValue(message->PrettyPrint().c_str(), "Content")); + WSL_LOG( + "ReceivedMessage", + TraceLoggingValue(m_name, "Name"), + TraceLoggingValue(message->PrettyPrint().c_str(), "Content")); #else - if (LoggingEnabled()) - { - LOG_INFO("ReceivedMessage on channel: {}: '{}'", m_name, message->PrettyPrint().c_str()); - } + if (LoggingEnabled()) + { + LOG_INFO("ReceivedMessage on channel: {}: '{}'", m_name, message->PrettyPrint().c_str()); + } #endif - return {message, receivedSpan}; + return {message, receivedSpan}; + } } template - TMessage& ReceiveMessage(gsl::span* responseSpan = nullptr, TTimeout timeout = DefaultSocketTimeout) + TMessage& ReceiveMessage(gsl::span* responseSpan = nullptr, TTimeout timeout = DefaultSocketTimeout, uint32_t expectedOffset = 0) { - auto [message, span] = ReceiveMessageOrClosed(timeout); + auto [message, span] = ReceiveMessageOrClosed(timeout, expectedOffset); if (message == nullptr) { #ifdef WIN32 @@ -295,6 +389,18 @@ class SocketChannel m_ignore_sequence = true; } + // This end always requests. Without concurrent requests. + void SetStrictRequestEnd() + { + m_strict_request_end = true; + } + + // This end always replies. Without concurrent requests. + void SetStrictReplyEnd() + { + m_strict_reply_end = true; + } + #ifndef WIN32 static void EnableSocketLogging(bool enable) @@ -321,33 +427,29 @@ class SocketChannel #endif - void ValidateMessageHeader(const MESSAGE_HEADER& header, LX_MESSAGE_TYPE expected, unsigned int expectedSequence) const + void ValidateMessageHeader(const MESSAGE_HEADER& header, LX_MESSAGE_TYPE expected) const { - if (header.MessageSize < sizeof(header) || (expected != LxMiniInitMessageAny && header.MessageType != expected) || - (!m_ignore_sequence && header.SequenceNumber != expectedSequence)) + + if (header.MessageSize < sizeof(header) || (expected != LxMiniInitMessageAny && header.MessageType != expected)) { #ifdef WIN32 THROW_HR_MSG( E_UNEXPECTED, - "Protocol error: Received message size: %u, type: %u, sequence: %u. Expected type: %u, expected sequence: %u, " + "Protocol error: Received message size: %u, type: %u. Expected type: %u, " "channel: %hs", header.MessageSize, header.MessageType, - header.SequenceNumber, expected, - expectedSequence, m_name); #else LOG_ERROR( - "Protocol error: Received message size: {}, type: {}, sequence: {}. Expected type: {}, expected sequence: {}, " - "channel: %s", + "Protocol error: Received message size: {}, type: {}. Expected type: {}, " + "channel: {}", header.MessageSize, header.MessageType, - header.SequenceNumber, expected, - expectedSequence, m_name); THROW_ERRNO(EINVAL); @@ -393,10 +495,14 @@ class SocketChannel #endif uint32_t m_sent_messages = 0; + uint32_t m_next_send_increment = 1; uint32_t m_received_messages = 0; bool m_ignore_sequence = false; + bool m_strict_request_end = false; + bool m_strict_reply_end = false; const char* m_name{}; std::mutex m_sendMutex; std::mutex m_receiveMutex; + std::mutex m_sequenceMutex; }; -} // namespace wsl::shared \ No newline at end of file +} // namespace wsl::shared diff --git a/src/windows/service/exe/LxssCreateProcess.h b/src/windows/service/exe/LxssCreateProcess.h index 8e1c704a4..e46b8080f 100644 --- a/src/windows/service/exe/LxssCreateProcess.h +++ b/src/windows/service/exe/LxssCreateProcess.h @@ -85,16 +85,15 @@ class LxssCreateProcess wsl::shared::MessageWriter message(LxInitCreateProcess); message.WriteString(message->PathIndex, Path); gsl::copy(as_bytes(gsl::span(ArgumentsData)), message.InsertBuffer(message->CommandLineIndex, ArgumentsData.size())); - channel.SendMessage(message.Span()); + channel.SendMessage(message.Span(), 2); - auto readResult = [&]() { - const auto& message = channel.ReceiveMessage>(nullptr, Timeout); - return message.Result; + auto readResult = [&](uint32_t expectedOffset = 0) { + return channel.ReceiveMessage>(nullptr, Timeout, expectedOffset).Result; }; auto processSocket = wsl::windows::common::hvsocket::Connect(RuntimeId, readResult(), terminatingEvent); - const auto execResult = readResult(); + const auto execResult = readResult(1); THROW_HR_IF_MSG(E_FAIL, execResult != 0, "Failed to execute '%hs', error=%d", Path, execResult); return processSocket; diff --git a/src/windows/service/exe/WslCoreInstance.cpp b/src/windows/service/exe/WslCoreInstance.cpp index 74fe3168d..425e8f3a0 100644 --- a/src/windows/service/exe/WslCoreInstance.cpp +++ b/src/windows/service/exe/WslCoreInstance.cpp @@ -378,6 +378,9 @@ void WslCoreInstance::Initialize() // Create a console manager that will be used to manage session leaders. m_consoleManager = ConsoleManager::CreateConsoleManager(m_initChannel); + // Enable strict request-end sequencing. + m_initChannel->GetChannel().SetStrictRequestEnd(); + // Send the initial configuration information to the init daemon. ULONG fixedDrives = 0; if (WI_IsFlagSet(m_configuration.Flags, LXSS_DISTRO_FLAGS_ENABLE_DRIVE_MOUNTING)) @@ -544,7 +547,9 @@ std::shared_ptr WslCoreInstance::WslCorePort::CreateSessionLeader(_In_ const auto& response = m_channel.Transaction(message, nullptr, m_socketTimeout); wil::unique_socket socket = wsl::windows::common::hvsocket::Connect(m_runtimeId, response.Port); - return std::make_shared(socket.release(), m_runtimeId, m_socketTimeout); + auto sessionLeader = std::make_shared(socket.release(), m_runtimeId, m_socketTimeout); + sessionLeader->GetChannel().SetStrictRequestEnd(); + return sessionLeader; } void WslCoreInstance::WslCorePort::DisconnectConsole(_In_ HANDLE)