Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion pkg/http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package http
import (
"context"
"errors"
"strings"
"log/slog"
"net/http"
Comment thread
RossTarrant marked this conversation as resolved.
Outdated

ghcontext "github.com/github/github-mcp-server/pkg/context"
"github.com/github/github-mcp-server/pkg/github"
"github.com/github/github-mcp-server/pkg/http/headers"
"github.com/github/github-mcp-server/pkg/http/middleware"
"github.com/github/github-mcp-server/pkg/http/oauth"
"github.com/github/github-mcp-server/pkg/inventory"
Expand Down Expand Up @@ -226,7 +228,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
mcpHandler := mcp.NewStreamableHTTPHandler(func(_ *http.Request) *mcp.Server {
return ghServer
}, &mcp.StreamableHTTPOptions{
Stateless: true,
Stateless: true,
CrossOriginProtection: h.config.CrossOriginProtection,
})

mcpHandler.ServeHTTP(w, r)
Expand Down Expand Up @@ -412,3 +415,39 @@ func PATScopeFilter(b *inventory.Builder, r *http.Request, fetcher scopes.Fetche

return b
}

// corsAllowHeaders is the precomputed Access-Control-Allow-Headers value.
var corsAllowHeaders = strings.Join([]string{
"Content-Type",
"Mcp-Session-Id",
"Mcp-Protocol-Version",
"Last-Event-ID",
headers.AuthorizationHeader,
headers.MCPReadOnlyHeader,
headers.MCPToolsetsHeader,
headers.MCPToolsHeader,
headers.MCPExcludeToolsHeader,
headers.MCPFeaturesHeader,
headers.MCPLockdownHeader,
headers.MCPInsidersHeader,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was actually going to say define this inside the SetCorsHeaders (before returning the closure), so it's defined when you create the middleware and then used each time.

Also more importantly, we have a middleware package and this should move there.

Copy link
Copy Markdown
Author

@RossTarrant RossTarrant Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the feedback, this logic now sits in cors.go and defines the allowed headers inside SetCorsHeaders like you suggested

}, ", ")

// SetCorsHeaders is middleware that sets CORS headers to allow browser-based
// MCP clients to connect from any origin. This is safe because the server
// authenticates via bearer tokens (not cookies), so cross-origin requests
// cannot exploit ambient credentials.
func SetCorsHeaders(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
w.Header().Set("Access-Control-Max-Age", "86400")
w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id, WWW-Authenticate")
w.Header().Set("Access-Control-Allow-Headers", corsAllowHeaders)

if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
h.ServeHTTP(w, r)
})
}
149 changes: 149 additions & 0 deletions pkg/http/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http/httptest"
"slices"
"sort"
"strings"
"testing"

ghcontext "github.com/github/github-mcp-server/pkg/context"
Expand Down Expand Up @@ -660,3 +661,151 @@ func buildStaticInventoryFromTools(cfg *ServerConfig, tools []inventory.ServerTo
ctx := context.Background()
return inv.AvailableTools(ctx), inv.AvailableResourceTemplates(ctx), inv.AvailablePrompts(ctx)
}

func TestSetCorsHeaders(t *testing.T) {
inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := SetCorsHeaders(inner)

t.Run("OPTIONS preflight returns 200 with CORS headers", func(t *testing.T) {
req := httptest.NewRequest(http.MethodOptions, "/", nil)
req.Header.Set("Origin", "http://localhost:6274")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin"))
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Methods"), "POST")
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Authorization")
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Content-Type")
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Mcp-Session-Id")
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "X-MCP-Lockdown")
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "X-MCP-Insiders")
assert.Contains(t, rr.Header().Get("Access-Control-Expose-Headers"), "Mcp-Session-Id")
assert.Contains(t, rr.Header().Get("Access-Control-Expose-Headers"), "WWW-Authenticate")
})

t.Run("POST request includes CORS headers", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set("Origin", "http://localhost:6274")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin"))
})
}

