guard/internal/repository/token.go

457 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"ark-permission/internal/domain"
"ark-permission/internal/domain/repository"
"ark-permission/internal/entity"
ers "code.30cm.net/wanderland/library-go/errors"
"context"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/zeromicro/go-zero/core/stores/redis"
)
type TokenRepositoryParam struct {
Store *redis.Redis `name:"redis"`
}
type tokenRepository struct {
store *redis.Redis
}
func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository {
return &tokenRepository{
store: param.Store,
}
}
func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error {
body, err := json.Marshal(token)
if err != nil {
return ers.ArkInternal("json.Marshal token error", err.Error())
}
err = t.store.Pipelined(func(tx redis.Pipeliner) error {
// rTTL := token.RedisExpiredSec()
refreshTTL := token.RedisRefreshExpiredSec()
if err := t.setToken(ctx, tx, token, body, time.Duration(refreshTTL)*time.Second); err != nil {
return err
}
if err := t.setRefreshToken(ctx, tx, token, time.Duration(refreshTTL)*time.Second); err != nil {
return err
}
err := t.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, time.Duration(refreshTTL)*time.Second)
if err != nil {
return err
}
return nil
})
if err != nil {
return domain.RedisPipLineError(err.Error())
}
return nil
}
func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error {
err := t.store.Pipelined(func(tx redis.Pipeliner) error {
keys := []string{
domain.GetAccessTokenRedisKey(token.ID),
domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(),
domain.UIDTokenRedisKey.With(token.UID).ToString(),
}
for _, key := range keys {
if err := tx.Del(ctx, key).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
}
if token.DeviceID != "" {
key := domain.DeviceTokenRedisKey.With(token.DeviceID).ToString()
_, err := t.store.Del(key)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.HDel deviceKey error: %v", err))
}
}
return nil
})
if err != nil {
return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err))
}
return nil
}
func (t *tokenRepository) GetAccessTokenByID(_ context.Context, id string) (entity.Token, error) {
return t.get(domain.GetAccessTokenRedisKey(id))
}
func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error {
tokens, err := t.GetAccessTokensByUID(ctx, uid)
if err != nil {
return err
}
for _, item := range tokens {
err := t.Delete(ctx, item)
if err != nil {
return err
}
}
return nil
}
// DeleteAccessTokenByID TODO 要做錯誤處理
func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error {
for _, tokenID := range ids {
token, err := t.GetAccessTokenByID(ctx, tokenID)
if err != nil {
continue
}
err = t.store.Pipelined(func(tx redis.Pipeliner) error {
keys := []string{
domain.GetAccessTokenRedisKey(token.ID),
domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(),
}
for _, key := range keys {
if err := tx.Del(ctx, key).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
}
_, err = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Srem DeviceTokenRedisKey error: %v", err))
}
_, err = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Srem UIDTokenRedisKey error: %v", err))
}
return nil
})
if err != nil {
continue
}
}
return nil
}
// GetAccessTokensByUID 透過 uid 得到目前未過期的 token
func (t *tokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) {
utKeys, err := t.store.Smembers(domain.GetUIDTokenRedisKey(uid))
if err != nil {
// 沒有就視為回空
if errors.Is(err, redis.Nil) {
return nil, nil
}
return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByUID store.Get GetUIDTokenRedisKey error: %v", err.Error()))
}
now := time.Now().UTC()
var tokens []entity.Token
var deleteToken []string
for _, id := range utKeys {
item := &entity.Token{}
tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id))
if err == nil {
err = json.Unmarshal([]byte(tk), item)
if err != nil {
return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err))
}
tokens = append(tokens, *item)
}
if errors.Is(err, redis.Nil) {
deleteToken = append(deleteToken, id)
}
if int64(item.ExpiresIn) < now.Unix() {
deleteToken = append(deleteToken, id)
continue
}
}
if len(deleteToken) > 0 {
// 如果失敗也沒關係其他get method撈取時會在判斷是否過期或存在
_ = t.DeleteAccessTokenByID(ctx, deleteToken)
}
return tokens, nil
}
func (t *tokenRepository) GetByRefresh(ctx context.Context, refreshToken string) (entity.Token, error) {
id, err := t.store.Get(domain.RefreshTokenRedisKey.With(refreshToken).ToString())
if err != nil {
return entity.Token{}, err
}
if errors.Is(err, redis.Nil) || id == "" {
return entity.Token{}, ers.ResourceNotFound("token key not found in redis", domain.RefreshTokenRedisKey.With(refreshToken).ToString())
}
if err != nil {
return entity.Token{}, ers.ArkInternal(fmt.Sprintf("store.GetByRefresh refresh token error: %v", err))
}
return t.GetAccessTokenByID(ctx, id)
}
func (t *tokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error {
err := t.store.Pipelined(func(tx redis.Pipeliner) error {
keys := make([]string, 0, len(ids)+len(tokens))
for _, id := range ids {
keys = append(keys, domain.RefreshTokenRedisKey.With(id).ToString())
}
for _, token := range tokens {
keys = append(keys, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString())
}
for _, key := range keys {
if err := tx.Del(ctx, key).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
}
return nil
})
if err != nil {
return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err))
}
return nil
}
func (t *tokenRepository) CreateOneTimeToken(_ context.Context, key string, ticket entity.Ticket, expires time.Duration) error {
body, err := json.Marshal(ticket)
if err != nil {
return ers.InvalidFormat("CreateOneTimeToken json.Marshal error:", err.Error())
}
_, err = t.store.SetnxEx(domain.RefreshTokenRedisKey.With(key).ToString(), string(body), int(expires.Seconds()))
if err != nil {
return ers.DBError("CreateOneTimeToken store.set error:", err.Error())
}
return nil
}
func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) {
utKeys, err := t.store.Smembers(domain.DeviceTokenRedisKey.With(deviceID).ToString())
if err != nil {
// 沒有就視為回空
if errors.Is(err, redis.Nil) {
return nil, nil
}
return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByDeviceID store.Get DeviceTokenRedisKey error: %v", err.Error()))
}
now := time.Now().UTC()
var tokens []entity.Token
var deleteToken []string
for _, id := range utKeys {
item := &entity.Token{}
tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id))
if err == nil {
err = json.Unmarshal([]byte(tk), item)
if err != nil {
return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err))
}
tokens = append(tokens, *item)
}
if errors.Is(err, redis.Nil) {
deleteToken = append(deleteToken, id)
}
if int64(item.ExpiresIn) < now.Unix() {
deleteToken = append(deleteToken, id)
continue
}
}
if len(deleteToken) > 0 {
// 如果失敗也沒關係其他get method撈取時會在判斷是否過期或存在
_ = t.DeleteAccessTokenByID(ctx, deleteToken)
}
return tokens, nil
}
func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error {
tokens, err := t.GetAccessTokensByDeviceID(ctx, deviceID)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("GetAccessTokensByDeviceID error: %v", err))
}
err = t.store.Pipelined(func(tx redis.Pipeliner) error {
for _, token := range tokens {
if err := tx.Del(ctx, domain.GetAccessTokenRedisKey(token.ID)).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
if err := tx.Del(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
_, err = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Srem UIDTokenRedisKey error: %v", err))
}
}
_, err := t.store.Del(domain.DeviceTokenRedisKey.With(deviceID).ToString())
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Srem DeviceTokenRedisKey error: %v", err))
}
return nil
})
if err != nil {
return err
}
return nil
}
func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) {
count, err := t.store.Scard(domain.DeviceTokenRedisKey.With(deviceID).ToString())
if err != nil {
return 0, err
}
return int(count), nil
}
func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) {
count, err := t.store.Scard(domain.UIDTokenRedisKey.With(uid).ToString())
if err != nil {
return 0, err
}
return int(count), nil
}
// -------------------- Private area --------------------
func (t *tokenRepository) get(key string) (entity.Token, error) {
body, err := t.store.Get(key)
if errors.Is(err, redis.Nil) || body == "" {
return entity.Token{}, ers.ResourceNotFound("token key not found in redis", key)
}
if err != nil {
return entity.Token{}, ers.ArkInternal(fmt.Sprintf("store.Get tokenTag error: %v", err))
}
var token entity.Token
if err := json.Unmarshal([]byte(body), &token); err != nil {
return entity.Token{}, ers.ArkInternal(fmt.Sprintf("json.Unmarshal token error: %w", err))
}
return token, nil
}
func (t *tokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, body []byte, rTTL time.Duration) error {
err := tx.Set(ctx, domain.GetAccessTokenRedisKey(token.ID), body, rTTL).Err()
if err != nil {
return wrapError("tx.Set GetAccessTokenRedisKey error", err)
}
return nil
}
func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error {
if token.RefreshToken != "" {
err := tx.Set(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), token.ID, rTTL).Err()
if err != nil {
return wrapError("tx.Set RefreshToken error", err)
}
}
return nil
}
func (t *tokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, rttl time.Duration) error {
uidKey := domain.UIDTokenRedisKey.With(uid).ToString()
err := tx.SAdd(ctx, uidKey, tokenID).Err()
if err != nil {
return err
}
err = tx.Expire(ctx, uidKey, rttl).Err()
if err != nil {
return err
}
deviceKey := domain.DeviceTokenRedisKey.With(deviceID).ToString()
err = tx.SAdd(ctx, deviceKey, tokenID).Err()
if err != nil {
return err
}
err = tx.Expire(ctx, deviceKey, rttl).Err()
if err != nil {
return err
}
return nil
}
// SetUIDToken 將 token 資料放進 uid key中
func (t *tokenRepository) SetUIDToken(token entity.Token) error {
uidTokens := make(entity.UIDToken)
b, err := t.store.Get(domain.GetUIDTokenRedisKey(token.UID))
if err != nil && !errors.Is(err, redis.Nil) {
return wrapError("t.store.Get GetUIDTokenRedisKey error", err)
}
if b != "" {
err = json.Unmarshal([]byte(b), &uidTokens)
if err != nil {
return wrapError("json.Unmarshal GetUIDTokenRedisKey error", err)
}
}
now := time.Now().Unix()
for k, t := range uidTokens {
if t < now {
delete(uidTokens, k)
}
}
uidTokens[token.ID] = token.RefreshTokenExpiresUnix()
s, err := json.Marshal(uidTokens)
if err != nil {
return wrapError("json.Marshal UIDToken error", err)
}
err = t.store.Setex(domain.GetUIDTokenRedisKey(token.UID), string(s), 86400*30)
if err != nil {
return wrapError("t.store.Setex GetUIDTokenRedisKey error", err)
}
return nil
}
func wrapError(message string, err error) error {
return fmt.Errorf("%s: %w", message, err)
}