package usecase import ( "context" "errors" "fmt" "time" authconfig "gateway/internal/model/auth/config" domrepo "gateway/internal/model/auth/domain/repository" domusecase "gateway/internal/model/auth/domain/usecase" "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" ) type tokenUseCase struct { cfg authconfig.Config revoke domrepo.TokenRevokeStore } // TokenUseCaseParam wires TokenUseCase. type TokenUseCaseParam struct { Config authconfig.Config Revoke domrepo.TokenRevokeStore } // MustTokenUseCase constructs TokenUseCase. func MustTokenUseCase(param TokenUseCaseParam) domusecase.TokenUseCase { cfg := param.Config.Defaults() if !cfg.Enabled() { panic("auth: JWT secrets are required") } return &tokenUseCase{cfg: cfg, revoke: param.Revoke} } func (uc *tokenUseCase) IssuePair(ctx context.Context, req *domusecase.IssuePairRequest) (*domusecase.TokenPair, error) { if req == nil || req.TenantID == "" || req.UID == "" { return nil, errb.InputMissingRequired("tenant_id and uid are required") } access, err := uc.sign(req, domusecase.TokenTypeAccess, uc.cfg.AccessExpire, uc.cfg.AccessSecret) if err != nil { return nil, errb.SysInternal("sign access token failed").WithCause(err) } refresh, err := uc.sign(req, domusecase.TokenTypeRefresh, uc.cfg.RefreshExpire, uc.cfg.RefreshSecret) if err != nil { return nil, errb.SysInternal("sign refresh token failed").WithCause(err) } if uc.revoke != nil { accessTTL := time.Until(access.expiresAt) refreshTTL := time.Until(refresh.expiresAt) if err := uc.revoke.SavePair(ctx, access.jti, refresh.jti, accessTTL, refreshTTL); err != nil { return nil, errb.DBError("save jwt pair failed").WithCause(err) } } return &domusecase.TokenPair{ AccessToken: access.raw, RefreshToken: refresh.raw, ExpiresIn: uc.cfg.AccessExpire, TokenType: "Bearer", }, nil } func (uc *tokenUseCase) Refresh(ctx context.Context, refreshToken string) (*domusecase.TokenPair, error) { if refreshToken == "" { return nil, errb.InputMissingRequired("refresh_token is required") } claims, err := uc.parse(refreshToken, domusecase.TokenTypeRefresh, uc.cfg.RefreshSecret) if err != nil { if errors.Is(err, errInvalidToken) { return nil, errb.AuthUnauthorized("invalid refresh token").WithCause(err) } return nil, errb.SysInternal("parse refresh token failed").WithCause(err) } if err := uc.ensureNotBlacklisted(ctx, claims.ID); err != nil { return nil, err } pair, err := uc.IssuePair(ctx, &domusecase.IssuePairRequest{ TenantID: claims.TenantID, UID: claims.UID, AuthGen: claims.AuthGen, }) if err != nil { return nil, err } if uc.revoke != nil { if err := uc.revoke.Blacklist(ctx, claims.ID, remainingTTL(claims.expiresAt)); err != nil { return nil, errb.DBError("blacklist refresh token failed").WithCause(err) } if accessJTI, err := uc.revoke.GetPairedJTI(ctx, claims.ID); err != nil { return nil, errb.DBError("read jwt pair failed").WithCause(err) } else if accessJTI != "" { if err := uc.revoke.Blacklist(ctx, accessJTI, time.Duration(uc.cfg.AccessExpire)*time.Second); err != nil { return nil, errb.DBError("blacklist access token failed").WithCause(err) } if err := uc.revoke.DeletePair(ctx, accessJTI, claims.ID); err != nil { return nil, errb.DBError("delete jwt pair failed").WithCause(err) } } } return pair, nil } func (uc *tokenUseCase) Logout(ctx context.Context, req *domusecase.LogoutRequest) error { if req == nil || req.AccessToken == "" { return errb.InputMissingRequired("access token is required") } if uc.revoke == nil { return errb.SysNotImplemented("token revoke store not configured") } claims, err := uc.parse(req.AccessToken, domusecase.TokenTypeAccess, uc.cfg.AccessSecret) if err != nil { if errors.Is(err, errInvalidToken) { return errb.AuthUnauthorized("invalid access token").WithCause(err) } return errb.SysInternal("parse access token failed").WithCause(err) } if err := uc.revoke.Blacklist(ctx, claims.ID, remainingTTL(claims.expiresAt)); err != nil { return errb.DBError("blacklist access token failed").WithCause(err) } refreshJTI, err := uc.revoke.GetPairedJTI(ctx, claims.ID) if err != nil { return errb.DBError("read jwt pair failed").WithCause(err) } if refreshJTI != "" { if err := uc.revoke.Blacklist(ctx, refreshJTI, time.Duration(uc.cfg.RefreshExpire)*time.Second); err != nil { return errb.DBError("blacklist refresh token failed").WithCause(err) } } if err := uc.revoke.DeletePair(ctx, claims.ID, refreshJTI); err != nil { return errb.DBError("delete jwt pair failed").WithCause(err) } return nil } func (uc *tokenUseCase) ParseAccessToken(ctx context.Context, accessToken string) (*domusecase.AccessClaims, error) { if accessToken == "" { return nil, errb.AuthUnauthorized("missing access token") } claims, err := uc.parse(accessToken, domusecase.TokenTypeAccess, uc.cfg.AccessSecret) if err != nil { if errors.Is(err, errInvalidToken) { return nil, errb.AuthUnauthorized("invalid access token").WithCause(err) } return nil, errb.SysInternal("parse access token failed").WithCause(err) } if err := uc.ensureNotBlacklisted(ctx, claims.ID); err != nil { return nil, err } return &domusecase.AccessClaims{ TenantID: claims.TenantID, UID: claims.UID, AuthGen: claims.AuthGen, JTI: claims.ID, }, nil } func (uc *tokenUseCase) ensureNotBlacklisted(ctx context.Context, jti string) error { if uc.revoke == nil || jti == "" { return nil } blacklisted, err := uc.revoke.IsBlacklisted(ctx, jti) if err != nil { return errb.DBError("check jwt blacklist failed").WithCause(err) } if blacklisted { return errb.AuthUnauthorized("token revoked") } return nil } var errInvalidToken = errors.New("auth: invalid token") type jwtClaims struct { TenantID string `json:"tenant_id"` UID string `json:"uid"` Typ string `json:"typ"` AuthGen int64 `json:"auth_gen"` jwt.RegisteredClaims } type parsedClaims struct { TenantID string UID string AuthGen int64 ID string expiresAt time.Time } type signedToken struct { raw string jti string expiresAt time.Time } func (uc *tokenUseCase) sign(req *domusecase.IssuePairRequest, typ domusecase.TokenType, expireSec int64, secret string) (*signedToken, error) { now := time.Now().UTC() expiresAt := now.Add(time.Duration(expireSec) * time.Second) jti := uuid.NewString() claims := jwtClaims{ TenantID: req.TenantID, UID: req.UID, Typ: string(typ), AuthGen: req.AuthGen, RegisteredClaims: jwt.RegisteredClaims{ ID: jti, IssuedAt: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(expiresAt), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token.Header["kid"] = uc.cfg.ActiveKID raw, err := token.SignedString([]byte(secret)) if err != nil { return nil, err } return &signedToken{raw: raw, jti: jti, expiresAt: expiresAt}, nil } func (uc *tokenUseCase) parse(raw string, want domusecase.TokenType, secret string) (*parsedClaims, error) { parsed, err := jwt.ParseWithClaims(raw, &jwtClaims{}, func(t *jwt.Token) (any, error) { if t.Method != jwt.SigningMethodHS256 { return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) } return []byte(secret), nil }) if err != nil { return nil, fmt.Errorf("%w: %w", errInvalidToken, err) } claims, ok := parsed.Claims.(*jwtClaims) if !ok || !parsed.Valid { return nil, errInvalidToken } if claims.Typ != string(want) { return nil, errInvalidToken } if claims.TenantID == "" || claims.UID == "" { return nil, errInvalidToken } expiresAt := time.Time{} if claims.ExpiresAt != nil { expiresAt = claims.ExpiresAt.Time } return &parsedClaims{ TenantID: claims.TenantID, UID: claims.UID, AuthGen: claims.AuthGen, ID: claims.ID, expiresAt: expiresAt, }, nil } func remainingTTL(expiresAt time.Time) time.Duration { if expiresAt.IsZero() { return time.Second } ttl := time.Until(expiresAt) if ttl < time.Second { return time.Second } return ttl }