200 lines
4.8 KiB
Go
200 lines
4.8 KiB
Go
|
|
package usecase_test
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"sync"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
authdomain "gateway/internal/model/auth/domain"
|
||
|
|
"gateway/internal/model/auth/domain/entity"
|
||
|
|
domrepo "gateway/internal/model/auth/domain/repository"
|
||
|
|
domusecase "gateway/internal/model/auth/domain/usecase"
|
||
|
|
authusecase "gateway/internal/model/auth/usecase"
|
||
|
|
|
||
|
|
"github.com/stretchr/testify/require"
|
||
|
|
"go.mongodb.org/mongo-driver/v2/bson"
|
||
|
|
)
|
||
|
|
|
||
|
|
const testTenantAcme = "acme"
|
||
|
|
|
||
|
|
func TestInviteUseCaseValidateAndConsume(t *testing.T) {
|
||
|
|
t.Parallel()
|
||
|
|
repo := newMemoryInviteRepo()
|
||
|
|
lock := newMemoryInviteLock()
|
||
|
|
uc := authusecase.MustInviteUseCase(authusecase.InviteUseCaseParam{
|
||
|
|
Repo: repo,
|
||
|
|
Lock: lock,
|
||
|
|
})
|
||
|
|
ctx := context.Background()
|
||
|
|
|
||
|
|
repo.seed(&entity.InviteCode{
|
||
|
|
ID: bson.NewObjectID(),
|
||
|
|
TenantID: testTenantAcme,
|
||
|
|
CodeHash: authdomain.HashInviteCode("BETA-2026-TEST"),
|
||
|
|
MaxUses: 2,
|
||
|
|
NewUsersOnly: true,
|
||
|
|
})
|
||
|
|
|
||
|
|
view, err := uc.Validate(ctx, &domusecase.ValidateInviteRequest{
|
||
|
|
TenantID: testTenantAcme,
|
||
|
|
Code: "beta-2026-test",
|
||
|
|
})
|
||
|
|
require.NoError(t, err)
|
||
|
|
require.Equal(t, testTenantAcme, view.TenantID)
|
||
|
|
require.Equal(t, int64(2), view.RemainingUses)
|
||
|
|
require.True(t, view.NewUsersOnly)
|
||
|
|
|
||
|
|
consumed, err := uc.Consume(ctx, &domusecase.ConsumeInviteRequest{
|
||
|
|
TenantID: testTenantAcme,
|
||
|
|
Code: "BETA-2026-TEST",
|
||
|
|
})
|
||
|
|
require.NoError(t, err)
|
||
|
|
require.Equal(t, int64(1), consumed.UsedCount)
|
||
|
|
|
||
|
|
view, err = uc.Validate(ctx, &domusecase.ValidateInviteRequest{
|
||
|
|
TenantID: testTenantAcme,
|
||
|
|
Code: "BETA-2026-TEST",
|
||
|
|
})
|
||
|
|
require.NoError(t, err)
|
||
|
|
require.Equal(t, int64(1), view.RemainingUses)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestInviteUseCaseExpired(t *testing.T) {
|
||
|
|
t.Parallel()
|
||
|
|
repo := newMemoryInviteRepo()
|
||
|
|
uc := authusecase.MustInviteUseCase(authusecase.InviteUseCaseParam{
|
||
|
|
Repo: repo,
|
||
|
|
Lock: newMemoryInviteLock(),
|
||
|
|
})
|
||
|
|
|
||
|
|
repo.seed(&entity.InviteCode{
|
||
|
|
ID: bson.NewObjectID(),
|
||
|
|
TenantID: testTenantAcme,
|
||
|
|
CodeHash: authdomain.HashInviteCode("EXPIRED"),
|
||
|
|
MaxUses: 1,
|
||
|
|
ExpiresAt: time.Now().UTC().Add(-time.Hour).UnixMilli(),
|
||
|
|
})
|
||
|
|
|
||
|
|
_, err := uc.Validate(context.Background(), &domusecase.ValidateInviteRequest{
|
||
|
|
TenantID: testTenantAcme,
|
||
|
|
Code: "EXPIRED",
|
||
|
|
})
|
||
|
|
require.Error(t, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestInviteUseCaseConcurrentConsume(t *testing.T) {
|
||
|
|
t.Parallel()
|
||
|
|
repo := newMemoryInviteRepo()
|
||
|
|
lock := newMemoryInviteLock()
|
||
|
|
uc := authusecase.MustInviteUseCase(authusecase.InviteUseCaseParam{
|
||
|
|
Repo: repo,
|
||
|
|
Lock: lock,
|
||
|
|
})
|
||
|
|
|
||
|
|
repo.seed(&entity.InviteCode{
|
||
|
|
ID: bson.NewObjectID(),
|
||
|
|
TenantID: testTenantAcme,
|
||
|
|
CodeHash: authdomain.HashInviteCode("ONCE"),
|
||
|
|
MaxUses: 1,
|
||
|
|
})
|
||
|
|
|
||
|
|
var wg sync.WaitGroup
|
||
|
|
successes := make(chan struct{}, 2)
|
||
|
|
failures := make(chan struct{}, 2)
|
||
|
|
for i := 0; i < 2; i++ {
|
||
|
|
wg.Add(1)
|
||
|
|
go func() {
|
||
|
|
defer wg.Done()
|
||
|
|
_, err := uc.Consume(context.Background(), &domusecase.ConsumeInviteRequest{
|
||
|
|
TenantID: testTenantAcme,
|
||
|
|
Code: "ONCE",
|
||
|
|
})
|
||
|
|
if err == nil {
|
||
|
|
successes <- struct{}{}
|
||
|
|
return
|
||
|
|
}
|
||
|
|
failures <- struct{}{}
|
||
|
|
}()
|
||
|
|
}
|
||
|
|
wg.Wait()
|
||
|
|
close(successes)
|
||
|
|
close(failures)
|
||
|
|
require.Len(t, successes, 1)
|
||
|
|
require.Len(t, failures, 1)
|
||
|
|
}
|
||
|
|
|
||
|
|
type memoryInviteRepo struct {
|
||
|
|
mu sync.Mutex
|
||
|
|
items map[string]*entity.InviteCode
|
||
|
|
}
|
||
|
|
|
||
|
|
func newMemoryInviteRepo() *memoryInviteRepo {
|
||
|
|
return &memoryInviteRepo{items: make(map[string]*entity.InviteCode)}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (r *memoryInviteRepo) key(tenantID, codeHash string) string {
|
||
|
|
return tenantID + ":" + codeHash
|
||
|
|
}
|
||
|
|
|
||
|
|
func (r *memoryInviteRepo) seed(invite *entity.InviteCode) {
|
||
|
|
r.mu.Lock()
|
||
|
|
defer r.mu.Unlock()
|
||
|
|
cp := *invite
|
||
|
|
r.items[r.key(invite.TenantID, invite.CodeHash)] = &cp
|
||
|
|
}
|
||
|
|
|
||
|
|
func (r *memoryInviteRepo) GetByTenantAndCodeHash(_ context.Context, tenantID, codeHash string) (*entity.InviteCode, error) {
|
||
|
|
r.mu.Lock()
|
||
|
|
defer r.mu.Unlock()
|
||
|
|
invite, ok := r.items[r.key(tenantID, codeHash)]
|
||
|
|
if !ok {
|
||
|
|
return nil, authdomain.ErrInviteNotFound
|
||
|
|
}
|
||
|
|
cp := *invite
|
||
|
|
return &cp, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (r *memoryInviteRepo) ConsumeOne(_ context.Context, id bson.ObjectID) (*entity.InviteCode, error) {
|
||
|
|
r.mu.Lock()
|
||
|
|
defer r.mu.Unlock()
|
||
|
|
for _, invite := range r.items {
|
||
|
|
if invite.ID != id {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
now := time.Now().UTC().UnixMilli()
|
||
|
|
if invite.ExpiresAt > 0 && invite.ExpiresAt <= now {
|
||
|
|
return nil, authdomain.ErrInviteExpired
|
||
|
|
}
|
||
|
|
if invite.UsedCount >= invite.MaxUses {
|
||
|
|
return nil, authdomain.ErrInviteExhausted
|
||
|
|
}
|
||
|
|
invite.UsedCount++
|
||
|
|
cp := *invite
|
||
|
|
return &cp, nil
|
||
|
|
}
|
||
|
|
return nil, authdomain.ErrInviteNotFound
|
||
|
|
}
|
||
|
|
|
||
|
|
var _ domrepo.InviteRepository = (*memoryInviteRepo)(nil)
|
||
|
|
|
||
|
|
type memoryInviteLock struct {
|
||
|
|
mu sync.Mutex
|
||
|
|
}
|
||
|
|
|
||
|
|
func newMemoryInviteLock() *memoryInviteLock {
|
||
|
|
return &memoryInviteLock{}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (l *memoryInviteLock) TryLock(_ context.Context, _, _ string) (bool, error) {
|
||
|
|
l.mu.Lock()
|
||
|
|
return true, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (l *memoryInviteLock) Unlock(_ context.Context, _, _ string) error {
|
||
|
|
l.mu.Unlock()
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
var _ domrepo.InviteConsumeLock = (*memoryInviteLock)(nil)
|