166 lines
4.2 KiB
Go
166 lines
4.2 KiB
Go
|
|
package usecase_test
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"sync"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
authconfig "gateway/internal/model/auth/config"
|
||
|
|
domrepo "gateway/internal/model/auth/domain/repository"
|
||
|
|
domusecase "gateway/internal/model/auth/domain/usecase"
|
||
|
|
authusecase "gateway/internal/model/auth/usecase"
|
||
|
|
|
||
|
|
"github.com/stretchr/testify/require"
|
||
|
|
)
|
||
|
|
|
||
|
|
const (
|
||
|
|
testTenantDev = "dev-tenant"
|
||
|
|
testUIDDev = "DEV-10000001"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestTokenUseCaseIssueAndRefresh(t *testing.T) {
|
||
|
|
t.Parallel()
|
||
|
|
uc := newTokenUC(t, nil)
|
||
|
|
|
||
|
|
pair, err := uc.IssuePair(context.Background(), &domusecase.IssuePairRequest{
|
||
|
|
TenantID: testTenantDev,
|
||
|
|
UID: testUIDDev,
|
||
|
|
})
|
||
|
|
require.NoError(t, err)
|
||
|
|
require.NotEmpty(t, pair.AccessToken)
|
||
|
|
require.NotEmpty(t, pair.RefreshToken)
|
||
|
|
require.Equal(t, int64(900), pair.ExpiresIn)
|
||
|
|
|
||
|
|
claims, err := uc.ParseAccessToken(context.Background(), pair.AccessToken)
|
||
|
|
require.NoError(t, err)
|
||
|
|
require.Equal(t, testTenantDev, claims.TenantID)
|
||
|
|
require.Equal(t, testUIDDev, claims.UID)
|
||
|
|
|
||
|
|
refreshed, err := uc.Refresh(context.Background(), pair.RefreshToken)
|
||
|
|
require.NoError(t, err)
|
||
|
|
require.NotEmpty(t, refreshed.AccessToken)
|
||
|
|
require.NotEqual(t, pair.AccessToken, refreshed.AccessToken)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTokenUseCaseInvalidRefresh(t *testing.T) {
|
||
|
|
t.Parallel()
|
||
|
|
uc := newTokenUC(t, nil)
|
||
|
|
_, err := uc.Refresh(context.Background(), "not-a-jwt")
|
||
|
|
require.Error(t, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTokenUseCaseLogoutRevokesPair(t *testing.T) {
|
||
|
|
t.Parallel()
|
||
|
|
store := newMemRevokeStore()
|
||
|
|
uc := newTokenUC(t, store)
|
||
|
|
|
||
|
|
pair, err := uc.IssuePair(context.Background(), &domusecase.IssuePairRequest{
|
||
|
|
TenantID: testTenantDev,
|
||
|
|
UID: testUIDDev,
|
||
|
|
})
|
||
|
|
require.NoError(t, err)
|
||
|
|
|
||
|
|
err = uc.Logout(context.Background(), &domusecase.LogoutRequest{AccessToken: pair.AccessToken})
|
||
|
|
require.NoError(t, err)
|
||
|
|
|
||
|
|
_, err = uc.ParseAccessToken(context.Background(), pair.AccessToken)
|
||
|
|
require.Error(t, err)
|
||
|
|
|
||
|
|
_, err = uc.Refresh(context.Background(), pair.RefreshToken)
|
||
|
|
require.Error(t, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTokenUseCaseRefreshRotatesAndRevokesOldRefresh(t *testing.T) {
|
||
|
|
t.Parallel()
|
||
|
|
store := newMemRevokeStore()
|
||
|
|
uc := newTokenUC(t, store)
|
||
|
|
|
||
|
|
pair, err := uc.IssuePair(context.Background(), &domusecase.IssuePairRequest{
|
||
|
|
TenantID: testTenantDev,
|
||
|
|
UID: testUIDDev,
|
||
|
|
})
|
||
|
|
require.NoError(t, err)
|
||
|
|
|
||
|
|
refreshed, err := uc.Refresh(context.Background(), pair.RefreshToken)
|
||
|
|
require.NoError(t, err)
|
||
|
|
require.NotEqual(t, pair.RefreshToken, refreshed.RefreshToken)
|
||
|
|
|
||
|
|
_, err = uc.Refresh(context.Background(), pair.RefreshToken)
|
||
|
|
require.Error(t, err)
|
||
|
|
|
||
|
|
claims, err := uc.ParseAccessToken(context.Background(), refreshed.AccessToken)
|
||
|
|
require.NoError(t, err)
|
||
|
|
require.Equal(t, testUIDDev, claims.UID)
|
||
|
|
}
|
||
|
|
|
||
|
|
func newTokenUC(t *testing.T, revoke domrepo.TokenRevokeStore) domusecase.TokenUseCase {
|
||
|
|
t.Helper()
|
||
|
|
return authusecase.MustTokenUseCase(authusecase.TokenUseCaseParam{
|
||
|
|
Config: authconfig.Config{
|
||
|
|
AccessSecret: "access-secret-32-bytes-minimum!!",
|
||
|
|
RefreshSecret: "refresh-secret-32-bytes-minimum!",
|
||
|
|
AccessExpire: 900,
|
||
|
|
RefreshExpire: 604800,
|
||
|
|
ActiveKID: "v1",
|
||
|
|
},
|
||
|
|
Revoke: revoke,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
type memRevokeStore struct {
|
||
|
|
mu sync.Mutex
|
||
|
|
pairs map[string]string
|
||
|
|
bl map[string]time.Time
|
||
|
|
}
|
||
|
|
|
||
|
|
func newMemRevokeStore() *memRevokeStore {
|
||
|
|
return &memRevokeStore{
|
||
|
|
pairs: make(map[string]string),
|
||
|
|
bl: make(map[string]time.Time),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *memRevokeStore) SavePair(_ context.Context, accessJTI, refreshJTI string, _, _ time.Duration) error {
|
||
|
|
s.mu.Lock()
|
||
|
|
defer s.mu.Unlock()
|
||
|
|
s.pairs[accessJTI] = refreshJTI
|
||
|
|
s.pairs[refreshJTI] = accessJTI
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *memRevokeStore) GetPairedJTI(_ context.Context, jti string) (string, error) {
|
||
|
|
s.mu.Lock()
|
||
|
|
defer s.mu.Unlock()
|
||
|
|
return s.pairs[jti], nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *memRevokeStore) DeletePair(_ context.Context, accessJTI, refreshJTI string) error {
|
||
|
|
s.mu.Lock()
|
||
|
|
defer s.mu.Unlock()
|
||
|
|
delete(s.pairs, accessJTI)
|
||
|
|
delete(s.pairs, refreshJTI)
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *memRevokeStore) Blacklist(_ context.Context, jti string, ttl time.Duration) error {
|
||
|
|
s.mu.Lock()
|
||
|
|
defer s.mu.Unlock()
|
||
|
|
s.bl[jti] = time.Now().Add(ttl)
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *memRevokeStore) IsBlacklisted(_ context.Context, jti string) (bool, error) {
|
||
|
|
s.mu.Lock()
|
||
|
|
defer s.mu.Unlock()
|
||
|
|
exp, ok := s.bl[jti]
|
||
|
|
if !ok {
|
||
|
|
return false, nil
|
||
|
|
}
|
||
|
|
if time.Now().After(exp) {
|
||
|
|
delete(s.bl, jti)
|
||
|
|
return false, nil
|
||
|
|
}
|
||
|
|
return true, nil
|
||
|
|
}
|