app-cloudep-permission-server/pkg/repository/token_test.go

1751 lines
44 KiB
Go
Raw Permalink Normal View History

2025-02-12 01:51:46 +00:00
package repository
import (
"context"
"encoding/json"
"errors"
2025-02-13 11:06:51 +00:00
"testing"
"time"
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain"
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/entity"
2025-02-12 01:51:46 +00:00
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stores/redis"
)
func setupMiniRedis() (*miniredis.Miniredis, *redis.Redis) {
// 啟動 setupMiniRedis 作為模擬的 Redis 服務
mr, err := miniredis.Run()
if err != nil {
panic("failed to start miniRedis: " + err.Error())
}
// 使用 setupMiniRedis 的地址配置 go-zero Redis 客戶端
redisConf := redis.RedisConf{
Host: mr.Addr(),
Type: "node",
}
r := redis.MustNewRedis(redisConf)
return mr, r
}
func TestTokenRepository_Create(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
// 初始化 TokenRepository
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
// 定義測試參數
token := entity.Token{
ID: "token123",
UID: "user123",
DeviceID: "device123",
AccessToken: "access123",
ExpiresIn: time.Now().UTC().Add(10 * time.Second).UnixNano(), // 過期時間,現在加 10 秒 = 10 秒後
RefreshToken: "refresh123",
RefreshExpiresIn: time.Now().UTC().Add(10 * time.Second).UnixNano(), // 過期時間,現在加 10 秒 = 10 秒後
}
expiredTTL := 10 * time.Second // 過期時間
// 定義測試場景
tests := []struct {
name string
token entity.Token
prepareFunc func() error // 用於模擬 Redis 或序列化錯誤
wantErr bool
errMsg string
}{
{
name: "Successful token creation",
token: token,
wantErr: false,
},
{
name: "Redis Pipeline error",
token: token,
prepareFunc: func() error {
mr.SetError("forced Redis error") // 模擬 Redis 操作錯誤
return nil
},
wantErr: true,
errMsg: "forced Redis error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 清除上一次的錯誤模擬
mr.SetError("")
// 執行準備函數(模擬 Redis 或序列化錯誤)
if tt.prepareFunc != nil {
tt.prepareFunc()
}
// 執行 Create 方法
err := repo.Create(context.Background(), tt.token)
// 檢查是否出現預期錯誤
if tt.wantErr {
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
// 檢查是否成功設置了 AccessToken、RefreshToken 和 UID 及 DeviceID 關聯
tokenKey := domain.GetAccessTokenRedisKey(tt.token.ID)
refreshTokenKey := domain.GetRefreshTokenRedisKey(tt.token.RefreshToken)
uidKey := domain.GetUIDTokenRedisKey(tt.token.UID)
deviceIDKey := domain.GetDeviceTokenRedisKey(tt.token.DeviceID)
// 驗證 AccessToken 是否已設置
val, err := mr.Get(tokenKey)
assert.NoError(t, err)
expectedBody, _ := json.Marshal(tt.token)
assert.Equal(t, string(expectedBody), val)
// 驗證 RefreshToken 是否已設置
val, err = mr.Get(refreshTokenKey)
assert.NoError(t, err)
assert.Equal(t, tt.token.ID, val)
// 檢查 UID 和 DeviceID 關聯是否已設置
uidSetMembers, err := mr.SMembers(uidKey)
assert.NoError(t, err)
assert.Contains(t, uidSetMembers, tt.token.ID)
deviceIDSetMembers, err := mr.SMembers(deviceIDKey)
assert.NoError(t, err)
assert.Contains(t, deviceIDSetMembers, tt.token.ID)
// 檢查 AccessToken 和 RefreshToken 的過期時間
accessTTL := mr.TTL(tokenKey)
assert.InDelta(t, expiredTTL.Seconds(), accessTTL.Seconds(), 2, "AccessToken TTL 與設置的過期 TTl 應該相近")
refreshTTLVal := mr.TTL(refreshTokenKey)
assert.InDelta(t, expiredTTL.Seconds(), refreshTTLVal.Seconds(), 2, "Refresh TTL 與 與設置的過期 TTl 應該相近")
}
// 清除模擬錯誤
mr.SetError("")
})
}
}
func TestTokenRepository_retrieveToken(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
// 初始化 TokenRepository
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
// 模擬一個 Token 實例並將其存入 Redis
now := time.Now().UTC().UnixNano()
token := entity.Token{
ID: "token123",
UID: "user123",
DeviceID: "device123",
AccessToken: "access123",
ExpiresIn: time.Now().UTC().Add(3600 * time.Second).UnixNano(),
AccessCreateAt: now,
RefreshToken: "refresh123",
RefreshExpiresIn: time.Now().UTC().Add(7200 * time.Second).UnixNano(),
RefreshCreateAt: now,
}
// 將 Token 序列化為 JSON 並存入 Redis
tokenKey := domain.GetAccessTokenRedisKey(token.ID)
tokenData, _ := json.Marshal(token)
err := mr.Set(tokenKey, string(tokenData))
assert.NoError(t, err)
// 定義測試場景
tests := []struct {
name string
key string
want entity.Token
wantErr bool
errMsg string
}{
{
name: "ok",
key: tokenKey,
want: token,
wantErr: false,
},
{
name: "Token not found",
key: domain.GetAccessTokenRedisKey("nonexistent"),
want: entity.Token{},
wantErr: true,
errMsg: "failed to found token",
},
{
name: "Invalid JSON format",
key: domain.GetAccessTokenRedisKey("invalid_json"),
want: entity.Token{},
wantErr: true,
errMsg: "failed to unmarshal token JSON: invalid character 'i' looking for beginning of object key string",
},
}
// 將錯誤的 JSON 格式設置到 Redis
err = mr.Set(domain.GetAccessTokenRedisKey("invalid_json"), "{invalid_json}")
assert.NoError(t, err)
// 執行測試
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := repo.retrieveToken(context.Background(), tt.key)
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
} else {
assert.NoError(t, err)
// 比較 Token 的每個字段
assert.Equal(t, tt.want.ID, got.ID)
assert.Equal(t, tt.want.UID, got.UID)
assert.Equal(t, tt.want.DeviceID, got.DeviceID)
assert.Equal(t, tt.want.AccessToken, got.AccessToken)
assert.Equal(t, tt.want.ExpiresIn, got.ExpiresIn)
assert.Equal(t, tt.want.RefreshToken, got.RefreshToken)
assert.Equal(t, tt.want.RefreshExpiresIn, got.RefreshExpiresIn)
// 將時間字段轉換為 Unix() 格式進行比較
assert.Equal(t, tt.want.AccessCreateAt, got.AccessCreateAt)
assert.Equal(t, tt.want.RefreshCreateAt, got.RefreshCreateAt)
}
})
}
}
func TestTokenRepository_GetTokensBySet(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
// 初始化 TokenRepository
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
// 模擬兩個 Token 實例,一個過期,一個未過期,並將它們存入 Redis
now := time.Now().UTC()
unexpiredToken := entity.Token{
ID: "token123",
UID: "user123",
DeviceID: "device123",
AccessToken: "access123",
ExpiresIn: now.Add(time.Hour).UnixNano(), // 1 小時後過期
AccessCreateAt: now.UnixNano(),
RefreshToken: "refresh123",
RefreshExpiresIn: now.Add(2 * time.Hour).UnixNano(),
RefreshCreateAt: now.UnixNano(),
}
expiredToken := entity.Token{
ID: "token456",
UID: "user456",
DeviceID: "device456",
AccessToken: "access456",
ExpiresIn: now.Add(-time.Hour).UnixNano(), // 1 小時前過期
AccessCreateAt: now.Add(-2 * time.Hour).UnixNano(),
RefreshToken: "refresh456",
RefreshExpiresIn: now.Add(-30 * time.Minute).UnixNano(),
RefreshCreateAt: now.Add(-90 * time.Minute).UnixNano(),
}
// 將 Token 存入 Redis
unexpiredTokenData, _ := json.Marshal(unexpiredToken)
expiredTokenData, _ := json.Marshal(expiredToken)
err := mr.Set(domain.GetAccessTokenRedisKey(unexpiredToken.ID), string(unexpiredTokenData))
assert.NoError(t, err)
err = mr.Set(domain.GetAccessTokenRedisKey(expiredToken.ID), string(expiredTokenData))
assert.NoError(t, err)
// 將兩個 Token ID 添加到 Set 集合中
setKey := "permission:token_set"
_, err = mr.SAdd(setKey, unexpiredToken.ID)
if err != nil {
return
}
_, err = mr.SAdd(setKey, expiredToken.ID)
assert.NoError(t, err)
// 定義測試場景
tests := []struct {
name string
setKey string
wantTokens []entity.Token
wantErr bool
}{
{
name: "Set contains unexpired and expired tokens",
setKey: setKey,
wantTokens: []entity.Token{unexpiredToken}, // 預期僅返回未過期的 Token
wantErr: false,
},
{
name: "Set key not found",
setKey: "permission:nonexistent_set",
wantTokens: nil, // 預期返回 nil
wantErr: false,
},
}
// 執行測試
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := repo.getTokensBySet(context.Background(), tt.setKey)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, len(tt.wantTokens), len(got))
// 比較每個返回的 Token並檢查時間戳
for i, token := range got {
assert.Equal(t, tt.wantTokens[i].ID, token.ID)
assert.Equal(t, tt.wantTokens[i].UID, token.UID)
assert.Equal(t, tt.wantTokens[i].DeviceID, token.DeviceID)
assert.Equal(t, tt.wantTokens[i].AccessToken, token.AccessToken)
assert.Equal(t, tt.wantTokens[i].ExpiresIn, token.ExpiresIn)
assert.Equal(t, tt.wantTokens[i].RefreshToken, token.RefreshToken)
assert.Equal(t, tt.wantTokens[i].RefreshExpiresIn, token.RefreshExpiresIn)
assert.Equal(t, tt.wantTokens[i].AccessCreateAt, token.AccessCreateAt)
assert.Equal(t, tt.wantTokens[i].RefreshCreateAt, token.RefreshCreateAt)
}
}
})
}
}
func TestTokenRepository_GetCountBySet(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
// 初始化 TokenRepository
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
// 定義測試集合鍵和測試數據
setKey := "permission:token_set"
// 將測試數據存入 Redis
mr.SAdd(setKey, "token123")
mr.SAdd(setKey, "token456")
mr.SAdd(setKey, "token789")
// 定義測試場景
tests := []struct {
name string
setKey string
want int
wantErr bool
}{
{
name: "Count of existing set",
setKey: setKey,
want: 3, // 預期集合中有 3 個元素
wantErr: false,
},
{
name: "Non-existent set",
setKey: "permission:nonexistent_set",
want: 0, // 預期集合不存在,返回 0
wantErr: false,
},
}
// 執行測試
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := repo.getCountBySet(context.Background(), tt.setKey)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}
func TestTokenRepository_SetRelation(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
// 初始化 TokenRepository
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
// 定義測試參數
uid := "user123"
deviceID := "device123"
tokenID := "token123"
ttl := 10 * time.Second // 設置過期時間為 10 秒
// 定義測試場景
tests := []struct {
name string
uid string
deviceID string
tokenID string
ttl time.Duration
prepareFunc func() error // 用於模擬 Redis 錯誤
wantErr bool
errMsg string
}{
{
name: "Valid relation setting",
uid: uid,
deviceID: deviceID,
tokenID: tokenID,
ttl: ttl,
wantErr: false,
},
{
name: "Redis SAdd error",
uid: uid,
deviceID: deviceID,
tokenID: tokenID,
ttl: ttl,
prepareFunc: func() error {
mr.SetError("forced SAdd error") // 模擬 SAdd 錯誤
return nil
},
wantErr: true,
errMsg: "forced SAdd error",
},
{
name: "Redis Expire error",
uid: uid,
deviceID: deviceID,
tokenID: tokenID,
ttl: ttl,
prepareFunc: func() error {
mr.SetError("forced Expire error") // 模擬 Expire 錯誤
return nil
},
wantErr: true,
errMsg: "forced Expire error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 清除上一次的錯誤模擬
mr.SetError("")
// 執行準備函數(模擬 Redis 錯誤)
if tt.prepareFunc != nil {
tt.prepareFunc()
}
// 構建 Redis 鍵
uidKey := domain.GetUIDTokenRedisKey(tt.uid)
deviceIDKey := domain.GetDeviceTokenRedisKey(tt.deviceID)
// 執行 Redis Pipeline
err := r.Pipelined(func(tx redis.Pipeliner) error {
return repo.setTokenRelation(context.Background(), tx, tt.uid, tt.deviceID, tt.tokenID, tt.ttl)
})
// 檢查是否出現預期錯誤
if tt.wantErr {
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
// 檢查 UID 和 DeviceID 關聯是否已設置
uidSetMembers, err := mr.SMembers(uidKey)
assert.NoError(t, err)
assert.Contains(t, uidSetMembers, tt.tokenID)
deviceIDSetMembers, err := mr.SMembers(deviceIDKey)
assert.NoError(t, err)
assert.Contains(t, deviceIDSetMembers, tt.tokenID)
// 檢查 UID 和 DeviceID 鍵的過期時間
uidTTL := mr.TTL(uidKey)
assert.Equal(t, tt.ttl.Seconds(), uidTTL.Seconds())
deviceIDTTL := mr.TTL(deviceIDKey)
assert.Equal(t, tt.ttl.Seconds(), deviceIDTTL.Seconds())
}
// 清除模擬錯誤
mr.SetError("")
})
}
}
func TestTokenRepository_SetRefreshToken(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
// 初始化 TokenRepository
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
// 定義測試參數
ttl := 10 * time.Second // 設置過期時間為 10 秒
// 定義測試場景
tests := []struct {
name string
token entity.Token
ttl time.Duration
prepareFunc func() error // 用於模擬 Redis 錯誤
wantErr bool
errMsg string
}{
{
name: "Valid RefreshToken setting",
token: entity.Token{
ID: "token123",
RefreshToken: "refresh123",
},
ttl: ttl,
wantErr: false,
},
{
name: "Empty RefreshToken",
token: entity.Token{
ID: "token456",
RefreshToken: "",
},
ttl: ttl,
wantErr: false,
},
{
name: "Redis Set error",
token: entity.Token{
ID: "token789",
RefreshToken: "refresh789",
},
ttl: ttl,
prepareFunc: func() error {
mr.SetError("forced Set error") // 模擬 Set 操作錯誤
return nil
},
wantErr: true,
errMsg: "forced Set error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 清除上一次的錯誤模擬
mr.SetError("")
// 執行準備函數(模擬 Redis 錯誤)
if tt.prepareFunc != nil {
tt.prepareFunc()
}
// 執行 Redis Pipeline
err := r.Pipelined(func(tx redis.Pipeliner) error {
return repo.setRefreshToken(context.Background(), tx, tt.token, tt.ttl)
})
// 檢查是否出現預期錯誤
if tt.wantErr {
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
// 如果 RefreshToken 不為空,檢查是否成功設置了鍵
if tt.token.RefreshToken != "" {
refreshTokenKey := domain.GetRefreshTokenRedisKey(tt.token.RefreshToken)
val, err := mr.Get(refreshTokenKey)
assert.NoError(t, err)
assert.Equal(t, tt.token.ID, val)
// 檢查 RefreshToken 鍵的過期時間
ttlVal := mr.TTL(refreshTokenKey)
assert.Equal(t, tt.ttl.Seconds(), ttlVal.Seconds())
}
}
// 清除模擬錯誤
mr.SetError("")
})
}
}
func TestTokenRepository_SetToken(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
// 初始化 TokenRepository
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
// 定義測試參數
ttl := 10 * time.Second // 設置過期時間為 10 秒
token := entity.Token{
ID: "token123",
UID: "user123",
DeviceID: "device123",
AccessToken: "access123",
ExpiresIn: time.Now().UTC().Add(7200 * time.Second).UnixNano(),
RefreshToken: "refresh123",
}
body, _ := json.Marshal(token) // 將 Token 轉為 JSON 格式
// 定義測試場景
tests := []struct {
name string
token entity.Token
body []byte
ttl time.Duration
prepareFunc func() error // 用於模擬 Redis 錯誤
wantErr bool
errMsg string
}{
{
name: "Valid Token setting",
token: token,
body: body,
ttl: ttl,
wantErr: false,
},
{
name: "Redis Set error",
token: token,
body: body,
ttl: ttl,
prepareFunc: func() error {
mr.SetError("forced Set error") // 模擬 Set 操作錯誤
return nil
},
wantErr: true,
errMsg: "forced Set error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 清除上一次的錯誤模擬
mr.SetError("")
// 執行準備函數(模擬 Redis 錯誤)
if tt.prepareFunc != nil {
tt.prepareFunc()
}
// 構建 Redis 鍵
tokenKey := domain.GetAccessTokenRedisKey(tt.token.ID)
// 執行 Redis Pipeline
err := r.Pipelined(func(tx redis.Pipeliner) error {
return repo.setToken(context.Background(), tx, tt.token.ID, tt.body, tt.ttl)
})
// 檢查是否出現預期錯誤
if tt.wantErr {
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
// 驗證 Token 是否已設置
val, err := mr.Get(tokenKey)
assert.NoError(t, err)
assert.Equal(t, string(tt.body), val)
// 檢查 Token 鍵的過期時間
ttlVal := mr.TTL(tokenKey)
assert.Equal(t, tt.ttl.Seconds(), ttlVal.Seconds())
}
// 清除模擬錯誤
mr.SetError("")
})
}
}
func TestTokenRepository_RunPipeline(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
// 初始化 TokenRepository
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
// 定義測試場景
tests := []struct {
name string
prepareFunc func() error // 準備函數,用於模擬 Redis 錯誤
fn func(tx redis.Pipeliner) error // 要在 Pipeline 中執行的函數
wantErr bool // 是否期望錯誤
errMsg string // 預期的錯誤信息
}{
{
name: "Successful Pipeline Execution",
fn: func(tx redis.Pipeliner) error {
// 模擬一個簡單的操作
return tx.Set(context.Background(), "testkey", "testvalue", 0).Err()
},
wantErr: false,
},
{
name: "Pipeline Function Error",
fn: func(tx redis.Pipeliner) error {
return errors.New("forced function error") // 模擬 Pipeline 操作中的錯誤
},
wantErr: true,
errMsg: "forced function error",
},
{
name: "Redis Pipeline Error",
fn: func(tx redis.Pipeliner) error {
return tx.Set(context.Background(), "testkey", "testvalue", 0).Err()
},
prepareFunc: func() error {
mr.SetError("forced Redis error") // 模擬 Redis 操作錯誤
return nil
},
wantErr: true,
errMsg: "forced Redis error",
},
}
// 執行測試
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 清除上一次的錯誤模擬
mr.SetError("")
// 執行準備函數以模擬 Redis 錯誤
if tt.prepareFunc != nil {
tt.prepareFunc()
}
// 執行 runPipeline 並捕獲錯誤
err := repo.runPipeline(context.Background(), tt.fn)
// 檢查是否出現預期錯誤
if tt.wantErr {
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
// 如果操作成功,驗證 Redis 中的鍵
if tt.name == "Successful Pipeline Execution" {
val, err := mr.Get("testkey")
assert.NoError(t, err)
assert.Equal(t, "testvalue", val)
}
}
// 清除模擬錯誤
mr.SetError("")
})
}
}
// 定義一個無法序列化的結構以模擬序列化錯誤
type Unserializable struct{}
func (u Unserializable) MarshalJSON() ([]byte, error) {
return nil, errors.New("forced JSON marshal error")
}
func TestTokenRepository_CreateOneTimeToken(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
// 初始化 TokenRepository
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
// 定義測試參數
key := "one_time_key"
duration := 10 * time.Second // 設置過期時間為 10 秒
ticket := entity.Ticket{
Data: "sample_data",
Token: entity.Token{
ID: "token123",
AccessToken: "access123",
},
}
// 定義測試場景
tests := []struct {
name string
key string
ticket entity.Ticket
duration time.Duration
prepareFunc func() error // 用於模擬 Redis 或序列化錯誤
wantErr bool
errMsg string
}{
{
name: "Successful one-time token creation",
key: key,
ticket: ticket,
duration: duration,
wantErr: false,
},
{
name: "JSON marshal error",
key: key,
ticket: entity.Ticket{
Data: Unserializable{},
Token: entity.Token{
ID: "invalid_token",
},
},
duration: duration,
wantErr: true,
errMsg: "json: error calling MarshalJSON for type repository.Unserializable: forced JSON marshal error",
},
{
name: "Redis SetnxEx error",
key: key,
ticket: ticket,
duration: duration,
prepareFunc: func() error {
mr.SetError("forced Redis error") // 模擬 Redis 操作錯誤
return nil
},
wantErr: true,
errMsg: "forced Redis error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 清除上一次的錯誤模擬
mr.SetError("")
// 執行準備函數(模擬 Redis 或序列化錯誤)
if tt.prepareFunc != nil {
tt.prepareFunc()
}
// 執行 CreateOneTimeToken 方法
err := repo.CreateOneTimeToken(context.Background(), tt.key, tt.ticket, tt.duration)
// 檢查是否出現預期錯誤
if tt.wantErr {
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
// 構建預期的 Redis 鍵
oneTimeTokenKey := domain.GetRefreshTokenRedisKey(tt.key)
// 檢查 Redis 中是否設置了臨時 Token
val, err := mr.Get(oneTimeTokenKey)
assert.NoError(t, err)
expectedBody, _ := json.Marshal(tt.ticket)
assert.Equal(t, string(expectedBody), val)
// 檢查過期時間
ttl := mr.TTL(oneTimeTokenKey)
assert.Equal(t, tt.duration.Seconds(), ttl.Seconds())
}
// 清除模擬錯誤
mr.SetError("")
})
}
}
func TestTokenRepository_GetAccessTokenByOneTimeToken(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
// 初始化 TokenRepository
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
// 定義測試參數
oneTimeToken := "one_time_token_123"
accessTokenID := "token123"
expectedToken := entity.Token{
ID: accessTokenID,
UID: "user123",
DeviceID: "device123",
AccessToken: "access123",
ExpiresIn: 3600,
RefreshToken: "refresh123",
}
// 在 Redis 中設置模擬的數據
_ = mr.Set(domain.GetRefreshTokenRedisKey(oneTimeToken), accessTokenID)
tokenData, _ := json.Marshal(expectedToken)
_ = mr.Set(domain.GetRefreshTokenRedisKey(oneTimeToken), string(tokenData))
// 定義測試場景
tests := []struct {
name string
oneTimeToken string
prepareFunc func() error // 用於模擬 Redis 錯誤
expected entity.Token
wantErr bool
errMsg string
}{
{
name: "Successful retrieval of access token by one-time token",
oneTimeToken: oneTimeToken,
expected: expectedToken,
wantErr: false,
},
{
name: "Token not found in Redis",
oneTimeToken: "nonexistent_token",
wantErr: true,
errMsg: "failed to found token",
},
{
name: "Redis Get error",
oneTimeToken: oneTimeToken,
prepareFunc: func() error {
mr.SetError("forced Redis error") // 模擬 Redis 錯誤
return nil
},
wantErr: true,
errMsg: "forced Redis error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 清除上一次的錯誤模擬
mr.SetError("")
// 執行準備函數(模擬 Redis 錯誤)
if tt.prepareFunc != nil {
tt.prepareFunc()
}
// 執行 GetAccessTokenByOneTimeToken 方法
result, err := repo.GetAccessTokenByOneTimeToken(context.Background(), tt.oneTimeToken)
// 檢查是否出現預期錯誤
if tt.wantErr {
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
// 清除模擬錯誤
mr.SetError("")
})
}
}
func TestTokenRepository_GetAccessTokensByUID(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
// 初始化 TokenRepository
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
// 定義測試參數
uid := "user123"
tokens := []entity.Token{
{
ID: "token1",
UID: uid,
DeviceID: "device1",
AccessToken: "access1",
ExpiresIn: time.Now().UTC().Add(60 * time.Minute).UnixNano(),
RefreshExpiresIn: time.Now().UTC().Add(60 * time.Minute).UnixNano(),
RefreshToken: "refresh1",
},
{
ID: "token2",
UID: uid,
DeviceID: "device2",
AccessToken: "access2",
ExpiresIn: time.Now().UTC().Add(60 * time.Minute).UnixNano(),
RefreshExpiresIn: time.Now().UTC().Add(60 * time.Minute).UnixNano(),
RefreshToken: "refresh2",
},
}
for _, token := range tokens {
err := repo.Create(context.Background(), token)
assert.NoError(t, err)
}
// 定義測試場景
tests := []struct {
name string
uid string
prepareFunc func() error // 用於模擬 Redis 錯誤
expected []entity.Token
wantErr bool
errMsg string
}{
{
name: "Successful retrieval of tokens by UID",
uid: uid,
expected: tokens,
wantErr: false,
},
{
name: "UID not found in Redis",
uid: "nonexistent_user",
expected: []entity.Token{},
wantErr: false,
},
{
name: "Redis SMember error",
uid: uid,
prepareFunc: func() error {
mr.SetError("forced Redis error") // 模擬 Redis 錯誤
return nil
},
wantErr: true,
errMsg: "forced Redis error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 清除上一次的錯誤模擬
mr.SetError("")
// 執行準備函數(模擬 Redis 錯誤)
if tt.prepareFunc != nil {
tt.prepareFunc()
}
// 執行 GetAccessTokensByUID 方法
result, err := repo.GetAccessTokensByUID(context.Background(), tt.uid)
// 檢查是否出現預期錯誤
if tt.wantErr {
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
// 清除模擬錯誤
mr.SetError("")
})
}
}
func TestTokenRepository_GetAccessTokenCountByUID(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
// 初始化 TokenRepository
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
// 定義測試參數
uid := "user123"
uidKey := domain.GetUIDTokenRedisKey(uid)
// 在 Redis 中設置模擬的數據
_, _ = mr.SAdd(uidKey, "token1")
_, _ = mr.SAdd(uidKey, "token2")
_, _ = mr.SAdd(uidKey, "token3")
// 定義測試場景
tests := []struct {
name string
uid string
prepareFunc func() error // 用於模擬 Redis 錯誤
expected int
wantErr bool
errMsg string
}{
{
name: "Successful retrieval of token count by UID",
uid: uid,
expected: 3,
wantErr: false,
},
{
name: "UID not found in Redis",
uid: "nonexistent_user",
expected: 0,
wantErr: false,
},
{
name: "Redis Scard error",
uid: uid,
prepareFunc: func() error {
mr.SetError("forced Redis error") // 模擬 Redis 錯誤
return nil
},
wantErr: true,
errMsg: "forced Redis error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 清除上一次的錯誤模擬
mr.SetError("")
// 執行準備函數(模擬 Redis 錯誤)
if tt.prepareFunc != nil {
tt.prepareFunc()
}
// 執行 GetAccessTokenCountByUID 方法
result, err := repo.GetAccessTokenCountByUID(context.Background(), tt.uid)
// 檢查是否出現預期錯誤
if tt.wantErr {
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
// 清除模擬錯誤
mr.SetError("")
})
}
}
func TestTokenRepository_GetAccessTokensByDeviceID(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
// 初始化 TokenRepository
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
// 定義測試參數
deviceID := "device123"
deviceKey := domain.GetDeviceTokenRedisKey(deviceID)
// 模擬在 Redis 中存儲多個 Token
tokens := []entity.Token{
{
ID: "token1",
UID: "user123",
DeviceID: deviceID,
AccessToken: "access1",
ExpiresIn: time.Now().UTC().Add(60 * time.Minute).UnixNano(),
RefreshToken: "refresh1",
},
{
ID: "token2",
UID: "user123",
DeviceID: deviceID,
AccessToken: "access2",
ExpiresIn: time.Now().UTC().Add(60 * time.Minute).UnixNano(),
RefreshToken: "refresh2",
},
}
// 在 Redis 中設置初始數據
for _, token := range tokens {
tokenData, _ := json.Marshal(token)
_ = mr.Set(domain.GetAccessTokenRedisKey(token.ID), string(tokenData))
_, _ = mr.SAdd(deviceKey, token.ID)
}
// 定義測試場景
tests := []struct {
name string
deviceID string
prepareFunc func() error // 用於模擬 Redis 錯誤
expected []entity.Token
wantErr bool
errMsg string
}{
{
name: "Successful retrieval of tokens by Device ID",
deviceID: deviceID,
expected: tokens,
wantErr: false,
},
{
name: "Device ID not found in Redis",
deviceID: "nonexistent_device",
expected: []entity.Token{},
wantErr: false,
},
{
name: "Redis SMember error",
deviceID: deviceID,
prepareFunc: func() error {
mr.SetError("forced Redis error") // 模擬 Redis 錯誤
return nil
},
wantErr: true,
errMsg: "forced Redis error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 清除上一次的錯誤模擬
mr.SetError("")
// 執行準備函數(模擬 Redis 錯誤)
if tt.prepareFunc != nil {
tt.prepareFunc()
}
// 執行 GetAccessTokensByDeviceID 方法
result, err := repo.GetAccessTokensByDeviceID(context.Background(), tt.deviceID)
// 檢查是否出現預期錯誤
if tt.wantErr {
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
// 清除模擬錯誤
mr.SetError("")
})
}
}
func TestTokenRepository_GetAccessTokenCountByDeviceID(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
// 初始化 TokenRepository
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
// 定義測試參數
deviceID := "device123"
deviceKey := domain.GetDeviceTokenRedisKey(deviceID)
// 在 Redis 中設置模擬的數據
_, _ = mr.SAdd(deviceKey, "token1")
_, _ = mr.SAdd(deviceKey, "token2")
_, _ = mr.SAdd(deviceKey, "token3")
// 定義測試場景
tests := []struct {
name string
deviceID string
prepareFunc func() error // 用於模擬 Redis 錯誤
expected int
wantErr bool
errMsg string
}{
{
name: "Successful retrieval of token count by Device ID",
deviceID: deviceID,
expected: 3,
wantErr: false,
},
{
name: "Device ID not found in Redis",
deviceID: "nonexistent_device",
expected: 0,
wantErr: false,
},
{
name: "Redis Scard error",
deviceID: deviceID,
prepareFunc: func() error {
mr.SetError("forced Redis error") // 模擬 Redis 錯誤
return nil
},
wantErr: true,
errMsg: "forced Redis error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 清除上一次的錯誤模擬
mr.SetError("")
// 執行準備函數(模擬 Redis 錯誤)
if tt.prepareFunc != nil {
tt.prepareFunc()
}
// 執行 GetAccessTokenCountByDeviceID 方法
result, err := repo.GetAccessTokenCountByDeviceID(context.Background(), tt.deviceID)
// 檢查是否出現預期錯誤
if tt.wantErr {
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
// 清除模擬錯誤
mr.SetError("")
})
}
}
func TestTokenRepository_Delete(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
// 初始化 TokenRepository
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
// 定義測試參數
token := entity.Token{
ID: "token123",
UID: "user123",
DeviceID: "device123",
AccessToken: "access123",
RefreshToken: "refresh123",
ExpiresIn: time.Now().UTC().Add(60 * time.Minute).UnixNano(),
RefreshExpiresIn: time.Now().UTC().Add(60 * time.Minute).UnixNano(),
}
// 模擬在 Redis 中存儲 Token 的數據
accessTokenKey := domain.GetAccessTokenRedisKey(token.ID)
refreshTokenKey := domain.GetRefreshTokenRedisKey(token.RefreshToken)
uidKey := domain.GetUIDTokenRedisKey(token.UID)
deviceIDKey := domain.GetDeviceTokenRedisKey(token.DeviceID)
// 模擬在 Redis 中存儲 Token 的數據
repo.Create(context.TODO(), token)
// 定義測試場景
tests := []struct {
name string
token entity.Token
prepareFunc func() error // 用於模擬 Redis 錯誤
wantErr bool
errMsg string
jump bool
}{
{
name: "Successful deletion of token",
token: token,
wantErr: false,
},
{
name: "Redis delete error",
token: token,
prepareFunc: func() error {
mr.SetError("forced Redis delete error") // 模擬 Redis 錯誤
return nil
},
wantErr: true,
errMsg: "forced Redis delete error",
},
{
name: "Deletion of non-existent token",
token: entity.Token{ID: "nonexistent_token", UID: "user123", DeviceID: "device123"},
wantErr: false,
jump: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 清除上一次的錯誤模擬
mr.SetError("")
// 執行準備函數(模擬 Redis 錯誤)
if tt.prepareFunc != nil {
tt.prepareFunc()
}
// 執行 Delete 方法
err := repo.Delete(context.Background(), tt.token)
// 檢查是否出現預期錯誤
if tt.wantErr {
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
if !tt.jump {
assert.NoError(t, err)
// 驗證 Token 的鍵已刪除
_, err = mr.Get(accessTokenKey)
assert.Error(t, miniredis.ErrKeyNotFound, err)
_, err = mr.Get(refreshTokenKey)
assert.Error(t, miniredis.ErrKeyNotFound, err)
// 驗證 UID 和 DeviceID 關聯已刪除
uidSetMembers, err := mr.SMembers(uidKey)
assert.Error(t, miniredis.ErrKeyNotFound, err)
assert.NotContains(t, uidSetMembers, token.ID)
deviceIDSetMembers, err := mr.SMembers(deviceIDKey)
assert.Error(t, miniredis.ErrKeyNotFound, err)
assert.NotContains(t, deviceIDSetMembers, token.ID)
}
}
// 清除模擬錯誤
mr.SetError("")
})
}
}
func TestTokenRepository_DeleteAccessTokensByDeviceID(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
// 初始化 TokenRepository
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
// 定義測試參數
deviceID := "device123"
tokens := []entity.Token{
{
ID: "token1",
UID: "user123",
DeviceID: deviceID,
AccessToken: "access1",
RefreshToken: "refresh1",
ExpiresIn: time.Now().UTC().Add(60 * time.Minute).UnixNano(),
RefreshExpiresIn: time.Now().UTC().Add(60 * time.Minute).UnixNano(),
},
{
ID: "token2",
UID: "user123",
DeviceID: deviceID,
AccessToken: "access2",
RefreshToken: "refresh2",
ExpiresIn: time.Now().UTC().Add(60 * time.Minute).UnixNano(),
RefreshExpiresIn: time.Now().UTC().Add(60 * time.Minute).UnixNano(),
},
}
// 在 Redis 中設置初始數據
deviceKey := domain.GetDeviceTokenRedisKey(deviceID)
for _, token := range tokens {
accessTokenKey := domain.GetAccessTokenRedisKey(token.ID)
refreshTokenKey := domain.GetRefreshTokenRedisKey(token.RefreshToken)
uidKey := domain.GetUIDTokenRedisKey(token.UID)
_ = mr.Set(accessTokenKey, token.AccessToken)
_ = mr.Set(refreshTokenKey, token.ID)
_, _ = mr.SAdd(uidKey, token.ID)
_, _ = mr.SAdd(deviceKey, token.ID)
}
// 定義測試場景
tests := []struct {
name string
deviceID string
prepareFunc func() error // 用於模擬 Redis 錯誤
wantErr bool
errMsg string
}{
{
name: "Successful deletion of tokens by Device ID",
deviceID: deviceID,
wantErr: false,
},
{
name: "GetAccessTokensByDeviceID error",
deviceID: deviceID,
prepareFunc: func() error {
mr.SetError("forced error in GetAccessTokensByDeviceID") // 模擬錯誤
return nil
},
wantErr: true,
errMsg: "forced error in GetAccessTokensByDeviceID",
},
{
name: "Delete non-existent device ID",
deviceID: "nonexistent_device",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 清除上一次的錯誤模擬
mr.SetError("")
// 執行準備函數(模擬 Redis 錯誤)
if tt.prepareFunc != nil {
tt.prepareFunc()
}
// 執行 DeleteAccessTokensByDeviceID 方法
err := repo.DeleteAccessTokensByDeviceID(context.Background(), tt.deviceID)
// 檢查是否出現預期錯誤
if tt.wantErr {
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
// 檢查是否刪除了 AccessToken、RefreshToken 和 UID 關聯的鍵
for _, token := range tokens {
accessTokenKey := domain.GetAccessTokenRedisKey(token.ID)
refreshTokenKey := domain.GetRefreshTokenRedisKey(token.RefreshToken)
_, err = mr.Get(accessTokenKey)
assert.Error(t, miniredis.ErrKeyNotFound, err)
_, err = mr.Get(refreshTokenKey)
assert.Error(t, miniredis.ErrKeyNotFound, err)
}
// 檢查是否刪除了 deviceID 關聯的鍵
_, err = mr.Get(deviceKey)
assert.Equal(t, miniredis.ErrKeyNotFound, err)
}
// 清除模擬錯誤
mr.SetError("")
})
}
}
func TestTokenRepository_DeleteOneTimeToken(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
// 初始化 TokenRepository
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
// 定義測試參數
ids := []string{"one_time_token1", "one_time_token2"}
tokens := []entity.Token{
{RefreshToken: "refresh_token1"},
{RefreshToken: "refresh_token2"},
}
// 在 Redis 中設置模擬的數據
for _, id := range ids {
_ = mr.Set(domain.GetRefreshTokenRedisKey(id), "dummy_value")
}
for _, token := range tokens {
_ = mr.Set(domain.GetRefreshTokenRedisKey(token.RefreshToken), "dummy_value")
}
// 定義測試場景
tests := []struct {
name string
ids []string
tokens []entity.Token
prepareFunc func() error // 用於模擬 Redis 錯誤
wantErr bool
errMsg string
}{
{
name: "Successful deletion of one-time tokens",
ids: ids,
tokens: tokens,
wantErr: false,
},
{
name: "Deletion of non-existent one-time tokens",
ids: []string{"nonexistent_id1", "nonexistent_id2"},
tokens: []entity.Token{{RefreshToken: "nonexistent_refresh1"}, {RefreshToken: "nonexistent_refresh2"}},
wantErr: false,
},
{
name: "Redis delete error",
ids: ids,
tokens: tokens,
prepareFunc: func() error {
mr.SetError("forced Redis delete error") // 模擬 Redis 刪除錯誤
return nil
},
wantErr: true,
errMsg: "forced Redis delete error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 清除上一次的錯誤模擬
mr.SetError("")
// 執行準備函數(模擬 Redis 錯誤)
if tt.prepareFunc != nil {
tt.prepareFunc()
}
// 執行 DeleteOneTimeToken 方法
err := repo.DeleteOneTimeToken(context.Background(), tt.ids, tt.tokens)
// 檢查是否出現預期錯誤
if tt.wantErr {
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
// 驗證 Redis 中的鍵已刪除
for _, id := range tt.ids {
key := domain.GetRefreshTokenRedisKey(id)
_, err := mr.Get(key)
assert.Equal(t, miniredis.ErrKeyNotFound, err)
}
for _, token := range tt.tokens {
key := domain.GetRefreshTokenRedisKey(token.RefreshToken)
_, err := mr.Get(key)
assert.Equal(t, miniredis.ErrKeyNotFound, err)
}
}
// 清除模擬錯誤
mr.SetError("")
})
}
}
func TestTokenRepository_DeleteAccessTokensByUID(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
// 初始化 TokenRepository
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
// 定義測試參數
uid := "user123"
tokens := []entity.Token{
{
ID: "token1",
UID: uid,
DeviceID: "device1",
AccessToken: "access1",
RefreshToken: "refresh1",
},
{
ID: "token2",
UID: uid,
DeviceID: "device2",
AccessToken: "access2",
RefreshToken: "refresh2",
},
}
// 在 Redis 中設置模擬的數據
for _, token := range tokens {
accessTokenKey := domain.GetAccessTokenRedisKey(token.ID)
refreshTokenKey := domain.GetRefreshTokenRedisKey(token.RefreshToken)
uidKey := domain.GetUIDTokenRedisKey(uid)
_ = mr.Set(accessTokenKey, token.AccessToken)
_ = mr.Set(refreshTokenKey, token.ID)
_, _ = mr.SAdd(uidKey, token.ID)
}
// 定義測試場景
tests := []struct {
name string
uid string
prepareFunc func() error // 用於模擬 Redis 錯誤
wantErr bool
errMsg string
jump bool
}{
{
name: "Successful deletion of tokens by UID",
uid: uid,
wantErr: false,
},
{
name: "GetAccessTokensByUID error",
uid: uid,
prepareFunc: func() error {
mr.SetError("forced error in GetAccessTokensByUID") // 模擬查詢錯誤
return nil
},
wantErr: true,
errMsg: "forced error in GetAccessTokensByUID",
},
{
name: "Delete non-existent UID",
uid: "nonexistent_uid",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 清除上一次的錯誤模擬
mr.SetError("")
// 執行準備函數(模擬 Redis 錯誤)
if tt.prepareFunc != nil {
tt.prepareFunc()
}
// 執行 DeleteAccessTokensByUID 方法
err := repo.DeleteAccessTokensByUID(context.Background(), tt.uid)
// 檢查是否出現預期錯誤
if tt.wantErr {
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
if tt.jump {
// 驗證 Redis 中的鍵已刪除
for _, token := range tokens {
accessTokenKey := domain.GetAccessTokenRedisKey(token.ID)
refreshTokenKey := domain.GetRefreshTokenRedisKey(token.RefreshToken)
uidKey := domain.GetUIDTokenRedisKey(uid)
// 驗證 AccessToken 和 RefreshToken 鍵是否已刪除
_, err := mr.Get(accessTokenKey)
assert.Error(t, miniredis.ErrKeyNotFound, err)
_, err = mr.Get(refreshTokenKey)
assert.Error(t, miniredis.ErrKeyNotFound, err)
// 驗證 UID 關聯是否已刪除
uidSetMembers, err := mr.SMembers(uidKey)
assert.NoError(t, err)
assert.NotContains(t, uidSetMembers, token.ID)
}
}
}
// 清除模擬錯誤
mr.SetError("")
})
}
}