guard/internal/logic/new_token_logic_test.go

200 lines
5.3 KiB
Go

package logic
import (
"ark-permission/internal/entity"
"github.com/golang-jwt/jwt/v4"
"testing"
"time"
)
// func TestNewTokenLogic_NewToken(t *testing.T) {
// // mock
// ctrl := gomock.NewController(t)
// defer ctrl.Finish()
//
// tokenMockRepo := repoMock.NewMockTokenRepository(ctrl)
// mockValidate := libMock.NewMockValidate(ctrl)
//
// sc := svc.ServiceContext{
// TokenRedisRepo: tokenMockRepo,
// Validate: mockValidate,
// }
//
// l := NewNewTokenLogic(context.Background(), &sc)
//
// tests := []struct {
// name string
// input *permission.AuthorizationReq
// setupMocks func()
// expectError bool
// expected *permission.TokenResp
// }{
// {
// name: "Valid token request",
// input: &permission.AuthorizationReq{
// GrantType: "authorization_code",
// DeviceId: "device123",
// Scope: "read",
// Expires: 3600,
// IsRefreshToken: false,
// Data: map[string]string{
// "uid": "user123",
// },
// },
// setupMocks: func() {
// mockValidate.EXPECT().ValidateAll(gomock.Any()).Return(nil)
// tokenMockRepo.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil).Do(func(ctx context.Context, token entity.Token) {
// token.AccessToken = "access_token"
// })
// generateAccessTokenFunc = func(token entity.Token, data any, sign string) (string, error) {
// return "access_token", nil
// }
// generateRefreshTokenFunc = func(accessToken string) string {
// return "refresh_token"
// }
// },
// expectError: false,
// expected: &permission.TokenResp{
// AccessToken: "access_token",
// TokenType: domain.TokenTypeBearer,
// ExpiresIn: 3600,
// RefreshToken: "",
// },
// },
// {
// name: "Validation error",
// input: &permission.AuthorizationReq{
// GrantType: "invalid_grant",
// DeviceId: "device123",
// Scope: "read",
// Expires: 3600,
// IsRefreshToken: false,
// Data: map[string]string{
// "uid": "user123",
// },
// },
// setupMocks: func() {
// mockValidate.EXPECT().ValidateAll(gomock.Any()).Return(errors.New("invalid grant type"))
// },
// expectError: true,
// expected: nil,
// },
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// tt.setupMocks()
//
// resp, err := l.NewToken(tt.input)
// if tt.expectError {
// assert.Error(t, err)
// } else {
// assert.NoError(t, err)
// assert.Equal(t, tt.expected, resp)
// }
// })
// }
// }
// 測試 generateAccessToken 函數
func TestGenerateAccessToken(t *testing.T) {
// 定義測試用例
tests := []struct {
name string
token entity.Token
data any
sign string
shouldFail bool
shouldVerify bool
}{
{
name: "Valid token with admin role",
token: entity.Token{
ID: "123",
ExpiresIn: int(time.Now().Add(time.Hour * 24).Unix()),
},
data: map[string]string{"role": "admin"},
sign: "secret",
shouldFail: false,
shouldVerify: true,
},
{
name: "Expired token",
token: entity.Token{
ID: "456",
ExpiresIn: int(time.Now().Add(-time.Hour * 24).Unix()), // 過期時間
},
data: map[string]string{"role": "user"},
sign: "secret",
shouldFail: false, // 這個測試不會失敗,因為過期檢查通常在驗證時進行
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokenString, err := generateAccessToken(tt.token, tt.data, tt.sign)
if (err != nil) != tt.shouldFail {
t.Errorf("generateAccessToken() error = %v, shouldFail %v", err, tt.shouldFail)
return
}
if tt.shouldVerify {
// 驗證生成的 token
parsedToken, err := jwt.ParseWithClaims(tokenString, &entity.Claims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(tt.sign), nil
})
if err != nil {
t.Errorf("Error parsing token: %v", err)
return
}
if claims, ok := parsedToken.Claims.(*entity.Claims); ok && parsedToken.Valid {
if claims.ID != tt.token.ID {
t.Errorf("Expected ID %v, got %v", tt.token.ID, claims.ID)
}
if claims.Issuer != "permission" {
t.Errorf("Expected Issuer 'permission', got %v", claims.Issuer)
}
for k, v := range tt.data.(map[string]string) {
if claims.Data.(map[string]any)[k] != v {
t.Errorf("Expected data %v, got %v", v, claims.Data.(map[string]string)[k])
}
}
} else {
t.Errorf("Invalid token claims")
}
}
})
}
}
// 測試 generateRefreshToken 函數
func TestGenerateRefreshToken(t *testing.T) {
// 定義測試用例
tests := []struct {
accessToken string
expected string
}{
{
accessToken: "test_access_token",
expected: "4993552f2cc6c4e57fa5738f9b161a1a4051c8370cddb32514c8f6f4c797801f",
},
{
accessToken: "another_test_access_token",
expected: "8361833e9a11f829f2be9a00f1939b5a72408ff829451169f3b223c41768cfa2",
},
{
accessToken: "",
expected: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
},
}
for _, tt := range tests {
t.Run(tt.accessToken, func(t *testing.T) {
got := generateRefreshToken(tt.accessToken)
if got != tt.expected {
t.Errorf("generateRefreshToken(%s) = %s; want %s", tt.accessToken, got, tt.expected)
}
})
}
}