template-monorepo/internal/model/auth/usecase/invite_usecase.go

149 lines
4.1 KiB
Go
Raw Permalink Normal View History

package usecase
import (
"context"
"strings"
"time"
authdomain "gateway/internal/model/auth/domain"
"gateway/internal/model/auth/domain/entity"
domrepo "gateway/internal/model/auth/domain/repository"
domusecase "gateway/internal/model/auth/domain/usecase"
"github.com/zeromicro/go-zero/core/logx"
)
type inviteUseCase struct {
repo domrepo.InviteRepository
lock domrepo.InviteConsumeLock
}
// InviteUseCaseParam wires InviteUseCase.
type InviteUseCaseParam struct {
Repo domrepo.InviteRepository
Lock domrepo.InviteConsumeLock
}
// MustInviteUseCase constructs InviteUseCase.
func MustInviteUseCase(param InviteUseCaseParam) domusecase.InviteUseCase {
if param.Repo == nil {
panic("auth: invite repository is required")
}
if param.Lock == nil {
panic("auth: invite consume lock is required")
}
return &inviteUseCase{repo: param.Repo, lock: param.Lock}
}
func (uc *inviteUseCase) Validate(ctx context.Context, req *domusecase.ValidateInviteRequest) (*domusecase.InviteView, error) {
if req == nil {
return nil, errb.InputMissingRequired("invite request is required")
}
invite, err := uc.lookup(ctx, req)
if err != nil {
return nil, err
}
return toInviteView(invite), nil
}
func (uc *inviteUseCase) Consume(ctx context.Context, req *domusecase.ConsumeInviteRequest) (*domusecase.ConsumedInvite, error) {
tenantID, code, err := normalizeInviteInput(req)
if err != nil {
return nil, err
}
codeHash := authdomain.HashInviteCode(code)
ok, err := uc.lock.TryLock(ctx, tenantID, codeHash)
if err != nil {
return nil, wrapRepoErr(err, "invite consume lock failed")
}
if !ok {
return nil, wrapRepoErr(authdomain.ErrInviteLocked)
}
defer func() {
if err := uc.lock.Unlock(ctx, tenantID, codeHash); err != nil {
logx.WithContext(ctx).Errorf("auth: invite unlock failed tenant=%s codeHash=%s: %v", tenantID, codeHash, err)
}
}()
invite, err := uc.repo.GetByTenantAndCodeHash(ctx, tenantID, codeHash)
if err != nil {
return nil, wrapRepoErr(err)
}
if err := checkInviteActive(invite); err != nil {
return nil, wrapRepoErr(err)
}
consumed, err := uc.repo.ConsumeOne(ctx, invite.ID)
if err != nil {
return nil, wrapRepoErr(err)
}
return &domusecase.ConsumedInvite{
ID: consumed.ID.Hex(),
TenantID: consumed.TenantID,
NewUsersOnly: consumed.NewUsersOnly,
UsedCount: consumed.UsedCount,
}, nil
}
func (uc *inviteUseCase) lookup(ctx context.Context, req *domusecase.ValidateInviteRequest) (*entity.InviteCode, error) {
tenantID, code, err := normalizeInviteFields(req.TenantID, req.Code)
if err != nil {
return nil, err
}
invite, err := uc.repo.GetByTenantAndCodeHash(ctx, tenantID, authdomain.HashInviteCode(code))
if err != nil {
return nil, wrapRepoErr(err)
}
if err := checkInviteActive(invite); err != nil {
return nil, wrapRepoErr(err)
}
return invite, nil
}
func normalizeInviteInput(req *domusecase.ConsumeInviteRequest) (tenantID, code string, err error) {
if req == nil {
return "", "", errb.InputMissingRequired("invite request is required")
}
return normalizeInviteFields(req.TenantID, req.Code)
}
func normalizeInviteFields(tenantIDRaw, codeRaw string) (tenantID, code string, err error) {
tenantID = strings.TrimSpace(tenantIDRaw)
code = authdomain.NormalizeInviteCode(codeRaw)
if tenantID == "" {
return "", "", errb.InputMissingRequired("tenant_id is required")
}
if code == "" {
return "", "", errb.InputMissingRequired("invite_code is required")
}
return tenantID, code, nil
}
func checkInviteActive(invite *entity.InviteCode) error {
if invite == nil {
return authdomain.ErrInviteNotFound
}
now := time.Now().UTC().UnixMilli()
if invite.ExpiresAt > 0 && invite.ExpiresAt <= now {
return authdomain.ErrInviteExpired
}
if invite.UsedCount >= invite.MaxUses {
return authdomain.ErrInviteExhausted
}
return nil
}
func toInviteView(invite *entity.InviteCode) *domusecase.InviteView {
remaining := invite.MaxUses - invite.UsedCount
if remaining < 0 {
remaining = 0
}
return &domusecase.InviteView{
ID: invite.ID.Hex(),
TenantID: invite.TenantID,
NewUsersOnly: invite.NewUsersOnly,
RemainingUses: remaining,
}
}