612 lines
18 KiB
Go
612 lines
18 KiB
Go
|
package usecase
|
|||
|
|
|||
|
import (
|
|||
|
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/entity"
|
|||
|
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/usecase"
|
|||
|
mock "code.30cm.net/digimon/app-cloudep-permission-server/pkg/mock/repository"
|
|||
|
"context"
|
|||
|
"fmt"
|
|||
|
"github.com/golang-jwt/jwt/v4"
|
|||
|
"github.com/stretchr/testify/assert"
|
|||
|
"go.uber.org/mock/gomock"
|
|||
|
"testing"
|
|||
|
"time"
|
|||
|
)
|
|||
|
|
|||
|
// TestTokenUseCase_CreateAccessToken_TableDriven 透過 table-driven 方式測試 CreateAccessToken
|
|||
|
func TestTokenUseCase_CreateAccessToken_TableDriven(t *testing.T) {
|
|||
|
// 固定發行者設定,測試中會用來驗證 claims.Issuer
|
|||
|
now := time.Now()
|
|||
|
mockCtrl := gomock.NewController(t)
|
|||
|
defer mockCtrl.Finish()
|
|||
|
|
|||
|
mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl)
|
|||
|
uc := NewTokenUseCase(TokenUseCaseParam{
|
|||
|
TokenRepo: mockAutoIDModel,
|
|||
|
RefreshExpires: 2 * time.Minute,
|
|||
|
Expired: 2 * time.Minute,
|
|||
|
Secret: "gg88g88",
|
|||
|
})
|
|||
|
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
token entity.Token
|
|||
|
data any
|
|||
|
secretKey string
|
|||
|
wantErr bool
|
|||
|
verifyClaims func(t *testing.T, claims *entity.Claims, expectedExpiry time.Time)
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "ok",
|
|||
|
token: entity.Token{
|
|||
|
ID: "token1",
|
|||
|
ExpiresIn: now.Add(1 * time.Hour).UnixNano(),
|
|||
|
},
|
|||
|
data: map[string]interface{}{"foo": "bar"},
|
|||
|
secretKey: "secret123",
|
|||
|
wantErr: false,
|
|||
|
verifyClaims: func(t *testing.T, claims *entity.Claims, expectedExpiry time.Time) {
|
|||
|
assert.Equal(t, "token1", claims.ID)
|
|||
|
|
|||
|
assert.True(t, claims.ExpiresAt.Time.Before(expectedExpiry),
|
|||
|
"expected expiry %v, got %v", expectedExpiry, claims.ExpiresAt.Time)
|
|||
|
|
|||
|
dataMap, ok := claims.Data.(map[string]interface{})
|
|||
|
assert.True(t, ok, "claims.Data 應為 map[string]interface{}")
|
|||
|
assert.Equal(t, "bar", dataMap["foo"])
|
|||
|
},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "valid token with string data",
|
|||
|
token: entity.Token{
|
|||
|
ID: "token2",
|
|||
|
ExpiresIn: now.Add(2 * time.Hour).UnixNano(),
|
|||
|
},
|
|||
|
data: map[string]interface{}{"foo": "bar"},
|
|||
|
secretKey: "anotherSecret",
|
|||
|
wantErr: false,
|
|||
|
verifyClaims: func(t *testing.T, claims *entity.Claims, expectedExpiry time.Time) {
|
|||
|
assert.Equal(t, "token2", claims.ID)
|
|||
|
assert.True(t, claims.ExpiresAt.Time.Before(expectedExpiry))
|
|||
|
assert.Equal(t, map[string]interface{}{"foo": "bar"}, claims.Data)
|
|||
|
},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "empty secret key",
|
|||
|
token: entity.Token{
|
|||
|
ID: "token3",
|
|||
|
ExpiresIn: now.Add(30 * time.Minute).UnixNano(),
|
|||
|
},
|
|||
|
data: map[string]interface{}{"key": "value"},
|
|||
|
secretKey: "",
|
|||
|
wantErr: false,
|
|||
|
verifyClaims: func(t *testing.T, claims *entity.Claims, expectedExpiry time.Time) {
|
|||
|
assert.Equal(t, "token3", claims.ID)
|
|||
|
assert.True(t, claims.ExpiresAt.Time.Before(expectedExpiry))
|
|||
|
dataMap, ok := claims.Data.(map[string]interface{})
|
|||
|
assert.True(t, ok, "claims.Data 應為 map[string]interface{}")
|
|||
|
assert.Equal(t, "value", dataMap["key"])
|
|||
|
},
|
|||
|
},
|
|||
|
// 如有需要,可加入更多測試案例,例如模擬簽名錯誤等情境
|
|||
|
}
|
|||
|
|
|||
|
for _, tt := range tests {
|
|||
|
tt := tt // 捕捉範圍變數
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
|
|||
|
jwtStr, err := uc.CreateAccessToken(tt.token, tt.data, tt.secretKey)
|
|||
|
if tt.wantErr {
|
|||
|
assert.Error(t, err)
|
|||
|
return
|
|||
|
}
|
|||
|
assert.NoError(t, err)
|
|||
|
|
|||
|
// 解析 JWT
|
|||
|
parsedToken, err := jwt.ParseWithClaims(jwtStr, &entity.Claims{}, func(token *jwt.Token) (interface{}, error) {
|
|||
|
// 驗證簽名方法是否為 HMAC
|
|||
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|||
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|||
|
}
|
|||
|
return []byte(tt.secretKey), nil
|
|||
|
})
|
|||
|
assert.NoError(t, err)
|
|||
|
assert.True(t, parsedToken.Valid, "解析後的 JWT 應該有效")
|
|||
|
|
|||
|
claims, ok := parsedToken.Claims.(*entity.Claims)
|
|||
|
assert.True(t, ok, "claims 型別錯誤,預期 *entity.Claims, got %T", parsedToken.Claims)
|
|||
|
|
|||
|
// 根據 token.ExpiresIn 計算預期的過期時間
|
|||
|
expectedExpiry := time.Unix(0, tt.token.ExpiresIn)
|
|||
|
// 呼叫 verifyClaims 驗證其它 Claim 資料
|
|||
|
tt.verifyClaims(t, claims, expectedExpiry)
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func TestTokenUseCase_CreateRefreshToken(t *testing.T) {
|
|||
|
mockCtrl := gomock.NewController(t)
|
|||
|
defer mockCtrl.Finish()
|
|||
|
|
|||
|
mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl)
|
|||
|
uc := NewTokenUseCase(TokenUseCaseParam{
|
|||
|
TokenRepo: mockAutoIDModel,
|
|||
|
RefreshExpires: 2 * time.Minute,
|
|||
|
Expired: 2 * time.Minute,
|
|||
|
Secret: "gg88g88",
|
|||
|
})
|
|||
|
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
accessToken string
|
|||
|
expected string
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "empty access token",
|
|||
|
accessToken: "",
|
|||
|
// SHA256("") 的 hex 編碼結果
|
|||
|
expected: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
|||
|
},
|
|||
|
{
|
|||
|
name: "normal access token",
|
|||
|
accessToken: "access-token",
|
|||
|
expected: "3f16bed7089f4653e5ef21bfd2824d7f3aaaecc7a598e7e89c580e1606a9cc52",
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
for _, tt := range tests {
|
|||
|
tt := tt // 捕捉變數
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
result := uc.CreateRefreshToken(tt.accessToken)
|
|||
|
assert.Equal(t, tt.expected, result)
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// TestTokenUseCase_ParseJWTClaimsByAccessToken 使用 table-driven 方式測試 ParseJWTClaimsByAccessToken
|
|||
|
func TestTokenUseCase_ParseJWTClaimsByAccessToken(t *testing.T) {
|
|||
|
mockCtrl := gomock.NewController(t)
|
|||
|
defer mockCtrl.Finish()
|
|||
|
|
|||
|
mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl)
|
|||
|
uc := NewTokenUseCase(TokenUseCaseParam{
|
|||
|
TokenRepo: mockAutoIDModel,
|
|||
|
RefreshExpires: 2 * time.Minute,
|
|||
|
Expired: 2 * time.Minute,
|
|||
|
Secret: "gg88g88",
|
|||
|
})
|
|||
|
|
|||
|
// 定義測試案例的結構
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
// tokenGen 用來動態產生要解析的 access token
|
|||
|
tokenGen func(t *testing.T) string
|
|||
|
secret string
|
|||
|
validate bool
|
|||
|
wantClaims jwt.MapClaims
|
|||
|
wantErr bool
|
|||
|
errContains string
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "valid token with validation",
|
|||
|
tokenGen: func(t *testing.T) string {
|
|||
|
claims := jwt.MapClaims{
|
|||
|
"sub": "123",
|
|||
|
"role": "admin",
|
|||
|
}
|
|||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|||
|
tokenString, err := token.SignedString([]byte("testsecret"))
|
|||
|
if err != nil {
|
|||
|
t.Fatalf("failed to sign token: %v", err)
|
|||
|
}
|
|||
|
return tokenString
|
|||
|
},
|
|||
|
secret: "testsecret",
|
|||
|
validate: true,
|
|||
|
wantClaims: jwt.MapClaims{
|
|||
|
"sub": "123",
|
|||
|
"role": "admin",
|
|||
|
},
|
|||
|
wantErr: false,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "valid token without validation",
|
|||
|
tokenGen: func(t *testing.T) string {
|
|||
|
claims := jwt.MapClaims{
|
|||
|
"sub": "123",
|
|||
|
"role": "admin",
|
|||
|
}
|
|||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|||
|
tokenString, err := token.SignedString([]byte("testsecret"))
|
|||
|
if err != nil {
|
|||
|
t.Fatalf("failed to sign token: %v", err)
|
|||
|
}
|
|||
|
return tokenString
|
|||
|
},
|
|||
|
secret: "testsecret",
|
|||
|
validate: false,
|
|||
|
wantClaims: jwt.MapClaims{
|
|||
|
"sub": "123",
|
|||
|
"role": "admin",
|
|||
|
},
|
|||
|
wantErr: false,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "invalid secret",
|
|||
|
tokenGen: func(t *testing.T) string {
|
|||
|
claims := jwt.MapClaims{
|
|||
|
"sub": "123",
|
|||
|
}
|
|||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|||
|
tokenString, err := token.SignedString([]byte("testsecret"))
|
|||
|
if err != nil {
|
|||
|
t.Fatalf("failed to sign token: %v", err)
|
|||
|
}
|
|||
|
return tokenString
|
|||
|
},
|
|||
|
secret: "wrongsecret",
|
|||
|
validate: true,
|
|||
|
wantErr: true,
|
|||
|
errContains: "signature", // 預期錯誤訊息中包含 "signature"
|
|||
|
},
|
|||
|
{
|
|||
|
name: "unexpected signing method",
|
|||
|
tokenGen: func(t *testing.T) string {
|
|||
|
claims := jwt.MapClaims{
|
|||
|
"sub": "456",
|
|||
|
}
|
|||
|
// 使用 SigningMethodNone 產生 token
|
|||
|
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
|
|||
|
// 針對 None 演算法,SignedString 需要使用 jwt.UnsafeAllowNoneSignatureType
|
|||
|
tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
|
|||
|
if err != nil {
|
|||
|
t.Fatalf("failed to sign token: %v", err)
|
|||
|
}
|
|||
|
return tokenString
|
|||
|
},
|
|||
|
secret: "testsecret",
|
|||
|
validate: true,
|
|||
|
wantErr: true,
|
|||
|
errContains: "unexpected signing method",
|
|||
|
},
|
|||
|
{
|
|||
|
name: "malformed token",
|
|||
|
tokenGen: func(t *testing.T) string {
|
|||
|
return "not-a-token"
|
|||
|
},
|
|||
|
secret: "testsecret",
|
|||
|
validate: true,
|
|||
|
wantErr: true,
|
|||
|
errContains: "token contains an invalid number of segments",
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
// 針對每個測試案例執行測試
|
|||
|
for _, tt := range tests {
|
|||
|
tt := tt // 捕捉迴圈變數
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
// 產生 access token
|
|||
|
accessToken := tt.tokenGen(t)
|
|||
|
|
|||
|
claims, err := uc.ParseJWTClaimsByAccessToken(accessToken, tt.secret, tt.validate)
|
|||
|
if tt.wantErr {
|
|||
|
assert.Error(t, err)
|
|||
|
if err != nil && tt.errContains != "" {
|
|||
|
assert.Contains(t, err.Error(), tt.errContains)
|
|||
|
}
|
|||
|
return
|
|||
|
}
|
|||
|
assert.NoError(t, err)
|
|||
|
// 驗證解析出來的 claims 是否符合預期
|
|||
|
assert.Equal(t, tt.wantClaims, claims)
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func TestTokenUseCase_ParseSystemClaimsByAccessToken(t *testing.T) {
|
|||
|
mockCtrl := gomock.NewController(t)
|
|||
|
defer mockCtrl.Finish()
|
|||
|
|
|||
|
mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl)
|
|||
|
uc := NewTokenUseCase(TokenUseCaseParam{
|
|||
|
TokenRepo: mockAutoIDModel,
|
|||
|
RefreshExpires: 2 * time.Minute,
|
|||
|
Expired: 2 * time.Minute,
|
|||
|
Secret: "gg88g88",
|
|||
|
})
|
|||
|
//table-driven 測試案例
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
tokenGen func(t *testing.T) string // 用來產生 access token
|
|||
|
secret string
|
|||
|
validate bool
|
|||
|
want map[string]string // 預期轉換後的資料
|
|||
|
wantErr bool
|
|||
|
errContains string
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "valid token with correct data map",
|
|||
|
tokenGen: func(t *testing.T) string {
|
|||
|
// 建立 claims,其中 "data" 欄位為 map[string]any
|
|||
|
claims := jwt.MapClaims{
|
|||
|
"data": map[string]any{
|
|||
|
"key1": "value1",
|
|||
|
"key2": "value2",
|
|||
|
},
|
|||
|
}
|
|||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|||
|
tokenStr, err := token.SignedString([]byte("secret"))
|
|||
|
if err != nil {
|
|||
|
t.Fatalf("failed to sign token: %v", err)
|
|||
|
}
|
|||
|
return tokenStr
|
|||
|
},
|
|||
|
secret: "secret",
|
|||
|
validate: true,
|
|||
|
want: map[string]string{
|
|||
|
"key1": "value1",
|
|||
|
"key2": "value2",
|
|||
|
},
|
|||
|
wantErr: false,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "token missing data field",
|
|||
|
tokenGen: func(t *testing.T) string {
|
|||
|
// claims 中不包含 "data"
|
|||
|
claims := jwt.MapClaims{
|
|||
|
"other": "something",
|
|||
|
}
|
|||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|||
|
tokenStr, err := token.SignedString([]byte("secret"))
|
|||
|
if err != nil {
|
|||
|
t.Fatalf("failed to sign token: %v", err)
|
|||
|
}
|
|||
|
return tokenStr
|
|||
|
},
|
|||
|
secret: "secret",
|
|||
|
validate: true,
|
|||
|
want: map[string]string{},
|
|||
|
wantErr: true,
|
|||
|
errContains: "get data from claim map error",
|
|||
|
},
|
|||
|
{
|
|||
|
name: "malformed token",
|
|||
|
tokenGen: func(t *testing.T) string {
|
|||
|
return "not-a-token"
|
|||
|
},
|
|||
|
secret: "secret",
|
|||
|
validate: true,
|
|||
|
want: map[string]string{},
|
|||
|
wantErr: true,
|
|||
|
errContains: "token contains an invalid number of segments",
|
|||
|
},
|
|||
|
{
|
|||
|
name: "data field not a map",
|
|||
|
tokenGen: func(t *testing.T) string {
|
|||
|
// 將 "data" 設為一個字串,而非 map
|
|||
|
claims := jwt.MapClaims{
|
|||
|
"data": "not-a-map",
|
|||
|
}
|
|||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|||
|
tokenStr, err := token.SignedString([]byte("secret"))
|
|||
|
if err != nil {
|
|||
|
t.Fatalf("failed to sign token: %v", err)
|
|||
|
}
|
|||
|
return tokenStr
|
|||
|
},
|
|||
|
secret: "secret",
|
|||
|
validate: true,
|
|||
|
want: map[string]string{},
|
|||
|
wantErr: true,
|
|||
|
errContains: "get data from claim map error",
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
for _, tt := range tests {
|
|||
|
tt := tt // 捕捉區域變數
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
|
|||
|
accessToken := tt.tokenGen(t)
|
|||
|
result, err := uc.ParseSystemClaimsByAccessToken(accessToken, tt.secret, tt.validate)
|
|||
|
if tt.wantErr {
|
|||
|
assert.Error(t, err)
|
|||
|
if err != nil && tt.errContains != "" {
|
|||
|
assert.Contains(t, err.Error(), tt.errContains)
|
|||
|
}
|
|||
|
return
|
|||
|
}
|
|||
|
assert.NoError(t, err)
|
|||
|
assert.Equal(t, tt.want, result)
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func TestTokenUseCase_newToken(t *testing.T) {
|
|||
|
mockCtrl := gomock.NewController(t)
|
|||
|
defer mockCtrl.Finish()
|
|||
|
|
|||
|
mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl)
|
|||
|
uc := TokenUseCase{
|
|||
|
TokenUseCaseParam: TokenUseCaseParam{
|
|||
|
TokenRepo: mockAutoIDModel,
|
|||
|
RefreshExpires: 2 * time.Minute,
|
|||
|
Expired: 2 * time.Minute,
|
|||
|
Secret: "gg88g88",
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
// 取得一個參考時間,用來檢查 default expiration 的結果
|
|||
|
nowRef := time.Now().UTC()
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
req *usecase.GenerateTokenRequest
|
|||
|
// 模擬產生 AccessToken 與 RefreshToken 的函式
|
|||
|
stubAccessToken func(token entity.Token, data map[string]interface{}, secret string) (string, error)
|
|||
|
stubRefreshToken func(accessToken string) string
|
|||
|
wantErr bool
|
|||
|
// 當使用者提供明確的 expires 與 refreshExpires 時,期望的值(否則使用預設)
|
|||
|
expectExpiresProvided bool
|
|||
|
expectedExpires int64
|
|||
|
expectRefreshExpiresProvided bool
|
|||
|
expectedRefreshExpires int64
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "default expiration used when req.Expires/RefreshExpires are zero",
|
|||
|
req: &usecase.GenerateTokenRequest{
|
|||
|
DeviceID: "device1",
|
|||
|
UID: "user1",
|
|||
|
Expires: 0,
|
|||
|
RefreshExpires: 0,
|
|||
|
Data: map[string]string{"foo": "bar"},
|
|||
|
Role: "admin",
|
|||
|
Scope: "read",
|
|||
|
Account: "account1",
|
|||
|
TokenType: "access",
|
|||
|
},
|
|||
|
wantErr: false,
|
|||
|
expectExpiresProvided: false,
|
|||
|
expectRefreshExpiresProvided: false,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "explicit expiration provided",
|
|||
|
req: func() *usecase.GenerateTokenRequest {
|
|||
|
// 提供明確的 expires 與 refreshExpires
|
|||
|
exp := nowRef.Add(5 * time.Minute).UnixNano()
|
|||
|
refExp := nowRef.Add(10 * time.Minute).UnixNano()
|
|||
|
return &usecase.GenerateTokenRequest{
|
|||
|
DeviceID: "device2",
|
|||
|
UID: "user2",
|
|||
|
Expires: exp,
|
|||
|
RefreshExpires: refExp,
|
|||
|
Data: map[string]string{},
|
|||
|
Role: "user",
|
|||
|
Scope: "write",
|
|||
|
Account: "account2",
|
|||
|
TokenType: "access",
|
|||
|
}
|
|||
|
}(),
|
|||
|
stubAccessToken: func(token entity.Token, data map[string]interface{}, secret string) (string, error) {
|
|||
|
return "access-token", nil
|
|||
|
},
|
|||
|
stubRefreshToken: func(accessToken string) string {
|
|||
|
return "refresh-token"
|
|||
|
},
|
|||
|
wantErr: false,
|
|||
|
expectExpiresProvided: true,
|
|||
|
// 預期值就與 req.Expires 相同
|
|||
|
expectedExpires: func() int64 { return nowRef.Add(5 * time.Minute).UnixNano() }(),
|
|||
|
expectRefreshExpiresProvided: true,
|
|||
|
expectedRefreshExpires: func() int64 { return nowRef.Add(10 * time.Minute).UnixNano() }(),
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
for _, tt := range tests {
|
|||
|
tt := tt // 捕捉範圍變數
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
// 呼叫 newToken 方法
|
|||
|
token, err := uc.newToken(context.Background(), tt.req)
|
|||
|
if tt.wantErr {
|
|||
|
assert.Error(t, err)
|
|||
|
return
|
|||
|
}
|
|||
|
assert.NoError(t, err)
|
|||
|
// 檢查基本欄位
|
|||
|
assert.NotEmpty(t, token.ID, "token.ID should not be empty")
|
|||
|
assert.Equal(t, tt.req.DeviceID, token.DeviceID)
|
|||
|
assert.Equal(t, tt.req.UID, token.UID)
|
|||
|
|
|||
|
// 驗證建立時間欄位有被設置
|
|||
|
assert.NotZero(t, token.AccessCreateAt)
|
|||
|
assert.NotZero(t, token.RefreshCreateAt)
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func TestTokenUseCase_GenerateAccessToken(t *testing.T) {
|
|||
|
mockCtrl := gomock.NewController(t)
|
|||
|
defer mockCtrl.Finish()
|
|||
|
|
|||
|
mockNewMockTokenRepo := mock.NewMockTokenRepo(mockCtrl)
|
|||
|
uc := TokenUseCase{
|
|||
|
TokenUseCaseParam: TokenUseCaseParam{
|
|||
|
TokenRepo: mockNewMockTokenRepo,
|
|||
|
RefreshExpires: 2 * time.Minute,
|
|||
|
Expired: 2 * time.Minute,
|
|||
|
Secret: "gg88g88",
|
|||
|
},
|
|||
|
}
|
|||
|
// 定義 table-driven 測試案例
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
repoErr error
|
|||
|
req usecase.GenerateTokenRequest
|
|||
|
wantErr bool
|
|||
|
errContains string
|
|||
|
setup func()
|
|||
|
// 若成功,預期回傳的 access token 與 refresh token
|
|||
|
expectedAccessToken string
|
|||
|
expectedRefreshToken string
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "newToken error from CreateAccessToken",
|
|||
|
repoErr: nil,
|
|||
|
setup: func() {
|
|||
|
mockNewMockTokenRepo.EXPECT().Create(gomock.Any(), gomock.Any()).Return(fmt.Errorf("token create error: failed to create token"))
|
|||
|
},
|
|||
|
req: usecase.GenerateTokenRequest{
|
|||
|
DeviceID: "device1",
|
|||
|
UID: "user1",
|
|||
|
Expires: 0, // 使用預設過期時間
|
|||
|
RefreshExpires: 0,
|
|||
|
Data: map[string]string{"foo": "bar"},
|
|||
|
Role: "admin",
|
|||
|
Scope: "read",
|
|||
|
Account: "account1",
|
|||
|
TokenType: "access",
|
|||
|
},
|
|||
|
wantErr: true,
|
|||
|
errContains: "token create error: failed to create token",
|
|||
|
},
|
|||
|
{
|
|||
|
name: "successful generation",
|
|||
|
repoErr: nil,
|
|||
|
setup: func() {
|
|||
|
mockNewMockTokenRepo.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil)
|
|||
|
},
|
|||
|
req: usecase.GenerateTokenRequest{
|
|||
|
DeviceID: "device3",
|
|||
|
UID: "user3",
|
|||
|
Expires: 0,
|
|||
|
RefreshExpires: 0,
|
|||
|
Data: map[string]string{"foo": "bar"},
|
|||
|
Role: "member",
|
|||
|
Scope: "read",
|
|||
|
Account: "account3",
|
|||
|
TokenType: "access",
|
|||
|
},
|
|||
|
wantErr: false,
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
// 針對每個測試案例執行測試
|
|||
|
for _, tt := range tests {
|
|||
|
tt := tt // 捕捉區域變數
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
ctx := context.Background()
|
|||
|
tt.setup()
|
|||
|
resp, err := uc.GenerateAccessToken(ctx, tt.req)
|
|||
|
if tt.wantErr {
|
|||
|
assert.Error(t, err)
|
|||
|
if err != nil && tt.errContains != "" {
|
|||
|
assert.Contains(t, err.Error(), tt.errContains)
|
|||
|
}
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
assert.NoError(t, err)
|
|||
|
// 驗證 ExpiresIn 非零(newToken 會根據當前時間與設定產生過期時間)
|
|||
|
assert.NotZero(t, resp.ExpiresIn)
|
|||
|
})
|
|||
|
}
|
|||
|
}
|