backend/pkg/permission/usecase/token.go

563 lines
15 KiB
Go
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package usecase
import (
"context"
"fmt"
"strconv"
"time"
"backend/internal/config"
errs "backend/pkg/library/errors"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"backend/pkg/permission/domain/token"
"backend/pkg/permission/domain/usecase"
"github.com/segmentio/ksuid"
)
type TokenUseCaseParam struct {
TokenRepo repository.TokenRepository
Logger errs.Logger
Config *config.Config
}
type TokenUseCase struct {
TokenUseCaseParam
}
func (use *TokenUseCase) ReadTokenBasicData(ctx context.Context, token string) (map[string]string, error) {
claims, err := ParseClaims(token, use.Config.Token.AccessSecret, false)
if err != nil {
return nil, errs.AuthSigPayloadMismatchError("validate token claims error")
}
return claims, nil
}
func MustTokenUseCase(param TokenUseCaseParam) usecase.TokenUseCase {
return &TokenUseCase{
param,
}
}
// ============================================ token ============================================
func (use *TokenUseCase) NewToken(ctx context.Context, req entity.AuthorizationReq) (entity.TokenResp, error) {
tokenObj, err := use.newToken(ctx, &req)
if err != nil {
return entity.TokenResp{}, err
}
err = use.TokenRepo.Create(ctx, *tokenObj)
if err != nil {
return entity.TokenResp{}, errs.DBErrorError("failed to create token")
}
return entity.TokenResp{
AccessToken: tokenObj.AccessToken,
TokenType: token.TypeBearer.String(),
ExpiresIn: int64(tokenObj.ExpiresIn),
RefreshToken: tokenObj.RefreshToken,
}, nil
}
func (use *TokenUseCase) newToken(ctx context.Context, req *entity.AuthorizationReq) (*entity.Token, error) {
// 準備建立 Token 所需
now := time.Now().UTC()
expires := req.Expires
refreshExpires := req.Expires
if expires <= 0 {
// 將時間加上 n 秒
sec := use.Config.Token.AccessTokenExpiry
// 獲取 Unix 時間戳
expires = now.Add(sec).Unix()
refreshExpires = expires
}
// 如果這是一個 Refresh Token 過期時間要比普通的Token 長
if req.IsRefreshToken {
// 獲取 Unix 時間戳
refresh := use.Config.Token.RefreshTokenExpiry
refreshExpires = now.Add(refresh).Unix()
}
token := entity.Token{
ID: ksuid.New().String(),
DeviceID: req.DeviceID,
ExpiresIn: int(expires),
RefreshExpiresIn: int(refreshExpires),
AccessCreateAt: now,
RefreshCreateAt: now,
}
tc := make(TokenClaims)
if req.Data != nil {
for k, v := range req.Data {
tc[k] = v
}
}
tc.SetRole(req.Role)
tc.SetID(token.ID)
tc.SetScope(req.Scope)
tc.SetLoginID(req.Account)
token.UID = tc.UID()
if req.DeviceID != "" {
tc.SetDeviceID(req.DeviceID)
}
var err error
token.AccessToken, err = accessTokenGenerator(token, tc, use.Config.Token.AccessSecret)
if err != nil {
return nil, errs.SysInternalError("failed to generator access token").Wrap(err)
}
if req.IsRefreshToken {
token.RefreshToken = refreshTokenGenerator(token.AccessToken)
}
return &token, nil
}
func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTokenReq) (entity.RefreshTokenResp, error) {
// Step 1: 檢查 refresh token
tokenObj, err := use.TokenRepo.GetAccessTokenByOneTimeToken(ctx, req.Token)
if err != nil {
return entity.RefreshTokenResp{}, errs.DBErrorError("failed to get access token").Wrap(err)
}
// Step 2: 提取 Claims Data
claimsData, err := ParseClaims(tokenObj.AccessToken, use.Config.Token.AccessSecret, false)
if err != nil {
return entity.RefreshTokenResp{}, errs.AuthSigPayloadMismatchError("failed to extract claims")
}
// Step 3: 創建新 token
credentials := token.ClientCredentials
newToken, err := use.newToken(ctx, &entity.AuthorizationReq{
GrantType: credentials.ToString(),
Scope: claimsData.Scope(),
DeviceID: req.DeviceID,
Data: claimsData,
Expires: req.Expires,
IsRefreshToken: true,
Account: claimsData.LoginID(),
Role: claimsData.Role(),
})
if err != nil {
return entity.RefreshTokenResp{}, errs.DBErrorError("failed to create new token").Wrap(err)
}
if err := use.TokenRepo.Create(ctx, *newToken); err != nil {
return entity.RefreshTokenResp{}, errs.DBErrorError("failed to create new token").Wrap(err)
}
// Step 4: 刪除舊 token 並創建新 token
if err := use.TokenRepo.Delete(ctx, tokenObj); err != nil {
return entity.RefreshTokenResp{}, errs.DBErrorError("failed to delete old token").Wrap(err)
}
// 返回新的 Token 響應
return entity.RefreshTokenResp{
Token: newToken.AccessToken,
OneTimeToken: newToken.RefreshToken,
ExpiresIn: int64(newToken.ExpiresIn),
TokenType: token.TypeBearer.String(),
}, nil
}
func (use *TokenUseCase) CancelToken(ctx context.Context, req entity.CancelTokenReq) error {
claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, false)
if err != nil {
return errs.AuthSigPayloadMismatchError("failed to get token claims")
}
token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
if err != nil {
return errs.DBErrorError(fmt.Sprintf("failed to get token claims :%s", claims.ID()))
}
err = use.TokenRepo.Delete(ctx, token)
if err != nil {
return errs.DBErrorError(fmt.Sprintf("failed to delete token :%s", token.ID)).Wrap(err)
}
return nil
}
func (use *TokenUseCase) ValidationToken(ctx context.Context, req entity.ValidationTokenReq) (entity.ValidationTokenResp, error) {
claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, true)
if err != nil {
return entity.ValidationTokenResp{}, errs.AuthSigPayloadMismatchError("validate token claims error")
}
token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
if err != nil {
return entity.ValidationTokenResp{}, errs.DBErrorError(fmt.Sprintf("failed to get token :%s", claims.ID())).Wrap(err)
}
return entity.ValidationTokenResp{
Token: entity.Token{
ID: token.ID,
UID: token.UID,
DeviceID: token.DeviceID,
AccessCreateAt: token.AccessCreateAt,
AccessToken: token.AccessToken,
ExpiresIn: token.ExpiresIn,
RefreshToken: token.RefreshToken,
RefreshExpiresIn: token.RefreshExpiresIn,
RefreshCreateAt: token.RefreshCreateAt,
},
Data: claims,
}, nil
}
func (use *TokenUseCase) CancelTokens(ctx context.Context, req entity.DoTokenByUIDReq) error {
if req.UID != "" {
err := use.TokenRepo.DeleteAccessTokensByUID(ctx, req.UID)
if err != nil {
return errs.DBErrorError("failed to cancel tokens by uid").Wrap(err)
}
}
if len(req.IDs) > 0 {
err := use.TokenRepo.DeleteAccessTokenByID(ctx, req.IDs)
if err != nil {
return errs.DBErrorError("failed to cancel tokens by token ids").Wrap(err)
}
}
return nil
}
func (use *TokenUseCase) CancelTokenByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) error {
err := use.TokenRepo.DeleteAccessTokensByDeviceID(ctx, req.DeviceID)
if err != nil {
return errs.DBErrorError("failed to cancel tokens by device id").Wrap(err)
}
return nil
}
func (use *TokenUseCase) GetUserTokensByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) ([]*entity.TokenResp, error) {
uidTokens, err := use.TokenRepo.GetAccessTokensByDeviceID(ctx, req.DeviceID)
if err != nil {
return nil, errs.DBErrorError("failed to get tokens by device id").Wrap(err)
}
tokens := make([]*entity.TokenResp, 0, len(uidTokens))
for _, v := range uidTokens {
tokens = append(tokens, &entity.TokenResp{
AccessToken: v.AccessToken,
TokenType: token.TypeBearer.String(),
ExpiresIn: int64(v.ExpiresIn),
RefreshToken: v.RefreshToken,
})
}
return tokens, nil
}
func (use *TokenUseCase) GetUserTokensByUID(ctx context.Context, req entity.QueryTokenByUIDReq) ([]*entity.TokenResp, error) {
uidTokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, req.UID)
if err != nil {
return nil, errs.DBErrorError("failed to get tokens by uid").Wrap(err)
}
tokens := make([]*entity.TokenResp, 0, len(uidTokens))
for _, v := range uidTokens {
tokens = append(tokens, &entity.TokenResp{
AccessToken: v.AccessToken,
TokenType: token.TypeBearer.String(),
ExpiresIn: int64(v.ExpiresIn),
RefreshToken: v.RefreshToken,
})
}
return tokens, nil
}
func (use *TokenUseCase) NewOneTimeToken(ctx context.Context, req entity.CreateOneTimeTokenReq) (entity.CreateOneTimeTokenResp, error) {
// 驗證Token
claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, false)
if err != nil {
return entity.CreateOneTimeTokenResp{}, errs.AuthSigPayloadMismatchError("failed to get token claims").Wrap(err)
}
tokenObj, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
if err != nil {
return entity.CreateOneTimeTokenResp{}, errs.DBErrorError("failed to get token by id").Wrap(err)
}
oneTimeToken := refreshTokenGenerator(ksuid.New().String())
key := token.TicketKeyPrefix + oneTimeToken
if err = use.TokenRepo.CreateOneTimeToken(ctx, key, entity.Ticket{
Data: claims,
Token: tokenObj,
}, time.Minute); err != nil {
return entity.CreateOneTimeTokenResp{}, errs.DBErrorError("failed to create new one-time token").Wrap(err)
}
return entity.CreateOneTimeTokenResp{
OneTimeToken: oneTimeToken,
}, nil
}
func (use *TokenUseCase) CancelOneTimeToken(ctx context.Context, req entity.CancelOneTimeTokenReq) error {
err := use.TokenRepo.DeleteOneTimeToken(ctx, req.Token, nil)
if err != nil {
return errs.DBErrorError("failed to del one time token by token")
}
return nil
}
// BlacklistToken 將 JWT token 加入黑名單 (立即撤銷)
func (use *TokenUseCase) BlacklistToken(ctx context.Context, token string, reason string) error {
// 解析 JWT 獲取完整的 claims
claimMap, err := parseToken(token, use.Config.Token.AccessSecret, false)
if err != nil {
return errs.AuthSigPayloadMismatchError("failed to parse token claims").Wrap(err)
}
// 獲取 JTI (JWT ID)
jti, exists := claimMap["jti"]
if !exists {
return errs.ResNotFoundError("token missing JTI claim").Wrap(err)
}
jtiStr, ok := jti.(string)
if !ok {
return errs.ResNotFoundError("token missing JTI claim").Wrap(err)
}
// 獲取 UID (可能在 data 中)
var uid string
if dataInterface, exists := claimMap["data"]; exists {
if dataMap, ok := dataInterface.(map[string]interface{}); ok {
if uidInterface, exists := dataMap["uid"]; exists {
uid, _ = uidInterface.(string)
}
}
}
// 獲取過期時間
exp, exists := claimMap["exp"]
if !exists {
return errs.AuthExpiredError("token missing exp claim").Wrap(err)
}
// 將 exp 轉換為 int64 (JWT 中通常是 float64)
var expInt int64
switch v := exp.(type) {
case float64:
expInt = int64(v)
case int64:
expInt = v
case string:
parsedExp, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return errs.SysInternalError("failed to parse exp claim").Wrap(err)
}
expInt = parsedExp
default:
return errs.SysInternalError("exp claim type conversion failed").Wrap(err)
}
// 創建黑名單條目
blacklistEntry := &entity.BlacklistEntry{
JTI: jtiStr,
UID: uid,
ExpiresAt: expInt,
CreatedAt: time.Now().Unix(),
}
// 添加到黑名單
err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算
if err != nil {
return errs.DBErrorError("failed to add token to blacklist").Wrap(err)
}
// 記錄成功日誌(如果 Logger 存在)
if use.Logger != nil {
use.Logger.WithFields(
errs.LogField{
Key: "jti",
Val: jtiStr,
},
errs.LogField{
Key: "uid",
Val: uid,
},
errs.LogField{
Key: "reason",
Val: reason,
}).Info("token blacklisted")
}
return nil
}
// IsTokenBlacklisted 檢查 JWT token 是否在黑名單中
func (use *TokenUseCase) IsTokenBlacklisted(ctx context.Context, jti string) (bool, error) {
isBlacklisted, err := use.TokenRepo.IsBlacklisted(ctx, jti)
if err != nil {
return false, errs.DBErrorError("failed to check blacklist status").Wrap(err)
}
return isBlacklisted, nil
}
// BlacklistAllUserTokens 將用戶的所有 token 加入黑名單 (全設備登出)
func (use *TokenUseCase) BlacklistAllUserTokens(ctx context.Context, uid string, reason string) error {
// 獲取用戶的所有 token
tokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, uid)
if err != nil {
return errs.DBErrorError("failed to get user tokens").Wrap(err)
}
// 為每個 token 創建黑名單條目
for _, token := range tokens {
// 解析 token 獲取 JTI 和過期時間
claims, err := ParseClaims(token.AccessToken, use.Config.Token.AccessSecret, false)
if err != nil {
if use.Logger != nil {
use.Logger.WithFields(
errs.LogField{
Key: "uid",
Val: uid,
},
errs.LogField{
Key: "tokenID",
Val: token.ID,
},
errs.LogField{
Key: "error",
Val: err,
}).Error("failed to parse token for blacklisting")
}
continue // 跳過無效的 token繼續處理其他 token
}
jti, exists := claims["jti"]
if !exists || jti == "" {
if use.Logger != nil {
use.Logger.WithFields(
errs.LogField{
Key: "uid",
Val: uid,
},
errs.LogField{
Key: "tokenID",
Val: token.ID,
}).Error("failed to parse token for blacklisting")
}
continue
}
exp, exists := claims["exp"]
if !exists {
if use.Logger != nil {
use.Logger.WithFields(
errs.LogField{
Key: "uid",
Val: uid,
},
errs.LogField{
Key: "tokenID",
Val: token.ID,
}).Error("token missing exp claim")
}
continue
}
// 將 exp 字符串轉換為 int64
expInt, err := strconv.ParseInt(exp, 10, 64)
if err != nil {
if use.Logger != nil {
use.Logger.WithFields(
errs.LogField{
Key: "uid",
Val: uid,
},
errs.LogField{
Key: "tokenID",
Val: token.ID,
},
errs.LogField{
Key: "error",
Val: err,
}).Error("failed to parse exp claim")
}
continue
}
// 創建黑名單條目
blacklistEntry := &entity.BlacklistEntry{
JTI: jti,
UID: uid,
ExpiresAt: expInt,
CreatedAt: time.Now().Unix(),
}
// 添加到黑名單
err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算
if err != nil {
if use.Logger != nil {
use.Logger.WithFields(
errs.LogField{
Key: "uid",
Val: uid,
},
errs.LogField{
Key: "jti",
Val: jti,
},
errs.LogField{
Key: "error",
Val: err,
}).Error("failed to add token to blacklist")
}
// 繼續處理其他 token不要因為一個失敗就停止
}
}
// 刪除用戶的所有 token 記錄
err = use.TokenRepo.DeleteAccessTokensByUID(ctx, uid)
if err != nil {
if use.Logger != nil {
use.Logger.WithFields(
errs.LogField{
Key: "uid",
Val: uid,
},
errs.LogField{
Key: "error",
Val: err,
}).Error("failed to delete user tokens")
}
}
if use.Logger != nil {
use.Logger.WithFields(
errs.LogField{
Key: "uid",
Val: uid,
},
errs.LogField{
Key: "tokenCount",
Val: len(tokens),
},
errs.LogField{
Key: "reason",
Val: reason,
}).Error("all user tokens blacklisted")
}
return nil
}