251 lines
5.8 KiB
Go
251 lines
5.8 KiB
Go
|
|
package zitadel
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"crypto/rsa"
|
||
|
|
"encoding/base64"
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"math/big"
|
||
|
|
"net/http"
|
||
|
|
"sync"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/golang-jwt/jwt/v4"
|
||
|
|
)
|
||
|
|
|
||
|
|
type jwksCache struct {
|
||
|
|
mu sync.RWMutex
|
||
|
|
fetchedAt time.Time
|
||
|
|
keys map[string]*rsa.PublicKey
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *Client) jwksURL() string {
|
||
|
|
if c.conf.JWKSUrl != "" {
|
||
|
|
return c.conf.JWKSUrl
|
||
|
|
}
|
||
|
|
return c.issuer + "/oauth/v2/keys"
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *Client) VerifyIDToken(ctx context.Context, idToken string) (*IDTokenClaims, error) {
|
||
|
|
if c == nil {
|
||
|
|
return nil, ErrNotConfigured
|
||
|
|
}
|
||
|
|
if idToken == "" {
|
||
|
|
return nil, fmt.Errorf("zitadel: id_token is required")
|
||
|
|
}
|
||
|
|
if c.conf.OAuthClientID == "" {
|
||
|
|
return nil, fmt.Errorf("zitadel: oauth client id is required for id_token verification")
|
||
|
|
}
|
||
|
|
|
||
|
|
parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()}))
|
||
|
|
token, err := parser.Parse(idToken, func(t *jwt.Token) (any, error) {
|
||
|
|
kid, ok := t.Header["kid"].(string)
|
||
|
|
if !ok || kid == "" {
|
||
|
|
return nil, fmt.Errorf("zitadel: id_token missing kid")
|
||
|
|
}
|
||
|
|
return c.publicKeyForKID(ctx, kid)
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("%w: %w", ErrInvalidIDToken, err)
|
||
|
|
}
|
||
|
|
if !token.Valid {
|
||
|
|
return nil, ErrInvalidIDToken
|
||
|
|
}
|
||
|
|
|
||
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||
|
|
if !ok {
|
||
|
|
return nil, ErrInvalidIDToken
|
||
|
|
}
|
||
|
|
if err := c.validateIDTokenClaims(claims); err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
out := &IDTokenClaims{
|
||
|
|
Sub: stringClaim(claims, "sub"),
|
||
|
|
Email: stringClaim(claims, "email"),
|
||
|
|
EmailVerified: boolClaim(claims, "email_verified"),
|
||
|
|
Name: stringClaim(claims, "name"),
|
||
|
|
Locale: stringClaim(claims, "locale"),
|
||
|
|
}
|
||
|
|
if out.Sub == "" {
|
||
|
|
return nil, ErrInvalidIDToken
|
||
|
|
}
|
||
|
|
return out, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *Client) validateIDTokenClaims(claims jwt.MapClaims) error {
|
||
|
|
iss := stringClaim(claims, "iss")
|
||
|
|
if iss != c.issuer && iss != c.issuer+"/" {
|
||
|
|
return fmt.Errorf("%w: unexpected iss", ErrInvalidIDToken)
|
||
|
|
}
|
||
|
|
if !audienceContains(claims["aud"], c.conf.OAuthClientID) {
|
||
|
|
return fmt.Errorf("%w: unexpected aud", ErrInvalidIDToken)
|
||
|
|
}
|
||
|
|
expRaw, ok := claims["exp"]
|
||
|
|
if !ok {
|
||
|
|
return fmt.Errorf("%w: missing exp", ErrInvalidIDToken)
|
||
|
|
}
|
||
|
|
var expUnix int64
|
||
|
|
switch t := expRaw.(type) {
|
||
|
|
case float64:
|
||
|
|
expUnix = int64(t)
|
||
|
|
case json.Number:
|
||
|
|
v, err := t.Int64()
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("%w: invalid exp", ErrInvalidIDToken)
|
||
|
|
}
|
||
|
|
expUnix = v
|
||
|
|
default:
|
||
|
|
return fmt.Errorf("%w: invalid exp", ErrInvalidIDToken)
|
||
|
|
}
|
||
|
|
if time.Now().UTC().Unix() >= expUnix {
|
||
|
|
return fmt.Errorf("%w: token expired", ErrInvalidIDToken)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *Client) publicKeyForKID(ctx context.Context, kid string) (*rsa.PublicKey, error) {
|
||
|
|
if c.jwks == nil {
|
||
|
|
c.jwks = &jwksCache{keys: make(map[string]*rsa.PublicKey)}
|
||
|
|
}
|
||
|
|
c.jwks.mu.RLock()
|
||
|
|
if key, ok := c.jwks.keys[kid]; ok && time.Since(c.jwks.fetchedAt) < 5*time.Minute {
|
||
|
|
c.jwks.mu.RUnlock()
|
||
|
|
return key, nil
|
||
|
|
}
|
||
|
|
c.jwks.mu.RUnlock()
|
||
|
|
|
||
|
|
c.jwks.mu.Lock()
|
||
|
|
defer c.jwks.mu.Unlock()
|
||
|
|
if key, ok := c.jwks.keys[kid]; ok && time.Since(c.jwks.fetchedAt) < 5*time.Minute {
|
||
|
|
return key, nil
|
||
|
|
}
|
||
|
|
if err := c.refreshJWKS(ctx); err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
key, ok := c.jwks.keys[kid]
|
||
|
|
if !ok {
|
||
|
|
return nil, fmt.Errorf("zitadel: jwks kid not found: %s", kid)
|
||
|
|
}
|
||
|
|
return key, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *Client) refreshJWKS(ctx context.Context) error {
|
||
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.jwksURL(), http.NoBody)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("zitadel: jwks request: %w", err)
|
||
|
|
}
|
||
|
|
req.Header.Set("Accept", "application/json")
|
||
|
|
|
||
|
|
resp, err := c.http.Do(req)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("zitadel: jwks request: %w", err)
|
||
|
|
}
|
||
|
|
defer resp.Body.Close()
|
||
|
|
|
||
|
|
raw, err := io.ReadAll(resp.Body)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("zitadel: read jwks body: %w", err)
|
||
|
|
}
|
||
|
|
if resp.StatusCode != http.StatusOK {
|
||
|
|
return fmt.Errorf("zitadel: jwks request: status %d: %s", resp.StatusCode, truncateBody(raw))
|
||
|
|
}
|
||
|
|
|
||
|
|
var payload struct {
|
||
|
|
Keys []struct {
|
||
|
|
Kty string `json:"kty"`
|
||
|
|
Kid string `json:"kid"`
|
||
|
|
N string `json:"n"`
|
||
|
|
E string `json:"e"`
|
||
|
|
} `json:"keys"`
|
||
|
|
}
|
||
|
|
if err := json.Unmarshal(raw, &payload); err != nil {
|
||
|
|
return fmt.Errorf("zitadel: decode jwks: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
keys := make(map[string]*rsa.PublicKey, len(payload.Keys))
|
||
|
|
for _, k := range payload.Keys {
|
||
|
|
if k.Kty != "RSA" || k.Kid == "" || k.N == "" || k.E == "" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
pub, err := rsaPublicKeyFromModExp(k.N, k.E)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
keys[k.Kid] = pub
|
||
|
|
}
|
||
|
|
if len(keys) == 0 {
|
||
|
|
return fmt.Errorf("zitadel: jwks contains no usable rsa keys")
|
||
|
|
}
|
||
|
|
c.jwks.keys = keys
|
||
|
|
c.jwks.fetchedAt = time.Now().UTC()
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func rsaPublicKeyFromModExp(nB64, eB64 string) (*rsa.PublicKey, error) {
|
||
|
|
nBytes, err := base64.RawURLEncoding.DecodeString(nB64)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("zitadel: decode jwks n: %w", err)
|
||
|
|
}
|
||
|
|
eBytes, err := base64.RawURLEncoding.DecodeString(eB64)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("zitadel: decode jwks e: %w", err)
|
||
|
|
}
|
||
|
|
n := new(big.Int).SetBytes(nBytes)
|
||
|
|
e := new(big.Int).SetBytes(eBytes).Int64()
|
||
|
|
if e <= 0 || e > int64(^uint(0)>>1) {
|
||
|
|
return nil, fmt.Errorf("zitadel: invalid jwks exponent")
|
||
|
|
}
|
||
|
|
return &rsa.PublicKey{N: n, E: int(e)}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func stringClaim(claims jwt.MapClaims, key string) string {
|
||
|
|
v, ok := claims[key]
|
||
|
|
if !ok || v == nil {
|
||
|
|
return ""
|
||
|
|
}
|
||
|
|
switch t := v.(type) {
|
||
|
|
case string:
|
||
|
|
return t
|
||
|
|
default:
|
||
|
|
return fmt.Sprint(t)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func boolClaim(claims jwt.MapClaims, key string) bool {
|
||
|
|
v, ok := claims[key]
|
||
|
|
if !ok || v == nil {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
switch t := v.(type) {
|
||
|
|
case bool:
|
||
|
|
return t
|
||
|
|
case string:
|
||
|
|
return t == "true"
|
||
|
|
default:
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func audienceContains(aud any, want string) bool {
|
||
|
|
switch t := aud.(type) {
|
||
|
|
case string:
|
||
|
|
return t == want
|
||
|
|
case []any:
|
||
|
|
for _, item := range t {
|
||
|
|
if s, ok := item.(string); ok && s == want {
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
case []string:
|
||
|
|
for _, s := range t {
|
||
|
|
if s == want {
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return false
|
||
|
|
}
|