app-cloudep-permission-server/pkg/usecase/token_test.go

612 lines
18 KiB
Go
Raw Permalink Normal View History

2025-02-13 11:06:51 +00:00
package usecase
import (
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/entity"
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/usecase"
mock "code.30cm.net/digimon/app-cloudep-permission-server/pkg/mock/repository"
"context"
"fmt"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"testing"
"time"
)
// TestTokenUseCase_CreateAccessToken_TableDriven 透過 table-driven 方式測試 CreateAccessToken
func TestTokenUseCase_CreateAccessToken_TableDriven(t *testing.T) {
// 固定發行者設定,測試中會用來驗證 claims.Issuer
now := time.Now()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl)
uc := NewTokenUseCase(TokenUseCaseParam{
TokenRepo: mockAutoIDModel,
RefreshExpires: 2 * time.Minute,
Expired: 2 * time.Minute,
Secret: "gg88g88",
})
tests := []struct {
name string
token entity.Token
data any
secretKey string
wantErr bool
verifyClaims func(t *testing.T, claims *entity.Claims, expectedExpiry time.Time)
}{
{
name: "ok",
token: entity.Token{
ID: "token1",
ExpiresIn: now.Add(1 * time.Hour).UnixNano(),
},
data: map[string]interface{}{"foo": "bar"},
secretKey: "secret123",
wantErr: false,
verifyClaims: func(t *testing.T, claims *entity.Claims, expectedExpiry time.Time) {
assert.Equal(t, "token1", claims.ID)
assert.True(t, claims.ExpiresAt.Time.Before(expectedExpiry),
"expected expiry %v, got %v", expectedExpiry, claims.ExpiresAt.Time)
dataMap, ok := claims.Data.(map[string]interface{})
assert.True(t, ok, "claims.Data 應為 map[string]interface{}")
assert.Equal(t, "bar", dataMap["foo"])
},
},
{
name: "valid token with string data",
token: entity.Token{
ID: "token2",
ExpiresIn: now.Add(2 * time.Hour).UnixNano(),
},
data: map[string]interface{}{"foo": "bar"},
secretKey: "anotherSecret",
wantErr: false,
verifyClaims: func(t *testing.T, claims *entity.Claims, expectedExpiry time.Time) {
assert.Equal(t, "token2", claims.ID)
assert.True(t, claims.ExpiresAt.Time.Before(expectedExpiry))
assert.Equal(t, map[string]interface{}{"foo": "bar"}, claims.Data)
},
},
{
name: "empty secret key",
token: entity.Token{
ID: "token3",
ExpiresIn: now.Add(30 * time.Minute).UnixNano(),
},
data: map[string]interface{}{"key": "value"},
secretKey: "",
wantErr: false,
verifyClaims: func(t *testing.T, claims *entity.Claims, expectedExpiry time.Time) {
assert.Equal(t, "token3", claims.ID)
assert.True(t, claims.ExpiresAt.Time.Before(expectedExpiry))
dataMap, ok := claims.Data.(map[string]interface{})
assert.True(t, ok, "claims.Data 應為 map[string]interface{}")
assert.Equal(t, "value", dataMap["key"])
},
},
// 如有需要,可加入更多測試案例,例如模擬簽名錯誤等情境
}
for _, tt := range tests {
tt := tt // 捕捉範圍變數
t.Run(tt.name, func(t *testing.T) {
jwtStr, err := uc.CreateAccessToken(tt.token, tt.data, tt.secretKey)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 解析 JWT
parsedToken, err := jwt.ParseWithClaims(jwtStr, &entity.Claims{}, func(token *jwt.Token) (interface{}, error) {
// 驗證簽名方法是否為 HMAC
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(tt.secretKey), nil
})
assert.NoError(t, err)
assert.True(t, parsedToken.Valid, "解析後的 JWT 應該有效")
claims, ok := parsedToken.Claims.(*entity.Claims)
assert.True(t, ok, "claims 型別錯誤,預期 *entity.Claims, got %T", parsedToken.Claims)
// 根據 token.ExpiresIn 計算預期的過期時間
expectedExpiry := time.Unix(0, tt.token.ExpiresIn)
// 呼叫 verifyClaims 驗證其它 Claim 資料
tt.verifyClaims(t, claims, expectedExpiry)
})
}
}
func TestTokenUseCase_CreateRefreshToken(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl)
uc := NewTokenUseCase(TokenUseCaseParam{
TokenRepo: mockAutoIDModel,
RefreshExpires: 2 * time.Minute,
Expired: 2 * time.Minute,
Secret: "gg88g88",
})
tests := []struct {
name string
accessToken string
expected string
}{
{
name: "empty access token",
accessToken: "",
// SHA256("") 的 hex 編碼結果
expected: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
},
{
name: "normal access token",
accessToken: "access-token",
expected: "3f16bed7089f4653e5ef21bfd2824d7f3aaaecc7a598e7e89c580e1606a9cc52",
},
}
for _, tt := range tests {
tt := tt // 捕捉變數
t.Run(tt.name, func(t *testing.T) {
result := uc.CreateRefreshToken(tt.accessToken)
assert.Equal(t, tt.expected, result)
})
}
}
// TestTokenUseCase_ParseJWTClaimsByAccessToken 使用 table-driven 方式測試 ParseJWTClaimsByAccessToken
func TestTokenUseCase_ParseJWTClaimsByAccessToken(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl)
uc := NewTokenUseCase(TokenUseCaseParam{
TokenRepo: mockAutoIDModel,
RefreshExpires: 2 * time.Minute,
Expired: 2 * time.Minute,
Secret: "gg88g88",
})
// 定義測試案例的結構
tests := []struct {
name string
// tokenGen 用來動態產生要解析的 access token
tokenGen func(t *testing.T) string
secret string
validate bool
wantClaims jwt.MapClaims
wantErr bool
errContains string
}{
{
name: "valid token with validation",
tokenGen: func(t *testing.T) string {
claims := jwt.MapClaims{
"sub": "123",
"role": "admin",
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte("testsecret"))
if err != nil {
t.Fatalf("failed to sign token: %v", err)
}
return tokenString
},
secret: "testsecret",
validate: true,
wantClaims: jwt.MapClaims{
"sub": "123",
"role": "admin",
},
wantErr: false,
},
{
name: "valid token without validation",
tokenGen: func(t *testing.T) string {
claims := jwt.MapClaims{
"sub": "123",
"role": "admin",
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte("testsecret"))
if err != nil {
t.Fatalf("failed to sign token: %v", err)
}
return tokenString
},
secret: "testsecret",
validate: false,
wantClaims: jwt.MapClaims{
"sub": "123",
"role": "admin",
},
wantErr: false,
},
{
name: "invalid secret",
tokenGen: func(t *testing.T) string {
claims := jwt.MapClaims{
"sub": "123",
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte("testsecret"))
if err != nil {
t.Fatalf("failed to sign token: %v", err)
}
return tokenString
},
secret: "wrongsecret",
validate: true,
wantErr: true,
errContains: "signature", // 預期錯誤訊息中包含 "signature"
},
{
name: "unexpected signing method",
tokenGen: func(t *testing.T) string {
claims := jwt.MapClaims{
"sub": "456",
}
// 使用 SigningMethodNone 產生 token
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
// 針對 None 演算法SignedString 需要使用 jwt.UnsafeAllowNoneSignatureType
tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
if err != nil {
t.Fatalf("failed to sign token: %v", err)
}
return tokenString
},
secret: "testsecret",
validate: true,
wantErr: true,
errContains: "unexpected signing method",
},
{
name: "malformed token",
tokenGen: func(t *testing.T) string {
return "not-a-token"
},
secret: "testsecret",
validate: true,
wantErr: true,
errContains: "token contains an invalid number of segments",
},
}
// 針對每個測試案例執行測試
for _, tt := range tests {
tt := tt // 捕捉迴圈變數
t.Run(tt.name, func(t *testing.T) {
// 產生 access token
accessToken := tt.tokenGen(t)
claims, err := uc.ParseJWTClaimsByAccessToken(accessToken, tt.secret, tt.validate)
if tt.wantErr {
assert.Error(t, err)
if err != nil && tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
assert.NoError(t, err)
// 驗證解析出來的 claims 是否符合預期
assert.Equal(t, tt.wantClaims, claims)
})
}
}
func TestTokenUseCase_ParseSystemClaimsByAccessToken(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl)
uc := NewTokenUseCase(TokenUseCaseParam{
TokenRepo: mockAutoIDModel,
RefreshExpires: 2 * time.Minute,
Expired: 2 * time.Minute,
Secret: "gg88g88",
})
//table-driven 測試案例
tests := []struct {
name string
tokenGen func(t *testing.T) string // 用來產生 access token
secret string
validate bool
want map[string]string // 預期轉換後的資料
wantErr bool
errContains string
}{
{
name: "valid token with correct data map",
tokenGen: func(t *testing.T) string {
// 建立 claims其中 "data" 欄位為 map[string]any
claims := jwt.MapClaims{
"data": map[string]any{
"key1": "value1",
"key2": "value2",
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := token.SignedString([]byte("secret"))
if err != nil {
t.Fatalf("failed to sign token: %v", err)
}
return tokenStr
},
secret: "secret",
validate: true,
want: map[string]string{
"key1": "value1",
"key2": "value2",
},
wantErr: false,
},
{
name: "token missing data field",
tokenGen: func(t *testing.T) string {
// claims 中不包含 "data"
claims := jwt.MapClaims{
"other": "something",
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := token.SignedString([]byte("secret"))
if err != nil {
t.Fatalf("failed to sign token: %v", err)
}
return tokenStr
},
secret: "secret",
validate: true,
want: map[string]string{},
wantErr: true,
errContains: "get data from claim map error",
},
{
name: "malformed token",
tokenGen: func(t *testing.T) string {
return "not-a-token"
},
secret: "secret",
validate: true,
want: map[string]string{},
wantErr: true,
errContains: "token contains an invalid number of segments",
},
{
name: "data field not a map",
tokenGen: func(t *testing.T) string {
// 將 "data" 設為一個字串,而非 map
claims := jwt.MapClaims{
"data": "not-a-map",
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := token.SignedString([]byte("secret"))
if err != nil {
t.Fatalf("failed to sign token: %v", err)
}
return tokenStr
},
secret: "secret",
validate: true,
want: map[string]string{},
wantErr: true,
errContains: "get data from claim map error",
},
}
for _, tt := range tests {
tt := tt // 捕捉區域變數
t.Run(tt.name, func(t *testing.T) {
accessToken := tt.tokenGen(t)
result, err := uc.ParseSystemClaimsByAccessToken(accessToken, tt.secret, tt.validate)
if tt.wantErr {
assert.Error(t, err)
if err != nil && tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
assert.NoError(t, err)
assert.Equal(t, tt.want, result)
})
}
}
func TestTokenUseCase_newToken(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl)
uc := TokenUseCase{
TokenUseCaseParam: TokenUseCaseParam{
TokenRepo: mockAutoIDModel,
RefreshExpires: 2 * time.Minute,
Expired: 2 * time.Minute,
Secret: "gg88g88",
},
}
// 取得一個參考時間,用來檢查 default expiration 的結果
nowRef := time.Now().UTC()
tests := []struct {
name string
req *usecase.GenerateTokenRequest
// 模擬產生 AccessToken 與 RefreshToken 的函式
stubAccessToken func(token entity.Token, data map[string]interface{}, secret string) (string, error)
stubRefreshToken func(accessToken string) string
wantErr bool
// 當使用者提供明確的 expires 與 refreshExpires 時,期望的值(否則使用預設)
expectExpiresProvided bool
expectedExpires int64
expectRefreshExpiresProvided bool
expectedRefreshExpires int64
}{
{
name: "default expiration used when req.Expires/RefreshExpires are zero",
req: &usecase.GenerateTokenRequest{
DeviceID: "device1",
UID: "user1",
Expires: 0,
RefreshExpires: 0,
Data: map[string]string{"foo": "bar"},
Role: "admin",
Scope: "read",
Account: "account1",
TokenType: "access",
},
wantErr: false,
expectExpiresProvided: false,
expectRefreshExpiresProvided: false,
},
{
name: "explicit expiration provided",
req: func() *usecase.GenerateTokenRequest {
// 提供明確的 expires 與 refreshExpires
exp := nowRef.Add(5 * time.Minute).UnixNano()
refExp := nowRef.Add(10 * time.Minute).UnixNano()
return &usecase.GenerateTokenRequest{
DeviceID: "device2",
UID: "user2",
Expires: exp,
RefreshExpires: refExp,
Data: map[string]string{},
Role: "user",
Scope: "write",
Account: "account2",
TokenType: "access",
}
}(),
stubAccessToken: func(token entity.Token, data map[string]interface{}, secret string) (string, error) {
return "access-token", nil
},
stubRefreshToken: func(accessToken string) string {
return "refresh-token"
},
wantErr: false,
expectExpiresProvided: true,
// 預期值就與 req.Expires 相同
expectedExpires: func() int64 { return nowRef.Add(5 * time.Minute).UnixNano() }(),
expectRefreshExpiresProvided: true,
expectedRefreshExpires: func() int64 { return nowRef.Add(10 * time.Minute).UnixNano() }(),
},
}
for _, tt := range tests {
tt := tt // 捕捉範圍變數
t.Run(tt.name, func(t *testing.T) {
// 呼叫 newToken 方法
token, err := uc.newToken(context.Background(), tt.req)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 檢查基本欄位
assert.NotEmpty(t, token.ID, "token.ID should not be empty")
assert.Equal(t, tt.req.DeviceID, token.DeviceID)
assert.Equal(t, tt.req.UID, token.UID)
// 驗證建立時間欄位有被設置
assert.NotZero(t, token.AccessCreateAt)
assert.NotZero(t, token.RefreshCreateAt)
})
}
}
func TestTokenUseCase_GenerateAccessToken(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockNewMockTokenRepo := mock.NewMockTokenRepo(mockCtrl)
uc := TokenUseCase{
TokenUseCaseParam: TokenUseCaseParam{
TokenRepo: mockNewMockTokenRepo,
RefreshExpires: 2 * time.Minute,
Expired: 2 * time.Minute,
Secret: "gg88g88",
},
}
// 定義 table-driven 測試案例
tests := []struct {
name string
repoErr error
req usecase.GenerateTokenRequest
wantErr bool
errContains string
setup func()
// 若成功,預期回傳的 access token 與 refresh token
expectedAccessToken string
expectedRefreshToken string
}{
{
name: "newToken error from CreateAccessToken",
repoErr: nil,
setup: func() {
mockNewMockTokenRepo.EXPECT().Create(gomock.Any(), gomock.Any()).Return(fmt.Errorf("token create error: failed to create token"))
},
req: usecase.GenerateTokenRequest{
DeviceID: "device1",
UID: "user1",
Expires: 0, // 使用預設過期時間
RefreshExpires: 0,
Data: map[string]string{"foo": "bar"},
Role: "admin",
Scope: "read",
Account: "account1",
TokenType: "access",
},
wantErr: true,
errContains: "token create error: failed to create token",
},
{
name: "successful generation",
repoErr: nil,
setup: func() {
mockNewMockTokenRepo.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil)
},
req: usecase.GenerateTokenRequest{
DeviceID: "device3",
UID: "user3",
Expires: 0,
RefreshExpires: 0,
Data: map[string]string{"foo": "bar"},
Role: "member",
Scope: "read",
Account: "account3",
TokenType: "access",
},
wantErr: false,
},
}
// 針對每個測試案例執行測試
for _, tt := range tests {
tt := tt // 捕捉區域變數
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
tt.setup()
resp, err := uc.GenerateAccessToken(ctx, tt.req)
if tt.wantErr {
assert.Error(t, err)
if err != nil && tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
assert.NoError(t, err)
// 驗證 ExpiresIn 非零newToken 會根據當前時間與設定產生過期時間)
assert.NotZero(t, resp.ExpiresIn)
})
}
}