302 lines
8.5 KiB
Go
302 lines
8.5 KiB
Go
package repository
|
|
|
|
import (
|
|
"ark-permission/internal/domain"
|
|
"ark-permission/internal/domain/repository"
|
|
"ark-permission/internal/entity"
|
|
ers "code.30cm.net/wanderland/library-go/errors"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
|
)
|
|
|
|
type TokenRepositoryParam struct {
|
|
Store *redis.Redis `name:"redis"`
|
|
}
|
|
|
|
type tokenRepository struct {
|
|
store *redis.Redis
|
|
}
|
|
|
|
func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository {
|
|
return &tokenRepository{
|
|
store: param.Store,
|
|
}
|
|
}
|
|
|
|
func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error {
|
|
body, err := json.Marshal(token)
|
|
if err != nil {
|
|
return ers.ArkInternal("json.Marshal token error", err.Error())
|
|
}
|
|
if err := t.store.Pipelined(func(tx redis.Pipeliner) error {
|
|
refreshTTL := time.Duration(token.RedisRefreshExpiredSec()) * time.Second
|
|
|
|
if err := t.setToken(ctx, tx, token, body, refreshTTL); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := t.setRefreshToken(ctx, tx, token, refreshTTL); err != nil {
|
|
return err
|
|
}
|
|
|
|
return t.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, refreshTTL)
|
|
}); err != nil {
|
|
return domain.RedisPipLineError(err.Error())
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error {
|
|
keys := []string{
|
|
domain.GetAccessTokenRedisKey(token.ID),
|
|
domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(),
|
|
}
|
|
|
|
if err := t.deleteKeys(ctx, keys...); err != nil {
|
|
return domain.RedisPipLineError(err.Error())
|
|
}
|
|
|
|
_, _ = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID)
|
|
_, _ = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *tokenRepository) GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) {
|
|
token, err := t.get(ctx, domain.GetAccessTokenRedisKey(id))
|
|
if err != nil {
|
|
return entity.Token{}, err
|
|
}
|
|
|
|
return token, nil
|
|
}
|
|
|
|
func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error {
|
|
tokens, err := t.GetAccessTokensByUID(ctx, uid)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, token := range tokens {
|
|
if err := t.Delete(ctx, token); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error {
|
|
for _, tokenID := range ids {
|
|
token, err := t.GetAccessTokenByID(ctx, tokenID)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
keys := []string{
|
|
domain.GetAccessTokenRedisKey(token.ID),
|
|
domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(),
|
|
}
|
|
|
|
if err := t.deleteKeys(ctx, keys...); err != nil {
|
|
continue
|
|
}
|
|
|
|
_, _ = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID)
|
|
_, _ = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *tokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) {
|
|
return t.getTokensBySet(ctx, domain.GetUIDTokenRedisKey(uid))
|
|
}
|
|
|
|
func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) {
|
|
return t.getTokensBySet(ctx, domain.DeviceTokenRedisKey.With(deviceID).ToString())
|
|
}
|
|
|
|
func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error {
|
|
|
|
tokens, err := t.GetAccessTokensByDeviceID(ctx, deviceID)
|
|
if err != nil {
|
|
return domain.RedisDelError(fmt.Sprintf("GetAccessTokensByDeviceID error: %v", err))
|
|
}
|
|
|
|
var keys []string
|
|
for _, token := range tokens {
|
|
keys = append(keys, domain.GetAccessTokenRedisKey(token.ID))
|
|
keys = append(keys, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString())
|
|
|
|
}
|
|
|
|
err = t.store.Pipelined(func(tx redis.Pipeliner) error {
|
|
for _, token := range tokens {
|
|
_, _ = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID)
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := t.deleteKeys(ctx, keys...); err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = t.store.Del(domain.DeviceTokenRedisKey.With(deviceID).ToString())
|
|
return err
|
|
}
|
|
|
|
func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) {
|
|
return t.getCountBySet(domain.DeviceTokenRedisKey.With(deviceID).ToString())
|
|
}
|
|
|
|
func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) {
|
|
return t.getCountBySet(domain.UIDTokenRedisKey.With(uid).ToString())
|
|
}
|
|
|
|
func (t *tokenRepository) GetAccessTokenByByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error) {
|
|
id, err := t.store.Get(domain.RefreshTokenRedisKey.With(oneTimeToken).ToString())
|
|
if err != nil {
|
|
return entity.Token{}, domain.RedisError(fmt.Sprintf("GetAccessTokenByByOneTimeToken store.Get error: %s", err.Error()))
|
|
}
|
|
|
|
if id == "" {
|
|
return entity.Token{}, ers.ResourceNotFound("token key not found in redis", domain.RefreshTokenRedisKey.With(oneTimeToken).ToString())
|
|
}
|
|
|
|
return t.GetAccessTokenByID(ctx, id)
|
|
}
|
|
|
|
func (t *tokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error {
|
|
var keys []string
|
|
|
|
for _, id := range ids {
|
|
keys = append(keys, domain.RefreshTokenRedisKey.With(id).ToString())
|
|
}
|
|
|
|
for _, token := range tokens {
|
|
keys = append(keys, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString())
|
|
}
|
|
|
|
return t.deleteKeys(ctx, keys...)
|
|
}
|
|
|
|
func (t *tokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, expires time.Duration) error {
|
|
body, err := json.Marshal(ticket)
|
|
if err != nil {
|
|
return ers.InvalidFormat("CreateOneTimeToken json.Marshal error", err.Error())
|
|
}
|
|
|
|
_, err = t.store.SetnxEx(domain.RefreshTokenRedisKey.With(key).ToString(), string(body), int(expires.Seconds()))
|
|
if err != nil {
|
|
return domain.RedisError(fmt.Sprintf("CreateOneTimeToken store.SetnxEx error: %s", err.Error()))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// -------------------- Private area --------------------
|
|
|
|
func (t *tokenRepository) get(ctx context.Context, key string) (entity.Token, error) {
|
|
body, err := t.store.GetCtx(ctx, key)
|
|
if err != nil {
|
|
return entity.Token{}, domain.RedisError(fmt.Sprintf("token %s not found in redis: %s", key, err.Error()))
|
|
}
|
|
|
|
if body == "" {
|
|
return entity.Token{}, ers.ResourceNotFound("this token not found")
|
|
}
|
|
|
|
var token entity.Token
|
|
if err := json.Unmarshal([]byte(body), &token); err != nil {
|
|
return entity.Token{}, ers.ArkInternal("json.Unmarshal token error", err.Error())
|
|
}
|
|
|
|
return token, nil
|
|
}
|
|
|
|
func (t *tokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, body []byte, ttl time.Duration) error {
|
|
return tx.Set(ctx, domain.GetAccessTokenRedisKey(token.ID), body, ttl).Err()
|
|
}
|
|
|
|
func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, ttl time.Duration) error {
|
|
if token.RefreshToken != "" {
|
|
return tx.Set(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), token.ID, ttl).Err()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *tokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, ttl time.Duration) error {
|
|
if err := tx.SAdd(ctx, domain.UIDTokenRedisKey.With(uid).ToString(), tokenID).Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := tx.SAdd(ctx, domain.DeviceTokenRedisKey.With(deviceID).ToString(), tokenID).Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *tokenRepository) deleteKeys(ctx context.Context, keys ...string) error {
|
|
return t.store.Pipelined(func(tx redis.Pipeliner) error {
|
|
for _, key := range keys {
|
|
if err := tx.Del(ctx, key).Err(); err != nil {
|
|
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (t *tokenRepository) getTokensBySet(ctx context.Context, setKey string) ([]entity.Token, error) {
|
|
ids, err := t.store.Smembers(setKey)
|
|
if err != nil {
|
|
if errors.Is(err, redis.Nil) {
|
|
return nil, nil
|
|
}
|
|
return nil, domain.RedisError(fmt.Sprintf("getTokensBySet store.Get %s error: %v", setKey, err.Error()))
|
|
}
|
|
|
|
var tokens []entity.Token
|
|
var deleteTokens []string
|
|
now := time.Now().Unix()
|
|
for _, id := range ids {
|
|
token, err := t.get(ctx, domain.GetAccessTokenRedisKey(id))
|
|
if err != nil {
|
|
deleteTokens = append(deleteTokens, id)
|
|
continue
|
|
}
|
|
|
|
if int64(token.ExpiresIn) < now {
|
|
deleteTokens = append(deleteTokens, id)
|
|
continue
|
|
}
|
|
|
|
tokens = append(tokens, token)
|
|
}
|
|
|
|
if len(deleteTokens) > 0 {
|
|
_ = t.DeleteAccessTokenByID(ctx, deleteTokens)
|
|
}
|
|
|
|
return tokens, nil
|
|
}
|
|
|
|
func (t *tokenRepository) getCountBySet(setKey string) (int, error) {
|
|
count, err := t.store.Scard(setKey)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return int(count), nil
|
|
}
|