app-cloudep-permission-server/pkg/usecase/token_test.go

612 lines
18 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package usecase
import (
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/entity"
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/usecase"
mock "code.30cm.net/digimon/app-cloudep-permission-server/pkg/mock/repository"
"context"
"fmt"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"testing"
"time"
)
// TestTokenUseCase_CreateAccessToken_TableDriven 透過 table-driven 方式測試 CreateAccessToken
func TestTokenUseCase_CreateAccessToken_TableDriven(t *testing.T) {
// 固定發行者設定,測試中會用來驗證 claims.Issuer
now := time.Now()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl)
uc := NewTokenUseCase(TokenUseCaseParam{
TokenRepo: mockAutoIDModel,
RefreshExpires: 2 * time.Minute,
Expired: 2 * time.Minute,
Secret: "gg88g88",
})
tests := []struct {
name string
token entity.Token
data any
secretKey string
wantErr bool
verifyClaims func(t *testing.T, claims *entity.Claims, expectedExpiry time.Time)
}{
{
name: "ok",
token: entity.Token{
ID: "token1",
ExpiresIn: now.Add(1 * time.Hour).UnixNano(),
},
data: map[string]interface{}{"foo": "bar"},
secretKey: "secret123",
wantErr: false,
verifyClaims: func(t *testing.T, claims *entity.Claims, expectedExpiry time.Time) {
assert.Equal(t, "token1", claims.ID)
assert.True(t, claims.ExpiresAt.Time.Before(expectedExpiry),
"expected expiry %v, got %v", expectedExpiry, claims.ExpiresAt.Time)
dataMap, ok := claims.Data.(map[string]interface{})
assert.True(t, ok, "claims.Data 應為 map[string]interface{}")
assert.Equal(t, "bar", dataMap["foo"])
},
},
{
name: "valid token with string data",
token: entity.Token{
ID: "token2",
ExpiresIn: now.Add(2 * time.Hour).UnixNano(),
},
data: map[string]interface{}{"foo": "bar"},
secretKey: "anotherSecret",
wantErr: false,
verifyClaims: func(t *testing.T, claims *entity.Claims, expectedExpiry time.Time) {
assert.Equal(t, "token2", claims.ID)
assert.True(t, claims.ExpiresAt.Time.Before(expectedExpiry))
assert.Equal(t, map[string]interface{}{"foo": "bar"}, claims.Data)
},
},
{
name: "empty secret key",
token: entity.Token{
ID: "token3",
ExpiresIn: now.Add(30 * time.Minute).UnixNano(),
},
data: map[string]interface{}{"key": "value"},
secretKey: "",
wantErr: false,
verifyClaims: func(t *testing.T, claims *entity.Claims, expectedExpiry time.Time) {
assert.Equal(t, "token3", claims.ID)
assert.True(t, claims.ExpiresAt.Time.Before(expectedExpiry))
dataMap, ok := claims.Data.(map[string]interface{})
assert.True(t, ok, "claims.Data 應為 map[string]interface{}")
assert.Equal(t, "value", dataMap["key"])
},
},
// 如有需要,可加入更多測試案例,例如模擬簽名錯誤等情境
}
for _, tt := range tests {
tt := tt // 捕捉範圍變數
t.Run(tt.name, func(t *testing.T) {
jwtStr, err := uc.CreateAccessToken(tt.token, tt.data, tt.secretKey)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 解析 JWT
parsedToken, err := jwt.ParseWithClaims(jwtStr, &entity.Claims{}, func(token *jwt.Token) (interface{}, error) {
// 驗證簽名方法是否為 HMAC
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(tt.secretKey), nil
})
assert.NoError(t, err)
assert.True(t, parsedToken.Valid, "解析後的 JWT 應該有效")
claims, ok := parsedToken.Claims.(*entity.Claims)
assert.True(t, ok, "claims 型別錯誤,預期 *entity.Claims, got %T", parsedToken.Claims)
// 根據 token.ExpiresIn 計算預期的過期時間
expectedExpiry := time.Unix(0, tt.token.ExpiresIn)
// 呼叫 verifyClaims 驗證其它 Claim 資料
tt.verifyClaims(t, claims, expectedExpiry)
})
}
}
func TestTokenUseCase_CreateRefreshToken(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl)
uc := NewTokenUseCase(TokenUseCaseParam{
TokenRepo: mockAutoIDModel,
RefreshExpires: 2 * time.Minute,
Expired: 2 * time.Minute,
Secret: "gg88g88",
})
tests := []struct {
name string
accessToken string
expected string
}{
{
name: "empty access token",
accessToken: "",
// SHA256("") 的 hex 編碼結果
expected: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
},
{
name: "normal access token",
accessToken: "access-token",
expected: "3f16bed7089f4653e5ef21bfd2824d7f3aaaecc7a598e7e89c580e1606a9cc52",
},
}
for _, tt := range tests {
tt := tt // 捕捉變數
t.Run(tt.name, func(t *testing.T) {
result := uc.CreateRefreshToken(tt.accessToken)
assert.Equal(t, tt.expected, result)
})
}
}
// TestTokenUseCase_ParseJWTClaimsByAccessToken 使用 table-driven 方式測試 ParseJWTClaimsByAccessToken
func TestTokenUseCase_ParseJWTClaimsByAccessToken(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl)
uc := NewTokenUseCase(TokenUseCaseParam{
TokenRepo: mockAutoIDModel,
RefreshExpires: 2 * time.Minute,
Expired: 2 * time.Minute,
Secret: "gg88g88",
})
// 定義測試案例的結構
tests := []struct {
name string
// tokenGen 用來動態產生要解析的 access token
tokenGen func(t *testing.T) string
secret string
validate bool
wantClaims jwt.MapClaims
wantErr bool
errContains string
}{
{
name: "valid token with validation",
tokenGen: func(t *testing.T) string {
claims := jwt.MapClaims{
"sub": "123",
"role": "admin",
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte("testsecret"))
if err != nil {
t.Fatalf("failed to sign token: %v", err)
}
return tokenString
},
secret: "testsecret",
validate: true,
wantClaims: jwt.MapClaims{
"sub": "123",
"role": "admin",
},
wantErr: false,
},
{
name: "valid token without validation",
tokenGen: func(t *testing.T) string {
claims := jwt.MapClaims{
"sub": "123",
"role": "admin",
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte("testsecret"))
if err != nil {
t.Fatalf("failed to sign token: %v", err)
}
return tokenString
},
secret: "testsecret",
validate: false,
wantClaims: jwt.MapClaims{
"sub": "123",
"role": "admin",
},
wantErr: false,
},
{
name: "invalid secret",
tokenGen: func(t *testing.T) string {
claims := jwt.MapClaims{
"sub": "123",
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte("testsecret"))
if err != nil {
t.Fatalf("failed to sign token: %v", err)
}
return tokenString
},
secret: "wrongsecret",
validate: true,
wantErr: true,
errContains: "signature", // 預期錯誤訊息中包含 "signature"
},
{
name: "unexpected signing method",
tokenGen: func(t *testing.T) string {
claims := jwt.MapClaims{
"sub": "456",
}
// 使用 SigningMethodNone 產生 token
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
// 針對 None 演算法SignedString 需要使用 jwt.UnsafeAllowNoneSignatureType
tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
if err != nil {
t.Fatalf("failed to sign token: %v", err)
}
return tokenString
},
secret: "testsecret",
validate: true,
wantErr: true,
errContains: "unexpected signing method",
},
{
name: "malformed token",
tokenGen: func(t *testing.T) string {
return "not-a-token"
},
secret: "testsecret",
validate: true,
wantErr: true,
errContains: "token contains an invalid number of segments",
},
}
// 針對每個測試案例執行測試
for _, tt := range tests {
tt := tt // 捕捉迴圈變數
t.Run(tt.name, func(t *testing.T) {
// 產生 access token
accessToken := tt.tokenGen(t)
claims, err := uc.ParseJWTClaimsByAccessToken(accessToken, tt.secret, tt.validate)
if tt.wantErr {
assert.Error(t, err)
if err != nil && tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
assert.NoError(t, err)
// 驗證解析出來的 claims 是否符合預期
assert.Equal(t, tt.wantClaims, claims)
})
}
}
func TestTokenUseCase_ParseSystemClaimsByAccessToken(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl)
uc := NewTokenUseCase(TokenUseCaseParam{
TokenRepo: mockAutoIDModel,
RefreshExpires: 2 * time.Minute,
Expired: 2 * time.Minute,
Secret: "gg88g88",
})
//table-driven 測試案例
tests := []struct {
name string
tokenGen func(t *testing.T) string // 用來產生 access token
secret string
validate bool
want map[string]string // 預期轉換後的資料
wantErr bool
errContains string
}{
{
name: "valid token with correct data map",
tokenGen: func(t *testing.T) string {
// 建立 claims其中 "data" 欄位為 map[string]any
claims := jwt.MapClaims{
"data": map[string]any{
"key1": "value1",
"key2": "value2",
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := token.SignedString([]byte("secret"))
if err != nil {
t.Fatalf("failed to sign token: %v", err)
}
return tokenStr
},
secret: "secret",
validate: true,
want: map[string]string{
"key1": "value1",
"key2": "value2",
},
wantErr: false,
},
{
name: "token missing data field",
tokenGen: func(t *testing.T) string {
// claims 中不包含 "data"
claims := jwt.MapClaims{
"other": "something",
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := token.SignedString([]byte("secret"))
if err != nil {
t.Fatalf("failed to sign token: %v", err)
}
return tokenStr
},
secret: "secret",
validate: true,
want: map[string]string{},
wantErr: true,
errContains: "get data from claim map error",
},
{
name: "malformed token",
tokenGen: func(t *testing.T) string {
return "not-a-token"
},
secret: "secret",
validate: true,
want: map[string]string{},
wantErr: true,
errContains: "token contains an invalid number of segments",
},
{
name: "data field not a map",
tokenGen: func(t *testing.T) string {
// 將 "data" 設為一個字串,而非 map
claims := jwt.MapClaims{
"data": "not-a-map",
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := token.SignedString([]byte("secret"))
if err != nil {
t.Fatalf("failed to sign token: %v", err)
}
return tokenStr
},
secret: "secret",
validate: true,
want: map[string]string{},
wantErr: true,
errContains: "get data from claim map error",
},
}
for _, tt := range tests {
tt := tt // 捕捉區域變數
t.Run(tt.name, func(t *testing.T) {
accessToken := tt.tokenGen(t)
result, err := uc.ParseSystemClaimsByAccessToken(accessToken, tt.secret, tt.validate)
if tt.wantErr {
assert.Error(t, err)
if err != nil && tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
assert.NoError(t, err)
assert.Equal(t, tt.want, result)
})
}
}
func TestTokenUseCase_newToken(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl)
uc := TokenUseCase{
TokenUseCaseParam: TokenUseCaseParam{
TokenRepo: mockAutoIDModel,
RefreshExpires: 2 * time.Minute,
Expired: 2 * time.Minute,
Secret: "gg88g88",
},
}
// 取得一個參考時間,用來檢查 default expiration 的結果
nowRef := time.Now().UTC()
tests := []struct {
name string
req *usecase.GenerateTokenRequest
// 模擬產生 AccessToken 與 RefreshToken 的函式
stubAccessToken func(token entity.Token, data map[string]interface{}, secret string) (string, error)
stubRefreshToken func(accessToken string) string
wantErr bool
// 當使用者提供明確的 expires 與 refreshExpires 時,期望的值(否則使用預設)
expectExpiresProvided bool
expectedExpires int64
expectRefreshExpiresProvided bool
expectedRefreshExpires int64
}{
{
name: "default expiration used when req.Expires/RefreshExpires are zero",
req: &usecase.GenerateTokenRequest{
DeviceID: "device1",
UID: "user1",
Expires: 0,
RefreshExpires: 0,
Data: map[string]string{"foo": "bar"},
Role: "admin",
Scope: "read",
Account: "account1",
TokenType: "access",
},
wantErr: false,
expectExpiresProvided: false,
expectRefreshExpiresProvided: false,
},
{
name: "explicit expiration provided",
req: func() *usecase.GenerateTokenRequest {
// 提供明確的 expires 與 refreshExpires
exp := nowRef.Add(5 * time.Minute).UnixNano()
refExp := nowRef.Add(10 * time.Minute).UnixNano()
return &usecase.GenerateTokenRequest{
DeviceID: "device2",
UID: "user2",
Expires: exp,
RefreshExpires: refExp,
Data: map[string]string{},
Role: "user",
Scope: "write",
Account: "account2",
TokenType: "access",
}
}(),
stubAccessToken: func(token entity.Token, data map[string]interface{}, secret string) (string, error) {
return "access-token", nil
},
stubRefreshToken: func(accessToken string) string {
return "refresh-token"
},
wantErr: false,
expectExpiresProvided: true,
// 預期值就與 req.Expires 相同
expectedExpires: func() int64 { return nowRef.Add(5 * time.Minute).UnixNano() }(),
expectRefreshExpiresProvided: true,
expectedRefreshExpires: func() int64 { return nowRef.Add(10 * time.Minute).UnixNano() }(),
},
}
for _, tt := range tests {
tt := tt // 捕捉範圍變數
t.Run(tt.name, func(t *testing.T) {
// 呼叫 newToken 方法
token, err := uc.newToken(context.Background(), tt.req)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 檢查基本欄位
assert.NotEmpty(t, token.ID, "token.ID should not be empty")
assert.Equal(t, tt.req.DeviceID, token.DeviceID)
assert.Equal(t, tt.req.UID, token.UID)
// 驗證建立時間欄位有被設置
assert.NotZero(t, token.AccessCreateAt)
assert.NotZero(t, token.RefreshCreateAt)
})
}
}
func TestTokenUseCase_GenerateAccessToken(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockNewMockTokenRepo := mock.NewMockTokenRepo(mockCtrl)
uc := TokenUseCase{
TokenUseCaseParam: TokenUseCaseParam{
TokenRepo: mockNewMockTokenRepo,
RefreshExpires: 2 * time.Minute,
Expired: 2 * time.Minute,
Secret: "gg88g88",
},
}
// 定義 table-driven 測試案例
tests := []struct {
name string
repoErr error
req usecase.GenerateTokenRequest
wantErr bool
errContains string
setup func()
// 若成功,預期回傳的 access token 與 refresh token
expectedAccessToken string
expectedRefreshToken string
}{
{
name: "newToken error from CreateAccessToken",
repoErr: nil,
setup: func() {
mockNewMockTokenRepo.EXPECT().Create(gomock.Any(), gomock.Any()).Return(fmt.Errorf("token create error: failed to create token"))
},
req: usecase.GenerateTokenRequest{
DeviceID: "device1",
UID: "user1",
Expires: 0, // 使用預設過期時間
RefreshExpires: 0,
Data: map[string]string{"foo": "bar"},
Role: "admin",
Scope: "read",
Account: "account1",
TokenType: "access",
},
wantErr: true,
errContains: "token create error: failed to create token",
},
{
name: "successful generation",
repoErr: nil,
setup: func() {
mockNewMockTokenRepo.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil)
},
req: usecase.GenerateTokenRequest{
DeviceID: "device3",
UID: "user3",
Expires: 0,
RefreshExpires: 0,
Data: map[string]string{"foo": "bar"},
Role: "member",
Scope: "read",
Account: "account3",
TokenType: "access",
},
wantErr: false,
},
}
// 針對每個測試案例執行測試
for _, tt := range tests {
tt := tt // 捕捉區域變數
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
tt.setup()
resp, err := uc.GenerateAccessToken(ctx, tt.req)
if tt.wantErr {
assert.Error(t, err)
if err != nil && tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
assert.NoError(t, err)
// 驗證 ExpiresIn 非零newToken 會根據當前時間與設定產生過期時間)
assert.NotZero(t, resp.ExpiresIn)
})
}
}