package usecase import ( "code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/entity" "code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/usecase" mock "code.30cm.net/digimon/app-cloudep-permission-server/pkg/mock/repository" "context" "fmt" "github.com/golang-jwt/jwt/v4" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "testing" "time" ) // TestTokenUseCase_CreateAccessToken_TableDriven 透過 table-driven 方式測試 CreateAccessToken func TestTokenUseCase_CreateAccessToken_TableDriven(t *testing.T) { // 固定發行者設定,測試中會用來驗證 claims.Issuer now := time.Now() mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl) uc := NewTokenUseCase(TokenUseCaseParam{ TokenRepo: mockAutoIDModel, RefreshExpires: 2 * time.Minute, Expired: 2 * time.Minute, Secret: "gg88g88", }) tests := []struct { name string token entity.Token data any secretKey string wantErr bool verifyClaims func(t *testing.T, claims *entity.Claims, expectedExpiry time.Time) }{ { name: "ok", token: entity.Token{ ID: "token1", ExpiresIn: now.Add(1 * time.Hour).UnixNano(), }, data: map[string]interface{}{"foo": "bar"}, secretKey: "secret123", wantErr: false, verifyClaims: func(t *testing.T, claims *entity.Claims, expectedExpiry time.Time) { assert.Equal(t, "token1", claims.ID) assert.True(t, claims.ExpiresAt.Time.Before(expectedExpiry), "expected expiry %v, got %v", expectedExpiry, claims.ExpiresAt.Time) dataMap, ok := claims.Data.(map[string]interface{}) assert.True(t, ok, "claims.Data 應為 map[string]interface{}") assert.Equal(t, "bar", dataMap["foo"]) }, }, { name: "valid token with string data", token: entity.Token{ ID: "token2", ExpiresIn: now.Add(2 * time.Hour).UnixNano(), }, data: map[string]interface{}{"foo": "bar"}, secretKey: "anotherSecret", wantErr: false, verifyClaims: func(t *testing.T, claims *entity.Claims, expectedExpiry time.Time) { assert.Equal(t, "token2", claims.ID) assert.True(t, claims.ExpiresAt.Time.Before(expectedExpiry)) assert.Equal(t, map[string]interface{}{"foo": "bar"}, claims.Data) }, }, { name: "empty secret key", token: entity.Token{ ID: "token3", ExpiresIn: now.Add(30 * time.Minute).UnixNano(), }, data: map[string]interface{}{"key": "value"}, secretKey: "", wantErr: false, verifyClaims: func(t *testing.T, claims *entity.Claims, expectedExpiry time.Time) { assert.Equal(t, "token3", claims.ID) assert.True(t, claims.ExpiresAt.Time.Before(expectedExpiry)) dataMap, ok := claims.Data.(map[string]interface{}) assert.True(t, ok, "claims.Data 應為 map[string]interface{}") assert.Equal(t, "value", dataMap["key"]) }, }, // 如有需要,可加入更多測試案例,例如模擬簽名錯誤等情境 } for _, tt := range tests { tt := tt // 捕捉範圍變數 t.Run(tt.name, func(t *testing.T) { jwtStr, err := uc.CreateAccessToken(tt.token, tt.data, tt.secretKey) if tt.wantErr { assert.Error(t, err) return } assert.NoError(t, err) // 解析 JWT parsedToken, err := jwt.ParseWithClaims(jwtStr, &entity.Claims{}, func(token *jwt.Token) (interface{}, error) { // 驗證簽名方法是否為 HMAC if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(tt.secretKey), nil }) assert.NoError(t, err) assert.True(t, parsedToken.Valid, "解析後的 JWT 應該有效") claims, ok := parsedToken.Claims.(*entity.Claims) assert.True(t, ok, "claims 型別錯誤,預期 *entity.Claims, got %T", parsedToken.Claims) // 根據 token.ExpiresIn 計算預期的過期時間 expectedExpiry := time.Unix(0, tt.token.ExpiresIn) // 呼叫 verifyClaims 驗證其它 Claim 資料 tt.verifyClaims(t, claims, expectedExpiry) }) } } func TestTokenUseCase_CreateRefreshToken(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl) uc := NewTokenUseCase(TokenUseCaseParam{ TokenRepo: mockAutoIDModel, RefreshExpires: 2 * time.Minute, Expired: 2 * time.Minute, Secret: "gg88g88", }) tests := []struct { name string accessToken string expected string }{ { name: "empty access token", accessToken: "", // SHA256("") 的 hex 編碼結果 expected: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", }, { name: "normal access token", accessToken: "access-token", expected: "3f16bed7089f4653e5ef21bfd2824d7f3aaaecc7a598e7e89c580e1606a9cc52", }, } for _, tt := range tests { tt := tt // 捕捉變數 t.Run(tt.name, func(t *testing.T) { result := uc.CreateRefreshToken(tt.accessToken) assert.Equal(t, tt.expected, result) }) } } // TestTokenUseCase_ParseJWTClaimsByAccessToken 使用 table-driven 方式測試 ParseJWTClaimsByAccessToken func TestTokenUseCase_ParseJWTClaimsByAccessToken(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl) uc := NewTokenUseCase(TokenUseCaseParam{ TokenRepo: mockAutoIDModel, RefreshExpires: 2 * time.Minute, Expired: 2 * time.Minute, Secret: "gg88g88", }) // 定義測試案例的結構 tests := []struct { name string // tokenGen 用來動態產生要解析的 access token tokenGen func(t *testing.T) string secret string validate bool wantClaims jwt.MapClaims wantErr bool errContains string }{ { name: "valid token with validation", tokenGen: func(t *testing.T) string { claims := jwt.MapClaims{ "sub": "123", "role": "admin", } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString([]byte("testsecret")) if err != nil { t.Fatalf("failed to sign token: %v", err) } return tokenString }, secret: "testsecret", validate: true, wantClaims: jwt.MapClaims{ "sub": "123", "role": "admin", }, wantErr: false, }, { name: "valid token without validation", tokenGen: func(t *testing.T) string { claims := jwt.MapClaims{ "sub": "123", "role": "admin", } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString([]byte("testsecret")) if err != nil { t.Fatalf("failed to sign token: %v", err) } return tokenString }, secret: "testsecret", validate: false, wantClaims: jwt.MapClaims{ "sub": "123", "role": "admin", }, wantErr: false, }, { name: "invalid secret", tokenGen: func(t *testing.T) string { claims := jwt.MapClaims{ "sub": "123", } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString([]byte("testsecret")) if err != nil { t.Fatalf("failed to sign token: %v", err) } return tokenString }, secret: "wrongsecret", validate: true, wantErr: true, errContains: "signature", // 預期錯誤訊息中包含 "signature" }, { name: "unexpected signing method", tokenGen: func(t *testing.T) string { claims := jwt.MapClaims{ "sub": "456", } // 使用 SigningMethodNone 產生 token token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) // 針對 None 演算法,SignedString 需要使用 jwt.UnsafeAllowNoneSignatureType tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType) if err != nil { t.Fatalf("failed to sign token: %v", err) } return tokenString }, secret: "testsecret", validate: true, wantErr: true, errContains: "unexpected signing method", }, { name: "malformed token", tokenGen: func(t *testing.T) string { return "not-a-token" }, secret: "testsecret", validate: true, wantErr: true, errContains: "token contains an invalid number of segments", }, } // 針對每個測試案例執行測試 for _, tt := range tests { tt := tt // 捕捉迴圈變數 t.Run(tt.name, func(t *testing.T) { // 產生 access token accessToken := tt.tokenGen(t) claims, err := uc.ParseJWTClaimsByAccessToken(accessToken, tt.secret, tt.validate) if tt.wantErr { assert.Error(t, err) if err != nil && tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } return } assert.NoError(t, err) // 驗證解析出來的 claims 是否符合預期 assert.Equal(t, tt.wantClaims, claims) }) } } func TestTokenUseCase_ParseSystemClaimsByAccessToken(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl) uc := NewTokenUseCase(TokenUseCaseParam{ TokenRepo: mockAutoIDModel, RefreshExpires: 2 * time.Minute, Expired: 2 * time.Minute, Secret: "gg88g88", }) //table-driven 測試案例 tests := []struct { name string tokenGen func(t *testing.T) string // 用來產生 access token secret string validate bool want map[string]string // 預期轉換後的資料 wantErr bool errContains string }{ { name: "valid token with correct data map", tokenGen: func(t *testing.T) string { // 建立 claims,其中 "data" 欄位為 map[string]any claims := jwt.MapClaims{ "data": map[string]any{ "key1": "value1", "key2": "value2", }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenStr, err := token.SignedString([]byte("secret")) if err != nil { t.Fatalf("failed to sign token: %v", err) } return tokenStr }, secret: "secret", validate: true, want: map[string]string{ "key1": "value1", "key2": "value2", }, wantErr: false, }, { name: "token missing data field", tokenGen: func(t *testing.T) string { // claims 中不包含 "data" claims := jwt.MapClaims{ "other": "something", } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenStr, err := token.SignedString([]byte("secret")) if err != nil { t.Fatalf("failed to sign token: %v", err) } return tokenStr }, secret: "secret", validate: true, want: map[string]string{}, wantErr: true, errContains: "get data from claim map error", }, { name: "malformed token", tokenGen: func(t *testing.T) string { return "not-a-token" }, secret: "secret", validate: true, want: map[string]string{}, wantErr: true, errContains: "token contains an invalid number of segments", }, { name: "data field not a map", tokenGen: func(t *testing.T) string { // 將 "data" 設為一個字串,而非 map claims := jwt.MapClaims{ "data": "not-a-map", } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenStr, err := token.SignedString([]byte("secret")) if err != nil { t.Fatalf("failed to sign token: %v", err) } return tokenStr }, secret: "secret", validate: true, want: map[string]string{}, wantErr: true, errContains: "get data from claim map error", }, } for _, tt := range tests { tt := tt // 捕捉區域變數 t.Run(tt.name, func(t *testing.T) { accessToken := tt.tokenGen(t) result, err := uc.ParseSystemClaimsByAccessToken(accessToken, tt.secret, tt.validate) if tt.wantErr { assert.Error(t, err) if err != nil && tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } return } assert.NoError(t, err) assert.Equal(t, tt.want, result) }) } } func TestTokenUseCase_newToken(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() mockAutoIDModel := mock.NewMockTokenRepo(mockCtrl) uc := TokenUseCase{ TokenUseCaseParam: TokenUseCaseParam{ TokenRepo: mockAutoIDModel, RefreshExpires: 2 * time.Minute, Expired: 2 * time.Minute, Secret: "gg88g88", }, } // 取得一個參考時間,用來檢查 default expiration 的結果 nowRef := time.Now().UTC() tests := []struct { name string req *usecase.GenerateTokenRequest // 模擬產生 AccessToken 與 RefreshToken 的函式 stubAccessToken func(token entity.Token, data map[string]interface{}, secret string) (string, error) stubRefreshToken func(accessToken string) string wantErr bool // 當使用者提供明確的 expires 與 refreshExpires 時,期望的值(否則使用預設) expectExpiresProvided bool expectedExpires int64 expectRefreshExpiresProvided bool expectedRefreshExpires int64 }{ { name: "default expiration used when req.Expires/RefreshExpires are zero", req: &usecase.GenerateTokenRequest{ DeviceID: "device1", UID: "user1", Expires: 0, RefreshExpires: 0, Data: map[string]string{"foo": "bar"}, Role: "admin", Scope: "read", Account: "account1", TokenType: "access", }, wantErr: false, expectExpiresProvided: false, expectRefreshExpiresProvided: false, }, { name: "explicit expiration provided", req: func() *usecase.GenerateTokenRequest { // 提供明確的 expires 與 refreshExpires exp := nowRef.Add(5 * time.Minute).UnixNano() refExp := nowRef.Add(10 * time.Minute).UnixNano() return &usecase.GenerateTokenRequest{ DeviceID: "device2", UID: "user2", Expires: exp, RefreshExpires: refExp, Data: map[string]string{}, Role: "user", Scope: "write", Account: "account2", TokenType: "access", } }(), stubAccessToken: func(token entity.Token, data map[string]interface{}, secret string) (string, error) { return "access-token", nil }, stubRefreshToken: func(accessToken string) string { return "refresh-token" }, wantErr: false, expectExpiresProvided: true, // 預期值就與 req.Expires 相同 expectedExpires: func() int64 { return nowRef.Add(5 * time.Minute).UnixNano() }(), expectRefreshExpiresProvided: true, expectedRefreshExpires: func() int64 { return nowRef.Add(10 * time.Minute).UnixNano() }(), }, } for _, tt := range tests { tt := tt // 捕捉範圍變數 t.Run(tt.name, func(t *testing.T) { // 呼叫 newToken 方法 token, err := uc.newToken(context.Background(), tt.req) if tt.wantErr { assert.Error(t, err) return } assert.NoError(t, err) // 檢查基本欄位 assert.NotEmpty(t, token.ID, "token.ID should not be empty") assert.Equal(t, tt.req.DeviceID, token.DeviceID) assert.Equal(t, tt.req.UID, token.UID) // 驗證建立時間欄位有被設置 assert.NotZero(t, token.AccessCreateAt) assert.NotZero(t, token.RefreshCreateAt) }) } } func TestTokenUseCase_GenerateAccessToken(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() mockNewMockTokenRepo := mock.NewMockTokenRepo(mockCtrl) uc := TokenUseCase{ TokenUseCaseParam: TokenUseCaseParam{ TokenRepo: mockNewMockTokenRepo, RefreshExpires: 2 * time.Minute, Expired: 2 * time.Minute, Secret: "gg88g88", }, } // 定義 table-driven 測試案例 tests := []struct { name string repoErr error req usecase.GenerateTokenRequest wantErr bool errContains string setup func() // 若成功,預期回傳的 access token 與 refresh token expectedAccessToken string expectedRefreshToken string }{ { name: "newToken error from CreateAccessToken", repoErr: nil, setup: func() { mockNewMockTokenRepo.EXPECT().Create(gomock.Any(), gomock.Any()).Return(fmt.Errorf("token create error: failed to create token")) }, req: usecase.GenerateTokenRequest{ DeviceID: "device1", UID: "user1", Expires: 0, // 使用預設過期時間 RefreshExpires: 0, Data: map[string]string{"foo": "bar"}, Role: "admin", Scope: "read", Account: "account1", TokenType: "access", }, wantErr: true, errContains: "token create error: failed to create token", }, { name: "successful generation", repoErr: nil, setup: func() { mockNewMockTokenRepo.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil) }, req: usecase.GenerateTokenRequest{ DeviceID: "device3", UID: "user3", Expires: 0, RefreshExpires: 0, Data: map[string]string{"foo": "bar"}, Role: "member", Scope: "read", Account: "account3", TokenType: "access", }, wantErr: false, }, } // 針對每個測試案例執行測試 for _, tt := range tests { tt := tt // 捕捉區域變數 t.Run(tt.name, func(t *testing.T) { ctx := context.Background() tt.setup() resp, err := uc.GenerateAccessToken(ctx, tt.req) if tt.wantErr { assert.Error(t, err) if err != nil && tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } return } assert.NoError(t, err) // 驗證 ExpiresIn 非零(newToken 會根據當前時間與設定產生過期時間) assert.NotZero(t, resp.ExpiresIn) }) } }