package repository import ( "app-cloudep-permission-server/internal/domain" "app-cloudep-permission-server/internal/domain/repository" "app-cloudep-permission-server/internal/entity" "context" "encoding/json" "errors" "fmt" "time" ers "code.30cm.net/digimon/library-go/errors" "github.com/zeromicro/go-zero/core/stores/redis" ) type TokenRepositoryParam struct { Store *redis.Redis `name:"redis"` } type tokenRepository struct { store *redis.Redis } func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository { return &tokenRepository{ store: param.Store, } } func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error { body, err := json.Marshal(token) if err != nil { return ers.ArkInternal("json.Marshal token error", err.Error()) } if err := t.store.Pipelined(func(tx redis.Pipeliner) error { refreshTTL := time.Duration(token.RedisRefreshExpiredSec()) * time.Second if err := t.setToken(ctx, tx, token, body, refreshTTL); err != nil { return err } if err := t.setRefreshToken(ctx, tx, token, refreshTTL); err != nil { return err } return t.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, refreshTTL) }); err != nil { return repository.RedisPipLineError(err.Error()) } return nil } func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error { keys := []string{ domain.GetAccessTokenRedisKey(token.ID), domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), } if err := t.deleteKeys(ctx, keys...); err != nil { return repository.RedisPipLineError(err.Error()) } _, _ = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID) _, _ = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) return nil } func (t *tokenRepository) GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) { token, err := t.get(ctx, domain.GetAccessTokenRedisKey(id)) if err != nil { return entity.Token{}, err } return token, nil } func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error { tokens, err := t.GetAccessTokensByUID(ctx, uid) if err != nil { return err } for _, token := range tokens { if err := t.Delete(ctx, token); err != nil { return err } } return nil } func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error { for _, tokenID := range ids { token, err := t.GetAccessTokenByID(ctx, tokenID) if err != nil { continue } keys := []string{ domain.GetAccessTokenRedisKey(token.ID), domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), } if err := t.deleteKeys(ctx, keys...); err != nil { continue } _, _ = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID) _, _ = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) } return nil } func (t *tokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) { return t.getTokensBySet(ctx, domain.GetUIDTokenRedisKey(uid)) } func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) { return t.getTokensBySet(ctx, domain.DeviceTokenRedisKey.With(deviceID).ToString()) } func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error { tokens, err := t.GetAccessTokensByDeviceID(ctx, deviceID) if err != nil { return repository.RedisDelError(fmt.Sprintf("GetAccessTokensByDeviceID error: %v", err)) } var keys []string for _, token := range tokens { keys = append(keys, domain.GetAccessTokenRedisKey(token.ID)) keys = append(keys, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()) } err = t.store.Pipelined(func(tx redis.Pipeliner) error { for _, token := range tokens { _, _ = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) } return nil }) if err != nil { return err } if err := t.deleteKeys(ctx, keys...); err != nil { return err } _, err = t.store.Del(domain.DeviceTokenRedisKey.With(deviceID).ToString()) return err } func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) { return t.getCountBySet(domain.DeviceTokenRedisKey.With(deviceID).ToString()) } func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) { return t.getCountBySet(domain.UIDTokenRedisKey.With(uid).ToString()) } func (t *tokenRepository) GetAccessTokenByByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error) { id, err := t.store.Get(domain.RefreshTokenRedisKey.With(oneTimeToken).ToString()) if err != nil { return entity.Token{}, repository.RedisError(fmt.Sprintf("GetAccessTokenByByOneTimeToken store.Get error: %s", err.Error())) } if id == "" { return entity.Token{}, ers.ResourceNotFound("token key not found in redis", domain.RefreshTokenRedisKey.With(oneTimeToken).ToString()) } return t.GetAccessTokenByID(ctx, id) } func (t *tokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error { var keys []string for _, id := range ids { keys = append(keys, domain.RefreshTokenRedisKey.With(id).ToString()) } for _, token := range tokens { keys = append(keys, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()) } return t.deleteKeys(ctx, keys...) } func (t *tokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, expires time.Duration) error { body, err := json.Marshal(ticket) if err != nil { return ers.InvalidFormat("CreateOneTimeToken json.Marshal error", err.Error()) } _, err = t.store.SetnxEx(domain.RefreshTokenRedisKey.With(key).ToString(), string(body), int(expires.Seconds())) if err != nil { return repository.RedisError(fmt.Sprintf("CreateOneTimeToken store.SetnxEx error: %s", err.Error())) } return nil } // -------------------- Private area -------------------- func (t *tokenRepository) get(ctx context.Context, key string) (entity.Token, error) { body, err := t.store.GetCtx(ctx, key) if err != nil { return entity.Token{}, repository.RedisError(fmt.Sprintf("token %s not found in redis: %s", key, err.Error())) } if body == "" { return entity.Token{}, ers.ResourceNotFound("this token not found") } var token entity.Token if err := json.Unmarshal([]byte(body), &token); err != nil { return entity.Token{}, ers.ArkInternal("json.Unmarshal token error", err.Error()) } return token, nil } func (t *tokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, body []byte, ttl time.Duration) error { return tx.Set(ctx, domain.GetAccessTokenRedisKey(token.ID), body, ttl).Err() } func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, ttl time.Duration) error { if token.RefreshToken != "" { return tx.Set(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), token.ID, ttl).Err() } return nil } func (t *tokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, ttl time.Duration) error { if err := tx.SAdd(ctx, domain.UIDTokenRedisKey.With(uid).ToString(), tokenID).Err(); err != nil { return err } // 設置 UID 鍵的過期時間 if err := tx.Expire(ctx, domain.UIDTokenRedisKey.With(uid).ToString(), ttl).Err(); err != nil { return err } if err := tx.SAdd(ctx, domain.DeviceTokenRedisKey.With(deviceID).ToString(), tokenID).Err(); err != nil { return err } // 設置 deviceID 鍵的過期時間 if err := tx.Expire(ctx, domain.DeviceTokenRedisKey.With(deviceID).ToString(), ttl).Err(); err != nil { return err } return nil } func (t *tokenRepository) deleteKeys(ctx context.Context, keys ...string) error { return t.store.Pipelined(func(tx redis.Pipeliner) error { for _, key := range keys { if err := tx.Del(ctx, key).Err(); err != nil { return repository.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) } } return nil }) } func (t *tokenRepository) getTokensBySet(ctx context.Context, setKey string) ([]entity.Token, error) { ids, err := t.store.Smembers(setKey) if err != nil { if errors.Is(err, redis.Nil) { return nil, nil } return nil, repository.RedisError(fmt.Sprintf("getTokensBySet store.Get %s error: %v", setKey, err.Error())) } var tokens []entity.Token var deleteTokens []string now := time.Now().Unix() for _, id := range ids { token, err := t.get(ctx, domain.GetAccessTokenRedisKey(id)) if err != nil { deleteTokens = append(deleteTokens, id) continue } if int64(token.ExpiresIn) < now { deleteTokens = append(deleteTokens, id) continue } tokens = append(tokens, token) } if len(deleteTokens) > 0 { _ = t.DeleteAccessTokenByID(ctx, deleteTokens) } return tokens, nil } func (t *tokenRepository) getCountBySet(setKey string) (int, error) { count, err := t.store.Scard(setKey) if err != nil { return 0, err } return int(count), nil }