package zitadel import ( "context" "crypto/rsa" "encoding/base64" "encoding/json" "fmt" "io" "math/big" "net/http" "sync" "time" "github.com/golang-jwt/jwt/v4" ) type jwksCache struct { mu sync.RWMutex fetchedAt time.Time keys map[string]*rsa.PublicKey } func (c *Client) jwksURL() string { if c.conf.JWKSUrl != "" { return c.conf.JWKSUrl } return c.issuer + "/oauth/v2/keys" } func (c *Client) VerifyIDToken(ctx context.Context, idToken string) (*IDTokenClaims, error) { if c == nil { return nil, ErrNotConfigured } if idToken == "" { return nil, fmt.Errorf("zitadel: id_token is required") } if c.conf.OAuthClientID == "" { return nil, fmt.Errorf("zitadel: oauth client id is required for id_token verification") } parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()})) token, err := parser.Parse(idToken, func(t *jwt.Token) (any, error) { kid, ok := t.Header["kid"].(string) if !ok || kid == "" { return nil, fmt.Errorf("zitadel: id_token missing kid") } return c.publicKeyForKID(ctx, kid) }) if err != nil { return nil, fmt.Errorf("%w: %w", ErrInvalidIDToken, err) } if !token.Valid { return nil, ErrInvalidIDToken } claims, ok := token.Claims.(jwt.MapClaims) if !ok { return nil, ErrInvalidIDToken } if err := c.validateIDTokenClaims(claims); err != nil { return nil, err } out := &IDTokenClaims{ Sub: stringClaim(claims, "sub"), Email: stringClaim(claims, "email"), EmailVerified: boolClaim(claims, "email_verified"), Name: stringClaim(claims, "name"), Locale: stringClaim(claims, "locale"), } if out.Sub == "" { return nil, ErrInvalidIDToken } return out, nil } func (c *Client) validateIDTokenClaims(claims jwt.MapClaims) error { iss := stringClaim(claims, "iss") if iss != c.issuer && iss != c.issuer+"/" { return fmt.Errorf("%w: unexpected iss", ErrInvalidIDToken) } if !audienceContains(claims["aud"], c.conf.OAuthClientID) { return fmt.Errorf("%w: unexpected aud", ErrInvalidIDToken) } expRaw, ok := claims["exp"] if !ok { return fmt.Errorf("%w: missing exp", ErrInvalidIDToken) } var expUnix int64 switch t := expRaw.(type) { case float64: expUnix = int64(t) case json.Number: v, err := t.Int64() if err != nil { return fmt.Errorf("%w: invalid exp", ErrInvalidIDToken) } expUnix = v default: return fmt.Errorf("%w: invalid exp", ErrInvalidIDToken) } if time.Now().UTC().Unix() >= expUnix { return fmt.Errorf("%w: token expired", ErrInvalidIDToken) } return nil } func (c *Client) publicKeyForKID(ctx context.Context, kid string) (*rsa.PublicKey, error) { if c.jwks == nil { c.jwks = &jwksCache{keys: make(map[string]*rsa.PublicKey)} } c.jwks.mu.RLock() if key, ok := c.jwks.keys[kid]; ok && time.Since(c.jwks.fetchedAt) < 5*time.Minute { c.jwks.mu.RUnlock() return key, nil } c.jwks.mu.RUnlock() c.jwks.mu.Lock() defer c.jwks.mu.Unlock() if key, ok := c.jwks.keys[kid]; ok && time.Since(c.jwks.fetchedAt) < 5*time.Minute { return key, nil } if err := c.refreshJWKS(ctx); err != nil { return nil, err } key, ok := c.jwks.keys[kid] if !ok { return nil, fmt.Errorf("zitadel: jwks kid not found: %s", kid) } return key, nil } func (c *Client) refreshJWKS(ctx context.Context) error { req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.jwksURL(), http.NoBody) if err != nil { return fmt.Errorf("zitadel: jwks request: %w", err) } req.Header.Set("Accept", "application/json") resp, err := c.http.Do(req) if err != nil { return fmt.Errorf("zitadel: jwks request: %w", err) } defer resp.Body.Close() raw, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("zitadel: read jwks body: %w", err) } if resp.StatusCode != http.StatusOK { return fmt.Errorf("zitadel: jwks request: status %d: %s", resp.StatusCode, truncateBody(raw)) } var payload struct { Keys []struct { Kty string `json:"kty"` Kid string `json:"kid"` N string `json:"n"` E string `json:"e"` } `json:"keys"` } if err := json.Unmarshal(raw, &payload); err != nil { return fmt.Errorf("zitadel: decode jwks: %w", err) } keys := make(map[string]*rsa.PublicKey, len(payload.Keys)) for _, k := range payload.Keys { if k.Kty != "RSA" || k.Kid == "" || k.N == "" || k.E == "" { continue } pub, err := rsaPublicKeyFromModExp(k.N, k.E) if err != nil { return err } keys[k.Kid] = pub } if len(keys) == 0 { return fmt.Errorf("zitadel: jwks contains no usable rsa keys") } c.jwks.keys = keys c.jwks.fetchedAt = time.Now().UTC() return nil } func rsaPublicKeyFromModExp(nB64, eB64 string) (*rsa.PublicKey, error) { nBytes, err := base64.RawURLEncoding.DecodeString(nB64) if err != nil { return nil, fmt.Errorf("zitadel: decode jwks n: %w", err) } eBytes, err := base64.RawURLEncoding.DecodeString(eB64) if err != nil { return nil, fmt.Errorf("zitadel: decode jwks e: %w", err) } n := new(big.Int).SetBytes(nBytes) e := new(big.Int).SetBytes(eBytes).Int64() if e <= 0 || e > int64(^uint(0)>>1) { return nil, fmt.Errorf("zitadel: invalid jwks exponent") } return &rsa.PublicKey{N: n, E: int(e)}, nil } func stringClaim(claims jwt.MapClaims, key string) string { v, ok := claims[key] if !ok || v == nil { return "" } switch t := v.(type) { case string: return t default: return fmt.Sprint(t) } } func boolClaim(claims jwt.MapClaims, key string) bool { v, ok := claims[key] if !ok || v == nil { return false } switch t := v.(type) { case bool: return t case string: return t == "true" default: return false } } func audienceContains(aud any, want string) bool { switch t := aud.(type) { case string: return t == want case []any: for _, item := range t { if s, ok := item.(string); ok && s == want { return true } } case []string: for _, s := range t { if s == want { return true } } } return false }