guard/internal/repository/token.go

401 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"ark-permission/internal/domain"
"ark-permission/internal/domain/repository"
"ark-permission/internal/entity"
ers "code.30cm.net/wanderland/library-go/errors"
"context"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/zeromicro/go-zero/core/stores/redis"
)
type TokenRepositoryParam struct {
Store *redis.Redis `name:"redis"`
}
type tokenRepository struct {
store *redis.Redis
}
func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) {
// TODO implement me
panic("implement me")
}
func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) {
// TODO implement me
panic("implement me")
}
func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) {
// TODO implement me
panic("implement me")
}
func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, id string) error {
// TODO implement me
panic("implement me")
}
func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error {
// TODO implement me
panic("implement me")
}
func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error {
// TODO implement me
panic("implement me")
}
func (t *tokenRepository) DeleteAccessTokenByDeviceIDAndUID(ctx context.Context, deviceID, uid string) error {
// TODO implement me
panic("implement me")
}
func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository {
return &tokenRepository{
store: param.Store,
}
}
func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error {
body, err := json.Marshal(token)
if err != nil {
return ers.ArkInternal("json.Marshal token error", err.Error())
}
err = t.store.Pipelined(func(tx redis.Pipeliner) error {
rTTL := token.RefreshTokenExpires()
if err := t.setToken(ctx, tx, token, body, rTTL); err != nil {
return err
}
if err := t.setRefreshToken(ctx, tx, token, rTTL); err != nil {
return err
}
if err := t.setDeviceToken(ctx, tx, token, rTTL); err != nil {
return err
}
return nil
})
if err != nil {
return domain.RedisPipLineError(err.Error())
}
if err := t.SetUIDToken(token); err != nil {
return ers.ArkInternal("SetUIDToken error", err.Error())
}
return nil
}
// // GetAccessTokensByDeviceID 透過 Device ID 得到目前未過期的token
// func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, uid string) ([]repository.DeviceToken, error) {
// data, err := t.store.Hgetall(domain.DeviceTokenRedisKey.With(uid).ToString())
// if err != nil {
// if errors.Is(err, redis.Nil) {
// return nil, nil
// }
//
// return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByDeviceID store.HGetAll Device Token error: %v", err.Error()))
// }
//
// ids := make([]repository.DeviceToken, 0, len(data))
// for deviceID, id := range data {
// ids = append(ids, repository.DeviceToken{
// DeviceID: deviceID,
//
// // e0a4f824-41db-4eb2-8e5a-d96966ea1d56-1698083859
// // -11是因為id組成最後11位數是-跟時間戳記
// TokenID: id[:len(id)-11],
// })
// }
// return ids, nil
// }
// GetAccessTokensByUID 透過 uid 得到目前未過期的 token
func (t *tokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) {
utKeys, err := t.store.Get(domain.GetUIDTokenRedisKey(uid))
if err != nil {
// 沒有就視為回空
if errors.Is(err, redis.Nil) {
return nil, nil
}
return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByUID store.Get GetUIDTokenRedisKey error: %v", err.Error()))
}
uidTokens := make(entity.UIDToken)
err = json.Unmarshal([]byte(utKeys), &uidTokens)
if err != nil {
return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err))
}
now := time.Now().Unix()
var tokens []entity.Token
var deleteToken []string
for id, token := range uidTokens {
if token < now {
deleteToken = append(deleteToken, id)
continue
}
tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id))
if err == nil {
item := entity.Token{}
err = json.Unmarshal([]byte(tk), &item)
if err != nil {
return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err))
}
tokens = append(tokens, item)
}
if errors.Is(err, redis.Nil) {
deleteToken = append(deleteToken, id)
}
}
if len(deleteToken) > 0 {
// 如果失敗也沒關係其他get method撈取時會在判斷是否過期或存在
_ = t.DeleteUIDToken(ctx, uid, deleteToken)
}
return tokens, nil
}
func (t *tokenRepository) DeleteUIDToken(ctx context.Context, uid string, ids []string) error {
uidTokens := make(entity.UIDToken)
tokenKeys, err := t.store.Get(domain.GetUIDTokenRedisKey(uid))
if err != nil {
if !errors.Is(err, redis.Nil) {
return fmt.Errorf("tx.get GetDeviceTokenRedisKey error: %w", err)
}
}
if tokenKeys != "" {
err = json.Unmarshal([]byte(tokenKeys), &uidTokens)
if err != nil {
return fmt.Errorf("json.Unmarshal GetDeviceTokenRedisKey error: %w", err)
}
}
now := time.Now().Unix()
for k, t := range uidTokens {
// 到期就刪除
if t < now {
delete(uidTokens, k)
}
}
for _, id := range ids {
delete(uidTokens, id)
}
b, err := json.Marshal(uidTokens)
if err != nil {
return fmt.Errorf("json.Marshal UIDToken error: %w", err)
}
_, err = t.store.SetnxEx(domain.GetUIDTokenRedisKey(uid), string(b), 86400*30)
if err != nil {
return fmt.Errorf("tx.set GetUIDTokenRedisKey error: %w", err)
}
return nil
}
func (t *tokenRepository) GetAccessTokenByID(_ context.Context, id string) (entity.Token, error) {
return t.get(domain.GetAccessTokenRedisKey(id))
}
func (t *tokenRepository) GetByRefresh(ctx context.Context, refreshToken string) (entity.Token, error) {
id, err := t.store.Get(domain.RefreshTokenRedisKey.With(refreshToken).ToString())
if err != nil {
return entity.Token{}, err
}
if errors.Is(err, redis.Nil) || id == "" {
return entity.Token{}, ers.ResourceNotFound("token key not found in redis", domain.RefreshTokenRedisKey.With(refreshToken).ToString())
}
if err != nil {
return entity.Token{}, ers.ArkInternal(fmt.Sprintf("store.GetByRefresh refresh token error: %v", err))
}
return t.GetAccessTokenByID(ctx, id)
}
func (t *tokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error {
err := t.store.Pipelined(func(tx redis.Pipeliner) error {
keys := make([]string, 0, len(ids)+len(tokens))
for _, id := range ids {
keys = append(keys, domain.RefreshTokenRedisKey.With(id).ToString())
}
for _, token := range tokens {
keys = append(keys, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString())
}
for _, key := range keys {
if err := tx.Del(ctx, key).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
}
return nil
})
if err != nil {
return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err))
}
return nil
}
func (t *tokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, expires time.Duration) error {
body, err := json.Marshal(ticket)
if err != nil {
return ers.InvalidFormat("CreateOneTimeToken json.Marshal error:", err.Error())
}
_, err = t.store.SetnxEx(domain.GetTicketRedisKey(key), string(body), int(expires.Seconds()))
if err != nil {
return ers.DBError("CreateOneTimeToken store.set error:", err.Error())
}
return nil
}
func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error {
err := t.store.Pipelined(func(tx redis.Pipeliner) error {
keys := []string{
domain.GetAccessTokenRedisKey(token.ID),
domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(),
}
for _, key := range keys {
if err := tx.Del(ctx, key).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
}
if token.DeviceID != "" {
key := domain.DeviceTokenRedisKey.With(token.UID).ToString()
_, err := t.store.Hdel(key, token.DeviceID)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.HDel deviceKey error: %v", err))
}
}
return nil
})
if err != nil {
return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err))
}
return nil
}
func (t *tokenRepository) get(key string) (entity.Token, error) {
body, err := t.store.Get(key)
if errors.Is(err, redis.Nil) || body == "" {
return entity.Token{}, ers.ResourceNotFound("token key not found in redis", key)
}
if err != nil {
return entity.Token{}, ers.ArkInternal(fmt.Sprintf("store.Get tokenTag error: %v", err))
}
var token entity.Token
if err := json.Unmarshal([]byte(body), &token); err != nil {
return entity.Token{}, ers.ArkInternal(fmt.Sprintf("json.Unmarshal token error: %w", err))
}
return token, nil
}
func (t *tokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, body []byte, rTTL time.Duration) error {
err := tx.Set(ctx, domain.GetAccessTokenRedisKey(token.ID), body, rTTL).Err()
if err != nil {
return wrapError("tx.Set GetAccessTokenRedisKey error", err)
}
return nil
}
func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error {
if token.RefreshToken != "" {
err := tx.Set(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), token.ID, rTTL).Err()
if err != nil {
return wrapError("tx.Set RefreshToken error", err)
}
}
return nil
}
func (t *tokenRepository) setDeviceToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error {
if token.DeviceID != "" {
key := domain.DeviceTokenRedisKey.With(token.UID).ToString()
value := fmt.Sprintf("%s-%d", token.ID, token.AccessCreateAt.Add(rTTL).Unix())
err := tx.HSet(ctx, key, token.DeviceID, value).Err()
if err != nil {
return wrapError("tx.HSet Device Token error", err)
}
err = tx.Expire(ctx, key, rTTL).Err()
if err != nil {
return wrapError("tx.Expire Device Token error", err)
}
}
return nil
}
// SetUIDToken 將 token 資料放進 uid key中
func (t *tokenRepository) SetUIDToken(token entity.Token) error {
uidTokens := make(entity.UIDToken)
b, err := t.store.Get(domain.GetUIDTokenRedisKey(token.UID))
if err != nil && !errors.Is(err, redis.Nil) {
return wrapError("t.store.Get GetUIDTokenRedisKey error", err)
}
if b != "" {
err = json.Unmarshal([]byte(b), &uidTokens)
if err != nil {
return wrapError("json.Unmarshal GetUIDTokenRedisKey error", err)
}
}
now := time.Now().Unix()
for k, t := range uidTokens {
if t < now {
delete(uidTokens, k)
}
}
uidTokens[token.ID] = token.RefreshTokenExpiresUnix()
s, err := json.Marshal(uidTokens)
if err != nil {
return wrapError("json.Marshal UIDToken error", err)
}
err = t.store.Setex(domain.GetUIDTokenRedisKey(token.UID), string(s), 86400*30)
if err != nil {
return wrapError("t.store.Setex GetUIDTokenRedisKey error", err)
}
return nil
}
func wrapError(message string, err error) error {
return fmt.Errorf("%s: %w", message, err)
}