234 lines
7.2 KiB
Go
234 lines
7.2 KiB
Go
|
|
package usecase
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"haixun-backend/internal/config"
|
||
|
|
app "haixun-backend/internal/library/errors"
|
||
|
|
"haixun-backend/internal/library/errors/code"
|
||
|
|
domrepo "haixun-backend/internal/model/auth/domain/repository"
|
||
|
|
domusecase "haixun-backend/internal/model/auth/domain/usecase"
|
||
|
|
|
||
|
|
"github.com/golang-jwt/jwt/v4"
|
||
|
|
"github.com/google/uuid"
|
||
|
|
)
|
||
|
|
|
||
|
|
type tokenUseCase struct {
|
||
|
|
cfg config.AuthConf
|
||
|
|
revoke domrepo.TokenRevokeStore
|
||
|
|
}
|
||
|
|
|
||
|
|
func NewTokenUseCase(cfg config.AuthConf, revoke domrepo.TokenRevokeStore) domusecase.TokenUseCase {
|
||
|
|
cfg = normalizeConfig(cfg)
|
||
|
|
return &tokenUseCase{cfg: cfg, revoke: revoke}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (u *tokenUseCase) IssuePair(ctx context.Context, req domusecase.IssuePairRequest) (*domusecase.TokenPair, error) {
|
||
|
|
if req.TenantID == "" || req.UID == "" {
|
||
|
|
return nil, app.For(code.Auth).InputMissingRequired("tenant_id and uid are required")
|
||
|
|
}
|
||
|
|
access, err := u.sign(req, domusecase.TokenTypeAccess, u.cfg.AccessExpireSeconds, u.cfg.AccessSecret)
|
||
|
|
if err != nil {
|
||
|
|
return nil, app.For(code.Auth).SysInternal("sign access token failed").WithCause(err)
|
||
|
|
}
|
||
|
|
refresh, err := u.sign(req, domusecase.TokenTypeRefresh, u.cfg.RefreshExpireSeconds, u.cfg.RefreshSecret)
|
||
|
|
if err != nil {
|
||
|
|
return nil, app.For(code.Auth).SysInternal("sign refresh token failed").WithCause(err)
|
||
|
|
}
|
||
|
|
if u.revoke != nil {
|
||
|
|
if err := u.revoke.SavePair(ctx, access.jti, refresh.jti, time.Until(access.expiresAt), time.Until(refresh.expiresAt)); err != nil {
|
||
|
|
return nil, app.For(code.Auth).DBError("save jwt pair failed").WithCause(err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return &domusecase.TokenPair{
|
||
|
|
AccessToken: access.raw,
|
||
|
|
RefreshToken: refresh.raw,
|
||
|
|
ExpiresIn: u.cfg.AccessExpireSeconds,
|
||
|
|
TokenType: "Bearer",
|
||
|
|
UID: req.UID,
|
||
|
|
}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (u *tokenUseCase) Refresh(ctx context.Context, refreshToken string) (*domusecase.TokenPair, error) {
|
||
|
|
if refreshToken == "" {
|
||
|
|
return nil, app.For(code.Auth).InputMissingRequired("refresh_token is required")
|
||
|
|
}
|
||
|
|
claims, err := u.parse(refreshToken, domusecase.TokenTypeRefresh, u.cfg.RefreshSecret)
|
||
|
|
if err != nil {
|
||
|
|
return nil, app.For(code.Auth).AuthUnauthorized("invalid refresh token").WithCause(err)
|
||
|
|
}
|
||
|
|
if err := u.ensureNotBlacklisted(ctx, claims.ID); err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
pair, err := u.IssuePair(ctx, domusecase.IssuePairRequest{
|
||
|
|
TenantID: claims.TenantID,
|
||
|
|
UID: claims.UID,
|
||
|
|
AuthGen: claims.AuthGen,
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
if u.revoke != nil {
|
||
|
|
_ = u.revoke.Blacklist(ctx, claims.ID, remainingTTL(claims.expiresAt))
|
||
|
|
if accessJTI, err := u.revoke.GetPairedJTI(ctx, claims.ID); err == nil && accessJTI != "" {
|
||
|
|
_ = u.revoke.Blacklist(ctx, accessJTI, time.Duration(u.cfg.AccessExpireSeconds)*time.Second)
|
||
|
|
_ = u.revoke.DeletePair(ctx, accessJTI, claims.ID)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return pair, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (u *tokenUseCase) Logout(ctx context.Context, req domusecase.LogoutRequest) error {
|
||
|
|
if req.AccessToken == "" {
|
||
|
|
return app.For(code.Auth).InputMissingRequired("access token is required")
|
||
|
|
}
|
||
|
|
if u.revoke == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
claims, err := u.parse(req.AccessToken, domusecase.TokenTypeAccess, u.cfg.AccessSecret)
|
||
|
|
if err != nil {
|
||
|
|
return app.For(code.Auth).AuthUnauthorized("invalid access token").WithCause(err)
|
||
|
|
}
|
||
|
|
if err := u.revoke.Blacklist(ctx, claims.ID, remainingTTL(claims.expiresAt)); err != nil {
|
||
|
|
return app.For(code.Auth).DBError("blacklist access token failed").WithCause(err)
|
||
|
|
}
|
||
|
|
refreshJTI, err := u.revoke.GetPairedJTI(ctx, claims.ID)
|
||
|
|
if err != nil {
|
||
|
|
return app.For(code.Auth).DBError("read jwt pair failed").WithCause(err)
|
||
|
|
}
|
||
|
|
if refreshJTI != "" {
|
||
|
|
_ = u.revoke.Blacklist(ctx, refreshJTI, time.Duration(u.cfg.RefreshExpireSeconds)*time.Second)
|
||
|
|
}
|
||
|
|
return u.revoke.DeletePair(ctx, claims.ID, refreshJTI)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (u *tokenUseCase) ParseAccessToken(ctx context.Context, accessToken string) (*domusecase.AccessClaims, error) {
|
||
|
|
if accessToken == "" {
|
||
|
|
return nil, app.For(code.Auth).AuthUnauthorized("missing access token")
|
||
|
|
}
|
||
|
|
claims, err := u.parse(accessToken, domusecase.TokenTypeAccess, u.cfg.AccessSecret)
|
||
|
|
if err != nil {
|
||
|
|
return nil, app.For(code.Auth).AuthUnauthorized("invalid access token").WithCause(err)
|
||
|
|
}
|
||
|
|
if err := u.ensureNotBlacklisted(ctx, claims.ID); err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
return &domusecase.AccessClaims{
|
||
|
|
TenantID: claims.TenantID,
|
||
|
|
UID: claims.UID,
|
||
|
|
AuthGen: claims.AuthGen,
|
||
|
|
JTI: claims.ID,
|
||
|
|
}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (u *tokenUseCase) ensureNotBlacklisted(ctx context.Context, jti string) error {
|
||
|
|
if u.revoke == nil || jti == "" {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
blacklisted, err := u.revoke.IsBlacklisted(ctx, jti)
|
||
|
|
if err != nil {
|
||
|
|
return app.For(code.Auth).DBError("check jwt blacklist failed").WithCause(err)
|
||
|
|
}
|
||
|
|
if blacklisted {
|
||
|
|
return app.For(code.Auth).AuthUnauthorized("token revoked")
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
type jwtClaims struct {
|
||
|
|
TenantID string `json:"tenant_id"`
|
||
|
|
UID string `json:"uid"`
|
||
|
|
Typ string `json:"typ"`
|
||
|
|
AuthGen int64 `json:"auth_gen"`
|
||
|
|
jwt.RegisteredClaims
|
||
|
|
}
|
||
|
|
|
||
|
|
type parsedClaims struct {
|
||
|
|
TenantID string
|
||
|
|
UID string
|
||
|
|
AuthGen int64
|
||
|
|
ID string
|
||
|
|
expiresAt time.Time
|
||
|
|
}
|
||
|
|
|
||
|
|
type signedToken struct {
|
||
|
|
raw string
|
||
|
|
jti string
|
||
|
|
expiresAt time.Time
|
||
|
|
}
|
||
|
|
|
||
|
|
func (u *tokenUseCase) sign(req domusecase.IssuePairRequest, typ domusecase.TokenType, expireSec int64, secret string) (*signedToken, error) {
|
||
|
|
now := time.Now().UTC()
|
||
|
|
expiresAt := now.Add(time.Duration(expireSec) * time.Second)
|
||
|
|
jti := uuid.NewString()
|
||
|
|
claims := jwtClaims{
|
||
|
|
TenantID: req.TenantID,
|
||
|
|
UID: req.UID,
|
||
|
|
Typ: string(typ),
|
||
|
|
AuthGen: req.AuthGen,
|
||
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
||
|
|
ID: jti,
|
||
|
|
IssuedAt: jwt.NewNumericDate(now),
|
||
|
|
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
||
|
|
},
|
||
|
|
}
|
||
|
|
raw, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(secret))
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
return &signedToken{raw: raw, jti: jti, expiresAt: expiresAt}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (u *tokenUseCase) parse(raw string, want domusecase.TokenType, secret string) (*parsedClaims, error) {
|
||
|
|
parsed, err := jwt.ParseWithClaims(raw, &jwtClaims{}, func(t *jwt.Token) (any, error) {
|
||
|
|
if t.Method != jwt.SigningMethodHS256 {
|
||
|
|
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||
|
|
}
|
||
|
|
return []byte(secret), nil
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("%w: %w", errInvalidToken, err)
|
||
|
|
}
|
||
|
|
claims, ok := parsed.Claims.(*jwtClaims)
|
||
|
|
if !ok || !parsed.Valid || claims.Typ != string(want) || claims.TenantID == "" || claims.UID == "" {
|
||
|
|
return nil, errInvalidToken
|
||
|
|
}
|
||
|
|
expiresAt := time.Time{}
|
||
|
|
if claims.ExpiresAt != nil {
|
||
|
|
expiresAt = claims.ExpiresAt.Time
|
||
|
|
}
|
||
|
|
return &parsedClaims{TenantID: claims.TenantID, UID: claims.UID, AuthGen: claims.AuthGen, ID: claims.ID, expiresAt: expiresAt}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func normalizeConfig(cfg config.AuthConf) config.AuthConf {
|
||
|
|
if cfg.AccessExpireSeconds <= 0 {
|
||
|
|
cfg.AccessExpireSeconds = 900
|
||
|
|
}
|
||
|
|
if cfg.RefreshExpireSeconds <= 0 {
|
||
|
|
cfg.RefreshExpireSeconds = 2592000
|
||
|
|
}
|
||
|
|
if cfg.AccessSecret == "" {
|
||
|
|
cfg.AccessSecret = "haixun-dev-access-secret-change-me"
|
||
|
|
}
|
||
|
|
if cfg.RefreshSecret == "" {
|
||
|
|
cfg.RefreshSecret = "haixun-dev-refresh-secret-change-me"
|
||
|
|
}
|
||
|
|
return cfg
|
||
|
|
}
|
||
|
|
|
||
|
|
func remainingTTL(expiresAt time.Time) time.Duration {
|
||
|
|
if expiresAt.IsZero() {
|
||
|
|
return time.Second
|
||
|
|
}
|
||
|
|
ttl := time.Until(expiresAt)
|
||
|
|
if ttl < time.Second {
|
||
|
|
return time.Second
|
||
|
|
}
|
||
|
|
return ttl
|
||
|
|
}
|
||
|
|
|
||
|
|
var errInvalidToken = errors.New("auth: invalid token")
|