2025-02-12 01:51:46 +00:00
|
|
|
|
package repository
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"context"
|
|
|
|
|
"encoding/json"
|
|
|
|
|
"errors"
|
|
|
|
|
"fmt"
|
|
|
|
|
"time"
|
2025-02-13 11:06:51 +00:00
|
|
|
|
|
|
|
|
|
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain"
|
|
|
|
|
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/entity"
|
|
|
|
|
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/repository"
|
|
|
|
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
2025-02-12 01:51:46 +00:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// TokenRepositoryParam token 需要的參數
|
|
|
|
|
type TokenRepositoryParam struct {
|
|
|
|
|
Redis *redis.Redis
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TokenRepository 用於操作 token 的存儲庫
|
|
|
|
|
type TokenRepository struct {
|
|
|
|
|
TokenRepositoryParam
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NewTokenRepository 初始化並返回一個 TokenRepository 實例
|
|
|
|
|
func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepo {
|
|
|
|
|
return &TokenRepository{
|
|
|
|
|
TokenRepositoryParam: param,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ====================== 私有工具函數 =======================
|
|
|
|
|
|
|
|
|
|
// runPipeline 執行 Redis pipeline 操作
|
|
|
|
|
func (repo *TokenRepository) runPipeline(ctx context.Context, pipelineFunc func(tx redis.Pipeliner) error) error {
|
|
|
|
|
return repo.Redis.PipelinedCtx(ctx, pipelineFunc)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// setToken 使用 Redis pipeline 存儲 token 並設定 TTL
|
|
|
|
|
func (repo *TokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, id string, body []byte, ttl time.Duration) error {
|
|
|
|
|
return tx.Set(ctx, domain.GetAccessTokenRedisKey(id), body, ttl).Err()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// setRefreshToken 若 token 中有 refresh token,則存儲之(使用 pipeline)
|
|
|
|
|
func (repo *TokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, ttl time.Duration) error {
|
|
|
|
|
if token.RefreshToken == "" {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
return tx.Set(ctx, domain.GetRefreshTokenRedisKey(token.RefreshToken), token.ID, ttl).Err()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// setTokenRelation 在 Redis 中設定 token 與 UID/Device 之間的關聯,並設定過期時間
|
|
|
|
|
func (repo *TokenRepository) setTokenRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, ttl time.Duration) error {
|
|
|
|
|
// 定義需要執行的操作列表
|
|
|
|
|
operations := []struct {
|
|
|
|
|
key string
|
|
|
|
|
op func() error
|
|
|
|
|
}{
|
|
|
|
|
{
|
|
|
|
|
key: domain.GetUIDTokenRedisKey(uid),
|
|
|
|
|
op: func() error {
|
|
|
|
|
return tx.SAdd(ctx, domain.GetUIDTokenRedisKey(uid), tokenID).Err()
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
key: domain.GetDeviceTokenRedisKey(deviceID),
|
|
|
|
|
op: func() error {
|
|
|
|
|
return tx.SAdd(ctx, domain.GetDeviceTokenRedisKey(deviceID), tokenID).Err()
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 執行每個操作,並為對應 key 設置過期時間
|
|
|
|
|
for _, operation := range operations {
|
|
|
|
|
if err := operation.op(); err != nil {
|
|
|
|
|
return fmt.Errorf("failed to create token relaction: %w", err)
|
|
|
|
|
}
|
|
|
|
|
if err := tx.Expire(ctx, operation.key, ttl).Err(); err != nil {
|
|
|
|
|
return fmt.Errorf("failed to set expire: %w", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// retrieveToken 根據指定 key 從 Redis 中獲取 token
|
|
|
|
|
func (repo *TokenRepository) retrieveToken(ctx context.Context, key string) (entity.Token, error) {
|
|
|
|
|
body, err := repo.Redis.GetCtx(ctx, key)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return entity.Token{}, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if body == "" {
|
|
|
|
|
return entity.Token{}, fmt.Errorf("failed to found token")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var token entity.Token
|
|
|
|
|
if err := json.Unmarshal([]byte(body), &token); err != nil {
|
|
|
|
|
return entity.Token{}, fmt.Errorf("failed to unmarshal token JSON: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return token, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// getTokensBySet 根據集合 key 獲取所有 token
|
|
|
|
|
func (repo *TokenRepository) getTokensBySet(ctx context.Context, setKey string) ([]entity.Token, error) {
|
|
|
|
|
ids, err := repo.Redis.Smembers(setKey)
|
|
|
|
|
if err != nil {
|
|
|
|
|
if errors.Is(err, redis.Nil) {
|
|
|
|
|
return nil, nil
|
|
|
|
|
}
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tokens := make([]entity.Token, 0, len(ids))
|
|
|
|
|
var tokensToDelete []string
|
|
|
|
|
now := time.Now().UnixNano()
|
|
|
|
|
for _, id := range ids {
|
|
|
|
|
token, err := repo.retrieveToken(ctx, domain.GetAccessTokenRedisKey(id))
|
|
|
|
|
if err != nil {
|
|
|
|
|
tokensToDelete = append(tokensToDelete, id)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if token.ExpiresIn < now {
|
|
|
|
|
tokensToDelete = append(tokensToDelete, id)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
tokens = append(tokens, token)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 清除過期或錯誤的 token
|
|
|
|
|
if len(tokensToDelete) > 0 {
|
|
|
|
|
_ = repo.DeleteAccessTokenByID(ctx, tokensToDelete)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return tokens, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// getCountBySet 獲取集合中元素的數量
|
|
|
|
|
func (repo *TokenRepository) getCountBySet(ctx context.Context, setKey string) (int, error) {
|
|
|
|
|
count, err := repo.Redis.ScardCtx(ctx, setKey)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return 0, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return int(count), nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// deleteKeysAndRelations 刪除指定的 Redis key 並移除相關關聯(UID 與 DeviceID)
|
|
|
|
|
func (repo *TokenRepository) deleteKeysAndRelations(ctx context.Context, keys []string, uid, deviceID, tokenID string) error {
|
|
|
|
|
err := repo.Redis.Pipelined(func(tx redis.Pipeliner) error {
|
|
|
|
|
// 移除 UID 與 DeviceID 關聯中的 tokenID
|
|
|
|
|
_ = tx.SRem(ctx, domain.GetUIDTokenRedisKey(uid), tokenID)
|
|
|
|
|
_ = tx.SRem(ctx, domain.GetDeviceTokenRedisKey(deviceID), tokenID)
|
|
|
|
|
|
|
|
|
|
// 刪除所有指定的 keys
|
|
|
|
|
for _, key := range keys {
|
|
|
|
|
_ = tx.Del(ctx, key)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// batchDeleteKeys 批量刪除 Redis keys
|
|
|
|
|
func (repo *TokenRepository) batchDeleteKeys(ctx context.Context, keys ...string) error {
|
|
|
|
|
return repo.Redis.Pipelined(func(tx redis.Pipeliner) error {
|
|
|
|
|
for _, key := range keys {
|
|
|
|
|
if err := tx.Del(ctx, key).Err(); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ====================== 公開方法 =======================
|
|
|
|
|
|
|
|
|
|
// Create 創建新的 token
|
|
|
|
|
func (repo *TokenRepository) Create(ctx context.Context, token entity.Token) error {
|
|
|
|
|
binToken, err := json.Marshal(token)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
// 根據 token 設定 refresh 過期秒數計算 TTL
|
|
|
|
|
refreshTTL := time.Duration(token.RedisRefreshExpiredSec()) * time.Second
|
|
|
|
|
|
|
|
|
|
return repo.runPipeline(ctx, func(tx redis.Pipeliner) error {
|
|
|
|
|
if err := repo.setToken(ctx, tx, token.ID, binToken, refreshTTL); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if err := repo.setRefreshToken(ctx, tx, token, refreshTTL); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := repo.setTokenRelation(ctx, tx, token.UID, token.DeviceID, token.ID, refreshTTL); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// CreateOneTimeToken 創建一次性 token
|
|
|
|
|
func (repo *TokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, dt time.Duration) error {
|
|
|
|
|
body, err := json.Marshal(ticket)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_, err = repo.Redis.SetnxExCtx(ctx, domain.GetRefreshTokenRedisKey(key), string(body), int(dt.Seconds()))
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GetAccessTokenByOneTimeToken 根據一次性 token 獲取 access token
|
|
|
|
|
func (repo *TokenRepository) GetAccessTokenByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error) {
|
|
|
|
|
return repo.retrieveToken(ctx, domain.GetRefreshTokenRedisKey(oneTimeToken))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GetAccessTokenByID 根據 token ID 獲取 access token
|
|
|
|
|
func (repo *TokenRepository) GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) {
|
|
|
|
|
return repo.retrieveToken(ctx, domain.GetAccessTokenRedisKey(id))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GetAccessTokensByUID 根據 UID 獲取所有 access tokens
|
|
|
|
|
func (repo *TokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) {
|
|
|
|
|
return repo.getTokensBySet(ctx, domain.GetUIDTokenRedisKey(uid))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GetAccessTokenCountByUID 根據 UID 獲取 access token 的數量
|
|
|
|
|
func (repo *TokenRepository) GetAccessTokenCountByUID(ctx context.Context, uid string) (int, error) {
|
|
|
|
|
return repo.getCountBySet(ctx, domain.GetUIDTokenRedisKey(uid))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GetAccessTokensByDeviceID 根據 DeviceID 獲取所有 access tokens
|
|
|
|
|
func (repo *TokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) {
|
|
|
|
|
return repo.getTokensBySet(ctx, domain.GetDeviceTokenRedisKey(deviceID))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GetAccessTokenCountByDeviceID 根據 DeviceID 獲取 access token 的數量
|
|
|
|
|
func (repo *TokenRepository) GetAccessTokenCountByDeviceID(ctx context.Context, deviceID string) (int, error) {
|
|
|
|
|
return repo.getCountBySet(ctx, domain.GetDeviceTokenRedisKey(deviceID))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Delete 刪除指定的 token
|
|
|
|
|
func (repo *TokenRepository) Delete(ctx context.Context, token entity.Token) error {
|
|
|
|
|
keys := []string{
|
|
|
|
|
domain.GetAccessTokenRedisKey(token.ID),
|
|
|
|
|
domain.GetRefreshTokenRedisKey(token.RefreshToken),
|
|
|
|
|
}
|
|
|
|
|
return repo.deleteKeysAndRelations(ctx, keys, token.UID, token.DeviceID, token.ID)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// DeleteAccessTokenByID 根據 token ID 刪除 access token
|
|
|
|
|
func (repo *TokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error {
|
|
|
|
|
for _, tokenID := range ids {
|
|
|
|
|
token, err := repo.GetAccessTokenByID(ctx, tokenID)
|
|
|
|
|
if err != nil {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
keys := []string{
|
|
|
|
|
domain.GetAccessTokenRedisKey(token.ID),
|
|
|
|
|
domain.GetRefreshTokenRedisKey(token.RefreshToken),
|
|
|
|
|
}
|
|
|
|
|
_ = repo.deleteKeysAndRelations(ctx, keys, token.UID, token.DeviceID, token.ID)
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// DeleteAccessTokensByUID 根據 UID 刪除所有 access tokens
|
|
|
|
|
func (repo *TokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error {
|
|
|
|
|
tokens, err := repo.GetAccessTokensByUID(ctx, uid)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
for _, token := range tokens {
|
|
|
|
|
if err := repo.Delete(ctx, token); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// DeleteAccessTokensByDeviceID 根據 DeviceID 刪除所有 access tokens
|
|
|
|
|
func (repo *TokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error {
|
|
|
|
|
tokens, err := repo.GetAccessTokensByDeviceID(ctx, deviceID)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 預分配 keys:每個 token 包含兩個 key
|
|
|
|
|
keys := make([]string, 0, len(tokens)*2)
|
|
|
|
|
for _, token := range tokens {
|
|
|
|
|
keys = append(keys, domain.GetAccessTokenRedisKey(token.ID))
|
|
|
|
|
keys = append(keys, domain.GetRefreshTokenRedisKey(token.RefreshToken))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 移除 UID 關聯中的 tokenID
|
|
|
|
|
if err := repo.runPipeline(ctx, func(tx redis.Pipeliner) error {
|
|
|
|
|
for _, token := range tokens {
|
|
|
|
|
_ = tx.SRem(ctx, domain.GetUIDTokenRedisKey(token.UID), token.ID)
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := repo.batchDeleteKeys(ctx, keys...); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_, err = repo.Redis.Del(domain.GetDeviceTokenRedisKey(deviceID))
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// DeleteOneTimeToken 刪除一次性 token(支持多個 key 一併刪除)
|
|
|
|
|
func (repo *TokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error {
|
|
|
|
|
totalKeys := len(ids) + len(tokens)
|
|
|
|
|
keys := make([]string, 0, totalKeys)
|
|
|
|
|
|
|
|
|
|
for _, id := range ids {
|
|
|
|
|
keys = append(keys, domain.GetRefreshTokenRedisKey(id))
|
|
|
|
|
}
|
|
|
|
|
for _, token := range tokens {
|
|
|
|
|
keys = append(keys, domain.GetRefreshTokenRedisKey(token.RefreshToken))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return repo.batchDeleteKeys(ctx, keys...)
|
|
|
|
|
}
|