template-monorepo/internal/model/auth/usecase/token_usecase_test.go

166 lines
4.2 KiB
Go

package usecase_test
import (
"context"
"sync"
"testing"
"time"
authconfig "gateway/internal/model/auth/config"
domrepo "gateway/internal/model/auth/domain/repository"
domusecase "gateway/internal/model/auth/domain/usecase"
authusecase "gateway/internal/model/auth/usecase"
"github.com/stretchr/testify/require"
)
const (
testTenantDev = "dev-tenant"
testUIDDev = "DEV-10000001"
)
func TestTokenUseCaseIssueAndRefresh(t *testing.T) {
t.Parallel()
uc := newTokenUC(t, nil)
pair, err := uc.IssuePair(context.Background(), &domusecase.IssuePairRequest{
TenantID: testTenantDev,
UID: testUIDDev,
})
require.NoError(t, err)
require.NotEmpty(t, pair.AccessToken)
require.NotEmpty(t, pair.RefreshToken)
require.Equal(t, int64(900), pair.ExpiresIn)
claims, err := uc.ParseAccessToken(context.Background(), pair.AccessToken)
require.NoError(t, err)
require.Equal(t, testTenantDev, claims.TenantID)
require.Equal(t, testUIDDev, claims.UID)
refreshed, err := uc.Refresh(context.Background(), pair.RefreshToken)
require.NoError(t, err)
require.NotEmpty(t, refreshed.AccessToken)
require.NotEqual(t, pair.AccessToken, refreshed.AccessToken)
}
func TestTokenUseCaseInvalidRefresh(t *testing.T) {
t.Parallel()
uc := newTokenUC(t, nil)
_, err := uc.Refresh(context.Background(), "not-a-jwt")
require.Error(t, err)
}
func TestTokenUseCaseLogoutRevokesPair(t *testing.T) {
t.Parallel()
store := newMemRevokeStore()
uc := newTokenUC(t, store)
pair, err := uc.IssuePair(context.Background(), &domusecase.IssuePairRequest{
TenantID: testTenantDev,
UID: testUIDDev,
})
require.NoError(t, err)
err = uc.Logout(context.Background(), &domusecase.LogoutRequest{AccessToken: pair.AccessToken})
require.NoError(t, err)
_, err = uc.ParseAccessToken(context.Background(), pair.AccessToken)
require.Error(t, err)
_, err = uc.Refresh(context.Background(), pair.RefreshToken)
require.Error(t, err)
}
func TestTokenUseCaseRefreshRotatesAndRevokesOldRefresh(t *testing.T) {
t.Parallel()
store := newMemRevokeStore()
uc := newTokenUC(t, store)
pair, err := uc.IssuePair(context.Background(), &domusecase.IssuePairRequest{
TenantID: testTenantDev,
UID: testUIDDev,
})
require.NoError(t, err)
refreshed, err := uc.Refresh(context.Background(), pair.RefreshToken)
require.NoError(t, err)
require.NotEqual(t, pair.RefreshToken, refreshed.RefreshToken)
_, err = uc.Refresh(context.Background(), pair.RefreshToken)
require.Error(t, err)
claims, err := uc.ParseAccessToken(context.Background(), refreshed.AccessToken)
require.NoError(t, err)
require.Equal(t, testUIDDev, claims.UID)
}
func newTokenUC(t *testing.T, revoke domrepo.TokenRevokeStore) domusecase.TokenUseCase {
t.Helper()
return authusecase.MustTokenUseCase(authusecase.TokenUseCaseParam{
Config: authconfig.Config{
AccessSecret: "access-secret-32-bytes-minimum!!",
RefreshSecret: "refresh-secret-32-bytes-minimum!",
AccessExpire: 900,
RefreshExpire: 604800,
ActiveKID: "v1",
},
Revoke: revoke,
})
}
type memRevokeStore struct {
mu sync.Mutex
pairs map[string]string
bl map[string]time.Time
}
func newMemRevokeStore() *memRevokeStore {
return &memRevokeStore{
pairs: make(map[string]string),
bl: make(map[string]time.Time),
}
}
func (s *memRevokeStore) SavePair(_ context.Context, accessJTI, refreshJTI string, _, _ time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
s.pairs[accessJTI] = refreshJTI
s.pairs[refreshJTI] = accessJTI
return nil
}
func (s *memRevokeStore) GetPairedJTI(_ context.Context, jti string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.pairs[jti], nil
}
func (s *memRevokeStore) DeletePair(_ context.Context, accessJTI, refreshJTI string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.pairs, accessJTI)
delete(s.pairs, refreshJTI)
return nil
}
func (s *memRevokeStore) Blacklist(_ context.Context, jti string, ttl time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
s.bl[jti] = time.Now().Add(ttl)
return nil
}
func (s *memRevokeStore) IsBlacklisted(_ context.Context, jti string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
exp, ok := s.bl[jti]
if !ok {
return false, nil
}
if time.Now().After(exp) {
delete(s.bl, jti)
return false, nil
}
return true, nil
}