146 lines
3.6 KiB
Go
146 lines
3.6 KiB
Go
|
package repository
|
|||
|
|
|||
|
import (
|
|||
|
"backend/pkg/library/errs"
|
|||
|
"backend/pkg/permission/domain/entity"
|
|||
|
"backend/pkg/permission/domain/repository"
|
|||
|
"context"
|
|||
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
|||
|
"strings"
|
|||
|
"time"
|
|||
|
)
|
|||
|
|
|||
|
// Token Repository Implementation
|
|||
|
|
|||
|
type TokenRepositoryParam struct {
|
|||
|
Redis *redis.Redis
|
|||
|
}
|
|||
|
|
|||
|
type TokenRepository struct {
|
|||
|
Redis *redis.Redis
|
|||
|
}
|
|||
|
|
|||
|
// NewTokenRepository 創建令牌倉庫實例
|
|||
|
func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository {
|
|||
|
return &TokenRepository{
|
|||
|
Redis: param.Redis,
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func (r *TokenRepository) Create(ctx context.Context, token *entity.Token) error {
|
|||
|
// 驗證數據
|
|||
|
if err := token.Validate(); err != nil {
|
|||
|
return errs.InvalidFormat(err.Error())
|
|||
|
}
|
|||
|
|
|||
|
token.CreateTime = time.Now()
|
|||
|
token.UpdateTime = time.Now()
|
|||
|
|
|||
|
// 在 Redis 中存儲 access token
|
|||
|
accessKey := "token:access:" + token.AccessToken
|
|||
|
refreshKey := "token:refresh:" + token.RefreshToken
|
|||
|
|
|||
|
// 設置過期時間
|
|||
|
expiry := int(time.Until(token.ExpiresAt).Seconds())
|
|||
|
if expiry <= 0 {
|
|||
|
return errs.InvalidFormat("token already expired")
|
|||
|
}
|
|||
|
|
|||
|
// 存儲 access token
|
|||
|
err := r.Redis.SetexCtx(ctx, accessKey, token.UID+":"+token.ClientID+":"+token.DeviceID, expiry)
|
|||
|
if err != nil {
|
|||
|
return errs.DatabaseErr(err.Error())
|
|||
|
}
|
|||
|
|
|||
|
// 存儲 refresh token (較長的過期時間)
|
|||
|
refreshExpiry := expiry * 7 // refresh token 過期時間是 access token 的 7 倍
|
|||
|
err = r.Redis.SetexCtx(ctx, refreshKey, token.UID+":"+token.ClientID+":"+token.DeviceID, refreshExpiry)
|
|||
|
if err != nil {
|
|||
|
return errs.DatabaseErr(err.Error())
|
|||
|
}
|
|||
|
|
|||
|
return nil
|
|||
|
}
|
|||
|
|
|||
|
func (r *TokenRepository) GetByAccessToken(ctx context.Context, accessToken string) (*entity.Token, error) {
|
|||
|
key := "token:access:" + accessToken
|
|||
|
value, err := r.Redis.GetCtx(ctx, key)
|
|||
|
if err != nil {
|
|||
|
if err == redis.Nil {
|
|||
|
return nil, errs.NotFound("access_token")
|
|||
|
}
|
|||
|
return nil, errs.DatabaseErr(err.Error())
|
|||
|
}
|
|||
|
|
|||
|
// 解析值
|
|||
|
parts := strings.Split(value, ":")
|
|||
|
if len(parts) != 3 {
|
|||
|
return nil, errs.InvalidFormat("invalid token format")
|
|||
|
}
|
|||
|
|
|||
|
return &entity.Token{
|
|||
|
UID: parts[0],
|
|||
|
ClientID: parts[1],
|
|||
|
DeviceID: parts[2],
|
|||
|
AccessToken: accessToken,
|
|||
|
}, nil
|
|||
|
}
|
|||
|
|
|||
|
func (r *TokenRepository) GetByRefreshToken(ctx context.Context, refreshToken string) (*entity.Token, error) {
|
|||
|
key := "token:refresh:" + refreshToken
|
|||
|
value, err := r.Redis.GetCtx(ctx, key)
|
|||
|
if err != nil {
|
|||
|
if err == redis.Nil {
|
|||
|
return nil, errs.NotFound("refresh_token")
|
|||
|
}
|
|||
|
return nil, errs.DatabaseErr(err.Error())
|
|||
|
}
|
|||
|
|
|||
|
// 解析值
|
|||
|
parts := strings.Split(value, ":")
|
|||
|
if len(parts) != 3 {
|
|||
|
return nil, errs.InvalidFormat("invalid token format")
|
|||
|
}
|
|||
|
|
|||
|
return &entity.Token{
|
|||
|
UID: parts[0],
|
|||
|
ClientID: parts[1],
|
|||
|
DeviceID: parts[2],
|
|||
|
RefreshToken: refreshToken,
|
|||
|
}, nil
|
|||
|
}
|
|||
|
|
|||
|
func (r *TokenRepository) Update(ctx context.Context, token *entity.Token) error {
|
|||
|
// 驗證數據
|
|||
|
if err := token.Validate(); err != nil {
|
|||
|
return errs.InvalidFormat(err.Error())
|
|||
|
}
|
|||
|
|
|||
|
token.UpdateTime = time.Now()
|
|||
|
|
|||
|
// 重新存儲 access token
|
|||
|
accessKey := "token:access:" + token.AccessToken
|
|||
|
expiry := int(time.Until(token.ExpiresAt).Seconds())
|
|||
|
if expiry <= 0 {
|
|||
|
return errs.InvalidFormat("token already expired")
|
|||
|
}
|
|||
|
|
|||
|
err := r.Redis.SetexCtx(ctx, accessKey, token.UID+":"+token.ClientID+":"+token.DeviceID, expiry)
|
|||
|
if err != nil {
|
|||
|
return errs.DatabaseErr(err.Error())
|
|||
|
}
|
|||
|
|
|||
|
return nil
|
|||
|
}
|
|||
|
|
|||
|
func (r *TokenRepository) Delete(ctx context.Context, id bson.ObjectID) error {
|
|||
|
// Redis 版本不需要 ObjectID,這裡留空實現
|
|||
|
return nil
|
|||
|
}
|
|||
|
|
|||
|
func (r *TokenRepository) DeleteByUserID(ctx context.Context, uid string) error {
|
|||
|
// 可以實現刪除用戶所有 token 的邏輯
|
|||
|
// 這裡簡化實現
|
|||
|
return nil
|
|||
|
}
|