guard/internal/repository/token.go

457 lines
12 KiB
Go
Raw Normal View History

2024-08-06 05:59:24 +00:00
package repository
import (
"ark-permission/internal/domain"
"ark-permission/internal/domain/repository"
"ark-permission/internal/entity"
2024-08-08 03:02:13 +00:00
ers "code.30cm.net/wanderland/library-go/errors"
2024-08-06 05:59:24 +00:00
"context"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/zeromicro/go-zero/core/stores/redis"
)
type TokenRepositoryParam struct {
Store *redis.Redis `name:"redis"`
}
type tokenRepository struct {
store *redis.Redis
}
func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository {
return &tokenRepository{
store: param.Store,
}
}
func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error {
body, err := json.Marshal(token)
if err != nil {
2024-08-08 08:10:38 +00:00
return ers.ArkInternal("json.Marshal token error", err.Error())
2024-08-06 05:59:24 +00:00
}
err = t.store.Pipelined(func(tx redis.Pipeliner) error {
2024-08-11 12:21:42 +00:00
// rTTL := token.RedisExpiredSec()
refreshTTL := token.RedisRefreshExpiredSec()
2024-08-06 05:59:24 +00:00
2024-08-11 12:21:42 +00:00
if err := t.setToken(ctx, tx, token, body, time.Duration(refreshTTL)*time.Second); err != nil {
2024-08-06 05:59:24 +00:00
return err
}
2024-08-11 12:21:42 +00:00
if err := t.setRefreshToken(ctx, tx, token, time.Duration(refreshTTL)*time.Second); err != nil {
2024-08-06 05:59:24 +00:00
return err
}
2024-08-11 12:21:42 +00:00
err := t.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, time.Duration(refreshTTL)*time.Second)
if err != nil {
2024-08-06 05:59:24 +00:00
return err
}
return nil
})
if err != nil {
2024-08-08 08:10:38 +00:00
return domain.RedisPipLineError(err.Error())
2024-08-06 05:59:24 +00:00
}
2024-08-11 12:21:42 +00:00
return nil
}
func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error {
err := t.store.Pipelined(func(tx redis.Pipeliner) error {
keys := []string{
domain.GetAccessTokenRedisKey(token.ID),
domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(),
domain.UIDTokenRedisKey.With(token.UID).ToString(),
}
for _, key := range keys {
if err := tx.Del(ctx, key).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
}
if token.DeviceID != "" {
key := domain.DeviceTokenRedisKey.With(token.DeviceID).ToString()
_, err := t.store.Del(key)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.HDel deviceKey error: %v", err))
}
}
return nil
})
if err != nil {
return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err))
}
return nil
}
func (t *tokenRepository) GetAccessTokenByID(_ context.Context, id string) (entity.Token, error) {
return t.get(domain.GetAccessTokenRedisKey(id))
}
func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error {
tokens, err := t.GetAccessTokensByUID(ctx, uid)
if err != nil {
return err
}
for _, item := range tokens {
err := t.Delete(ctx, item)
if err != nil {
return err
}
2024-08-06 05:59:24 +00:00
}
return nil
}
2024-08-11 12:21:42 +00:00
// DeleteAccessTokenByID TODO 要做錯誤處理
func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error {
for _, tokenID := range ids {
token, err := t.GetAccessTokenByID(ctx, tokenID)
if err != nil {
continue
}
err = t.store.Pipelined(func(tx redis.Pipeliner) error {
keys := []string{
domain.GetAccessTokenRedisKey(token.ID),
domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(),
}
for _, key := range keys {
if err := tx.Del(ctx, key).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
}
_, err = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Srem DeviceTokenRedisKey error: %v", err))
}
_, err = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Srem UIDTokenRedisKey error: %v", err))
}
return nil
})
if err != nil {
continue
}
}
return nil
}
2024-08-10 01:52:23 +00:00
// GetAccessTokensByUID 透過 uid 得到目前未過期的 token
func (t *tokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) {
2024-08-11 12:21:42 +00:00
utKeys, err := t.store.Smembers(domain.GetUIDTokenRedisKey(uid))
2024-08-10 01:52:23 +00:00
if err != nil {
// 沒有就視為回空
if errors.Is(err, redis.Nil) {
return nil, nil
}
return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByUID store.Get GetUIDTokenRedisKey error: %v", err.Error()))
}
2024-08-11 12:21:42 +00:00
now := time.Now().UTC()
2024-08-10 01:52:23 +00:00
var tokens []entity.Token
var deleteToken []string
2024-08-11 12:21:42 +00:00
for _, id := range utKeys {
item := &entity.Token{}
2024-08-10 01:52:23 +00:00
tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id))
if err == nil {
2024-08-11 12:21:42 +00:00
err = json.Unmarshal([]byte(tk), item)
2024-08-10 01:52:23 +00:00
if err != nil {
return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err))
}
2024-08-11 12:21:42 +00:00
tokens = append(tokens, *item)
2024-08-10 01:52:23 +00:00
}
if errors.Is(err, redis.Nil) {
deleteToken = append(deleteToken, id)
}
2024-08-11 12:21:42 +00:00
if int64(item.ExpiresIn) < now.Unix() {
deleteToken = append(deleteToken, id)
2024-08-10 01:52:23 +00:00
2024-08-11 12:21:42 +00:00
continue
2024-08-10 01:52:23 +00:00
}
}
2024-08-11 12:21:42 +00:00
if len(deleteToken) > 0 {
// 如果失敗也沒關係其他get method撈取時會在判斷是否過期或存在
_ = t.DeleteAccessTokenByID(ctx, deleteToken)
2024-08-10 01:52:23 +00:00
}
2024-08-11 12:21:42 +00:00
return tokens, nil
2024-08-08 03:02:13 +00:00
}
2024-08-10 01:52:23 +00:00
func (t *tokenRepository) GetByRefresh(ctx context.Context, refreshToken string) (entity.Token, error) {
id, err := t.store.Get(domain.RefreshTokenRedisKey.With(refreshToken).ToString())
if err != nil {
return entity.Token{}, err
}
if errors.Is(err, redis.Nil) || id == "" {
return entity.Token{}, ers.ResourceNotFound("token key not found in redis", domain.RefreshTokenRedisKey.With(refreshToken).ToString())
}
if err != nil {
return entity.Token{}, ers.ArkInternal(fmt.Sprintf("store.GetByRefresh refresh token error: %v", err))
}
return t.GetAccessTokenByID(ctx, id)
}
func (t *tokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error {
err := t.store.Pipelined(func(tx redis.Pipeliner) error {
keys := make([]string, 0, len(ids)+len(tokens))
for _, id := range ids {
keys = append(keys, domain.RefreshTokenRedisKey.With(id).ToString())
}
for _, token := range tokens {
keys = append(keys, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString())
}
for _, key := range keys {
if err := tx.Del(ctx, key).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
}
return nil
})
if err != nil {
return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err))
}
return nil
}
2024-08-11 12:21:42 +00:00
func (t *tokenRepository) CreateOneTimeToken(_ context.Context, key string, ticket entity.Ticket, expires time.Duration) error {
2024-08-10 01:52:23 +00:00
body, err := json.Marshal(ticket)
if err != nil {
return ers.InvalidFormat("CreateOneTimeToken json.Marshal error:", err.Error())
}
2024-08-11 12:21:42 +00:00
_, err = t.store.SetnxEx(domain.RefreshTokenRedisKey.With(key).ToString(), string(body), int(expires.Seconds()))
2024-08-10 01:52:23 +00:00
if err != nil {
return ers.DBError("CreateOneTimeToken store.set error:", err.Error())
}
return nil
}
2024-08-11 12:21:42 +00:00
func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) {
utKeys, err := t.store.Smembers(domain.DeviceTokenRedisKey.With(deviceID).ToString())
if err != nil {
// 沒有就視為回空
if errors.Is(err, redis.Nil) {
return nil, nil
2024-08-08 03:02:13 +00:00
}
2024-08-11 12:21:42 +00:00
return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByDeviceID store.Get DeviceTokenRedisKey error: %v", err.Error()))
}
now := time.Now().UTC()
var tokens []entity.Token
var deleteToken []string
for _, id := range utKeys {
item := &entity.Token{}
tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id))
if err == nil {
err = json.Unmarshal([]byte(tk), item)
if err != nil {
return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err))
2024-08-08 03:02:13 +00:00
}
2024-08-11 12:21:42 +00:00
tokens = append(tokens, *item)
2024-08-08 03:02:13 +00:00
}
2024-08-11 12:21:42 +00:00
if errors.Is(err, redis.Nil) {
deleteToken = append(deleteToken, id)
}
if int64(item.ExpiresIn) < now.Unix() {
deleteToken = append(deleteToken, id)
continue
}
}
if len(deleteToken) > 0 {
// 如果失敗也沒關係其他get method撈取時會在判斷是否過期或存在
_ = t.DeleteAccessTokenByID(ctx, deleteToken)
}
return tokens, nil
}
func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error {
tokens, err := t.GetAccessTokensByDeviceID(ctx, deviceID)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("GetAccessTokensByDeviceID error: %v", err))
}
err = t.store.Pipelined(func(tx redis.Pipeliner) error {
for _, token := range tokens {
if err := tx.Del(ctx, domain.GetAccessTokenRedisKey(token.ID)).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
if err := tx.Del(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
_, err = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID)
2024-08-08 03:02:13 +00:00
if err != nil {
2024-08-11 12:21:42 +00:00
return domain.RedisDelError(fmt.Sprintf("store.Srem UIDTokenRedisKey error: %v", err))
2024-08-08 03:02:13 +00:00
}
}
2024-08-11 12:21:42 +00:00
_, err := t.store.Del(domain.DeviceTokenRedisKey.With(deviceID).ToString())
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Srem DeviceTokenRedisKey error: %v", err))
}
2024-08-08 03:02:13 +00:00
return nil
})
if err != nil {
2024-08-11 12:21:42 +00:00
return err
2024-08-08 03:02:13 +00:00
}
return nil
}
2024-08-11 12:21:42 +00:00
func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) {
count, err := t.store.Scard(domain.DeviceTokenRedisKey.With(deviceID).ToString())
if err != nil {
return 0, err
}
return int(count), nil
}
func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) {
count, err := t.store.Scard(domain.UIDTokenRedisKey.With(uid).ToString())
if err != nil {
return 0, err
}
return int(count), nil
}
// -------------------- Private area --------------------
2024-08-08 03:02:13 +00:00
func (t *tokenRepository) get(key string) (entity.Token, error) {
body, err := t.store.Get(key)
2024-08-08 08:10:38 +00:00
if errors.Is(err, redis.Nil) || body == "" {
2024-08-08 03:02:13 +00:00
return entity.Token{}, ers.ResourceNotFound("token key not found in redis", key)
}
if err != nil {
2024-08-08 08:10:38 +00:00
return entity.Token{}, ers.ArkInternal(fmt.Sprintf("store.Get tokenTag error: %v", err))
2024-08-08 03:02:13 +00:00
}
var token entity.Token
if err := json.Unmarshal([]byte(body), &token); err != nil {
2024-08-08 08:10:38 +00:00
return entity.Token{}, ers.ArkInternal(fmt.Sprintf("json.Unmarshal token error: %w", err))
2024-08-08 03:02:13 +00:00
}
return token, nil
}
2024-08-06 05:59:24 +00:00
func (t *tokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, body []byte, rTTL time.Duration) error {
err := tx.Set(ctx, domain.GetAccessTokenRedisKey(token.ID), body, rTTL).Err()
if err != nil {
return wrapError("tx.Set GetAccessTokenRedisKey error", err)
}
return nil
}
func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error {
if token.RefreshToken != "" {
err := tx.Set(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), token.ID, rTTL).Err()
if err != nil {
return wrapError("tx.Set RefreshToken error", err)
}
}
return nil
}
2024-08-11 12:21:42 +00:00
func (t *tokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, rttl time.Duration) error {
uidKey := domain.UIDTokenRedisKey.With(uid).ToString()
err := tx.SAdd(ctx, uidKey, tokenID).Err()
if err != nil {
return err
}
err = tx.Expire(ctx, uidKey, rttl).Err()
if err != nil {
return err
2024-08-06 05:59:24 +00:00
}
2024-08-11 12:21:42 +00:00
deviceKey := domain.DeviceTokenRedisKey.With(deviceID).ToString()
err = tx.SAdd(ctx, deviceKey, tokenID).Err()
if err != nil {
return err
}
err = tx.Expire(ctx, deviceKey, rttl).Err()
if err != nil {
return err
}
2024-08-06 05:59:24 +00:00
return nil
}
// SetUIDToken 將 token 資料放進 uid key中
func (t *tokenRepository) SetUIDToken(token entity.Token) error {
uidTokens := make(entity.UIDToken)
b, err := t.store.Get(domain.GetUIDTokenRedisKey(token.UID))
if err != nil && !errors.Is(err, redis.Nil) {
return wrapError("t.store.Get GetUIDTokenRedisKey error", err)
}
if b != "" {
err = json.Unmarshal([]byte(b), &uidTokens)
if err != nil {
return wrapError("json.Unmarshal GetUIDTokenRedisKey error", err)
}
}
now := time.Now().Unix()
for k, t := range uidTokens {
if t < now {
delete(uidTokens, k)
}
}
uidTokens[token.ID] = token.RefreshTokenExpiresUnix()
s, err := json.Marshal(uidTokens)
if err != nil {
return wrapError("json.Marshal UIDToken error", err)
}
err = t.store.Setex(domain.GetUIDTokenRedisKey(token.UID), string(s), 86400*30)
if err != nil {
return wrapError("t.store.Setex GetUIDTokenRedisKey error", err)
}
return nil
}
func wrapError(message string, err error) error {
return fmt.Errorf("%s: %w", message, err)
}