208 lines
5.2 KiB
Go
208 lines
5.2 KiB
Go
|
package usecase
|
||
|
|
||
|
import (
|
||
|
"backend/pkg/library/errs/code"
|
||
|
"backend/pkg/permission/utils"
|
||
|
"context"
|
||
|
"crypto/rand"
|
||
|
"encoding/hex"
|
||
|
"time"
|
||
|
|
||
|
"backend/pkg/library/errs"
|
||
|
"backend/pkg/permission/domain/config"
|
||
|
"backend/pkg/permission/domain/entity"
|
||
|
"backend/pkg/permission/domain/repository"
|
||
|
"backend/pkg/permission/domain/usecase"
|
||
|
|
||
|
"github.com/golang-jwt/jwt/v5"
|
||
|
)
|
||
|
|
||
|
type AuthUseCaseParam struct {
|
||
|
ClientRepo repository.ClientRepository
|
||
|
TokenRepo repository.TokenRepository
|
||
|
JWTConfig config.JWTConfig
|
||
|
}
|
||
|
|
||
|
type AuthUseCase struct {
|
||
|
clientRepo repository.ClientRepository
|
||
|
tokenRepo repository.TokenRepository
|
||
|
jwtConfig config.JWTConfig
|
||
|
}
|
||
|
|
||
|
// MustAuthUseCase 創建認證用例實例
|
||
|
func MustAuthUseCase(param AuthUseCaseParam) usecase.AuthUseCase {
|
||
|
return &AuthUseCase{
|
||
|
clientRepo: param.ClientRepo,
|
||
|
tokenRepo: param.TokenRepo,
|
||
|
jwtConfig: param.JWTConfig,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (uc *AuthUseCase) CreateToken(ctx context.Context, req usecase.CreateTokenRequest) (*usecase.TokenResponse, error) {
|
||
|
// 驗證客戶端
|
||
|
client, err := uc.clientRepo.GetByClientID(ctx, req.ClientID)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if !utils.IsActive(client.Status) {
|
||
|
return nil, errs.UserSuspended(code.CloudEPPermission, "failed to get token since user has been suspended")
|
||
|
}
|
||
|
|
||
|
// 根據授權類型處理
|
||
|
var uid string
|
||
|
switch req.GrantType {
|
||
|
case "client_credentials":
|
||
|
uid = "client_" + req.ClientID
|
||
|
case "password":
|
||
|
if req.Username == "" || req.Password == "" {
|
||
|
return nil, errs.InvalidCredentials()
|
||
|
}
|
||
|
// 這裡應該驗證用戶名密碼,簡化處理
|
||
|
uid = req.Username
|
||
|
default:
|
||
|
return nil, errs.InvalidFormat("unsupported grant type: " + req.GrantType)
|
||
|
}
|
||
|
|
||
|
// 生成令牌
|
||
|
accessToken, err := uc.generateAccessToken(uid, req.ClientID, req.DeviceID)
|
||
|
if err != nil {
|
||
|
return nil, errs.SystemInternal("failed to generate access token: " + err.Error())
|
||
|
}
|
||
|
|
||
|
refreshToken, err := uc.generateRefreshToken()
|
||
|
if err != nil {
|
||
|
return nil, errs.SystemInternal("failed to generate refresh token: " + err.Error())
|
||
|
}
|
||
|
|
||
|
// 保存令牌
|
||
|
token := &entity.Token{
|
||
|
UID: uid,
|
||
|
ClientID: req.ClientID,
|
||
|
AccessToken: accessToken,
|
||
|
RefreshToken: refreshToken,
|
||
|
DeviceID: req.DeviceID,
|
||
|
ExpiresAt: time.Now().Add(uc.jwtConfig.AccessExpires),
|
||
|
}
|
||
|
|
||
|
if err := uc.tokenRepo.Create(ctx, token); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return &usecase.TokenResponse{
|
||
|
AccessToken: accessToken,
|
||
|
RefreshToken: refreshToken,
|
||
|
TokenType: "Bearer",
|
||
|
ExpiresIn: int64(uc.jwtConfig.AccessExpires.Seconds()),
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
func (uc *AuthUseCase) RefreshToken(ctx context.Context, refreshToken string) (*usecase.TokenResponse, error) {
|
||
|
// 查找刷新令牌
|
||
|
token, err := uc.tokenRepo.GetByRefreshToken(ctx, refreshToken)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if token.IsExpired() {
|
||
|
return nil, errs.TokenExpired()
|
||
|
}
|
||
|
|
||
|
// 生成新的訪問令牌
|
||
|
accessToken, err := uc.generateAccessToken(token.UID, token.ClientID, token.DeviceID)
|
||
|
if err != nil {
|
||
|
return nil, errs.SystemInternal("failed to generate access token: " + err.Error())
|
||
|
}
|
||
|
|
||
|
// 更新令牌
|
||
|
token.AccessToken = accessToken
|
||
|
token.ExpiresAt = time.Now().Add(uc.jwtConfig.AccessExpires)
|
||
|
|
||
|
if err := uc.tokenRepo.Update(ctx, token); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return &usecase.TokenResponse{
|
||
|
AccessToken: accessToken,
|
||
|
RefreshToken: refreshToken,
|
||
|
TokenType: "Bearer",
|
||
|
ExpiresIn: int64(uc.jwtConfig.AccessExpires.Seconds()),
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
func (uc *AuthUseCase) ValidateToken(ctx context.Context, accessToken string) (*usecase.TokenClaims, error) {
|
||
|
// 解析JWT令牌
|
||
|
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
|
||
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||
|
return nil, errs.TokenInvalid()
|
||
|
}
|
||
|
return []byte(uc.jwtConfig.Secret), nil
|
||
|
})
|
||
|
|
||
|
if err != nil {
|
||
|
return nil, errs.TokenInvalid()
|
||
|
}
|
||
|
|
||
|
if !token.Valid {
|
||
|
return nil, errs.TokenInvalid()
|
||
|
}
|
||
|
|
||
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||
|
if !ok {
|
||
|
return nil, errs.TokenInvalid()
|
||
|
}
|
||
|
|
||
|
uid, ok := claims["uid"].(string)
|
||
|
if !ok {
|
||
|
return nil, errs.TokenInvalid()
|
||
|
}
|
||
|
|
||
|
clientID, ok := claims["client_id"].(string)
|
||
|
if !ok {
|
||
|
return nil, errs.TokenInvalid()
|
||
|
}
|
||
|
|
||
|
deviceID, _ := claims["device_id"].(string)
|
||
|
|
||
|
return &usecase.TokenClaims{
|
||
|
UID: uid,
|
||
|
ClientID: clientID,
|
||
|
DeviceID: deviceID,
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
func (uc *AuthUseCase) Logout(ctx context.Context, accessToken string) error {
|
||
|
// 查找並刪除令牌
|
||
|
token, err := uc.tokenRepo.GetByAccessToken(ctx, accessToken)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return uc.tokenRepo.Delete(ctx, token.ID)
|
||
|
}
|
||
|
|
||
|
func (uc *AuthUseCase) LogoutAllByUserID(ctx context.Context, uid string) error {
|
||
|
return uc.tokenRepo.DeleteByUserID(ctx, uid)
|
||
|
}
|
||
|
|
||
|
func (uc *AuthUseCase) generateAccessToken(uid, clientID, deviceID string) (string, error) {
|
||
|
claims := jwt.MapClaims{
|
||
|
"uid": uid,
|
||
|
"client_id": clientID,
|
||
|
"device_id": deviceID,
|
||
|
"exp": time.Now().Add(uc.jwtConfig.AccessExpires).Unix(),
|
||
|
"iat": time.Now().Unix(),
|
||
|
}
|
||
|
|
||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||
|
return token.SignedString([]byte(uc.jwtConfig.Secret))
|
||
|
}
|
||
|
|
||
|
func (uc *AuthUseCase) generateRefreshToken() (string, error) {
|
||
|
bytes := make([]byte, 32)
|
||
|
if _, err := rand.Read(bytes); err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
return hex.EncodeToString(bytes), nil
|
||
|
}
|