431 lines
11 KiB
Go
Executable File
431 lines
11 KiB
Go
Executable File
package repository
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"time"
|
||
|
||
"backend/pkg/permission/domain/entity"
|
||
"backend/pkg/permission/domain/repository"
|
||
"backend/pkg/permission/domain/token"
|
||
|
||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||
)
|
||
|
||
// TokenRepositoryParam token 需要的參數
|
||
type TokenRepositoryParam struct {
|
||
Redis *redis.Redis
|
||
}
|
||
|
||
// TokenRepository 通知
|
||
type TokenRepository struct {
|
||
TokenRepositoryParam
|
||
}
|
||
|
||
func MustTokenRepository(param TokenRepositoryParam) repository.TokenRepository {
|
||
return &TokenRepository{
|
||
param,
|
||
}
|
||
}
|
||
|
||
// Create 創建一個新 Token,並將其存儲於 Redis
|
||
func (repo *TokenRepository) Create(ctx context.Context, token entity.Token) error {
|
||
body, err := json.Marshal(token)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
refreshTTL := time.Duration(token.RedisRefreshExpiredSec()) * time.Second
|
||
|
||
return repo.runPipeline(ctx, func(tx redis.Pipeliner) error {
|
||
if err := repo.setToken(ctx, tx, token, body, refreshTTL); err != nil {
|
||
return err
|
||
}
|
||
if err := repo.setRefreshToken(ctx, tx, token, refreshTTL); err != nil {
|
||
return err
|
||
}
|
||
|
||
return repo.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, refreshTTL)
|
||
})
|
||
}
|
||
|
||
func (repo *TokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, dt time.Duration) error {
|
||
body, err := json.Marshal(ticket)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
_, err = repo.Redis.SetnxExCtx(ctx, token.RefreshTokenRedisKey(key), string(body), int(dt.Seconds()))
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (repo *TokenRepository) GetAccessTokenByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error) {
|
||
id, err := repo.Redis.Get(token.RefreshTokenRedisKey(oneTimeToken))
|
||
if err != nil {
|
||
return entity.Token{}, err
|
||
}
|
||
|
||
if id == "" {
|
||
return entity.Token{}, fmt.Errorf("token not found")
|
||
}
|
||
|
||
return repo.GetAccessTokenByID(ctx, id)
|
||
}
|
||
|
||
func (repo *TokenRepository) GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) {
|
||
return repo.get(ctx, token.GetAccessTokenRedisKey(id))
|
||
}
|
||
|
||
func (repo *TokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) {
|
||
return repo.getTokensBySet(ctx, token.GetUIDTokenRedisKey(uid))
|
||
}
|
||
|
||
func (repo *TokenRepository) GetAccessTokenCountByUID(ctx context.Context, uid string) (int, error) {
|
||
return repo.getCountBySet(ctx, token.UIDTokenRedisKey(uid))
|
||
}
|
||
|
||
func (repo *TokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) {
|
||
return repo.getTokensBySet(ctx, token.DeviceTokenRedisKey(deviceID))
|
||
}
|
||
|
||
func (repo *TokenRepository) GetAccessTokenCountByDeviceID(ctx context.Context, deviceID string) (int, error) {
|
||
return repo.getCountBySet(ctx, token.DeviceTokenRedisKey(deviceID))
|
||
}
|
||
|
||
func (repo *TokenRepository) Delete(ctx context.Context, tokenObj entity.Token) error {
|
||
// Delete 刪除指定的 Token
|
||
keys := []string{
|
||
token.GetAccessTokenRedisKey(tokenObj.ID),
|
||
token.RefreshTokenRedisKey(tokenObj.RefreshToken),
|
||
}
|
||
|
||
return repo.deleteKeysAndRelations(ctx, keys, tokenObj.UID, tokenObj.DeviceID, tokenObj.ID)
|
||
}
|
||
|
||
func (repo *TokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error {
|
||
l := len(ids) + len(tokens)
|
||
keys := make([]string, 0, l)
|
||
|
||
for _, id := range ids {
|
||
keys = append(keys, token.RefreshTokenRedisKey(id))
|
||
}
|
||
|
||
for _, tokenObj := range tokens {
|
||
keys = append(keys, token.RefreshTokenRedisKey(tokenObj.RefreshToken))
|
||
}
|
||
|
||
return repo.deleteKeys(ctx, keys...)
|
||
}
|
||
|
||
func (repo *TokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error {
|
||
for _, tokenID := range ids {
|
||
tokenObj, err := repo.GetAccessTokenByID(ctx, tokenID)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
|
||
keys := []string{
|
||
token.GetAccessTokenRedisKey(tokenObj.ID),
|
||
token.RefreshTokenRedisKey(tokenObj.RefreshToken),
|
||
}
|
||
|
||
_ = repo.deleteKeysAndRelations(ctx, keys, tokenObj.UID, tokenObj.DeviceID, tokenObj.ID)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (repo *TokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error {
|
||
tokens, err := repo.GetAccessTokensByUID(ctx, uid)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
for _, token := range tokens {
|
||
if err := repo.Delete(ctx, token); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (repo *TokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error {
|
||
tokens, err := repo.GetAccessTokensByDeviceID(ctx, deviceID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
l := len(tokens) * 2
|
||
keys := make([]string, 0, l)
|
||
for _, tokenObj := range tokens {
|
||
keys = append(keys, token.GetAccessTokenRedisKey(tokenObj.ID))
|
||
keys = append(keys, token.RefreshTokenRedisKey(tokenObj.RefreshToken))
|
||
}
|
||
|
||
err = repo.runPipeline(ctx, func(tx redis.Pipeliner) error {
|
||
for _, tokenObj := range tokens {
|
||
tx.SRem(ctx, token.UIDTokenRedisKey(tokenObj.UID), tokenObj.ID)
|
||
}
|
||
|
||
return nil
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := repo.deleteKeys(ctx, keys...); err != nil {
|
||
return err
|
||
}
|
||
|
||
_, err = repo.Redis.Del(token.DeviceTokenRedisKey(deviceID))
|
||
|
||
return err
|
||
}
|
||
|
||
// ========================================================================
|
||
// deleteKeysAndRelations 刪除指定鍵並移除相關的關聯
|
||
func (repo *TokenRepository) deleteKeysAndRelations(ctx context.Context, keys []string, uid, deviceID, tokenID string) error {
|
||
err := repo.Redis.Pipelined(func(tx redis.Pipeliner) error {
|
||
// 刪除 UID 和 DeviceID 的關聯
|
||
_ = tx.SRem(ctx, token.UIDTokenRedisKey(uid), tokenID)
|
||
_ = tx.SRem(ctx, token.DeviceTokenRedisKey(deviceID), tokenID)
|
||
|
||
for _, key := range keys {
|
||
_ = tx.Del(ctx, key)
|
||
}
|
||
|
||
return nil
|
||
})
|
||
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// runPipeline 執行 Redis 的 Pipeline 操作
|
||
func (repo *TokenRepository) runPipeline(ctx context.Context, fn func(tx redis.Pipeliner) error) error {
|
||
if err := repo.Redis.PipelinedCtx(ctx, fn); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// deleteKeys 批量刪除 Redis 鍵
|
||
func (repo *TokenRepository) deleteKeys(ctx context.Context, keys ...string) error {
|
||
return repo.Redis.Pipelined(func(tx redis.Pipeliner) error {
|
||
for _, key := range keys {
|
||
if err := tx.Del(ctx, key).Err(); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
})
|
||
}
|
||
|
||
func (repo *TokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, tokenObj entity.Token, body []byte, ttl time.Duration) error {
|
||
return tx.Set(ctx, token.GetAccessTokenRedisKey(tokenObj.ID), body, ttl).Err()
|
||
}
|
||
|
||
func (repo *TokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, tokenObj entity.Token, ttl time.Duration) error {
|
||
if tokenObj.RefreshToken != "" {
|
||
return tx.Set(ctx, token.RefreshTokenRedisKey(tokenObj.RefreshToken), tokenObj.ID, ttl).Err()
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (repo *TokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, ttl time.Duration) error {
|
||
if err := tx.SAdd(ctx, token.UIDTokenRedisKey(uid), tokenID).Err(); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 設置 UID 鍵的過期時間
|
||
if err := tx.Expire(ctx, token.UIDTokenRedisKey(uid), ttl).Err(); err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := tx.SAdd(ctx, token.DeviceTokenRedisKey(deviceID), tokenID).Err(); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 設置 deviceID 鍵的過期時間
|
||
if err := tx.Expire(ctx, token.DeviceTokenRedisKey(deviceID), ttl).Err(); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// get 根據鍵獲取 Token
|
||
func (repo *TokenRepository) get(ctx context.Context, key string) (entity.Token, error) {
|
||
body, err := repo.Redis.GetCtx(ctx, key)
|
||
if err != nil {
|
||
return entity.Token{}, err
|
||
}
|
||
|
||
if body == "" {
|
||
return entity.Token{}, fmt.Errorf("token not found")
|
||
}
|
||
|
||
var token entity.Token
|
||
if err := json.Unmarshal([]byte(body), &token); err != nil {
|
||
return entity.Token{}, fmt.Errorf("json.Marshal token error")
|
||
}
|
||
|
||
return token, nil
|
||
}
|
||
|
||
// getTokensBySet 根據集合鍵獲取所有 Token
|
||
func (repo *TokenRepository) getTokensBySet(ctx context.Context, setKey string) ([]entity.Token, error) {
|
||
ids, err := repo.Redis.Smembers(setKey)
|
||
if err != nil {
|
||
if errors.Is(err, redis.Nil) {
|
||
return nil, nil
|
||
}
|
||
|
||
return nil, err
|
||
}
|
||
|
||
tokens := make([]entity.Token, 0, len(ids))
|
||
var deleteTokens []string
|
||
now := time.Now().Unix()
|
||
for _, id := range ids {
|
||
token, err := repo.get(ctx, token.GetAccessTokenRedisKey(id))
|
||
if err != nil {
|
||
deleteTokens = append(deleteTokens, id)
|
||
|
||
continue
|
||
}
|
||
|
||
if int64(token.ExpiresIn) < now {
|
||
deleteTokens = append(deleteTokens, id)
|
||
|
||
continue
|
||
}
|
||
tokens = append(tokens, token)
|
||
}
|
||
|
||
if len(deleteTokens) > 0 {
|
||
_ = repo.DeleteAccessTokenByID(ctx, deleteTokens)
|
||
}
|
||
|
||
return tokens, nil
|
||
}
|
||
|
||
// getCountBySet 獲取集合中的元素數量
|
||
func (repo *TokenRepository) getCountBySet(ctx context.Context, setKey string) (int, error) {
|
||
count, err := repo.Redis.ScardCtx(ctx, setKey)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
return int(count), nil
|
||
}
|
||
|
||
// AddToBlacklist 將 token 加入黑名單
|
||
func (repo *TokenRepository) AddToBlacklist(ctx context.Context, entry *entity.BlacklistEntry, ttl time.Duration) error {
|
||
key := token.GetBlacklistRedisKey(entry.JTI)
|
||
|
||
// 序列化黑名單條目
|
||
data, err := json.Marshal(entry)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to marshal blacklist entry: %w", err)
|
||
}
|
||
|
||
// 使用提供的 TTL,如果 TTL <= 0,則計算默認 TTL
|
||
if ttl <= 0 {
|
||
// 計算 TTL (token 過期時間 - 當前時間)
|
||
ttl = time.Unix(entry.ExpiresAt, 0).Sub(time.Now())
|
||
if ttl <= 0 {
|
||
// Token 已經過期,不需要加入黑名單
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// 存儲到 Redis 並設置過期時間
|
||
err = repo.Redis.SetexCtx(ctx, key, string(data), int(ttl.Seconds()))
|
||
if err != nil {
|
||
return fmt.Errorf("failed to add token to blacklist: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// IsBlacklisted 檢查 token 是否在黑名單中
|
||
func (repo *TokenRepository) IsBlacklisted(ctx context.Context, jti string) (bool, error) {
|
||
key := token.GetBlacklistRedisKey(jti)
|
||
|
||
exists, err := repo.Redis.ExistsCtx(ctx, key)
|
||
if err != nil {
|
||
return false, fmt.Errorf("failed to check blacklist: %w", err)
|
||
}
|
||
|
||
return exists, nil
|
||
}
|
||
|
||
// RemoveFromBlacklist 從黑名單中移除 token
|
||
func (repo *TokenRepository) RemoveFromBlacklist(ctx context.Context, jti string) error {
|
||
key := token.GetBlacklistRedisKey(jti)
|
||
|
||
_, err := repo.Redis.DelCtx(ctx, key)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to remove token from blacklist: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetBlacklistedTokensByUID 獲取用戶的所有黑名單 token
|
||
func (repo *TokenRepository) GetBlacklistedTokensByUID(ctx context.Context, uid string) ([]*entity.BlacklistEntry, error) {
|
||
// 使用 SCAN 來查找所有黑名單鍵
|
||
pattern := token.BlacklistKeyPrefix + "*"
|
||
|
||
var entries []*entity.BlacklistEntry
|
||
var cursor uint64 = 0
|
||
|
||
for {
|
||
keys, nextCursor, err := repo.Redis.ScanCtx(ctx, cursor, pattern, 100)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to scan blacklist keys: %w", err)
|
||
}
|
||
|
||
// 獲取每個鍵的值並檢查 UID
|
||
for _, key := range keys {
|
||
data, err := repo.Redis.GetCtx(ctx, key)
|
||
if err != nil {
|
||
if errors.Is(err, redis.Nil) {
|
||
continue // 鍵已過期或不存在
|
||
}
|
||
return nil, fmt.Errorf("failed to get blacklist entry: %w", err)
|
||
}
|
||
|
||
var entry entity.BlacklistEntry
|
||
if err := json.Unmarshal([]byte(data), &entry); err != nil {
|
||
continue // 跳過無效的條目
|
||
}
|
||
|
||
// 檢查 UID 是否匹配
|
||
if entry.UID == uid {
|
||
entries = append(entries, &entry)
|
||
}
|
||
}
|
||
|
||
cursor = nextCursor
|
||
if cursor == 0 {
|
||
break
|
||
}
|
||
}
|
||
|
||
return entries, nil
|
||
}
|