99 lines
2.6 KiB
Go
99 lines
2.6 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
|
|
app "haixun-backend/internal/library/errors"
|
|
"haixun-backend/internal/library/errors/code"
|
|
domrepo "haixun-backend/internal/model/auth/domain/repository"
|
|
|
|
goredis "github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
const (
|
|
jwtPairPrefix = "auth:jwt:pair:"
|
|
jwtBlacklistPrefix = "auth:jwt:blacklist:"
|
|
)
|
|
|
|
type redisTokenRevokeStore struct {
|
|
client *goredis.Client
|
|
}
|
|
|
|
func NewRedisTokenRevokeStore(client *goredis.Client) domrepo.TokenRevokeStore {
|
|
return &redisTokenRevokeStore{client: client}
|
|
}
|
|
|
|
func (s *redisTokenRevokeStore) SavePair(ctx context.Context, accessJTI, refreshJTI string, accessTTL, refreshTTL time.Duration) error {
|
|
if err := s.requireRedis(); err != nil {
|
|
return err
|
|
}
|
|
if accessJTI == "" || refreshJTI == "" {
|
|
return app.For(code.Auth).InputMissingRequired("jwt pair jti is required")
|
|
}
|
|
if err := s.client.Set(ctx, jwtPairPrefix+accessJTI, refreshJTI, minTTL(accessTTL)).Err(); err != nil {
|
|
return err
|
|
}
|
|
return s.client.Set(ctx, jwtPairPrefix+refreshJTI, accessJTI, minTTL(refreshTTL)).Err()
|
|
}
|
|
|
|
func (s *redisTokenRevokeStore) GetPairedJTI(ctx context.Context, jti string) (string, error) {
|
|
if err := s.requireRedis(); err != nil {
|
|
return "", err
|
|
}
|
|
value, err := s.client.Get(ctx, jwtPairPrefix+jti).Result()
|
|
if err == goredis.Nil {
|
|
return "", nil
|
|
}
|
|
return value, err
|
|
}
|
|
|
|
func (s *redisTokenRevokeStore) DeletePair(ctx context.Context, accessJTI, refreshJTI string) error {
|
|
if err := s.requireRedis(); err != nil {
|
|
return err
|
|
}
|
|
keys := make([]string, 0, 2)
|
|
if accessJTI != "" {
|
|
keys = append(keys, jwtPairPrefix+accessJTI)
|
|
}
|
|
if refreshJTI != "" {
|
|
keys = append(keys, jwtPairPrefix+refreshJTI)
|
|
}
|
|
if len(keys) == 0 {
|
|
return nil
|
|
}
|
|
return s.client.Del(ctx, keys...).Err()
|
|
}
|
|
|
|
func (s *redisTokenRevokeStore) Blacklist(ctx context.Context, jti string, ttl time.Duration) error {
|
|
if err := s.requireRedis(); err != nil {
|
|
return err
|
|
}
|
|
if jti == "" {
|
|
return app.For(code.Auth).InputMissingRequired("jti is required")
|
|
}
|
|
return s.client.Set(ctx, jwtBlacklistPrefix+jti, "1", minTTL(ttl)).Err()
|
|
}
|
|
|
|
func (s *redisTokenRevokeStore) IsBlacklisted(ctx context.Context, jti string) (bool, error) {
|
|
if err := s.requireRedis(); err != nil {
|
|
return false, err
|
|
}
|
|
count, err := s.client.Exists(ctx, jwtBlacklistPrefix+jti).Result()
|
|
return count > 0, err
|
|
}
|
|
|
|
func (s *redisTokenRevokeStore) requireRedis() error {
|
|
if s.client == nil {
|
|
return app.For(code.Auth).DBUnavailable("Redis is not configured")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func minTTL(ttl time.Duration) time.Duration {
|
|
if ttl < time.Second {
|
|
return time.Second
|
|
}
|
|
return ttl
|
|
}
|