template-monorepo/internal/library/zitadel/jwks.go

251 lines
5.8 KiB
Go
Raw Permalink Normal View History

package zitadel
import (
"context"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/big"
"net/http"
"sync"
"time"
"github.com/golang-jwt/jwt/v4"
)
type jwksCache struct {
mu sync.RWMutex
fetchedAt time.Time
keys map[string]*rsa.PublicKey
}
func (c *Client) jwksURL() string {
if c.conf.JWKSUrl != "" {
return c.conf.JWKSUrl
}
return c.issuer + "/oauth/v2/keys"
}
func (c *Client) VerifyIDToken(ctx context.Context, idToken string) (*IDTokenClaims, error) {
if c == nil {
return nil, ErrNotConfigured
}
if idToken == "" {
return nil, fmt.Errorf("zitadel: id_token is required")
}
if c.conf.OAuthClientID == "" {
return nil, fmt.Errorf("zitadel: oauth client id is required for id_token verification")
}
parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()}))
token, err := parser.Parse(idToken, func(t *jwt.Token) (any, error) {
kid, ok := t.Header["kid"].(string)
if !ok || kid == "" {
return nil, fmt.Errorf("zitadel: id_token missing kid")
}
return c.publicKeyForKID(ctx, kid)
})
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrInvalidIDToken, err)
}
if !token.Valid {
return nil, ErrInvalidIDToken
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, ErrInvalidIDToken
}
if err := c.validateIDTokenClaims(claims); err != nil {
return nil, err
}
out := &IDTokenClaims{
Sub: stringClaim(claims, "sub"),
Email: stringClaim(claims, "email"),
EmailVerified: boolClaim(claims, "email_verified"),
Name: stringClaim(claims, "name"),
Locale: stringClaim(claims, "locale"),
}
if out.Sub == "" {
return nil, ErrInvalidIDToken
}
return out, nil
}
func (c *Client) validateIDTokenClaims(claims jwt.MapClaims) error {
iss := stringClaim(claims, "iss")
if iss != c.issuer && iss != c.issuer+"/" {
return fmt.Errorf("%w: unexpected iss", ErrInvalidIDToken)
}
if !audienceContains(claims["aud"], c.conf.OAuthClientID) {
return fmt.Errorf("%w: unexpected aud", ErrInvalidIDToken)
}
expRaw, ok := claims["exp"]
if !ok {
return fmt.Errorf("%w: missing exp", ErrInvalidIDToken)
}
var expUnix int64
switch t := expRaw.(type) {
case float64:
expUnix = int64(t)
case json.Number:
v, err := t.Int64()
if err != nil {
return fmt.Errorf("%w: invalid exp", ErrInvalidIDToken)
}
expUnix = v
default:
return fmt.Errorf("%w: invalid exp", ErrInvalidIDToken)
}
if time.Now().UTC().Unix() >= expUnix {
return fmt.Errorf("%w: token expired", ErrInvalidIDToken)
}
return nil
}
func (c *Client) publicKeyForKID(ctx context.Context, kid string) (*rsa.PublicKey, error) {
if c.jwks == nil {
c.jwks = &jwksCache{keys: make(map[string]*rsa.PublicKey)}
}
c.jwks.mu.RLock()
if key, ok := c.jwks.keys[kid]; ok && time.Since(c.jwks.fetchedAt) < 5*time.Minute {
c.jwks.mu.RUnlock()
return key, nil
}
c.jwks.mu.RUnlock()
c.jwks.mu.Lock()
defer c.jwks.mu.Unlock()
if key, ok := c.jwks.keys[kid]; ok && time.Since(c.jwks.fetchedAt) < 5*time.Minute {
return key, nil
}
if err := c.refreshJWKS(ctx); err != nil {
return nil, err
}
key, ok := c.jwks.keys[kid]
if !ok {
return nil, fmt.Errorf("zitadel: jwks kid not found: %s", kid)
}
return key, nil
}
func (c *Client) refreshJWKS(ctx context.Context) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.jwksURL(), http.NoBody)
if err != nil {
return fmt.Errorf("zitadel: jwks request: %w", err)
}
req.Header.Set("Accept", "application/json")
resp, err := c.http.Do(req)
if err != nil {
return fmt.Errorf("zitadel: jwks request: %w", err)
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("zitadel: read jwks body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("zitadel: jwks request: status %d: %s", resp.StatusCode, truncateBody(raw))
}
var payload struct {
Keys []struct {
Kty string `json:"kty"`
Kid string `json:"kid"`
N string `json:"n"`
E string `json:"e"`
} `json:"keys"`
}
if err := json.Unmarshal(raw, &payload); err != nil {
return fmt.Errorf("zitadel: decode jwks: %w", err)
}
keys := make(map[string]*rsa.PublicKey, len(payload.Keys))
for _, k := range payload.Keys {
if k.Kty != "RSA" || k.Kid == "" || k.N == "" || k.E == "" {
continue
}
pub, err := rsaPublicKeyFromModExp(k.N, k.E)
if err != nil {
return err
}
keys[k.Kid] = pub
}
if len(keys) == 0 {
return fmt.Errorf("zitadel: jwks contains no usable rsa keys")
}
c.jwks.keys = keys
c.jwks.fetchedAt = time.Now().UTC()
return nil
}
func rsaPublicKeyFromModExp(nB64, eB64 string) (*rsa.PublicKey, error) {
nBytes, err := base64.RawURLEncoding.DecodeString(nB64)
if err != nil {
return nil, fmt.Errorf("zitadel: decode jwks n: %w", err)
}
eBytes, err := base64.RawURLEncoding.DecodeString(eB64)
if err != nil {
return nil, fmt.Errorf("zitadel: decode jwks e: %w", err)
}
n := new(big.Int).SetBytes(nBytes)
e := new(big.Int).SetBytes(eBytes).Int64()
if e <= 0 || e > int64(^uint(0)>>1) {
return nil, fmt.Errorf("zitadel: invalid jwks exponent")
}
return &rsa.PublicKey{N: n, E: int(e)}, nil
}
func stringClaim(claims jwt.MapClaims, key string) string {
v, ok := claims[key]
if !ok || v == nil {
return ""
}
switch t := v.(type) {
case string:
return t
default:
return fmt.Sprint(t)
}
}
func boolClaim(claims jwt.MapClaims, key string) bool {
v, ok := claims[key]
if !ok || v == nil {
return false
}
switch t := v.(type) {
case bool:
return t
case string:
return t == "true"
default:
return false
}
}
func audienceContains(aud any, want string) bool {
switch t := aud.(type) {
case string:
return t == want
case []any:
for _, item := range t {
if s, ok := item.(string); ok && s == want {
return true
}
}
case []string:
for _, s := range t {
if s == want {
return true
}
}
}
return false
}