191 lines
5.0 KiB
Go
191 lines
5.0 KiB
Go
package repository
|
|
|
|
import (
|
|
"ark-permission/internal/domain"
|
|
"ark-permission/internal/domain/repository"
|
|
"ark-permission/internal/entity"
|
|
ers "code.30cm.net/wanderland/library-go/errors"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
|
)
|
|
|
|
type TokenRepositoryParam struct {
|
|
Store *redis.Redis `name:"redis"`
|
|
}
|
|
|
|
type tokenRepository struct {
|
|
store *redis.Redis
|
|
}
|
|
|
|
func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository {
|
|
return &tokenRepository{
|
|
store: param.Store,
|
|
}
|
|
}
|
|
|
|
func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error {
|
|
body, err := json.Marshal(token)
|
|
if err != nil {
|
|
return ers.ArkInternal("json.Marshal token error", err.Error())
|
|
}
|
|
|
|
err = t.store.Pipelined(func(tx redis.Pipeliner) error {
|
|
rTTL := token.RefreshTokenExpires()
|
|
|
|
if err := t.setToken(ctx, tx, token, body, rTTL); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := t.setRefreshToken(ctx, tx, token, rTTL); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := t.setDeviceToken(ctx, tx, token, rTTL); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return domain.RedisPipLineError(err.Error())
|
|
}
|
|
|
|
if err := t.SetUIDToken(token); err != nil {
|
|
return ers.ArkInternal("SetUIDToken error", err.Error())
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *tokenRepository) GetByAccess(_ context.Context, id string) (entity.Token, error) {
|
|
return t.get(domain.GetAccessTokenRedisKey(id))
|
|
}
|
|
|
|
func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error {
|
|
err := t.store.Pipelined(func(tx redis.Pipeliner) error {
|
|
keys := []string{
|
|
domain.GetAccessTokenRedisKey(token.ID),
|
|
domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(),
|
|
}
|
|
|
|
for _, key := range keys {
|
|
if err := tx.Del(ctx, key).Err(); err != nil {
|
|
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
|
|
}
|
|
}
|
|
|
|
if token.DeviceID != "" {
|
|
key := domain.DeviceTokenRedisKey.With(token.UID).ToString()
|
|
_, err := t.store.Hdel(key, token.DeviceID)
|
|
if err != nil {
|
|
return domain.RedisDelError(fmt.Sprintf("store.HDel deviceKey error: %v", err))
|
|
}
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *tokenRepository) get(key string) (entity.Token, error) {
|
|
body, err := t.store.Get(key)
|
|
if errors.Is(err, redis.Nil) || body == "" {
|
|
return entity.Token{}, ers.ResourceNotFound("token key not found in redis", key)
|
|
}
|
|
|
|
if err != nil {
|
|
return entity.Token{}, ers.ArkInternal(fmt.Sprintf("store.Get tokenTag error: %v", err))
|
|
}
|
|
|
|
var token entity.Token
|
|
if err := json.Unmarshal([]byte(body), &token); err != nil {
|
|
return entity.Token{}, ers.ArkInternal(fmt.Sprintf("json.Unmarshal token error: %w", err))
|
|
}
|
|
|
|
return token, nil
|
|
}
|
|
|
|
func (t *tokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, body []byte, rTTL time.Duration) error {
|
|
err := tx.Set(ctx, domain.GetAccessTokenRedisKey(token.ID), body, rTTL).Err()
|
|
if err != nil {
|
|
return wrapError("tx.Set GetAccessTokenRedisKey error", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error {
|
|
if token.RefreshToken != "" {
|
|
err := tx.Set(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), token.ID, rTTL).Err()
|
|
if err != nil {
|
|
return wrapError("tx.Set RefreshToken error", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *tokenRepository) setDeviceToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error {
|
|
if token.DeviceID != "" {
|
|
key := domain.DeviceTokenRedisKey.With(token.UID).ToString()
|
|
value := fmt.Sprintf("%s-%d", token.ID, token.AccessCreateAt.Add(rTTL).Unix())
|
|
err := tx.HSet(ctx, key, token.DeviceID, value).Err()
|
|
if err != nil {
|
|
return wrapError("tx.HSet Device Token error", err)
|
|
}
|
|
err = tx.Expire(ctx, key, rTTL).Err()
|
|
if err != nil {
|
|
return wrapError("tx.Expire Device Token error", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SetUIDToken 將 token 資料放進 uid key中
|
|
func (t *tokenRepository) SetUIDToken(token entity.Token) error {
|
|
uidTokens := make(entity.UIDToken)
|
|
b, err := t.store.Get(domain.GetUIDTokenRedisKey(token.UID))
|
|
if err != nil && !errors.Is(err, redis.Nil) {
|
|
return wrapError("t.store.Get GetUIDTokenRedisKey error", err)
|
|
}
|
|
|
|
if b != "" {
|
|
err = json.Unmarshal([]byte(b), &uidTokens)
|
|
if err != nil {
|
|
return wrapError("json.Unmarshal GetUIDTokenRedisKey error", err)
|
|
}
|
|
}
|
|
|
|
now := time.Now().Unix()
|
|
for k, t := range uidTokens {
|
|
if t < now {
|
|
delete(uidTokens, k)
|
|
}
|
|
}
|
|
|
|
uidTokens[token.ID] = token.RefreshTokenExpiresUnix()
|
|
s, err := json.Marshal(uidTokens)
|
|
if err != nil {
|
|
return wrapError("json.Marshal UIDToken error", err)
|
|
}
|
|
|
|
err = t.store.Setex(domain.GetUIDTokenRedisKey(token.UID), string(s), 86400*30)
|
|
if err != nil {
|
|
return wrapError("t.store.Setex GetUIDTokenRedisKey error", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func wrapError(message string, err error) error {
|
|
return fmt.Errorf("%s: %w", message, err)
|
|
}
|