diff --git a/mcp/streamable.go b/mcp/streamable.go index da135374..b8e36553 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -260,284 +260,348 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque } if h.opts.CrossOriginProtection != nil { - // Verify the 'Origin' header to protect against CSRF attacks. if err := h.opts.CrossOriginProtection.Check(req); err != nil { http.Error(w, err.Error(), http.StatusForbidden) return } } - // Validate 'Content-Type' header. - if req.Method == http.MethodPost && baseMediaType(req.Header.Get("Content-Type")) != "application/json" { + // [§2.7] of the spec (2025-06-18): validate the MCP-Protocol-Version + // header. If provided, it must be a supported version. If absent, the + // version is unknown (the request may be an initialize for any version). + // + // [§2.7]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header + protocolVersion := req.Header.Get(protocolVersionHeader) + if protocolVersion != "" && !slices.Contains(supportedProtocolVersions, protocolVersion) { + http.Error(w, fmt.Sprintf("Bad Request: Unsupported protocol version (supported versions: %s)", strings.Join(supportedProtocolVersions, ",")), http.StatusBadRequest) + return + } + req = req.WithContext(context.WithValue(req.Context(), protocolVersionContextKey{}, protocolVersion)) + + if h.opts.Stateless { + h.serveStateless(w, req) + } else { + h.serveStateful(w, req) + } +} + +// serveStateless handles requests for stateless servers. +// Stateless servers only support POST. Each request creates a temporary +// session that is closed when the request completes. +func (h *StreamableHTTPHandler) serveStateless(w http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + // RFC 9110 §15.5.6: 405 responses MUST include Allow header. + w.Header().Set("Allow", "POST") + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + return + } + + if baseMediaType(req.Header.Get("Content-Type")) != "application/json" { http.Error(w, "Content-Type must be 'application/json'", http.StatusUnsupportedMediaType) return } - // Allow multiple 'Accept' headers. - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept#syntax + // Accept must contain both 'application/json' and 'text/event-stream'. jsonOK, streamOK := streamableAccepts(req.Header.Values("Accept")) - - if req.Method == http.MethodGet { - if !streamOK { - http.Error(w, "Accept must contain 'text/event-stream' for GET requests", http.StatusBadRequest) - return - } - } else if (!jsonOK || !streamOK) && req.Method != http.MethodDelete { // TODO: consolidate with handling of http method below. + if !jsonOK || !streamOK { http.Error(w, "Accept must contain both 'application/json' and 'text/event-stream'", http.StatusBadRequest) return } + server := h.getServer(req) + if server == nil { + http.Error(w, "no server available", http.StatusBadRequest) + return + } + + // In stateless mode, the client may provide a session ID for application- + // level state correlation. If absent, generate one. sessionID := req.Header.Get(sessionIDHeader) - var sessInfo *sessionInfo - if sessionID != "" { - h.mu.Lock() - sessInfo = h.sessions[sessionID] - h.mu.Unlock() - if sessInfo == nil && !h.opts.Stateless { - // Unless we're in 'stateless' mode, which doesn't perform any Session-ID - // validation, we require that the session ID matches a known session. - // - // In stateless mode, a temporary transport is be created below. - http.Error(w, "session not found", http.StatusNotFound) - return - } - // Prevent session hijacking: if the session was created with a user ID, - // verify that subsequent requests come from the same user. - if sessInfo != nil && sessInfo.userID != "" { - tokenInfo := auth.TokenInfoFromContext(req.Context()) - if tokenInfo == nil || tokenInfo.UserID != sessInfo.userID { - http.Error(w, "session user mismatch", http.StatusForbidden) - return + if sessionID == "" { + sessionID = server.opts.GetSessionID() + } + + transport := &StreamableServerTransport{ + SessionID: sessionID, + Stateless: true, + EventStore: h.opts.EventStore, + jsonResponse: h.opts.JSONResponse, + logger: h.opts.Logger, + } + + connectOpts, err := h.ephemeralConnectOpts(req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + session, err := connectStreamable(req.Context(), server, transport, connectOpts) + if err != nil { + h.opts.Logger.Error(fmt.Sprintf("failed to connect: %v", err)) + http.Error(w, "failed connection", http.StatusInternalServerError) + return + } + defer session.Close() + + transport.ServeHTTP(w, req) +} + +// ephemeralConnectOpts peeks at the request body to determine whether it +// contains an initialize or initialized message. If not, default session state +// is constructed so that the session doesn't reject the request. +// It is used for both stateless servers and stateful servers with no session ID. +func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (*ServerSessionOptions, error) { + protocolVersion := protocolVersionFromContext(req.Context()) + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + + var hasInitialize, hasInitialized bool + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, fmt.Errorf("failed to read body") + } + req.Body.Close() + req.Body = io.NopCloser(bytes.NewBuffer(body)) + msgs, _, err := readBatch(body) + if err == nil { + for _, msg := range msgs { + if r, ok := msg.(*jsonrpc.Request); ok { + switch r.Method { + case methodInitialize: + hasInitialize = true + case notificationInitialized: + hasInitialized = true + } } } } - if req.Method == http.MethodDelete { - if sessionID == "" { - http.Error(w, "Bad Request: DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) - return - } - if sessInfo != nil { // sessInfo may be nil in stateless mode - // Closing the session also removes it from h.sessions, due to the - // onClose callback. - sessInfo.session.Close() + state := new(ServerSessionState) + if !hasInitialize { + state.InitializeParams = &InitializeParams{ + ProtocolVersion: protocolVersion, } - w.WriteHeader(http.StatusNoContent) - return } + if !hasInitialized { + state.InitializedParams = new(InitializedParams) + } + state.LogLevel = "info" + return &ServerSessionOptions{ + State: state, + }, nil +} + +func connectStreamable(ctx context.Context, server *Server, transport *StreamableServerTransport, opts *ServerSessionOptions) (*ServerSession, error) { + s, err := server.Connect(ctx, transport, opts) + if err != nil { + return nil, err + } + transport.connection.toolLookup = server.getServerTool + return s, nil +} +// serveStateful handles requests for stateful servers. +// Stateful servers support GET, POST, and DELETE, and maintain persistent +// sessions keyed by session ID. +func (h *StreamableHTTPHandler) serveStateful(w http.ResponseWriter, req *http.Request) { switch req.Method { - case http.MethodPost, http.MethodGet: - if req.Method == http.MethodGet && (h.opts.Stateless || sessionID == "") { - if h.opts.Stateless { - // Per MCP spec: server MUST return 405 if it doesn't offer SSE stream. - // In stateless mode, GET (SSE streaming) is not supported. - // RFC 9110 §15.5.6: 405 responses MUST include Allow header. - w.Header().Set("Allow", "POST") - http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) - } else { - // In stateful mode, GET is supported but requires a session ID. - // This is a precondition error, similar to DELETE without session. - http.Error(w, "Bad Request: GET requires an Mcp-Session-Id header", http.StatusBadRequest) - } - return - } + case http.MethodGet: + h.serveStatefulGET(w, req) + case http.MethodPost: + h.serveStatefulPOST(w, req) + case http.MethodDelete: + h.serveStatefulDELETE(w, req) default: // RFC 9110 §15.5.6: 405 responses MUST include Allow header. - if h.opts.Stateless { - w.Header().Set("Allow", "POST") - } else { - w.Header().Set("Allow", "GET, POST, DELETE") - } + w.Header().Set("Allow", "GET, POST, DELETE") http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + } +} + +// lookupSession looks up a session by the Mcp-Session-Id header value. +// It returns the session info and whether the caller should proceed. If ok is +// false, an error response has been written. The sessionID must be non-empty; +// callers are responsible for checking this before calling lookupSession. +func (h *StreamableHTTPHandler) lookupSession(w http.ResponseWriter, req *http.Request, sessionID string) (info *sessionInfo, ok bool) { + h.mu.Lock() + info = h.sessions[sessionID] + h.mu.Unlock() + if info == nil { + http.Error(w, "session not found", http.StatusNotFound) + return nil, false + } + if info.userID != "" { + tokenInfo := auth.TokenInfoFromContext(req.Context()) + if tokenInfo == nil || tokenInfo.UserID != info.userID { + http.Error(w, "session user mismatch", http.StatusForbidden) + return nil, false + } + } + return info, true +} + +// serveStatefulGET handles GET requests for standalone SSE streams. +// GET requires a valid Mcp-Session-Id header. +func (h *StreamableHTTPHandler) serveStatefulGET(w http.ResponseWriter, req *http.Request) { + if _, streamOK := streamableAccepts(req.Header.Values("Accept")); !streamOK { + http.Error(w, "Accept must contain 'text/event-stream' for GET requests", http.StatusBadRequest) return } - // [§2.7] of the spec (2025-06-18) states: - // - // "If using HTTP, the client MUST include the MCP-Protocol-Version: - // HTTP header on all subsequent requests to the MCP - // server, allowing the MCP server to respond based on the MCP protocol - // version. - // - // For example: MCP-Protocol-Version: 2025-06-18 - // The protocol version sent by the client SHOULD be the one negotiated during - // initialization. - // - // For backwards compatibility, if the server does not receive an - // MCP-Protocol-Version header, and has no other way to identify the version - - // for example, by relying on the protocol version negotiated during - // initialization - the server SHOULD assume protocol version 2025-03-26. - // - // If the server receives a request with an invalid or unsupported - // MCP-Protocol-Version, it MUST respond with 400 Bad Request." - // - // Since this wasn't present in the 2025-03-26 version of the spec, this - // effectively means: - // 1. IF the client provides a version header, it must be a supported - // version. - // 2. In stateless mode, where we've lost the state of the initialize - // request, we assume that whatever the client tells us is the truth (or - // assume 2025-03-26 if the client doesn't say anything). - // - // This logic matches the typescript SDK. - // - // [§2.7]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header - protocolVersion := req.Header.Get(protocolVersionHeader) - if protocolVersion == "" { - protocolVersion = protocolVersion20250326 + sessionID := req.Header.Get(sessionIDHeader) + if sessionID == "" { + http.Error(w, "Bad Request: GET requires an Mcp-Session-Id header", http.StatusBadRequest) + return } - if !slices.Contains(supportedProtocolVersions, protocolVersion) { - http.Error(w, fmt.Sprintf("Bad Request: Unsupported protocol version (supported versions: %s)", strings.Join(supportedProtocolVersions, ",")), http.StatusBadRequest) + + sessInfo, ok := h.lookupSession(w, req, sessionID) + if !ok { + return + } + + sessInfo.transport.ServeHTTP(w, req) +} + +// serveStatefulDELETE handles DELETE requests for session termination. +// DELETE requires a valid Mcp-Session-Id header. +func (h *StreamableHTTPHandler) serveStatefulDELETE(w http.ResponseWriter, req *http.Request) { + sessionID := req.Header.Get(sessionIDHeader) + if sessionID == "" { + http.Error(w, "Bad Request: DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) + return + } + + sessInfo, ok := h.lookupSession(w, req, sessionID) + if !ok { + return + } + + sessInfo.session.Close() + w.WriteHeader(http.StatusNoContent) +} + +// serveStatefulPOST handles POST requests for stateful servers. +// POST may arrive with or without a Mcp-Session-Id header. Without a session +// ID, a new session is created (this is the normal path for the first +// initialize request). +func (h *StreamableHTTPHandler) serveStatefulPOST(w http.ResponseWriter, req *http.Request) { + if baseMediaType(req.Header.Get("Content-Type")) != "application/json" { + http.Error(w, "Content-Type must be 'application/json'", http.StatusUnsupportedMediaType) + return + } + + jsonOK, streamOK := streamableAccepts(req.Header.Values("Accept")) + if !jsonOK || !streamOK { + http.Error(w, "Accept must contain both 'application/json' and 'text/event-stream'", http.StatusBadRequest) return } - if sessInfo == nil { - server := h.getServer(req) - if server == nil { - // The getServer argument to NewStreamableHTTPHandler returned nil. - http.Error(w, "no server available", http.StatusBadRequest) + sessionID := req.Header.Get(sessionIDHeader) + + // Look up existing session if a session ID was provided. + if sessionID != "" { + sessInfo, ok := h.lookupSession(w, req, sessionID) + if !ok { return } - if sessionID == "" { - // In stateless mode, sessionID may be nonempty even if there's no - // existing transport. - sessionID = server.opts.GetSessionID() - } - transport := &StreamableServerTransport{ - SessionID: sessionID, - Stateless: h.opts.Stateless, - EventStore: h.opts.EventStore, - jsonResponse: h.opts.JSONResponse, - logger: h.opts.Logger, - } + sessInfo.startPOST() + defer sessInfo.endPOST() + sessInfo.transport.ServeHTTP(w, req) + return + } - // Sessions without a session ID are also stateless: there's no way to - // address them. - stateless := h.opts.Stateless || sessionID == "" - // To support stateless mode, we initialize the session with a default - // state, so that it doesn't reject subsequent requests. - var connectOpts *ServerSessionOptions - if stateless { - // Peek at the body to see if it is initialize or initialized. - // We want those to be handled as usual. - var hasInitialize, hasInitialized bool - { - // TODO: verify that this allows protocol version negotiation for - // stateless servers. - body, err := io.ReadAll(req.Body) - if err != nil { - http.Error(w, "failed to read body", http.StatusBadRequest) - return - } - req.Body.Close() - - // Reset the body so that it can be read later. - req.Body = io.NopCloser(bytes.NewBuffer(body)) - - msgs, _, err := readBatch(body) - if err == nil { - for _, msg := range msgs { - if req, ok := msg.(*jsonrpc.Request); ok { - switch req.Method { - case methodInitialize: - hasInitialize = true - case notificationInitialized: - hasInitialized = true - } - } - } - } - } + // No session ID: create a new session. + server := h.getServer(req) + if server == nil { + http.Error(w, "no server available", http.StatusBadRequest) + return + } + sessionID = server.opts.GetSessionID() - // If we don't have InitializeParams or InitializedParams in the request, - // set the initial state to a default value. - state := new(ServerSessionState) - if !hasInitialize { - state.InitializeParams = &InitializeParams{ - ProtocolVersion: protocolVersion, - } - } - if !hasInitialized { - state.InitializedParams = new(InitializedParams) - } - state.LogLevel = "info" - connectOpts = &ServerSessionOptions{ - State: state, - } - } else { - // Cleanup is only required in stateful mode, as transportation is - // not stored in the map otherwise. - connectOpts = &ServerSessionOptions{ - onClose: func() { - h.mu.Lock() - defer h.mu.Unlock() - if info, ok := h.sessions[transport.SessionID]; ok { - info.stopTimer() - delete(h.sessions, transport.SessionID) - if h.onTransportDeletion != nil { - h.onTransportDeletion(transport.SessionID) - } - } - }, - } - } + transport := &StreamableServerTransport{ + SessionID: sessionID, + Stateless: false, + EventStore: h.opts.EventStore, + jsonResponse: h.opts.JSONResponse, + logger: h.opts.Logger, + } - // Pass req.Context() here, to allow middleware to add context values. - // The context is detached in the jsonrpc2 library when handling the - // long-running stream. - session, err := server.Connect(req.Context(), transport, connectOpts) + // Sessions without a session ID (GetSessionID returned "") are ephemeral: + // there's no way to address them, so they are closed after the request. + // This can happen when ServerOptions.GetSessionID is explicitly set to + // return "" to suppress session IDs entirely. It also covers any request + // that arrives before a session exists (e.g. initialize or ping) on a + // server configured this way. + if sessionID == "" { + connectOpts, err := h.ephemeralConnectOpts(req) if err != nil { - http.Error(w, "failed connection", http.StatusInternalServerError) + http.Error(w, err.Error(), http.StatusBadRequest) return } - transport.connection.toolLookup = server.getServerTool - // Capture the user ID from the token info to enable session hijacking - // prevention on subsequent requests. - var userID string - if tokenInfo := auth.TokenInfoFromContext(req.Context()); tokenInfo != nil { - userID = tokenInfo.UserID - } - sessInfo = &sessionInfo{ - session: session, - transport: transport, - userID: userID, + session, err := connectStreamable(req.Context(), server, transport, connectOpts) + if err != nil { + h.opts.Logger.Error(fmt.Sprintf("failed to connect: %v", err)) + http.Error(w, "failed connection", http.StatusInternalServerError) + return } + defer session.Close() + transport.ServeHTTP(w, req) + return + } - if stateless { - // Stateless mode: close the session when the request exits. - defer session.Close() // close the fake session after handling the request - } else { - // Otherwise, save the transport so that it can be reused - - // Clean up the session when it times out. - // - // Note that the timer here may fire multiple times, but - // sessInfo.session.Close is idempotent. - if h.opts.SessionTimeout > 0 { - sessInfo.timeout = h.opts.SessionTimeout - sessInfo.timer = time.AfterFunc(sessInfo.timeout, func() { - sessInfo.session.Close() - }) - } + connectOpts := &ServerSessionOptions{ + onClose: func() { h.mu.Lock() - h.sessions[transport.SessionID] = sessInfo - h.mu.Unlock() - defer func() { - // If initialization failed, clean up the session (#578). - if session.InitializeParams() == nil { - // Initialization failed. - session.Close() + defer h.mu.Unlock() + if info, ok := h.sessions[transport.SessionID]; ok { + info.stopTimer() + delete(h.sessions, transport.SessionID) + if h.onTransportDeletion != nil { + h.onTransportDeletion(transport.SessionID) } - }() - } + } + }, } - if req.Method == http.MethodPost { - sessInfo.startPOST() - defer sessInfo.endPOST() + // Pass req.Context() here, to allow middleware to add context values. + // The context is detached in the jsonrpc2 library when handling the + // long-running stream. + session, err := connectStreamable(req.Context(), server, transport, connectOpts) + if err != nil { + h.opts.Logger.Error(fmt.Sprintf("failed to connect: %v", err)) + http.Error(w, "failed connection", http.StatusInternalServerError) + return + } + // Capture the user ID from the token info to enable session hijacking + // prevention on subsequent requests. + var userID string + if tokenInfo := auth.TokenInfoFromContext(req.Context()); tokenInfo != nil { + userID = tokenInfo.UserID + } + sessInfo := &sessionInfo{ + session: session, + transport: transport, + userID: userID, } + if h.opts.SessionTimeout > 0 { + sessInfo.timeout = h.opts.SessionTimeout + sessInfo.timer = time.AfterFunc(sessInfo.timeout, func() { + sessInfo.session.Close() + }) + } + h.mu.Lock() + h.sessions[transport.SessionID] = sessInfo + h.mu.Unlock() + defer func() { + // If initialization failed, clean up the session (#578). + if session.InitializeParams() == nil { + session.Close() + } + }() + + sessInfo.startPOST() + defer sessInfo.endPOST() sessInfo.transport.ServeHTTP(w, req) } @@ -901,6 +965,18 @@ func (c *streamableServerConn) newStream(ctx context.Context, requests map[jsonr // ID, we avoid having to make this API decision. type idContextKey struct{} +// protocolVersionContextKey stores the protocol version extracted from the +// MCP-Protocol-Version HTTP header for use by lower layers. +type protocolVersionContextKey struct{} + +// protocolVersionFromContext returns the protocol version from the context, or +// the empty string if not set. An empty string means the version is unknown +// (e.g. the header was absent). +func protocolVersionFromContext(ctx context.Context) string { + v, _ := ctx.Value(protocolVersionContextKey{}).(string) + return v +} + // ServeHTTP handles a single HTTP request for the session. func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) { if t.connection == nil { @@ -946,9 +1022,7 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request ctx := req.Context() - // Read the protocol version from the header. For GET requests, this should - // always be present since GET only happens after initialization. - protocolVersion := req.Header.Get(protocolVersionHeader) + protocolVersion := protocolVersionFromContext(ctx) if protocolVersion == "" { protocolVersion = protocolVersion20250326 } @@ -1141,7 +1215,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques return } - protocolVersion := req.Header.Get(protocolVersionHeader) + protocolVersion := protocolVersionFromContext(req.Context()) if protocolVersion == "" { protocolVersion = protocolVersion20250326 } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 77eed1b0..d3cd9f8d 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -2428,13 +2428,13 @@ func TestStreamable405AllowHeader(t *testing.T) { wantAllow: "POST", }, { - // DELETE without session returns 400 Bad Request (not 405) - // because DELETE is a valid method, just requires a session ID. + // In stateless mode, only POST is supported. + // DELETE returns 405, consistent with Allow: POST. name: "DELETE without session stateless", stateless: true, method: "DELETE", - wantStatus: http.StatusBadRequest, - wantAllow: "", // No Allow header for 400 responses + wantStatus: http.StatusMethodNotAllowed, + wantAllow: "POST", }, }