package usecase import ( "context" "errors" "fmt" "time" "haixun-backend/internal/config" app "haixun-backend/internal/library/errors" "haixun-backend/internal/library/errors/code" domrepo "haixun-backend/internal/model/auth/domain/repository" domusecase "haixun-backend/internal/model/auth/domain/usecase" "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" ) type tokenUseCase struct { cfg config.AuthConf revoke domrepo.TokenRevokeStore } func NewTokenUseCase(cfg config.AuthConf, revoke domrepo.TokenRevokeStore) domusecase.TokenUseCase { cfg = normalizeConfig(cfg) return &tokenUseCase{cfg: cfg, revoke: revoke} } func (u *tokenUseCase) IssuePair(ctx context.Context, req domusecase.IssuePairRequest) (*domusecase.TokenPair, error) { if req.TenantID == "" || req.UID == "" { return nil, app.For(code.Auth).InputMissingRequired("tenant_id and uid are required") } access, err := u.sign(req, domusecase.TokenTypeAccess, u.cfg.AccessExpireSeconds, u.cfg.AccessSecret) if err != nil { return nil, app.For(code.Auth).SysInternal("sign access token failed").WithCause(err) } refresh, err := u.sign(req, domusecase.TokenTypeRefresh, u.cfg.RefreshExpireSeconds, u.cfg.RefreshSecret) if err != nil { return nil, app.For(code.Auth).SysInternal("sign refresh token failed").WithCause(err) } if u.revoke != nil { if err := u.revoke.SavePair(ctx, access.jti, refresh.jti, time.Until(access.expiresAt), time.Until(refresh.expiresAt)); err != nil { return nil, app.For(code.Auth).DBError("save jwt pair failed").WithCause(err) } } return &domusecase.TokenPair{ AccessToken: access.raw, RefreshToken: refresh.raw, ExpiresIn: u.cfg.AccessExpireSeconds, TokenType: "Bearer", UID: req.UID, }, nil } func (u *tokenUseCase) Refresh(ctx context.Context, refreshToken string) (*domusecase.TokenPair, error) { if refreshToken == "" { return nil, app.For(code.Auth).InputMissingRequired("refresh_token is required") } claims, err := u.parse(refreshToken, domusecase.TokenTypeRefresh, u.cfg.RefreshSecret) if err != nil { return nil, app.For(code.Auth).AuthUnauthorized("invalid refresh token").WithCause(err) } if err := u.ensureNotBlacklisted(ctx, claims.ID); err != nil { return nil, err } pair, err := u.IssuePair(ctx, domusecase.IssuePairRequest{ TenantID: claims.TenantID, UID: claims.UID, AuthGen: claims.AuthGen, }) if err != nil { return nil, err } if u.revoke != nil { _ = u.revoke.Blacklist(ctx, claims.ID, remainingTTL(claims.expiresAt)) if accessJTI, err := u.revoke.GetPairedJTI(ctx, claims.ID); err == nil && accessJTI != "" { _ = u.revoke.Blacklist(ctx, accessJTI, time.Duration(u.cfg.AccessExpireSeconds)*time.Second) _ = u.revoke.DeletePair(ctx, accessJTI, claims.ID) } } return pair, nil } func (u *tokenUseCase) Logout(ctx context.Context, req domusecase.LogoutRequest) error { if req.AccessToken == "" { return app.For(code.Auth).InputMissingRequired("access token is required") } if u.revoke == nil { return nil } claims, err := u.parse(req.AccessToken, domusecase.TokenTypeAccess, u.cfg.AccessSecret) if err != nil { return app.For(code.Auth).AuthUnauthorized("invalid access token").WithCause(err) } if err := u.revoke.Blacklist(ctx, claims.ID, remainingTTL(claims.expiresAt)); err != nil { return app.For(code.Auth).DBError("blacklist access token failed").WithCause(err) } refreshJTI, err := u.revoke.GetPairedJTI(ctx, claims.ID) if err != nil { return app.For(code.Auth).DBError("read jwt pair failed").WithCause(err) } if refreshJTI != "" { _ = u.revoke.Blacklist(ctx, refreshJTI, time.Duration(u.cfg.RefreshExpireSeconds)*time.Second) } return u.revoke.DeletePair(ctx, claims.ID, refreshJTI) } func (u *tokenUseCase) ParseAccessToken(ctx context.Context, accessToken string) (*domusecase.AccessClaims, error) { if accessToken == "" { return nil, app.For(code.Auth).AuthUnauthorized("missing access token") } claims, err := u.parse(accessToken, domusecase.TokenTypeAccess, u.cfg.AccessSecret) if err != nil { return nil, app.For(code.Auth).AuthUnauthorized("invalid access token").WithCause(err) } if err := u.ensureNotBlacklisted(ctx, claims.ID); err != nil { return nil, err } return &domusecase.AccessClaims{ TenantID: claims.TenantID, UID: claims.UID, AuthGen: claims.AuthGen, JTI: claims.ID, }, nil } func (u *tokenUseCase) ensureNotBlacklisted(ctx context.Context, jti string) error { if u.revoke == nil || jti == "" { return nil } blacklisted, err := u.revoke.IsBlacklisted(ctx, jti) if err != nil { return app.For(code.Auth).DBError("check jwt blacklist failed").WithCause(err) } if blacklisted { return app.For(code.Auth).AuthUnauthorized("token revoked") } return nil } type jwtClaims struct { TenantID string `json:"tenant_id"` UID string `json:"uid"` Typ string `json:"typ"` AuthGen int64 `json:"auth_gen"` jwt.RegisteredClaims } type parsedClaims struct { TenantID string UID string AuthGen int64 ID string expiresAt time.Time } type signedToken struct { raw string jti string expiresAt time.Time } func (u *tokenUseCase) sign(req domusecase.IssuePairRequest, typ domusecase.TokenType, expireSec int64, secret string) (*signedToken, error) { now := time.Now().UTC() expiresAt := now.Add(time.Duration(expireSec) * time.Second) jti := uuid.NewString() claims := jwtClaims{ TenantID: req.TenantID, UID: req.UID, Typ: string(typ), AuthGen: req.AuthGen, RegisteredClaims: jwt.RegisteredClaims{ ID: jti, IssuedAt: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(expiresAt), }, } raw, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(secret)) if err != nil { return nil, err } return &signedToken{raw: raw, jti: jti, expiresAt: expiresAt}, nil } func (u *tokenUseCase) parse(raw string, want domusecase.TokenType, secret string) (*parsedClaims, error) { parsed, err := jwt.ParseWithClaims(raw, &jwtClaims{}, func(t *jwt.Token) (any, error) { if t.Method != jwt.SigningMethodHS256 { return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) } return []byte(secret), nil }) if err != nil { return nil, fmt.Errorf("%w: %w", errInvalidToken, err) } claims, ok := parsed.Claims.(*jwtClaims) if !ok || !parsed.Valid || claims.Typ != string(want) || claims.TenantID == "" || claims.UID == "" { return nil, errInvalidToken } expiresAt := time.Time{} if claims.ExpiresAt != nil { expiresAt = claims.ExpiresAt.Time } return &parsedClaims{TenantID: claims.TenantID, UID: claims.UID, AuthGen: claims.AuthGen, ID: claims.ID, expiresAt: expiresAt}, nil } func normalizeConfig(cfg config.AuthConf) config.AuthConf { if cfg.AccessExpireSeconds <= 0 { cfg.AccessExpireSeconds = 900 } if cfg.RefreshExpireSeconds <= 0 { cfg.RefreshExpireSeconds = 2592000 } if cfg.AccessSecret == "" { cfg.AccessSecret = "haixun-dev-access-secret-change-me" } if cfg.RefreshSecret == "" { cfg.RefreshSecret = "haixun-dev-refresh-secret-change-me" } return cfg } func remainingTTL(expiresAt time.Time) time.Duration { if expiresAt.IsZero() { return time.Second } ttl := time.Until(expiresAt) if ttl < time.Second { return time.Second } return ttl } var errInvalidToken = errors.New("auth: invalid token")