diff --git a/.context/mcp.json b/.context/mcp.json index 8da20c9..4b4dd4b 100644 --- a/.context/mcp.json +++ b/.context/mcp.json @@ -16,7 +16,10 @@ }, "infra": { "type": "http", - "url": "https://infra-mcp.d-ma.be/mcp" + "url": "https://infra-mcp.d-ma.be/mcp", + "headers": { + "Authorization": "Bearer ${INFRA_MCP_TOKEN}" + } } } } diff --git a/cmd/gitea-mcp/main.go b/cmd/gitea-mcp/main.go index 5d4032b..e977044 100644 --- a/cmd/gitea-mcp/main.go +++ b/cmd/gitea-mcp/main.go @@ -2,10 +2,12 @@ package main import ( "context" - "encoding/json" "log/slog" "net/http" "os" + "strings" + + chassisauth "gitea.d-ma.be/mathias/mcp-chassis/auth" "gitea.d-ma.be/mathias/gitea-mcp/internal/allowlist" "gitea.d-ma.be/mathias/gitea-mcp/internal/auth" @@ -27,7 +29,7 @@ func main() { ctx := context.Background() - jwtValidator, err := auth.NewJWTValidator(ctx, cfg.DexIssuerURL, cfg.MCPAudience) + jwtValidator, err := chassisauth.NewJWTValidator(ctx, cfg.DexIssuerURL, cfg.MCPAudience) if err != nil { logger.Warn("jwt validator init failed; JWT auth disabled", "err", err) } @@ -78,9 +80,17 @@ func main() { Sessions: mcp.NewSessionStore(), }) + // resourceMetadataURL is only emitted in the WWW-Authenticate challenge + // when both MCPResourceURL and a Dex issuer are wired; empty disables + // the challenge so static-only clients aren't pushed into OAuth discovery. + var resourceMetadataURL string + if cfg.MCPResourceURL != "" && cfg.DexIssuerURL != "" { + resourceMetadataURL = strings.TrimRight(cfg.MCPResourceURL, "/") + "/.well-known/oauth-protected-resource" + } + mux := http.NewServeMux() mux.Handle("/mcp", mcp.OriginAllowlist(cfg.OriginAllowlist)( - auth.BearerMiddleware(jwtValidator, cfg.StaticToken, + chassisauth.BearerMiddleware(cfg.StaticToken, jwtValidator, "gitea", resourceMetadataURL, auth.CallerMiddleware(mcpSrv), ), )) @@ -88,21 +98,10 @@ func main() { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) }) - mux.HandleFunc("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - w.Header().Set("Content-Type", "application/json") - payload := map[string]any{ - "resource": cfg.MCPResourceURL, - "authorization_servers": []string{}, - } - if cfg.DexIssuerURL != "" { - payload["authorization_servers"] = []string{cfg.DexIssuerURL} - } - _ = json.NewEncoder(w).Encode(payload) - }) + if cfg.DexIssuerURL != "" { + mux.HandleFunc("GET /.well-known/oauth-protected-resource", + chassisauth.ProtectedResourceHandler(cfg.MCPResourceURL, cfg.DexIssuerURL)) + } addr := ":" + cfg.Port logger.Info("gitea-mcp starting", "addr", addr, "version", "0.1.0") diff --git a/go.mod b/go.mod index 7a47925..eb7cd6d 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( ) require ( + gitea.d-ma.be/mathias/mcp-chassis v0.1.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect github.com/goccy/go-json v0.10.3 // indirect diff --git a/go.sum b/go.sum index b46ec4d..1b55aba 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +gitea.d-ma.be/mathias/mcp-chassis v0.1.0 h1:8RXO34+n7Vu8HnUMagars6fc4oemqRpMu7MVtjaj4qY= +gitea.d-ma.be/mathias/mcp-chassis v0.1.0/go.mod h1:ajbLlwr2L7FAN3TBU39KucZkKJM02wTbKbDKDEW2YvE= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/auth/bearer.go b/internal/auth/bearer.go deleted file mode 100644 index 55d5901..0000000 --- a/internal/auth/bearer.go +++ /dev/null @@ -1,42 +0,0 @@ -package auth - -import ( - "crypto/subtle" - "net/http" - "strings" -) - -// BearerMiddleware authenticates requests via the Authorization header. -// -// A request is allowed when: -// -// 1. The Bearer token is a valid JWT issued by the configured Dex OIDC server, or -// 2. The Bearer token matches staticToken (constant-time compare). -// -// Any other case — including missing or empty Authorization header — returns 401. -// -// The Gitea service PAT is intentionally NOT used to authenticate the caller: -// it is only used by the Gitea client for upstream API calls. Decoupling the -// two prevents the MCP endpoint from being reachable anonymously when a service -// PAT happens to be configured. -func BearerMiddleware(jwtValidator *JWTValidator, staticToken string, next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - bearer, hasBearer := strings.CutPrefix(r.Header.Get("Authorization"), "Bearer ") - if !hasBearer || bearer == "" { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - - if jwtValidator.Validate(r.Context(), bearer) { - next.ServeHTTP(w, r) - return - } - - if staticToken != "" && subtle.ConstantTimeCompare([]byte(bearer), []byte(staticToken)) == 1 { - next.ServeHTTP(w, r) - return - } - - http.Error(w, "unauthorized", http.StatusUnauthorized) - }) -} diff --git a/internal/auth/bearer_test.go b/internal/auth/bearer_test.go deleted file mode 100644 index ddc1d92..0000000 --- a/internal/auth/bearer_test.go +++ /dev/null @@ -1,92 +0,0 @@ -package auth_test - -import ( - "net/http" - "net/http/httptest" - "testing" - - "gitea.d-ma.be/mathias/gitea-mcp/internal/auth" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func okHandler(called *bool) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - if called != nil { - *called = true - } - w.WriteHeader(http.StatusOK) - }) -} - -func TestBearerMiddleware_NoAuthHeader(t *testing.T) { - srv := httptest.NewServer(auth.BearerMiddleware(nil, "", okHandler(nil))) - defer srv.Close() - - resp, err := http.Post(srv.URL+"/mcp", "application/json", nil) - require.NoError(t, err) - defer func() { _ = resp.Body.Close() }() - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) -} - -func TestBearerMiddleware_NoAuthHeader_RejectsEvenWhenStaticConfigured(t *testing.T) { - // A configured staticToken must not allow unauthenticated callers through. - srv := httptest.NewServer(auth.BearerMiddleware(nil, "any-static", okHandler(nil))) - defer srv.Close() - - resp, err := http.Post(srv.URL+"/mcp", "application/json", nil) - require.NoError(t, err) - defer func() { _ = resp.Body.Close() }() - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) -} - -func TestBearerMiddleware_EmptyBearer(t *testing.T) { - srv := httptest.NewServer(auth.BearerMiddleware(nil, "static", okHandler(nil))) - defer srv.Close() - - req, _ := http.NewRequest(http.MethodPost, srv.URL+"/mcp", nil) - req.Header.Set("Authorization", "Bearer ") - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - defer func() { _ = resp.Body.Close() }() - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) -} - -func TestBearerMiddleware_StaticToken_Valid(t *testing.T) { - const staticToken = "my-static-token" - called := false - srv := httptest.NewServer(auth.BearerMiddleware(nil, staticToken, okHandler(&called))) - defer srv.Close() - - req, _ := http.NewRequest(http.MethodPost, srv.URL+"/mcp", nil) - req.Header.Set("Authorization", "Bearer "+staticToken) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - defer func() { _ = resp.Body.Close() }() - assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.True(t, called) -} - -func TestBearerMiddleware_StaticToken_Invalid(t *testing.T) { - srv := httptest.NewServer(auth.BearerMiddleware(nil, "correct-token", okHandler(nil))) - defer srv.Close() - - req, _ := http.NewRequest(http.MethodPost, srv.URL+"/mcp", nil) - req.Header.Set("Authorization", "Bearer wrong-token") - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - defer func() { _ = resp.Body.Close() }() - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) -} - -func TestBearerMiddleware_UnknownBearer_NoStatic_NoJWT(t *testing.T) { - srv := httptest.NewServer(auth.BearerMiddleware(nil, "", okHandler(nil))) - defer srv.Close() - - req, _ := http.NewRequest(http.MethodPost, srv.URL+"/mcp", nil) - req.Header.Set("Authorization", "Bearer random-unknown-token") - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - defer func() { _ = resp.Body.Close() }() - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) -} diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go deleted file mode 100644 index 94367cc..0000000 --- a/internal/auth/jwt.go +++ /dev/null @@ -1,79 +0,0 @@ -package auth - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "time" - - "github.com/lestrrat-go/jwx/v2/jwk" - "github.com/lestrrat-go/jwx/v2/jwt" -) - -// JWTValidator validates bearer tokens as JWTs issued by a Dex OIDC server. -// A nil JWTValidator always returns false — JWT validation is disabled. -type JWTValidator struct { - issuer string - aud string - cache *jwk.Cache - jwksURI string -} - -// NewJWTValidator creates a validator by fetching the OIDC discovery document -// from issuerURL. Returns nil, nil when issuerURL is empty (disabled). -func NewJWTValidator(ctx context.Context, issuerURL, audience string) (*JWTValidator, error) { - if issuerURL == "" { - return nil, nil - } - - resp, err := http.Get(issuerURL + "/.well-known/openid-configuration") - if err != nil { - return nil, fmt.Errorf("fetch oidc discovery: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - var doc struct { - JWKSURI string `json:"jwks_uri"` - } - if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { - return nil, fmt.Errorf("decode oidc discovery: %w", err) - } - - cache := jwk.NewCache(ctx) - if err := cache.Register(doc.JWKSURI, jwk.WithRefreshInterval(time.Hour)); err != nil { - return nil, fmt.Errorf("register jwks uri: %w", err) - } - // warm the cache immediately so first request doesn't block - if _, err := cache.Refresh(ctx, doc.JWKSURI); err != nil { - return nil, fmt.Errorf("warm jwks cache: %w", err) - } - - return &JWTValidator{ - issuer: issuerURL, - aud: audience, - cache: cache, - jwksURI: doc.JWKSURI, - }, nil -} - -// Validate returns true if rawToken is a valid JWT signed by the OIDC server. -func (v *JWTValidator) Validate(ctx context.Context, rawToken string) bool { - if v == nil { - return false - } - keySet, err := v.cache.Get(ctx, v.jwksURI) - if err != nil { - return false - } - opts := []jwt.ParseOption{ - jwt.WithKeySet(keySet), - jwt.WithIssuer(v.issuer), - jwt.WithValidate(true), - } - if v.aud != "" { - opts = append(opts, jwt.WithAudience(v.aud)) - } - _, err = jwt.Parse([]byte(rawToken), opts...) - return err == nil -}