264 lines
7.9 KiB
Go
264 lines
7.9 KiB
Go
package usecase
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
authconfig "gateway/internal/model/auth/config"
|
|
domrepo "gateway/internal/model/auth/domain/repository"
|
|
domusecase "gateway/internal/model/auth/domain/usecase"
|
|
|
|
"github.com/golang-jwt/jwt/v4"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type tokenUseCase struct {
|
|
cfg authconfig.Config
|
|
revoke domrepo.TokenRevokeStore
|
|
}
|
|
|
|
// TokenUseCaseParam wires TokenUseCase.
|
|
type TokenUseCaseParam struct {
|
|
Config authconfig.Config
|
|
Revoke domrepo.TokenRevokeStore
|
|
}
|
|
|
|
// MustTokenUseCase constructs TokenUseCase.
|
|
func MustTokenUseCase(param TokenUseCaseParam) domusecase.TokenUseCase {
|
|
cfg := param.Config.Defaults()
|
|
if !cfg.Enabled() {
|
|
panic("auth: JWT secrets are required")
|
|
}
|
|
return &tokenUseCase{cfg: cfg, revoke: param.Revoke}
|
|
}
|
|
|
|
func (uc *tokenUseCase) IssuePair(ctx context.Context, req *domusecase.IssuePairRequest) (*domusecase.TokenPair, error) {
|
|
if req == nil || req.TenantID == "" || req.UID == "" {
|
|
return nil, errb.InputMissingRequired("tenant_id and uid are required")
|
|
}
|
|
access, err := uc.sign(req, domusecase.TokenTypeAccess, uc.cfg.AccessExpire, uc.cfg.AccessSecret)
|
|
if err != nil {
|
|
return nil, errb.SysInternal("sign access token failed").WithCause(err)
|
|
}
|
|
refresh, err := uc.sign(req, domusecase.TokenTypeRefresh, uc.cfg.RefreshExpire, uc.cfg.RefreshSecret)
|
|
if err != nil {
|
|
return nil, errb.SysInternal("sign refresh token failed").WithCause(err)
|
|
}
|
|
if uc.revoke != nil {
|
|
accessTTL := time.Until(access.expiresAt)
|
|
refreshTTL := time.Until(refresh.expiresAt)
|
|
if err := uc.revoke.SavePair(ctx, access.jti, refresh.jti, accessTTL, refreshTTL); err != nil {
|
|
return nil, errb.DBError("save jwt pair failed").WithCause(err)
|
|
}
|
|
}
|
|
return &domusecase.TokenPair{
|
|
AccessToken: access.raw,
|
|
RefreshToken: refresh.raw,
|
|
ExpiresIn: uc.cfg.AccessExpire,
|
|
TokenType: "Bearer",
|
|
}, nil
|
|
}
|
|
|
|
func (uc *tokenUseCase) Refresh(ctx context.Context, refreshToken string) (*domusecase.TokenPair, error) {
|
|
if refreshToken == "" {
|
|
return nil, errb.InputMissingRequired("refresh_token is required")
|
|
}
|
|
claims, err := uc.parse(refreshToken, domusecase.TokenTypeRefresh, uc.cfg.RefreshSecret)
|
|
if err != nil {
|
|
if errors.Is(err, errInvalidToken) {
|
|
return nil, errb.AuthUnauthorized("invalid refresh token").WithCause(err)
|
|
}
|
|
return nil, errb.SysInternal("parse refresh token failed").WithCause(err)
|
|
}
|
|
if err := uc.ensureNotBlacklisted(ctx, claims.ID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
pair, err := uc.IssuePair(ctx, &domusecase.IssuePairRequest{
|
|
TenantID: claims.TenantID,
|
|
UID: claims.UID,
|
|
AuthGen: claims.AuthGen,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if uc.revoke != nil {
|
|
if err := uc.revoke.Blacklist(ctx, claims.ID, remainingTTL(claims.expiresAt)); err != nil {
|
|
return nil, errb.DBError("blacklist refresh token failed").WithCause(err)
|
|
}
|
|
if accessJTI, err := uc.revoke.GetPairedJTI(ctx, claims.ID); err != nil {
|
|
return nil, errb.DBError("read jwt pair failed").WithCause(err)
|
|
} else if accessJTI != "" {
|
|
if err := uc.revoke.Blacklist(ctx, accessJTI, time.Duration(uc.cfg.AccessExpire)*time.Second); err != nil {
|
|
return nil, errb.DBError("blacklist access token failed").WithCause(err)
|
|
}
|
|
if err := uc.revoke.DeletePair(ctx, accessJTI, claims.ID); err != nil {
|
|
return nil, errb.DBError("delete jwt pair failed").WithCause(err)
|
|
}
|
|
}
|
|
}
|
|
return pair, nil
|
|
}
|
|
|
|
func (uc *tokenUseCase) Logout(ctx context.Context, req *domusecase.LogoutRequest) error {
|
|
if req == nil || req.AccessToken == "" {
|
|
return errb.InputMissingRequired("access token is required")
|
|
}
|
|
if uc.revoke == nil {
|
|
return errb.SysNotImplemented("token revoke store not configured")
|
|
}
|
|
claims, err := uc.parse(req.AccessToken, domusecase.TokenTypeAccess, uc.cfg.AccessSecret)
|
|
if err != nil {
|
|
if errors.Is(err, errInvalidToken) {
|
|
return errb.AuthUnauthorized("invalid access token").WithCause(err)
|
|
}
|
|
return errb.SysInternal("parse access token failed").WithCause(err)
|
|
}
|
|
if err := uc.revoke.Blacklist(ctx, claims.ID, remainingTTL(claims.expiresAt)); err != nil {
|
|
return errb.DBError("blacklist access token failed").WithCause(err)
|
|
}
|
|
refreshJTI, err := uc.revoke.GetPairedJTI(ctx, claims.ID)
|
|
if err != nil {
|
|
return errb.DBError("read jwt pair failed").WithCause(err)
|
|
}
|
|
if refreshJTI != "" {
|
|
if err := uc.revoke.Blacklist(ctx, refreshJTI, time.Duration(uc.cfg.RefreshExpire)*time.Second); err != nil {
|
|
return errb.DBError("blacklist refresh token failed").WithCause(err)
|
|
}
|
|
}
|
|
if err := uc.revoke.DeletePair(ctx, claims.ID, refreshJTI); err != nil {
|
|
return errb.DBError("delete jwt pair failed").WithCause(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (uc *tokenUseCase) ParseAccessToken(ctx context.Context, accessToken string) (*domusecase.AccessClaims, error) {
|
|
if accessToken == "" {
|
|
return nil, errb.AuthUnauthorized("missing access token")
|
|
}
|
|
claims, err := uc.parse(accessToken, domusecase.TokenTypeAccess, uc.cfg.AccessSecret)
|
|
if err != nil {
|
|
if errors.Is(err, errInvalidToken) {
|
|
return nil, errb.AuthUnauthorized("invalid access token").WithCause(err)
|
|
}
|
|
return nil, errb.SysInternal("parse access token failed").WithCause(err)
|
|
}
|
|
if err := uc.ensureNotBlacklisted(ctx, claims.ID); err != nil {
|
|
return nil, err
|
|
}
|
|
return &domusecase.AccessClaims{
|
|
TenantID: claims.TenantID,
|
|
UID: claims.UID,
|
|
AuthGen: claims.AuthGen,
|
|
JTI: claims.ID,
|
|
}, nil
|
|
}
|
|
|
|
func (uc *tokenUseCase) ensureNotBlacklisted(ctx context.Context, jti string) error {
|
|
if uc.revoke == nil || jti == "" {
|
|
return nil
|
|
}
|
|
blacklisted, err := uc.revoke.IsBlacklisted(ctx, jti)
|
|
if err != nil {
|
|
return errb.DBError("check jwt blacklist failed").WithCause(err)
|
|
}
|
|
if blacklisted {
|
|
return errb.AuthUnauthorized("token revoked")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var errInvalidToken = errors.New("auth: invalid token")
|
|
|
|
type jwtClaims struct {
|
|
TenantID string `json:"tenant_id"`
|
|
UID string `json:"uid"`
|
|
Typ string `json:"typ"`
|
|
AuthGen int64 `json:"auth_gen"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
type parsedClaims struct {
|
|
TenantID string
|
|
UID string
|
|
AuthGen int64
|
|
ID string
|
|
expiresAt time.Time
|
|
}
|
|
|
|
type signedToken struct {
|
|
raw string
|
|
jti string
|
|
expiresAt time.Time
|
|
}
|
|
|
|
func (uc *tokenUseCase) sign(req *domusecase.IssuePairRequest, typ domusecase.TokenType, expireSec int64, secret string) (*signedToken, error) {
|
|
now := time.Now().UTC()
|
|
expiresAt := now.Add(time.Duration(expireSec) * time.Second)
|
|
jti := uuid.NewString()
|
|
claims := jwtClaims{
|
|
TenantID: req.TenantID,
|
|
UID: req.UID,
|
|
Typ: string(typ),
|
|
AuthGen: req.AuthGen,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: jti,
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
|
},
|
|
}
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
token.Header["kid"] = uc.cfg.ActiveKID
|
|
raw, err := token.SignedString([]byte(secret))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &signedToken{raw: raw, jti: jti, expiresAt: expiresAt}, nil
|
|
}
|
|
|
|
func (uc *tokenUseCase) parse(raw string, want domusecase.TokenType, secret string) (*parsedClaims, error) {
|
|
parsed, err := jwt.ParseWithClaims(raw, &jwtClaims{}, func(t *jwt.Token) (any, error) {
|
|
if t.Method != jwt.SigningMethodHS256 {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
|
}
|
|
return []byte(secret), nil
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: %w", errInvalidToken, err)
|
|
}
|
|
claims, ok := parsed.Claims.(*jwtClaims)
|
|
if !ok || !parsed.Valid {
|
|
return nil, errInvalidToken
|
|
}
|
|
if claims.Typ != string(want) {
|
|
return nil, errInvalidToken
|
|
}
|
|
if claims.TenantID == "" || claims.UID == "" {
|
|
return nil, errInvalidToken
|
|
}
|
|
expiresAt := time.Time{}
|
|
if claims.ExpiresAt != nil {
|
|
expiresAt = claims.ExpiresAt.Time
|
|
}
|
|
return &parsedClaims{
|
|
TenantID: claims.TenantID,
|
|
UID: claims.UID,
|
|
AuthGen: claims.AuthGen,
|
|
ID: claims.ID,
|
|
expiresAt: expiresAt,
|
|
}, nil
|
|
}
|
|
|
|
func remainingTTL(expiresAt time.Time) time.Duration {
|
|
if expiresAt.IsZero() {
|
|
return time.Second
|
|
}
|
|
ttl := time.Until(expiresAt)
|
|
if ttl < time.Second {
|
|
return time.Second
|
|
}
|
|
return ttl
|
|
}
|