package usecase import ( "context" "crypto/rand" "errors" "math/big" "time" "github.com/google/uuid" "golang.org/x/crypto/bcrypt" errs "gateway/internal/library/errors" "gateway/internal/library/errors/code" "gateway/internal/model/member" memberconfig "gateway/internal/model/member/config" domrepo "gateway/internal/model/member/domain/repository" domusecase "gateway/internal/model/member/domain/usecase" ) var errb = errs.For(code.Facade) type otpUseCase struct { store domrepo.OTPChallengeStore config memberconfig.Config } // OTPUseCaseParam wires OTPUseCase. type OTPUseCaseParam struct { Store domrepo.OTPChallengeStore Config memberconfig.Config } // MustOTPUseCase constructs OTPUseCase. func MustOTPUseCase(param OTPUseCaseParam) domusecase.OTPUseCase { return &otpUseCase{ store: param.Store, config: param.Config.Defaults(), } } func (uc *otpUseCase) Generate(ctx context.Context, req *domusecase.GenerateOTPRequest) (*domusecase.OTPChallengeDTO, string, error) { if req == nil || req.TenantID == "" || req.Purpose == "" { return nil, "", errb.InputMissingRequired("tenant_id and purpose are required") } plainCode, err := generateNumericOTP(uc.config.OTP.Length) if err != nil { return nil, "", errb.SysInternal("otp generation failed").WithCause(err) } hash, err := bcrypt.GenerateFromPassword([]byte(plainCode), bcrypt.DefaultCost) if err != nil { return nil, "", errb.SysInternal("otp hash failed").WithCause(err) } challengeID := uuid.NewString() ch := &domrepo.OTPChallenge{ TenantID: req.TenantID, UID: req.UID, Purpose: req.Purpose, Target: req.Target, CodeHash: string(hash), } ttl := time.Duration(uc.config.OTP.TTLSeconds) * time.Second if err := uc.store.Save(ctx, challengeID, ch, ttl); err != nil { return nil, "", errb.SysInternal("otp persist failed").WithCause(err) } return &domusecase.OTPChallengeDTO{ ChallengeID: challengeID, ExpiresIn: uc.config.OTP.TTLSeconds, }, plainCode, nil } func (uc *otpUseCase) Verify(ctx context.Context, req *domusecase.VerifyOTPRequest) (string, error) { if req == nil || req.ChallengeID == "" || req.Code == "" || req.Purpose == "" { return "", errb.InputMissingRequired("challenge_id, code and purpose are required") } ch, err := uc.store.Get(ctx, req.ChallengeID) if err != nil { if errors.Is(err, member.ErrChallengeNotFound) { return "", errb.ResNotFound("otp challenge", req.ChallengeID).WithCause(err) } return "", errb.SysInternal("otp read failed").WithCause(err) } if ch.TenantID != req.TenantID { return "", errb.AuthForbidden("otp challenge tenant mismatch") } if ch.UID != "" { if req.UID == "" { return "", errb.InputMissingRequired("uid is required for this otp challenge") } if ch.UID != req.UID { return "", errb.AuthForbidden("otp challenge uid mismatch") } } if ch.Purpose != req.Purpose { return "", errb.AuthForbidden("otp challenge purpose mismatch") } if ch.Attempts >= uc.config.OTP.MaxAttempts { return "", errb.ResInvalidState("otp challenge locked").WithCause(member.ErrChallengeLocked) } if err := bcrypt.CompareHashAndPassword([]byte(ch.CodeHash), []byte(req.Code)); err != nil { attempts, incErr := uc.store.IncrementAttempts(ctx, req.ChallengeID) if incErr != nil { if errors.Is(incErr, member.ErrChallengeNotFound) { return "", errb.ResNotFound("otp challenge", req.ChallengeID).WithCause(incErr) } return "", errb.SysInternal("otp persist failed").WithCause(incErr) } if attempts >= uc.config.OTP.MaxAttempts { return "", errb.ResInvalidState("otp challenge locked").WithCause(member.ErrChallengeLocked) } return "", errb.AuthForbidden("invalid otp code").WithCause(member.ErrInvalidOTP) } target := ch.Target if delErr := uc.store.Delete(ctx, req.ChallengeID); delErr != nil { return "", errb.SysInternal("otp delete failed").WithCause(delErr) } return target, nil } func (uc *otpUseCase) Invalidate(ctx context.Context, challengeID string) error { if challengeID == "" { return errb.InputMissingRequired("challenge_id is required") } return uc.store.Delete(ctx, challengeID) } func generateNumericOTP(length int) (string, error) { if length <= 0 { length = 6 } out := make([]byte, length) for i := range out { n, err := rand.Int(rand.Reader, big.NewInt(10)) if err != nil { return "", err } out[i] = byte('0' + n.Uint64()) //nolint:gosec // digit 0-9 } return string(out), nil }