backend/pkg/permission/repository/token_model.go

431 lines
11 KiB
Go
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"backend/pkg/permission/domain/token"
"github.com/zeromicro/go-zero/core/stores/redis"
)
// TokenRepositoryParam token 需要的參數
type TokenRepositoryParam struct {
Redis *redis.Redis
}
// TokenRepository 通知
type TokenRepository struct {
TokenRepositoryParam
}
func MustTokenRepository(param TokenRepositoryParam) repository.TokenRepository {
return &TokenRepository{
param,
}
}
// Create 創建一個新 Token並將其存儲於 Redis
func (repo *TokenRepository) Create(ctx context.Context, token entity.Token) error {
body, err := json.Marshal(token)
if err != nil {
return err
}
refreshTTL := time.Duration(token.RedisRefreshExpiredSec()) * time.Second
return repo.runPipeline(ctx, func(tx redis.Pipeliner) error {
if err := repo.setToken(ctx, tx, token, body, refreshTTL); err != nil {
return err
}
if err := repo.setRefreshToken(ctx, tx, token, refreshTTL); err != nil {
return err
}
return repo.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, refreshTTL)
})
}
func (repo *TokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, dt time.Duration) error {
body, err := json.Marshal(ticket)
if err != nil {
return err
}
_, err = repo.Redis.SetnxExCtx(ctx, token.RefreshTokenRedisKey(key), string(body), int(dt.Seconds()))
if err != nil {
return err
}
return nil
}
func (repo *TokenRepository) GetAccessTokenByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error) {
id, err := repo.Redis.Get(token.RefreshTokenRedisKey(oneTimeToken))
if err != nil {
return entity.Token{}, err
}
if id == "" {
return entity.Token{}, fmt.Errorf("token not found")
}
return repo.GetAccessTokenByID(ctx, id)
}
func (repo *TokenRepository) GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) {
return repo.get(ctx, token.GetAccessTokenRedisKey(id))
}
func (repo *TokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) {
return repo.getTokensBySet(ctx, token.GetUIDTokenRedisKey(uid))
}
func (repo *TokenRepository) GetAccessTokenCountByUID(ctx context.Context, uid string) (int, error) {
return repo.getCountBySet(ctx, token.UIDTokenRedisKey(uid))
}
func (repo *TokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) {
return repo.getTokensBySet(ctx, token.DeviceTokenRedisKey(deviceID))
}
func (repo *TokenRepository) GetAccessTokenCountByDeviceID(ctx context.Context, deviceID string) (int, error) {
return repo.getCountBySet(ctx, token.DeviceTokenRedisKey(deviceID))
}
func (repo *TokenRepository) Delete(ctx context.Context, tokenObj entity.Token) error {
// Delete 刪除指定的 Token
keys := []string{
token.GetAccessTokenRedisKey(tokenObj.ID),
token.RefreshTokenRedisKey(tokenObj.RefreshToken),
}
return repo.deleteKeysAndRelations(ctx, keys, tokenObj.UID, tokenObj.DeviceID, tokenObj.ID)
}
func (repo *TokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error {
l := len(ids) + len(tokens)
keys := make([]string, 0, l)
for _, id := range ids {
keys = append(keys, token.RefreshTokenRedisKey(id))
}
for _, tokenObj := range tokens {
keys = append(keys, token.RefreshTokenRedisKey(tokenObj.RefreshToken))
}
return repo.deleteKeys(ctx, keys...)
}
func (repo *TokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error {
for _, tokenID := range ids {
tokenObj, err := repo.GetAccessTokenByID(ctx, tokenID)
if err != nil {
continue
}
keys := []string{
token.GetAccessTokenRedisKey(tokenObj.ID),
token.RefreshTokenRedisKey(tokenObj.RefreshToken),
}
_ = repo.deleteKeysAndRelations(ctx, keys, tokenObj.UID, tokenObj.DeviceID, tokenObj.ID)
}
return nil
}
func (repo *TokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error {
tokens, err := repo.GetAccessTokensByUID(ctx, uid)
if err != nil {
return err
}
for _, token := range tokens {
if err := repo.Delete(ctx, token); err != nil {
return err
}
}
return nil
}
func (repo *TokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error {
tokens, err := repo.GetAccessTokensByDeviceID(ctx, deviceID)
if err != nil {
return err
}
l := len(tokens) * 2
keys := make([]string, 0, l)
for _, tokenObj := range tokens {
keys = append(keys, token.GetAccessTokenRedisKey(tokenObj.ID))
keys = append(keys, token.RefreshTokenRedisKey(tokenObj.RefreshToken))
}
err = repo.runPipeline(ctx, func(tx redis.Pipeliner) error {
for _, tokenObj := range tokens {
tx.SRem(ctx, token.UIDTokenRedisKey(tokenObj.UID), tokenObj.ID)
}
return nil
})
if err != nil {
return err
}
if err := repo.deleteKeys(ctx, keys...); err != nil {
return err
}
_, err = repo.Redis.Del(token.DeviceTokenRedisKey(deviceID))
return err
}
// ========================================================================
// deleteKeysAndRelations 刪除指定鍵並移除相關的關聯
func (repo *TokenRepository) deleteKeysAndRelations(ctx context.Context, keys []string, uid, deviceID, tokenID string) error {
err := repo.Redis.Pipelined(func(tx redis.Pipeliner) error {
// 刪除 UID 和 DeviceID 的關聯
_ = tx.SRem(ctx, token.UIDTokenRedisKey(uid), tokenID)
_ = tx.SRem(ctx, token.DeviceTokenRedisKey(deviceID), tokenID)
for _, key := range keys {
_ = tx.Del(ctx, key)
}
return nil
})
if err != nil {
return err
}
return nil
}
// runPipeline 執行 Redis 的 Pipeline 操作
func (repo *TokenRepository) runPipeline(ctx context.Context, fn func(tx redis.Pipeliner) error) error {
if err := repo.Redis.PipelinedCtx(ctx, fn); err != nil {
return err
}
return nil
}
// deleteKeys 批量刪除 Redis 鍵
func (repo *TokenRepository) deleteKeys(ctx context.Context, keys ...string) error {
return repo.Redis.Pipelined(func(tx redis.Pipeliner) error {
for _, key := range keys {
if err := tx.Del(ctx, key).Err(); err != nil {
return err
}
}
return nil
})
}
func (repo *TokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, tokenObj entity.Token, body []byte, ttl time.Duration) error {
return tx.Set(ctx, token.GetAccessTokenRedisKey(tokenObj.ID), body, ttl).Err()
}
func (repo *TokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, tokenObj entity.Token, ttl time.Duration) error {
if tokenObj.RefreshToken != "" {
return tx.Set(ctx, token.RefreshTokenRedisKey(tokenObj.RefreshToken), tokenObj.ID, ttl).Err()
}
return nil
}
func (repo *TokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, ttl time.Duration) error {
if err := tx.SAdd(ctx, token.UIDTokenRedisKey(uid), tokenID).Err(); err != nil {
return err
}
// 設置 UID 鍵的過期時間
if err := tx.Expire(ctx, token.UIDTokenRedisKey(uid), ttl).Err(); err != nil {
return err
}
if err := tx.SAdd(ctx, token.DeviceTokenRedisKey(deviceID), tokenID).Err(); err != nil {
return err
}
// 設置 deviceID 鍵的過期時間
if err := tx.Expire(ctx, token.DeviceTokenRedisKey(deviceID), ttl).Err(); err != nil {
return err
}
return nil
}
// get 根據鍵獲取 Token
func (repo *TokenRepository) get(ctx context.Context, key string) (entity.Token, error) {
body, err := repo.Redis.GetCtx(ctx, key)
if err != nil {
return entity.Token{}, err
}
if body == "" {
return entity.Token{}, fmt.Errorf("token not found")
}
var token entity.Token
if err := json.Unmarshal([]byte(body), &token); err != nil {
return entity.Token{}, fmt.Errorf("json.Marshal token error")
}
return token, nil
}
// getTokensBySet 根據集合鍵獲取所有 Token
func (repo *TokenRepository) getTokensBySet(ctx context.Context, setKey string) ([]entity.Token, error) {
ids, err := repo.Redis.Smembers(setKey)
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, nil
}
return nil, err
}
tokens := make([]entity.Token, 0, len(ids))
var deleteTokens []string
now := time.Now().Unix()
for _, id := range ids {
token, err := repo.get(ctx, token.GetAccessTokenRedisKey(id))
if err != nil {
deleteTokens = append(deleteTokens, id)
continue
}
if int64(token.ExpiresIn) < now {
deleteTokens = append(deleteTokens, id)
continue
}
tokens = append(tokens, token)
}
if len(deleteTokens) > 0 {
_ = repo.DeleteAccessTokenByID(ctx, deleteTokens)
}
return tokens, nil
}
// getCountBySet 獲取集合中的元素數量
func (repo *TokenRepository) getCountBySet(ctx context.Context, setKey string) (int, error) {
count, err := repo.Redis.ScardCtx(ctx, setKey)
if err != nil {
return 0, err
}
return int(count), nil
}
// AddToBlacklist 將 token 加入黑名單
func (repo *TokenRepository) AddToBlacklist(ctx context.Context, entry *entity.BlacklistEntry, ttl time.Duration) error {
key := token.GetBlacklistRedisKey(entry.JTI)
// 序列化黑名單條目
data, err := json.Marshal(entry)
if err != nil {
return fmt.Errorf("failed to marshal blacklist entry: %w", err)
}
// 使用提供的 TTL如果 TTL <= 0則計算默認 TTL
if ttl <= 0 {
// 計算 TTL (token 過期時間 - 當前時間)
ttl = time.Unix(entry.ExpiresAt, 0).Sub(time.Now())
if ttl <= 0 {
// Token 已經過期,不需要加入黑名單
return nil
}
}
// 存儲到 Redis 並設置過期時間
err = repo.Redis.SetexCtx(ctx, key, string(data), int(ttl.Seconds()))
if err != nil {
return fmt.Errorf("failed to add token to blacklist: %w", err)
}
return nil
}
// IsBlacklisted 檢查 token 是否在黑名單中
func (repo *TokenRepository) IsBlacklisted(ctx context.Context, jti string) (bool, error) {
key := token.GetBlacklistRedisKey(jti)
exists, err := repo.Redis.ExistsCtx(ctx, key)
if err != nil {
return false, fmt.Errorf("failed to check blacklist: %w", err)
}
return exists, nil
}
// RemoveFromBlacklist 從黑名單中移除 token
func (repo *TokenRepository) RemoveFromBlacklist(ctx context.Context, jti string) error {
key := token.GetBlacklistRedisKey(jti)
_, err := repo.Redis.DelCtx(ctx, key)
if err != nil {
return fmt.Errorf("failed to remove token from blacklist: %w", err)
}
return nil
}
// GetBlacklistedTokensByUID 獲取用戶的所有黑名單 token
func (repo *TokenRepository) GetBlacklistedTokensByUID(ctx context.Context, uid string) ([]*entity.BlacklistEntry, error) {
// 使用 SCAN 來查找所有黑名單鍵
pattern := token.BlacklistKeyPrefix + "*"
var entries []*entity.BlacklistEntry
var cursor uint64 = 0
for {
keys, nextCursor, err := repo.Redis.ScanCtx(ctx, cursor, pattern, 100)
if err != nil {
return nil, fmt.Errorf("failed to scan blacklist keys: %w", err)
}
// 獲取每個鍵的值並檢查 UID
for _, key := range keys {
data, err := repo.Redis.GetCtx(ctx, key)
if err != nil {
if errors.Is(err, redis.Nil) {
continue // 鍵已過期或不存在
}
return nil, fmt.Errorf("failed to get blacklist entry: %w", err)
}
var entry entity.BlacklistEntry
if err := json.Unmarshal([]byte(data), &entry); err != nil {
continue // 跳過無效的條目
}
// 檢查 UID 是否匹配
if entry.UID == uid {
entries = append(entries, &entry)
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
return entries, nil
}