710 lines
20 KiB
Go
Executable File
710 lines
20 KiB
Go
Executable File
package usecase
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"strconv"
|
||
"time"
|
||
|
||
"backend/internal/config"
|
||
"backend/pkg/library/errs"
|
||
"backend/pkg/library/errs/code"
|
||
"backend/pkg/permission/domain/entity"
|
||
"backend/pkg/permission/domain/repository"
|
||
"backend/pkg/permission/domain/token"
|
||
"backend/pkg/permission/domain/usecase"
|
||
|
||
"github.com/segmentio/ksuid"
|
||
"github.com/zeromicro/go-zero/core/logx"
|
||
)
|
||
|
||
type TokenUseCaseParam struct {
|
||
TokenRepo repository.TokenRepository
|
||
|
||
Config *config.Config
|
||
}
|
||
|
||
type TokenUseCase struct {
|
||
TokenUseCaseParam
|
||
}
|
||
|
||
func (use *TokenUseCase) ReadTokenBasicData(ctx context.Context, token string) (map[string]string, error) {
|
||
claims, err := parseClaims(token, use.Config.Token.AccessSecret, false)
|
||
if err != nil {
|
||
return nil,
|
||
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "parseClaims",
|
||
req: token,
|
||
err: err,
|
||
message: "validate token claims error",
|
||
errorCode: code.TokenValidateError,
|
||
})
|
||
}
|
||
|
||
return claims, nil
|
||
}
|
||
|
||
func MustTokenUseCase(param TokenUseCaseParam) usecase.TokenUseCase {
|
||
return &TokenUseCase{
|
||
param,
|
||
}
|
||
}
|
||
|
||
// ============================================ token ============================================
|
||
|
||
func (use *TokenUseCase) NewToken(ctx context.Context, req entity.AuthorizationReq) (entity.TokenResp, error) {
|
||
tokenObj, err := use.newToken(ctx, &req)
|
||
if err != nil {
|
||
return entity.TokenResp{}, err
|
||
}
|
||
|
||
err = use.TokenRepo.Create(ctx, *tokenObj)
|
||
if err != nil {
|
||
return entity.TokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "TokenRepo.Create",
|
||
req: req,
|
||
err: err,
|
||
message: "failed to create token",
|
||
errorCode: code.TokenCreateError,
|
||
})
|
||
}
|
||
|
||
return entity.TokenResp{
|
||
AccessToken: tokenObj.AccessToken,
|
||
TokenType: token.TypeBearer.String(),
|
||
ExpiresIn: int64(tokenObj.ExpiresIn),
|
||
RefreshToken: tokenObj.RefreshToken,
|
||
}, nil
|
||
}
|
||
|
||
func (use *TokenUseCase) newToken(ctx context.Context, req *entity.AuthorizationReq) (*entity.Token, error) {
|
||
// 準備建立 Token 所需
|
||
now := time.Now().UTC()
|
||
expires := req.Expires
|
||
refreshExpires := req.Expires
|
||
|
||
if expires <= 0 {
|
||
// 將時間加上 n 秒
|
||
sec := use.Config.Token.AccessTokenExpiry
|
||
// 獲取 Unix 時間戳
|
||
expires = now.Add(sec).Unix()
|
||
refreshExpires = expires
|
||
}
|
||
|
||
// 如果這是一個 Refresh Token 過期時間要比普通的Token 長
|
||
if req.IsRefreshToken {
|
||
// 獲取 Unix 時間戳
|
||
refresh := use.Config.Token.RefreshTokenExpiry
|
||
refreshExpires = now.Add(refresh).Unix()
|
||
}
|
||
|
||
token := entity.Token{
|
||
ID: ksuid.New().String(),
|
||
DeviceID: req.DeviceID,
|
||
ExpiresIn: int(expires),
|
||
RefreshExpiresIn: int(refreshExpires),
|
||
AccessCreateAt: now,
|
||
RefreshCreateAt: now,
|
||
}
|
||
|
||
tc := make(tokenClaims)
|
||
if req.Data != nil {
|
||
for k, v := range req.Data {
|
||
tc[k] = v
|
||
}
|
||
}
|
||
tc.SetRole(req.Role)
|
||
tc.SetID(token.ID)
|
||
tc.SetScope(req.Scope)
|
||
tc.SetAccount(req.Account)
|
||
|
||
token.UID = tc.UID()
|
||
|
||
if req.DeviceID != "" {
|
||
tc.SetDeviceID(req.DeviceID)
|
||
}
|
||
|
||
var err error
|
||
token.AccessToken, err = accessTokenGenerator(token, tc, use.Config.Token.AccessSecret)
|
||
if err != nil {
|
||
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "accessTokenGenerator",
|
||
req: req,
|
||
err: err,
|
||
message: "failed to generator access token",
|
||
errorCode: code.TokenCreateError,
|
||
})
|
||
}
|
||
|
||
if req.IsRefreshToken {
|
||
token.RefreshToken = refreshTokenGenerator(token.AccessToken)
|
||
}
|
||
|
||
return &token, nil
|
||
}
|
||
|
||
func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTokenReq) (entity.RefreshTokenResp, error) {
|
||
// Step 1: 檢查 refresh token
|
||
tokenObj, err := use.TokenRepo.GetAccessTokenByOneTimeToken(ctx, req.Token)
|
||
if err != nil {
|
||
return entity.RefreshTokenResp{},
|
||
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "TokenRepo.GetAccessTokenByOneTimeToken",
|
||
req: req,
|
||
err: err,
|
||
message: "failed to get access token",
|
||
errorCode: code.TokenValidateError,
|
||
})
|
||
}
|
||
|
||
// Step 2: 提取 Claims Data
|
||
claimsData, err := parseClaims(tokenObj.AccessToken, use.Config.Token.AccessSecret, false)
|
||
if err != nil {
|
||
return entity.RefreshTokenResp{},
|
||
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "extractClaims",
|
||
req: req,
|
||
err: err,
|
||
message: "failed to extract claims",
|
||
errorCode: code.TokenValidateError,
|
||
})
|
||
}
|
||
|
||
// Step 3: 創建新 token
|
||
credentials := token.ClientCredentials
|
||
newToken, err := use.newToken(ctx, &entity.AuthorizationReq{
|
||
GrantType: credentials.ToString(),
|
||
Scope: claimsData.Scope(),
|
||
DeviceID: req.DeviceID,
|
||
Data: claimsData,
|
||
Expires: req.Expires,
|
||
IsRefreshToken: true,
|
||
Account: claimsData.Account(),
|
||
Role: claimsData.Role(),
|
||
})
|
||
if err != nil {
|
||
return entity.RefreshTokenResp{},
|
||
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "use.newToken",
|
||
req: req,
|
||
err: err,
|
||
message: "failed to create new token",
|
||
errorCode: code.TokenValidateError,
|
||
})
|
||
}
|
||
|
||
if err := use.TokenRepo.Create(ctx, *newToken); err != nil {
|
||
return entity.RefreshTokenResp{},
|
||
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "TokenRepo.Create",
|
||
req: req,
|
||
err: err,
|
||
message: "failed to create new token",
|
||
errorCode: code.TokenValidateError,
|
||
})
|
||
}
|
||
|
||
// Step 4: 刪除舊 token 並創建新 token
|
||
if err := use.TokenRepo.Delete(ctx, tokenObj); err != nil {
|
||
return entity.RefreshTokenResp{},
|
||
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "TokenRepo.Delete",
|
||
req: req,
|
||
err: err,
|
||
message: "failed to delete old token",
|
||
errorCode: code.TokenValidateError,
|
||
})
|
||
}
|
||
|
||
// 返回新的 Token 響應
|
||
return entity.RefreshTokenResp{
|
||
Token: newToken.AccessToken,
|
||
OneTimeToken: newToken.RefreshToken,
|
||
ExpiresIn: int64(newToken.ExpiresIn),
|
||
TokenType: token.TypeBearer.String(),
|
||
}, nil
|
||
}
|
||
|
||
func (use *TokenUseCase) CancelToken(ctx context.Context, req entity.CancelTokenReq) error {
|
||
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, false)
|
||
if err != nil {
|
||
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "CancelToken extractClaims",
|
||
req: req,
|
||
err: err,
|
||
message: "failed to get token claims",
|
||
errorCode: code.TokenValidateError,
|
||
})
|
||
}
|
||
|
||
token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
|
||
if err != nil {
|
||
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "TokenRepo GetAccessTokenByID",
|
||
req: req,
|
||
err: err,
|
||
message: fmt.Sprintf("failed to get token claims :%s", claims.ID()),
|
||
errorCode: code.TokenValidateError,
|
||
})
|
||
}
|
||
|
||
err = use.TokenRepo.Delete(ctx, token)
|
||
if err != nil {
|
||
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "TokenRepo Delete",
|
||
req: req,
|
||
err: err,
|
||
message: fmt.Sprintf("failed to delete token :%s", token.ID),
|
||
errorCode: code.TokenValidateError,
|
||
})
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (use *TokenUseCase) ValidationToken(ctx context.Context, req entity.ValidationTokenReq) (entity.ValidationTokenResp, error) {
|
||
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, true)
|
||
if err != nil {
|
||
return entity.ValidationTokenResp{},
|
||
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "parseClaims",
|
||
req: req,
|
||
err: err,
|
||
message: "validate token claims error",
|
||
errorCode: code.TokenValidateError,
|
||
})
|
||
}
|
||
|
||
token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
|
||
if err != nil {
|
||
return entity.ValidationTokenResp{},
|
||
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "TokenRepo.GetAccessTokenByID",
|
||
req: req,
|
||
err: err,
|
||
message: fmt.Sprintf("failed to get token :%s", claims.ID()),
|
||
errorCode: code.TokenValidateError,
|
||
})
|
||
}
|
||
|
||
return entity.ValidationTokenResp{
|
||
Token: entity.Token{
|
||
ID: token.ID,
|
||
UID: token.UID,
|
||
DeviceID: token.DeviceID,
|
||
AccessCreateAt: token.AccessCreateAt,
|
||
AccessToken: token.AccessToken,
|
||
ExpiresIn: token.ExpiresIn,
|
||
RefreshToken: token.RefreshToken,
|
||
RefreshExpiresIn: token.RefreshExpiresIn,
|
||
RefreshCreateAt: token.RefreshCreateAt,
|
||
},
|
||
Data: claims,
|
||
}, nil
|
||
}
|
||
|
||
func (use *TokenUseCase) CancelTokens(ctx context.Context, req entity.DoTokenByUIDReq) error {
|
||
if req.UID != "" {
|
||
err := use.TokenRepo.DeleteAccessTokensByUID(ctx, req.UID)
|
||
if err != nil {
|
||
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "TokenRepo.DeleteAccessTokensByUID",
|
||
req: req,
|
||
err: err,
|
||
message: "failed to cancel tokens by uid",
|
||
errorCode: code.TokenValidateError,
|
||
})
|
||
}
|
||
}
|
||
|
||
if len(req.IDs) > 0 {
|
||
err := use.TokenRepo.DeleteAccessTokenByID(ctx, req.IDs)
|
||
if err != nil {
|
||
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "TokenRepo.DeleteAccessTokenByID",
|
||
req: req,
|
||
err: err,
|
||
message: "failed to cancel tokens by token ids",
|
||
errorCode: code.TokenValidateError,
|
||
})
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (use *TokenUseCase) CancelTokenByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) error {
|
||
err := use.TokenRepo.DeleteAccessTokensByDeviceID(ctx, req.DeviceID)
|
||
if err != nil {
|
||
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "TokenRepo.DeleteAccessTokensByDeviceID",
|
||
req: req,
|
||
err: err,
|
||
message: "failed to cancel token by device id",
|
||
errorCode: code.TokenValidateError,
|
||
})
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (use *TokenUseCase) GetUserTokensByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) ([]*entity.TokenResp, error) {
|
||
uidTokens, err := use.TokenRepo.GetAccessTokensByDeviceID(ctx, req.DeviceID)
|
||
if err != nil {
|
||
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "TokenRepo.GetAccessTokensByDeviceID",
|
||
req: req,
|
||
err: err,
|
||
message: "failed to get token by device id",
|
||
errorCode: code.TokenNotFound,
|
||
})
|
||
}
|
||
|
||
tokens := make([]*entity.TokenResp, 0, len(uidTokens))
|
||
for _, v := range uidTokens {
|
||
tokens = append(tokens, &entity.TokenResp{
|
||
AccessToken: v.AccessToken,
|
||
TokenType: token.TypeBearer.String(),
|
||
ExpiresIn: int64(v.ExpiresIn),
|
||
RefreshToken: v.RefreshToken,
|
||
})
|
||
}
|
||
|
||
return tokens, nil
|
||
}
|
||
|
||
func (use *TokenUseCase) GetUserTokensByUID(ctx context.Context, req entity.QueryTokenByUIDReq) ([]*entity.TokenResp, error) {
|
||
uidTokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, req.UID)
|
||
if err != nil {
|
||
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "TokenRepo.GetAccessTokensByUID",
|
||
req: req,
|
||
err: err,
|
||
message: "failed to get token by uid",
|
||
errorCode: code.TokenNotFound,
|
||
})
|
||
}
|
||
|
||
tokens := make([]*entity.TokenResp, 0, len(uidTokens))
|
||
for _, v := range uidTokens {
|
||
tokens = append(tokens, &entity.TokenResp{
|
||
AccessToken: v.AccessToken,
|
||
TokenType: token.TypeBearer.String(),
|
||
ExpiresIn: int64(v.ExpiresIn),
|
||
RefreshToken: v.RefreshToken,
|
||
})
|
||
}
|
||
|
||
return tokens, nil
|
||
}
|
||
|
||
func (use *TokenUseCase) NewOneTimeToken(ctx context.Context, req entity.CreateOneTimeTokenReq) (entity.CreateOneTimeTokenResp, error) {
|
||
// 驗證Token
|
||
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, false)
|
||
if err != nil {
|
||
return entity.CreateOneTimeTokenResp{},
|
||
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "parseClaims",
|
||
req: req,
|
||
err: err,
|
||
message: "failed to get token claims",
|
||
errorCode: code.OneTimeTokenError,
|
||
})
|
||
}
|
||
|
||
tokenObj, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
|
||
if err != nil {
|
||
return entity.CreateOneTimeTokenResp{},
|
||
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "TokenRepo.GetAccessTokenByID",
|
||
req: req,
|
||
err: err,
|
||
message: "failed to get token by id",
|
||
errorCode: code.OneTimeTokenError,
|
||
})
|
||
}
|
||
|
||
oneTimeToken := refreshTokenGenerator(ksuid.New().String())
|
||
key := token.TicketKeyPrefix + oneTimeToken
|
||
if err = use.TokenRepo.CreateOneTimeToken(ctx, key, entity.Ticket{
|
||
Data: claims,
|
||
Token: tokenObj,
|
||
}, time.Minute); err != nil {
|
||
return entity.CreateOneTimeTokenResp{},
|
||
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "TokenRepo.CreateOneTimeToken",
|
||
req: req,
|
||
err: err,
|
||
message: "create one time token error",
|
||
errorCode: code.OneTimeTokenError,
|
||
})
|
||
}
|
||
|
||
return entity.CreateOneTimeTokenResp{
|
||
OneTimeToken: oneTimeToken,
|
||
}, nil
|
||
}
|
||
|
||
func (use *TokenUseCase) CancelOneTimeToken(ctx context.Context, req entity.CancelOneTimeTokenReq) error {
|
||
err := use.TokenRepo.DeleteOneTimeToken(ctx, req.Token, nil)
|
||
if err != nil {
|
||
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "TokenRepo.DeleteOneTimeToken",
|
||
req: req,
|
||
err: err,
|
||
message: "failed to del one time token by token",
|
||
errorCode: code.OneTimeTokenError,
|
||
})
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
type wrapTokenErrorReq struct {
|
||
funcName string
|
||
req any
|
||
err error
|
||
message string
|
||
errorCode uint32
|
||
}
|
||
|
||
// wrapTokenError 將錯誤信息封裝到 errs.LibError 中
|
||
func (use *TokenUseCase) wrapTokenError(ctx context.Context, param wrapTokenErrorReq) error {
|
||
logFields := []logx.LogField{
|
||
{Key: "req", Value: param.req},
|
||
{Key: "func", Value: param.funcName},
|
||
{Key: "err", Value: param.err.Error()},
|
||
}
|
||
|
||
logx.WithContext(ctx).Errorw(param.message, logFields...)
|
||
|
||
wrappedErr := errs.NewError(
|
||
code.CatToken,
|
||
code.CatToken,
|
||
param.errorCode,
|
||
param.message,
|
||
).Wrap(param.err)
|
||
|
||
return wrappedErr
|
||
}
|
||
|
||
// BlacklistToken 將 JWT token 加入黑名單 (立即撤銷)
|
||
func (use *TokenUseCase) BlacklistToken(ctx context.Context, token string, reason string) error {
|
||
// 解析 JWT 獲取完整的 claims
|
||
claimMap, err := parseToken(token, use.Config.Token.AccessSecret, false)
|
||
if err != nil {
|
||
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "BlacklistToken.parseToken",
|
||
req: token,
|
||
err: err,
|
||
message: "failed to parse token claims",
|
||
errorCode: code.InvalidJWT,
|
||
})
|
||
}
|
||
|
||
// 獲取 JTI (JWT ID)
|
||
jti, exists := claimMap["jti"]
|
||
if !exists {
|
||
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "BlacklistToken.getJTI",
|
||
req: token,
|
||
err: entity.ErrInvalidJTI,
|
||
message: "token missing JTI claim",
|
||
errorCode: code.InvalidJWT,
|
||
})
|
||
}
|
||
|
||
jtiStr, ok := jti.(string)
|
||
if !ok {
|
||
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "BlacklistToken.convertJTI",
|
||
req: token,
|
||
err: entity.ErrInvalidJTI,
|
||
message: "JTI claim is not a string",
|
||
errorCode: code.InvalidJWT,
|
||
})
|
||
}
|
||
|
||
// 獲取 UID (可能在 data 中)
|
||
var uid string
|
||
if dataInterface, exists := claimMap["data"]; exists {
|
||
if dataMap, ok := dataInterface.(map[string]interface{}); ok {
|
||
if uidInterface, exists := dataMap["uid"]; exists {
|
||
uid, _ = uidInterface.(string)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 獲取過期時間
|
||
exp, exists := claimMap["exp"]
|
||
if !exists {
|
||
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "BlacklistToken.getExp",
|
||
req: token,
|
||
err: entity.ErrTokenExpired,
|
||
message: "token missing exp claim",
|
||
errorCode: code.TokenExpired,
|
||
})
|
||
}
|
||
|
||
// 將 exp 轉換為 int64 (JWT 中通常是 float64)
|
||
var expInt int64
|
||
switch v := exp.(type) {
|
||
case float64:
|
||
expInt = int64(v)
|
||
case int64:
|
||
expInt = v
|
||
case string:
|
||
parsedExp, err := strconv.ParseInt(v, 10, 64)
|
||
if err != nil {
|
||
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "BlacklistToken.parseExp",
|
||
req: token,
|
||
err: err,
|
||
message: "failed to parse exp claim",
|
||
errorCode: code.TokenExpired,
|
||
})
|
||
}
|
||
expInt = parsedExp
|
||
default:
|
||
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "BlacklistToken.convertExp",
|
||
req: token,
|
||
err: fmt.Errorf("exp claim is not a valid type: %T", exp),
|
||
message: "exp claim type conversion failed",
|
||
errorCode: code.TokenExpired,
|
||
})
|
||
}
|
||
|
||
// 創建黑名單條目
|
||
blacklistEntry := &entity.BlacklistEntry{
|
||
JTI: jtiStr,
|
||
UID: uid,
|
||
ExpiresAt: expInt,
|
||
CreatedAt: time.Now().Unix(),
|
||
}
|
||
|
||
// 添加到黑名單
|
||
err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算
|
||
if err != nil {
|
||
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "BlacklistToken.AddToBlacklist",
|
||
req: jtiStr,
|
||
err: err,
|
||
message: "failed to add token to blacklist",
|
||
errorCode: code.TokenCreateError,
|
||
})
|
||
}
|
||
|
||
logx.WithContext(ctx).Infow("token blacklisted",
|
||
logx.Field("jti", jtiStr),
|
||
logx.Field("uid", uid),
|
||
logx.Field("reason", reason))
|
||
|
||
return nil
|
||
}
|
||
|
||
// IsTokenBlacklisted 檢查 JWT token 是否在黑名單中
|
||
func (use *TokenUseCase) IsTokenBlacklisted(ctx context.Context, jti string) (bool, error) {
|
||
isBlacklisted, err := use.TokenRepo.IsBlacklisted(ctx, jti)
|
||
if err != nil {
|
||
return false, use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "IsTokenBlacklisted",
|
||
req: jti,
|
||
err: err,
|
||
message: "failed to check blacklist status",
|
||
errorCode: code.TokenValidateError,
|
||
})
|
||
}
|
||
|
||
return isBlacklisted, nil
|
||
}
|
||
|
||
// BlacklistAllUserTokens 將用戶的所有 token 加入黑名單 (全設備登出)
|
||
func (use *TokenUseCase) BlacklistAllUserTokens(ctx context.Context, uid string, reason string) error {
|
||
// 獲取用戶的所有 token
|
||
tokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, uid)
|
||
if err != nil {
|
||
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
funcName: "BlacklistAllUserTokens.GetAccessTokensByUID",
|
||
req: uid,
|
||
err: err,
|
||
message: "failed to get user tokens",
|
||
errorCode: code.TokenValidateError,
|
||
})
|
||
}
|
||
|
||
// 為每個 token 創建黑名單條目
|
||
for _, token := range tokens {
|
||
// 解析 token 獲取 JTI 和過期時間
|
||
claims, err := parseClaims(token.AccessToken, use.Config.Token.AccessSecret, false)
|
||
if err != nil {
|
||
logx.WithContext(ctx).Errorw("failed to parse token for blacklisting",
|
||
logx.Field("uid", uid),
|
||
logx.Field("tokenID", token.ID),
|
||
logx.Field("error", err))
|
||
continue // 跳過無效的 token,繼續處理其他 token
|
||
}
|
||
|
||
jti, exists := claims["jti"]
|
||
if !exists || jti == "" {
|
||
logx.WithContext(ctx).Errorw("token missing JTI claim",
|
||
logx.Field("uid", uid),
|
||
logx.Field("tokenID", token.ID))
|
||
continue
|
||
}
|
||
|
||
exp, exists := claims["exp"]
|
||
if !exists {
|
||
logx.WithContext(ctx).Errorw("token missing exp claim",
|
||
logx.Field("uid", uid),
|
||
logx.Field("tokenID", token.ID))
|
||
continue
|
||
}
|
||
|
||
// 將 exp 字符串轉換為 int64
|
||
expInt, err := strconv.ParseInt(exp, 10, 64)
|
||
if err != nil {
|
||
logx.WithContext(ctx).Errorw("failed to parse exp claim",
|
||
logx.Field("uid", uid),
|
||
logx.Field("tokenID", token.ID),
|
||
logx.Field("error", err))
|
||
continue
|
||
}
|
||
|
||
// 創建黑名單條目
|
||
blacklistEntry := &entity.BlacklistEntry{
|
||
JTI: jti,
|
||
UID: uid,
|
||
ExpiresAt: expInt,
|
||
CreatedAt: time.Now().Unix(),
|
||
}
|
||
|
||
// 添加到黑名單
|
||
err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算
|
||
if err != nil {
|
||
logx.WithContext(ctx).Errorw("failed to add token to blacklist",
|
||
logx.Field("uid", uid),
|
||
logx.Field("jti", jti),
|
||
logx.Field("error", err))
|
||
// 繼續處理其他 token,不要因為一個失敗就停止
|
||
}
|
||
}
|
||
|
||
// 刪除用戶的所有 token 記錄
|
||
err = use.TokenRepo.DeleteAccessTokensByUID(ctx, uid)
|
||
if err != nil {
|
||
logx.WithContext(ctx).Errorw("failed to delete user tokens",
|
||
logx.Field("uid", uid),
|
||
logx.Field("error", err))
|
||
// 這不是致命錯誤,因為 token 已經被加入黑名單
|
||
}
|
||
|
||
logx.WithContext(ctx).Infow("all user tokens blacklisted",
|
||
logx.Field("uid", uid),
|
||
logx.Field("tokenCount", len(tokens)),
|
||
logx.Field("reason", reason))
|
||
|
||
return nil
|
||
}
|