96 lines
2.6 KiB
Go
96 lines
2.6 KiB
Go
|
|
package repository
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
redislib "gateway/internal/library/redis"
|
||
|
|
authdomain "gateway/internal/model/auth/domain"
|
||
|
|
domrepo "gateway/internal/model/auth/domain/repository"
|
||
|
|
|
||
|
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
||
|
|
)
|
||
|
|
|
||
|
|
type redisTokenRevokeStore struct {
|
||
|
|
client *redis.Redis
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewRedisTokenRevokeStore creates a Redis-backed JWT revoke store.
|
||
|
|
func NewRedisTokenRevokeStore(client *redislib.Client) domrepo.TokenRevokeStore {
|
||
|
|
if client == nil || client.Zero() == nil {
|
||
|
|
panic("auth: redis client is required for token revoke store")
|
||
|
|
}
|
||
|
|
return &redisTokenRevokeStore{client: client.Zero()}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *redisTokenRevokeStore) SavePair(ctx context.Context, accessJTI, refreshJTI string, accessTTL, refreshTTL time.Duration) error {
|
||
|
|
if accessJTI == "" || refreshJTI == "" {
|
||
|
|
return fmt.Errorf("auth: jwt pair jti is required")
|
||
|
|
}
|
||
|
|
accessSec := ttlSeconds(accessTTL)
|
||
|
|
refreshSec := ttlSeconds(refreshTTL)
|
||
|
|
if err := s.client.SetexCtx(ctx, authdomain.JWTPairRedisKey(accessJTI), refreshJTI, accessSec); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
return s.client.SetexCtx(ctx, authdomain.JWTPairRedisKey(refreshJTI), accessJTI, refreshSec)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *redisTokenRevokeStore) GetPairedJTI(ctx context.Context, jti string) (string, error) {
|
||
|
|
if jti == "" {
|
||
|
|
return "", fmt.Errorf("auth: jti is required")
|
||
|
|
}
|
||
|
|
val, err := s.client.GetCtx(ctx, authdomain.JWTPairRedisKey(jti))
|
||
|
|
if errors.Is(err, redis.Nil) || val == "" {
|
||
|
|
return "", nil
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return "", err
|
||
|
|
}
|
||
|
|
return val, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *redisTokenRevokeStore) DeletePair(ctx context.Context, accessJTI, refreshJTI string) error {
|
||
|
|
keys := make([]string, 0, 2)
|
||
|
|
if accessJTI != "" {
|
||
|
|
keys = append(keys, authdomain.JWTPairRedisKey(accessJTI))
|
||
|
|
}
|
||
|
|
if refreshJTI != "" {
|
||
|
|
keys = append(keys, authdomain.JWTPairRedisKey(refreshJTI))
|
||
|
|
}
|
||
|
|
if len(keys) == 0 {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
_, err := s.client.DelCtx(ctx, keys...)
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *redisTokenRevokeStore) Blacklist(ctx context.Context, jti string, ttl time.Duration) error {
|
||
|
|
if jti == "" {
|
||
|
|
return fmt.Errorf("auth: jti is required")
|
||
|
|
}
|
||
|
|
return s.client.SetexCtx(ctx, authdomain.JWTBlacklistRedisKey(jti), "1", ttlSeconds(ttl))
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *redisTokenRevokeStore) IsBlacklisted(ctx context.Context, jti string) (bool, error) {
|
||
|
|
if jti == "" {
|
||
|
|
return false, fmt.Errorf("auth: jti is required")
|
||
|
|
}
|
||
|
|
exists, err := s.client.ExistsCtx(ctx, authdomain.JWTBlacklistRedisKey(jti))
|
||
|
|
if err != nil {
|
||
|
|
return false, err
|
||
|
|
}
|
||
|
|
return exists, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func ttlSeconds(d time.Duration) int {
|
||
|
|
sec := int(d.Round(time.Second).Seconds())
|
||
|
|
if sec < 1 {
|
||
|
|
return 1
|
||
|
|
}
|
||
|
|
return sec
|
||
|
|
}
|
||
|
|
|
||
|
|
var _ domrepo.TokenRevokeStore = (*redisTokenRevokeStore)(nil)
|