template-monorepo/internal/model/member/usecase/totp_usecase.go

299 lines
9.5 KiB
Go
Raw Permalink Normal View History

2026-05-20 13:03:59 +00:00
package usecase
import (
"context"
"errors"
"strings"
"time"
"golang.org/x/crypto/bcrypt"
"gateway/internal/library/crypto"
memberconfig "gateway/internal/model/member/config"
member "gateway/internal/model/member/domain"
domrepo "gateway/internal/model/member/domain/repository"
domusecase "gateway/internal/model/member/domain/usecase"
"gateway/internal/model/member/totp"
)
// TOTPUseCaseParam wires TOTPUseCase dependencies. Cipher is mandatory; the
// factory rejects construction when the KEK is missing or invalid.
type TOTPUseCaseParam struct {
Profile domrepo.TOTPProfileRepository
Enroll domrepo.TOTPEnrollStore
Replay domrepo.TOTPReplayStore
Cipher *crypto.Cipher
Config memberconfig.Config
Now func() time.Time
}
// MustTOTPUseCase constructs a TOTPUseCase. All collaborators must be non-nil
// (the wiring layer is responsible for validating SecretKEK before calling).
func MustTOTPUseCase(param TOTPUseCaseParam) domusecase.TOTPUseCase {
if param.Profile == nil {
panic("member: totp profile repository is required")
}
if param.Enroll == nil {
panic("member: totp enroll store is required")
}
if param.Replay == nil {
panic("member: totp replay store is required")
}
if param.Cipher == nil {
panic("member: totp cipher is required")
}
now := param.Now
if now == nil {
now = time.Now
}
return &totpUseCase{
profile: param.Profile,
enroll: param.Enroll,
replay: param.Replay,
cipher: param.Cipher,
config: param.Config.Defaults(),
now: now,
}
}
type totpUseCase struct {
profile domrepo.TOTPProfileRepository
enroll domrepo.TOTPEnrollStore
replay domrepo.TOTPReplayStore
cipher *crypto.Cipher
config memberconfig.Config
now func() time.Time
}
func (uc *totpUseCase) StartEnroll(ctx context.Context, tenantID, uid, account string) (*domusecase.EnrollStartDTO, error) {
if tenantID == "" || uid == "" {
return nil, errb.InputMissingRequired("tenant_id and uid are required")
}
rec, err := uc.profile.Get(ctx, tenantID, uid)
if err != nil {
return nil, wrapRepoErr(err, "read totp profile failed")
2026-05-20 13:03:59 +00:00
}
if rec != nil && rec.Enrolled {
return nil, errb.ResAlreadyExist("totp already enrolled").WithCause(member.ErrTOTPAlreadyEnroll)
}
secret, err := totp.GenerateSecret()
if err != nil {
return nil, errb.SysInternal("totp secret generation failed").WithCause(err)
}
cipherBlob, err := uc.cipher.Encrypt(secret)
if err != nil {
return nil, errb.SysInternal("totp secret encrypt failed").WithCause(err)
}
ttl := time.Duration(uc.config.TOTP.EnrollTTLSeconds) * time.Second
if err := uc.enroll.Save(ctx, tenantID, uid, cipherBlob, ttl); err != nil {
return nil, wrapRepoErr(err, "totp enroll persist failed")
2026-05-20 13:03:59 +00:00
}
accountLabel := account
if accountLabel == "" {
accountLabel = uid
}
otpURL, err := totp.BuildOtpauthURL(totp.OtpauthURLInput{
Issuer: uc.config.TOTP.Issuer,
Account: accountLabel,
Secret: secret,
Algorithm: uc.config.TOTP.Algorithm,
Digits: uc.config.TOTP.Digits,
Period: time.Duration(uc.config.TOTP.PeriodSeconds) * time.Second,
})
if err != nil {
return nil, errb.SysInternal("totp otpauth url build failed").WithCause(err)
}
return &domusecase.EnrollStartDTO{
OtpauthURL: otpURL,
Issuer: uc.config.TOTP.Issuer,
Account: accountLabel,
Digits: uc.config.TOTP.Digits,
PeriodSec: uc.config.TOTP.PeriodSeconds,
ExpiresIn: uc.config.TOTP.EnrollTTLSeconds,
}, nil
}
func (uc *totpUseCase) ConfirmEnroll(ctx context.Context, tenantID, uid, code string) ([]string, error) {
if tenantID == "" || uid == "" || code == "" {
return nil, errb.InputMissingRequired("tenant_id, uid and code are required")
}
rec, err := uc.profile.Get(ctx, tenantID, uid)
if err != nil {
return nil, wrapRepoErr(err, "read totp profile failed")
2026-05-20 13:03:59 +00:00
}
if rec != nil && rec.Enrolled {
return nil, errb.ResAlreadyExist("totp already enrolled").WithCause(member.ErrTOTPAlreadyEnroll)
}
cipherBlob, err := uc.enroll.Get(ctx, tenantID, uid)
if err != nil {
if errors.Is(err, member.ErrTOTPEnrollMissing) {
return nil, errb.ResNotFound("totp enroll", uid).WithCause(err)
}
return nil, wrapRepoErr(err, "read totp enroll failed")
2026-05-20 13:03:59 +00:00
}
secret, err := uc.cipher.Decrypt(cipherBlob)
if err != nil {
return nil, errb.SysInternal("totp secret decrypt failed").WithCause(err)
}
if _, ok := totp.Verify(secret, code, uc.now(), uc.period(), uc.config.TOTP.Digits, uc.config.TOTP.Window); !ok {
return nil, errb.AuthForbidden("invalid totp code").WithCause(member.ErrTOTPInvalidCode)
}
plainCodes, hashes, err := uc.generateBackupCodes()
if err != nil {
return nil, err
}
if err := uc.profile.Save(ctx, tenantID, uid, &domrepo.TOTPProfileRecord{
Enrolled: true,
SecretCipher: cipherBlob,
BackupCodesHash: hashes,
EnrolledAt: uc.now().UnixMilli(),
}); err != nil {
if errors.Is(err, member.ErrNotFound) {
return nil, errb.ResNotFound("member", uid).WithCause(err)
}
return nil, wrapRepoErr(err, "persist totp profile failed")
2026-05-20 13:03:59 +00:00
}
if delErr := uc.enroll.Delete(ctx, tenantID, uid); delErr != nil {
return nil, wrapRepoErr(delErr, "clear totp enroll failed")
2026-05-20 13:03:59 +00:00
}
return plainCodes, nil
}
func (uc *totpUseCase) VerifyCode(ctx context.Context, tenantID, uid, code string) error {
if tenantID == "" || uid == "" || code == "" {
return errb.InputMissingRequired("tenant_id, uid and code are required")
}
rec, err := uc.profile.Get(ctx, tenantID, uid)
if err != nil {
return wrapRepoErr(err, "read totp profile failed")
2026-05-20 13:03:59 +00:00
}
if rec == nil || !rec.Enrolled {
return errb.ResInvalidState("totp not enrolled").WithCause(member.ErrTOTPNotEnrolled)
}
secret, err := uc.cipher.Decrypt(rec.SecretCipher)
if err != nil {
return errb.SysInternal("totp secret decrypt failed").WithCause(err)
}
now := uc.now()
cleanCode := strings.TrimSpace(code)
digits := uc.config.TOTP.Digits
if len(cleanCode) == digits {
if step, ok := totp.Verify(secret, cleanCode, now, uc.period(), digits, uc.config.TOTP.Window); ok {
ttl := time.Duration(uc.config.TOTP.ReplayTTLSeconds) * time.Second
fresh, markErr := uc.replay.MarkUsed(ctx, tenantID, uid, step, ttl)
if markErr != nil {
return wrapRepoErr(markErr, "totp replay mark failed")
2026-05-20 13:03:59 +00:00
}
if !fresh {
return errb.AuthForbidden("totp code already used").WithCause(member.ErrTOTPCodeReplay)
}
return nil
}
}
if uc.tryBackupCode(ctx, tenantID, uid, rec, cleanCode) {
return nil
}
return errb.AuthForbidden("invalid totp code").WithCause(member.ErrTOTPInvalidCode)
}
func (uc *totpUseCase) tryBackupCode(ctx context.Context, tenantID, uid string, rec *domrepo.TOTPProfileRecord, code string) bool {
for _, hash := range rec.BackupCodesHash {
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(code)); err == nil {
if _, err := uc.profile.ConsumeBackupCode(ctx, tenantID, uid, hash); err != nil {
return false
}
return true
}
}
return false
}
func (uc *totpUseCase) Disable(ctx context.Context, tenantID, uid string) error {
if tenantID == "" || uid == "" {
return errb.InputMissingRequired("tenant_id and uid are required")
}
if err := uc.profile.Clear(ctx, tenantID, uid); err != nil {
if errors.Is(err, member.ErrNotFound) {
return errb.ResNotFound("member", uid).WithCause(err)
}
return wrapRepoErr(err, "clear totp profile failed")
2026-05-20 13:03:59 +00:00
}
if err := uc.enroll.Delete(ctx, tenantID, uid); err != nil {
return wrapRepoErr(err, "clear totp enroll failed")
2026-05-20 13:03:59 +00:00
}
return nil
}
func (uc *totpUseCase) RegenerateBackupCodes(ctx context.Context, tenantID, uid string) ([]string, error) {
if tenantID == "" || uid == "" {
return nil, errb.InputMissingRequired("tenant_id and uid are required")
}
rec, err := uc.profile.Get(ctx, tenantID, uid)
if err != nil {
return nil, wrapRepoErr(err, "read totp profile failed")
2026-05-20 13:03:59 +00:00
}
if rec == nil || !rec.Enrolled {
return nil, errb.ResInvalidState("totp not enrolled").WithCause(member.ErrTOTPNotEnrolled)
}
plain, hashes, err := uc.generateBackupCodes()
if err != nil {
return nil, err
}
if err := uc.profile.ReplaceBackupCodes(ctx, tenantID, uid, hashes); err != nil {
if errors.Is(err, member.ErrNotFound) {
return nil, errb.ResNotFound("member", uid).WithCause(err)
}
return nil, wrapRepoErr(err, "replace backup codes failed")
2026-05-20 13:03:59 +00:00
}
return plain, nil
}
func (uc *totpUseCase) Status(ctx context.Context, tenantID, uid string) (*domusecase.TOTPStatusDTO, error) {
if tenantID == "" || uid == "" {
return nil, errb.InputMissingRequired("tenant_id and uid are required")
}
rec, err := uc.profile.Get(ctx, tenantID, uid)
if err != nil {
return nil, wrapRepoErr(err, "read totp profile failed")
2026-05-20 13:03:59 +00:00
}
dto := &domusecase.TOTPStatusDTO{}
if rec == nil {
return dto, nil
}
dto.Enrolled = rec.Enrolled
dto.EnrolledAt = rec.EnrolledAt
dto.BackupCodesRemaining = len(rec.BackupCodesHash)
return dto, nil
}
func (uc *totpUseCase) period() time.Duration {
return time.Duration(uc.config.TOTP.PeriodSeconds) * time.Second
}
// generateBackupCodes returns the plaintext codes alongside their bcrypt
// hashes ready for persistence.
func (uc *totpUseCase) generateBackupCodes() (plain, hashes []string, err error) {
codes, err := totp.GenerateBackupCodes(uc.config.TOTP.BackupCodeCount, uc.config.TOTP.BackupCodeLength)
if err != nil {
return nil, nil, errb.SysInternal("backup code generation failed").WithCause(err)
}
hashes = make([]string, 0, len(codes))
for _, c := range codes {
h, hashErr := bcrypt.GenerateFromPassword([]byte(c), bcrypt.DefaultCost)
if hashErr != nil {
return nil, nil, errb.SysInternal("backup code hash failed").WithCause(hashErr)
}
hashes = append(hashes, string(h))
}
return codes, hashes, nil
}