template-monorepo/internal/model/auth/usecase/token_usecase.go

264 lines
7.9 KiB
Go

package usecase
import (
"context"
"errors"
"fmt"
"time"
authconfig "gateway/internal/model/auth/config"
domrepo "gateway/internal/model/auth/domain/repository"
domusecase "gateway/internal/model/auth/domain/usecase"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
)
type tokenUseCase struct {
cfg authconfig.Config
revoke domrepo.TokenRevokeStore
}
// TokenUseCaseParam wires TokenUseCase.
type TokenUseCaseParam struct {
Config authconfig.Config
Revoke domrepo.TokenRevokeStore
}
// MustTokenUseCase constructs TokenUseCase.
func MustTokenUseCase(param TokenUseCaseParam) domusecase.TokenUseCase {
cfg := param.Config.Defaults()
if !cfg.Enabled() {
panic("auth: JWT secrets are required")
}
return &tokenUseCase{cfg: cfg, revoke: param.Revoke}
}
func (uc *tokenUseCase) IssuePair(ctx context.Context, req *domusecase.IssuePairRequest) (*domusecase.TokenPair, error) {
if req == nil || req.TenantID == "" || req.UID == "" {
return nil, errb.InputMissingRequired("tenant_id and uid are required")
}
access, err := uc.sign(req, domusecase.TokenTypeAccess, uc.cfg.AccessExpire, uc.cfg.AccessSecret)
if err != nil {
return nil, errb.SysInternal("sign access token failed").WithCause(err)
}
refresh, err := uc.sign(req, domusecase.TokenTypeRefresh, uc.cfg.RefreshExpire, uc.cfg.RefreshSecret)
if err != nil {
return nil, errb.SysInternal("sign refresh token failed").WithCause(err)
}
if uc.revoke != nil {
accessTTL := time.Until(access.expiresAt)
refreshTTL := time.Until(refresh.expiresAt)
if err := uc.revoke.SavePair(ctx, access.jti, refresh.jti, accessTTL, refreshTTL); err != nil {
return nil, errb.DBError("save jwt pair failed").WithCause(err)
}
}
return &domusecase.TokenPair{
AccessToken: access.raw,
RefreshToken: refresh.raw,
ExpiresIn: uc.cfg.AccessExpire,
TokenType: "Bearer",
}, nil
}
func (uc *tokenUseCase) Refresh(ctx context.Context, refreshToken string) (*domusecase.TokenPair, error) {
if refreshToken == "" {
return nil, errb.InputMissingRequired("refresh_token is required")
}
claims, err := uc.parse(refreshToken, domusecase.TokenTypeRefresh, uc.cfg.RefreshSecret)
if err != nil {
if errors.Is(err, errInvalidToken) {
return nil, errb.AuthUnauthorized("invalid refresh token").WithCause(err)
}
return nil, errb.SysInternal("parse refresh token failed").WithCause(err)
}
if err := uc.ensureNotBlacklisted(ctx, claims.ID); err != nil {
return nil, err
}
pair, err := uc.IssuePair(ctx, &domusecase.IssuePairRequest{
TenantID: claims.TenantID,
UID: claims.UID,
AuthGen: claims.AuthGen,
})
if err != nil {
return nil, err
}
if uc.revoke != nil {
if err := uc.revoke.Blacklist(ctx, claims.ID, remainingTTL(claims.expiresAt)); err != nil {
return nil, errb.DBError("blacklist refresh token failed").WithCause(err)
}
if accessJTI, err := uc.revoke.GetPairedJTI(ctx, claims.ID); err != nil {
return nil, errb.DBError("read jwt pair failed").WithCause(err)
} else if accessJTI != "" {
if err := uc.revoke.Blacklist(ctx, accessJTI, time.Duration(uc.cfg.AccessExpire)*time.Second); err != nil {
return nil, errb.DBError("blacklist access token failed").WithCause(err)
}
if err := uc.revoke.DeletePair(ctx, accessJTI, claims.ID); err != nil {
return nil, errb.DBError("delete jwt pair failed").WithCause(err)
}
}
}
return pair, nil
}
func (uc *tokenUseCase) Logout(ctx context.Context, req *domusecase.LogoutRequest) error {
if req == nil || req.AccessToken == "" {
return errb.InputMissingRequired("access token is required")
}
if uc.revoke == nil {
return errb.SysNotImplemented("token revoke store not configured")
}
claims, err := uc.parse(req.AccessToken, domusecase.TokenTypeAccess, uc.cfg.AccessSecret)
if err != nil {
if errors.Is(err, errInvalidToken) {
return errb.AuthUnauthorized("invalid access token").WithCause(err)
}
return errb.SysInternal("parse access token failed").WithCause(err)
}
if err := uc.revoke.Blacklist(ctx, claims.ID, remainingTTL(claims.expiresAt)); err != nil {
return errb.DBError("blacklist access token failed").WithCause(err)
}
refreshJTI, err := uc.revoke.GetPairedJTI(ctx, claims.ID)
if err != nil {
return errb.DBError("read jwt pair failed").WithCause(err)
}
if refreshJTI != "" {
if err := uc.revoke.Blacklist(ctx, refreshJTI, time.Duration(uc.cfg.RefreshExpire)*time.Second); err != nil {
return errb.DBError("blacklist refresh token failed").WithCause(err)
}
}
if err := uc.revoke.DeletePair(ctx, claims.ID, refreshJTI); err != nil {
return errb.DBError("delete jwt pair failed").WithCause(err)
}
return nil
}
func (uc *tokenUseCase) ParseAccessToken(ctx context.Context, accessToken string) (*domusecase.AccessClaims, error) {
if accessToken == "" {
return nil, errb.AuthUnauthorized("missing access token")
}
claims, err := uc.parse(accessToken, domusecase.TokenTypeAccess, uc.cfg.AccessSecret)
if err != nil {
if errors.Is(err, errInvalidToken) {
return nil, errb.AuthUnauthorized("invalid access token").WithCause(err)
}
return nil, errb.SysInternal("parse access token failed").WithCause(err)
}
if err := uc.ensureNotBlacklisted(ctx, claims.ID); err != nil {
return nil, err
}
return &domusecase.AccessClaims{
TenantID: claims.TenantID,
UID: claims.UID,
AuthGen: claims.AuthGen,
JTI: claims.ID,
}, nil
}
func (uc *tokenUseCase) ensureNotBlacklisted(ctx context.Context, jti string) error {
if uc.revoke == nil || jti == "" {
return nil
}
blacklisted, err := uc.revoke.IsBlacklisted(ctx, jti)
if err != nil {
return errb.DBError("check jwt blacklist failed").WithCause(err)
}
if blacklisted {
return errb.AuthUnauthorized("token revoked")
}
return nil
}
var errInvalidToken = errors.New("auth: invalid token")
type jwtClaims struct {
TenantID string `json:"tenant_id"`
UID string `json:"uid"`
Typ string `json:"typ"`
AuthGen int64 `json:"auth_gen"`
jwt.RegisteredClaims
}
type parsedClaims struct {
TenantID string
UID string
AuthGen int64
ID string
expiresAt time.Time
}
type signedToken struct {
raw string
jti string
expiresAt time.Time
}
func (uc *tokenUseCase) sign(req *domusecase.IssuePairRequest, typ domusecase.TokenType, expireSec int64, secret string) (*signedToken, error) {
now := time.Now().UTC()
expiresAt := now.Add(time.Duration(expireSec) * time.Second)
jti := uuid.NewString()
claims := jwtClaims{
TenantID: req.TenantID,
UID: req.UID,
Typ: string(typ),
AuthGen: req.AuthGen,
RegisteredClaims: jwt.RegisteredClaims{
ID: jti,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(expiresAt),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token.Header["kid"] = uc.cfg.ActiveKID
raw, err := token.SignedString([]byte(secret))
if err != nil {
return nil, err
}
return &signedToken{raw: raw, jti: jti, expiresAt: expiresAt}, nil
}
func (uc *tokenUseCase) parse(raw string, want domusecase.TokenType, secret string) (*parsedClaims, error) {
parsed, err := jwt.ParseWithClaims(raw, &jwtClaims{}, func(t *jwt.Token) (any, error) {
if t.Method != jwt.SigningMethodHS256 {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return []byte(secret), nil
})
if err != nil {
return nil, fmt.Errorf("%w: %w", errInvalidToken, err)
}
claims, ok := parsed.Claims.(*jwtClaims)
if !ok || !parsed.Valid {
return nil, errInvalidToken
}
if claims.Typ != string(want) {
return nil, errInvalidToken
}
if claims.TenantID == "" || claims.UID == "" {
return nil, errInvalidToken
}
expiresAt := time.Time{}
if claims.ExpiresAt != nil {
expiresAt = claims.ExpiresAt.Time
}
return &parsedClaims{
TenantID: claims.TenantID,
UID: claims.UID,
AuthGen: claims.AuthGen,
ID: claims.ID,
expiresAt: expiresAt,
}, nil
}
func remainingTTL(expiresAt time.Time) time.Duration {
if expiresAt.IsZero() {
return time.Second
}
ttl := time.Until(expiresAt)
if ttl < time.Second {
return time.Second
}
return ttl
}