func TestCrossOriginProtection(t *testing.T) {
jsonRPCBody := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"test","version":"0.1"}}}`

newHandler := func(t *testing.T, crossOriginProtection *http.CrossOriginProtection) http.Handler {
t.Helper()

apiHost, err := utils.NewAPIHost("https://api.githubcopilot.com")
require.NoError(t, err)

handler := NewHTTPMcpHandler(
context.Background(),
&ServerConfig{
Version: "test",
CrossOriginProtection: crossOriginProtection,
},
nil,
translations.NullTranslationHelper,
slog.Default(),
apiHost,
WithInventoryFactory(func(_ *http.Request) (*inventory.Inventory, error) {
return inventory.NewBuilder().Build()
}),
WithGitHubMCPServerFactory(func(_ *http.Request, _ github.ToolDependencies, _ *inventory.Inventory, _ *github.MCPServerConfig) (*mcp.Server, error) {
return mcp.NewServer(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil), nil
}),
WithScopeFetcher(allScopesFetcher{}),
)

r := chi.NewRouter()
handler.RegisterMiddleware(r)
handler.RegisterRoutes(r)
return r
}

tests := []struct {
name string
crossOriginProtection *http.CrossOriginProtection
secFetchSite string
origin string
expectedStatusCode int
}{
{
name: "SDK default rejects cross-site when no bypass configured",
secFetchSite: "cross-site",
origin: "https://evil.example.com",
expectedStatusCode: http.StatusForbidden,
},
{
name: "SDK default allows same-origin request",
secFetchSite: "same-origin",
expectedStatusCode: http.StatusOK,
},
{
name: "SDK default allows request without Sec-Fetch-Site (native client)",
secFetchSite: "",
expectedStatusCode: http.StatusOK,
},
{
name: "bypass protection allows cross-site request",
crossOriginProtection: func() *http.CrossOriginProtection {
p := http.NewCrossOriginProtection()
p.AddInsecureBypassPattern("/")
return p
}(),
secFetchSite: "cross-site",
origin: "https://example.com",
expectedStatusCode: http.StatusOK,
},
{
name: "bypass protection still allows same-origin request",
crossOriginProtection: func() *http.CrossOriginProtection {
p := http.NewCrossOriginProtection()
p.AddInsecureBypassPattern("/")
return p
}(),
secFetchSite: "same-origin",
expectedStatusCode: http.StatusOK,
},
{
name: "bypass protection allows request without Sec-Fetch-Site (native client)",
crossOriginProtection: func() *http.CrossOriginProtection {
p := http.NewCrossOriginProtection()
p.AddInsecureBypassPattern("/")
return p
}(),
secFetchSite: "",
expectedStatusCode: http.StatusOK,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := newHandler(t, tt.crossOriginProtection)

req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(jsonRPCBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
req.Header.Set(headers.AuthorizationHeader, "Bearer github_pat_xyz")
if tt.secFetchSite != "" {
req.Header.Set("Sec-Fetch-Site", tt.secFetchSite)
}
if tt.origin != "" {
req.Header.Set("Origin", tt.origin)
}

rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)

assert.Equal(t, tt.expectedStatusCode, rr.Code, "unexpected status code; body: %s", rr.Body.String())
})
}
}
15 changes: 15 additions & 0 deletions pkg/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ type ServerConfig struct {

// InsidersMode indicates if we should enable experimental features.
InsidersMode bool

// CrossOriginProtection configures the SDK's cross-origin request protection.
// If nil and using RunHTTPServer, cross-origin requests are allowed (auto-bypass).
// If nil and using the handler as a library, the SDK default (reject) applies.
CrossOriginProtection *http.CrossOriginProtection
}

func RunHTTPServer(cfg ServerConfig) error {
Expand Down Expand Up @@ -159,6 +164,14 @@ func RunHTTPServer(cfg ServerConfig) error {
serverOptions = append(serverOptions, WithScopeFetcher(scopeFetcher))
}

// Bypass cross-origin protection: this server uses bearer tokens, not
// cookies, so CSRF checks are unnecessary.
if cfg.CrossOriginProtection == nil {
p := http.NewCrossOriginProtection()
p.AddInsecureBypassPattern("/")
cfg.CrossOriginProtection = p
}

r := chi.NewRouter()
handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, apiHost, append(serverOptions, WithFeatureChecker(featureChecker), WithOAuthConfig(oauthCfg))...)
oauthHandler, err := oauth.NewAuthHandler(oauthCfg, apiHost)
Expand All @@ -167,6 +180,8 @@ func RunHTTPServer(cfg ServerConfig) error {
}

r.Group(func(r chi.Router) {
r.Use(SetCorsHeaders)

// Register Middleware First, needs to be before route registration
handler.RegisterMiddleware(r)

Expand Down
Loading