package repository import ( "code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain" "code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/entity" "code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/repository" "context" "encoding/json" "errors" "fmt" "github.com/zeromicro/go-zero/core/stores/redis" "time" ) // TokenRepositoryParam token 需要的參數 type TokenRepositoryParam struct { Redis *redis.Redis } // TokenRepository 用於操作 token 的存儲庫 type TokenRepository struct { TokenRepositoryParam } // NewTokenRepository 初始化並返回一個 TokenRepository 實例 func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepo { return &TokenRepository{ TokenRepositoryParam: param, } } // ====================== 私有工具函數 ======================= // runPipeline 執行 Redis pipeline 操作 func (repo *TokenRepository) runPipeline(ctx context.Context, pipelineFunc func(tx redis.Pipeliner) error) error { return repo.Redis.PipelinedCtx(ctx, pipelineFunc) } // setToken 使用 Redis pipeline 存儲 token 並設定 TTL func (repo *TokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, id string, body []byte, ttl time.Duration) error { return tx.Set(ctx, domain.GetAccessTokenRedisKey(id), body, ttl).Err() } // setRefreshToken 若 token 中有 refresh token,則存儲之(使用 pipeline) func (repo *TokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, ttl time.Duration) error { if token.RefreshToken == "" { return nil } return tx.Set(ctx, domain.GetRefreshTokenRedisKey(token.RefreshToken), token.ID, ttl).Err() } // setTokenRelation 在 Redis 中設定 token 與 UID/Device 之間的關聯,並設定過期時間 func (repo *TokenRepository) setTokenRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, ttl time.Duration) error { // 定義需要執行的操作列表 operations := []struct { key string op func() error }{ { key: domain.GetUIDTokenRedisKey(uid), op: func() error { return tx.SAdd(ctx, domain.GetUIDTokenRedisKey(uid), tokenID).Err() }, }, { key: domain.GetDeviceTokenRedisKey(deviceID), op: func() error { return tx.SAdd(ctx, domain.GetDeviceTokenRedisKey(deviceID), tokenID).Err() }, }, } // 執行每個操作,並為對應 key 設置過期時間 for _, operation := range operations { if err := operation.op(); err != nil { return fmt.Errorf("failed to create token relaction: %w", err) } if err := tx.Expire(ctx, operation.key, ttl).Err(); err != nil { return fmt.Errorf("failed to set expire: %w", err) } } return nil } // retrieveToken 根據指定 key 從 Redis 中獲取 token func (repo *TokenRepository) retrieveToken(ctx context.Context, key string) (entity.Token, error) { body, err := repo.Redis.GetCtx(ctx, key) if err != nil { return entity.Token{}, err } if body == "" { return entity.Token{}, fmt.Errorf("failed to found token") } var token entity.Token if err := json.Unmarshal([]byte(body), &token); err != nil { return entity.Token{}, fmt.Errorf("failed to unmarshal token JSON: %w", err) } return token, nil } // getTokensBySet 根據集合 key 獲取所有 token func (repo *TokenRepository) getTokensBySet(ctx context.Context, setKey string) ([]entity.Token, error) { ids, err := repo.Redis.Smembers(setKey) if err != nil { if errors.Is(err, redis.Nil) { return nil, nil } return nil, err } tokens := make([]entity.Token, 0, len(ids)) var tokensToDelete []string now := time.Now().UnixNano() for _, id := range ids { token, err := repo.retrieveToken(ctx, domain.GetAccessTokenRedisKey(id)) if err != nil { tokensToDelete = append(tokensToDelete, id) continue } if token.ExpiresIn < now { tokensToDelete = append(tokensToDelete, id) continue } tokens = append(tokens, token) } // 清除過期或錯誤的 token if len(tokensToDelete) > 0 { _ = repo.DeleteAccessTokenByID(ctx, tokensToDelete) } return tokens, nil } // getCountBySet 獲取集合中元素的數量 func (repo *TokenRepository) getCountBySet(ctx context.Context, setKey string) (int, error) { count, err := repo.Redis.ScardCtx(ctx, setKey) if err != nil { return 0, err } return int(count), nil } // deleteKeysAndRelations 刪除指定的 Redis key 並移除相關關聯(UID 與 DeviceID) func (repo *TokenRepository) deleteKeysAndRelations(ctx context.Context, keys []string, uid, deviceID, tokenID string) error { err := repo.Redis.Pipelined(func(tx redis.Pipeliner) error { // 移除 UID 與 DeviceID 關聯中的 tokenID _ = tx.SRem(ctx, domain.GetUIDTokenRedisKey(uid), tokenID) _ = tx.SRem(ctx, domain.GetDeviceTokenRedisKey(deviceID), tokenID) // 刪除所有指定的 keys for _, key := range keys { _ = tx.Del(ctx, key) } return nil }) return err } // batchDeleteKeys 批量刪除 Redis keys func (repo *TokenRepository) batchDeleteKeys(ctx context.Context, keys ...string) error { return repo.Redis.Pipelined(func(tx redis.Pipeliner) error { for _, key := range keys { if err := tx.Del(ctx, key).Err(); err != nil { return err } } return nil }) } // ====================== 公開方法 ======================= // Create 創建新的 token func (repo *TokenRepository) Create(ctx context.Context, token entity.Token) error { binToken, err := json.Marshal(token) if err != nil { return err } // 根據 token 設定 refresh 過期秒數計算 TTL refreshTTL := time.Duration(token.RedisRefreshExpiredSec()) * time.Second return repo.runPipeline(ctx, func(tx redis.Pipeliner) error { if err := repo.setToken(ctx, tx, token.ID, binToken, refreshTTL); err != nil { return err } if err := repo.setRefreshToken(ctx, tx, token, refreshTTL); err != nil { return err } if err := repo.setTokenRelation(ctx, tx, token.UID, token.DeviceID, token.ID, refreshTTL); err != nil { return err } return nil }) } // CreateOneTimeToken 創建一次性 token func (repo *TokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, dt time.Duration) error { body, err := json.Marshal(ticket) if err != nil { return err } _, err = repo.Redis.SetnxExCtx(ctx, domain.GetRefreshTokenRedisKey(key), string(body), int(dt.Seconds())) return err } // GetAccessTokenByOneTimeToken 根據一次性 token 獲取 access token func (repo *TokenRepository) GetAccessTokenByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error) { return repo.retrieveToken(ctx, domain.GetRefreshTokenRedisKey(oneTimeToken)) } // GetAccessTokenByID 根據 token ID 獲取 access token func (repo *TokenRepository) GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) { return repo.retrieveToken(ctx, domain.GetAccessTokenRedisKey(id)) } // GetAccessTokensByUID 根據 UID 獲取所有 access tokens func (repo *TokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) { return repo.getTokensBySet(ctx, domain.GetUIDTokenRedisKey(uid)) } // GetAccessTokenCountByUID 根據 UID 獲取 access token 的數量 func (repo *TokenRepository) GetAccessTokenCountByUID(ctx context.Context, uid string) (int, error) { return repo.getCountBySet(ctx, domain.GetUIDTokenRedisKey(uid)) } // GetAccessTokensByDeviceID 根據 DeviceID 獲取所有 access tokens func (repo *TokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) { return repo.getTokensBySet(ctx, domain.GetDeviceTokenRedisKey(deviceID)) } // GetAccessTokenCountByDeviceID 根據 DeviceID 獲取 access token 的數量 func (repo *TokenRepository) GetAccessTokenCountByDeviceID(ctx context.Context, deviceID string) (int, error) { return repo.getCountBySet(ctx, domain.GetDeviceTokenRedisKey(deviceID)) } // Delete 刪除指定的 token func (repo *TokenRepository) Delete(ctx context.Context, token entity.Token) error { keys := []string{ domain.GetAccessTokenRedisKey(token.ID), domain.GetRefreshTokenRedisKey(token.RefreshToken), } return repo.deleteKeysAndRelations(ctx, keys, token.UID, token.DeviceID, token.ID) } // DeleteAccessTokenByID 根據 token ID 刪除 access token func (repo *TokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error { for _, tokenID := range ids { token, err := repo.GetAccessTokenByID(ctx, tokenID) if err != nil { continue } keys := []string{ domain.GetAccessTokenRedisKey(token.ID), domain.GetRefreshTokenRedisKey(token.RefreshToken), } _ = repo.deleteKeysAndRelations(ctx, keys, token.UID, token.DeviceID, token.ID) } return nil } // DeleteAccessTokensByUID 根據 UID 刪除所有 access tokens func (repo *TokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error { tokens, err := repo.GetAccessTokensByUID(ctx, uid) if err != nil { return err } for _, token := range tokens { if err := repo.Delete(ctx, token); err != nil { return err } } return nil } // DeleteAccessTokensByDeviceID 根據 DeviceID 刪除所有 access tokens func (repo *TokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error { tokens, err := repo.GetAccessTokensByDeviceID(ctx, deviceID) if err != nil { return err } // 預分配 keys:每個 token 包含兩個 key keys := make([]string, 0, len(tokens)*2) for _, token := range tokens { keys = append(keys, domain.GetAccessTokenRedisKey(token.ID)) keys = append(keys, domain.GetRefreshTokenRedisKey(token.RefreshToken)) } // 移除 UID 關聯中的 tokenID if err := repo.runPipeline(ctx, func(tx redis.Pipeliner) error { for _, token := range tokens { _ = tx.SRem(ctx, domain.GetUIDTokenRedisKey(token.UID), token.ID) } return nil }); err != nil { return err } if err := repo.batchDeleteKeys(ctx, keys...); err != nil { return err } _, err = repo.Redis.Del(domain.GetDeviceTokenRedisKey(deviceID)) return err } // DeleteOneTimeToken 刪除一次性 token(支持多個 key 一併刪除) func (repo *TokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error { totalKeys := len(ids) + len(tokens) keys := make([]string, 0, totalKeys) for _, id := range ids { keys = append(keys, domain.GetRefreshTokenRedisKey(id)) } for _, token := range tokens { keys = append(keys, domain.GetRefreshTokenRedisKey(token.RefreshToken)) } return repo.batchDeleteKeys(ctx, keys...) }