package usecase import ( "context" "errors" "strings" "time" "golang.org/x/crypto/bcrypt" "gateway/internal/library/crypto" memberconfig "gateway/internal/model/member/config" member "gateway/internal/model/member/domain" domrepo "gateway/internal/model/member/domain/repository" domusecase "gateway/internal/model/member/domain/usecase" "gateway/internal/model/member/totp" ) // TOTPUseCaseParam wires TOTPUseCase dependencies. Cipher is mandatory; the // factory rejects construction when the KEK is missing or invalid. type TOTPUseCaseParam struct { Profile domrepo.TOTPProfileRepository Enroll domrepo.TOTPEnrollStore Replay domrepo.TOTPReplayStore Cipher *crypto.Cipher Config memberconfig.Config Now func() time.Time } // MustTOTPUseCase constructs a TOTPUseCase. All collaborators must be non-nil // (the wiring layer is responsible for validating SecretKEK before calling). func MustTOTPUseCase(param TOTPUseCaseParam) domusecase.TOTPUseCase { if param.Profile == nil { panic("member: totp profile repository is required") } if param.Enroll == nil { panic("member: totp enroll store is required") } if param.Replay == nil { panic("member: totp replay store is required") } if param.Cipher == nil { panic("member: totp cipher is required") } now := param.Now if now == nil { now = time.Now } return &totpUseCase{ profile: param.Profile, enroll: param.Enroll, replay: param.Replay, cipher: param.Cipher, config: param.Config.Defaults(), now: now, } } type totpUseCase struct { profile domrepo.TOTPProfileRepository enroll domrepo.TOTPEnrollStore replay domrepo.TOTPReplayStore cipher *crypto.Cipher config memberconfig.Config now func() time.Time } func (uc *totpUseCase) StartEnroll(ctx context.Context, tenantID, uid, account string) (*domusecase.EnrollStartDTO, error) { if tenantID == "" || uid == "" { return nil, errb.InputMissingRequired("tenant_id and uid are required") } rec, err := uc.profile.Get(ctx, tenantID, uid) if err != nil { return nil, errb.SysInternal("read totp profile failed").WithCause(err) } if rec != nil && rec.Enrolled { return nil, errb.ResAlreadyExist("totp already enrolled").WithCause(member.ErrTOTPAlreadyEnroll) } secret, err := totp.GenerateSecret() if err != nil { return nil, errb.SysInternal("totp secret generation failed").WithCause(err) } cipherBlob, err := uc.cipher.Encrypt(secret) if err != nil { return nil, errb.SysInternal("totp secret encrypt failed").WithCause(err) } ttl := time.Duration(uc.config.TOTP.EnrollTTLSeconds) * time.Second if err := uc.enroll.Save(ctx, tenantID, uid, cipherBlob, ttl); err != nil { return nil, errb.SysInternal("totp enroll persist failed").WithCause(err) } accountLabel := account if accountLabel == "" { accountLabel = uid } otpURL, err := totp.BuildOtpauthURL(totp.OtpauthURLInput{ Issuer: uc.config.TOTP.Issuer, Account: accountLabel, Secret: secret, Algorithm: uc.config.TOTP.Algorithm, Digits: uc.config.TOTP.Digits, Period: time.Duration(uc.config.TOTP.PeriodSeconds) * time.Second, }) if err != nil { return nil, errb.SysInternal("totp otpauth url build failed").WithCause(err) } return &domusecase.EnrollStartDTO{ OtpauthURL: otpURL, Issuer: uc.config.TOTP.Issuer, Account: accountLabel, Digits: uc.config.TOTP.Digits, PeriodSec: uc.config.TOTP.PeriodSeconds, ExpiresIn: uc.config.TOTP.EnrollTTLSeconds, }, nil } func (uc *totpUseCase) ConfirmEnroll(ctx context.Context, tenantID, uid, code string) ([]string, error) { if tenantID == "" || uid == "" || code == "" { return nil, errb.InputMissingRequired("tenant_id, uid and code are required") } rec, err := uc.profile.Get(ctx, tenantID, uid) if err != nil { return nil, errb.SysInternal("read totp profile failed").WithCause(err) } if rec != nil && rec.Enrolled { return nil, errb.ResAlreadyExist("totp already enrolled").WithCause(member.ErrTOTPAlreadyEnroll) } cipherBlob, err := uc.enroll.Get(ctx, tenantID, uid) if err != nil { if errors.Is(err, member.ErrTOTPEnrollMissing) { return nil, errb.ResNotFound("totp enroll", uid).WithCause(err) } return nil, errb.SysInternal("read totp enroll failed").WithCause(err) } secret, err := uc.cipher.Decrypt(cipherBlob) if err != nil { return nil, errb.SysInternal("totp secret decrypt failed").WithCause(err) } if _, ok := totp.Verify(secret, code, uc.now(), uc.period(), uc.config.TOTP.Digits, uc.config.TOTP.Window); !ok { return nil, errb.AuthForbidden("invalid totp code").WithCause(member.ErrTOTPInvalidCode) } plainCodes, hashes, err := uc.generateBackupCodes() if err != nil { return nil, err } if err := uc.profile.Save(ctx, tenantID, uid, &domrepo.TOTPProfileRecord{ Enrolled: true, SecretCipher: cipherBlob, BackupCodesHash: hashes, EnrolledAt: uc.now().UnixMilli(), }); err != nil { return nil, errb.SysInternal("persist totp profile failed").WithCause(err) } if delErr := uc.enroll.Delete(ctx, tenantID, uid); delErr != nil { return nil, errb.SysInternal("clear totp enroll failed").WithCause(delErr) } return plainCodes, nil } func (uc *totpUseCase) VerifyCode(ctx context.Context, tenantID, uid, code string) error { if tenantID == "" || uid == "" || code == "" { return errb.InputMissingRequired("tenant_id, uid and code are required") } rec, err := uc.profile.Get(ctx, tenantID, uid) if err != nil { return errb.SysInternal("read totp profile failed").WithCause(err) } if rec == nil || !rec.Enrolled { return errb.ResInvalidState("totp not enrolled").WithCause(member.ErrTOTPNotEnrolled) } secret, err := uc.cipher.Decrypt(rec.SecretCipher) if err != nil { return errb.SysInternal("totp secret decrypt failed").WithCause(err) } now := uc.now() cleanCode := strings.TrimSpace(code) digits := uc.config.TOTP.Digits if len(cleanCode) == digits { if step, ok := totp.Verify(secret, cleanCode, now, uc.period(), digits, uc.config.TOTP.Window); ok { ttl := time.Duration(uc.config.TOTP.ReplayTTLSeconds) * time.Second fresh, markErr := uc.replay.MarkUsed(ctx, tenantID, uid, step, ttl) if markErr != nil { return errb.SysInternal("totp replay mark failed").WithCause(markErr) } if !fresh { return errb.AuthForbidden("totp code already used").WithCause(member.ErrTOTPCodeReplay) } return nil } } if uc.tryBackupCode(ctx, tenantID, uid, rec, cleanCode) { return nil } return errb.AuthForbidden("invalid totp code").WithCause(member.ErrTOTPInvalidCode) } func (uc *totpUseCase) tryBackupCode(ctx context.Context, tenantID, uid string, rec *domrepo.TOTPProfileRecord, code string) bool { for _, hash := range rec.BackupCodesHash { if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(code)); err == nil { if _, err := uc.profile.ConsumeBackupCode(ctx, tenantID, uid, hash); err != nil { return false } return true } } return false } func (uc *totpUseCase) Disable(ctx context.Context, tenantID, uid string) error { if tenantID == "" || uid == "" { return errb.InputMissingRequired("tenant_id and uid are required") } if err := uc.profile.Clear(ctx, tenantID, uid); err != nil { return errb.SysInternal("clear totp profile failed").WithCause(err) } if err := uc.enroll.Delete(ctx, tenantID, uid); err != nil { return errb.SysInternal("clear totp enroll failed").WithCause(err) } return nil } func (uc *totpUseCase) RegenerateBackupCodes(ctx context.Context, tenantID, uid string) ([]string, error) { if tenantID == "" || uid == "" { return nil, errb.InputMissingRequired("tenant_id and uid are required") } rec, err := uc.profile.Get(ctx, tenantID, uid) if err != nil { return nil, errb.SysInternal("read totp profile failed").WithCause(err) } if rec == nil || !rec.Enrolled { return nil, errb.ResInvalidState("totp not enrolled").WithCause(member.ErrTOTPNotEnrolled) } plain, hashes, err := uc.generateBackupCodes() if err != nil { return nil, err } if err := uc.profile.ReplaceBackupCodes(ctx, tenantID, uid, hashes); err != nil { return nil, errb.SysInternal("replace backup codes failed").WithCause(err) } return plain, nil } func (uc *totpUseCase) Status(ctx context.Context, tenantID, uid string) (*domusecase.TOTPStatusDTO, error) { if tenantID == "" || uid == "" { return nil, errb.InputMissingRequired("tenant_id and uid are required") } rec, err := uc.profile.Get(ctx, tenantID, uid) if err != nil { return nil, errb.SysInternal("read totp profile failed").WithCause(err) } dto := &domusecase.TOTPStatusDTO{} if rec == nil { return dto, nil } dto.Enrolled = rec.Enrolled dto.EnrolledAt = rec.EnrolledAt dto.BackupCodesRemaining = len(rec.BackupCodesHash) return dto, nil } func (uc *totpUseCase) period() time.Duration { return time.Duration(uc.config.TOTP.PeriodSeconds) * time.Second } // generateBackupCodes returns the plaintext codes alongside their bcrypt // hashes ready for persistence. func (uc *totpUseCase) generateBackupCodes() (plain, hashes []string, err error) { codes, err := totp.GenerateBackupCodes(uc.config.TOTP.BackupCodeCount, uc.config.TOTP.BackupCodeLength) if err != nil { return nil, nil, errb.SysInternal("backup code generation failed").WithCause(err) } hashes = make([]string, 0, len(codes)) for _, c := range codes { h, hashErr := bcrypt.GenerateFromPassword([]byte(c), bcrypt.DefaultCost) if hashErr != nil { return nil, nil, errb.SysInternal("backup code hash failed").WithCause(hashErr) } hashes = append(hashes, string(h)) } return codes, hashes, nil }