563 lines
15 KiB
Go
Executable File
563 lines
15 KiB
Go
Executable File
package usecase
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"strconv"
|
||
"time"
|
||
|
||
"backend/internal/config"
|
||
errs "backend/pkg/library/errors"
|
||
"backend/pkg/permission/domain/entity"
|
||
"backend/pkg/permission/domain/repository"
|
||
"backend/pkg/permission/domain/token"
|
||
"backend/pkg/permission/domain/usecase"
|
||
|
||
"github.com/segmentio/ksuid"
|
||
)
|
||
|
||
type TokenUseCaseParam struct {
|
||
TokenRepo repository.TokenRepository
|
||
Logger errs.Logger
|
||
Config *config.Config
|
||
}
|
||
|
||
type TokenUseCase struct {
|
||
TokenUseCaseParam
|
||
}
|
||
|
||
func (use *TokenUseCase) ReadTokenBasicData(ctx context.Context, token string) (map[string]string, error) {
|
||
claims, err := ParseClaims(token, use.Config.Token.AccessSecret, false)
|
||
if err != nil {
|
||
return nil, errs.AuthSigPayloadMismatchError("validate token claims error")
|
||
}
|
||
|
||
return claims, nil
|
||
}
|
||
|
||
func MustTokenUseCase(param TokenUseCaseParam) usecase.TokenUseCase {
|
||
return &TokenUseCase{
|
||
param,
|
||
}
|
||
}
|
||
|
||
// ============================================ token ============================================
|
||
|
||
func (use *TokenUseCase) NewToken(ctx context.Context, req entity.AuthorizationReq) (entity.TokenResp, error) {
|
||
tokenObj, err := use.newToken(ctx, &req)
|
||
if err != nil {
|
||
return entity.TokenResp{}, err
|
||
}
|
||
|
||
err = use.TokenRepo.Create(ctx, *tokenObj)
|
||
if err != nil {
|
||
return entity.TokenResp{}, errs.DBErrorError("failed to create token")
|
||
}
|
||
|
||
return entity.TokenResp{
|
||
AccessToken: tokenObj.AccessToken,
|
||
TokenType: token.TypeBearer.String(),
|
||
ExpiresIn: int64(tokenObj.ExpiresIn),
|
||
RefreshToken: tokenObj.RefreshToken,
|
||
}, nil
|
||
}
|
||
|
||
func (use *TokenUseCase) newToken(ctx context.Context, req *entity.AuthorizationReq) (*entity.Token, error) {
|
||
// 準備建立 Token 所需
|
||
now := time.Now().UTC()
|
||
expires := req.Expires
|
||
refreshExpires := req.Expires
|
||
|
||
if expires <= 0 {
|
||
// 將時間加上 n 秒
|
||
sec := use.Config.Token.AccessTokenExpiry
|
||
// 獲取 Unix 時間戳
|
||
expires = now.Add(sec).Unix()
|
||
refreshExpires = expires
|
||
}
|
||
|
||
// 如果這是一個 Refresh Token 過期時間要比普通的Token 長
|
||
if req.IsRefreshToken {
|
||
// 獲取 Unix 時間戳
|
||
refresh := use.Config.Token.RefreshTokenExpiry
|
||
refreshExpires = now.Add(refresh).Unix()
|
||
}
|
||
|
||
token := entity.Token{
|
||
ID: ksuid.New().String(),
|
||
DeviceID: req.DeviceID,
|
||
ExpiresIn: int(expires),
|
||
RefreshExpiresIn: int(refreshExpires),
|
||
AccessCreateAt: now,
|
||
RefreshCreateAt: now,
|
||
}
|
||
|
||
tc := make(TokenClaims)
|
||
if req.Data != nil {
|
||
for k, v := range req.Data {
|
||
tc[k] = v
|
||
}
|
||
}
|
||
tc.SetRole(req.Role)
|
||
tc.SetID(token.ID)
|
||
tc.SetScope(req.Scope)
|
||
tc.SetLoginID(req.Account)
|
||
|
||
token.UID = tc.UID()
|
||
|
||
if req.DeviceID != "" {
|
||
tc.SetDeviceID(req.DeviceID)
|
||
}
|
||
|
||
var err error
|
||
token.AccessToken, err = accessTokenGenerator(token, tc, use.Config.Token.AccessSecret)
|
||
if err != nil {
|
||
return nil, errs.SysInternalError("failed to generator access token").Wrap(err)
|
||
}
|
||
|
||
if req.IsRefreshToken {
|
||
token.RefreshToken = refreshTokenGenerator(token.AccessToken)
|
||
}
|
||
|
||
return &token, nil
|
||
}
|
||
|
||
func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTokenReq) (entity.RefreshTokenResp, error) {
|
||
// Step 1: 檢查 refresh token
|
||
tokenObj, err := use.TokenRepo.GetAccessTokenByOneTimeToken(ctx, req.Token)
|
||
if err != nil {
|
||
return entity.RefreshTokenResp{}, errs.DBErrorError("failed to get access token").Wrap(err)
|
||
}
|
||
|
||
// Step 2: 提取 Claims Data
|
||
claimsData, err := ParseClaims(tokenObj.AccessToken, use.Config.Token.AccessSecret, false)
|
||
if err != nil {
|
||
return entity.RefreshTokenResp{}, errs.AuthSigPayloadMismatchError("failed to extract claims")
|
||
}
|
||
|
||
// Step 3: 創建新 token
|
||
credentials := token.ClientCredentials
|
||
newToken, err := use.newToken(ctx, &entity.AuthorizationReq{
|
||
GrantType: credentials.ToString(),
|
||
Scope: claimsData.Scope(),
|
||
DeviceID: req.DeviceID,
|
||
Data: claimsData,
|
||
Expires: req.Expires,
|
||
IsRefreshToken: true,
|
||
Account: claimsData.LoginID(),
|
||
Role: claimsData.Role(),
|
||
})
|
||
if err != nil {
|
||
return entity.RefreshTokenResp{}, errs.DBErrorError("failed to create new token").Wrap(err)
|
||
}
|
||
|
||
if err := use.TokenRepo.Create(ctx, *newToken); err != nil {
|
||
return entity.RefreshTokenResp{}, errs.DBErrorError("failed to create new token").Wrap(err)
|
||
}
|
||
|
||
// Step 4: 刪除舊 token 並創建新 token
|
||
if err := use.TokenRepo.Delete(ctx, tokenObj); err != nil {
|
||
return entity.RefreshTokenResp{}, errs.DBErrorError("failed to delete old token").Wrap(err)
|
||
}
|
||
|
||
// 返回新的 Token 響應
|
||
return entity.RefreshTokenResp{
|
||
Token: newToken.AccessToken,
|
||
OneTimeToken: newToken.RefreshToken,
|
||
ExpiresIn: int64(newToken.ExpiresIn),
|
||
TokenType: token.TypeBearer.String(),
|
||
}, nil
|
||
}
|
||
|
||
func (use *TokenUseCase) CancelToken(ctx context.Context, req entity.CancelTokenReq) error {
|
||
claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, false)
|
||
if err != nil {
|
||
return errs.AuthSigPayloadMismatchError("failed to get token claims")
|
||
}
|
||
|
||
token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
|
||
if err != nil {
|
||
return errs.DBErrorError(fmt.Sprintf("failed to get token claims :%s", claims.ID()))
|
||
}
|
||
|
||
err = use.TokenRepo.Delete(ctx, token)
|
||
if err != nil {
|
||
return errs.DBErrorError(fmt.Sprintf("failed to delete token :%s", token.ID)).Wrap(err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (use *TokenUseCase) ValidationToken(ctx context.Context, req entity.ValidationTokenReq) (entity.ValidationTokenResp, error) {
|
||
claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, true)
|
||
if err != nil {
|
||
return entity.ValidationTokenResp{}, errs.AuthSigPayloadMismatchError("validate token claims error")
|
||
}
|
||
|
||
token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
|
||
if err != nil {
|
||
return entity.ValidationTokenResp{}, errs.DBErrorError(fmt.Sprintf("failed to get token :%s", claims.ID())).Wrap(err)
|
||
}
|
||
|
||
return entity.ValidationTokenResp{
|
||
Token: entity.Token{
|
||
ID: token.ID,
|
||
UID: token.UID,
|
||
DeviceID: token.DeviceID,
|
||
AccessCreateAt: token.AccessCreateAt,
|
||
AccessToken: token.AccessToken,
|
||
ExpiresIn: token.ExpiresIn,
|
||
RefreshToken: token.RefreshToken,
|
||
RefreshExpiresIn: token.RefreshExpiresIn,
|
||
RefreshCreateAt: token.RefreshCreateAt,
|
||
},
|
||
Data: claims,
|
||
}, nil
|
||
}
|
||
|
||
func (use *TokenUseCase) CancelTokens(ctx context.Context, req entity.DoTokenByUIDReq) error {
|
||
if req.UID != "" {
|
||
err := use.TokenRepo.DeleteAccessTokensByUID(ctx, req.UID)
|
||
if err != nil {
|
||
return errs.DBErrorError("failed to cancel tokens by uid").Wrap(err)
|
||
}
|
||
}
|
||
|
||
if len(req.IDs) > 0 {
|
||
err := use.TokenRepo.DeleteAccessTokenByID(ctx, req.IDs)
|
||
if err != nil {
|
||
return errs.DBErrorError("failed to cancel tokens by token ids").Wrap(err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (use *TokenUseCase) CancelTokenByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) error {
|
||
err := use.TokenRepo.DeleteAccessTokensByDeviceID(ctx, req.DeviceID)
|
||
if err != nil {
|
||
return errs.DBErrorError("failed to cancel tokens by device id").Wrap(err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (use *TokenUseCase) GetUserTokensByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) ([]*entity.TokenResp, error) {
|
||
uidTokens, err := use.TokenRepo.GetAccessTokensByDeviceID(ctx, req.DeviceID)
|
||
if err != nil {
|
||
return nil, errs.DBErrorError("failed to get tokens by device id").Wrap(err)
|
||
}
|
||
|
||
tokens := make([]*entity.TokenResp, 0, len(uidTokens))
|
||
for _, v := range uidTokens {
|
||
tokens = append(tokens, &entity.TokenResp{
|
||
AccessToken: v.AccessToken,
|
||
TokenType: token.TypeBearer.String(),
|
||
ExpiresIn: int64(v.ExpiresIn),
|
||
RefreshToken: v.RefreshToken,
|
||
})
|
||
}
|
||
|
||
return tokens, nil
|
||
}
|
||
|
||
func (use *TokenUseCase) GetUserTokensByUID(ctx context.Context, req entity.QueryTokenByUIDReq) ([]*entity.TokenResp, error) {
|
||
uidTokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, req.UID)
|
||
if err != nil {
|
||
return nil, errs.DBErrorError("failed to get tokens by uid").Wrap(err)
|
||
}
|
||
|
||
tokens := make([]*entity.TokenResp, 0, len(uidTokens))
|
||
for _, v := range uidTokens {
|
||
tokens = append(tokens, &entity.TokenResp{
|
||
AccessToken: v.AccessToken,
|
||
TokenType: token.TypeBearer.String(),
|
||
ExpiresIn: int64(v.ExpiresIn),
|
||
RefreshToken: v.RefreshToken,
|
||
})
|
||
}
|
||
|
||
return tokens, nil
|
||
}
|
||
|
||
func (use *TokenUseCase) NewOneTimeToken(ctx context.Context, req entity.CreateOneTimeTokenReq) (entity.CreateOneTimeTokenResp, error) {
|
||
// 驗證Token
|
||
claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, false)
|
||
if err != nil {
|
||
return entity.CreateOneTimeTokenResp{}, errs.AuthSigPayloadMismatchError("failed to get token claims").Wrap(err)
|
||
}
|
||
|
||
tokenObj, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
|
||
if err != nil {
|
||
return entity.CreateOneTimeTokenResp{}, errs.DBErrorError("failed to get token by id").Wrap(err)
|
||
}
|
||
|
||
oneTimeToken := refreshTokenGenerator(ksuid.New().String())
|
||
key := token.TicketKeyPrefix + oneTimeToken
|
||
if err = use.TokenRepo.CreateOneTimeToken(ctx, key, entity.Ticket{
|
||
Data: claims,
|
||
Token: tokenObj,
|
||
}, time.Minute); err != nil {
|
||
return entity.CreateOneTimeTokenResp{}, errs.DBErrorError("failed to create new one-time token").Wrap(err)
|
||
}
|
||
|
||
return entity.CreateOneTimeTokenResp{
|
||
OneTimeToken: oneTimeToken,
|
||
}, nil
|
||
}
|
||
|
||
func (use *TokenUseCase) CancelOneTimeToken(ctx context.Context, req entity.CancelOneTimeTokenReq) error {
|
||
err := use.TokenRepo.DeleteOneTimeToken(ctx, req.Token, nil)
|
||
if err != nil {
|
||
return errs.DBErrorError("failed to del one time token by token")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// BlacklistToken 將 JWT token 加入黑名單 (立即撤銷)
|
||
func (use *TokenUseCase) BlacklistToken(ctx context.Context, token string, reason string) error {
|
||
// 解析 JWT 獲取完整的 claims
|
||
claimMap, err := parseToken(token, use.Config.Token.AccessSecret, false)
|
||
if err != nil {
|
||
return errs.AuthSigPayloadMismatchError("failed to parse token claims").Wrap(err)
|
||
}
|
||
|
||
// 獲取 JTI (JWT ID)
|
||
jti, exists := claimMap["jti"]
|
||
if !exists {
|
||
return errs.ResNotFoundError("token missing JTI claim").Wrap(err)
|
||
}
|
||
|
||
jtiStr, ok := jti.(string)
|
||
if !ok {
|
||
return errs.ResNotFoundError("token missing JTI claim").Wrap(err)
|
||
}
|
||
|
||
// 獲取 UID (可能在 data 中)
|
||
var uid string
|
||
if dataInterface, exists := claimMap["data"]; exists {
|
||
if dataMap, ok := dataInterface.(map[string]interface{}); ok {
|
||
if uidInterface, exists := dataMap["uid"]; exists {
|
||
uid, _ = uidInterface.(string)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 獲取過期時間
|
||
exp, exists := claimMap["exp"]
|
||
if !exists {
|
||
return errs.AuthExpiredError("token missing exp claim").Wrap(err)
|
||
}
|
||
|
||
// 將 exp 轉換為 int64 (JWT 中通常是 float64)
|
||
var expInt int64
|
||
switch v := exp.(type) {
|
||
case float64:
|
||
expInt = int64(v)
|
||
case int64:
|
||
expInt = v
|
||
case string:
|
||
parsedExp, err := strconv.ParseInt(v, 10, 64)
|
||
if err != nil {
|
||
return errs.SysInternalError("failed to parse exp claim").Wrap(err)
|
||
}
|
||
expInt = parsedExp
|
||
default:
|
||
return errs.SysInternalError("exp claim type conversion failed").Wrap(err)
|
||
}
|
||
|
||
// 創建黑名單條目
|
||
blacklistEntry := &entity.BlacklistEntry{
|
||
JTI: jtiStr,
|
||
UID: uid,
|
||
ExpiresAt: expInt,
|
||
CreatedAt: time.Now().Unix(),
|
||
}
|
||
|
||
// 添加到黑名單
|
||
err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算
|
||
if err != nil {
|
||
return errs.DBErrorError("failed to add token to blacklist").Wrap(err)
|
||
}
|
||
|
||
// 記錄成功日誌(如果 Logger 存在)
|
||
if use.Logger != nil {
|
||
use.Logger.WithFields(
|
||
errs.LogField{
|
||
Key: "jti",
|
||
Val: jtiStr,
|
||
},
|
||
errs.LogField{
|
||
Key: "uid",
|
||
Val: uid,
|
||
},
|
||
errs.LogField{
|
||
Key: "reason",
|
||
Val: reason,
|
||
}).Info("token blacklisted")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// IsTokenBlacklisted 檢查 JWT token 是否在黑名單中
|
||
func (use *TokenUseCase) IsTokenBlacklisted(ctx context.Context, jti string) (bool, error) {
|
||
isBlacklisted, err := use.TokenRepo.IsBlacklisted(ctx, jti)
|
||
if err != nil {
|
||
return false, errs.DBErrorError("failed to check blacklist status").Wrap(err)
|
||
}
|
||
|
||
return isBlacklisted, nil
|
||
}
|
||
|
||
// BlacklistAllUserTokens 將用戶的所有 token 加入黑名單 (全設備登出)
|
||
func (use *TokenUseCase) BlacklistAllUserTokens(ctx context.Context, uid string, reason string) error {
|
||
// 獲取用戶的所有 token
|
||
tokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, uid)
|
||
if err != nil {
|
||
return errs.DBErrorError("failed to get user tokens").Wrap(err)
|
||
}
|
||
|
||
// 為每個 token 創建黑名單條目
|
||
for _, token := range tokens {
|
||
// 解析 token 獲取 JTI 和過期時間
|
||
claims, err := ParseClaims(token.AccessToken, use.Config.Token.AccessSecret, false)
|
||
if err != nil {
|
||
if use.Logger != nil {
|
||
use.Logger.WithFields(
|
||
errs.LogField{
|
||
Key: "uid",
|
||
Val: uid,
|
||
},
|
||
errs.LogField{
|
||
Key: "tokenID",
|
||
Val: token.ID,
|
||
},
|
||
errs.LogField{
|
||
Key: "error",
|
||
Val: err,
|
||
}).Error("failed to parse token for blacklisting")
|
||
}
|
||
continue // 跳過無效的 token,繼續處理其他 token
|
||
}
|
||
|
||
jti, exists := claims["jti"]
|
||
if !exists || jti == "" {
|
||
if use.Logger != nil {
|
||
use.Logger.WithFields(
|
||
errs.LogField{
|
||
Key: "uid",
|
||
Val: uid,
|
||
},
|
||
errs.LogField{
|
||
Key: "tokenID",
|
||
Val: token.ID,
|
||
}).Error("failed to parse token for blacklisting")
|
||
}
|
||
continue
|
||
}
|
||
|
||
exp, exists := claims["exp"]
|
||
if !exists {
|
||
if use.Logger != nil {
|
||
use.Logger.WithFields(
|
||
errs.LogField{
|
||
Key: "uid",
|
||
Val: uid,
|
||
},
|
||
errs.LogField{
|
||
Key: "tokenID",
|
||
Val: token.ID,
|
||
}).Error("token missing exp claim")
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 將 exp 字符串轉換為 int64
|
||
expInt, err := strconv.ParseInt(exp, 10, 64)
|
||
if err != nil {
|
||
if use.Logger != nil {
|
||
use.Logger.WithFields(
|
||
errs.LogField{
|
||
Key: "uid",
|
||
Val: uid,
|
||
},
|
||
errs.LogField{
|
||
Key: "tokenID",
|
||
Val: token.ID,
|
||
},
|
||
errs.LogField{
|
||
Key: "error",
|
||
Val: err,
|
||
}).Error("failed to parse exp claim")
|
||
}
|
||
|
||
continue
|
||
}
|
||
|
||
// 創建黑名單條目
|
||
blacklistEntry := &entity.BlacklistEntry{
|
||
JTI: jti,
|
||
UID: uid,
|
||
ExpiresAt: expInt,
|
||
CreatedAt: time.Now().Unix(),
|
||
}
|
||
|
||
// 添加到黑名單
|
||
err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算
|
||
if err != nil {
|
||
if use.Logger != nil {
|
||
use.Logger.WithFields(
|
||
errs.LogField{
|
||
Key: "uid",
|
||
Val: uid,
|
||
},
|
||
errs.LogField{
|
||
Key: "jti",
|
||
Val: jti,
|
||
},
|
||
errs.LogField{
|
||
Key: "error",
|
||
Val: err,
|
||
}).Error("failed to add token to blacklist")
|
||
}
|
||
// 繼續處理其他 token,不要因為一個失敗就停止
|
||
}
|
||
}
|
||
|
||
// 刪除用戶的所有 token 記錄
|
||
err = use.TokenRepo.DeleteAccessTokensByUID(ctx, uid)
|
||
if err != nil {
|
||
if use.Logger != nil {
|
||
use.Logger.WithFields(
|
||
errs.LogField{
|
||
Key: "uid",
|
||
Val: uid,
|
||
},
|
||
errs.LogField{
|
||
Key: "error",
|
||
Val: err,
|
||
}).Error("failed to delete user tokens")
|
||
}
|
||
}
|
||
|
||
if use.Logger != nil {
|
||
use.Logger.WithFields(
|
||
errs.LogField{
|
||
Key: "uid",
|
||
Val: uid,
|
||
},
|
||
errs.LogField{
|
||
Key: "tokenCount",
|
||
Val: len(tokens),
|
||
},
|
||
errs.LogField{
|
||
Key: "reason",
|
||
Val: reason,
|
||
}).Error("all user tokens blacklisted")
|
||
}
|
||
|
||
return nil
|
||
}
|