package usecase import ( "context" "fmt" "strconv" "time" "backend/internal/config" "backend/pkg/library/errs" "backend/pkg/library/errs/code" "backend/pkg/permission/domain/entity" "backend/pkg/permission/domain/repository" "backend/pkg/permission/domain/token" "backend/pkg/permission/domain/usecase" "github.com/segmentio/ksuid" "github.com/zeromicro/go-zero/core/logx" ) type TokenUseCaseParam struct { TokenRepo repository.TokenRepository Config *config.Config } type TokenUseCase struct { TokenUseCaseParam } func (use *TokenUseCase) ReadTokenBasicData(ctx context.Context, token string) (map[string]string, error) { claims, err := parseClaims(token, use.Config.Token.AccessSecret, false) if err != nil { return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "parseClaims", req: token, err: err, message: "validate token claims error", errorCode: code.TokenValidateError, }) } return claims, nil } func MustTokenUseCase(param TokenUseCaseParam) usecase.TokenUseCase { return &TokenUseCase{ param, } } // ============================================ token ============================================ func (use *TokenUseCase) NewToken(ctx context.Context, req entity.AuthorizationReq) (entity.TokenResp, error) { tokenObj, err := use.newToken(ctx, &req) if err != nil { return entity.TokenResp{}, err } err = use.TokenRepo.Create(ctx, *tokenObj) if err != nil { return entity.TokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.Create", req: req, err: err, message: "failed to create token", errorCode: code.TokenCreateError, }) } return entity.TokenResp{ AccessToken: tokenObj.AccessToken, TokenType: token.TypeBearer.String(), ExpiresIn: int64(tokenObj.ExpiresIn), RefreshToken: tokenObj.RefreshToken, }, nil } func (use *TokenUseCase) newToken(ctx context.Context, req *entity.AuthorizationReq) (*entity.Token, error) { // 準備建立 Token 所需 now := time.Now().UTC() expires := req.Expires refreshExpires := req.Expires if expires <= 0 { // 將時間加上 n 秒 sec := use.Config.Token.AccessTokenExpiry // 獲取 Unix 時間戳 expires = now.Add(sec).Unix() refreshExpires = expires } // 如果這是一個 Refresh Token 過期時間要比普通的Token 長 if req.IsRefreshToken { // 獲取 Unix 時間戳 refresh := use.Config.Token.RefreshTokenExpiry refreshExpires = now.Add(refresh).Unix() } token := entity.Token{ ID: ksuid.New().String(), DeviceID: req.DeviceID, ExpiresIn: int(expires), RefreshExpiresIn: int(refreshExpires), AccessCreateAt: now, RefreshCreateAt: now, } tc := make(tokenClaims) if req.Data != nil { for k, v := range req.Data { tc[k] = v } } tc.SetRole(req.Role) tc.SetID(token.ID) tc.SetScope(req.Scope) tc.SetAccount(req.Account) token.UID = tc.UID() if req.DeviceID != "" { tc.SetDeviceID(req.DeviceID) } var err error token.AccessToken, err = accessTokenGenerator(token, tc, use.Config.Token.AccessSecret) if err != nil { return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "accessTokenGenerator", req: req, err: err, message: "failed to generator access token", errorCode: code.TokenCreateError, }) } if req.IsRefreshToken { token.RefreshToken = refreshTokenGenerator(token.AccessToken) } return &token, nil } func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTokenReq) (entity.RefreshTokenResp, error) { // Step 1: 檢查 refresh token tokenObj, err := use.TokenRepo.GetAccessTokenByOneTimeToken(ctx, req.Token) if err != nil { return entity.RefreshTokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.GetAccessTokenByOneTimeToken", req: req, err: err, message: "failed to get access token", errorCode: code.TokenValidateError, }) } // Step 2: 提取 Claims Data claimsData, err := parseClaims(tokenObj.AccessToken, use.Config.Token.AccessSecret, false) if err != nil { return entity.RefreshTokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "extractClaims", req: req, err: err, message: "failed to extract claims", errorCode: code.TokenValidateError, }) } // Step 3: 創建新 token credentials := token.ClientCredentials newToken, err := use.newToken(ctx, &entity.AuthorizationReq{ GrantType: credentials.ToString(), Scope: req.Scope, DeviceID: req.DeviceID, Data: claimsData, Expires: req.Expires, IsRefreshToken: true, Account: req.DeviceID, }) if err != nil { return entity.RefreshTokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "use.newToken", req: req, err: err, message: "failed to create new token", errorCode: code.TokenValidateError, }) } if err := use.TokenRepo.Create(ctx, *newToken); err != nil { return entity.RefreshTokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.Create", req: req, err: err, message: "failed to create new token", errorCode: code.TokenValidateError, }) } // Step 4: 刪除舊 token 並創建新 token if err := use.TokenRepo.Delete(ctx, tokenObj); err != nil { return entity.RefreshTokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.Delete", req: req, err: err, message: "failed to delete old token", errorCode: code.TokenValidateError, }) } // 返回新的 Token 響應 return entity.RefreshTokenResp{ Token: newToken.AccessToken, OneTimeToken: newToken.RefreshToken, ExpiresIn: int64(newToken.ExpiresIn), TokenType: token.TypeBearer.String(), }, nil } func (use *TokenUseCase) CancelToken(ctx context.Context, req entity.CancelTokenReq) error { claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, false) if err != nil { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "CancelToken extractClaims", req: req, err: err, message: "failed to get token claims", errorCode: code.TokenValidateError, }) } token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID()) if err != nil { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo GetAccessTokenByID", req: req, err: err, message: fmt.Sprintf("failed to get token claims :%s", claims.ID()), errorCode: code.TokenValidateError, }) } err = use.TokenRepo.Delete(ctx, token) if err != nil { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo Delete", req: req, err: err, message: fmt.Sprintf("failed to delete token :%s", token.ID), errorCode: code.TokenValidateError, }) } return nil } func (use *TokenUseCase) ValidationToken(ctx context.Context, req entity.ValidationTokenReq) (entity.ValidationTokenResp, error) { claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, true) if err != nil { return entity.ValidationTokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "parseClaims", req: req, err: err, message: "validate token claims error", errorCode: code.TokenValidateError, }) } token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID()) if err != nil { return entity.ValidationTokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.GetAccessTokenByID", req: req, err: err, message: fmt.Sprintf("failed to get token :%s", claims.ID()), errorCode: code.TokenValidateError, }) } return entity.ValidationTokenResp{ Token: entity.Token{ ID: token.ID, UID: token.UID, DeviceID: token.DeviceID, AccessCreateAt: token.AccessCreateAt, AccessToken: token.AccessToken, ExpiresIn: token.ExpiresIn, RefreshToken: token.RefreshToken, RefreshExpiresIn: token.RefreshExpiresIn, RefreshCreateAt: token.RefreshCreateAt, }, Data: claims, }, nil } func (use *TokenUseCase) CancelTokens(ctx context.Context, req entity.DoTokenByUIDReq) error { if req.UID != "" { err := use.TokenRepo.DeleteAccessTokensByUID(ctx, req.UID) if err != nil { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.DeleteAccessTokensByUID", req: req, err: err, message: "failed to cancel tokens by uid", errorCode: code.TokenValidateError, }) } } if len(req.IDs) > 0 { err := use.TokenRepo.DeleteAccessTokenByID(ctx, req.IDs) if err != nil { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.DeleteAccessTokenByID", req: req, err: err, message: "failed to cancel tokens by token ids", errorCode: code.TokenValidateError, }) } } return nil } func (use *TokenUseCase) CancelTokenByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) error { err := use.TokenRepo.DeleteAccessTokensByDeviceID(ctx, req.DeviceID) if err != nil { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.DeleteAccessTokensByDeviceID", req: req, err: err, message: "failed to cancel token by device id", errorCode: code.TokenValidateError, }) } return nil } func (use *TokenUseCase) GetUserTokensByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) ([]*entity.TokenResp, error) { uidTokens, err := use.TokenRepo.GetAccessTokensByDeviceID(ctx, req.DeviceID) if err != nil { return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.GetAccessTokensByDeviceID", req: req, err: err, message: "failed to get token by device id", errorCode: code.TokenNotFound, }) } tokens := make([]*entity.TokenResp, 0, len(uidTokens)) for _, v := range uidTokens { tokens = append(tokens, &entity.TokenResp{ AccessToken: v.AccessToken, TokenType: token.TypeBearer.String(), ExpiresIn: int64(v.ExpiresIn), RefreshToken: v.RefreshToken, }) } return tokens, nil } func (use *TokenUseCase) GetUserTokensByUID(ctx context.Context, req entity.QueryTokenByUIDReq) ([]*entity.TokenResp, error) { uidTokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, req.UID) if err != nil { return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.GetAccessTokensByUID", req: req, err: err, message: "failed to get token by uid", errorCode: code.TokenNotFound, }) } tokens := make([]*entity.TokenResp, 0, len(uidTokens)) for _, v := range uidTokens { tokens = append(tokens, &entity.TokenResp{ AccessToken: v.AccessToken, TokenType: token.TypeBearer.String(), ExpiresIn: int64(v.ExpiresIn), RefreshToken: v.RefreshToken, }) } return tokens, nil } func (use *TokenUseCase) NewOneTimeToken(ctx context.Context, req entity.CreateOneTimeTokenReq) (entity.CreateOneTimeTokenResp, error) { // 驗證Token claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, false) if err != nil { return entity.CreateOneTimeTokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "parseClaims", req: req, err: err, message: "failed to get token claims", errorCode: code.OneTimeTokenError, }) } tokenObj, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID()) if err != nil { return entity.CreateOneTimeTokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.GetAccessTokenByID", req: req, err: err, message: "failed to get token by id", errorCode: code.OneTimeTokenError, }) } oneTimeToken := refreshTokenGenerator(ksuid.New().String()) key := token.TicketKeyPrefix + oneTimeToken if err = use.TokenRepo.CreateOneTimeToken(ctx, key, entity.Ticket{ Data: claims, Token: tokenObj, }, time.Minute); err != nil { return entity.CreateOneTimeTokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.CreateOneTimeToken", req: req, err: err, message: "create one time token error", errorCode: code.OneTimeTokenError, }) } return entity.CreateOneTimeTokenResp{ OneTimeToken: oneTimeToken, }, nil } func (use *TokenUseCase) CancelOneTimeToken(ctx context.Context, req entity.CancelOneTimeTokenReq) error { err := use.TokenRepo.DeleteOneTimeToken(ctx, req.Token, nil) if err != nil { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.DeleteOneTimeToken", req: req, err: err, message: "failed to del one time token by token", errorCode: code.OneTimeTokenError, }) } return nil } type wrapTokenErrorReq struct { funcName string req any err error message string errorCode uint32 } // wrapTokenError 將錯誤信息封裝到 errs.LibError 中 func (use *TokenUseCase) wrapTokenError(ctx context.Context, param wrapTokenErrorReq) error { logFields := []logx.LogField{ {Key: "req", Value: param.req}, {Key: "func", Value: param.funcName}, {Key: "err", Value: param.err.Error()}, } logx.WithContext(ctx).Errorw(param.message, logFields...) wrappedErr := errs.NewError( code.CatToken, code.CatToken, param.errorCode, param.message, ).Wrap(param.err) return wrappedErr } // BlacklistToken 將 JWT token 加入黑名單 (立即撤銷) func (use *TokenUseCase) BlacklistToken(ctx context.Context, token string, reason string) error { // 解析 JWT 獲取完整的 claims claimMap, err := parseToken(token, use.Config.Token.AccessSecret, false) if err != nil { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "BlacklistToken.parseToken", req: token, err: err, message: "failed to parse token claims", errorCode: code.InvalidJWT, }) } // 獲取 JTI (JWT ID) jti, exists := claimMap["jti"] if !exists { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "BlacklistToken.getJTI", req: token, err: entity.ErrInvalidJTI, message: "token missing JTI claim", errorCode: code.InvalidJWT, }) } jtiStr, ok := jti.(string) if !ok { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "BlacklistToken.convertJTI", req: token, err: entity.ErrInvalidJTI, message: "JTI claim is not a string", errorCode: code.InvalidJWT, }) } // 獲取 UID (可能在 data 中) var uid string if dataInterface, exists := claimMap["data"]; exists { if dataMap, ok := dataInterface.(map[string]interface{}); ok { if uidInterface, exists := dataMap["uid"]; exists { uid, _ = uidInterface.(string) } } } // 獲取過期時間 exp, exists := claimMap["exp"] if !exists { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "BlacklistToken.getExp", req: token, err: entity.ErrTokenExpired, message: "token missing exp claim", errorCode: code.TokenExpired, }) } // 將 exp 轉換為 int64 (JWT 中通常是 float64) var expInt int64 switch v := exp.(type) { case float64: expInt = int64(v) case int64: expInt = v case string: parsedExp, err := strconv.ParseInt(v, 10, 64) if err != nil { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "BlacklistToken.parseExp", req: token, err: err, message: "failed to parse exp claim", errorCode: code.TokenExpired, }) } expInt = parsedExp default: return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "BlacklistToken.convertExp", req: token, err: fmt.Errorf("exp claim is not a valid type: %T", exp), message: "exp claim type conversion failed", errorCode: code.TokenExpired, }) } // 創建黑名單條目 blacklistEntry := &entity.BlacklistEntry{ JTI: jtiStr, UID: uid, ExpiresAt: expInt, CreatedAt: time.Now().Unix(), } // 添加到黑名單 err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算 if err != nil { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "BlacklistToken.AddToBlacklist", req: jtiStr, err: err, message: "failed to add token to blacklist", errorCode: code.TokenCreateError, }) } logx.WithContext(ctx).Infow("token blacklisted", logx.Field("jti", jtiStr), logx.Field("uid", uid), logx.Field("reason", reason)) return nil } // IsTokenBlacklisted 檢查 JWT token 是否在黑名單中 func (use *TokenUseCase) IsTokenBlacklisted(ctx context.Context, jti string) (bool, error) { isBlacklisted, err := use.TokenRepo.IsBlacklisted(ctx, jti) if err != nil { return false, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "IsTokenBlacklisted", req: jti, err: err, message: "failed to check blacklist status", errorCode: code.TokenValidateError, }) } return isBlacklisted, nil } // BlacklistAllUserTokens 將用戶的所有 token 加入黑名單 (全設備登出) func (use *TokenUseCase) BlacklistAllUserTokens(ctx context.Context, uid string, reason string) error { // 獲取用戶的所有 token tokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, uid) if err != nil { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "BlacklistAllUserTokens.GetAccessTokensByUID", req: uid, err: err, message: "failed to get user tokens", errorCode: code.TokenValidateError, }) } // 為每個 token 創建黑名單條目 for _, token := range tokens { // 解析 token 獲取 JTI 和過期時間 claims, err := parseClaims(token.AccessToken, use.Config.Token.AccessSecret, false) if err != nil { logx.WithContext(ctx).Errorw("failed to parse token for blacklisting", logx.Field("uid", uid), logx.Field("tokenID", token.ID), logx.Field("error", err)) continue // 跳過無效的 token,繼續處理其他 token } jti, exists := claims["jti"] if !exists || jti == "" { logx.WithContext(ctx).Errorw("token missing JTI claim", logx.Field("uid", uid), logx.Field("tokenID", token.ID)) continue } exp, exists := claims["exp"] if !exists { logx.WithContext(ctx).Errorw("token missing exp claim", logx.Field("uid", uid), logx.Field("tokenID", token.ID)) continue } // 將 exp 字符串轉換為 int64 expInt, err := strconv.ParseInt(exp, 10, 64) if err != nil { logx.WithContext(ctx).Errorw("failed to parse exp claim", logx.Field("uid", uid), logx.Field("tokenID", token.ID), logx.Field("error", err)) continue } // 創建黑名單條目 blacklistEntry := &entity.BlacklistEntry{ JTI: jti, UID: uid, ExpiresAt: expInt, CreatedAt: time.Now().Unix(), } // 添加到黑名單 err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算 if err != nil { logx.WithContext(ctx).Errorw("failed to add token to blacklist", logx.Field("uid", uid), logx.Field("jti", jti), logx.Field("error", err)) // 繼續處理其他 token,不要因為一個失敗就停止 } } // 刪除用戶的所有 token 記錄 err = use.TokenRepo.DeleteAccessTokensByUID(ctx, uid) if err != nil { logx.WithContext(ctx).Errorw("failed to delete user tokens", logx.Field("uid", uid), logx.Field("error", err)) // 這不是致命錯誤,因為 token 已經被加入黑名單 } logx.WithContext(ctx).Infow("all user tokens blacklisted", logx.Field("uid", uid), logx.Field("tokenCount", len(tokens)), logx.Field("reason", reason)) return nil }