149 lines
4.1 KiB
Go
149 lines
4.1 KiB
Go
package usecase
|
|
|
|
import (
|
|
"context"
|
|
"strings"
|
|
"time"
|
|
|
|
authdomain "gateway/internal/model/auth/domain"
|
|
"gateway/internal/model/auth/domain/entity"
|
|
domrepo "gateway/internal/model/auth/domain/repository"
|
|
domusecase "gateway/internal/model/auth/domain/usecase"
|
|
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|
)
|
|
|
|
type inviteUseCase struct {
|
|
repo domrepo.InviteRepository
|
|
lock domrepo.InviteConsumeLock
|
|
}
|
|
|
|
// InviteUseCaseParam wires InviteUseCase.
|
|
type InviteUseCaseParam struct {
|
|
Repo domrepo.InviteRepository
|
|
Lock domrepo.InviteConsumeLock
|
|
}
|
|
|
|
// MustInviteUseCase constructs InviteUseCase.
|
|
func MustInviteUseCase(param InviteUseCaseParam) domusecase.InviteUseCase {
|
|
if param.Repo == nil {
|
|
panic("auth: invite repository is required")
|
|
}
|
|
if param.Lock == nil {
|
|
panic("auth: invite consume lock is required")
|
|
}
|
|
return &inviteUseCase{repo: param.Repo, lock: param.Lock}
|
|
}
|
|
|
|
func (uc *inviteUseCase) Validate(ctx context.Context, req *domusecase.ValidateInviteRequest) (*domusecase.InviteView, error) {
|
|
if req == nil {
|
|
return nil, errb.InputMissingRequired("invite request is required")
|
|
}
|
|
invite, err := uc.lookup(ctx, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return toInviteView(invite), nil
|
|
}
|
|
|
|
func (uc *inviteUseCase) Consume(ctx context.Context, req *domusecase.ConsumeInviteRequest) (*domusecase.ConsumedInvite, error) {
|
|
tenantID, code, err := normalizeInviteInput(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
codeHash := authdomain.HashInviteCode(code)
|
|
|
|
ok, err := uc.lock.TryLock(ctx, tenantID, codeHash)
|
|
if err != nil {
|
|
return nil, wrapRepoErr(err, "invite consume lock failed")
|
|
}
|
|
if !ok {
|
|
return nil, wrapRepoErr(authdomain.ErrInviteLocked)
|
|
}
|
|
defer func() {
|
|
if err := uc.lock.Unlock(ctx, tenantID, codeHash); err != nil {
|
|
logx.WithContext(ctx).Errorf("auth: invite unlock failed tenant=%s codeHash=%s: %v", tenantID, codeHash, err)
|
|
}
|
|
}()
|
|
|
|
invite, err := uc.repo.GetByTenantAndCodeHash(ctx, tenantID, codeHash)
|
|
if err != nil {
|
|
return nil, wrapRepoErr(err)
|
|
}
|
|
if err := checkInviteActive(invite); err != nil {
|
|
return nil, wrapRepoErr(err)
|
|
}
|
|
|
|
consumed, err := uc.repo.ConsumeOne(ctx, invite.ID)
|
|
if err != nil {
|
|
return nil, wrapRepoErr(err)
|
|
}
|
|
return &domusecase.ConsumedInvite{
|
|
ID: consumed.ID.Hex(),
|
|
TenantID: consumed.TenantID,
|
|
NewUsersOnly: consumed.NewUsersOnly,
|
|
UsedCount: consumed.UsedCount,
|
|
}, nil
|
|
}
|
|
|
|
func (uc *inviteUseCase) lookup(ctx context.Context, req *domusecase.ValidateInviteRequest) (*entity.InviteCode, error) {
|
|
tenantID, code, err := normalizeInviteFields(req.TenantID, req.Code)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
invite, err := uc.repo.GetByTenantAndCodeHash(ctx, tenantID, authdomain.HashInviteCode(code))
|
|
if err != nil {
|
|
return nil, wrapRepoErr(err)
|
|
}
|
|
if err := checkInviteActive(invite); err != nil {
|
|
return nil, wrapRepoErr(err)
|
|
}
|
|
return invite, nil
|
|
}
|
|
|
|
func normalizeInviteInput(req *domusecase.ConsumeInviteRequest) (tenantID, code string, err error) {
|
|
if req == nil {
|
|
return "", "", errb.InputMissingRequired("invite request is required")
|
|
}
|
|
return normalizeInviteFields(req.TenantID, req.Code)
|
|
}
|
|
|
|
func normalizeInviteFields(tenantIDRaw, codeRaw string) (tenantID, code string, err error) {
|
|
tenantID = strings.TrimSpace(tenantIDRaw)
|
|
code = authdomain.NormalizeInviteCode(codeRaw)
|
|
if tenantID == "" {
|
|
return "", "", errb.InputMissingRequired("tenant_id is required")
|
|
}
|
|
if code == "" {
|
|
return "", "", errb.InputMissingRequired("invite_code is required")
|
|
}
|
|
return tenantID, code, nil
|
|
}
|
|
|
|
func checkInviteActive(invite *entity.InviteCode) error {
|
|
if invite == nil {
|
|
return authdomain.ErrInviteNotFound
|
|
}
|
|
now := time.Now().UTC().UnixMilli()
|
|
if invite.ExpiresAt > 0 && invite.ExpiresAt <= now {
|
|
return authdomain.ErrInviteExpired
|
|
}
|
|
if invite.UsedCount >= invite.MaxUses {
|
|
return authdomain.ErrInviteExhausted
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func toInviteView(invite *entity.InviteCode) *domusecase.InviteView {
|
|
remaining := invite.MaxUses - invite.UsedCount
|
|
if remaining < 0 {
|
|
remaining = 0
|
|
}
|
|
return &domusecase.InviteView{
|
|
ID: invite.ID.Hex(),
|
|
TenantID: invite.TenantID,
|
|
NewUsersOnly: invite.NewUsersOnly,
|
|
RemainingUses: remaining,
|
|
}
|
|
}
|