backend/pkg/permission/repository/token.go

146 lines
3.6 KiB
Go
Raw Normal View History

2025-10-03 08:38:12 +00:00
package repository
import (
"backend/pkg/library/errs"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"context"
"github.com/zeromicro/go-zero/core/stores/redis"
"strings"
"time"
)
// Token Repository Implementation
type TokenRepositoryParam struct {
Redis *redis.Redis
}
type TokenRepository struct {
Redis *redis.Redis
}
// NewTokenRepository 創建令牌倉庫實例
func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository {
return &TokenRepository{
Redis: param.Redis,
}
}
func (r *TokenRepository) Create(ctx context.Context, token *entity.Token) error {
// 驗證數據
if err := token.Validate(); err != nil {
return errs.InvalidFormat(err.Error())
}
token.CreateTime = time.Now()
token.UpdateTime = time.Now()
// 在 Redis 中存儲 access token
accessKey := "token:access:" + token.AccessToken
refreshKey := "token:refresh:" + token.RefreshToken
// 設置過期時間
expiry := int(time.Until(token.ExpiresAt).Seconds())
if expiry <= 0 {
return errs.InvalidFormat("token already expired")
}
// 存儲 access token
err := r.Redis.SetexCtx(ctx, accessKey, token.UID+":"+token.ClientID+":"+token.DeviceID, expiry)
if err != nil {
return errs.DatabaseErr(err.Error())
}
// 存儲 refresh token (較長的過期時間)
refreshExpiry := expiry * 7 // refresh token 過期時間是 access token 的 7 倍
err = r.Redis.SetexCtx(ctx, refreshKey, token.UID+":"+token.ClientID+":"+token.DeviceID, refreshExpiry)
if err != nil {
return errs.DatabaseErr(err.Error())
}
return nil
}
func (r *TokenRepository) GetByAccessToken(ctx context.Context, accessToken string) (*entity.Token, error) {
key := "token:access:" + accessToken
value, err := r.Redis.GetCtx(ctx, key)
if err != nil {
if err == redis.Nil {
return nil, errs.NotFound("access_token")
}
return nil, errs.DatabaseErr(err.Error())
}
// 解析值
parts := strings.Split(value, ":")
if len(parts) != 3 {
return nil, errs.InvalidFormat("invalid token format")
}
return &entity.Token{
UID: parts[0],
ClientID: parts[1],
DeviceID: parts[2],
AccessToken: accessToken,
}, nil
}
func (r *TokenRepository) GetByRefreshToken(ctx context.Context, refreshToken string) (*entity.Token, error) {
key := "token:refresh:" + refreshToken
value, err := r.Redis.GetCtx(ctx, key)
if err != nil {
if err == redis.Nil {
return nil, errs.NotFound("refresh_token")
}
return nil, errs.DatabaseErr(err.Error())
}
// 解析值
parts := strings.Split(value, ":")
if len(parts) != 3 {
return nil, errs.InvalidFormat("invalid token format")
}
return &entity.Token{
UID: parts[0],
ClientID: parts[1],
DeviceID: parts[2],
RefreshToken: refreshToken,
}, nil
}
func (r *TokenRepository) Update(ctx context.Context, token *entity.Token) error {
// 驗證數據
if err := token.Validate(); err != nil {
return errs.InvalidFormat(err.Error())
}
token.UpdateTime = time.Now()
// 重新存儲 access token
accessKey := "token:access:" + token.AccessToken
expiry := int(time.Until(token.ExpiresAt).Seconds())
if expiry <= 0 {
return errs.InvalidFormat("token already expired")
}
err := r.Redis.SetexCtx(ctx, accessKey, token.UID+":"+token.ClientID+":"+token.DeviceID, expiry)
if err != nil {
return errs.DatabaseErr(err.Error())
}
return nil
}
func (r *TokenRepository) Delete(ctx context.Context, id bson.ObjectID) error {
// Redis 版本不需要 ObjectID這裡留空實現
return nil
}
func (r *TokenRepository) DeleteByUserID(ctx context.Context, uid string) error {
// 可以實現刪除用戶所有 token 的邏輯
// 這裡簡化實現
return nil
}