package logic import ( "ark-permission/gen_result/pb/permission" "ark-permission/internal/domain" "ark-permission/internal/entity" libMock "ark-permission/internal/mock/lib" repoMock "ark-permission/internal/mock/repository" "ark-permission/internal/svc" "errors" "github.com/stretchr/testify/assert" "context" "github.com/golang-jwt/jwt/v4" "go.uber.org/mock/gomock" "testing" "time" ) func TestNewTokenLogic_NewToken(t *testing.T) { // mock ctrl := gomock.NewController(t) defer ctrl.Finish() tokenMockRepo := repoMock.NewMockTokenRepository(ctrl) mockValidate := libMock.NewMockValidate(ctrl) sc := svc.ServiceContext{ TokenRedisRepo: tokenMockRepo, Validate: mockValidate, } l := NewNewTokenLogic(context.Background(), &sc) tests := []struct { name string input *permission.AuthorizationReq setupMocks func() expectError bool expected *permission.TokenResp }{ { name: "Valid token request", input: &permission.AuthorizationReq{ GrantType: "authorization_code", DeviceId: "device123", Scope: "read", Expires: 3600, IsRefreshToken: false, Data: map[string]string{ "uid": "user123", }, }, setupMocks: func() { mockValidate.EXPECT().ValidateAll(gomock.Any()).Return(nil) tokenMockRepo.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil).Do(func(ctx context.Context, token entity.Token) { token.AccessToken = "access_token" }) generateAccessTokenFunc = func(token entity.Token, data any, sign string) (string, error) { return "access_token", nil } generateRefreshTokenFunc = func(accessToken string) string { return "refresh_token" } }, expectError: false, expected: &permission.TokenResp{ AccessToken: "access_token", TokenType: domain.TokenTypeBearer, ExpiresIn: 3600, RefreshToken: "", }, }, { name: "Validation error", input: &permission.AuthorizationReq{ GrantType: "invalid_grant", DeviceId: "device123", Scope: "read", Expires: 3600, IsRefreshToken: false, Data: map[string]string{ "uid": "user123", }, }, setupMocks: func() { mockValidate.EXPECT().ValidateAll(gomock.Any()).Return(errors.New("invalid grant type")) }, expectError: true, expected: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.setupMocks() resp, err := l.NewToken(tt.input) if tt.expectError { assert.Error(t, err) } else { assert.NoError(t, err) assert.Equal(t, tt.expected, resp) } }) } } // 測試 generateAccessToken 函數 func TestGenerateAccessToken(t *testing.T) { // 定義測試用例 tests := []struct { name string token entity.Token data any sign string shouldFail bool shouldVerify bool }{ { name: "Valid token with admin role", token: entity.Token{ ID: "123", ExpiresIn: int(time.Now().Add(time.Hour * 24).Unix()), }, data: map[string]string{"role": "admin"}, sign: "secret", shouldFail: false, shouldVerify: true, }, { name: "Expired token", token: entity.Token{ ID: "456", ExpiresIn: int(time.Now().Add(-time.Hour * 24).Unix()), // 過期時間 }, data: map[string]string{"role": "user"}, sign: "secret", shouldFail: false, // 這個測試不會失敗,因為過期檢查通常在驗證時進行 }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tokenString, err := generateAccessToken(tt.token, tt.data, tt.sign) if (err != nil) != tt.shouldFail { t.Errorf("generateAccessToken() error = %v, shouldFail %v", err, tt.shouldFail) return } if tt.shouldVerify { // 驗證生成的 token parsedToken, err := jwt.ParseWithClaims(tokenString, &entity.Claims{}, func(token *jwt.Token) (interface{}, error) { return []byte(tt.sign), nil }) if err != nil { t.Errorf("Error parsing token: %v", err) return } if claims, ok := parsedToken.Claims.(*entity.Claims); ok && parsedToken.Valid { if claims.ID != tt.token.ID { t.Errorf("Expected ID %v, got %v", tt.token.ID, claims.ID) } if claims.Issuer != "permission" { t.Errorf("Expected Issuer 'permission', got %v", claims.Issuer) } for k, v := range tt.data.(map[string]string) { if claims.Data.(map[string]any)[k] != v { t.Errorf("Expected data %v, got %v", v, claims.Data.(map[string]string)[k]) } } } else { t.Errorf("Invalid token claims") } } }) } } // 測試 generateRefreshToken 函數 func TestGenerateRefreshToken(t *testing.T) { // 定義測試用例 tests := []struct { accessToken string expected string }{ { accessToken: "test_access_token", expected: "4993552f2cc6c4e57fa5738f9b161a1a4051c8370cddb32514c8f6f4c797801f", }, { accessToken: "another_test_access_token", expected: "8361833e9a11f829f2be9a00f1939b5a72408ff829451169f3b223c41768cfa2", }, { accessToken: "", expected: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", }, } for _, tt := range tests { t.Run(tt.accessToken, func(t *testing.T) { got := generateRefreshToken(tt.accessToken) if got != tt.expected { t.Errorf("generateRefreshToken(%s) = %s; want %s", tt.accessToken, got, tt.expected) } }) } }