Skip to content
5 changes: 5 additions & 0 deletions src/linux/init/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,8 @@ try
_exit(1);
}

channel.SetStrictReplyEnd();

SessionLeaderEntryUtilityVm(channel, Config);
});
}
Expand Down Expand Up @@ -2427,6 +2429,9 @@ Return Value:
PollDescriptors[1].events = POLLIN;
}

// Enable strict reply-end sequencing before receiving any messages from Windows.
channel.SetStrictReplyEnd();

for (;;)
{
auto Result = poll(PollDescriptors.data(), PollDescriptors.size(), -1);
Expand Down
194 changes: 150 additions & 44 deletions src/shared/inc/SocketChannel.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ class SocketChannel
m_exitEvent = std::move(other.m_exitEvent);
#endif
m_ignore_sequence = other.m_ignore_sequence;
m_strict_request_end = other.m_strict_request_end;
m_strict_reply_end = other.m_strict_reply_end;
m_sent_messages = other.m_sent_messages;
m_next_send_increment = other.m_next_send_increment;
m_received_messages = other.m_received_messages;

return *this;
}
Expand All @@ -82,7 +87,7 @@ class SocketChannel
#endif

template <typename TMessage>
void SendMessage(gsl::span<gsl::byte> span)
void SendMessage(gsl::span<gsl::byte> span, uint32_t nextSendIncrement = 1)
{
// Ensure that no other thread is using this channel.
const std::unique_lock<std::mutex> lock{m_sendMutex, std::try_to_lock};
Expand All @@ -103,12 +108,30 @@ class SocketChannel

THROW_INVALID_ARG_IF(m_name == nullptr || span.size() < sizeof(TMessage));

m_sent_messages++;
uint32_t sequenceNumber = 0;
if (m_strict_request_end)
{
std::lock_guard<std::mutex> sequenceLock{m_sequenceMutex};
m_sent_messages += m_next_send_increment;
m_next_send_increment = nextSendIncrement;
sequenceNumber = m_sent_messages;
}
else if (m_strict_reply_end)
{
std::lock_guard<std::mutex> sequenceLock{m_sequenceMutex};
m_sent_messages++;
sequenceNumber = m_sent_messages;
}
else
{
m_sent_messages++;
sequenceNumber = m_sent_messages;
}

auto* header = gslhelpers::try_get_struct<MESSAGE_HEADER>(span);
WI_ASSERT(header->MessageSize == span.size());

header->SequenceNumber = m_sent_messages;
header->SequenceNumber = sequenceNumber;

#ifdef WIN32

Expand Down Expand Up @@ -150,7 +173,7 @@ class SocketChannel
}

template <typename TMessage>
void SendMessage(TMessage& message)
void SendMessage(TMessage& message, uint32_t nextSendIncrement = 1)
{
// Catch situations where the other SendMessage() method should be used
const auto& header = GetMessageHeader(message);
Expand All @@ -164,7 +187,7 @@ class SocketChannel
#endif
}

SendMessage<TMessage>(gslhelpers::struct_as_writeable_bytes(message));
SendMessage<TMessage>(gslhelpers::struct_as_writeable_bytes(message), nextSendIncrement);
}

template <typename TResult>
Expand All @@ -179,7 +202,7 @@ class SocketChannel
}

template <typename TMessage>
std::pair<TMessage*, gsl::span<gsl::byte>> ReceiveMessageOrClosed(TTimeout timeout = DefaultSocketTimeout)
std::pair<TMessage*, gsl::span<gsl::byte>> ReceiveMessageOrClosed(TTimeout timeout = DefaultSocketTimeout, uint32_t expectedOffset = 0)
{
WI_ASSERT(m_name != nullptr);

Expand All @@ -199,53 +222,124 @@ class SocketChannel
#endif
}

m_received_messages++;

