backend/pkg/permission/repository/token_blacklist_test.go

383 lines
10 KiB
Go
Raw Normal View History

2025-10-06 08:28:39 +00:00
package repository
import (
"context"
"testing"
"time"
"backend/pkg/permission/domain/entity"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stores/redis"
)
func setupMiniRedis() (*miniredis.Miniredis, *redis.Redis) {
// 啟動 setupMiniRedis 作為模擬的 Redis 服務
mr, err := miniredis.Run()
if err != nil {
panic("failed to start miniRedis: " + err.Error())
}
// 使用 setupMiniRedis 的地址配置 go-zero Redis 客戶端
redisConf := redis.RedisConf{
Host: mr.Addr(),
Type: "node",
}
r := redis.MustNewRedis(redisConf)
return mr, r
}
func TestTokenRepository_Blacklist(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
ctx := context.Background()
t.Run("AddToBlacklist", func(t *testing.T) {
entry := &entity.BlacklistEntry{
JTI: "test-jti-123",
UID: "user123",
TokenID: "token123",
Reason: "user logout",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
}
err := repo.AddToBlacklist(ctx, entry, time.Hour)
assert.NoError(t, err)
// Verify it was added
isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI)
assert.NoError(t, err)
assert.True(t, isBlacklisted)
})
t.Run("IsBlacklisted - not found", func(t *testing.T) {
isBlacklisted, err := repo.IsBlacklisted(ctx, "non-existent-jti")
assert.NoError(t, err)
assert.False(t, isBlacklisted)
})
t.Run("RemoveFromBlacklist", func(t *testing.T) {
// First add an entry
entry := &entity.BlacklistEntry{
JTI: "test-jti-456",
UID: "user456",
TokenID: "token456",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
}
err := repo.AddToBlacklist(ctx, entry, time.Hour)
assert.NoError(t, err)
// Verify it exists
isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI)
assert.NoError(t, err)
assert.True(t, isBlacklisted)
// Remove it
err = repo.RemoveFromBlacklist(ctx, entry.JTI)
assert.NoError(t, err)
// Verify it's gone
isBlacklisted, err = repo.IsBlacklisted(ctx, entry.JTI)
assert.NoError(t, err)
assert.False(t, isBlacklisted)
})
t.Run("GetBlacklistedTokensByUID", func(t *testing.T) {
uid := "user789"
// Add multiple entries for the same user
entries := []*entity.BlacklistEntry{
{
JTI: "jti-1",
UID: uid,
TokenID: "token-1",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
},
{
JTI: "jti-2",
UID: uid,
TokenID: "token-2",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
},
{
JTI: "jti-3",
UID: "different-user",
TokenID: "token-3",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
},
}
for _, entry := range entries {
err := repo.AddToBlacklist(ctx, entry, time.Hour)
assert.NoError(t, err)
}
// Get blacklisted tokens for the user
userEntries, err := repo.GetBlacklistedTokensByUID(ctx, uid)
assert.NoError(t, err)
assert.Len(t, userEntries, 2) // Should only get entries for the specific user
// Verify all returned entries belong to the correct user
for _, entry := range userEntries {
assert.Equal(t, uid, entry.UID)
}
})
t.Run("AddToBlacklist with zero TTL", func(t *testing.T) {
entry := &entity.BlacklistEntry{
JTI: "test-jti-zero-ttl",
UID: "user-zero-ttl",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
}
// Test with zero TTL - should calculate from ExpiresAt
err := repo.AddToBlacklist(ctx, entry, 0)
assert.NoError(t, err)
// Verify it was added
isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI)
assert.NoError(t, err)
assert.True(t, isBlacklisted)
})
t.Run("AddToBlacklist with expired token", func(t *testing.T) {
entry := &entity.BlacklistEntry{
JTI: "test-jti-expired",
UID: "user-expired",
ExpiresAt: time.Now().Add(-time.Hour).Unix(), // Already expired
CreatedAt: time.Now().Unix(),
}
// Should not add expired token to blacklist
err := repo.AddToBlacklist(ctx, entry, 0)
assert.NoError(t, err) // No error, but token won't be added
// Verify it was not added
isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI)
assert.NoError(t, err)
assert.False(t, isBlacklisted)
})
}
func TestTokenRepository_CreateAndGet(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
ctx := context.Background()
t.Run("Create and GetAccessTokenByID", func(t *testing.T) {
now := time.Now()
token := entity.Token{
ID: "test-token-123",
UID: "user123",
DeviceID: "device123",
AccessToken: "access-token-123",
ExpiresIn: 3600,
AccessCreateAt: now,
RefreshToken: "refresh-token-123",
RefreshCreateAt: now,
RefreshExpiresIn: 86400,
}
// Create token
err := repo.Create(ctx, token)
assert.NoError(t, err)
// Get token by ID
retrievedToken, err := repo.GetAccessTokenByID(ctx, token.ID)
assert.NoError(t, err)
assert.Equal(t, token.ID, retrievedToken.ID)
assert.Equal(t, token.UID, retrievedToken.UID)
assert.Equal(t, token.AccessToken, retrievedToken.AccessToken)
})
t.Run("GetAccessTokensByUID", func(t *testing.T) {
uid := "user456"
now := time.Now()
tokens := []entity.Token{
{
ID: "token-1",
UID: uid,
DeviceID: "device1",
AccessToken: "access-1",
ExpiresIn: int(now.Add(time.Hour).Unix()),
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
},
{
ID: "token-2",
UID: uid,
DeviceID: "device2",
AccessToken: "access-2",
ExpiresIn: int(now.Add(time.Hour).Unix()),
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
},
}
// Create tokens
for _, token := range tokens {
err := repo.Create(ctx, token)
assert.NoError(t, err)
}
// Get tokens by UID
retrievedTokens, err := repo.GetAccessTokensByUID(ctx, uid)
assert.NoError(t, err)
assert.Len(t, retrievedTokens, 2)
// Verify all tokens belong to the user
for _, token := range retrievedTokens {
assert.Equal(t, uid, token.UID)
}
})
t.Run("GetAccessTokenCountByUID", func(t *testing.T) {
uid := "user789"
now := time.Now()
// Create multiple tokens for the user
for i := 0; i < 3; i++ {
token := entity.Token{
ID: "count-token-" + string(rune(i+'1')),
UID: uid,
DeviceID: "device" + string(rune(i+'1')),
AccessToken: "access-" + string(rune(i+'1')),
ExpiresIn: int(now.Add(time.Hour).Unix()),
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
}
err := repo.Create(ctx, token)
assert.NoError(t, err)
}
// Get count
count, err := repo.GetAccessTokenCountByUID(ctx, uid)
assert.NoError(t, err)
assert.Equal(t, 3, count)
})
t.Run("Delete", func(t *testing.T) {
token := entity.Token{
ID: "delete-token",
UID: "delete-user",
DeviceID: "delete-device",
AccessToken: "delete-access",
RefreshToken: "delete-refresh",
ExpiresIn: 3600,
}
// Create token
err := repo.Create(ctx, token)
assert.NoError(t, err)
// Verify it exists
_, err = repo.GetAccessTokenByID(ctx, token.ID)
assert.NoError(t, err)
// Delete token
err = repo.Delete(ctx, token)
assert.NoError(t, err)
// Verify it's gone
_, err = repo.GetAccessTokenByID(ctx, token.ID)
assert.Error(t, err) // Should return error when not found
})
t.Run("DeleteAccessTokensByUID", func(t *testing.T) {
uid := "delete-user-uid"
now := time.Now()
// Create multiple tokens for the user
for i := 0; i < 2; i++ {
token := entity.Token{
ID: "delete-uid-token-" + string(rune(i+'1')),
UID: uid,
DeviceID: "device" + string(rune(i+'1')),
AccessToken: "access-" + string(rune(i+'1')),
ExpiresIn: int(now.Add(time.Hour).Unix()),
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
}
err := repo.Create(ctx, token)
assert.NoError(t, err)
}
// Verify tokens exist
count, err := repo.GetAccessTokenCountByUID(ctx, uid)
assert.NoError(t, err)
assert.Equal(t, 2, count)
// Delete all tokens for the user
err = repo.DeleteAccessTokensByUID(ctx, uid)
assert.NoError(t, err)
// Verify tokens are gone
count, err = repo.GetAccessTokenCountByUID(ctx, uid)
assert.NoError(t, err)
assert.Equal(t, 0, count)
})
}
func TestTokenRepository_OneTimeToken(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
ctx := context.Background()
t.Run("CreateOneTimeToken", func(t *testing.T) {
now := time.Now()
// Create one-time token with ticket
token := entity.Token{
ID: "one-time-base-token",
UID: "user123",
AccessToken: "base-access-token",
ExpiresIn: int(now.Add(time.Hour).Unix()),
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
}
oneTimeKey := "one-time-key-123"
ticket := entity.Ticket{
Data: map[string]string{"uid": "user123"},
Token: token,
}
err := repo.CreateOneTimeToken(ctx, oneTimeKey, ticket, time.Minute)
assert.NoError(t, err)
})
t.Run("DeleteOneTimeToken", func(t *testing.T) {
// Create one-time tokens
keys := []string{"delete-key-1", "delete-key-2"}
ticket := entity.Ticket{
Data: map[string]string{"test": "data"},
Token: entity.Token{ID: "test-token"},
}
for _, key := range keys {
err := repo.CreateOneTimeToken(ctx, key, ticket, time.Minute)
assert.NoError(t, err)
}
// Delete one-time tokens
err := repo.DeleteOneTimeToken(ctx, keys, nil)
assert.NoError(t, err)
// Verify they're gone
for _, key := range keys {
_, err := repo.GetAccessTokenByOneTimeToken(ctx, key)
assert.Error(t, err)
}
})
}