backend/pkg/permission/usecase/token.go

710 lines
20 KiB
Go
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package usecase
import (
"context"
"fmt"
"strconv"
"time"
"backend/internal/config"
"backend/pkg/library/errs"
"backend/pkg/library/errs/code"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"backend/pkg/permission/domain/token"
"backend/pkg/permission/domain/usecase"
"github.com/segmentio/ksuid"
"github.com/zeromicro/go-zero/core/logx"
)
type TokenUseCaseParam struct {
TokenRepo repository.TokenRepository
Config *config.Config
}
type TokenUseCase struct {
TokenUseCaseParam
}
func (use *TokenUseCase) ReadTokenBasicData(ctx context.Context, token string) (map[string]string, error) {
claims, err := parseClaims(token, use.Config.Token.AccessSecret, false)
if err != nil {
return nil,
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "parseClaims",
req: token,
err: err,
message: "validate token claims error",
errorCode: code.TokenValidateError,
})
}
return claims, nil
}
func MustTokenUseCase(param TokenUseCaseParam) usecase.TokenUseCase {
return &TokenUseCase{
param,
}
}
// ============================================ token ============================================
func (use *TokenUseCase) NewToken(ctx context.Context, req entity.AuthorizationReq) (entity.TokenResp, error) {
tokenObj, err := use.newToken(ctx, &req)
if err != nil {
return entity.TokenResp{}, err
}
err = use.TokenRepo.Create(ctx, *tokenObj)
if err != nil {
return entity.TokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.Create",
req: req,
err: err,
message: "failed to create token",
errorCode: code.TokenCreateError,
})
}
return entity.TokenResp{
AccessToken: tokenObj.AccessToken,
TokenType: token.TypeBearer.String(),
ExpiresIn: int64(tokenObj.ExpiresIn),
RefreshToken: tokenObj.RefreshToken,
}, nil
}
func (use *TokenUseCase) newToken(ctx context.Context, req *entity.AuthorizationReq) (*entity.Token, error) {
// 準備建立 Token 所需
now := time.Now().UTC()
expires := req.Expires
refreshExpires := req.Expires
if expires <= 0 {
// 將時間加上 n 秒
sec := use.Config.Token.AccessTokenExpiry
// 獲取 Unix 時間戳
expires = now.Add(sec).Unix()
refreshExpires = expires
}
// 如果這是一個 Refresh Token 過期時間要比普通的Token 長
if req.IsRefreshToken {
// 獲取 Unix 時間戳
refresh := use.Config.Token.RefreshTokenExpiry
refreshExpires = now.Add(refresh).Unix()
}
token := entity.Token{
ID: ksuid.New().String(),
DeviceID: req.DeviceID,
ExpiresIn: int(expires),
RefreshExpiresIn: int(refreshExpires),
AccessCreateAt: now,
RefreshCreateAt: now,
}
tc := make(tokenClaims)
if req.Data != nil {
for k, v := range req.Data {
tc[k] = v
}
}
tc.SetRole(req.Role)
tc.SetID(token.ID)
tc.SetScope(req.Scope)
tc.SetAccount(req.Account)
token.UID = tc.UID()
if req.DeviceID != "" {
tc.SetDeviceID(req.DeviceID)
}
var err error
token.AccessToken, err = accessTokenGenerator(token, tc, use.Config.Token.AccessSecret)
if err != nil {
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "accessTokenGenerator",
req: req,
err: err,
message: "failed to generator access token",
errorCode: code.TokenCreateError,
})
}
if req.IsRefreshToken {
token.RefreshToken = refreshTokenGenerator(token.AccessToken)
}
return &token, nil
}
func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTokenReq) (entity.RefreshTokenResp, error) {
// Step 1: 檢查 refresh token
tokenObj, err := use.TokenRepo.GetAccessTokenByOneTimeToken(ctx, req.Token)
if err != nil {
return entity.RefreshTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokenByOneTimeToken",
req: req,
err: err,
message: "failed to get access token",
errorCode: code.TokenValidateError,
})
}
// Step 2: 提取 Claims Data
claimsData, err := parseClaims(tokenObj.AccessToken, use.Config.Token.AccessSecret, false)
if err != nil {
return entity.RefreshTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "extractClaims",
req: req,
err: err,
message: "failed to extract claims",
errorCode: code.TokenValidateError,
})
}
// Step 3: 創建新 token
credentials := token.ClientCredentials
newToken, err := use.newToken(ctx, &entity.AuthorizationReq{
GrantType: credentials.ToString(),
Scope: claimsData.Scope(),
DeviceID: req.DeviceID,
Data: claimsData,
Expires: req.Expires,
IsRefreshToken: true,
Account: claimsData.Account(),
Role: claimsData.Role(),
})
if err != nil {
return entity.RefreshTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "use.newToken",
req: req,
err: err,
message: "failed to create new token",
errorCode: code.TokenValidateError,
})
}
if err := use.TokenRepo.Create(ctx, *newToken); err != nil {
return entity.RefreshTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.Create",
req: req,
err: err,
message: "failed to create new token",
errorCode: code.TokenValidateError,
})
}
// Step 4: 刪除舊 token 並創建新 token
if err := use.TokenRepo.Delete(ctx, tokenObj); err != nil {
return entity.RefreshTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.Delete",
req: req,
err: err,
message: "failed to delete old token",
errorCode: code.TokenValidateError,
})
}
// 返回新的 Token 響應
return entity.RefreshTokenResp{
Token: newToken.AccessToken,
OneTimeToken: newToken.RefreshToken,
ExpiresIn: int64(newToken.ExpiresIn),
TokenType: token.TypeBearer.String(),
}, nil
}
func (use *TokenUseCase) CancelToken(ctx context.Context, req entity.CancelTokenReq) error {
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, false)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "CancelToken extractClaims",
req: req,
err: err,
message: "failed to get token claims",
errorCode: code.TokenValidateError,
})
}
token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo GetAccessTokenByID",
req: req,
err: err,
message: fmt.Sprintf("failed to get token claims :%s", claims.ID()),
errorCode: code.TokenValidateError,
})
}
err = use.TokenRepo.Delete(ctx, token)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo Delete",
req: req,
err: err,
message: fmt.Sprintf("failed to delete token :%s", token.ID),
errorCode: code.TokenValidateError,
})
}
return nil
}
func (use *TokenUseCase) ValidationToken(ctx context.Context, req entity.ValidationTokenReq) (entity.ValidationTokenResp, error) {
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, true)
if err != nil {
return entity.ValidationTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "parseClaims",
req: req,
err: err,
message: "validate token claims error",
errorCode: code.TokenValidateError,
})
}
token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
if err != nil {
return entity.ValidationTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokenByID",
req: req,
err: err,
message: fmt.Sprintf("failed to get token :%s", claims.ID()),
errorCode: code.TokenValidateError,
})
}
return entity.ValidationTokenResp{
Token: entity.Token{
ID: token.ID,
UID: token.UID,
DeviceID: token.DeviceID,
AccessCreateAt: token.AccessCreateAt,
AccessToken: token.AccessToken,
ExpiresIn: token.ExpiresIn,
RefreshToken: token.RefreshToken,
RefreshExpiresIn: token.RefreshExpiresIn,
RefreshCreateAt: token.RefreshCreateAt,
},
Data: claims,
}, nil
}
func (use *TokenUseCase) CancelTokens(ctx context.Context, req entity.DoTokenByUIDReq) error {
if req.UID != "" {
err := use.TokenRepo.DeleteAccessTokensByUID(ctx, req.UID)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.DeleteAccessTokensByUID",
req: req,
err: err,
message: "failed to cancel tokens by uid",
errorCode: code.TokenValidateError,
})
}
}
if len(req.IDs) > 0 {
err := use.TokenRepo.DeleteAccessTokenByID(ctx, req.IDs)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.DeleteAccessTokenByID",
req: req,
err: err,
message: "failed to cancel tokens by token ids",
errorCode: code.TokenValidateError,
})
}
}
return nil
}
func (use *TokenUseCase) CancelTokenByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) error {
err := use.TokenRepo.DeleteAccessTokensByDeviceID(ctx, req.DeviceID)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.DeleteAccessTokensByDeviceID",
req: req,
err: err,
message: "failed to cancel token by device id",
errorCode: code.TokenValidateError,
})
}
return nil
}
func (use *TokenUseCase) GetUserTokensByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) ([]*entity.TokenResp, error) {
uidTokens, err := use.TokenRepo.GetAccessTokensByDeviceID(ctx, req.DeviceID)
if err != nil {
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokensByDeviceID",
req: req,
err: err,
message: "failed to get token by device id",
errorCode: code.TokenNotFound,
})
}
tokens := make([]*entity.TokenResp, 0, len(uidTokens))
for _, v := range uidTokens {
tokens = append(tokens, &entity.TokenResp{
AccessToken: v.AccessToken,
TokenType: token.TypeBearer.String(),
ExpiresIn: int64(v.ExpiresIn),
RefreshToken: v.RefreshToken,
})
}
return tokens, nil
}
func (use *TokenUseCase) GetUserTokensByUID(ctx context.Context, req entity.QueryTokenByUIDReq) ([]*entity.TokenResp, error) {
uidTokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, req.UID)
if err != nil {
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokensByUID",
req: req,
err: err,
message: "failed to get token by uid",
errorCode: code.TokenNotFound,
})
}
tokens := make([]*entity.TokenResp, 0, len(uidTokens))
for _, v := range uidTokens {
tokens = append(tokens, &entity.TokenResp{
AccessToken: v.AccessToken,
TokenType: token.TypeBearer.String(),
ExpiresIn: int64(v.ExpiresIn),
RefreshToken: v.RefreshToken,
})
}
return tokens, nil
}
func (use *TokenUseCase) NewOneTimeToken(ctx context.Context, req entity.CreateOneTimeTokenReq) (entity.CreateOneTimeTokenResp, error) {
// 驗證Token
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, false)
if err != nil {
return entity.CreateOneTimeTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "parseClaims",
req: req,
err: err,
message: "failed to get token claims",
errorCode: code.OneTimeTokenError,
})
}
tokenObj, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
if err != nil {
return entity.CreateOneTimeTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokenByID",
req: req,
err: err,
message: "failed to get token by id",
errorCode: code.OneTimeTokenError,
})
}
oneTimeToken := refreshTokenGenerator(ksuid.New().String())
key := token.TicketKeyPrefix + oneTimeToken
if err = use.TokenRepo.CreateOneTimeToken(ctx, key, entity.Ticket{
Data: claims,
Token: tokenObj,
}, time.Minute); err != nil {
return entity.CreateOneTimeTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.CreateOneTimeToken",
req: req,
err: err,
message: "create one time token error",
errorCode: code.OneTimeTokenError,
})
}
return entity.CreateOneTimeTokenResp{
OneTimeToken: oneTimeToken,
}, nil
}
func (use *TokenUseCase) CancelOneTimeToken(ctx context.Context, req entity.CancelOneTimeTokenReq) error {
err := use.TokenRepo.DeleteOneTimeToken(ctx, req.Token, nil)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.DeleteOneTimeToken",
req: req,
err: err,
message: "failed to del one time token by token",
errorCode: code.OneTimeTokenError,
})
}
return nil
}
type wrapTokenErrorReq struct {
funcName string
req any
err error
message string
errorCode uint32
}
// wrapTokenError 將錯誤信息封裝到 errs.LibError 中
func (use *TokenUseCase) wrapTokenError(ctx context.Context, param wrapTokenErrorReq) error {
logFields := []logx.LogField{
{Key: "req", Value: param.req},
{Key: "func", Value: param.funcName},
{Key: "err", Value: param.err.Error()},
}
logx.WithContext(ctx).Errorw(param.message, logFields...)
wrappedErr := errs.NewError(
code.CatToken,
code.CatToken,
param.errorCode,
param.message,
).Wrap(param.err)
return wrappedErr
}
// BlacklistToken 將 JWT token 加入黑名單 (立即撤銷)
func (use *TokenUseCase) BlacklistToken(ctx context.Context, token string, reason string) error {
// 解析 JWT 獲取完整的 claims
claimMap, err := parseToken(token, use.Config.Token.AccessSecret, false)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.parseToken",
req: token,
err: err,
message: "failed to parse token claims",
errorCode: code.InvalidJWT,
})
}
// 獲取 JTI (JWT ID)
jti, exists := claimMap["jti"]
if !exists {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.getJTI",
req: token,
err: entity.ErrInvalidJTI,
message: "token missing JTI claim",
errorCode: code.InvalidJWT,
})
}
jtiStr, ok := jti.(string)
if !ok {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.convertJTI",
req: token,
err: entity.ErrInvalidJTI,
message: "JTI claim is not a string",
errorCode: code.InvalidJWT,
})
}
// 獲取 UID (可能在 data 中)
var uid string
if dataInterface, exists := claimMap["data"]; exists {
if dataMap, ok := dataInterface.(map[string]interface{}); ok {
if uidInterface, exists := dataMap["uid"]; exists {
uid, _ = uidInterface.(string)
}
}
}
// 獲取過期時間
exp, exists := claimMap["exp"]
if !exists {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.getExp",
req: token,
err: entity.ErrTokenExpired,
message: "token missing exp claim",
errorCode: code.TokenExpired,
})
}
// 將 exp 轉換為 int64 (JWT 中通常是 float64)
var expInt int64
switch v := exp.(type) {
case float64:
expInt = int64(v)
case int64:
expInt = v
case string:
parsedExp, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.parseExp",
req: token,
err: err,
message: "failed to parse exp claim",
errorCode: code.TokenExpired,
})
}
expInt = parsedExp
default:
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.convertExp",
req: token,
err: fmt.Errorf("exp claim is not a valid type: %T", exp),
message: "exp claim type conversion failed",
errorCode: code.TokenExpired,
})
}
// 創建黑名單條目
blacklistEntry := &entity.BlacklistEntry{
JTI: jtiStr,
UID: uid,
ExpiresAt: expInt,
CreatedAt: time.Now().Unix(),
}
// 添加到黑名單
err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.AddToBlacklist",
req: jtiStr,
err: err,
message: "failed to add token to blacklist",
errorCode: code.TokenCreateError,
})
}
logx.WithContext(ctx).Infow("token blacklisted",
logx.Field("jti", jtiStr),
logx.Field("uid", uid),
logx.Field("reason", reason))
return nil
}
// IsTokenBlacklisted 檢查 JWT token 是否在黑名單中
func (use *TokenUseCase) IsTokenBlacklisted(ctx context.Context, jti string) (bool, error) {
isBlacklisted, err := use.TokenRepo.IsBlacklisted(ctx, jti)
if err != nil {
return false, use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "IsTokenBlacklisted",
req: jti,
err: err,
message: "failed to check blacklist status",
errorCode: code.TokenValidateError,
})
}
return isBlacklisted, nil
}
// BlacklistAllUserTokens 將用戶的所有 token 加入黑名單 (全設備登出)
func (use *TokenUseCase) BlacklistAllUserTokens(ctx context.Context, uid string, reason string) error {
// 獲取用戶的所有 token
tokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, uid)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistAllUserTokens.GetAccessTokensByUID",
req: uid,
err: err,
message: "failed to get user tokens",
errorCode: code.TokenValidateError,
})
}
// 為每個 token 創建黑名單條目
for _, token := range tokens {
// 解析 token 獲取 JTI 和過期時間
claims, err := parseClaims(token.AccessToken, use.Config.Token.AccessSecret, false)
if err != nil {
logx.WithContext(ctx).Errorw("failed to parse token for blacklisting",
logx.Field("uid", uid),
logx.Field("tokenID", token.ID),
logx.Field("error", err))
continue // 跳過無效的 token繼續處理其他 token
}
jti, exists := claims["jti"]
if !exists || jti == "" {
logx.WithContext(ctx).Errorw("token missing JTI claim",
logx.Field("uid", uid),
logx.Field("tokenID", token.ID))
continue
}
exp, exists := claims["exp"]
if !exists {
logx.WithContext(ctx).Errorw("token missing exp claim",
logx.Field("uid", uid),
logx.Field("tokenID", token.ID))
continue
}
// 將 exp 字符串轉換為 int64
expInt, err := strconv.ParseInt(exp, 10, 64)
if err != nil {
logx.WithContext(ctx).Errorw("failed to parse exp claim",
logx.Field("uid", uid),
logx.Field("tokenID", token.ID),
logx.Field("error", err))
continue
}
// 創建黑名單條目
blacklistEntry := &entity.BlacklistEntry{
JTI: jti,
UID: uid,
ExpiresAt: expInt,
CreatedAt: time.Now().Unix(),
}
// 添加到黑名單
err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算
if err != nil {
logx.WithContext(ctx).Errorw("failed to add token to blacklist",
logx.Field("uid", uid),
logx.Field("jti", jti),
logx.Field("error", err))
// 繼續處理其他 token不要因為一個失敗就停止
}
}
// 刪除用戶的所有 token 記錄
err = use.TokenRepo.DeleteAccessTokensByUID(ctx, uid)
if err != nil {
logx.WithContext(ctx).Errorw("failed to delete user tokens",
logx.Field("uid", uid),
logx.Field("error", err))
// 這不是致命錯誤,因為 token 已經被加入黑名單
}
logx.WithContext(ctx).Infow("all user tokens blacklisted",
logx.Field("uid", uid),
logx.Field("tokenCount", len(tokens)),
logx.Field("reason", reason))
return nil
}