auto receivedSpan = ReceiveImpl(TMessage::Type, timeout);
if (receivedSpan.empty())
for (;;)
{

#ifdef WIN32
if (errno == HCS_E_CONNECTION_TIMEOUT)
m_received_messages++;

auto receivedSpan = ReceiveImpl(TMessage::Type, timeout);
if (receivedSpan.empty())
{
THROW_HR_MSG(HCS_E_CONNECTION_TIMEOUT, "Timeout: %d, expected type: %hs, channel: %hs", timeout, ToString(TMessage::Type), m_name);

#ifdef WIN32
if (errno == HCS_E_CONNECTION_TIMEOUT)
{
THROW_HR_MSG(HCS_E_CONNECTION_TIMEOUT, "Timeout: %d, expected type: %hs, channel: %hs", timeout, ToString(TMessage::Type), m_name);
}
#endif

return {nullptr, {}};
}

// Validate sequence
if (!m_ignore_sequence)
{
// Use header only since this could be a stale message of the wrong type.
auto* receivedHeader = gslhelpers::try_get_struct<MESSAGE_HEADER>(receivedSpan);
if (receivedHeader == nullptr)
{
#ifdef WIN32
THROW_HR_MSG(E_UNEXPECTED, "Message too small for header: %zd, channel: %hs", receivedSpan.size(), m_name);
#else
LOG_ERROR("Message too small for header: {}, channel: {}", receivedSpan.size(), m_name);
THROW_ERRNO(EINVAL);
#endif
}

return {nullptr, {}};
}
if (m_strict_request_end)
{
// Skip stale message for strict request end.
std::lock_guard<std::mutex> sequenceLock{m_sequenceMutex};
uint32_t expectedSequence = m_sent_messages + expectedOffset;
auto diff = static_cast<int32_t>(receivedHeader->SequenceNumber - expectedSequence);
if (diff < 0)
{
#ifdef WIN32
WSL_LOG(
"DiscardStaleResponse",
TraceLoggingValue(m_name, "Name"),
TraceLoggingValue(receivedHeader->SequenceNumber, "StaleSeq"),
TraceLoggingValue(expectedSequence, "ExpectedSeq"));
#else
LOG_WARNING("Discard stale response on channel: {}. StaleSeq: {}, ExpectedSeq: {}", m_name, receivedHeader->SequenceNumber, expectedSequence);
#endif
continue;
}

auto* message = gslhelpers::try_get_struct<TMessage>(receivedSpan);
if (diff != 0)
{
#ifdef WIN32
THROW_HR_MSG(E_UNEXPECTED, "Unexpected response sequence: %u, expected: %u, channel: %hs", receivedHeader->SequenceNumber, expectedSequence, m_name);
#else
LOG_ERROR("Unexpected response sequence: {}, expected: {}, channel: {}", receivedHeader->SequenceNumber, expectedSequence, m_name);
THROW_ERRNO(EINVAL);
#endif
}
}
else if (m_strict_reply_end)
{
// The send id must catch up to the latest received id.
std::lock_guard<std::mutex> sequenceLock{m_sequenceMutex};
// - 1: m_sent_messages will be ++ before used.
m_sent_messages = receivedHeader->SequenceNumber - 1;
}
else if (receivedHeader->SequenceNumber != m_received_messages)
{
// Ensure consecutive sequence numbers
#ifdef WIN32
THROW_HR_MSG(E_UNEXPECTED, "Unexpected message sequence: %u, expected: %u, channel: %hs", receivedHeader->SequenceNumber, m_received_messages, m_name);
#else
LOG_ERROR("Unexpected message sequence: {}, expected: {}, channel: {}", receivedHeader->SequenceNumber, m_received_messages, m_name);
THROW_ERRNO(EINVAL);
#endif
}
}

if (message == nullptr)
{
auto* message = gslhelpers::try_get_struct<TMessage>(receivedSpan);

if (message == nullptr)
{
#ifdef WIN32
THROW_HR_MSG(
E_UNEXPECTED, "Message size is too small: %zd, expected type: %hs, channel: %hs", receivedSpan.size(), ToString(TMessage::Type), m_name);
THROW_HR_MSG(
E_UNEXPECTED, "Message size is too small: %zd, expected type: %hs, channel: %hs", receivedSpan.size(), ToString(TMessage::Type), m_name);
#else
LOG_ERROR("MessageSize is too small: {}, expected type: {}, channel: {}", receivedSpan.size(), ToString(TMessage::Type), m_name);
THROW_ERRNO(EINVAL);
LOG_ERROR("MessageSize is too small: {}, expected type: {}, channel: {}", receivedSpan.size(), ToString(TMessage::Type), m_name);
THROW_ERRNO(EINVAL);
#endif
}
}

ValidateMessageHeader(GetMessageHeader(*message), TMessage::Type, m_received_messages);
// Validate type
ValidateMessageHeader(GetMessageHeader(*message), TMessage::Type);

#ifdef WIN32
WSL_LOG(
"ReceivedMessage", TraceLoggingValue(m_name, "Name"), TraceLoggingValue(message->PrettyPrint().c_str(), "Content"));
WSL_LOG(
"ReceivedMessage",
TraceLoggingValue(m_name, "Name"),
TraceLoggingValue(message->PrettyPrint().c_str(), "Content"));
#else
if (LoggingEnabled())
{
LOG_INFO("ReceivedMessage on channel: {}: '{}'", m_name, message->PrettyPrint().c_str());
}
if (LoggingEnabled())
{
LOG_INFO("ReceivedMessage on channel: {}: '{}'", m_name, message->PrettyPrint().c_str());
}
#endif
return {message, receivedSpan};
return {message, receivedSpan};
}
}

