backend/pkg/permission/usecase/token.go

710 lines
20 KiB
Go
Raw Normal View History

2025-10-06 08:28:39 +00:00
package usecase
import (
"context"
"fmt"
"strconv"
"time"
"backend/internal/config"
"backend/pkg/library/errs"
"backend/pkg/library/errs/code"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"backend/pkg/permission/domain/token"
"backend/pkg/permission/domain/usecase"
"github.com/segmentio/ksuid"
"github.com/zeromicro/go-zero/core/logx"
)
type TokenUseCaseParam struct {
TokenRepo repository.TokenRepository
Config *config.Config
}
type TokenUseCase struct {
TokenUseCaseParam
}
func (use *TokenUseCase) ReadTokenBasicData(ctx context.Context, token string) (map[string]string, error) {
claims, err := parseClaims(token, use.Config.Token.AccessSecret, false)
if err != nil {
return nil,
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "parseClaims",
req: token,
err: err,
message: "validate token claims error",
errorCode: code.TokenValidateError,
})
}
return claims, nil
}
func MustTokenUseCase(param TokenUseCaseParam) usecase.TokenUseCase {
return &TokenUseCase{
param,
}
}
// ============================================ token ============================================
func (use *TokenUseCase) NewToken(ctx context.Context, req entity.AuthorizationReq) (entity.TokenResp, error) {
tokenObj, err := use.newToken(ctx, &req)
if err != nil {
return entity.TokenResp{}, err
}
err = use.TokenRepo.Create(ctx, *tokenObj)
if err != nil {
return entity.TokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.Create",
req: req,
err: err,
message: "failed to create token",
errorCode: code.TokenCreateError,
})
}
return entity.TokenResp{
AccessToken: tokenObj.AccessToken,
TokenType: token.TypeBearer.String(),
ExpiresIn: int64(tokenObj.ExpiresIn),
RefreshToken: tokenObj.RefreshToken,
}, nil
}
func (use *TokenUseCase) newToken(ctx context.Context, req *entity.AuthorizationReq) (*entity.Token, error) {
// 準備建立 Token 所需
now := time.Now().UTC()
expires := req.Expires
refreshExpires := req.Expires
if expires <= 0 {
// 將時間加上 n 秒
sec := use.Config.Token.AccessTokenExpiry
// 獲取 Unix 時間戳
expires = now.Add(sec).Unix()
refreshExpires = expires
}
// 如果這是一個 Refresh Token 過期時間要比普通的Token 長
if req.IsRefreshToken {
// 獲取 Unix 時間戳
refresh := use.Config.Token.RefreshTokenExpiry
refreshExpires = now.Add(refresh).Unix()
}
token := entity.Token{
ID: ksuid.New().String(),
DeviceID: req.DeviceID,
ExpiresIn: int(expires),
RefreshExpiresIn: int(refreshExpires),
AccessCreateAt: now,
RefreshCreateAt: now,
}
tc := make(tokenClaims)
if req.Data != nil {
for k, v := range req.Data {
tc[k] = v
}
}
tc.SetRole(req.Role)
tc.SetID(token.ID)
tc.SetScope(req.Scope)
tc.SetAccount(req.Account)
token.UID = tc.UID()
if req.DeviceID != "" {
tc.SetDeviceID(req.DeviceID)
}
var err error
token.AccessToken, err = accessTokenGenerator(token, tc, use.Config.Token.AccessSecret)
if err != nil {
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "accessTokenGenerator",
req: req,
err: err,
message: "failed to generator access token",
errorCode: code.TokenCreateError,
})
}
if req.IsRefreshToken {
token.RefreshToken = refreshTokenGenerator(token.AccessToken)
}
return &token, nil
}
func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTokenReq) (entity.RefreshTokenResp, error) {
// Step 1: 檢查 refresh token
tokenObj, err := use.TokenRepo.GetAccessTokenByOneTimeToken(ctx, req.Token)
if err != nil {
return entity.RefreshTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokenByOneTimeToken",
req: req,
err: err,
message: "failed to get access token",
errorCode: code.TokenValidateError,
})
}
// Step 2: 提取 Claims Data
claimsData, err := parseClaims(tokenObj.AccessToken, use.Config.Token.AccessSecret, false)
if err != nil {
return entity.RefreshTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "extractClaims",
req: req,
err: err,
message: "failed to extract claims",
errorCode: code.TokenValidateError,
})
}
// Step 3: 創建新 token
credentials := token.ClientCredentials
newToken, err := use.newToken(ctx, &entity.AuthorizationReq{
GrantType: credentials.ToString(),
2025-10-06 13:14:58 +00:00
Scope: claimsData.Scope(),
2025-10-06 08:28:39 +00:00
DeviceID: req.DeviceID,
Data: claimsData,
Expires: req.Expires,
IsRefreshToken: true,
2025-10-06 13:14:58 +00:00
Account: claimsData.Account(),
Role: claimsData.Role(),
2025-10-06 08:28:39 +00:00
})
if err != nil {
return entity.RefreshTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "use.newToken",
req: req,
err: err,
message: "failed to create new token",
errorCode: code.TokenValidateError,
})
}
if err := use.TokenRepo.Create(ctx, *newToken); err != nil {
return entity.RefreshTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.Create",
req: req,
err: err,
message: "failed to create new token",
errorCode: code.TokenValidateError,
})
}
// Step 4: 刪除舊 token 並創建新 token
if err := use.TokenRepo.Delete(ctx, tokenObj); err != nil {
return entity.RefreshTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.Delete",
req: req,
err: err,
message: "failed to delete old token",
errorCode: code.TokenValidateError,
})
}
// 返回新的 Token 響應
return entity.RefreshTokenResp{
Token: newToken.AccessToken,
OneTimeToken: newToken.RefreshToken,
ExpiresIn: int64(newToken.ExpiresIn),
TokenType: token.TypeBearer.String(),
}, nil
}
func (use *TokenUseCase) CancelToken(ctx context.Context, req entity.CancelTokenReq) error {
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, false)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "CancelToken extractClaims",
req: req,
err: err,
message: "failed to get token claims",
errorCode: code.TokenValidateError,
})
}
token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo GetAccessTokenByID",
req: req,
err: err,
message: fmt.Sprintf("failed to get token claims :%s", claims.ID()),
errorCode: code.TokenValidateError,
})
}
err = use.TokenRepo.Delete(ctx, token)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo Delete",
req: req,
err: err,
message: fmt.Sprintf("failed to delete token :%s", token.ID),
errorCode: code.TokenValidateError,
})
}
return nil
}
func (use *TokenUseCase) ValidationToken(ctx context.Context, req entity.ValidationTokenReq) (entity.ValidationTokenResp, error) {
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, true)
if err != nil {
return entity.ValidationTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "parseClaims",
req: req,
err: err,
message: "validate token claims error",
errorCode: code.TokenValidateError,
})
}
token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
if err != nil {
return entity.ValidationTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokenByID",
req: req,
err: err,
message: fmt.Sprintf("failed to get token :%s", claims.ID()),
errorCode: code.TokenValidateError,
})
}
return entity.ValidationTokenResp{
Token: entity.Token{
ID: token.ID,
UID: token.UID,
DeviceID: token.DeviceID,
AccessCreateAt: token.AccessCreateAt,
AccessToken: token.AccessToken,
ExpiresIn: token.ExpiresIn,
RefreshToken: token.RefreshToken,
RefreshExpiresIn: token.RefreshExpiresIn,
RefreshCreateAt: token.RefreshCreateAt,
},
Data: claims,
}, nil
}
func (use *TokenUseCase) CancelTokens(ctx context.Context, req entity.DoTokenByUIDReq) error {
if req.UID != "" {
err := use.TokenRepo.DeleteAccessTokensByUID(ctx, req.UID)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.DeleteAccessTokensByUID",
req: req,
err: err,
message: "failed to cancel tokens by uid",
errorCode: code.TokenValidateError,
})
}
}
if len(req.IDs) > 0 {
err := use.TokenRepo.DeleteAccessTokenByID(ctx, req.IDs)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.DeleteAccessTokenByID",
req: req,
err: err,
message: "failed to cancel tokens by token ids",
errorCode: code.TokenValidateError,
})
}
}
return nil
}
func (use *TokenUseCase) CancelTokenByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) error {
err := use.TokenRepo.DeleteAccessTokensByDeviceID(ctx, req.DeviceID)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.DeleteAccessTokensByDeviceID",
req: req,
err: err,
message: "failed to cancel token by device id",
errorCode: code.TokenValidateError,
})
}
return nil
}
func (use *TokenUseCase) GetUserTokensByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) ([]*entity.TokenResp, error) {
uidTokens, err := use.TokenRepo.GetAccessTokensByDeviceID(ctx, req.DeviceID)
if err != nil {
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokensByDeviceID",
req: req,
err: err,
message: "failed to get token by device id",
errorCode: code.TokenNotFound,
})
}
tokens := make([]*entity.TokenResp, 0, len(uidTokens))
for _, v := range uidTokens {
tokens = append(tokens, &entity.TokenResp{
AccessToken: v.AccessToken,
TokenType: token.TypeBearer.String(),
ExpiresIn: int64(v.ExpiresIn),
RefreshToken: v.RefreshToken,
})
}
return tokens, nil
}
func (use *TokenUseCase) GetUserTokensByUID(ctx context.Context, req entity.QueryTokenByUIDReq) ([]*entity.TokenResp, error) {
uidTokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, req.UID)
if err != nil {
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokensByUID",
req: req,
err: err,
message: "failed to get token by uid",
errorCode: code.TokenNotFound,
})
}
tokens := make([]*entity.TokenResp, 0, len(uidTokens))
for _, v := range uidTokens {
tokens = append(tokens, &entity.TokenResp{
AccessToken: v.AccessToken,
TokenType: token.TypeBearer.String(),
ExpiresIn: int64(v.ExpiresIn),
RefreshToken: v.RefreshToken,
})
}
return tokens, nil
}
func (use *TokenUseCase) NewOneTimeToken(ctx context.Context, req entity.CreateOneTimeTokenReq) (entity.CreateOneTimeTokenResp, error) {
// 驗證Token
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, false)
if err != nil {
return entity.CreateOneTimeTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "parseClaims",
req: req,
err: err,
message: "failed to get token claims",
errorCode: code.OneTimeTokenError,
})
}
tokenObj, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
if err != nil {
return entity.CreateOneTimeTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokenByID",
req: req,
err: err,
message: "failed to get token by id",
errorCode: code.OneTimeTokenError,
})
}
oneTimeToken := refreshTokenGenerator(ksuid.New().String())
key := token.TicketKeyPrefix + oneTimeToken
if err = use.TokenRepo.CreateOneTimeToken(ctx, key, entity.Ticket{
Data: claims,
Token: tokenObj,
}, time.Minute); err != nil {
return entity.CreateOneTimeTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.CreateOneTimeToken",
req: req,
err: err,
message: "create one time token error",
errorCode: code.OneTimeTokenError,
})
}
return entity.CreateOneTimeTokenResp{
OneTimeToken: oneTimeToken,
}, nil
}
func (use *TokenUseCase) CancelOneTimeToken(ctx context.Context, req entity.CancelOneTimeTokenReq) error {
err := use.TokenRepo.DeleteOneTimeToken(ctx, req.Token, nil)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.DeleteOneTimeToken",
req: req,
err: err,
message: "failed to del one time token by token",
errorCode: code.OneTimeTokenError,
})
}
return nil
}
type wrapTokenErrorReq struct {
funcName string
req any
err error
message string
errorCode uint32
}
// wrapTokenError 將錯誤信息封裝到 errs.LibError 中
func (use *TokenUseCase) wrapTokenError(ctx context.Context, param wrapTokenErrorReq) error {
logFields := []logx.LogField{
{Key: "req", Value: param.req},
{Key: "func", Value: param.funcName},
{Key: "err", Value: param.err.Error()},
}
2025-10-06 13:14:58 +00:00
2025-10-06 08:28:39 +00:00
logx.WithContext(ctx).Errorw(param.message, logFields...)
2025-10-06 13:14:58 +00:00
2025-10-06 08:28:39 +00:00
wrappedErr := errs.NewError(
code.CatToken,
code.CatToken,
param.errorCode,
param.message,
).Wrap(param.err)
return wrappedErr
}
// BlacklistToken 將 JWT token 加入黑名單 (立即撤銷)
func (use *TokenUseCase) BlacklistToken(ctx context.Context, token string, reason string) error {
// 解析 JWT 獲取完整的 claims
claimMap, err := parseToken(token, use.Config.Token.AccessSecret, false)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.parseToken",
req: token,
err: err,
message: "failed to parse token claims",
errorCode: code.InvalidJWT,
})
}
// 獲取 JTI (JWT ID)
jti, exists := claimMap["jti"]
if !exists {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.getJTI",
req: token,
err: entity.ErrInvalidJTI,
message: "token missing JTI claim",
errorCode: code.InvalidJWT,
})
}
jtiStr, ok := jti.(string)
if !ok {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.convertJTI",
req: token,
err: entity.ErrInvalidJTI,
message: "JTI claim is not a string",
errorCode: code.InvalidJWT,
})
}
// 獲取 UID (可能在 data 中)
var uid string
if dataInterface, exists := claimMap["data"]; exists {
if dataMap, ok := dataInterface.(map[string]interface{}); ok {
if uidInterface, exists := dataMap["uid"]; exists {
uid, _ = uidInterface.(string)
}
}
}
// 獲取過期時間
exp, exists := claimMap["exp"]
if !exists {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.getExp",
req: token,
err: entity.ErrTokenExpired,
message: "token missing exp claim",
errorCode: code.TokenExpired,
})
}
// 將 exp 轉換為 int64 (JWT 中通常是 float64)
var expInt int64
switch v := exp.(type) {
case float64:
expInt = int64(v)
case int64:
expInt = v
case string:
parsedExp, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.parseExp",
req: token,
err: err,
message: "failed to parse exp claim",
errorCode: code.TokenExpired,
})
}
expInt = parsedExp
default:
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.convertExp",
req: token,
err: fmt.Errorf("exp claim is not a valid type: %T", exp),
message: "exp claim type conversion failed",
errorCode: code.TokenExpired,
})
}
// 創建黑名單條目
blacklistEntry := &entity.BlacklistEntry{
JTI: jtiStr,
UID: uid,
ExpiresAt: expInt,
CreatedAt: time.Now().Unix(),
}
// 添加到黑名單
err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.AddToBlacklist",
req: jtiStr,
err: err,
message: "failed to add token to blacklist",
errorCode: code.TokenCreateError,
})
}
2025-10-06 13:14:58 +00:00
logx.WithContext(ctx).Infow("token blacklisted",
2025-10-06 08:28:39 +00:00
logx.Field("jti", jtiStr),
logx.Field("uid", uid),
logx.Field("reason", reason))
return nil
}
// IsTokenBlacklisted 檢查 JWT token 是否在黑名單中
func (use *TokenUseCase) IsTokenBlacklisted(ctx context.Context, jti string) (bool, error) {
isBlacklisted, err := use.TokenRepo.IsBlacklisted(ctx, jti)
if err != nil {
return false, use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "IsTokenBlacklisted",
req: jti,
err: err,
message: "failed to check blacklist status",
errorCode: code.TokenValidateError,
})
}
return isBlacklisted, nil
}
// BlacklistAllUserTokens 將用戶的所有 token 加入黑名單 (全設備登出)
func (use *TokenUseCase) BlacklistAllUserTokens(ctx context.Context, uid string, reason string) error {
// 獲取用戶的所有 token
tokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, uid)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistAllUserTokens.GetAccessTokensByUID",
req: uid,
err: err,
message: "failed to get user tokens",
errorCode: code.TokenValidateError,
})
}
// 為每個 token 創建黑名單條目
for _, token := range tokens {
// 解析 token 獲取 JTI 和過期時間
claims, err := parseClaims(token.AccessToken, use.Config.Token.AccessSecret, false)
if err != nil {
logx.WithContext(ctx).Errorw("failed to parse token for blacklisting",
logx.Field("uid", uid),
logx.Field("tokenID", token.ID),
logx.Field("error", err))
continue // 跳過無效的 token繼續處理其他 token
}
jti, exists := claims["jti"]
if !exists || jti == "" {
logx.WithContext(ctx).Errorw("token missing JTI claim",
logx.Field("uid", uid),
logx.Field("tokenID", token.ID))
continue
}
exp, exists := claims["exp"]
if !exists {
logx.WithContext(ctx).Errorw("token missing exp claim",
logx.Field("uid", uid),
logx.Field("tokenID", token.ID))
continue
}
// 將 exp 字符串轉換為 int64
expInt, err := strconv.ParseInt(exp, 10, 64)
if err != nil {
logx.WithContext(ctx).Errorw("failed to parse exp claim",
logx.Field("uid", uid),
logx.Field("tokenID", token.ID),
logx.Field("error", err))
continue
}
// 創建黑名單條目
blacklistEntry := &entity.BlacklistEntry{
JTI: jti,
UID: uid,
ExpiresAt: expInt,
CreatedAt: time.Now().Unix(),
}
// 添加到黑名單
err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算
if err != nil {
logx.WithContext(ctx).Errorw("failed to add token to blacklist",
logx.Field("uid", uid),
logx.Field("jti", jti),
logx.Field("error", err))
// 繼續處理其他 token不要因為一個失敗就停止
}
}
// 刪除用戶的所有 token 記錄
err = use.TokenRepo.DeleteAccessTokensByUID(ctx, uid)
if err != nil {
logx.WithContext(ctx).Errorw("failed to delete user tokens",
logx.Field("uid", uid),
logx.Field("error", err))
// 這不是致命錯誤,因為 token 已經被加入黑名單
}
logx.WithContext(ctx).Infow("all user tokens blacklisted",
logx.Field("uid", uid),
logx.Field("tokenCount", len(tokens)),
logx.Field("reason", reason))
return nil
}