backend/pkg/permission/usecase/auth.go

208 lines
5.2 KiB
Go
Raw Normal View History

2025-10-03 08:38:12 +00:00
package usecase
import (
"backend/pkg/library/errs/code"
"backend/pkg/permission/utils"
"context"
"crypto/rand"
"encoding/hex"
"time"
"backend/pkg/library/errs"
"backend/pkg/permission/domain/config"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"backend/pkg/permission/domain/usecase"
"github.com/golang-jwt/jwt/v5"
)
type AuthUseCaseParam struct {
ClientRepo repository.ClientRepository
TokenRepo repository.TokenRepository
JWTConfig config.JWTConfig
}
type AuthUseCase struct {
clientRepo repository.ClientRepository
tokenRepo repository.TokenRepository
jwtConfig config.JWTConfig
}
// MustAuthUseCase 創建認證用例實例
func MustAuthUseCase(param AuthUseCaseParam) usecase.AuthUseCase {
return &AuthUseCase{
clientRepo: param.ClientRepo,
tokenRepo: param.TokenRepo,
jwtConfig: param.JWTConfig,
}
}
func (uc *AuthUseCase) CreateToken(ctx context.Context, req usecase.CreateTokenRequest) (*usecase.TokenResponse, error) {
// 驗證客戶端
client, err := uc.clientRepo.GetByClientID(ctx, req.ClientID)
if err != nil {
return nil, err
}
if !utils.IsActive(client.Status) {
return nil, errs.UserSuspended(code.CloudEPPermission, "failed to get token since user has been suspended")
}
// 根據授權類型處理
var uid string
switch req.GrantType {
case "client_credentials":
uid = "client_" + req.ClientID
case "password":
if req.Username == "" || req.Password == "" {
return nil, errs.InvalidCredentials()
}
// 這裡應該驗證用戶名密碼,簡化處理
uid = req.Username
default:
return nil, errs.InvalidFormat("unsupported grant type: " + req.GrantType)
}
// 生成令牌
accessToken, err := uc.generateAccessToken(uid, req.ClientID, req.DeviceID)
if err != nil {
return nil, errs.SystemInternal("failed to generate access token: " + err.Error())
}
refreshToken, err := uc.generateRefreshToken()
if err != nil {
return nil, errs.SystemInternal("failed to generate refresh token: " + err.Error())
}
// 保存令牌
token := &entity.Token{
UID: uid,
ClientID: req.ClientID,
AccessToken: accessToken,
RefreshToken: refreshToken,
DeviceID: req.DeviceID,
ExpiresAt: time.Now().Add(uc.jwtConfig.AccessExpires),
}
if err := uc.tokenRepo.Create(ctx, token); err != nil {
return nil, err
}
return &usecase.TokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: "Bearer",
ExpiresIn: int64(uc.jwtConfig.AccessExpires.Seconds()),
}, nil
}
func (uc *AuthUseCase) RefreshToken(ctx context.Context, refreshToken string) (*usecase.TokenResponse, error) {
// 查找刷新令牌
token, err := uc.tokenRepo.GetByRefreshToken(ctx, refreshToken)
if err != nil {
return nil, err
}
if token.IsExpired() {
return nil, errs.TokenExpired()
}
// 生成新的訪問令牌
accessToken, err := uc.generateAccessToken(token.UID, token.ClientID, token.DeviceID)
if err != nil {
return nil, errs.SystemInternal("failed to generate access token: " + err.Error())
}
// 更新令牌
token.AccessToken = accessToken
token.ExpiresAt = time.Now().Add(uc.jwtConfig.AccessExpires)
if err := uc.tokenRepo.Update(ctx, token); err != nil {
return nil, err
}
return &usecase.TokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: "Bearer",
ExpiresIn: int64(uc.jwtConfig.AccessExpires.Seconds()),
}, nil
}
func (uc *AuthUseCase) ValidateToken(ctx context.Context, accessToken string) (*usecase.TokenClaims, error) {
// 解析JWT令牌
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errs.TokenInvalid()
}
return []byte(uc.jwtConfig.Secret), nil
})
if err != nil {
return nil, errs.TokenInvalid()
}
if !token.Valid {
return nil, errs.TokenInvalid()
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, errs.TokenInvalid()
}
uid, ok := claims["uid"].(string)
if !ok {
return nil, errs.TokenInvalid()
}
clientID, ok := claims["client_id"].(string)
if !ok {
return nil, errs.TokenInvalid()
}
deviceID, _ := claims["device_id"].(string)
return &usecase.TokenClaims{
UID: uid,
ClientID: clientID,
DeviceID: deviceID,
}, nil
}
func (uc *AuthUseCase) Logout(ctx context.Context, accessToken string) error {
// 查找並刪除令牌
token, err := uc.tokenRepo.GetByAccessToken(ctx, accessToken)
if err != nil {
return err
}
return uc.tokenRepo.Delete(ctx, token.ID)
}
func (uc *AuthUseCase) LogoutAllByUserID(ctx context.Context, uid string) error {
return uc.tokenRepo.DeleteByUserID(ctx, uid)
}
func (uc *AuthUseCase) generateAccessToken(uid, clientID, deviceID string) (string, error) {
claims := jwt.MapClaims{
"uid": uid,
"client_id": clientID,
"device_id": deviceID,
"exp": time.Now().Add(uc.jwtConfig.AccessExpires).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(uc.jwtConfig.Secret))
}
func (uc *AuthUseCase) generateRefreshToken() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}