290 lines
9.4 KiB
Go
290 lines
9.4 KiB
Go
|
|
package usecase
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"golang.org/x/crypto/bcrypt"
|
||
|
|
|
||
|
|
"gateway/internal/library/crypto"
|
||
|
|
memberconfig "gateway/internal/model/member/config"
|
||
|
|
member "gateway/internal/model/member/domain"
|
||
|
|
domrepo "gateway/internal/model/member/domain/repository"
|
||
|
|
domusecase "gateway/internal/model/member/domain/usecase"
|
||
|
|
"gateway/internal/model/member/totp"
|
||
|
|
)
|
||
|
|
|
||
|
|
// TOTPUseCaseParam wires TOTPUseCase dependencies. Cipher is mandatory; the
|
||
|
|
// factory rejects construction when the KEK is missing or invalid.
|
||
|
|
type TOTPUseCaseParam struct {
|
||
|
|
Profile domrepo.TOTPProfileRepository
|
||
|
|
Enroll domrepo.TOTPEnrollStore
|
||
|
|
Replay domrepo.TOTPReplayStore
|
||
|
|
Cipher *crypto.Cipher
|
||
|
|
Config memberconfig.Config
|
||
|
|
Now func() time.Time
|
||
|
|
}
|
||
|
|
|
||
|
|
// MustTOTPUseCase constructs a TOTPUseCase. All collaborators must be non-nil
|
||
|
|
// (the wiring layer is responsible for validating SecretKEK before calling).
|
||
|
|
func MustTOTPUseCase(param TOTPUseCaseParam) domusecase.TOTPUseCase {
|
||
|
|
if param.Profile == nil {
|
||
|
|
panic("member: totp profile repository is required")
|
||
|
|
}
|
||
|
|
if param.Enroll == nil {
|
||
|
|
panic("member: totp enroll store is required")
|
||
|
|
}
|
||
|
|
if param.Replay == nil {
|
||
|
|
panic("member: totp replay store is required")
|
||
|
|
}
|
||
|
|
if param.Cipher == nil {
|
||
|
|
panic("member: totp cipher is required")
|
||
|
|
}
|
||
|
|
now := param.Now
|
||
|
|
if now == nil {
|
||
|
|
now = time.Now
|
||
|
|
}
|
||
|
|
return &totpUseCase{
|
||
|
|
profile: param.Profile,
|
||
|
|
enroll: param.Enroll,
|
||
|
|
replay: param.Replay,
|
||
|
|
cipher: param.Cipher,
|
||
|
|
config: param.Config.Defaults(),
|
||
|
|
now: now,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
type totpUseCase struct {
|
||
|
|
profile domrepo.TOTPProfileRepository
|
||
|
|
enroll domrepo.TOTPEnrollStore
|
||
|
|
replay domrepo.TOTPReplayStore
|
||
|
|
cipher *crypto.Cipher
|
||
|
|
config memberconfig.Config
|
||
|
|
now func() time.Time
|
||
|
|
}
|
||
|
|
|
||
|
|
func (uc *totpUseCase) StartEnroll(ctx context.Context, tenantID, uid, account string) (*domusecase.EnrollStartDTO, error) {
|
||
|
|
if tenantID == "" || uid == "" {
|
||
|
|
return nil, errb.InputMissingRequired("tenant_id and uid are required")
|
||
|
|
}
|
||
|
|
rec, err := uc.profile.Get(ctx, tenantID, uid)
|
||
|
|
if err != nil {
|
||
|
|
return nil, errb.SysInternal("read totp profile failed").WithCause(err)
|
||
|
|
}
|
||
|
|
if rec != nil && rec.Enrolled {
|
||
|
|
return nil, errb.ResAlreadyExist("totp already enrolled").WithCause(member.ErrTOTPAlreadyEnroll)
|
||
|
|
}
|
||
|
|
|
||
|
|
secret, err := totp.GenerateSecret()
|
||
|
|
if err != nil {
|
||
|
|
return nil, errb.SysInternal("totp secret generation failed").WithCause(err)
|
||
|
|
}
|
||
|
|
cipherBlob, err := uc.cipher.Encrypt(secret)
|
||
|
|
if err != nil {
|
||
|
|
return nil, errb.SysInternal("totp secret encrypt failed").WithCause(err)
|
||
|
|
}
|
||
|
|
ttl := time.Duration(uc.config.TOTP.EnrollTTLSeconds) * time.Second
|
||
|
|
if err := uc.enroll.Save(ctx, tenantID, uid, cipherBlob, ttl); err != nil {
|
||
|
|
return nil, errb.SysInternal("totp enroll persist failed").WithCause(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
accountLabel := account
|
||
|
|
if accountLabel == "" {
|
||
|
|
accountLabel = uid
|
||
|
|
}
|
||
|
|
otpURL, err := totp.BuildOtpauthURL(totp.OtpauthURLInput{
|
||
|
|
Issuer: uc.config.TOTP.Issuer,
|
||
|
|
Account: accountLabel,
|
||
|
|
Secret: secret,
|
||
|
|
Algorithm: uc.config.TOTP.Algorithm,
|
||
|
|
Digits: uc.config.TOTP.Digits,
|
||
|
|
Period: time.Duration(uc.config.TOTP.PeriodSeconds) * time.Second,
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
return nil, errb.SysInternal("totp otpauth url build failed").WithCause(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return &domusecase.EnrollStartDTO{
|
||
|
|
OtpauthURL: otpURL,
|
||
|
|
Issuer: uc.config.TOTP.Issuer,
|
||
|
|
Account: accountLabel,
|
||
|
|
Digits: uc.config.TOTP.Digits,
|
||
|
|
PeriodSec: uc.config.TOTP.PeriodSeconds,
|
||
|
|
ExpiresIn: uc.config.TOTP.EnrollTTLSeconds,
|
||
|
|
}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (uc *totpUseCase) ConfirmEnroll(ctx context.Context, tenantID, uid, code string) ([]string, error) {
|
||
|
|
if tenantID == "" || uid == "" || code == "" {
|
||
|
|
return nil, errb.InputMissingRequired("tenant_id, uid and code are required")
|
||
|
|
}
|
||
|
|
rec, err := uc.profile.Get(ctx, tenantID, uid)
|
||
|
|
if err != nil {
|
||
|
|
return nil, errb.SysInternal("read totp profile failed").WithCause(err)
|
||
|
|
}
|
||
|
|
if rec != nil && rec.Enrolled {
|
||
|
|
return nil, errb.ResAlreadyExist("totp already enrolled").WithCause(member.ErrTOTPAlreadyEnroll)
|
||
|
|
}
|
||
|
|
|
||
|
|
cipherBlob, err := uc.enroll.Get(ctx, tenantID, uid)
|
||
|
|
if err != nil {
|
||
|
|
if errors.Is(err, member.ErrTOTPEnrollMissing) {
|
||
|
|
return nil, errb.ResNotFound("totp enroll", uid).WithCause(err)
|
||
|
|
}
|
||
|
|
return nil, errb.SysInternal("read totp enroll failed").WithCause(err)
|
||
|
|
}
|
||
|
|
secret, err := uc.cipher.Decrypt(cipherBlob)
|
||
|
|
if err != nil {
|
||
|
|
return nil, errb.SysInternal("totp secret decrypt failed").WithCause(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if _, ok := totp.Verify(secret, code, uc.now(), uc.period(), uc.config.TOTP.Digits, uc.config.TOTP.Window); !ok {
|
||
|
|
return nil, errb.AuthForbidden("invalid totp code").WithCause(member.ErrTOTPInvalidCode)
|
||
|
|
}
|
||
|
|
|
||
|
|
plainCodes, hashes, err := uc.generateBackupCodes()
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := uc.profile.Save(ctx, tenantID, uid, &domrepo.TOTPProfileRecord{
|
||
|
|
Enrolled: true,
|
||
|
|
SecretCipher: cipherBlob,
|
||
|
|
BackupCodesHash: hashes,
|
||
|
|
EnrolledAt: uc.now().UnixMilli(),
|
||
|
|
}); err != nil {
|
||
|
|
return nil, errb.SysInternal("persist totp profile failed").WithCause(err)
|
||
|
|
}
|
||
|
|
if delErr := uc.enroll.Delete(ctx, tenantID, uid); delErr != nil {
|
||
|
|
return nil, errb.SysInternal("clear totp enroll failed").WithCause(delErr)
|
||
|
|
}
|
||
|
|
return plainCodes, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (uc *totpUseCase) VerifyCode(ctx context.Context, tenantID, uid, code string) error {
|
||
|
|
if tenantID == "" || uid == "" || code == "" {
|
||
|
|
return errb.InputMissingRequired("tenant_id, uid and code are required")
|
||
|
|
}
|
||
|
|
rec, err := uc.profile.Get(ctx, tenantID, uid)
|
||
|
|
if err != nil {
|
||
|
|
return errb.SysInternal("read totp profile failed").WithCause(err)
|
||
|
|
}
|
||
|
|
if rec == nil || !rec.Enrolled {
|
||
|
|
return errb.ResInvalidState("totp not enrolled").WithCause(member.ErrTOTPNotEnrolled)
|
||
|
|
}
|
||
|
|
secret, err := uc.cipher.Decrypt(rec.SecretCipher)
|
||
|
|
if err != nil {
|
||
|
|
return errb.SysInternal("totp secret decrypt failed").WithCause(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
now := uc.now()
|
||
|
|
cleanCode := strings.TrimSpace(code)
|
||
|
|
digits := uc.config.TOTP.Digits
|
||
|
|
if len(cleanCode) == digits {
|
||
|
|
if step, ok := totp.Verify(secret, cleanCode, now, uc.period(), digits, uc.config.TOTP.Window); ok {
|
||
|
|
ttl := time.Duration(uc.config.TOTP.ReplayTTLSeconds) * time.Second
|
||
|
|
fresh, markErr := uc.replay.MarkUsed(ctx, tenantID, uid, step, ttl)
|
||
|
|
if markErr != nil {
|
||
|
|
return errb.SysInternal("totp replay mark failed").WithCause(markErr)
|
||
|
|
}
|
||
|
|
if !fresh {
|
||
|
|
return errb.AuthForbidden("totp code already used").WithCause(member.ErrTOTPCodeReplay)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if uc.tryBackupCode(ctx, tenantID, uid, rec, cleanCode) {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
return errb.AuthForbidden("invalid totp code").WithCause(member.ErrTOTPInvalidCode)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (uc *totpUseCase) tryBackupCode(ctx context.Context, tenantID, uid string, rec *domrepo.TOTPProfileRecord, code string) bool {
|
||
|
|
for _, hash := range rec.BackupCodesHash {
|
||
|
|
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(code)); err == nil {
|
||
|
|
if _, err := uc.profile.ConsumeBackupCode(ctx, tenantID, uid, hash); err != nil {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
func (uc *totpUseCase) Disable(ctx context.Context, tenantID, uid string) error {
|
||
|
|
if tenantID == "" || uid == "" {
|
||
|
|
return errb.InputMissingRequired("tenant_id and uid are required")
|
||
|
|
}
|
||
|
|
if err := uc.profile.Clear(ctx, tenantID, uid); err != nil {
|
||
|
|
return errb.SysInternal("clear totp profile failed").WithCause(err)
|
||
|
|
}
|
||
|
|
if err := uc.enroll.Delete(ctx, tenantID, uid); err != nil {
|
||
|
|
return errb.SysInternal("clear totp enroll failed").WithCause(err)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (uc *totpUseCase) RegenerateBackupCodes(ctx context.Context, tenantID, uid string) ([]string, error) {
|
||
|
|
if tenantID == "" || uid == "" {
|
||
|
|
return nil, errb.InputMissingRequired("tenant_id and uid are required")
|
||
|
|
}
|
||
|
|
rec, err := uc.profile.Get(ctx, tenantID, uid)
|
||
|
|
if err != nil {
|
||
|
|
return nil, errb.SysInternal("read totp profile failed").WithCause(err)
|
||
|
|
}
|
||
|
|
if rec == nil || !rec.Enrolled {
|
||
|
|
return nil, errb.ResInvalidState("totp not enrolled").WithCause(member.ErrTOTPNotEnrolled)
|
||
|
|
}
|
||
|
|
plain, hashes, err := uc.generateBackupCodes()
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
if err := uc.profile.ReplaceBackupCodes(ctx, tenantID, uid, hashes); err != nil {
|
||
|
|
return nil, errb.SysInternal("replace backup codes failed").WithCause(err)
|
||
|
|
}
|
||
|
|
return plain, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (uc *totpUseCase) Status(ctx context.Context, tenantID, uid string) (*domusecase.TOTPStatusDTO, error) {
|
||
|
|
if tenantID == "" || uid == "" {
|
||
|
|
return nil, errb.InputMissingRequired("tenant_id and uid are required")
|
||
|
|
}
|
||
|
|
rec, err := uc.profile.Get(ctx, tenantID, uid)
|
||
|
|
if err != nil {
|
||
|
|
return nil, errb.SysInternal("read totp profile failed").WithCause(err)
|
||
|
|
}
|
||
|
|
dto := &domusecase.TOTPStatusDTO{}
|
||
|
|
if rec == nil {
|
||
|
|
return dto, nil
|
||
|
|
}
|
||
|
|
dto.Enrolled = rec.Enrolled
|
||
|
|
dto.EnrolledAt = rec.EnrolledAt
|
||
|
|
dto.BackupCodesRemaining = len(rec.BackupCodesHash)
|
||
|
|
return dto, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (uc *totpUseCase) period() time.Duration {
|
||
|
|
return time.Duration(uc.config.TOTP.PeriodSeconds) * time.Second
|
||
|
|
}
|
||
|
|
|
||
|
|
// generateBackupCodes returns the plaintext codes alongside their bcrypt
|
||
|
|
// hashes ready for persistence.
|
||
|
|
func (uc *totpUseCase) generateBackupCodes() (plain, hashes []string, err error) {
|
||
|
|
codes, err := totp.GenerateBackupCodes(uc.config.TOTP.BackupCodeCount, uc.config.TOTP.BackupCodeLength)
|
||
|
|
if err != nil {
|
||
|
|
return nil, nil, errb.SysInternal("backup code generation failed").WithCause(err)
|
||
|
|
}
|
||
|
|
hashes = make([]string, 0, len(codes))
|
||
|
|
for _, c := range codes {
|
||
|
|
h, hashErr := bcrypt.GenerateFromPassword([]byte(c), bcrypt.DefaultCost)
|
||
|
|
if hashErr != nil {
|
||
|
|
return nil, nil, errb.SysInternal("backup code hash failed").WithCause(hashErr)
|
||
|
|
}
|
||
|
|
hashes = append(hashes, string(h))
|
||
|
|
}
|
||
|
|
return codes, hashes, nil
|
||
|
|
}
|