template-monorepo/internal/model/member/usecase/otp_usecase.go

142 lines
4.4 KiB
Go
Raw Normal View History

package usecase
import (
"context"
"crypto/rand"
"errors"
"math/big"
"time"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
errs "gateway/internal/library/errors"
"gateway/internal/library/errors/code"
"gateway/internal/model/member"
memberconfig "gateway/internal/model/member/config"
domrepo "gateway/internal/model/member/domain/repository"
domusecase "gateway/internal/model/member/domain/usecase"
)
var errb = errs.For(code.Facade)
type otpUseCase struct {
store domrepo.OTPChallengeStore
config memberconfig.Config
}
// OTPUseCaseParam wires OTPUseCase.
type OTPUseCaseParam struct {
Store domrepo.OTPChallengeStore
Config memberconfig.Config
}
// MustOTPUseCase constructs OTPUseCase.
func MustOTPUseCase(param OTPUseCaseParam) domusecase.OTPUseCase {
return &otpUseCase{
store: param.Store,
config: param.Config.Defaults(),
}
}
func (uc *otpUseCase) Generate(ctx context.Context, req *domusecase.GenerateOTPRequest) (*domusecase.OTPChallengeDTO, string, error) {
if req == nil || req.TenantID == "" || req.Purpose == "" {
return nil, "", errb.InputMissingRequired("tenant_id and purpose are required")
}
plainCode, err := generateNumericOTP(uc.config.OTP.Length)
if err != nil {
return nil, "", errb.SysInternal("otp generation failed").WithCause(err)
}
hash, err := bcrypt.GenerateFromPassword([]byte(plainCode), bcrypt.DefaultCost)
if err != nil {
return nil, "", errb.SysInternal("otp hash failed").WithCause(err)
}
challengeID := uuid.NewString()
ch := &domrepo.OTPChallenge{
TenantID: req.TenantID,
UID: req.UID,
Purpose: req.Purpose,
Target: req.Target,
CodeHash: string(hash),
}
ttl := time.Duration(uc.config.OTP.TTLSeconds) * time.Second
if err := uc.store.Save(ctx, challengeID, ch, ttl); err != nil {
return nil, "", errb.SysInternal("otp persist failed").WithCause(err)
}
return &domusecase.OTPChallengeDTO{
ChallengeID: challengeID,
ExpiresIn: uc.config.OTP.TTLSeconds,
}, plainCode, nil
}
func (uc *otpUseCase) Verify(ctx context.Context, req *domusecase.VerifyOTPRequest) (string, error) {
if req == nil || req.ChallengeID == "" || req.Code == "" || req.Purpose == "" {
return "", errb.InputMissingRequired("challenge_id, code and purpose are required")
}
ch, err := uc.store.Get(ctx, req.ChallengeID)
if err != nil {
if errors.Is(err, member.ErrChallengeNotFound) {
return "", errb.ResNotFound("otp challenge", req.ChallengeID).WithCause(err)
}
return "", errb.SysInternal("otp read failed").WithCause(err)
}
if ch.TenantID != req.TenantID {
return "", errb.AuthForbidden("otp challenge tenant mismatch")
}
if ch.UID != "" {
if req.UID == "" {
return "", errb.InputMissingRequired("uid is required for this otp challenge")
}
if ch.UID != req.UID {
return "", errb.AuthForbidden("otp challenge uid mismatch")
}
}
if ch.Purpose != req.Purpose {
return "", errb.AuthForbidden("otp challenge purpose mismatch")
}
if ch.Attempts >= uc.config.OTP.MaxAttempts {
return "", errb.ResInvalidState("otp challenge locked").WithCause(member.ErrChallengeLocked)
}
if err := bcrypt.CompareHashAndPassword([]byte(ch.CodeHash), []byte(req.Code)); err != nil {
attempts, incErr := uc.store.IncrementAttempts(ctx, req.ChallengeID)
if incErr != nil {
if errors.Is(incErr, member.ErrChallengeNotFound) {
return "", errb.ResNotFound("otp challenge", req.ChallengeID).WithCause(incErr)
}
return "", errb.SysInternal("otp persist failed").WithCause(incErr)
}
if attempts >= uc.config.OTP.MaxAttempts {
return "", errb.ResInvalidState("otp challenge locked").WithCause(member.ErrChallengeLocked)
}
return "", errb.AuthForbidden("invalid otp code").WithCause(member.ErrInvalidOTP)
}
target := ch.Target
if delErr := uc.store.Delete(ctx, req.ChallengeID); delErr != nil {
return "", errb.SysInternal("otp delete failed").WithCause(delErr)
}
return target, nil
}
func (uc *otpUseCase) Invalidate(ctx context.Context, challengeID string) error {
if challengeID == "" {
return errb.InputMissingRequired("challenge_id is required")
}
return uc.store.Delete(ctx, challengeID)
}
func generateNumericOTP(length int) (string, error) {
if length <= 0 {
length = 6
}
out := make([]byte, length)
for i := range out {
n, err := rand.Int(rand.Reader, big.NewInt(10))
if err != nil {
return "", err
}
out[i] = byte('0' + n.Uint64()) //nolint:gosec // digit 0-9
}
return string(out), nil
}