383 lines
10 KiB
Go
383 lines
10 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
"backend/pkg/permission/domain/entity"
|
|
|
|
"github.com/alicebob/miniredis/v2"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
|
)
|
|
|
|
func setupMiniRedis() (*miniredis.Miniredis, *redis.Redis) {
|
|
// 啟動 setupMiniRedis 作為模擬的 Redis 服務
|
|
mr, err := miniredis.Run()
|
|
if err != nil {
|
|
panic("failed to start miniRedis: " + err.Error())
|
|
}
|
|
|
|
// 使用 setupMiniRedis 的地址配置 go-zero Redis 客戶端
|
|
redisConf := redis.RedisConf{
|
|
Host: mr.Addr(),
|
|
Type: "node",
|
|
}
|
|
r := redis.MustNewRedis(redisConf)
|
|
|
|
return mr, r
|
|
}
|
|
|
|
func TestTokenRepository_Blacklist(t *testing.T) {
|
|
mr, r := setupMiniRedis()
|
|
defer mr.Close()
|
|
|
|
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
|
|
ctx := context.Background()
|
|
|
|
t.Run("AddToBlacklist", func(t *testing.T) {
|
|
entry := &entity.BlacklistEntry{
|
|
JTI: "test-jti-123",
|
|
UID: "user123",
|
|
TokenID: "token123",
|
|
Reason: "user logout",
|
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
|
CreatedAt: time.Now().Unix(),
|
|
}
|
|
|
|
err := repo.AddToBlacklist(ctx, entry, time.Hour)
|
|
assert.NoError(t, err)
|
|
|
|
// Verify it was added
|
|
isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI)
|
|
assert.NoError(t, err)
|
|
assert.True(t, isBlacklisted)
|
|
})
|
|
|
|
t.Run("IsBlacklisted - not found", func(t *testing.T) {
|
|
isBlacklisted, err := repo.IsBlacklisted(ctx, "non-existent-jti")
|
|
assert.NoError(t, err)
|
|
assert.False(t, isBlacklisted)
|
|
})
|
|
|
|
t.Run("RemoveFromBlacklist", func(t *testing.T) {
|
|
// First add an entry
|
|
entry := &entity.BlacklistEntry{
|
|
JTI: "test-jti-456",
|
|
UID: "user456",
|
|
TokenID: "token456",
|
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
|
CreatedAt: time.Now().Unix(),
|
|
}
|
|
|
|
err := repo.AddToBlacklist(ctx, entry, time.Hour)
|
|
assert.NoError(t, err)
|
|
|
|
// Verify it exists
|
|
isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI)
|
|
assert.NoError(t, err)
|
|
assert.True(t, isBlacklisted)
|
|
|
|
// Remove it
|
|
err = repo.RemoveFromBlacklist(ctx, entry.JTI)
|
|
assert.NoError(t, err)
|
|
|
|
// Verify it's gone
|
|
isBlacklisted, err = repo.IsBlacklisted(ctx, entry.JTI)
|
|
assert.NoError(t, err)
|
|
assert.False(t, isBlacklisted)
|
|
})
|
|
|
|
t.Run("GetBlacklistedTokensByUID", func(t *testing.T) {
|
|
uid := "user789"
|
|
|
|
// Add multiple entries for the same user
|
|
entries := []*entity.BlacklistEntry{
|
|
{
|
|
JTI: "jti-1",
|
|
UID: uid,
|
|
TokenID: "token-1",
|
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
|
CreatedAt: time.Now().Unix(),
|
|
},
|
|
{
|
|
JTI: "jti-2",
|
|
UID: uid,
|
|
TokenID: "token-2",
|
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
|
CreatedAt: time.Now().Unix(),
|
|
},
|
|
{
|
|
JTI: "jti-3",
|
|
UID: "different-user",
|
|
TokenID: "token-3",
|
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
|
CreatedAt: time.Now().Unix(),
|
|
},
|
|
}
|
|
|
|
for _, entry := range entries {
|
|
err := repo.AddToBlacklist(ctx, entry, time.Hour)
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
// Get blacklisted tokens for the user
|
|
userEntries, err := repo.GetBlacklistedTokensByUID(ctx, uid)
|
|
assert.NoError(t, err)
|
|
assert.Len(t, userEntries, 2) // Should only get entries for the specific user
|
|
|
|
// Verify all returned entries belong to the correct user
|
|
for _, entry := range userEntries {
|
|
assert.Equal(t, uid, entry.UID)
|
|
}
|
|
})
|
|
|
|
t.Run("AddToBlacklist with zero TTL", func(t *testing.T) {
|
|
entry := &entity.BlacklistEntry{
|
|
JTI: "test-jti-zero-ttl",
|
|
UID: "user-zero-ttl",
|
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
|
CreatedAt: time.Now().Unix(),
|
|
}
|
|
|
|
// Test with zero TTL - should calculate from ExpiresAt
|
|
err := repo.AddToBlacklist(ctx, entry, 0)
|
|
assert.NoError(t, err)
|
|
|
|
// Verify it was added
|
|
isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI)
|
|
assert.NoError(t, err)
|
|
assert.True(t, isBlacklisted)
|
|
})
|
|
|
|
t.Run("AddToBlacklist with expired token", func(t *testing.T) {
|
|
entry := &entity.BlacklistEntry{
|
|
JTI: "test-jti-expired",
|
|
UID: "user-expired",
|
|
ExpiresAt: time.Now().Add(-time.Hour).Unix(), // Already expired
|
|
CreatedAt: time.Now().Unix(),
|
|
}
|
|
|
|
// Should not add expired token to blacklist
|
|
err := repo.AddToBlacklist(ctx, entry, 0)
|
|
assert.NoError(t, err) // No error, but token won't be added
|
|
|
|
// Verify it was not added
|
|
isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI)
|
|
assert.NoError(t, err)
|
|
assert.False(t, isBlacklisted)
|
|
})
|
|
}
|
|
|
|
func TestTokenRepository_CreateAndGet(t *testing.T) {
|
|
mr, r := setupMiniRedis()
|
|
defer mr.Close()
|
|
|
|
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
|
|
ctx := context.Background()
|
|
|
|
t.Run("Create and GetAccessTokenByID", func(t *testing.T) {
|
|
now := time.Now()
|
|
token := entity.Token{
|
|
ID: "test-token-123",
|
|
UID: "user123",
|
|
DeviceID: "device123",
|
|
AccessToken: "access-token-123",
|
|
ExpiresIn: 3600,
|
|
AccessCreateAt: now,
|
|
RefreshToken: "refresh-token-123",
|
|
RefreshCreateAt: now,
|
|
RefreshExpiresIn: 86400,
|
|
}
|
|
|
|
// Create token
|
|
err := repo.Create(ctx, token)
|
|
assert.NoError(t, err)
|
|
|
|
// Get token by ID
|
|
retrievedToken, err := repo.GetAccessTokenByID(ctx, token.ID)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, token.ID, retrievedToken.ID)
|
|
assert.Equal(t, token.UID, retrievedToken.UID)
|
|
assert.Equal(t, token.AccessToken, retrievedToken.AccessToken)
|
|
})
|
|
|
|
t.Run("GetAccessTokensByUID", func(t *testing.T) {
|
|
uid := "user456"
|
|
now := time.Now()
|
|
tokens := []entity.Token{
|
|
{
|
|
ID: "token-1",
|
|
UID: uid,
|
|
DeviceID: "device1",
|
|
AccessToken: "access-1",
|
|
ExpiresIn: int(now.Add(time.Hour).Unix()),
|
|
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
|
|
},
|
|
{
|
|
ID: "token-2",
|
|
UID: uid,
|
|
DeviceID: "device2",
|
|
AccessToken: "access-2",
|
|
ExpiresIn: int(now.Add(time.Hour).Unix()),
|
|
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
|
|
},
|
|
}
|
|
|
|
// Create tokens
|
|
for _, token := range tokens {
|
|
err := repo.Create(ctx, token)
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
// Get tokens by UID
|
|
retrievedTokens, err := repo.GetAccessTokensByUID(ctx, uid)
|
|
assert.NoError(t, err)
|
|
assert.Len(t, retrievedTokens, 2)
|
|
|
|
// Verify all tokens belong to the user
|
|
for _, token := range retrievedTokens {
|
|
assert.Equal(t, uid, token.UID)
|
|
}
|
|
})
|
|
|
|
t.Run("GetAccessTokenCountByUID", func(t *testing.T) {
|
|
uid := "user789"
|
|
now := time.Now()
|
|
|
|
// Create multiple tokens for the user
|
|
for i := 0; i < 3; i++ {
|
|
token := entity.Token{
|
|
ID: "count-token-" + string(rune(i+'1')),
|
|
UID: uid,
|
|
DeviceID: "device" + string(rune(i+'1')),
|
|
AccessToken: "access-" + string(rune(i+'1')),
|
|
ExpiresIn: int(now.Add(time.Hour).Unix()),
|
|
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
|
|
}
|
|
err := repo.Create(ctx, token)
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
// Get count
|
|
count, err := repo.GetAccessTokenCountByUID(ctx, uid)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 3, count)
|
|
})
|
|
|
|
t.Run("Delete", func(t *testing.T) {
|
|
token := entity.Token{
|
|
ID: "delete-token",
|
|
UID: "delete-user",
|
|
DeviceID: "delete-device",
|
|
AccessToken: "delete-access",
|
|
RefreshToken: "delete-refresh",
|
|
ExpiresIn: 3600,
|
|
}
|
|
|
|
// Create token
|
|
err := repo.Create(ctx, token)
|
|
assert.NoError(t, err)
|
|
|
|
// Verify it exists
|
|
_, err = repo.GetAccessTokenByID(ctx, token.ID)
|
|
assert.NoError(t, err)
|
|
|
|
// Delete token
|
|
err = repo.Delete(ctx, token)
|
|
assert.NoError(t, err)
|
|
|
|
// Verify it's gone
|
|
_, err = repo.GetAccessTokenByID(ctx, token.ID)
|
|
assert.Error(t, err) // Should return error when not found
|
|
})
|
|
|
|
t.Run("DeleteAccessTokensByUID", func(t *testing.T) {
|
|
uid := "delete-user-uid"
|
|
now := time.Now()
|
|
|
|
// Create multiple tokens for the user
|
|
for i := 0; i < 2; i++ {
|
|
token := entity.Token{
|
|
ID: "delete-uid-token-" + string(rune(i+'1')),
|
|
UID: uid,
|
|
DeviceID: "device" + string(rune(i+'1')),
|
|
AccessToken: "access-" + string(rune(i+'1')),
|
|
ExpiresIn: int(now.Add(time.Hour).Unix()),
|
|
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
|
|
}
|
|
err := repo.Create(ctx, token)
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
// Verify tokens exist
|
|
count, err := repo.GetAccessTokenCountByUID(ctx, uid)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 2, count)
|
|
|
|
// Delete all tokens for the user
|
|
err = repo.DeleteAccessTokensByUID(ctx, uid)
|
|
assert.NoError(t, err)
|
|
|
|
// Verify tokens are gone
|
|
count, err = repo.GetAccessTokenCountByUID(ctx, uid)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 0, count)
|
|
})
|
|
}
|
|
|
|
func TestTokenRepository_OneTimeToken(t *testing.T) {
|
|
mr, r := setupMiniRedis()
|
|
defer mr.Close()
|
|
|
|
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
|
|
ctx := context.Background()
|
|
|
|
t.Run("CreateOneTimeToken", func(t *testing.T) {
|
|
now := time.Now()
|
|
// Create one-time token with ticket
|
|
token := entity.Token{
|
|
ID: "one-time-base-token",
|
|
UID: "user123",
|
|
AccessToken: "base-access-token",
|
|
ExpiresIn: int(now.Add(time.Hour).Unix()),
|
|
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
|
|
}
|
|
|
|
oneTimeKey := "one-time-key-123"
|
|
ticket := entity.Ticket{
|
|
Data: map[string]string{"uid": "user123"},
|
|
Token: token,
|
|
}
|
|
|
|
err := repo.CreateOneTimeToken(ctx, oneTimeKey, ticket, time.Minute)
|
|
assert.NoError(t, err)
|
|
})
|
|
|
|
|
|
t.Run("DeleteOneTimeToken", func(t *testing.T) {
|
|
// Create one-time tokens
|
|
keys := []string{"delete-key-1", "delete-key-2"}
|
|
ticket := entity.Ticket{
|
|
Data: map[string]string{"test": "data"},
|
|
Token: entity.Token{ID: "test-token"},
|
|
}
|
|
|
|
for _, key := range keys {
|
|
err := repo.CreateOneTimeToken(ctx, key, ticket, time.Minute)
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
// Delete one-time tokens
|
|
err := repo.DeleteOneTimeToken(ctx, keys, nil)
|
|
assert.NoError(t, err)
|
|
|
|
// Verify they're gone
|
|
for _, key := range keys {
|
|
_, err := repo.GetAccessTokenByOneTimeToken(ctx, key)
|
|
assert.Error(t, err)
|
|
}
|
|
})
|
|
}
|