backend/pkg/permission/repository/token.go

146 lines
3.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"backend/pkg/library/errs"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"context"
"github.com/zeromicro/go-zero/core/stores/redis"
"strings"
"time"
)
// Token Repository Implementation
type TokenRepositoryParam struct {
Redis *redis.Redis
}
type TokenRepository struct {
Redis *redis.Redis
}
// NewTokenRepository 創建令牌倉庫實例
func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository {
return &TokenRepository{
Redis: param.Redis,
}
}
func (r *TokenRepository) Create(ctx context.Context, token *entity.Token) error {
// 驗證數據
if err := token.Validate(); err != nil {
return errs.InvalidFormat(err.Error())
}
token.CreateTime = time.Now()
token.UpdateTime = time.Now()
// 在 Redis 中存儲 access token
accessKey := "token:access:" + token.AccessToken
refreshKey := "token:refresh:" + token.RefreshToken
// 設置過期時間
expiry := int(time.Until(token.ExpiresAt).Seconds())
if expiry <= 0 {
return errs.InvalidFormat("token already expired")
}
// 存儲 access token
err := r.Redis.SetexCtx(ctx, accessKey, token.UID+":"+token.ClientID+":"+token.DeviceID, expiry)
if err != nil {
return errs.DatabaseErr(err.Error())
}
// 存儲 refresh token (較長的過期時間)
refreshExpiry := expiry * 7 // refresh token 過期時間是 access token 的 7 倍
err = r.Redis.SetexCtx(ctx, refreshKey, token.UID+":"+token.ClientID+":"+token.DeviceID, refreshExpiry)
if err != nil {
return errs.DatabaseErr(err.Error())
}
return nil
}
func (r *TokenRepository) GetByAccessToken(ctx context.Context, accessToken string) (*entity.Token, error) {
key := "token:access:" + accessToken
value, err := r.Redis.GetCtx(ctx, key)
if err != nil {
if err == redis.Nil {
return nil, errs.NotFound("access_token")
}
return nil, errs.DatabaseErr(err.Error())
}
// 解析值
parts := strings.Split(value, ":")
if len(parts) != 3 {
return nil, errs.InvalidFormat("invalid token format")
}
return &entity.Token{
UID: parts[0],
ClientID: parts[1],
DeviceID: parts[2],
AccessToken: accessToken,
}, nil
}
func (r *TokenRepository) GetByRefreshToken(ctx context.Context, refreshToken string) (*entity.Token, error) {
key := "token:refresh:" + refreshToken
value, err := r.Redis.GetCtx(ctx, key)
if err != nil {
if err == redis.Nil {
return nil, errs.NotFound("refresh_token")
}
return nil, errs.DatabaseErr(err.Error())
}
// 解析值
parts := strings.Split(value, ":")
if len(parts) != 3 {
return nil, errs.InvalidFormat("invalid token format")
}
return &entity.Token{
UID: parts[0],
ClientID: parts[1],
DeviceID: parts[2],
RefreshToken: refreshToken,
}, nil
}
func (r *TokenRepository) Update(ctx context.Context, token *entity.Token) error {
// 驗證數據
if err := token.Validate(); err != nil {
return errs.InvalidFormat(err.Error())
}
token.UpdateTime = time.Now()
// 重新存儲 access token
accessKey := "token:access:" + token.AccessToken
expiry := int(time.Until(token.ExpiresAt).Seconds())
if expiry <= 0 {
return errs.InvalidFormat("token already expired")
}
err := r.Redis.SetexCtx(ctx, accessKey, token.UID+":"+token.ClientID+":"+token.DeviceID, expiry)
if err != nil {
return errs.DatabaseErr(err.Error())
}
return nil
}
func (r *TokenRepository) Delete(ctx context.Context, id bson.ObjectID) error {
// Redis 版本不需要 ObjectID這裡留空實現
return nil
}
func (r *TokenRepository) DeleteByUserID(ctx context.Context, uid string) error {
// 可以實現刪除用戶所有 token 的邏輯
// 這裡簡化實現
return nil
}