thread-master/backend/internal/model/auth/usecase/token.go

234 lines
7.4 KiB
Go

package usecase
import (
"context"
"errors"
"fmt"
"time"
"haixun-backend/internal/config"
app "haixun-backend/internal/library/errors"
"haixun-backend/internal/library/errors/code"
domrepo "haixun-backend/internal/model/auth/domain/repository"
domusecase "haixun-backend/internal/model/auth/domain/usecase"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
)
type tokenUseCase struct {
cfg config.AuthConf
revoke domrepo.TokenRevokeStore
}
func NewTokenUseCase(cfg config.AuthConf, revoke domrepo.TokenRevokeStore) domusecase.TokenUseCase {
cfg = normalizeConfig(cfg)
if cfg.AccessSecret == "" || cfg.RefreshSecret == "" {
// Fail fast: never fall back to a known/hardcoded secret, otherwise tokens
// could be forged. Provide secrets via env (HAIXUN_JWT_ACCESS_SECRET /
// HAIXUN_JWT_REFRESH_SECRET) or the dev config file.
panic("auth: AccessSecret and RefreshSecret must be configured")
}
return &tokenUseCase{cfg: cfg, revoke: revoke}
}
func (u *tokenUseCase) IssuePair(ctx context.Context, req domusecase.IssuePairRequest) (*domusecase.TokenPair, error) {
if req.TenantID == "" || req.UID == "" {
return nil, app.For(code.Auth).InputMissingRequired("tenant_id and uid are required")
}
access, err := u.sign(req, domusecase.TokenTypeAccess, u.cfg.AccessExpireSeconds, u.cfg.AccessSecret)
if err != nil {
return nil, app.For(code.Auth).SysInternal("sign access token failed").WithCause(err)
}
refresh, err := u.sign(req, domusecase.TokenTypeRefresh, u.cfg.RefreshExpireSeconds, u.cfg.RefreshSecret)
if err != nil {
return nil, app.For(code.Auth).SysInternal("sign refresh token failed").WithCause(err)
}
if u.revoke != nil {
if err := u.revoke.SavePair(ctx, access.jti, refresh.jti, time.Until(access.expiresAt), time.Until(refresh.expiresAt)); err != nil {
return nil, app.For(code.Auth).DBError("save jwt pair failed").WithCause(err)
}
}
return &domusecase.TokenPair{
AccessToken: access.raw,
RefreshToken: refresh.raw,
ExpiresIn: u.cfg.AccessExpireSeconds,
TokenType: "Bearer",
UID: req.UID,
}, nil
}
func (u *tokenUseCase) Refresh(ctx context.Context, refreshToken string) (*domusecase.TokenPair, error) {
if refreshToken == "" {
return nil, app.For(code.Auth).InputMissingRequired("refresh_token is required")
}
claims, err := u.parse(refreshToken, domusecase.TokenTypeRefresh, u.cfg.RefreshSecret)
if err != nil {
return nil, app.For(code.Auth).AuthUnauthorized("invalid refresh token").WithCause(err)
}
if err := u.ensureNotBlacklisted(ctx, claims.ID); err != nil {
return nil, err
}
pair, err := u.IssuePair(ctx, domusecase.IssuePairRequest{
TenantID: claims.TenantID,
UID: claims.UID,
AuthGen: claims.AuthGen,
})
if err != nil {
return nil, err
}
if u.revoke != nil {
_ = u.revoke.Blacklist(ctx, claims.ID, remainingTTL(claims.expiresAt))
if accessJTI, err := u.revoke.GetPairedJTI(ctx, claims.ID); err == nil && accessJTI != "" {
_ = u.revoke.Blacklist(ctx, accessJTI, time.Duration(u.cfg.AccessExpireSeconds)*time.Second)
_ = u.revoke.DeletePair(ctx, accessJTI, claims.ID)
}
}
return pair, nil
}
func (u *tokenUseCase) Logout(ctx context.Context, req domusecase.LogoutRequest) error {
if req.AccessToken == "" {
return app.For(code.Auth).InputMissingRequired("access token is required")
}
if u.revoke == nil {
return nil
}
claims, err := u.parse(req.AccessToken, domusecase.TokenTypeAccess, u.cfg.AccessSecret)
if err != nil {
return app.For(code.Auth).AuthUnauthorized("invalid access token").WithCause(err)
}
if err := u.revoke.Blacklist(ctx, claims.ID, remainingTTL(claims.expiresAt)); err != nil {
return app.For(code.Auth).DBError("blacklist access token failed").WithCause(err)
}
refreshJTI, err := u.revoke.GetPairedJTI(ctx, claims.ID)
if err != nil {
return app.For(code.Auth).DBError("read jwt pair failed").WithCause(err)
}
if refreshJTI != "" {
_ = u.revoke.Blacklist(ctx, refreshJTI, time.Duration(u.cfg.RefreshExpireSeconds)*time.Second)
}
return u.revoke.DeletePair(ctx, claims.ID, refreshJTI)
}
func (u *tokenUseCase) ParseAccessToken(ctx context.Context, accessToken string) (*domusecase.AccessClaims, error) {
if accessToken == "" {
return nil, app.For(code.Auth).AuthUnauthorized("missing access token")
}
claims, err := u.parse(accessToken, domusecase.TokenTypeAccess, u.cfg.AccessSecret)
if err != nil {
return nil, app.For(code.Auth).AuthUnauthorized("invalid access token").WithCause(err)
}
if err := u.ensureNotBlacklisted(ctx, claims.ID); err != nil {
return nil, err
}
return &domusecase.AccessClaims{
TenantID: claims.TenantID,
UID: claims.UID,
AuthGen: claims.AuthGen,
JTI: claims.ID,
}, nil
}
func (u *tokenUseCase) ensureNotBlacklisted(ctx context.Context, jti string) error {
if u.revoke == nil || jti == "" {
return nil
}
blacklisted, err := u.revoke.IsBlacklisted(ctx, jti)
if err != nil {
return app.For(code.Auth).DBError("check jwt blacklist failed").WithCause(err)
}
if blacklisted {
return app.For(code.Auth).AuthUnauthorized("token revoked")
}
return nil
}
type jwtClaims struct {
TenantID string `json:"tenant_id"`
UID string `json:"uid"`
Typ string `json:"typ"`
AuthGen int64 `json:"auth_gen"`
jwt.RegisteredClaims
}
type parsedClaims struct {
TenantID string
UID string
AuthGen int64
ID string
expiresAt time.Time
}
type signedToken struct {
raw string
jti string
expiresAt time.Time
}
func (u *tokenUseCase) sign(req domusecase.IssuePairRequest, typ domusecase.TokenType, expireSec int64, secret string) (*signedToken, error) {
now := time.Now().UTC()
expiresAt := now.Add(time.Duration(expireSec) * time.Second)
jti := uuid.NewString()
claims := jwtClaims{
TenantID: req.TenantID,
UID: req.UID,
Typ: string(typ),
AuthGen: req.AuthGen,
RegisteredClaims: jwt.RegisteredClaims{
ID: jti,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(expiresAt),
},
}
raw, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(secret))
if err != nil {
return nil, err
}
return &signedToken{raw: raw, jti: jti, expiresAt: expiresAt}, nil
}
func (u *tokenUseCase) parse(raw string, want domusecase.TokenType, secret string) (*parsedClaims, error) {
parsed, err := jwt.ParseWithClaims(raw, &jwtClaims{}, func(t *jwt.Token) (any, error) {
if t.Method != jwt.SigningMethodHS256 {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return []byte(secret), nil
})
if err != nil {
return nil, fmt.Errorf("%w: %w", errInvalidToken, err)
}
claims, ok := parsed.Claims.(*jwtClaims)
if !ok || !parsed.Valid || claims.Typ != string(want) || claims.TenantID == "" || claims.UID == "" {
return nil, errInvalidToken
}
expiresAt := time.Time{}
if claims.ExpiresAt != nil {
expiresAt = claims.ExpiresAt.Time
}
return &parsedClaims{TenantID: claims.TenantID, UID: claims.UID, AuthGen: claims.AuthGen, ID: claims.ID, expiresAt: expiresAt}, nil
}
func normalizeConfig(cfg config.AuthConf) config.AuthConf {
if cfg.AccessExpireSeconds <= 0 {
cfg.AccessExpireSeconds = 900
}
if cfg.RefreshExpireSeconds <= 0 {
cfg.RefreshExpireSeconds = 2592000
}
return cfg
}
func remainingTTL(expiresAt time.Time) time.Duration {
if expiresAt.IsZero() {
return time.Second
}
ttl := time.Until(expiresAt)
if ttl < time.Second {
return time.Second
}
return ttl
}
var errInvalidToken = errors.New("auth: invalid token")