diff --git a/go.work b/go.work index d9fccaf..90a2422 100644 --- a/go.work +++ b/go.work @@ -4,4 +4,5 @@ use ( ./errors ./validator ./worker_pool + ./jwt ) diff --git a/jwt/claims.go b/jwt/claims.go new file mode 100644 index 0000000..d4a7a62 --- /dev/null +++ b/jwt/claims.go @@ -0,0 +1,68 @@ +package jwt + +type DataClaims map[string]string + +const ( + idCode = "id" + roleCode = "role" + deviceIDCode = "device_id" + scopeCode = "scope" + uidCode = "uid" +) + +// ============ 使用具體的 setter ============ + +// Set 通用的 setter 方法 +func (c DataClaims) Set(key, value string) { + c[key] = value +} + +func (c DataClaims) SetID(id string) { + c.Set(idCode, id) +} + +func (c DataClaims) SetRole(role string) { + c.Set(roleCode, role) +} + +func (c DataClaims) SetDeviceID(deviceID string) { + c.Set(deviceIDCode, deviceID) +} + +func (c DataClaims) SetScope(scope string) { + c.Set(scopeCode, scope) +} + +func (c DataClaims) SetUID(uid string) { + c.Set(uidCode, uid) +} + +// ============ 使用具體的 getter ============ + +func (c DataClaims) Get(key string) string { + if val, ok := c[key]; ok { + return val + } + + return "" +} + +func (c DataClaims) Scope() { + c.Get(scopeCode) +} + +func (c DataClaims) Role() string { + return c.Get(roleCode) +} + +func (c DataClaims) ID() string { + return c.Get(idCode) +} + +func (c DataClaims) DeviceID() string { + return c.Get(deviceIDCode) +} + +func (c DataClaims) UID() string { + return c.Get(uidCode) +} diff --git a/jwt/claims_test.go b/jwt/claims_test.go new file mode 100644 index 0000000..260b8c6 --- /dev/null +++ b/jwt/claims_test.go @@ -0,0 +1,72 @@ +package jwt + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDataClaimsSettersAndGetters(t *testing.T) { + tests := []struct { + name string + setterFunc func(c DataClaims, value string) + getterFunc func(c DataClaims) string + value string + expectedVal string + }{ + { + name: "Set and Get ID", + setterFunc: func(c DataClaims, value string) { c.SetID(value) }, + getterFunc: func(c DataClaims) string { return c.ID() }, + value: "12345", + expectedVal: "12345", + }, + { + name: "Set and Get Role", + setterFunc: func(c DataClaims, value string) { c.SetRole(value) }, + getterFunc: func(c DataClaims) string { return c.Role() }, + value: "admin", + expectedVal: "admin", + }, + { + name: "Set and Get Device ID", + setterFunc: func(c DataClaims, value string) { c.SetDeviceID(value) }, + getterFunc: func(c DataClaims) string { return c.DeviceID() }, + value: "device123", + expectedVal: "device123", + }, + { + name: "Set and Get Scope", + setterFunc: func(c DataClaims, value string) { c.SetScope(value) }, + getterFunc: func(c DataClaims) string { return c.Get(scopeCode) }, + value: "read", + expectedVal: "read", + }, + { + name: "Set and Get UID", + setterFunc: func(c DataClaims, value string) { c.SetUID(value) }, + getterFunc: func(c DataClaims) string { return c.UID() }, + value: "user123", + expectedVal: "user123", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + claims := DataClaims{} + + // Call the setter function + tt.setterFunc(claims, tt.value) + + // Call the getter function and verify the result + require.Equal(t, tt.expectedVal, tt.getterFunc(claims), "Expected value does not match") + }) + } +} + +func TestDataClaimsGetNonExistentKey(t *testing.T) { + claims := DataClaims{} + + // 對於不存在的鍵,應返回空字串 + require.Equal(t, "", claims.Get("nonexistent_key"), "Should return empty string for non-existent key") +} diff --git a/jwt/define.go b/jwt/define.go new file mode 100644 index 0000000..f8cf828 --- /dev/null +++ b/jwt/define.go @@ -0,0 +1,52 @@ +package jwt + +import ( + "time" + + "github.com/golang-jwt/jwt/v4" +) + +type Token struct { + ID string `json:"id"` + UID string `json:"uid"` + DeviceID string `json:"device_id"` + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + AccessCreateAt time.Time `json:"access_create_at"` + RefreshToken string `json:"refresh_token"` + RefreshExpiresIn int `json:"refresh_expires_in"` + RefreshCreateAt time.Time `json:"refresh_create_at"` +} + +func (t *Token) AccessTokenExpires() time.Duration { + return time.Duration(t.ExpiresIn) * time.Second +} + +func (t *Token) RefreshTokenExpires() time.Duration { + return time.Duration(t.RefreshExpiresIn) * time.Second +} + +func (t *Token) RefreshTokenExpiresUnix() int64 { + return time.Now().Add(t.RefreshTokenExpires()).Unix() +} + +func (t *Token) IsExpires() bool { + return t.AccessCreateAt.Add(t.AccessTokenExpires()).Before(time.Now()) +} + +func (t *Token) RedisExpiredSec() int64 { + sec := time.Unix(int64(t.ExpiresIn), 0).Sub(time.Now().UTC()) + + return int64(sec.Seconds()) +} + +func (t *Token) RedisRefreshExpiredSec() int64 { + sec := time.Unix(int64(t.RefreshExpiresIn), 0).Sub(time.Now().UTC()) + + return int64(sec.Seconds()) +} + +type Claims struct { + jwt.RegisteredClaims + Data interface{} `json:"data"` +} diff --git a/jwt/define_test.go b/jwt/define_test.go new file mode 100644 index 0000000..2496e64 --- /dev/null +++ b/jwt/define_test.go @@ -0,0 +1,82 @@ +package jwt + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestAccessTokenExpires(t *testing.T) { + token := &Token{ + ExpiresIn: 3600, // 1小時 + } + + expectedDuration := time.Hour + actualDuration := token.AccessTokenExpires() + + require.Equal(t, expectedDuration, actualDuration, "Access token expiration duration should be 1 hour") +} + +func TestRefreshTokenExpires(t *testing.T) { + token := &Token{ + RefreshExpiresIn: 7200, // 2小時 + } + + expectedDuration := 2 * time.Hour + actualDuration := token.RefreshTokenExpires() + + require.Equal(t, expectedDuration, actualDuration, "Refresh token expiration duration should be 2 hours") +} + +func TestRefreshTokenExpiresUnix(t *testing.T) { + token := &Token{ + RefreshExpiresIn: 3600, // 1小時 + } + + expectedUnix := time.Now().Add(1 * time.Hour).Unix() + actualUnix := token.RefreshTokenExpiresUnix() + + // 設定允許範圍,確保結果在1秒的範圍內 + require.InEpsilon(t, expectedUnix, actualUnix, 1, "Refresh token expires Unix time should match the expected time") +} + +func TestIsExpires(t *testing.T) { + // 測試過期情況 + tokenExpired := &Token{ + ExpiresIn: 3600, // 1小時 + AccessCreateAt: time.Now().Add(-2 * time.Hour), // 2小時前生成的 token,應該過期 + } + require.True(t, tokenExpired.IsExpires(), "Token should be expired") + + // 測試未過期情況 + tokenNotExpired := &Token{ + ExpiresIn: 3600, // 1小時 + AccessCreateAt: time.Now().Add(-30 * time.Minute), // 30分鐘前生成的 token,應該未過期 + } + require.False(t, tokenNotExpired.IsExpires(), "Token should not be expired") +} + +func TestRedisExpiredSec(t *testing.T) { + token := &Token{ + ExpiresIn: int(time.Now().Add(1 * time.Hour).Unix()), // 1小時後過期 + } + + expectedSec := int64(3600) // 1小時 + actualSec := token.RedisExpiredSec() + + // 確保時間在合理範圍內 + require.InDelta(t, expectedSec, actualSec, 1, "Redis expired seconds should be close to 3600 seconds") +} + +func TestRedisRefreshExpiredSec(t *testing.T) { + token := &Token{ + RefreshExpiresIn: int(time.Now().Add(2 * time.Hour).Unix()), // 2小時後過期 + } + + expectedSec := int64(7200) // 2小時 + actualSec := token.RedisRefreshExpiredSec() + + // 確保時間在合理範圍內 + require.InDelta(t, expectedSec, actualSec, 1, "Redis refresh expired seconds should be close to 7200 seconds") +} diff --git a/jwt/go.mod b/jwt/go.mod new file mode 100644 index 0000000..c2a705a --- /dev/null +++ b/jwt/go.mod @@ -0,0 +1,14 @@ +module code.30cm.net/digimon/library-go/jwt + +go 1.22.3 + +require ( + github.com/golang-jwt/jwt/v4 v4.5.0 + github.com/stretchr/testify v1.9.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/jwt/token.go b/jwt/token.go new file mode 100644 index 0000000..a58e6fe --- /dev/null +++ b/jwt/token.go @@ -0,0 +1,91 @@ +package jwt + +import ( + "fmt" + "time" + + "github.com/golang-jwt/jwt/v4" +) + +func GenerateAccessToken(token Token, data any, sign string, issuer string) (string, error) { + claim := Claims{ + Data: data, + RegisteredClaims: jwt.RegisteredClaims{ + ID: token.ID, + ExpiresAt: jwt.NewNumericDate(time.Unix(int64(token.ExpiresIn), 0)), + Issuer: issuer, + }, + } + + accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claim). + SignedString([]byte(sign)) + if err != nil { + return "", err + } + + return accessToken, nil +} + +func ParseToken(accessToken string, secret string, validate bool) (jwt.MapClaims, error) { + // 跳過驗證的解析 + var token *jwt.Token + var err error + + if validate { + token, err = jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("token unexpected signing method: %v", token.Header["alg"]) + } + + return []byte(secret), nil + }) + if err != nil { + return jwt.MapClaims{}, err + } + } else { + parser := jwt.NewParser(jwt.WithoutClaimsValidation()) + token, err = parser.Parse(accessToken, func(_ *jwt.Token) (interface{}, error) { + return []byte(secret), nil + }) + if err != nil { + return jwt.MapClaims{}, err + } + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok && token.Valid { + return jwt.MapClaims{}, fmt.Errorf("token valid error") + } + + return claims, nil +} + +func ParseClaims(accessToken string, secret string, validate bool) (DataClaims, error) { + claimMap, err := ParseToken(accessToken, secret, validate) + if err != nil { + return DataClaims{}, err + } + + claimsData, ok := claimMap["data"].(map[string]any) + if ok { + return convertMap(claimsData), nil + } + + return DataClaims{}, fmt.Errorf("get data from claim map error") +} + +func convertMap(input map[string]interface{}) map[string]string { + output := make(map[string]string) + for key, value := range input { + switch v := value.(type) { + case string: + output[key] = v + case fmt.Stringer: + output[key] = v.String() + default: + output[key] = fmt.Sprintf("%v", value) + } + } + + return output +} diff --git a/jwt/token_test.go b/jwt/token_test.go new file mode 100644 index 0000000..aeea248 --- /dev/null +++ b/jwt/token_test.go @@ -0,0 +1,87 @@ +package jwt + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestGenerateAccessToken(t *testing.T) { + // 定義測試參數 + token := Token{ + ID: "12345", + ExpiresIn: int(time.Now().Add(1 * time.Hour).Unix()), + } + sign := "secret_sign" + data := map[string]string{ + "role": "admin", + "uid": "user123", + } + issuer := "test_issuer" + + // 調用生成 access token + accessToken, err := GenerateAccessToken(token, data, sign, issuer) + require.NoError(t, err) + require.NotEmpty(t, accessToken) + + // 檢查 access token 是否可以解析 + claims, err := ParseToken(accessToken, sign, true) + require.NoError(t, err) + + // 驗證 Claims 是否正確 + require.Equal(t, token.ID, claims["jti"]) + require.Equal(t, issuer, claims["iss"]) + require.Equal(t, "admin", claims["data"].(map[string]interface{})["role"]) + require.Equal(t, "user123", claims["data"].(map[string]interface{})["uid"]) +} + +func TestParseToken(t *testing.T) { + // 測試生成並解析 token + token := Token{ + ID: "67890", + ExpiresIn: int(time.Now().Add(2 * time.Hour).Unix()), + } + sign := "another_secret_sign" + data := map[string]string{ + "role": "user", + "uid": "user456", + } + + accessToken, err := GenerateAccessToken(token, data, sign, "example_issuer") + require.NoError(t, err) + require.NotEmpty(t, accessToken) + + // 測試有驗證的解析 + claims, err := ParseToken(accessToken, sign, true) + require.NoError(t, err) + require.Equal(t, "user", claims["data"].(map[string]interface{})["role"]) + require.Equal(t, "user456", claims["data"].(map[string]interface{})["uid"]) + + // 測試不驗證的解析 + claimsNoValidation, err := ParseToken(accessToken, sign, false) + require.NoError(t, err) + require.Equal(t, "user", claimsNoValidation["data"].(map[string]interface{})["role"]) +} + +func TestParseClaims(t *testing.T) { + // 測試生成並解析 claims + token := Token{ + ID: "54321", + ExpiresIn: int(time.Now().Add(3 * time.Hour).Unix()), + } + sign := "test_sign" + data := map[string]string{ + "role": "moderator", + "uid": "user789", + } + + accessToken, err := GenerateAccessToken(token, data, sign, "sample_issuer") + require.NoError(t, err) + + // 測試 claims 解析 + parsedClaims, err := ParseClaims(accessToken, sign, true) + require.NoError(t, err) + require.Equal(t, "moderator", parsedClaims["role"]) + require.Equal(t, "user789", parsedClaims["uid"]) +}