refactor(mcp): compose origin allowlist as middleware, remove duplication
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -23,13 +23,12 @@ func main() {
|
|||||||
// Tool registration happens in Phase 6+; for now, registry is empty.
|
// Tool registration happens in Phase 6+; for now, registry is empty.
|
||||||
|
|
||||||
mcpSrv := mcp.NewServer(mcp.ServerOptions{
|
mcpSrv := mcp.NewServer(mcp.ServerOptions{
|
||||||
Registry: reg,
|
Registry: reg,
|
||||||
OriginAllowlist: cfg.OriginAllowlist,
|
Sessions: mcp.NewSessionStore(),
|
||||||
Sessions: mcp.NewSessionStore(),
|
|
||||||
})
|
})
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.Handle("/mcp", mcpSrv)
|
mux.Handle("/mcp", mcp.OriginAllowlist(cfg.OriginAllowlist)(mcpSrv))
|
||||||
mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
|
mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
_, _ = w.Write([]byte("ok"))
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
|||||||
@@ -11,9 +11,8 @@ import (
|
|||||||
const ProtocolVersion = "2025-06-18"
|
const ProtocolVersion = "2025-06-18"
|
||||||
|
|
||||||
type ServerOptions struct {
|
type ServerOptions struct {
|
||||||
Registry *registry.Registry
|
Registry *registry.Registry
|
||||||
OriginAllowlist []string
|
Sessions *SessionStore
|
||||||
Sessions *SessionStore
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
@@ -28,24 +27,6 @@ func NewServer(opts ServerOptions) *Server {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
// Origin allowlist (no-op when allowlist empty or Origin missing)
|
|
||||||
if len(s.opts.OriginAllowlist) > 0 {
|
|
||||||
origin := r.Header.Get("Origin")
|
|
||||||
if origin != "" {
|
|
||||||
ok := false
|
|
||||||
for _, a := range s.opts.OriginAllowlist {
|
|
||||||
if a == origin {
|
|
||||||
ok = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, "origin not allowed", http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
s.handleGET(w, r)
|
s.handleGET(w, r)
|
||||||
|
|||||||
@@ -17,9 +17,8 @@ func newServer(t *testing.T) *mcp.Server {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
reg := registry.New()
|
reg := registry.New()
|
||||||
return mcp.NewServer(mcp.ServerOptions{
|
return mcp.NewServer(mcp.ServerOptions{
|
||||||
Registry: reg,
|
Registry: reg,
|
||||||
OriginAllowlist: nil,
|
Sessions: mcp.NewSessionStore(),
|
||||||
Sessions: mcp.NewSessionStore(),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,6 +67,22 @@ func TestPostWithoutSessionRejected(t *testing.T) {
|
|||||||
require.Equal(t, http.StatusBadRequest, rr.Code)
|
require.Equal(t, http.StatusBadRequest, rr.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServerWithOriginAllowlistRejectsBadOrigin(t *testing.T) {
|
||||||
|
srv := mcp.OriginAllowlist([]string{"https://claude.ai"})(newServer(t))
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"method": "initialize",
|
||||||
|
"params": map[string]any{"protocolVersion": "2025-06-18"},
|
||||||
|
})
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Origin", "https://evil.example")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
srv.ServeHTTP(rr, req)
|
||||||
|
assert.Equal(t, http.StatusForbidden, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
func TestToolsListAfterInitialize(t *testing.T) {
|
func TestToolsListAfterInitialize(t *testing.T) {
|
||||||
srv := newServer(t)
|
srv := newServer(t)
|
||||||
init := postJSON(t, srv, map[string]any{
|
init := postJSON(t, srv, map[string]any{
|
||||||
|
|||||||
Reference in New Issue
Block a user