package repository import ( "ark-permission/internal/domain" "ark-permission/internal/domain/repository" "ark-permission/internal/entity" ers "code.30cm.net/wanderland/library-go/errors" "context" "encoding/json" "errors" "fmt" "time" "github.com/zeromicro/go-zero/core/stores/redis" ) type TokenRepositoryParam struct { Store *redis.Redis `name:"redis"` } type tokenRepository struct { store *redis.Redis } func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository { return &tokenRepository{ store: param.Store, } } func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error { body, err := json.Marshal(token) if err != nil { return ers.ArkInternal("json.Marshal token error", err.Error()) } err = t.store.Pipelined(func(tx redis.Pipeliner) error { // rTTL := token.RedisExpiredSec() refreshTTL := token.RedisRefreshExpiredSec() if err := t.setToken(ctx, tx, token, body, time.Duration(refreshTTL)*time.Second); err != nil { return err } if err := t.setRefreshToken(ctx, tx, token, time.Duration(refreshTTL)*time.Second); err != nil { return err } err := t.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, time.Duration(refreshTTL)*time.Second) if err != nil { return err } return nil }) if err != nil { return domain.RedisPipLineError(err.Error()) } return nil } func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error { err := t.store.Pipelined(func(tx redis.Pipeliner) error { keys := []string{ domain.GetAccessTokenRedisKey(token.ID), domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), domain.UIDTokenRedisKey.With(token.UID).ToString(), } for _, key := range keys { if err := tx.Del(ctx, key).Err(); err != nil { return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) } } if token.DeviceID != "" { key := domain.DeviceTokenRedisKey.With(token.DeviceID).ToString() _, err := t.store.Del(key) if err != nil { return domain.RedisDelError(fmt.Sprintf("store.HDel deviceKey error: %v", err)) } } return nil }) if err != nil { return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err)) } return nil } func (t *tokenRepository) GetAccessTokenByID(_ context.Context, id string) (entity.Token, error) { return t.get(domain.GetAccessTokenRedisKey(id)) } func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error { tokens, err := t.GetAccessTokensByUID(ctx, uid) if err != nil { return err } for _, item := range tokens { err := t.Delete(ctx, item) if err != nil { return err } } return nil } // DeleteAccessTokenByID TODO 要做錯誤處理 func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error { for _, tokenID := range ids { token, err := t.GetAccessTokenByID(ctx, tokenID) if err != nil { continue } err = t.store.Pipelined(func(tx redis.Pipeliner) error { keys := []string{ domain.GetAccessTokenRedisKey(token.ID), domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), } for _, key := range keys { if err := tx.Del(ctx, key).Err(); err != nil { return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) } } _, err = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID) if err != nil { return domain.RedisDelError(fmt.Sprintf("store.Srem DeviceTokenRedisKey error: %v", err)) } _, err = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) if err != nil { return domain.RedisDelError(fmt.Sprintf("store.Srem UIDTokenRedisKey error: %v", err)) } return nil }) if err != nil { continue } } return nil } // GetAccessTokensByUID 透過 uid 得到目前未過期的 token func (t *tokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) { utKeys, err := t.store.Smembers(domain.GetUIDTokenRedisKey(uid)) if err != nil { // 沒有就視為回空 if errors.Is(err, redis.Nil) { return nil, nil } return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByUID store.Get GetUIDTokenRedisKey error: %v", err.Error())) } now := time.Now().UTC() var tokens []entity.Token var deleteToken []string for _, id := range utKeys { item := &entity.Token{} tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id)) if err == nil { err = json.Unmarshal([]byte(tk), item) if err != nil { return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err)) } tokens = append(tokens, *item) } if errors.Is(err, redis.Nil) { deleteToken = append(deleteToken, id) } if int64(item.ExpiresIn) < now.Unix() { deleteToken = append(deleteToken, id) continue } } if len(deleteToken) > 0 { // 如果失敗也沒關係,其他get method撈取時會在判斷是否過期或存在 _ = t.DeleteAccessTokenByID(ctx, deleteToken) } return tokens, nil } func (t *tokenRepository) GetByRefresh(ctx context.Context, refreshToken string) (entity.Token, error) { id, err := t.store.Get(domain.RefreshTokenRedisKey.With(refreshToken).ToString()) if err != nil { return entity.Token{}, err } if errors.Is(err, redis.Nil) || id == "" { return entity.Token{}, ers.ResourceNotFound("token key not found in redis", domain.RefreshTokenRedisKey.With(refreshToken).ToString()) } if err != nil { return entity.Token{}, ers.ArkInternal(fmt.Sprintf("store.GetByRefresh refresh token error: %v", err)) } return t.GetAccessTokenByID(ctx, id) } func (t *tokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error { err := t.store.Pipelined(func(tx redis.Pipeliner) error { keys := make([]string, 0, len(ids)+len(tokens)) for _, id := range ids { keys = append(keys, domain.RefreshTokenRedisKey.With(id).ToString()) } for _, token := range tokens { keys = append(keys, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()) } for _, key := range keys { if err := tx.Del(ctx, key).Err(); err != nil { return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) } } return nil }) if err != nil { return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err)) } return nil } func (t *tokenRepository) CreateOneTimeToken(_ context.Context, key string, ticket entity.Ticket, expires time.Duration) error { body, err := json.Marshal(ticket) if err != nil { return ers.InvalidFormat("CreateOneTimeToken json.Marshal error:", err.Error()) } _, err = t.store.SetnxEx(domain.RefreshTokenRedisKey.With(key).ToString(), string(body), int(expires.Seconds())) if err != nil { return ers.DBError("CreateOneTimeToken store.set error:", err.Error()) } return nil } func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) { utKeys, err := t.store.Smembers(domain.DeviceTokenRedisKey.With(deviceID).ToString()) if err != nil { // 沒有就視為回空 if errors.Is(err, redis.Nil) { return nil, nil } return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByDeviceID store.Get DeviceTokenRedisKey error: %v", err.Error())) } now := time.Now().UTC() var tokens []entity.Token var deleteToken []string for _, id := range utKeys { item := &entity.Token{} tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id)) if err == nil { err = json.Unmarshal([]byte(tk), item) if err != nil { return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err)) } tokens = append(tokens, *item) } if errors.Is(err, redis.Nil) { deleteToken = append(deleteToken, id) } if int64(item.ExpiresIn) < now.Unix() { deleteToken = append(deleteToken, id) continue } } if len(deleteToken) > 0 { // 如果失敗也沒關係,其他get method撈取時會在判斷是否過期或存在 _ = t.DeleteAccessTokenByID(ctx, deleteToken) } return tokens, nil } func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error { tokens, err := t.GetAccessTokensByDeviceID(ctx, deviceID) if err != nil { return domain.RedisDelError(fmt.Sprintf("GetAccessTokensByDeviceID error: %v", err)) } err = t.store.Pipelined(func(tx redis.Pipeliner) error { for _, token := range tokens { if err := tx.Del(ctx, domain.GetAccessTokenRedisKey(token.ID)).Err(); err != nil { return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) } if err := tx.Del(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()).Err(); err != nil { return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) } _, err = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) if err != nil { return domain.RedisDelError(fmt.Sprintf("store.Srem UIDTokenRedisKey error: %v", err)) } } _, err := t.store.Del(domain.DeviceTokenRedisKey.With(deviceID).ToString()) if err != nil { return domain.RedisDelError(fmt.Sprintf("store.Srem DeviceTokenRedisKey error: %v", err)) } return nil }) if err != nil { return err } return nil } func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) { count, err := t.store.Scard(domain.DeviceTokenRedisKey.With(deviceID).ToString()) if err != nil { return 0, err } return int(count), nil } func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) { count, err := t.store.Scard(domain.UIDTokenRedisKey.With(uid).ToString()) if err != nil { return 0, err } return int(count), nil } // -------------------- Private area -------------------- func (t *tokenRepository) get(key string) (entity.Token, error) { body, err := t.store.Get(key) if errors.Is(err, redis.Nil) || body == "" { return entity.Token{}, ers.ResourceNotFound("token key not found in redis", key) } if err != nil { return entity.Token{}, ers.ArkInternal(fmt.Sprintf("store.Get tokenTag error: %v", err)) } var token entity.Token if err := json.Unmarshal([]byte(body), &token); err != nil { return entity.Token{}, ers.ArkInternal(fmt.Sprintf("json.Unmarshal token error: %w", err)) } return token, nil } func (t *tokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, body []byte, rTTL time.Duration) error { err := tx.Set(ctx, domain.GetAccessTokenRedisKey(token.ID), body, rTTL).Err() if err != nil { return wrapError("tx.Set GetAccessTokenRedisKey error", err) } return nil } func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error { if token.RefreshToken != "" { err := tx.Set(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), token.ID, rTTL).Err() if err != nil { return wrapError("tx.Set RefreshToken error", err) } } return nil } func (t *tokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, rttl time.Duration) error { uidKey := domain.UIDTokenRedisKey.With(uid).ToString() err := tx.SAdd(ctx, uidKey, tokenID).Err() if err != nil { return err } err = tx.Expire(ctx, uidKey, rttl).Err() if err != nil { return err } deviceKey := domain.DeviceTokenRedisKey.With(deviceID).ToString() err = tx.SAdd(ctx, deviceKey, tokenID).Err() if err != nil { return err } err = tx.Expire(ctx, deviceKey, rttl).Err() if err != nil { return err } return nil } // SetUIDToken 將 token 資料放進 uid key中 func (t *tokenRepository) SetUIDToken(token entity.Token) error { uidTokens := make(entity.UIDToken) b, err := t.store.Get(domain.GetUIDTokenRedisKey(token.UID)) if err != nil && !errors.Is(err, redis.Nil) { return wrapError("t.store.Get GetUIDTokenRedisKey error", err) } if b != "" { err = json.Unmarshal([]byte(b), &uidTokens) if err != nil { return wrapError("json.Unmarshal GetUIDTokenRedisKey error", err) } } now := time.Now().Unix() for k, t := range uidTokens { if t < now { delete(uidTokens, k) } } uidTokens[token.ID] = token.RefreshTokenExpiresUnix() s, err := json.Marshal(uidTokens) if err != nil { return wrapError("json.Marshal UIDToken error", err) } err = t.store.Setex(domain.GetUIDTokenRedisKey(token.UID), string(s), 86400*30) if err != nil { return wrapError("t.store.Setex GetUIDTokenRedisKey error", err) } return nil } func wrapError(message string, err error) error { return fmt.Errorf("%s: %w", message, err) }