528 lines
15 KiB
Go
528 lines
15 KiB
Go
|
package usecase
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/sha256"
|
||
|
"encoding/hex"
|
||
|
"fmt"
|
||
|
"time"
|
||
|
|
||
|
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain"
|
||
|
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/entity"
|
||
|
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/repository"
|
||
|
dt "code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/token"
|
||
|
"code.30cm.net/digimon/app-cloudep-permission-server/pkg/domain/usecase"
|
||
|
ers "code.30cm.net/digimon/library-go/errs"
|
||
|
"github.com/golang-jwt/jwt/v4"
|
||
|
"github.com/segmentio/ksuid"
|
||
|
"github.com/zeromicro/go-zero/core/logx"
|
||
|
)
|
||
|
|
||
|
type TokenUseCaseParam struct {
|
||
|
TokenRepo repository.TokenRepo
|
||
|
RefreshExpires time.Duration
|
||
|
Expired time.Duration
|
||
|
Secret string
|
||
|
}
|
||
|
|
||
|
type TokenUseCase struct {
|
||
|
TokenUseCaseParam
|
||
|
Token struct {
|
||
|
RefreshExpires time.Duration
|
||
|
Expired time.Duration
|
||
|
Secret string
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func NewTokenUseCase(param TokenUseCaseParam) usecase.TokenUseCase {
|
||
|
return &TokenUseCase{
|
||
|
TokenUseCaseParam: param,
|
||
|
Token: struct {
|
||
|
RefreshExpires time.Duration
|
||
|
Expired time.Duration
|
||
|
Secret string
|
||
|
}{
|
||
|
RefreshExpires: param.RefreshExpires,
|
||
|
Expired: param.Expired,
|
||
|
Secret: param.Secret,
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (use *TokenUseCase) GenerateAccessToken(ctx context.Context, req usecase.GenerateTokenRequest) (usecase.AccessTokenResponse, error) {
|
||
|
token, err := use.newToken(ctx, &req)
|
||
|
if err != nil {
|
||
|
return usecase.AccessTokenResponse{}, err
|
||
|
}
|
||
|
|
||
|
err = use.TokenRepo.Create(ctx, *token)
|
||
|
if err != nil {
|
||
|
// 錯誤代碼
|
||
|
e := domain.TokenErrorL(
|
||
|
domain.TokenCreateErrorCode,
|
||
|
logx.WithContext(ctx),
|
||
|
[]logx.LogField{
|
||
|
{Key: "req", Value: req},
|
||
|
{Key: "func", Value: "TokenRepo.Create"},
|
||
|
{Key: "err", Value: err.Error()},
|
||
|
},
|
||
|
"failed to create token").Wrap(err)
|
||
|
|
||
|
return usecase.AccessTokenResponse{}, e
|
||
|
}
|
||
|
|
||
|
return usecase.AccessTokenResponse{
|
||
|
AccessToken: token.AccessToken,
|
||
|
ExpiresIn: token.ExpiresIn,
|
||
|
RefreshToken: token.RefreshToken,
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
func (use *TokenUseCase) RefreshAccessToken(ctx context.Context, req usecase.RefreshTokenRequest) (usecase.RefreshTokenResponse, error) {
|
||
|
// Step 1: 檢查 refresh token
|
||
|
token, err := use.TokenRepo.GetAccessTokenByOneTimeToken(ctx, req.Token)
|
||
|
if err != nil {
|
||
|
return usecase.RefreshTokenResponse{},
|
||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
|
funcName: "TokenRepo.GetAccessTokenByOneTimeToken",
|
||
|
req: req,
|
||
|
err: err,
|
||
|
message: "failed to get access token",
|
||
|
errorCode: domain.TokenRefreshErrorCode,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// Step 2: 提取 Claims Data
|
||
|
claimsData, err := use.ParseSystemClaimsByAccessToken(token.AccessToken, use.Token.Secret, false)
|
||
|
if err != nil {
|
||
|
return usecase.RefreshTokenResponse{},
|
||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
|
funcName: "extractClaims",
|
||
|
req: req,
|
||
|
err: err,
|
||
|
message: "failed to extract claims",
|
||
|
errorCode: domain.TokenRefreshErrorCode,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
data := NewAdditional(claimsData)
|
||
|
data.Set(dt.Scope, req.Scope)
|
||
|
data.Set(dt.Device, req.DeviceID)
|
||
|
|
||
|
// Step 3: 創建新 token
|
||
|
newToken, err := use.newToken(ctx, &usecase.GenerateTokenRequest{
|
||
|
Scope: req.Scope,
|
||
|
DeviceID: req.DeviceID,
|
||
|
Expires: req.Expires,
|
||
|
RefreshExpires: req.RefreshExpires,
|
||
|
Data: data.GetAll(),
|
||
|
Role: data.Get(dt.Role),
|
||
|
UID: data.Get(dt.UID),
|
||
|
Account: data.Get(dt.Account),
|
||
|
})
|
||
|
if err != nil {
|
||
|
return usecase.RefreshTokenResponse{},
|
||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
|
funcName: "use.newToken",
|
||
|
req: req,
|
||
|
err: err,
|
||
|
message: "failed to create new token",
|
||
|
errorCode: domain.TokenRefreshErrorCode,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
if err := use.TokenRepo.Create(ctx, *newToken); err != nil {
|
||
|
return usecase.RefreshTokenResponse{},
|
||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
|
funcName: "TokenRepo.Create",
|
||
|
req: req,
|
||
|
err: err,
|
||
|
message: "failed to create new token",
|
||
|
errorCode: domain.TokenRefreshErrorCode,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// Step 4: 刪除舊 token 並創建新 token
|
||
|
if err := use.TokenRepo.Delete(ctx, token); err != nil {
|
||
|
return usecase.RefreshTokenResponse{},
|
||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
|
funcName: "TokenRepo.Delete",
|
||
|
req: req,
|
||
|
err: err,
|
||
|
message: "failed to delete old token",
|
||
|
errorCode: domain.TokenRefreshErrorCode,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// 返回新的 Token 響應
|
||
|
return usecase.RefreshTokenResponse{
|
||
|
AccessToken: newToken.AccessToken,
|
||
|
RefreshToken: newToken.RefreshToken,
|
||
|
ExpiresIn: newToken.ExpiresIn,
|
||
|
TokenType: data.Get(dt.Type),
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
func (use *TokenUseCase) RevokeToken(ctx context.Context, req usecase.TokenRequest) error {
|
||
|
claims, err := use.ParseSystemClaimsByAccessToken(req.Token, use.Token.Secret, false)
|
||
|
if err != nil {
|
||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
|
funcName: "CancelToken extractClaims",
|
||
|
req: req,
|
||
|
err: err,
|
||
|
message: "failed to get token claims",
|
||
|
errorCode: domain.TokenCancelErrorCode,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
data := NewAdditional(claims)
|
||
|
token, err := use.TokenRepo.GetAccessTokenByID(ctx, data.Get(dt.ID))
|
||
|
if err != nil {
|
||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
|
funcName: "TokenRepo GetAccessTokenByID",
|
||
|
req: req,
|
||
|
err: err,
|
||
|
message: fmt.Sprintf("failed to get token claims :%s", data.Get(dt.ID)),
|
||
|
errorCode: domain.TokenCancelErrorCode,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
err = use.TokenRepo.Delete(ctx, token)
|
||
|
if err != nil {
|
||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
|
funcName: "TokenRepo Delete",
|
||
|
req: req,
|
||
|
err: err,
|
||
|
message: fmt.Sprintf("failed to delete token :%s", token.ID),
|
||
|
errorCode: domain.TokenCancelErrorCode,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (use *TokenUseCase) VerifyToken(ctx context.Context, req usecase.TokenRequest) (usecase.VerifyTokenResponse, error) {
|
||
|
claims, err := use.ParseSystemClaimsByAccessToken(req.Token, use.Token.Secret, true)
|
||
|
if err != nil {
|
||
|
return usecase.VerifyTokenResponse{},
|
||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
|
funcName: "parseClaims",
|
||
|
req: req,
|
||
|
err: err,
|
||
|
message: "validate token claims error",
|
||
|
errorCode: domain.TokenValidateErrorCode,
|
||
|
})
|
||
|
}
|
||
|
data := NewAdditional(claims)
|
||
|
|
||
|
token, err := use.TokenRepo.GetAccessTokenByID(ctx, data.Get(dt.ID))
|
||
|
if err != nil {
|
||
|
return usecase.VerifyTokenResponse{},
|
||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
|
funcName: "TokenRepo.GetAccessTokenByID",
|
||
|
req: req,
|
||
|
err: err,
|
||
|
message: fmt.Sprintf("failed to get token :%s", data.Get(dt.ID)),
|
||
|
errorCode: domain.TokenValidateErrorCode,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
return usecase.VerifyTokenResponse{
|
||
|
Token: token,
|
||
|
Data: data.GetAll(),
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
func (use *TokenUseCase) RevokeTokensByUID(ctx context.Context, req usecase.RevokeTokensByUIDRequest) error {
|
||
|
if req.UID != "" {
|
||
|
err := use.TokenRepo.DeleteAccessTokensByUID(ctx, req.UID)
|
||
|
if err != nil {
|
||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
|
funcName: "TokenRepo.DeleteAccessTokensByUID",
|
||
|
req: req,
|
||
|
err: err,
|
||
|
message: "failed to cancel tokens by uid",
|
||
|
errorCode: domain.TokensCancelErrorCode,
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if len(req.IDs) > 0 {
|
||
|
err := use.TokenRepo.DeleteAccessTokenByID(ctx, req.IDs)
|
||
|
if err != nil {
|
||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
|
funcName: "TokenRepo.DeleteAccessTokenByID",
|
||
|
req: req,
|
||
|
err: err,
|
||
|
message: "failed to cancel tokens by token ids",
|
||
|
errorCode: domain.TokensCancelErrorCode,
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (use *TokenUseCase) RevokeTokensByDeviceID(ctx context.Context, deviceID string) error {
|
||
|
err := use.TokenRepo.DeleteAccessTokensByDeviceID(ctx, deviceID)
|
||
|
if err != nil {
|
||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
|
funcName: "TokenRepo.DeleteAccessTokensByDeviceID",
|
||
|
req: deviceID,
|
||
|
err: err,
|
||
|
message: "failed to cancel token by device id",
|
||
|
errorCode: domain.TokensCancelErrorCode,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (use *TokenUseCase) GetUserTokensByDeviceID(ctx context.Context, deviceID string) ([]*usecase.AccessTokenResponse, error) {
|
||
|
tokens, err := use.TokenRepo.GetAccessTokensByDeviceID(ctx, deviceID)
|
||
|
if err != nil {
|
||
|
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
|
funcName: "TokenRepo.GetAccessTokensByDeviceID",
|
||
|
req: deviceID,
|
||
|
err: err,
|
||
|
message: "failed to get token by device id",
|
||
|
errorCode: domain.TokenGetErrorCode,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
result := make([]*usecase.AccessTokenResponse, 0, len(tokens))
|
||
|
for _, v := range tokens {
|
||
|
result = append(result, &usecase.AccessTokenResponse{
|
||
|
AccessToken: v.AccessToken,
|
||
|
ExpiresIn: v.ExpiresIn,
|
||
|
RefreshToken: v.RefreshToken,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
return result, nil
|
||
|
}
|
||
|
|
||
|
func (use *TokenUseCase) GetUserTokensByUID(ctx context.Context, uid string) ([]*usecase.AccessTokenResponse, error) {
|
||
|
tokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, uid)
|
||
|
if err != nil {
|
||
|
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
|
funcName: "TokenRepo.GetAccessTokensByUID",
|
||
|
req: uid,
|
||
|
err: err,
|
||
|
message: "failed to get token by uid",
|
||
|
errorCode: domain.TokenGetErrorCode,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
result := make([]*usecase.AccessTokenResponse, 0, len(tokens))
|
||
|
for _, v := range tokens {
|
||
|
result = append(result, &usecase.AccessTokenResponse{
|
||
|
AccessToken: v.AccessToken,
|
||
|
ExpiresIn: v.ExpiresIn,
|
||
|
RefreshToken: v.RefreshToken,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
return result, nil
|
||
|
}
|
||
|
|
||
|
func (use *TokenUseCase) ReadTokenBasicData(ctx context.Context, token string) (usecase.Additional, error) {
|
||
|
claims, err := use.ParseSystemClaimsByAccessToken(token, use.Token.Secret, false)
|
||
|
if err != nil {
|
||
|
return nil,
|
||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||
|
funcName: "parseClaims",
|
||
|
req: token,
|
||
|
err: err,
|
||
|
message: "validate token claims error",
|
||
|
errorCode: domain.TokenValidateErrorCode,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
return NewAdditional(claims), nil
|
||
|
}
|
||
|
|
||
|
// ======== JWT Token ========
|
||
|
|
||
|
// CreateAccessToken 會將基本 token 以及想要加入Token Claims 的Data 依照 secret key 加密之後變成 jwt access token
|
||
|
func (use *TokenUseCase) CreateAccessToken(token entity.Token, data any, secretKey string) (string, error) {
|
||
|
claims := entity.Claims{
|
||
|
Data: data,
|
||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||
|
ID: token.ID,
|
||
|
ExpiresAt: jwt.NewNumericDate(time.Unix(0, token.ExpiresIn)),
|
||
|
Issuer: dt.Issuer,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).
|
||
|
SignedString([]byte(secretKey))
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
|
||
|
return accessToken, nil
|
||
|
}
|
||
|
|
||
|
func (use *TokenUseCase) CreateRefreshToken(accessToken string) string {
|
||
|
hash := sha256.New()
|
||
|
_, _ = hash.Write([]byte(accessToken))
|
||
|
|
||
|
return hex.EncodeToString(hash.Sum(nil))
|
||
|
}
|
||
|
|
||
|
func (use *TokenUseCase) ParseJWTClaimsByAccessToken(accessToken string, secret string, validate bool) (jwt.MapClaims, error) {
|
||
|
// 跳過驗證的解析
|
||
|
var token *jwt.Token
|
||
|
var err error
|
||
|
|
||
|
if validate {
|
||
|
token, err = jwt.Parse(accessToken, func(token *jwt.Token) (any, error) {
|
||
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||
|
return nil, fmt.Errorf("token unexpected signing method: %v", token.Header["alg"])
|
||
|
}
|
||
|
|
||
|
return []byte(secret), nil
|
||
|
})
|
||
|
if err != nil {
|
||
|
return jwt.MapClaims{}, err
|
||
|
}
|
||
|
} else {
|
||
|
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
|
||
|
token, err = parser.Parse(accessToken, func(_ *jwt.Token) (any, error) {
|
||
|
return []byte(secret), nil
|
||
|
})
|
||
|
if err != nil {
|
||
|
return jwt.MapClaims{}, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||
|
if !ok && token.Valid {
|
||
|
return jwt.MapClaims{}, fmt.Errorf("token valid error")
|
||
|
}
|
||
|
|
||
|
return claims, nil
|
||
|
}
|
||
|
|
||
|
func (use *TokenUseCase) ParseSystemClaimsByAccessToken(accessToken string, secret string, validate bool) (map[string]string, error) {
|
||
|
claimMap, err := use.ParseJWTClaimsByAccessToken(accessToken, secret, validate)
|
||
|
if err != nil {
|
||
|
return map[string]string{}, err
|
||
|
}
|
||
|
|
||
|
claimsData, ok := claimMap["data"].(map[string]any)
|
||
|
if ok {
|
||
|
return convertMap(claimsData), nil
|
||
|
}
|
||
|
|
||
|
return map[string]string{}, fmt.Errorf("get data from claim map error")
|
||
|
}
|
||
|
|
||
|
// ======== 工具 ========
|
||
|
|
||
|
func (use *TokenUseCase) newToken(ctx context.Context, req *usecase.GenerateTokenRequest) (*entity.Token, error) {
|
||
|
// 準備建立 Token 所需
|
||
|
now := time.Now().UTC()
|
||
|
expires := req.Expires
|
||
|
refreshExpires := req.RefreshExpires
|
||
|
|
||
|
if expires <= 0 {
|
||
|
// 將時間加上 n 秒 -> 系統內預設
|
||
|
sec := time.Duration(use.Token.Expired.Seconds()) * time.Second
|
||
|
// 獲取 Unix 時間戳
|
||
|
expires = now.Add(sec).UnixNano()
|
||
|
}
|
||
|
|
||
|
// Refresh Token 過期時間要比普通的Token 長
|
||
|
if req.RefreshExpires <= 0 {
|
||
|
// 獲取 Unix 時間戳
|
||
|
refresh := time.Duration(use.Token.RefreshExpires.Seconds()) * time.Second
|
||
|
refreshExpires = now.Add(refresh).UnixNano()
|
||
|
}
|
||
|
|
||
|
token := entity.Token{
|
||
|
ID: ksuid.New().String(),
|
||
|
DeviceID: req.DeviceID,
|
||
|
ExpiresIn: expires,
|
||
|
RefreshExpiresIn: refreshExpires,
|
||
|
AccessCreateAt: now.UnixNano(),
|
||
|
RefreshCreateAt: now.UnixNano(),
|
||
|
UID: req.UID,
|
||
|
}
|
||
|
// 故意 data 裡面不會有那些已經有的欄位資訊
|
||
|
data := NewAdditional(req.Data)
|
||
|
data.Set(dt.ID, token.ID)
|
||
|
data.Set(dt.Role, req.Role)
|
||
|
data.Set(dt.Scope, req.Scope)
|
||
|
data.Set(dt.Account, req.Account)
|
||
|
data.Set(dt.UID, req.UID)
|
||
|
data.Set(dt.Type, req.TokenType)
|
||
|
|
||
|
if req.DeviceID != "" {
|
||
|
data.Set(dt.Device, req.DeviceID)
|
||
|
}
|
||
|
|
||
|
var err error
|
||
|
token.AccessToken, err = use.CreateAccessToken(token, data.GetAll(), use.Token.Secret)
|
||
|
token.RefreshToken = use.CreateRefreshToken(token.AccessToken)
|
||
|
|
||
|
if err != nil {
|
||
|
// 錯誤代碼 20-201-02
|
||
|
e := domain.TokenErrorL(
|
||
|
domain.TokenClaimErrorCode,
|
||
|
logx.WithContext(ctx),
|
||
|
[]logx.LogField{
|
||
|
{Key: "req", Value: req},
|
||
|
{Key: "func", Value: "accessTokenGenerator"},
|
||
|
{Key: "err", Value: err.Error()},
|
||
|
},
|
||
|
"failed to generator access token").Wrap(err)
|
||
|
|
||
|
return nil, e
|
||
|
}
|
||
|
|
||
|
return &token, nil
|
||
|
}
|
||
|
|
||
|
func convertMap(input map[string]any) map[string]string {
|
||
|
output := make(map[string]string)
|
||
|
for key, value := range input {
|
||
|
switch v := value.(type) {
|
||
|
case string:
|
||
|
output[key] = v
|
||
|
case fmt.Stringer:
|
||
|
output[key] = v.String()
|
||
|
default:
|
||
|
output[key] = fmt.Sprintf("%v", value)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return output
|
||
|
}
|
||
|
|
||
|
type wrapTokenErrorReq struct {
|
||
|
funcName string
|
||
|
req any
|
||
|
err error
|
||
|
message string
|
||
|
errorCode ers.ErrorCode
|
||
|
}
|
||
|
|
||
|
// wrapTokenError 將錯誤訊息封裝到 domain.TokenErrorL 中
|
||
|
func (use *TokenUseCase) wrapTokenError(ctx context.Context, param wrapTokenErrorReq) error {
|
||
|
logFields := []logx.LogField{
|
||
|
{Key: "req", Value: param.req},
|
||
|
{Key: "func", Value: param.funcName},
|
||
|
{Key: "err", Value: param.err.Error()},
|
||
|
}
|
||
|
wrappedErr := domain.TokenErrorL(
|
||
|
param.errorCode,
|
||
|
logx.WithContext(ctx),
|
||
|
logFields,
|
||
|
param.message,
|
||
|
).Wrap(param.err)
|
||
|
|
||
|
return wrappedErr
|
||
|
}
|