package repository import ( "context" "encoding/json" "errors" "fmt" "time" "backend/pkg/permission/domain/entity" "backend/pkg/permission/domain/repository" "backend/pkg/permission/domain/token" "github.com/zeromicro/go-zero/core/stores/redis" ) // TokenRepositoryParam token 需要的參數 type TokenRepositoryParam struct { Redis *redis.Redis } // TokenRepository 通知 type TokenRepository struct { TokenRepositoryParam } func MustTokenRepository(param TokenRepositoryParam) repository.TokenRepository { return &TokenRepository{ param, } } // Create 創建一個新 Token,並將其存儲於 Redis func (repo *TokenRepository) Create(ctx context.Context, token entity.Token) error { body, err := json.Marshal(token) if err != nil { return err } refreshTTL := time.Duration(token.RedisRefreshExpiredSec()) * time.Second return repo.runPipeline(ctx, func(tx redis.Pipeliner) error { if err := repo.setToken(ctx, tx, token, body, refreshTTL); err != nil { return err } if err := repo.setRefreshToken(ctx, tx, token, refreshTTL); err != nil { return err } return repo.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, refreshTTL) }) } func (repo *TokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, dt time.Duration) error { body, err := json.Marshal(ticket) if err != nil { return err } _, err = repo.Redis.SetnxExCtx(ctx, token.RefreshTokenRedisKey(key), string(body), int(dt.Seconds())) if err != nil { return err } return nil } func (repo *TokenRepository) GetAccessTokenByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error) { id, err := repo.Redis.Get(token.RefreshTokenRedisKey(oneTimeToken)) if err != nil { return entity.Token{}, err } if id == "" { return entity.Token{}, fmt.Errorf("token not found") } return repo.GetAccessTokenByID(ctx, id) } func (repo *TokenRepository) GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) { return repo.get(ctx, token.GetAccessTokenRedisKey(id)) } func (repo *TokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) { return repo.getTokensBySet(ctx, token.GetUIDTokenRedisKey(uid)) } func (repo *TokenRepository) GetAccessTokenCountByUID(ctx context.Context, uid string) (int, error) { return repo.getCountBySet(ctx, token.UIDTokenRedisKey(uid)) } func (repo *TokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) { return repo.getTokensBySet(ctx, token.DeviceTokenRedisKey(deviceID)) } func (repo *TokenRepository) GetAccessTokenCountByDeviceID(ctx context.Context, deviceID string) (int, error) { return repo.getCountBySet(ctx, token.DeviceTokenRedisKey(deviceID)) } func (repo *TokenRepository) Delete(ctx context.Context, tokenObj entity.Token) error { // Delete 刪除指定的 Token keys := []string{ token.GetAccessTokenRedisKey(tokenObj.ID), token.RefreshTokenRedisKey(tokenObj.RefreshToken), } return repo.deleteKeysAndRelations(ctx, keys, tokenObj.UID, tokenObj.DeviceID, tokenObj.ID) } func (repo *TokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error { l := len(ids) + len(tokens) keys := make([]string, 0, l) for _, id := range ids { keys = append(keys, token.RefreshTokenRedisKey(id)) } for _, tokenObj := range tokens { keys = append(keys, token.RefreshTokenRedisKey(tokenObj.RefreshToken)) } return repo.deleteKeys(ctx, keys...) } func (repo *TokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error { for _, tokenID := range ids { tokenObj, err := repo.GetAccessTokenByID(ctx, tokenID) if err != nil { continue } keys := []string{ token.GetAccessTokenRedisKey(tokenObj.ID), token.RefreshTokenRedisKey(tokenObj.RefreshToken), } _ = repo.deleteKeysAndRelations(ctx, keys, tokenObj.UID, tokenObj.DeviceID, tokenObj.ID) } return nil } func (repo *TokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error { tokens, err := repo.GetAccessTokensByUID(ctx, uid) if err != nil { return err } for _, token := range tokens { if err := repo.Delete(ctx, token); err != nil { return err } } return nil } func (repo *TokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error { tokens, err := repo.GetAccessTokensByDeviceID(ctx, deviceID) if err != nil { return err } l := len(tokens) * 2 keys := make([]string, 0, l) for _, tokenObj := range tokens { keys = append(keys, token.GetAccessTokenRedisKey(tokenObj.ID)) keys = append(keys, token.RefreshTokenRedisKey(tokenObj.RefreshToken)) } err = repo.runPipeline(ctx, func(tx redis.Pipeliner) error { for _, tokenObj := range tokens { tx.SRem(ctx, token.UIDTokenRedisKey(tokenObj.UID), tokenObj.ID) } return nil }) if err != nil { return err } if err := repo.deleteKeys(ctx, keys...); err != nil { return err } _, err = repo.Redis.Del(token.DeviceTokenRedisKey(deviceID)) return err } // ======================================================================== // deleteKeysAndRelations 刪除指定鍵並移除相關的關聯 func (repo *TokenRepository) deleteKeysAndRelations(ctx context.Context, keys []string, uid, deviceID, tokenID string) error { err := repo.Redis.Pipelined(func(tx redis.Pipeliner) error { // 刪除 UID 和 DeviceID 的關聯 _ = tx.SRem(ctx, token.UIDTokenRedisKey(uid), tokenID) _ = tx.SRem(ctx, token.DeviceTokenRedisKey(deviceID), tokenID) for _, key := range keys { _ = tx.Del(ctx, key) } return nil }) if err != nil { return err } return nil } // runPipeline 執行 Redis 的 Pipeline 操作 func (repo *TokenRepository) runPipeline(ctx context.Context, fn func(tx redis.Pipeliner) error) error { if err := repo.Redis.PipelinedCtx(ctx, fn); err != nil { return err } return nil } // deleteKeys 批量刪除 Redis 鍵 func (repo *TokenRepository) deleteKeys(ctx context.Context, keys ...string) error { return repo.Redis.Pipelined(func(tx redis.Pipeliner) error { for _, key := range keys { if err := tx.Del(ctx, key).Err(); err != nil { return err } } return nil }) } func (repo *TokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, tokenObj entity.Token, body []byte, ttl time.Duration) error { return tx.Set(ctx, token.GetAccessTokenRedisKey(tokenObj.ID), body, ttl).Err() } func (repo *TokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, tokenObj entity.Token, ttl time.Duration) error { if tokenObj.RefreshToken != "" { return tx.Set(ctx, token.RefreshTokenRedisKey(tokenObj.RefreshToken), tokenObj.ID, ttl).Err() } return nil } func (repo *TokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, ttl time.Duration) error { if err := tx.SAdd(ctx, token.UIDTokenRedisKey(uid), tokenID).Err(); err != nil { return err } // 設置 UID 鍵的過期時間 if err := tx.Expire(ctx, token.UIDTokenRedisKey(uid), ttl).Err(); err != nil { return err } if err := tx.SAdd(ctx, token.DeviceTokenRedisKey(deviceID), tokenID).Err(); err != nil { return err } // 設置 deviceID 鍵的過期時間 if err := tx.Expire(ctx, token.DeviceTokenRedisKey(deviceID), ttl).Err(); err != nil { return err } return nil } // get 根據鍵獲取 Token func (repo *TokenRepository) get(ctx context.Context, key string) (entity.Token, error) { body, err := repo.Redis.GetCtx(ctx, key) if err != nil { return entity.Token{}, err } if body == "" { return entity.Token{}, fmt.Errorf("token not found") } var token entity.Token if err := json.Unmarshal([]byte(body), &token); err != nil { return entity.Token{}, fmt.Errorf("json.Marshal token error") } return token, nil } // getTokensBySet 根據集合鍵獲取所有 Token func (repo *TokenRepository) getTokensBySet(ctx context.Context, setKey string) ([]entity.Token, error) { ids, err := repo.Redis.Smembers(setKey) if err != nil { if errors.Is(err, redis.Nil) { return nil, nil } return nil, err } tokens := make([]entity.Token, 0, len(ids)) var deleteTokens []string now := time.Now().Unix() for _, id := range ids { token, err := repo.get(ctx, token.GetAccessTokenRedisKey(id)) if err != nil { deleteTokens = append(deleteTokens, id) continue } if int64(token.ExpiresIn) < now { deleteTokens = append(deleteTokens, id) continue } tokens = append(tokens, token) } if len(deleteTokens) > 0 { _ = repo.DeleteAccessTokenByID(ctx, deleteTokens) } return tokens, nil } // getCountBySet 獲取集合中的元素數量 func (repo *TokenRepository) getCountBySet(ctx context.Context, setKey string) (int, error) { count, err := repo.Redis.ScardCtx(ctx, setKey) if err != nil { return 0, err } return int(count), nil } // AddToBlacklist 將 token 加入黑名單 func (repo *TokenRepository) AddToBlacklist(ctx context.Context, entry *entity.BlacklistEntry, ttl time.Duration) error { key := token.GetBlacklistRedisKey(entry.JTI) // 序列化黑名單條目 data, err := json.Marshal(entry) if err != nil { return fmt.Errorf("failed to marshal blacklist entry: %w", err) } // 使用提供的 TTL,如果 TTL <= 0,則計算默認 TTL if ttl <= 0 { // 計算 TTL (token 過期時間 - 當前時間) ttl = time.Unix(entry.ExpiresAt, 0).Sub(time.Now()) if ttl <= 0 { // Token 已經過期,不需要加入黑名單 return nil } } // 存儲到 Redis 並設置過期時間 err = repo.Redis.SetexCtx(ctx, key, string(data), int(ttl.Seconds())) if err != nil { return fmt.Errorf("failed to add token to blacklist: %w", err) } return nil } // IsBlacklisted 檢查 token 是否在黑名單中 func (repo *TokenRepository) IsBlacklisted(ctx context.Context, jti string) (bool, error) { key := token.GetBlacklistRedisKey(jti) exists, err := repo.Redis.ExistsCtx(ctx, key) if err != nil { return false, fmt.Errorf("failed to check blacklist: %w", err) } return exists, nil } // RemoveFromBlacklist 從黑名單中移除 token func (repo *TokenRepository) RemoveFromBlacklist(ctx context.Context, jti string) error { key := token.GetBlacklistRedisKey(jti) _, err := repo.Redis.DelCtx(ctx, key) if err != nil { return fmt.Errorf("failed to remove token from blacklist: %w", err) } return nil } // GetBlacklistedTokensByUID 獲取用戶的所有黑名單 token func (repo *TokenRepository) GetBlacklistedTokensByUID(ctx context.Context, uid string) ([]*entity.BlacklistEntry, error) { // 使用 SCAN 來查找所有黑名單鍵 pattern := token.BlacklistKeyPrefix + "*" var entries []*entity.BlacklistEntry var cursor uint64 = 0 for { keys, nextCursor, err := repo.Redis.ScanCtx(ctx, cursor, pattern, 100) if err != nil { return nil, fmt.Errorf("failed to scan blacklist keys: %w", err) } // 獲取每個鍵的值並檢查 UID for _, key := range keys { data, err := repo.Redis.GetCtx(ctx, key) if err != nil { if errors.Is(err, redis.Nil) { continue // 鍵已過期或不存在 } return nil, fmt.Errorf("failed to get blacklist entry: %w", err) } var entry entity.BlacklistEntry if err := json.Unmarshal([]byte(data), &entry); err != nil { continue // 跳過無效的條目 } // 檢查 UID 是否匹配 if entry.UID == uid { entries = append(entries, &entry) } } cursor = nextCursor if cursor == 0 { break } } return entries, nil }