package auth import ( "context" "encoding/json" "fmt" "net/http" "time" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" ) // Validator validates Bearer JWTs issued by a Dex (OIDC) authorization server. // Audience is optional; leave empty to skip audience validation. type Validator struct { issuer string audience string jwksURI string cache *jwk.Cache } // NewValidator fetches the OIDC discovery document from issuerURL, extracts // jwks_uri, seeds the JWKS cache, and returns a ready Validator. // If DEX_ISSUER_URL is not set the caller should pass "" and skip construction. func NewValidator(issuerURL, audience string) (*Validator, error) { resp, err := http.Get(issuerURL + "/.well-known/openid-configuration") //nolint:noctx if err != nil { return nil, fmt.Errorf("fetch oidc discovery: %w", err) } defer resp.Body.Close() //nolint:errcheck if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("oidc discovery: status %d", resp.StatusCode) } var doc struct { JWKSURI string `json:"jwks_uri"` } if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { return nil, fmt.Errorf("decode oidc discovery: %w", err) } if doc.JWKSURI == "" { return nil, fmt.Errorf("oidc discovery: empty jwks_uri") } ctx := context.Background() cache := jwk.NewCache(ctx) if err := cache.Register(doc.JWKSURI, jwk.WithMinRefreshInterval(time.Hour)); err != nil { return nil, fmt.Errorf("register jwks cache: %w", err) } if _, err := cache.Refresh(ctx, doc.JWKSURI); err != nil { return nil, fmt.Errorf("initial jwks fetch: %w", err) } return &Validator{ issuer: issuerURL, audience: audience, jwksURI: doc.JWKSURI, cache: cache, }, nil } // Validate parses and validates rawToken. Returns the subject claim on success. func (v *Validator) Validate(ctx context.Context, rawToken string) (string, error) { keySet, err := v.cache.Get(ctx, v.jwksURI) if err != nil { return "", fmt.Errorf("get jwks: %w", err) } opts := []jwt.ParseOption{ jwt.WithKeySet(keySet), jwt.WithValidate(true), jwt.WithIssuer(v.issuer), } if v.audience != "" { opts = append(opts, jwt.WithAudience(v.audience)) } tok, err := jwt.ParseString(rawToken, opts...) if err != nil { return "", fmt.Errorf("validate jwt: %w", err) } return tok.Subject(), nil }