package repository import ( "context" "testing" "time" "backend/pkg/permission/domain/entity" "github.com/alicebob/miniredis/v2" "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/stores/redis" ) func setupMiniRedis() (*miniredis.Miniredis, *redis.Redis) { // 啟動 setupMiniRedis 作為模擬的 Redis 服務 mr, err := miniredis.Run() if err != nil { panic("failed to start miniRedis: " + err.Error()) } // 使用 setupMiniRedis 的地址配置 go-zero Redis 客戶端 redisConf := redis.RedisConf{ Host: mr.Addr(), Type: "node", } r := redis.MustNewRedis(redisConf) return mr, r } func TestTokenRepository_Blacklist(t *testing.T) { mr, r := setupMiniRedis() defer mr.Close() repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}} ctx := context.Background() t.Run("AddToBlacklist", func(t *testing.T) { entry := &entity.BlacklistEntry{ JTI: "test-jti-123", UID: "user123", TokenID: "token123", Reason: "user logout", ExpiresAt: time.Now().Add(time.Hour).Unix(), CreatedAt: time.Now().Unix(), } err := repo.AddToBlacklist(ctx, entry, time.Hour) assert.NoError(t, err) // Verify it was added isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI) assert.NoError(t, err) assert.True(t, isBlacklisted) }) t.Run("IsBlacklisted - not found", func(t *testing.T) { isBlacklisted, err := repo.IsBlacklisted(ctx, "non-existent-jti") assert.NoError(t, err) assert.False(t, isBlacklisted) }) t.Run("RemoveFromBlacklist", func(t *testing.T) { // First add an entry entry := &entity.BlacklistEntry{ JTI: "test-jti-456", UID: "user456", TokenID: "token456", ExpiresAt: time.Now().Add(time.Hour).Unix(), CreatedAt: time.Now().Unix(), } err := repo.AddToBlacklist(ctx, entry, time.Hour) assert.NoError(t, err) // Verify it exists isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI) assert.NoError(t, err) assert.True(t, isBlacklisted) // Remove it err = repo.RemoveFromBlacklist(ctx, entry.JTI) assert.NoError(t, err) // Verify it's gone isBlacklisted, err = repo.IsBlacklisted(ctx, entry.JTI) assert.NoError(t, err) assert.False(t, isBlacklisted) }) t.Run("GetBlacklistedTokensByUID", func(t *testing.T) { uid := "user789" // Add multiple entries for the same user entries := []*entity.BlacklistEntry{ { JTI: "jti-1", UID: uid, TokenID: "token-1", ExpiresAt: time.Now().Add(time.Hour).Unix(), CreatedAt: time.Now().Unix(), }, { JTI: "jti-2", UID: uid, TokenID: "token-2", ExpiresAt: time.Now().Add(time.Hour).Unix(), CreatedAt: time.Now().Unix(), }, { JTI: "jti-3", UID: "different-user", TokenID: "token-3", ExpiresAt: time.Now().Add(time.Hour).Unix(), CreatedAt: time.Now().Unix(), }, } for _, entry := range entries { err := repo.AddToBlacklist(ctx, entry, time.Hour) assert.NoError(t, err) } // Get blacklisted tokens for the user userEntries, err := repo.GetBlacklistedTokensByUID(ctx, uid) assert.NoError(t, err) assert.Len(t, userEntries, 2) // Should only get entries for the specific user // Verify all returned entries belong to the correct user for _, entry := range userEntries { assert.Equal(t, uid, entry.UID) } }) t.Run("AddToBlacklist with zero TTL", func(t *testing.T) { entry := &entity.BlacklistEntry{ JTI: "test-jti-zero-ttl", UID: "user-zero-ttl", ExpiresAt: time.Now().Add(time.Hour).Unix(), CreatedAt: time.Now().Unix(), } // Test with zero TTL - should calculate from ExpiresAt err := repo.AddToBlacklist(ctx, entry, 0) assert.NoError(t, err) // Verify it was added isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI) assert.NoError(t, err) assert.True(t, isBlacklisted) }) t.Run("AddToBlacklist with expired token", func(t *testing.T) { entry := &entity.BlacklistEntry{ JTI: "test-jti-expired", UID: "user-expired", ExpiresAt: time.Now().Add(-time.Hour).Unix(), // Already expired CreatedAt: time.Now().Unix(), } // Should not add expired token to blacklist err := repo.AddToBlacklist(ctx, entry, 0) assert.NoError(t, err) // No error, but token won't be added // Verify it was not added isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI) assert.NoError(t, err) assert.False(t, isBlacklisted) }) } func TestTokenRepository_CreateAndGet(t *testing.T) { mr, r := setupMiniRedis() defer mr.Close() repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}} ctx := context.Background() t.Run("Create and GetAccessTokenByID", func(t *testing.T) { now := time.Now() token := entity.Token{ ID: "test-token-123", UID: "user123", DeviceID: "device123", AccessToken: "access-token-123", ExpiresIn: 3600, AccessCreateAt: now, RefreshToken: "refresh-token-123", RefreshCreateAt: now, RefreshExpiresIn: 86400, } // Create token err := repo.Create(ctx, token) assert.NoError(t, err) // Get token by ID retrievedToken, err := repo.GetAccessTokenByID(ctx, token.ID) assert.NoError(t, err) assert.Equal(t, token.ID, retrievedToken.ID) assert.Equal(t, token.UID, retrievedToken.UID) assert.Equal(t, token.AccessToken, retrievedToken.AccessToken) }) t.Run("GetAccessTokensByUID", func(t *testing.T) { uid := "user456" now := time.Now() tokens := []entity.Token{ { ID: "token-1", UID: uid, DeviceID: "device1", AccessToken: "access-1", ExpiresIn: int(now.Add(time.Hour).Unix()), RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()), }, { ID: "token-2", UID: uid, DeviceID: "device2", AccessToken: "access-2", ExpiresIn: int(now.Add(time.Hour).Unix()), RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()), }, } // Create tokens for _, token := range tokens { err := repo.Create(ctx, token) assert.NoError(t, err) } // Get tokens by UID retrievedTokens, err := repo.GetAccessTokensByUID(ctx, uid) assert.NoError(t, err) assert.Len(t, retrievedTokens, 2) // Verify all tokens belong to the user for _, token := range retrievedTokens { assert.Equal(t, uid, token.UID) } }) t.Run("GetAccessTokenCountByUID", func(t *testing.T) { uid := "user789" now := time.Now() // Create multiple tokens for the user for i := 0; i < 3; i++ { token := entity.Token{ ID: "count-token-" + string(rune(i+'1')), UID: uid, DeviceID: "device" + string(rune(i+'1')), AccessToken: "access-" + string(rune(i+'1')), ExpiresIn: int(now.Add(time.Hour).Unix()), RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()), } err := repo.Create(ctx, token) assert.NoError(t, err) } // Get count count, err := repo.GetAccessTokenCountByUID(ctx, uid) assert.NoError(t, err) assert.Equal(t, 3, count) }) t.Run("Delete", func(t *testing.T) { token := entity.Token{ ID: "delete-token", UID: "delete-user", DeviceID: "delete-device", AccessToken: "delete-access", RefreshToken: "delete-refresh", ExpiresIn: 3600, } // Create token err := repo.Create(ctx, token) assert.NoError(t, err) // Verify it exists _, err = repo.GetAccessTokenByID(ctx, token.ID) assert.NoError(t, err) // Delete token err = repo.Delete(ctx, token) assert.NoError(t, err) // Verify it's gone _, err = repo.GetAccessTokenByID(ctx, token.ID) assert.Error(t, err) // Should return error when not found }) t.Run("DeleteAccessTokensByUID", func(t *testing.T) { uid := "delete-user-uid" now := time.Now() // Create multiple tokens for the user for i := 0; i < 2; i++ { token := entity.Token{ ID: "delete-uid-token-" + string(rune(i+'1')), UID: uid, DeviceID: "device" + string(rune(i+'1')), AccessToken: "access-" + string(rune(i+'1')), ExpiresIn: int(now.Add(time.Hour).Unix()), RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()), } err := repo.Create(ctx, token) assert.NoError(t, err) } // Verify tokens exist count, err := repo.GetAccessTokenCountByUID(ctx, uid) assert.NoError(t, err) assert.Equal(t, 2, count) // Delete all tokens for the user err = repo.DeleteAccessTokensByUID(ctx, uid) assert.NoError(t, err) // Verify tokens are gone count, err = repo.GetAccessTokenCountByUID(ctx, uid) assert.NoError(t, err) assert.Equal(t, 0, count) }) } func TestTokenRepository_OneTimeToken(t *testing.T) { mr, r := setupMiniRedis() defer mr.Close() repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}} ctx := context.Background() t.Run("CreateOneTimeToken", func(t *testing.T) { now := time.Now() // Create one-time token with ticket token := entity.Token{ ID: "one-time-base-token", UID: "user123", AccessToken: "base-access-token", ExpiresIn: int(now.Add(time.Hour).Unix()), RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()), } oneTimeKey := "one-time-key-123" ticket := entity.Ticket{ Data: map[string]string{"uid": "user123"}, Token: token, } err := repo.CreateOneTimeToken(ctx, oneTimeKey, ticket, time.Minute) assert.NoError(t, err) }) t.Run("DeleteOneTimeToken", func(t *testing.T) { // Create one-time tokens keys := []string{"delete-key-1", "delete-key-2"} ticket := entity.Ticket{ Data: map[string]string{"test": "data"}, Token: entity.Token{ID: "test-token"}, } for _, key := range keys { err := repo.CreateOneTimeToken(ctx, key, ticket, time.Minute) assert.NoError(t, err) } // Delete one-time tokens err := repo.DeleteOneTimeToken(ctx, keys, nil) assert.NoError(t, err) // Verify they're gone for _, key := range keys { _, err := repo.GetAccessTokenByOneTimeToken(ctx, key) assert.Error(t, err) } }) }