app-cloudep-permission-server/pkg/usecase/token.go

528 lines
15 KiB
Go
Raw Permalink Normal View History

2025-02-13 11:06:51 +00:00
package usecase
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"time"
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain"
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/entity"
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/repository"
dt "code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/token"
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/usecase"
ers "code.30cm.net/digimon/library-go/errs"
"github.com/golang-jwt/jwt/v4"
"github.com/segmentio/ksuid"
"github.com/zeromicro/go-zero/core/logx"
)
type TokenUseCaseParam struct {
TokenRepo repository.TokenRepo
RefreshExpires time.Duration
Expired time.Duration
Secret string
}
type TokenUseCase struct {
TokenUseCaseParam
Token struct {
RefreshExpires time.Duration
Expired time.Duration
Secret string
}
}
func NewTokenUseCase(param TokenUseCaseParam) usecase.TokenUseCase {
return &TokenUseCase{
TokenUseCaseParam: param,
Token: struct {
RefreshExpires time.Duration
Expired time.Duration
Secret string
}{
RefreshExpires: param.RefreshExpires,
Expired: param.Expired,
Secret: param.Secret,
},
}
}
func (use *TokenUseCase) GenerateAccessToken(ctx context.Context, req usecase.GenerateTokenRequest) (usecase.AccessTokenResponse, error) {
token, err := use.newToken(ctx, &req)
if err != nil {
return usecase.AccessTokenResponse{}, err
}
err = use.TokenRepo.Create(ctx, *token)
if err != nil {
// 錯誤代碼
e := domain.TokenErrorL(
domain.TokenCreateErrorCode,
logx.WithContext(ctx),
[]logx.LogField{
{Key: "req", Value: req},
{Key: "func", Value: "TokenRepo.Create"},
{Key: "err", Value: err.Error()},
},
"failed to create token").Wrap(err)
return usecase.AccessTokenResponse{}, e
}
return usecase.AccessTokenResponse{
AccessToken: token.AccessToken,
ExpiresIn: token.ExpiresIn,
RefreshToken: token.RefreshToken,
}, nil
}
func (use *TokenUseCase) RefreshAccessToken(ctx context.Context, req usecase.RefreshTokenRequest) (usecase.RefreshTokenResponse, error) {
// Step 1: 檢查 refresh token
token, err := use.TokenRepo.GetAccessTokenByOneTimeToken(ctx, req.Token)
if err != nil {
return usecase.RefreshTokenResponse{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokenByOneTimeToken",
req: req,
err: err,
message: "failed to get access token",
errorCode: domain.TokenRefreshErrorCode,
})
}
// Step 2: 提取 Claims Data
claimsData, err := use.ParseSystemClaimsByAccessToken(token.AccessToken, use.Token.Secret, false)
if err != nil {
return usecase.RefreshTokenResponse{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "extractClaims",
req: req,
err: err,
message: "failed to extract claims",
errorCode: domain.TokenRefreshErrorCode,
})
}
data := NewAdditional(claimsData)
data.Set(dt.Scope, req.Scope)
data.Set(dt.Device, req.DeviceID)
// Step 3: 創建新 token
newToken, err := use.newToken(ctx, &usecase.GenerateTokenRequest{
Scope: req.Scope,
DeviceID: req.DeviceID,
Expires: req.Expires,
RefreshExpires: req.RefreshExpires,
Data: data.GetAll(),
Role: data.Get(dt.Role),
UID: data.Get(dt.UID),
Account: data.Get(dt.Account),
})
if err != nil {
return usecase.RefreshTokenResponse{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "use.newToken",
req: req,
err: err,
message: "failed to create new token",
errorCode: domain.TokenRefreshErrorCode,
})
}
if err := use.TokenRepo.Create(ctx, *newToken); err != nil {
return usecase.RefreshTokenResponse{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.Create",
req: req,
err: err,
message: "failed to create new token",
errorCode: domain.TokenRefreshErrorCode,
})
}
// Step 4: 刪除舊 token 並創建新 token
if err := use.TokenRepo.Delete(ctx, token); err != nil {
return usecase.RefreshTokenResponse{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.Delete",
req: req,
err: err,
message: "failed to delete old token",
errorCode: domain.TokenRefreshErrorCode,
})
}
// 返回新的 Token 響應
return usecase.RefreshTokenResponse{
AccessToken: newToken.AccessToken,
RefreshToken: newToken.RefreshToken,
ExpiresIn: newToken.ExpiresIn,
TokenType: data.Get(dt.Type),
}, nil
}
func (use *TokenUseCase) RevokeToken(ctx context.Context, req usecase.TokenRequest) error {
claims, err := use.ParseSystemClaimsByAccessToken(req.Token, use.Token.Secret, false)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "CancelToken extractClaims",
req: req,
err: err,
message: "failed to get token claims",
errorCode: domain.TokenCancelErrorCode,
})
}
data := NewAdditional(claims)
token, err := use.TokenRepo.GetAccessTokenByID(ctx, data.Get(dt.ID))
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo GetAccessTokenByID",
req: req,
err: err,
message: fmt.Sprintf("failed to get token claims :%s", data.Get(dt.ID)),
errorCode: domain.TokenCancelErrorCode,
})
}
err = use.TokenRepo.Delete(ctx, token)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo Delete",
req: req,
err: err,
message: fmt.Sprintf("failed to delete token :%s", token.ID),
errorCode: domain.TokenCancelErrorCode,
})
}
return nil
}
func (use *TokenUseCase) VerifyToken(ctx context.Context, req usecase.TokenRequest) (usecase.VerifyTokenResponse, error) {
claims, err := use.ParseSystemClaimsByAccessToken(req.Token, use.Token.Secret, true)
if err != nil {
return usecase.VerifyTokenResponse{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "parseClaims",
req: req,
err: err,
message: "validate token claims error",
errorCode: domain.TokenValidateErrorCode,
})
}
data := NewAdditional(claims)
token, err := use.TokenRepo.GetAccessTokenByID(ctx, data.Get(dt.ID))
if err != nil {
return usecase.VerifyTokenResponse{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokenByID",
req: req,
err: err,
message: fmt.Sprintf("failed to get token :%s", data.Get(dt.ID)),
errorCode: domain.TokenValidateErrorCode,
})
}
return usecase.VerifyTokenResponse{
Token: token,
Data: data.GetAll(),
}, nil
}
func (use *TokenUseCase) RevokeTokensByUID(ctx context.Context, req usecase.RevokeTokensByUIDRequest) error {
if req.UID != "" {
err := use.TokenRepo.DeleteAccessTokensByUID(ctx, req.UID)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.DeleteAccessTokensByUID",
req: req,
err: err,
message: "failed to cancel tokens by uid",
errorCode: domain.TokensCancelErrorCode,
})
}
}
if len(req.IDs) > 0 {
err := use.TokenRepo.DeleteAccessTokenByID(ctx, req.IDs)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.DeleteAccessTokenByID",
req: req,
err: err,
message: "failed to cancel tokens by token ids",
errorCode: domain.TokensCancelErrorCode,
})
}
}
return nil
}
func (use *TokenUseCase) RevokeTokensByDeviceID(ctx context.Context, deviceID string) error {
err := use.TokenRepo.DeleteAccessTokensByDeviceID(ctx, deviceID)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.DeleteAccessTokensByDeviceID",
req: deviceID,
err: err,
message: "failed to cancel token by device id",
errorCode: domain.TokensCancelErrorCode,
})
}
return nil
}
func (use *TokenUseCase) GetUserTokensByDeviceID(ctx context.Context, deviceID string) ([]*usecase.AccessTokenResponse, error) {
tokens, err := use.TokenRepo.GetAccessTokensByDeviceID(ctx, deviceID)
if err != nil {
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokensByDeviceID",
req: deviceID,
err: err,
message: "failed to get token by device id",
errorCode: domain.TokenGetErrorCode,
})
}
result := make([]*usecase.AccessTokenResponse, 0, len(tokens))
for _, v := range tokens {
result = append(result, &usecase.AccessTokenResponse{
AccessToken: v.AccessToken,
ExpiresIn: v.ExpiresIn,
RefreshToken: v.RefreshToken,
})
}
return result, nil
}
func (use *TokenUseCase) GetUserTokensByUID(ctx context.Context, uid string) ([]*usecase.AccessTokenResponse, error) {
tokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, uid)
if err != nil {
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokensByUID",
req: uid,
err: err,
message: "failed to get token by uid",
errorCode: domain.TokenGetErrorCode,
})
}
result := make([]*usecase.AccessTokenResponse, 0, len(tokens))
for _, v := range tokens {
result = append(result, &usecase.AccessTokenResponse{
AccessToken: v.AccessToken,
ExpiresIn: v.ExpiresIn,
RefreshToken: v.RefreshToken,
})
}
return result, nil
}
func (use *TokenUseCase) ReadTokenBasicData(ctx context.Context, token string) (usecase.Additional, error) {
claims, err := use.ParseSystemClaimsByAccessToken(token, use.Token.Secret, false)
if err != nil {
return nil,
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "parseClaims",
req: token,
err: err,
message: "validate token claims error",
errorCode: domain.TokenValidateErrorCode,
})
}
return NewAdditional(claims), nil
}
// ======== JWT Token ========
// CreateAccessToken 會將基本 token 以及想要加入Token Claims 的Data 依照 secret key 加密之後變成 jwt access token
func (use *TokenUseCase) CreateAccessToken(token entity.Token, data any, secretKey string) (string, error) {
claims := entity.Claims{
Data: data,
RegisteredClaims: jwt.RegisteredClaims{
ID: token.ID,
ExpiresAt: jwt.NewNumericDate(time.Unix(0, token.ExpiresIn)),
Issuer: dt.Issuer,
},
}
accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).
SignedString([]byte(secretKey))
if err != nil {
return "", err
}
return accessToken, nil
}
func (use *TokenUseCase) CreateRefreshToken(accessToken string) string {
hash := sha256.New()
_, _ = hash.Write([]byte(accessToken))
return hex.EncodeToString(hash.Sum(nil))
}
func (use *TokenUseCase) ParseJWTClaimsByAccessToken(accessToken string, secret string, validate bool) (jwt.MapClaims, error) {
// 跳過驗證的解析
var token *jwt.Token
var err error
if validate {
token, err = jwt.Parse(accessToken, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("token unexpected signing method: %v", token.Header["alg"])
}
return []byte(secret), nil
})
if err != nil {
return jwt.MapClaims{}, err
}
} else {
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
token, err = parser.Parse(accessToken, func(_ *jwt.Token) (any, error) {
return []byte(secret), nil
})
if err != nil {
return jwt.MapClaims{}, err
}
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok && token.Valid {
return jwt.MapClaims{}, fmt.Errorf("token valid error")
}
return claims, nil
}
func (use *TokenUseCase) ParseSystemClaimsByAccessToken(accessToken string, secret string, validate bool) (map[string]string, error) {
claimMap, err := use.ParseJWTClaimsByAccessToken(accessToken, secret, validate)
if err != nil {
return map[string]string{}, err
}
claimsData, ok := claimMap["data"].(map[string]any)
if ok {
return convertMap(claimsData), nil
}
return map[string]string{}, fmt.Errorf("get data from claim map error")
}
// ======== 工具 ========
func (use *TokenUseCase) newToken(ctx context.Context, req *usecase.GenerateTokenRequest) (*entity.Token, error) {
// 準備建立 Token 所需
now := time.Now().UTC()
expires := req.Expires
refreshExpires := req.RefreshExpires
if expires <= 0 {
// 將時間加上 n 秒 -> 系統內預設
sec := time.Duration(use.Token.Expired.Seconds()) * time.Second
// 獲取 Unix 時間戳
expires = now.Add(sec).UnixNano()
}
// Refresh Token 過期時間要比普通的Token 長
if req.RefreshExpires <= 0 {
// 獲取 Unix 時間戳
refresh := time.Duration(use.Token.RefreshExpires.Seconds()) * time.Second
refreshExpires = now.Add(refresh).UnixNano()
}
token := entity.Token{
ID: ksuid.New().String(),
DeviceID: req.DeviceID,
ExpiresIn: expires,
RefreshExpiresIn: refreshExpires,
AccessCreateAt: now.UnixNano(),
RefreshCreateAt: now.UnixNano(),
UID: req.UID,
}
// 故意 data 裡面不會有那些已經有的欄位資訊
data := NewAdditional(req.Data)
data.Set(dt.ID, token.ID)
data.Set(dt.Role, req.Role)
data.Set(dt.Scope, req.Scope)
data.Set(dt.Account, req.Account)
data.Set(dt.UID, req.UID)
data.Set(dt.Type, req.TokenType)
if req.DeviceID != "" {
data.Set(dt.Device, req.DeviceID)
}
var err error
token.AccessToken, err = use.CreateAccessToken(token, data.GetAll(), use.Token.Secret)
token.RefreshToken = use.CreateRefreshToken(token.AccessToken)
if err != nil {
// 錯誤代碼 20-201-02
e := domain.TokenErrorL(
domain.TokenClaimErrorCode,
logx.WithContext(ctx),
[]logx.LogField{
{Key: "req", Value: req},
{Key: "func", Value: "accessTokenGenerator"},
{Key: "err", Value: err.Error()},
},
"failed to generator access token").Wrap(err)
return nil, e
}
return &token, nil
}
func convertMap(input map[string]any) map[string]string {
output := make(map[string]string)
for key, value := range input {
switch v := value.(type) {
case string:
output[key] = v
case fmt.Stringer:
output[key] = v.String()
default:
output[key] = fmt.Sprintf("%v", value)
}
}
return output
}
type wrapTokenErrorReq struct {
funcName string
req any
err error
message string
errorCode ers.ErrorCode
}
// wrapTokenError 將錯誤訊息封裝到 domain.TokenErrorL 中
func (use *TokenUseCase) wrapTokenError(ctx context.Context, param wrapTokenErrorReq) error {
logFields := []logx.LogField{
{Key: "req", Value: param.req},
{Key: "func", Value: param.funcName},
{Key: "err", Value: param.err.Error()},
}
wrappedErr := domain.TokenErrorL(
param.errorCode,
logx.WithContext(ctx),
logFields,
param.message,
).Wrap(param.err)
return wrappedErr
}