package usecase import ( "context" "strings" "time" authdomain "gateway/internal/model/auth/domain" "gateway/internal/model/auth/domain/entity" domrepo "gateway/internal/model/auth/domain/repository" domusecase "gateway/internal/model/auth/domain/usecase" "github.com/zeromicro/go-zero/core/logx" ) type inviteUseCase struct { repo domrepo.InviteRepository lock domrepo.InviteConsumeLock } // InviteUseCaseParam wires InviteUseCase. type InviteUseCaseParam struct { Repo domrepo.InviteRepository Lock domrepo.InviteConsumeLock } // MustInviteUseCase constructs InviteUseCase. func MustInviteUseCase(param InviteUseCaseParam) domusecase.InviteUseCase { if param.Repo == nil { panic("auth: invite repository is required") } if param.Lock == nil { panic("auth: invite consume lock is required") } return &inviteUseCase{repo: param.Repo, lock: param.Lock} } func (uc *inviteUseCase) Validate(ctx context.Context, req *domusecase.ValidateInviteRequest) (*domusecase.InviteView, error) { if req == nil { return nil, errb.InputMissingRequired("invite request is required") } invite, err := uc.lookup(ctx, req) if err != nil { return nil, err } return toInviteView(invite), nil } func (uc *inviteUseCase) Consume(ctx context.Context, req *domusecase.ConsumeInviteRequest) (*domusecase.ConsumedInvite, error) { tenantID, code, err := normalizeInviteInput(req) if err != nil { return nil, err } codeHash := authdomain.HashInviteCode(code) ok, err := uc.lock.TryLock(ctx, tenantID, codeHash) if err != nil { return nil, wrapRepoErr(err, "invite consume lock failed") } if !ok { return nil, wrapRepoErr(authdomain.ErrInviteLocked) } defer func() { if err := uc.lock.Unlock(ctx, tenantID, codeHash); err != nil { logx.WithContext(ctx).Errorf("auth: invite unlock failed tenant=%s codeHash=%s: %v", tenantID, codeHash, err) } }() invite, err := uc.repo.GetByTenantAndCodeHash(ctx, tenantID, codeHash) if err != nil { return nil, wrapRepoErr(err) } if err := checkInviteActive(invite); err != nil { return nil, wrapRepoErr(err) } consumed, err := uc.repo.ConsumeOne(ctx, invite.ID) if err != nil { return nil, wrapRepoErr(err) } return &domusecase.ConsumedInvite{ ID: consumed.ID.Hex(), TenantID: consumed.TenantID, NewUsersOnly: consumed.NewUsersOnly, UsedCount: consumed.UsedCount, }, nil } func (uc *inviteUseCase) lookup(ctx context.Context, req *domusecase.ValidateInviteRequest) (*entity.InviteCode, error) { tenantID, code, err := normalizeInviteFields(req.TenantID, req.Code) if err != nil { return nil, err } invite, err := uc.repo.GetByTenantAndCodeHash(ctx, tenantID, authdomain.HashInviteCode(code)) if err != nil { return nil, wrapRepoErr(err) } if err := checkInviteActive(invite); err != nil { return nil, wrapRepoErr(err) } return invite, nil } func normalizeInviteInput(req *domusecase.ConsumeInviteRequest) (tenantID, code string, err error) { if req == nil { return "", "", errb.InputMissingRequired("invite request is required") } return normalizeInviteFields(req.TenantID, req.Code) } func normalizeInviteFields(tenantIDRaw, codeRaw string) (tenantID, code string, err error) { tenantID = strings.TrimSpace(tenantIDRaw) code = authdomain.NormalizeInviteCode(codeRaw) if tenantID == "" { return "", "", errb.InputMissingRequired("tenant_id is required") } if code == "" { return "", "", errb.InputMissingRequired("invite_code is required") } return tenantID, code, nil } func checkInviteActive(invite *entity.InviteCode) error { if invite == nil { return authdomain.ErrInviteNotFound } now := time.Now().UTC().UnixMilli() if invite.ExpiresAt > 0 && invite.ExpiresAt <= now { return authdomain.ErrInviteExpired } if invite.UsedCount >= invite.MaxUses { return authdomain.ErrInviteExhausted } return nil } func toInviteView(invite *entity.InviteCode) *domusecase.InviteView { remaining := invite.MaxUses - invite.UsedCount if remaining < 0 { remaining = 0 } return &domusecase.InviteView{ ID: invite.ID.Hex(), TenantID: invite.TenantID, NewUsersOnly: invite.NewUsersOnly, RemainingUses: remaining, } }