backend/pkg/permission/usecase/token_test.go

1528 lines
39 KiB
Go

package usecase
import (
"backend/internal/config"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/token"
mockRepo "backend/pkg/permission/mock/repository"
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
)
func TestTokenUseCase_NewToken(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
req entity.AuthorizationReq
mockSetup func(*mockRepo.MockTokenRepository)
wantErr bool
}{
{
name: "成功創建 Access Token",
req: entity.AuthorizationReq{
GrantType: "client_credentials",
Scope: "read",
DeviceID: "device123",
Account: "user123",
Role: "admin",
Data: map[string]string{
"uid": "user123",
},
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().Create(ctx, gomock.Any()).Return(nil)
},
wantErr: false,
},
{
name: "成功創建 Refresh Token",
req: entity.AuthorizationReq{
GrantType: "client_credentials",
Scope: "read",
DeviceID: "device123",
Account: "user123",
Role: "admin",
IsRefreshToken: true,
Data: map[string]string{
"uid": "user123",
},
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().Create(ctx, gomock.Any()).Return(nil)
},
wantErr: false,
},
{
name: "創建失敗 - Repository 錯誤",
req: entity.AuthorizationReq{
GrantType: "client_credentials",
Scope: "read",
Account: "user123",
Role: "admin",
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().Create(ctx, gomock.Any()).Return(errors.New("db error"))
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
tt.mockSetup(mockTokenRepo)
uc := MustTokenUseCase(TokenUseCaseParam{
TokenRepo: mockTokenRepo,
Config: &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-secret",
AccessTokenExpiry: time.Hour,
RefreshTokenExpiry: time.Hour * 24,
RefreshSecret: "refresh-secret",
OneTimeTokenExpiry: time.Minute * 5,
MaxTokensPerUser: 10,
MaxTokensPerDevice: 5,
},
},
})
resp, err := uc.NewToken(ctx, tt.req)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, resp.AccessToken)
assert.Equal(t, token.TypeBearer.String(), resp.TokenType)
assert.Greater(t, resp.ExpiresIn, int64(0))
if tt.req.IsRefreshToken {
assert.NotEmpty(t, resp.RefreshToken)
}
}
})
}
}
func TestTokenUseCase_RefreshToken(t *testing.T) {
ctx := context.Background()
now := time.Now().UTC()
expires := now.Add(time.Hour).Unix()
// 創建一個測試用的 token
testToken := entity.Token{
ID: "test-id",
UID: "user123",
DeviceID: "device123",
AccessToken: "",
RefreshToken: "refresh-token",
ExpiresIn: int(expires),
RefreshExpiresIn: int(now.Add(time.Hour * 24).Unix()),
AccessCreateAt: now,
RefreshCreateAt: now,
}
// 生成實際的 JWT token
claims := entity.Claims{
Data: map[string]string{
"id": testToken.ID,
"uid": "user123",
"role": "admin",
"scope": "read",
"account": "user123",
"deviceId": "device123",
},
RegisteredClaims: jwt.RegisteredClaims{
ID: testToken.ID,
ExpiresAt: jwt.NewNumericDate(time.Unix(expires, 0)),
Issuer: "permission",
},
}
accessToken, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte("test-secret"))
testToken.AccessToken = accessToken
tests := []struct {
name string
req entity.RefreshTokenReq
mockSetup func(*mockRepo.MockTokenRepository)
wantErr bool
}{
{
name: "成功刷新 Token",
req: entity.RefreshTokenReq{
Token: "refresh-token",
DeviceID: "device123",
Expires: 0,
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
GetAccessTokenByOneTimeToken(ctx, "refresh-token").
Return(testToken, nil)
mockTokenRepo.EXPECT().Create(ctx, gomock.Any()).Return(nil)
mockTokenRepo.EXPECT().Delete(ctx, testToken).Return(nil)
},
wantErr: false,
},
{
name: "刷新失敗 - Token 不存在",
req: entity.RefreshTokenReq{
Token: "invalid-token",
DeviceID: "device123",
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
GetAccessTokenByOneTimeToken(ctx, "invalid-token").
Return(entity.Token{}, errors.New("token not found"))
},
wantErr: true,
},
{
name: "刷新失敗 - 創建新 Token 失敗",
req: entity.RefreshTokenReq{
Token: "refresh-token",
DeviceID: "device123",
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
GetAccessTokenByOneTimeToken(ctx, "refresh-token").
Return(testToken, nil)
mockTokenRepo.EXPECT().Create(ctx, gomock.Any()).Return(errors.New("create error"))
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
tt.mockSetup(mockTokenRepo)
uc := MustTokenUseCase(TokenUseCaseParam{
TokenRepo: mockTokenRepo,
Config: &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-secret",
AccessTokenExpiry: time.Hour,
RefreshTokenExpiry: time.Hour * 24,
RefreshSecret: "refresh-secret",
OneTimeTokenExpiry: time.Minute * 5,
MaxTokensPerUser: 10,
MaxTokensPerDevice: 5,
},
},
})
resp, err := uc.RefreshToken(ctx, tt.req)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, resp.Token)
assert.NotEmpty(t, resp.OneTimeToken)
assert.Greater(t, resp.ExpiresIn, int64(0))
}
})
}
}
func TestTokenUseCase_ValidationToken(t *testing.T) {
ctx := context.Background()
now := time.Now().UTC()
expires := now.Add(time.Hour).Unix()
testToken := entity.Token{
ID: "test-id",
UID: "user123",
DeviceID: "device123",
ExpiresIn: int(expires),
RefreshExpiresIn: int(now.Add(time.Hour * 24).Unix()),
AccessCreateAt: now,
RefreshCreateAt: now,
}
claims := entity.Claims{
Data: map[string]string{
"id": testToken.ID,
"uid": "user123",
"role": "admin",
"scope": "read",
"account": "user123",
},
RegisteredClaims: jwt.RegisteredClaims{
ID: testToken.ID,
ExpiresAt: jwt.NewNumericDate(time.Unix(expires, 0)),
Issuer: "permission",
},
}
validToken, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte("test-secret"))
testToken.AccessToken = validToken
tests := []struct {
name string
req entity.ValidationTokenReq
mockSetup func(*mockRepo.MockTokenRepository)
wantErr bool
}{
{
name: "成功驗證 Token",
req: entity.ValidationTokenReq{
Token: validToken,
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
GetAccessTokenByID(ctx, "test-id").
Return(testToken, nil)
},
wantErr: false,
},
{
name: "驗證失敗 - Token 無效",
req: entity.ValidationTokenReq{
Token: "invalid-token",
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
// parseClaims will fail, no repo call
},
wantErr: true,
},
{
name: "驗證失敗 - Token 不存在",
req: entity.ValidationTokenReq{
Token: validToken,
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
GetAccessTokenByID(ctx, "test-id").
Return(entity.Token{}, errors.New("token not found"))
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
tt.mockSetup(mockTokenRepo)
uc := MustTokenUseCase(TokenUseCaseParam{
TokenRepo: mockTokenRepo,
Config: &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-secret",
AccessTokenExpiry: time.Hour,
RefreshTokenExpiry: time.Hour * 24,
RefreshSecret: "refresh-secret",
OneTimeTokenExpiry: time.Minute * 5,
MaxTokensPerUser: 10,
MaxTokensPerDevice: 5,
},
},
})
resp, err := uc.ValidationToken(ctx, tt.req)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, testToken.ID, resp.Token.ID)
assert.Equal(t, testToken.UID, resp.Token.UID)
assert.NotNil(t, resp.Data)
}
})
}
}
func TestTokenUseCase_CancelToken(t *testing.T) {
ctx := context.Background()
now := time.Now().UTC()
expires := now.Add(time.Hour).Unix()
testToken := entity.Token{
ID: "test-id",
UID: "user123",
DeviceID: "device123",
ExpiresIn: int(expires),
RefreshExpiresIn: int(now.Add(time.Hour * 24).Unix()),
AccessCreateAt: now,
RefreshCreateAt: now,
}
claims := entity.Claims{
Data: map[string]string{
"id": testToken.ID,
"uid": "user123",
"role": "admin",
"scope": "read",
"account": "user123",
},
RegisteredClaims: jwt.RegisteredClaims{
ID: testToken.ID,
ExpiresAt: jwt.NewNumericDate(time.Unix(expires, 0)),
Issuer: "permission",
},
}
validToken, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte("test-secret"))
testToken.AccessToken = validToken
tests := []struct {
name string
req entity.CancelTokenReq
mockSetup func(*mockRepo.MockTokenRepository)
wantErr bool
}{
{
name: "成功取消 Token",
req: entity.CancelTokenReq{
Token: validToken,
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
GetAccessTokenByID(ctx, "test-id").
Return(testToken, nil)
mockTokenRepo.EXPECT().Delete(ctx, testToken).Return(nil)
},
wantErr: false,
},
{
name: "取消失敗 - Token 無效",
req: entity.CancelTokenReq{
Token: "invalid-token",
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
// parseClaims will fail, no repo call
},
wantErr: true,
},
{
name: "取消失敗 - Token 不存在",
req: entity.CancelTokenReq{
Token: validToken,
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
GetAccessTokenByID(ctx, "test-id").
Return(entity.Token{}, errors.New("token not found"))
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
tt.mockSetup(mockTokenRepo)
uc := MustTokenUseCase(TokenUseCaseParam{
TokenRepo: mockTokenRepo,
Config: &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-secret",
AccessTokenExpiry: time.Hour,
RefreshTokenExpiry: time.Hour * 24,
RefreshSecret: "refresh-secret",
OneTimeTokenExpiry: time.Minute * 5,
MaxTokensPerUser: 10,
MaxTokensPerDevice: 5,
},
},
})
err := uc.CancelToken(ctx, tt.req)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestTokenUseCase_CancelTokens(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
req entity.DoTokenByUIDReq
mockSetup func(*mockRepo.MockTokenRepository)
wantErr bool
}{
{
name: "成功取消 UID 的所有 Token",
req: entity.DoTokenByUIDReq{
UID: "user123",
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
DeleteAccessTokensByUID(ctx, "user123").
Return(nil)
},
wantErr: false,
},
{
name: "成功取消指定 ID 的 Token",
req: entity.DoTokenByUIDReq{
IDs: []string{"token1", "token2"},
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
DeleteAccessTokenByID(ctx, []string{"token1", "token2"}).
Return(nil)
},
wantErr: false,
},
{
name: "成功取消 UID 和 ID 的 Token",
req: entity.DoTokenByUIDReq{
UID: "user123",
IDs: []string{"token1"},
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
DeleteAccessTokensByUID(ctx, "user123").
Return(nil)
mockTokenRepo.EXPECT().
DeleteAccessTokenByID(ctx, []string{"token1"}).
Return(nil)
},
wantErr: false,
},
{
name: "取消失敗 - UID 刪除錯誤",
req: entity.DoTokenByUIDReq{
UID: "user123",
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
DeleteAccessTokensByUID(ctx, "user123").
Return(errors.New("delete error"))
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
tt.mockSetup(mockTokenRepo)
uc := MustTokenUseCase(TokenUseCaseParam{
TokenRepo: mockTokenRepo,
Config: &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-secret",
AccessTokenExpiry: time.Hour,
RefreshTokenExpiry: time.Hour * 24,
RefreshSecret: "refresh-secret",
OneTimeTokenExpiry: time.Minute * 5,
MaxTokensPerUser: 10,
MaxTokensPerDevice: 5,
},
},
})
err := uc.CancelTokens(ctx, tt.req)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestTokenUseCase_CancelTokenByDeviceID(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
req entity.DoTokenByDeviceIDReq
mockSetup func(*mockRepo.MockTokenRepository)
wantErr bool
}{
{
name: "成功取消 DeviceID 的所有 Token",
req: entity.DoTokenByDeviceIDReq{
DeviceID: "device123",
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
DeleteAccessTokensByDeviceID(ctx, "device123").
Return(nil)
},
wantErr: false,
},
{
name: "取消失敗 - 刪除錯誤",
req: entity.DoTokenByDeviceIDReq{
DeviceID: "device123",
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
DeleteAccessTokensByDeviceID(ctx, "device123").
Return(errors.New("delete error"))
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
tt.mockSetup(mockTokenRepo)
uc := MustTokenUseCase(TokenUseCaseParam{
TokenRepo: mockTokenRepo,
Config: &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-secret",
AccessTokenExpiry: time.Hour,
RefreshTokenExpiry: time.Hour * 24,
RefreshSecret: "refresh-secret",
OneTimeTokenExpiry: time.Minute * 5,
MaxTokensPerUser: 10,
MaxTokensPerDevice: 5,
},
},
})
err := uc.CancelTokenByDeviceID(ctx, tt.req)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestTokenUseCase_GetUserTokensByUID(t *testing.T) {
ctx := context.Background()
now := time.Now().UTC()
testTokens := []entity.Token{
{
ID: "token1",
UID: "user123",
AccessToken: "access1",
RefreshToken: "refresh1",
ExpiresIn: int(now.Add(time.Hour).Unix()),
RefreshExpiresIn: int(now.Add(time.Hour * 24).Unix()),
},
{
ID: "token2",
UID: "user123",
AccessToken: "access2",
RefreshToken: "refresh2",
ExpiresIn: int(now.Add(time.Hour).Unix()),
RefreshExpiresIn: int(now.Add(time.Hour * 24).Unix()),
},
}
tests := []struct {
name string
req entity.QueryTokenByUIDReq
mockSetup func(*mockRepo.MockTokenRepository)
wantCount int
wantErr bool
}{
{
name: "成功獲取 UID 的所有 Token",
req: entity.QueryTokenByUIDReq{
UID: "user123",
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
GetAccessTokensByUID(ctx, "user123").
Return(testTokens, nil)
},
wantCount: 2,
wantErr: false,
},
{
name: "成功獲取但沒有 Token",
req: entity.QueryTokenByUIDReq{
UID: "user999",
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
GetAccessTokensByUID(ctx, "user999").
Return([]entity.Token{}, nil)
},
wantCount: 0,
wantErr: false,
},
{
name: "獲取失敗 - Repository 錯誤",
req: entity.QueryTokenByUIDReq{
UID: "user123",
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
GetAccessTokensByUID(ctx, "user123").
Return(nil, errors.New("db error"))
},
wantCount: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
tt.mockSetup(mockTokenRepo)
uc := MustTokenUseCase(TokenUseCaseParam{
TokenRepo: mockTokenRepo,
Config: &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-secret",
AccessTokenExpiry: time.Hour,
RefreshTokenExpiry: time.Hour * 24,
RefreshSecret: "refresh-secret",
OneTimeTokenExpiry: time.Minute * 5,
MaxTokensPerUser: 10,
MaxTokensPerDevice: 5,
},
},
})
resp, err := uc.GetUserTokensByUID(ctx, tt.req)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Len(t, resp, tt.wantCount)
}
})
}
}
func TestTokenUseCase_GetUserTokensByDeviceID(t *testing.T) {
ctx := context.Background()
now := time.Now().UTC()
testTokens := []entity.Token{
{
ID: "token1",
UID: "user123",
DeviceID: "device123",
AccessToken: "access1",
RefreshToken: "refresh1",
ExpiresIn: int(now.Add(time.Hour).Unix()),
RefreshExpiresIn: int(now.Add(time.Hour * 24).Unix()),
},
}
tests := []struct {
name string
req entity.DoTokenByDeviceIDReq
mockSetup func(*mockRepo.MockTokenRepository)
wantCount int
wantErr bool
}{
{
name: "成功獲取 DeviceID 的所有 Token",
req: entity.DoTokenByDeviceIDReq{
DeviceID: "device123",
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
GetAccessTokensByDeviceID(ctx, "device123").
Return(testTokens, nil)
},
wantCount: 1,
wantErr: false,
},
{
name: "獲取失敗 - Repository 錯誤",
req: entity.DoTokenByDeviceIDReq{
DeviceID: "device123",
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
GetAccessTokensByDeviceID(ctx, "device123").
Return(nil, errors.New("db error"))
},
wantCount: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
tt.mockSetup(mockTokenRepo)
uc := MustTokenUseCase(TokenUseCaseParam{
TokenRepo: mockTokenRepo,
Config: &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-secret",
AccessTokenExpiry: time.Hour,
RefreshTokenExpiry: time.Hour * 24,
RefreshSecret: "refresh-secret",
OneTimeTokenExpiry: time.Minute * 5,
MaxTokensPerUser: 10,
MaxTokensPerDevice: 5,
},
},
})
resp, err := uc.GetUserTokensByDeviceID(ctx, tt.req)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Len(t, resp, tt.wantCount)
}
})
}
}
func TestTokenUseCase_NewOneTimeToken(t *testing.T) {
ctx := context.Background()
now := time.Now().UTC()
expires := now.Add(time.Hour).Unix()
testToken := entity.Token{
ID: "test-id",
UID: "user123",
DeviceID: "device123",
ExpiresIn: int(expires),
RefreshExpiresIn: int(now.Add(time.Hour * 24).Unix()),
AccessCreateAt: now,
RefreshCreateAt: now,
}
claims := entity.Claims{
Data: map[string]string{
"id": testToken.ID,
"uid": "user123",
"role": "admin",
"scope": "read",
"account": "user123",
},
RegisteredClaims: jwt.RegisteredClaims{
ID: testToken.ID,
ExpiresAt: jwt.NewNumericDate(time.Unix(expires, 0)),
Issuer: "permission",
},
}
validToken, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte("test-secret"))
testToken.AccessToken = validToken
tests := []struct {
name string
req entity.CreateOneTimeTokenReq
mockSetup func(*mockRepo.MockTokenRepository)
wantErr bool
}{
{
name: "成功創建 OneTimeToken",
req: entity.CreateOneTimeTokenReq{
Token: validToken,
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
GetAccessTokenByID(ctx, "test-id").
Return(testToken, nil)
mockTokenRepo.EXPECT().
CreateOneTimeToken(ctx, gomock.Any(), gomock.Any(), time.Minute).
Return(nil)
},
wantErr: false,
},
{
name: "創建失敗 - Token 無效",
req: entity.CreateOneTimeTokenReq{
Token: "invalid-token",
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
// parseClaims will fail, no repo call
},
wantErr: true,
},
{
name: "創建失敗 - Token 不存在",
req: entity.CreateOneTimeTokenReq{
Token: validToken,
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
GetAccessTokenByID(ctx, "test-id").
Return(entity.Token{}, errors.New("token not found"))
},
wantErr: true,
},
{
name: "創建失敗 - Repository 錯誤",
req: entity.CreateOneTimeTokenReq{
Token: validToken,
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
GetAccessTokenByID(ctx, "test-id").
Return(testToken, nil)
mockTokenRepo.EXPECT().
CreateOneTimeToken(ctx, gomock.Any(), gomock.Any(), time.Minute).
Return(errors.New("create error"))
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
tt.mockSetup(mockTokenRepo)
uc := MustTokenUseCase(TokenUseCaseParam{
TokenRepo: mockTokenRepo,
Config: &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-secret",
AccessTokenExpiry: time.Hour,
RefreshTokenExpiry: time.Hour * 24,
RefreshSecret: "refresh-secret",
OneTimeTokenExpiry: time.Minute * 5,
MaxTokensPerUser: 10,
MaxTokensPerDevice: 5,
},
},
})
resp, err := uc.NewOneTimeToken(ctx, tt.req)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, resp.OneTimeToken)
}
})
}
}
func TestTokenUseCase_CancelOneTimeToken(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
req entity.CancelOneTimeTokenReq
mockSetup func(*mockRepo.MockTokenRepository)
wantErr bool
}{
{
name: "成功取消 OneTimeToken",
req: entity.CancelOneTimeTokenReq{
Token: []string{"one-time-token"},
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
DeleteOneTimeToken(ctx, []string{"one-time-token"}, nil).
Return(nil)
},
wantErr: false,
},
{
name: "取消失敗 - Repository 錯誤",
req: entity.CancelOneTimeTokenReq{
Token: []string{"one-time-token"},
},
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
DeleteOneTimeToken(ctx, []string{"one-time-token"}, nil).
Return(errors.New("delete error"))
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
tt.mockSetup(mockTokenRepo)
uc := MustTokenUseCase(TokenUseCaseParam{
TokenRepo: mockTokenRepo,
Config: &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-secret",
AccessTokenExpiry: time.Hour,
RefreshTokenExpiry: time.Hour * 24,
RefreshSecret: "refresh-secret",
OneTimeTokenExpiry: time.Minute * 5,
MaxTokensPerUser: 10,
MaxTokensPerDevice: 5,
},
},
})
err := uc.CancelOneTimeToken(ctx, tt.req)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestTokenUseCase_ReadTokenBasicData(t *testing.T) {
ctx := context.Background()
now := time.Now().UTC()
expires := now.Add(time.Hour).Unix()
claims := entity.Claims{
Data: map[string]string{
"uid": "user123",
"role": "admin",
"scope": "read",
"account": "user123",
},
RegisteredClaims: jwt.RegisteredClaims{
ID: "test-id",
ExpiresAt: jwt.NewNumericDate(time.Unix(expires, 0)),
Issuer: "permission",
},
}
validToken, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte("test-secret"))
tests := []struct {
name string
token string
mockSetup func(*mockRepo.MockTokenRepository)
wantErr bool
checkData func(*testing.T, map[string]string)
}{
{
name: "成功讀取 Token 數據",
token: validToken,
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
// No repo call needed for ReadTokenBasicData
},
wantErr: false,
checkData: func(t *testing.T, data map[string]string) {
assert.Equal(t, "user123", data["uid"])
assert.Equal(t, "admin", data["role"])
assert.Equal(t, "read", data["scope"])
},
},
{
name: "讀取失敗 - Token 無效",
token: "invalid-token",
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
// No repo call
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
tt.mockSetup(mockTokenRepo)
uc := MustTokenUseCase(TokenUseCaseParam{
TokenRepo: mockTokenRepo,
Config: &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-secret",
AccessTokenExpiry: time.Hour,
RefreshTokenExpiry: time.Hour * 24,
RefreshSecret: "refresh-secret",
OneTimeTokenExpiry: time.Minute * 5,
MaxTokensPerUser: 10,
MaxTokensPerDevice: 5,
},
},
})
data, err := uc.ReadTokenBasicData(ctx, tt.token)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotNil(t, data)
if tt.checkData != nil {
tt.checkData(t, data)
}
}
})
}
}
func TestTokenUseCase_BlacklistToken(t *testing.T) {
ctx := context.Background()
now := time.Now().UTC()
expires := now.Add(time.Hour).Unix()
claims := entity.Claims{
Data: map[string]string{
"uid": "user123",
"role": "admin",
"scope": "read",
"account": "user123",
},
RegisteredClaims: jwt.RegisteredClaims{
ID: "test-jti",
ExpiresAt: jwt.NewNumericDate(time.Unix(expires, 0)),
Issuer: "permission",
},
}
validToken, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte("test-secret"))
tests := []struct {
name string
token string
reason string
mockSetup func(*mockRepo.MockTokenRepository)
wantErr bool
}{
{
name: "成功將 Token 加入黑名單",
token: validToken,
reason: "user logout",
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
AddToBlacklist(ctx, gomock.Any(), time.Duration(0)).
Return(nil)
},
wantErr: false,
},
{
name: "加入黑名單失敗 - Token 無效",
token: "invalid-token",
reason: "test",
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
// parseToken will fail, no repo call
},
wantErr: true,
},
{
name: "加入黑名單失敗 - Repository 錯誤",
token: validToken,
reason: "test",
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
AddToBlacklist(ctx, gomock.Any(), time.Duration(0)).
Return(errors.New("redis error"))
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
tt.mockSetup(mockTokenRepo)
uc := MustTokenUseCase(TokenUseCaseParam{
TokenRepo: mockTokenRepo,
Config: &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-secret",
AccessTokenExpiry: time.Hour,
RefreshTokenExpiry: time.Hour * 24,
RefreshSecret: "refresh-secret",
OneTimeTokenExpiry: time.Minute * 5,
MaxTokensPerUser: 10,
MaxTokensPerDevice: 5,
},
},
})
err := uc.BlacklistToken(ctx, tt.token, tt.reason)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestTokenUseCase_IsTokenBlacklisted(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
jti string
mockSetup func(*mockRepo.MockTokenRepository)
wantBlacklisted bool
wantErr bool
}{
{
name: "Token 在黑名單中",
jti: "test-jti",
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
IsBlacklisted(ctx, "test-jti").
Return(true, nil)
},
wantBlacklisted: true,
wantErr: false,
},
{
name: "Token 不在黑名單中",
jti: "test-jti",
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
IsBlacklisted(ctx, "test-jti").
Return(false, nil)
},
wantBlacklisted: false,
wantErr: false,
},
{
name: "檢查失敗 - Repository 錯誤",
jti: "test-jti",
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
IsBlacklisted(ctx, "test-jti").
Return(false, errors.New("redis error"))
},
wantBlacklisted: false,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
tt.mockSetup(mockTokenRepo)
uc := MustTokenUseCase(TokenUseCaseParam{
TokenRepo: mockTokenRepo,
Config: &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-secret",
AccessTokenExpiry: time.Hour,
RefreshTokenExpiry: time.Hour * 24,
RefreshSecret: "refresh-secret",
OneTimeTokenExpiry: time.Minute * 5,
MaxTokensPerUser: 10,
MaxTokensPerDevice: 5,
},
},
})
isBlacklisted, err := uc.IsTokenBlacklisted(ctx, tt.jti)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.wantBlacklisted, isBlacklisted)
}
})
}
}
func TestTokenUseCase_BlacklistAllUserTokens(t *testing.T) {
ctx := context.Background()
now := time.Now().UTC()
expires := now.Add(time.Hour).Unix()
// 創建測試用的 tokens
testToken1 := entity.Token{
ID: "token1",
UID: "user123",
ExpiresIn: int(expires),
RefreshExpiresIn: int(now.Add(time.Hour * 24).Unix()),
}
testToken2 := entity.Token{
ID: "token2",
UID: "user123",
ExpiresIn: int(expires),
RefreshExpiresIn: int(now.Add(time.Hour * 24).Unix()),
}
// 為每個 token 創建 JWT
claims1 := entity.Claims{
Data: map[string]string{
"uid": "user123",
"jti": "jti1",
"exp": fmt.Sprintf("%d", expires),
},
RegisteredClaims: jwt.RegisteredClaims{
ID: "jti1",
ExpiresAt: jwt.NewNumericDate(time.Unix(expires, 0)),
Issuer: "permission",
},
}
token1JWT, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, claims1).SignedString([]byte("test-secret"))
testToken1.AccessToken = token1JWT
claims2 := entity.Claims{
Data: map[string]string{
"uid": "user123",
"jti": "jti2",
"exp": fmt.Sprintf("%d", expires),
},
RegisteredClaims: jwt.RegisteredClaims{
ID: "jti2",
ExpiresAt: jwt.NewNumericDate(time.Unix(expires, 0)),
Issuer: "permission",
},
}
token2JWT, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, claims2).SignedString([]byte("test-secret"))
testToken2.AccessToken = token2JWT
tests := []struct {
name string
uid string
reason string
mockSetup func(*mockRepo.MockTokenRepository)
wantErr bool
}{
{
name: "成功將用戶所有 Token 加入黑名單",
uid: "user123",
reason: "security issue",
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
GetAccessTokensByUID(ctx, "user123").
Return([]entity.Token{testToken1, testToken2}, nil)
mockTokenRepo.EXPECT().
AddToBlacklist(ctx, gomock.Any(), time.Duration(0)).
Return(nil).
Times(2)
mockTokenRepo.EXPECT().
DeleteAccessTokensByUID(ctx, "user123").
Return(nil)
},
wantErr: false,
},
{
name: "成功但用戶沒有 Token",
uid: "user999",
reason: "test",
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
GetAccessTokensByUID(ctx, "user999").
Return([]entity.Token{}, nil)
mockTokenRepo.EXPECT().
DeleteAccessTokensByUID(ctx, "user999").
Return(nil)
},
wantErr: false,
},
{
name: "失敗 - 獲取用戶 Token 失敗",
uid: "user123",
reason: "test",
mockSetup: func(mockTokenRepo *mockRepo.MockTokenRepository) {
mockTokenRepo.EXPECT().
GetAccessTokensByUID(ctx, "user123").
Return(nil, errors.New("db error"))
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockTokenRepo := mockRepo.NewMockTokenRepository(mockCtrl)
tt.mockSetup(mockTokenRepo)
uc := MustTokenUseCase(TokenUseCaseParam{
TokenRepo: mockTokenRepo,
Config: &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-secret",
AccessTokenExpiry: time.Hour,
RefreshTokenExpiry: time.Hour * 24,
RefreshSecret: "refresh-secret",
OneTimeTokenExpiry: time.Minute * 5,
MaxTokensPerUser: 10,
MaxTokensPerDevice: 5,
},
},
})
err := uc.BlacklistAllUserTokens(ctx, tt.uid, tt.reason)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}