package auth import ( "context" "strings" "time" "gateway/internal/library/zitadel" authmetaenum "gateway/internal/model/auth/domain/enum" domauth "gateway/internal/model/auth/domain/usecase" memberdom "gateway/internal/model/member/domain" dommember "gateway/internal/model/member/domain/usecase" notifenum "gateway/internal/model/notification/domain/enum" notifuc "gateway/internal/model/notification/domain/usecase" "gateway/internal/svc" "gateway/internal/types" "github.com/zeromicro/go-zero/core/logx" ) type RegisterLogic struct { logx.Logger ctx context.Context svcCtx *svc.ServiceContext } func NewRegisterLogic(ctx context.Context, svcCtx *svc.ServiceContext) *RegisterLogic { return &RegisterLogic{ Logger: logx.WithContext(ctx), ctx: ctx, svcCtx: svcCtx, } } func (l *RegisterLogic) Register(req *types.RegisterReq) (*types.RegisterData, error) { if err := requireRegistrationDeps(l.svcCtx); err != nil { return nil, err } tenant, err := resolveTenant(l.ctx, l.svcCtx, req.TenantSlug) if err != nil { return nil, err } regCfg := l.svcCtx.Config.Member.Defaults().Registration var inviteCodeID string if regCfg.RequireInviteCode { if l.svcCtx.AuthInvite == nil { return nil, errb.SysNotImplemented("invite validation not configured") } consumed, err := l.svcCtx.AuthInvite.Consume(l.ctx, &domauth.ConsumeInviteRequest{ TenantID: tenant.TenantID, Code: req.InviteCode, }) if err != nil { return nil, err } inviteCodeID = consumed.ID } email := strings.TrimSpace(strings.ToLower(req.Email)) zResult, err := l.svcCtx.Zitadel.CreateHumanUser(l.ctx, zitadel.CreateHumanUserRequest{ OrgID: tenant.OrgID, Email: email, Password: req.Password, DisplayName: strings.TrimSpace(req.DisplayName), Language: strings.TrimSpace(req.Language), }) if err != nil { return nil, wrapZitadelErr(err) } memberDTO, err := l.svcCtx.MemberLifecycle.CreateUnverified(l.ctx, &dommember.CreatePlatformMemberRequest{ TenantID: tenant.TenantID, Email: email, DisplayName: strings.TrimSpace(req.DisplayName), Language: strings.TrimSpace(req.Language), ZitadelUserID: zResult.UserID, }) if err != nil { if deactErr := l.svcCtx.Zitadel.DeactivateUser(l.ctx, zResult.UserID); deactErr != nil { logx.WithContext(l.ctx).Errorf("register: deactivate zitadel user after member failure: %v", deactErr) } return nil, err } if err := recordRegistrationMeta(l.ctx, l.svcCtx, tenant.TenantID, memberDTO.UID, inviteCodeID, req.AcceptTermsVersion, req.MarketingOptIn, authmetaenum.RegistrationChannelEmail); err != nil { if abortErr := l.svcCtx.MemberLifecycle.AbortPending(l.ctx, tenant.TenantID, memberDTO.UID); abortErr != nil { logx.WithContext(l.ctx).Errorf("register: abort pending member after metadata failure: %v", abortErr) } if deactErr := l.svcCtx.Zitadel.DeactivateUser(l.ctx, zResult.UserID); deactErr != nil { logx.WithContext(l.ctx).Errorf("register: deactivate zitadel user after metadata failure: %v", deactErr) } return nil, err } data, err := sendRegistrationOTP(l.ctx, l.svcCtx, tenant.TenantID, memberDTO.UID, email) if err != nil { if abortErr := l.svcCtx.MemberLifecycle.AbortPending(l.ctx, tenant.TenantID, memberDTO.UID); abortErr != nil { logx.WithContext(l.ctx).Errorf("register: abort pending member: %v", abortErr) } if deactErr := l.svcCtx.Zitadel.DeactivateUser(l.ctx, zResult.UserID); deactErr != nil { logx.WithContext(l.ctx).Errorf("register: deactivate zitadel user after otp failure: %v", deactErr) } return nil, err } data.UID = memberDTO.UID return data, nil } func sendRegistrationOTP( ctx context.Context, sc *svc.ServiceContext, tenantID, uid, email string, ) (*types.RegisterData, error) { cfg := sc.Config.Member.Defaults() rateKey := memberdom.GetVerifyRateRedisKey(tenantID, uid, string(registrationPurpose())) if err := sc.MemberVerifyRate.AssertResendAllowed(ctx, rateKey, time.Duration(cfg.OTP.ResendCooldownSeconds)*time.Second); err != nil { return nil, err } dto, plainCode, err := sc.MemberOTP.Generate(ctx, &dommember.GenerateOTPRequest{ TenantID: tenantID, UID: uid, Purpose: registrationPurpose(), Target: email, }) if err != nil { return nil, err } locale := sc.Config.Notification.DefaultLocale if strings.TrimSpace(locale) == "" { locale = "en-us" } if _, sendErr := sc.Notifier.Send(ctx, ¬ifuc.SendRequest{ TenantID: tenantID, UID: uid, Channel: notifenum.ChannelEmail, Kind: notifenum.NotifyVerifyRegistrationEmail, Target: email, Locale: locale, Data: map[string]any{"code": plainCode, "expires_in": dto.ExpiresIn}, IdempotencyKey: dto.ChallengeID, DoNotPersistBody: true, Severity: notifenum.SeverityInfo, }); sendErr != nil { if invErr := sc.MemberOTP.Invalidate(ctx, dto.ChallengeID); invErr != nil { return nil, invErr } return nil, sendErr } return &types.RegisterData{ ChallengeID: dto.ChallengeID, ExpiresIn: dto.ExpiresIn, }, nil }