template <typename TMessage>
TMessage& ReceiveMessage(gsl::span<gsl::byte>* responseSpan = nullptr, TTimeout timeout = DefaultSocketTimeout)
TMessage& ReceiveMessage(gsl::span<gsl::byte>* responseSpan = nullptr, TTimeout timeout = DefaultSocketTimeout, uint32_t expectedOffset = 0)
{
auto [message, span] = ReceiveMessageOrClosed<TMessage>(timeout);
auto [message, span] = ReceiveMessageOrClosed<TMessage>(timeout, expectedOffset);
if (message == nullptr)
{
#ifdef WIN32
Expand Down Expand Up @@ -295,6 +389,18 @@ class SocketChannel
m_ignore_sequence = true;
}

// This end always requests. Without concurrent requests.
void SetStrictRequestEnd()
{
m_strict_request_end = true;
}

// This end always replies. Without concurrent requests.
void SetStrictReplyEnd()
{
m_strict_reply_end = true;
}

#ifndef WIN32

static void EnableSocketLogging(bool enable)
Expand All @@ -321,33 +427,29 @@ class SocketChannel

#endif

void ValidateMessageHeader(const MESSAGE_HEADER& header, LX_MESSAGE_TYPE expected, unsigned int expectedSequence) const
void ValidateMessageHeader(const MESSAGE_HEADER& header, LX_MESSAGE_TYPE expected) const
{
if (header.MessageSize < sizeof(header) || (expected != LxMiniInitMessageAny && header.MessageType != expected) ||
(!m_ignore_sequence && header.SequenceNumber != expectedSequence))

if (header.MessageSize < sizeof(header) || (expected != LxMiniInitMessageAny && header.MessageType != expected))
{
#ifdef WIN32

THROW_HR_MSG(
E_UNEXPECTED,
"Protocol error: Received message size: %u, type: %u, sequence: %u. Expected type: %u, expected sequence: %u, "
"Protocol error: Received message size: %u, type: %u. Expected type: %u, "
"channel: %hs",
header.MessageSize,
header.MessageType,
header.SequenceNumber,
expected,
expectedSequence,
m_name);
#else

LOG_ERROR(
"Protocol error: Received message size: {}, type: {}, sequence: {}. Expected type: {}, expected sequence: {}, "
"channel: %s",
"Protocol error: Received message size: {}, type: {}. Expected type: {}, "
"channel: {}",
header.MessageSize,
header.MessageType,
header.SequenceNumber,
expected,
expectedSequence,
m_name);

THROW_ERRNO(EINVAL);
Expand Down Expand Up @@ -393,10 +495,14 @@ class SocketChannel

#endif
uint32_t m_sent_messages = 0;
uint32_t m_next_send_increment = 1;
uint32_t m_received_messages = 0;
bool m_ignore_sequence = false;
bool m_strict_request_end = false;
bool m_strict_reply_end = false;
const char* m_name{};
std::mutex m_sendMutex;
std::mutex m_receiveMutex;
std::mutex m_sequenceMutex;
};
} // namespace wsl::shared
} // namespace wsl::shared
9 changes: 4 additions & 5 deletions src/windows/service/exe/LxssCreateProcess.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,15 @@ class LxssCreateProcess
wsl::shared::MessageWriter<CREATE_PROCESS_MESSAGE> message(LxInitCreateProcess);
message.WriteString(message->PathIndex, Path);
gsl::copy(as_bytes(gsl::span(ArgumentsData)), message.InsertBuffer(message->CommandLineIndex, ArgumentsData.size()));
channel.SendMessage<CREATE_PROCESS_MESSAGE>(message.Span());
channel.SendMessage<CREATE_PROCESS_MESSAGE>(message.Span(), 2);

auto readResult = [&]() {
const auto& message = channel.ReceiveMessage<RESULT_MESSAGE<int32_t>>(nullptr, Timeout);
return message.Result;
auto readResult = [&](uint32_t expectedOffset = 0) {
return channel.ReceiveMessage<RESULT_MESSAGE<int32_t>>(nullptr, Timeout, expectedOffset).Result;
};

auto processSocket = wsl::windows::common::hvsocket::Connect(RuntimeId, readResult(), terminatingEvent);

const auto execResult = readResult();
const auto execResult = readResult(1);
THROW_HR_IF_MSG(E_FAIL, execResult != 0, "Failed to execute '%hs', error=%d", Path, execResult);

return processSocket;
Expand Down
7 changes: 6 additions & 1 deletion src/windows/service/exe/WslCoreInstance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,9 @@ void WslCoreInstance::Initialize()
// Create a console manager that will be used to manage session leaders.
m_consoleManager = ConsoleManager::CreateConsoleManager(m_initChannel);

// Enable strict request-end sequencing.
m_initChannel->GetChannel().SetStrictRequestEnd();

// Send the initial configuration information to the init daemon.
ULONG fixedDrives = 0;
if (WI_IsFlagSet(m_configuration.Flags, LXSS_DISTRO_FLAGS_ENABLE_DRIVE_MOUNTING))
Expand Down Expand Up @@ -544,7 +547,9 @@ std::shared_ptr<LxssPort> WslCoreInstance::WslCorePort::CreateSessionLeader(_In_
const auto& response = m_channel.Transaction(message, nullptr, m_socketTimeout);

wil::unique_socket socket = wsl::windows::common::hvsocket::Connect(m_runtimeId, response.Port);
return std::make_shared<WslCorePort>(socket.release(), m_runtimeId, m_socketTimeout);
auto sessionLeader = std::make_shared<WslCorePort>(socket.release(), m_runtimeId, m_socketTimeout);
sessionLeader->GetChannel().SetStrictRequestEnd();
return sessionLeader;
}

void WslCoreInstance::WslCorePort::DisconnectConsole(_In_ HANDLE)
Expand Down
Loading