From ea195161094a47e480668c1adf033e4e30ddb2b2 Mon Sep 17 00:00:00 2001 From: Mathias Bergqvist Date: Mon, 4 May 2026 20:41:21 +0200 Subject: [PATCH] feat(mcp): origin allowlist middleware --- internal/mcp/origin.go | 27 ++++++++++++++++++++++ internal/mcp/origin_test.go | 45 +++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 internal/mcp/origin.go create mode 100644 internal/mcp/origin_test.go diff --git a/internal/mcp/origin.go b/internal/mcp/origin.go new file mode 100644 index 0000000..3c4f716 --- /dev/null +++ b/internal/mcp/origin.go @@ -0,0 +1,27 @@ +package mcp + +import "net/http" + +// OriginAllowlist returns middleware that rejects requests whose Origin header +// is not in the allowlist. Empty Origin (e.g. server-side curl) is allowed +// because Origin is browser-only by design. +func OriginAllowlist(allowed []string) func(http.Handler) http.Handler { + set := make(map[string]struct{}, len(allowed)) + for _, a := range allowed { + set[a] = struct{}{} + } + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + if origin == "" { + next.ServeHTTP(w, r) + return + } + if _, ok := set[origin]; !ok { + http.Error(w, "origin not allowed", http.StatusForbidden) + return + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/mcp/origin_test.go b/internal/mcp/origin_test.go new file mode 100644 index 0000000..d162b5b --- /dev/null +++ b/internal/mcp/origin_test.go @@ -0,0 +1,45 @@ +package mcp_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "gitea.d-ma.be/mathias/gitea-mcp/internal/mcp" + "github.com/stretchr/testify/assert" +) + +func TestOriginAllowlist(t *testing.T) { + allow := []string{"https://claude.ai", "https://api.anthropic.com"} + called := false + h := mcp.OriginAllowlist(allow)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + + cases := []struct { + name string + origin string + wantCode int + wantCalled bool + }{ + {"allowed", "https://claude.ai", 200, true}, + {"allowed-2", "https://api.anthropic.com", 200, true}, + {"forbidden", "https://evil.example", 403, false}, + {"empty allowed (server-side caller)", "", 200, true}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + called = false + req := httptest.NewRequest(http.MethodPost, "/", nil) + if tc.origin != "" { + req.Header.Set("Origin", tc.origin) + } + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + assert.Equal(t, tc.wantCode, rr.Code) + assert.Equal(t, tc.wantCalled, called) + }) + } +}