Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ public Mono<Void> notifyClients(String method, Object params) {
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
final McpTransportContext transportContext = this.contextExtractor.extract(request);

String requestURI = request.getRequestURI();
if (!requestURI.endsWith(sseEndpoint)) {
Expand Down Expand Up @@ -239,7 +240,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
writer);

// Create a new session using the session factory
McpServerSession session = sessionFactory.create(sessionTransport);
McpServerSession session = sessionFactory.create(transportContext, sessionTransport);
this.sessions.put(sessionId, session);

// Send initial endpoint event
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
new TypeRef<McpSchema.InitializeRequest>() {
});
McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory
.startSession(initializeRequest);
.startSession(transportContext, initializeRequest);
this.sessions.put(init.session().getId(), init.session());

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public List<String> protocolVersions() {
public void setSessionFactory(McpServerSession.Factory sessionFactory) {
// Create a single session for the stdio connection
var transport = new StdioMcpSessionTransport();
this.session = sessionFactory.create(transport);
this.session = sessionFactory.create(null, transport);
transport.initProcessing();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

package io.modelcontextprotocol.spec;

import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.server.McpNotificationHandler;
import io.modelcontextprotocol.server.McpRequestHandler;

Expand Down Expand Up @@ -45,11 +46,33 @@ public DefaultMcpStreamableServerSessionFactory(Duration requestTimeout,

@Override
public McpStreamableServerSession.McpStreamableServerSessionInit startSession(
McpSchema.InitializeRequest initializeRequest) {
final McpSchema.InitializeRequest initializeRequest) {
final String sessionId = generateSessionId(null, initializeRequest);
return new McpStreamableServerSession.McpStreamableServerSessionInit(
new McpStreamableServerSession(UUID.randomUUID().toString(), initializeRequest.capabilities(),
new McpStreamableServerSession(sessionId, initializeRequest.capabilities(),
initializeRequest.clientInfo(), requestTimeout, requestHandlers, notificationHandlers),
this.initRequestHandler.handle(initializeRequest));
}

@Override
public McpStreamableServerSession.McpStreamableServerSessionInit startSession(
final McpTransportContext mcpTransportContext, final McpSchema.InitializeRequest initializeRequest) {
final String sessionId = generateSessionId(mcpTransportContext, initializeRequest);
return new McpStreamableServerSession.McpStreamableServerSessionInit(
new McpStreamableServerSession(sessionId, initializeRequest.capabilities(),
initializeRequest.clientInfo(), requestTimeout, requestHandlers, notificationHandlers),
this.initRequestHandler.handle(initializeRequest));
}

/**
* An extensibility point to generate session IDs differently.
* @param mcpTransportContext transport context
* @param initializeRequest initialization request
* @return generated session ID
*/
protected String generateSessionId(McpTransportContext mcpTransportContext,
McpSchema.InitializeRequest initializeRequest) {
return UUID.randomUUID().toString();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,16 @@ public interface Factory {
*/
McpServerSession create(McpServerTransport sessionTransport);

/**
* Creates a new 1:1 representation of the client-server interaction.
* @param mcpTransportContext the transport context associated with the client.
* @param sessionTransport the transport to use for communication with the client.
* @return a new server session.
*/
default McpServerSession create(McpTransportContext mcpTransportContext, McpServerTransport sessionTransport) {
return create(sessionTransport);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,20 @@ public interface Factory {
* @param initializeRequest the initialization request from the client
* @return a composite allowing the session to start
*/
@Deprecated
McpStreamableServerSessionInit startSession(McpSchema.InitializeRequest initializeRequest);

/**
* Given an initialize request, create a composite for the session initialization
* @param mcpTransportContext the transport context for the initialization request
* @param initializeRequest the initialization request from the client
* @return a composite allowing the session to start
*/
default McpStreamableServerSessionInit startSession(McpTransportContext mcpTransportContext,
McpSchema.InitializeRequest initializeRequest) {
return startSession(initializeRequest);
}

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public MockMcpServerTransport getTransport() {
@Override
public void setSessionFactory(Factory sessionFactory) {

session = sessionFactory.create(transport);
session = sessionFactory.create(null, transport);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
Expand Down Expand Up @@ -67,7 +68,8 @@ void setUp() {
sessionFactory = mock(McpServerSession.Factory.class);

// Configure mock behavior
when(sessionFactory.create(any(McpServerTransport.class))).thenReturn(mockSession);
when(sessionFactory.create(any(McpTransportContext.class), any(McpServerTransport.class)))
.thenReturn(mockSession);
when(mockSession.closeGracefully()).thenReturn(Mono.empty());
when(mockSession.sendNotification(any(), any())).thenReturn(Mono.empty());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
.body(Flux.<ServerSentEvent<?>>create(sink -> {
WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink);

McpServerSession session = sessionFactory.create(sessionTransport);
McpServerSession session = sessionFactory.create(transportContext, sessionTransport);
String sessionId = session.getId();

logger.debug("Created new SSE connection for session: {}", sessionId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
McpSchema.InitializeRequest initializeRequest = jsonMapper.convertValue(jsonrpcRequest.params(),
typeReference);
McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory
.startSession(initializeRequest);
.startSession(transportContext, initializeRequest);
sessions.put(init.session().getId(), init.session());
return init.initResult().map(initializeResult -> {
McpSchema.JSONRPCResponse jsonrpcResponse = new McpSchema.JSONRPCResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ private ServerResponse handleSseConnection(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
}

McpTransportContext mcpTransportContext = this.contextExtractor.extract(request);
String sessionId = UUID.randomUUID().toString();
logger.debug("Creating new SSE connection for session: {}", sessionId);

Expand All @@ -271,7 +272,7 @@ private ServerResponse handleSseConnection(ServerRequest request) {
});

WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sessionId, sseBuilder);
McpServerSession session = sessionFactory.create(sessionTransport);
McpServerSession session = sessionFactory.create(mcpTransportContext, sessionTransport);
this.sessions.put(sessionId, session);

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ private ServerResponse handlePost(ServerRequest request) {
new TypeRef<McpSchema.InitializeRequest>() {
});
McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory
.startSession(initializeRequest);
.startSession(transportContext, initializeRequest);
this.sessions.put(init.session().getId(), init.session());

try {
Expand Down