package repository import ( "ark-permission/internal/domain" "ark-permission/internal/domain/repository" "ark-permission/internal/entity" ers "code.30cm.net/wanderland/library-go/errors" "context" "encoding/json" "errors" "fmt" "time" "github.com/zeromicro/go-zero/core/stores/redis" ) type TokenRepositoryParam struct { Store *redis.Redis `name:"redis"` } type tokenRepository struct { store *redis.Redis } func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository { return &tokenRepository{ store: param.Store, } } func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error { body, err := json.Marshal(token) if err != nil { return wrapError("json.Marshal token error", err) } err = t.store.Pipelined(func(tx redis.Pipeliner) error { rTTL := token.RefreshTokenExpires() if err := t.setToken(ctx, tx, token, body, rTTL); err != nil { return err } if err := t.setRefreshToken(ctx, tx, token, rTTL); err != nil { return err } if err := t.setDeviceToken(ctx, tx, token, rTTL); err != nil { return err } return nil }) if err != nil { return wrapError("store.Pipelined error", err) } if err := t.SetUIDToken(token); err != nil { return wrapError("SetUIDToken error", err) } return nil } func (t *tokenRepository) GetByAccess(_ context.Context, id string) (entity.Token, error) { return t.get(domain.GetAccessTokenRedisKey(id)) } func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error { err := t.store.Pipelined(func(tx redis.Pipeliner) error { keys := []string{ domain.GetAccessTokenRedisKey(token.ID), domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), } for _, key := range keys { if err := tx.Del(ctx, key).Err(); err != nil { return fmt.Errorf("store.Del key error: %w", err) } } if token.DeviceID != "" { key := domain.DeviceTokenRedisKey.With(token.UID).ToString() _, err := t.store.Hdel(key, token.DeviceID) if err != nil { return fmt.Errorf("store.HDel deviceKey error: %w", err) } } return nil }) if err != nil { return fmt.Errorf("store.Pipelined error: %w", err) } return nil } func (t *tokenRepository) get(key string) (entity.Token, error) { body, err := t.store.Get(key) if errors.Is(err, redis.Nil) { return entity.Token{}, ers.ResourceNotFound("token key not found in redis", key) } if err != nil { return entity.Token{}, fmt.Errorf("store.Get tokenTag error: %w", err) } var token entity.Token if err := json.Unmarshal([]byte(body), &token); err != nil { return entity.Token{}, fmt.Errorf("json.Unmarshal token error: %w", err) } return token, nil } func (t *tokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, body []byte, rTTL time.Duration) error { err := tx.Set(ctx, domain.GetAccessTokenRedisKey(token.ID), body, rTTL).Err() if err != nil { return wrapError("tx.Set GetAccessTokenRedisKey error", err) } return nil } func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error { if token.RefreshToken != "" { err := tx.Set(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), token.ID, rTTL).Err() if err != nil { return wrapError("tx.Set RefreshToken error", err) } } return nil } func (t *tokenRepository) setDeviceToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error { if token.DeviceID != "" { key := domain.DeviceTokenRedisKey.With(token.UID).ToString() value := fmt.Sprintf("%s-%d", token.ID, token.AccessCreateAt.Add(rTTL).Unix()) err := tx.HSet(ctx, key, token.DeviceID, value).Err() if err != nil { return wrapError("tx.HSet Device Token error", err) } err = tx.Expire(ctx, key, rTTL).Err() if err != nil { return wrapError("tx.Expire Device Token error", err) } } return nil } // SetUIDToken 將 token 資料放進 uid key中 func (t *tokenRepository) SetUIDToken(token entity.Token) error { uidTokens := make(entity.UIDToken) b, err := t.store.Get(domain.GetUIDTokenRedisKey(token.UID)) if err != nil && !errors.Is(err, redis.Nil) { return wrapError("t.store.Get GetUIDTokenRedisKey error", err) } if b != "" { err = json.Unmarshal([]byte(b), &uidTokens) if err != nil { return wrapError("json.Unmarshal GetUIDTokenRedisKey error", err) } } now := time.Now().Unix() for k, t := range uidTokens { if t < now { delete(uidTokens, k) } } uidTokens[token.ID] = token.RefreshTokenExpiresUnix() s, err := json.Marshal(uidTokens) if err != nil { return wrapError("json.Marshal UIDToken error", err) } err = t.store.Setex(domain.GetUIDTokenRedisKey(token.UID), string(s), 86400*30) if err != nil { return wrapError("t.store.Setex GetUIDTokenRedisKey error", err) } return nil } func wrapError(message string, err error) error { return fmt.Errorf("%s: %w", message, err) }