218 lines
7.2 KiB
Go
218 lines
7.2 KiB
Go
|
|
package usecase_test
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"crypto/rand"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/alicebob/miniredis/v2"
|
||
|
|
"github.com/stretchr/testify/require"
|
||
|
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
||
|
|
|
||
|
|
libcrypto "gateway/internal/library/crypto"
|
||
|
|
redislib "gateway/internal/library/redis"
|
||
|
|
memberconfig "gateway/internal/model/member/config"
|
||
|
|
member "gateway/internal/model/member/domain"
|
||
|
|
"gateway/internal/model/member/repository"
|
||
|
|
libtotp "gateway/internal/model/member/totp"
|
||
|
|
"gateway/internal/model/member/usecase"
|
||
|
|
)
|
||
|
|
|
||
|
|
func newTOTPFixture(t *testing.T) *usecase.TOTPUseCaseParam {
|
||
|
|
t.Helper()
|
||
|
|
mr := miniredis.RunT(t)
|
||
|
|
rds, err := redislib.NewClient(redis.RedisConf{Host: mr.Addr(), Type: testRedisTypeNode})
|
||
|
|
require.NoError(t, err)
|
||
|
|
|
||
|
|
key := make([]byte, 32)
|
||
|
|
_, err = rand.Read(key)
|
||
|
|
require.NoError(t, err)
|
||
|
|
cipher, err := libcrypto.NewAESGCM(key)
|
||
|
|
require.NoError(t, err)
|
||
|
|
|
||
|
|
now := time.Unix(1_716_000_000, 0)
|
||
|
|
return &usecase.TOTPUseCaseParam{
|
||
|
|
Profile: repository.NewMemoryTOTPProfileRepository(),
|
||
|
|
Enroll: repository.NewRedisTOTPEnrollStore(rds),
|
||
|
|
Replay: repository.NewRedisTOTPReplayStore(rds),
|
||
|
|
Cipher: cipher,
|
||
|
|
Config: memberconfig.Config{}.Defaults(),
|
||
|
|
Now: func() time.Time { return now },
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// currentCode reads the staged secret from the enroll store, decrypts it,
|
||
|
|
// and returns a freshly generated TOTP code for the configured `now`.
|
||
|
|
func currentCode(t *testing.T, param *usecase.TOTPUseCaseParam, tenantID, uid string) string {
|
||
|
|
t.Helper()
|
||
|
|
cipherBlob, err := param.Enroll.Get(context.Background(), tenantID, uid)
|
||
|
|
require.NoError(t, err)
|
||
|
|
secret, err := param.Cipher.Decrypt(cipherBlob)
|
||
|
|
require.NoError(t, err)
|
||
|
|
code, err := libtotp.Generate(secret, param.Now(), time.Duration(param.Config.TOTP.PeriodSeconds)*time.Second, param.Config.TOTP.Digits)
|
||
|
|
require.NoError(t, err)
|
||
|
|
return code
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTOTPUseCase_EnrollConfirmAndStatus(t *testing.T) {
|
||
|
|
param := newTOTPFixture(t)
|
||
|
|
uc := usecase.MustTOTPUseCase(*param)
|
||
|
|
|
||
|
|
dto, err := uc.StartEnroll(context.Background(), "t1", "u1", "u1@example.com")
|
||
|
|
require.NoError(t, err)
|
||
|
|
require.NotEmpty(t, dto.OtpauthURL)
|
||
|
|
require.Equal(t, 6, dto.Digits)
|
||
|
|
require.Equal(t, 30, dto.PeriodSec)
|
||
|
|
|
||
|
|
code := currentCode(t, param, "t1", "u1")
|
||
|
|
backup, err := uc.ConfirmEnroll(context.Background(), "t1", "u1", code)
|
||
|
|
require.NoError(t, err)
|
||
|
|
require.Len(t, backup, param.Config.TOTP.BackupCodeCount)
|
||
|
|
|
||
|
|
status, err := uc.Status(context.Background(), "t1", "u1")
|
||
|
|
require.NoError(t, err)
|
||
|
|
require.True(t, status.Enrolled)
|
||
|
|
require.Equal(t, len(backup), status.BackupCodesRemaining)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTOTPUseCase_StartEnroll_AlreadyEnrolled(t *testing.T) {
|
||
|
|
param := newTOTPFixture(t)
|
||
|
|
uc := usecase.MustTOTPUseCase(*param)
|
||
|
|
|
||
|
|
_, err := uc.StartEnroll(context.Background(), "t1", "u1", "")
|
||
|
|
require.NoError(t, err)
|
||
|
|
code := currentCode(t, param, "t1", "u1")
|
||
|
|
_, err = uc.ConfirmEnroll(context.Background(), "t1", "u1", code)
|
||
|
|
require.NoError(t, err)
|
||
|
|
|
||
|
|
_, err = uc.StartEnroll(context.Background(), "t1", "u1", "")
|
||
|
|
require.Error(t, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTOTPUseCase_ConfirmEnroll_MissingStash(t *testing.T) {
|
||
|
|
param := newTOTPFixture(t)
|
||
|
|
uc := usecase.MustTOTPUseCase(*param)
|
||
|
|
|
||
|
|
_, err := uc.ConfirmEnroll(context.Background(), "t1", "u1", "123456")
|
||
|
|
require.Error(t, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTOTPUseCase_ConfirmEnroll_BadCode(t *testing.T) {
|
||
|
|
param := newTOTPFixture(t)
|
||
|
|
uc := usecase.MustTOTPUseCase(*param)
|
||
|
|
|
||
|
|
_, err := uc.StartEnroll(context.Background(), "t1", "u1", "")
|
||
|
|
require.NoError(t, err)
|
||
|
|
|
||
|
|
_, err = uc.ConfirmEnroll(context.Background(), "t1", "u1", "000000")
|
||
|
|
require.Error(t, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTOTPUseCase_VerifyCode_SuccessAndReplay(t *testing.T) {
|
||
|
|
param := newTOTPFixture(t)
|
||
|
|
uc := usecase.MustTOTPUseCase(*param)
|
||
|
|
|
||
|
|
_, err := uc.StartEnroll(context.Background(), "t1", "u1", "")
|
||
|
|
require.NoError(t, err)
|
||
|
|
code := currentCode(t, param, "t1", "u1")
|
||
|
|
_, err = uc.ConfirmEnroll(context.Background(), "t1", "u1", code)
|
||
|
|
require.NoError(t, err)
|
||
|
|
|
||
|
|
// Use a fresh TOTP code for verification; with `now` static the same code
|
||
|
|
// should validate then be rejected on replay.
|
||
|
|
verifyCode, err := libtotp.Generate(decryptStored(t, param, "t1", "u1"), param.Now(), 30*time.Second, 6)
|
||
|
|
require.NoError(t, err)
|
||
|
|
require.NoError(t, uc.VerifyCode(context.Background(), "t1", "u1", verifyCode))
|
||
|
|
|
||
|
|
err = uc.VerifyCode(context.Background(), "t1", "u1", verifyCode)
|
||
|
|
require.Error(t, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTOTPUseCase_VerifyCode_BackupCode(t *testing.T) {
|
||
|
|
param := newTOTPFixture(t)
|
||
|
|
uc := usecase.MustTOTPUseCase(*param)
|
||
|
|
|
||
|
|
_, err := uc.StartEnroll(context.Background(), "t1", "u1", "")
|
||
|
|
require.NoError(t, err)
|
||
|
|
code := currentCode(t, param, "t1", "u1")
|
||
|
|
backup, err := uc.ConfirmEnroll(context.Background(), "t1", "u1", code)
|
||
|
|
require.NoError(t, err)
|
||
|
|
require.NotEmpty(t, backup)
|
||
|
|
|
||
|
|
require.NoError(t, uc.VerifyCode(context.Background(), "t1", "u1", backup[0]))
|
||
|
|
|
||
|
|
// Same backup code cannot be reused.
|
||
|
|
require.Error(t, uc.VerifyCode(context.Background(), "t1", "u1", backup[0]))
|
||
|
|
|
||
|
|
status, err := uc.Status(context.Background(), "t1", "u1")
|
||
|
|
require.NoError(t, err)
|
||
|
|
require.Equal(t, len(backup)-1, status.BackupCodesRemaining)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTOTPUseCase_VerifyCode_NotEnrolled(t *testing.T) {
|
||
|
|
param := newTOTPFixture(t)
|
||
|
|
uc := usecase.MustTOTPUseCase(*param)
|
||
|
|
|
||
|
|
err := uc.VerifyCode(context.Background(), "t1", "u1", "123456")
|
||
|
|
require.Error(t, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTOTPUseCase_Disable(t *testing.T) {
|
||
|
|
param := newTOTPFixture(t)
|
||
|
|
uc := usecase.MustTOTPUseCase(*param)
|
||
|
|
|
||
|
|
_, err := uc.StartEnroll(context.Background(), "t1", "u1", "")
|
||
|
|
require.NoError(t, err)
|
||
|
|
code := currentCode(t, param, "t1", "u1")
|
||
|
|
_, err = uc.ConfirmEnroll(context.Background(), "t1", "u1", code)
|
||
|
|
require.NoError(t, err)
|
||
|
|
|
||
|
|
require.NoError(t, uc.Disable(context.Background(), "t1", "u1"))
|
||
|
|
status, err := uc.Status(context.Background(), "t1", "u1")
|
||
|
|
require.NoError(t, err)
|
||
|
|
require.False(t, status.Enrolled)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTOTPUseCase_RegenerateBackupCodes(t *testing.T) {
|
||
|
|
param := newTOTPFixture(t)
|
||
|
|
uc := usecase.MustTOTPUseCase(*param)
|
||
|
|
|
||
|
|
_, err := uc.StartEnroll(context.Background(), "t1", "u1", "")
|
||
|
|
require.NoError(t, err)
|
||
|
|
code := currentCode(t, param, "t1", "u1")
|
||
|
|
first, err := uc.ConfirmEnroll(context.Background(), "t1", "u1", code)
|
||
|
|
require.NoError(t, err)
|
||
|
|
|
||
|
|
second, err := uc.RegenerateBackupCodes(context.Background(), "t1", "u1")
|
||
|
|
require.NoError(t, err)
|
||
|
|
require.Len(t, second, len(first))
|
||
|
|
require.NotEqual(t, first, second)
|
||
|
|
|
||
|
|
require.Error(t, uc.VerifyCode(context.Background(), "t1", "u1", first[0]))
|
||
|
|
require.NoError(t, uc.VerifyCode(context.Background(), "t1", "u1", second[0]))
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTOTPUseCase_RegenerateBackupCodes_NotEnrolled(t *testing.T) {
|
||
|
|
param := newTOTPFixture(t)
|
||
|
|
uc := usecase.MustTOTPUseCase(*param)
|
||
|
|
|
||
|
|
_, err := uc.RegenerateBackupCodes(context.Background(), "t1", "u1")
|
||
|
|
require.Error(t, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// decryptStored is a helper for tests to read the encrypted secret directly
|
||
|
|
// from the profile after enrollment to compute fresh verification codes.
|
||
|
|
func decryptStored(t *testing.T, param *usecase.TOTPUseCaseParam, tenantID, uid string) []byte {
|
||
|
|
t.Helper()
|
||
|
|
rec, err := param.Profile.Get(context.Background(), tenantID, uid)
|
||
|
|
require.NoError(t, err)
|
||
|
|
secret, err := param.Cipher.Decrypt(rec.SecretCipher)
|
||
|
|
require.NoError(t, err)
|
||
|
|
return secret
|
||
|
|
}
|
||
|
|
|
||
|
|
// silence the unused `member` import; intentional retained for symmetry with
|
||
|
|
// otp_usecase_test.go which references the package for sentinels.
|
||
|
|
var _ = member.ErrTOTPNotEnrolled
|