backend/tmp/reborn/repository/cache_repository.go

175 lines
4.5 KiB
Go

package repository
import (
"context"
"encoding/json"
"fmt"
"permission/reborn/config"
"permission/reborn/domain/errors"
"permission/reborn/domain/repository"
"time"
"github.com/redis/go-redis/v9"
)
type cacheRepository struct {
client *redis.Client
config config.RedisConfig
}
// NewCacheRepository 建立快取 Repository
func NewCacheRepository(client *redis.Client, cfg config.RedisConfig) repository.CacheRepository {
return &cacheRepository{
client: client,
config: cfg,
}
}
func (r *cacheRepository) Get(ctx context.Context, key string) (string, error) {
val, err := r.client.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return "", errors.ErrNotFound
}
return "", errors.Wrap(errors.ErrCodeCacheError, "failed to get cache", err)
}
return val, nil
}
func (r *cacheRepository) Set(ctx context.Context, key, value string, ttl time.Duration) error {
if ttl == 0 {
ttl = r.getDefaultTTL(key)
}
err := r.client.Set(ctx, key, value, ttl).Err()
if err != nil {
return errors.Wrap(errors.ErrCodeCacheError, "failed to set cache", err)
}
return nil
}
func (r *cacheRepository) Delete(ctx context.Context, keys ...string) error {
if len(keys) == 0 {
return nil
}
err := r.client.Del(ctx, keys...).Err()
if err != nil {
return errors.Wrap(errors.ErrCodeCacheError, "failed to delete cache", err)
}
return nil
}
func (r *cacheRepository) Exists(ctx context.Context, key string) (bool, error) {
count, err := r.client.Exists(ctx, key).Result()
if err != nil {
return false, errors.Wrap(errors.ErrCodeCacheError, "failed to check cache exists", err)
}
return count > 0, nil
}
func (r *cacheRepository) GetObject(ctx context.Context, key string, dest interface{}) error {
val, err := r.Get(ctx, key)
if err != nil {
return err
}
if err := json.Unmarshal([]byte(val), dest); err != nil {
return errors.Wrap(errors.ErrCodeCacheError, "failed to unmarshal cache object", err)
}
return nil
}
func (r *cacheRepository) SetObject(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
data, err := json.Marshal(value)
if err != nil {
return errors.Wrap(errors.ErrCodeCacheError, "failed to marshal cache object", err)
}
return r.Set(ctx, key, string(data), ttl)
}
func (r *cacheRepository) DeletePattern(ctx context.Context, pattern string) error {
var cursor uint64
var keys []string
for {
var scanKeys []string
var err error
scanKeys, cursor, err = r.client.Scan(ctx, cursor, pattern, 100).Result()
if err != nil {
return errors.Wrap(errors.ErrCodeCacheError, "failed to scan cache keys", err)
}
keys = append(keys, scanKeys...)
if cursor == 0 {
break
}
}
if len(keys) > 0 {
return r.Delete(ctx, keys...)
}
return nil
}
// getDefaultTTL 根據 key 類型取得預設 TTL
func (r *cacheRepository) getDefaultTTL(key string) time.Duration {
switch {
case key == repository.CacheKeyPermissionTree:
return r.config.PermissionTreeTTL
case key == repository.CacheKeyPermissionList:
return r.config.PermissionTreeTTL
case isUserPermissionKey(key):
return r.config.UserPermissionTTL
case isRolePermissionKey(key):
return r.config.RolePolicyTTL
default:
return 5 * time.Minute
}
}
func isUserPermissionKey(key string) bool {
return len(key) > len(repository.CacheKeyUserPermissionPrefix) &&
key[:len(repository.CacheKeyUserPermissionPrefix)] == repository.CacheKeyUserPermissionPrefix
}
func isRolePermissionKey(key string) bool {
return len(key) > len(repository.CacheKeyRolePermissionPrefix) &&
key[:len(repository.CacheKeyRolePermissionPrefix)] == repository.CacheKeyRolePermissionPrefix
}
// InvalidateUserPermission 清除使用者權限快取
func (r *cacheRepository) InvalidateUserPermission(ctx context.Context, uid string) error {
key := fmt.Sprintf("%s%s", repository.CacheKeyUserPermissionPrefix, uid)
return r.Delete(ctx, key)
}
// InvalidateRolePermission 清除角色權限快取
func (r *cacheRepository) InvalidateRolePermission(ctx context.Context, roleUID string) error {
key := fmt.Sprintf("%s%s", repository.CacheKeyRolePermissionPrefix, roleUID)
return r.Delete(ctx, key)
}
// InvalidateAllPermissions 清除所有權限相關快取
func (r *cacheRepository) InvalidateAllPermissions(ctx context.Context) error {
patterns := []string{
repository.CacheKeyPermissionTree,
repository.CacheKeyPermissionList,
repository.CacheKeyUserPermissionPrefix + "*",
repository.CacheKeyRolePermissionPrefix + "*",
}
for _, pattern := range patterns {
if err := r.DeletePattern(ctx, pattern); err != nil {
return err
}
}
return nil
}