1528 lines
39 KiB
Go
1528 lines
39 KiB
Go
package usecase
|
|
|
|
import (
|
|
"backend/internal/config"
|
|
"backend/pkg/permission/domain/entity"
|
|
"backend/pkg/permission/domain/token"
|
|
mockRepo "backend/pkg/permission/mock/repository"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v4"
|
|
"github.com/stretchr/testify/assert"
|
|
"go.uber.org/mock/gomock"
|
|
)
|
|
|
|
func TestTokenUseCase_NewToken(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
tests := []struct {
|
|
name string
|
|
req entity.AuthorizationReq
|
|
mockSetup func(*mockRepo.MockTokenRepository)
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "成功創建 Access Token",
|
|
req: entity.AuthorizationReq{
|
|
GrantType: "client_credentials",
|
|
Scope: "read",
|
|
DeviceID: "device123",
|
|
Account: "user123",
|
|
Role: "admin",
|
|
Data: map[string]string{
|
|
"uid": "user123",
|
|
},
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().Create(ctx, gomock.Any()).Return(nil)
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "成功創建 Refresh Token",
|
|
req: entity.AuthorizationReq{
|
|
GrantType: "client_credentials",
|
|
Scope: "read",
|
|
DeviceID: "device123",
|
|
Account: "user123",
|
|
Role: "admin",
|
|
IsRefreshToken: true,
|
|
Data: map[string]string{
|
|
"uid": "user123",
|
|
},
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().Create(ctx, gomock.Any()).Return(nil)
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "創建失敗 - Repository 錯誤",
|
|
req: entity.AuthorizationReq{
|
|
GrantType: "client_credentials",
|
|
Scope: "read",
|
|
Account: "user123",
|
|
Role: "admin",
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().Create(ctx, gomock.Any()).Return(errors.New("db error"))
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
defer mockCtrl.Finish()
|
|
|
|
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
|
|
tt.mockSetup(mockTokenRepo)
|
|
|
|
uc := MustTokenUseCase(TokenUseCaseParam{
|
|
TokenRepo: mockTokenRepo,
|
|
Config: &config.Config{
|
|
Token: struct {
|
|
AccessSecret string
|
|
RefreshSecret string
|
|
AccessTokenExpiry time.Duration
|
|
RefreshTokenExpiry time.Duration
|
|
OneTimeTokenExpiry time.Duration
|
|
MaxTokensPerUser int
|
|
MaxTokensPerDevice int
|
|
}{
|
|
AccessSecret: "test-secret",
|
|
AccessTokenExpiry: time.Hour,
|
|
RefreshTokenExpiry: time.Hour * 24,
|
|
RefreshSecret: "refresh-secret",
|
|
OneTimeTokenExpiry: time.Minute * 5,
|
|
MaxTokensPerUser: 10,
|
|
MaxTokensPerDevice: 5,
|
|
},
|
|
},
|
|
})
|
|
|
|
resp, err := uc.NewToken(ctx, tt.req)
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotEmpty(t, resp.AccessToken)
|
|
assert.Equal(t, token.TypeBearer.String(), resp.TokenType)
|
|
assert.Greater(t, resp.ExpiresIn, int64(0))
|
|
if tt.req.IsRefreshToken {
|
|
assert.NotEmpty(t, resp.RefreshToken)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTokenUseCase_RefreshToken(t *testing.T) {
|
|
ctx := context.Background()
|
|
now := time.Now().UTC()
|
|
expires := now.Add(time.Hour).Unix()
|
|
|
|
// 創建一個測試用的 token
|
|
testToken := entity.Token{
|
|
ID: "test-id",
|
|
UID: "user123",
|
|
DeviceID: "device123",
|
|
AccessToken: "",
|
|
RefreshToken: "refresh-token",
|
|
ExpiresIn: int(expires),
|
|
RefreshExpiresIn: int(now.Add(time.Hour * 24).Unix()),
|
|
AccessCreateAt: now,
|
|
RefreshCreateAt: now,
|
|
}
|
|
|
|
// 生成實際的 JWT token
|
|
claims := entity.Claims{
|
|
Data: map[string]string{
|
|
"id": testToken.ID,
|
|
"uid": "user123",
|
|
"role": "admin",
|
|
"scope": "read",
|
|
"account": "user123",
|
|
"deviceId": "device123",
|
|
},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: testToken.ID,
|
|
ExpiresAt: jwt.NewNumericDate(time.Unix(expires, 0)),
|
|
Issuer: "permission",
|
|
},
|
|
}
|
|
accessToken, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte("test-secret"))
|
|
testToken.AccessToken = accessToken
|
|
|
|
tests := []struct {
|
|
name string
|
|
req entity.RefreshTokenReq
|
|
mockSetup func(*mockRepo.MockTokenRepository)
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "成功刷新 Token",
|
|
req: entity.RefreshTokenReq{
|
|
Token: "refresh-token",
|
|
DeviceID: "device123",
|
|
Expires: 0,
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
GetAccessTokenByOneTimeToken(ctx, "refresh-token").
|
|
Return(testToken, nil)
|
|
mockTokenRepo.EXPECT().Create(ctx, gomock.Any()).Return(nil)
|
|
mockTokenRepo.EXPECT().Delete(ctx, testToken).Return(nil)
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "刷新失敗 - Token 不存在",
|
|
req: entity.RefreshTokenReq{
|
|
Token: "invalid-token",
|
|
DeviceID: "device123",
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
GetAccessTokenByOneTimeToken(ctx, "invalid-token").
|
|
Return(entity.Token{}, errors.New("token not found"))
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "刷新失敗 - 創建新 Token 失敗",
|
|
req: entity.RefreshTokenReq{
|
|
Token: "refresh-token",
|
|
DeviceID: "device123",
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
GetAccessTokenByOneTimeToken(ctx, "refresh-token").
|
|
Return(testToken, nil)
|
|
mockTokenRepo.EXPECT().Create(ctx, gomock.Any()).Return(errors.New("create error"))
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
defer mockCtrl.Finish()
|
|
|
|
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
|
|
tt.mockSetup(mockTokenRepo)
|
|
|
|
uc := MustTokenUseCase(TokenUseCaseParam{
|
|
TokenRepo: mockTokenRepo,
|
|
Config: &config.Config{
|
|
Token: struct {
|
|
AccessSecret string
|
|
RefreshSecret string
|
|
AccessTokenExpiry time.Duration
|
|
RefreshTokenExpiry time.Duration
|
|
OneTimeTokenExpiry time.Duration
|
|
MaxTokensPerUser int
|
|
MaxTokensPerDevice int
|
|
}{
|
|
AccessSecret: "test-secret",
|
|
AccessTokenExpiry: time.Hour,
|
|
RefreshTokenExpiry: time.Hour * 24,
|
|
RefreshSecret: "refresh-secret",
|
|
OneTimeTokenExpiry: time.Minute * 5,
|
|
MaxTokensPerUser: 10,
|
|
MaxTokensPerDevice: 5,
|
|
},
|
|
},
|
|
})
|
|
|
|
resp, err := uc.RefreshToken(ctx, tt.req)
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotEmpty(t, resp.Token)
|
|
assert.NotEmpty(t, resp.OneTimeToken)
|
|
assert.Greater(t, resp.ExpiresIn, int64(0))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTokenUseCase_ValidationToken(t *testing.T) {
|
|
ctx := context.Background()
|
|
now := time.Now().UTC()
|
|
expires := now.Add(time.Hour).Unix()
|
|
|
|
testToken := entity.Token{
|
|
ID: "test-id",
|
|
UID: "user123",
|
|
DeviceID: "device123",
|
|
ExpiresIn: int(expires),
|
|
RefreshExpiresIn: int(now.Add(time.Hour * 24).Unix()),
|
|
AccessCreateAt: now,
|
|
RefreshCreateAt: now,
|
|
}
|
|
|
|
claims := entity.Claims{
|
|
Data: map[string]string{
|
|
"id": testToken.ID,
|
|
"uid": "user123",
|
|
"role": "admin",
|
|
"scope": "read",
|
|
"account": "user123",
|
|
},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: testToken.ID,
|
|
ExpiresAt: jwt.NewNumericDate(time.Unix(expires, 0)),
|
|
Issuer: "permission",
|
|
},
|
|
}
|
|
validToken, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte("test-secret"))
|
|
testToken.AccessToken = validToken
|
|
|
|
tests := []struct {
|
|
name string
|
|
req entity.ValidationTokenReq
|
|
mockSetup func(*mockRepo.MockTokenRepository)
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "成功驗證 Token",
|
|
req: entity.ValidationTokenReq{
|
|
Token: validToken,
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
GetAccessTokenByID(ctx, "test-id").
|
|
Return(testToken, nil)
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "驗證失敗 - Token 無效",
|
|
req: entity.ValidationTokenReq{
|
|
Token: "invalid-token",
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
// parseClaims will fail, no repo call
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "驗證失敗 - Token 不存在",
|
|
req: entity.ValidationTokenReq{
|
|
Token: validToken,
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
GetAccessTokenByID(ctx, "test-id").
|
|
Return(entity.Token{}, errors.New("token not found"))
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
defer mockCtrl.Finish()
|
|
|
|
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
|
|
tt.mockSetup(mockTokenRepo)
|
|
|
|
uc := MustTokenUseCase(TokenUseCaseParam{
|
|
TokenRepo: mockTokenRepo,
|
|
Config: &config.Config{
|
|
Token: struct {
|
|
AccessSecret string
|
|
RefreshSecret string
|
|
AccessTokenExpiry time.Duration
|
|
RefreshTokenExpiry time.Duration
|
|
OneTimeTokenExpiry time.Duration
|
|
MaxTokensPerUser int
|
|
MaxTokensPerDevice int
|
|
}{
|
|
AccessSecret: "test-secret",
|
|
AccessTokenExpiry: time.Hour,
|
|
RefreshTokenExpiry: time.Hour * 24,
|
|
RefreshSecret: "refresh-secret",
|
|
OneTimeTokenExpiry: time.Minute * 5,
|
|
MaxTokensPerUser: 10,
|
|
MaxTokensPerDevice: 5,
|
|
},
|
|
},
|
|
})
|
|
|
|
resp, err := uc.ValidationToken(ctx, tt.req)
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, testToken.ID, resp.Token.ID)
|
|
assert.Equal(t, testToken.UID, resp.Token.UID)
|
|
assert.NotNil(t, resp.Data)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTokenUseCase_CancelToken(t *testing.T) {
|
|
ctx := context.Background()
|
|
now := time.Now().UTC()
|
|
expires := now.Add(time.Hour).Unix()
|
|
|
|
testToken := entity.Token{
|
|
ID: "test-id",
|
|
UID: "user123",
|
|
DeviceID: "device123",
|
|
ExpiresIn: int(expires),
|
|
RefreshExpiresIn: int(now.Add(time.Hour * 24).Unix()),
|
|
AccessCreateAt: now,
|
|
RefreshCreateAt: now,
|
|
}
|
|
|
|
claims := entity.Claims{
|
|
Data: map[string]string{
|
|
"id": testToken.ID,
|
|
"uid": "user123",
|
|
"role": "admin",
|
|
"scope": "read",
|
|
"account": "user123",
|
|
},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: testToken.ID,
|
|
ExpiresAt: jwt.NewNumericDate(time.Unix(expires, 0)),
|
|
Issuer: "permission",
|
|
},
|
|
}
|
|
validToken, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte("test-secret"))
|
|
testToken.AccessToken = validToken
|
|
|
|
tests := []struct {
|
|
name string
|
|
req entity.CancelTokenReq
|
|
mockSetup func(*mockRepo.MockTokenRepository)
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "成功取消 Token",
|
|
req: entity.CancelTokenReq{
|
|
Token: validToken,
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
GetAccessTokenByID(ctx, "test-id").
|
|
Return(testToken, nil)
|
|
mockTokenRepo.EXPECT().Delete(ctx, testToken).Return(nil)
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "取消失敗 - Token 無效",
|
|
req: entity.CancelTokenReq{
|
|
Token: "invalid-token",
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
// parseClaims will fail, no repo call
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "取消失敗 - Token 不存在",
|
|
req: entity.CancelTokenReq{
|
|
Token: validToken,
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
GetAccessTokenByID(ctx, "test-id").
|
|
Return(entity.Token{}, errors.New("token not found"))
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
defer mockCtrl.Finish()
|
|
|
|
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
|
|
tt.mockSetup(mockTokenRepo)
|
|
|
|
uc := MustTokenUseCase(TokenUseCaseParam{
|
|
TokenRepo: mockTokenRepo,
|
|
Config: &config.Config{
|
|
Token: struct {
|
|
AccessSecret string
|
|
RefreshSecret string
|
|
AccessTokenExpiry time.Duration
|
|
RefreshTokenExpiry time.Duration
|
|
OneTimeTokenExpiry time.Duration
|
|
MaxTokensPerUser int
|
|
MaxTokensPerDevice int
|
|
}{
|
|
AccessSecret: "test-secret",
|
|
AccessTokenExpiry: time.Hour,
|
|
RefreshTokenExpiry: time.Hour * 24,
|
|
RefreshSecret: "refresh-secret",
|
|
OneTimeTokenExpiry: time.Minute * 5,
|
|
MaxTokensPerUser: 10,
|
|
MaxTokensPerDevice: 5,
|
|
},
|
|
},
|
|
})
|
|
|
|
err := uc.CancelToken(ctx, tt.req)
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTokenUseCase_CancelTokens(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
tests := []struct {
|
|
name string
|
|
req entity.DoTokenByUIDReq
|
|
mockSetup func(*mockRepo.MockTokenRepository)
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "成功取消 UID 的所有 Token",
|
|
req: entity.DoTokenByUIDReq{
|
|
UID: "user123",
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
DeleteAccessTokensByUID(ctx, "user123").
|
|
Return(nil)
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "成功取消指定 ID 的 Token",
|
|
req: entity.DoTokenByUIDReq{
|
|
IDs: []string{"token1", "token2"},
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
DeleteAccessTokenByID(ctx, []string{"token1", "token2"}).
|
|
Return(nil)
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "成功取消 UID 和 ID 的 Token",
|
|
req: entity.DoTokenByUIDReq{
|
|
UID: "user123",
|
|
IDs: []string{"token1"},
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
DeleteAccessTokensByUID(ctx, "user123").
|
|
Return(nil)
|
|
mockTokenRepo.EXPECT().
|
|
DeleteAccessTokenByID(ctx, []string{"token1"}).
|
|
Return(nil)
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "取消失敗 - UID 刪除錯誤",
|
|
req: entity.DoTokenByUIDReq{
|
|
UID: "user123",
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
DeleteAccessTokensByUID(ctx, "user123").
|
|
Return(errors.New("delete error"))
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
defer mockCtrl.Finish()
|
|
|
|
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
|
|
tt.mockSetup(mockTokenRepo)
|
|
|
|
uc := MustTokenUseCase(TokenUseCaseParam{
|
|
TokenRepo: mockTokenRepo,
|
|
Config: &config.Config{
|
|
Token: struct {
|
|
AccessSecret string
|
|
RefreshSecret string
|
|
AccessTokenExpiry time.Duration
|
|
RefreshTokenExpiry time.Duration
|
|
OneTimeTokenExpiry time.Duration
|
|
MaxTokensPerUser int
|
|
MaxTokensPerDevice int
|
|
}{
|
|
AccessSecret: "test-secret",
|
|
AccessTokenExpiry: time.Hour,
|
|
RefreshTokenExpiry: time.Hour * 24,
|
|
RefreshSecret: "refresh-secret",
|
|
OneTimeTokenExpiry: time.Minute * 5,
|
|
MaxTokensPerUser: 10,
|
|
MaxTokensPerDevice: 5,
|
|
},
|
|
},
|
|
})
|
|
|
|
err := uc.CancelTokens(ctx, tt.req)
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTokenUseCase_CancelTokenByDeviceID(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
tests := []struct {
|
|
name string
|
|
req entity.DoTokenByDeviceIDReq
|
|
mockSetup func(*mockRepo.MockTokenRepository)
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "成功取消 DeviceID 的所有 Token",
|
|
req: entity.DoTokenByDeviceIDReq{
|
|
DeviceID: "device123",
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
DeleteAccessTokensByDeviceID(ctx, "device123").
|
|
Return(nil)
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "取消失敗 - 刪除錯誤",
|
|
req: entity.DoTokenByDeviceIDReq{
|
|
DeviceID: "device123",
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
DeleteAccessTokensByDeviceID(ctx, "device123").
|
|
Return(errors.New("delete error"))
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
defer mockCtrl.Finish()
|
|
|
|
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
|
|
tt.mockSetup(mockTokenRepo)
|
|
|
|
uc := MustTokenUseCase(TokenUseCaseParam{
|
|
TokenRepo: mockTokenRepo,
|
|
Config: &config.Config{
|
|
Token: struct {
|
|
AccessSecret string
|
|
RefreshSecret string
|
|
AccessTokenExpiry time.Duration
|
|
RefreshTokenExpiry time.Duration
|
|
OneTimeTokenExpiry time.Duration
|
|
MaxTokensPerUser int
|
|
MaxTokensPerDevice int
|
|
}{
|
|
AccessSecret: "test-secret",
|
|
AccessTokenExpiry: time.Hour,
|
|
RefreshTokenExpiry: time.Hour * 24,
|
|
RefreshSecret: "refresh-secret",
|
|
OneTimeTokenExpiry: time.Minute * 5,
|
|
MaxTokensPerUser: 10,
|
|
MaxTokensPerDevice: 5,
|
|
},
|
|
},
|
|
})
|
|
|
|
err := uc.CancelTokenByDeviceID(ctx, tt.req)
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTokenUseCase_GetUserTokensByUID(t *testing.T) {
|
|
ctx := context.Background()
|
|
now := time.Now().UTC()
|
|
|
|
testTokens := []entity.Token{
|
|
{
|
|
ID: "token1",
|
|
UID: "user123",
|
|
AccessToken: "access1",
|
|
RefreshToken: "refresh1",
|
|
ExpiresIn: int(now.Add(time.Hour).Unix()),
|
|
RefreshExpiresIn: int(now.Add(time.Hour * 24).Unix()),
|
|
},
|
|
{
|
|
ID: "token2",
|
|
UID: "user123",
|
|
AccessToken: "access2",
|
|
RefreshToken: "refresh2",
|
|
ExpiresIn: int(now.Add(time.Hour).Unix()),
|
|
RefreshExpiresIn: int(now.Add(time.Hour * 24).Unix()),
|
|
},
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
req entity.QueryTokenByUIDReq
|
|
mockSetup func(*mockRepo.MockTokenRepository)
|
|
wantCount int
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "成功獲取 UID 的所有 Token",
|
|
req: entity.QueryTokenByUIDReq{
|
|
UID: "user123",
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
GetAccessTokensByUID(ctx, "user123").
|
|
Return(testTokens, nil)
|
|
},
|
|
wantCount: 2,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "成功獲取但沒有 Token",
|
|
req: entity.QueryTokenByUIDReq{
|
|
UID: "user999",
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
GetAccessTokensByUID(ctx, "user999").
|
|
Return([]entity.Token{}, nil)
|
|
},
|
|
wantCount: 0,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "獲取失敗 - Repository 錯誤",
|
|
req: entity.QueryTokenByUIDReq{
|
|
UID: "user123",
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
GetAccessTokensByUID(ctx, "user123").
|
|
Return(nil, errors.New("db error"))
|
|
},
|
|
wantCount: 0,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
defer mockCtrl.Finish()
|
|
|
|
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
|
|
tt.mockSetup(mockTokenRepo)
|
|
|
|
uc := MustTokenUseCase(TokenUseCaseParam{
|
|
TokenRepo: mockTokenRepo,
|
|
Config: &config.Config{
|
|
Token: struct {
|
|
AccessSecret string
|
|
RefreshSecret string
|
|
AccessTokenExpiry time.Duration
|
|
RefreshTokenExpiry time.Duration
|
|
OneTimeTokenExpiry time.Duration
|
|
MaxTokensPerUser int
|
|
MaxTokensPerDevice int
|
|
}{
|
|
AccessSecret: "test-secret",
|
|
AccessTokenExpiry: time.Hour,
|
|
RefreshTokenExpiry: time.Hour * 24,
|
|
RefreshSecret: "refresh-secret",
|
|
OneTimeTokenExpiry: time.Minute * 5,
|
|
MaxTokensPerUser: 10,
|
|
MaxTokensPerDevice: 5,
|
|
},
|
|
},
|
|
})
|
|
|
|
resp, err := uc.GetUserTokensByUID(ctx, tt.req)
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.Len(t, resp, tt.wantCount)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTokenUseCase_GetUserTokensByDeviceID(t *testing.T) {
|
|
ctx := context.Background()
|
|
now := time.Now().UTC()
|
|
|
|
testTokens := []entity.Token{
|
|
{
|
|
ID: "token1",
|
|
UID: "user123",
|
|
DeviceID: "device123",
|
|
AccessToken: "access1",
|
|
RefreshToken: "refresh1",
|
|
ExpiresIn: int(now.Add(time.Hour).Unix()),
|
|
RefreshExpiresIn: int(now.Add(time.Hour * 24).Unix()),
|
|
},
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
req entity.DoTokenByDeviceIDReq
|
|
mockSetup func(*mockRepo.MockTokenRepository)
|
|
wantCount int
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "成功獲取 DeviceID 的所有 Token",
|
|
req: entity.DoTokenByDeviceIDReq{
|
|
DeviceID: "device123",
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
GetAccessTokensByDeviceID(ctx, "device123").
|
|
Return(testTokens, nil)
|
|
},
|
|
wantCount: 1,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "獲取失敗 - Repository 錯誤",
|
|
req: entity.DoTokenByDeviceIDReq{
|
|
DeviceID: "device123",
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
GetAccessTokensByDeviceID(ctx, "device123").
|
|
Return(nil, errors.New("db error"))
|
|
},
|
|
wantCount: 0,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
defer mockCtrl.Finish()
|
|
|
|
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
|
|
tt.mockSetup(mockTokenRepo)
|
|
|
|
uc := MustTokenUseCase(TokenUseCaseParam{
|
|
TokenRepo: mockTokenRepo,
|
|
Config: &config.Config{
|
|
Token: struct {
|
|
AccessSecret string
|
|
RefreshSecret string
|
|
AccessTokenExpiry time.Duration
|
|
RefreshTokenExpiry time.Duration
|
|
OneTimeTokenExpiry time.Duration
|
|
MaxTokensPerUser int
|
|
MaxTokensPerDevice int
|
|
}{
|
|
AccessSecret: "test-secret",
|
|
AccessTokenExpiry: time.Hour,
|
|
RefreshTokenExpiry: time.Hour * 24,
|
|
RefreshSecret: "refresh-secret",
|
|
OneTimeTokenExpiry: time.Minute * 5,
|
|
MaxTokensPerUser: 10,
|
|
MaxTokensPerDevice: 5,
|
|
},
|
|
},
|
|
})
|
|
|
|
resp, err := uc.GetUserTokensByDeviceID(ctx, tt.req)
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.Len(t, resp, tt.wantCount)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTokenUseCase_NewOneTimeToken(t *testing.T) {
|
|
ctx := context.Background()
|
|
now := time.Now().UTC()
|
|
expires := now.Add(time.Hour).Unix()
|
|
|
|
testToken := entity.Token{
|
|
ID: "test-id",
|
|
UID: "user123",
|
|
DeviceID: "device123",
|
|
ExpiresIn: int(expires),
|
|
RefreshExpiresIn: int(now.Add(time.Hour * 24).Unix()),
|
|
AccessCreateAt: now,
|
|
RefreshCreateAt: now,
|
|
}
|
|
|
|
claims := entity.Claims{
|
|
Data: map[string]string{
|
|
"id": testToken.ID,
|
|
"uid": "user123",
|
|
"role": "admin",
|
|
"scope": "read",
|
|
"account": "user123",
|
|
},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: testToken.ID,
|
|
ExpiresAt: jwt.NewNumericDate(time.Unix(expires, 0)),
|
|
Issuer: "permission",
|
|
},
|
|
}
|
|
validToken, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte("test-secret"))
|
|
testToken.AccessToken = validToken
|
|
|
|
tests := []struct {
|
|
name string
|
|
req entity.CreateOneTimeTokenReq
|
|
mockSetup func(*mockRepo.MockTokenRepository)
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "成功創建 OneTimeToken",
|
|
req: entity.CreateOneTimeTokenReq{
|
|
Token: validToken,
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
GetAccessTokenByID(ctx, "test-id").
|
|
Return(testToken, nil)
|
|
mockTokenRepo.EXPECT().
|
|
CreateOneTimeToken(ctx, gomock.Any(), gomock.Any(), time.Minute).
|
|
Return(nil)
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "創建失敗 - Token 無效",
|
|
req: entity.CreateOneTimeTokenReq{
|
|
Token: "invalid-token",
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
// parseClaims will fail, no repo call
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "創建失敗 - Token 不存在",
|
|
req: entity.CreateOneTimeTokenReq{
|
|
Token: validToken,
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
GetAccessTokenByID(ctx, "test-id").
|
|
Return(entity.Token{}, errors.New("token not found"))
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "創建失敗 - Repository 錯誤",
|
|
req: entity.CreateOneTimeTokenReq{
|
|
Token: validToken,
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
GetAccessTokenByID(ctx, "test-id").
|
|
Return(testToken, nil)
|
|
mockTokenRepo.EXPECT().
|
|
CreateOneTimeToken(ctx, gomock.Any(), gomock.Any(), time.Minute).
|
|
Return(errors.New("create error"))
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
defer mockCtrl.Finish()
|
|
|
|
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
|
|
tt.mockSetup(mockTokenRepo)
|
|
|
|
uc := MustTokenUseCase(TokenUseCaseParam{
|
|
TokenRepo: mockTokenRepo,
|
|
Config: &config.Config{
|
|
Token: struct {
|
|
AccessSecret string
|
|
RefreshSecret string
|
|
AccessTokenExpiry time.Duration
|
|
RefreshTokenExpiry time.Duration
|
|
OneTimeTokenExpiry time.Duration
|
|
MaxTokensPerUser int
|
|
MaxTokensPerDevice int
|
|
}{
|
|
AccessSecret: "test-secret",
|
|
AccessTokenExpiry: time.Hour,
|
|
RefreshTokenExpiry: time.Hour * 24,
|
|
RefreshSecret: "refresh-secret",
|
|
OneTimeTokenExpiry: time.Minute * 5,
|
|
MaxTokensPerUser: 10,
|
|
MaxTokensPerDevice: 5,
|
|
},
|
|
},
|
|
})
|
|
|
|
resp, err := uc.NewOneTimeToken(ctx, tt.req)
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotEmpty(t, resp.OneTimeToken)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTokenUseCase_CancelOneTimeToken(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
tests := []struct {
|
|
name string
|
|
req entity.CancelOneTimeTokenReq
|
|
mockSetup func(*mockRepo.MockTokenRepository)
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "成功取消 OneTimeToken",
|
|
req: entity.CancelOneTimeTokenReq{
|
|
Token: []string{"one-time-token"},
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
DeleteOneTimeToken(ctx, []string{"one-time-token"}, nil).
|
|
Return(nil)
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "取消失敗 - Repository 錯誤",
|
|
req: entity.CancelOneTimeTokenReq{
|
|
Token: []string{"one-time-token"},
|
|
},
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
DeleteOneTimeToken(ctx, []string{"one-time-token"}, nil).
|
|
Return(errors.New("delete error"))
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
defer mockCtrl.Finish()
|
|
|
|
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
|
|
tt.mockSetup(mockTokenRepo)
|
|
|
|
uc := MustTokenUseCase(TokenUseCaseParam{
|
|
TokenRepo: mockTokenRepo,
|
|
Config: &config.Config{
|
|
Token: struct {
|
|
AccessSecret string
|
|
RefreshSecret string
|
|
AccessTokenExpiry time.Duration
|
|
RefreshTokenExpiry time.Duration
|
|
OneTimeTokenExpiry time.Duration
|
|
MaxTokensPerUser int
|
|
MaxTokensPerDevice int
|
|
}{
|
|
AccessSecret: "test-secret",
|
|
AccessTokenExpiry: time.Hour,
|
|
RefreshTokenExpiry: time.Hour * 24,
|
|
RefreshSecret: "refresh-secret",
|
|
OneTimeTokenExpiry: time.Minute * 5,
|
|
MaxTokensPerUser: 10,
|
|
MaxTokensPerDevice: 5,
|
|
},
|
|
},
|
|
})
|
|
|
|
err := uc.CancelOneTimeToken(ctx, tt.req)
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTokenUseCase_ReadTokenBasicData(t *testing.T) {
|
|
ctx := context.Background()
|
|
now := time.Now().UTC()
|
|
expires := now.Add(time.Hour).Unix()
|
|
|
|
claims := entity.Claims{
|
|
Data: map[string]string{
|
|
"uid": "user123",
|
|
"role": "admin",
|
|
"scope": "read",
|
|
"account": "user123",
|
|
},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: "test-id",
|
|
ExpiresAt: jwt.NewNumericDate(time.Unix(expires, 0)),
|
|
Issuer: "permission",
|
|
},
|
|
}
|
|
validToken, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte("test-secret"))
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
mockSetup func(*mockRepo.MockTokenRepository)
|
|
wantErr bool
|
|
checkData func(*testing.T, map[string]string)
|
|
}{
|
|
{
|
|
name: "成功讀取 Token 數據",
|
|
token: validToken,
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
// No repo call needed for ReadTokenBasicData
|
|
},
|
|
wantErr: false,
|
|
checkData: func(t *testing.T, data map[string]string) {
|
|
assert.Equal(t, "user123", data["uid"])
|
|
assert.Equal(t, "admin", data["role"])
|
|
assert.Equal(t, "read", data["scope"])
|
|
},
|
|
},
|
|
{
|
|
name: "讀取失敗 - Token 無效",
|
|
token: "invalid-token",
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
// No repo call
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
defer mockCtrl.Finish()
|
|
|
|
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
|
|
tt.mockSetup(mockTokenRepo)
|
|
|
|
uc := MustTokenUseCase(TokenUseCaseParam{
|
|
TokenRepo: mockTokenRepo,
|
|
Config: &config.Config{
|
|
Token: struct {
|
|
AccessSecret string
|
|
RefreshSecret string
|
|
AccessTokenExpiry time.Duration
|
|
RefreshTokenExpiry time.Duration
|
|
OneTimeTokenExpiry time.Duration
|
|
MaxTokensPerUser int
|
|
MaxTokensPerDevice int
|
|
}{
|
|
AccessSecret: "test-secret",
|
|
AccessTokenExpiry: time.Hour,
|
|
RefreshTokenExpiry: time.Hour * 24,
|
|
RefreshSecret: "refresh-secret",
|
|
OneTimeTokenExpiry: time.Minute * 5,
|
|
MaxTokensPerUser: 10,
|
|
MaxTokensPerDevice: 5,
|
|
},
|
|
},
|
|
})
|
|
|
|
data, err := uc.ReadTokenBasicData(ctx, tt.token)
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, data)
|
|
if tt.checkData != nil {
|
|
tt.checkData(t, data)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTokenUseCase_BlacklistToken(t *testing.T) {
|
|
ctx := context.Background()
|
|
now := time.Now().UTC()
|
|
expires := now.Add(time.Hour).Unix()
|
|
|
|
claims := entity.Claims{
|
|
Data: map[string]string{
|
|
"uid": "user123",
|
|
"role": "admin",
|
|
"scope": "read",
|
|
"account": "user123",
|
|
},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: "test-jti",
|
|
ExpiresAt: jwt.NewNumericDate(time.Unix(expires, 0)),
|
|
Issuer: "permission",
|
|
},
|
|
}
|
|
validToken, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte("test-secret"))
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
reason string
|
|
mockSetup func(*mockRepo.MockTokenRepository)
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "成功將 Token 加入黑名單",
|
|
token: validToken,
|
|
reason: "user logout",
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
AddToBlacklist(ctx, gomock.Any(), time.Duration(0)).
|
|
Return(nil)
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "加入黑名單失敗 - Token 無效",
|
|
token: "invalid-token",
|
|
reason: "test",
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
// parseToken will fail, no repo call
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "加入黑名單失敗 - Repository 錯誤",
|
|
token: validToken,
|
|
reason: "test",
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
AddToBlacklist(ctx, gomock.Any(), time.Duration(0)).
|
|
Return(errors.New("redis error"))
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
defer mockCtrl.Finish()
|
|
|
|
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
|
|
tt.mockSetup(mockTokenRepo)
|
|
|
|
uc := MustTokenUseCase(TokenUseCaseParam{
|
|
TokenRepo: mockTokenRepo,
|
|
Config: &config.Config{
|
|
Token: struct {
|
|
AccessSecret string
|
|
RefreshSecret string
|
|
AccessTokenExpiry time.Duration
|
|
RefreshTokenExpiry time.Duration
|
|
OneTimeTokenExpiry time.Duration
|
|
MaxTokensPerUser int
|
|
MaxTokensPerDevice int
|
|
}{
|
|
AccessSecret: "test-secret",
|
|
AccessTokenExpiry: time.Hour,
|
|
RefreshTokenExpiry: time.Hour * 24,
|
|
RefreshSecret: "refresh-secret",
|
|
OneTimeTokenExpiry: time.Minute * 5,
|
|
MaxTokensPerUser: 10,
|
|
MaxTokensPerDevice: 5,
|
|
},
|
|
},
|
|
})
|
|
|
|
err := uc.BlacklistToken(ctx, tt.token, tt.reason)
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTokenUseCase_IsTokenBlacklisted(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
tests := []struct {
|
|
name string
|
|
jti string
|
|
mockSetup func(*mockRepo.MockTokenRepository)
|
|
wantBlacklisted bool
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "Token 在黑名單中",
|
|
jti: "test-jti",
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
IsBlacklisted(ctx, "test-jti").
|
|
Return(true, nil)
|
|
},
|
|
wantBlacklisted: true,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "Token 不在黑名單中",
|
|
jti: "test-jti",
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
IsBlacklisted(ctx, "test-jti").
|
|
Return(false, nil)
|
|
},
|
|
wantBlacklisted: false,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "檢查失敗 - Repository 錯誤",
|
|
jti: "test-jti",
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
IsBlacklisted(ctx, "test-jti").
|
|
Return(false, errors.New("redis error"))
|
|
},
|
|
wantBlacklisted: false,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
defer mockCtrl.Finish()
|
|
|
|
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
|
|
tt.mockSetup(mockTokenRepo)
|
|
|
|
uc := MustTokenUseCase(TokenUseCaseParam{
|
|
TokenRepo: mockTokenRepo,
|
|
Config: &config.Config{
|
|
Token: struct {
|
|
AccessSecret string
|
|
RefreshSecret string
|
|
AccessTokenExpiry time.Duration
|
|
RefreshTokenExpiry time.Duration
|
|
OneTimeTokenExpiry time.Duration
|
|
MaxTokensPerUser int
|
|
MaxTokensPerDevice int
|
|
}{
|
|
AccessSecret: "test-secret",
|
|
AccessTokenExpiry: time.Hour,
|
|
RefreshTokenExpiry: time.Hour * 24,
|
|
RefreshSecret: "refresh-secret",
|
|
OneTimeTokenExpiry: time.Minute * 5,
|
|
MaxTokensPerUser: 10,
|
|
MaxTokensPerDevice: 5,
|
|
},
|
|
},
|
|
})
|
|
|
|
isBlacklisted, err := uc.IsTokenBlacklisted(ctx, tt.jti)
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.wantBlacklisted, isBlacklisted)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTokenUseCase_BlacklistAllUserTokens(t *testing.T) {
|
|
ctx := context.Background()
|
|
now := time.Now().UTC()
|
|
expires := now.Add(time.Hour).Unix()
|
|
|
|
// 創建測試用的 tokens
|
|
testToken1 := entity.Token{
|
|
ID: "token1",
|
|
UID: "user123",
|
|
ExpiresIn: int(expires),
|
|
RefreshExpiresIn: int(now.Add(time.Hour * 24).Unix()),
|
|
}
|
|
|
|
testToken2 := entity.Token{
|
|
ID: "token2",
|
|
UID: "user123",
|
|
ExpiresIn: int(expires),
|
|
RefreshExpiresIn: int(now.Add(time.Hour * 24).Unix()),
|
|
}
|
|
|
|
// 為每個 token 創建 JWT
|
|
claims1 := entity.Claims{
|
|
Data: map[string]string{
|
|
"uid": "user123",
|
|
"jti": "jti1",
|
|
"exp": fmt.Sprintf("%d", expires),
|
|
},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: "jti1",
|
|
ExpiresAt: jwt.NewNumericDate(time.Unix(expires, 0)),
|
|
Issuer: "permission",
|
|
},
|
|
}
|
|
token1JWT, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, claims1).SignedString([]byte("test-secret"))
|
|
testToken1.AccessToken = token1JWT
|
|
|
|
claims2 := entity.Claims{
|
|
Data: map[string]string{
|
|
"uid": "user123",
|
|
"jti": "jti2",
|
|
"exp": fmt.Sprintf("%d", expires),
|
|
},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: "jti2",
|
|
ExpiresAt: jwt.NewNumericDate(time.Unix(expires, 0)),
|
|
Issuer: "permission",
|
|
},
|
|
}
|
|
token2JWT, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, claims2).SignedString([]byte("test-secret"))
|
|
testToken2.AccessToken = token2JWT
|
|
|
|
tests := []struct {
|
|
name string
|
|
uid string
|
|
reason string
|
|
mockSetup func(*mockRepo.MockTokenRepository)
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "成功將用戶所有 Token 加入黑名單",
|
|
uid: "user123",
|
|
reason: "security issue",
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
GetAccessTokensByUID(ctx, "user123").
|
|
Return([]entity.Token{testToken1, testToken2}, nil)
|
|
mockTokenRepo.EXPECT().
|
|
AddToBlacklist(ctx, gomock.Any(), time.Duration(0)).
|
|
Return(nil).
|
|
Times(2)
|
|
mockTokenRepo.EXPECT().
|
|
DeleteAccessTokensByUID(ctx, "user123").
|
|
Return(nil)
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "成功但用戶沒有 Token",
|
|
uid: "user999",
|
|
reason: "test",
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
GetAccessTokensByUID(ctx, "user999").
|
|
Return([]entity.Token{}, nil)
|
|
mockTokenRepo.EXPECT().
|
|
DeleteAccessTokensByUID(ctx, "user999").
|
|
Return(nil)
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "失敗 - 獲取用戶 Token 失敗",
|
|
uid: "user123",
|
|
reason: "test",
|
|
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
|
|
mockTokenRepo.EXPECT().
|
|
GetAccessTokensByUID(ctx, "user123").
|
|
Return(nil, errors.New("db error"))
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
defer mockCtrl.Finish()
|
|
|
|
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
|
|
tt.mockSetup(mockTokenRepo)
|
|
|
|
uc := MustTokenUseCase(TokenUseCaseParam{
|
|
TokenRepo: mockTokenRepo,
|
|
Config: &config.Config{
|
|
Token: struct {
|
|
AccessSecret string
|
|
RefreshSecret string
|
|
AccessTokenExpiry time.Duration
|
|
RefreshTokenExpiry time.Duration
|
|
OneTimeTokenExpiry time.Duration
|
|
MaxTokensPerUser int
|
|
MaxTokensPerDevice int
|
|
}{
|
|
AccessSecret: "test-secret",
|
|
AccessTokenExpiry: time.Hour,
|
|
RefreshTokenExpiry: time.Hour * 24,
|
|
RefreshSecret: "refresh-secret",
|
|
OneTimeTokenExpiry: time.Minute * 5,
|
|
MaxTokensPerUser: 10,
|
|
MaxTokensPerDevice: 5,
|
|
},
|
|
},
|
|
})
|
|
|
|
err := uc.BlacklistAllUserTokens(ctx, tt.uid, tt.reason)
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|