backend/pkg/permission/usecase/token.go

563 lines
15 KiB
Go
Raw Normal View History

2025-10-06 08:28:39 +00:00
package usecase
import (
"context"
"fmt"
"strconv"
"time"
"backend/internal/config"
2025-11-04 09:47:36 +00:00
errs "backend/pkg/library/errors"
2025-10-06 08:28:39 +00:00
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"backend/pkg/permission/domain/token"
"backend/pkg/permission/domain/usecase"
"github.com/segmentio/ksuid"
)
type TokenUseCaseParam struct {
TokenRepo repository.TokenRepository
2025-11-04 09:47:36 +00:00
Logger errs.Logger
Config *config.Config
2025-10-06 08:28:39 +00:00
}
type TokenUseCase struct {
TokenUseCaseParam
}
func (use *TokenUseCase) ReadTokenBasicData(ctx context.Context, token string) (map[string]string, error) {
2025-10-22 13:40:31 +00:00
claims, err := ParseClaims(token, use.Config.Token.AccessSecret, false)
2025-10-06 08:28:39 +00:00
if err != nil {
2025-11-04 09:47:36 +00:00
return nil, errs.AuthSigPayloadMismatchError("validate token claims error")
2025-10-06 08:28:39 +00:00
}
return claims, nil
}
func MustTokenUseCase(param TokenUseCaseParam) usecase.TokenUseCase {
return &TokenUseCase{
param,
}
}
// ============================================ token ============================================
func (use *TokenUseCase) NewToken(ctx context.Context, req entity.AuthorizationReq) (entity.TokenResp, error) {
tokenObj, err := use.newToken(ctx, &req)
if err != nil {
return entity.TokenResp{}, err
}
err = use.TokenRepo.Create(ctx, *tokenObj)
if err != nil {
2025-11-04 09:47:36 +00:00
return entity.TokenResp{}, errs.DBErrorError("failed to create token")
2025-10-06 08:28:39 +00:00
}
return entity.TokenResp{
AccessToken: tokenObj.AccessToken,
TokenType: token.TypeBearer.String(),
ExpiresIn: int64(tokenObj.ExpiresIn),
RefreshToken: tokenObj.RefreshToken,
}, nil
}
func (use *TokenUseCase) newToken(ctx context.Context, req *entity.AuthorizationReq) (*entity.Token, error) {
// 準備建立 Token 所需
now := time.Now().UTC()
expires := req.Expires
refreshExpires := req.Expires
if expires <= 0 {
// 將時間加上 n 秒
sec := use.Config.Token.AccessTokenExpiry
// 獲取 Unix 時間戳
expires = now.Add(sec).Unix()
refreshExpires = expires
}
// 如果這是一個 Refresh Token 過期時間要比普通的Token 長
if req.IsRefreshToken {
// 獲取 Unix 時間戳
refresh := use.Config.Token.RefreshTokenExpiry
refreshExpires = now.Add(refresh).Unix()
}
token := entity.Token{
ID: ksuid.New().String(),
DeviceID: req.DeviceID,
ExpiresIn: int(expires),
RefreshExpiresIn: int(refreshExpires),
AccessCreateAt: now,
RefreshCreateAt: now,
}
2025-10-22 13:40:31 +00:00
tc := make(TokenClaims)
2025-10-06 08:28:39 +00:00
if req.Data != nil {
for k, v := range req.Data {
tc[k] = v
}
}
tc.SetRole(req.Role)
tc.SetID(token.ID)
tc.SetScope(req.Scope)
2025-10-22 13:40:31 +00:00
tc.SetLoginID(req.Account)
2025-10-06 08:28:39 +00:00
token.UID = tc.UID()
if req.DeviceID != "" {
tc.SetDeviceID(req.DeviceID)
}
var err error
token.AccessToken, err = accessTokenGenerator(token, tc, use.Config.Token.AccessSecret)
if err != nil {
2025-11-04 09:47:36 +00:00
return nil, errs.SysInternalError("failed to generator access token").Wrap(err)
2025-10-06 08:28:39 +00:00
}
if req.IsRefreshToken {
token.RefreshToken = refreshTokenGenerator(token.AccessToken)
}
return &token, nil
}
func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTokenReq) (entity.RefreshTokenResp, error) {
// Step 1: 檢查 refresh token
tokenObj, err := use.TokenRepo.GetAccessTokenByOneTimeToken(ctx, req.Token)
if err != nil {
2025-11-04 09:47:36 +00:00
return entity.RefreshTokenResp{}, errs.DBErrorError("failed to get access token").Wrap(err)
2025-10-06 08:28:39 +00:00
}
// Step 2: 提取 Claims Data
2025-10-22 13:40:31 +00:00
claimsData, err := ParseClaims(tokenObj.AccessToken, use.Config.Token.AccessSecret, false)
2025-10-06 08:28:39 +00:00
if err != nil {
2025-11-04 09:47:36 +00:00
return entity.RefreshTokenResp{}, errs.AuthSigPayloadMismatchError("failed to extract claims")
2025-10-06 08:28:39 +00:00
}
// Step 3: 創建新 token
credentials := token.ClientCredentials
newToken, err := use.newToken(ctx, &entity.AuthorizationReq{
GrantType: credentials.ToString(),
2025-10-06 13:14:58 +00:00
Scope: claimsData.Scope(),
2025-10-06 08:28:39 +00:00
DeviceID: req.DeviceID,
Data: claimsData,
Expires: req.Expires,
IsRefreshToken: true,
2025-10-22 13:40:31 +00:00
Account: claimsData.LoginID(),
2025-10-06 13:14:58 +00:00
Role: claimsData.Role(),
2025-10-06 08:28:39 +00:00
})
if err != nil {
2025-11-04 09:47:36 +00:00
return entity.RefreshTokenResp{}, errs.DBErrorError("failed to create new token").Wrap(err)
2025-10-06 08:28:39 +00:00
}
if err := use.TokenRepo.Create(ctx, *newToken); err != nil {
2025-11-04 09:47:36 +00:00
return entity.RefreshTokenResp{}, errs.DBErrorError("failed to create new token").Wrap(err)
2025-10-06 08:28:39 +00:00
}
// Step 4: 刪除舊 token 並創建新 token
if err := use.TokenRepo.Delete(ctx, tokenObj); err != nil {
2025-11-04 09:47:36 +00:00
return entity.RefreshTokenResp{}, errs.DBErrorError("failed to delete old token").Wrap(err)
2025-10-06 08:28:39 +00:00
}
// 返回新的 Token 響應
return entity.RefreshTokenResp{
Token: newToken.AccessToken,
OneTimeToken: newToken.RefreshToken,
ExpiresIn: int64(newToken.ExpiresIn),
TokenType: token.TypeBearer.String(),
}, nil
}
func (use *TokenUseCase) CancelToken(ctx context.Context, req entity.CancelTokenReq) error {
2025-10-22 13:40:31 +00:00
claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, false)
2025-10-06 08:28:39 +00:00
if err != nil {
2025-11-04 09:47:36 +00:00
return errs.AuthSigPayloadMismatchError("failed to get token claims")
2025-10-06 08:28:39 +00:00
}
token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
if err != nil {
2025-11-04 09:47:36 +00:00
return errs.DBErrorError(fmt.Sprintf("failed to get token claims :%s", claims.ID()))
2025-10-06 08:28:39 +00:00
}
err = use.TokenRepo.Delete(ctx, token)
if err != nil {
2025-11-04 09:47:36 +00:00
return errs.DBErrorError(fmt.Sprintf("failed to delete token :%s", token.ID)).Wrap(err)
2025-10-06 08:28:39 +00:00
}
return nil
}
func (use *TokenUseCase) ValidationToken(ctx context.Context, req entity.ValidationTokenReq) (entity.ValidationTokenResp, error) {
2025-10-22 13:40:31 +00:00
claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, true)
2025-10-06 08:28:39 +00:00
if err != nil {
2025-11-04 09:47:36 +00:00
return entity.ValidationTokenResp{}, errs.AuthSigPayloadMismatchError("validate token claims error")
2025-10-06 08:28:39 +00:00
}
token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
if err != nil {
2025-11-04 09:47:36 +00:00
return entity.ValidationTokenResp{}, errs.DBErrorError(fmt.Sprintf("failed to get token :%s", claims.ID())).Wrap(err)
2025-10-06 08:28:39 +00:00
}
return entity.ValidationTokenResp{
Token: entity.Token{
ID: token.ID,
UID: token.UID,
DeviceID: token.DeviceID,
AccessCreateAt: token.AccessCreateAt,
AccessToken: token.AccessToken,
ExpiresIn: token.ExpiresIn,
RefreshToken: token.RefreshToken,
RefreshExpiresIn: token.RefreshExpiresIn,
RefreshCreateAt: token.RefreshCreateAt,
},
Data: claims,
}, nil
}
func (use *TokenUseCase) CancelTokens(ctx context.Context, req entity.DoTokenByUIDReq) error {
if req.UID != "" {
err := use.TokenRepo.DeleteAccessTokensByUID(ctx, req.UID)
if err != nil {
2025-11-04 09:47:36 +00:00
return errs.DBErrorError("failed to cancel tokens by uid").Wrap(err)
2025-10-06 08:28:39 +00:00
}
}
if len(req.IDs) > 0 {
err := use.TokenRepo.DeleteAccessTokenByID(ctx, req.IDs)
if err != nil {
2025-11-04 09:47:36 +00:00
return errs.DBErrorError("failed to cancel tokens by token ids").Wrap(err)
2025-10-06 08:28:39 +00:00
}
}
return nil
}
func (use *TokenUseCase) CancelTokenByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) error {
err := use.TokenRepo.DeleteAccessTokensByDeviceID(ctx, req.DeviceID)
if err != nil {
2025-11-04 09:47:36 +00:00
return errs.DBErrorError("failed to cancel tokens by device id").Wrap(err)
2025-10-06 08:28:39 +00:00
}
return nil
}
func (use *TokenUseCase) GetUserTokensByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) ([]*entity.TokenResp, error) {
uidTokens, err := use.TokenRepo.GetAccessTokensByDeviceID(ctx, req.DeviceID)
if err != nil {
2025-11-04 09:47:36 +00:00
return nil, errs.DBErrorError("failed to get tokens by device id").Wrap(err)
2025-10-06 08:28:39 +00:00
}
tokens := make([]*entity.TokenResp, 0, len(uidTokens))
for _, v := range uidTokens {
tokens = append(tokens, &entity.TokenResp{
AccessToken: v.AccessToken,
TokenType: token.TypeBearer.String(),
ExpiresIn: int64(v.ExpiresIn),
RefreshToken: v.RefreshToken,
})
}
return tokens, nil
}
func (use *TokenUseCase) GetUserTokensByUID(ctx context.Context, req entity.QueryTokenByUIDReq) ([]*entity.TokenResp, error) {
uidTokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, req.UID)
if err != nil {
2025-11-04 09:47:36 +00:00
return nil, errs.DBErrorError("failed to get tokens by uid").Wrap(err)
2025-10-06 08:28:39 +00:00
}
tokens := make([]*entity.TokenResp, 0, len(uidTokens))
for _, v := range uidTokens {
tokens = append(tokens, &entity.TokenResp{
AccessToken: v.AccessToken,
TokenType: token.TypeBearer.String(),
ExpiresIn: int64(v.ExpiresIn),
RefreshToken: v.RefreshToken,
})
}
return tokens, nil
}
func (use *TokenUseCase) NewOneTimeToken(ctx context.Context, req entity.CreateOneTimeTokenReq) (entity.CreateOneTimeTokenResp, error) {
// 驗證Token
2025-10-22 13:40:31 +00:00
claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, false)
2025-10-06 08:28:39 +00:00
if err != nil {
2025-11-04 09:47:36 +00:00
return entity.CreateOneTimeTokenResp{}, errs.AuthSigPayloadMismatchError("failed to get token claims").Wrap(err)
2025-10-06 08:28:39 +00:00
}
tokenObj, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
if err != nil {
2025-11-04 09:47:36 +00:00
return entity.CreateOneTimeTokenResp{}, errs.DBErrorError("failed to get token by id").Wrap(err)
2025-10-06 08:28:39 +00:00
}
oneTimeToken := refreshTokenGenerator(ksuid.New().String())
key := token.TicketKeyPrefix + oneTimeToken
if err = use.TokenRepo.CreateOneTimeToken(ctx, key, entity.Ticket{
Data: claims,
Token: tokenObj,
}, time.Minute); err != nil {
2025-11-04 09:47:36 +00:00
return entity.CreateOneTimeTokenResp{}, errs.DBErrorError("failed to create new one-time token").Wrap(err)
2025-10-06 08:28:39 +00:00
}
return entity.CreateOneTimeTokenResp{
OneTimeToken: oneTimeToken,
}, nil
}
func (use *TokenUseCase) CancelOneTimeToken(ctx context.Context, req entity.CancelOneTimeTokenReq) error {
err := use.TokenRepo.DeleteOneTimeToken(ctx, req.Token, nil)
if err != nil {
2025-11-04 09:47:36 +00:00
return errs.DBErrorError("failed to del one time token by token")
2025-10-06 08:28:39 +00:00
}
return nil
}
// BlacklistToken 將 JWT token 加入黑名單 (立即撤銷)
func (use *TokenUseCase) BlacklistToken(ctx context.Context, token string, reason string) error {
// 解析 JWT 獲取完整的 claims
claimMap, err := parseToken(token, use.Config.Token.AccessSecret, false)
if err != nil {
2025-11-04 09:47:36 +00:00
return errs.AuthSigPayloadMismatchError("failed to parse token claims").Wrap(err)
2025-10-06 08:28:39 +00:00
}
// 獲取 JTI (JWT ID)
jti, exists := claimMap["jti"]
if !exists {
2025-11-04 09:47:36 +00:00
return errs.ResNotFoundError("token missing JTI claim").Wrap(err)
2025-10-06 08:28:39 +00:00
}
jtiStr, ok := jti.(string)
if !ok {
2025-11-04 09:47:36 +00:00
return errs.ResNotFoundError("token missing JTI claim").Wrap(err)
2025-10-06 08:28:39 +00:00
}
// 獲取 UID (可能在 data 中)
var uid string
if dataInterface, exists := claimMap["data"]; exists {
if dataMap, ok := dataInterface.(map[string]interface{}); ok {
if uidInterface, exists := dataMap["uid"]; exists {
uid, _ = uidInterface.(string)
}
}
}
// 獲取過期時間
exp, exists := claimMap["exp"]
if !exists {
2025-11-04 09:47:36 +00:00
return errs.AuthExpiredError("token missing exp claim").Wrap(err)
2025-10-06 08:28:39 +00:00
}
// 將 exp 轉換為 int64 (JWT 中通常是 float64)
var expInt int64
switch v := exp.(type) {
case float64:
expInt = int64(v)
case int64:
expInt = v
case string:
parsedExp, err := strconv.ParseInt(v, 10, 64)
if err != nil {
2025-11-04 09:47:36 +00:00
return errs.SysInternalError("failed to parse exp claim").Wrap(err)
2025-10-06 08:28:39 +00:00
}
expInt = parsedExp
default:
2025-11-04 09:47:36 +00:00
return errs.SysInternalError("exp claim type conversion failed").Wrap(err)
2025-10-06 08:28:39 +00:00
}
// 創建黑名單條目
blacklistEntry := &entity.BlacklistEntry{
JTI: jtiStr,
UID: uid,
ExpiresAt: expInt,
CreatedAt: time.Now().Unix(),
}
// 添加到黑名單
err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算
if err != nil {
2025-11-04 09:47:36 +00:00
return errs.DBErrorError("failed to add token to blacklist").Wrap(err)
}
// 記錄成功日誌(如果 Logger 存在)
if use.Logger != nil {
use.Logger.WithFields(
errs.LogField{
Key: "jti",
Val: jtiStr,
},
errs.LogField{
Key: "uid",
Val: uid,
},
errs.LogField{
Key: "reason",
Val: reason,
}).Info("token blacklisted")
2025-10-06 08:28:39 +00:00
}
return nil
}
// IsTokenBlacklisted 檢查 JWT token 是否在黑名單中
func (use *TokenUseCase) IsTokenBlacklisted(ctx context.Context, jti string) (bool, error) {
isBlacklisted, err := use.TokenRepo.IsBlacklisted(ctx, jti)
if err != nil {
2025-11-04 09:47:36 +00:00
return false, errs.DBErrorError("failed to check blacklist status").Wrap(err)
2025-10-06 08:28:39 +00:00
}
return isBlacklisted, nil
}
// BlacklistAllUserTokens 將用戶的所有 token 加入黑名單 (全設備登出)
func (use *TokenUseCase) BlacklistAllUserTokens(ctx context.Context, uid string, reason string) error {
// 獲取用戶的所有 token
tokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, uid)
if err != nil {
2025-11-04 09:47:36 +00:00
return errs.DBErrorError("failed to get user tokens").Wrap(err)
2025-10-06 08:28:39 +00:00
}
// 為每個 token 創建黑名單條目
for _, token := range tokens {
// 解析 token 獲取 JTI 和過期時間
2025-10-22 13:40:31 +00:00
claims, err := ParseClaims(token.AccessToken, use.Config.Token.AccessSecret, false)
2025-10-06 08:28:39 +00:00
if err != nil {
2025-11-04 09:47:36 +00:00
if use.Logger != nil {
use.Logger.WithFields(
errs.LogField{
Key: "uid",
Val: uid,
},
errs.LogField{
Key: "tokenID",
Val: token.ID,
},
errs.LogField{
Key: "error",
Val: err,
}).Error("failed to parse token for blacklisting")
}
2025-10-06 08:28:39 +00:00
continue // 跳過無效的 token繼續處理其他 token
}
jti, exists := claims["jti"]
if !exists || jti == "" {
2025-11-04 09:47:36 +00:00
if use.Logger != nil {
use.Logger.WithFields(
errs.LogField{
Key: "uid",
Val: uid,
},
errs.LogField{
Key: "tokenID",
Val: token.ID,
}).Error("failed to parse token for blacklisting")
}
2025-10-06 08:28:39 +00:00
continue
}
exp, exists := claims["exp"]
if !exists {
2025-11-04 09:47:36 +00:00
if use.Logger != nil {
use.Logger.WithFields(
errs.LogField{
Key: "uid",
Val: uid,
},
errs.LogField{
Key: "tokenID",
Val: token.ID,
}).Error("token missing exp claim")
}
2025-10-06 08:28:39 +00:00
continue
}
// 將 exp 字符串轉換為 int64
expInt, err := strconv.ParseInt(exp, 10, 64)
if err != nil {
2025-11-04 09:47:36 +00:00
if use.Logger != nil {
use.Logger.WithFields(
errs.LogField{
Key: "uid",
Val: uid,
},
errs.LogField{
Key: "tokenID",
Val: token.ID,
},
errs.LogField{
Key: "error",
Val: err,
}).Error("failed to parse exp claim")
}
2025-10-06 08:28:39 +00:00
continue
}
// 創建黑名單條目
blacklistEntry := &entity.BlacklistEntry{
JTI: jti,
UID: uid,
ExpiresAt: expInt,
CreatedAt: time.Now().Unix(),
}
// 添加到黑名單
err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算
if err != nil {
2025-11-04 09:47:36 +00:00
if use.Logger != nil {
use.Logger.WithFields(
errs.LogField{
Key: "uid",
Val: uid,
},
errs.LogField{
Key: "jti",
Val: jti,
},
errs.LogField{
Key: "error",
Val: err,
}).Error("failed to add token to blacklist")
}
2025-10-06 08:28:39 +00:00
// 繼續處理其他 token不要因為一個失敗就停止
}
}
// 刪除用戶的所有 token 記錄
err = use.TokenRepo.DeleteAccessTokensByUID(ctx, uid)
if err != nil {
2025-11-04 09:47:36 +00:00
if use.Logger != nil {
use.Logger.WithFields(
errs.LogField{
Key: "uid",
Val: uid,
},
errs.LogField{
Key: "error",
Val: err,
}).Error("failed to delete user tokens")
}
2025-10-06 08:28:39 +00:00
}
2025-11-04 09:47:36 +00:00
if use.Logger != nil {
use.Logger.WithFields(
errs.LogField{
Key: "uid",
Val: uid,
},
errs.LogField{
Key: "tokenCount",
Val: len(tokens),
},
errs.LogField{
Key: "reason",
Val: reason,
}).Error("all user tokens blacklisted")
}
2025-10-06 08:28:39 +00:00
return nil
}