Skip to content

Commit 0f46a5e

Browse files
committed
feat(streamable_http): elicitation request
1 parent aef7c8d commit 0f46a5e

File tree

2 files changed

+87
-7
lines changed

2 files changed

+87
-7
lines changed

client/client.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,8 @@ func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JS
478478
return c.handleSamplingRequestTransport(ctx, request)
479479
case string(mcp.MethodElicitationCreate):
480480
return c.handleElicitationRequestTransport(ctx, request)
481+
case string(mcp.MethodPing):
482+
return c.handlePingRequestTransport(ctx, request)
481483
default:
482484
return nil, fmt.Errorf("unsupported request method: %s", request.Method)
483485
}
@@ -579,6 +581,15 @@ func (c *Client) handleElicitationRequestTransport(ctx context.Context, request
579581
return response, nil
580582
}
581583

584+
func (c *Client) handlePingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
585+
b, _ := json.Marshal(&mcp.EmptyResult{})
586+
return &transport.JSONRPCResponse{
587+
JSONRPC: mcp.JSONRPC_VERSION,
588+
ID: request.ID,
589+
Result: b,
590+
}, nil
591+
}
592+
582593
func listByPage[T any](
583594
ctx context.Context,
584595
client *Client,

server/streamable_http.go

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,21 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
473473
case <-done:
474474
return
475475
}
476+
case elicitationReq := <-session.elicitationRequestChan:
477+
// Send elicitation request to client via SSE
478+
jsonrpcRequest := mcp.JSONRPCRequest{
479+
JSONRPC: "2.0",
480+
ID: mcp.NewRequestId(elicitationReq.requestID),
481+
Request: mcp.Request{
482+
Method: string(mcp.MethodElicitationCreate),
483+
},
484+
Params: elicitationReq.request.Params,
485+
}
486+
select {
487+
case writeChan <- jsonrpcRequest:
488+
case <-done:
489+
return
490+
}
476491
case <-done:
477492
return
478493
}
@@ -612,12 +627,6 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *
612627
}
613628
} else if responseMessage.Result != nil {
614629
// Parse result
615-
var result mcp.CreateMessageResult
616-
if err := json.Unmarshal(responseMessage.Result, &result); err != nil {
617-
response.err = fmt.Errorf("failed to parse sampling result: %v", err)
618-
} else {
619-
response.result = &result
620-
}
621630
} else {
622631
response.err = fmt.Errorf("sampling response has neither result nor error")
623632
}
@@ -768,6 +777,13 @@ type samplingResponseItem struct {
768777
err error
769778
}
770779

780+
// Elicitation support types for HTTP transport
781+
type elicitationRequestItem struct {
782+
requestID int64
783+
request mcp.ElicitationRequest
784+
response chan responseItem
785+
}
786+
771787
// streamableHttpSession is a session for streamable-http transport
772788
// When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler.
773789
// When in GET handlers(listening), it's a real session, and will be registered in the MCP server.
@@ -780,6 +796,8 @@ type streamableHttpSession struct {
780796

781797
// Sampling support for bidirectional communication
782798
samplingRequestChan chan samplingRequestItem // server -> client sampling requests
799+
elicitationRequestChan chan elicitationRequestItem // server -> client elicitation requests
800+
783801
samplingRequests sync.Map // requestID -> pending sampling request context
784802
requestIDCounter atomic.Int64 // for generating unique request IDs
785803
}
@@ -791,6 +809,7 @@ func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, le
791809
tools: toolStore,
792810
logLevels: levels,
793811
samplingRequestChan: make(chan samplingRequestItem, 10),
812+
elicitationRequestChan: make(chan elicitationRequestItem, 10),
794813
}
795814
return s
796815
}
@@ -877,13 +896,63 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp
877896
if response.err != nil {
878897
return nil, response.err
879898
}
880-
return response.result, nil
899+
var result mcp.CreateMessageResult
900+
if err := json.Unmarshal(response.result, &result); err != nil {
901+
return nil, fmt.Errorf("failed to unmarshal sampling response: %v", err)
902+
}
903+
return &result, nil
904+
case <-ctx.Done():
905+
return nil, ctx.Err()
906+
}
907+
}
908+
909+
// RequestElicitation implements SessionWithElicitation interface for HTTP transport
910+
func (s *streamableHttpSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) {
911+
// Generate unique request ID
912+
requestID := s.requestIDCounter.Add(1)
913+
914+
// Create response channel for this specific request
915+
responseChan := make(chan responseItem, 1)
916+
917+
// Create the sampling request item
918+
elicitationRequest := elicitationRequestItem{
919+
requestID: requestID,
920+
request: request,
921+
response: responseChan,
922+
}
923+
924+
// Store the pending request
925+
s.requests.Store(requestID, responseChan)
926+
defer s.requests.Delete(requestID)
927+
928+
// Send the sampling request via the channel (non-blocking)
929+
select {
930+
case s.elicitationRequestChan <- elicitationRequest:
931+
// Request queued successfully
932+
case <-ctx.Done():
933+
return nil, ctx.Err()
934+
default:
935+
return nil, fmt.Errorf("elicitation request queue is full - server overloaded")
936+
}
937+
938+
// Wait for response or context cancellation
939+
select {
940+
case response := <-responseChan:
941+
if response.err != nil {
942+
return nil, response.err
943+
}
944+
var result mcp.ElicitationResult
945+
if err := json.Unmarshal(response.result, &result); err != nil {
946+
return nil, fmt.Errorf("failed to unmarshal elicitation response: %v", err)
947+
}
948+
return &result, nil
881949
case <-ctx.Done():
882950
return nil, ctx.Err()
883951
}
884952
}
885953

886954
var _ SessionWithSampling = (*streamableHttpSession)(nil)
955+
var _ SessionWithElicitation = (*streamableHttpSession)(nil)
887956

888957
// --- session id manager ---
889958

0 commit comments

Comments
 (0)