app-cloudep-permission-server/pkg/repository/token.go

333 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain"
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/entity"
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/repository"
"github.com/zeromicro/go-zero/core/stores/redis"
)
// TokenRepositoryParam token 需要的參數
type TokenRepositoryParam struct {
Redis *redis.Redis
}
// TokenRepository 用於操作 token 的存儲庫
type TokenRepository struct {
TokenRepositoryParam
}
// NewTokenRepository 初始化並返回一個 TokenRepository 實例
func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepo {
return &TokenRepository{
TokenRepositoryParam: param,
}
}
// ====================== 私有工具函數 =======================
// runPipeline 執行 Redis pipeline 操作
func (repo *TokenRepository) runPipeline(ctx context.Context, pipelineFunc func(tx redis.Pipeliner) error) error {
return repo.Redis.PipelinedCtx(ctx, pipelineFunc)
}
// setToken 使用 Redis pipeline 存儲 token 並設定 TTL
func (repo *TokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, id string, body []byte, ttl time.Duration) error {
return tx.Set(ctx, domain.GetAccessTokenRedisKey(id), body, ttl).Err()
}
// setRefreshToken 若 token 中有 refresh token則存儲之使用 pipeline
func (repo *TokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, ttl time.Duration) error {
if token.RefreshToken == "" {
return nil
}
return tx.Set(ctx, domain.GetRefreshTokenRedisKey(token.RefreshToken), token.ID, ttl).Err()
}
// setTokenRelation 在 Redis 中設定 token 與 UID/Device 之間的關聯,並設定過期時間
func (repo *TokenRepository) setTokenRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, ttl time.Duration) error {
// 定義需要執行的操作列表
operations := []struct {
key string
op func() error
}{
{
key: domain.GetUIDTokenRedisKey(uid),
op: func() error {
return tx.SAdd(ctx, domain.GetUIDTokenRedisKey(uid), tokenID).Err()
},
},
{
key: domain.GetDeviceTokenRedisKey(deviceID),
op: func() error {
return tx.SAdd(ctx, domain.GetDeviceTokenRedisKey(deviceID), tokenID).Err()
},
},
}
// 執行每個操作,並為對應 key 設置過期時間
for _, operation := range operations {
if err := operation.op(); err != nil {
return fmt.Errorf("failed to create token relaction: %w", err)
}
if err := tx.Expire(ctx, operation.key, ttl).Err(); err != nil {
return fmt.Errorf("failed to set expire: %w", err)
}
}
return nil
}
// retrieveToken 根據指定 key 從 Redis 中獲取 token
func (repo *TokenRepository) retrieveToken(ctx context.Context, key string) (entity.Token, error) {
body, err := repo.Redis.GetCtx(ctx, key)
if err != nil {
return entity.Token{}, err
}
if body == "" {
return entity.Token{}, fmt.Errorf("failed to found token")
}
var token entity.Token
if err := json.Unmarshal([]byte(body), &token); err != nil {
return entity.Token{}, fmt.Errorf("failed to unmarshal token JSON: %w", err)
}
return token, nil
}
// getTokensBySet 根據集合 key 獲取所有 token
func (repo *TokenRepository) getTokensBySet(ctx context.Context, setKey string) ([]entity.Token, error) {
ids, err := repo.Redis.Smembers(setKey)
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, nil
}
return nil, err
}
tokens := make([]entity.Token, 0, len(ids))
var tokensToDelete []string
now := time.Now().UnixNano()
for _, id := range ids {
token, err := repo.retrieveToken(ctx, domain.GetAccessTokenRedisKey(id))
if err != nil {
tokensToDelete = append(tokensToDelete, id)
continue
}
if token.ExpiresIn < now {
tokensToDelete = append(tokensToDelete, id)
continue
}
tokens = append(tokens, token)
}
// 清除過期或錯誤的 token
if len(tokensToDelete) > 0 {
_ = repo.DeleteAccessTokenByID(ctx, tokensToDelete)
}
return tokens, nil
}
// getCountBySet 獲取集合中元素的數量
func (repo *TokenRepository) getCountBySet(ctx context.Context, setKey string) (int, error) {
count, err := repo.Redis.ScardCtx(ctx, setKey)
if err != nil {
return 0, err
}
return int(count), nil
}
// deleteKeysAndRelations 刪除指定的 Redis key 並移除相關關聯UID 與 DeviceID
func (repo *TokenRepository) deleteKeysAndRelations(ctx context.Context, keys []string, uid, deviceID, tokenID string) error {
err := repo.Redis.Pipelined(func(tx redis.Pipeliner) error {
// 移除 UID 與 DeviceID 關聯中的 tokenID
_ = tx.SRem(ctx, domain.GetUIDTokenRedisKey(uid), tokenID)
_ = tx.SRem(ctx, domain.GetDeviceTokenRedisKey(deviceID), tokenID)
// 刪除所有指定的 keys
for _, key := range keys {
_ = tx.Del(ctx, key)
}
return nil
})
return err
}
// batchDeleteKeys 批量刪除 Redis keys
func (repo *TokenRepository) batchDeleteKeys(ctx context.Context, keys ...string) error {
return repo.Redis.Pipelined(func(tx redis.Pipeliner) error {
for _, key := range keys {
if err := tx.Del(ctx, key).Err(); err != nil {
return err
}
}
return nil
})
}
// ====================== 公開方法 =======================
// Create 創建新的 token
func (repo *TokenRepository) Create(ctx context.Context, token entity.Token) error {
binToken, err := json.Marshal(token)
if err != nil {
return err
}
// 根據 token 設定 refresh 過期秒數計算 TTL
refreshTTL := time.Duration(token.RedisRefreshExpiredSec()) * time.Second
return repo.runPipeline(ctx, func(tx redis.Pipeliner) error {
if err := repo.setToken(ctx, tx, token.ID, binToken, refreshTTL); err != nil {
return err
}
if err := repo.setRefreshToken(ctx, tx, token, refreshTTL); err != nil {
return err
}
if err := repo.setTokenRelation(ctx, tx, token.UID, token.DeviceID, token.ID, refreshTTL); err != nil {
return err
}
return nil
})
}
// CreateOneTimeToken 創建一次性 token
func (repo *TokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, dt time.Duration) error {
body, err := json.Marshal(ticket)
if err != nil {
return err
}
_, err = repo.Redis.SetnxExCtx(ctx, domain.GetRefreshTokenRedisKey(key), string(body), int(dt.Seconds()))
return err
}
// GetAccessTokenByOneTimeToken 根據一次性 token 獲取 access token
func (repo *TokenRepository) GetAccessTokenByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error) {
return repo.retrieveToken(ctx, domain.GetRefreshTokenRedisKey(oneTimeToken))
}
// GetAccessTokenByID 根據 token ID 獲取 access token
func (repo *TokenRepository) GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) {
return repo.retrieveToken(ctx, domain.GetAccessTokenRedisKey(id))
}
// GetAccessTokensByUID 根據 UID 獲取所有 access tokens
func (repo *TokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) {
return repo.getTokensBySet(ctx, domain.GetUIDTokenRedisKey(uid))
}
// GetAccessTokenCountByUID 根據 UID 獲取 access token 的數量
func (repo *TokenRepository) GetAccessTokenCountByUID(ctx context.Context, uid string) (int, error) {
return repo.getCountBySet(ctx, domain.GetUIDTokenRedisKey(uid))
}
// GetAccessTokensByDeviceID 根據 DeviceID 獲取所有 access tokens
func (repo *TokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) {
return repo.getTokensBySet(ctx, domain.GetDeviceTokenRedisKey(deviceID))
}
// GetAccessTokenCountByDeviceID 根據 DeviceID 獲取 access token 的數量
func (repo *TokenRepository) GetAccessTokenCountByDeviceID(ctx context.Context, deviceID string) (int, error) {
return repo.getCountBySet(ctx, domain.GetDeviceTokenRedisKey(deviceID))
}
// Delete 刪除指定的 token
func (repo *TokenRepository) Delete(ctx context.Context, token entity.Token) error {
keys := []string{
domain.GetAccessTokenRedisKey(token.ID),
domain.GetRefreshTokenRedisKey(token.RefreshToken),
}
return repo.deleteKeysAndRelations(ctx, keys, token.UID, token.DeviceID, token.ID)
}
// DeleteAccessTokenByID 根據 token ID 刪除 access token
func (repo *TokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error {
for _, tokenID := range ids {
token, err := repo.GetAccessTokenByID(ctx, tokenID)
if err != nil {
continue
}
keys := []string{
domain.GetAccessTokenRedisKey(token.ID),
domain.GetRefreshTokenRedisKey(token.RefreshToken),
}
_ = repo.deleteKeysAndRelations(ctx, keys, token.UID, token.DeviceID, token.ID)
}
return nil
}
// DeleteAccessTokensByUID 根據 UID 刪除所有 access tokens
func (repo *TokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error {
tokens, err := repo.GetAccessTokensByUID(ctx, uid)
if err != nil {
return err
}
for _, token := range tokens {
if err := repo.Delete(ctx, token); err != nil {
return err
}
}
return nil
}
// DeleteAccessTokensByDeviceID 根據 DeviceID 刪除所有 access tokens
func (repo *TokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error {
tokens, err := repo.GetAccessTokensByDeviceID(ctx, deviceID)
if err != nil {
return err
}
// 預分配 keys每個 token 包含兩個 key
keys := make([]string, 0, len(tokens)*2)
for _, token := range tokens {
keys = append(keys, domain.GetAccessTokenRedisKey(token.ID))
keys = append(keys, domain.GetRefreshTokenRedisKey(token.RefreshToken))
}
// 移除 UID 關聯中的 tokenID
if err := repo.runPipeline(ctx, func(tx redis.Pipeliner) error {
for _, token := range tokens {
_ = tx.SRem(ctx, domain.GetUIDTokenRedisKey(token.UID), token.ID)
}
return nil
}); err != nil {
return err
}
if err := repo.batchDeleteKeys(ctx, keys...); err != nil {
return err
}
_, err = repo.Redis.Del(domain.GetDeviceTokenRedisKey(deviceID))
return err
}
// DeleteOneTimeToken 刪除一次性 token支持多個 key 一併刪除)
func (repo *TokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error {
totalKeys := len(ids) + len(tokens)
keys := make([]string, 0, totalKeys)
for _, id := range ids {
keys = append(keys, domain.GetRefreshTokenRedisKey(id))
}
for _, token := range tokens {
keys = append(keys, domain.GetRefreshTokenRedisKey(token.RefreshToken))
}
return repo.batchDeleteKeys(ctx, keys...)
}