package logic import ( "context" "crypto/rand" "fmt" "math/big" "member/gen_result/pb/member" "member/internal/domain" ers "member/internal/lib/error" "member/internal/lib/required" "member/internal/svc" "strconv" "github.com/zeromicro/go-zero/core/logx" ) type GenerateRefreshCodeLogic struct { ctx context.Context svcCtx *svc.ServiceContext logx.Logger } func NewGenerateRefreshCodeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GenerateRefreshCodeLogic { return &GenerateRefreshCodeLogic{ ctx: ctx, svcCtx: svcCtx, Logger: logx.WithContext(ctx), } } type generateRefreshCodeReq struct { Account string `json:"account" validate:"account"` // CodeType 1 email 2 phone CodeType int32 `json:"code_type" validate:"required,oneof=1 2 3"` } var codeMap = map[int32]string{ 1: "email", 2: "phone", } func getCodeNameByCode(code int32) (string, bool) { res, ok := codeMap[code] if !ok { return "", false } return res, true } func generateVerifyCode(digits int) (string, error) { if digits <= 0 { // 預設為六位數 digits = 6 } // 計算最大值 (10^digits - 1) exp := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(digits)), nil) // 生成隨機數 randomNumber, err := rand.Int(rand.Reader, exp) if err != nil { return "", err } // 將隨機數轉換為 string verifyCode := strconv.Itoa(int(randomNumber.Int64())) // 如果隨機數的位數少於指定的位數,則補 0 if len(verifyCode) < digits { verifyCode = fmt.Sprintf("%0*d", digits, randomNumber) } return verifyCode, nil } // GenerateRefreshCode 這個帳號驗證碼(十分鐘),通用的 func (l *GenerateRefreshCodeLogic) GenerateRefreshCode(in *member.GenerateRefreshCodeReq) (*member.GenerateRefreshCodeResp, error) { err := required.ValidateAll(l.svcCtx.Validate, &generateRefreshCodeReq{ Account: in.GetAccount(), CodeType: in.GetCodeType(), }) if err != nil { return nil, ers.InvalidFormat(err.Error()) } checkType, status := getCodeNameByCode(in.GetCodeType()) if !status { return nil, ers.InvalidFormat(fmt.Errorf("failed to get correct code type").Error()) } rk := fmt.Sprintf("verify:%s:%s", checkType, in.GetAccount()) // 拿過就不要再拿了 get, err := l.svcCtx.Redis.Get(rk) if err != nil { return nil, ers.DBError("failed to connect to redis", err.Error()) } if get != "" { return &member.GenerateRefreshCodeResp{ Status: &member.BaseResp{ Code: domain.CodeOk.ToString(), Message: "success", Error: "", }, Data: &member.VerifyCode{ VerifyCode: get, }, }, nil } code, err := generateVerifyCode(6) if !status { return nil, ers.ArkInternal(err.Error()) } err = l.svcCtx.Redis.Setex(rk, code, 600) if err != nil { return nil, ers.ArkInternal(err.Error()) } return &member.GenerateRefreshCodeResp{ Status: &member.BaseResp{ Code: domain.CodeOk.ToString(), Message: "success", Error: "", }, Data: &member.VerifyCode{ VerifyCode: code, }, }, nil }