template-monorepo/internal/logic/auth/register_logic.go

164 lines
5.3 KiB
Go

package auth
import (
"context"
"errors"
"strings"
"time"
"gateway/internal/library/zitadel"
authmetaenum "gateway/internal/model/auth/domain/enum"
domauth "gateway/internal/model/auth/domain/usecase"
memberdom "gateway/internal/model/member/domain"
dommember "gateway/internal/model/member/domain/usecase"
notifenum "gateway/internal/model/notification/domain/enum"
notifuc "gateway/internal/model/notification/domain/usecase"
"gateway/internal/svc"
"gateway/internal/types"
"github.com/zeromicro/go-zero/core/logx"
)
type RegisterLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewRegisterLogic(ctx context.Context, svcCtx *svc.ServiceContext) *RegisterLogic {
return &RegisterLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
func (l *RegisterLogic) Register(req *types.RegisterReq) (*types.RegisterData, error) {
if err := requireRegistrationDeps(l.svcCtx); err != nil {
return nil, err
}
tenant, err := resolveTenant(l.ctx, l.svcCtx, req.TenantSlug)
if err != nil {
return nil, err
}
email := normalizeLoginEmail(req.Email)
zResult, err := l.svcCtx.Zitadel.CreateHumanUser(l.ctx, zitadel.CreateHumanUserRequest{
OrgID: tenant.OrgID,
Email: email,
Password: req.Password,
DisplayName: strings.TrimSpace(req.DisplayName),
Language: strings.TrimSpace(req.Language),
})
if err != nil {
if errors.Is(err, zitadel.ErrUserAlreadyExists) {
return recoverPendingRegistration(l.ctx, l.svcCtx, tenant, req)
}
return nil, wrapZitadelErr(err)
}
regCfg := l.svcCtx.Config.Member.Defaults().Registration
var inviteCodeID string
if regCfg.RequireInviteCode {
if l.svcCtx.AuthInvite == nil {
return nil, errb.SysNotImplemented("invite validation not configured")
}
consumed, err := l.svcCtx.AuthInvite.Consume(l.ctx, &domauth.ConsumeInviteRequest{
TenantID: tenant.TenantID,
Code: req.InviteCode,
})
if err != nil {
if deactErr := l.svcCtx.Zitadel.DeactivateUser(l.ctx, zResult.UserID); deactErr != nil {
logx.WithContext(l.ctx).Errorf("register: deactivate zitadel user after invite failure: %v", deactErr)
}
return nil, err
}
inviteCodeID = consumed.ID
}
memberDTO, err := l.svcCtx.MemberLifecycle.CreateUnverified(l.ctx, &dommember.CreatePlatformMemberRequest{
TenantID: tenant.TenantID,
Email: email,
DisplayName: strings.TrimSpace(req.DisplayName),
Language: strings.TrimSpace(req.Language),
ZitadelUserID: zResult.UserID,
})
if err != nil {
if deactErr := l.svcCtx.Zitadel.DeactivateUser(l.ctx, zResult.UserID); deactErr != nil {
logx.WithContext(l.ctx).Errorf("register: deactivate zitadel user after member failure: %v", deactErr)
}
return nil, err
}
if err := recordRegistrationMeta(l.ctx, l.svcCtx, tenant.TenantID, memberDTO.UID, inviteCodeID, req.AcceptTermsVersion, req.MarketingOptIn, authmetaenum.RegistrationChannelEmail); err != nil {
if abortErr := l.svcCtx.MemberLifecycle.AbortPending(l.ctx, tenant.TenantID, memberDTO.UID); abortErr != nil {
logx.WithContext(l.ctx).Errorf("register: abort pending member after metadata failure: %v", abortErr)
}
if deactErr := l.svcCtx.Zitadel.DeactivateUser(l.ctx, zResult.UserID); deactErr != nil {
logx.WithContext(l.ctx).Errorf("register: deactivate zitadel user after metadata failure: %v", deactErr)
}
return nil, err
}
data, err := sendRegistrationOTP(l.ctx, l.svcCtx, tenant.TenantID, memberDTO.UID, email)
if err != nil {
if abortErr := l.svcCtx.MemberLifecycle.AbortPending(l.ctx, tenant.TenantID, memberDTO.UID); abortErr != nil {
logx.WithContext(l.ctx).Errorf("register: abort pending member: %v", abortErr)
}
if deactErr := l.svcCtx.Zitadel.DeactivateUser(l.ctx, zResult.UserID); deactErr != nil {
logx.WithContext(l.ctx).Errorf("register: deactivate zitadel user after otp failure: %v", deactErr)
}
return nil, err
}
data.UID = memberDTO.UID
return data, nil
}
func sendRegistrationOTP(
ctx context.Context,
sc *svc.ServiceContext,
tenantID, uid, email string,
) (*types.RegisterData, error) {
cfg := sc.Config.Member.Defaults()
rateKey := memberdom.GetVerifyRateRedisKey(tenantID, uid, string(registrationPurpose()))
if err := sc.MemberVerifyRate.AssertResendAllowed(ctx, rateKey, time.Duration(cfg.OTP.ResendCooldownSeconds)*time.Second); err != nil {
return nil, err
}
dto, plainCode, err := sc.MemberOTP.Generate(ctx, &dommember.GenerateOTPRequest{
TenantID: tenantID,
UID: uid,
Purpose: registrationPurpose(),
Target: email,
})
if err != nil {
return nil, err
}
locale := sc.Config.Notification.DefaultLocale
if strings.TrimSpace(locale) == "" {
locale = "en-us"
}
if _, sendErr := sc.Notifier.Send(ctx, &notifuc.SendRequest{
TenantID: tenantID,
UID: uid,
Channel: notifenum.ChannelEmail,
Kind: notifenum.NotifyVerifyRegistrationEmail,
Target: email,
Locale: locale,
Data: map[string]any{"code": plainCode, "expires_in": dto.ExpiresIn},
IdempotencyKey: dto.ChallengeID,
DoNotPersistBody: true,
Severity: notifenum.SeverityInfo,
}); sendErr != nil {
if invErr := sc.MemberOTP.Invalidate(ctx, dto.ChallengeID); invErr != nil {
return nil, invErr
}
return nil, sendErr
}
return &types.RegisterData{
ChallengeID: dto.ChallengeID,
ExpiresIn: dto.ExpiresIn,
}, nil
}