package usecase import ( "context" "fmt" "strconv" "time" "backend/internal/config" errs "backend/pkg/library/errors" "backend/pkg/permission/domain/entity" "backend/pkg/permission/domain/repository" "backend/pkg/permission/domain/token" "backend/pkg/permission/domain/usecase" "github.com/segmentio/ksuid" ) type TokenUseCaseParam struct { TokenRepo repository.TokenRepository Logger errs.Logger Config *config.Config } type TokenUseCase struct { TokenUseCaseParam } func (use *TokenUseCase) ReadTokenBasicData(ctx context.Context, token string) (map[string]string, error) { claims, err := ParseClaims(token, use.Config.Token.AccessSecret, false) if err != nil { return nil, errs.AuthSigPayloadMismatchError("validate token claims error") } return claims, nil } func MustTokenUseCase(param TokenUseCaseParam) usecase.TokenUseCase { return &TokenUseCase{ param, } } // ============================================ token ============================================ func (use *TokenUseCase) NewToken(ctx context.Context, req entity.AuthorizationReq) (entity.TokenResp, error) { tokenObj, err := use.newToken(ctx, &req) if err != nil { return entity.TokenResp{}, err } err = use.TokenRepo.Create(ctx, *tokenObj) if err != nil { return entity.TokenResp{}, errs.DBErrorError("failed to create token") } return entity.TokenResp{ AccessToken: tokenObj.AccessToken, TokenType: token.TypeBearer.String(), ExpiresIn: int64(tokenObj.ExpiresIn), RefreshToken: tokenObj.RefreshToken, }, nil } func (use *TokenUseCase) newToken(ctx context.Context, req *entity.AuthorizationReq) (*entity.Token, error) { // 準備建立 Token 所需 now := time.Now().UTC() expires := req.Expires refreshExpires := req.Expires if expires <= 0 { // 將時間加上 n 秒 sec := use.Config.Token.AccessTokenExpiry // 獲取 Unix 時間戳 expires = now.Add(sec).Unix() refreshExpires = expires } // 如果這是一個 Refresh Token 過期時間要比普通的Token 長 if req.IsRefreshToken { // 獲取 Unix 時間戳 refresh := use.Config.Token.RefreshTokenExpiry refreshExpires = now.Add(refresh).Unix() } token := entity.Token{ ID: ksuid.New().String(), DeviceID: req.DeviceID, ExpiresIn: int(expires), RefreshExpiresIn: int(refreshExpires), AccessCreateAt: now, RefreshCreateAt: now, } tc := make(TokenClaims) if req.Data != nil { for k, v := range req.Data { tc[k] = v } } tc.SetRole(req.Role) tc.SetID(token.ID) tc.SetScope(req.Scope) tc.SetLoginID(req.Account) token.UID = tc.UID() if req.DeviceID != "" { tc.SetDeviceID(req.DeviceID) } var err error token.AccessToken, err = accessTokenGenerator(token, tc, use.Config.Token.AccessSecret) if err != nil { return nil, errs.SysInternalError("failed to generator access token").Wrap(err) } if req.IsRefreshToken { token.RefreshToken = refreshTokenGenerator(token.AccessToken) } return &token, nil } func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTokenReq) (entity.RefreshTokenResp, error) { // Step 1: 檢查 refresh token tokenObj, err := use.TokenRepo.GetAccessTokenByOneTimeToken(ctx, req.Token) if err != nil { return entity.RefreshTokenResp{}, errs.DBErrorError("failed to get access token").Wrap(err) } // Step 2: 提取 Claims Data claimsData, err := ParseClaims(tokenObj.AccessToken, use.Config.Token.AccessSecret, false) if err != nil { return entity.RefreshTokenResp{}, errs.AuthSigPayloadMismatchError("failed to extract claims") } // Step 3: 創建新 token credentials := token.ClientCredentials newToken, err := use.newToken(ctx, &entity.AuthorizationReq{ GrantType: credentials.ToString(), Scope: claimsData.Scope(), DeviceID: req.DeviceID, Data: claimsData, Expires: req.Expires, IsRefreshToken: true, Account: claimsData.LoginID(), Role: claimsData.Role(), }) if err != nil { return entity.RefreshTokenResp{}, errs.DBErrorError("failed to create new token").Wrap(err) } if err := use.TokenRepo.Create(ctx, *newToken); err != nil { return entity.RefreshTokenResp{}, errs.DBErrorError("failed to create new token").Wrap(err) } // Step 4: 刪除舊 token 並創建新 token if err := use.TokenRepo.Delete(ctx, tokenObj); err != nil { return entity.RefreshTokenResp{}, errs.DBErrorError("failed to delete old token").Wrap(err) } // 返回新的 Token 響應 return entity.RefreshTokenResp{ Token: newToken.AccessToken, OneTimeToken: newToken.RefreshToken, ExpiresIn: int64(newToken.ExpiresIn), TokenType: token.TypeBearer.String(), }, nil } func (use *TokenUseCase) CancelToken(ctx context.Context, req entity.CancelTokenReq) error { claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, false) if err != nil { return errs.AuthSigPayloadMismatchError("failed to get token claims") } token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID()) if err != nil { return errs.DBErrorError(fmt.Sprintf("failed to get token claims :%s", claims.ID())) } err = use.TokenRepo.Delete(ctx, token) if err != nil { return errs.DBErrorError(fmt.Sprintf("failed to delete token :%s", token.ID)).Wrap(err) } return nil } func (use *TokenUseCase) ValidationToken(ctx context.Context, req entity.ValidationTokenReq) (entity.ValidationTokenResp, error) { claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, true) if err != nil { return entity.ValidationTokenResp{}, errs.AuthSigPayloadMismatchError("validate token claims error") } token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID()) if err != nil { return entity.ValidationTokenResp{}, errs.DBErrorError(fmt.Sprintf("failed to get token :%s", claims.ID())).Wrap(err) } return entity.ValidationTokenResp{ Token: entity.Token{ ID: token.ID, UID: token.UID, DeviceID: token.DeviceID, AccessCreateAt: token.AccessCreateAt, AccessToken: token.AccessToken, ExpiresIn: token.ExpiresIn, RefreshToken: token.RefreshToken, RefreshExpiresIn: token.RefreshExpiresIn, RefreshCreateAt: token.RefreshCreateAt, }, Data: claims, }, nil } func (use *TokenUseCase) CancelTokens(ctx context.Context, req entity.DoTokenByUIDReq) error { if req.UID != "" { err := use.TokenRepo.DeleteAccessTokensByUID(ctx, req.UID) if err != nil { return errs.DBErrorError("failed to cancel tokens by uid").Wrap(err) } } if len(req.IDs) > 0 { err := use.TokenRepo.DeleteAccessTokenByID(ctx, req.IDs) if err != nil { return errs.DBErrorError("failed to cancel tokens by token ids").Wrap(err) } } return nil } func (use *TokenUseCase) CancelTokenByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) error { err := use.TokenRepo.DeleteAccessTokensByDeviceID(ctx, req.DeviceID) if err != nil { return errs.DBErrorError("failed to cancel tokens by device id").Wrap(err) } return nil } func (use *TokenUseCase) GetUserTokensByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) ([]*entity.TokenResp, error) { uidTokens, err := use.TokenRepo.GetAccessTokensByDeviceID(ctx, req.DeviceID) if err != nil { return nil, errs.DBErrorError("failed to get tokens by device id").Wrap(err) } tokens := make([]*entity.TokenResp, 0, len(uidTokens)) for _, v := range uidTokens { tokens = append(tokens, &entity.TokenResp{ AccessToken: v.AccessToken, TokenType: token.TypeBearer.String(), ExpiresIn: int64(v.ExpiresIn), RefreshToken: v.RefreshToken, }) } return tokens, nil } func (use *TokenUseCase) GetUserTokensByUID(ctx context.Context, req entity.QueryTokenByUIDReq) ([]*entity.TokenResp, error) { uidTokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, req.UID) if err != nil { return nil, errs.DBErrorError("failed to get tokens by uid").Wrap(err) } tokens := make([]*entity.TokenResp, 0, len(uidTokens)) for _, v := range uidTokens { tokens = append(tokens, &entity.TokenResp{ AccessToken: v.AccessToken, TokenType: token.TypeBearer.String(), ExpiresIn: int64(v.ExpiresIn), RefreshToken: v.RefreshToken, }) } return tokens, nil } func (use *TokenUseCase) NewOneTimeToken(ctx context.Context, req entity.CreateOneTimeTokenReq) (entity.CreateOneTimeTokenResp, error) { // 驗證Token claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, false) if err != nil { return entity.CreateOneTimeTokenResp{}, errs.AuthSigPayloadMismatchError("failed to get token claims").Wrap(err) } tokenObj, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID()) if err != nil { return entity.CreateOneTimeTokenResp{}, errs.DBErrorError("failed to get token by id").Wrap(err) } oneTimeToken := refreshTokenGenerator(ksuid.New().String()) key := token.TicketKeyPrefix + oneTimeToken if err = use.TokenRepo.CreateOneTimeToken(ctx, key, entity.Ticket{ Data: claims, Token: tokenObj, }, time.Minute); err != nil { return entity.CreateOneTimeTokenResp{}, errs.DBErrorError("failed to create new one-time token").Wrap(err) } return entity.CreateOneTimeTokenResp{ OneTimeToken: oneTimeToken, }, nil } func (use *TokenUseCase) CancelOneTimeToken(ctx context.Context, req entity.CancelOneTimeTokenReq) error { err := use.TokenRepo.DeleteOneTimeToken(ctx, req.Token, nil) if err != nil { return errs.DBErrorError("failed to del one time token by token") } return nil } // BlacklistToken 將 JWT token 加入黑名單 (立即撤銷) func (use *TokenUseCase) BlacklistToken(ctx context.Context, token string, reason string) error { // 解析 JWT 獲取完整的 claims claimMap, err := parseToken(token, use.Config.Token.AccessSecret, false) if err != nil { return errs.AuthSigPayloadMismatchError("failed to parse token claims").Wrap(err) } // 獲取 JTI (JWT ID) jti, exists := claimMap["jti"] if !exists { return errs.ResNotFoundError("token missing JTI claim").Wrap(err) } jtiStr, ok := jti.(string) if !ok { return errs.ResNotFoundError("token missing JTI claim").Wrap(err) } // 獲取 UID (可能在 data 中) var uid string if dataInterface, exists := claimMap["data"]; exists { if dataMap, ok := dataInterface.(map[string]interface{}); ok { if uidInterface, exists := dataMap["uid"]; exists { uid, _ = uidInterface.(string) } } } // 獲取過期時間 exp, exists := claimMap["exp"] if !exists { return errs.AuthExpiredError("token missing exp claim").Wrap(err) } // 將 exp 轉換為 int64 (JWT 中通常是 float64) var expInt int64 switch v := exp.(type) { case float64: expInt = int64(v) case int64: expInt = v case string: parsedExp, err := strconv.ParseInt(v, 10, 64) if err != nil { return errs.SysInternalError("failed to parse exp claim").Wrap(err) } expInt = parsedExp default: return errs.SysInternalError("exp claim type conversion failed").Wrap(err) } // 創建黑名單條目 blacklistEntry := &entity.BlacklistEntry{ JTI: jtiStr, UID: uid, ExpiresAt: expInt, CreatedAt: time.Now().Unix(), } // 添加到黑名單 err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算 if err != nil { return errs.DBErrorError("failed to add token to blacklist").Wrap(err) } // 記錄成功日誌(如果 Logger 存在) if use.Logger != nil { use.Logger.WithFields( errs.LogField{ Key: "jti", Val: jtiStr, }, errs.LogField{ Key: "uid", Val: uid, }, errs.LogField{ Key: "reason", Val: reason, }).Info("token blacklisted") } return nil } // IsTokenBlacklisted 檢查 JWT token 是否在黑名單中 func (use *TokenUseCase) IsTokenBlacklisted(ctx context.Context, jti string) (bool, error) { isBlacklisted, err := use.TokenRepo.IsBlacklisted(ctx, jti) if err != nil { return false, errs.DBErrorError("failed to check blacklist status").Wrap(err) } return isBlacklisted, nil } // BlacklistAllUserTokens 將用戶的所有 token 加入黑名單 (全設備登出) func (use *TokenUseCase) BlacklistAllUserTokens(ctx context.Context, uid string, reason string) error { // 獲取用戶的所有 token tokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, uid) if err != nil { return errs.DBErrorError("failed to get user tokens").Wrap(err) } // 為每個 token 創建黑名單條目 for _, token := range tokens { // 解析 token 獲取 JTI 和過期時間 claims, err := ParseClaims(token.AccessToken, use.Config.Token.AccessSecret, false) if err != nil { if use.Logger != nil { use.Logger.WithFields( errs.LogField{ Key: "uid", Val: uid, }, errs.LogField{ Key: "tokenID", Val: token.ID, }, errs.LogField{ Key: "error", Val: err, }).Error("failed to parse token for blacklisting") } continue // 跳過無效的 token,繼續處理其他 token } jti, exists := claims["jti"] if !exists || jti == "" { if use.Logger != nil { use.Logger.WithFields( errs.LogField{ Key: "uid", Val: uid, }, errs.LogField{ Key: "tokenID", Val: token.ID, }).Error("failed to parse token for blacklisting") } continue } exp, exists := claims["exp"] if !exists { if use.Logger != nil { use.Logger.WithFields( errs.LogField{ Key: "uid", Val: uid, }, errs.LogField{ Key: "tokenID", Val: token.ID, }).Error("token missing exp claim") } continue } // 將 exp 字符串轉換為 int64 expInt, err := strconv.ParseInt(exp, 10, 64) if err != nil { if use.Logger != nil { use.Logger.WithFields( errs.LogField{ Key: "uid", Val: uid, }, errs.LogField{ Key: "tokenID", Val: token.ID, }, errs.LogField{ Key: "error", Val: err, }).Error("failed to parse exp claim") } continue } // 創建黑名單條目 blacklistEntry := &entity.BlacklistEntry{ JTI: jti, UID: uid, ExpiresAt: expInt, CreatedAt: time.Now().Unix(), } // 添加到黑名單 err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算 if err != nil { if use.Logger != nil { use.Logger.WithFields( errs.LogField{ Key: "uid", Val: uid, }, errs.LogField{ Key: "jti", Val: jti, }, errs.LogField{ Key: "error", Val: err, }).Error("failed to add token to blacklist") } // 繼續處理其他 token,不要因為一個失敗就停止 } } // 刪除用戶的所有 token 記錄 err = use.TokenRepo.DeleteAccessTokensByUID(ctx, uid) if err != nil { if use.Logger != nil { use.Logger.WithFields( errs.LogField{ Key: "uid", Val: uid, }, errs.LogField{ Key: "error", Val: err, }).Error("failed to delete user tokens") } } if use.Logger != nil { use.Logger.WithFields( errs.LogField{ Key: "uid", Val: uid, }, errs.LogField{ Key: "tokenCount", Val: len(tokens), }, errs.LogField{ Key: "reason", Val: reason, }).Error("all user tokens blacklisted") } return nil }