package usecase import ( "backend/pkg/library/errs/code" "backend/pkg/permission/utils" "context" "crypto/rand" "encoding/hex" "time" "backend/pkg/library/errs" "backend/pkg/permission/domain/config" "backend/pkg/permission/domain/entity" "backend/pkg/permission/domain/repository" "backend/pkg/permission/domain/usecase" "github.com/golang-jwt/jwt/v5" ) type AuthUseCaseParam struct { ClientRepo repository.ClientRepository TokenRepo repository.TokenRepository JWTConfig config.JWTConfig } type AuthUseCase struct { clientRepo repository.ClientRepository tokenRepo repository.TokenRepository jwtConfig config.JWTConfig } // MustAuthUseCase 創建認證用例實例 func MustAuthUseCase(param AuthUseCaseParam) usecase.AuthUseCase { return &AuthUseCase{ clientRepo: param.ClientRepo, tokenRepo: param.TokenRepo, jwtConfig: param.JWTConfig, } } func (uc *AuthUseCase) CreateToken(ctx context.Context, req usecase.CreateTokenRequest) (*usecase.TokenResponse, error) { // 驗證客戶端 client, err := uc.clientRepo.GetByClientID(ctx, req.ClientID) if err != nil { return nil, err } if !utils.IsActive(client.Status) { return nil, errs.UserSuspended(code.CloudEPPermission, "failed to get token since user has been suspended") } // 根據授權類型處理 var uid string switch req.GrantType { case "client_credentials": uid = "client_" + req.ClientID case "password": if req.Username == "" || req.Password == "" { return nil, errs.InvalidCredentials() } // 這裡應該驗證用戶名密碼,簡化處理 uid = req.Username default: return nil, errs.InvalidFormat("unsupported grant type: " + req.GrantType) } // 生成令牌 accessToken, err := uc.generateAccessToken(uid, req.ClientID, req.DeviceID) if err != nil { return nil, errs.SystemInternal("failed to generate access token: " + err.Error()) } refreshToken, err := uc.generateRefreshToken() if err != nil { return nil, errs.SystemInternal("failed to generate refresh token: " + err.Error()) } // 保存令牌 token := &entity.Token{ UID: uid, ClientID: req.ClientID, AccessToken: accessToken, RefreshToken: refreshToken, DeviceID: req.DeviceID, ExpiresAt: time.Now().Add(uc.jwtConfig.AccessExpires), } if err := uc.tokenRepo.Create(ctx, token); err != nil { return nil, err } return &usecase.TokenResponse{ AccessToken: accessToken, RefreshToken: refreshToken, TokenType: "Bearer", ExpiresIn: int64(uc.jwtConfig.AccessExpires.Seconds()), }, nil } func (uc *AuthUseCase) RefreshToken(ctx context.Context, refreshToken string) (*usecase.TokenResponse, error) { // 查找刷新令牌 token, err := uc.tokenRepo.GetByRefreshToken(ctx, refreshToken) if err != nil { return nil, err } if token.IsExpired() { return nil, errs.TokenExpired() } // 生成新的訪問令牌 accessToken, err := uc.generateAccessToken(token.UID, token.ClientID, token.DeviceID) if err != nil { return nil, errs.SystemInternal("failed to generate access token: " + err.Error()) } // 更新令牌 token.AccessToken = accessToken token.ExpiresAt = time.Now().Add(uc.jwtConfig.AccessExpires) if err := uc.tokenRepo.Update(ctx, token); err != nil { return nil, err } return &usecase.TokenResponse{ AccessToken: accessToken, RefreshToken: refreshToken, TokenType: "Bearer", ExpiresIn: int64(uc.jwtConfig.AccessExpires.Seconds()), }, nil } func (uc *AuthUseCase) ValidateToken(ctx context.Context, accessToken string) (*usecase.TokenClaims, error) { // 解析JWT令牌 token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, errs.TokenInvalid() } return []byte(uc.jwtConfig.Secret), nil }) if err != nil { return nil, errs.TokenInvalid() } if !token.Valid { return nil, errs.TokenInvalid() } claims, ok := token.Claims.(jwt.MapClaims) if !ok { return nil, errs.TokenInvalid() } uid, ok := claims["uid"].(string) if !ok { return nil, errs.TokenInvalid() } clientID, ok := claims["client_id"].(string) if !ok { return nil, errs.TokenInvalid() } deviceID, _ := claims["device_id"].(string) return &usecase.TokenClaims{ UID: uid, ClientID: clientID, DeviceID: deviceID, }, nil } func (uc *AuthUseCase) Logout(ctx context.Context, accessToken string) error { // 查找並刪除令牌 token, err := uc.tokenRepo.GetByAccessToken(ctx, accessToken) if err != nil { return err } return uc.tokenRepo.Delete(ctx, token.ID) } func (uc *AuthUseCase) LogoutAllByUserID(ctx context.Context, uid string) error { return uc.tokenRepo.DeleteByUserID(ctx, uid) } func (uc *AuthUseCase) generateAccessToken(uid, clientID, deviceID string) (string, error) { claims := jwt.MapClaims{ "uid": uid, "client_id": clientID, "device_id": deviceID, "exp": time.Now().Add(uc.jwtConfig.AccessExpires).Unix(), "iat": time.Now().Unix(), } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString([]byte(uc.jwtConfig.Secret)) } func (uc *AuthUseCase) generateRefreshToken() (string, error) { bytes := make([]byte, 32) if _, err := rand.Read(bytes); err != nil { return "", err } return hex.EncodeToString(bytes), nil }