package usecase import ( "context" "crypto/sha256" "encoding/hex" "fmt" "time" "code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain" "code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/entity" "code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/repository" dt "code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/token" "code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/usecase" ers "code.30cm.net/digimon/library-go/errs" "github.com/golang-jwt/jwt/v4" "github.com/segmentio/ksuid" "github.com/zeromicro/go-zero/core/logx" ) type TokenUseCaseParam struct { TokenRepo repository.TokenRepo RefreshExpires time.Duration Expired time.Duration Secret string } type TokenUseCase struct { TokenUseCaseParam Token struct { RefreshExpires time.Duration Expired time.Duration Secret string } } func NewTokenUseCase(param TokenUseCaseParam) usecase.TokenUseCase { return &TokenUseCase{ TokenUseCaseParam: param, Token: struct { RefreshExpires time.Duration Expired time.Duration Secret string }{ RefreshExpires: param.RefreshExpires, Expired: param.Expired, Secret: param.Secret, }, } } func (use *TokenUseCase) GenerateAccessToken(ctx context.Context, req usecase.GenerateTokenRequest) (usecase.AccessTokenResponse, error) { token, err := use.newToken(ctx, &req) if err != nil { return usecase.AccessTokenResponse{}, err } err = use.TokenRepo.Create(ctx, *token) if err != nil { // 錯誤代碼 e := domain.TokenErrorL( domain.TokenCreateErrorCode, logx.WithContext(ctx), []logx.LogField{ {Key: "req", Value: req}, {Key: "func", Value: "TokenRepo.Create"}, {Key: "err", Value: err.Error()}, }, "failed to create token").Wrap(err) return usecase.AccessTokenResponse{}, e } return usecase.AccessTokenResponse{ AccessToken: token.AccessToken, ExpiresIn: token.ExpiresIn, RefreshToken: token.RefreshToken, }, nil } func (use *TokenUseCase) RefreshAccessToken(ctx context.Context, req usecase.RefreshTokenRequest) (usecase.RefreshTokenResponse, error) { // Step 1: 檢查 refresh token token, err := use.TokenRepo.GetAccessTokenByOneTimeToken(ctx, req.Token) if err != nil { return usecase.RefreshTokenResponse{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.GetAccessTokenByOneTimeToken", req: req, err: err, message: "failed to get access token", errorCode: domain.TokenRefreshErrorCode, }) } // Step 2: 提取 Claims Data claimsData, err := use.ParseSystemClaimsByAccessToken(token.AccessToken, use.Token.Secret, false) if err != nil { return usecase.RefreshTokenResponse{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "extractClaims", req: req, err: err, message: "failed to extract claims", errorCode: domain.TokenRefreshErrorCode, }) } data := NewAdditional(claimsData) data.Set(dt.Scope, req.Scope) data.Set(dt.Device, req.DeviceID) // Step 3: 創建新 token newToken, err := use.newToken(ctx, &usecase.GenerateTokenRequest{ Scope: req.Scope, DeviceID: req.DeviceID, Expires: req.Expires, RefreshExpires: req.RefreshExpires, Data: data.GetAll(), Role: data.Get(dt.Role), UID: data.Get(dt.UID), Account: data.Get(dt.Account), }) if err != nil { return usecase.RefreshTokenResponse{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "use.newToken", req: req, err: err, message: "failed to create new token", errorCode: domain.TokenRefreshErrorCode, }) } if err := use.TokenRepo.Create(ctx, *newToken); err != nil { return usecase.RefreshTokenResponse{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.Create", req: req, err: err, message: "failed to create new token", errorCode: domain.TokenRefreshErrorCode, }) } // Step 4: 刪除舊 token 並創建新 token if err := use.TokenRepo.Delete(ctx, token); err != nil { return usecase.RefreshTokenResponse{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.Delete", req: req, err: err, message: "failed to delete old token", errorCode: domain.TokenRefreshErrorCode, }) } // 返回新的 Token 響應 return usecase.RefreshTokenResponse{ AccessToken: newToken.AccessToken, RefreshToken: newToken.RefreshToken, ExpiresIn: newToken.ExpiresIn, TokenType: data.Get(dt.Type), }, nil } func (use *TokenUseCase) RevokeToken(ctx context.Context, req usecase.TokenRequest) error { claims, err := use.ParseSystemClaimsByAccessToken(req.Token, use.Token.Secret, false) if err != nil { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "CancelToken extractClaims", req: req, err: err, message: "failed to get token claims", errorCode: domain.TokenCancelErrorCode, }) } data := NewAdditional(claims) token, err := use.TokenRepo.GetAccessTokenByID(ctx, data.Get(dt.ID)) if err != nil { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo GetAccessTokenByID", req: req, err: err, message: fmt.Sprintf("failed to get token claims :%s", data.Get(dt.ID)), errorCode: domain.TokenCancelErrorCode, }) } err = use.TokenRepo.Delete(ctx, token) if err != nil { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo Delete", req: req, err: err, message: fmt.Sprintf("failed to delete token :%s", token.ID), errorCode: domain.TokenCancelErrorCode, }) } return nil } func (use *TokenUseCase) VerifyToken(ctx context.Context, req usecase.TokenRequest) (usecase.VerifyTokenResponse, error) { claims, err := use.ParseSystemClaimsByAccessToken(req.Token, use.Token.Secret, true) if err != nil { return usecase.VerifyTokenResponse{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "parseClaims", req: req, err: err, message: "validate token claims error", errorCode: domain.TokenValidateErrorCode, }) } data := NewAdditional(claims) token, err := use.TokenRepo.GetAccessTokenByID(ctx, data.Get(dt.ID)) if err != nil { return usecase.VerifyTokenResponse{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.GetAccessTokenByID", req: req, err: err, message: fmt.Sprintf("failed to get token :%s", data.Get(dt.ID)), errorCode: domain.TokenValidateErrorCode, }) } return usecase.VerifyTokenResponse{ Token: token, Data: data.GetAll(), }, nil } func (use *TokenUseCase) RevokeTokensByUID(ctx context.Context, req usecase.RevokeTokensByUIDRequest) error { if req.UID != "" { err := use.TokenRepo.DeleteAccessTokensByUID(ctx, req.UID) if err != nil { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.DeleteAccessTokensByUID", req: req, err: err, message: "failed to cancel tokens by uid", errorCode: domain.TokensCancelErrorCode, }) } } if len(req.IDs) > 0 { err := use.TokenRepo.DeleteAccessTokenByID(ctx, req.IDs) if err != nil { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.DeleteAccessTokenByID", req: req, err: err, message: "failed to cancel tokens by token ids", errorCode: domain.TokensCancelErrorCode, }) } } return nil } func (use *TokenUseCase) RevokeTokensByDeviceID(ctx context.Context, deviceID string) error { err := use.TokenRepo.DeleteAccessTokensByDeviceID(ctx, deviceID) if err != nil { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.DeleteAccessTokensByDeviceID", req: deviceID, err: err, message: "failed to cancel token by device id", errorCode: domain.TokensCancelErrorCode, }) } return nil } func (use *TokenUseCase) GetUserTokensByDeviceID(ctx context.Context, deviceID string) ([]*usecase.AccessTokenResponse, error) { tokens, err := use.TokenRepo.GetAccessTokensByDeviceID(ctx, deviceID) if err != nil { return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.GetAccessTokensByDeviceID", req: deviceID, err: err, message: "failed to get token by device id", errorCode: domain.TokenGetErrorCode, }) } result := make([]*usecase.AccessTokenResponse, 0, len(tokens)) for _, v := range tokens { result = append(result, &usecase.AccessTokenResponse{ AccessToken: v.AccessToken, ExpiresIn: v.ExpiresIn, RefreshToken: v.RefreshToken, }) } return result, nil } func (use *TokenUseCase) GetUserTokensByUID(ctx context.Context, uid string) ([]*usecase.AccessTokenResponse, error) { tokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, uid) if err != nil { return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "TokenRepo.GetAccessTokensByUID", req: uid, err: err, message: "failed to get token by uid", errorCode: domain.TokenGetErrorCode, }) } result := make([]*usecase.AccessTokenResponse, 0, len(tokens)) for _, v := range tokens { result = append(result, &usecase.AccessTokenResponse{ AccessToken: v.AccessToken, ExpiresIn: v.ExpiresIn, RefreshToken: v.RefreshToken, }) } return result, nil } func (use *TokenUseCase) ReadTokenBasicData(ctx context.Context, token string) (usecase.Additional, error) { claims, err := use.ParseSystemClaimsByAccessToken(token, use.Token.Secret, false) if err != nil { return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "parseClaims", req: token, err: err, message: "validate token claims error", errorCode: domain.TokenValidateErrorCode, }) } return NewAdditional(claims), nil } // ======== JWT Token ======== // CreateAccessToken 會將基本 token 以及想要加入Token Claims 的Data 依照 secret key 加密之後變成 jwt access token func (use *TokenUseCase) CreateAccessToken(token entity.Token, data any, secretKey string) (string, error) { claims := entity.Claims{ Data: data, RegisteredClaims: jwt.RegisteredClaims{ ID: token.ID, ExpiresAt: jwt.NewNumericDate(time.Unix(0, token.ExpiresIn)), Issuer: dt.Issuer, }, } accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims). SignedString([]byte(secretKey)) if err != nil { return "", err } return accessToken, nil } func (use *TokenUseCase) CreateRefreshToken(accessToken string) string { hash := sha256.New() _, _ = hash.Write([]byte(accessToken)) return hex.EncodeToString(hash.Sum(nil)) } func (use *TokenUseCase) ParseJWTClaimsByAccessToken(accessToken string, secret string, validate bool) (jwt.MapClaims, error) { // 跳過驗證的解析 var token *jwt.Token var err error if validate { token, err = jwt.Parse(accessToken, func(token *jwt.Token) (any, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("token unexpected signing method: %v", token.Header["alg"]) } return []byte(secret), nil }) if err != nil { return jwt.MapClaims{}, err } } else { parser := jwt.NewParser(jwt.WithoutClaimsValidation()) token, err = parser.Parse(accessToken, func(_ *jwt.Token) (any, error) { return []byte(secret), nil }) if err != nil { return jwt.MapClaims{}, err } } claims, ok := token.Claims.(jwt.MapClaims) if !ok && token.Valid { return jwt.MapClaims{}, fmt.Errorf("token valid error") } return claims, nil } func (use *TokenUseCase) ParseSystemClaimsByAccessToken(accessToken string, secret string, validate bool) (map[string]string, error) { claimMap, err := use.ParseJWTClaimsByAccessToken(accessToken, secret, validate) if err != nil { return map[string]string{}, err } claimsData, ok := claimMap["data"].(map[string]any) if ok { return convertMap(claimsData), nil } return map[string]string{}, fmt.Errorf("get data from claim map error") } // ======== 工具 ======== func (use *TokenUseCase) newToken(ctx context.Context, req *usecase.GenerateTokenRequest) (*entity.Token, error) { // 準備建立 Token 所需 now := time.Now().UTC() expires := req.Expires refreshExpires := req.RefreshExpires if expires <= 0 { // 將時間加上 n 秒 -> 系統內預設 sec := time.Duration(use.Token.Expired.Seconds()) * time.Second // 獲取 Unix 時間戳 expires = now.Add(sec).UnixNano() } // Refresh Token 過期時間要比普通的Token 長 if req.RefreshExpires <= 0 { // 獲取 Unix 時間戳 refresh := time.Duration(use.Token.RefreshExpires.Seconds()) * time.Second refreshExpires = now.Add(refresh).UnixNano() } token := entity.Token{ ID: ksuid.New().String(), DeviceID: req.DeviceID, ExpiresIn: expires, RefreshExpiresIn: refreshExpires, AccessCreateAt: now.UnixNano(), RefreshCreateAt: now.UnixNano(), UID: req.UID, } // 故意 data 裡面不會有那些已經有的欄位資訊 data := NewAdditional(req.Data) data.Set(dt.ID, token.ID) data.Set(dt.Role, req.Role) data.Set(dt.Scope, req.Scope) data.Set(dt.Account, req.Account) data.Set(dt.UID, req.UID) data.Set(dt.Type, req.TokenType) if req.DeviceID != "" { data.Set(dt.Device, req.DeviceID) } var err error token.AccessToken, err = use.CreateAccessToken(token, data.GetAll(), use.Token.Secret) token.RefreshToken = use.CreateRefreshToken(token.AccessToken) if err != nil { // 錯誤代碼 20-201-02 e := domain.TokenErrorL( domain.TokenClaimErrorCode, logx.WithContext(ctx), []logx.LogField{ {Key: "req", Value: req}, {Key: "func", Value: "accessTokenGenerator"}, {Key: "err", Value: err.Error()}, }, "failed to generator access token").Wrap(err) return nil, e } return &token, nil } func convertMap(input map[string]any) map[string]string { output := make(map[string]string) for key, value := range input { switch v := value.(type) { case string: output[key] = v case fmt.Stringer: output[key] = v.String() default: output[key] = fmt.Sprintf("%v", value) } } return output } type wrapTokenErrorReq struct { funcName string req any err error message string errorCode ers.ErrorCode } // wrapTokenError 將錯誤訊息封裝到 domain.TokenErrorL 中 func (use *TokenUseCase) wrapTokenError(ctx context.Context, param wrapTokenErrorReq) error { logFields := []logx.LogField{ {Key: "req", Value: param.req}, {Key: "func", Value: param.funcName}, {Key: "err", Value: param.err.Error()}, } wrappedErr := domain.TokenErrorL( param.errorCode, logx.WithContext(ctx), logFields, param.message, ).Wrap(param.err) return wrappedErr }